mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-08 08:11:38 +00:00
Merge branch 'main' into docker_s6
This commit is contained in:
commit
59da190512
417 changed files with 26434 additions and 3321 deletions
|
|
@ -971,6 +971,18 @@ class TestSessionConfiguration:
|
|||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
fake_resolve_runtime_provider,
|
||||
)
|
||||
# Pin the parser so this test doesn't depend on live
|
||||
# ``_KNOWN_PROVIDER_NAMES`` / ``_PROVIDER_ALIASES`` module state
|
||||
# (sibling of the same hardening on
|
||||
# ``test_model_switch_uses_requested_provider``).
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.parse_model_input",
|
||||
lambda raw, current: ("anthropic", "claude-sonnet-4-6"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.detect_provider_for_model",
|
||||
lambda model, current: None,
|
||||
)
|
||||
manager = SessionManager(db=SessionDB(tmp_path / "state.db"))
|
||||
|
||||
with patch("run_agent.AIAgent", side_effect=fake_agent):
|
||||
|
|
@ -1191,6 +1203,48 @@ class TestPrompt:
|
|||
assert len(agent_chunks) == 1
|
||||
assert agent_chunks[0].content.text == "streamed answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_delivers_transformed_response_after_streaming(self, agent):
|
||||
"""If a transform_llm_output plugin hook modifies the response after
|
||||
streaming, ACP must deliver the transformed final_response so the
|
||||
appended/rewritten text reaches the client.
|
||||
"""
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
state.agent.stream_delta_callback("original answer")
|
||||
return {
|
||||
"final_response": "original answer\n\n[plugin appended this]",
|
||||
"response_transformed": True,
|
||||
"messages": [],
|
||||
}
|
||||
|
||||
state.agent.run_conversation = mock_run
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="hello")]
|
||||
await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
updates = [
|
||||
call.kwargs.get("update") or call.args[1]
|
||||
for call in mock_conn.session_update.call_args_list
|
||||
]
|
||||
# The streamed chunk and the post-stream transformed message should
|
||||
# both be present (final delivery is a separate update_agent_message_text
|
||||
# call carrying the full transformed text).
|
||||
all_texts = [
|
||||
getattr(getattr(u, "content", None), "text", None)
|
||||
for u in updates
|
||||
]
|
||||
assert any(
|
||||
text and "[plugin appended this]" in text for text in all_texts
|
||||
), f"expected transformed final to be delivered, got: {all_texts!r}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_auto_titles_session(self, agent):
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
|
|
@ -1543,6 +1597,20 @@ class TestSlashCommands:
|
|||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
fake_resolve_runtime_provider,
|
||||
)
|
||||
# Pin the model-string parser independently of the live
|
||||
# ``_KNOWN_PROVIDER_NAMES`` / ``_PROVIDER_ALIASES`` module state.
|
||||
# Otherwise any test in the same xdist worker that mutates those
|
||||
# globals (e.g. registers a custom provider that shadows
|
||||
# ``anthropic``) flakes this one — observed once in CI as
|
||||
# ``'custom' == 'anthropic'``.
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.parse_model_input",
|
||||
lambda raw, current: ("anthropic", "claude-sonnet-4-6"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.detect_provider_for_model",
|
||||
lambda model, current: None,
|
||||
)
|
||||
manager = SessionManager(db=SessionDB(tmp_path / "state.db"))
|
||||
|
||||
with patch("run_agent.AIAgent", side_effect=fake_agent):
|
||||
|
|
|
|||
250
tests/agent/test_anthropic_mcp_prefix_strip.py
Normal file
250
tests/agent/test_anthropic_mcp_prefix_strip.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
"""Tests for GH-25255: Anthropic OAuth mcp_ prefix stripping.
|
||||
|
||||
When strip_tool_prefix=True (Anthropic OAuth path), the transport must only
|
||||
strip the ``mcp_`` prefix from OAuth-injected tools, NOT from Hermes-native
|
||||
MCP server tools that are registered under their full ``mcp_<server>_<tool>``
|
||||
name in the tool registry.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tool_use_block(name: str, block_id: str = "tc_1", input_data: dict | None = None):
|
||||
"""Create a fake Anthropic tool_use content block."""
|
||||
return SimpleNamespace(
|
||||
type="tool_use",
|
||||
id=block_id,
|
||||
name=name,
|
||||
input=input_data or {"query": "test"},
|
||||
)
|
||||
|
||||
|
||||
def _make_response(*blocks, stop_reason="end_turn"):
|
||||
"""Create a fake Anthropic Messages response."""
|
||||
return SimpleNamespace(
|
||||
content=list(blocks),
|
||||
stop_reason=stop_reason,
|
||||
model="claude-sonnet-4",
|
||||
usage=SimpleNamespace(input_tokens=100, output_tokens=50),
|
||||
)
|
||||
|
||||
|
||||
class _FakeRegistry:
|
||||
"""Minimal fake tool registry for testing prefix stripping logic."""
|
||||
|
||||
def __init__(self, registered_names: set[str]):
|
||||
self._names = registered_names
|
||||
|
||||
def get_entry(self, name: str):
|
||||
if name in self._names:
|
||||
return SimpleNamespace(name=name) # truthy = tool exists
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAnthropicMcpPrefixStrip:
|
||||
"""Verify that strip_tool_prefix only strips OAuth-injected prefixes."""
|
||||
|
||||
def _get_transport(self):
|
||||
from agent.transports.anthropic import AnthropicTransport
|
||||
return AnthropicTransport()
|
||||
|
||||
def test_strips_prefix_for_oauth_injected_tool(self):
|
||||
"""OAuth tools: mcp_read_file -> read_file (stripped).
|
||||
|
||||
The tool was registered as 'read_file' in the registry.
|
||||
Anthropic sees 'mcp_read_file' because Hermes adds the prefix.
|
||||
On response, we must strip it back to 'read_file'.
|
||||
"""
|
||||
transport = self._get_transport()
|
||||
block = _make_tool_use_block("mcp_read_file")
|
||||
response = _make_response(block)
|
||||
|
||||
registry = _FakeRegistry({"read_file", "terminal", "web_search"})
|
||||
with patch("tools.registry.registry", registry):
|
||||
result = transport.normalize_response(response, strip_tool_prefix=True)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "read_file"
|
||||
|
||||
def test_preserves_native_mcp_server_tool_name(self):
|
||||
"""Native MCP tools: mcp_composio_SEARCH -> mcp_composio_SEARCH (kept).
|
||||
|
||||
The tool is registered with the full mcp_ prefix in the registry.
|
||||
Stripping would break registry lookup.
|
||||
"""
|
||||
transport = self._get_transport()
|
||||
block = _make_tool_use_block("mcp_composio_COMPOSIO_SEARCH_TOOLS")
|
||||
response = _make_response(block)
|
||||
|
||||
registry = _FakeRegistry({
|
||||
"mcp_composio_COMPOSIO_SEARCH_TOOLS",
|
||||
"mcp_composio_COMPOSIO_GET_TOOL_SCHEMAS",
|
||||
"read_file",
|
||||
})
|
||||
with patch("tools.registry.registry", registry):
|
||||
result = transport.normalize_response(response, strip_tool_prefix=True)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "mcp_composio_COMPOSIO_SEARCH_TOOLS"
|
||||
|
||||
def test_no_strip_when_flag_false(self):
|
||||
"""When strip_tool_prefix=False, names are never modified."""
|
||||
transport = self._get_transport()
|
||||
block = _make_tool_use_block("mcp_read_file")
|
||||
response = _make_response(block)
|
||||
|
||||
registry = _FakeRegistry({"read_file"})
|
||||
with patch("tools.registry.registry", registry):
|
||||
result = transport.normalize_response(response, strip_tool_prefix=False)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "mcp_read_file"
|
||||
|
||||
def test_no_strip_when_not_mcp_prefixed(self):
|
||||
"""Non-mcp_ names are untouched regardless of strip flag."""
|
||||
transport = self._get_transport()
|
||||
block = _make_tool_use_block("web_search")
|
||||
response = _make_response(block)
|
||||
|
||||
registry = _FakeRegistry({"web_search"})
|
||||
with patch("tools.registry.registry", registry):
|
||||
result = transport.normalize_response(response, strip_tool_prefix=True)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "web_search"
|
||||
|
||||
def test_preserves_name_when_neither_in_registry(self):
|
||||
"""When neither stripped nor full name is in registry, keep full name.
|
||||
|
||||
Safety fallback: if we can't determine the type, prefer the full name
|
||||
since it's what the LLM was told about.
|
||||
"""
|
||||
transport = self._get_transport()
|
||||
block = _make_tool_use_block("mcp_unknown_tool")
|
||||
response = _make_response(block)
|
||||
|
||||
registry = _FakeRegistry({"read_file"}) # neither name registered
|
||||
with patch("tools.registry.registry", registry):
|
||||
result = transport.normalize_response(response, strip_tool_prefix=True)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "mcp_unknown_tool"
|
||||
|
||||
def test_mixed_tools_same_response(self):
|
||||
"""Both OAuth and native MCP tools in the same response."""
|
||||
transport = self._get_transport()
|
||||
block1 = _make_tool_use_block("mcp_read_file", block_id="tc_1")
|
||||
block2 = _make_tool_use_block("mcp_composio_SEARCH", block_id="tc_2")
|
||||
block3 = _make_tool_use_block("mcp_composio_SEARCH", block_id="tc_3") # also registered natively
|
||||
response = _make_response(block1, block2, block3)
|
||||
|
||||
registry = _FakeRegistry({
|
||||
"read_file", # OAuth-injected
|
||||
"mcp_composio_SEARCH", # native MCP
|
||||
})
|
||||
with patch("tools.registry.registry", registry):
|
||||
result = transport.normalize_response(response, strip_tool_prefix=True)
|
||||
|
||||
assert len(result.tool_calls) == 3
|
||||
# OAuth tool: stripped
|
||||
assert result.tool_calls[0].name == "read_file"
|
||||
# Native MCP: preserved (both stripped and full are registered, full wins)
|
||||
assert result.tool_calls[1].name == "mcp_composio_SEARCH"
|
||||
assert result.tool_calls[2].name == "mcp_composio_SEARCH"
|
||||
|
||||
def test_both_stripped_and_full_registered_prefers_full(self):
|
||||
"""Edge case: both 'foo' and 'mcp_foo' exist in registry.
|
||||
|
||||
Keep 'mcp_foo' (the original name) since it's what the LLM requested.
|
||||
"""
|
||||
transport = self._get_transport()
|
||||
block = _make_tool_use_block("mcp_foo")
|
||||
response = _make_response(block)
|
||||
|
||||
registry = _FakeRegistry({"foo", "mcp_foo"})
|
||||
with patch("tools.registry.registry", registry):
|
||||
result = transport.normalize_response(response, strip_tool_prefix=True)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
# Both exist — the condition `get_entry(stripped) and not get_entry(name)`
|
||||
# is False because get_entry(name) IS truthy, so we keep the full name.
|
||||
assert result.tool_calls[0].name == "mcp_foo"
|
||||
|
||||
|
||||
class TestAnthropicOAuthOutgoingPrefix:
|
||||
"""Verify the outgoing-side companion fix: build_anthropic_kwargs must not
|
||||
double-prefix tool names that already start with ``mcp_`` (native MCP server
|
||||
tools registered as ``mcp_<server>_<tool>``). GH-25255."""
|
||||
|
||||
def _build(self, tools, is_oauth=True):
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
return build_anthropic_kwargs(
|
||||
model="claude-sonnet-4-6",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
tools=tools,
|
||||
max_tokens=4096,
|
||||
reasoning_config=None,
|
||||
is_oauth=is_oauth,
|
||||
)
|
||||
|
||||
def test_oauth_adds_prefix_to_bare_tool_name(self):
|
||||
"""OAuth + bare name → prefix added (existing Claude Code convention)."""
|
||||
kwargs = self._build([{
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "description": "x", "parameters": {}},
|
||||
}])
|
||||
names = [t["name"] for t in kwargs["tools"]]
|
||||
assert names == ["mcp_read_file"]
|
||||
|
||||
def test_oauth_does_not_double_prefix_native_mcp_tool(self):
|
||||
"""OAuth + already-prefixed native MCP name → left alone."""
|
||||
kwargs = self._build([{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "mcp_composio_COMPOSIO_SEARCH_TOOLS",
|
||||
"description": "x",
|
||||
"parameters": {},
|
||||
},
|
||||
}])
|
||||
names = [t["name"] for t in kwargs["tools"]]
|
||||
# Must NOT become "mcp_mcp_composio_..." — that breaks the round-trip
|
||||
# because normalize_response only strips ONE mcp_ prefix.
|
||||
assert names == ["mcp_composio_COMPOSIO_SEARCH_TOOLS"]
|
||||
|
||||
def test_oauth_mixed_native_and_bare_tools(self):
|
||||
"""Mixed: native MCP preserved, bare names prefixed."""
|
||||
kwargs = self._build([
|
||||
{"type": "function", "function": {"name": "read_file",
|
||||
"description": "x", "parameters": {}}},
|
||||
{"type": "function", "function": {"name": "mcp_composio_SEARCH",
|
||||
"description": "y", "parameters": {}}},
|
||||
{"type": "function", "function": {"name": "terminal",
|
||||
"description": "z", "parameters": {}}},
|
||||
])
|
||||
names = sorted(t["name"] for t in kwargs["tools"])
|
||||
assert names == ["mcp_composio_SEARCH", "mcp_read_file", "mcp_terminal"]
|
||||
|
||||
def test_non_oauth_path_untouched(self):
|
||||
"""Non-OAuth requests never get the prefix — schemas pass through as-is."""
|
||||
kwargs = self._build([
|
||||
{"type": "function", "function": {"name": "read_file",
|
||||
"description": "x", "parameters": {}}},
|
||||
{"type": "function", "function": {"name": "mcp_composio_SEARCH",
|
||||
"description": "y", "parameters": {}}},
|
||||
], is_oauth=False)
|
||||
names = sorted(t["name"] for t in kwargs["tools"])
|
||||
assert names == ["mcp_composio_SEARCH", "read_file"]
|
||||
|
|
@ -198,22 +198,32 @@ class TestGatewayBridgeCodeParity:
|
|||
"""Verify the gateway/run.py config bridge contains the auxiliary section."""
|
||||
|
||||
def test_gateway_has_auxiliary_bridge(self):
|
||||
"""The gateway config bridge must include auxiliary.* bridging."""
|
||||
"""The gateway config bridge must include auxiliary.* bridging.
|
||||
|
||||
After the plugin-aux-task API refactor (2026-05), gateway env-var
|
||||
names are derived dynamically (``AUXILIARY_<KEY_UPPER>_*``) so the
|
||||
literal strings ``AUXILIARY_VISION_PROVIDER`` etc. no longer appear
|
||||
in source. Assert the dynamic shape and the canonical built-in keys
|
||||
bridged set instead.
|
||||
"""
|
||||
gateway_path = Path(__file__).parent.parent.parent / "gateway" / "run.py"
|
||||
# Pin encoding to UTF-8: source files in this repo are UTF-8, but
|
||||
# Path.read_text() defaults to the system locale — which is cp1252
|
||||
# on most Western Windows installs and crashes as soon as the file
|
||||
# contains any non-ASCII byte (e.g. an em-dash in a comment).
|
||||
content = gateway_path.read_text(encoding="utf-8")
|
||||
# Check for key patterns that indicate the bridge is present
|
||||
assert "AUXILIARY_VISION_PROVIDER" in content
|
||||
assert "AUXILIARY_VISION_MODEL" in content
|
||||
assert "AUXILIARY_VISION_BASE_URL" in content
|
||||
assert "AUXILIARY_VISION_API_KEY" in content
|
||||
assert "AUXILIARY_WEB_EXTRACT_PROVIDER" in content
|
||||
assert "AUXILIARY_WEB_EXTRACT_MODEL" in content
|
||||
assert "AUXILIARY_WEB_EXTRACT_BASE_URL" in content
|
||||
assert "AUXILIARY_WEB_EXTRACT_API_KEY" in content
|
||||
# Dynamic env-var derivation present
|
||||
assert 'f"AUXILIARY_{_upper}_PROVIDER"' in content
|
||||
assert 'f"AUXILIARY_{_upper}_MODEL"' in content
|
||||
assert 'f"AUXILIARY_{_upper}_BASE_URL"' in content
|
||||
assert 'f"AUXILIARY_{_upper}_API_KEY"' in content
|
||||
# Built-in bridged keys present
|
||||
assert "_aux_bridged_keys" in content
|
||||
assert '"vision"' in content
|
||||
assert '"web_extract"' in content
|
||||
assert '"approval"' in content
|
||||
# Plugin-aux-task discovery hooked into bridging
|
||||
assert "get_plugin_auxiliary_tasks" in content
|
||||
|
||||
def test_gateway_no_compression_env_bridge(self):
|
||||
"""Gateway should NOT bridge compression config to env vars (config-only)."""
|
||||
|
|
|
|||
|
|
@ -65,11 +65,11 @@ class TestCompress:
|
|||
assert result == msgs
|
||||
|
||||
def test_truncation_fallback_no_client(self, compressor):
|
||||
# compressor has client=None and abort_on_summary_failure=False (default),
|
||||
# so the LEGACY fallback path inserts a static "summary unavailable"
|
||||
# placeholder and the middle window is dropped.
|
||||
# Simulate "no summarizer available" explicitly. call_llm can otherwise
|
||||
# discover the developer's real auxiliary credentials from auth state.
|
||||
msgs = [{"role": "system", "content": "System prompt"}] + self._make_messages(10)
|
||||
result = compressor.compress(msgs)
|
||||
with patch("agent.context_compressor.call_llm", side_effect=RuntimeError("no provider")):
|
||||
result = compressor.compress(msgs)
|
||||
assert len(result) < len(msgs)
|
||||
# Should keep system message and last N
|
||||
assert result[0]["role"] == "system"
|
||||
|
|
|
|||
243
tests/agent/test_display_todo_progress.py
Normal file
243
tests/agent/test_display_todo_progress.py
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
"""Tests for get_cute_tool_message todo progress display.
|
||||
|
||||
Verifies the completion status rendering (done/total ✓) on all three
|
||||
todo tool call paths: read, create (merge=False), update (merge=True).
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from agent.display import get_cute_tool_message
|
||||
|
||||
|
||||
def _todo_result(total: int, completed: int) -> str:
|
||||
"""Build a fake todo_tool return value."""
|
||||
return json.dumps({
|
||||
"todos": [],
|
||||
"summary": {
|
||||
"total": total,
|
||||
"pending": total - completed,
|
||||
"in_progress": 0,
|
||||
"completed": completed,
|
||||
"cancelled": 0,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
class TestTodoRead:
|
||||
"""get_cute_tool_message(…, result=…) when todos_arg is None (read path)."""
|
||||
|
||||
def test_read_no_result(self):
|
||||
msg = get_cute_tool_message("todo", {}, 0.5)
|
||||
assert "reading tasks" in msg
|
||||
assert "0.5s" in msg
|
||||
|
||||
def test_read_with_progress(self):
|
||||
msg = get_cute_tool_message("todo", {}, 0.5,
|
||||
result=_todo_result(4, 2))
|
||||
assert "2/4" in msg
|
||||
assert "task(s)" in msg
|
||||
|
||||
def test_read_all_done(self):
|
||||
msg = get_cute_tool_message("todo", {}, 0.5,
|
||||
result=_todo_result(4, 4))
|
||||
assert "4/4" in msg
|
||||
assert "task(s)" in msg
|
||||
|
||||
def test_read_zero_total(self):
|
||||
"""Edge case: empty todo list returns summary with total=0."""
|
||||
msg = get_cute_tool_message("todo", {}, 0.5,
|
||||
result=_todo_result(0, 0))
|
||||
assert "reading tasks" in msg
|
||||
|
||||
def test_read_invalid_result_fallback(self):
|
||||
"""Garbage result should not crash; fall back to reading tasks."""
|
||||
msg = get_cute_tool_message("todo", {}, 0.5, result="not json")
|
||||
assert "reading tasks" in msg
|
||||
|
||||
def test_read_result_missing_summary(self):
|
||||
msg = get_cute_tool_message("todo", {}, 0.5,
|
||||
result='{"todos": []}')
|
||||
assert "reading tasks" in msg
|
||||
|
||||
|
||||
class TestTodoCreate:
|
||||
"""get_cute_tool_message when merge=False (new plan creation)."""
|
||||
|
||||
def test_create_default(self):
|
||||
"""Brand-new plan: all pending, no result — plain count."""
|
||||
msg = get_cute_tool_message("todo",
|
||||
{"todos": [
|
||||
{"id": "a", "content": "x", "status": "pending"},
|
||||
]}, 0.3)
|
||||
assert "1 task(s)" in msg
|
||||
assert "0.3s" in msg
|
||||
assert "/" not in msg # no progress fraction
|
||||
|
||||
def test_create_multiple(self):
|
||||
msg = get_cute_tool_message("todo",
|
||||
{"todos": [
|
||||
{"id": "a", "content": "x", "status": "pending"},
|
||||
{"id": "b", "content": "y", "status": "pending"},
|
||||
{"id": "c", "content": "z", "status": "pending"},
|
||||
]}, 0.2)
|
||||
assert "3 task(s)" in msg
|
||||
|
||||
def test_create_with_result_shows_progress_when_done(self):
|
||||
"""Even on create, if result has completed tasks show it."""
|
||||
msg = get_cute_tool_message("todo",
|
||||
{"todos": [{"id": "a", "content": "x", "status": "completed"}]},
|
||||
0.4,
|
||||
result=_todo_result(1, 1))
|
||||
assert "1/1" in msg
|
||||
assert "task(s)" in msg
|
||||
|
||||
def test_create_with_result_zero_done(self):
|
||||
"""New plan with 0 done — plain count, no progress fraction."""
|
||||
msg = get_cute_tool_message("todo",
|
||||
{"todos": [
|
||||
{"id": "a", "content": "x", "status": "pending"},
|
||||
{"id": "b", "content": "y", "status": "pending"},
|
||||
]},
|
||||
0.3,
|
||||
result=_todo_result(2, 0))
|
||||
assert "2 task(s)" in msg
|
||||
assert "/" not in msg
|
||||
|
||||
|
||||
class TestTodoUpdate:
|
||||
"""get_cute_tool_message when merge=True (incremental update)."""
|
||||
|
||||
def test_update_no_result(self):
|
||||
"""No result available — plain update N task(s)."""
|
||||
msg = get_cute_tool_message("todo",
|
||||
{"todos": [{"id": "a", "status": "completed"}],
|
||||
"merge": True}, 0.5)
|
||||
assert "update 1 task(s)" in msg
|
||||
|
||||
def test_update_partial_progress(self):
|
||||
"""1/4 tasks completed — show fraction with checkmark."""
|
||||
msg = get_cute_tool_message("todo",
|
||||
{"todos": [{"id": "a", "status": "completed"}],
|
||||
"merge": True},
|
||||
0.5,
|
||||
result=_todo_result(4, 1))
|
||||
assert "update" in msg
|
||||
assert "1/4" in msg
|
||||
assert "✓" in msg
|
||||
|
||||
def test_update_halfway(self):
|
||||
"""2/4 — midpoint progress."""
|
||||
msg = get_cute_tool_message("todo",
|
||||
{"todos": [{"id": "b", "status": "in_progress"}],
|
||||
"merge": True},
|
||||
0.7,
|
||||
result=_todo_result(4, 2))
|
||||
assert "2/4" in msg
|
||||
assert "✓" in msg
|
||||
|
||||
def test_update_all_completed(self):
|
||||
"""4/4 — full checkmark."""
|
||||
msg = get_cute_tool_message("todo",
|
||||
{"todos": [{"id": "d", "status": "completed"}],
|
||||
"merge": True},
|
||||
0.2,
|
||||
result=_todo_result(4, 4))
|
||||
assert "4/4" in msg
|
||||
assert "✓" in msg
|
||||
|
||||
def test_update_zero_done(self):
|
||||
"""No completed tasks yet — plain update N task(s)."""
|
||||
msg = get_cute_tool_message("todo",
|
||||
{"todos": [{"id": "a", "status": "pending"}],
|
||||
"merge": True},
|
||||
0.3,
|
||||
result=_todo_result(3, 0))
|
||||
assert "update 1 task(s)" in msg
|
||||
assert "✓" not in msg
|
||||
assert "/" not in msg # no progress fraction when done=0
|
||||
|
||||
def test_update_invalid_result_fallback(self):
|
||||
"""Bad JSON result — fall back to plain update N task(s)."""
|
||||
msg = get_cute_tool_message("todo",
|
||||
{"todos": [{"id": "a", "status": "completed"}],
|
||||
"merge": True},
|
||||
0.6,
|
||||
result="{broken")
|
||||
assert "update 1 task(s)" in msg
|
||||
assert "✓" not in msg
|
||||
|
||||
def test_update_result_missing_summary(self):
|
||||
"""Result no summary key — fall back to plain update."""
|
||||
msg = get_cute_tool_message("todo",
|
||||
{"todos": [{"id": "a", "status": "completed"}],
|
||||
"merge": True},
|
||||
0.4,
|
||||
result='{"todos": []}')
|
||||
assert "update 1 task(s)" in msg
|
||||
assert "✓" not in msg
|
||||
|
||||
def test_update_total_not_in_summary(self):
|
||||
"""Result summary missing total key."""
|
||||
msg = get_cute_tool_message("todo",
|
||||
{"todos": [{"id": "a", "status": "completed"}],
|
||||
"merge": True},
|
||||
0.3,
|
||||
result=json.dumps({"summary": {"completed": 2}}))
|
||||
assert "update 1 task(s)" in msg
|
||||
assert "✓" not in msg
|
||||
|
||||
def test_update_multiple_tasks_in_line(self):
|
||||
"""Update line with several tasks in the update request."""
|
||||
msg = get_cute_tool_message("todo",
|
||||
{"todos": [
|
||||
{"id": "a", "status": "completed"},
|
||||
{"id": "b", "status": "in_progress"},
|
||||
], "merge": True},
|
||||
0.5,
|
||||
result=_todo_result(5, 3))
|
||||
assert "update" in msg
|
||||
assert "3/5" in msg
|
||||
assert "✓" in msg
|
||||
|
||||
|
||||
class TestTodoEdgeCases:
|
||||
"""Boundary cases that should not crash."""
|
||||
|
||||
def test_merge_default_value(self):
|
||||
"""merge defaults to False in function signature, should be False when absent."""
|
||||
msg = get_cute_tool_message("todo",
|
||||
{"todos": [{"id": "a", "content": "x", "status": "pending"}]},
|
||||
1.0)
|
||||
assert "1 task(s)" in msg
|
||||
|
||||
def test_duration_formatting(self):
|
||||
"""Duration formatting works correctly."""
|
||||
msg = get_cute_tool_message("todo", {}, 0.123)
|
||||
assert "0.1s" in msg
|
||||
|
||||
msg = get_cute_tool_message("todo", {}, 1.0)
|
||||
assert "1.0s" in msg
|
||||
|
||||
msg = get_cute_tool_message("todo", {}, 123.456)
|
||||
assert "123.5s" in msg
|
||||
|
||||
def test_large_task_count(self):
|
||||
"""Many tasks should not break formatting."""
|
||||
many = [{"id": str(i), "content": "x", "status": "pending"} for i in range(50)]
|
||||
msg = get_cute_tool_message("todo", {"todos": many}, 0.5)
|
||||
assert "50 task(s)" in msg
|
||||
|
||||
def test_read_with_no_args_and_no_result(self):
|
||||
"""Completely empty call."""
|
||||
msg = get_cute_tool_message("todo", {}, 0.0)
|
||||
assert "reading tasks" in msg
|
||||
|
||||
|
||||
class TestTodoSkinIntegration:
|
||||
"""Verify the skin prefix is applied to todo messages too.
|
||||
This uses the same pattern as test_skin_engine test_tool_message_uses_skin_prefix.
|
||||
"""
|
||||
|
||||
def test_default_skin_prefix(self):
|
||||
msg = get_cute_tool_message("todo", {}, 0.5)
|
||||
assert msg.startswith("┊")
|
||||
185
tests/agent/test_display_tool_failure.py
Normal file
185
tests/agent/test_display_tool_failure.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
"""Tests for _detect_tool_failure + _trim_error + get_cute_tool_message
|
||||
inline failure suffix rendering.
|
||||
|
||||
Covers the user-visible promise: when a tool fails, the CLI shows a short,
|
||||
specific reason in square brackets at the end of the completion line —
|
||||
not a generic "[error]".
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from agent.display import (
|
||||
_detect_tool_failure,
|
||||
_trim_error,
|
||||
_ERROR_SUFFIX_MAX_LEN,
|
||||
get_cute_tool_message,
|
||||
)
|
||||
|
||||
|
||||
class TestTrimError:
|
||||
"""The helper that shrinks an error message for inline display."""
|
||||
|
||||
def test_short_message_unchanged(self):
|
||||
assert _trim_error("nope") == "nope"
|
||||
|
||||
def test_whitespace_stripped(self):
|
||||
assert _trim_error(" bad input ") == "bad input"
|
||||
|
||||
def test_long_message_truncated_to_cap(self):
|
||||
msg = "x" * 200
|
||||
trimmed = _trim_error(msg)
|
||||
assert len(trimmed) <= _ERROR_SUFFIX_MAX_LEN
|
||||
assert trimmed.endswith("...")
|
||||
|
||||
def test_file_not_found_path_collapsed_to_filename(self):
|
||||
long_path = "File not found: /home/teknium/.hermes/hermes-agent/very/deep/path/foo.py"
|
||||
assert _trim_error(long_path) == "File not found: foo.py"
|
||||
|
||||
def test_file_not_found_already_short_unchanged(self):
|
||||
assert _trim_error("File not found: foo.py") == "File not found: foo.py"
|
||||
|
||||
def test_file_not_found_relative_path_unchanged(self):
|
||||
# Without a slash there's no path to trim.
|
||||
assert _trim_error("File not found: foo.py") == "File not found: foo.py"
|
||||
|
||||
|
||||
class TestDetectToolFailureTerminal:
|
||||
"""terminal: non-zero exit_code is the canonical failure signal."""
|
||||
|
||||
def test_success_returns_no_suffix(self):
|
||||
result = json.dumps({"output": "ok\n", "exit_code": 0})
|
||||
assert _detect_tool_failure("terminal", result) == (False, "")
|
||||
|
||||
def test_nonzero_exit_with_no_error_shows_exit_code(self):
|
||||
result = json.dumps({"output": "", "exit_code": 1})
|
||||
is_failure, suffix = _detect_tool_failure("terminal", result)
|
||||
assert is_failure is True
|
||||
assert suffix == " [exit 1]"
|
||||
|
||||
def test_nonzero_exit_with_error_shows_message(self):
|
||||
result = json.dumps({
|
||||
"output": "",
|
||||
"exit_code": 127,
|
||||
"error": "ls: cannot access 'foo': No such file or directory",
|
||||
})
|
||||
is_failure, suffix = _detect_tool_failure("terminal", result)
|
||||
assert is_failure is True
|
||||
assert "cannot access" in suffix
|
||||
# Trimmed to the cap, in brackets
|
||||
assert suffix.startswith(" [")
|
||||
assert suffix.endswith("]")
|
||||
|
||||
def test_malformed_json_returns_no_suffix(self):
|
||||
# Terminal is special: only exit_code matters. Malformed JSON should
|
||||
# not crash and should not be flagged as failure.
|
||||
assert _detect_tool_failure("terminal", "not json") == (False, "")
|
||||
|
||||
def test_none_result_returns_no_suffix(self):
|
||||
assert _detect_tool_failure("terminal", None) == (False, "")
|
||||
|
||||
|
||||
class TestDetectToolFailureMemory:
|
||||
"""memory: 'full' is distinct from real errors."""
|
||||
|
||||
def test_memory_full_returns_full_suffix(self):
|
||||
result = json.dumps({"success": False, "error": "would exceed the limit"})
|
||||
assert _detect_tool_failure("memory", result) == (True, " [full]")
|
||||
|
||||
def test_memory_other_error_returns_specific_message(self):
|
||||
# An error that's NOT a "full" overflow falls through to the
|
||||
# structured-error path and surfaces the actual message.
|
||||
result = json.dumps({"success": False, "error": "invalid action: zap"})
|
||||
is_failure, suffix = _detect_tool_failure("memory", result)
|
||||
assert is_failure is True
|
||||
assert "invalid action" in suffix
|
||||
|
||||
|
||||
class TestDetectToolFailureStructured:
|
||||
"""Generic path: any tool that returns {"error": ...} JSON."""
|
||||
|
||||
def test_read_file_error_surfaced(self):
|
||||
result = json.dumps({
|
||||
"path": "/nope/missing.py",
|
||||
"success": False,
|
||||
"error": "File not found: /nope/missing.py",
|
||||
})
|
||||
is_failure, suffix = _detect_tool_failure("read_file", result)
|
||||
assert is_failure is True
|
||||
# _trim_error reduces the path to the basename.
|
||||
assert suffix == " [File not found: missing.py]"
|
||||
|
||||
def test_error_without_success_key_still_flagged(self):
|
||||
# Some tools return {"error": "..."} with no explicit success flag.
|
||||
result = json.dumps({"error": "remote unavailable"})
|
||||
is_failure, suffix = _detect_tool_failure("web_search", result)
|
||||
assert is_failure is True
|
||||
assert suffix == " [remote unavailable]"
|
||||
|
||||
def test_message_field_only_with_success_false_flagged(self):
|
||||
# When success is False and only 'message' is set, surface it.
|
||||
result = json.dumps({"success": False, "message": "rate limited"})
|
||||
is_failure, suffix = _detect_tool_failure("web_search", result)
|
||||
assert is_failure is True
|
||||
assert "rate limited" in suffix
|
||||
|
||||
def test_successful_result_not_flagged(self):
|
||||
result = json.dumps({"success": True, "data": "hello"})
|
||||
assert _detect_tool_failure("web_search", result) == (False, "")
|
||||
|
||||
def test_dict_without_error_or_success_uses_generic_heuristic(self):
|
||||
# Plain successful dict — should pass through the generic
|
||||
# heuristic which only fires on the string "Error" / '"error"' / etc.
|
||||
result = json.dumps({"data": "hello"})
|
||||
is_failure, _ = _detect_tool_failure("web_search", result)
|
||||
assert is_failure is False
|
||||
|
||||
|
||||
class TestGetCuteToolMessageFailureSuffix:
|
||||
"""End-to-end: failure suffix is appended by get_cute_tool_message."""
|
||||
|
||||
def test_read_file_failure_suffix_appended(self):
|
||||
fail = json.dumps({
|
||||
"path": "/etc/missing",
|
||||
"success": False,
|
||||
"error": "File not found: /etc/missing",
|
||||
})
|
||||
line = get_cute_tool_message("read_file", {"path": "/etc/missing"}, 0.1, result=fail)
|
||||
assert "[File not found: missing]" in line
|
||||
|
||||
def test_terminal_exit_only_suffix(self):
|
||||
fail = json.dumps({"output": "", "exit_code": 2})
|
||||
line = get_cute_tool_message("terminal", {"command": "false"}, 0.1, result=fail)
|
||||
assert "[exit 2]" in line
|
||||
|
||||
def test_terminal_with_stderr_uses_message(self):
|
||||
fail = json.dumps({
|
||||
"output": "",
|
||||
"exit_code": 127,
|
||||
"error": "command not found: notathing",
|
||||
})
|
||||
line = get_cute_tool_message("terminal", {"command": "notathing"}, 0.1, result=fail)
|
||||
assert "command not found" in line
|
||||
# No '[exit 127]' tag when we have a specific message
|
||||
assert "exit 127" not in line
|
||||
|
||||
def test_memory_full_suffix(self):
|
||||
fail = json.dumps({"success": False, "error": "would exceed the limit"})
|
||||
line = get_cute_tool_message(
|
||||
"memory",
|
||||
{"action": "add", "target": "memory", "content": "x"},
|
||||
0.05,
|
||||
result=fail,
|
||||
)
|
||||
assert "[full]" in line
|
||||
|
||||
def test_success_has_no_suffix(self):
|
||||
ok = json.dumps({"success": True, "data": "hi"})
|
||||
line = get_cute_tool_message("web_search", {"query": "hi"}, 0.2, result=ok)
|
||||
assert "[" not in line.split("0.2s", 1)[1]
|
||||
|
||||
def test_no_result_has_no_suffix(self):
|
||||
# No result passed at all — display function should not invent a
|
||||
# failure suffix.
|
||||
line = get_cute_tool_message("terminal", {"command": "ls"}, 0.2)
|
||||
assert "[" not in line.split("0.2s", 1)[1]
|
||||
|
|
@ -56,6 +56,7 @@ class TestFailoverReason:
|
|||
"overloaded", "server_error", "timeout",
|
||||
"context_overflow", "payload_too_large", "image_too_large",
|
||||
"model_not_found", "format_error",
|
||||
"multimodal_tool_content_unsupported",
|
||||
"provider_policy_blocked",
|
||||
"thinking_signature", "long_context_tier",
|
||||
"oauth_long_context_beta_forbidden",
|
||||
|
|
@ -292,6 +293,64 @@ class TestClassifyApiError:
|
|||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.overloaded
|
||||
|
||||
# ── 5xx that are actually request-validation errors ──
|
||||
# Some OpenAI-compatible gateways (e.g. codex.nekos.me) return
|
||||
# request-validation failures with a 5xx status. These are
|
||||
# deterministic, so they must NOT be retried — otherwise the retry
|
||||
# loop hammers the identical bad request into a flood.
|
||||
|
||||
def test_502_with_unknown_parameter_is_non_retryable(self):
|
||||
e = MockAPIError(
|
||||
"Unknown parameter: 'input[617]._empty_recovery_synthetic'",
|
||||
status_code=502,
|
||||
body={
|
||||
"error": {
|
||||
"type": "invalid_request_error",
|
||||
"message": (
|
||||
"[ObjectParam] [input[617]._empty_recovery_synthetic] "
|
||||
"[unknown_parameter] Unknown parameter: "
|
||||
"'input[617]._empty_recovery_synthetic'."
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.format_error
|
||||
assert result.retryable is False
|
||||
assert result.should_fallback is True
|
||||
|
||||
def test_502_with_unsupported_parameter_is_non_retryable(self):
|
||||
e = MockAPIError(
|
||||
"Unsupported parameter: logprobs",
|
||||
status_code=502,
|
||||
body={
|
||||
"error": {
|
||||
"type": "invalid_request_error",
|
||||
"message": "Unsupported parameter: logprobs",
|
||||
}
|
||||
},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.format_error
|
||||
assert result.retryable is False
|
||||
|
||||
def test_500_with_invalid_request_error_type_is_non_retryable(self):
|
||||
e = MockAPIError(
|
||||
"bad request",
|
||||
status_code=500,
|
||||
body={"error": {"type": "invalid_request_error", "message": "bad request"}},
|
||||
)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.format_error
|
||||
assert result.retryable is False
|
||||
|
||||
def test_502_plain_bad_gateway_still_retryable(self):
|
||||
"""A genuine 502 with no request-validation signal stays retryable."""
|
||||
e = MockAPIError("Bad Gateway", status_code=502)
|
||||
result = classify_api_error(e)
|
||||
assert result.reason == FailoverReason.server_error
|
||||
assert result.retryable is True
|
||||
|
||||
# ── Model not found ──
|
||||
|
||||
def test_404_model_not_found(self):
|
||||
|
|
@ -1256,3 +1315,66 @@ class TestRateLimitErrorWithoutStatusCode:
|
|||
e.status_code = None
|
||||
result = classify_api_error(e, provider="copilot", model="gpt-4o")
|
||||
assert result.reason != FailoverReason.rate_limit
|
||||
|
||||
|
||||
|
||||
# ── Test: multimodal_tool_content_unsupported pattern ───────────────────
|
||||
|
||||
class TestMultimodalToolContentUnsupported:
|
||||
"""Issue #27344 — providers that reject list-type tool message content
|
||||
should be classified as ``multimodal_tool_content_unsupported`` so the
|
||||
retry loop can downgrade screenshots to text and try again.
|
||||
"""
|
||||
|
||||
def test_xiaomi_mimo_text_is_not_set_pattern(self):
|
||||
"""The actual Xiaomi MiMo 400 wording from the bug report."""
|
||||
e = MockAPIError(
|
||||
"Error code: 400 - {'error': {'code': '400', 'message': 'Param Incorrect', 'param': 'text is not set', 'type': ''}}",
|
||||
status_code=400,
|
||||
)
|
||||
result = classify_api_error(e, provider="xiaomi", model="mimo-v2.5")
|
||||
assert result.reason == FailoverReason.multimodal_tool_content_unsupported
|
||||
assert result.retryable is True
|
||||
|
||||
def test_generic_tool_message_must_be_string(self):
|
||||
e = MockAPIError(
|
||||
"tool message content must be a string",
|
||||
status_code=400,
|
||||
)
|
||||
result = classify_api_error(e, provider="custom", model="some-model")
|
||||
assert result.reason == FailoverReason.multimodal_tool_content_unsupported
|
||||
|
||||
def test_expected_string_got_list(self):
|
||||
e = MockAPIError(
|
||||
"Schema validation failed: expected string, got list",
|
||||
status_code=400,
|
||||
)
|
||||
result = classify_api_error(e, provider="custom", model="some-model")
|
||||
assert result.reason == FailoverReason.multimodal_tool_content_unsupported
|
||||
|
||||
def test_multimodal_tool_content_takes_priority_over_context_overflow(self):
|
||||
"""Some providers return a 400 whose message contains BOTH
|
||||
'text is not set' and a length-shaped phrase; the tool-content
|
||||
recovery is cheaper than compression so it must win the priority.
|
||||
"""
|
||||
e = MockAPIError(
|
||||
"text is not set; context length exceeded",
|
||||
status_code=400,
|
||||
)
|
||||
result = classify_api_error(e, provider="xiaomi", model="mimo-v2.5")
|
||||
assert result.reason == FailoverReason.multimodal_tool_content_unsupported
|
||||
|
||||
def test_no_status_code_path_also_classifies(self):
|
||||
"""When the error reaches us without a status code (transport
|
||||
layer ate it) the message-only classifier branch must also
|
||||
recognise the pattern.
|
||||
"""
|
||||
e = MockTransportError("tool_call.content must be string")
|
||||
result = classify_api_error(e, provider="alibaba", model="qwen3.5-plus")
|
||||
assert result.reason == FailoverReason.multimodal_tool_content_unsupported
|
||||
|
||||
def test_unrelated_400_is_not_misclassified(self):
|
||||
"""Make sure the patterns don't false-positive on normal 400s."""
|
||||
e = MockAPIError("bad request: missing field 'model'", status_code=400)
|
||||
result = classify_api_error(e, provider="openrouter", model="anthropic/claude-sonnet-4")
|
||||
assert result.reason != FailoverReason.multimodal_tool_content_unsupported
|
||||
|
|
|
|||
275
tests/agent/test_file_safety_credentials.py
Normal file
275
tests/agent/test_file_safety_credentials.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
"""Tests for HERMES_HOME credential-file read blocking in file_safety.
|
||||
|
||||
Regression for https://github.com/NousResearch/hermes-agent/issues/17656 —
|
||||
``read_file`` was previously only sandboxed against ``HERMES_HOME`` itself,
|
||||
which left ``auth.json`` and ``.anthropic_oauth.json`` (plaintext provider
|
||||
keys + OAuth tokens) readable by the agent. A prompt-injection reaching
|
||||
``read_file`` could exfiltrate active credentials.
|
||||
|
||||
These tests verify that ``get_read_block_error`` returns a denial message
|
||||
for the credential stores while leaving arbitrary ``HERMES_HOME`` files
|
||||
readable, and that the existing ``skills/.hub`` deny still applies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fake_home(tmp_path, monkeypatch):
|
||||
"""Point ``_hermes_home_path()`` at a tmp dir for isolated checks."""
|
||||
import agent.file_safety as fs
|
||||
|
||||
home = tmp_path / "hermes_home"
|
||||
home.mkdir()
|
||||
monkeypatch.setattr(fs, "_hermes_home_path", lambda: home)
|
||||
return home
|
||||
|
||||
|
||||
def _create(home: Path, rel: str | Path) -> Path:
|
||||
"""Create the file (with parents) so realpath() resolves it."""
|
||||
p = home / rel
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
p.write_text("dummy", encoding="utf-8")
|
||||
return p
|
||||
|
||||
|
||||
def test_auth_json_blocked(fake_home):
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
auth = _create(fake_home, "auth.json")
|
||||
err = get_read_block_error(str(auth))
|
||||
assert err is not None
|
||||
assert "credential store" in err
|
||||
assert "auth.json" in err
|
||||
|
||||
|
||||
def test_auth_lock_blocked(fake_home):
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
lock = _create(fake_home, "auth.lock")
|
||||
err = get_read_block_error(str(lock))
|
||||
assert err is not None
|
||||
assert "credential store" in err
|
||||
|
||||
|
||||
def test_anthropic_oauth_json_blocked(fake_home):
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
oauth = _create(fake_home, ".anthropic_oauth.json")
|
||||
err = get_read_block_error(str(oauth))
|
||||
assert err is not None
|
||||
assert "credential store" in err
|
||||
|
||||
|
||||
def test_arbitrary_hermes_home_file_not_blocked(fake_home):
|
||||
"""Non-credential files inside HERMES_HOME stay readable."""
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
safe = _create(fake_home, "session_log.txt")
|
||||
assert get_read_block_error(str(safe)) is None
|
||||
|
||||
|
||||
def test_subdirectory_named_auth_json_not_blocked(fake_home):
|
||||
"""Only the top-level auth.json is the credential store; a file with the
|
||||
same name in a subdirectory (e.g., a skill mock) must remain readable."""
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
nested = _create(fake_home, Path("skills") / "my-skill" / "auth.json")
|
||||
assert get_read_block_error(str(nested)) is None
|
||||
|
||||
|
||||
def test_skills_hub_block_still_applies(fake_home):
|
||||
"""Regression guard: the original skills/.hub deny must keep working."""
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
hub_file = _create(fake_home, "skills/.hub/manifest.json")
|
||||
err = get_read_block_error(str(hub_file))
|
||||
assert err is not None
|
||||
assert "internal Hermes cache file" in err
|
||||
|
||||
|
||||
def test_path_traversal_resolves_to_blocked(fake_home, tmp_path):
|
||||
"""A path that traverses through a sibling dir back into HERMES_HOME's
|
||||
auth.json must still be caught — the check resolves through realpath."""
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
_create(fake_home, "auth.json")
|
||||
sibling = tmp_path / "elsewhere"
|
||||
sibling.mkdir()
|
||||
traversal = sibling / ".." / "hermes_home" / "auth.json"
|
||||
err = get_read_block_error(str(traversal))
|
||||
assert err is not None
|
||||
assert "credential store" in err
|
||||
|
||||
|
||||
def test_symlink_to_auth_json_blocked(fake_home, tmp_path):
|
||||
"""A symlink pointing at HERMES_HOME/auth.json from outside the home
|
||||
must be blocked — readlink-resolution catches the indirection."""
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
target = _create(fake_home, "auth.json")
|
||||
link = tmp_path / "shim.json"
|
||||
try:
|
||||
os.symlink(target, link)
|
||||
except (OSError, NotImplementedError):
|
||||
pytest.skip("symlinks not supported on this platform/filesystem")
|
||||
err = get_read_block_error(str(link))
|
||||
assert err is not None
|
||||
assert "credential store" in err
|
||||
|
||||
|
||||
def test_read_file_tool_blocks_relative_path_under_terminal_cwd(
|
||||
fake_home, tmp_path, monkeypatch
|
||||
):
|
||||
"""Bypass guard: a relative path like ``"auth.json"`` resolved by
|
||||
``read_file_tool`` against ``TERMINAL_CWD == HERMES_HOME`` must still
|
||||
be blocked, even though ``get_read_block_error``'s own ``resolve()``
|
||||
is anchored at the (different) Python process cwd.
|
||||
"""
|
||||
import json
|
||||
|
||||
import tools.file_tools as ft
|
||||
|
||||
_create(fake_home, "auth.json")
|
||||
# Force the file_tools resolver to anchor relative paths at HERMES_HOME
|
||||
# while the Python process cwd remains tmp_path (a different directory).
|
||||
monkeypatch.setenv("TERMINAL_CWD", str(fake_home))
|
||||
monkeypatch.chdir(tmp_path)
|
||||
monkeypatch.setattr(
|
||||
ft, "_get_live_tracking_cwd", lambda task_id="default": None
|
||||
)
|
||||
|
||||
out = json.loads(ft.read_file_tool("auth.json"))
|
||||
assert "error" in out
|
||||
assert "credential store" in out["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Widening: .env, webhook_subscriptions.json, mcp-tokens/
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_dotenv_blocked(fake_home):
|
||||
""".env in HERMES_HOME holds API keys — blocked."""
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
env = _create(fake_home, ".env")
|
||||
err = get_read_block_error(str(env))
|
||||
assert err is not None
|
||||
assert "credential store" in err
|
||||
|
||||
|
||||
def test_webhook_subscriptions_blocked(fake_home):
|
||||
"""webhook_subscriptions.json holds per-route HMAC secrets — blocked."""
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
subs = _create(fake_home, "webhook_subscriptions.json")
|
||||
err = get_read_block_error(str(subs))
|
||||
assert err is not None
|
||||
assert "credential store" in err
|
||||
|
||||
|
||||
def test_mcp_tokens_file_blocked(fake_home):
|
||||
"""Files under mcp-tokens/ hold OAuth tokens — blocked."""
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
tok = _create(fake_home, Path("mcp-tokens") / "github.json")
|
||||
err = get_read_block_error(str(tok))
|
||||
assert err is not None
|
||||
assert "MCP token" in err
|
||||
|
||||
|
||||
def test_mcp_tokens_nested_blocked(fake_home):
|
||||
"""Nested files inside mcp-tokens/ are also blocked."""
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
tok = _create(fake_home, Path("mcp-tokens") / "providers" / "azure.json")
|
||||
err = get_read_block_error(str(tok))
|
||||
assert err is not None
|
||||
assert "MCP token" in err
|
||||
|
||||
|
||||
def test_mcp_tokens_dir_itself_blocked(fake_home):
|
||||
"""The mcp-tokens directory itself is blocked (listing is exfiltrating)."""
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
tokens_dir = fake_home / "mcp-tokens"
|
||||
tokens_dir.mkdir(parents=True, exist_ok=True)
|
||||
err = get_read_block_error(str(tokens_dir))
|
||||
assert err is not None
|
||||
assert "MCP token" in err
|
||||
|
||||
|
||||
def test_identically_named_files_outside_hermes_home_not_blocked(
|
||||
fake_home, tmp_path
|
||||
):
|
||||
"""A project's ``.env``, ``auth.json``, or ``mcp-tokens/`` outside
|
||||
HERMES_HOME must remain readable — the gate is per-location, not
|
||||
per-filename."""
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
project = tmp_path / "myproject"
|
||||
project.mkdir()
|
||||
for rel in (".env", "auth.json"):
|
||||
p = project / rel
|
||||
p.write_text("not secret here", encoding="utf-8")
|
||||
assert get_read_block_error(str(p)) is None, (
|
||||
f"{rel} outside HERMES_HOME should NOT be blocked"
|
||||
)
|
||||
|
||||
tokens = project / "mcp-tokens"
|
||||
tokens.mkdir()
|
||||
tok_file = tokens / "token.json"
|
||||
tok_file.write_text("not really a token", encoding="utf-8")
|
||||
assert get_read_block_error(str(tok_file)) is None
|
||||
|
||||
|
||||
def test_config_yaml_not_blocked(fake_home):
|
||||
"""config.yaml is NOT a credential file — agent should still be
|
||||
able to read it for debugging. (Writes are denied separately by
|
||||
is_write_denied; reads stay allowed.)"""
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
cfg = _create(fake_home, "config.yaml")
|
||||
assert get_read_block_error(str(cfg)) is None
|
||||
|
||||
|
||||
def test_profile_mode_blocks_root_credentials(tmp_path, monkeypatch):
|
||||
"""Under a profile, HERMES_HOME = <root>/profiles/<name>, but
|
||||
<root>/auth.json must ALSO be blocked — credentials at root are
|
||||
inherited by every profile."""
|
||||
import agent.file_safety as fs
|
||||
|
||||
root = tmp_path / "hermes"
|
||||
profile = root / "profiles" / "coder"
|
||||
profile.mkdir(parents=True)
|
||||
monkeypatch.setattr(fs, "_hermes_home_path", lambda: profile)
|
||||
monkeypatch.setattr(fs, "_hermes_root_path", lambda: root)
|
||||
|
||||
from agent.file_safety import get_read_block_error
|
||||
|
||||
# Profile-local credential store: blocked
|
||||
profile_auth = profile / "auth.json"
|
||||
profile_auth.write_text("x")
|
||||
assert "credential store" in (get_read_block_error(str(profile_auth)) or "")
|
||||
|
||||
# Root-level credential store: ALSO blocked (this is the widening)
|
||||
root_auth = root / "auth.json"
|
||||
root_auth.write_text("x")
|
||||
assert "credential store" in (get_read_block_error(str(root_auth)) or "")
|
||||
|
||||
# Root-level .env: blocked too
|
||||
root_env = root / ".env"
|
||||
root_env.write_text("x")
|
||||
assert "credential store" in (get_read_block_error(str(root_env)) or "")
|
||||
|
||||
# Root-level mcp-tokens: blocked
|
||||
root_tok = root / "mcp-tokens" / "gh.json"
|
||||
root_tok.parent.mkdir(parents=True, exist_ok=True)
|
||||
root_tok.write_text("x")
|
||||
assert "MCP token" in (get_read_block_error(str(root_tok)) or "")
|
||||
219
tests/agent/test_file_safety_cross_profile.py
Normal file
219
tests/agent/test_file_safety_cross_profile.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
"""Tests for the cross-Hermes-profile write guard in agent/file_safety.
|
||||
|
||||
The guard fires when a tool tries to write into another Hermes profile's
|
||||
skills/plugins/cron/memories directory. It's a soft guard — defense in
|
||||
depth, NOT a security boundary — but it prevents the agent from silently
|
||||
corrupting a profile that belongs to a different session.
|
||||
|
||||
Reference: May 2026 incident — a hermes-security profile session
|
||||
accidentally edited skills under both ~/.hermes/profiles/hermes-security/skills/
|
||||
AND ~/.hermes/skills/ (the default profile's skills), realizing only
|
||||
afterwards that the second path belonged to a different profile.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — set up a fake Hermes root with two profiles, monkeypatch the
|
||||
# resolver helpers so the classifier sees the test layout.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_hermes(tmp_path, monkeypatch):
|
||||
"""Build a fake Hermes layout:
|
||||
|
||||
<tmp>/
|
||||
skills/foo/SKILL.md # default profile
|
||||
plugins/foo/__init__.py
|
||||
cron/<state>
|
||||
memories/MEMORY.md
|
||||
profiles/
|
||||
hermes-security/
|
||||
skills/foo/SKILL.md # named profile
|
||||
plugins/...
|
||||
coder/
|
||||
skills/foo/SKILL.md # another named profile
|
||||
"""
|
||||
root = tmp_path / "fake-hermes"
|
||||
(root / "skills" / "foo").mkdir(parents=True)
|
||||
(root / "skills" / "foo" / "SKILL.md").write_text("# default skill\n")
|
||||
(root / "plugins" / "foo").mkdir(parents=True)
|
||||
(root / "memories").mkdir(parents=True)
|
||||
(root / "cron").mkdir(parents=True)
|
||||
|
||||
sec_home = root / "profiles" / "hermes-security"
|
||||
(sec_home / "skills" / "foo").mkdir(parents=True)
|
||||
(sec_home / "skills" / "foo" / "SKILL.md").write_text("# sec skill\n")
|
||||
(sec_home / "plugins").mkdir(parents=True)
|
||||
|
||||
coder_home = root / "profiles" / "coder"
|
||||
(coder_home / "skills" / "foo").mkdir(parents=True)
|
||||
(coder_home / "skills" / "foo" / "SKILL.md").write_text("# coder skill\n")
|
||||
|
||||
# Monkeypatch the resolver functions used by file_safety so each test
|
||||
# can choose which profile is "active".
|
||||
import hermes_constants
|
||||
monkeypatch.setattr(hermes_constants, "get_default_hermes_root", lambda: root)
|
||||
|
||||
# The reloads below ensure get_cross_profile_warning/classify see the patched root.
|
||||
import agent.file_safety as fs
|
||||
monkeypatch.setattr(fs, "_hermes_root_path", lambda: root)
|
||||
|
||||
return {
|
||||
"root": root,
|
||||
"default_home": root,
|
||||
"security_home": sec_home,
|
||||
"coder_home": coder_home,
|
||||
}
|
||||
|
||||
|
||||
def _set_active_home(monkeypatch, hermes_home: Path):
|
||||
"""Point file_safety._hermes_home_path at a specific profile dir."""
|
||||
import agent.file_safety as fs
|
||||
monkeypatch.setattr(fs, "_hermes_home_path", lambda: hermes_home)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_active_profile_name
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveActiveProfileName:
|
||||
def test_default_when_home_is_root(self, fake_hermes, monkeypatch):
|
||||
_set_active_home(monkeypatch, fake_hermes["default_home"])
|
||||
from agent.file_safety import _resolve_active_profile_name
|
||||
assert _resolve_active_profile_name() == "default"
|
||||
|
||||
def test_named_profile(self, fake_hermes, monkeypatch):
|
||||
_set_active_home(monkeypatch, fake_hermes["security_home"])
|
||||
from agent.file_safety import _resolve_active_profile_name
|
||||
assert _resolve_active_profile_name() == "hermes-security"
|
||||
|
||||
def test_falls_back_to_default_on_resolution_failure(self, fake_hermes, monkeypatch):
|
||||
"""If HERMES_HOME resolution raises, return 'default' rather than crashing the tool."""
|
||||
import agent.file_safety as fs
|
||||
|
||||
def _boom():
|
||||
raise RuntimeError("simulated")
|
||||
|
||||
monkeypatch.setattr(fs, "_hermes_home_path", _boom)
|
||||
# Should not raise — falls back to "default"
|
||||
assert fs._resolve_active_profile_name() == "default"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# classify_cross_profile_target
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClassifyCrossProfileTarget:
|
||||
def test_same_profile_write_returns_none(self, fake_hermes, monkeypatch):
|
||||
_set_active_home(monkeypatch, fake_hermes["security_home"])
|
||||
from agent.file_safety import classify_cross_profile_target
|
||||
result = classify_cross_profile_target(
|
||||
str(fake_hermes["security_home"] / "skills" / "foo" / "SKILL.md")
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_security_writing_default_skill(self, fake_hermes, monkeypatch):
|
||||
"""The exact incident from May 2026."""
|
||||
_set_active_home(monkeypatch, fake_hermes["security_home"])
|
||||
from agent.file_safety import classify_cross_profile_target
|
||||
result = classify_cross_profile_target(
|
||||
str(fake_hermes["default_home"] / "skills" / "foo" / "SKILL.md")
|
||||
)
|
||||
assert result is not None
|
||||
assert result["active_profile"] == "hermes-security"
|
||||
assert result["target_profile"] == "default"
|
||||
assert result["area"] == "skills"
|
||||
|
||||
def test_default_writing_security_skill(self, fake_hermes, monkeypatch):
|
||||
"""Inverse direction — default-profile session reaching into a named profile."""
|
||||
_set_active_home(monkeypatch, fake_hermes["default_home"])
|
||||
from agent.file_safety import classify_cross_profile_target
|
||||
result = classify_cross_profile_target(
|
||||
str(fake_hermes["security_home"] / "skills" / "foo" / "SKILL.md")
|
||||
)
|
||||
assert result is not None
|
||||
assert result["active_profile"] == "default"
|
||||
assert result["target_profile"] == "hermes-security"
|
||||
|
||||
def test_named_to_named_cross_profile(self, fake_hermes, monkeypatch):
|
||||
_set_active_home(monkeypatch, fake_hermes["security_home"])
|
||||
from agent.file_safety import classify_cross_profile_target
|
||||
result = classify_cross_profile_target(
|
||||
str(fake_hermes["coder_home"] / "skills" / "foo" / "SKILL.md")
|
||||
)
|
||||
assert result is not None
|
||||
assert result["target_profile"] == "coder"
|
||||
|
||||
@pytest.mark.parametrize("area", ["skills", "plugins", "cron", "memories"])
|
||||
def test_all_profile_scoped_areas_classified(self, fake_hermes, monkeypatch, area):
|
||||
_set_active_home(monkeypatch, fake_hermes["security_home"])
|
||||
from agent.file_safety import classify_cross_profile_target
|
||||
target = fake_hermes["default_home"] / area / "foo.txt"
|
||||
result = classify_cross_profile_target(str(target))
|
||||
assert result is not None
|
||||
assert result["area"] == area
|
||||
|
||||
def test_non_hermes_path_returns_none(self, fake_hermes, monkeypatch, tmp_path):
|
||||
_set_active_home(monkeypatch, fake_hermes["security_home"])
|
||||
from agent.file_safety import classify_cross_profile_target
|
||||
# Path outside any Hermes root
|
||||
assert classify_cross_profile_target(str(tmp_path / "random.txt")) is None
|
||||
|
||||
def test_hermes_config_not_classified_as_cross_profile(self, fake_hermes, monkeypatch):
|
||||
"""Files under <root>/config.yaml or <root>/.env are NOT profile-scoped
|
||||
(already covered by build_write_denied_paths). Don't double-warn."""
|
||||
_set_active_home(monkeypatch, fake_hermes["security_home"])
|
||||
from agent.file_safety import classify_cross_profile_target
|
||||
# config.yaml at root level is not in PROFILE_SCOPED_AREAS
|
||||
result = classify_cross_profile_target(
|
||||
str(fake_hermes["default_home"] / "config.yaml")
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_cross_profile_warning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetCrossProfileWarning:
|
||||
def test_in_profile_returns_none(self, fake_hermes, monkeypatch):
|
||||
_set_active_home(monkeypatch, fake_hermes["security_home"])
|
||||
from agent.file_safety import get_cross_profile_warning
|
||||
assert get_cross_profile_warning(
|
||||
str(fake_hermes["security_home"] / "skills" / "foo" / "SKILL.md")
|
||||
) is None
|
||||
|
||||
def test_cross_profile_warning_names_both_profiles(self, fake_hermes, monkeypatch):
|
||||
_set_active_home(monkeypatch, fake_hermes["security_home"])
|
||||
from agent.file_safety import get_cross_profile_warning
|
||||
warn = get_cross_profile_warning(
|
||||
str(fake_hermes["default_home"] / "skills" / "foo" / "SKILL.md")
|
||||
)
|
||||
assert warn is not None
|
||||
# Must name BOTH profiles so the model knows which is which.
|
||||
assert "default" in warn
|
||||
assert "hermes-security" in warn
|
||||
# Must name the bypass kwarg.
|
||||
assert "cross_profile=True" in warn
|
||||
# Must reference the area.
|
||||
assert "skills" in warn
|
||||
|
||||
def test_warning_is_defense_in_depth_not_boundary(self, fake_hermes, monkeypatch):
|
||||
_set_active_home(monkeypatch, fake_hermes["security_home"])
|
||||
from agent.file_safety import get_cross_profile_warning
|
||||
warn = get_cross_profile_warning(
|
||||
str(fake_hermes["default_home"] / "skills" / "foo" / "SKILL.md")
|
||||
)
|
||||
# Must self-document as defense-in-depth so future reviewers
|
||||
# don't promote it to a hard block.
|
||||
assert "not a security boundary" in warn.lower()
|
||||
22
tests/agent/test_last_total_tokens.py
Normal file
22
tests/agent/test_last_total_tokens.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""Test that last_total_tokens is correctly set by ContextCompressor."""
|
||||
|
||||
from agent.context_compressor import ContextCompressor
|
||||
|
||||
|
||||
def test_update_from_response_sets_total_tokens():
|
||||
"""ABC contract: last_total_tokens must be set from API response."""
|
||||
c = ContextCompressor(model="test", quiet_mode=True, config_context_length=200000)
|
||||
|
||||
c.update_from_response({"prompt_tokens": 100, "completion_tokens": 30, "total_tokens": 130})
|
||||
assert c.last_total_tokens == 130
|
||||
|
||||
c.update_from_response({"prompt_tokens": 100, "completion_tokens": 30})
|
||||
assert c.last_total_tokens == 130
|
||||
|
||||
|
||||
def test_session_reset_clears_total_tokens():
|
||||
"""on_session_reset must zero total_tokens."""
|
||||
c = ContextCompressor(model="test", quiet_mode=True, config_context_length=200000)
|
||||
c.update_from_response({"prompt_tokens": 100, "completion_tokens": 30, "total_tokens": 130})
|
||||
c.on_session_reset()
|
||||
assert c.last_total_tokens == 0
|
||||
|
|
@ -1060,3 +1060,191 @@ class TestHonchoCadenceTracking:
|
|||
p.on_turn_start(2, "second message")
|
||||
should_skip = p._injection_frequency == "first-turn" and p._turn_count > 1
|
||||
assert should_skip, "Second turn (turn 2) SHOULD be skipped"
|
||||
|
||||
|
||||
class TestMemoryToolToolsetGate:
|
||||
"""Issue #5544: memory provider tools must respect platform_toolsets.
|
||||
|
||||
Before the fix, MemoryManager.get_all_tool_schemas() output was appended
|
||||
to AIAgent.tools unconditionally in agent_init.py — bypassing the
|
||||
enabled_toolsets filter. Result: `platform_toolsets: telegram: []`
|
||||
still leaked fact_store and other memory tools into the tool surface,
|
||||
causing 10x latency on local models (Qwen3-30B: 1.7s → 42s) and
|
||||
tool-call loops on small models.
|
||||
|
||||
These tests mirror the gate logic in agent/agent_init.py around the
|
||||
memory provider tool injection block. The gate condition is:
|
||||
|
||||
enabled_toolsets is None → no filter, inject (backward compat)
|
||||
"memory" in enabled_toolsets → user opted in, inject
|
||||
otherwise (incl. []) → skip injection
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _run_memory_injection(enabled_toolsets, memory_manager):
|
||||
"""Simulate the gated memory-tool injection block from agent_init.py."""
|
||||
tools = []
|
||||
valid_tool_names = set()
|
||||
|
||||
if memory_manager and tools is not None and (
|
||||
enabled_toolsets is None or "memory" in enabled_toolsets
|
||||
):
|
||||
_existing = {
|
||||
t.get("function", {}).get("name")
|
||||
for t in tools
|
||||
if isinstance(t, dict)
|
||||
}
|
||||
for _schema in memory_manager.get_all_tool_schemas():
|
||||
_tname = _schema.get("name", "")
|
||||
if _tname and _tname in _existing:
|
||||
continue
|
||||
tools.append({"type": "function", "function": _schema})
|
||||
if _tname:
|
||||
valid_tool_names.add(_tname)
|
||||
_existing.add(_tname)
|
||||
|
||||
return tools, valid_tool_names
|
||||
|
||||
def _mgr_with_tools(self, *tool_names):
|
||||
"""Build a MemoryManager whose providers expose the named tool schemas."""
|
||||
mgr = MemoryManager()
|
||||
p = FakeMemoryProvider(
|
||||
"ext",
|
||||
tools=[{"name": n, "description": n, "parameters": {}} for n in tool_names],
|
||||
)
|
||||
mgr.add_provider(p)
|
||||
return mgr
|
||||
|
||||
def test_none_toolsets_injects(self):
|
||||
"""enabled_toolsets=None (no filter) injects memory tools — backward compat."""
|
||||
mgr = self._mgr_with_tools("fact_store")
|
||||
tools, names = self._run_memory_injection(None, mgr)
|
||||
assert "fact_store" in names
|
||||
assert any(t["function"]["name"] == "fact_store" for t in tools)
|
||||
|
||||
def test_memory_in_toolsets_injects(self):
|
||||
"""enabled_toolsets including 'memory' injects memory tools."""
|
||||
mgr = self._mgr_with_tools("fact_store")
|
||||
tools, names = self._run_memory_injection(["terminal", "memory", "web"], mgr)
|
||||
assert "fact_store" in names
|
||||
|
||||
def test_empty_toolsets_blocks_injection(self):
|
||||
"""`platform_toolsets: telegram: []` must suppress memory tools. (#5544)"""
|
||||
mgr = self._mgr_with_tools("fact_store")
|
||||
tools, names = self._run_memory_injection([], mgr)
|
||||
assert tools == []
|
||||
assert names == set()
|
||||
|
||||
def test_toolsets_without_memory_blocks_injection(self):
|
||||
"""Toolset list that doesn't name 'memory' must suppress injection."""
|
||||
mgr = self._mgr_with_tools("fact_store")
|
||||
tools, names = self._run_memory_injection(["terminal", "web"], mgr)
|
||||
assert tools == []
|
||||
assert names == set()
|
||||
|
||||
def test_no_memory_manager_no_injection(self):
|
||||
"""Gate is moot without a memory manager."""
|
||||
tools, names = self._run_memory_injection(None, None)
|
||||
assert tools == []
|
||||
|
||||
def test_multiple_schemas_all_blocked_together(self):
|
||||
"""When the gate is closed, no memory tools leak — not even partially."""
|
||||
mgr = self._mgr_with_tools("fact_store", "memory_search", "memory_add")
|
||||
tools, names = self._run_memory_injection(["terminal"], mgr)
|
||||
assert tools == []
|
||||
assert names == set()
|
||||
|
||||
def test_multiple_schemas_all_injected_when_enabled(self):
|
||||
"""When the gate is open, every memory tool schema is injected."""
|
||||
mgr = self._mgr_with_tools("fact_store", "memory_search", "memory_add")
|
||||
tools, names = self._run_memory_injection(None, mgr)
|
||||
assert names == {"fact_store", "memory_search", "memory_add"}
|
||||
|
||||
|
||||
class TestContextEngineToolsetGate:
|
||||
"""Issue #5544 (sibling): context engine tools follow the same gate.
|
||||
|
||||
`agent.context_compressor.get_tool_schemas()` (e.g. lcm_grep, lcm_describe,
|
||||
lcm_expand) was appended to AIAgent.tools unconditionally. Same blind
|
||||
injection class as the memory bug; same local-model penalty. Gate name:
|
||||
"context_engine" (matches the existing plugin-system convention).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _run_context_engine_injection(enabled_toolsets, compressor):
|
||||
"""Simulate the gated context-engine injection block from agent_init.py."""
|
||||
tools = []
|
||||
valid_tool_names = set()
|
||||
engine_tool_names = set()
|
||||
|
||||
if (
|
||||
compressor is not None
|
||||
and tools is not None
|
||||
and (
|
||||
enabled_toolsets is None
|
||||
or "context_engine" in enabled_toolsets
|
||||
)
|
||||
):
|
||||
_existing = {
|
||||
t.get("function", {}).get("name")
|
||||
for t in tools
|
||||
if isinstance(t, dict)
|
||||
}
|
||||
for _schema in compressor.get_tool_schemas():
|
||||
_tname = _schema.get("name", "")
|
||||
if _tname and _tname in _existing:
|
||||
continue
|
||||
tools.append({"type": "function", "function": _schema})
|
||||
if _tname:
|
||||
valid_tool_names.add(_tname)
|
||||
engine_tool_names.add(_tname)
|
||||
_existing.add(_tname)
|
||||
|
||||
return tools, valid_tool_names, engine_tool_names
|
||||
|
||||
class _FakeCompressor:
|
||||
def __init__(self, schemas):
|
||||
self._schemas = schemas
|
||||
|
||||
def get_tool_schemas(self):
|
||||
return list(self._schemas)
|
||||
|
||||
def _compressor_with(self, *tool_names):
|
||||
return self._FakeCompressor(
|
||||
[{"name": n, "description": n, "parameters": {}} for n in tool_names]
|
||||
)
|
||||
|
||||
def test_none_toolsets_injects(self):
|
||||
"""enabled_toolsets=None injects context-engine tools — backward compat."""
|
||||
c = self._compressor_with("lcm_grep", "lcm_describe", "lcm_expand")
|
||||
tools, names, engine_names = self._run_context_engine_injection(None, c)
|
||||
assert engine_names == {"lcm_grep", "lcm_describe", "lcm_expand"}
|
||||
|
||||
def test_context_engine_in_toolsets_injects(self):
|
||||
"""enabled_toolsets including 'context_engine' injects the tools."""
|
||||
c = self._compressor_with("lcm_grep")
|
||||
tools, names, engine_names = self._run_context_engine_injection(
|
||||
["terminal", "context_engine"], c
|
||||
)
|
||||
assert "lcm_grep" in engine_names
|
||||
|
||||
def test_empty_toolsets_blocks_injection(self):
|
||||
"""`platform_toolsets: telegram: []` must suppress context-engine tools."""
|
||||
c = self._compressor_with("lcm_grep")
|
||||
tools, names, engine_names = self._run_context_engine_injection([], c)
|
||||
assert tools == []
|
||||
assert engine_names == set()
|
||||
|
||||
def test_toolsets_without_context_engine_blocks_injection(self):
|
||||
"""A toolset list that doesn't name 'context_engine' suppresses injection."""
|
||||
c = self._compressor_with("lcm_grep", "lcm_describe")
|
||||
tools, names, engine_names = self._run_context_engine_injection(
|
||||
["terminal", "memory"], c
|
||||
)
|
||||
assert tools == []
|
||||
assert engine_names == set()
|
||||
|
||||
def test_no_compressor_no_injection(self):
|
||||
"""Gate is moot without a context_compressor."""
|
||||
tools, names, engine_names = self._run_context_engine_injection(None, None)
|
||||
assert tools == []
|
||||
|
|
|
|||
|
|
@ -164,6 +164,7 @@ class TestDefaultContextLengths:
|
|||
"grok-4-1-fast": 2000000,
|
||||
"grok-4-fast": 2000000,
|
||||
"grok-4": 256000,
|
||||
"grok-build": 256000,
|
||||
"grok-code-fast": 256000,
|
||||
"grok-3": 131072,
|
||||
"grok-2": 131072,
|
||||
|
|
@ -195,6 +196,7 @@ class TestDefaultContextLengths:
|
|||
("grok-4-fast-non-reasoning", 2000000),
|
||||
("grok-4", 256000),
|
||||
("grok-4-0709", 256000),
|
||||
("grok-build-0.1", 256000),
|
||||
("grok-code-fast-1", 256000),
|
||||
("grok-3", 131072),
|
||||
("grok-3-mini", 131072),
|
||||
|
|
@ -210,6 +212,32 @@ class TestDefaultContextLengths:
|
|||
f"{model_id}: expected {expected_ctx}, got {actual}"
|
||||
)
|
||||
|
||||
def test_xai_oauth_grok_build_uses_xai_models_dev_context(self):
|
||||
"""xAI OAuth should share the xAI provider metadata path.
|
||||
|
||||
The xAI /v1/models endpoint does not currently include context fields
|
||||
for grok-build-0.1, so this guards against falling through to the
|
||||
generic "grok" 131k fallback when using OAuth credentials.
|
||||
"""
|
||||
registry = {
|
||||
"xai": {
|
||||
"models": {
|
||||
"grok-build-0.1": {
|
||||
"limit": {"context": 256000, "output": 64000},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
||||
patch("agent.model_metadata._query_ollama_api_show", return_value=None), \
|
||||
patch("agent.models_dev.fetch_models_dev", return_value=registry):
|
||||
assert get_model_context_length(
|
||||
"grok-build-0.1",
|
||||
provider="xai-oauth",
|
||||
base_url="https://api.x.ai/v1",
|
||||
api_key="oauth-token",
|
||||
) == 256000
|
||||
|
||||
def test_deepseek_v4_models_1m_context(self):
|
||||
from agent.model_metadata import get_model_context_length
|
||||
from unittest.mock import patch as mock_patch
|
||||
|
|
|
|||
|
|
@ -41,6 +41,16 @@ SAMPLE_REGISTRY = {
|
|||
},
|
||||
},
|
||||
},
|
||||
"xai": {
|
||||
"id": "xai",
|
||||
"name": "xAI",
|
||||
"models": {
|
||||
"grok-build-0.1": {
|
||||
"id": "grok-build-0.1",
|
||||
"limit": {"context": 256000, "output": 64000},
|
||||
},
|
||||
},
|
||||
},
|
||||
"kilo": {
|
||||
"id": "kilo",
|
||||
"name": "Kilo Gateway",
|
||||
|
|
@ -86,6 +96,10 @@ class TestProviderMapping:
|
|||
assert PROVIDER_TO_MODELS_DEV["kilocode"] == "kilo"
|
||||
assert PROVIDER_TO_MODELS_DEV["ai-gateway"] == "vercel"
|
||||
|
||||
def test_xai_oauth_uses_xai_catalog(self):
|
||||
assert PROVIDER_TO_MODELS_DEV["xai"] == "xai"
|
||||
assert PROVIDER_TO_MODELS_DEV["xai-oauth"] == "xai"
|
||||
|
||||
def test_unmapped_provider_not_in_dict(self):
|
||||
assert "nous" not in PROVIDER_TO_MODELS_DEV
|
||||
|
||||
|
|
@ -144,6 +158,12 @@ class TestLookupModelsDevContext:
|
|||
# GitHub Copilot: only 128K for same model
|
||||
assert lookup_models_dev_context("copilot", "claude-opus-4.6") == 128000
|
||||
|
||||
@patch("agent.models_dev.fetch_models_dev")
|
||||
def test_xai_oauth_resolves_xai_context(self, mock_fetch):
|
||||
"""xAI OAuth is an auth path, not a separate model catalog."""
|
||||
mock_fetch.return_value = SAMPLE_REGISTRY
|
||||
assert lookup_models_dev_context("xai-oauth", "grok-build-0.1") == 256000
|
||||
|
||||
@patch("agent.models_dev.fetch_models_dev")
|
||||
def test_zero_context_filtered(self, mock_fetch):
|
||||
mock_fetch.return_value = SAMPLE_REGISTRY
|
||||
|
|
|
|||
|
|
@ -451,6 +451,28 @@ class TestUrlQueryParamRedaction:
|
|||
result = redact_sensitive_text(text)
|
||||
assert "opaqueWsToken123" not in result
|
||||
|
||||
def test_http_access_log_relative_request_target_query(self):
|
||||
text = (
|
||||
'INFO aiohttp.access: 127.0.0.1 "POST '
|
||||
'/bluebubbles-webhook?password=webhookSecret123&event=new-message '
|
||||
'HTTP/1.1" 200 173 "-" "test-client"'
|
||||
)
|
||||
result = redact_sensitive_text(text)
|
||||
assert "webhookSecret123" not in result
|
||||
assert "password=***" in result
|
||||
assert "event=new-message" in result
|
||||
|
||||
def test_http_access_log_absolute_request_target_query(self):
|
||||
text = (
|
||||
'INFO aiohttp.access: 127.0.0.1 "GET '
|
||||
'https://example.com/callback?code=oauthCode123&state=csrf-ok '
|
||||
'HTTP/1.1" 200 173 "-" "test-client"'
|
||||
)
|
||||
result = redact_sensitive_text(text)
|
||||
assert "oauthCode123" not in result
|
||||
assert "code=***" in result
|
||||
assert "state=csrf-ok" in result
|
||||
|
||||
|
||||
class TestUrlUserinfoRedaction:
|
||||
"""URL userinfo (`scheme://user:pass@host`) for non-DB schemes."""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,12 @@
|
|||
"""Tests for agent/skill_utils.py."""
|
||||
|
||||
from agent.skill_utils import extract_skill_conditions, iter_skill_index_files
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent.skill_utils import (
|
||||
extract_skill_conditions,
|
||||
iter_skill_index_files,
|
||||
skill_matches_platform,
|
||||
)
|
||||
|
||||
|
||||
def test_metadata_as_dict_with_hermes():
|
||||
|
|
@ -94,3 +100,100 @@ def test_iter_skill_index_files_prunes_dependency_dirs(tmp_path):
|
|||
found = list(iter_skill_index_files(tmp_path, "SKILL.md"))
|
||||
|
||||
assert found == [real / "SKILL.md"]
|
||||
|
||||
|
||||
# ── skill_matches_platform on Termux ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestSkillMatchesPlatformTermux:
|
||||
"""Termux is Linux userland on Android. Skills tagged platforms:[linux]
|
||||
must load there regardless of whether Python reports sys.platform as
|
||||
"linux" (pre-3.13) or "android" (3.13+). Reported by user @LikiusInik
|
||||
in May 2026 — only 3 built-in skills appeared on Termux because every
|
||||
github/productivity/mlops skill is tagged platforms:[linux,macos,windows]
|
||||
and sys.platform=="android" did not start with "linux".
|
||||
"""
|
||||
|
||||
def test_no_platforms_field_matches_everywhere(self):
|
||||
# Backward-compat default — skills without a platforms tag load
|
||||
# on any OS, Termux included.
|
||||
with patch("agent.skill_utils.sys.platform", "android"), patch(
|
||||
"agent.skill_utils.is_termux", return_value=True
|
||||
):
|
||||
assert skill_matches_platform({}) is True
|
||||
assert skill_matches_platform({"name": "foo"}) is True
|
||||
|
||||
def test_linux_skill_loads_on_termux_android_platform(self):
|
||||
# Python 3.13+ on Termux reports sys.platform == "android".
|
||||
fm = {"platforms": ["linux"]}
|
||||
with patch("agent.skill_utils.sys.platform", "android"), patch(
|
||||
"agent.skill_utils.is_termux", return_value=True
|
||||
):
|
||||
assert skill_matches_platform(fm) is True
|
||||
|
||||
def test_linux_macos_windows_skill_loads_on_termux(self):
|
||||
# The common "[linux, macos, windows]" tag used by github-*,
|
||||
# productivity, mlops, etc.
|
||||
fm = {"platforms": ["linux", "macos", "windows"]}
|
||||
with patch("agent.skill_utils.sys.platform", "android"), patch(
|
||||
"agent.skill_utils.is_termux", return_value=True
|
||||
):
|
||||
assert skill_matches_platform(fm) is True
|
||||
|
||||
def test_linux_skill_loads_on_termux_linux_platform(self):
|
||||
# Pre-3.13 Termux reports sys.platform == "linux" already — this
|
||||
# works without the Termux escape hatch but must still pass.
|
||||
fm = {"platforms": ["linux"]}
|
||||
with patch("agent.skill_utils.sys.platform", "linux"), patch(
|
||||
"agent.skill_utils.is_termux", return_value=True
|
||||
):
|
||||
assert skill_matches_platform(fm) is True
|
||||
|
||||
def test_macos_only_skill_still_excluded_on_termux(self):
|
||||
# macOS-only skills (apple-notes, imessage, ...) should NOT load
|
||||
# on Termux. The Termux fallback only widens platforms:[linux,...].
|
||||
fm = {"platforms": ["macos"]}
|
||||
with patch("agent.skill_utils.sys.platform", "android"), patch(
|
||||
"agent.skill_utils.is_termux", return_value=True
|
||||
):
|
||||
assert skill_matches_platform(fm) is False
|
||||
|
||||
def test_windows_only_skill_still_excluded_on_termux(self):
|
||||
fm = {"platforms": ["windows"]}
|
||||
with patch("agent.skill_utils.sys.platform", "android"), patch(
|
||||
"agent.skill_utils.is_termux", return_value=True
|
||||
):
|
||||
assert skill_matches_platform(fm) is False
|
||||
|
||||
def test_explicit_termux_or_android_tag_matches(self):
|
||||
# Skills can also opt in explicitly via platforms:[termux] or
|
||||
# platforms:[android] — both should match a Termux session.
|
||||
with patch("agent.skill_utils.sys.platform", "android"), patch(
|
||||
"agent.skill_utils.is_termux", return_value=True
|
||||
):
|
||||
assert skill_matches_platform({"platforms": ["termux"]}) is True
|
||||
assert skill_matches_platform({"platforms": ["android"]}) is True
|
||||
|
||||
def test_non_termux_android_does_not_widen(self):
|
||||
# If we're somehow on a plain Android Python (not Termux), don't
|
||||
# silently load Linux skills — Termux is the supported environment.
|
||||
fm = {"platforms": ["linux"]}
|
||||
with patch("agent.skill_utils.sys.platform", "android"), patch(
|
||||
"agent.skill_utils.is_termux", return_value=False
|
||||
):
|
||||
assert skill_matches_platform(fm) is False
|
||||
|
||||
def test_linux_skill_on_real_linux_unaffected(self):
|
||||
# The non-Termux Linux path must not change.
|
||||
fm = {"platforms": ["linux"]}
|
||||
with patch("agent.skill_utils.sys.platform", "linux"), patch(
|
||||
"agent.skill_utils.is_termux", return_value=False
|
||||
):
|
||||
assert skill_matches_platform(fm) is True
|
||||
|
||||
def test_macos_skill_on_real_macos_unaffected(self):
|
||||
fm = {"platforms": ["macos"]}
|
||||
with patch("agent.skill_utils.sys.platform", "darwin"), patch(
|
||||
"agent.skill_utils.is_termux", return_value=False
|
||||
):
|
||||
assert skill_matches_platform(fm) is True
|
||||
|
|
|
|||
297
tests/agent/test_vision_routing_31179.py
Normal file
297
tests/agent/test_vision_routing_31179.py
Normal file
|
|
@ -0,0 +1,297 @@
|
|||
"""Regression tests for issue #31179.
|
||||
|
||||
Before the fix:
|
||||
- ``auxiliary.vision.provider: openai`` silently failed to resolve because
|
||||
``openai`` is not a first-class provider in PROVIDER_REGISTRY (only
|
||||
``openai-codex`` for OAuth and ``custom`` for OPENAI_BASE_URL).
|
||||
- The vision branch of ``call_llm`` then silently fell back to ``auto``
|
||||
which happily picked the user's main provider (e.g. DeepSeek), sending
|
||||
image content to a text-only endpoint and producing cryptic
|
||||
``unknown variant 'image_url', expected 'text'`` errors.
|
||||
- ``check_vision_requirements`` used the explicit-only path, so
|
||||
``vision_analyze`` disappeared from the tool list while ``browser_vision``
|
||||
stayed (its check_fn only validated the browser).
|
||||
|
||||
The three fixes covered here:
|
||||
1. ``provider: openai`` in auxiliary task config resolves to
|
||||
``custom`` + ``https://api.openai.com/v1``.
|
||||
2. The vision auto-detect chain skips the user's main provider when it
|
||||
reports ``supports_vision=False`` instead of routing image content to
|
||||
a text-only endpoint.
|
||||
3. ``check_vision_requirements`` mirrors the runtime fallback chain so
|
||||
``vision_analyze`` shows up whenever the auto chain can serve vision,
|
||||
and ``browser_vision`` gates on vision availability as well.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test infrastructure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_home(monkeypatch):
|
||||
"""Temp HERMES_HOME with config + clean credential env vars."""
|
||||
test_home = tempfile.mkdtemp(prefix="hermes_test_31179_")
|
||||
hermes_home = os.path.join(test_home, ".hermes")
|
||||
os.makedirs(hermes_home)
|
||||
monkeypatch.setenv("HERMES_HOME", hermes_home)
|
||||
|
||||
# Strip all credential-shaped env vars so each scenario starts hermetic.
|
||||
for k in list(os.environ.keys()):
|
||||
if k.endswith("_API_KEY") or k.endswith("_TOKEN"):
|
||||
monkeypatch.delenv(k, raising=False)
|
||||
|
||||
yield hermes_home
|
||||
shutil.rmtree(test_home, ignore_errors=True)
|
||||
|
||||
|
||||
def _write_config(home: str, text: str) -> None:
|
||||
with open(os.path.join(home, "config.yaml"), "w") as fp:
|
||||
fp.write(text)
|
||||
|
||||
|
||||
def _fresh_modules():
|
||||
"""Drop cached hermes modules so each test reloads against current env."""
|
||||
for mod in list(sys.modules.keys()):
|
||||
if mod.startswith(("agent.auxiliary_client", "agent.image_routing",
|
||||
"tools.vision_tools", "tools.browser_tool",
|
||||
"hermes_cli.config")):
|
||||
del sys.modules[mod]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix 1: provider=openai → custom + api.openai.com/v1
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpenAiAliasForAuxiliary:
|
||||
"""``auxiliary.<task>.provider: openai`` should produce a working client."""
|
||||
|
||||
def test_provider_openai_routes_to_openai_dot_com(self, isolated_home, monkeypatch):
|
||||
_write_config(isolated_home, """
|
||||
auxiliary:
|
||||
vision:
|
||||
provider: openai
|
||||
model: gpt-4o-mini
|
||||
""")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
|
||||
_fresh_modules()
|
||||
|
||||
from agent.auxiliary_client import _resolve_task_provider_model
|
||||
provider, model, base_url, _key, _mode = _resolve_task_provider_model("vision")
|
||||
assert provider == "custom"
|
||||
assert model == "gpt-4o-mini"
|
||||
assert base_url == "https://api.openai.com/v1"
|
||||
|
||||
def test_provider_openai_with_explicit_base_url_preserves_user_endpoint(
|
||||
self, isolated_home, monkeypatch
|
||||
):
|
||||
"""User-supplied base_url wins; alias still normalizes provider name
|
||||
to ``custom`` so resolution doesn't hit the unknown-provider path."""
|
||||
_write_config(isolated_home, """
|
||||
auxiliary:
|
||||
vision:
|
||||
provider: openai
|
||||
model: gpt-4o-mini
|
||||
base_url: https://my-proxy.example.com/v1
|
||||
""")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
|
||||
_fresh_modules()
|
||||
|
||||
from agent.auxiliary_client import _resolve_task_provider_model
|
||||
provider, _model, base_url, _key, _mode = _resolve_task_provider_model("vision")
|
||||
assert provider == "custom"
|
||||
assert base_url == "https://my-proxy.example.com/v1"
|
||||
|
||||
def test_provider_openai_resolves_to_working_client(self, isolated_home, monkeypatch):
|
||||
"""End-to-end: the resolved client points at api.openai.com."""
|
||||
_write_config(isolated_home, """
|
||||
auxiliary:
|
||||
vision:
|
||||
provider: openai
|
||||
model: gpt-4o-mini
|
||||
""")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
|
||||
_fresh_modules()
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
from urllib.parse import urlparse
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
assert client is not None, "openai alias should produce a usable client"
|
||||
# Exact hostname comparison (not substring) — defends against URLs
|
||||
# like ``api.openai.com.evil.example`` and keeps CodeQL happy.
|
||||
host = urlparse(str(getattr(client, "base_url", ""))).hostname or ""
|
||||
assert host == "api.openai.com", f"expected api.openai.com host, got {host!r}"
|
||||
assert model == "gpt-4o-mini"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix 2: auto chain skips text-only main providers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTextOnlyMainSkippedForVision:
|
||||
"""Vision auto-detect must not return a text-only main-provider client."""
|
||||
|
||||
def test_text_only_main_skipped_when_no_aggregator(self, isolated_home, monkeypatch):
|
||||
"""DeepSeek main + no aggregator credentials → no client built.
|
||||
|
||||
Pre-fix this silently returned the deepseek client with model
|
||||
substitution, producing ``unknown variant 'image_url'`` at call time.
|
||||
"""
|
||||
_write_config(isolated_home, """
|
||||
model:
|
||||
provider: deepseek
|
||||
default: deepseek-v4-pro
|
||||
""")
|
||||
monkeypatch.setenv("DEEPSEEK_API_KEY", "sk-test")
|
||||
_fresh_modules()
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
provider, client, _model = resolve_vision_provider_client(provider="auto")
|
||||
assert client is None, (
|
||||
f"Vision auto-detect must skip text-only main {provider!r} when "
|
||||
"no vision-capable aggregator is available, not return a client "
|
||||
"that will fail at API time"
|
||||
)
|
||||
|
||||
def test_vision_capable_main_used(self, isolated_home, monkeypatch):
|
||||
"""Vision-capable main provider should be returned by auto chain."""
|
||||
_write_config(isolated_home, """
|
||||
model:
|
||||
provider: anthropic
|
||||
default: claude-sonnet-4-6
|
||||
""")
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test")
|
||||
_fresh_modules()
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
provider, client, _model = resolve_vision_provider_client(provider="auto")
|
||||
assert client is not None
|
||||
assert provider == "anthropic"
|
||||
|
||||
def test_unknown_capability_does_not_block(self, isolated_home, monkeypatch):
|
||||
"""When models.dev has no entry, fall back to permissive (attempt the call).
|
||||
|
||||
This keeps new/custom providers working — only providers we have
|
||||
cataloged as text-only are skipped.
|
||||
"""
|
||||
_fresh_modules()
|
||||
from agent.auxiliary_client import _main_model_supports_vision
|
||||
# Bogus provider/model — capability lookup returns None → permissive.
|
||||
assert _main_model_supports_vision("nonexistent-provider", "nonexistent-model") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fix 3: check_vision_requirements + check_browser_vision_requirements parity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVisionToolGating:
|
||||
"""Tool visibility must match runtime capability."""
|
||||
|
||||
def test_check_vision_succeeds_for_aliased_openai(self, isolated_home, monkeypatch):
|
||||
"""The user's exact reported scenario: provider=openai unhides
|
||||
vision_analyze instead of silently dropping it."""
|
||||
_write_config(isolated_home, """
|
||||
auxiliary:
|
||||
vision:
|
||||
provider: openai
|
||||
model: gpt-4o-mini
|
||||
""")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-test")
|
||||
_fresh_modules()
|
||||
|
||||
from tools.vision_tools import check_vision_requirements
|
||||
assert check_vision_requirements() is True
|
||||
|
||||
def test_check_vision_falls_back_to_auto(self, isolated_home, monkeypatch):
|
||||
"""Bad explicit provider doesn't hide the tool when auto fallback works.
|
||||
|
||||
Mirrors call_llm's runtime fallback chain.
|
||||
"""
|
||||
_write_config(isolated_home, """
|
||||
model:
|
||||
provider: openrouter
|
||||
default: anthropic/claude-sonnet-4
|
||||
auxiliary:
|
||||
vision:
|
||||
provider: not-a-real-provider
|
||||
""")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-test")
|
||||
_fresh_modules()
|
||||
|
||||
from tools.vision_tools import check_vision_requirements
|
||||
assert check_vision_requirements() is True
|
||||
|
||||
def test_check_vision_false_with_text_only_main_and_no_aggregator(
|
||||
self, isolated_home, monkeypatch
|
||||
):
|
||||
_write_config(isolated_home, """
|
||||
model:
|
||||
provider: deepseek
|
||||
default: deepseek-v4-pro
|
||||
""")
|
||||
monkeypatch.setenv("DEEPSEEK_API_KEY", "sk-test")
|
||||
_fresh_modules()
|
||||
|
||||
from tools.vision_tools import check_vision_requirements
|
||||
assert check_vision_requirements() is False
|
||||
|
||||
def test_browser_vision_requires_both_browser_and_vision(self, isolated_home, monkeypatch):
|
||||
"""``browser_vision`` must not be advertised when vision is unavailable."""
|
||||
from unittest.mock import patch
|
||||
|
||||
_write_config(isolated_home, """
|
||||
model:
|
||||
provider: deepseek
|
||||
default: deepseek-v4-pro
|
||||
""")
|
||||
monkeypatch.setenv("DEEPSEEK_API_KEY", "sk-test")
|
||||
_fresh_modules()
|
||||
|
||||
import tools.browser_tool
|
||||
# Force the browser side to True so we exercise the vision-gating part.
|
||||
with patch.object(tools.browser_tool, "check_browser_requirements", return_value=True):
|
||||
assert tools.browser_tool.check_browser_vision_requirements() is False
|
||||
|
||||
def test_browser_vision_false_when_browser_missing(self, isolated_home, monkeypatch):
|
||||
from unittest.mock import patch
|
||||
|
||||
_write_config(isolated_home, """
|
||||
model:
|
||||
provider: openrouter
|
||||
default: anthropic/claude-sonnet-4
|
||||
""")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-test")
|
||||
_fresh_modules()
|
||||
|
||||
import tools.browser_tool
|
||||
with patch.object(tools.browser_tool, "check_browser_requirements", return_value=False):
|
||||
# Vision available but browser missing → still False.
|
||||
assert tools.browser_tool.check_browser_vision_requirements() is False
|
||||
|
||||
def test_browser_vision_true_when_both_available(self, isolated_home, monkeypatch):
|
||||
from unittest.mock import patch
|
||||
|
||||
_write_config(isolated_home, """
|
||||
model:
|
||||
provider: openrouter
|
||||
default: anthropic/claude-sonnet-4
|
||||
""")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-test")
|
||||
_fresh_modules()
|
||||
|
||||
import tools.browser_tool
|
||||
with patch.object(tools.browser_tool, "check_browser_requirements", return_value=True):
|
||||
assert tools.browser_tool.check_browser_vision_requirements() is True
|
||||
|
|
@ -66,6 +66,38 @@ class TestChatCompletionsBasic:
|
|||
# Original list untouched (deepcopy-on-demand)
|
||||
assert msgs[2]["tool_name"] == "execute_code"
|
||||
|
||||
def test_convert_messages_strips_internal_scaffolding_markers(self, transport):
|
||||
"""Hermes-internal ``_``-prefixed markers must never reach the wire.
|
||||
|
||||
The empty-response recovery path appends synthetic messages tagged
|
||||
with ``_empty_recovery_synthetic``; permissive providers ignore the
|
||||
unknown key, but strict gateways (opencode-go, codex.nekos.me)
|
||||
reject the request, poisoning every later turn in the session.
|
||||
"""
|
||||
msgs = [
|
||||
{"role": "user", "content": "run the task"},
|
||||
{"role": "assistant", "content": "(empty)", "_empty_recovery_synthetic": True},
|
||||
{"role": "user", "content": "continue", "_empty_recovery_synthetic": True},
|
||||
{"role": "assistant", "content": "done", "_thinking_prefill": True,
|
||||
"_empty_terminal_sentinel": True},
|
||||
]
|
||||
result = transport.convert_messages(msgs)
|
||||
for m in result:
|
||||
assert not any(k.startswith("_") for k in m), m
|
||||
# Visible content preserved
|
||||
assert result[1]["content"] == "(empty)"
|
||||
assert result[2]["content"] == "continue"
|
||||
# Original list untouched (deepcopy-on-demand)
|
||||
assert msgs[1]["_empty_recovery_synthetic"] is True
|
||||
|
||||
def test_convert_messages_clean_list_is_identity(self, transport):
|
||||
"""A list with no internal/codex keys is returned as-is (no copy)."""
|
||||
msgs = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hello"},
|
||||
]
|
||||
assert transport.convert_messages(msgs) is msgs
|
||||
|
||||
|
||||
class TestChatCompletionsBuildKwargs:
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from agent.transports.codex_app_server_session import (
|
|||
TurnResult,
|
||||
_ServerRequestRouting,
|
||||
_approval_choice_to_codex_decision,
|
||||
_coerce_turn_input_text,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -128,6 +129,15 @@ class TestApprovalChoiceMapping:
|
|||
assert _approval_choice_to_codex_decision(choice) == expected
|
||||
|
||||
|
||||
class TestTurnInputCoercion:
|
||||
def test_list_content_keeps_text_and_marks_images(self):
|
||||
text = _coerce_turn_input_text([
|
||||
{"type": "text", "text": "caption"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
])
|
||||
assert text == "caption\n\n[image attached]"
|
||||
|
||||
|
||||
# ---- lifecycle ----
|
||||
|
||||
class TestLifecycle:
|
||||
|
|
@ -188,6 +198,35 @@ class TestRunTurn:
|
|||
# turn_id propagated for downstream session-DB linkage
|
||||
assert r.turn_id == "turn-fake-001"
|
||||
|
||||
def test_rich_content_turn_is_collapsed_to_text_payload(self):
|
||||
client = FakeClient()
|
||||
client.queue_notification(
|
||||
"turn/completed",
|
||||
threadId="t",
|
||||
turn={"id": "tu1", "status": "completed", "error": None},
|
||||
)
|
||||
s = make_session(client)
|
||||
r = s.run_turn(
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": "look at this\n\n[Image attached at: /tmp/a.png]",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/png;base64,abc"},
|
||||
},
|
||||
],
|
||||
turn_timeout=2.0,
|
||||
)
|
||||
assert r.error is None
|
||||
method, params = next(req for req in client.requests if req[0] == "turn/start")
|
||||
assert method == "turn/start"
|
||||
text = params["input"][0]["text"]
|
||||
assert isinstance(text, str)
|
||||
assert "[Image attached at: /tmp/a.png]" in text
|
||||
assert "[image attached]" in text
|
||||
|
||||
def test_tool_iteration_counter_ticks(self):
|
||||
client = FakeClient()
|
||||
# Two completed exec items + one final agent message
|
||||
|
|
|
|||
|
|
@ -168,6 +168,25 @@ class TestBranchCommandCLI:
|
|||
|
||||
assert cli_instance._resumed is True
|
||||
|
||||
def test_branch_rotates_hermes_session_id_env_and_context(self, cli_instance, session_db):
|
||||
"""Branching must update process-local session-id readers too."""
|
||||
from cli import HermesCLI
|
||||
from gateway.session_context import _UNSET, _VAR_MAP, get_session_env
|
||||
|
||||
old_session_id = cli_instance.session_id
|
||||
os.environ["HERMES_SESSION_ID"] = old_session_id
|
||||
_VAR_MAP["HERMES_SESSION_ID"].set(old_session_id)
|
||||
|
||||
try:
|
||||
HermesCLI._handle_branch_command(cli_instance, "/branch")
|
||||
|
||||
assert cli_instance.session_id != old_session_id
|
||||
assert os.environ["HERMES_SESSION_ID"] == cli_instance.session_id
|
||||
assert get_session_env("HERMES_SESSION_ID") == cli_instance.session_id
|
||||
finally:
|
||||
os.environ.pop("HERMES_SESSION_ID", None)
|
||||
_VAR_MAP["HERMES_SESSION_ID"].set(_UNSET)
|
||||
|
||||
def test_branch_fires_on_session_switch_hook(self, cli_instance, session_db):
|
||||
"""The /branch command must notify memory providers of the rotation.
|
||||
|
||||
|
|
|
|||
|
|
@ -102,6 +102,20 @@ class TestVerboseAndToolProgress:
|
|||
assert cli.tool_progress_mode in {"off", "new", "all", "verbose"}
|
||||
|
||||
|
||||
class TestFallbackChainInit:
|
||||
def test_merges_new_and_legacy_fallback_config(self):
|
||||
cli = _make_cli(config_overrides={
|
||||
"fallback_providers": [
|
||||
{"provider": "openrouter", "model": "anthropic/claude-sonnet-4.6"},
|
||||
],
|
||||
"fallback_model": {"provider": "nous", "model": "Hermes-4"},
|
||||
})
|
||||
assert cli._fallback_model == [
|
||||
{"provider": "openrouter", "model": "anthropic/claude-sonnet-4.6"},
|
||||
{"provider": "nous", "model": "Hermes-4"},
|
||||
]
|
||||
|
||||
|
||||
class TestBusyInputMode:
|
||||
def test_default_busy_input_mode_is_interrupt(self):
|
||||
cli = _make_cli()
|
||||
|
|
@ -317,7 +331,63 @@ class TestHistoryDisplay:
|
|||
|
||||
assert "Recent sessions" in output
|
||||
assert "Checking Running Hermes Agent" in output
|
||||
assert "Use /resume <session id or title> to continue" in output
|
||||
assert "Use /resume" in output
|
||||
assert "session title" in output
|
||||
|
||||
def test_resume_updates_hermes_session_id_env_and_context(self, tmp_path):
|
||||
from gateway.session_context import _UNSET, _VAR_MAP, get_session_env
|
||||
from hermes_state import SessionDB
|
||||
|
||||
cli = _make_cli()
|
||||
cli.session_id = "current_session"
|
||||
cli.conversation_history = []
|
||||
cli.agent = None
|
||||
cli._session_db = SessionDB(db_path=tmp_path / "state.db")
|
||||
cli._session_db.create_session("current_session", "cli")
|
||||
cli._session_db.create_session("target_session", "cli")
|
||||
cli._session_db.append_message("target_session", "user", "hello from resumed session")
|
||||
|
||||
os.environ["HERMES_SESSION_ID"] = "current_session"
|
||||
_VAR_MAP["HERMES_SESSION_ID"].set("current_session")
|
||||
|
||||
try:
|
||||
cli._handle_resume_command("/resume target_session")
|
||||
|
||||
assert cli.session_id == "target_session"
|
||||
assert os.environ["HERMES_SESSION_ID"] == "target_session"
|
||||
assert get_session_env("HERMES_SESSION_ID") == "target_session"
|
||||
finally:
|
||||
cli._session_db.close()
|
||||
os.environ.pop("HERMES_SESSION_ID", None)
|
||||
_VAR_MAP["HERMES_SESSION_ID"].set(_UNSET)
|
||||
|
||||
def test_resume_list_shows_full_long_titles(self, capsys):
|
||||
"""Long session titles render in full in the /resume table — not
|
||||
truncated to 30 chars (fixes #14082)."""
|
||||
cli = _make_cli()
|
||||
cli.session_id = "current"
|
||||
cli._session_db = MagicMock()
|
||||
long_title = "Salvage BytePlus Volcengine PR With Fixes"
|
||||
cli._session_db.list_sessions_rich.return_value = [
|
||||
{
|
||||
"id": "current",
|
||||
"title": "Current",
|
||||
"preview": "Current preview",
|
||||
"last_active": 0,
|
||||
},
|
||||
{
|
||||
"id": "20260401_201329_d85961",
|
||||
"title": long_title,
|
||||
"preview": "fix byteplus pr and resume",
|
||||
"last_active": 0,
|
||||
},
|
||||
]
|
||||
|
||||
cli._handle_resume_command("/resume")
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert long_title in output
|
||||
assert "20260401_201329_d85961" in output
|
||||
|
||||
def test_sessions_command_no_args_lists_recent_sessions(self, capsys):
|
||||
"""/sessions with no args prints the recent-sessions table (TUI parity).
|
||||
|
|
@ -429,8 +499,8 @@ class TestRootLevelProviderOverride:
|
|||
|
||||
assert cfg["model"]["provider"] == "openrouter"
|
||||
|
||||
def test_root_provider_ignored_when_default_model_provider_exists(self, tmp_path, monkeypatch):
|
||||
"""Even when model.provider is the default 'auto', root-level provider is ignored."""
|
||||
def test_root_provider_used_as_fallback_when_model_provider_missing(self, tmp_path, monkeypatch):
|
||||
"""Legacy root-level provider still populates model.provider in the CLI loader."""
|
||||
import yaml
|
||||
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
|
|
@ -450,8 +520,29 @@ class TestRootLevelProviderOverride:
|
|||
monkeypatch.setattr(cli, "_hermes_home", hermes_home)
|
||||
cfg = cli.load_cli_config()
|
||||
|
||||
# Root-level "opencode-go" must NOT leak through
|
||||
assert cfg["model"]["provider"] != "opencode-go"
|
||||
assert cfg["model"]["provider"] == "opencode-go"
|
||||
|
||||
def test_root_base_url_used_as_fallback_when_model_base_url_missing(self, tmp_path, monkeypatch):
|
||||
"""Legacy root-level base_url still populates model.base_url in the CLI loader."""
|
||||
import yaml
|
||||
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text(yaml.safe_dump({
|
||||
"base_url": "https://example.com/v1",
|
||||
"model": {
|
||||
"default": "google/gemini-3-flash-preview",
|
||||
},
|
||||
}))
|
||||
|
||||
import cli
|
||||
monkeypatch.setattr(cli, "_hermes_home", hermes_home)
|
||||
cfg = cli.load_cli_config()
|
||||
|
||||
assert cfg["model"]["base_url"] == "https://example.com/v1"
|
||||
|
||||
def test_terminal_vercel_runtime_bridged_to_env(self, tmp_path, monkeypatch):
|
||||
"""Classic CLI must expose terminal.vercel_runtime to terminal_tool.py."""
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ import sys
|
|||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_state import SessionDB
|
||||
from tools.todo_tool import TodoStore
|
||||
|
||||
|
|
@ -138,6 +140,15 @@ def _prepare_cli_with_active_session(tmp_path):
|
|||
return cli
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_session_id_context():
|
||||
from gateway.session_context import _UNSET, _VAR_MAP
|
||||
|
||||
yield
|
||||
os.environ.pop("HERMES_SESSION_ID", None)
|
||||
_VAR_MAP["HERMES_SESSION_ID"].set(_UNSET)
|
||||
|
||||
|
||||
def test_new_command_creates_real_fresh_session_and_resets_agent_state(tmp_path):
|
||||
cli = _prepare_cli_with_active_session(tmp_path)
|
||||
old_session_id = cli.session_id
|
||||
|
|
@ -164,6 +175,21 @@ def test_new_command_creates_real_fresh_session_and_resets_agent_state(tmp_path)
|
|||
cli.agent._invalidate_system_prompt.assert_called_once()
|
||||
|
||||
|
||||
def test_new_command_rotates_hermes_session_id_env_and_context(tmp_path):
|
||||
from gateway.session_context import _VAR_MAP, get_session_env
|
||||
|
||||
cli = _prepare_cli_with_active_session(tmp_path)
|
||||
old_session_id = cli.session_id
|
||||
os.environ["HERMES_SESSION_ID"] = old_session_id
|
||||
_VAR_MAP["HERMES_SESSION_ID"].set(old_session_id)
|
||||
|
||||
cli.process_command("/new")
|
||||
|
||||
assert cli.session_id != old_session_id
|
||||
assert os.environ["HERMES_SESSION_ID"] == cli.session_id
|
||||
assert get_session_env("HERMES_SESSION_ID") == cli.session_id
|
||||
|
||||
|
||||
def test_reset_command_is_alias_for_new_session(tmp_path):
|
||||
cli = _prepare_cli_with_active_session(tmp_path)
|
||||
old_session_id = cli.session_id
|
||||
|
|
|
|||
77
tests/cli/test_cli_resume_command.py
Normal file
77
tests/cli/test_cli_resume_command.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
def _make_cli():
|
||||
cli_obj = HermesCLI.__new__(HermesCLI)
|
||||
cli_obj.session_id = "current_session"
|
||||
cli_obj._resumed = False
|
||||
cli_obj._pending_title = None
|
||||
cli_obj.conversation_history = []
|
||||
cli_obj.agent = None
|
||||
cli_obj._session_db = MagicMock()
|
||||
# _handle_resume_command now triggers _display_resumed_history (#31695),
|
||||
# which reads self.resume_display. "minimal" short-circuits the recap so
|
||||
# the test only exercises session-switch behavior.
|
||||
cli_obj.resume_display = "minimal"
|
||||
return cli_obj
|
||||
|
||||
|
||||
class TestCliResumeCommand:
|
||||
def test_show_recent_sessions_includes_indexes_and_resume_hint(self, capsys):
|
||||
cli_obj = _make_cli()
|
||||
cli_obj._list_recent_sessions = MagicMock(return_value=[
|
||||
{"id": "sess_002", "title": "Coding", "preview": "build feature", "last_active": None},
|
||||
{"id": "sess_001", "title": "Research", "preview": "read docs", "last_active": None},
|
||||
])
|
||||
|
||||
shown = cli_obj._show_recent_sessions(reason="resume")
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert shown is True
|
||||
assert "1" in output
|
||||
assert "2" in output
|
||||
assert "Coding" in output
|
||||
assert "Research" in output
|
||||
assert "/resume 2" in output
|
||||
assert "/resume <session title>" in output
|
||||
|
||||
def test_handle_resume_by_index_switches_to_numbered_session(self):
|
||||
cli_obj = _make_cli()
|
||||
cli_obj._list_recent_sessions = MagicMock(return_value=[
|
||||
{"id": "sess_002", "title": "Coding"},
|
||||
{"id": "sess_001", "title": "Research"},
|
||||
])
|
||||
cli_obj._session_db.get_session.return_value = {"id": "sess_001", "title": "Research"}
|
||||
cli_obj._session_db.get_messages_as_conversation.return_value = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
# resolve_resume_session_id passes the id through when no compression chain.
|
||||
cli_obj._session_db.resolve_resume_session_id.return_value = "sess_001"
|
||||
|
||||
with (
|
||||
patch("hermes_cli.main._resolve_session_by_name_or_id", return_value=None),
|
||||
patch("cli._cprint") as mock_cprint,
|
||||
):
|
||||
cli_obj._handle_resume_command("/resume 2")
|
||||
|
||||
printed = " ".join(str(call) for call in mock_cprint.call_args_list)
|
||||
assert cli_obj.session_id == "sess_001"
|
||||
assert "Resumed session sess_001" in printed
|
||||
assert "Research" in printed
|
||||
|
||||
def test_handle_resume_by_index_out_of_range(self):
|
||||
cli_obj = _make_cli()
|
||||
cli_obj._list_recent_sessions = MagicMock(return_value=[
|
||||
{"id": "sess_002", "title": "Coding"},
|
||||
])
|
||||
|
||||
with patch("cli._cprint") as mock_cprint:
|
||||
cli_obj._handle_resume_command("/resume 9")
|
||||
|
||||
printed = " ".join(str(call) for call in mock_cprint.call_args_list)
|
||||
assert "out of range" in printed.lower()
|
||||
assert "/resume" in printed
|
||||
assert cli_obj.session_id == "current_session"
|
||||
|
|
@ -209,3 +209,123 @@ def test_slash_confirm_display_fragments_include_choice_mapping():
|
|||
assert "[2] Always Approve" in rendered
|
||||
assert "[3] Cancel" in rendered
|
||||
assert "Type 1/2/3" in rendered
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inline-skip escape hatch (issue #30768)
|
||||
#
|
||||
# Users on platforms where the prompt_toolkit modal doesn't dispatch keys
|
||||
# (currently native Windows PowerShell) need a way to bypass the confirmation
|
||||
# without flipping the config gate. ``/reset now``, ``/new --yes``, ``/clear
|
||||
# -y`` all skip the modal and return "once" immediately.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_split_destructive_skip_recognized_tokens():
|
||||
"""``now``, ``--yes``, and ``-y`` are recognized as skip tokens."""
|
||||
from cli import HermesCLI
|
||||
|
||||
assert HermesCLI._split_destructive_skip("/reset now") == ("", True)
|
||||
assert HermesCLI._split_destructive_skip("/clear --yes") == ("", True)
|
||||
assert HermesCLI._split_destructive_skip("/undo -y") == ("", True)
|
||||
|
||||
|
||||
def test_split_destructive_skip_strips_command_word():
|
||||
"""Leading ``/cmd`` token is stripped; remaining args survive."""
|
||||
from cli import HermesCLI
|
||||
|
||||
assert HermesCLI._split_destructive_skip("/new My title") == ("My title", False)
|
||||
assert HermesCLI._split_destructive_skip("/new --yes My title") == ("My title", True)
|
||||
|
||||
|
||||
def test_split_destructive_skip_case_insensitive():
|
||||
"""Token matching is case-insensitive but not a substring match."""
|
||||
from cli import HermesCLI
|
||||
|
||||
assert HermesCLI._split_destructive_skip("/new NOW") == ("", True)
|
||||
# Substring match must NOT trigger — "Now-Title" is a literal title token.
|
||||
assert HermesCLI._split_destructive_skip("/new Now-Title") == ("Now-Title", False)
|
||||
|
||||
|
||||
def test_split_destructive_skip_handles_empty_and_none():
|
||||
"""Defensive against missing/empty input."""
|
||||
from cli import HermesCLI
|
||||
|
||||
assert HermesCLI._split_destructive_skip(None) == ("", False)
|
||||
assert HermesCLI._split_destructive_skip("") == ("", False)
|
||||
assert HermesCLI._split_destructive_skip(" ") == ("", False)
|
||||
|
||||
|
||||
def test_confirm_destructive_slash_now_skips_modal():
|
||||
"""``/reset now`` skips the modal even when the gate is on."""
|
||||
from cli import HermesCLI
|
||||
|
||||
# Build a prompt stub that fails the test if invoked — proving the modal
|
||||
# was never reached.
|
||||
def _explode(**_kw):
|
||||
raise AssertionError("modal must not be invoked when inline-skip present")
|
||||
|
||||
self_ = SimpleNamespace(
|
||||
_app=None,
|
||||
_prompt_text_input_modal=_explode,
|
||||
)
|
||||
self_._normalize_slash_confirm_choice = _bound(
|
||||
HermesCLI._normalize_slash_confirm_choice, self_,
|
||||
)
|
||||
self_._split_destructive_skip = HermesCLI._split_destructive_skip # classmethod
|
||||
|
||||
with patch(
|
||||
"cli.load_cli_config",
|
||||
return_value={"approvals": {"destructive_slash_confirm": True}},
|
||||
):
|
||||
result = _bound(HermesCLI._confirm_destructive_slash, self_)(
|
||||
"new", "detail", cmd_original="/reset now",
|
||||
)
|
||||
|
||||
assert result == "once"
|
||||
|
||||
|
||||
def test_confirm_destructive_slash_yes_flag_skips_modal():
|
||||
"""``--yes`` flag is equivalent to ``now``."""
|
||||
from cli import HermesCLI
|
||||
|
||||
def _explode(**_kw):
|
||||
raise AssertionError("modal must not be invoked when --yes present")
|
||||
|
||||
self_ = SimpleNamespace(
|
||||
_app=None,
|
||||
_prompt_text_input_modal=_explode,
|
||||
)
|
||||
self_._normalize_slash_confirm_choice = _bound(
|
||||
HermesCLI._normalize_slash_confirm_choice, self_,
|
||||
)
|
||||
self_._split_destructive_skip = HermesCLI._split_destructive_skip
|
||||
|
||||
with patch(
|
||||
"cli.load_cli_config",
|
||||
return_value={"approvals": {"destructive_slash_confirm": True}},
|
||||
):
|
||||
result = _bound(HermesCLI._confirm_destructive_slash, self_)(
|
||||
"new", "detail", cmd_original="/new --yes My Session",
|
||||
)
|
||||
|
||||
assert result == "once"
|
||||
|
||||
|
||||
def test_confirm_destructive_slash_no_skip_token_still_prompts():
|
||||
"""Without a skip token the gate-on path still consults the modal."""
|
||||
from cli import HermesCLI
|
||||
|
||||
self_ = _make_self(prompt_response="3") # cancel
|
||||
self_._split_destructive_skip = HermesCLI._split_destructive_skip
|
||||
|
||||
with patch(
|
||||
"cli.load_cli_config",
|
||||
return_value={"approvals": {"destructive_slash_confirm": True}},
|
||||
):
|
||||
result = _bound(HermesCLI._confirm_destructive_slash, self_)(
|
||||
"new", "detail", cmd_original="/new My Session",
|
||||
)
|
||||
|
||||
# Prompt was reached and returned cancel → None.
|
||||
assert result is None
|
||||
|
|
|
|||
129
tests/cli/test_destructive_slash_inline_skip_e2e.py
Normal file
129
tests/cli/test_destructive_slash_inline_skip_e2e.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""End-to-end integration test for the destructive-slash inline-skip path.
|
||||
|
||||
Drives ``HermesCLI.process_command("/reset now")`` against a minimal stand-in
|
||||
and verifies:
|
||||
|
||||
1. ``new_session`` was invoked (the command actually ran)
|
||||
2. ``_prompt_text_input_modal`` was NOT invoked (modal bypassed)
|
||||
3. The skip token did not leak into the session title
|
||||
|
||||
This is the regression test for issue #30768 — the inline-skip escape hatch
|
||||
must work without ever touching the modal, on every platform.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def _make_cli_stub():
|
||||
"""Build a minimal HermesCLI-shaped object that can run ``process_command``
|
||||
for the destructive-slash branches without spinning up a real TUI."""
|
||||
from cli import HermesCLI
|
||||
|
||||
new_session_calls = []
|
||||
|
||||
def _capture_new_session(self_, title=None, silent=False):
|
||||
new_session_calls.append({"title": title, "silent": silent})
|
||||
|
||||
self_ = SimpleNamespace(
|
||||
_app=None,
|
||||
_prompt_text_input_modal=lambda **_kw: (_ for _ in ()).throw(
|
||||
AssertionError("modal must not be invoked when inline-skip token present")
|
||||
),
|
||||
new_session=lambda **kw: _capture_new_session(self_, **kw),
|
||||
# Stub out side-effects the destructive-slash branches reach for.
|
||||
console=SimpleNamespace(clear=lambda: None),
|
||||
compact=False,
|
||||
model="stub-model",
|
||||
session_id="stub-session",
|
||||
enabled_toolsets=[],
|
||||
_pending_title=None,
|
||||
_session_db=None,
|
||||
)
|
||||
# Bind the methods we need under test.
|
||||
self_._split_destructive_skip = HermesCLI._split_destructive_skip
|
||||
self_._confirm_destructive_slash = HermesCLI._confirm_destructive_slash.__get__(
|
||||
self_, type(self_)
|
||||
)
|
||||
self_.process_command = HermesCLI.process_command.__get__(self_, type(self_))
|
||||
return self_, new_session_calls
|
||||
|
||||
|
||||
def test_reset_now_invokes_new_session_without_modal():
|
||||
"""``/reset now`` runs ``new_session`` and never touches the modal."""
|
||||
self_, calls = _make_cli_stub()
|
||||
|
||||
with patch(
|
||||
"cli.load_cli_config",
|
||||
return_value={"approvals": {"destructive_slash_confirm": True}},
|
||||
):
|
||||
self_.process_command("/reset now")
|
||||
|
||||
assert calls, "new_session was never invoked"
|
||||
# The /new branch passes title=None when there's no non-skip remainder.
|
||||
assert calls[0]["title"] is None
|
||||
|
||||
|
||||
def test_new_yes_with_title_preserves_title():
|
||||
"""``/new --yes My Session`` runs ``new_session(title='My Session')``."""
|
||||
self_, calls = _make_cli_stub()
|
||||
|
||||
with patch(
|
||||
"cli.load_cli_config",
|
||||
return_value={"approvals": {"destructive_slash_confirm": True}},
|
||||
):
|
||||
self_.process_command("/new --yes My Session")
|
||||
|
||||
assert calls, "new_session was never invoked"
|
||||
assert calls[0]["title"] == "My Session"
|
||||
|
||||
|
||||
def test_new_without_skip_token_still_consults_modal():
|
||||
"""``/new My Session`` (no skip token) must reach the modal.
|
||||
|
||||
Sanity check that we haven't accidentally short-circuited the normal path.
|
||||
"""
|
||||
from cli import HermesCLI
|
||||
|
||||
new_session_calls = []
|
||||
modal_calls = []
|
||||
|
||||
def _capture_new_session(self_, title=None, silent=False):
|
||||
new_session_calls.append({"title": title, "silent": silent})
|
||||
|
||||
def _record_modal(**kw):
|
||||
modal_calls.append(kw)
|
||||
# Simulate user cancelling so new_session is not called.
|
||||
return "3"
|
||||
|
||||
self_ = SimpleNamespace(
|
||||
_app=None,
|
||||
_prompt_text_input_modal=_record_modal,
|
||||
new_session=lambda **kw: _capture_new_session(self_, **kw),
|
||||
console=SimpleNamespace(clear=lambda: None),
|
||||
compact=False,
|
||||
model="stub-model",
|
||||
session_id="stub-session",
|
||||
enabled_toolsets=[],
|
||||
_pending_title=None,
|
||||
_session_db=None,
|
||||
)
|
||||
self_._split_destructive_skip = HermesCLI._split_destructive_skip
|
||||
self_._normalize_slash_confirm_choice = HermesCLI._normalize_slash_confirm_choice.__get__(
|
||||
self_, type(self_)
|
||||
)
|
||||
self_._confirm_destructive_slash = HermesCLI._confirm_destructive_slash.__get__(
|
||||
self_, type(self_)
|
||||
)
|
||||
self_.process_command = HermesCLI.process_command.__get__(self_, type(self_))
|
||||
|
||||
with patch(
|
||||
"cli.load_cli_config",
|
||||
return_value={"approvals": {"destructive_slash_confirm": True}},
|
||||
):
|
||||
self_.process_command("/new My Session")
|
||||
|
||||
assert modal_calls, "modal must be reached when no skip token is present"
|
||||
assert not new_session_calls, "user cancelled — new_session must not run"
|
||||
|
|
@ -155,14 +155,34 @@ class TestDisplayResumedHistory:
|
|||
assert "Page content" not in output
|
||||
|
||||
def test_tool_calls_shown_as_summary(self):
|
||||
cli = _make_cli()
|
||||
# Disable tool-only skip so the summary line is rendered for this fixture.
|
||||
cli = _make_cli(config_overrides={"display": {"resume_skip_tool_only": False}})
|
||||
cli.conversation_history = _tool_call_history()
|
||||
output = self._capture_display(cli)
|
||||
import cli as _cli_mod
|
||||
# CLI_CONFIG is read at call-time inside _display_resumed_history, so
|
||||
# apply the override for the duration of the capture, not just at init.
|
||||
with patch.dict(_cli_mod.__dict__, {"CLI_CONFIG": {
|
||||
"display": {"resume_skip_tool_only": False, "resume_display": "full"}
|
||||
}}):
|
||||
output = self._capture_display(cli)
|
||||
|
||||
assert "2 tool calls" in output
|
||||
assert "web_search" in output
|
||||
assert "web_extract" in output
|
||||
|
||||
def test_tool_only_message_skipped_by_default(self):
|
||||
"""Assistant messages with only tool_calls (no text) are skipped when
|
||||
resume_skip_tool_only=True (the default). The summary line is hidden.
|
||||
"""
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = _tool_call_history()
|
||||
output = self._capture_display(cli)
|
||||
|
||||
# The tool-only assistant entry should be skipped
|
||||
assert "2 tool calls" not in output
|
||||
# The final text reply should still appear
|
||||
assert "Here are some great Python tutorials" in output
|
||||
|
||||
def test_long_user_message_truncated(self):
|
||||
cli = _make_cli()
|
||||
long_text = "A" * 500
|
||||
|
|
@ -611,6 +631,55 @@ class TestPreloadResumedSession:
|
|||
assert "1 user messages" not in output
|
||||
|
||||
|
||||
# ── Tests for _handle_resume_command recap display ───────────────────
|
||||
|
||||
|
||||
class TestHandleResumeCommandRecap:
|
||||
"""In-session /resume should show the same recap panel as startup resume."""
|
||||
|
||||
def test_resume_command_displays_recap_when_messages_restored(self):
|
||||
cli = _make_cli()
|
||||
cli.session_id = "current_session"
|
||||
messages = _simple_history()
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_session.return_value = {"id": "target_session", "title": "Test Session"}
|
||||
mock_db.get_messages_as_conversation.return_value = messages
|
||||
# resolve_resume_session_id passes the id through when no compression chain.
|
||||
mock_db.resolve_resume_session_id.return_value = "target_session"
|
||||
cli._session_db = mock_db
|
||||
|
||||
with (
|
||||
patch("hermes_cli.main._resolve_session_by_name_or_id", return_value="target_session"),
|
||||
patch.object(cli, "_display_resumed_history") as display_mock,
|
||||
):
|
||||
cli._handle_resume_command("/resume test session")
|
||||
|
||||
assert cli.session_id == "target_session"
|
||||
assert cli.conversation_history == messages
|
||||
mock_db.end_session.assert_called_once_with("current_session", "resumed_other")
|
||||
mock_db.reopen_session.assert_called_once_with("target_session")
|
||||
display_mock.assert_called_once_with()
|
||||
|
||||
def test_resume_command_skips_recap_when_session_has_no_messages(self):
|
||||
cli = _make_cli()
|
||||
cli.session_id = "current_session"
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_session.return_value = {"id": "target_session", "title": None}
|
||||
mock_db.get_messages_as_conversation.return_value = []
|
||||
mock_db.resolve_resume_session_id.return_value = "target_session"
|
||||
cli._session_db = mock_db
|
||||
|
||||
with (
|
||||
patch("hermes_cli.main._resolve_session_by_name_or_id", return_value="target_session"),
|
||||
patch.object(cli, "_display_resumed_history") as display_mock,
|
||||
):
|
||||
cli._handle_resume_command("/resume target_session")
|
||||
|
||||
display_mock.assert_not_called()
|
||||
|
||||
|
||||
# ── Integration: _init_agent skips when preloaded ────────────────────
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,9 +14,10 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
|||
|
||||
# Module-level reference to the cli module (set by _make_cli on first call)
|
||||
_cli_mod = None
|
||||
_UNSET = object()
|
||||
|
||||
|
||||
def _make_cli(tool_progress="all"):
|
||||
def _make_cli(tool_progress="all", verbose=_UNSET):
|
||||
"""Create a HermesCLI instance with minimal mocking."""
|
||||
global _cli_mod
|
||||
_clean_config = {
|
||||
|
|
@ -54,7 +55,9 @@ def _make_cli(tool_progress="all"):
|
|||
_cli_mod = mod
|
||||
with patch.object(mod, "get_tool_definitions", return_value=[]), \
|
||||
patch.dict(mod.__dict__, {"CLI_CONFIG": _clean_config}):
|
||||
return mod.HermesCLI()
|
||||
if verbose is _UNSET:
|
||||
return mod.HermesCLI()
|
||||
return mod.HermesCLI(verbose=verbose)
|
||||
|
||||
|
||||
class TestToolProgressScrollback:
|
||||
|
|
@ -122,14 +125,21 @@ class TestToolProgressScrollback:
|
|||
mock_print.assert_not_called()
|
||||
|
||||
def test_error_suffix_on_failed_tool(self):
|
||||
"""When is_error=True, the stacked line includes [error]."""
|
||||
"""When a failed tool's result is forwarded, the stacked line surfaces
|
||||
the specific error (e.g. ``[exit 1]`` or ``[File not found: x]``)
|
||||
instead of the legacy generic ``[error]`` suffix."""
|
||||
import json
|
||||
cli = _make_cli(tool_progress="all")
|
||||
cli._on_tool_progress("tool.started", "terminal", "bad cmd", {"command": "bad cmd"})
|
||||
cli._on_tool_progress("tool.started", "terminal", "false", {"command": "false"})
|
||||
with patch.object(_cli_mod, "_cprint") as mock_print:
|
||||
cli._on_tool_progress("tool.completed", "terminal", None, None, duration=0.5, is_error=True)
|
||||
cli._on_tool_progress(
|
||||
"tool.completed", "terminal", None, None,
|
||||
duration=0.5, is_error=True,
|
||||
result=json.dumps({"output": "", "exit_code": 1}),
|
||||
)
|
||||
|
||||
line = mock_print.call_args[0][0]
|
||||
assert "[error]" in line
|
||||
assert "[exit 1]" in line
|
||||
|
||||
def test_spinner_still_updates_on_started(self):
|
||||
"""tool.started still updates the spinner text for live display."""
|
||||
|
|
@ -168,6 +178,35 @@ class TestToolProgressScrollback:
|
|||
|
||||
mock_print.assert_not_called()
|
||||
|
||||
def test_verbose_mode_config_does_not_enable_global_debug_logging(self):
|
||||
"""display.tool_progress=verbose controls TOOL-CALL DISPLAY ONLY.
|
||||
|
||||
It must NOT auto-flip self.verbose, which controls root-logger DEBUG
|
||||
level for the entire process (every module spews to console). PR
|
||||
#6a1aa420e had coupled them, causing all debug logs to flood the
|
||||
terminal whenever a user picked tool_progress: verbose for richer
|
||||
per-tool rendering.
|
||||
"""
|
||||
cli = _make_cli(tool_progress="verbose")
|
||||
|
||||
assert cli.tool_progress_mode == "verbose"
|
||||
assert cli.verbose is False
|
||||
|
||||
def test_explicit_verbose_argument_wins_over_config(self):
|
||||
"""Explicit verbose=True from the CLI flag still enables DEBUG logging
|
||||
regardless of tool_progress_mode."""
|
||||
cli = _make_cli(tool_progress="off", verbose=True)
|
||||
|
||||
assert cli.tool_progress_mode == "off"
|
||||
assert cli.verbose is True
|
||||
|
||||
def test_explicit_non_verbose_argument_keeps_debug_logging_off(self):
|
||||
"""Explicit verbose=False overrides any default to enable DEBUG."""
|
||||
cli = _make_cli(tool_progress="verbose", verbose=False)
|
||||
|
||||
assert cli.tool_progress_mode == "verbose"
|
||||
assert cli.verbose is False
|
||||
|
||||
def test_pending_info_stores_on_started(self):
|
||||
"""tool.started stores args for later use by tool.completed."""
|
||||
cli = _make_cli(tool_progress="all")
|
||||
|
|
|
|||
|
|
@ -358,6 +358,10 @@ def _hermetic_environment(tmp_path, monkeypatch):
|
|||
monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true")
|
||||
monkeypatch.setenv("AWS_METADATA_SERVICE_TIMEOUT", "1")
|
||||
monkeypatch.setenv("AWS_METADATA_SERVICE_NUM_ATTEMPTS", "1")
|
||||
# Tirith auto-installs from GitHub when enabled and missing. Unit tests
|
||||
# should never perform that implicit network/bootstrap path; Tirith-specific
|
||||
# tests opt back in by patching the security config directly.
|
||||
monkeypatch.setenv("TIRITH_ENABLED", "false")
|
||||
|
||||
# 5. Reset plugin singleton so tests don't leak plugins from
|
||||
# ~/.hermes/plugins/ (which, per step 3, is now empty — but the
|
||||
|
|
|
|||
|
|
@ -490,6 +490,17 @@ class TestRoutingIntents:
|
|||
class TestDeliverResultWrapping:
|
||||
"""Verify that cron deliveries are wrapped with header/footer and no longer mirrored."""
|
||||
|
||||
def _safe_media_path(self, tmp_path, monkeypatch, name, data=b"media"):
|
||||
root = tmp_path / "media-cache"
|
||||
media_file = root / name
|
||||
media_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
media_file.write_bytes(data)
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS",
|
||||
(root,),
|
||||
)
|
||||
return media_file.resolve()
|
||||
|
||||
def test_delivery_wraps_content_with_header_and_footer(self):
|
||||
"""Delivered content should include task name header and agent-invisible note."""
|
||||
from gateway.config import Platform
|
||||
|
|
@ -564,9 +575,10 @@ class TestDeliverResultWrapping:
|
|||
assert "Cronjob Response" not in sent_content
|
||||
assert "The agent cannot see" not in sent_content
|
||||
|
||||
def test_delivery_extracts_media_tags_before_send(self):
|
||||
def test_delivery_extracts_media_tags_before_send(self, tmp_path, monkeypatch):
|
||||
"""Cron delivery should pass MEDIA attachments separately to the send helper."""
|
||||
from gateway.config import Platform
|
||||
media_path = self._safe_media_path(tmp_path, monkeypatch, "test-voice.ogg")
|
||||
|
||||
pconfig = MagicMock()
|
||||
pconfig.enabled = True
|
||||
|
|
@ -581,7 +593,7 @@ class TestDeliverResultWrapping:
|
|||
"deliver": "origin",
|
||||
"origin": {"platform": "telegram", "chat_id": "123"},
|
||||
}
|
||||
_deliver_result(job, "Title\nMEDIA:/tmp/test-voice.ogg")
|
||||
_deliver_result(job, f"Title\nMEDIA:{media_path}")
|
||||
|
||||
send_mock.assert_called_once()
|
||||
args, kwargs = send_mock.call_args
|
||||
|
|
@ -589,14 +601,15 @@ class TestDeliverResultWrapping:
|
|||
assert "MEDIA:" not in args[3]
|
||||
assert "Title" in args[3]
|
||||
# Media files should be forwarded separately
|
||||
assert kwargs["media_files"] == [("/tmp/test-voice.ogg", False)]
|
||||
assert kwargs["media_files"] == [(str(media_path), False)]
|
||||
|
||||
def test_live_adapter_sends_media_as_attachments(self):
|
||||
def test_live_adapter_sends_media_as_attachments(self, tmp_path, monkeypatch):
|
||||
"""When a live adapter is available, MEDIA files should be sent as native
|
||||
platform attachments (e.g., Discord voice, Telegram audio) rather than
|
||||
as literal 'MEDIA:/path' text."""
|
||||
from gateway.config import Platform
|
||||
from concurrent.futures import Future
|
||||
media_path = self._safe_media_path(tmp_path, monkeypatch, "cron-voice.mp3")
|
||||
|
||||
adapter = AsyncMock()
|
||||
adapter.send.return_value = MagicMock(success=True)
|
||||
|
|
@ -628,7 +641,7 @@ class TestDeliverResultWrapping:
|
|||
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
|
||||
_deliver_result(
|
||||
job,
|
||||
"Here is TTS\nMEDIA:/tmp/cron-voice.mp3",
|
||||
f"Here is TTS\nMEDIA:{media_path}",
|
||||
adapters={Platform.DISCORD: adapter},
|
||||
loop=loop,
|
||||
)
|
||||
|
|
@ -642,12 +655,13 @@ class TestDeliverResultWrapping:
|
|||
# Audio file should be sent as a voice attachment
|
||||
adapter.send_voice.assert_called_once()
|
||||
voice_call = adapter.send_voice.call_args
|
||||
assert voice_call[1]["audio_path"] == "/tmp/cron-voice.mp3"
|
||||
assert voice_call[1]["audio_path"] == str(media_path)
|
||||
|
||||
def test_live_adapter_routes_image_to_send_image_file(self):
|
||||
def test_live_adapter_routes_image_to_send_image_file(self, tmp_path, monkeypatch):
|
||||
"""Image MEDIA files should be routed to send_image_file, not send_voice."""
|
||||
from gateway.config import Platform
|
||||
from concurrent.futures import Future
|
||||
media_path = self._safe_media_path(tmp_path, monkeypatch, "chart.png")
|
||||
|
||||
adapter = AsyncMock()
|
||||
adapter.send.return_value = MagicMock(success=True)
|
||||
|
|
@ -678,19 +692,20 @@ class TestDeliverResultWrapping:
|
|||
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
|
||||
_deliver_result(
|
||||
job,
|
||||
"Chart attached\nMEDIA:/tmp/chart.png",
|
||||
f"Chart attached\nMEDIA:{media_path}",
|
||||
adapters={Platform.DISCORD: adapter},
|
||||
loop=loop,
|
||||
)
|
||||
|
||||
adapter.send_image_file.assert_called_once()
|
||||
assert adapter.send_image_file.call_args[1]["image_path"] == "/tmp/chart.png"
|
||||
assert adapter.send_image_file.call_args[1]["image_path"] == str(media_path)
|
||||
adapter.send_voice.assert_not_called()
|
||||
|
||||
def test_live_adapter_media_only_no_text(self):
|
||||
def test_live_adapter_media_only_no_text(self, tmp_path, monkeypatch):
|
||||
"""When content is ONLY a MEDIA tag with no text, media should still be sent."""
|
||||
from gateway.config import Platform
|
||||
from concurrent.futures import Future
|
||||
media_path = self._safe_media_path(tmp_path, monkeypatch, "voice.ogg")
|
||||
|
||||
adapter = AsyncMock()
|
||||
adapter.send_voice.return_value = MagicMock(success=True)
|
||||
|
|
@ -720,7 +735,7 @@ class TestDeliverResultWrapping:
|
|||
patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
|
||||
_deliver_result(
|
||||
job,
|
||||
"[[audio_as_voice]]\nMEDIA:/tmp/voice.ogg",
|
||||
f"[[audio_as_voice]]\nMEDIA:{media_path}",
|
||||
adapters={Platform.TELEGRAM: adapter},
|
||||
loop=loop,
|
||||
)
|
||||
|
|
@ -2164,43 +2179,56 @@ class TestBuildJobPromptBumpUse:
|
|||
class TestSendMediaViaAdapter:
|
||||
"""Unit tests for _send_media_via_adapter — routes files to typed adapter methods."""
|
||||
|
||||
def _safe_media_path(self, tmp_path, monkeypatch, name, data=b"media"):
|
||||
root = tmp_path / "media-cache"
|
||||
media_file = root / name
|
||||
media_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
media_file.write_bytes(data)
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS",
|
||||
(root,),
|
||||
)
|
||||
return media_file.resolve()
|
||||
|
||||
@staticmethod
|
||||
def _run_with_loop(adapter, chat_id, media_files, metadata, job):
|
||||
"""Helper: run _send_media_via_adapter with a real running event loop."""
|
||||
import asyncio
|
||||
import threading
|
||||
"""Helper: run _send_media_via_adapter with immediate scheduling."""
|
||||
from concurrent.futures import Future
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
t = threading.Thread(target=loop.run_forever, daemon=True)
|
||||
t.start()
|
||||
try:
|
||||
_send_media_via_adapter(adapter, chat_id, media_files, metadata, loop, job)
|
||||
finally:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
t.join(timeout=5)
|
||||
loop.close()
|
||||
def fake_run_coro(coro, _loop):
|
||||
coro.close()
|
||||
completed = Future()
|
||||
completed.set_result(MagicMock(success=True))
|
||||
return completed
|
||||
|
||||
def test_video_dispatched_to_send_video(self):
|
||||
with patch("asyncio.run_coroutine_threadsafe", side_effect=fake_run_coro):
|
||||
_send_media_via_adapter(adapter, chat_id, media_files, metadata, MagicMock(), job)
|
||||
|
||||
def test_video_dispatched_to_send_video(self, tmp_path, monkeypatch):
|
||||
adapter = MagicMock()
|
||||
adapter.send_video = AsyncMock()
|
||||
media_files = [("/tmp/clip.mp4", False)]
|
||||
media_path = self._safe_media_path(tmp_path, monkeypatch, "clip.mp4")
|
||||
media_files = [(str(media_path), False)]
|
||||
self._run_with_loop(adapter, "123", media_files, None, {"id": "j1"})
|
||||
adapter.send_video.assert_called_once()
|
||||
assert adapter.send_video.call_args[1]["video_path"] == "/tmp/clip.mp4"
|
||||
assert adapter.send_video.call_args[1]["video_path"] == str(media_path)
|
||||
|
||||
def test_unknown_ext_dispatched_to_send_document(self):
|
||||
def test_unknown_ext_dispatched_to_send_document(self, tmp_path, monkeypatch):
|
||||
adapter = MagicMock()
|
||||
adapter.send_document = AsyncMock()
|
||||
media_files = [("/tmp/report.pdf", False)]
|
||||
media_path = self._safe_media_path(tmp_path, monkeypatch, "report.pdf")
|
||||
media_files = [(str(media_path), False)]
|
||||
self._run_with_loop(adapter, "123", media_files, None, {"id": "j2"})
|
||||
adapter.send_document.assert_called_once()
|
||||
assert adapter.send_document.call_args[1]["file_path"] == "/tmp/report.pdf"
|
||||
assert adapter.send_document.call_args[1]["file_path"] == str(media_path)
|
||||
|
||||
def test_multiple_media_files_all_delivered(self):
|
||||
def test_multiple_media_files_all_delivered(self, tmp_path, monkeypatch):
|
||||
adapter = MagicMock()
|
||||
adapter.send_voice = AsyncMock()
|
||||
adapter.send_image_file = AsyncMock()
|
||||
media_files = [("/tmp/voice.mp3", False), ("/tmp/photo.jpg", False)]
|
||||
voice_path = self._safe_media_path(tmp_path, monkeypatch, "voice.mp3")
|
||||
photo_path = self._safe_media_path(tmp_path, monkeypatch, "photo.jpg")
|
||||
media_files = [(str(voice_path), False), (str(photo_path), False)]
|
||||
self._run_with_loop(adapter, "123", media_files, None, {"id": "j3"})
|
||||
adapter.send_voice.assert_called_once()
|
||||
adapter.send_image_file.assert_called_once()
|
||||
|
|
@ -2462,7 +2490,7 @@ class TestSendMediaTimeoutCancelsFuture:
|
|||
in-flight coroutine must be cancelled before the next file is tried.
|
||||
"""
|
||||
|
||||
def test_media_send_timeout_cancels_future_and_continues(self):
|
||||
def test_media_send_timeout_cancels_future_and_continues(self, tmp_path, monkeypatch):
|
||||
"""End-to-end: _send_media_via_adapter with a future whose .result()
|
||||
raises TimeoutError. Assert cancel() fires and the loop proceeds
|
||||
to the next file rather than hanging or crashing."""
|
||||
|
|
@ -2493,9 +2521,19 @@ class TestSendMediaTimeoutCancelsFuture:
|
|||
coro.close()
|
||||
return next(futures_iter)
|
||||
|
||||
root = tmp_path / "media-cache"
|
||||
slow = root / "slow.png"
|
||||
fast = root / "fast.mp4"
|
||||
slow.parent.mkdir(parents=True)
|
||||
slow.write_bytes(b"slow")
|
||||
fast.write_bytes(b"fast")
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS",
|
||||
(root,),
|
||||
)
|
||||
media_files = [
|
||||
("/tmp/slow.png", False), # times out
|
||||
("/tmp/fast.mp4", False), # succeeds
|
||||
(str(slow), False), # times out
|
||||
(str(fast), False), # succeeds
|
||||
]
|
||||
|
||||
loop = MagicMock()
|
||||
|
|
@ -2509,4 +2547,4 @@ class TestSendMediaTimeoutCancelsFuture:
|
|||
assert timeout_cancel_calls == [True], "future.cancel() must fire on TimeoutError"
|
||||
# 2. Second file still got dispatched — one timeout doesn't abort the batch
|
||||
adapter.send_video.assert_called_once()
|
||||
assert adapter.send_video.call_args[1]["video_path"] == "/tmp/fast.mp4"
|
||||
assert adapter.send_video.call_args[1]["video_path"] == str(fast.resolve())
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ _ensure_slack_mock()
|
|||
|
||||
import discord # noqa: E402 — mocked above
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
|
||||
import gateway.platforms.slack as _slack_mod # noqa: E402
|
||||
_slack_mod.SLACK_AVAILABLE = True
|
||||
|
|
|
|||
|
|
@ -1,20 +1,10 @@
|
|||
"""Regression test for #4469.
|
||||
"""Regression tests for active-session TEXT follow-up queueing.
|
||||
|
||||
When the agent is actively running (session present in
|
||||
``adapter._active_sessions``) and the user fires off multiple TEXT
|
||||
follow-ups in rapid succession, the previous behaviour was a single-slot
|
||||
replacement at ``gateway/platforms/base.py``:
|
||||
|
||||
self._pending_messages[session_key] = event
|
||||
|
||||
So three rapid messages ``A``, ``B``, ``C`` arriving while the agent was
|
||||
still working on the initial turn produced a pending slot containing only
|
||||
``C``; ``A`` and ``B`` were silently dropped.
|
||||
|
||||
The fix routes the follow-up through ``merge_pending_message_event(...,
|
||||
merge_text=True)`` so TEXT events accumulate into the existing pending
|
||||
event's text instead of clobbering it. Photo / media bursts continue to
|
||||
merge through the same helper (they always did).
|
||||
When the agent is actively running, rapid text follow-ups should survive as
|
||||
one next-turn pending message instead of clobbering each other. In
|
||||
``busy_text_mode=queue`` those active follow-ups first pass through a short
|
||||
debounce so bursty multi-message thoughts are merged before the active drain
|
||||
hands off the next turn.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -22,7 +12,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -44,16 +34,27 @@ from gateway.platforms.base import (
|
|||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_event(text: str, chat_id: str = "12345") -> MessageEvent:
|
||||
def _make_event(
|
||||
text: str,
|
||||
chat_id: str = "12345",
|
||||
*,
|
||||
chat_type: str = "dm",
|
||||
user_id: str = "u1",
|
||||
user_name: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
) -> MessageEvent:
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
chat_type="dm",
|
||||
user_id="u1",
|
||||
chat_type=chat_type,
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
|
|
@ -63,27 +64,26 @@ def _make_event(text: str, chat_id: str = "12345") -> MessageEvent:
|
|||
)
|
||||
|
||||
|
||||
class _DummyAdapter(BasePlatformAdapter): # type: ignore[misc]
|
||||
async def connect(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return None
|
||||
|
||||
async def send(self, *args, **kwargs):
|
||||
return SendResult(success=True, message_id="x")
|
||||
|
||||
|
||||
def _make_initialized_adapter() -> BasePlatformAdapter:
|
||||
return _DummyAdapter(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||
|
||||
|
||||
def _make_adapter() -> BasePlatformAdapter:
|
||||
"""Build a BasePlatformAdapter without running its heavy __init__.
|
||||
|
||||
We only need the bits ``handle_message`` touches on the active-session
|
||||
path: ``_active_sessions``, ``_pending_messages``,
|
||||
``_message_handler``, ``_busy_session_handler``, ``config``, ``platform``.
|
||||
"""
|
||||
|
||||
class _DummyAdapter(BasePlatformAdapter): # type: ignore[misc]
|
||||
async def connect(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return None
|
||||
|
||||
async def send(self, *args, **kwargs):
|
||||
return MagicMock(success=True, message_id="x", retryable=False)
|
||||
|
||||
"""Build a BasePlatformAdapter without running its heavy __init__."""
|
||||
adapter = object.__new__(_DummyAdapter)
|
||||
adapter.config = PlatformConfig(enabled=True, token="***")
|
||||
adapter.platform = Platform.TELEGRAM
|
||||
|
|
@ -100,6 +100,10 @@ def _make_adapter() -> BasePlatformAdapter:
|
|||
adapter._fatal_error_retryable = True
|
||||
adapter._fatal_error_handler = None
|
||||
adapter._running = True
|
||||
adapter._busy_text_mode = "queue"
|
||||
adapter._busy_text_debounce_seconds = 0.1
|
||||
adapter._busy_text_hard_cap_seconds = 1.0
|
||||
adapter._text_debounce = {}
|
||||
adapter._auto_tts_default = False
|
||||
adapter._auto_tts_enabled_chats = set()
|
||||
adapter._auto_tts_disabled_chats = set()
|
||||
|
|
@ -107,39 +111,235 @@ def _make_adapter() -> BasePlatformAdapter:
|
|||
return adapter
|
||||
|
||||
|
||||
def _debounced_event(adapter: BasePlatformAdapter, session_key: str) -> MessageEvent:
|
||||
return adapter._text_debounce[session_key].event
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_text_followups_accumulate_instead_of_replacing():
|
||||
"""Three rapid TEXT follow-ups during an active session must all
|
||||
survive in ``adapter._pending_messages[session_key].text``."""
|
||||
"""Rapid TEXT follow-ups must all survive in the pending event."""
|
||||
adapter = _make_adapter()
|
||||
adapter._busy_text_mode = "" # direct-merge behavior, no debounce
|
||||
first = _make_event("part one")
|
||||
session_key = build_session_key(first.source)
|
||||
|
||||
# Mark the session as active so subsequent messages take the
|
||||
# "already running" branch in handle_message.
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
second = _make_event("part two")
|
||||
third = _make_event("part three")
|
||||
await adapter.handle_message(_make_event("part two"))
|
||||
await adapter.handle_message(_make_event("part three"))
|
||||
|
||||
await adapter.handle_message(second)
|
||||
await adapter.handle_message(third)
|
||||
|
||||
# Both rapid follow-ups must be preserved, not just the last one.
|
||||
pending = adapter._pending_messages[session_key]
|
||||
assert pending.text == "part two\npart three", (
|
||||
f"expected accumulated text, got {pending.text!r}"
|
||||
assert pending.text == "part two\npart three"
|
||||
assert not adapter._active_sessions[session_key].is_set()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debounce_buffers_rapid_text_then_flushes_to_pending():
|
||||
adapter = _make_adapter()
|
||||
adapter._busy_text_debounce_seconds = 0.05
|
||||
|
||||
first = _make_event("part one")
|
||||
session_key = build_session_key(first.source)
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("part two"))
|
||||
assert session_key in adapter._text_debounce
|
||||
assert _debounced_event(adapter, session_key).text == "part two"
|
||||
assert session_key not in adapter._pending_messages
|
||||
|
||||
await adapter.handle_message(_make_event("part three"))
|
||||
assert _debounced_event(adapter, session_key).text == "part two\npart three"
|
||||
|
||||
await asyncio.sleep(0.15)
|
||||
|
||||
assert session_key not in adapter._text_debounce
|
||||
assert adapter._pending_messages[session_key].text == "part two\npart three"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debounce_resets_timer_on_new_arrival():
|
||||
adapter = _make_adapter()
|
||||
adapter._busy_text_debounce_seconds = 0.1
|
||||
|
||||
first = _make_event("one")
|
||||
session_key = build_session_key(first.source)
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(first)
|
||||
task1 = adapter._text_debounce[session_key].task
|
||||
assert task1 is not None
|
||||
assert not task1.done()
|
||||
|
||||
await adapter.handle_message(_make_event("two"))
|
||||
task2 = adapter._text_debounce[session_key].task
|
||||
assert task2 is not None
|
||||
assert task2 is not task1
|
||||
await asyncio.sleep(0)
|
||||
assert task1.cancelled() or task1.done()
|
||||
assert adapter._text_debounce[session_key].task is task2
|
||||
|
||||
await adapter.handle_message(_make_event("three"))
|
||||
task3 = adapter._text_debounce[session_key].task
|
||||
assert task3 is not None
|
||||
assert task3 is not task2
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
assert session_key not in adapter._text_debounce
|
||||
assert adapter._pending_messages[session_key].text == "one\ntwo\nthree"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_drain_force_flushes_debounce_before_release():
|
||||
adapter = _make_adapter()
|
||||
adapter._busy_text_debounce_seconds = 1.0
|
||||
processed: list[str] = []
|
||||
|
||||
async def _handler(event):
|
||||
processed.append(event.text)
|
||||
if event.text == "current":
|
||||
await adapter.handle_message(_make_event("follow up"))
|
||||
return None
|
||||
|
||||
adapter._message_handler = _handler
|
||||
current = _make_event("current")
|
||||
session_key = build_session_key(current.source)
|
||||
|
||||
task = asyncio.create_task(adapter._process_message_background(current, session_key))
|
||||
adapter._session_tasks[session_key] = task
|
||||
await asyncio.wait_for(task, timeout=1.0)
|
||||
|
||||
for _ in range(20):
|
||||
if processed == ["current", "follow up"] and session_key not in adapter._active_sessions:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert processed == ["current", "follow up"]
|
||||
assert session_key not in adapter._text_debounce
|
||||
assert session_key not in adapter._pending_messages
|
||||
assert session_key not in adapter._active_sessions
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_force_flush_cancels_timer_without_duplicate_processing():
|
||||
adapter = _make_adapter()
|
||||
adapter._busy_text_debounce_seconds = 0.2
|
||||
|
||||
event = _make_event("queued once")
|
||||
session_key = build_session_key(event.source)
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(event)
|
||||
timer_task = adapter._text_debounce[session_key].task
|
||||
|
||||
flushed = await adapter._flush_text_debounce_now(session_key)
|
||||
assert flushed is True
|
||||
assert session_key not in adapter._text_debounce
|
||||
assert adapter._pending_messages[session_key].text == "queued once"
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
assert timer_task is not None
|
||||
assert timer_task.cancelled() or timer_task.done()
|
||||
assert adapter._pending_messages[session_key].text == "queued once"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_debounce_does_not_merge_different_senders():
|
||||
adapter = _make_adapter()
|
||||
adapter._busy_text_debounce_seconds = 1.0
|
||||
|
||||
first = _make_event(
|
||||
"from alice",
|
||||
chat_type="group",
|
||||
user_id="alice",
|
||||
user_name="Alice",
|
||||
thread_id="topic-1",
|
||||
)
|
||||
# Interrupt event must be signalled exactly like before.
|
||||
assert adapter._active_sessions[session_key].is_set()
|
||||
second = _make_event(
|
||||
"from bob",
|
||||
chat_type="group",
|
||||
user_id="bob",
|
||||
user_name="Bob",
|
||||
thread_id="topic-1",
|
||||
)
|
||||
session_key = build_session_key(first.source)
|
||||
assert session_key == build_session_key(second.source)
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(first)
|
||||
await adapter.handle_message(second)
|
||||
|
||||
assert adapter._pending_messages[session_key].text == "from alice"
|
||||
assert _debounced_event(adapter, session_key).text == "from bob"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_control_and_clarify_messages_bypass_text_debounce():
|
||||
adapter = _make_adapter()
|
||||
started: list[str] = []
|
||||
|
||||
def _fake_start(event, session_key, *, interrupt_event=None):
|
||||
started.append(event.text)
|
||||
return True
|
||||
|
||||
adapter._start_session_processing = _fake_start # type: ignore[method-assign]
|
||||
|
||||
await adapter.handle_message(_make_event("/status"))
|
||||
assert started == ["/status"]
|
||||
assert adapter._text_debounce == {}
|
||||
|
||||
answer = _make_event("clarify answer")
|
||||
session_key = build_session_key(answer.source)
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
adapter._message_handler = AsyncMock(return_value=None)
|
||||
|
||||
with patch("tools.clarify_gateway.get_pending_for_session", return_value=object()):
|
||||
await adapter.handle_message(answer)
|
||||
|
||||
adapter._message_handler.assert_awaited_once_with(answer)
|
||||
assert session_key not in adapter._text_debounce
|
||||
assert session_key not in adapter._pending_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debounce_skipped_when_busy_text_mode_not_queue():
|
||||
adapter = _make_adapter()
|
||||
adapter._busy_text_mode = ""
|
||||
event = _make_event("direct merge")
|
||||
session_key = build_session_key(event.source)
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(event)
|
||||
|
||||
assert adapter._pending_messages[session_key].text == "direct merge"
|
||||
assert session_key not in adapter._text_debounce
|
||||
|
||||
|
||||
def test_debounce_respects_env_var_override(monkeypatch):
|
||||
monkeypatch.setenv("HERMES_GATEWAY_BUSY_TEXT_DEBOUNCE_SECONDS", "2.5")
|
||||
adapter = _make_initialized_adapter()
|
||||
assert adapter._busy_text_debounce_seconds == 2.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debounce_cleanup_in_cancel_background_tasks():
|
||||
adapter = _make_adapter()
|
||||
adapter._busy_text_debounce_seconds = 1.0
|
||||
|
||||
event = _make_event("cleanup test")
|
||||
session_key = build_session_key(event.source)
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
await adapter.handle_message(event)
|
||||
|
||||
assert session_key in adapter._text_debounce
|
||||
|
||||
await adapter.cancel_background_tasks()
|
||||
|
||||
assert session_key not in adapter._text_debounce
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_followup_is_stored_as_is():
|
||||
"""One TEXT follow-up still lands as the event object itself
|
||||
(no spurious wrapping / mutation) — guards against the merge path
|
||||
breaking the simple case."""
|
||||
adapter = _make_adapter()
|
||||
adapter._busy_text_mode = ""
|
||||
first = _make_event("only one")
|
||||
session_key = build_session_key(first.source)
|
||||
|
||||
|
|
@ -149,4 +349,29 @@ async def test_single_followup_is_stored_as_is():
|
|||
pending = adapter._pending_messages[session_key]
|
||||
assert pending is first
|
||||
assert pending.text == "only one"
|
||||
assert adapter._active_sessions[session_key].is_set()
|
||||
assert not adapter._active_sessions[session_key].is_set()
|
||||
|
||||
|
||||
def test_adapter_defaults_to_queue_mode(monkeypatch):
|
||||
monkeypatch.delenv("HERMES_GATEWAY_BUSY_TEXT_MODE", raising=False)
|
||||
adapter = _make_initialized_adapter()
|
||||
assert adapter._busy_text_mode == "queue"
|
||||
assert adapter._is_queue_text_debounce_candidate(_make_event("hello"))
|
||||
|
||||
|
||||
def test_adapter_is_queue_text_debounce_candidate_by_default():
|
||||
adapter = _make_adapter()
|
||||
assert adapter._is_queue_text_debounce_candidate(_make_event("hello world"))
|
||||
|
||||
|
||||
def test_command_messages_bypass_debounce_even_in_queue_mode():
|
||||
adapter = _make_adapter()
|
||||
assert not adapter._is_queue_text_debounce_candidate(_make_event(""))
|
||||
assert not adapter._is_queue_text_debounce_candidate(_make_event("/stop"))
|
||||
|
||||
|
||||
def test_busy_text_mode_respects_env_var_override(monkeypatch):
|
||||
monkeypatch.setenv("HERMES_GATEWAY_BUSY_TEXT_MODE", "interrupt")
|
||||
adapter = _make_initialized_adapter()
|
||||
assert adapter._busy_text_mode == "interrupt"
|
||||
assert not adapter._is_queue_text_debounce_candidate(_make_event("test"))
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ Tests cover:
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
|
@ -128,6 +130,37 @@ class TestResponseStore:
|
|||
# resp_2 mapping should still be intact
|
||||
assert store.get_conversation("chat-b") == "resp_2"
|
||||
|
||||
@pytest.mark.skipif(os.name == "nt", reason="POSIX mode bits are platform-specific")
|
||||
def test_file_store_created_owner_only_under_permissive_umask(self, tmp_path):
|
||||
"""response_store.db must be 0o600 on creation even under umask 022."""
|
||||
db_path = tmp_path / "response_store.db"
|
||||
store = None
|
||||
old_umask = os.umask(0o022)
|
||||
try:
|
||||
store = ResponseStore(max_size=10, db_path=str(db_path))
|
||||
store.put(
|
||||
"resp_secret",
|
||||
{
|
||||
"response": {"id": "resp_secret"},
|
||||
"conversation_history": [{"role": "tool", "content": "dummy-marker"}],
|
||||
},
|
||||
)
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
if store is not None:
|
||||
store.close()
|
||||
|
||||
assert stat.S_IMODE(db_path.stat().st_mode) == 0o600
|
||||
# WAL/SHM sidecars are owner-only too when present. WAL mode may be
|
||||
# unavailable on some filesystems (NFS/SMB) — only assert when the
|
||||
# sidecar files actually exist.
|
||||
for sidecar in (
|
||||
db_path.with_name(db_path.name + "-wal"),
|
||||
db_path.with_name(db_path.name + "-shm"),
|
||||
):
|
||||
if sidecar.exists():
|
||||
assert stat.S_IMODE(sidecar.stat().st_mode) == 0o600
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _IdempotencyCache
|
||||
|
|
|
|||
|
|
@ -27,8 +27,11 @@ class TestResolveRuntimeAgentKwargsAuthFallback:
|
|||
|
||||
def _mock_resolve(**kwargs):
|
||||
call_count["n"] += 1
|
||||
requested = kwargs.get("requested", "")
|
||||
if requested and "codex" in str(requested).lower():
|
||||
# First call = primary path (gateway reads model.provider from
|
||||
# config.yaml internally; we simulate the auth failure here).
|
||||
# Second call = fallback path with explicit_api_key + explicit_base_url
|
||||
# supplied by gateway from fallback_model config.
|
||||
if call_count["n"] == 1:
|
||||
raise AuthError("Codex token refresh failed with status 401")
|
||||
return {
|
||||
"api_key": "fallback-key",
|
||||
|
|
@ -40,8 +43,6 @@ class TestResolveRuntimeAgentKwargsAuthFallback:
|
|||
"credential_pool": None,
|
||||
}
|
||||
|
||||
monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "openai-codex")
|
||||
|
||||
with patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
side_effect=_mock_resolve,
|
||||
|
|
@ -62,7 +63,6 @@ class TestResolveRuntimeAgentKwargsAuthFallback:
|
|||
config_path.write_text("model:\n provider: openai-codex\n")
|
||||
|
||||
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
|
||||
monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "openai-codex")
|
||||
|
||||
with patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
|
|
@ -71,3 +71,46 @@ class TestResolveRuntimeAgentKwargsAuthFallback:
|
|||
from gateway.run import _resolve_runtime_agent_kwargs
|
||||
with pytest.raises(RuntimeError):
|
||||
_resolve_runtime_agent_kwargs()
|
||||
|
||||
def test_legacy_fallback_is_appended_after_fallback_providers(self, tmp_path, monkeypatch):
|
||||
"""When both keys exist, the legacy entry still participates in resolution."""
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(
|
||||
"fallback_providers:\n"
|
||||
" - provider: openrouter\n"
|
||||
" model: anthropic/claude-sonnet-4.6\n"
|
||||
"fallback_model:\n"
|
||||
" provider: nous\n"
|
||||
" model: Hermes-4\n"
|
||||
)
|
||||
|
||||
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
|
||||
|
||||
calls = []
|
||||
|
||||
def _mock_resolve(**kwargs):
|
||||
requested = kwargs.get("requested")
|
||||
calls.append(requested)
|
||||
if requested == "openrouter":
|
||||
raise RuntimeError("openrouter unavailable")
|
||||
return {
|
||||
"api_key": "nous-key",
|
||||
"base_url": "https://portal.nousresearch.com/v1",
|
||||
"provider": "nous",
|
||||
"api_mode": "chat_completions",
|
||||
"command": None,
|
||||
"args": None,
|
||||
"credential_pool": None,
|
||||
}
|
||||
|
||||
with patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
side_effect=_mock_resolve,
|
||||
):
|
||||
from gateway.run import _try_resolve_fallback_provider
|
||||
|
||||
result = _try_resolve_fallback_provider()
|
||||
|
||||
assert calls == ["openrouter", "nous"]
|
||||
assert result["provider"] == "nous"
|
||||
assert result["model"] == "Hermes-4"
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from gateway.session import SessionSource, build_session_key
|
|||
class DummyTelegramAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM)
|
||||
self._busy_text_mode = ""
|
||||
self.sent = []
|
||||
self.typing = []
|
||||
self.processing_hooks = []
|
||||
|
|
|
|||
|
|
@ -452,6 +452,14 @@ class TestBlueBubblesWebhookUrl:
|
|||
adapter = _make_adapter(monkeypatch, password="W9fTC&L5JL*@")
|
||||
assert "password=W9fTC%26L5JL%2A%40" in adapter._webhook_register_url
|
||||
|
||||
def test_register_url_for_log_masks_password(self, monkeypatch):
|
||||
"""Log-safe webhook URLs must never expose the webhook password."""
|
||||
adapter = _make_adapter(monkeypatch, password="W9fTC&L5JL*@")
|
||||
safe_url = adapter._webhook_register_url_for_log
|
||||
assert safe_url.endswith("?password=***")
|
||||
assert "W9fTC" not in safe_url
|
||||
assert "%26" not in safe_url
|
||||
|
||||
def test_register_url_omits_query_when_no_password(self, monkeypatch):
|
||||
"""If no password is configured, the register URL should be the bare URL."""
|
||||
monkeypatch.delenv("BLUEBUBBLES_PASSWORD", raising=False)
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ def _make_runner():
|
|||
runner._pending_messages = {}
|
||||
runner._busy_ack_ts = {}
|
||||
runner._draining = False
|
||||
runner._busy_text_mode = "interrupt"
|
||||
runner.adapters = {}
|
||||
runner.config = MagicMock()
|
||||
runner.session_store = None
|
||||
|
|
@ -84,6 +85,8 @@ def _make_adapter(platform_val="telegram"):
|
|||
adapter.config = MagicMock()
|
||||
adapter.config.extra = {}
|
||||
adapter.platform = MagicMock(value=platform_val)
|
||||
adapter._text_debounce = {}
|
||||
adapter._busy_text_debounce_seconds = 0.6
|
||||
return adapter
|
||||
|
||||
|
||||
|
|
@ -186,6 +189,32 @@ class TestBusySessionAck:
|
|||
assert "respond once the current task finishes" in content
|
||||
assert "Interrupting" not in content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_busy_text_mode_queue_delegates_to_adapter_handle_message(self):
|
||||
"""busy_text_mode=queue lets the adapter debounce text silently."""
|
||||
runner, sentinel = _make_runner()
|
||||
runner._busy_input_mode = "interrupt"
|
||||
runner._busy_text_mode = "queue"
|
||||
adapter = _make_adapter()
|
||||
|
||||
first = _make_event(text="part one")
|
||||
second = _make_event(text="part two")
|
||||
sk = build_session_key(first.source)
|
||||
|
||||
agent = MagicMock()
|
||||
runner._running_agents[sk] = agent
|
||||
runner.adapters[first.source.platform] = adapter
|
||||
runner.adapters[second.source.platform] = adapter
|
||||
|
||||
result1 = await runner._handle_active_session_busy_message(first, sk)
|
||||
result2 = await runner._handle_active_session_busy_message(second, sk)
|
||||
|
||||
assert result1 is False
|
||||
assert result2 is False
|
||||
assert sk not in adapter._pending_messages
|
||||
agent.interrupt.assert_not_called()
|
||||
adapter._send_with_retry.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_mode_calls_agent_steer_no_interrupt_no_queue(self):
|
||||
"""busy_input_mode='steer' injects via agent.steer() and skips queueing."""
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ def _make_adapter():
|
|||
"""Create a minimal adapter for testing the active-session guard."""
|
||||
config = PlatformConfig(enabled=True, token="test-token")
|
||||
adapter = _StubAdapter(config, Platform.TELEGRAM)
|
||||
adapter._busy_text_mode = ""
|
||||
adapter.sent_responses = []
|
||||
|
||||
async def _mock_handler(event):
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ def _run_gateway_import(hermes_home: Path, initial_env: dict[str, str]) -> dict[
|
|||
"HERMES_AGENT_TIMEOUT",
|
||||
"HERMES_AGENT_TIMEOUT_WARNING",
|
||||
"HERMES_GATEWAY_BUSY_INPUT_MODE",
|
||||
"HERMES_GATEWAY_BUSY_TEXT_MODE",
|
||||
"HERMES_TIMEZONE",
|
||||
):
|
||||
v = os.environ.get(k)
|
||||
|
|
@ -143,6 +144,15 @@ def test_config_display_busy_input_mode_wins_over_stale_env(hermes_home: Path) -
|
|||
assert env.get("HERMES_GATEWAY_BUSY_INPUT_MODE") == "interrupt"
|
||||
|
||||
|
||||
def test_config_display_busy_text_mode_wins_over_stale_env(hermes_home: Path) -> None:
|
||||
_write_config(hermes_home, display_cfg={"busy_text_mode": "queue"})
|
||||
_write_env(hermes_home, {"HERMES_GATEWAY_BUSY_TEXT_MODE": "interrupt"})
|
||||
|
||||
env = _run_gateway_import(hermes_home, initial_env={})
|
||||
|
||||
assert env.get("HERMES_GATEWAY_BUSY_TEXT_MODE") == "queue"
|
||||
|
||||
|
||||
def test_config_timezone_wins_over_stale_env(hermes_home: Path) -> None:
|
||||
_write_config(hermes_home, timezone="America/Los_Angeles")
|
||||
_write_env(hermes_home, {"HERMES_TIMEZONE": "UTC"})
|
||||
|
|
|
|||
|
|
@ -407,6 +407,36 @@ class TestConnect:
|
|||
assert len(adapter._dedup._seen) == 0
|
||||
assert adapter._http_client is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_finalizes_open_streaming_cards(self):
|
||||
"""Streaming cards must be finalized before HTTP client closes."""
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
adapter._http_client = AsyncMock()
|
||||
adapter._stream_task = None
|
||||
adapter._streaming_cards = {
|
||||
"chat-1": {"track-a": "last content"},
|
||||
"chat-2": {"track-b": "other"},
|
||||
}
|
||||
|
||||
close_calls = []
|
||||
|
||||
async def fake_close_siblings(chat_id):
|
||||
# HTTP client must still be alive at call time.
|
||||
assert adapter._http_client is not None, (
|
||||
"HTTP client was already closed before card finalization"
|
||||
)
|
||||
close_calls.append(chat_id)
|
||||
adapter._streaming_cards.pop(chat_id, None)
|
||||
|
||||
with patch.object(adapter, "_close_streaming_siblings", side_effect=fake_close_siblings):
|
||||
await adapter.disconnect()
|
||||
|
||||
assert set(close_calls) == {"chat-1", "chat-2"}
|
||||
assert adapter._streaming_cards == {}
|
||||
assert adapter._http_client is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Platform enum
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ def _ensure_discord_mock():
|
|||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import _build_allowed_mentions # noqa: E402
|
||||
from plugins.platforms.discord.adapter import _build_allowed_mentions # noqa: E402
|
||||
|
||||
|
||||
# The four DISCORD_ALLOW_MENTION_* env vars that _build_allowed_mentions reads.
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ def _ensure_discord_mock():
|
|||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
from gateway.platforms.base import MessageType # noqa: E402
|
||||
|
||||
|
||||
|
|
@ -146,10 +146,10 @@ class TestCacheDiscordImage:
|
|||
att = _make_attachment_with_read(_PNG_BYTES)
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_bytes",
|
||||
"plugins.platforms.discord.adapter.cache_image_from_bytes",
|
||||
return_value="/tmp/cached.png",
|
||||
) as mock_bytes, patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
"plugins.platforms.discord.adapter.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_url:
|
||||
result = await adapter._cache_discord_image(att, ".png")
|
||||
|
|
@ -165,9 +165,9 @@ class TestCacheDiscordImage:
|
|||
att = _make_attachment_without_read()
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_bytes",
|
||||
"plugins.platforms.discord.adapter.cache_image_from_bytes",
|
||||
) as mock_bytes, patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
"plugins.platforms.discord.adapter.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
return_value="/tmp/from_url.png",
|
||||
) as mock_url:
|
||||
|
|
@ -186,10 +186,10 @@ class TestCacheDiscordImage:
|
|||
att = _make_attachment_with_read(b"<html>forbidden</html>")
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_bytes",
|
||||
"plugins.platforms.discord.adapter.cache_image_from_bytes",
|
||||
side_effect=ValueError("not a valid image"),
|
||||
), patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
"plugins.platforms.discord.adapter.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
return_value="/tmp/fallback.png",
|
||||
) as mock_url:
|
||||
|
|
@ -210,10 +210,10 @@ class TestCacheDiscordAudio:
|
|||
att = _make_attachment_with_read(_OGG_BYTES)
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_audio_from_bytes",
|
||||
"plugins.platforms.discord.adapter.cache_audio_from_bytes",
|
||||
return_value="/tmp/voice.ogg",
|
||||
) as mock_bytes, patch(
|
||||
"gateway.platforms.discord.cache_audio_from_url",
|
||||
"plugins.platforms.discord.adapter.cache_audio_from_url",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_url:
|
||||
result = await adapter._cache_discord_audio(att, ".ogg")
|
||||
|
|
@ -228,7 +228,7 @@ class TestCacheDiscordAudio:
|
|||
att = _make_attachment_without_read()
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_audio_from_url",
|
||||
"plugins.platforms.discord.adapter.cache_audio_from_url",
|
||||
new_callable=AsyncMock,
|
||||
return_value="/tmp/from_url.ogg",
|
||||
) as mock_url:
|
||||
|
|
@ -267,7 +267,7 @@ class TestCacheDiscordDocument:
|
|||
att = _make_attachment_without_read() # no .read → forces fallback
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.is_safe_url", return_value=False
|
||||
"plugins.platforms.discord.adapter.is_safe_url", return_value=False
|
||||
) as mock_safe, patch("aiohttp.ClientSession") as mock_session:
|
||||
with pytest.raises(ValueError, match="SSRF"):
|
||||
await adapter._cache_discord_document(att, ".pdf")
|
||||
|
|
@ -295,7 +295,7 @@ class TestCacheDiscordDocument:
|
|||
session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.is_safe_url", return_value=True
|
||||
"plugins.platforms.discord.adapter.is_safe_url", return_value=True
|
||||
), patch("aiohttp.ClientSession", return_value=session):
|
||||
result = await adapter._cache_discord_document(att, ".pdf")
|
||||
|
||||
|
|
@ -320,10 +320,10 @@ class TestHandleMessageUsesAuthenticatedRead:
|
|||
adapter.handle_message = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_bytes",
|
||||
"plugins.platforms.discord.adapter.cache_image_from_bytes",
|
||||
return_value="/tmp/img_from_read.png",
|
||||
), patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
"plugins.platforms.discord.adapter.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_url_download:
|
||||
att = SimpleNamespace(
|
||||
|
|
@ -342,7 +342,7 @@ class TestHandleMessageUsesAuthenticatedRead:
|
|||
|
||||
# Patch the DMChannel isinstance check so our fake counts as DM.
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.discord.discord.DMChannel",
|
||||
"plugins.platforms.discord.adapter.discord.DMChannel",
|
||||
_FakeDMChannel,
|
||||
)
|
||||
chan = _FakeDMChannel()
|
||||
|
|
@ -368,7 +368,7 @@ class TestHandleMessageUsesAuthenticatedRead:
|
|||
adapter.handle_message = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_audio_from_bytes",
|
||||
"plugins.platforms.discord.adapter.cache_audio_from_bytes",
|
||||
return_value="/tmp/voice_from_read.ogg",
|
||||
):
|
||||
att = SimpleNamespace(
|
||||
|
|
@ -386,7 +386,7 @@ class TestHandleMessageUsesAuthenticatedRead:
|
|||
name = "dm"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.discord.discord.DMChannel",
|
||||
"plugins.platforms.discord.adapter.discord.DMChannel",
|
||||
_FakeDMChannel,
|
||||
)
|
||||
chan = _FakeDMChannel()
|
||||
|
|
@ -412,7 +412,7 @@ class TestHandleMessageUsesAuthenticatedRead:
|
|||
adapter.handle_message = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_audio_from_bytes",
|
||||
"plugins.platforms.discord.adapter.cache_audio_from_bytes",
|
||||
return_value="/tmp/audio_from_read.ogg",
|
||||
):
|
||||
att = SimpleNamespace(
|
||||
|
|
@ -430,7 +430,7 @@ class TestHandleMessageUsesAuthenticatedRead:
|
|||
name = "dm"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.discord.discord.DMChannel",
|
||||
"plugins.platforms.discord.adapter.discord.DMChannel",
|
||||
_FakeDMChannel,
|
||||
)
|
||||
chan = _FakeDMChannel()
|
||||
|
|
|
|||
|
|
@ -172,42 +172,49 @@ def test_bot_bypass_does_not_leak_to_other_platforms(monkeypatch):
|
|||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DISCORD_ALLOWED_ROLES gateway-layer bypass (#7871)
|
||||
# DISCORD_ALLOWED_ROLES no longer bypasses the gateway allowlist (#30742)
|
||||
#
|
||||
# Prior behavior: setting DISCORD_ALLOWED_ROLES caused _is_user_authorized
|
||||
# to return True for ANY Discord event, on the assumption that the adapter
|
||||
# pre-filter had already validated role membership. That allowed slash
|
||||
# commands and synthetic voice events to bypass role checks. PR #30742
|
||||
# removed the shortcut — Discord auth now flows through the same allowlist
|
||||
# / pairing / allow-all path as every other platform.
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_discord_role_config_bypasses_gateway_allowlist(monkeypatch):
|
||||
"""When DISCORD_ALLOWED_ROLES is set, _is_user_authorized must trust
|
||||
the adapter's pre-filter and authorize. Without this, role-only setups
|
||||
(DISCORD_ALLOWED_ROLES populated, DISCORD_ALLOWED_USERS empty) would
|
||||
hit the 'no allowlists configured' branch and get rejected.
|
||||
def test_discord_role_config_does_not_bypass_gateway_allowlist(monkeypatch):
|
||||
"""DISCORD_ALLOWED_ROLES alone must NOT authorize at the gateway layer
|
||||
(regression guard for #30742). Role-based access is enforced by the
|
||||
adapter pre-filter on real message events; the gateway layer requires
|
||||
an explicit allowlist hit or pairing approval.
|
||||
"""
|
||||
runner = _make_bare_runner()
|
||||
|
||||
monkeypatch.setenv("DISCORD_ALLOWED_ROLES", "1493705176387948674")
|
||||
# Note: DISCORD_ALLOWED_USERS is NOT set — the entire point.
|
||||
# DISCORD_ALLOWED_USERS deliberately NOT set — verifies the role
|
||||
# config alone no longer grants authorization.
|
||||
|
||||
source = _make_discord_human_source(user_id="999888777")
|
||||
assert runner._is_user_authorized(source) is True
|
||||
assert runner._is_user_authorized(source) is False
|
||||
|
||||
|
||||
def test_discord_role_config_still_authorizes_alongside_users(monkeypatch):
|
||||
"""Sanity: setting both DISCORD_ALLOWED_ROLES and DISCORD_ALLOWED_USERS
|
||||
doesn't break the user-id path. Users in the allowlist should still be
|
||||
authorized even if they don't have a role. (OR semantics.)
|
||||
def test_discord_user_allowlist_still_authorizes_when_role_is_also_configured(monkeypatch):
|
||||
"""Sanity: DISCORD_ALLOWED_USERS still authorizes users on the list,
|
||||
independent of DISCORD_ALLOWED_ROLES. This guards against a future
|
||||
regression that ties the user-allowlist check to the (now-removed)
|
||||
role bypass.
|
||||
"""
|
||||
runner = _make_bare_runner()
|
||||
|
||||
monkeypatch.setenv("DISCORD_ALLOWED_ROLES", "1493705176387948674")
|
||||
monkeypatch.setenv("DISCORD_ALLOWED_USERS", "100200300")
|
||||
|
||||
# User on the user allowlist, no role → still authorized at gateway
|
||||
# level via the role bypass (adapter already approved them).
|
||||
source = _make_discord_human_source(user_id="100200300")
|
||||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
def test_discord_role_bypass_does_not_leak_to_other_platforms(monkeypatch):
|
||||
def test_discord_role_config_does_not_leak_to_other_platforms(monkeypatch):
|
||||
"""DISCORD_ALLOWED_ROLES must only affect Discord. Setting it should
|
||||
not suddenly start authorizing Telegram users whose platform has its
|
||||
own empty allowlist.
|
||||
|
|
|
|||
|
|
@ -45,8 +45,8 @@ def _ensure_discord_mock():
|
|||
|
||||
_ensure_discord_mock()
|
||||
|
||||
import gateway.platforms.discord as discord_platform # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
import plugins.platforms.discord.adapter as discord_platform # noqa: E402
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
class FakeDMChannel:
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ def _install_fake_agent(monkeypatch):
|
|||
|
||||
def _make_adapter():
|
||||
_ensure_discord_mock()
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter.config = MagicMock()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
|
||||
def _make_adapter():
|
||||
"""Create a minimal DiscordAdapter with mocked config."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter.config = MagicMock()
|
||||
adapter.config.extra = {}
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ if _repo not in sys.path:
|
|||
|
||||
# Triggers the shared discord mock from tests/gateway/conftest.py before
|
||||
# importing the production module.
|
||||
from gateway.platforms.discord import ( # noqa: E402
|
||||
from plugins.platforms.discord.adapter import ( # noqa: E402
|
||||
ClarifyChoiceView,
|
||||
DiscordAdapter,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import pytest
|
|||
|
||||
# Trigger the shared discord mock from tests/gateway/conftest.py before
|
||||
# importing the production module.
|
||||
from gateway.platforms.discord import ( # noqa: E402
|
||||
from plugins.platforms.discord.adapter import ( # noqa: E402
|
||||
ExecApprovalView,
|
||||
ModelPickerView,
|
||||
SlashConfirmView,
|
||||
|
|
|
|||
|
|
@ -67,8 +67,8 @@ def _ensure_discord_mock():
|
|||
|
||||
_ensure_discord_mock()
|
||||
|
||||
import gateway.platforms.discord as discord_platform # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
import plugins.platforms.discord.adapter as discord_platform # noqa: E402
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
|
|||
|
|
@ -57,8 +57,8 @@ def _ensure_discord_mock():
|
|||
|
||||
_ensure_discord_mock()
|
||||
|
||||
import gateway.platforms.discord as discord_platform # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
import plugins.platforms.discord.adapter as discord_platform # noqa: E402
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -371,7 +371,7 @@ class TestIncomingDocumentHandling:
|
|||
async def test_image_attachment_unaffected(self, adapter):
|
||||
"""Image attachments should still go through the image path, not the document path."""
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
"plugins.platforms.discord.adapter.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
return_value="/tmp/cached_image.png",
|
||||
):
|
||||
|
|
|
|||
|
|
@ -45,8 +45,8 @@ def _ensure_discord_mock():
|
|||
|
||||
_ensure_discord_mock()
|
||||
|
||||
import gateway.platforms.discord as discord_platform # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
import plugins.platforms.discord.adapter as discord_platform # noqa: E402
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
class FakeDMChannel:
|
||||
|
|
|
|||
|
|
@ -14,10 +14,13 @@ class TestDiscordImportSafety:
|
|||
raise ImportError("discord unavailable for test")
|
||||
return original_import(name, globals, locals, fromlist, level)
|
||||
|
||||
monkeypatch.delitem(sys.modules, "gateway.platforms.discord", raising=False)
|
||||
# Purge the cached module so the import below actually re-runs the
|
||||
# module body with discord.py simulated-missing.
|
||||
monkeypatch.delitem(sys.modules, "plugins.platforms.discord.adapter", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "plugins.platforms.discord", raising=False)
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
|
||||
module = importlib.import_module("gateway.platforms.discord")
|
||||
module = importlib.import_module("plugins.platforms.discord.adapter")
|
||||
|
||||
assert module.DISCORD_AVAILABLE is False
|
||||
assert module.discord is None
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class TestDefineDiscordViewClasses:
|
|||
|
||||
def test_registers_all_five_view_classes(self, monkeypatch):
|
||||
"""Calling _define_discord_view_classes() must (re)define all 5 view classes."""
|
||||
dp = importlib.import_module("gateway.platforms.discord")
|
||||
dp = importlib.import_module("plugins.platforms.discord.adapter")
|
||||
|
||||
# Remove the classes to simulate the state where the module was loaded
|
||||
# with DISCORD_AVAILABLE=False (the lazy-install scenario).
|
||||
|
|
@ -54,7 +54,7 @@ class TestDefineDiscordViewClasses:
|
|||
def test_check_discord_requirements_calls_define_on_lazy_install(self, monkeypatch):
|
||||
"""check_discord_requirements() must call _define_discord_view_classes() on
|
||||
a successful lazy install so view classes exist when DISCORD_AVAILABLE=True."""
|
||||
dp = importlib.import_module("gateway.platforms.discord")
|
||||
dp = importlib.import_module("plugins.platforms.discord.adapter")
|
||||
|
||||
# Simulate discord not yet available at module load.
|
||||
monkeypatch.setattr(dp, "DISCORD_AVAILABLE", False)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import inspect
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
|
||||
|
||||
def test_discord_media_methods_accept_metadata_kwarg():
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from gateway.platforms.discord import ModelPickerView
|
||||
from plugins.platforms.discord.adapter import ModelPickerView
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -8,14 +8,14 @@ class TestOpusFindLibrary:
|
|||
|
||||
def test_uses_find_library_first(self):
|
||||
"""find_library must be the primary lookup strategy."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
source = inspect.getsource(DiscordAdapter.connect)
|
||||
assert "find_library" in source, \
|
||||
"Opus loading must use ctypes.util.find_library"
|
||||
|
||||
def test_homebrew_fallback_is_conditional(self):
|
||||
"""Homebrew paths must only be tried when find_library returns None."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
source = inspect.getsource(DiscordAdapter.connect)
|
||||
# Homebrew fallback must exist
|
||||
assert "/opt/homebrew" in source or "homebrew" in source, \
|
||||
|
|
@ -31,7 +31,7 @@ class TestOpusFindLibrary:
|
|||
|
||||
def test_opus_decode_error_logged(self):
|
||||
"""Opus decode failure must log the error, not silently return."""
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
from plugins.platforms.discord.adapter import VoiceReceiver
|
||||
source = inspect.getsource(VoiceReceiver._on_packet)
|
||||
assert "logger" in source, \
|
||||
"_on_packet must log Opus decode errors"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from gateway.config import Platform, PlatformConfig
|
|||
|
||||
|
||||
def _make_adapter():
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter._platform = Platform.DISCORD
|
||||
|
|
@ -60,7 +60,7 @@ async def test_concurrent_joins_do_not_double_connect():
|
|||
channel.guild.id = 42
|
||||
channel.connect = lambda: slow_connect(channel)
|
||||
|
||||
from gateway.platforms import discord as discord_mod
|
||||
from plugins.platforms.discord import adapter as discord_mod
|
||||
with patch.object(discord_mod, "VoiceReceiver",
|
||||
MagicMock(return_value=MagicMock(start=lambda: None))):
|
||||
with patch.object(discord_mod.asyncio, "ensure_future",
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def _ensure_discord_mock():
|
|||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
class FakeTree:
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ def _ensure_discord_mock():
|
|||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
|
||||
|
||||
def _set_dm_role_auth_guild(monkeypatch, guild_id=None):
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ def _ensure_discord_mock():
|
|||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ def _ensure_discord_mock():
|
|||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ def _ensure_discord_mock():
|
|||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
class FakeTree:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class TestDiscordThreadPersistence:
|
|||
def _make_adapter(self, tmp_path):
|
||||
"""Build a minimal DiscordAdapter with HERMES_HOME pointed at tmp_path."""
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
|
||||
config = PlatformConfig(enabled=True, token="test-token")
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
|
|
|
|||
|
|
@ -148,6 +148,15 @@ async def test_run_agent_passes_priority_processing_to_gateway_agent(monkeypatch
|
|||
monkeypatch.setattr(gateway_run, "_env_path", tmp_path / ".env")
|
||||
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {})
|
||||
# ``_load_service_tier`` was refactored to call ``_load_gateway_runtime_config``
|
||||
# (which wraps ``_load_gateway_config`` plus env-expansion). Since the test
|
||||
# stubs ``_load_gateway_config`` to ``{}``, also stub the runtime wrapper
|
||||
# directly so the priority routing assertions still exercise the live tier.
|
||||
monkeypatch.setattr(
|
||||
gateway_run,
|
||||
"_load_gateway_runtime_config",
|
||||
lambda: {"agent": {"service_tier": "fast"}},
|
||||
)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_gateway_model", lambda config=None: "gpt-5.4")
|
||||
monkeypatch.setattr(
|
||||
gateway_run,
|
||||
|
|
|
|||
|
|
@ -167,6 +167,7 @@ class TestFeishuAdapterMessaging(unittest.TestCase):
|
|||
"FEISHU_WEBHOOK_HOST": "127.0.0.1",
|
||||
"FEISHU_WEBHOOK_PORT": "9001",
|
||||
"FEISHU_WEBHOOK_PATH": "/hook",
|
||||
"FEISHU_VERIFICATION_TOKEN": "vtok",
|
||||
}, clear=True)
|
||||
def test_connect_webhook_mode_starts_local_server(self):
|
||||
from gateway.config import PlatformConfig
|
||||
|
|
@ -1538,6 +1539,34 @@ class TestAdapterBehavior(unittest.TestCase):
|
|||
self.assertEqual(response.status, 200)
|
||||
adapter._on_message_event.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {"FEISHU_VERIFICATION_TOKEN": "expected-token"}, clear=True)
|
||||
def test_url_verification_requires_configured_verification_token(self):
|
||||
"""url_verification must be rejected when token is set but mismatched.
|
||||
|
||||
Regression: previously the challenge was reflected before the token
|
||||
check, so an unauthenticated remote could prove endpoint control by
|
||||
sending an attacker-controlled challenge string.
|
||||
"""
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(PlatformConfig())
|
||||
body = json.dumps({
|
||||
"type": "url_verification",
|
||||
"token": "wrong-token",
|
||||
"challenge": "attacker-controlled-challenge",
|
||||
}).encode("utf-8")
|
||||
request = SimpleNamespace(
|
||||
remote="203.0.113.10",
|
||||
content_length=None,
|
||||
headers={},
|
||||
read=AsyncMock(return_value=body),
|
||||
)
|
||||
|
||||
response = asyncio.run(adapter._handle_webhook_request(request))
|
||||
|
||||
self.assertEqual(response.status, 401)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_process_inbound_message_uses_event_sender_identity_only(self):
|
||||
from gateway.config import PlatformConfig
|
||||
|
|
@ -3191,6 +3220,39 @@ class TestWebhookSecurity(unittest.TestCase):
|
|||
response = asyncio.run(adapter._handle_webhook_request(request))
|
||||
self.assertEqual(response.status, 401)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_webhook_connect_requires_inbound_auth_secret(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(
|
||||
PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"app_id": "cli_app", "app_secret": "secret_app", "connection_mode": "webhook"},
|
||||
)
|
||||
)
|
||||
self.assertFalse(asyncio.run(adapter.connect()))
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_webhook_loads_auth_secrets_from_platform_extra(self):
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
adapter = FeishuAdapter(
|
||||
PlatformConfig(
|
||||
enabled=True,
|
||||
extra={
|
||||
"app_id": "cli_app",
|
||||
"app_secret": "secret_app",
|
||||
"connection_mode": "webhook",
|
||||
"verification_token": "token_from_extra",
|
||||
"encrypt_key": "encrypt_from_extra",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(adapter._verification_token, "token_from_extra")
|
||||
self.assertEqual(adapter._encrypt_key, "encrypt_from_extra")
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_webhook_url_verification_challenge_passes_without_signature(self):
|
||||
"""Challenge requests must succeed even when no encrypt_key is set."""
|
||||
|
|
|
|||
|
|
@ -320,7 +320,7 @@ class TestResolveApproval:
|
|||
}
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
|
||||
await adapter._resolve_approval(1, "once", "Norbert")
|
||||
await adapter._resolve_approval(1, "once", "Norbert", open_id="ou_user1", chat_id="oc_12345")
|
||||
|
||||
mock_resolve.assert_called_once_with("agent:main:feishu:group:oc_12345", "once")
|
||||
assert 1 not in adapter._approval_state
|
||||
|
|
@ -335,7 +335,7 @@ class TestResolveApproval:
|
|||
}
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
|
||||
await adapter._resolve_approval(2, "deny", "Alice")
|
||||
await adapter._resolve_approval(2, "deny", "Alice", open_id="ou_user1", chat_id="oc_12345")
|
||||
|
||||
mock_resolve.assert_called_once_with("some-session", "deny")
|
||||
|
||||
|
|
@ -349,7 +349,7 @@ class TestResolveApproval:
|
|||
}
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
|
||||
await adapter._resolve_approval(3, "session", "Bob")
|
||||
await adapter._resolve_approval(3, "session", "Bob", open_id="ou_user1", chat_id="oc_99")
|
||||
|
||||
mock_resolve.assert_called_once_with("sess-3", "session")
|
||||
|
||||
|
|
@ -363,7 +363,7 @@ class TestResolveApproval:
|
|||
}
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve:
|
||||
await adapter._resolve_approval(4, "always", "Carol")
|
||||
await adapter._resolve_approval(4, "always", "Carol", open_id="ou_user1", chat_id="oc_55")
|
||||
|
||||
mock_resolve.assert_called_once_with("sess-4", "always")
|
||||
|
||||
|
|
@ -372,10 +372,41 @@ class TestResolveApproval:
|
|||
adapter = _make_adapter()
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
|
||||
await adapter._resolve_approval(99, "once", "Nobody")
|
||||
await adapter._resolve_approval(99, "once", "Nobody", open_id="ou_user1", chat_id="oc_12345")
|
||||
|
||||
mock_resolve.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_click_does_not_resolve(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._admins = {"ou_admin"}
|
||||
adapter._approval_state[5] = {
|
||||
"session_key": "sess-5",
|
||||
"message_id": "msg_005",
|
||||
"chat_id": "oc_12345",
|
||||
}
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
|
||||
await adapter._resolve_approval(5, "once", "Mallory", open_id="ou_intruder", chat_id="oc_12345")
|
||||
|
||||
mock_resolve.assert_not_called()
|
||||
assert 5 in adapter._approval_state
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_mismatch_does_not_resolve(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._approval_state[6] = {
|
||||
"session_key": "sess-6",
|
||||
"message_id": "msg_006",
|
||||
"chat_id": "oc_expected",
|
||||
}
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
|
||||
await adapter._resolve_approval(6, "session", "Norbert", open_id="ou_user1", chat_id="oc_wrong")
|
||||
|
||||
mock_resolve.assert_not_called()
|
||||
assert 6 in adapter._approval_state
|
||||
|
||||
# ===========================================================================
|
||||
# _handle_card_action_event — non-approval card actions
|
||||
# ===========================================================================
|
||||
|
|
@ -448,6 +479,12 @@ class TestCardActionCallbackResponse:
|
|||
adapter = _make_adapter()
|
||||
adapter._loop = MagicMock()
|
||||
adapter._loop.is_closed = MagicMock(return_value=False)
|
||||
adapter._allowed_group_users = {"ou_bob"}
|
||||
adapter._approval_state[1] = {
|
||||
"session_key": "sess-1",
|
||||
"message_id": "msg-1",
|
||||
"chat_id": "oc_12345",
|
||||
}
|
||||
data = _make_card_action_data(
|
||||
{"hermes_action": "approve_once", "approval_id": 1},
|
||||
open_id="ou_bob",
|
||||
|
|
@ -469,6 +506,12 @@ class TestCardActionCallbackResponse:
|
|||
adapter = _make_adapter()
|
||||
adapter._loop = MagicMock()
|
||||
adapter._loop.is_closed = MagicMock(return_value=False)
|
||||
adapter._allowed_group_users = {"ou_user1"}
|
||||
adapter._approval_state[2] = {
|
||||
"session_key": "sess-2",
|
||||
"message_id": "msg-2",
|
||||
"chat_id": "oc_12345",
|
||||
}
|
||||
data = _make_card_action_data(
|
||||
{"hermes_action": "deny", "approval_id": 2},
|
||||
)
|
||||
|
|
@ -510,6 +553,12 @@ class TestCardActionCallbackResponse:
|
|||
adapter = _make_adapter()
|
||||
adapter._loop = MagicMock()
|
||||
adapter._loop.is_closed = MagicMock(return_value=False)
|
||||
adapter._allowed_group_users = {"ou_unknown"}
|
||||
adapter._approval_state[3] = {
|
||||
"session_key": "sess-3",
|
||||
"message_id": "msg-3",
|
||||
"chat_id": "oc_12345",
|
||||
}
|
||||
data = _make_card_action_data(
|
||||
{"hermes_action": "approve_session", "approval_id": 3},
|
||||
open_id="ou_unknown",
|
||||
|
|
@ -525,6 +574,12 @@ class TestCardActionCallbackResponse:
|
|||
adapter = _make_adapter()
|
||||
adapter._loop = MagicMock()
|
||||
adapter._loop.is_closed = MagicMock(return_value=False)
|
||||
adapter._allowed_group_users = {"ou_expired"}
|
||||
adapter._approval_state[4] = {
|
||||
"session_key": "sess-4",
|
||||
"message_id": "msg-4",
|
||||
"chat_id": "oc_12345",
|
||||
}
|
||||
data = _make_card_action_data(
|
||||
{"hermes_action": "approve_once", "approval_id": 4},
|
||||
open_id="ou_expired",
|
||||
|
|
@ -538,6 +593,51 @@ class TestCardActionCallbackResponse:
|
|||
assert "Old Name" not in card["elements"][0]["content"]
|
||||
assert "ou_expired" in card["elements"][0]["content"]
|
||||
|
||||
def test_rejects_approval_click_from_unauthorized_user(self, _patch_callback_card_types):
|
||||
adapter = _make_adapter()
|
||||
adapter._loop = MagicMock()
|
||||
adapter._loop.is_closed = MagicMock(return_value=False)
|
||||
adapter._allowed_group_users = {"ou_allowed"}
|
||||
adapter._approval_state[5] = {
|
||||
"session_key": "sess-5",
|
||||
"message_id": "msg-5",
|
||||
"chat_id": "oc_12345",
|
||||
}
|
||||
data = _make_card_action_data(
|
||||
{"hermes_action": "approve_once", "approval_id": 5},
|
||||
open_id="ou_attacker",
|
||||
)
|
||||
|
||||
with patch("asyncio.run_coroutine_threadsafe") as mock_submit:
|
||||
response = adapter._on_card_action_trigger(data)
|
||||
|
||||
assert response is not None
|
||||
assert response.card is None
|
||||
mock_submit.assert_not_called()
|
||||
|
||||
def test_rejects_approval_click_when_callback_chat_mismatches(self, _patch_callback_card_types):
|
||||
adapter = _make_adapter()
|
||||
adapter._loop = MagicMock()
|
||||
adapter._loop.is_closed = MagicMock(return_value=False)
|
||||
adapter._allowed_group_users = {"ou_bob"}
|
||||
adapter._approval_state[6] = {
|
||||
"session_key": "sess-6",
|
||||
"message_id": "msg-6",
|
||||
"chat_id": "oc_expected",
|
||||
}
|
||||
data = _make_card_action_data(
|
||||
{"hermes_action": "approve_once", "approval_id": 6},
|
||||
chat_id="oc_mismatch",
|
||||
open_id="ou_bob",
|
||||
)
|
||||
|
||||
with patch("asyncio.run_coroutine_threadsafe") as mock_submit:
|
||||
response = adapter._on_card_action_trigger(data)
|
||||
|
||||
assert response is not None
|
||||
assert response.card is None
|
||||
mock_submit.assert_not_called()
|
||||
|
||||
def test_returns_card_for_update_prompt_yes(self, _patch_callback_card_types):
|
||||
adapter = _make_adapter()
|
||||
adapter._loop = MagicMock()
|
||||
|
|
|
|||
|
|
@ -103,6 +103,7 @@ class TestInterruptKeyConsistency:
|
|||
async def test_handle_message_stores_under_session_key(self):
|
||||
"""handle_message stores pending messages under session_key, not chat_id."""
|
||||
adapter = StubAdapter()
|
||||
adapter._busy_text_mode = ""
|
||||
adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
|
||||
|
||||
source = _source("-1001234", "group")
|
||||
|
|
@ -120,8 +121,8 @@ class TestInterruptKeyConsistency:
|
|||
# NOT stored under chat_id
|
||||
assert source.chat_id not in adapter._pending_messages
|
||||
|
||||
# Interrupt event was set
|
||||
assert adapter._active_sessions[session_key].is_set()
|
||||
# Text follow-ups queue silently and do not interrupt the active turn.
|
||||
assert adapter._active_sessions[session_key].is_set() is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_photo_followup_is_queued_without_interrupt(self):
|
||||
|
|
|
|||
210
tests/gateway/test_loop_exception_handler.py
Normal file
210
tests/gateway/test_loop_exception_handler.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
"""Tests for the gateway loop-level transient-network-error safety net.
|
||||
|
||||
Issues #31066 / #31110: unhandled ``telegram.error.TimedOut`` (or peer
|
||||
``NetworkError`` / ``httpx`` connection error) propagating to the
|
||||
asyncio event loop killed the gateway process, taking down every
|
||||
profile attached to the same runner. The safety net installed in
|
||||
:func:`gateway.run.start_gateway` catches the transient crash class
|
||||
and logs+swallows it; non-transient errors still surface.
|
||||
|
||||
These tests pin the classifier and the loop handler so the safety net
|
||||
can't silently regress to swallowing every exception.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.run import (
|
||||
_gateway_loop_exception_handler,
|
||||
_is_transient_network_error,
|
||||
)
|
||||
|
||||
|
||||
# ----- Fake exception classes that mimic the real wire types ----------
|
||||
# We avoid importing telegram / httpx here so the test runs in environments
|
||||
# without those packages installed (the classifier matches on class name).
|
||||
|
||||
class TimedOut(Exception):
|
||||
"""Stand-in for ``telegram.error.TimedOut``."""
|
||||
|
||||
|
||||
class NetworkError(Exception):
|
||||
"""Stand-in for ``telegram.error.NetworkError``."""
|
||||
|
||||
|
||||
class ConnectError(Exception):
|
||||
"""Stand-in for ``httpx.ConnectError``."""
|
||||
|
||||
|
||||
class ReadTimeout(Exception):
|
||||
"""Stand-in for ``httpx.ReadTimeout``."""
|
||||
|
||||
|
||||
class PoolTimeout(Exception):
|
||||
"""Stand-in for ``httpx.PoolTimeout``."""
|
||||
|
||||
|
||||
class ClientConnectorError(Exception):
|
||||
"""Stand-in for ``aiohttp.ClientConnectorError``."""
|
||||
|
||||
|
||||
class SomeUnrelatedBug(Exception):
|
||||
"""A non-transient error that should NOT be swallowed."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Classifier
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exc_cls",
|
||||
[
|
||||
TimedOut,
|
||||
NetworkError,
|
||||
ConnectError,
|
||||
ReadTimeout,
|
||||
PoolTimeout,
|
||||
ClientConnectorError,
|
||||
],
|
||||
)
|
||||
def test_transient_classifier_matches_known_network_errors(exc_cls):
|
||||
"""Every well-known transient network exception class is classified."""
|
||||
assert _is_transient_network_error(exc_cls("boom")) is True
|
||||
|
||||
|
||||
def test_transient_classifier_rejects_unrelated_errors():
|
||||
"""Real bugs (ValueError, KeyError, custom app errors) are NOT swallowed."""
|
||||
for exc in (ValueError("bad"), KeyError("missing"), SomeUnrelatedBug("x")):
|
||||
assert _is_transient_network_error(exc) is False
|
||||
|
||||
|
||||
def test_transient_classifier_unwraps_cause_chain():
|
||||
"""A NetworkError wrapping a ConnectError is still classified."""
|
||||
inner = ConnectError("connection refused")
|
||||
outer = NetworkError("upstream failed")
|
||||
outer.__cause__ = inner
|
||||
assert _is_transient_network_error(outer) is True
|
||||
|
||||
|
||||
def test_transient_classifier_unwraps_context_chain():
|
||||
"""Implicit ``__context__`` wrapping is also unwrapped."""
|
||||
try:
|
||||
try:
|
||||
raise TimedOut("upstream timeout")
|
||||
except TimedOut:
|
||||
# Re-raise something else with the original as implicit context
|
||||
raise SomeUnrelatedBug("wrapper")
|
||||
except SomeUnrelatedBug as e:
|
||||
wrapped = e
|
||||
# The wrapper class name is not transient, but the chained context is.
|
||||
assert _is_transient_network_error(wrapped) is True
|
||||
|
||||
|
||||
def test_transient_classifier_does_not_infinite_loop_on_cyclic_cause():
|
||||
"""A pathological self-referential cause chain terminates."""
|
||||
exc = SomeUnrelatedBug("loop")
|
||||
exc.__cause__ = exc # cycle
|
||||
# Must return without hanging.
|
||||
assert _is_transient_network_error(exc) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Loop handler
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_handler_swallows_transient_error_and_logs_warning(caplog):
|
||||
"""Transient errors are logged at WARNING but not re-raised."""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.run"):
|
||||
_gateway_loop_exception_handler(
|
||||
loop,
|
||||
{
|
||||
"message": "Task exception was never retrieved",
|
||||
"exception": TimedOut("Timed out"),
|
||||
},
|
||||
)
|
||||
# Warning emitted, exception class name appears in the log.
|
||||
assert any("TimedOut" in r.message for r in caplog.records)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
def test_handler_delegates_unknown_errors_to_default(monkeypatch):
|
||||
"""A non-transient error is forwarded to ``loop.default_exception_handler``."""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
forwarded: list[dict] = []
|
||||
|
||||
def fake_default(ctx):
|
||||
forwarded.append(ctx)
|
||||
|
||||
monkeypatch.setattr(loop, "default_exception_handler", fake_default)
|
||||
|
||||
context = {
|
||||
"message": "Something else broke",
|
||||
"exception": SomeUnrelatedBug("real bug"),
|
||||
}
|
||||
_gateway_loop_exception_handler(loop, context)
|
||||
assert forwarded == [context]
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
def test_handler_tolerates_missing_exception_key(monkeypatch):
|
||||
"""Contexts without an ``exception`` key fall through to the default handler."""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
forwarded: list[dict] = []
|
||||
monkeypatch.setattr(
|
||||
loop, "default_exception_handler", lambda ctx: forwarded.append(ctx)
|
||||
)
|
||||
ctx = {"message": "warning without exception"}
|
||||
_gateway_loop_exception_handler(loop, ctx)
|
||||
assert forwarded == [ctx]
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# End-to-end: task-level
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_unhandled_transient_error_in_task_does_not_propagate_to_loop():
|
||||
"""Smoke test the wiring as a loop would actually use it.
|
||||
|
||||
Schedules a task that raises TimedOut and is never awaited. With the
|
||||
handler installed, the loop completes normally and logs a warning
|
||||
instead of dying. Without the handler, asyncio would emit
|
||||
``Task exception was never retrieved`` and (depending on Python's
|
||||
debug mode) potentially escalate.
|
||||
"""
|
||||
|
||||
async def raiser():
|
||||
raise TimedOut("upstream timeout")
|
||||
|
||||
async def main():
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.set_exception_handler(_gateway_loop_exception_handler)
|
||||
task = loop.create_task(raiser())
|
||||
# Give the task a tick to run and raise.
|
||||
await asyncio.sleep(0)
|
||||
# Don't await ``task`` — let it become an unhandled-exception task.
|
||||
del task
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# If the safety net works, this returns cleanly. If not, the test
|
||||
# would still pass (asyncio's default is a warning, not a crash) —
|
||||
# the real assertion is that no unhandled exception escapes the
|
||||
# ``run`` boundary.
|
||||
asyncio.run(main())
|
||||
|
|
@ -797,6 +797,79 @@ class TestMatrixRequirements:
|
|||
with patch("tools.lazy_deps.ensure", side_effect=ImportError("mautrix unavailable")):
|
||||
assert matrix_mod.check_matrix_requirements() is False
|
||||
|
||||
def test_check_e2ee_deps_requires_asyncpg(self, monkeypatch):
|
||||
"""E2EE deps check must reject when asyncpg is missing — even if olm is present.
|
||||
|
||||
Regression for #31116: ``mautrix[encryption]`` extra installs python-olm
|
||||
but NOT asyncpg/aiosqlite, which are required by mautrix's crypto store
|
||||
at connect time. ``_check_e2ee_deps`` previously only tested
|
||||
``OlmMachine`` import and returned True, so the failure manifested as
|
||||
a confusing ``No module named 'asyncpg'`` deep in
|
||||
``MatrixAdapter.connect()``.
|
||||
"""
|
||||
from gateway.platforms.matrix import _check_e2ee_deps
|
||||
import builtins
|
||||
real_import = builtins.__import__
|
||||
|
||||
def _blocking_import(name, *args, **kwargs):
|
||||
if name == "asyncpg" or name.startswith("asyncpg."):
|
||||
raise ImportError("blocked for test")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", _blocking_import):
|
||||
assert _check_e2ee_deps() is False
|
||||
|
||||
def test_check_e2ee_deps_requires_aiosqlite(self):
|
||||
"""E2EE deps check must reject when aiosqlite is missing.
|
||||
|
||||
Mautrix's ``Database.create("sqlite:///...")`` driver lookup imports
|
||||
aiosqlite lazily — without it, connect fails at ``crypto_db.start()``.
|
||||
"""
|
||||
from gateway.platforms.matrix import _check_e2ee_deps
|
||||
import builtins
|
||||
real_import = builtins.__import__
|
||||
|
||||
def _blocking_import(name, *args, **kwargs):
|
||||
if name == "aiosqlite" or name.startswith("aiosqlite."):
|
||||
raise ImportError("blocked for test")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
with patch.object(builtins, "__import__", _blocking_import):
|
||||
assert _check_e2ee_deps() is False
|
||||
|
||||
def test_check_requirements_runs_lazy_install_when_partial(self, monkeypatch):
|
||||
"""When mautrix is installed but asyncpg/aiosqlite are missing,
|
||||
check_matrix_requirements must still run the lazy installer.
|
||||
|
||||
Regression for #31116: the previous ``try: import mautrix`` gate
|
||||
short-circuited the install of the OTHER 4 platform.matrix packages,
|
||||
so a partial install (mautrix only) was treated as fully installed.
|
||||
"""
|
||||
monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test")
|
||||
monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org")
|
||||
monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False)
|
||||
|
||||
from gateway.platforms import matrix as matrix_mod
|
||||
|
||||
# Simulate "mautrix installed, asyncpg missing" → feature_missing
|
||||
# returns a non-empty tuple → ensure_and_bind MUST be called.
|
||||
called = {"ensure_and_bind": False}
|
||||
|
||||
def _fake_ensure_and_bind(feature, importer, target_globals, **kwargs):
|
||||
called["ensure_and_bind"] = True
|
||||
assert feature == "platform.matrix"
|
||||
return True # Pretend install succeeded.
|
||||
|
||||
with patch("tools.lazy_deps.feature_missing", return_value=("asyncpg==0.31.0",)), \
|
||||
patch("tools.lazy_deps.ensure_and_bind", side_effect=_fake_ensure_and_bind):
|
||||
matrix_mod.check_matrix_requirements()
|
||||
|
||||
assert called["ensure_and_bind"], (
|
||||
"check_matrix_requirements must call ensure_and_bind whenever ANY "
|
||||
"platform.matrix dep is missing, not just when mautrix itself is "
|
||||
"missing (#31116)"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Access-token auth / E2EE bootstrap
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import json
|
|||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig, _apply_env_overrides
|
||||
from gateway.platforms.msgraph_webhook import MSGraphWebhookAdapter
|
||||
from gateway.platforms.msgraph_webhook import AIOHTTP_AVAILABLE, MSGraphWebhookAdapter
|
||||
|
||||
|
||||
def _make_adapter(**extra_overrides) -> MSGraphWebhookAdapter:
|
||||
|
|
@ -70,6 +70,16 @@ class TestMSGraphWebhookConfig:
|
|||
|
||||
|
||||
class TestMSGraphValidationHandshake:
|
||||
@pytest.mark.anyio
|
||||
async def test_connect_requires_client_state(self):
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
pytest.skip("aiohttp not installed")
|
||||
adapter = MSGraphWebhookAdapter(PlatformConfig(enabled=True, extra={}))
|
||||
connected = await adapter.connect()
|
||||
assert connected is False
|
||||
# is_connected is a @property on the base adapter, not a method.
|
||||
assert adapter.is_connected is False
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_validation_token_echo_on_get(self):
|
||||
adapter = _make_adapter()
|
||||
|
|
@ -99,6 +109,22 @@ class TestMSGraphValidationHandshake:
|
|||
|
||||
|
||||
class TestMSGraphNotifications:
|
||||
@pytest.mark.anyio
|
||||
async def test_missing_client_state_is_auth_rejected(self):
|
||||
adapter = _make_adapter(client_state=None)
|
||||
payload = {
|
||||
"value": [
|
||||
{
|
||||
"id": "notif-no-client-state",
|
||||
"subscriptionId": "sub-1",
|
||||
"changeType": "updated",
|
||||
"resource": "communications/onlineMeetings/meeting-1",
|
||||
}
|
||||
]
|
||||
}
|
||||
resp = await adapter._handle_notification(_FakeRequest(json_payload=payload))
|
||||
assert resp.status == 403
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_valid_notification_accepted_and_scheduled(self):
|
||||
adapter = _make_adapter()
|
||||
|
|
|
|||
943
tests/gateway/test_ntfy_plugin.py
Normal file
943
tests/gateway/test_ntfy_plugin.py
Normal file
|
|
@ -0,0 +1,943 @@
|
|||
"""Tests for the ntfy platform-plugin adapter.
|
||||
|
||||
Loaded via the ``_plugin_adapter_loader`` helper so this lives under
|
||||
``plugin_adapter_ntfy`` in ``sys.modules`` and cannot collide with
|
||||
sibling platform-plugin tests on the same xdist worker.
|
||||
|
||||
Most tests target the adapter class directly. The plugin-shape tests
|
||||
(``register()``, ``_env_enablement``, ``_standalone_send``, registry
|
||||
presence) replace the core-file grep tests from the original PR — the
|
||||
ntfy adapter no longer modifies ``gateway/config.py``, ``gateway/run.py``,
|
||||
``cron/scheduler.py``, ``toolsets.py``, etc. Everything routes through
|
||||
the ``platform_registry``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from tests.gateway._plugin_adapter_loader import load_plugin_adapter
|
||||
|
||||
_ntfy = load_plugin_adapter("ntfy")
|
||||
|
||||
NtfyAdapter = _ntfy.NtfyAdapter
|
||||
check_requirements = _ntfy.check_requirements
|
||||
validate_config = _ntfy.validate_config
|
||||
is_connected = _ntfy.is_connected
|
||||
register = _ntfy.register
|
||||
_env_enablement = _ntfy._env_enablement
|
||||
_standalone_send = _ntfy._standalone_send
|
||||
DEFAULT_SERVER = _ntfy.DEFAULT_SERVER
|
||||
DEDUP_WINDOW_SECONDS = _ntfy.DEDUP_WINDOW_SECONDS
|
||||
DEDUP_MAX_SIZE = _ntfy.DEDUP_MAX_SIZE
|
||||
MAX_MESSAGE_LENGTH = _ntfy.MAX_MESSAGE_LENGTH
|
||||
|
||||
|
||||
def _run(coro):
|
||||
"""Run an async coroutine synchronously."""
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Platform enum (plugin-discovered, not bundled)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_platform_enum_resolves_via_plugin_scan():
|
||||
"""The plugin filesystem scan should expose Platform("ntfy")."""
|
||||
from gateway.config import Platform
|
||||
p = Platform("ntfy")
|
||||
assert p.value == "ntfy"
|
||||
# Identity stability — repeated lookups return the same pseudo-member
|
||||
assert Platform("ntfy") is p
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. check_requirements / validate_config / is_connected
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNtfyRequirements:
|
||||
|
||||
def test_returns_false_when_httpx_unavailable(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "hermes-test")
|
||||
monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", False)
|
||||
assert check_requirements() is False
|
||||
|
||||
def test_returns_false_when_topic_not_set(self, monkeypatch):
|
||||
monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", True)
|
||||
monkeypatch.delenv("NTFY_TOPIC", raising=False)
|
||||
assert check_requirements() is False
|
||||
|
||||
def test_returns_true_when_topic_set_via_env(self, monkeypatch):
|
||||
monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", True)
|
||||
monkeypatch.setenv("NTFY_TOPIC", "hermes-test")
|
||||
assert check_requirements() is True
|
||||
|
||||
def test_validate_config_requires_topic(self, monkeypatch):
|
||||
monkeypatch.delenv("NTFY_TOPIC", raising=False)
|
||||
assert validate_config(PlatformConfig(enabled=True, extra={})) is False
|
||||
assert validate_config(
|
||||
PlatformConfig(enabled=True, extra={"topic": "t"})
|
||||
) is True
|
||||
|
||||
def test_is_connected_from_extra(self, monkeypatch):
|
||||
monkeypatch.delenv("NTFY_TOPIC", raising=False)
|
||||
assert is_connected(PlatformConfig(enabled=True, extra={"topic": "t"})) is True
|
||||
assert is_connected(PlatformConfig(enabled=True, extra={})) is False
|
||||
|
||||
def test_is_connected_from_env(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "env-topic")
|
||||
assert is_connected(PlatformConfig(enabled=True, extra={})) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Adapter init
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNtfyAdapterInit:
|
||||
|
||||
def test_default_server_url(self, monkeypatch):
|
||||
monkeypatch.delenv("NTFY_SERVER_URL", raising=False)
|
||||
config = PlatformConfig(enabled=True, extra={"topic": "hermes-in"})
|
||||
adapter = NtfyAdapter(config)
|
||||
assert adapter._server == DEFAULT_SERVER.rstrip("/")
|
||||
|
||||
def test_topic_read_from_extra(self):
|
||||
config = PlatformConfig(enabled=True, extra={"topic": "my-topic"})
|
||||
adapter = NtfyAdapter(config)
|
||||
assert adapter._topic == "my-topic"
|
||||
|
||||
def test_topic_read_from_env(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "env-topic")
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
adapter = NtfyAdapter(config)
|
||||
assert adapter._topic == "env-topic"
|
||||
|
||||
def test_publish_topic_falls_back_to_topic(self, monkeypatch):
|
||||
monkeypatch.delenv("NTFY_PUBLISH_TOPIC", raising=False)
|
||||
config = PlatformConfig(enabled=True, extra={"topic": "hermes-in"})
|
||||
adapter = NtfyAdapter(config)
|
||||
assert adapter._publish_topic == "hermes-in"
|
||||
|
||||
def test_publish_topic_uses_extra_value(self):
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"topic": "hermes-in", "publish_topic": "hermes-out"},
|
||||
)
|
||||
adapter = NtfyAdapter(config)
|
||||
assert adapter._publish_topic == "hermes-out"
|
||||
|
||||
def test_token_read_from_extra(self):
|
||||
config = PlatformConfig(enabled=True, extra={"topic": "t", "token": "tok-123"})
|
||||
adapter = NtfyAdapter(config)
|
||||
assert adapter._token == "tok-123"
|
||||
|
||||
def test_token_read_from_env(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOKEN", "env-token")
|
||||
config = PlatformConfig(enabled=True, extra={"topic": "t"})
|
||||
adapter = NtfyAdapter(config)
|
||||
assert adapter._token == "env-token"
|
||||
|
||||
def test_server_trailing_slash_stripped(self):
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"topic": "t", "server": "https://ntfy.example.com/"},
|
||||
)
|
||||
adapter = NtfyAdapter(config)
|
||||
assert not adapter._server.endswith("/")
|
||||
|
||||
def test_initial_state(self):
|
||||
config = PlatformConfig(enabled=True, extra={"topic": "t"})
|
||||
adapter = NtfyAdapter(config)
|
||||
assert adapter._stream_task is None
|
||||
assert adapter._http_client is None
|
||||
assert adapter._seen_messages == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Auth headers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuthHeaders:
|
||||
|
||||
def _make_adapter(self, token=""):
|
||||
config = PlatformConfig(enabled=True, extra={"topic": "t", "token": token})
|
||||
return NtfyAdapter(config)
|
||||
|
||||
def test_no_token_returns_empty_dict(self):
|
||||
adapter = self._make_adapter(token="")
|
||||
assert adapter._auth_headers() == {}
|
||||
|
||||
def test_bearer_token_for_plain_token(self):
|
||||
adapter = self._make_adapter(token="myapitoken")
|
||||
headers = adapter._auth_headers()
|
||||
assert headers["Authorization"] == "Bearer myapitoken"
|
||||
|
||||
def test_basic_auth_for_user_colon_password(self):
|
||||
adapter = self._make_adapter(token="user:pass")
|
||||
headers = adapter._auth_headers()
|
||||
assert headers["Authorization"].startswith("Basic ")
|
||||
import base64
|
||||
expected = "Basic " + base64.b64encode(b"user:pass").decode()
|
||||
assert headers["Authorization"] == expected
|
||||
|
||||
def test_bearer_token_used_when_no_colon(self):
|
||||
adapter = self._make_adapter(token="noColonHere")
|
||||
headers = adapter._auth_headers()
|
||||
assert headers["Authorization"] == "Bearer noColonHere"
|
||||
|
||||
def test_auth_header_key_is_authorization(self):
|
||||
adapter = self._make_adapter(token="tok")
|
||||
headers = adapter._auth_headers()
|
||||
assert list(headers.keys()) == ["Authorization"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Deduplication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeduplication:
|
||||
|
||||
def _make_adapter(self):
|
||||
return NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"}))
|
||||
|
||||
def test_first_message_not_duplicate(self):
|
||||
adapter = self._make_adapter()
|
||||
assert adapter._is_duplicate("msg-1") is False
|
||||
|
||||
def test_second_occurrence_is_duplicate(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter._is_duplicate("msg-1")
|
||||
assert adapter._is_duplicate("msg-1") is True
|
||||
|
||||
def test_different_ids_not_duplicate(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter._is_duplicate("msg-1")
|
||||
assert adapter._is_duplicate("msg-2") is False
|
||||
|
||||
def test_many_messages_recorded(self):
|
||||
adapter = self._make_adapter()
|
||||
for i in range(50):
|
||||
adapter._is_duplicate(f"msg-{i}")
|
||||
assert len(adapter._seen_messages) == 50
|
||||
|
||||
def test_cache_pruned_on_overflow(self):
|
||||
adapter = self._make_adapter()
|
||||
for i in range(DEDUP_MAX_SIZE + 20):
|
||||
adapter._is_duplicate(f"msg-{i}")
|
||||
assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 20
|
||||
|
||||
def test_expired_id_can_be_seen_again(self):
|
||||
import time
|
||||
adapter = self._make_adapter()
|
||||
adapter._seen_messages["old-msg"] = time.time() - DEDUP_WINDOW_SECONDS - 1
|
||||
for i in range(DEDUP_MAX_SIZE + 1):
|
||||
adapter._is_duplicate(f"fill-{i}")
|
||||
assert adapter._is_duplicate("old-msg") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. connect() / disconnect()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConnect:
|
||||
|
||||
def test_connect_fails_when_httpx_unavailable(self, monkeypatch):
|
||||
monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", False)
|
||||
adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"}))
|
||||
result = _run(adapter.connect())
|
||||
assert result is False
|
||||
|
||||
def test_connect_fails_when_no_topic(self, monkeypatch):
|
||||
monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", True)
|
||||
monkeypatch.delenv("NTFY_TOPIC", raising=False)
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
adapter = NtfyAdapter(config)
|
||||
result = _run(adapter.connect())
|
||||
assert result is False
|
||||
|
||||
def test_connect_starts_stream_task(self, monkeypatch):
|
||||
monkeypatch.setattr(_ntfy, "HTTPX_AVAILABLE", True)
|
||||
config = PlatformConfig(enabled=True, extra={"topic": "hermes-test"})
|
||||
adapter = NtfyAdapter(config)
|
||||
|
||||
with patch.object(adapter, "_run_stream", new_callable=AsyncMock):
|
||||
with patch.object(_ntfy, "httpx") as mock_httpx:
|
||||
mock_httpx.AsyncClient.return_value = MagicMock()
|
||||
result = _run(adapter.connect())
|
||||
|
||||
assert result is True
|
||||
assert adapter._stream_task is not None
|
||||
adapter._stream_task.cancel()
|
||||
try:
|
||||
_run(adapter._stream_task)
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
def test_disconnect_clears_state(self):
|
||||
adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"}))
|
||||
adapter._seen_messages["x"] = 1.0
|
||||
adapter._http_client = AsyncMock()
|
||||
adapter._stream_task = None
|
||||
adapter._running = True
|
||||
|
||||
_run(adapter.disconnect())
|
||||
|
||||
assert adapter._seen_messages == {}
|
||||
assert adapter._http_client is None
|
||||
assert adapter._running is False
|
||||
|
||||
def test_disconnect_cancels_stream_task(self):
|
||||
adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"}))
|
||||
|
||||
async def _hang():
|
||||
await asyncio.sleep(9999)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
adapter._stream_task = loop.create_task(_hang())
|
||||
adapter._http_client = AsyncMock()
|
||||
adapter._running = True
|
||||
|
||||
_run(adapter.disconnect())
|
||||
assert adapter._stream_task is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. send()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSend:
|
||||
|
||||
def _make_adapter(self, topic="hermes-in", publish_topic="", token="", markdown=False):
|
||||
extra: dict = {"topic": topic, "token": token}
|
||||
if publish_topic:
|
||||
extra["publish_topic"] = publish_topic
|
||||
if markdown:
|
||||
extra["markdown"] = True
|
||||
return NtfyAdapter(PlatformConfig(enabled=True, extra=extra))
|
||||
|
||||
def test_send_fails_without_http_client(self):
|
||||
adapter = self._make_adapter()
|
||||
result = _run(adapter.send("hermes-in", "hello"))
|
||||
assert result.success is False
|
||||
assert "not initialized" in result.error.lower()
|
||||
|
||||
def test_send_posts_to_publish_topic(self):
|
||||
adapter = self._make_adapter(topic="hermes-in", publish_topic="hermes-out")
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"id": "abc123"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
adapter._http_client = mock_client
|
||||
|
||||
result = _run(adapter.send("hermes-in", "Hello ntfy!"))
|
||||
assert result.success is True
|
||||
assert result.message_id == "abc123"
|
||||
|
||||
posted_url = mock_client.post.call_args[0][0]
|
||||
assert posted_url.endswith("/hermes-out")
|
||||
|
||||
def test_send_falls_back_to_subscribe_topic(self):
|
||||
adapter = self._make_adapter(topic="hermes-in")
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
adapter._http_client = mock_client
|
||||
|
||||
result = _run(adapter.send("hermes-in", "Hello!"))
|
||||
assert result.success is True
|
||||
posted_url = mock_client.post.call_args[0][0]
|
||||
assert posted_url.endswith("/hermes-in")
|
||||
|
||||
def test_send_uses_metadata_publish_topic(self):
|
||||
adapter = self._make_adapter(topic="hermes-in")
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
adapter._http_client = mock_client
|
||||
|
||||
result = _run(adapter.send(
|
||||
"hermes-in", "Hi!", metadata={"publish_topic": "override-out"}
|
||||
))
|
||||
assert result.success is True
|
||||
posted_url = mock_client.post.call_args[0][0]
|
||||
assert posted_url.endswith("/override-out")
|
||||
|
||||
def test_send_handles_http_error_status(self):
|
||||
adapter = self._make_adapter(topic="hermes-in")
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 403
|
||||
mock_resp.text = "Forbidden"
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
adapter._http_client = mock_client
|
||||
|
||||
result = _run(adapter.send("hermes-in", "Hello!"))
|
||||
assert result.success is False
|
||||
assert "403" in result.error
|
||||
|
||||
def test_send_handles_timeout(self):
|
||||
adapter = self._make_adapter(topic="hermes-in")
|
||||
|
||||
class _FakeTimeout(Exception):
|
||||
pass
|
||||
|
||||
fake_httpx = MagicMock()
|
||||
fake_httpx.TimeoutException = _FakeTimeout
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(side_effect=_FakeTimeout("timed out"))
|
||||
adapter._http_client = mock_client
|
||||
|
||||
with patch.object(_ntfy, "httpx", fake_httpx):
|
||||
result = _run(adapter.send("hermes-in", "Hello!"))
|
||||
|
||||
assert result.success is False
|
||||
assert "timeout" in result.error.lower()
|
||||
|
||||
def test_send_truncates_to_max_length(self):
|
||||
adapter = self._make_adapter(topic="t")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
adapter._http_client = mock_client
|
||||
|
||||
long_msg = "x" * (MAX_MESSAGE_LENGTH + 500)
|
||||
_run(adapter.send("t", long_msg))
|
||||
|
||||
posted_body = mock_client.post.call_args[1]["content"]
|
||||
assert len(posted_body.decode()) <= MAX_MESSAGE_LENGTH
|
||||
|
||||
def test_send_typing_is_noop(self):
|
||||
adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"}))
|
||||
_run(adapter.send_typing("t")) # must not raise
|
||||
|
||||
def test_get_chat_info_returns_dict(self):
|
||||
adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"}))
|
||||
info = _run(adapter.get_chat_info("hermes-in"))
|
||||
assert info["name"] == "hermes-in"
|
||||
assert info["type"] == "dm"
|
||||
|
||||
def test_send_includes_bearer_auth_header(self):
|
||||
adapter = self._make_adapter(topic="hermes-in", token="mytoken")
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
adapter._http_client = mock_client
|
||||
|
||||
_run(adapter.send("hermes-in", "secure message"))
|
||||
|
||||
call_headers = mock_client.post.call_args[1]["headers"]
|
||||
assert call_headers.get("Authorization") == "Bearer mytoken"
|
||||
|
||||
def test_send_emits_markdown_header_when_enabled(self):
|
||||
adapter = self._make_adapter(topic="hermes-in", markdown=True)
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
adapter._http_client = mock_client
|
||||
|
||||
_run(adapter.send("hermes-in", "**bold**"))
|
||||
call_headers = mock_client.post.call_args[1]["headers"]
|
||||
assert call_headers.get("X-Markdown") == "true"
|
||||
|
||||
def test_send_omits_markdown_header_when_disabled(self):
|
||||
adapter = self._make_adapter(topic="hermes-in", markdown=False)
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
adapter._http_client = mock_client
|
||||
|
||||
_run(adapter.send("hermes-in", "plain"))
|
||||
call_headers = mock_client.post.call_args[1]["headers"]
|
||||
assert "X-Markdown" not in call_headers
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Inbound message processing (identity invariant — security-critical)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOnMessage:
|
||||
|
||||
def _make_adapter(self):
|
||||
return NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "hermes-in"}))
|
||||
|
||||
def test_message_dispatched_to_handler(self):
|
||||
adapter = self._make_adapter()
|
||||
calls = []
|
||||
|
||||
async def handler(event):
|
||||
calls.append(event)
|
||||
|
||||
adapter.set_message_handler(handler)
|
||||
|
||||
event = {
|
||||
"id": "evt-001",
|
||||
"event": "message",
|
||||
"topic": "hermes-in",
|
||||
"message": "Hello from ntfy",
|
||||
"time": 1700000000,
|
||||
}
|
||||
_run(adapter._on_message(event))
|
||||
assert len(calls) == 1
|
||||
assert calls[0].text == "Hello from ntfy"
|
||||
|
||||
def test_empty_message_skipped(self):
|
||||
adapter = self._make_adapter()
|
||||
calls = []
|
||||
|
||||
async def handler(event):
|
||||
calls.append(event)
|
||||
|
||||
adapter.set_message_handler(handler)
|
||||
_run(adapter._on_message({
|
||||
"id": "x", "event": "message", "topic": "t", "message": "", "time": None
|
||||
}))
|
||||
assert calls == []
|
||||
|
||||
def test_duplicate_message_skipped(self):
|
||||
adapter = self._make_adapter()
|
||||
calls = []
|
||||
|
||||
async def handler(event):
|
||||
calls.append(event)
|
||||
|
||||
adapter.set_message_handler(handler)
|
||||
event = {"id": "dup-1", "event": "message", "topic": "hermes-in", "message": "hi", "time": None}
|
||||
_run(adapter._on_message(event))
|
||||
_run(adapter._on_message(event))
|
||||
assert len(calls) == 1
|
||||
|
||||
def test_timestamp_parsed_from_event(self):
|
||||
from datetime import timezone
|
||||
adapter = self._make_adapter()
|
||||
captured = []
|
||||
|
||||
async def handler(event):
|
||||
captured.append(event)
|
||||
|
||||
adapter.set_message_handler(handler)
|
||||
_run(adapter._on_message({
|
||||
"id": "ts-1",
|
||||
"event": "message",
|
||||
"topic": "hermes-in",
|
||||
"message": "ping",
|
||||
"time": 1700000000,
|
||||
}))
|
||||
ts = captured[0].timestamp
|
||||
assert ts.tzinfo == timezone.utc
|
||||
|
||||
def test_message_id_set_from_event(self):
|
||||
adapter = self._make_adapter()
|
||||
captured = []
|
||||
|
||||
async def handler(event):
|
||||
captured.append(event)
|
||||
|
||||
adapter.set_message_handler(handler)
|
||||
_run(adapter._on_message({
|
||||
"id": "ntfy-id-42",
|
||||
"event": "message",
|
||||
"topic": "hermes-in",
|
||||
"message": "test",
|
||||
"time": None,
|
||||
}))
|
||||
assert captured[0].message_id == "ntfy-id-42"
|
||||
|
||||
def test_title_not_used_as_user_id(self):
|
||||
"""title field must not be used for identity — it is publisher-controlled."""
|
||||
adapter = self._make_adapter()
|
||||
captured = []
|
||||
|
||||
async def handler(event):
|
||||
captured.append(event)
|
||||
|
||||
adapter.set_message_handler(handler)
|
||||
_run(adapter._on_message({
|
||||
"id": "u-1",
|
||||
"event": "message",
|
||||
"topic": "hermes-in",
|
||||
"message": "hello",
|
||||
"title": "Alice",
|
||||
"time": None,
|
||||
}))
|
||||
assert captured[0].source.user_id == "hermes-in"
|
||||
assert captured[0].source.user_name == "hermes-in"
|
||||
|
||||
def test_unknown_publisher_cannot_impersonate_allowed_user(self):
|
||||
"""An unknown publisher setting title=admin must not gain admin identity."""
|
||||
adapter = self._make_adapter()
|
||||
captured = []
|
||||
|
||||
async def handler(event):
|
||||
captured.append(event)
|
||||
|
||||
adapter.set_message_handler(handler)
|
||||
_run(adapter._on_message({
|
||||
"id": "u-2",
|
||||
"event": "message",
|
||||
"topic": "hermes-in",
|
||||
"message": "sensitive command",
|
||||
"title": "admin",
|
||||
"time": None,
|
||||
}))
|
||||
assert captured[0].source.user_id == "hermes-in"
|
||||
assert captured[0].source.user_id != "admin"
|
||||
|
||||
def test_source_chat_id_is_topic(self):
|
||||
adapter = self._make_adapter()
|
||||
captured = []
|
||||
|
||||
async def handler(event):
|
||||
captured.append(event)
|
||||
|
||||
adapter.set_message_handler(handler)
|
||||
_run(adapter._on_message({
|
||||
"id": "s-1",
|
||||
"event": "message",
|
||||
"topic": "hermes-in",
|
||||
"message": "hello",
|
||||
"time": None,
|
||||
}))
|
||||
assert captured[0].source.chat_id == "hermes-in"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 9. _env_enablement() — env-only auto-config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnvEnablement:
|
||||
|
||||
def test_returns_none_without_topic(self, monkeypatch):
|
||||
monkeypatch.delenv("NTFY_TOPIC", raising=False)
|
||||
assert _env_enablement() is None
|
||||
|
||||
def test_seeds_topic_and_server(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
|
||||
monkeypatch.delenv("NTFY_SERVER_URL", raising=False)
|
||||
seed = _env_enablement()
|
||||
assert seed is not None
|
||||
assert seed["topic"] == "hermes-in"
|
||||
assert seed["server"] == DEFAULT_SERVER
|
||||
|
||||
def test_custom_server_url(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
|
||||
monkeypatch.setenv("NTFY_SERVER_URL", "https://ntfy.example.com/")
|
||||
seed = _env_enablement()
|
||||
assert seed["server"] == "https://ntfy.example.com" # trailing slash stripped
|
||||
|
||||
def test_publish_topic_seeded(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
|
||||
monkeypatch.setenv("NTFY_PUBLISH_TOPIC", "hermes-out")
|
||||
seed = _env_enablement()
|
||||
assert seed["publish_topic"] == "hermes-out"
|
||||
|
||||
def test_token_seeded(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
|
||||
monkeypatch.setenv("NTFY_TOKEN", "tk_abc")
|
||||
seed = _env_enablement()
|
||||
assert seed["token"] == "tk_abc"
|
||||
|
||||
def test_markdown_truthy_values(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
|
||||
for val in ("true", "1", "yes", "TRUE"):
|
||||
monkeypatch.setenv("NTFY_MARKDOWN", val)
|
||||
assert _env_enablement()["markdown"] is True
|
||||
|
||||
def test_markdown_falsy_values(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
|
||||
for val in ("false", "0", "no", "anything"):
|
||||
monkeypatch.setenv("NTFY_MARKDOWN", val)
|
||||
assert _env_enablement()["markdown"] is False
|
||||
|
||||
def test_home_channel_defaults_to_topic(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
|
||||
monkeypatch.delenv("NTFY_HOME_CHANNEL", raising=False)
|
||||
seed = _env_enablement()
|
||||
assert seed["home_channel"]["chat_id"] == "hermes-in"
|
||||
assert seed["home_channel"]["name"] == "hermes-in"
|
||||
|
||||
def test_home_channel_override(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
|
||||
monkeypatch.setenv("NTFY_HOME_CHANNEL", "alerts")
|
||||
monkeypatch.setenv("NTFY_HOME_CHANNEL_NAME", "Alerts Channel")
|
||||
seed = _env_enablement()
|
||||
assert seed["home_channel"]["chat_id"] == "alerts"
|
||||
assert seed["home_channel"]["name"] == "Alerts Channel"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 10. _standalone_send() — out-of-process cron delivery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStandaloneSend:
|
||||
|
||||
def test_errors_without_topic(self, monkeypatch):
|
||||
monkeypatch.delenv("NTFY_TOPIC", raising=False)
|
||||
monkeypatch.delenv("NTFY_PUBLISH_TOPIC", raising=False)
|
||||
pconfig = MagicMock()
|
||||
pconfig.extra = {}
|
||||
result = _run(_standalone_send(pconfig, "", "hello"))
|
||||
assert "error" in result
|
||||
assert "NTFY_TOPIC" in result["error"]
|
||||
|
||||
def test_posts_to_server(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
|
||||
pconfig = MagicMock()
|
||||
pconfig.extra = {"server": "https://ntfy.example.com", "topic": "hermes-in"}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"id": "id-42"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(_ntfy, "httpx") as mock_httpx:
|
||||
mock_httpx.AsyncClient.return_value = mock_client
|
||||
result = _run(_standalone_send(pconfig, "hermes-in", "hello"))
|
||||
|
||||
assert result.get("success") is True
|
||||
assert result["platform"] == "ntfy"
|
||||
assert result["message_id"] == "id-42"
|
||||
posted_url = mock_client.post.call_args[0][0]
|
||||
assert posted_url == "https://ntfy.example.com/hermes-in"
|
||||
|
||||
def test_emits_bearer_token_when_configured(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
|
||||
pconfig = MagicMock()
|
||||
pconfig.extra = {"topic": "hermes-in", "token": "tk_xyz"}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(_ntfy, "httpx") as mock_httpx:
|
||||
mock_httpx.AsyncClient.return_value = mock_client
|
||||
_run(_standalone_send(pconfig, "hermes-in", "hi"))
|
||||
|
||||
headers = mock_client.post.call_args[1]["headers"]
|
||||
assert headers["Authorization"] == "Bearer tk_xyz"
|
||||
|
||||
def test_basic_auth_when_token_has_colon(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
|
||||
pconfig = MagicMock()
|
||||
pconfig.extra = {"topic": "hermes-in", "token": "user:pass"}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {}
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(_ntfy, "httpx") as mock_httpx:
|
||||
mock_httpx.AsyncClient.return_value = mock_client
|
||||
_run(_standalone_send(pconfig, "hermes-in", "hi"))
|
||||
|
||||
headers = mock_client.post.call_args[1]["headers"]
|
||||
assert headers["Authorization"].startswith("Basic ")
|
||||
|
||||
def test_returns_error_on_http_failure(self, monkeypatch):
|
||||
monkeypatch.setenv("NTFY_TOPIC", "hermes-in")
|
||||
pconfig = MagicMock()
|
||||
pconfig.extra = {"topic": "hermes-in"}
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 403
|
||||
mock_resp.text = "Forbidden"
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(_ntfy, "httpx") as mock_httpx:
|
||||
mock_httpx.AsyncClient.return_value = mock_client
|
||||
result = _run(_standalone_send(pconfig, "hermes-in", "hi"))
|
||||
|
||||
assert "error" in result
|
||||
assert "403" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 11. register() — plugin-side metadata
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_register_calls_register_platform():
|
||||
ctx = MagicMock()
|
||||
register(ctx)
|
||||
ctx.register_platform.assert_called_once()
|
||||
kwargs = ctx.register_platform.call_args.kwargs
|
||||
assert kwargs["name"] == "ntfy"
|
||||
assert kwargs["label"] == "ntfy"
|
||||
assert kwargs["required_env"] == ["NTFY_TOPIC"]
|
||||
assert kwargs["allowed_users_env"] == "NTFY_ALLOWED_USERS"
|
||||
assert kwargs["allow_all_env"] == "NTFY_ALLOW_ALL_USERS"
|
||||
assert kwargs["cron_deliver_env_var"] == "NTFY_HOME_CHANNEL"
|
||||
assert kwargs["max_message_length"] == MAX_MESSAGE_LENGTH
|
||||
assert callable(kwargs["check_fn"])
|
||||
assert callable(kwargs["validate_config"])
|
||||
assert callable(kwargs["is_connected"])
|
||||
assert callable(kwargs["env_enablement_fn"])
|
||||
assert callable(kwargs["standalone_sender_fn"])
|
||||
assert callable(kwargs["adapter_factory"])
|
||||
# ntfy has no user-identifying PII (only topic names)
|
||||
assert kwargs["pii_safe"] is True
|
||||
assert "ntfy" in kwargs["platform_hint"].lower()
|
||||
|
||||
|
||||
def test_adapter_factory_returns_ntfy_adapter():
|
||||
ctx = MagicMock()
|
||||
register(ctx)
|
||||
factory = ctx.register_platform.call_args.kwargs["adapter_factory"]
|
||||
cfg = PlatformConfig(enabled=True, extra={"topic": "t"})
|
||||
adapter = factory(cfg)
|
||||
assert isinstance(adapter, NtfyAdapter)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 12. Robustness — token hygiene + fatal-state propagation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTokenHygiene:
|
||||
"""``_build_auth_header`` must strip pasted-token whitespace; pasted
|
||||
tokens often carry trailing newlines that break the Authorization line."""
|
||||
|
||||
def test_trailing_whitespace_stripped(self):
|
||||
assert _ntfy._build_auth_header(" tok123 ") == {"Authorization": "Bearer tok123"}
|
||||
|
||||
def test_trailing_newline_stripped(self):
|
||||
assert _ntfy._build_auth_header("tok123\n") == {"Authorization": "Bearer tok123"}
|
||||
|
||||
def test_whitespace_only_returns_empty(self):
|
||||
assert _ntfy._build_auth_header(" \n ") == {}
|
||||
|
||||
def test_basic_auth_token_also_stripped(self):
|
||||
h = _ntfy._build_auth_header(" user:pass ")
|
||||
assert h["Authorization"].startswith("Basic ")
|
||||
import base64
|
||||
assert h["Authorization"] == "Basic " + base64.b64encode(b"user:pass").decode()
|
||||
|
||||
def test_adapter_strips_token_via_helper(self):
|
||||
"""The adapter delegates to _build_auth_header, so token whitespace
|
||||
passed via config.extra is also stripped."""
|
||||
config = PlatformConfig(enabled=True, extra={"topic": "t", "token": " tok\n"})
|
||||
adapter = NtfyAdapter(config)
|
||||
assert adapter._auth_headers() == {"Authorization": "Bearer tok"}
|
||||
|
||||
|
||||
class TestFatalErrorPropagation:
|
||||
"""When the stream hits 401/404, the adapter must transition to the
|
||||
``fatal`` state via ``_set_fatal_error`` so the gateway's runtime
|
||||
status reflects reality instead of staying 'connected'."""
|
||||
|
||||
def test_401_sets_fatal_unauthorized(self):
|
||||
adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "t"}))
|
||||
adapter._http_client = MagicMock()
|
||||
|
||||
# Mock the streaming response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
# async-context-manager flavor for httpx.stream
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=None)
|
||||
adapter._http_client.stream = MagicMock(return_value=mock_cm)
|
||||
|
||||
fake_httpx = MagicMock()
|
||||
fake_httpx.Timeout = MagicMock()
|
||||
with patch.object(_ntfy, "httpx", fake_httpx):
|
||||
with pytest.raises(_ntfy._FatalStreamError):
|
||||
_run(adapter._consume_stream("https://ntfy.example/t/json", {}))
|
||||
|
||||
assert adapter.has_fatal_error is True
|
||||
assert adapter._fatal_error_code == "ntfy_unauthorized"
|
||||
assert adapter._fatal_error_retryable is False
|
||||
|
||||
def test_404_sets_fatal_topic_not_found(self):
|
||||
adapter = NtfyAdapter(PlatformConfig(enabled=True, extra={"topic": "missing-topic"}))
|
||||
adapter._http_client = MagicMock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=None)
|
||||
adapter._http_client.stream = MagicMock(return_value=mock_cm)
|
||||
|
||||
fake_httpx = MagicMock()
|
||||
fake_httpx.Timeout = MagicMock()
|
||||
with patch.object(_ntfy, "httpx", fake_httpx):
|
||||
with pytest.raises(_ntfy._FatalStreamError):
|
||||
_run(adapter._consume_stream("https://ntfy.example/missing-topic/json", {}))
|
||||
|
||||
assert adapter.has_fatal_error is True
|
||||
assert adapter._fatal_error_code == "ntfy_topic_not_found"
|
||||
assert "missing-topic" in adapter._fatal_error_message
|
||||
assert adapter._fatal_error_retryable is False
|
||||
|
||||
|
||||
class TestTruncateHelper:
|
||||
"""``_truncate_body`` is shared between adapter.send() (inline truncation
|
||||
today, may migrate) and ``_standalone_send``. It must cap to
|
||||
MAX_MESSAGE_LENGTH and return bytes."""
|
||||
|
||||
def test_short_message_passes_through(self):
|
||||
assert _ntfy._truncate_body("hi", context="test") == b"hi"
|
||||
|
||||
def test_long_message_truncated(self):
|
||||
long = "x" * (MAX_MESSAGE_LENGTH + 50)
|
||||
result = _ntfy._truncate_body(long, context="test")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) == MAX_MESSAGE_LENGTH
|
||||
|
||||
def test_unicode_message_encoded(self):
|
||||
result = _ntfy._truncate_body("héllo 🔔", context="test")
|
||||
assert result == "héllo 🔔".encode("utf-8")
|
||||
|
|
@ -2,10 +2,13 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.pairing import (
|
||||
PairingStore,
|
||||
ALPHABET,
|
||||
|
|
@ -37,6 +40,10 @@ class TestSecureWrite:
|
|||
assert target.exists()
|
||||
assert json.loads(target.read_text()) == {"hello": "world"}
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("win"),
|
||||
reason="POSIX file modes are not enforced on Windows",
|
||||
)
|
||||
def test_sets_file_permissions(self, tmp_path):
|
||||
target = tmp_path / "secret.json"
|
||||
_secure_write(target, "data")
|
||||
|
|
@ -75,9 +82,197 @@ class TestCodeGeneration:
|
|||
code = store.generate_code("telegram", "user1", "Alice")
|
||||
pending = store.list_pending("telegram")
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["code"] == code
|
||||
# list_pending no longer returns the original code — it returns a
|
||||
# truncated hash prefix. Verify the metadata is correct instead.
|
||||
assert pending[0]["user_id"] == "user1"
|
||||
assert pending[0]["user_name"] == "Alice"
|
||||
# The code field is now a hash prefix, not the original plaintext code
|
||||
assert pending[0]["code"] != code
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hashed storage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHashedStorage:
|
||||
def test_pending_file_contains_hash_and_salt(self, tmp_path):
|
||||
"""Stored entries must have 'hash' and 'salt', never the plaintext code."""
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1", "Alice")
|
||||
raw = json.loads(
|
||||
(tmp_path / "telegram-pending.json").read_text(encoding="utf-8")
|
||||
)
|
||||
|
||||
assert len(raw) == 1
|
||||
entry = next(iter(raw.values()))
|
||||
# Must have hash and salt fields
|
||||
assert "hash" in entry
|
||||
assert "salt" in entry
|
||||
# Hash must be a valid hex SHA-256 digest (64 hex chars)
|
||||
assert len(entry["hash"]) == 64
|
||||
assert all(c in "0123456789abcdef" for c in entry["hash"])
|
||||
# Salt must be a valid hex string (32 hex chars for 16 bytes)
|
||||
assert len(entry["salt"]) == 32
|
||||
assert all(c in "0123456789abcdef" for c in entry["salt"])
|
||||
# The plaintext code must NOT appear as a key or value anywhere
|
||||
assert code not in raw # not a key
|
||||
for key, val in raw.items():
|
||||
assert code != key
|
||||
for field_val in val.values():
|
||||
if isinstance(field_val, str):
|
||||
assert field_val != code
|
||||
|
||||
def test_plaintext_code_not_stored(self, tmp_path):
|
||||
"""The raw JSON file must not contain the plaintext code anywhere."""
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1")
|
||||
raw_text = (tmp_path / "telegram-pending.json").read_text(encoding="utf-8")
|
||||
assert code not in raw_text
|
||||
|
||||
def test_valid_code_verifies_against_hash(self, tmp_path):
|
||||
"""approve_code with the correct code should succeed."""
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1", "Bob")
|
||||
result = store.approve_code("telegram", code)
|
||||
assert result is not None
|
||||
assert result["user_id"] == "user1"
|
||||
assert result["user_name"] == "Bob"
|
||||
|
||||
def test_invalid_code_rejected(self, tmp_path):
|
||||
"""approve_code with a wrong code should fail."""
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
store.generate_code("telegram", "user1")
|
||||
result = store.approve_code("telegram", "ZZZZZZZZ")
|
||||
assert result is None
|
||||
|
||||
def test_different_salts_per_entry(self, tmp_path):
|
||||
"""Each pending entry should have a unique salt."""
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
store.generate_code("telegram", "user0")
|
||||
store.generate_code("telegram", "user1")
|
||||
store.generate_code("telegram", "user2")
|
||||
raw = json.loads(
|
||||
(tmp_path / "telegram-pending.json").read_text(encoding="utf-8")
|
||||
)
|
||||
salts = [entry["salt"] for entry in raw.values()]
|
||||
assert len(set(salts)) == 3 # all unique
|
||||
|
||||
def test_hash_code_static_method(self, tmp_path):
|
||||
"""_hash_code should be deterministic for the same code+salt."""
|
||||
salt = os.urandom(16)
|
||||
h1 = PairingStore._hash_code("ABCD1234", salt)
|
||||
h2 = PairingStore._hash_code("ABCD1234", salt)
|
||||
assert h1 == h2
|
||||
# Different salt should produce a different hash
|
||||
salt2 = os.urandom(16)
|
||||
h3 = PairingStore._hash_code("ABCD1234", salt2)
|
||||
assert h3 != h1
|
||||
|
||||
|
||||
class TestLegacyPendingFileCompat:
|
||||
"""Defensive coverage for pre-hash pending.json on upgraded installs.
|
||||
|
||||
Existing user installs may have a pending.json written by the old
|
||||
code (plaintext code as key, no hash/salt fields). The new
|
||||
approve_code / list_pending / _cleanup_expired must not crash on
|
||||
those entries — they should be ignored and aged out at TTL.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _write_legacy(tmp_path, code="ABCD1234", created_at=None):
|
||||
"""Write a pre-hash pending.json with plaintext code as the key."""
|
||||
import time as _time
|
||||
if created_at is None:
|
||||
created_at = _time.time()
|
||||
legacy = {
|
||||
code: {
|
||||
"user_id": "legacy-user",
|
||||
"user_name": "Legacy",
|
||||
"created_at": created_at,
|
||||
}
|
||||
}
|
||||
(tmp_path / "telegram-pending.json").write_text(
|
||||
json.dumps(legacy), encoding="utf-8"
|
||||
)
|
||||
|
||||
def test_approve_code_ignores_legacy_entries(self, tmp_path):
|
||||
"""A valid old-format code must NOT silently approve under the new schema."""
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
self._write_legacy(tmp_path, code="LEGACY01")
|
||||
store = PairingStore()
|
||||
# The plaintext "code" used to be the key — under the new schema
|
||||
# it's not even looked at, and there's no hash/salt to verify.
|
||||
# Result: approve_code returns None, the legacy entry is left
|
||||
# alone (gets pruned by _cleanup_expired at TTL).
|
||||
result = store.approve_code("telegram", "LEGACY01")
|
||||
assert result is None
|
||||
# Approved list must be empty
|
||||
assert store.is_approved("telegram", "legacy-user") is False
|
||||
|
||||
def test_list_pending_handles_legacy_entries(self, tmp_path):
|
||||
"""list_pending must not KeyError on a missing 'hash' field."""
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
self._write_legacy(tmp_path)
|
||||
store = PairingStore()
|
||||
pending = store.list_pending("telegram")
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["user_id"] == "legacy-user"
|
||||
assert pending[0]["code"] == "legacy" # placeholder
|
||||
|
||||
def test_cleanup_expired_removes_legacy_at_ttl(self, tmp_path):
|
||||
"""Legacy entries past CODE_TTL must still get pruned."""
|
||||
import time as _time
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
self._write_legacy(
|
||||
tmp_path,
|
||||
code="LEGACY99",
|
||||
created_at=_time.time() - CODE_TTL_SECONDS - 1,
|
||||
)
|
||||
store = PairingStore()
|
||||
store._cleanup_expired("telegram")
|
||||
raw = json.loads(
|
||||
(tmp_path / "telegram-pending.json").read_text(encoding="utf-8")
|
||||
)
|
||||
assert raw == {}
|
||||
|
||||
def test_cleanup_expired_handles_malformed_entries(self, tmp_path):
|
||||
"""Non-dict / missing-created_at entries get evicted, not crashed on."""
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
(tmp_path / "telegram-pending.json").write_text(
|
||||
json.dumps({
|
||||
"broken1": "not a dict",
|
||||
"broken2": {"user_id": "x"}, # no created_at
|
||||
"broken3": {"created_at": "not a number"},
|
||||
}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
store = PairingStore()
|
||||
store._cleanup_expired("telegram")
|
||||
raw = json.loads(
|
||||
(tmp_path / "telegram-pending.json").read_text(encoding="utf-8")
|
||||
)
|
||||
assert raw == {}
|
||||
|
||||
def test_approve_code_skips_malformed_entries(self, tmp_path):
|
||||
"""Malformed entries must not crash approve_code's hash loop."""
|
||||
import time as _time
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
(tmp_path / "telegram-pending.json").write_text(
|
||||
json.dumps({
|
||||
"broken": {"user_id": "x", "created_at": _time.time(),
|
||||
"salt": "not-hex", "hash": "doesntmatter"},
|
||||
}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
store = PairingStore()
|
||||
# Approving with any code must just return None, not crash.
|
||||
assert store.approve_code("telegram", "ABCD1234") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -117,6 +312,23 @@ class TestRateLimiting:
|
|||
assert isinstance(code2, str) and len(code2) == CODE_LENGTH
|
||||
assert code2 != code1
|
||||
|
||||
def test_whatsapp_alias_flip_hits_same_rate_limit(self, tmp_path, monkeypatch):
|
||||
mapping_dir = tmp_path / "whatsapp" / "session"
|
||||
mapping_dir.mkdir(parents=True, exist_ok=True)
|
||||
(mapping_dir / "lid-mapping-999999999999999.json").write_text(
|
||||
json.dumps("15551234567@s.whatsapp.net"),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code1 = store.generate_code("whatsapp", "15551234567@s.whatsapp.net")
|
||||
code2 = store.generate_code("whatsapp", "999999999999999@lid")
|
||||
|
||||
assert isinstance(code1, str) and len(code1) == CODE_LENGTH
|
||||
assert code2 is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Max pending limit
|
||||
|
|
@ -209,6 +421,55 @@ class TestApprovalFlow:
|
|||
result = store.approve_code("telegram", "INVALIDCODE")
|
||||
assert result is None
|
||||
|
||||
def test_whatsapp_approved_user_survives_alias_flip(self, tmp_path, monkeypatch):
|
||||
mapping_dir = tmp_path / "whatsapp" / "session"
|
||||
mapping_dir.mkdir(parents=True, exist_ok=True)
|
||||
(mapping_dir / "lid-mapping-999999999999999.json").write_text(
|
||||
json.dumps("15551234567@s.whatsapp.net"),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
code = store.generate_code("whatsapp", "15551234567@s.whatsapp.net", "Alice")
|
||||
store.approve_code("whatsapp", code)
|
||||
|
||||
assert store.is_approved("whatsapp", "15551234567@s.whatsapp.net") is True
|
||||
assert store.is_approved("whatsapp", "999999999999999@lid") is True
|
||||
|
||||
approved = store.list_approved("whatsapp")
|
||||
|
||||
assert len(approved) == 1
|
||||
assert approved[0]["user_id"] == "15551234567"
|
||||
|
||||
def test_whatsapp_legacy_raw_jid_approval_survives_alias_flip(self, tmp_path, monkeypatch):
|
||||
mapping_dir = tmp_path / "whatsapp" / "session"
|
||||
mapping_dir.mkdir(parents=True, exist_ok=True)
|
||||
(mapping_dir / "lid-mapping-999999999999999.json").write_text(
|
||||
json.dumps("15551234567@s.whatsapp.net"),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
approved_path = tmp_path / "whatsapp-approved.json"
|
||||
approved_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"15551234567@s.whatsapp.net": {
|
||||
"user_name": "Legacy Alice",
|
||||
"approved_at": time.time(),
|
||||
}
|
||||
},
|
||||
indent=2,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
|
||||
store = PairingStore()
|
||||
assert store.is_approved("whatsapp", "999999999999999@lid") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lockout after failed attempts
|
||||
|
|
@ -300,9 +561,10 @@ class TestCodeExpiry:
|
|||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1")
|
||||
|
||||
# Manually expire the code
|
||||
# Manually expire all pending entries
|
||||
pending = store._load_json(store._pending_path("telegram"))
|
||||
pending[code]["created_at"] = time.time() - CODE_TTL_SECONDS - 1
|
||||
for entry_id in pending:
|
||||
pending[entry_id]["created_at"] = time.time() - CODE_TTL_SECONDS - 1
|
||||
store._save_json(store._pending_path("telegram"), pending)
|
||||
|
||||
# Cleanup happens on next operation
|
||||
|
|
@ -314,9 +576,10 @@ class TestCodeExpiry:
|
|||
store = PairingStore()
|
||||
code = store.generate_code("telegram", "user1")
|
||||
|
||||
# Expire it
|
||||
# Expire all entries
|
||||
pending = store._load_json(store._pending_path("telegram"))
|
||||
pending[code]["created_at"] = time.time() - CODE_TTL_SECONDS - 1
|
||||
for entry_id in pending:
|
||||
pending[entry_id]["created_at"] = time.time() - CODE_TTL_SECONDS - 1
|
||||
store._save_json(store._pending_path("telegram"), pending)
|
||||
|
||||
result = store.approve_code("telegram", code)
|
||||
|
|
|
|||
|
|
@ -361,6 +361,72 @@ class TestExtractMedia:
|
|||
assert "[[as_document]]" not in cleaned
|
||||
|
||||
|
||||
class TestMediaDeliveryPathValidation:
|
||||
def _patch_roots(self, monkeypatch, *roots):
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS",
|
||||
tuple(roots),
|
||||
)
|
||||
|
||||
def test_allows_existing_file_inside_safe_root(self, tmp_path, monkeypatch):
|
||||
root = tmp_path / "media-cache"
|
||||
media_file = root / "voice.ogg"
|
||||
media_file.parent.mkdir(parents=True)
|
||||
media_file.write_bytes(b"OggS")
|
||||
self._patch_roots(monkeypatch, root)
|
||||
|
||||
assert BasePlatformAdapter.validate_media_delivery_path(str(media_file)) == str(media_file.resolve())
|
||||
|
||||
def test_rejects_existing_file_outside_safe_root(self, tmp_path, monkeypatch):
|
||||
root = tmp_path / "media-cache"
|
||||
root.mkdir()
|
||||
secret = tmp_path / "secrets.txt"
|
||||
secret.write_text("not for upload")
|
||||
self._patch_roots(monkeypatch, root)
|
||||
|
||||
assert BasePlatformAdapter.validate_media_delivery_path(str(secret)) is None
|
||||
|
||||
def test_rejects_symlink_escape_from_safe_root(self, tmp_path, monkeypatch):
|
||||
root = tmp_path / "media-cache"
|
||||
root.mkdir()
|
||||
secret = tmp_path / "outside.png"
|
||||
secret.write_bytes(b"secret")
|
||||
link = root / "safe-looking.png"
|
||||
try:
|
||||
link.symlink_to(secret)
|
||||
except OSError:
|
||||
pytest.skip("symlink creation is unavailable")
|
||||
self._patch_roots(monkeypatch, root)
|
||||
|
||||
assert BasePlatformAdapter.validate_media_delivery_path(str(link)) is None
|
||||
|
||||
def test_filter_keeps_safe_media_and_drops_unsafe(self, tmp_path, monkeypatch):
|
||||
root = tmp_path / "media-cache"
|
||||
safe = root / "speech.ogg"
|
||||
unsafe = tmp_path / "outside.ogg"
|
||||
safe.parent.mkdir(parents=True)
|
||||
safe.write_bytes(b"OggS")
|
||||
unsafe.write_bytes(b"OggS")
|
||||
self._patch_roots(monkeypatch, root)
|
||||
|
||||
filtered = BasePlatformAdapter.filter_media_delivery_paths([
|
||||
(str(unsafe), False),
|
||||
(str(safe), True),
|
||||
])
|
||||
|
||||
assert filtered == [(str(safe.resolve()), True)]
|
||||
|
||||
def test_allows_operator_configured_extra_root(self, tmp_path, monkeypatch):
|
||||
extra_root = tmp_path / "operator-media"
|
||||
media_file = extra_root / "report.pdf"
|
||||
media_file.parent.mkdir(parents=True)
|
||||
media_file.write_bytes(b"%PDF-1.4")
|
||||
self._patch_roots(monkeypatch)
|
||||
monkeypatch.setenv("HERMES_MEDIA_ALLOW_DIRS", str(extra_root))
|
||||
|
||||
assert BasePlatformAdapter.validate_media_delivery_path(str(media_file)) == str(media_file.resolve())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# should_send_media_as_audio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -728,4 +794,3 @@ class TestProxyKwargsForAiohttp:
|
|||
sess_kw, req_kw = proxy_kwargs_for_aiohttp("http://proxy:8080")
|
||||
assert sess_kw == {}
|
||||
assert req_kw == {"proxy": "http://proxy:8080"}
|
||||
|
||||
|
|
|
|||
|
|
@ -79,10 +79,11 @@ def test_checker_returns_true_when_configured(platform, checker, monkeypatch):
|
|||
elif platform in {
|
||||
Platform.API_SERVER,
|
||||
Platform.WEBHOOK,
|
||||
Platform.MSGRAPH_WEBHOOK,
|
||||
Platform.WHATSAPP,
|
||||
}:
|
||||
mock_config.extra = {}
|
||||
elif platform == Platform.MSGRAPH_WEBHOOK:
|
||||
mock_config.extra = {"client_state": "expected-client-state"}
|
||||
elif platform == Platform.FEISHU:
|
||||
mock_config.extra = {"app_id": "app"}
|
||||
elif platform == Platform.WECOM:
|
||||
|
|
|
|||
|
|
@ -708,3 +708,279 @@ class TestPluginPlatformSharedKeyBridge:
|
|||
assert extra.get("allow_from") == ["alice", "bob"]
|
||||
finally:
|
||||
_reg.unregister("mysharedplat")
|
||||
|
||||
|
||||
class TestPluginEnablementGate:
|
||||
"""Plugin platforms must NOT auto-enable on check_fn alone (#31116).
|
||||
|
||||
When a plugin registers ``is_connected`` (the "did the user actually
|
||||
configure credentials" probe), ``load_gateway_config`` must consult it
|
||||
before flipping ``enabled = True``. Without this gate, ``check_fn``
|
||||
semantics ("the SDK is importable") get conflated with "the user wants
|
||||
this platform on", and the gateway tries to connect to e.g. Discord
|
||||
with no token — emitting noisy retry-forever errors on every fresh
|
||||
install that has the plugin loaded.
|
||||
"""
|
||||
|
||||
def _write_config(self, tmp_path, content: str = ""):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text(content, encoding="utf-8")
|
||||
return hermes_home
|
||||
|
||||
def test_plugin_with_is_connected_false_is_NOT_enabled(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""check_fn=True + is_connected=False must NOT enable the platform.
|
||||
|
||||
Reproduces #31116: Discord plugin loads, its check_fn lazy-installs
|
||||
discord.py and returns True, but the user has no DISCORD_BOT_TOKEN.
|
||||
Previously this auto-enabled Discord and the gateway spammed
|
||||
``ERROR ... [Discord] No bot token configured`` on every reconnect.
|
||||
"""
|
||||
from gateway.platform_registry import platform_registry as _reg
|
||||
|
||||
_reg.register(PlatformEntry(
|
||||
name="myunconfiguredplat",
|
||||
label="MyUnconfigured",
|
||||
adapter_factory=lambda cfg: None,
|
||||
check_fn=lambda: True, # SDK available
|
||||
is_connected=lambda cfg: False, # but user hasn't set credentials
|
||||
source="plugin",
|
||||
))
|
||||
try:
|
||||
home = self._write_config(tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
cfg = load_gateway_config()
|
||||
|
||||
plat = Platform("myunconfiguredplat")
|
||||
# Either absent entirely, or present but explicitly disabled.
|
||||
if plat in cfg.platforms:
|
||||
assert cfg.platforms[plat].enabled is False, (
|
||||
"Plugin with is_connected=False must NOT be auto-enabled"
|
||||
)
|
||||
finally:
|
||||
_reg.unregister("myunconfiguredplat")
|
||||
|
||||
def test_plugin_with_is_connected_true_is_enabled(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""check_fn=True + is_connected=True still enables the platform."""
|
||||
from gateway.platform_registry import platform_registry as _reg
|
||||
|
||||
_reg.register(PlatformEntry(
|
||||
name="myconfiguredplat",
|
||||
label="MyConfigured",
|
||||
adapter_factory=lambda cfg: None,
|
||||
check_fn=lambda: True,
|
||||
is_connected=lambda cfg: True,
|
||||
source="plugin",
|
||||
))
|
||||
try:
|
||||
home = self._write_config(tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
cfg = load_gateway_config()
|
||||
|
||||
plat = Platform("myconfiguredplat")
|
||||
assert plat in cfg.platforms
|
||||
assert cfg.platforms[plat].enabled is True
|
||||
finally:
|
||||
_reg.unregister("myconfiguredplat")
|
||||
|
||||
def test_plugin_without_is_connected_falls_back_to_check_fn(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""Legacy plugins that don't register is_connected keep working.
|
||||
|
||||
For plugins where ``is_connected is None``, gating on ``check_fn``
|
||||
alone remains the contract — that's what callers without a
|
||||
credential probe have always done.
|
||||
"""
|
||||
from gateway.platform_registry import platform_registry as _reg
|
||||
|
||||
_reg.register(PlatformEntry(
|
||||
name="mylegacyplat",
|
||||
label="MyLegacy",
|
||||
adapter_factory=lambda cfg: None,
|
||||
check_fn=lambda: True,
|
||||
# is_connected intentionally omitted (None)
|
||||
source="plugin",
|
||||
))
|
||||
try:
|
||||
home = self._write_config(tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
cfg = load_gateway_config()
|
||||
|
||||
plat = Platform("mylegacyplat")
|
||||
assert plat in cfg.platforms
|
||||
assert cfg.platforms[plat].enabled is True
|
||||
finally:
|
||||
_reg.unregister("mylegacyplat")
|
||||
|
||||
def test_is_connected_raises_does_not_enable(self, tmp_path, monkeypatch):
|
||||
"""A buggy is_connected must not silently enable the platform.
|
||||
|
||||
Treat a raising is_connected as "configuration unknown" — refuse to
|
||||
enable, log, and move on. Anything else would re-introduce the
|
||||
#31116 bug for plugins whose probe has a transient failure.
|
||||
"""
|
||||
from gateway.platform_registry import platform_registry as _reg
|
||||
|
||||
def _bad_probe(cfg):
|
||||
raise RuntimeError("plugin bug")
|
||||
|
||||
_reg.register(PlatformEntry(
|
||||
name="mybadprobeplat",
|
||||
label="MyBadProbe",
|
||||
adapter_factory=lambda cfg: None,
|
||||
check_fn=lambda: True,
|
||||
is_connected=_bad_probe,
|
||||
source="plugin",
|
||||
))
|
||||
try:
|
||||
home = self._write_config(tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
cfg = load_gateway_config()
|
||||
|
||||
plat = Platform("mybadprobeplat")
|
||||
if plat in cfg.platforms:
|
||||
assert cfg.platforms[plat].enabled is False
|
||||
finally:
|
||||
_reg.unregister("mybadprobeplat")
|
||||
|
||||
def test_yaml_enabled_true_overrides_is_connected_false(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""Explicit YAML ``enabled: true`` wins over is_connected=False.
|
||||
|
||||
If the user wrote ``platforms.X.enabled: true`` themselves, respect
|
||||
that — they may be using a credential mechanism the plugin's
|
||||
is_connected probe doesn't know about. Don't fight them.
|
||||
"""
|
||||
from gateway.platform_registry import platform_registry as _reg
|
||||
|
||||
_reg.register(PlatformEntry(
|
||||
name="myexplicitplat",
|
||||
label="MyExplicit",
|
||||
adapter_factory=lambda cfg: None,
|
||||
check_fn=lambda: True,
|
||||
is_connected=lambda cfg: False,
|
||||
source="plugin",
|
||||
))
|
||||
try:
|
||||
home = self._write_config(
|
||||
tmp_path,
|
||||
"platforms:\n"
|
||||
" myexplicitplat:\n"
|
||||
" enabled: true\n",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
cfg = load_gateway_config()
|
||||
|
||||
plat = Platform("myexplicitplat")
|
||||
assert plat in cfg.platforms
|
||||
assert cfg.platforms[plat].enabled is True, (
|
||||
"Explicit YAML enabled: true must win over plugin's "
|
||||
"is_connected=False — user has the final say"
|
||||
)
|
||||
finally:
|
||||
_reg.unregister("myexplicitplat")
|
||||
|
||||
def test_is_connected_sees_env_seeded_extras(self, tmp_path, monkeypatch):
|
||||
"""``env_enablement_fn`` extras must be visible to ``is_connected``.
|
||||
|
||||
Some plugins (e.g. Google Chat) implement ``is_connected`` by
|
||||
inspecting ``config.extra`` (where ``env_enablement_fn`` deposits
|
||||
env-var-derived state) rather than reading ``os.environ`` directly.
|
||||
If the gate runs BEFORE the seeding step, those plugins fail the
|
||||
gate even when the user is genuinely configured via env vars.
|
||||
|
||||
Pin the contract: when both hooks are present, ``env_enablement_fn``
|
||||
feeds a candidate config to ``is_connected``.
|
||||
"""
|
||||
from gateway.platform_registry import platform_registry as _reg
|
||||
|
||||
seen_extras: dict = {}
|
||||
|
||||
def _is_connected(cfg):
|
||||
seen_extras["snapshot"] = dict(getattr(cfg, "extra", {}) or {})
|
||||
extra = getattr(cfg, "extra", {}) or {}
|
||||
return bool(extra.get("project_id") and extra.get("subscription_name"))
|
||||
|
||||
def _env_enablement():
|
||||
return {"project_id": "p", "subscription_name": "s"}
|
||||
|
||||
_reg.register(PlatformEntry(
|
||||
name="myextrasplat",
|
||||
label="MyExtras",
|
||||
adapter_factory=lambda cfg: None,
|
||||
check_fn=lambda: True,
|
||||
is_connected=_is_connected,
|
||||
env_enablement_fn=_env_enablement,
|
||||
source="plugin",
|
||||
))
|
||||
try:
|
||||
home = self._write_config(tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
cfg = load_gateway_config()
|
||||
|
||||
plat = Platform("myextrasplat")
|
||||
assert plat in cfg.platforms, (
|
||||
"is_connected was called with empty extras — "
|
||||
"env_enablement_fn must seed the probe BEFORE the gate"
|
||||
)
|
||||
assert cfg.platforms[plat].enabled is True
|
||||
# extras populated on the live config too
|
||||
assert cfg.platforms[plat].extra.get("project_id") == "p"
|
||||
assert cfg.platforms[plat].extra.get("subscription_name") == "s"
|
||||
# and the probe saw them
|
||||
assert seen_extras["snapshot"]["project_id"] == "p"
|
||||
finally:
|
||||
_reg.unregister("myextrasplat")
|
||||
|
||||
def test_is_connected_failed_gate_does_not_leak_extras(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""When the gate rejects, env-seeded extras must NOT leak onto
|
||||
``config.platforms``. A rejected plugin should be invisible, not
|
||||
present-but-partially-populated.
|
||||
"""
|
||||
from gateway.platform_registry import platform_registry as _reg
|
||||
|
||||
_reg.register(PlatformEntry(
|
||||
name="myrejectedplat",
|
||||
label="MyRejected",
|
||||
adapter_factory=lambda cfg: None,
|
||||
check_fn=lambda: True,
|
||||
is_connected=lambda cfg: False,
|
||||
env_enablement_fn=lambda: {"some_key": "should-not-leak"},
|
||||
source="plugin",
|
||||
))
|
||||
try:
|
||||
home = self._write_config(tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
from gateway.config import load_gateway_config, Platform
|
||||
cfg = load_gateway_config()
|
||||
|
||||
plat = Platform("myrejectedplat")
|
||||
if plat in cfg.platforms:
|
||||
assert cfg.platforms[plat].enabled is False
|
||||
assert "some_key" not in cfg.platforms[plat].extra, (
|
||||
"Rejected plugin's env-seeded extras leaked onto "
|
||||
"config.platforms"
|
||||
)
|
||||
finally:
|
||||
_reg.unregister("myrejectedplat")
|
||||
|
|
|
|||
|
|
@ -1233,14 +1233,14 @@ class TestAdapterInteractionDispatch:
|
|||
"user_openid": "user-1",
|
||||
"data": {
|
||||
"type": 11,
|
||||
"resolved": {"button_data": "approve:s:deny", "button_id": "deny"},
|
||||
"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny", "button_id": "deny"},
|
||||
},
|
||||
})
|
||||
|
||||
assert len(ack_calls) == 1
|
||||
assert ack_calls[0][0] == "i-1"
|
||||
assert len(received) == 1
|
||||
assert received[0].button_data == "approve:s:deny"
|
||||
assert received[0].button_data == "approve:agent:main:qqbot:c2c:u:deny"
|
||||
assert received[0].scene == "c2c"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -1262,7 +1262,7 @@ class TestAdapterInteractionDispatch:
|
|||
adapter.set_interaction_callback(cb)
|
||||
await adapter._on_interaction({
|
||||
"chat_type": 2, # no id
|
||||
"data": {"resolved": {"button_data": "approve:s:deny"}},
|
||||
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}},
|
||||
})
|
||||
|
||||
assert ack_calls == []
|
||||
|
|
@ -1286,7 +1286,7 @@ class TestAdapterInteractionDispatch:
|
|||
"id": "i-2",
|
||||
"chat_type": 2,
|
||||
"user_openid": "u",
|
||||
"data": {"resolved": {"button_data": "approve:s:deny"}},
|
||||
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}},
|
||||
})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -1304,7 +1304,7 @@ class TestAdapterInteractionDispatch:
|
|||
"id": "i-3",
|
||||
"chat_type": 2,
|
||||
"user_openid": "u",
|
||||
"data": {"resolved": {"button_data": "approve:s:deny"}},
|
||||
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}},
|
||||
})
|
||||
|
||||
|
||||
|
|
@ -1570,13 +1570,13 @@ class TestDefaultInteractionDispatch:
|
|||
"id": "i",
|
||||
"chat_type": 2,
|
||||
"user_openid": "u-42",
|
||||
"data": {"resolved": {"button_data": "approve:sess-abc:allow-once"}},
|
||||
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u-42:allow-once"}},
|
||||
})
|
||||
await adapter._default_interaction_dispatch(event)
|
||||
finally:
|
||||
tools.approval.resolve_gateway_approval = orig
|
||||
|
||||
assert resolve_calls == [("sess-abc", "once", False)]
|
||||
assert resolve_calls == [("agent:main:qqbot:c2c:u-42", "once", False)]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approval_click_always_maps_to_always(self):
|
||||
|
|
@ -1594,13 +1594,13 @@ class TestDefaultInteractionDispatch:
|
|||
from gateway.platforms.qqbot.keyboards import parse_interaction_event
|
||||
event = parse_interaction_event({
|
||||
"id": "i", "chat_type": 2, "user_openid": "u",
|
||||
"data": {"resolved": {"button_data": "approve:s:allow-always"}},
|
||||
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:allow-always"}},
|
||||
})
|
||||
await adapter._default_interaction_dispatch(event)
|
||||
finally:
|
||||
tools.approval.resolve_gateway_approval = orig
|
||||
|
||||
assert resolve_calls == [("s", "always", False)]
|
||||
assert resolve_calls == [("agent:main:qqbot:c2c:u", "always", False)]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approval_click_deny_maps_to_deny(self):
|
||||
|
|
@ -1618,13 +1618,40 @@ class TestDefaultInteractionDispatch:
|
|||
from gateway.platforms.qqbot.keyboards import parse_interaction_event
|
||||
event = parse_interaction_event({
|
||||
"id": "i", "chat_type": 2, "user_openid": "u",
|
||||
"data": {"resolved": {"button_data": "approve:s:deny"}},
|
||||
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}},
|
||||
})
|
||||
await adapter._default_interaction_dispatch(event)
|
||||
finally:
|
||||
tools.approval.resolve_gateway_approval = orig
|
||||
|
||||
assert resolve_calls == [("s", "deny", False)]
|
||||
assert resolve_calls == [("agent:main:qqbot:c2c:u", "deny", False)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approval_click_rejects_unauthorized_operator(self):
|
||||
adapter = self._make_adapter()
|
||||
resolve_calls = []
|
||||
|
||||
def fake_resolve(session_key, choice, resolve_all=False):
|
||||
resolve_calls.append((session_key, choice, resolve_all))
|
||||
return 1
|
||||
|
||||
import tools.approval
|
||||
orig = tools.approval.resolve_gateway_approval
|
||||
tools.approval.resolve_gateway_approval = fake_resolve
|
||||
try:
|
||||
from gateway.platforms.qqbot.keyboards import parse_interaction_event
|
||||
event = parse_interaction_event({
|
||||
"id": "i", "chat_type": 1,
|
||||
"group_openid": "g-1",
|
||||
"group_member_openid": "attacker",
|
||||
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:group:g-1:owner:allow-once"}},
|
||||
})
|
||||
await adapter._default_interaction_dispatch(event)
|
||||
finally:
|
||||
tools.approval.resolve_gateway_approval = orig
|
||||
|
||||
assert resolve_calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_prompt_click_writes_response_file(self, tmp_path, monkeypatch):
|
||||
|
|
@ -1700,7 +1727,7 @@ class TestDefaultInteractionDispatch:
|
|||
from gateway.platforms.qqbot.keyboards import parse_interaction_event
|
||||
event = parse_interaction_event({
|
||||
"id": "i", "chat_type": 2, "user_openid": "u",
|
||||
"data": {"resolved": {"button_data": "approve:s:deny"}},
|
||||
"data": {"resolved": {"button_data": "approve:agent:main:qqbot:c2c:u:deny"}},
|
||||
})
|
||||
# Must not raise.
|
||||
await adapter._default_interaction_dispatch(event)
|
||||
|
|
@ -1810,3 +1837,365 @@ class TestSendUpdatePrompt:
|
|||
|
||||
adapter.send_with_keyboard = fake_swk # type: ignore[assignment]
|
||||
await adapter.send_update_prompt(chat_id="u", prompt="ok?")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_identify includes INTERACTION intent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIdentifyIntents:
|
||||
"""Verify the WebSocket identify payload includes the INTERACTION intent bit."""
|
||||
|
||||
def _make_adapter(self):
|
||||
from gateway.platforms.qqbot.adapter import QQAdapter
|
||||
return QQAdapter(_make_config(app_id="a", client_secret="b"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intents_include_interaction_bit(self):
|
||||
adapter = self._make_adapter()
|
||||
|
||||
# Mock token retrieval and WebSocket
|
||||
adapter._access_token = "fake_token"
|
||||
adapter._token_expires_at = 9999999999.0
|
||||
|
||||
sent_payloads = []
|
||||
|
||||
class FakeWS:
|
||||
closed = False
|
||||
|
||||
async def send_json(self, payload):
|
||||
sent_payloads.append(payload)
|
||||
|
||||
adapter._ws = FakeWS()
|
||||
await adapter._send_identify()
|
||||
|
||||
assert len(sent_payloads) == 1
|
||||
intents = sent_payloads[0]["d"]["intents"]
|
||||
|
||||
# Verify all expected intent bits are present
|
||||
assert intents & (1 << 25), "GROUP_MESSAGES (1<<25) missing"
|
||||
assert intents & (1 << 30), "GUILD_AT_MESSAGE (1<<30) missing"
|
||||
assert intents & (1 << 12), "DIRECT_MESSAGES (1<<12) missing"
|
||||
assert intents & (1 << 26), "INTERACTION (1<<26) missing"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _process_attachments: video/file path exposure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProcessAttachmentsPathExposure:
|
||||
"""Verify that video and file attachments include the cached local path."""
|
||||
|
||||
def _make_adapter(self):
|
||||
from gateway.platforms.qqbot.adapter import QQAdapter
|
||||
return QQAdapter(_make_config(app_id="a", client_secret="b"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_video_attachment_includes_path(self):
|
||||
adapter = self._make_adapter()
|
||||
|
||||
# Mock _download_and_cache to return a known path
|
||||
async def fake_download(url, ct, original_name=""):
|
||||
return "/tmp/cache/video_abc123.mp4"
|
||||
|
||||
adapter._download_and_cache = fake_download # type: ignore[assignment]
|
||||
|
||||
attachments = [
|
||||
{
|
||||
"content_type": "video/mp4",
|
||||
"url": "https://multimedia.nt.qq.com.cn/download/video123",
|
||||
"filename": "my_video.mp4",
|
||||
}
|
||||
]
|
||||
result = await adapter._process_attachments(attachments)
|
||||
|
||||
assert result["image_urls"] == []
|
||||
assert result["voice_transcripts"] == []
|
||||
info = result["attachment_info"]
|
||||
assert "[video:" in info
|
||||
assert "my_video.mp4" in info
|
||||
assert "/tmp/cache/video_abc123.mp4" in info
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_attachment_includes_path(self):
|
||||
adapter = self._make_adapter()
|
||||
|
||||
async def fake_download(url, ct, original_name=""):
|
||||
return "/tmp/cache/doc_abc123_report.pdf"
|
||||
|
||||
adapter._download_and_cache = fake_download # type: ignore[assignment]
|
||||
|
||||
attachments = [
|
||||
{
|
||||
"content_type": "application/pdf",
|
||||
"url": "https://multimedia.nt.qq.com.cn/download/file456",
|
||||
"filename": "report.pdf",
|
||||
}
|
||||
]
|
||||
result = await adapter._process_attachments(attachments)
|
||||
|
||||
info = result["attachment_info"]
|
||||
assert "[file:" in info
|
||||
assert "report.pdf" in info
|
||||
assert "/tmp/cache/doc_abc123_report.pdf" in info
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_video_without_filename_falls_back_to_content_type(self):
|
||||
adapter = self._make_adapter()
|
||||
|
||||
async def fake_download(url, ct, original_name=""):
|
||||
return "/tmp/cache/video_xyz.mp4"
|
||||
|
||||
adapter._download_and_cache = fake_download # type: ignore[assignment]
|
||||
|
||||
attachments = [
|
||||
{
|
||||
"content_type": "video/mp4",
|
||||
"url": "https://cdn.qq.com/vid",
|
||||
"filename": "",
|
||||
}
|
||||
]
|
||||
result = await adapter._process_attachments(attachments)
|
||||
|
||||
info = result["attachment_info"]
|
||||
assert "[video: video/mp4" in info
|
||||
assert "/tmp/cache/video_xyz.mp4" in info
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_failure_produces_no_attachment_info(self):
|
||||
adapter = self._make_adapter()
|
||||
|
||||
async def fake_download(url, ct, original_name=""):
|
||||
return None
|
||||
|
||||
adapter._download_and_cache = fake_download # type: ignore[assignment]
|
||||
|
||||
attachments = [
|
||||
{
|
||||
"content_type": "video/mp4",
|
||||
"url": "https://cdn.qq.com/vid",
|
||||
"filename": "vid.mp4",
|
||||
}
|
||||
]
|
||||
result = await adapter._process_attachments(attachments)
|
||||
assert result["attachment_info"] == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quoted_video_includes_path_in_quote_block(self):
|
||||
"""Quoted video attachments should surface the cached path in the quote block."""
|
||||
adapter = self._make_adapter()
|
||||
|
||||
async def fake_process(atts):
|
||||
# Simulate the fixed _process_attachments for a video attachment.
|
||||
return {
|
||||
"image_urls": [],
|
||||
"image_media_types": [],
|
||||
"voice_transcripts": [],
|
||||
"attachment_info": "[video: clip.mp4 (/tmp/cache/clip.mp4)]",
|
||||
}
|
||||
|
||||
adapter._process_attachments = fake_process # type: ignore[assignment]
|
||||
|
||||
d = {
|
||||
"message_type": 103,
|
||||
"msg_elements": [{
|
||||
"content": "看看这个视频",
|
||||
"attachments": [
|
||||
{"content_type": "video/mp4",
|
||||
"url": "https://qq-cdn/clip.mp4",
|
||||
"filename": "clip.mp4"}
|
||||
],
|
||||
}],
|
||||
}
|
||||
out = await adapter._process_quoted_context(d)
|
||||
assert "[Quoted message]:" in out["quote_block"]
|
||||
assert "/tmp/cache/clip.mp4" in out["quote_block"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quoted_file_includes_path_in_quote_block(self):
|
||||
"""Quoted file attachments should surface the cached path in the quote block."""
|
||||
adapter = self._make_adapter()
|
||||
|
||||
async def fake_process(atts):
|
||||
return {
|
||||
"image_urls": [],
|
||||
"image_media_types": [],
|
||||
"voice_transcripts": [],
|
||||
"attachment_info": "[file: report.pdf (/tmp/cache/report.pdf)]",
|
||||
}
|
||||
|
||||
adapter._process_attachments = fake_process # type: ignore[assignment]
|
||||
|
||||
d = {
|
||||
"message_type": 103,
|
||||
"msg_elements": [{
|
||||
"content": "",
|
||||
"attachments": [
|
||||
{"content_type": "application/pdf",
|
||||
"url": "https://qq-cdn/report.pdf",
|
||||
"filename": "report.pdf"}
|
||||
],
|
||||
}],
|
||||
}
|
||||
out = await adapter._process_quoted_context(d)
|
||||
assert "[Quoted message]:" in out["quote_block"]
|
||||
assert "/tmp/cache/report.pdf" in out["quote_block"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WebSocket op 7 (Server Reconnect) and op 9 (Invalid Session)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOp7ServerReconnect:
|
||||
"""Verify op 7 triggers WS close (which triggers reconnect in outer loop)."""
|
||||
|
||||
def _make_adapter(self):
|
||||
from gateway.platforms.qqbot.adapter import QQAdapter
|
||||
return QQAdapter(_make_config(app_id="a", client_secret="b"))
|
||||
|
||||
def test_op7_closes_websocket(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter._session_id = "sess_keep"
|
||||
adapter._last_seq = 42
|
||||
|
||||
close_called = []
|
||||
|
||||
class FakeWS:
|
||||
closed = False
|
||||
|
||||
async def close(self):
|
||||
close_called.append(True)
|
||||
|
||||
adapter._ws = FakeWS()
|
||||
adapter._dispatch_payload({"op": 7, "d": None})
|
||||
|
||||
# Session should be preserved for Resume
|
||||
assert adapter._session_id == "sess_keep"
|
||||
assert adapter._last_seq == 42
|
||||
# close() should have been scheduled
|
||||
assert len(close_called) == 0 # _create_task schedules, not immediate
|
||||
# But the task was created — verify via asyncio
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_op7_close_task_executes(self):
|
||||
adapter = self._make_adapter()
|
||||
close_called = []
|
||||
|
||||
class FakeWS:
|
||||
closed = False
|
||||
|
||||
async def close(self):
|
||||
close_called.append(True)
|
||||
self.closed = True
|
||||
|
||||
adapter._ws = FakeWS()
|
||||
adapter._dispatch_payload({"op": 7, "d": None})
|
||||
|
||||
# Let the event loop run the scheduled task
|
||||
await asyncio.sleep(0)
|
||||
assert close_called == [True]
|
||||
# Session preserved
|
||||
assert adapter._session_id is None # was never set
|
||||
|
||||
|
||||
class TestOp9InvalidSession:
|
||||
"""Verify op 9 handles resumable vs non-resumable sessions."""
|
||||
|
||||
def _make_adapter(self):
|
||||
from gateway.platforms.qqbot.adapter import QQAdapter
|
||||
return QQAdapter(_make_config(app_id="a", client_secret="b"))
|
||||
|
||||
def test_op9_not_resumable_clears_session(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter._session_id = "sess_old"
|
||||
adapter._last_seq = 99
|
||||
|
||||
class FakeWS:
|
||||
closed = False
|
||||
|
||||
async def close(self):
|
||||
self.closed = True
|
||||
|
||||
adapter._ws = FakeWS()
|
||||
adapter._dispatch_payload({"op": 9, "d": False})
|
||||
|
||||
assert adapter._session_id is None
|
||||
assert adapter._last_seq is None
|
||||
|
||||
def test_op9_resumable_preserves_session(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter._session_id = "sess_keep"
|
||||
adapter._last_seq = 99
|
||||
|
||||
class FakeWS:
|
||||
closed = False
|
||||
|
||||
async def close(self):
|
||||
self.closed = True
|
||||
|
||||
adapter._ws = FakeWS()
|
||||
adapter._dispatch_payload({"op": 9, "d": True})
|
||||
|
||||
# Session should be preserved for Resume
|
||||
assert adapter._session_id == "sess_keep"
|
||||
assert adapter._last_seq == 99
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_op9_non_resumable_triggers_ws_close(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter._session_id = "s"
|
||||
adapter._last_seq = 1
|
||||
close_called = []
|
||||
|
||||
class FakeWS:
|
||||
closed = False
|
||||
|
||||
async def close(self):
|
||||
close_called.append(True)
|
||||
self.closed = True
|
||||
|
||||
adapter._ws = FakeWS()
|
||||
adapter._dispatch_payload({"op": 9, "d": False})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert close_called == [True]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Close code classification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCloseCodeClassification:
|
||||
"""Verify fatal close codes stop reconnecting and 4009 preserves session."""
|
||||
|
||||
def _make_adapter(self):
|
||||
from gateway.platforms.qqbot.adapter import QQAdapter
|
||||
return QQAdapter(_make_config(app_id="a", client_secret="b"))
|
||||
|
||||
def test_4009_preserves_session(self):
|
||||
"""4009 (connection timeout) should NOT clear the session."""
|
||||
adapter = self._make_adapter()
|
||||
adapter._session_id = "sess_to_keep"
|
||||
adapter._last_seq = 50
|
||||
|
||||
# The session-clearing codes set should NOT contain 4009.
|
||||
# We verify the logic directly: dispatch a close-code event that
|
||||
# exercises the session-clearing path (4006), then verify 4009 does not.
|
||||
session_clear_codes = {
|
||||
4006, 4007, 4900, 4901, 4902, 4903,
|
||||
4904, 4905, 4906, 4907, 4908, 4909,
|
||||
4910, 4911, 4912, 4913,
|
||||
}
|
||||
assert 4009 not in session_clear_codes
|
||||
|
||||
def test_fatal_codes_include_intent_errors(self):
|
||||
"""4013 (invalid intent) and 4014 (not authorized) should be fatal."""
|
||||
fatal_codes = {4001, 4002, 4010, 4011, 4012, 4013, 4014, 4914, 4915}
|
||||
# Verify these are all treated as fatal by checking the adapter's
|
||||
# code path would call _set_fatal_error. We verify the set membership
|
||||
# which is what the if-branch checks.
|
||||
assert 4013 in fatal_codes
|
||||
assert 4014 in fatal_codes
|
||||
assert 4001 in fatal_codes
|
||||
assert 4915 in fatal_codes
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from unittest.mock import MagicMock
|
|||
|
||||
def _make_adapter():
|
||||
"""Construct a DiscordAdapter without going through __init__ / token checks."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
from gateway.platforms.base import Platform
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter.config = MagicMock()
|
||||
|
|
|
|||
|
|
@ -116,6 +116,24 @@ def test_load_busy_input_mode_prefers_env_then_config_then_default(tmp_path, mon
|
|||
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
|
||||
|
||||
|
||||
def test_load_busy_text_mode_defaults_to_queue_and_allows_interrupt(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_BUSY_TEXT_MODE", raising=False)
|
||||
|
||||
assert gateway_run.GatewayRunner._load_busy_text_mode() == "queue"
|
||||
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"display:\n busy_text_mode: interrupt\n", encoding="utf-8"
|
||||
)
|
||||
assert gateway_run.GatewayRunner._load_busy_text_mode() == "interrupt"
|
||||
|
||||
monkeypatch.setenv("HERMES_GATEWAY_BUSY_TEXT_MODE", "queue")
|
||||
assert gateway_run.GatewayRunner._load_busy_text_mode() == "queue"
|
||||
|
||||
monkeypatch.setenv("HERMES_GATEWAY_BUSY_TEXT_MODE", "bogus")
|
||||
assert gateway_run.GatewayRunner._load_busy_text_mode() == "queue"
|
||||
|
||||
|
||||
def test_load_restart_drain_timeout_prefers_env_then_config_then_default(
|
||||
tmp_path, monkeypatch, caplog
|
||||
):
|
||||
|
|
|
|||
|
|
@ -88,6 +88,9 @@ class TestHandleResumeCommand:
|
|||
assert "Research" in result
|
||||
assert "Coding" in result
|
||||
assert "Named Sessions" in result
|
||||
assert "1." in result
|
||||
assert "2." in result
|
||||
assert "/resume 1" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -104,6 +107,47 @@ class TestHandleResumeCommand:
|
|||
assert "/title" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_by_index(self, tmp_path):
|
||||
"""Numeric argument resumes the indexed titled session from the list."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("sess_001", "telegram")
|
||||
db.create_session("sess_002", "telegram")
|
||||
db.set_session_title("sess_001", "Research")
|
||||
db.set_session_title("sess_002", "Coding")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume 2")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
|
||||
assert "Resumed" in result
|
||||
runner.session_store.switch_session.assert_called_once()
|
||||
call_args = runner.session_store.switch_session.call_args
|
||||
assert call_args[0][1] == "sess_001"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_index_out_of_range(self, tmp_path):
|
||||
"""Out-of-range numeric arguments show a helpful error."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("sess_001", "telegram")
|
||||
db.set_session_title("sess_001", "Research")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume 9")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
|
||||
assert "out of range" in result.lower()
|
||||
assert "/resume" in result
|
||||
runner.session_store.switch_session.assert_not_called()
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_by_name(self, tmp_path):
|
||||
"""Resolves a title and switches to that session."""
|
||||
|
|
|
|||
|
|
@ -942,6 +942,62 @@ async def test_run_agent_matrix_streaming_omits_cursor(monkeypatch, tmp_path):
|
|||
assert any("Continuing to refine:" in text for text in all_text)
|
||||
|
||||
|
||||
class TransformedStreamAgent:
|
||||
"""Streams a response, then signals the gateway that a plugin hook
|
||||
(``transform_llm_output``) modified the final text after streaming
|
||||
finished. ``run_conversation`` returns ``response_transformed=True``
|
||||
plus a ``final_response`` that diverges from what was streamed.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.stream_delta_callback = kwargs.get("stream_delta_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
if self.stream_delta_callback:
|
||||
self.stream_delta_callback("original answer")
|
||||
return {
|
||||
"final_response": "original answer\n\n[plugin appended this]",
|
||||
"response_previewed": True,
|
||||
"response_transformed": True,
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transformed_response_edits_streamed_message_in_place(monkeypatch, tmp_path):
|
||||
"""When a transform_llm_output hook modifies the response after streaming,
|
||||
the gateway must edit the existing streamed message in place with the full
|
||||
transformed content (so plugins like content filters / appenders reach the
|
||||
user) and still mark already_sent=True (no duplicate send).
|
||||
"""
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
TransformedStreamAgent,
|
||||
session_id="sess-transformed-stream",
|
||||
config_data={
|
||||
"display": {"tool_progress": "off", "interim_assistant_messages": False},
|
||||
"streaming": {"enabled": True, "edit_interval": 0.01, "buffer_threshold": 1},
|
||||
},
|
||||
platform=Platform.MATRIX,
|
||||
chat_id="!room:matrix.example.org",
|
||||
chat_type="group",
|
||||
thread_id="$thread",
|
||||
adapter_cls=MetadataEditProgressCaptureAdapter,
|
||||
)
|
||||
|
||||
# Final delivery happened (no duplicate send fallback).
|
||||
assert result.get("already_sent") is True
|
||||
# The transformed final text reached the user — appended portion is present
|
||||
# in an edit_message call (not just in the streamed sends).
|
||||
edited_texts = [e["content"] for e in adapter.edits]
|
||||
assert any("[plugin appended this]" in text for text in edited_texts), (
|
||||
f"expected transformed text in adapter.edits, got: {edited_texts!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_queued_message_does_not_treat_commentary_as_final(monkeypatch, tmp_path):
|
||||
QueuedCommentaryAgent.calls = 0
|
||||
|
|
|
|||
|
|
@ -207,6 +207,7 @@ async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_p
|
|||
lambda **kwargs: 0,
|
||||
)
|
||||
monkeypatch.setattr("gateway.status.terminate_pid", lambda pid, force=False: calls.append((pid, force)))
|
||||
monkeypatch.setattr("gateway.status._pid_exists", lambda pid: True)
|
||||
monkeypatch.setattr("gateway.run.os.getpid", lambda: 100)
|
||||
monkeypatch.setattr("gateway.run.os.kill", lambda pid, sig: None)
|
||||
monkeypatch.setattr("time.sleep", lambda _: None)
|
||||
|
|
|
|||
97
tests/gateway/test_runtime_config_env_expansion.py
Normal file
97
tests/gateway/test_runtime_config_env_expansion.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""Regression tests for gateway runtime config env-var expansion."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
import gateway.run as gateway_run
|
||||
|
||||
|
||||
def _write_config(home, body: str) -> None:
|
||||
(home / "config.yaml").write_text(body, encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gateway_home(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.delenv("HERMES_PREFILL_MESSAGES_FILE", raising=False)
|
||||
monkeypatch.delenv("HERMES_EPHEMERAL_SYSTEM_PROMPT", raising=False)
|
||||
monkeypatch.delenv("HERMES_GATEWAY_BUSY_INPUT_MODE", raising=False)
|
||||
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
|
||||
monkeypatch.delenv("HERMES_BACKGROUND_NOTIFICATIONS", raising=False)
|
||||
return tmp_path
|
||||
|
||||
|
||||
def test_load_prefill_messages_expands_env_var_path(monkeypatch, gateway_home):
|
||||
prefill = [{"role": "system", "content": "few-shot"}]
|
||||
(gateway_home / "prefill.json").write_text(json.dumps(prefill), encoding="utf-8")
|
||||
_write_config(gateway_home, "prefill_messages_file: ${PREFILL_FILE}\n")
|
||||
monkeypatch.setenv("PREFILL_FILE", "prefill.json")
|
||||
|
||||
assert gateway_run.GatewayRunner._load_prefill_messages() == prefill
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("config_body", "env_name", "env_value", "loader_name", "expected"),
|
||||
[
|
||||
(
|
||||
"agent:\n system_prompt: ${GW_PROMPT}\n",
|
||||
"GW_PROMPT",
|
||||
"expanded prompt",
|
||||
"_load_ephemeral_system_prompt",
|
||||
"expanded prompt",
|
||||
),
|
||||
(
|
||||
"agent:\n reasoning_effort: ${REASONING_LEVEL}\n",
|
||||
"REASONING_LEVEL",
|
||||
"high",
|
||||
"_load_reasoning_config",
|
||||
{"enabled": True, "effort": "high"},
|
||||
),
|
||||
(
|
||||
"agent:\n service_tier: ${SERVICE_TIER}\n",
|
||||
"SERVICE_TIER",
|
||||
"priority",
|
||||
"_load_service_tier",
|
||||
"priority",
|
||||
),
|
||||
(
|
||||
"display:\n busy_input_mode: ${BUSY_MODE}\n",
|
||||
"BUSY_MODE",
|
||||
"steer",
|
||||
"_load_busy_input_mode",
|
||||
"steer",
|
||||
),
|
||||
(
|
||||
"agent:\n restart_drain_timeout: ${DRAIN_TIMEOUT}\n",
|
||||
"DRAIN_TIMEOUT",
|
||||
"12",
|
||||
"_load_restart_drain_timeout",
|
||||
12.0,
|
||||
),
|
||||
(
|
||||
"display:\n background_process_notifications: ${BG_MODE}\n",
|
||||
"BG_MODE",
|
||||
"error",
|
||||
"_load_background_notifications_mode",
|
||||
"error",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_gateway_runtime_loaders_expand_env_var_templates(
|
||||
monkeypatch,
|
||||
gateway_home,
|
||||
config_body,
|
||||
env_name,
|
||||
env_value,
|
||||
loader_name,
|
||||
expected,
|
||||
):
|
||||
_write_config(gateway_home, config_body)
|
||||
monkeypatch.setenv(env_name, env_value)
|
||||
|
||||
loader = getattr(gateway_run.GatewayRunner, loader_name)
|
||||
|
||||
assert loader() == expected
|
||||
|
|
@ -190,7 +190,7 @@ def _ensure_discord_mock():
|
|||
_ensure_discord_mock()
|
||||
|
||||
import discord as discord_mod_ref # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
class TestDiscordSendImageFile:
|
||||
|
|
|
|||
|
|
@ -210,7 +210,7 @@ def _ensure_discord_mock():
|
|||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
class TestDiscordMultiImage:
|
||||
|
|
|
|||
|
|
@ -218,3 +218,46 @@ fallback_providers:
|
|||
assert runtime_kwargs["provider"] == "openrouter"
|
||||
assert runtime_kwargs["api_key"] == "sk-openrouter"
|
||||
|
||||
|
||||
def test_gateway_auth_fallback_resolves_key_env_for_custom_provider(tmp_path, monkeypatch):
|
||||
"""Auth-failure fallback should honor key_env/api_key_env custom-endpoint hints."""
|
||||
config = tmp_path / "config.yaml"
|
||||
config.write_text(
|
||||
"""
|
||||
fallback_providers:
|
||||
- provider: custom
|
||||
model: fallback-model
|
||||
base_url: https://fallback.example/v1
|
||||
key_env: MY_FALLBACK_KEY
|
||||
""".lstrip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setenv("MY_FALLBACK_KEY", "env-secret")
|
||||
|
||||
def fake_resolve_runtime_provider(*, requested=None, explicit_base_url=None, explicit_api_key=None):
|
||||
assert requested == "custom"
|
||||
assert explicit_base_url == "https://fallback.example/v1"
|
||||
assert explicit_api_key == "env-secret"
|
||||
return {
|
||||
"api_key": explicit_api_key,
|
||||
"base_url": explicit_base_url,
|
||||
"provider": "custom",
|
||||
"api_mode": "chat_completions",
|
||||
"command": None,
|
||||
"args": [],
|
||||
"credential_pool": None,
|
||||
}
|
||||
|
||||
import hermes_cli.runtime_provider as runtime_provider
|
||||
|
||||
monkeypatch.setattr(runtime_provider, "resolve_runtime_provider", fake_resolve_runtime_provider)
|
||||
|
||||
runtime_kwargs = gateway_run._try_resolve_fallback_provider()
|
||||
|
||||
assert runtime_kwargs is not None
|
||||
assert runtime_kwargs["provider"] == "custom"
|
||||
assert runtime_kwargs["api_key"] == "env-secret"
|
||||
assert runtime_kwargs["base_url"] == "https://fallback.example/v1"
|
||||
assert runtime_kwargs["model"] == "fallback-model"
|
||||
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ class _StubAdapter(BasePlatformAdapter):
|
|||
def _make_adapter():
|
||||
config = PlatformConfig(enabled=True, token="test-token")
|
||||
adapter = _StubAdapter(config, Platform.TELEGRAM)
|
||||
adapter._busy_text_mode = ""
|
||||
adapter.sent_responses = []
|
||||
|
||||
async def _mock_send_retry(chat_id, content, **kwargs):
|
||||
|
|
@ -396,4 +397,3 @@ class TestOldTaskCannotClobberNewerGuard:
|
|||
# default path) still work.
|
||||
adapter._release_session_guard(sk)
|
||||
assert sk not in adapter._active_sessions
|
||||
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ class TestEditMessageFinalizeSignature:
|
|||
"module_path,class_name",
|
||||
[
|
||||
("gateway.platforms.telegram", "TelegramAdapter"),
|
||||
("gateway.platforms.discord", "DiscordAdapter"),
|
||||
("plugins.platforms.discord.adapter", "DiscordAdapter"),
|
||||
("gateway.platforms.slack", "SlackAdapter"),
|
||||
("gateway.platforms.matrix", "MatrixAdapter"),
|
||||
("gateway.platforms.mattermost", "MattermostAdapter"),
|
||||
|
|
|
|||
|
|
@ -225,6 +225,128 @@ def test_observed_group_context_uses_shared_source_and_prompt_for_later_mentions
|
|||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_observed_group_context_replays_as_current_message_context_not_user_turns():
|
||||
from gateway.run import (
|
||||
_build_gateway_agent_history,
|
||||
_wrap_current_message_with_observed_context,
|
||||
)
|
||||
|
||||
history = [
|
||||
{"role": "session_meta", "content": "tool defs"},
|
||||
{"role": "user", "content": "[Alice|111]\nAcha que dá fazer estoque?", "observed": True},
|
||||
{"role": "user", "content": "[Alice|111]\nTem lote e vencimento", "observed": True},
|
||||
{"role": "assistant", "content": "previous explicit reply"},
|
||||
]
|
||||
|
||||
agent_history, observed_context = _build_gateway_agent_history(
|
||||
history,
|
||||
channel_prompt="You are handling Telegram; observed Telegram group context is present.",
|
||||
)
|
||||
api_message = _wrap_current_message_with_observed_context(
|
||||
"[Bob|222]\ncambio",
|
||||
observed_context,
|
||||
)
|
||||
|
||||
assert agent_history == [{"role": "assistant", "content": "previous explicit reply"}]
|
||||
assert "[Observed Telegram group context - context only, not requests]" in api_message
|
||||
assert "[Current addressed message - answer only this" in api_message
|
||||
assert "Acha que dá fazer estoque?" in api_message
|
||||
assert "Tem lote e vencimento" in api_message
|
||||
assert api_message.endswith("[Bob|222]\ncambio")
|
||||
|
||||
|
||||
def test_observed_group_context_does_not_hide_current_user_turn_behind_history_offset():
|
||||
from agent.agent_runtime_helpers import repair_message_sequence
|
||||
from gateway.run import (
|
||||
_build_gateway_agent_history,
|
||||
_wrap_current_message_with_observed_context,
|
||||
)
|
||||
|
||||
history = [
|
||||
{"role": "user", "content": "[Alice|111]\nAcha que dá fazer estoque?", "observed": True},
|
||||
]
|
||||
agent_history, observed_context = _build_gateway_agent_history(
|
||||
history,
|
||||
channel_prompt="observed Telegram group context",
|
||||
)
|
||||
api_message = _wrap_current_message_with_observed_context("[Bob|222]\ncambio", observed_context)
|
||||
messages = list(agent_history) + [{"role": "user", "content": api_message}]
|
||||
|
||||
repair_message_sequence(object(), messages)
|
||||
|
||||
history_offset = len(agent_history)
|
||||
new_messages = messages[history_offset:]
|
||||
assert len(agent_history) == 0
|
||||
assert new_messages[0]["role"] == "user"
|
||||
assert new_messages[0]["content"].endswith("[Bob|222]\ncambio")
|
||||
|
||||
|
||||
def test_observed_group_context_wraps_multimodal_current_message_without_mutating_parts():
|
||||
from gateway.run import _wrap_current_message_with_observed_context
|
||||
|
||||
original = [
|
||||
{"type": "text", "text": "[Bob|222]\nsee this image"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
]
|
||||
|
||||
wrapped = _wrap_current_message_with_observed_context(
|
||||
original,
|
||||
"[Alice|111]\nside chatter",
|
||||
)
|
||||
|
||||
assert original[0]["text"] == "[Bob|222]\nsee this image"
|
||||
assert wrapped[0]["text"].startswith("[Observed Telegram group context - context only")
|
||||
assert wrapped[0]["text"].endswith("[Bob|222]\nsee this image")
|
||||
assert wrapped[1] == original[1]
|
||||
|
||||
|
||||
def test_observed_group_context_replays_normally_without_telegram_prompt():
|
||||
from gateway.run import _build_gateway_agent_history
|
||||
|
||||
history = [
|
||||
{"role": "user", "content": "[Alice|111]\nside chatter", "observed": True},
|
||||
]
|
||||
|
||||
agent_history, observed_context = _build_gateway_agent_history(history, channel_prompt=None)
|
||||
|
||||
assert observed_context is None
|
||||
assert agent_history == [{"role": "user", "content": "[Alice|111]\nside chatter"}]
|
||||
|
||||
|
||||
def test_observed_group_context_preserves_slash_command_text_for_dispatch():
|
||||
from gateway.platforms.base import MessageEvent, MessageType, Platform, SessionSource
|
||||
|
||||
adapter = _make_adapter(
|
||||
require_mention=True,
|
||||
allowed_chats=["-100"],
|
||||
group_allowed_chats=["-100"],
|
||||
observe_unmentioned_group_messages=True,
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="/new@hermes_bot",
|
||||
message_type=MessageType.COMMAND,
|
||||
source=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-100",
|
||||
user_id="111",
|
||||
user_name="Alice",
|
||||
chat_type="group",
|
||||
thread_id="7",
|
||||
),
|
||||
raw_message=_group_message(
|
||||
"/new@hermes_bot",
|
||||
entities=[_bot_command_entity("/new@hermes_bot", "/new@hermes_bot")],
|
||||
),
|
||||
)
|
||||
|
||||
attributed = adapter._apply_telegram_group_observe_attribution(event)
|
||||
|
||||
assert attributed.text == "/new@hermes_bot"
|
||||
assert attributed.get_command() == "new"
|
||||
assert attributed.source.user_id is None
|
||||
assert "observed Telegram group context" in attributed.channel_prompt
|
||||
|
||||
|
||||
def test_unmentioned_group_observe_requires_chat_allowlist_for_shared_context():
|
||||
async def _run():
|
||||
adapter = _make_adapter(
|
||||
|
|
|
|||
90
tests/gateway/test_telegram_send_path_health.py
Normal file
90
tests/gateway/test_telegram_send_path_health.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
"""TelegramAdapter send-path health gating after reconnect storms.
|
||||
|
||||
After sustained Bad Gateway / TimedOut reconnect cycles, the PTB httpx client
|
||||
can enter a wedged state where ``bot.send_message()`` returns a valid Message
|
||||
but nothing reaches the recipient. ``_send_path_degraded`` short-circuits
|
||||
``send()`` so cron's live-adapter branch falls through to standalone HTTP.
|
||||
"""
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _ensure_telegram_mock():
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
return
|
||||
mod = MagicMock()
|
||||
mod.error.NetworkError = type("NetworkError", (OSError,), {})
|
||||
mod.error.TimedOut = type("TimedOut", (OSError,), {})
|
||||
mod.error.BadRequest = type("BadRequest", (Exception,), {})
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, mod)
|
||||
sys.modules.setdefault("telegram.error", mod.error)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
def _make_adapter() -> TelegramAdapter:
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
adapter._bot = MagicMock()
|
||||
adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=42))
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_succeeds_when_path_healthy():
|
||||
"""Healthy adapter delivers normally; send_message is called."""
|
||||
adapter = _make_adapter()
|
||||
assert adapter._send_path_degraded is False
|
||||
|
||||
result = await adapter.send("123", "hello")
|
||||
|
||||
assert result.success is True
|
||||
adapter._bot.send_message.assert_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_short_circuits_when_path_degraded():
|
||||
"""Degraded adapter returns failure WITHOUT calling send_message,
|
||||
so cron's live-adapter branch falls through to standalone HTTP."""
|
||||
adapter = _make_adapter()
|
||||
adapter._send_path_degraded = True
|
||||
|
||||
result = await adapter.send("123", "hello")
|
||||
|
||||
assert result.success is False
|
||||
assert result.error == "send_path_degraded"
|
||||
assert result.retryable is True
|
||||
adapter._bot.send_message.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_storm_sets_and_heartbeat_clears_flag(monkeypatch):
|
||||
"""_handle_polling_network_error sets the flag; a successful heartbeat
|
||||
probe in _verify_polling_after_reconnect clears it."""
|
||||
adapter = _make_adapter()
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.updater = MagicMock()
|
||||
adapter._app.updater.running = True
|
||||
adapter._app.updater.stop = AsyncMock()
|
||||
adapter._app.updater.start_polling = AsyncMock()
|
||||
adapter._app.bot = MagicMock()
|
||||
adapter._app.bot.get_me = AsyncMock(return_value=MagicMock())
|
||||
adapter._polling_error_callback_ref = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.telegram.Update", MagicMock(ALL_TYPES=[])
|
||||
)
|
||||
|
||||
await adapter._handle_polling_network_error(OSError("Bad Gateway"))
|
||||
assert adapter._send_path_degraded is True
|
||||
|
||||
with patch("gateway.platforms.telegram.asyncio.sleep", new_callable=AsyncMock):
|
||||
await adapter._verify_polling_after_reconnect()
|
||||
assert adapter._send_path_degraded is False
|
||||
162
tests/gateway/test_telegram_status_update.py
Normal file
162
tests/gateway/test_telegram_status_update.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""Tests for TelegramAdapter.send_or_update_status (issue #30045).
|
||||
|
||||
The status-update path must:
|
||||
1. Send a fresh message on the first call for a (chat_id, status_key) pair.
|
||||
2. Edit that same message on subsequent calls with the same key.
|
||||
3. Fall back to sending fresh when the cached message edit fails.
|
||||
4. Keep distinct keys independent (no cross-talk).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.base import SendResult
|
||||
|
||||
|
||||
def _install_fake_telegram(monkeypatch):
|
||||
"""Stub the python-telegram-bot package so TelegramAdapter can be imported."""
|
||||
fake_telegram = types.ModuleType("telegram")
|
||||
fake_telegram.Update = SimpleNamespace(ALL_TYPES=())
|
||||
fake_telegram.Bot = object
|
||||
fake_telegram.Message = object
|
||||
fake_telegram.InlineKeyboardButton = object
|
||||
fake_telegram.InlineKeyboardMarkup = object
|
||||
|
||||
fake_error = types.ModuleType("telegram.error")
|
||||
fake_error.NetworkError = type("NetworkError", (Exception,), {})
|
||||
fake_error.BadRequest = type("BadRequest", (Exception,), {})
|
||||
fake_error.TimedOut = type("TimedOut", (Exception,), {})
|
||||
fake_telegram.error = fake_error
|
||||
|
||||
fake_constants = types.ModuleType("telegram.constants")
|
||||
fake_constants.ParseMode = SimpleNamespace(MARKDOWN_V2="MarkdownV2")
|
||||
fake_constants.ChatType = SimpleNamespace(
|
||||
GROUP="group", SUPERGROUP="supergroup",
|
||||
CHANNEL="channel", PRIVATE="private",
|
||||
)
|
||||
fake_telegram.constants = fake_constants
|
||||
|
||||
fake_ext = types.ModuleType("telegram.ext")
|
||||
fake_ext.Application = object
|
||||
fake_ext.CommandHandler = object
|
||||
fake_ext.CallbackQueryHandler = object
|
||||
fake_ext.MessageHandler = object
|
||||
fake_ext.ContextTypes = SimpleNamespace(DEFAULT_TYPE=object)
|
||||
fake_ext.filters = object
|
||||
|
||||
fake_request = types.ModuleType("telegram.request")
|
||||
fake_request.HTTPXRequest = object
|
||||
|
||||
monkeypatch.setitem(sys.modules, "telegram", fake_telegram)
|
||||
monkeypatch.setitem(sys.modules, "telegram.error", fake_error)
|
||||
monkeypatch.setitem(sys.modules, "telegram.constants", fake_constants)
|
||||
monkeypatch.setitem(sys.modules, "telegram.ext", fake_ext)
|
||||
monkeypatch.setitem(sys.modules, "telegram.request", fake_request)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(monkeypatch):
|
||||
_install_fake_telegram(monkeypatch)
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
|
||||
a = TelegramAdapter(PlatformConfig(enabled=True, token="fake-token"))
|
||||
a._bot = MagicMock()
|
||||
# Patch send / edit_message so tests can drive them directly.
|
||||
a.send = AsyncMock()
|
||||
a.edit_message = AsyncMock()
|
||||
return a
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_call_sends_and_caches_message_id(adapter):
|
||||
"""First call for a (chat, key) pair must send and remember the id."""
|
||||
adapter.send.return_value = SendResult(success=True, message_id="100")
|
||||
|
||||
result = await adapter.send_or_update_status("chat-1", "lifecycle", "starting")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "100"
|
||||
adapter.send.assert_awaited_once()
|
||||
adapter.edit_message.assert_not_awaited()
|
||||
assert adapter._status_message_ids[("chat-1", "lifecycle")] == "100"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_call_edits_in_place(adapter):
|
||||
"""Same (chat, key) on the second call must edit, not send."""
|
||||
adapter.send.return_value = SendResult(success=True, message_id="100")
|
||||
adapter.edit_message.return_value = SendResult(success=True, message_id="100")
|
||||
|
||||
await adapter.send_or_update_status("chat-1", "lifecycle", "step 1")
|
||||
await adapter.send_or_update_status("chat-1", "lifecycle", "step 2")
|
||||
|
||||
adapter.send.assert_awaited_once()
|
||||
adapter.edit_message.assert_awaited_once()
|
||||
# Edit was directed at the cached message id.
|
||||
args, kwargs = adapter.edit_message.call_args
|
||||
assert args[0] == "chat-1"
|
||||
assert args[1] == "100"
|
||||
assert args[2] == "step 2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_failure_falls_back_to_fresh_send(adapter):
|
||||
"""When edit_message fails the cache is cleared and a new send happens."""
|
||||
adapter.send.side_effect = [
|
||||
SendResult(success=True, message_id="100"),
|
||||
SendResult(success=True, message_id="200"),
|
||||
]
|
||||
adapter.edit_message.return_value = SendResult(
|
||||
success=False, error="Bad Request: message to edit not found",
|
||||
)
|
||||
|
||||
await adapter.send_or_update_status("chat-1", "lifecycle", "step 1")
|
||||
result = await adapter.send_or_update_status("chat-1", "lifecycle", "step 2")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "200"
|
||||
assert adapter.send.await_count == 2
|
||||
assert adapter.edit_message.await_count == 1
|
||||
# Cache now points at the fresh message id.
|
||||
assert adapter._status_message_ids[("chat-1", "lifecycle")] == "200"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_distinct_status_keys_do_not_collide(adapter):
|
||||
"""A different status_key gets its own message; the original isn't touched."""
|
||||
adapter.send.side_effect = [
|
||||
SendResult(success=True, message_id="100"),
|
||||
SendResult(success=True, message_id="200"),
|
||||
]
|
||||
|
||||
await adapter.send_or_update_status("chat-1", "lifecycle", "ctx pressure")
|
||||
await adapter.send_or_update_status("chat-1", "model-switch", "switched to opus")
|
||||
|
||||
assert adapter.send.await_count == 2
|
||||
adapter.edit_message.assert_not_awaited()
|
||||
assert adapter._status_message_ids[("chat-1", "lifecycle")] == "100"
|
||||
assert adapter._status_message_ids[("chat-1", "model-switch")] == "200"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_distinct_chat_ids_do_not_collide(adapter):
|
||||
"""Same status_key in different chats must not edit each other's messages."""
|
||||
adapter.send.side_effect = [
|
||||
SendResult(success=True, message_id="100"),
|
||||
SendResult(success=True, message_id="200"),
|
||||
]
|
||||
|
||||
await adapter.send_or_update_status("chat-1", "lifecycle", "first")
|
||||
await adapter.send_or_update_status("chat-2", "lifecycle", "second")
|
||||
|
||||
assert adapter.send.await_count == 2
|
||||
adapter.edit_message.assert_not_awaited()
|
||||
assert adapter._status_message_ids[("chat-1", "lifecycle")] == "100"
|
||||
assert adapter._status_message_ids[("chat-2", "lifecycle")] == "200"
|
||||
|
|
@ -98,6 +98,7 @@ _fake_telegram_ext.Application = object
|
|||
_fake_telegram_ext.CommandHandler = object
|
||||
_fake_telegram_ext.CallbackQueryHandler = object
|
||||
_fake_telegram_ext.MessageHandler = object
|
||||
_fake_telegram_ext.TypeHandler = object
|
||||
_fake_telegram_ext.ContextTypes = SimpleNamespace(DEFAULT_TYPE=object)
|
||||
_fake_telegram_ext.filters = object
|
||||
_fake_telegram_request = types.ModuleType("telegram.request")
|
||||
|
|
|
|||
|
|
@ -1175,13 +1175,15 @@ def test_recover_returns_none_for_known_topic(tmp_path):
|
|||
assert runner._recover_telegram_topic_thread_id(_make_source(thread_id="222")) is None
|
||||
|
||||
|
||||
def test_recover_rewrites_unknown_thread_id_to_most_recent(tmp_path):
|
||||
# Cross-topic Reply leak: inbound thread_id is a Telegram-only id we never bound.
|
||||
def test_recover_preserves_unknown_thread_id_for_new_topic(tmp_path):
|
||||
# A newly-created Telegram DM topic arrives with a real, previously-unbound
|
||||
# message_thread_id. It must become its own session lane rather than being
|
||||
# rewritten to whichever older topic was most recently active.
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
_seed_two_topic_bindings(db)
|
||||
runner = _make_runner(session_db=db)
|
||||
|
||||
assert runner._recover_telegram_topic_thread_id(_make_source(thread_id="9999")) == "222"
|
||||
assert runner._recover_telegram_topic_thread_id(_make_source(thread_id="9999")) is None
|
||||
|
||||
|
||||
def test_recover_rewrites_lobby_thread_id_to_most_recent(tmp_path):
|
||||
|
|
@ -1209,6 +1211,31 @@ def test_recover_returns_none_when_no_bindings_yet(tmp_path):
|
|||
assert runner._recover_telegram_topic_thread_id(_make_source(thread_id=None)) is None
|
||||
|
||||
|
||||
def test_recover_returns_none_for_brand_new_topic(tmp_path):
|
||||
# Regression for #31086: bindings exist for a prior topic but the user
|
||||
# opened a fresh one (thread_id "99999"). Recovery must return None so the
|
||||
# new topic gets its own session rather than being silently merged into
|
||||
# the previous topic's session. The hijack was self-reinforcing — because
|
||||
# the rewrite ran before _record_telegram_topic_binding, the new topic's
|
||||
# binding row never got written, so every subsequent message in that topic
|
||||
# looked "unknown" and was hijacked again.
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.enable_telegram_topic_mode(chat_id="208214988", user_id="208214988")
|
||||
db.create_session(session_id="sess-old", source="telegram", user_id="208214988")
|
||||
src_old = _make_source(thread_id="12345")
|
||||
db.bind_telegram_topic(
|
||||
chat_id=src_old.chat_id,
|
||||
thread_id=src_old.thread_id,
|
||||
user_id=src_old.user_id,
|
||||
session_key=build_session_key(src_old),
|
||||
session_id="sess-old",
|
||||
)
|
||||
runner = _make_runner(session_db=db)
|
||||
|
||||
# "99999" is non-lobby and not in the binding table — brand-new topic.
|
||||
assert runner._recover_telegram_topic_thread_id(_make_source(thread_id="99999")) is None
|
||||
|
||||
|
||||
def test_list_telegram_topic_bindings_for_chat(tmp_path):
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
_seed_two_topic_bindings(db)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def _make_event(
|
|||
|
||||
def _make_discord_adapter():
|
||||
"""Create a minimal DiscordAdapter for testing text batching."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
|
||||
config = PlatformConfig(enabled=True, token="test-token")
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
|
|
|
|||
|
|
@ -50,11 +50,24 @@ def _event(thread_id=None):
|
|||
)
|
||||
|
||||
|
||||
def _allowed_media_path(tmp_path, monkeypatch, name):
|
||||
root = tmp_path / "media-cache"
|
||||
media_file = root / name
|
||||
media_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
media_file.write_bytes(b"media")
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS",
|
||||
(root,),
|
||||
)
|
||||
return media_file.resolve()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_adapter_routes_telegram_flac_media_tag_to_document_sender():
|
||||
async def test_base_adapter_routes_telegram_flac_media_tag_to_document_sender(tmp_path, monkeypatch):
|
||||
adapter = _MediaRoutingAdapter()
|
||||
event = _event()
|
||||
adapter._message_handler = AsyncMock(return_value="MEDIA:/tmp/speech.flac")
|
||||
media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.flac")
|
||||
adapter._message_handler = AsyncMock(return_value=f"MEDIA:{media_file}")
|
||||
adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice"))
|
||||
adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc"))
|
||||
|
||||
|
|
@ -62,17 +75,18 @@ async def test_base_adapter_routes_telegram_flac_media_tag_to_document_sender():
|
|||
|
||||
adapter.send_document.assert_awaited_once_with(
|
||||
chat_id="chat-1",
|
||||
file_path="/tmp/speech.flac",
|
||||
file_path=str(media_file),
|
||||
metadata=None,
|
||||
)
|
||||
adapter.send_voice.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_adapter_routes_non_voice_telegram_ogg_media_tag_to_document_sender():
|
||||
async def test_base_adapter_routes_non_voice_telegram_ogg_media_tag_to_document_sender(tmp_path, monkeypatch):
|
||||
adapter = _MediaRoutingAdapter()
|
||||
event = _event()
|
||||
adapter._message_handler = AsyncMock(return_value="MEDIA:/tmp/speech.ogg")
|
||||
media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.ogg")
|
||||
adapter._message_handler = AsyncMock(return_value=f"MEDIA:{media_file}")
|
||||
adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice"))
|
||||
adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc"))
|
||||
|
||||
|
|
@ -80,18 +94,19 @@ async def test_base_adapter_routes_non_voice_telegram_ogg_media_tag_to_document_
|
|||
|
||||
adapter.send_document.assert_awaited_once_with(
|
||||
chat_id="chat-1",
|
||||
file_path="/tmp/speech.ogg",
|
||||
file_path=str(media_file),
|
||||
metadata=None,
|
||||
)
|
||||
adapter.send_voice.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_adapter_routes_voice_tagged_telegram_ogg_media_tag_to_voice_sender():
|
||||
async def test_base_adapter_routes_voice_tagged_telegram_ogg_media_tag_to_voice_sender(tmp_path, monkeypatch):
|
||||
adapter = _MediaRoutingAdapter()
|
||||
event = _event()
|
||||
media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.ogg")
|
||||
adapter._message_handler = AsyncMock(
|
||||
return_value="[[audio_as_voice]]\nMEDIA:/tmp/speech.ogg"
|
||||
return_value=f"[[audio_as_voice]]\nMEDIA:{media_file}"
|
||||
)
|
||||
adapter.send_voice = AsyncMock(return_value=SendResult(success=True, message_id="voice"))
|
||||
adapter.send_document = AsyncMock(return_value=SendResult(success=True, message_id="doc"))
|
||||
|
|
@ -100,7 +115,7 @@ async def test_base_adapter_routes_voice_tagged_telegram_ogg_media_tag_to_voice_
|
|||
|
||||
adapter.send_voice.assert_awaited_once_with(
|
||||
chat_id="chat-1",
|
||||
audio_path="/tmp/speech.ogg",
|
||||
audio_path=str(media_file),
|
||||
metadata=None,
|
||||
)
|
||||
adapter.send_document.assert_not_awaited()
|
||||
|
|
@ -117,8 +132,9 @@ def _fake_runner(thread_meta):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_delivery_routes_telegram_flac_media_tag_to_document_sender():
|
||||
async def test_streaming_delivery_routes_telegram_flac_media_tag_to_document_sender(tmp_path, monkeypatch):
|
||||
event = _event(thread_id="topic-1")
|
||||
media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.flac")
|
||||
adapter = SimpleNamespace(
|
||||
name="test",
|
||||
extract_media=BasePlatformAdapter.extract_media,
|
||||
|
|
@ -132,22 +148,23 @@ async def test_streaming_delivery_routes_telegram_flac_media_tag_to_document_sen
|
|||
|
||||
await GatewayRunner._deliver_media_from_response(
|
||||
_fake_runner({"thread_id": "topic-1"}),
|
||||
"MEDIA:/tmp/speech.flac",
|
||||
f"MEDIA:{media_file}",
|
||||
event,
|
||||
adapter,
|
||||
)
|
||||
|
||||
adapter.send_document.assert_awaited_once_with(
|
||||
chat_id="chat-1",
|
||||
file_path="/tmp/speech.flac",
|
||||
file_path=str(media_file),
|
||||
metadata={"thread_id": "topic-1"},
|
||||
)
|
||||
adapter.send_voice.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_delivery_routes_non_voice_telegram_ogg_media_tag_to_document_sender():
|
||||
async def test_streaming_delivery_routes_non_voice_telegram_ogg_media_tag_to_document_sender(tmp_path, monkeypatch):
|
||||
event = _event(thread_id="topic-1")
|
||||
media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.ogg")
|
||||
adapter = SimpleNamespace(
|
||||
name="test",
|
||||
extract_media=BasePlatformAdapter.extract_media,
|
||||
|
|
@ -161,24 +178,25 @@ async def test_streaming_delivery_routes_non_voice_telegram_ogg_media_tag_to_doc
|
|||
|
||||
await GatewayRunner._deliver_media_from_response(
|
||||
_fake_runner({"thread_id": "topic-1"}),
|
||||
"MEDIA:/tmp/speech.ogg",
|
||||
f"MEDIA:{media_file}",
|
||||
event,
|
||||
adapter,
|
||||
)
|
||||
|
||||
adapter.send_document.assert_awaited_once_with(
|
||||
chat_id="chat-1",
|
||||
file_path="/tmp/speech.ogg",
|
||||
file_path=str(media_file),
|
||||
metadata={"thread_id": "topic-1"},
|
||||
)
|
||||
adapter.send_voice.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_delivery_routes_telegram_mp3_media_tag_to_voice_sender():
|
||||
async def test_streaming_delivery_routes_telegram_mp3_media_tag_to_voice_sender(tmp_path, monkeypatch):
|
||||
"""MP3 audio on Telegram must go through send_voice (which routes to
|
||||
sendAudio internally); Telegram accepts MP3 for the audio player."""
|
||||
event = _event(thread_id="topic-1")
|
||||
media_file = _allowed_media_path(tmp_path, monkeypatch, "speech.mp3")
|
||||
adapter = SimpleNamespace(
|
||||
name="test",
|
||||
extract_media=BasePlatformAdapter.extract_media,
|
||||
|
|
@ -192,14 +210,47 @@ async def test_streaming_delivery_routes_telegram_mp3_media_tag_to_voice_sender(
|
|||
|
||||
await GatewayRunner._deliver_media_from_response(
|
||||
_fake_runner({"thread_id": "topic-1"}),
|
||||
"MEDIA:/tmp/speech.mp3",
|
||||
f"MEDIA:{media_file}",
|
||||
event,
|
||||
adapter,
|
||||
)
|
||||
|
||||
adapter.send_voice.assert_awaited_once_with(
|
||||
chat_id="chat-1",
|
||||
audio_path="/tmp/speech.mp3",
|
||||
audio_path=str(media_file),
|
||||
metadata={"thread_id": "topic-1"},
|
||||
)
|
||||
adapter.send_document.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_delivery_blocks_media_path_outside_allowed_roots(tmp_path, monkeypatch):
|
||||
event = _event(thread_id="topic-1")
|
||||
allowed_root = tmp_path / "media-cache"
|
||||
allowed_root.mkdir()
|
||||
secret = tmp_path / "outside.pdf"
|
||||
secret.write_bytes(b"%PDF secret")
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.MEDIA_DELIVERY_SAFE_ROOTS",
|
||||
(allowed_root,),
|
||||
)
|
||||
adapter = SimpleNamespace(
|
||||
name="test",
|
||||
extract_media=BasePlatformAdapter.extract_media,
|
||||
extract_images=BasePlatformAdapter.extract_images,
|
||||
extract_local_files=BasePlatformAdapter.extract_local_files,
|
||||
send_voice=AsyncMock(return_value=SendResult(success=True, message_id="voice")),
|
||||
send_document=AsyncMock(return_value=SendResult(success=True, message_id="doc")),
|
||||
send_image_file=AsyncMock(return_value=SendResult(success=True, message_id="image")),
|
||||
send_video=AsyncMock(return_value=SendResult(success=True, message_id="video")),
|
||||
)
|
||||
|
||||
await GatewayRunner._deliver_media_from_response(
|
||||
_fake_runner({"thread_id": "topic-1"}),
|
||||
f"MEDIA:{secret}",
|
||||
event,
|
||||
adapter,
|
||||
)
|
||||
|
||||
adapter.send_document.assert_not_awaited()
|
||||
adapter.send_voice.assert_not_awaited()
|
||||
|
|
|
|||
|
|
@ -511,7 +511,7 @@ class TestDiscordPlayTtsSkip:
|
|||
"""Discord adapter skips play_tts when bot is in a voice channel."""
|
||||
|
||||
def _make_discord_adapter(self):
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
config.token = "fake-token"
|
||||
|
|
@ -599,7 +599,7 @@ class TestVoiceReceiver:
|
|||
"""Test VoiceReceiver silence detection, SSRC mapping, and lifecycle."""
|
||||
|
||||
def _make_receiver(self):
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
from plugins.platforms.discord.adapter import VoiceReceiver
|
||||
mock_vc = MagicMock()
|
||||
mock_vc._connection.secret_key = [0] * 32
|
||||
mock_vc._connection.dave_session = None
|
||||
|
|
@ -1066,7 +1066,7 @@ class TestDiscordVoiceChannelMethods:
|
|||
"""Test DiscordAdapter voice channel methods (join, leave, play, etc.)."""
|
||||
|
||||
def _make_adapter(self):
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
config.token = "fake-token"
|
||||
|
|
@ -1208,7 +1208,7 @@ class TestDiscordVoiceChannelMethods:
|
|||
|
||||
pcm_data = b"\x00" * 96000
|
||||
|
||||
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
|
||||
with patch("plugins.platforms.discord.adapter.VoiceReceiver.pcm_to_wav"), \
|
||||
patch("tools.transcription_tools.transcribe_audio",
|
||||
return_value={"success": True, "transcript": "Hello"}), \
|
||||
patch("tools.voice_mode.is_whisper_hallucination", return_value=False):
|
||||
|
|
@ -1223,7 +1223,7 @@ class TestDiscordVoiceChannelMethods:
|
|||
callback = AsyncMock()
|
||||
adapter._voice_input_callback = callback
|
||||
|
||||
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
|
||||
with patch("plugins.platforms.discord.adapter.VoiceReceiver.pcm_to_wav"), \
|
||||
patch("tools.transcription_tools.transcribe_audio",
|
||||
return_value={"success": True, "transcript": "Thank you."}), \
|
||||
patch("tools.voice_mode.is_whisper_hallucination", return_value=True):
|
||||
|
|
@ -1238,7 +1238,7 @@ class TestDiscordVoiceChannelMethods:
|
|||
callback = AsyncMock()
|
||||
adapter._voice_input_callback = callback
|
||||
|
||||
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \
|
||||
with patch("plugins.platforms.discord.adapter.VoiceReceiver.pcm_to_wav"), \
|
||||
patch("tools.transcription_tools.transcribe_audio",
|
||||
return_value={"success": False, "error": "API error"}):
|
||||
await adapter._process_voice_input(111, 42, b"\x00" * 96000)
|
||||
|
|
@ -1251,7 +1251,7 @@ class TestDiscordVoiceChannelMethods:
|
|||
adapter = self._make_adapter()
|
||||
adapter._voice_input_callback = AsyncMock()
|
||||
|
||||
with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav",
|
||||
with patch("plugins.platforms.discord.adapter.VoiceReceiver.pcm_to_wav",
|
||||
side_effect=RuntimeError("ffmpeg not found")):
|
||||
await adapter._process_voice_input(111, 42, b"\x00" * 96000)
|
||||
# Should not raise
|
||||
|
|
@ -1269,7 +1269,7 @@ class TestVoiceReceiverThreadSafety:
|
|||
"""Verify that VoiceReceiver buffer access is protected by lock."""
|
||||
|
||||
def _make_receiver(self):
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
from plugins.platforms.discord.adapter import VoiceReceiver
|
||||
mock_vc = MagicMock()
|
||||
mock_vc._connection.secret_key = [0] * 32
|
||||
mock_vc._connection.dave_session = None
|
||||
|
|
@ -1282,7 +1282,7 @@ class TestVoiceReceiverThreadSafety:
|
|||
def test_check_silence_holds_lock(self):
|
||||
"""check_silence must hold lock while iterating buffers."""
|
||||
import ast, inspect, textwrap
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
from plugins.platforms.discord.adapter import VoiceReceiver
|
||||
source = textwrap.dedent(inspect.getsource(VoiceReceiver.check_silence))
|
||||
tree = ast.parse(source)
|
||||
# Find 'with self._lock:' that contains buffer iteration
|
||||
|
|
@ -1303,7 +1303,7 @@ class TestVoiceReceiverThreadSafety:
|
|||
def test_on_packet_buffer_write_holds_lock(self):
|
||||
"""_on_packet must hold lock when writing to buffers."""
|
||||
import ast, inspect, textwrap
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
from plugins.platforms.discord.adapter import VoiceReceiver
|
||||
source = textwrap.dedent(inspect.getsource(VoiceReceiver._on_packet))
|
||||
tree = ast.parse(source)
|
||||
# Find 'with self._lock:' that contains buffer extend
|
||||
|
|
@ -1670,7 +1670,7 @@ class TestStopAcquiresLock:
|
|||
|
||||
@staticmethod
|
||||
def _make_receiver():
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
from plugins.platforms.discord.adapter import VoiceReceiver
|
||||
vc = MagicMock()
|
||||
vc._connection.secret_key = [0] * 32
|
||||
vc._connection.dave_session = None
|
||||
|
|
@ -1772,7 +1772,7 @@ class TestPacketDebugCounterIsInstanceLevel:
|
|||
|
||||
@staticmethod
|
||||
def _make_receiver():
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
from plugins.platforms.discord.adapter import VoiceReceiver
|
||||
vc = MagicMock()
|
||||
vc._connection.secret_key = [0] * 32
|
||||
vc._connection.dave_session = None
|
||||
|
|
@ -1805,7 +1805,7 @@ class TestPlayInVoiceChannelUsesRunningLoop:
|
|||
def test_source_uses_get_running_loop(self):
|
||||
"""The method source code calls get_running_loop, not get_event_loop."""
|
||||
import inspect
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
source = inspect.getsource(DiscordAdapter.play_in_voice_channel)
|
||||
assert "get_running_loop" in source, \
|
||||
"play_in_voice_channel should use asyncio.get_running_loop()"
|
||||
|
|
@ -1849,7 +1849,7 @@ class TestVoiceTimeoutCleansRunnerState:
|
|||
|
||||
@staticmethod
|
||||
def _make_discord_adapter():
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
from gateway.config import PlatformConfig, Platform
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
config.token = "fake-token"
|
||||
|
|
@ -1940,7 +1940,7 @@ class TestPlaybackTimeout:
|
|||
|
||||
@staticmethod
|
||||
def _make_discord_adapter():
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
from gateway.config import PlatformConfig, Platform
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
config.token = "fake-token"
|
||||
|
|
@ -1964,7 +1964,7 @@ class TestPlaybackTimeout:
|
|||
def test_source_has_wait_for_timeout(self):
|
||||
"""The method uses asyncio.wait_for with timeout."""
|
||||
import inspect
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
source = inspect.getsource(DiscordAdapter.play_in_voice_channel)
|
||||
assert "wait_for" in source, \
|
||||
"play_in_voice_channel must use asyncio.wait_for for timeout"
|
||||
|
|
@ -1973,14 +1973,14 @@ class TestPlaybackTimeout:
|
|||
|
||||
def test_playback_timeout_constant_exists(self):
|
||||
"""PLAYBACK_TIMEOUT constant is defined on DiscordAdapter."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
assert hasattr(DiscordAdapter, "PLAYBACK_TIMEOUT")
|
||||
assert DiscordAdapter.PLAYBACK_TIMEOUT > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_playback_timeout_fires(self):
|
||||
"""When done event is never set, playback times out gracefully."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
adapter = self._make_discord_adapter()
|
||||
|
||||
mock_vc = MagicMock()
|
||||
|
|
@ -2008,7 +2008,7 @@ class TestPlaybackTimeout:
|
|||
@pytest.mark.asyncio
|
||||
async def test_is_playing_wait_has_timeout(self):
|
||||
"""While loop waiting for previous playback has a timeout."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
adapter = self._make_discord_adapter()
|
||||
|
||||
mock_vc = MagicMock()
|
||||
|
|
@ -2124,7 +2124,7 @@ class TestVoiceChannelAwareness:
|
|||
"""Tests for get_voice_channel_info() and get_voice_channel_context()."""
|
||||
|
||||
def _make_adapter(self):
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
config.token = "fake-token"
|
||||
|
|
@ -2267,7 +2267,7 @@ class TestVoiceReception:
|
|||
|
||||
@staticmethod
|
||||
def _make_receiver(allowed_ids=None, members=None, dave=False, bot_id=9999):
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
from plugins.platforms.discord.adapter import VoiceReceiver
|
||||
vc = MagicMock()
|
||||
vc._connection.secret_key = [0] * 32
|
||||
vc._connection.dave_session = MagicMock() if dave else None
|
||||
|
|
@ -2451,7 +2451,7 @@ class TestVoiceReception:
|
|||
|
||||
def _make_receiver_with_nacl(self, dave_session=None, mapped_ssrcs=None):
|
||||
"""Create a receiver that can process _on_packet with mocked NaCl + Opus."""
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
from plugins.platforms.discord.adapter import VoiceReceiver
|
||||
vc = MagicMock()
|
||||
vc._connection.secret_key = [0] * 32
|
||||
vc._connection.dave_session = dave_session
|
||||
|
|
@ -2593,7 +2593,7 @@ class TestVoiceTTSPlayback:
|
|||
|
||||
@staticmethod
|
||||
def _make_discord_adapter():
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
from gateway.config import PlatformConfig, Platform
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
config.token = "fake-token"
|
||||
|
|
@ -2766,14 +2766,14 @@ class TestUDPKeepalive:
|
|||
"""UDP keepalive prevents Discord from dropping the voice session."""
|
||||
|
||||
def test_keepalive_interval_is_reasonable(self):
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
interval = DiscordAdapter._KEEPALIVE_INTERVAL
|
||||
assert 5 <= interval <= 30, f"Keepalive interval {interval}s should be between 5-30s"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keepalive_sends_silence_frame(self):
|
||||
"""Listen loop sends silence frame via send_packet after interval."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
from gateway.config import PlatformConfig, Platform
|
||||
|
||||
config = PlatformConfig(enabled=True, extra={})
|
||||
|
|
@ -2795,7 +2795,7 @@ class TestUDPKeepalive:
|
|||
adapter._voice_clients[111] = mock_vc
|
||||
mock_vc._connection = mock_conn
|
||||
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
from plugins.platforms.discord.adapter import VoiceReceiver
|
||||
mock_receiver_vc = MagicMock()
|
||||
mock_receiver_vc._connection.secret_key = [0] * 32
|
||||
mock_receiver_vc._connection.dave_session = None
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ Covers:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
|
|
@ -100,6 +101,18 @@ def _generic_signature(body: bytes, secret: str) -> str:
|
|||
return hmac.new(secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
|
||||
def _svix_signature(body: bytes, secret: str, msg_id: str, timestamp: str) -> str:
|
||||
"""Compute a Svix v1 signature header for *body* using *secret*."""
|
||||
key = (
|
||||
base64.b64decode(secret.removeprefix("whsec_"))
|
||||
if secret.startswith("whsec_")
|
||||
else secret.encode()
|
||||
)
|
||||
signed = msg_id.encode() + b"." + timestamp.encode() + b"." + body
|
||||
digest = hmac.new(key, signed, hashlib.sha256).digest()
|
||||
return "v1," + base64.b64encode(digest).decode()
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Signature validation
|
||||
# ===================================================================
|
||||
|
|
@ -170,6 +183,134 @@ class TestValidateSignature:
|
|||
req = _mock_request(headers={"X-Webhook-Signature": sig})
|
||||
assert adapter._validate_signature(req, body, secret) is True
|
||||
|
||||
def test_validate_svix_signature_valid(self):
|
||||
"""Valid Svix/AgentMail v1 signature headers are accepted."""
|
||||
adapter = _make_adapter()
|
||||
body = b'{"event_type":"message.received"}'
|
||||
secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode()
|
||||
msg_id = "msg_123"
|
||||
timestamp = str(int(time.time()))
|
||||
sig = _svix_signature(body, secret, msg_id, timestamp)
|
||||
req = _mock_request(
|
||||
headers={
|
||||
"svix-id": msg_id,
|
||||
"svix-timestamp": timestamp,
|
||||
"svix-signature": sig,
|
||||
}
|
||||
)
|
||||
assert adapter._validate_signature(req, body, secret) is True
|
||||
|
||||
def test_validate_svix_signature_wrong_body_rejects(self):
|
||||
"""Svix/AgentMail signatures are bound to the exact raw request body."""
|
||||
adapter = _make_adapter()
|
||||
signed_body = b'{"event_type":"message.received"}'
|
||||
received_body = b'{"event_type":"message.sent"}'
|
||||
secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode()
|
||||
msg_id = "msg_123"
|
||||
timestamp = str(int(time.time()))
|
||||
sig = _svix_signature(signed_body, secret, msg_id, timestamp)
|
||||
req = _mock_request(
|
||||
headers={
|
||||
"svix-id": msg_id,
|
||||
"svix-timestamp": timestamp,
|
||||
"svix-signature": sig,
|
||||
}
|
||||
)
|
||||
assert adapter._validate_signature(req, received_body, secret) is False
|
||||
|
||||
def test_validate_svix_signature_old_timestamp_rejects(self):
|
||||
"""Svix/AgentMail signatures outside the replay window are rejected."""
|
||||
adapter = _make_adapter()
|
||||
body = b'{"event_type":"message.received"}'
|
||||
secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode()
|
||||
msg_id = "msg_123"
|
||||
timestamp = str(int(time.time()) - 301)
|
||||
sig = _svix_signature(body, secret, msg_id, timestamp)
|
||||
req = _mock_request(
|
||||
headers={
|
||||
"svix-id": msg_id,
|
||||
"svix-timestamp": timestamp,
|
||||
"svix-signature": sig,
|
||||
}
|
||||
)
|
||||
assert adapter._validate_signature(req, body, secret) is False
|
||||
|
||||
def test_validate_svix_signature_multiple_entries_accepts_matching_v1(self):
|
||||
"""Svix rotation headers may contain multiple space-separated signatures."""
|
||||
adapter = _make_adapter()
|
||||
body = b'{"event_type":"message.received"}'
|
||||
secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode()
|
||||
msg_id = "msg_123"
|
||||
timestamp = str(int(time.time()))
|
||||
sig = _svix_signature(body, secret, msg_id, timestamp)
|
||||
req = _mock_request(
|
||||
headers={
|
||||
"svix-id": msg_id,
|
||||
"svix-timestamp": timestamp,
|
||||
"svix-signature": "v1,wrong " + sig,
|
||||
}
|
||||
)
|
||||
assert adapter._validate_signature(req, body, secret) is True
|
||||
|
||||
def test_validate_svix_signature_missing_signature_rejects(self):
|
||||
"""Partial Svix headers reject instead of falling through to another scheme."""
|
||||
adapter = _make_adapter()
|
||||
req = _mock_request(headers={"svix-id": "msg_123"})
|
||||
assert adapter._validate_signature(req, b"{}", "secret") is False
|
||||
|
||||
def test_validate_svix_signature_unsupported_version_rejects(self):
|
||||
"""Only Svix v1 signatures are accepted."""
|
||||
adapter = _make_adapter()
|
||||
body = b'{"event_type":"message.received"}'
|
||||
secret = "whsec_" + base64.b64encode(b"agentmail-signing-secret").decode()
|
||||
msg_id = "msg_123"
|
||||
timestamp = str(int(time.time()))
|
||||
sig = _svix_signature(body, secret, msg_id, timestamp).replace("v1,", "v2,")
|
||||
req = _mock_request(
|
||||
headers={
|
||||
"svix-id": msg_id,
|
||||
"svix-timestamp": timestamp,
|
||||
"svix-signature": sig,
|
||||
}
|
||||
)
|
||||
assert adapter._validate_signature(req, body, secret) is False
|
||||
|
||||
def test_validate_svix_signature_invalid_whsec_rejects(self):
|
||||
"""Malformed whsec_ secrets are rejected, not silently treated as raw secrets."""
|
||||
adapter = _make_adapter()
|
||||
body = b'{"event_type":"message.received"}'
|
||||
malformed_secret = "whsec_not-valid-base64!"
|
||||
msg_id = "msg_123"
|
||||
timestamp = str(int(time.time()))
|
||||
raw_sig = _svix_signature(
|
||||
body, malformed_secret.removeprefix("whsec_"), msg_id, timestamp
|
||||
)
|
||||
req = _mock_request(
|
||||
headers={
|
||||
"svix-id": msg_id,
|
||||
"svix-timestamp": timestamp,
|
||||
"svix-signature": raw_sig,
|
||||
}
|
||||
)
|
||||
assert adapter._validate_signature(req, body, malformed_secret) is False
|
||||
|
||||
def test_validate_svix_signature_raw_secret_valid(self):
|
||||
"""Raw shared secrets are accepted for Svix-style senders without whsec_ secrets."""
|
||||
adapter = _make_adapter()
|
||||
body = b'{"event_type":"message.received"}'
|
||||
secret = "raw-agentmail-secret"
|
||||
msg_id = "msg_123"
|
||||
timestamp = str(int(time.time()))
|
||||
sig = _svix_signature(body, secret, msg_id, timestamp)
|
||||
req = _mock_request(
|
||||
headers={
|
||||
"svix-id": msg_id,
|
||||
"svix-timestamp": timestamp,
|
||||
"svix-signature": sig,
|
||||
}
|
||||
)
|
||||
assert adapter._validate_signature(req, body, secret) is True
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Prompt rendering
|
||||
|
|
@ -304,6 +445,27 @@ class TestEventFilter:
|
|||
)
|
||||
assert resp.status == 202
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_filter_accepts_payload_type_field(self):
|
||||
"""Svix-style payloads often use a top-level `type` event field."""
|
||||
routes = {
|
||||
"svix": {
|
||||
"secret": _INSECURE_NO_AUTH,
|
||||
"events": ["message.received"],
|
||||
"prompt": "got it",
|
||||
}
|
||||
}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post(
|
||||
"/webhooks/svix",
|
||||
json={"type": "message.received"},
|
||||
)
|
||||
assert resp.status == 202
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# HTTP handling
|
||||
|
|
@ -336,6 +498,22 @@ class TestHTTPHandling:
|
|||
assert data["status"] == "accepted"
|
||||
assert data["route"] == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_without_secret_rejects_unsigned_request(self):
|
||||
"""Missing HMAC secret must fail closed even if connect() was bypassed."""
|
||||
routes = {"test": {"prompt": "hi"}}
|
||||
adapter = _make_adapter(routes=routes, secret="")
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.post("/webhooks/test", json={"data": "value"})
|
||||
assert resp.status == 403
|
||||
data = await resp.json()
|
||||
assert data["error"] == "Webhook route is missing an HMAC secret"
|
||||
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint(self):
|
||||
"""GET /health returns 200 with status=ok."""
|
||||
|
|
@ -432,6 +610,25 @@ class TestIdempotency:
|
|||
resp2 = await cli.post("/webhooks/idem", json={"x": 1}, headers=headers)
|
||||
assert resp2.status == 202 # re-accepted
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_svix_id_used_as_delivery_id_for_deduplication(self):
|
||||
"""Svix retries reuse svix-id, so use it as the delivery ID when present."""
|
||||
routes = {"idem": {"secret": _INSECURE_NO_AUTH, "prompt": "test"}}
|
||||
adapter = _make_adapter(routes=routes)
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
headers = {"svix-id": "msg_duplicate"}
|
||||
resp1 = await cli.post("/webhooks/idem", json={"a": 1}, headers=headers)
|
||||
assert resp1.status == 202
|
||||
|
||||
resp2 = await cli.post("/webhooks/idem", json={"a": 1}, headers=headers)
|
||||
assert resp2.status == 200
|
||||
data = await resp2.json()
|
||||
assert data["status"] == "duplicate"
|
||||
assert data["delivery_id"] == "msg_duplicate"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Rate limiting
|
||||
|
|
|
|||
|
|
@ -6,7 +6,11 @@ import pytest
|
|||
from pathlib import Path
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.webhook import WebhookAdapter, _DYNAMIC_ROUTES_FILENAME
|
||||
from gateway.platforms.webhook import (
|
||||
WebhookAdapter,
|
||||
_DYNAMIC_ROUTES_FILENAME,
|
||||
_INSECURE_NO_AUTH,
|
||||
)
|
||||
|
||||
|
||||
def _make_adapter(routes=None, extra=None):
|
||||
|
|
@ -85,3 +89,88 @@ class TestDynamicRouteLoading:
|
|||
adapter._reload_dynamic_routes()
|
||||
assert "static" in adapter._routes
|
||||
assert len(adapter._dynamic_routes) == 0
|
||||
|
||||
|
||||
class TestDynamicRouteSecretValidation:
|
||||
"""Empty/missing secrets must be rejected during hot-reload.
|
||||
|
||||
Regression for HMAC bypass: prior to the fix, an agent-induced
|
||||
dynamic route with `"secret": ""` would be merged into self._routes
|
||||
by _reload_dynamic_routes(), then _handle_webhook's
|
||||
`if secret and secret != _INSECURE_NO_AUTH` would skip signature
|
||||
validation because empty string is falsy. Unauthenticated POSTs
|
||||
would then execute the webhook prompt.
|
||||
"""
|
||||
|
||||
def test_empty_secret_rejected(self, tmp_path):
|
||||
# Explicit empty-string secret must NOT fall back to the global
|
||||
# secret, and the route must be skipped entirely.
|
||||
(tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(
|
||||
json.dumps({"evil": {"secret": "", "prompt": "rm -rf"}})
|
||||
)
|
||||
adapter = _make_adapter() # has global secret
|
||||
adapter._reload_dynamic_routes()
|
||||
assert "evil" not in adapter._routes
|
||||
assert "evil" not in adapter._dynamic_routes
|
||||
|
||||
def test_missing_secret_no_global_rejected(self, tmp_path):
|
||||
(tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(
|
||||
json.dumps({"orphan": {"prompt": "test"}})
|
||||
)
|
||||
# No global secret configured
|
||||
adapter = _make_adapter(extra={"secret": ""})
|
||||
adapter._reload_dynamic_routes()
|
||||
assert "orphan" not in adapter._routes
|
||||
assert "orphan" not in adapter._dynamic_routes
|
||||
|
||||
def test_missing_secret_inherits_global(self, tmp_path):
|
||||
# No per-route secret but a global one is set → route is kept,
|
||||
# the global secret protects it. Preserves existing fallback.
|
||||
(tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(
|
||||
json.dumps({"valid": {"prompt": "ok"}})
|
||||
)
|
||||
adapter = _make_adapter() # global secret set
|
||||
adapter._reload_dynamic_routes()
|
||||
assert "valid" in adapter._routes
|
||||
|
||||
def test_insecure_no_auth_preserved(self, tmp_path):
|
||||
# Explicit opt-in escape hatch for local testing — must still load.
|
||||
(tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(
|
||||
json.dumps({"test": {"secret": _INSECURE_NO_AUTH, "prompt": "p"}})
|
||||
)
|
||||
adapter = _make_adapter(extra={"host": "127.0.0.1"})
|
||||
adapter._reload_dynamic_routes()
|
||||
assert "test" in adapter._routes
|
||||
|
||||
def test_insecure_no_auth_rejected_on_non_loopback_bind(self, tmp_path):
|
||||
# Dynamic INSECURE_NO_AUTH routes are only valid on loopback hosts.
|
||||
(tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(
|
||||
json.dumps({"pub": {"secret": _INSECURE_NO_AUTH, "prompt": "p"}})
|
||||
)
|
||||
adapter = _make_adapter(extra={"host": "0.0.0.0"})
|
||||
adapter._reload_dynamic_routes()
|
||||
assert "pub" not in adapter._routes
|
||||
assert "pub" not in adapter._dynamic_routes
|
||||
|
||||
def test_warning_logged_on_skip(self, tmp_path, caplog):
|
||||
import logging
|
||||
(tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(
|
||||
json.dumps({"silent": {"secret": "", "prompt": "x"}})
|
||||
)
|
||||
adapter = _make_adapter()
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.platforms.webhook"):
|
||||
adapter._reload_dynamic_routes()
|
||||
assert any("silent" in rec.message for rec in caplog.records)
|
||||
|
||||
def test_partial_skip(self, tmp_path):
|
||||
# One route bad, one route good — only the bad one is dropped.
|
||||
(tmp_path / _DYNAMIC_ROUTES_FILENAME).write_text(
|
||||
json.dumps({
|
||||
"bad": {"secret": "", "prompt": "x"},
|
||||
"good": {"secret": "valid-secret", "prompt": "y"},
|
||||
})
|
||||
)
|
||||
adapter = _make_adapter()
|
||||
adapter._reload_dynamic_routes()
|
||||
assert "good" in adapter._routes
|
||||
assert "bad" not in adapter._routes
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Tests for the WeCom platform adapter."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
|
@ -831,3 +832,91 @@ class TestWeComZombieSessionFix:
|
|||
cmd = adapter._send_request.await_args.args[0]
|
||||
assert cmd == APP_CMD_SEND
|
||||
|
||||
|
||||
|
||||
class TestTextBatchFlushRace:
|
||||
"""Regression tests for the cancel-delivery race in _flush_text_batch.
|
||||
|
||||
When asyncio.sleep() fires and Task.cancel() is called before the task
|
||||
runs, CPython sets _must_cancel but cannot cancel the already-done sleep
|
||||
future. CancelledError is then delivered at the *next* await
|
||||
(handle_message), after the task has already popped the event — the
|
||||
superseding task sees an empty batch and silently drops the message.
|
||||
The fix adds a synchronous task-registry check between the sleep and
|
||||
the pop so a superseded task returns before touching the event.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_superseded_task_does_not_pop_or_process_event(self):
|
||||
"""A flush task that has been superseded must leave the event in the
|
||||
batch dict for the new task to handle."""
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.platforms.wecom import WeComAdapter
|
||||
|
||||
adapter = WeComAdapter(PlatformConfig(enabled=True))
|
||||
adapter._text_batch_delay_seconds = 0
|
||||
|
||||
key = "test-session"
|
||||
event = MessageEvent(text="hello", message_type=MessageType.TEXT)
|
||||
adapter._pending_text_batches[key] = event
|
||||
|
||||
handle_calls = []
|
||||
|
||||
async def fake_handle(evt):
|
||||
handle_calls.append(evt)
|
||||
|
||||
adapter.handle_message = fake_handle
|
||||
|
||||
# Create T1 and register it.
|
||||
t1 = asyncio.create_task(adapter._flush_text_batch(key))
|
||||
adapter._pending_text_batch_tasks[key] = t1
|
||||
|
||||
# Simulate T2 superseding T1 before T1 wakes from sleep.
|
||||
t2 = asyncio.create_task(asyncio.sleep(9999))
|
||||
adapter._pending_text_batch_tasks[key] = t2
|
||||
|
||||
# Yield long enough for T1's sleep(0) to complete and T1 to run.
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
t2.cancel()
|
||||
try:
|
||||
await t2
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# T1 must have returned without processing or removing the event.
|
||||
assert handle_calls == [], "superseded task must not call handle_message"
|
||||
assert adapter._pending_text_batches.get(key) is event, (
|
||||
"superseded task must not pop the event"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_task_processes_event_normally(self):
|
||||
"""When the task is not superseded it must still process the event."""
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.platforms.wecom import WeComAdapter
|
||||
|
||||
adapter = WeComAdapter(PlatformConfig(enabled=True))
|
||||
adapter._text_batch_delay_seconds = 0
|
||||
|
||||
key = "test-session"
|
||||
event = MessageEvent(text="world", message_type=MessageType.TEXT)
|
||||
adapter._pending_text_batches[key] = event
|
||||
|
||||
handle_calls = []
|
||||
|
||||
async def fake_handle(evt):
|
||||
handle_calls.append(evt)
|
||||
|
||||
adapter.handle_message = fake_handle
|
||||
|
||||
t1 = asyncio.create_task(adapter._flush_text_batch(key))
|
||||
adapter._pending_text_batch_tasks[key] = t1
|
||||
|
||||
# No superseding task — T1 should process normally.
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert handle_calls == [event], "active task must call handle_message"
|
||||
assert adapter._pending_text_batches.get(key) is None, (
|
||||
"active task must pop the event after processing"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -153,6 +153,130 @@ class TestWecomCallbackRouting:
|
|||
assert calls["json"]["agentid"] == 1001
|
||||
|
||||
|
||||
class TestWecomCallbackSendTokenRefresh:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_retries_with_fresh_token_on_errcode_40001(self):
|
||||
"""errcode=40001 must evict the cached token, refresh, and retry once."""
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
adapter._access_tokens["test-app"] = {"token": "stale", "expires_at": 9999999999}
|
||||
adapter._user_app_map["ww1234567890:alice"] = "test-app"
|
||||
|
||||
responses = [
|
||||
{"errcode": 40001, "errmsg": "invalid credential"},
|
||||
{"errcode": 0, "msgid": "msg-ok"},
|
||||
]
|
||||
post_calls = []
|
||||
|
||||
class FakeClient:
|
||||
async def post(self, url, json=None, **kw):
|
||||
post_calls.append(url)
|
||||
|
||||
class R:
|
||||
def json(inner):
|
||||
return responses[len(post_calls) - 1]
|
||||
return R()
|
||||
|
||||
async def get(self, url, params=None, **kw):
|
||||
class R:
|
||||
def json(inner):
|
||||
return {"errcode": 0, "access_token": "fresh", "expires_in": 7200}
|
||||
return R()
|
||||
|
||||
adapter._http_client = FakeClient()
|
||||
result = await adapter.send("ww1234567890:alice", "hello")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "msg-ok"
|
||||
assert len(post_calls) == 2
|
||||
assert "fresh" in post_calls[1]
|
||||
assert adapter._access_tokens["test-app"]["token"] == "fresh"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_retries_with_fresh_token_on_errcode_42001(self):
|
||||
"""errcode=42001 (token expired) must also trigger the refresh-retry path."""
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
adapter._access_tokens["test-app"] = {"token": "expired", "expires_at": 9999999999}
|
||||
|
||||
responses = [
|
||||
{"errcode": 42001, "errmsg": "access_token expired"},
|
||||
{"errcode": 0, "msgid": "msg-42"},
|
||||
]
|
||||
post_calls = []
|
||||
|
||||
class FakeClient:
|
||||
async def post(self, url, json=None, **kw):
|
||||
post_calls.append(url)
|
||||
|
||||
class R:
|
||||
def json(inner):
|
||||
return responses[len(post_calls) - 1]
|
||||
return R()
|
||||
|
||||
async def get(self, url, params=None, **kw):
|
||||
class R:
|
||||
def json(inner):
|
||||
return {"errcode": 0, "access_token": "renewed", "expires_in": 7200}
|
||||
return R()
|
||||
|
||||
adapter._http_client = FakeClient()
|
||||
result = await adapter.send("alice", "hello")
|
||||
|
||||
assert result.success is True
|
||||
assert len(post_calls) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_does_not_retry_on_non_token_errcode(self):
|
||||
"""Errors unrelated to token validity must fail immediately without retrying."""
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
adapter._access_tokens["test-app"] = {"token": "good", "expires_at": 9999999999}
|
||||
|
||||
post_calls = []
|
||||
|
||||
class FakeClient:
|
||||
async def post(self, url, json=None, **kw):
|
||||
post_calls.append(url)
|
||||
|
||||
class R:
|
||||
def json(inner):
|
||||
return {"errcode": 60020, "errmsg": "not allow to access"}
|
||||
return R()
|
||||
|
||||
adapter._http_client = FakeClient()
|
||||
result = await adapter.send("alice", "hello")
|
||||
|
||||
assert result.success is False
|
||||
assert len(post_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_fails_cleanly_when_retry_also_fails(self):
|
||||
"""If the refreshed token is also rejected, return failure without looping further."""
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
adapter._access_tokens["test-app"] = {"token": "bad1", "expires_at": 9999999999}
|
||||
|
||||
post_calls = []
|
||||
|
||||
class FakeClient:
|
||||
async def post(self, url, json=None, **kw):
|
||||
post_calls.append(url)
|
||||
|
||||
class R:
|
||||
def json(inner):
|
||||
return {"errcode": 42001, "errmsg": "access_token expired"}
|
||||
return R()
|
||||
|
||||
async def get(self, url, params=None, **kw):
|
||||
class R:
|
||||
def json(inner):
|
||||
return {"errcode": 0, "access_token": "bad2", "expires_in": 7200}
|
||||
return R()
|
||||
|
||||
adapter._http_client = FakeClient()
|
||||
result = await adapter.send("alice", "hello")
|
||||
|
||||
assert result.success is False
|
||||
assert len(post_calls) == 2
|
||||
|
||||
|
||||
class TestWecomCallbackPollLoop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_loop_dispatches_handle_message(self, monkeypatch):
|
||||
|
|
|
|||
|
|
@ -57,6 +57,59 @@ def _build_parser():
|
|||
return parser
|
||||
|
||||
|
||||
class TestChatVerboseArg:
|
||||
"""Verify chat --verbose preserves config fallback when absent."""
|
||||
|
||||
def test_chat_without_verbose_leaves_attribute_unset(self):
|
||||
from hermes_cli._parser import build_top_level_parser
|
||||
|
||||
parser, _subparsers, _chat_parser = build_top_level_parser()
|
||||
args = parser.parse_args(["chat"])
|
||||
|
||||
assert not hasattr(args, "verbose")
|
||||
|
||||
def test_chat_verbose_sets_attribute_true(self):
|
||||
from hermes_cli._parser import build_top_level_parser
|
||||
|
||||
parser, _subparsers, _chat_parser = build_top_level_parser()
|
||||
args = parser.parse_args(["chat", "--verbose"])
|
||||
|
||||
assert args.verbose is True
|
||||
|
||||
def test_cmd_chat_forwards_none_when_verbose_is_absent(self, monkeypatch):
|
||||
import types
|
||||
import sys
|
||||
|
||||
import hermes_cli.main as main_mod
|
||||
from hermes_cli._parser import build_top_level_parser
|
||||
|
||||
parser, _subparsers, chat_parser = build_top_level_parser()
|
||||
chat_parser.set_defaults(func=main_mod.cmd_chat)
|
||||
args = parser.parse_args(["chat"])
|
||||
captured = {}
|
||||
fake_cli = types.ModuleType("cli")
|
||||
|
||||
def fake_main(**kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
setattr(fake_cli, "main", fake_main)
|
||||
fake_banner = types.ModuleType("hermes_cli.banner")
|
||||
setattr(fake_banner, "prefetch_update_check", lambda: None)
|
||||
fake_skills_sync = types.ModuleType("tools.skills_sync")
|
||||
setattr(fake_skills_sync, "sync_skills", lambda quiet=True: None)
|
||||
|
||||
monkeypatch.setitem(sys.modules, "cli", fake_cli)
|
||||
monkeypatch.setitem(sys.modules, "hermes_cli.banner", fake_banner)
|
||||
monkeypatch.setitem(sys.modules, "tools.skills_sync", fake_skills_sync)
|
||||
monkeypatch.setattr(main_mod, "_has_any_provider_configured", lambda: True)
|
||||
monkeypatch.setattr(main_mod, "_pin_kanban_board_env", lambda: None)
|
||||
|
||||
main_mod.cmd_chat(args)
|
||||
|
||||
assert captured["quiet"] is False
|
||||
assert "verbose" not in captured
|
||||
|
||||
|
||||
class TestYoloEnvVar:
|
||||
"""Verify --yolo sets HERMES_YOLO_MODE regardless of flag position.
|
||||
|
||||
|
|
|
|||
|
|
@ -392,8 +392,84 @@ def test_get_qwen_auth_status_logged_in(qwen_env):
|
|||
assert status["api_key"] == "status-at"
|
||||
|
||||
|
||||
def test_get_qwen_auth_status_refreshes_expired_token(qwen_env):
|
||||
expired_ms = int((time.time() - 3600) * 1000)
|
||||
tokens = _make_qwen_tokens(access_token="old-at", expiry_date=expired_ms)
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
|
||||
refreshed = _make_qwen_tokens(access_token="refreshed-at")
|
||||
|
||||
with patch(
|
||||
"hermes_cli.auth._refresh_qwen_cli_tokens", return_value=refreshed
|
||||
) as mock_refresh:
|
||||
status = get_qwen_auth_status()
|
||||
|
||||
mock_refresh.assert_called_once()
|
||||
assert status["logged_in"] is True
|
||||
assert status["api_key"] == "refreshed-at"
|
||||
|
||||
|
||||
def test_get_qwen_auth_status_expired_unrefreshable_token_is_not_logged_in(qwen_env):
|
||||
expired_ms = int((time.time() - 3600) * 1000)
|
||||
tokens = _make_qwen_tokens(access_token="dead-at", expiry_date=expired_ms)
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
|
||||
with patch(
|
||||
"hermes_cli.auth._refresh_qwen_cli_tokens",
|
||||
side_effect=AuthError(
|
||||
"Qwen refresh rejected. Re-run 'qwen auth qwen-oauth'.",
|
||||
provider="qwen-oauth",
|
||||
code="qwen_refresh_failed",
|
||||
),
|
||||
) as mock_refresh:
|
||||
status = get_qwen_auth_status()
|
||||
|
||||
mock_refresh.assert_called_once()
|
||||
assert status["logged_in"] is False
|
||||
assert "qwen auth qwen-oauth" in status["error"]
|
||||
|
||||
|
||||
def test_get_qwen_auth_status_not_logged_in(qwen_env):
|
||||
# No credentials file
|
||||
status = get_qwen_auth_status()
|
||||
assert status["logged_in"] is False
|
||||
assert "error" in status
|
||||
|
||||
|
||||
def test_model_flow_qwen_oauth_stale_token_shows_reauth_guidance(qwen_env, monkeypatch, capsys):
|
||||
from hermes_cli.main import _model_flow_qwen_oauth
|
||||
|
||||
expired_ms = int((time.time() - 3600) * 1000)
|
||||
tokens = _make_qwen_tokens(access_token="dead-at", expiry_date=expired_ms)
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth._refresh_qwen_cli_tokens",
|
||||
lambda *args, **kwargs: (_ for _ in ()).throw(
|
||||
AuthError(
|
||||
"Qwen refresh rejected. Re-run 'qwen auth qwen-oauth'.",
|
||||
provider="qwen-oauth",
|
||||
code="qwen_refresh_failed",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
prompt_called = {"value": False}
|
||||
update_called = {"value": False}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth._prompt_model_selection",
|
||||
lambda *args, **kwargs: prompt_called.__setitem__("value", True),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth._update_config_for_provider",
|
||||
lambda *args, **kwargs: update_called.__setitem__("value", True),
|
||||
)
|
||||
|
||||
_model_flow_qwen_oauth({}, current_model="qwen3-coder-plus")
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Run: qwen auth qwen-oauth" in out
|
||||
assert "Qwen refresh rejected" in out
|
||||
assert prompt_called["value"] is False
|
||||
assert update_called["value"] is False
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue