Merge remote-tracking branch 'origin/main' into bb/pets-merge

# Conflicts:
#	hermes_cli/commands.py
#	tui_gateway/server.py
This commit is contained in:
Brooklyn Nicholson 2026-06-23 19:05:22 -05:00
commit e495b33bf1
251 changed files with 23395 additions and 2720 deletions

View file

@ -331,6 +331,131 @@ class TestResolveAnthropicToken:
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
assert resolve_anthropic_token() == "cc-auto-token"
def test_falls_back_to_anthropic_credential_pool_oauth(self, monkeypatch, tmp_path):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
# Isolate source #4 (credential_pool): ensure source #3 (Claude Code
# creds, incl. the macOS keychain read which Path.home does not cover)
# returns nothing, mirroring a Hermes-PKCE-only setup.
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
pool_entry = SimpleNamespace(
auth_type="oauth",
access_token="pool-oauth-token",
)
pool = SimpleNamespace(
_available_entries=lambda **_kwargs: [pool_entry],
)
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
assert resolve_anthropic_token() == "pool-oauth-token"
def test_prefers_anthropic_credential_pool_oauth_over_api_key(self, monkeypatch, tmp_path):
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant...ykey")
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
# Pool (source #4) must win over ANTHROPIC_API_KEY (source #5); also
# isolate source #3 so a machine-local Claude Code creds / keychain
# entry can't short-circuit before the pool.
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
pool_entry = SimpleNamespace(
auth_type="oauth",
access_token="pool-oauth-token",
)
pool = SimpleNamespace(
_available_entries=lambda **_kwargs: [pool_entry],
)
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
assert resolve_anthropic_token() == "pool-oauth-token"
def test_pool_entry_with_null_access_token_does_not_crash(self, monkeypatch, tmp_path):
"""A persisted OAuth entry with access_token=None must not crash the
resolver (None.strip() would escape the helper's try/excepts and take
down the whole resolver incl. the ANTHROPIC_API_KEY fallback). It should
be skipped and the api-key fallback (source #5) should win."""
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant...ykey")
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
broken_entry = SimpleNamespace(auth_type="oauth", access_token=None)
pool = SimpleNamespace(
_available_entries=lambda **_kwargs: [broken_entry],
)
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
# Must fall through to source #5 (ANTHROPIC_API_KEY), not raise.
assert resolve_anthropic_token() == "sk-ant...ykey"
def test_pool_api_key_only_entry_is_not_returned_as_token(self, monkeypatch, tmp_path):
"""resolve_anthropic_token() returns an OAuth bearer token; a pool entry
whose auth_type is api_key (not oauth) must NOT be returned from the pool
path those are consumed via the aux client's _pool_runtime_api_key
lane, a different resolution concern."""
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
api_key_entry = SimpleNamespace(auth_type="api_key", access_token="sk-pool-apikey")
pool = SimpleNamespace(
_available_entries=lambda **_kwargs: [api_key_entry],
)
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
# No OAuth entry and no other source → None (the api_key entry is ignored here).
assert resolve_anthropic_token() is None
def test_pool_is_not_consulted_when_env_token_present(self, monkeypatch, tmp_path):
"""Source #1 (ANTHROPIC_TOKEN) must short-circuit before the pool: when
it is set, load_pool must never be called (ordering contract #1 → #4)."""
monkeypatch.setenv("ANTHROPIC_TOKEN", "env-token")
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
pool_calls = []
def _tracking_load_pool(provider):
pool_calls.append(provider)
raise AssertionError("load_pool must not be called when source #1 wins")
monkeypatch.setattr("agent.credential_pool.load_pool", _tracking_load_pool)
assert resolve_anthropic_token() == "env-token"
assert pool_calls == []
def test_pool_resolution_is_read_only(self, monkeypatch, tmp_path):
"""The resolver must enumerate the pool read-only — clear_expired and
refresh must both be False so a bare resolve never writes auth.json or
triggers a network refresh from diagnostic call sites (#50108 MED)."""
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
captured = {}
pool_entry = SimpleNamespace(auth_type="oauth", access_token="pool-oauth-token")
def _available_entries(**kwargs):
captured.update(kwargs)
return [pool_entry]
pool = SimpleNamespace(_available_entries=_available_entries)
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
assert resolve_anthropic_token() == "pool-oauth-token"
assert captured == {"clear_expired": False, "refresh": False}
def test_prefers_refreshable_claude_code_credentials_over_static_anthropic_token(self, monkeypatch, tmp_path):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.setenv("ANTHROPIC_TOKEN", "sk-ant-oat01-static-token")

View file

@ -206,6 +206,35 @@ class TestProjectFacts:
assert "Project: package.json" in block
assert "Verify:" not in block
def test_detect_project_facts_structured(self, tmp_path):
(tmp_path / "package.json").write_text(
json.dumps({"scripts": {"test": "vitest", "dev": "vite"}})
)
(tmp_path / "pnpm-lock.yaml").write_text("")
facts = cc.detect_project_facts(tmp_path)
assert facts.manifests == ["package.json"]
assert facts.package_managers == ["pnpm"]
assert facts.verify_commands == ["pnpm run test"] # dev excluded
assert facts.context_files == []
def test_project_facts_for_matches_prompt_block(self, tmp_path):
# Invariant: the structured facts the UI consumes must not drift from the
# commands the prompt snapshot renders — one detector feeds both.
_git_init(tmp_path)
(tmp_path / "package.json").write_text(
json.dumps({"scripts": {"test": "vitest", "lint": "eslint ."}})
)
(tmp_path / "pnpm-lock.yaml").write_text("")
facts = cc.project_facts_for(tmp_path)
assert facts is not None
verify_line = cc.build_coding_workspace_block(tmp_path).split("Verify:")[1].splitlines()[0]
assert facts["verifyCommands"]
for cmd in facts["verifyCommands"]:
assert cmd in verify_line
def test_project_facts_for_none_outside_workspace(self, tmp_path):
assert cc.project_facts_for(tmp_path) is None
# ── $HOME dotfiles guard ────────────────────────────────────────────────────

View file

@ -0,0 +1,86 @@
"""Regression: detect compression progress by tokens, not just rows.
Issue #39548: preflight compression in the turn prologue was checking
``len(messages) >= _orig_len`` to decide "Cannot compress further". This
false-positives when a pass summarises message contents reducing the
estimated request token count without removing any rows and surfaces a
spurious ``Context length exceeded`` failure followed by an auto-reset of
an otherwise healthy session.
These tests pin the contract of ``_compression_made_progress``: a
row-count reduction OR a *material* (>5%) token-count reduction counts as
progress.
"""
from __future__ import annotations
from agent.turn_context import _compression_made_progress
class TestCompressionMadeProgress:
def test_rows_reduced_counts_as_progress(self):
"""Removing message rows is the obvious progress signal."""
assert _compression_made_progress(
orig_len=10, new_len=5, orig_tokens=1000, new_tokens=1000
) is True
def test_tokens_reduced_without_row_change_counts_as_progress(self):
"""Issue #39548: 220 → 220 rows, 288k → 183k tokens IS progress."""
assert _compression_made_progress(
orig_len=220, new_len=220, orig_tokens=288_028, new_tokens=183_180
) is True
def test_both_reduced_counts_as_progress(self):
"""Common case: summarising drops some rows and shrinks the rest."""
assert _compression_made_progress(
orig_len=220, new_len=180, orig_tokens=288_028, new_tokens=150_000
) is True
def test_neither_moved_means_no_progress(self):
"""The genuine "stuck" case — same rows, same tokens, give up."""
assert _compression_made_progress(
orig_len=10, new_len=10, orig_tokens=1000, new_tokens=1000
) is False
def test_rows_grew_and_tokens_grew_means_no_progress(self):
"""Pathological: the pass made the request larger — definitely stuck."""
assert _compression_made_progress(
orig_len=10, new_len=12, orig_tokens=1000, new_tokens=1200
) is False
def test_rows_grew_but_tokens_dropped_is_progress(self):
"""Edge: summary rows may expand the row count while shrinking tokens.
Token reduction alone is sufficient to keep the loop going.
"""
assert _compression_made_progress(
orig_len=10, new_len=11, orig_tokens=1000, new_tokens=600
) is True
def test_tokens_grew_but_rows_dropped_is_progress(self):
"""Edge: row reduction alone is sufficient even if tokens nominally
creep up (e.g. summary verbosity). Row-count reduction is a hard
signal that the transcript actually shrank.
"""
assert _compression_made_progress(
orig_len=10, new_len=5, orig_tokens=1000, new_tokens=1100
) is True
def test_sub_5pct_token_drop_is_not_progress(self):
"""A token reduction below the 5% material floor does NOT count as
progress matching the overflow-handler retry path (#39550) so a
marginal wobble can't keep the multi-pass loop spinning."""
# 1000 -> 970 is a 3% drop, below the 5% floor.
assert _compression_made_progress(
orig_len=10, new_len=10, orig_tokens=1000, new_tokens=970
) is False
# 1000 -> 940 is a 6% drop, above the floor.
assert _compression_made_progress(
orig_len=10, new_len=10, orig_tokens=1000, new_tokens=940
) is True
def test_zero_orig_tokens_is_not_progress(self):
"""Degenerate estimate (0 tokens) must not be read as a token win."""
assert _compression_made_progress(
orig_len=10, new_len=10, orig_tokens=0, new_tokens=0
) is False

View file

@ -0,0 +1,107 @@
"""Regression tests for tool_call envelope accounting in the compression
tail-protection budget walks (issue #28053).
The budget walks used to estimate an assistant message's tokens from
content + ``function.arguments`` only, dropping each ``tool_call``'s ``id``,
``type`` and ``function.name`` (plus JSON structure). For assistant turns
that fan out into parallel tool calls this undercounted by 2-15x, so the
protected tail overshot ``tail_token_budget`` and compression became
ineffective. The fix routes all three walks through
``_estimate_msg_budget_tokens``, which counts the full envelope.
"""
import pytest
from unittest.mock import patch
from agent.context_compressor import (
ContextCompressor,
_CHARS_PER_TOKEN,
_estimate_msg_budget_tokens,
)
def _assistant_with_tool_calls(n_calls: int, *, args: str = '{"path":"a"}') -> dict:
"""An assistant turn fanning into ``n_calls`` parallel tool calls with
realistic id/name overhead but a small arguments string."""
return {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": f"call_{i:02d}_{'a' * 24}", # ~32 chars, UUID-ish id
"type": "function",
"function": {"name": "read_file", "arguments": args},
}
for i in range(n_calls)
],
}
def _args_only_estimate(msg: dict) -> int:
"""Reproduce the OLD (buggy) arguments-only walk for comparison."""
content = msg.get("content") or ""
tokens = len(content) // _CHARS_PER_TOKEN + 10
for tc in msg.get("tool_calls") or []:
if isinstance(tc, dict):
tokens += len(tc.get("function", {}).get("arguments", "")) // _CHARS_PER_TOKEN
return tokens
class TestToolCallEnvelopeEstimate:
def test_envelope_counted_not_just_arguments(self):
msg = _assistant_with_tool_calls(4)
new = _estimate_msg_budget_tokens(msg)
old = _args_only_estimate(msg)
# id/type/name + JSON structure dwarf the tiny arguments string.
assert new > old * 3, (new, old)
# The estimate covers the full serialized tool_call envelope.
envelope = sum(len(str(tc)) for tc in msg["tool_calls"]) // _CHARS_PER_TOKEN
assert new >= envelope
def test_scales_with_number_of_parallel_calls(self):
one = _estimate_msg_budget_tokens(_assistant_with_tool_calls(1))
five = _estimate_msg_budget_tokens(_assistant_with_tool_calls(5))
assert five > one * 3
def test_no_tool_calls_matches_content_estimate(self):
msg = {"role": "user", "content": "x" * 400}
# Plain message: content//4 + 10 overhead, behavior unchanged.
assert _estimate_msg_budget_tokens(msg) == 400 // _CHARS_PER_TOKEN + 10
def test_non_dict_tool_calls_do_not_crash(self):
msg = {"role": "assistant", "content": "hi", "tool_calls": ["weird", None]}
# Non-dict entries are ignored (as before) without raising.
assert _estimate_msg_budget_tokens(msg) == len("hi") // _CHARS_PER_TOKEN + 10
@pytest.fixture()
def compressor():
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
return ContextCompressor(
model="test/model",
threshold_percent=0.85,
protect_first_n=2,
protect_last_n=2,
quiet_mode=True,
)
class TestTailCutAccountsForToolCalls:
def test_tail_cut_stops_on_tool_call_heavy_tail(self, compressor):
# 20 assistant turns, each fanning into 5 short-arg tool calls.
heavy = [_assistant_with_tool_calls(5) for _ in range(20)]
messages = [{"role": "user", "content": "start"}] + heavy
per_msg = _estimate_msg_budget_tokens(messages[-1])
assert per_msg > 30 # sanity: a heavy turn is non-trivial once the envelope counts
# Budget sized so ~6 heavy turns fit under the 1.5x soft ceiling.
token_budget = int(per_msg * 6 / 1.5)
cut = compressor._find_tail_cut_by_tokens(messages, head_end=1, token_budget=token_budget)
protected = len(messages) - cut
# With the envelope counted, the walk stops well short of protecting all
# 20 turns. The old arguments-only estimate (~25 tokens/turn) never
# reaches the ceiling and would protect the entire transcript.
assert protected < len(heavy)
assert 3 <= protected <= 12

View file

@ -86,6 +86,28 @@ class TestPreflightDeferral:
assert compressor.should_defer_preflight_to_real_usage(93_000) is False
def test_defers_immediately_after_compaction_with_stale_real_prompt(self, compressor):
"""#36718: right after a compaction, last_real_prompt_tokens still holds
the stale pre-compression value (above threshold). The awaiting flag
must force deferral so preflight doesn't fire a SECOND compaction before
real post-compaction usage arrives."""
compressor.threshold_tokens = 85_000
# Stale pre-compression value — would hit the `>= threshold => False`
# short-circuit and defeat deferral without the flag guard.
compressor.last_real_prompt_tokens = 120_000
compressor.awaiting_real_usage_after_compression = True
assert compressor.should_defer_preflight_to_real_usage(95_000) is True
def test_resumes_normal_deferral_after_flag_cleared(self, compressor):
"""Once update_from_response() clears the flag, the normal baseline/
growth deferral logic governs again (no permanent deferral)."""
compressor.threshold_tokens = 85_000
compressor.last_real_prompt_tokens = 120_000
compressor.awaiting_real_usage_after_compression = False
# Stale-high real prompt with the flag cleared => the >= threshold
# short-circuit applies => no deferral.
assert compressor.should_defer_preflight_to_real_usage(95_000) is False
class TestCompress:
@ -242,6 +264,59 @@ class TestCompress:
assert c.should_compress(55000) is True
assert c.should_compress(40000) is False
def test_max_tokens_reservation_lowers_threshold(self):
"""#43547: the provider reserves max_tokens out of the window, so the
threshold must be based on (context_length - max_tokens), not the full
window. A 200K model reserving 65536 output tokens has a ~134K input
budget; at 50% that's ~67K, NOT 100K."""
# No reservation (provider default) → full-window behavior, unchanged.
assert ContextCompressor._compute_threshold_tokens(200000, 0.50) == 100000
assert ContextCompressor._compute_threshold_tokens(200000, 0.50, None) == 100000
# 65536 reserved → effective input budget 134464; 50% = 67232.
assert ContextCompressor._compute_threshold_tokens(200000, 0.50, 65536) == 67232
def test_max_tokens_reservation_with_small_window_floors(self):
"""With a large reservation on a smaller window the effective budget
can drop near/below the minimum floor the degenerate-window guard
then triggers at 85% of the EFFECTIVE budget, never the raw window."""
# 128K window, 65536 reserved → effective 62464 (< MINIMUM 64000).
# Floor (64000) >= effective window (62464) → 85% of effective.
t = ContextCompressor._compute_threshold_tokens(128000, 0.50, 65536)
assert t == int(62464 * 0.85) # 53094
assert t < 62464
def test_max_tokens_exceeding_window_falls_back_to_full(self):
"""Pathological: max_tokens >= context_length would make the effective
budget <= 0; fall back to the full window rather than produce a
non-positive threshold."""
t = ContextCompressor._compute_threshold_tokens(64000, 0.50, 70000)
# effective_window <= 0 → fall back to full context (64000) → 85% guard.
assert t == 54400 # 85% of 64000, same as no-reservation small-ctx case
assert t > 0
def test_max_tokens_coercion_treats_non_int_as_no_reservation(self):
"""A non-int / non-positive max_tokens must coerce safely so the
threshold arithmetic never raises. Guards the path where a mocked
parent agent forwards a MagicMock max_tokens into a child
ContextCompressor (regression for the delegate-test TypeError:
'<=' not supported between MagicMock and int)."""
from unittest.mock import MagicMock
assert ContextCompressor._coerce_max_tokens(None) is None
assert ContextCompressor._coerce_max_tokens(0) is None
assert ContextCompressor._coerce_max_tokens(-5) is None
assert ContextCompressor._coerce_max_tokens("nope") is None
assert ContextCompressor._coerce_max_tokens(65536) == 65536
# The actual regression: building a compressor with a MagicMock
# max_tokens must NOT raise (the unmocked code did `ctx - MagicMock`
# then `MagicMock <= 0`). int(MagicMock()) returns 1, so coercion
# yields a harmless positive int rather than crashing — the threshold
# is computed cleanly with a 1-token reservation.
with patch("agent.context_compressor.get_model_context_length", return_value=200000):
c = ContextCompressor(model="m", quiet_mode=True, max_tokens=MagicMock())
assert isinstance(c.max_tokens, int)
assert isinstance(c.threshold_tokens, int)
assert c.threshold_tokens > 0 # no crash, sane value
def test_compression_increments_count(self, compressor):
msgs = self._make_messages(10)
# Default config (abort_on_summary_failure=False) — fallback path

View file

@ -0,0 +1,73 @@
"""Tests for /learn — open-ended skill distillation.
Covers the shared prompt builder (agent.learn_prompt.build_learn_prompt) and
the slash-command registry wiring. /learn has no engine and no model tool: it
builds a standards-guided prompt that the live agent runs as a normal turn, so
these are the load-bearing behavior contracts.
"""
from agent.learn_prompt import build_learn_prompt, _AUTHORING_STANDARDS
class TestBuildLearnPrompt:
def test_embeds_the_user_request_verbatim(self):
req = "the REST client in ~/projects/acme-sdk, focus on auth"
prompt = build_learn_prompt(req)
assert req in prompt
def test_always_includes_the_authoring_standards(self):
# The standards are what make distilled skills match house style;
# they must travel with every prompt regardless of input.
for req in ["", "a url https://x/y", "what we just did"]:
assert _AUTHORING_STANDARDS in build_learn_prompt(req)
def test_instructs_saving_via_skill_manage_not_a_raw_file(self):
prompt = build_learn_prompt("learn the thing")
assert "skill_manage" in prompt
def test_references_gather_tools_for_open_ended_sourcing(self):
# Open-ended sourcing relies on the agent's own tools, named so it
# knows dirs/URLs/conversation/paste all route through existing tools.
prompt = build_learn_prompt("learn from somewhere")
for tool in ("read_file", "search_files", "web_extract"):
assert tool in prompt
def test_empty_request_falls_back_to_the_conversation(self):
# Bare /learn should distill "what we just did", not error.
prompt = build_learn_prompt("")
assert "conversation" in prompt.lower()
# And still carries the standards + save instruction.
assert "skill_manage" in prompt
def test_whitespace_only_request_is_treated_as_empty(self):
assert build_learn_prompt(" \n ") == build_learn_prompt("")
def test_description_length_rule_is_in_the_standards(self):
# The single most-violated rule must be explicit in the prompt.
assert "60" in _AUTHORING_STANDARDS
class TestLearnRegistryWiring:
def test_learn_is_registered_and_resolves(self):
from hermes_cli.commands import resolve_command
cmd = resolve_command("learn")
assert cmd is not None
assert cmd.name == "learn"
def test_learn_is_in_tools_and_skills_category(self):
from hermes_cli.commands import resolve_command
assert resolve_command("learn").category == "Tools & Skills"
def test_learn_works_on_the_gateway(self):
# /learn must reach the gateway runner (it's a both-surfaces command),
# not be CLI-only.
from hermes_cli.commands import GATEWAY_KNOWN_COMMANDS
assert "learn" in GATEWAY_KNOWN_COMMANDS
def test_learn_is_not_cli_only(self):
from hermes_cli.commands import resolve_command
assert not resolve_command("learn").cli_only

View file

@ -1172,16 +1172,12 @@ class TestOnMemoryWriteBridge:
mgr.on_memory_write("replace", "user", "updated pref")
assert p.memory_writes == [("replace", "user", "updated pref")]
def test_on_memory_write_remove_not_bridged(self):
"""The bridge intentionally skips 'remove' — only add/replace notify."""
# This tests the contract that run_agent.py checks:
# function_args.get("action") in ("add", "replace")
def test_on_memory_write_remove_supported_by_manager(self):
"""The manager forwards remove actions when a caller elects to bridge them."""
mgr = MemoryManager()
p = FakeMemoryProvider("ext")
mgr.add_provider(p)
# Manager itself doesn't filter — run_agent.py does.
# But providers should handle remove gracefully.
mgr.on_memory_write("remove", "memory", "old fact")
assert p.memory_writes == [("remove", "memory", "old fact")]

View file

@ -0,0 +1,145 @@
"""Behavior tests for the built-in memory → external provider bridge.
The bridge lives behind the MemoryManager interface
(``MemoryManager.notify_memory_tool_write``): the agent loop hands over the raw
built-in memory tool result + args, and the manager decides whether/what to
mirror to external providers. These tests drive that method with a fake
external provider and assert which ``on_memory_write`` calls land.
"""
import json
import pytest
from agent.memory_manager import MemoryManager
from agent.memory_provider import MemoryProvider
class _RecordingProvider(MemoryProvider):
"""Minimal external provider that records on_memory_write calls."""
def __init__(self) -> None:
self.calls = []
@property
def name(self) -> str:
return "recording"
def is_available(self) -> bool:
return True
def initialize(self, session_id: str, **kwargs) -> None:
pass
def get_tool_schemas(self):
return []
def shutdown(self) -> None:
pass
def on_memory_write(self, action, target, content, metadata=None):
self.calls.append({
"action": action,
"target": target,
"content": content,
"metadata": dict(metadata or {}),
})
def _manager_with_provider():
mgr = MemoryManager()
provider = _RecordingProvider()
mgr.add_provider(provider)
return mgr, provider
def test_notifies_remove_with_old_text_after_success():
mgr, provider = _manager_with_provider()
mgr.notify_memory_tool_write(
json.dumps({"success": True}),
{"action": "remove", "target": "memory", "old_text": "stale preference entry"},
)
assert provider.calls == [
{
"action": "remove",
"target": "memory",
"content": "",
"metadata": {"old_text": "stale preference entry"},
}
]
def test_skips_failed_memory_write():
mgr, provider = _manager_with_provider()
mgr.notify_memory_tool_write(
json.dumps({"success": False, "error": "No entry matched"}),
{"action": "remove", "target": "memory", "old_text": "stale preference entry"},
)
assert provider.calls == []
def test_skips_staged_memory_write():
mgr, provider = _manager_with_provider()
mgr.notify_memory_tool_write(
json.dumps({"success": True, "staged": True, "pending_id": "abc123"}),
{"action": "remove", "target": "memory", "old_text": "stale preference entry"},
)
assert provider.calls == []
@pytest.mark.parametrize("tool_result", [None, [], object(), "not-json"])
def test_skips_unrecognized_tool_result_shape(tool_result):
mgr, provider = _manager_with_provider()
mgr.notify_memory_tool_write(
tool_result,
{"action": "add", "target": "memory", "content": "new fact"},
)
assert provider.calls == []
def test_preserves_old_text_for_replace_and_remove_batch():
mgr, provider = _manager_with_provider()
mgr.notify_memory_tool_write(
json.dumps({"success": True}),
{
"target": "user",
"operations": [
{"action": "replace", "old_text": "old preference", "content": "updated"},
{"action": "remove", "old_text": "obsolete preference"},
{"action": "add", "content": "new fact"},
],
},
)
assert provider.calls == [
{"action": "replace", "target": "user", "content": "updated",
"metadata": {"old_text": "old preference"}},
{"action": "remove", "target": "user", "content": "",
"metadata": {"old_text": "obsolete preference"}},
{"action": "add", "target": "user", "content": "new fact", "metadata": {}},
]
def test_non_mutating_actions_are_not_mirrored():
mgr, provider = _manager_with_provider()
mgr.notify_memory_tool_write(
json.dumps({"success": True}),
{"action": "read", "target": "memory"},
)
assert provider.calls == []
def test_build_metadata_callback_is_merged_per_op():
mgr, provider = _manager_with_provider()
mgr.notify_memory_tool_write(
json.dumps({"success": True}),
{"action": "add", "target": "memory", "content": "fact"},
build_metadata=lambda: {"session_id": "s1", "tool_name": "memory"},
)
assert provider.calls == [
{
"action": "add",
"target": "memory",
"content": "fact",
"metadata": {"session_id": "s1", "tool_name": "memory"},
}
]

110
tests/agent/test_oneshot.py Normal file
View file

@ -0,0 +1,110 @@
"""Tests for agent.oneshot — shared one-off (stateless) LLM requests."""
from unittest.mock import MagicMock, patch
import pytest
from agent.oneshot import (
PROMPT_TEMPLATES,
render_template,
run_oneshot,
_strip_code_fence,
_truncate,
)
class TestRenderTemplate:
def test_unknown_template_raises(self):
with pytest.raises(KeyError):
render_template("does-not-exist", {})
def test_commit_message_template_is_registered(self):
assert "commit_message" in PROMPT_TEMPLATES
def test_commit_message_includes_diff_and_recent(self):
instructions, user = render_template(
"commit_message",
{"diff": "diff --git a/x b/x\n+new", "recent_commits": "feat: a\nfix: b"},
)
# Instructions describe the contract (conventional commits), not a snapshot.
assert "Conventional Commits" in instructions
assert "diff --git a/x b/x" in user
assert "feat: a" in user
def test_commit_message_diff_with_braces_passes_through(self):
# Templates must not use str.format — code payloads carry literal { }.
_, user = render_template("commit_message", {"diff": "x = {a: 1}"})
assert "x = {a: 1}" in user
def test_commit_message_handles_missing_variables(self):
instructions, user = render_template("commit_message", {})
assert instructions
assert "no textual diff available" in user
def test_commit_message_avoid_forces_new_message(self):
# Passing the previous message must instruct the model not to repeat it,
# so "regenerate" yields a different result even on greedy models.
_, plain = render_template("commit_message", {"diff": "d"})
_, regen = render_template("commit_message", {"diff": "d", "avoid": "feat: prior"})
assert "feat: prior" in regen
assert "do not repeat" in regen
assert "feat: prior" not in plain
class TestRunOneshot:
def _mock_response(self, content):
resp = MagicMock()
resp.choices = [MagicMock()]
resp.choices[0].message.content = content
resp.choices[0].message.reasoning = None
resp.choices[0].message.reasoning_content = None
resp.choices[0].message.reasoning_details = None
return resp
def test_template_path_calls_llm_with_rendered_prompt(self):
with patch(
"agent.oneshot.call_llm",
return_value=self._mock_response("feat: add thing"),
) as llm:
out = run_oneshot(template="commit_message", variables={"diff": "d"})
assert out == "feat: add thing"
messages = llm.call_args.kwargs["messages"]
assert messages[0]["role"] == "system"
assert messages[1]["role"] == "user"
def test_explicit_instructions_path(self):
with patch(
"agent.oneshot.call_llm",
return_value=self._mock_response("hello"),
) as llm:
out = run_oneshot(instructions="be brief", user_input="say hi")
assert out == "hello"
messages = llm.call_args.kwargs["messages"]
assert messages[0]["content"] == "be brief"
assert messages[1]["content"] == "say hi"
def test_requires_template_or_prompt(self):
with pytest.raises(ValueError):
run_oneshot()
def test_strips_wrapping_code_fence(self):
with patch(
"agent.oneshot.call_llm",
return_value=self._mock_response("```\nfix: bug\n```"),
):
assert run_oneshot(instructions="x", user_input="y") == "fix: bug"
class TestHelpers:
def test_truncate_under_limit_unchanged(self):
assert _truncate("short", 100) == "short"
def test_truncate_over_limit_marks_truncation(self):
out = _truncate("x" * 200, 50)
assert out.endswith("…(truncated)")
assert len(out) < 200
def test_strip_code_fence_without_fence_is_noop(self):
assert _strip_code_fence("plain text") == "plain text"

View file

@ -100,7 +100,13 @@ class _StubAgent:
pass
def _run(agent):
def _run(
agent,
*,
final_response=None,
api_call_count=3,
turn_exit_reason="unknown",
):
messages = [
{"role": "user", "content": "do a thing"},
{
@ -114,8 +120,8 @@ def _run(agent):
]
return finalize_turn(
agent,
final_response=None, # forces the max-iterations summary path
api_call_count=3,
final_response=final_response,
api_call_count=api_call_count,
interrupted=False,
failed=False,
messages=messages,
@ -125,7 +131,7 @@ def _run(agent):
user_message="do a thing",
original_user_message="do a thing",
_should_review_memory=False,
_turn_exit_reason="unknown",
_turn_exit_reason=turn_exit_reason,
)
@ -162,4 +168,17 @@ def test_clean_turn_has_no_cleanup_errors_key():
agent = _StubAgent(raise_in=())
result = _run(agent)
assert result["final_response"] == "PARTIAL SUMMARY FROM MODEL"
assert result["completed"] is False
assert "cleanup_errors" not in result
def test_text_response_on_last_allowed_call_is_completed():
agent = _StubAgent(raise_in=())
result = _run(
agent,
final_response="final report",
api_call_count=agent.max_iterations,
turn_exit_reason="text_response(finish_reason=stop)",
)
assert result["final_response"] == "final report"
assert result["completed"] is True

View file

@ -0,0 +1,85 @@
"""Tests for scripts/ci/classify_changes.py.
Check some common patterns of file modifications and the CI lanes they should run.
We should always fail open. We may run a lane we didn't need, never skip one a
change could have broken.
"""
from __future__ import annotations
import importlib.util
from pathlib import Path
import pytest
_PATH = Path(__file__).resolve().parents[2] / "scripts" / "ci" / "classify_changes.py"
_spec = importlib.util.spec_from_file_location("classify_changes", _PATH)
if _spec is None or _spec.loader is None:
raise ImportError("Failed to load classify_changes.py")
_mod = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(_mod)
classify = _mod.classify
DEFAULT = {
"python": True,
"frontend": True,
"docker_meta": True,
"site": True,
"scan": True,
"deps": True,
"mcp_catalog": False,
}
def _lanes(python=False, frontend=False, site=False, scan=False, deps=False, mcp_catalog=False, docker_meta=False) -> dict[str, bool]:
return {
"python": python,
"frontend": frontend,
"docker_meta": docker_meta,
"site": site,
"scan": scan,
"deps": deps,
"mcp_catalog": mcp_catalog,
}
CASES = {
"docs-only → nothing heavy": (["README.md", "docs/guide.md"], _lanes()),
"python source → python": (["run_agent.py"], _lanes(python=True, scan=True)),
"dep manifest → python": (["pyproject.toml"], _lanes(python=True, scan=True, deps=True)),
"uv.lock → python": (["uv.lock"], _lanes(python=True)),
"ts package → frontend": (["apps/desktop/src/app.tsx"], _lanes(frontend=True)),
"ui-tui → frontend": (["ui-tui/src/entry.ts"], _lanes(frontend=True)),
# Lockfile bump shifts every TS package's tree, but not the Python suite.
"root lockfile → frontend, not python": (["package-lock.json"], _lanes(frontend=True)),
"website → site": (["website/docs/intro.md"], _lanes(site=True)),
# SKILL.md reads like docs, but the skill-doc tests read skills/, so a
# skill edit must still run Python.
"skill md → python + site": (["skills/github/SKILL.md"], _lanes(python=True, site=True)),
"dockerfile → docker meta": (["Dockerfile"], _lanes(docker_meta=True)),
# Unknown top-level file keeps Python on rather than risk a silent skip.
"unknown toplevel → python": (["Makefile"], _lanes(python=True)),
"mixed docs+python → python": (["README.md", "agent/x.py"], _lanes(python=True, scan=True)),
"mixed docs+frontend → frontend": (["README.md", "apps/x.tsx"], _lanes(frontend=True)),
# Supply-chain lanes
".pth file → scan": (["evil.pth"], _lanes(python=True, scan=True)),
"setup.py → scan": (["setup.py"], _lanes(python=True, scan=True)),
"mcp catalog manifest → mcp_catalog": (
["optional-mcps/foo/manifest.yaml"],
_lanes(python=True, mcp_catalog=True),
),
"mcp_catalog.py → mcp_catalog": (
["hermes_cli/mcp_catalog.py"],
_lanes(python=True, scan=True, mcp_catalog=True),
),
# Fail open: CI-config / empty / blank diffs run everything.
".github change → all": ([".github/workflows/tests.yml"], DEFAULT),
"action change → all": ([".github/actions/detect-changes/action.yml"], DEFAULT),
"empty diff → all": ([], DEFAULT),
"blank lines → all": (["", " "], DEFAULT),
}
@pytest.mark.parametrize("files,expected", CASES.values(), ids=CASES.keys())
def test_classify(files, expected):
assert classify(files) == expected

View file

@ -189,3 +189,82 @@ def test_indicators_independent_agents_and_processes(monkeypatch):
rendered = "".join(text for _style, text in frags)
assert "▶ 1" in rendered
assert "⚙ 2" in rendered
# ── Background/async subagent indicator (⛓ N) ─────────────────────────────
# Source of truth is tools.async_delegation.active_count() — the count of
# delegate_task delegations (batch + background single) still in the
# "running" state. Distinct from ▶ (/background agent threads) and ⚙ (shell
# processes); all three can be active at once.
def _patch_async_active(monkeypatch, count: int) -> None:
import tools.async_delegation as ad_mod
monkeypatch.setattr(ad_mod, "active_count", lambda: count)
def test_snapshot_reports_zero_when_no_background_subagents(monkeypatch):
cli_obj = _make_cli()
_patch_async_active(monkeypatch, 0)
snap = cli_obj._get_status_bar_snapshot()
assert snap["active_background_subagents"] == 0
def test_snapshot_counts_live_background_subagents(monkeypatch):
cli_obj = _make_cli()
_patch_async_active(monkeypatch, 4)
snap = cli_obj._get_status_bar_snapshot()
assert snap["active_background_subagents"] == 4
def test_snapshot_safe_when_async_active_count_raises(monkeypatch):
"""If active_count() raises the snapshot stays at 0; no propagate."""
cli_obj = _make_cli()
import tools.async_delegation as ad_mod
def _boom():
raise RuntimeError("boom")
monkeypatch.setattr(ad_mod, "active_count", _boom)
snap = cli_obj._get_status_bar_snapshot()
assert snap["active_background_subagents"] == 0
def test_plain_text_status_shows_subagent_indicator_when_active(monkeypatch):
cli_obj = _make_cli()
_patch_async_active(monkeypatch, 3)
text = cli_obj._build_status_bar_text(width=80)
assert "⛓ 3" in text
def test_plain_text_status_omits_subagent_indicator_when_idle(monkeypatch):
cli_obj = _make_cli()
_patch_async_active(monkeypatch, 0)
text = cli_obj._build_status_bar_text(width=80)
assert "" not in text
def test_fragments_include_subagent_segment_when_active(monkeypatch):
cli_obj = _make_cli()
_patch_async_active(monkeypatch, 2)
cli_obj._status_bar_visible = True
cli_obj._get_tui_terminal_width = lambda: 120 # type: ignore[method-assign]
frags = cli_obj._get_status_bar_fragments()
rendered = "".join(text for _style, text in frags)
assert "⛓ 2" in rendered
def test_all_three_background_indicators_independent(monkeypatch):
"""▶ (agent tasks), ⚙ (shell processes), ⛓ (subagents) all coexist."""
cli_obj = _make_cli()
cli_obj._background_tasks = {"bg_a": _stub_thread()}
_patch_process_registry(monkeypatch, 2)
_patch_async_active(monkeypatch, 5)
cli_obj._status_bar_visible = True
cli_obj._get_tui_terminal_width = lambda: 120 # type: ignore[method-assign]
frags = cli_obj._get_status_bar_fragments()
rendered = "".join(text for _style, text in frags)
assert "▶ 1" in rendered
assert "⚙ 2" in rendered
assert "⛓ 5" in rendered

View file

@ -169,7 +169,7 @@ class TestHealthyTurnStillRuns:
# Force the judge to say "continue" without touching the network.
with patch(
"hermes_cli.goals.judge_goal",
return_value=("continue", "needs more steps", False),
return_value=("continue", "needs more steps", False, None),
):
cli._maybe_continue_goal_after_turn()
@ -189,7 +189,7 @@ class TestHealthyTurnStillRuns:
with patch(
"hermes_cli.goals.judge_goal",
return_value=("done", "goal satisfied", False),
return_value=("done", "goal satisfied", False, None),
):
cli._maybe_continue_goal_after_turn()

View file

@ -0,0 +1,80 @@
"""Tests for the cua-driver telemetry opt-in policy.
cua-driver ships anonymous PostHog telemetry ENABLED by default upstream.
Hermes disables it unless the user opts in via
``computer_use.cua_telemetry: true``. The policy is applied by injecting
``CUA_DRIVER_RS_TELEMETRY_ENABLED=0`` into every cua-driver child env.
These assert the behavior contract (default disables, opt-in leaves the var
untouched, config failure fails safe toward disabled), not specific config
snapshots.
"""
from unittest.mock import patch
from tools.computer_use import cua_backend
_VAR = "CUA_DRIVER_RS_TELEMETRY_ENABLED"
class TestTelemetryDisabledFlag:
def test_default_config_disables(self):
# cua_telemetry absent / False => telemetry disabled.
with patch("hermes_cli.config.load_config", return_value={}):
assert cua_backend._cua_telemetry_disabled() is True
def test_explicit_false_disables(self):
with patch("hermes_cli.config.load_config",
return_value={"computer_use": {"cua_telemetry": False}}):
assert cua_backend._cua_telemetry_disabled() is True
def test_opt_in_true_does_not_disable(self):
with patch("hermes_cli.config.load_config",
return_value={"computer_use": {"cua_telemetry": True}}):
assert cua_backend._cua_telemetry_disabled() is False
def test_config_load_failure_fails_safe(self):
# Unreadable config => default to disabling telemetry (privacy-safe).
with patch("hermes_cli.config.load_config", side_effect=RuntimeError("boom")):
assert cua_backend._cua_telemetry_disabled() is True
def test_missing_section_disables(self):
with patch("hermes_cli.config.load_config", return_value={"other": {}}):
assert cua_backend._cua_telemetry_disabled() is True
class TestChildEnv:
def test_disabled_injects_var_zero(self):
with patch.object(cua_backend, "_cua_telemetry_disabled", return_value=True):
env = cua_backend.cua_driver_child_env({"PATH": "/usr/bin"})
assert env[_VAR] == "0"
# base env is preserved
assert env["PATH"] == "/usr/bin"
def test_opt_in_leaves_var_untouched(self):
# When the user opts in, we must NOT set the var — the driver uses its
# own default. If the base env already has a value, it is preserved.
with patch.object(cua_backend, "_cua_telemetry_disabled", return_value=False):
env = cua_backend.cua_driver_child_env({"PATH": "/usr/bin"})
assert _VAR not in env
def test_opt_in_preserves_user_set_var(self):
with patch.object(cua_backend, "_cua_telemetry_disabled", return_value=False):
env = cua_backend.cua_driver_child_env({_VAR: "1", "PATH": "/usr/bin"})
# user opted in and explicitly set it — don't clobber.
assert env[_VAR] == "1"
def test_disabled_overrides_inherited_enabled(self):
# Even if the parent process had telemetry enabled, the default policy
# forces it off in the child.
with patch.object(cua_backend, "_cua_telemetry_disabled", return_value=True):
env = cua_backend.cua_driver_child_env({_VAR: "1"})
assert env[_VAR] == "0"
def test_defaults_to_os_environ_when_no_base(self):
with patch.object(cua_backend, "_cua_telemetry_disabled", return_value=True), \
patch.dict("os.environ", {"SOME_MARKER": "yes"}, clear=False):
env = cua_backend.cua_driver_child_env()
assert env.get("SOME_MARKER") == "yes"
assert env[_VAR] == "0"

View file

@ -0,0 +1,325 @@
"""Tests for ``tools.computer_use.doctor``.
The doctor module drives cua-driver's stable ``health_report`` MCP tool over
stdio JSON-RPC and renders the structured response. Most of the surface is
about parsing what cua-driver hands back, plus the exit-code contract
downstream consumers (CI / `hermes update`) rely on:
* Exit 0 when overall == "ok"
* Exit 1 when overall in ("degraded", "failed") at least one check
failed but the tool itself ran successfully
* Exit 2 when the cua-driver binary is missing or the protocol breaks
We do NOT spin up a real cua-driver that lives in the cua-driver
integration test suite (libs/cua-driver/rust/tests/integration/
test_health_report_mcp.py). Here we mock the subprocess and assert the
Hermes-side adapter behaves correctly against the documented response
shape.
"""
from __future__ import annotations
import json
from io import StringIO
from unittest.mock import MagicMock, patch
# ── helpers ────────────────────────────────────────────────────────────────
def _fake_proc_with_responses(*responses: dict) -> MagicMock:
"""Build a MagicMock subprocess.Popen handle that yields one JSON-RPC
response per `readline()` call, then returns "" (EOF)."""
lines = [json.dumps(r) + "\n" for r in responses] + [""]
proc = MagicMock()
proc.stdin = MagicMock()
proc.stdout = MagicMock()
proc.stdout.readline = MagicMock(side_effect=lines)
proc.stderr = MagicMock()
proc.stderr.read = MagicMock(return_value="")
proc.wait = MagicMock(return_value=0)
proc.kill = MagicMock()
return proc
def _ok_report() -> dict:
"""Minimal well-formed health_report response."""
return {
"schema_version": "1",
"platform": "darwin",
"driver_version": "0.5.8",
"overall": "ok",
"checks": [
{"name": "binary_version", "status": "pass", "message": "cua-driver 0.5.8"},
{"name": "tcc_accessibility", "status": "pass", "message": "Accessibility is granted."},
],
}
def _degraded_report() -> dict:
"""Report with one failing check — overall=degraded."""
return {
"schema_version": "1",
"platform": "darwin",
"driver_version": "0.5.8",
"overall": "degraded",
"checks": [
{"name": "binary_version", "status": "pass", "message": "cua-driver 0.5.8"},
{
"name": "bundle_identity",
"status": "fail",
"message": "Process has no CFBundleIdentifier.",
"hint": "Run inside CuaDriver.app",
"data": {"executable_path": "/tmp/cua-driver"},
},
],
}
# ── exit codes ─────────────────────────────────────────────────────────────
class TestDoctorExitCodes:
def test_ok_exits_0(self):
from tools.computer_use import doctor
proc = _fake_proc_with_responses(
{"jsonrpc": "2.0", "id": 1, "result": {}},
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
)
with patch("shutil.which", return_value="/fake/cua-driver"), \
patch("subprocess.Popen", return_value=proc), \
patch("sys.stdout", new_callable=StringIO):
code = doctor.run_doctor()
assert code == 0
def test_degraded_exits_1(self):
from tools.computer_use import doctor
proc = _fake_proc_with_responses(
{"jsonrpc": "2.0", "id": 1, "result": {}},
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _degraded_report()}},
)
with patch("shutil.which", return_value="/fake/cua-driver"), \
patch("subprocess.Popen", return_value=proc), \
patch("sys.stdout", new_callable=StringIO):
code = doctor.run_doctor()
assert code == 1
def test_failed_overall_exits_1(self):
"""`failed` overall (every check failed) is also exit 1, not 2 —
the tool ran successfully; the diagnosis was bad."""
from tools.computer_use import doctor
report = _degraded_report()
report["overall"] = "failed"
proc = _fake_proc_with_responses(
{"jsonrpc": "2.0", "id": 1, "result": {}},
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": report}},
)
with patch("shutil.which", return_value="/fake/cua-driver"), \
patch("subprocess.Popen", return_value=proc), \
patch("sys.stdout", new_callable=StringIO):
code = doctor.run_doctor()
assert code == 1
def test_missing_binary_exits_2(self):
from tools.computer_use import doctor
with patch("shutil.which", return_value=None), \
patch("sys.stdout", new_callable=StringIO):
code = doctor.run_doctor()
assert code == 2
def test_protocol_error_exits_2(self, capsys):
"""An empty stdout response (driver crashed during handshake) is a
protocol failure exit 2."""
from tools.computer_use import doctor
proc = MagicMock()
proc.stdin = MagicMock()
proc.stdout = MagicMock()
proc.stdout.readline = MagicMock(return_value="") # EOF on initialize
proc.stderr = MagicMock()
proc.stderr.read = MagicMock(return_value="boom\n")
proc.wait = MagicMock(return_value=0)
proc.kill = MagicMock()
with patch("shutil.which", return_value="/fake/cua-driver"), \
patch("subprocess.Popen", return_value=proc):
code = doctor.run_doctor()
assert code == 2
# stderr should mention the failure
captured = capsys.readouterr()
assert "cua-driver" in captured.err.lower() or "health_report" in captured.err.lower()
# ── response-shape parsing ─────────────────────────────────────────────────
class TestResponseShapeParsing:
def test_prefers_structuredContent(self):
from tools.computer_use import doctor
proc = _fake_proc_with_responses(
{"jsonrpc": "2.0", "id": 1, "result": {}},
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
)
with patch("shutil.which", return_value="/fake/cua-driver"), \
patch("subprocess.Popen", return_value=proc), \
patch("sys.stdout", new_callable=StringIO) as out:
doctor.run_doctor()
# Header line includes driver version + platform + overall.
text = out.getvalue()
assert "darwin" in text
assert "ok" in text
def test_falls_back_to_text_content_when_structuredContent_absent(self):
"""Older cua-driver builds may emit health_report as a text content
item carrying the JSON the doctor should still parse it."""
from tools.computer_use import doctor
proc = _fake_proc_with_responses(
{"jsonrpc": "2.0", "id": 1, "result": {}},
{
"jsonrpc": "2.0", "id": 2,
"result": {
"content": [
{"type": "text", "text": json.dumps(_ok_report())},
],
},
},
)
with patch("shutil.which", return_value="/fake/cua-driver"), \
patch("subprocess.Popen", return_value=proc), \
patch("sys.stdout", new_callable=StringIO) as out:
code = doctor.run_doctor()
assert code == 0
assert "ok" in out.getvalue()
def test_jsonrpc_error_response_exits_2(self, capsys):
from tools.computer_use import doctor
proc = _fake_proc_with_responses(
{"jsonrpc": "2.0", "id": 1, "result": {}},
{"jsonrpc": "2.0", "id": 2, "error": {"code": -32601, "message": "method not found"}},
)
with patch("shutil.which", return_value="/fake/cua-driver"), \
patch("subprocess.Popen", return_value=proc):
code = doctor.run_doctor()
assert code == 2
assert "method not found" in capsys.readouterr().err
# ── args / arg passthrough ─────────────────────────────────────────────────
class TestArgPassthrough:
def test_include_passed_through_to_tools_call(self):
from tools.computer_use import doctor
proc = _fake_proc_with_responses(
{"jsonrpc": "2.0", "id": 1, "result": {}},
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
)
with patch("shutil.which", return_value="/fake/cua-driver"), \
patch("subprocess.Popen", return_value=proc), \
patch("sys.stdout", new_callable=StringIO):
doctor.run_doctor(include=["binary_version", "tcc_accessibility"])
# Inspect the second write to stdin — the tools/call payload.
writes = [call.args[0] for call in proc.stdin.write.call_args_list]
call_payload = next(json.loads(w) for w in writes if "tools/call" in w)
assert call_payload["params"]["arguments"]["include"] == [
"binary_version", "tcc_accessibility",
]
def test_skip_passed_through(self):
from tools.computer_use import doctor
proc = _fake_proc_with_responses(
{"jsonrpc": "2.0", "id": 1, "result": {}},
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
)
with patch("shutil.which", return_value="/fake/cua-driver"), \
patch("subprocess.Popen", return_value=proc), \
patch("sys.stdout", new_callable=StringIO):
doctor.run_doctor(skip=["bundle_identity"])
writes = [call.args[0] for call in proc.stdin.write.call_args_list]
call_payload = next(json.loads(w) for w in writes if "tools/call" in w)
assert call_payload["params"]["arguments"]["skip"] == ["bundle_identity"]
def test_no_filters_sends_empty_arguments(self):
"""When neither include nor skip is given, the arguments object is
empty not present-but-null so the driver's default 'run every
check' branch fires."""
from tools.computer_use import doctor
proc = _fake_proc_with_responses(
{"jsonrpc": "2.0", "id": 1, "result": {}},
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
)
with patch("shutil.which", return_value="/fake/cua-driver"), \
patch("subprocess.Popen", return_value=proc), \
patch("sys.stdout", new_callable=StringIO):
doctor.run_doctor()
writes = [call.args[0] for call in proc.stdin.write.call_args_list]
call_payload = next(json.loads(w) for w in writes if "tools/call" in w)
assert call_payload["params"]["arguments"] == {}
# ── json output ────────────────────────────────────────────────────────────
class TestJsonOutput:
def test_json_output_is_parseable_round_trip(self):
from tools.computer_use import doctor
proc = _fake_proc_with_responses(
{"jsonrpc": "2.0", "id": 1, "result": {}},
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
)
with patch("shutil.which", return_value="/fake/cua-driver"), \
patch("subprocess.Popen", return_value=proc), \
patch("sys.stdout", new_callable=StringIO) as out:
doctor.run_doctor(json_output=True)
# Verify the captured text round-trips through json.loads and matches
# the input report (the contract: --json passes the structured payload
# through unchanged so downstream tooling can consume it directly).
parsed = json.loads(out.getvalue())
assert parsed == _ok_report()
# ── HERMES_CUA_DRIVER_CMD resolution ───────────────────────────────────────
class TestDriverCmdResolution:
def test_explicit_driver_cmd_arg_wins(self):
from tools.computer_use import doctor
proc = _fake_proc_with_responses(
{"jsonrpc": "2.0", "id": 1, "result": {}},
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
)
with patch("shutil.which", return_value="/fake/explicit-binary") as which_mock, \
patch("subprocess.Popen", return_value=proc), \
patch("sys.stdout", new_callable=StringIO):
doctor.run_doctor(driver_cmd="/custom/path/cua-driver")
# shutil.which should have been called with the explicit arg, not
# the env-var / default resolver.
which_mock.assert_called_with("/custom/path/cua-driver")
def test_env_var_used_when_no_arg_given(self, monkeypatch):
from tools.computer_use import doctor
monkeypatch.setenv("HERMES_CUA_DRIVER_CMD", "/env/path/cua-driver")
proc = _fake_proc_with_responses(
{"jsonrpc": "2.0", "id": 1, "result": {}},
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
)
with patch("shutil.which", return_value="/env/path/cua-driver") as which_mock, \
patch("subprocess.Popen", return_value=proc), \
patch("sys.stdout", new_callable=StringIO):
doctor.run_doctor()
# First (and only) which call should have used the env var.
which_mock.assert_called_with("/env/path/cua-driver")

View file

@ -14,10 +14,7 @@ import pytest
def temp_home(tmp_path, monkeypatch):
"""Isolated HERMES_HOME so jobs.json doesn't touch the real store."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
# NOTE: cron.jobs resolves its store paths (JOBS_FILE, CRON_DIR) from
# get_default_hermes_root() at IMPORT time, so setting HERMES_HOME here does
# not re-point an already-imported module's store. These tests exercise the
# claim logic on in-memory job dicts and don't depend on the on-disk path.
# cron.jobs caches no home at import; get_hermes_home() reads the env live.
yield tmp_path

View file

@ -1,105 +0,0 @@
"""Regression tests for #32091 — profile-scoped cron jobs orphaned.
Cron storage (CRON_DIR/JOBS_FILE) must anchor at the *default root* Hermes
home, not the active profile's home. Otherwise a job created from a
profile-scoped agent session writes to ~/.hermes/profiles/<p>/cron/jobs.json,
while the profile-less gateway reads only ~/.hermes/cron/jobs.json the job
is silently orphaned (looks healthy in `list`, never fires).
"""
import importlib
import os
from pathlib import Path
def test_cron_storage_anchors_at_root_under_profile(tmp_path, monkeypatch):
"""Under a profile HERMES_HOME (<root>/profiles/<name>), the cron store
resolves to <root>/cron, NOT <root>/profiles/<name>/cron."""
root = tmp_path / "hermes_home"
profile_home = root / "profiles" / "myprofile"
profile_home.mkdir(parents=True)
# Pretend the platform default root IS our tmp root, and the active
# HERMES_HOME is a profile under it (the #32091 scenario).
import hermes_constants
monkeypatch.setattr(hermes_constants, "_get_platform_default_hermes_home",
lambda: root)
monkeypatch.setenv("HERMES_HOME", str(profile_home))
# get_default_hermes_root must return the ROOT, not the profile dir.
assert hermes_constants.get_default_hermes_root().resolve() == root.resolve()
# ...while get_hermes_home (used elsewhere) follows the profile override.
assert hermes_constants.get_hermes_home().resolve() == profile_home.resolve()
# cron/jobs.py computes HERMES_DIR from get_default_hermes_root at import,
# so a fresh import under this env anchors the store at <root>/cron.
import cron.jobs as jobs
importlib.reload(jobs)
try:
assert jobs.HERMES_DIR.resolve() == root.resolve()
assert jobs.JOBS_FILE.resolve() == (root / "cron" / "jobs.json").resolve()
# The orphan path (<profile>/cron/jobs.json) must NOT be the store.
assert jobs.JOBS_FILE.resolve() != (profile_home / "cron" / "jobs.json").resolve()
finally:
# Restore module state for other tests (reload under the real env).
monkeypatch.undo()
importlib.reload(jobs)
def test_cron_storage_unaffected_when_no_profile(tmp_path, monkeypatch):
"""With no profile (HERMES_HOME == root), behavior is unchanged: store at
<root>/cron."""
root = tmp_path / "hermes_home"
root.mkdir(parents=True)
import hermes_constants
monkeypatch.setattr(hermes_constants, "_get_platform_default_hermes_home",
lambda: root)
monkeypatch.setenv("HERMES_HOME", str(root))
import cron.jobs as jobs
importlib.reload(jobs)
try:
assert jobs.JOBS_FILE.resolve() == (root / "cron" / "jobs.json").resolve()
finally:
monkeypatch.undo()
importlib.reload(jobs)
def test_tick_lock_anchors_at_root_under_profile(tmp_path, monkeypatch):
"""The cron tick lock must live at <root>/cron/.tick.lock, NOT the profile
dir otherwise tickers under different profiles grab different locks and
double-fire the (now root-anchored) jobs store (#32091)."""
import importlib
root = tmp_path / "hermes_home"
profile_home = root / "profiles" / "p"
profile_home.mkdir(parents=True)
import hermes_constants
monkeypatch.setattr(hermes_constants, "_get_platform_default_hermes_home", lambda: root)
monkeypatch.setenv("HERMES_HOME", str(profile_home))
import cron.scheduler as sched
importlib.reload(sched)
try:
# _hermes_home override is None -> uses get_default_hermes_root()
sched._hermes_home = None
lock_dir, lock_file = sched._get_lock_paths()
assert lock_dir.resolve() == (root / "cron").resolve()
assert lock_file.resolve() == (root / "cron" / ".tick.lock").resolve()
assert lock_dir.resolve() != (profile_home / "cron").resolve()
finally:
monkeypatch.undo()
importlib.reload(sched)
def test_get_default_hermes_root_docker_layouts(tmp_path, monkeypatch):
"""get_default_hermes_root resolves the root for Docker/custom HERMES_HOME
(outside ~/.hermes), so cron storage works in containers."""
import hermes_constants
native = tmp_path / "native_home"
monkeypatch.setattr(hermes_constants, "_get_platform_default_hermes_home", lambda: native)
# Docker custom root (outside native): HERMES_HOME itself IS the root.
monkeypatch.setenv("HERMES_HOME", "/opt/data")
assert hermes_constants.get_default_hermes_root() == Path("/opt/data")
# Docker profile layout: <custom>/profiles/<name> -> <custom>.
monkeypatch.setenv("HERMES_HOME", "/opt/data/profiles/coder")
assert hermes_constants.get_default_hermes_root() == Path("/opt/data")

View file

@ -7,11 +7,75 @@ from unittest.mock import AsyncMock, patch, MagicMock
import pytest
from cron.scheduler import _resolve_origin, _resolve_delivery_target, _deliver_result, _send_media_via_adapter, run_job, SILENT_MARKER, _build_job_prompt
from cron.scheduler import _resolve_origin, _resolve_delivery_target, _deliver_result, _send_media_via_adapter, run_job, SILENT_MARKER, _build_job_prompt, _resolve_cron_enabled_toolsets, _merge_mcp_into_per_job_toolsets
from tools.env_passthrough import clear_env_passthrough
from tools.credential_files import clear_credential_files
class TestPerJobToolsetMcpMerge:
"""A per-job enabled_toolsets allowlist must not silently drop MCP servers."""
CFG = {
"mcp_servers": {
"finnhub": {"enabled": True},
"playwright": {"enabled": True},
"disabled_one": {"enabled": False},
"string_enabled": {"enabled": "true"},
"not_a_dict": "ignored",
}
}
def _enabled_names(self):
return {"finnhub", "playwright", "string_enabled"}
def test_native_only_list_gets_all_enabled_mcp_servers(self):
result = _merge_mcp_into_per_job_toolsets(["web", "terminal"], self.CFG)
assert result[:2] == ["web", "terminal"]
assert set(result) == {"web", "terminal"} | self._enabled_names()
def test_disabled_servers_are_not_added(self):
result = _merge_mcp_into_per_job_toolsets(["web"], self.CFG)
assert "disabled_one" not in result
def test_explicit_mcp_name_is_treated_as_allowlist(self):
# User named one server -> add nothing further.
result = _merge_mcp_into_per_job_toolsets(["web", "finnhub"], self.CFG)
assert result == ["web", "finnhub"]
assert "playwright" not in result
def test_no_mcp_sentinel_opts_out_and_is_stripped(self):
result = _merge_mcp_into_per_job_toolsets(["web", "no_mcp"], self.CFG)
assert result == ["web"]
assert not (set(result) & self._enabled_names())
def test_no_mcp_config_adds_nothing(self):
result = _merge_mcp_into_per_job_toolsets(["web"], {})
assert result == ["web"]
def test_no_duplicate_when_listed_name_also_globally_enabled(self):
result = _merge_mcp_into_per_job_toolsets(["finnhub", "finnhub"], self.CFG)
assert result.count("finnhub") == 2 # input dups preserved, none added
def test_resolver_uses_merge_for_per_job_lists(self):
job = {"enabled_toolsets": ["web", "terminal"]}
result = _resolve_cron_enabled_toolsets(job, self.CFG)
assert set(result) == {"web", "terminal"} | self._enabled_names()
def test_resolver_empty_per_job_falls_through_to_platform(self):
# No per-job list -> must delegate to _get_platform_tools (the platform
# fallback), NOT the per-job merge. Stub the platform resolver and assert
# it is the path taken and its result is returned.
job = {"enabled_toolsets": None}
sentinel = ["web", "finnhub"]
with patch("hermes_cli.tools_config._get_platform_tools",
return_value=set(sentinel)) as m_platform:
result = _resolve_cron_enabled_toolsets(job, self.CFG)
m_platform.assert_called_once()
# _get_platform_tools args: (cfg, "cron")
assert m_platform.call_args[0][1] == "cron"
assert set(result) == set(sentinel)
class TestResolveOrigin:
def test_full_origin(self):
job = {
@ -1330,6 +1394,52 @@ class TestRunJobSessionPersistence:
assert error is None
assert final_response == "all good"
def test_run_job_delivers_max_iteration_fallback_summary(self, tmp_path):
"""Cron should deliver a usable max-iteration fallback summary.
A cron run can exhaust the iteration budget, get a final text summary
from the no-tools fallback call, and still have ``completed=False`` in
the generic agent result. That should not make cron raise the report
text as a RuntimeError.
"""
job = {
"id": "summary-job",
"name": "summary",
"prompt": "finish the report",
}
fake_db = MagicMock()
with patch("cron.scheduler._hermes_home", tmp_path), \
patch("cron.scheduler._resolve_origin", return_value=None), \
patch("dotenv.load_dotenv"), \
patch("hermes_state.SessionDB", return_value=fake_db), \
patch(
"hermes_cli.runtime_provider.resolve_runtime_provider",
return_value={
"api_key": "***",
"base_url": "https://example.invalid/v1",
"provider": "openrouter",
"api_mode": "chat_completions",
},
), \
patch("run_agent.AIAgent") as mock_agent_cls:
mock_agent = MagicMock()
mock_agent.run_conversation.return_value = {
"final_response": "final fallback report",
"completed": False,
"failed": False,
"turn_exit_reason": "max_iterations_reached(60/60)",
}
mock_agent_cls.return_value = mock_agent
success, output, final_response, error = run_job(job)
assert success is True
assert error is None
assert final_response == "final fallback report"
assert "final fallback report" in output
assert "(FAILED)" not in output
def test_tick_marks_empty_response_as_error(self, tmp_path):
"""When run_job returns success=True but final_response is empty,
tick() should mark the job as error so last_status != 'ok'.

View file

@ -0,0 +1,243 @@
"""Phase 5 §5.3 — going-idle / buffered-flip primitive (gateway side).
Exercises the WebSocketRelayTransport's going_idle/ack handshake, the
buffered-inbound ack (a bufferId-carrying inbound is acked after the handler
runs), the NET-NEW reconnect loop (re-dial + re-handshake after an unexpected
close), and the RelayAdapter emitting going_idle from its existing drain
(disconnect) transition. All against a real in-process websockets server.
"""
from __future__ import annotations
import asyncio
import json
import pytest
import pytest_asyncio
from gateway.relay.ws_transport import WebSocketRelayTransport, WEBSOCKETS_AVAILABLE
pytestmark = pytest.mark.skipif(not WEBSOCKETS_AVAILABLE, reason="websockets not installed")
if WEBSOCKETS_AVAILABLE:
import websockets
DESCRIPTOR = {
"contract_version": 1,
"platform": "discord",
"label": "Discord",
"max_message_length": 2000,
"supports_draft_streaming": False,
"supports_edit": True,
"supports_threads": True,
"markdown_dialect": "discord",
"len_unit": "chars",
}
class _IdleAwareServer:
"""Connector stub: descriptor on hello, acks going_idle, records inbound_acks,
and can push buffered inbound frames (with bufferId) after handshake."""
def __init__(self):
self.received: list[dict] = []
self.inbound_acks: list[str] = []
self.going_idle_count = 0
self._server = None
self.url = ""
# Frames to push right after each handshake (e.g. buffered backlog replay).
self._to_push: list[dict] = []
self.connections = 0
async def start(self):
self._server = await websockets.serve(self._handle, "127.0.0.1", 0)
sock = next(iter(self._server.sockets))
self.url = f"ws://127.0.0.1:{sock.getsockname()[1]}"
async def stop(self):
if self._server is not None:
self._server.close()
await self._server.wait_closed()
async def _handle(self, ws):
self.connections += 1
try:
async for raw in ws:
for line in str(raw).split("\n"):
if not line.strip():
continue
frame = json.loads(line)
self.received.append(frame)
await self._on_frame(ws, frame)
except Exception:
pass
async def _on_frame(self, ws, frame):
ftype = frame.get("type")
if ftype == "hello":
await ws.send(json.dumps({"type": "descriptor", "descriptor": DESCRIPTOR}) + "\n")
for f in self._to_push:
await ws.send(json.dumps(f) + "\n")
elif ftype == "going_idle":
self.going_idle_count += 1
await ws.send(json.dumps({"type": "going_idle_ack"}) + "\n")
elif ftype == "inbound_ack":
self.inbound_acks.append(frame.get("bufferId"))
@pytest_asyncio.fixture
async def server():
srv = _IdleAwareServer()
await srv.start()
yield srv
await srv.stop()
@pytest.mark.asyncio
async def test_go_idle_awaits_ack(server):
t = WebSocketRelayTransport(server.url, "discord", "appShared")
await t.connect()
try:
await t.handshake()
acked = await t.go_idle(timeout_s=2)
assert acked is True
assert server.going_idle_count == 1
assert any(f["type"] == "going_idle" for f in server.received)
finally:
await t.disconnect()
@pytest.mark.asyncio
async def test_go_idle_returns_false_on_timeout(server):
# A server that never acks going_idle -> go_idle returns False (caller closes anyway).
async def no_ack(ws, frame):
if frame.get("type") == "hello":
await ws.send(json.dumps({"type": "descriptor", "descriptor": DESCRIPTOR}) + "\n")
# deliberately ignore going_idle
server._on_frame = no_ack # type: ignore[assignment]
t = WebSocketRelayTransport(server.url, "discord", "appShared")
await t.connect()
try:
await t.handshake()
acked = await t.go_idle(timeout_s=0.3)
assert acked is False
finally:
await t.disconnect()
@pytest.mark.asyncio
async def test_buffered_inbound_is_acked_after_handler(server):
# A buffered delivery (bufferId present) is acked AFTER the handler runs; a
# live delivery (no bufferId) is not acked.
server._to_push = [
{
"type": "inbound",
"event": {
"text": "buffered",
"message_type": "text",
"source": {"platform": "discord", "chat_id": "c1", "chat_type": "dm"},
},
"bufferId": "buf-42",
},
{
"type": "inbound",
"event": {
"text": "live",
"message_type": "text",
"source": {"platform": "discord", "chat_id": "c1", "chat_type": "dm"},
},
},
]
seen = []
async def handler(ev):
seen.append(ev.text)
t = WebSocketRelayTransport(server.url, "discord", "appShared")
t.set_inbound_handler(handler)
await t.connect()
try:
await t.handshake()
await asyncio.sleep(0.1)
assert "buffered" in seen and "live" in seen
# Only the buffered (bufferId) delivery was acked.
assert server.inbound_acks == ["buf-42"]
finally:
await t.disconnect()
@pytest.mark.asyncio
async def test_reconnect_redials_after_unexpected_close():
# A server that drops the FIRST connection right after handshake; the
# transport with reconnect=True re-dials and handshakes again.
drops = {"n": 0}
srv = _IdleAwareServer()
async def handle(ws):
srv.connections += 1
async for raw in ws:
for line in str(raw).split("\n"):
if not line.strip():
continue
frame = json.loads(line)
if frame.get("type") == "hello":
await ws.send(json.dumps({"type": "descriptor", "descriptor": DESCRIPTOR}) + "\n")
if drops["n"] == 0:
drops["n"] += 1
await ws.close() # force an unexpected close on the first connection
return
srv._server = await websockets.serve(handle, "127.0.0.1", 0)
sock = next(iter(srv._server.sockets))
srv.url = f"ws://127.0.0.1:{sock.getsockname()[1]}"
t = WebSocketRelayTransport(srv.url, "discord", "appShared", reconnect=True, reconnect_backoff_s=0.05)
try:
await t.connect()
await t.handshake()
# First connection is dropped server-side; the reconnect loop re-dials.
await asyncio.sleep(0.5)
assert srv.connections >= 2
finally:
await t.disconnect()
srv._server.close()
await srv._server.wait_closed()
@pytest.mark.asyncio
async def test_no_reconnect_after_deliberate_disconnect(server):
t = WebSocketRelayTransport(server.url, "discord", "appShared", reconnect=True, reconnect_backoff_s=0.05)
await t.connect()
await t.handshake()
before = server.connections
await t.disconnect()
await asyncio.sleep(0.3)
# A deliberate disconnect must NOT trigger the reconnect loop.
assert server.connections == before
@pytest.mark.asyncio
async def test_adapter_emits_going_idle_on_disconnect(server):
# The RelayAdapter emits going_idle as part of its existing disconnect (drain)
# transition, then tears down the transport.
from gateway.config import PlatformConfig
from gateway.relay.adapter import RelayAdapter
from gateway.relay.descriptor import CONTRACT_VERSION, CapabilityDescriptor
placeholder = CapabilityDescriptor(
contract_version=CONTRACT_VERSION,
platform="discord",
label="Relay",
max_message_length=4096,
supports_draft_streaming=False,
supports_edit=True,
supports_threads=False,
markdown_dialect="plain",
len_unit="chars",
)
transport = WebSocketRelayTransport(server.url, "discord", "appShared")
adapter = RelayAdapter(PlatformConfig(), placeholder, transport=transport)
await adapter.connect()
await adapter.disconnect()
assert server.going_idle_count == 1

View file

@ -0,0 +1,192 @@
"""Unit tests for the gateway-side relay relevance-policy declaration (Phase 6 ζ).
Covers gateway.relay.relay_relevance_policy() (the projection of the agent's
mention-gating / free-response / allow-bots config into the connector's generic
vocabulary) and send_relay_policy() (the boot-time POST to /relay/policy). The
connector HTTP POST is monkeypatched; the cross-repo E2E (connector repo,
gateway_policy_driver.py) exercises the real route. These prove the PROJECTION
mapping, the auth/skip logic, and the fail-soft boot behaviour.
"""
from __future__ import annotations
import pytest
import gateway.relay as relay
@pytest.fixture(autouse=True)
def _clean_env(monkeypatch):
for k in (
"GATEWAY_RELAY_URL",
"GATEWAY_RELAY_ID",
"GATEWAY_RELAY_SECRET",
"GATEWAY_RELAY_PLATFORM",
"GATEWAY_RELAY_BOT_ID",
"DISCORD_ALLOW_BOTS",
):
monkeypatch.delenv(k, raising=False)
monkeypatch.setattr("gateway.run._load_gateway_config", lambda: {}, raising=False)
# --------------------------------------------------------------------------
# relay_relevance_policy() — the projection
# --------------------------------------------------------------------------
def test_projection_maps_require_mention_and_free_response(monkeypatch):
monkeypatch.setenv("GATEWAY_RELAY_PLATFORM", "discord")
monkeypatch.setattr(
"gateway.run._load_gateway_config",
lambda: {"discord": {"require_mention": True, "free_response_channels": ["c-support", "c-help"]}},
raising=False,
)
pol = relay.relay_relevance_policy()
assert pol == {
"platform": "discord",
"requireAddress": True,
"freeResponseScopes": ["c-support", "c-help"],
"allowOtherBots": False,
}
def test_projection_allow_other_bots_from_env(monkeypatch):
monkeypatch.setenv("GATEWAY_RELAY_PLATFORM", "discord")
monkeypatch.setenv("DISCORD_ALLOW_BOTS", "all")
monkeypatch.setattr(
"gateway.run._load_gateway_config",
lambda: {"discord": {"require_mention": True}},
raising=False,
)
pol = relay.relay_relevance_policy()
assert pol is not None and pol["allowOtherBots"] is True
def test_projection_comma_string_free_response(monkeypatch):
monkeypatch.setenv("GATEWAY_RELAY_PLATFORM", "discord")
monkeypatch.setattr(
"gateway.run._load_gateway_config",
lambda: {"discord": {"free_response_channels": "c1, c2 ,c3"}},
raising=False,
)
pol = relay.relay_relevance_policy()
assert pol is not None and pol["freeResponseScopes"] == ["c1", "c2", "c3"]
def test_projection_falls_back_to_top_level_require_mention(monkeypatch):
monkeypatch.setenv("GATEWAY_RELAY_PLATFORM", "discord")
monkeypatch.setattr(
"gateway.run._load_gateway_config",
lambda: {"require_mention": True}, # top-level, no discord: block
raising=False,
)
pol = relay.relay_relevance_policy()
assert pol is not None and pol["requireAddress"] is True
def test_projection_none_when_all_default(monkeypatch):
# No require_mention, no free-response, no allow-bots ⇒ nothing to declare
# (the connector's quiet default already matches).
monkeypatch.setenv("GATEWAY_RELAY_PLATFORM", "discord")
monkeypatch.setattr("gateway.run._load_gateway_config", lambda: {"discord": {}}, raising=False)
assert relay.relay_relevance_policy() is None
def test_projection_none_when_platform_unresolved(monkeypatch):
# Default platform "relay" ⇒ no concrete fronted platform ⇒ nothing to project.
monkeypatch.setattr(
"gateway.run._load_gateway_config",
lambda: {"discord": {"require_mention": True}},
raising=False,
)
assert relay.relay_relevance_policy() is None
# --------------------------------------------------------------------------
# send_relay_policy() — the boot-time declaration
# --------------------------------------------------------------------------
def _arm(monkeypatch, *, url="wss://connector.example/relay"):
monkeypatch.setenv("GATEWAY_RELAY_URL", url)
monkeypatch.setenv("GATEWAY_RELAY_ID", "gw-x")
monkeypatch.setenv("GATEWAY_RELAY_SECRET", "s" * 48)
monkeypatch.setenv("GATEWAY_RELAY_PLATFORM", "discord")
def test_send_posts_projected_policy_with_token(monkeypatch):
_arm(monkeypatch)
monkeypatch.setattr(
"gateway.run._load_gateway_config",
lambda: {"discord": {"require_mention": True, "free_response_channels": ["c-support"]}},
raising=False,
)
captured = {}
def _fake_post(*, policy_url, token, policy, timeout=15.0):
captured["policy_url"] = policy_url
captured["token"] = token
captured["policy"] = policy
return 200
monkeypatch.setattr(relay, "_post_policy", _fake_post)
assert relay.send_relay_policy() is True
assert captured["policy_url"] == "https://connector.example/relay/policy"
assert captured["token"] # a real upgrade token was minted
assert captured["policy"]["requireAddress"] is True
assert captured["policy"]["freeResponseScopes"] == ["c-support"]
def test_send_skips_when_no_secret(monkeypatch):
monkeypatch.setenv("GATEWAY_RELAY_URL", "wss://connector.example/relay")
monkeypatch.setenv("GATEWAY_RELAY_PLATFORM", "discord")
# no GATEWAY_RELAY_ID / SECRET
monkeypatch.setattr(
"gateway.run._load_gateway_config",
lambda: {"discord": {"require_mention": True}},
raising=False,
)
called = {"n": 0}
monkeypatch.setattr(relay, "_post_policy", lambda **k: called.__setitem__("n", called["n"] + 1) or 200)
assert relay.send_relay_policy() is False
assert called["n"] == 0 # never attempted without a secret to auth with
def test_send_skips_when_nothing_to_declare(monkeypatch):
_arm(monkeypatch)
monkeypatch.setattr("gateway.run._load_gateway_config", lambda: {"discord": {}}, raising=False)
called = {"n": 0}
monkeypatch.setattr(relay, "_post_policy", lambda **k: called.__setitem__("n", called["n"] + 1) or 200)
assert relay.send_relay_policy() is False
assert called["n"] == 0 # no redundant write of the default
def test_send_fail_soft_on_transport_error(monkeypatch):
_arm(monkeypatch)
monkeypatch.setattr(
"gateway.run._load_gateway_config",
lambda: {"discord": {"require_mention": True}},
raising=False,
)
def _boom(**kwargs):
raise RuntimeError("connector unreachable")
monkeypatch.setattr(relay, "_post_policy", _boom)
# Never raises; returns False so boot proceeds.
assert relay.send_relay_policy() is False
def test_send_fail_soft_on_non_200(monkeypatch):
_arm(monkeypatch)
monkeypatch.setattr(
"gateway.run._load_gateway_config",
lambda: {"discord": {"require_mention": True}},
raising=False,
)
monkeypatch.setattr(relay, "_post_policy", lambda **k: 401)
assert relay.send_relay_policy() is False
def test_send_skips_when_relay_unconfigured(monkeypatch):
# No GATEWAY_RELAY_URL ⇒ relay not configured ⇒ no-op.
monkeypatch.setattr(relay, "_post_policy", lambda **k: 200)
assert relay.send_relay_policy() is False

View file

@ -30,6 +30,7 @@ def _clean_env(monkeypatch):
"GATEWAY_RELAY_ROUTE_KEYS",
"GATEWAY_RELAY_PLATFORM",
"GATEWAY_RELAY_BOT_ID",
"GATEWAY_RELAY_INSTANCE_ID",
):
monkeypatch.delenv(k, raising=False)
# Never read config.yaml off disk in these tests.
@ -83,6 +84,24 @@ def test_relay_route_keys_empty():
assert relay.relay_route_keys() == []
def test_relay_instance_id_from_env(monkeypatch):
monkeypatch.setenv("GATEWAY_RELAY_INSTANCE_ID", " inst-abc ")
assert relay.relay_instance_id() == "inst-abc"
def test_relay_instance_id_absent_is_none():
assert relay.relay_instance_id() is None
def test_relay_instance_id_from_config(monkeypatch):
monkeypatch.setattr(
"gateway.run._load_gateway_config",
lambda: {"gateway": {"relay_instance_id": "inst-from-config"}},
raising=False,
)
assert relay.relay_instance_id() == "inst-from-config"
def test_provision_url_maps_ws_to_http():
assert relay._provision_url("wss://c.example/relay") == "https://c.example/relay/provision"
assert relay._provision_url("ws://c.example/relay") == "http://c.example/relay/provision"
@ -161,6 +180,81 @@ def test_outbound_only_when_no_endpoint(monkeypatch):
assert relay.relay_connection_auth()[1] == "a" * 64
# ─────────────────── instance-id forwarding (Phase 6 Unit α) ───────────────────
def test_forwards_instance_id_to_provision(monkeypatch):
"""A managed agent stamped with GATEWAY_RELAY_INSTANCE_ID forwards it to the
connector so it can bind gatewayId -> instanceId (per-instance routing)."""
_arm(monkeypatch)
monkeypatch.setenv("GATEWAY_RELAY_INSTANCE_ID", "inst-abc")
captured: dict = {}
monkeypatch.setattr(relay, "_post_provision", _stub_post(captured))
assert relay.self_provision_relay() is True
assert captured["instance_id"] == "inst-abc"
def test_instance_id_absent_forwards_none(monkeypatch):
"""No stamp (self-hosted / pre-Phase-6) -> instance_id None; the connector
stores null and per-instance routing simply has no binding yet."""
_arm(monkeypatch)
captured: dict = {}
monkeypatch.setattr(relay, "_post_provision", _stub_post(captured))
assert relay.self_provision_relay() is True
assert captured["instance_id"] is None
def test_post_provision_body_includes_instanceId_only_when_set(monkeypatch):
"""The real _post_provision adds `instanceId` to the JSON body ONLY when a
value is supplied omitting it lets the connector store null (back-compat),
rather than binding an empty string."""
import json
sent: dict = {}
class _Resp:
def __enter__(self):
return self
def __exit__(self, *a):
return False
def read(self):
return json.dumps({"secret": "a" * 64, "deliveryKey": "b" * 64, "tenant": "t", "gatewayId": "gw-1"}).encode()
def _fake_urlopen(req, timeout=None): # noqa: ANN001
sent["body"] = json.loads(req.data.decode())
return _Resp()
monkeypatch.setattr("urllib.request.urlopen", _fake_urlopen)
# With an instance id -> present in the body.
relay._post_provision(
provision_url="https://c.example/relay/provision",
access_token="tok",
gateway_id="gw-1",
platform="discord",
bot_id="app",
gateway_endpoint=None,
route_keys=[],
instance_id="inst-abc",
)
assert sent["body"]["instanceId"] == "inst-abc"
# Without one -> the key is absent entirely (not "" ).
relay._post_provision(
provision_url="https://c.example/relay/provision",
access_token="tok",
gateway_id="gw-1",
platform="discord",
bot_id="app",
gateway_endpoint=None,
route_keys=[],
)
assert "instanceId" not in sent["body"]
# ─────────────────────────── fail-soft ───────────────────────────
def test_no_nas_token_is_non_fatal(monkeypatch):

View file

@ -0,0 +1,128 @@
"""Regression test for approval prompt credential redaction (issue #48456).
When Tirith flags a command for containing a credential-shaped pattern, the
gateway approval prompt must redact the credential from the command text
before sending it to the chat platform. Without this fix, the raw command
(with the credential in plaintext) is sent verbatim to Telegram/Discord/etc.,
undoing Tirith's redaction one layer up.
The redaction is wired through the module-level ``_redact_approval_command``
seam. These tests bind that seam -- the production wiring -- not just the
underlying ``redact_sensitive_text`` helper, so they fail if the redaction
call is removed from either approval path.
Credential fixtures are built at runtime from a benign prefix + a run of
``X`` characters (the same trick tests/agent/test_redact.py uses): they match
the redactor regexes so the assertions stay meaningful, but contain no real
or real-looking key, so secret scanners do not flag this file.
"""
from gateway.run import _redact_approval_command
# Synthetic, scanner-safe credential fixtures. Each matches its redactor
# regex (ghp_/sk-/JWT) but is unmistakably fake -- a run of X's, never a
# real or real-format key.
_FAKE_GHP = "ghp_" + "X" * 36
_FAKE_OPENAI = "sk-proj-" + "X" * 40
_FAKE_JWT = "eyJ" + "X" * 20 + "." + "eyJ" + "X" * 24 + "." + "X" * 30
class TestRedactApprovalCommand:
"""Contract for the approval-prompt redaction seam used by the gateway."""
def test_redacts_github_pat(self):
raw = "curl -H 'Authorization: token " + _FAKE_GHP + "' https://api.github.com/user"
out = _redact_approval_command(raw)
assert _FAKE_GHP not in out
# command structure preserved so the operator can still judge the action
assert "curl" in out
assert "github.com" in out
def test_redacts_openai_key(self):
raw = "export OPENAI_API_KEY=" + _FAKE_OPENAI + " && python s.py"
out = _redact_approval_command(raw)
assert _FAKE_OPENAI not in out
assert "python s.py" in out
def test_redacts_bearer_token(self):
raw = "curl -H 'Authorization: Bearer " + _FAKE_JWT + "' https://api.example.com"
out = _redact_approval_command(raw)
assert _FAKE_JWT not in out
def test_clean_command_passes_through_unchanged(self):
raw = "ls -la /tmp && echo hello"
assert _redact_approval_command(raw) == raw
def test_forces_redaction_even_when_disabled(self, monkeypatch):
"""force=True must redact even if security.redact_secrets is off -- the
approval prompt is a hard secret-egress boundary regardless of config."""
raw = "curl -H 'Authorization: token " + _FAKE_GHP + "' https://api.github.com"
# With redaction globally disabled, the seam must STILL redact (force=True).
monkeypatch.setattr("agent.redact._REDACT_ENABLED", False, raising=False)
out = _redact_approval_command(raw)
assert _FAKE_GHP not in out
def test_handles_none_and_empty(self):
assert _redact_approval_command("") == ""
assert _redact_approval_command(None) == ""
class TestApprovalCommandWiring:
"""Guard the production wiring on BOTH approval-notify transports:
1. the chat-platform path (_approval_notify_sync in gateway/run.py), and
2. the SSE/API path (_approval_notify in gateway/platforms/api_server.py),
each of which must route the command through _redact_approval_command and
REASSIGN the redacted value before any send/enqueue (so the raw command
cannot reach a client). Uses AST (not char-offset string slicing) so a
benign refactor doesn't cause a false failure, and so a discarded-result
call (`_redact(cmd); send(cmd)`) does NOT pass."""
def _assert_redacts_then_uses(self, module, func_name: str, sink_substr: str):
"""Parse `module`'s full AST, locate the (possibly nested) function
`func_name`, and assert it contains an assignment
`<x> = _redact_approval_command(...)` whose result is then used by a
statement matching `sink_substr` on a LATER line. Walking the real AST
(not a source slice) is refactor-robust and rejects discarded-result
calls (the call must be an assignment, not a bare expression)."""
import ast
import inspect
source = inspect.getsource(module)
tree = ast.parse(source)
target_fn = None
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == func_name:
target_fn = node
break
assert target_fn is not None, f"function {func_name} not found in {module.__name__}"
redact_line = None
for node in ast.walk(target_fn):
if isinstance(node, ast.Assign) and isinstance(node.value, ast.Call):
fn = node.value.func
if isinstance(fn, ast.Name) and fn.id == "_redact_approval_command":
redact_line = node.lineno
assert redact_line is not None, (
f"{func_name} must assign the result of _redact_approval_command(...) "
"(a discarded-result call would still leak the raw command)"
)
sink_line = None
for node in ast.walk(target_fn):
seg = ast.get_source_segment(source, node)
if seg and sink_substr in seg and getattr(node, "lineno", 0) > redact_line:
sink_line = node.lineno
break
assert sink_line is not None, (
f"`{sink_substr}` sink not found after the redaction in {func_name}"
)
def test_chat_platform_path_redacts_before_send(self):
import gateway.run as run
self._assert_redacts_then_uses(run, "_approval_notify_sync", "send_exec_approval")
def test_sse_api_path_redacts_before_enqueue(self):
from gateway.platforms import api_server
self._assert_redacts_then_uses(api_server, "_approval_notify", "put_nowait")

View file

@ -281,3 +281,143 @@ async def test_platform_send_failure_raises_for_delivery_result(tmp_path, monkey
with pytest.raises(RuntimeError, match="route failed"):
await router._deliver_to_platform(target, "hello", metadata={"telegram_reply_to_message_id": "9001"})
# ---------------------------------------------------------------------------
# Cron output truncation / adapter-aware chunking (issue #50126)
# ---------------------------------------------------------------------------
class ChunkingAdapter:
"""Adapter that declares splits_long_messages=True (like Discord/Telegram)."""
splits_long_messages = True
def __init__(self):
self.calls = []
async def send(self, chat_id, content, metadata=None):
self.calls.append({"chat_id": chat_id, "content": content, "metadata": metadata})
return {"success": True}
class NonChunkingAdapter:
"""Adapter without splits_long_messages (default False — legacy behavior)."""
def __init__(self):
self.calls = []
async def send(self, chat_id, content, metadata=None):
self.calls.append({"chat_id": chat_id, "content": content, "metadata": metadata})
return {"success": True}
@pytest.mark.asyncio
async def test_long_output_truncated_for_non_chunking_adapter(tmp_path, monkeypatch):
"""Non-chunking adapters receive truncated content with a footer + file save."""
monkeypatch.setattr("gateway.delivery.get_hermes_home", lambda: tmp_path)
adapter = NonChunkingAdapter()
router = DeliveryRouter(GatewayConfig(), adapters={Platform.DISCORD: adapter})
target = DeliveryTarget.parse("discord:123")
long_content = "x" * 5000
await router._deliver_to_platform(target, long_content, metadata={"job_id": "job1"})
delivered = adapter.calls[0]["content"]
assert len(delivered) < 5000 # was truncated
assert "truncated" in delivered.lower()
assert "full output saved to" in delivered
# Full output was saved to disk
saved_files = list(tmp_path.glob("cron/output/job1_*.txt"))
assert len(saved_files) == 1
assert saved_files[0].read_text() == long_content
@pytest.mark.asyncio
async def test_long_output_preserved_for_chunking_adapter(tmp_path, monkeypatch):
"""Chunking adapters (splits_long_messages=True) receive the FULL content."""
monkeypatch.setattr("gateway.delivery.get_hermes_home", lambda: tmp_path)
adapter = ChunkingAdapter()
router = DeliveryRouter(GatewayConfig(), adapters={Platform.DISCORD: adapter})
target = DeliveryTarget.parse("discord:123")
long_content = "x" * 5000
await router._deliver_to_platform(target, long_content, metadata={"job_id": "job2"})
delivered = adapter.calls[0]["content"]
assert delivered == long_content # NOT truncated — adapter handles chunking
assert "truncated" not in delivered.lower()
# Full output still saved to disk as audit trail
saved_files = list(tmp_path.glob("cron/output/job2_*.txt"))
assert len(saved_files) == 1
assert saved_files[0].read_text() == long_content
@pytest.mark.asyncio
async def test_short_output_never_truncated(tmp_path, monkeypatch):
"""Output under the limit passes through untouched for any adapter."""
monkeypatch.setattr("gateway.delivery.get_hermes_home", lambda: tmp_path)
adapter = NonChunkingAdapter()
router = DeliveryRouter(GatewayConfig(), adapters={Platform.DISCORD: adapter})
target = DeliveryTarget.parse("discord:123")
short_content = "x" * 100
await router._deliver_to_platform(target, short_content, metadata={"job_id": "job3"})
assert adapter.calls[0]["content"] == short_content
# Nothing saved to disk
assert not list(tmp_path.glob("cron/output/*.txt"))
@pytest.mark.asyncio
async def test_audit_save_failure_does_not_break_chunking_delivery(tmp_path, monkeypatch):
"""If the audit save fails (disk full, permissions), chunking adapters
still receive the full content the save is best-effort."""
monkeypatch.setattr("gateway.delivery.get_hermes_home", lambda: tmp_path)
adapter = ChunkingAdapter()
router = DeliveryRouter(GatewayConfig(), adapters={Platform.DISCORD: adapter})
target = DeliveryTarget.parse("discord:123")
long_content = "x" * 5000
call_count = {"n": 0}
def failing_save(content, job_id):
call_count["n"] += 1
raise OSError("No space left on device")
monkeypatch.setattr(router, "_save_full_output", failing_save)
# Should NOT raise — audit failure is caught for chunking adapters
await router._deliver_to_platform(target, long_content, metadata={"job_id": "job6"})
# Adapter still got the full content
assert adapter.calls[0]["content"] == long_content
# Save was attempted (best-effort, swallowed)
assert call_count["n"] == 1
@pytest.mark.asyncio
async def test_save_failure_during_truncation_raises_for_non_chunking_adapter(tmp_path, monkeypatch):
"""For a non-chunking adapter, the truncation footer needs a valid saved
path. If the save fails there, that is a real delivery problem and the
error propagates (not swallowed like the chunking best-effort save)."""
monkeypatch.setattr("gateway.delivery.get_hermes_home", lambda: tmp_path)
adapter = NonChunkingAdapter()
router = DeliveryRouter(GatewayConfig(), adapters={Platform.DISCORD: adapter})
target = DeliveryTarget.parse("discord:123")
long_content = "x" * 5000
def failing_save(content, job_id):
raise OSError("No space left on device")
monkeypatch.setattr(router, "_save_full_output", failing_save)
# Non-chunking adapter must truncate → needs a valid saved path → the
# Step 1 best-effort catch swallows the first attempt, but the Step 2
# retry (footer needs the path) re-raises.
with pytest.raises(OSError, match="No space left on device"):
await router._deliver_to_platform(target, long_content, metadata={"job_id": "job7"})

View file

@ -0,0 +1,516 @@
"""Tests for Discord double-dispatch prevention (#51057).
When _auto_create_thread() creates a thread from a user message via
message.create_thread(), Discord fires a second MESSAGE_CREATE event for
the "thread starter message". That starter message carries
``message.id == thread.id`` and may arrive with ``type=default``
(instead of ``type=21 / thread_starter_message``), so the type filter
does NOT catch it resulting in two agent runs and two responses.
Fix: after _auto_create_thread succeeds, pre-seed the dedup cache with
``str(thread.id)`` so the duplicate starter-message event is dropped.
Two sub-scenarios are tested:
1. Thread-starter as a duplicate MESSAGE_CREATE (the primary bug).
2. When text_batch_delay=0 the dispatch path is direct (no batching).
The same dedup pre-seed must still protect against the duplicate.
"""
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import sys
import pytest
from gateway.config import PlatformConfig
# ---------------------------------------------------------------------------
# Discord mock setup
# The tests/gateway/conftest.py already installs a comprehensive discord
# mock at collection time. We import the adapter AFTER that is done.
# ---------------------------------------------------------------------------
import plugins.platforms.discord.adapter as discord_platform # noqa: E402
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
# ---------------------------------------------------------------------------
# Fake channel/thread helpers
#
# IMPORTANT: FakeTextChannel must NOT be the same class as discord.DMChannel
# or discord.Thread (those are set up by conftest). We give it a neutral name
# and do NOT monkeypatch discord.DMChannel to it.
# ---------------------------------------------------------------------------
class _TextChannel:
"""Fake Discord text channel (not a DM, not a Thread)."""
def __init__(self, channel_id: int = 100, name: str = "general",
guild_name: str = "Test Server"):
self.id = channel_id
self.name = name
self.guild = SimpleNamespace(name=guild_name, id=1)
self.topic = None
def history(self, *, limit, before, after=None, oldest_first=None):
async def _empty():
return
yield
return _empty()
class _Thread:
"""Fake Discord thread (not a DM, not a top-level channel)."""
def __init__(self, thread_id: int, name: str = "thread",
parent=None, guild_name: str = "Test Server"):
self.id = thread_id
self.name = name
self.parent = parent
self.parent_id = getattr(parent, "id", None)
self.guild = getattr(parent, "guild", None) or SimpleNamespace(
name=guild_name, id=1
)
self.topic = None
def history(self, *, limit, before, after=None, oldest_first=None):
async def _empty():
return
yield
return _empty()
def _make_message(
*,
msg_id: int = 42,
channel,
content: str = "hello",
mentions=None,
author=None,
msg_type=None,
attachments=None,
reference=None,
message_snapshots=None,
):
if author is None:
author = SimpleNamespace(id=7, display_name="Alice", name="Alice", bot=False)
return SimpleNamespace(
id=msg_id,
content=content,
mentions=list(mentions or []),
attachments=list(attachments or []),
reference=reference,
message_snapshots=message_snapshots,
created_at=datetime.now(timezone.utc),
channel=channel,
author=author,
type=(
msg_type
if msg_type is not None
else discord_platform.discord.MessageType.default
),
)
# ---------------------------------------------------------------------------
# Adapter fixture
# ---------------------------------------------------------------------------
@pytest.fixture
def adapter(monkeypatch):
# Clear relevant env vars so tests are hermetic
for var in (
"DISCORD_REQUIRE_MENTION",
"DISCORD_AUTO_THREAD",
"DISCORD_NO_THREAD_CHANNELS",
"DISCORD_FREE_RESPONSE_CHANNELS",
"DISCORD_ALLOWED_CHANNELS",
"DISCORD_IGNORED_CHANNELS",
"DISCORD_HISTORY_BACKFILL",
"DISCORD_ALLOW_BOTS",
"DISCORD_IGNORE_NO_MENTION",
):
monkeypatch.delenv(var, raising=False)
config = PlatformConfig(enabled=True, token="***")
a = DiscordAdapter(config)
a._client = SimpleNamespace(user=SimpleNamespace(id=999, bot=True))
a._text_batch_delay_seconds = 0 # disable batching so dispatch is synchronous
a.handle_message = AsyncMock()
return a
# ---------------------------------------------------------------------------
# Scenario 1 — thread-starter message duplicate via on_message (the main bug)
# ---------------------------------------------------------------------------
class TestThreadStarterDedup:
"""Pre-seeding dedup with thread.id prevents a second dispatch when the
thread-starter message arrives as a duplicate MESSAGE_CREATE event."""
@pytest.mark.asyncio
async def test_thread_starter_duplicate_dropped(self, adapter, monkeypatch):
"""After _auto_create_thread the thread.id is pre-seeded in dedup.
Simulates the exact Discord bug: after thread creation, Discord
fires MESSAGE_CREATE again with message.id == thread.id. The
adapter's on_message guard calls _dedup.is_duplicate(str(message.id))
before dispatching. With the fix the duplicate is dropped; without
it there would be two agent runs.
"""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
channel = _TextChannel(channel_id=100)
thread_id = 55555 # thread.id == starter-message.id on Discord
fake_thread = _Thread(thread_id=thread_id, parent=channel)
async def fake_auto_create_thread(message):
return fake_thread
monkeypatch.setattr(adapter, "_auto_create_thread", fake_auto_create_thread)
# 1) Original user message arrives → triggers thread creation + dispatch
user_msg = _make_message(msg_id=42, channel=channel, content="hello bot")
await adapter._handle_message(user_msg)
# One dispatch for the user message
assert adapter.handle_message.call_count == 1, (
"Expected handle_message to be called exactly once for the user message"
)
# 2) Discord fires a second MESSAGE_CREATE for the thread starter.
# Its message.id == thread.id (this is the Discord quirk).
# Simulate what on_message does: check _dedup.is_duplicate first.
#
# The fix pre-seeded thread.id via _dedup.is_duplicate(str(thread.id))
# inside _handle_message. That call already marked thread.id as seen.
# So this second call with the same id returns True → drop the duplicate.
starter_msg_id = str(thread_id)
is_dup = adapter._dedup.is_duplicate(starter_msg_id)
assert is_dup is True, (
"Thread starter message (id == thread.id) should be in dedup cache "
"after _auto_create_thread returns, so the duplicate event is dropped"
)
# Confirm: handle_message was only called once total
assert adapter.handle_message.call_count == 1, (
"handle_message should only be called once — duplicate starter dropped"
)
@pytest.mark.asyncio
async def test_thread_id_pre_seeded_in_dedup_cache(self, adapter, monkeypatch):
"""After _handle_message with auto-thread, thread.id is in _dedup._seen."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
channel = _TextChannel(channel_id=100)
thread_id = 55555
fake_thread = _Thread(thread_id=thread_id, parent=channel)
async def fake_auto_create_thread(message):
return fake_thread
monkeypatch.setattr(adapter, "_auto_create_thread", fake_auto_create_thread)
user_msg = _make_message(msg_id=42, channel=channel, content="hello")
await adapter._handle_message(user_msg)
# Thread id must be in the dedup internal cache
assert str(thread_id) in adapter._dedup._seen, (
f"thread.id={thread_id} should be pre-seeded in _dedup._seen "
"after _auto_create_thread returns a thread"
)
@pytest.mark.asyncio
async def test_no_dedup_seed_when_thread_creation_fails(self, adapter, monkeypatch):
"""When _auto_create_thread returns None, no pre-seeding occurs."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
channel = _TextChannel(channel_id=100)
phantom_thread_id = 55555
async def fake_auto_create_thread_fail(message):
return None # thread creation failed
monkeypatch.setattr(
adapter, "_auto_create_thread", fake_auto_create_thread_fail
)
user_msg = _make_message(msg_id=42, channel=channel, content="hello")
await adapter._handle_message(user_msg)
# The message was still dispatched (no thread, but message goes through)
adapter.handle_message.assert_awaited_once()
# The phantom thread id should NOT be in the dedup cache
assert str(phantom_thread_id) not in adapter._dedup._seen, (
"thread.id should NOT be pre-seeded when thread creation fails"
)
@pytest.mark.asyncio
async def test_no_dedup_seed_when_auto_thread_disabled(self, adapter, monkeypatch):
"""When DISCORD_AUTO_THREAD=false, no thread is created and no pre-seeding."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
channel = _TextChannel(channel_id=100)
auto_create_called = []
async def fake_auto_create_thread(message):
auto_create_called.append(True)
return _Thread(thread_id=55555, parent=channel)
monkeypatch.setattr(adapter, "_auto_create_thread", fake_auto_create_thread)
user_msg = _make_message(msg_id=42, channel=channel, content="hello")
await adapter._handle_message(user_msg)
# _auto_create_thread should NOT have been called
assert not auto_create_called, "_auto_create_thread should not run when disabled"
# thread.id should NOT be pre-seeded
assert "55555" not in adapter._dedup._seen, (
"thread.id should not be in dedup when auto-threading is disabled"
)
@pytest.mark.asyncio
async def test_dedup_seed_with_text_batch_delay_zero(self, adapter, monkeypatch):
"""With text_batch_delay=0 (direct dispatch path), pre-seeding still works."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
# text_batch_delay_seconds is already 0 in the fixture
assert adapter._text_batch_delay_seconds == 0
channel = _TextChannel(channel_id=100)
thread_id = 77777
fake_thread = _Thread(thread_id=thread_id, parent=channel)
async def fake_auto_create_thread(message):
return fake_thread
monkeypatch.setattr(adapter, "_auto_create_thread", fake_auto_create_thread)
user_msg = _make_message(msg_id=42, channel=channel, content="hello")
await adapter._handle_message(user_msg)
# Dispatched once
adapter.handle_message.assert_awaited_once()
# Thread id IS pre-seeded even with direct dispatch path
assert str(thread_id) in adapter._dedup._seen, (
"thread.id must be pre-seeded regardless of text_batch_delay setting"
)
@pytest.mark.asyncio
async def test_thread_id_different_from_message_id_both_tracked(
self, adapter, monkeypatch
):
"""Verify thread.id is tracked independently when it differs from message.id."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
channel = _TextChannel(channel_id=100)
user_msg_id = 12345
thread_id = 99999 # always different in practice
fake_thread = _Thread(thread_id=thread_id, parent=channel)
async def fake_auto_create_thread(message):
return fake_thread
monkeypatch.setattr(adapter, "_auto_create_thread", fake_auto_create_thread)
user_msg = _make_message(msg_id=user_msg_id, channel=channel, content="hello")
await adapter._handle_message(user_msg)
# The thread.id (99999) is pre-seeded
assert str(thread_id) in adapter._dedup._seen, (
f"thread.id={thread_id} must be pre-seeded after auto-thread creation"
)
# A second MESSAGE_CREATE with message.id=thread.id is caught as duplicate
assert adapter._dedup.is_duplicate(str(thread_id)) is True, (
"Subsequent is_duplicate(thread.id) must return True"
)
# A hypothetical NEW message with a different id is not a duplicate
assert adapter._dedup.is_duplicate("11111") is False, (
"An unrelated new message id must not be blocked"
)
# ---------------------------------------------------------------------------
# Scenario 2 — direct double-call to _handle_message with same message id
# ---------------------------------------------------------------------------
class TestDirectDoubleDispatch:
"""on_message dedup (checked before _handle_message) prevents double dispatch.
While the on_message guard calls _dedup.is_duplicate before _handle_message,
these tests verify that the adapter's own _dedup correctly marks IDs as seen
so that hypothetical double-delivery of the same MESSAGE_CREATE is dropped.
"""
@pytest.mark.asyncio
async def test_same_message_id_not_dispatched_twice_via_dedup(
self, adapter, monkeypatch
):
"""Calling on_message dedup check twice with the same id only dispatches once."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
channel = _TextChannel(channel_id=100)
msg = _make_message(msg_id=42, channel=channel, content="hello")
# Simulate on_message dedup check + dispatch for first delivery
is_dup_1 = adapter._dedup.is_duplicate(str(msg.id))
assert is_dup_1 is False
await adapter._handle_message(msg)
assert adapter.handle_message.call_count == 1
# Simulate on_message dedup check for second delivery (RESUME replay)
is_dup_2 = adapter._dedup.is_duplicate(str(msg.id))
assert is_dup_2 is True
# on_message would return early here — do NOT call _handle_message again
assert adapter.handle_message.call_count == 1, (
"Second delivery with same message.id must be dropped by dedup"
)
@pytest.mark.asyncio
async def test_different_message_ids_both_dispatched(self, adapter, monkeypatch):
"""Two distinct messages with different IDs both reach the agent."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
channel = _TextChannel(channel_id=100)
msg1 = _make_message(msg_id=1, channel=channel, content="first")
msg2 = _make_message(msg_id=2, channel=channel, content="second")
assert adapter._dedup.is_duplicate(str(msg1.id)) is False
await adapter._handle_message(msg1)
assert adapter._dedup.is_duplicate(str(msg2.id)) is False
await adapter._handle_message(msg2)
assert adapter.handle_message.call_count == 2
# ---------------------------------------------------------------------------
# Scenario 3 — message_type=thread_starter filtered by type guard
# ---------------------------------------------------------------------------
class TestThreadStarterTypeFilter:
"""Discord sometimes sends thread starter messages with the correct
type=21 (thread_starter_message). Verify the type filter in on_message
blocks those correctly, separate from the dedup path.
"""
def test_thread_starter_message_type_not_in_allowed_set(self):
"""MessageType.thread_starter_message (21) is not in the allowed set."""
discord_mod = sys.modules["discord"]
# The adapter's on_message guard uses:
# if message.type not in {discord.MessageType.default, discord.MessageType.reply}
# Verify that thread_starter_message (if it has a numeric value of 21)
# would be excluded.
allowed = {
discord_mod.MessageType.default,
discord_mod.MessageType.reply,
}
# In real discord.py, thread_starter_message has value 21.
# In our mock, MessageType is a MagicMock so attribute access returns
# a new unique Mock each time — which is NOT in the allowed set.
thread_starter = discord_mod.MessageType.thread_starter_message
assert thread_starter not in allowed, (
"thread_starter_message type should not be in the allowed types set"
)
@pytest.mark.asyncio
async def test_message_type_default_passes_type_filter(self, adapter, monkeypatch):
"""MessageType.default messages pass the type filter (they reach _handle_message)."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
channel = _TextChannel(channel_id=100)
msg = _make_message(
msg_id=42,
channel=channel,
content="hello",
msg_type=discord_platform.discord.MessageType.default,
)
await adapter._handle_message(msg)
adapter.handle_message.assert_awaited_once()
# ---------------------------------------------------------------------------
# Scenario 4 — dedup cache integrity after thread pre-seeding
# ---------------------------------------------------------------------------
class TestDedupCacheIntegrity:
"""Verify the dedup cache state is correct after pre-seeding."""
@pytest.mark.asyncio
async def test_preseed_does_not_block_legitimate_new_messages(
self, adapter, monkeypatch
):
"""Pre-seeding thread.id does NOT interfere with other unrelated messages."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
channel = _TextChannel(channel_id=100)
thread_id = 22222
fake_thread = _Thread(thread_id=thread_id, parent=channel)
async def fake_auto_create_thread(message):
return fake_thread
monkeypatch.setattr(adapter, "_auto_create_thread", fake_auto_create_thread)
# First message — creates thread, pre-seeds dedup
msg1 = _make_message(msg_id=10, channel=channel, content="first")
await adapter._handle_message(msg1)
assert adapter.handle_message.call_count == 1
# A new message ID that is unrelated to the thread
msg2_id = 20
assert str(msg2_id) != str(thread_id) # sanity check
assert adapter._dedup.is_duplicate(str(msg2_id)) is False, (
"A new message with a different ID should not be blocked"
)
@pytest.mark.asyncio
async def test_multiple_thread_creations_each_preseeded(
self, adapter, monkeypatch
):
"""Each thread creation pre-seeds its own thread.id independently."""
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
channel = _TextChannel(channel_id=100)
thread_ids = [33333, 44444, 55555]
thread_idx = [0]
async def fake_auto_create_thread(message):
tid = thread_ids[thread_idx[0] % len(thread_ids)]
thread_idx[0] += 1
return _Thread(thread_id=tid, parent=channel)
monkeypatch.setattr(adapter, "_auto_create_thread", fake_auto_create_thread)
for i, tid in enumerate(thread_ids):
msg = _make_message(msg_id=100 + i, channel=channel, content=f"msg {i}")
await adapter._handle_message(msg)
# All three thread ids should be pre-seeded
for tid in thread_ids:
assert str(tid) in adapter._dedup._seen, (
f"thread.id={tid} should be pre-seeded in _dedup._seen "
"after its thread was created"
)
# And they should be detected as duplicates now
assert adapter._dedup.is_duplicate(str(tid)) is True, (
f"thread.id={tid} should be treated as duplicate"
)

View file

@ -0,0 +1,140 @@
"""Test Discord slash command sync respects the 100-command hard limit."""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import sys
import pytest
from gateway.config import PlatformConfig
def _ensure_discord_mock():
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return
if sys.modules.get("discord") is None:
discord_mod = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
sys.modules["discord"] = discord_mod
sys.modules["discord.ext"] = MagicMock()
sys.modules["discord.ext.commands"] = MagicMock()
_ensure_discord_mock()
from plugins.platforms.discord.adapter import DiscordAdapter
class _FakeTreeCommand:
"""Minimal command stub matching discord.py tree command API."""
def __init__(self, name: str, command_type: int = 1):
self.name = name
self.type = command_type
def to_dict(self, _tree):
return {"name": self.name, "type": self.type}
@pytest.fixture
def adapter():
"""Create a Discord adapter with mocked Discord client."""
_ensure_discord_mock()
config = PlatformConfig(enabled=True, token="fake-token")
adapter = DiscordAdapter(config)
# Mock the Discord client and tree
adapter._client = MagicMock()
adapter._client.tree = MagicMock()
adapter._client.http = AsyncMock()
adapter._client.application_id = "test_app_id"
adapter._sleep_between_command_sync_mutations = AsyncMock()
adapter._existing_command_to_payload = MagicMock(side_effect=lambda cmd: {"name": cmd.name})
adapter._canonicalize_app_command_payload = MagicMock(side_effect=lambda p: p)
adapter._patchable_app_command_payload = MagicMock(side_effect=lambda p: p)
return adapter
@pytest.mark.asyncio
async def test_safe_sync_deletes_before_creating():
"""Sync must delete obsolete commands BEFORE creating new ones.
Discord's 100-command limit is enforced when trying to upsert. If we
have 100 commands on Discord, try to add 1 new one, and haven't deleted
any yet, Discord rejects with error 30032.
The fix: identify and delete obsolete commands first, then create/update.
This ensures we never temporarily exceed 100 during the sync operation.
This is a regression guard for the samuraiheart bug where sync would fail
with error 30032 even though the registration code properly capped at 100.
"""
_ensure_discord_mock()
config = PlatformConfig(enabled=True, token="fake-token")
adapter = DiscordAdapter(config)
adapter._client = MagicMock()
adapter._client.tree = MagicMock()
adapter._client.http = AsyncMock()
adapter._client.application_id = "test_app_id"
adapter._sleep_between_command_sync_mutations = AsyncMock()
adapter._existing_command_to_payload = MagicMock(side_effect=lambda cmd: {"name": cmd.name})
adapter._canonicalize_app_command_payload = MagicMock(side_effect=lambda p: p)
adapter._patchable_app_command_payload = MagicMock(side_effect=lambda p: p)
# Simulate having 100 commands on Discord, with 1 that's no longer desired
# and 1 new command that should be created.
# Existing on Discord: cmd_0, cmd_1, ..., cmd_99 (100 total)
# Desired locally: cmd_1, cmd_2, ..., cmd_99, cmd_new (100 total)
# So: delete cmd_0 (1 deletion), create cmd_new (1 creation)
existing_commands = [
SimpleNamespace(id=f"id_{i}", name=f"cmd_{i}", type=1)
for i in range(100)
]
adapter._client.tree.fetch_commands = AsyncMock(return_value=existing_commands)
adapter._client.tree.get_commands = MagicMock(
return_value=[
_FakeTreeCommand(name=f"cmd_{i}", command_type=1)
for i in range(1, 100)
] + [_FakeTreeCommand(name="cmd_new", command_type=1)]
)
# Track the order of mutations
mutation_log = []
async def mock_delete(*args):
mutation_log.append(("delete", args[-1]))
async def mock_upsert(*args):
mutation_log.append(("create", args[-1].get("name")))
adapter._client.http.delete_global_command = mock_delete
adapter._client.http.upsert_global_command = mock_upsert
adapter._client.http.edit_global_command = AsyncMock()
# Call sync
await adapter._safe_sync_slash_commands()
# Verify that:
# 1. A deletion happened (cmd_0)
# 2. It happened BEFORE any creation
# 3. The creation of cmd_new happened AFTER deletion
deletes = [m for m in mutation_log if m[0] == "delete"]
creates = [m for m in mutation_log if m[0] == "create"]
assert len(deletes) >= 1, "At least one command should be deleted"
assert len(creates) >= 1, "At least one command should be created"
# The key assertion: all deletions should come before all creations.
# Find the index of the last delete and the first create.
last_delete_idx = max(i for i, m in enumerate(mutation_log) if m[0] == "delete")
first_create_idx = min(i for i, m in enumerate(mutation_log) if m[0] == "create")
assert last_delete_idx < first_create_idx, (
f"Deletions must happen before creations to avoid exceeding 100-command limit. "
f"Last delete at index {last_delete_idx}, first create at index {first_create_idx}"
)

View file

@ -510,3 +510,48 @@ class TestToolProgressGrouping:
resolve_display_setting(config, "telegram", "tool_progress_grouping")
== "separate"
)
class TestReasoningStyle:
"""Per-platform reasoning render style (code | blockquote | subtext)."""
def test_discord_defaults_to_subtext(self):
from gateway.display_config import resolve_display_setting
assert resolve_display_setting({}, "discord", "reasoning_style") == "subtext"
def test_other_platforms_default_to_code(self):
from gateway.display_config import resolve_display_setting
for plat in ("telegram", "slack", "matrix", "api_server"):
assert (
resolve_display_setting({}, plat, "reasoning_style") == "code"
), plat
def test_platform_override_wins(self):
from gateway.display_config import resolve_display_setting
config = {"display": {"platforms": {"discord": {"reasoning_style": "blockquote"}}}}
assert (
resolve_display_setting(config, "discord", "reasoning_style") == "blockquote"
)
def test_global_override(self):
from gateway.display_config import resolve_display_setting
config = {"display": {"reasoning_style": "subtext"}}
assert (
resolve_display_setting(config, "telegram", "reasoning_style") == "subtext"
)
def test_invalid_value_falls_back_to_code(self):
from gateway.display_config import resolve_display_setting
config = {"display": {"reasoning_style": "bogus"}}
assert resolve_display_setting(config, "telegram", "reasoning_style") == "code"
def test_case_insensitive(self):
from gateway.display_config import resolve_display_setting
config = {"display": {"reasoning_style": "SUBTEXT"}}
assert resolve_display_setting(config, "telegram", "reasoning_style") == "subtext"

View file

@ -107,7 +107,7 @@ async def test_goal_verdict_done_sent_via_adapter_send(hermes_home):
mgr = GoalManager(session_entry.session_id)
mgr.set("ship the feature")
with patch("hermes_cli.goals.judge_goal", return_value=("done", "the feature shipped", False)):
with patch("hermes_cli.goals.judge_goal", return_value=("done", "the feature shipped", False, None)):
await runner._post_turn_goal_continuation(
session_entry=session_entry,
source=src,
@ -136,7 +136,7 @@ async def test_goal_verdict_continue_enqueues_continuation(hermes_home):
mgr = GoalManager(session_entry.session_id)
mgr.set("polish the docs")
with patch("hermes_cli.goals.judge_goal", return_value=("continue", "still needs work", False)):
with patch("hermes_cli.goals.judge_goal", return_value=("continue", "still needs work", False, None)):
await runner._post_turn_goal_continuation(
session_entry=session_entry,
source=src,
@ -164,7 +164,7 @@ async def test_goal_verdict_budget_exhausted_sends_pause(hermes_home):
state.turns_used = 2
save_goal(session_entry.session_id, state)
with patch("hermes_cli.goals.judge_goal", return_value=("continue", "keep going", False)):
with patch("hermes_cli.goals.judge_goal", return_value=("continue", "keep going", False, None)):
await runner._post_turn_goal_continuation(
session_entry=session_entry,
source=src,
@ -211,7 +211,7 @@ async def test_goal_verdict_survives_adapter_without_send(hermes_home):
runner.adapters[Platform.TELEGRAM] = _NoSendAdapter()
with patch("hermes_cli.goals.judge_goal", return_value=("done", "ok", False)):
with patch("hermes_cli.goals.judge_goal", return_value=("done", "ok", False, None)):
# must not raise
await runner._post_turn_goal_continuation(
session_entry=session_entry,

View file

@ -967,6 +967,105 @@ class TestMediaDeliveryDefaultMode:
assert BasePlatformAdapter.validate_media_delivery_path(str(config_file)) is None
def test_denylist_blocks_google_token_default_mode(self, tmp_path, monkeypatch):
"""Integration credentials at the HERMES_HOME root (google_token.json)
must never be deliverable, even though they aren't the historically
enumerated .env/auth.json/config.yaml files. Regression for a
refreshed google_token.json being auto-attached to a Slack reply
(#50912).
"""
self._patch_roots(monkeypatch)
fake_home = tmp_path / "home"
hermes_dir = fake_home / ".hermes"
hermes_dir.mkdir(parents=True)
token = hermes_dir / "google_token.json"
token.write_text('{"access_token": "***", "refresh_token": "***"}')
monkeypatch.setenv("HOME", str(fake_home))
monkeypatch.setattr("gateway.platforms.base._HERMES_HOME", hermes_dir)
monkeypatch.setattr("gateway.platforms.base._HERMES_ROOT", hermes_dir)
assert BasePlatformAdapter.validate_media_delivery_path(str(token)) is None
def test_denylist_blocks_google_token_even_when_freshly_refreshed(self, tmp_path, monkeypatch):
"""The exploit was that the Google integration rewrites
google_token.json every turn, bumping its mtime to ~now, so the
strict-mode recency window (trust_recent_files) kept re-trusting it
and it re-sent on every reply. An explicit denylist entry must win
over recency trust.
"""
self._patch_roots(monkeypatch) # zero cache allowlist, strict mode on
monkeypatch.setenv("HERMES_MEDIA_TRUST_RECENT_FILES", "1")
monkeypatch.setenv("HERMES_MEDIA_TRUST_RECENT_SECONDS", "600")
fake_home = tmp_path / "home"
hermes_dir = fake_home / ".hermes"
hermes_dir.mkdir(parents=True)
token = hermes_dir / "google_token.json"
token.write_text('{"access_token": "***"}') # mtime = now → "recent"
monkeypatch.setenv("HOME", str(fake_home))
monkeypatch.setattr("gateway.platforms.base._HERMES_HOME", hermes_dir)
monkeypatch.setattr("gateway.platforms.base._HERMES_ROOT", hermes_dir)
assert BasePlatformAdapter.validate_media_delivery_path(str(token)) is None
def test_denylist_blocks_pairing_directory_contents(self, tmp_path, monkeypatch):
"""Files under ~/.hermes/pairing/ (platform pairing tokens) are
credential material and must not be deliverable.
"""
self._patch_roots(monkeypatch)
fake_home = tmp_path / "home"
hermes_dir = fake_home / ".hermes"
pairing = hermes_dir / "pairing"
pairing.mkdir(parents=True)
token = pairing / "telegram-approved.json"
token.write_text('{"approved": ["123"]}')
monkeypatch.setenv("HOME", str(fake_home))
monkeypatch.setattr("gateway.platforms.base._HERMES_HOME", hermes_dir)
monkeypatch.setattr("gateway.platforms.base._HERMES_ROOT", hermes_dir)
assert BasePlatformAdapter.validate_media_delivery_path(str(token)) is None
def test_hermes_cache_still_delivers_under_denied_home(self, tmp_path, monkeypatch):
"""The targeted credential denylist must not break legitimate cache
deliveries: a generated artifact under the allowlisted cache root is
matched before the denylist and still delivers.
"""
fake_home = tmp_path / "home"
hermes_dir = fake_home / ".hermes"
cache_dir = hermes_dir / "cache" / "documents"
cache_dir.mkdir(parents=True)
artifact = cache_dir / "report.pdf"
artifact.write_bytes(b"%PDF-1.4")
self._patch_roots(monkeypatch, cache_dir)
monkeypatch.setenv("HOME", str(fake_home))
monkeypatch.setattr("gateway.platforms.base._HERMES_HOME", hermes_dir)
monkeypatch.setattr("gateway.platforms.base._HERMES_ROOT", hermes_dir)
assert BasePlatformAdapter.validate_media_delivery_path(str(artifact)) == str(artifact.resolve())
def test_denylist_blocks_non_cache_file_under_hermes_home(self, tmp_path, monkeypatch):
"""A non-credential file the agent wrote directly under ~/.hermes
(not in a cache subdir) is still deliverable via recency trust we
did NOT blanket-deny the tree (per #32090/#34425). This guards against
accidentally re-introducing the rejected whole-tree deny.
"""
self._patch_roots(monkeypatch) # strict mode on
monkeypatch.setenv("HERMES_MEDIA_TRUST_RECENT_FILES", "1")
monkeypatch.setenv("HERMES_MEDIA_TRUST_RECENT_SECONDS", "600")
fake_home = tmp_path / "home"
hermes_dir = fake_home / ".hermes"
hermes_dir.mkdir(parents=True)
artifact = hermes_dir / "adhoc_report.pdf"
artifact.write_bytes(b"%PDF-1.4") # fresh mtime
monkeypatch.setenv("HOME", str(fake_home))
monkeypatch.setattr("gateway.platforms.base._HERMES_HOME", hermes_dir)
monkeypatch.setattr("gateway.platforms.base._HERMES_ROOT", hermes_dir)
assert BasePlatformAdapter.validate_media_delivery_path(str(artifact)) == str(artifact.resolve())
def test_strict_mode_envvar_restores_legacy_behavior(self, tmp_path, monkeypatch):
"""Setting HERMES_MEDIA_DELIVERY_STRICT=1 reactivates the older
allowlist+recency logic. A stale file outside the allowlist is

View file

@ -299,6 +299,78 @@ class TestStaleSessionLockSelfHeal:
assert sk in adapter._active_sessions
assert sk in adapter._session_tasks
@pytest.mark.asyncio
async def test_guard_mismatch_preserves_session_task_for_stale_detection(self):
"""When guard mismatch skips _release_session_guard, _session_tasks is preserved.
This is the core of the production split-brain fix: the finally block
only deletes _session_tasks[key] if _active_sessions[key] was actually
released. If the guard was swapped (e.g., by a reset command), the
_session_tasks entry remains so _session_task_is_stale can detect the
done task and heal the lock on the next inbound message.
"""
adapter = _make_adapter()
sk = _session_key()
# Simulate: task recorded with guard=event_a
event_a = asyncio.Event()
async def _done():
return None
done_task = asyncio.create_task(_done())
await done_task
adapter._active_sessions[sk] = event_a
adapter._session_tasks[sk] = done_task
# Simulate guard swap (as reset/new command would do)
event_b = asyncio.Event()
adapter._active_sessions[sk] = event_b
# Drive the REAL finally-block cleanup helper (not a copy of its logic):
# _release_session_guard sees event_b != event_a → skips releasing, so
# _session_tasks must be preserved for stale detection.
adapter._cleanup_finished_session_task(sk, event_a)
# _session_tasks preserved because guard mismatch kept _active_sessions
assert sk in adapter._session_tasks, (
"_session_tasks entry must survive guard mismatch so stale detection works"
)
assert adapter._session_tasks[sk] is done_task
# Stale detection now works: task is done, guard is stale
assert adapter._session_task_is_stale(sk) is True
# Heal clears both
assert adapter._heal_stale_session_lock(sk) is True
assert sk not in adapter._active_sessions
assert sk not in adapter._session_tasks
@pytest.mark.asyncio
async def test_cleanup_releases_and_deletes_when_guard_matches(self):
"""Positive path for #48300: when the guard still matches (normal
completion), the helper releases the guard AND drops the task entry
the release-then-conditional-delete must not strand a healthy session."""
adapter = _make_adapter()
sk = _session_key()
event_a = asyncio.Event()
async def _done():
return None
done_task = asyncio.create_task(_done())
await done_task
adapter._active_sessions[sk] = event_a
adapter._session_tasks[sk] = done_task
# No guard swap → _release_session_guard matches event_a and releases.
adapter._cleanup_finished_session_task(sk, event_a)
assert sk not in adapter._active_sessions, "guard must be released on match"
assert sk not in adapter._session_tasks, "task entry must be dropped after release"
# ===========================================================================
# Layer 3: Runner-side generation guard on slot promotion + release

View file

@ -1754,6 +1754,193 @@ class TestIncomingDocumentHandling:
assert "> /deploy now" in msg_event.text
# ---------------------------------------------------------------------------
# TestIncomingAudioHandling — Slack voice messages (regression)
# ---------------------------------------------------------------------------
class TestSlackAudioExtResolution:
"""Unit coverage for the inbound-audio extension resolver.
Regression for: Slack in-app voice messages are MP4/AAC containers
(``audio/mp4``, filename ``audio_message*.mp4``) that the old code cached
as ``.ogg`` (the catch-all fallback), so OpenAI STT which sniffs the
container from the filename extension rejected them. WhatsApp ``.ogg``
and uploaded ``.m4a`` worked because their extension happened to match.
"""
def test_slack_voice_message_mp4_keeps_real_extension(self):
"""The core bug: audio/mp4 voice message must NOT become .ogg."""
f = {"name": "audio_message.mp4", "mimetype": "audio/mp4"}
ext = _slack_mod._resolve_slack_audio_ext(f, f["mimetype"])
assert ext != ".ogg", "regression: MP4 voice message mislabeled as .ogg"
assert ext in {".mp4", ".m4a"}
assert ext in _slack_mod._SLACK_STT_SUPPORTED_EXTS
def test_whatsapp_ogg_preserved(self):
f = {"name": "voice.ogg", "mimetype": "audio/ogg"}
assert _slack_mod._resolve_slack_audio_ext(f, f["mimetype"]) == ".ogg"
def test_m4a_upload_preserved(self):
f = {"name": "clip.m4a", "mimetype": "audio/x-m4a"}
assert _slack_mod._resolve_slack_audio_ext(f, f["mimetype"]) == ".m4a"
def test_mp3_upload_preserved(self):
f = {"name": "song.mp3", "mimetype": "audio/mpeg"}
assert _slack_mod._resolve_slack_audio_ext(f, f["mimetype"]) == ".mp3"
def test_mimetype_used_when_filename_extension_missing(self):
"""No usable filename ext → fall back to the mime map, not .ogg."""
f = {"name": "", "mimetype": "audio/mp4"}
assert _slack_mod._resolve_slack_audio_ext(f, f["mimetype"]) == ".m4a"
def test_unknown_audio_defaults_to_m4a_not_ogg(self):
"""A truly unknown audio type defaults to the broadly-decodable .m4a."""
f = {"name": "weird", "mimetype": "audio/x-some-future-codec"}
ext = _slack_mod._resolve_slack_audio_ext(f, f["mimetype"])
assert ext == ".m4a"
assert ext != ".ogg"
class TestSlackVoiceClipDetection:
"""Unit coverage for the video/mp4-mislabeled voice-clip detector."""
def test_audio_message_filename_detected(self):
assert _slack_mod._is_slack_voice_clip(
{"name": "audio_message.mp4", "mimetype": "video/mp4"}
)
def test_slack_audio_subtype_detected(self):
assert _slack_mod._is_slack_voice_clip(
{"name": "clip.mp4", "subtype": "slack_audio", "mimetype": "video/mp4"}
)
def test_real_video_not_detected(self):
"""A genuine uploaded video must NOT be hijacked into the audio path."""
assert not _slack_mod._is_slack_voice_clip(
{"name": "vacation.mp4", "mimetype": "video/mp4"}
)
def test_slack_video_clip_not_detected(self):
"""slack_video clips carry a real video track — leave them as video."""
assert not _slack_mod._is_slack_voice_clip(
{"name": "screen_recording.mp4", "subtype": "slack_video"}
)
class TestIncomingAudioHandling:
def _make_event(self, files=None, text="hello"):
return {
"text": text,
"user": "U_USER",
"channel": "D123",
"channel_type": "im",
"ts": "1234567890.000001",
"files": files or [],
"blocks": [],
"attachments": [],
}
@pytest.mark.asyncio
async def test_voice_message_cached_with_correct_extension(self, adapter, tmp_path):
"""audio/mp4 voice message is cached with an STT-acceptable extension,
not the old .ogg fallback, and routed as audio."""
captured = {}
async def _fake_download(url, ext, audio=False, team_id=""):
captured["ext"] = ext
captured["audio"] = audio
path = tmp_path / f"cached{ext}"
path.write_bytes(b"\x00\x00\x00\x18ftypmp42fake mp4 bytes")
return str(path)
with patch.object(adapter, "_download_slack_file", side_effect=_fake_download):
event = self._make_event(
files=[
{
"mimetype": "audio/mp4",
"name": "audio_message.mp4",
"subtype": "slack_audio",
"url_private_download": "https://files.slack.com/audio_message.mp4",
"size": 2048,
}
]
)
await adapter._handle_slack_message(event)
assert captured.get("audio") is True
assert captured["ext"] != ".ogg", "regression: voice message cached as .ogg"
assert captured["ext"] in {".mp4", ".m4a"}
msg_event = adapter.handle_message.call_args[0][0]
assert len(msg_event.media_urls) == 1
# media_type stays audio/* so the gateway routes it to STT
assert msg_event.media_types[0].startswith("audio/")
@pytest.mark.asyncio
async def test_video_mp4_voice_clip_rerouted_to_audio(self, adapter, tmp_path):
"""A voice clip mislabeled video/mp4 is rerouted to the audio path
(cached as audio, reported as audio/*) instead of video understanding."""
captured = {}
async def _fake_download(url, ext, audio=False, team_id=""):
captured["ext"] = ext
captured["audio"] = audio
path = tmp_path / f"cached{ext}"
path.write_bytes(b"\x00\x00\x00\x18ftypmp42fake mp4 bytes")
return str(path)
with patch.object(adapter, "_download_slack_file", side_effect=_fake_download):
event = self._make_event(
files=[
{
"mimetype": "video/mp4",
"name": "audio_message.mp4",
"subtype": "slack_audio",
"url_private_download": "https://files.slack.com/audio_message.mp4",
"size": 2048,
}
]
)
await adapter._handle_slack_message(event)
assert captured.get("audio") is True
assert captured["ext"] in {".mp4", ".m4a"}
msg_event = adapter.handle_message.call_args[0][0]
assert len(msg_event.media_urls) == 1
assert msg_event.media_types[0].startswith("audio/"), (
"voice clip should route to STT, not video understanding"
)
@pytest.mark.asyncio
async def test_real_video_still_routed_as_video(self, adapter, tmp_path):
"""A genuine uploaded video must remain on the video path."""
async def _fake_download_bytes(url, team_id=""):
return b"\x00\x00\x00\x18ftypisomfake real video"
with patch.object(
adapter, "_download_slack_file_bytes", side_effect=_fake_download_bytes
):
event = self._make_event(
files=[
{
"mimetype": "video/mp4",
"name": "vacation.mp4",
"url_private_download": "https://files.slack.com/vacation.mp4",
"size": 4096,
}
]
)
await adapter._handle_slack_message(event)
msg_event = adapter.handle_message.call_args[0][0]
assert len(msg_event.media_urls) == 1
assert msg_event.media_types[0].startswith("video/"), (
"a real video must not be hijacked into the audio path"
)
# ---------------------------------------------------------------------------
# TestMessageRouting
# ---------------------------------------------------------------------------

View file

@ -55,7 +55,8 @@ CHANNEL_ID = "C0AQWDLHY9M"
OTHER_CHANNEL_ID = "C9999999999"
def _make_adapter(require_mention=None, strict_mention=None, free_response_channels=None, allowed_channels=None):
def _make_adapter(require_mention=None, strict_mention=None, free_response_channels=None,
allowed_channels=None, mention_patterns=None):
extra = {}
if require_mention is not None:
extra["require_mention"] = require_mention
@ -65,6 +66,8 @@ def _make_adapter(require_mention=None, strict_mention=None, free_response_chann
extra["free_response_channels"] = free_response_channels
if allowed_channels is not None:
extra["allowed_channels"] = allowed_channels
if mention_patterns is not None:
extra["mention_patterns"] = mention_patterns
adapter = object.__new__(SlackAdapter)
adapter.platform = Platform.SLACK
@ -249,7 +252,10 @@ def _would_process(adapter, *, is_dm=False, channel_id=CHANNEL_ID,
bot_uid = adapter._team_bot_user_ids.get("T1", adapter._bot_user_id)
if mentioned:
text = f"<@{bot_uid}> {text}"
is_mentioned = bot_uid and f"<@{bot_uid}>" in text
is_mentioned = bool(
(bot_uid and f"<@{bot_uid}>" in text)
or adapter._slack_message_matches_mention_patterns(text)
)
if not is_dm and bot_uid:
# allowed_channels check (whitelist — must pass before other gating)
@ -687,3 +693,61 @@ def test_config_bridges_slack_allowed_channels_env_takes_precedence(monkeypatch,
import os as _os
# env var must not be overwritten by config.yaml
assert _os.environ["SLACK_ALLOWED_CHANNELS"] == OTHER_CHANNEL_ID
# ---------------------------------------------------------------------------
# Tests: mention_patterns (wake words) — parity with other adapters (#50732)
# ---------------------------------------------------------------------------
def test_mention_patterns_default_no_match(monkeypatch):
monkeypatch.delenv("SLACK_MENTION_PATTERNS", raising=False)
adapter = _make_adapter()
assert adapter._slack_mention_patterns() == []
assert adapter._slack_message_matches_mention_patterns("hello there") is False
def test_mention_patterns_list_matches():
adapter = _make_adapter(mention_patterns=["hey hermes", "hermes,"])
assert adapter._slack_message_matches_mention_patterns("hey hermes, you there?") is True
assert adapter._slack_message_matches_mention_patterns("just chatting") is False
def test_mention_patterns_case_insensitive():
adapter = _make_adapter(mention_patterns=["hey hermes"])
assert adapter._slack_message_matches_mention_patterns("HEY HERMES!") is True
def test_mention_patterns_single_string():
adapter = _make_adapter(mention_patterns="^hermes")
assert adapter._slack_message_matches_mention_patterns("hermes do this") is True
assert adapter._slack_message_matches_mention_patterns("ok hermes") is False
def test_mention_patterns_invalid_regex_skipped_without_crash():
# An invalid pattern is dropped; valid siblings still work.
adapter = _make_adapter(mention_patterns=["(unclosed", "hey hermes"])
assert adapter._slack_message_matches_mention_patterns("hey hermes") is True
def test_mention_patterns_env_var_fallback(monkeypatch):
monkeypatch.setenv("SLACK_MENTION_PATTERNS", '["hey hermes", "hermes,"]')
adapter = _make_adapter() # no config value -> falls back to env
assert adapter._slack_message_matches_mention_patterns("hey hermes") is True
def test_mention_patterns_env_var_csv_fallback_splits_patterns(monkeypatch):
monkeypatch.setenv("SLACK_MENTION_PATTERNS", "hey hermes,hermes,")
adapter = _make_adapter() # no config value -> falls back to env
patterns = adapter._slack_mention_patterns()
assert [pattern.pattern for pattern in patterns] == ["hey hermes", "hermes"]
assert adapter._slack_message_matches_mention_patterns("hey hermes") is True
def test_mention_patterns_trigger_in_channel_without_literal_mention():
"""A wake word triggers the bot in a channel even with require_mention on."""
adapter = _make_adapter(require_mention=True, mention_patterns=["hey hermes"])
assert _would_process(adapter, text="hey hermes what's the status") is True
# Unrelated channel chatter is still ignored.
assert _would_process(adapter, text="lunch anyone?") is False

View file

@ -0,0 +1,177 @@
"""Regression test for #31599 — Telegram general-pool CLOSE_WAIT fd leak.
Background
----------
PTB's ``telegram.request.HTTPXRequest`` builds the underlying
``httpx.AsyncClient`` with ``limits = httpx.Limits(max_connections=...)``
and *no* keepalive tuning, so httpx's default ``keepalive_expiry=5.0``
applies. Behind an HTTP proxy (Cloudflare Warp etc.) a peer-initiated
FIN can sit in ``CLOSE_WAIT`` longer than that, leaking fds in the
general request pool (``_request[1]`` the pool that routes
``bot.send_message`` / ``set_my_commands``), which
``_drain_polling_connections`` never resets.
The fix wires the shared ``gateway.platforms._http_client_limits``
``platform_httpx_limits()`` helper into *every* HTTPXRequest the adapter
builds the fallback-transport branch, the proxy branch, and the plain
branch so idle keepalive sockets drain aggressively.
Contract asserted here (mutation-survivable)
---------------------------------------------
Every ``HTTPXRequest`` constructed by ``TelegramAdapter.connect()`` must
receive ``httpx_kwargs["limits"]`` that is an ``httpx.Limits`` with a
``keepalive_expiry`` strictly below httpx's 5.0 default and a positive,
bounded ``max_keepalive_connections``. Reverting the limits wiring (so
HTTPXRequest falls back to PTB's default 5.0s keepalive) fails this test.
"""
import asyncio
import sys
from unittest.mock import MagicMock, patch
import httpx
import pytest
from gateway.config import PlatformConfig
def _ensure_telegram_mock():
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
return
telegram_mod = MagicMock()
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
telegram_mod.constants.ChatType.GROUP = "group"
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
telegram_mod.constants.ChatType.CHANNEL = "channel"
telegram_mod.constants.ChatType.PRIVATE = "private"
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
sys.modules.setdefault(name, telegram_mod)
_ensure_telegram_mock()
from plugins.platforms.telegram import adapter as tg_adapter # noqa: E402
from plugins.platforms.telegram.adapter import TelegramAdapter # noqa: E402
class _StopConnect(Exception):
"""Sentinel raised to abort connect() once requests are built."""
class _RecordingHTTPXRequest:
"""Stand-in for PTB's HTTPXRequest that records constructor kwargs."""
instances: list = []
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
_RecordingHTTPXRequest.instances.append(self)
def _make_adapter() -> TelegramAdapter:
return TelegramAdapter(PlatformConfig(enabled=True, token="test-token"))
def _drive_connect(monkeypatch, *, proxy_url):
"""Run connect() far enough to build the HTTPXRequests, then abort.
Returns the list of recorded _RecordingHTTPXRequest instances.
"""
_RecordingHTTPXRequest.instances = []
# No DoH auto-discovery → exercise the proxy / plain branches, not fallback.
async def _no_fallback():
return []
monkeypatch.setattr(tg_adapter, "discover_fallback_ips", _no_fallback)
monkeypatch.setattr(
tg_adapter, "resolve_proxy_url", lambda *a, **k: proxy_url
)
# Replace the real HTTPXRequest with our recorder.
monkeypatch.setattr(tg_adapter, "HTTPXRequest", _RecordingHTTPXRequest)
adapter = _make_adapter()
# Skip the cross-process token lock.
monkeypatch.setattr(adapter, "_acquire_platform_lock", lambda *a, **k: True)
# Ensure the adapter reports no statically-configured fallback IPs.
monkeypatch.setattr(adapter, "_fallback_ips", lambda: [])
# builder.request(...).get_updates_request(...).build() must be harmless;
# make build() raise our sentinel so connect() stops right after the
# HTTPXRequests are constructed (before any real network/init).
fake_built_app = MagicMock()
fake_built_app.initialize = MagicMock(side_effect=_StopConnect)
chainable = MagicMock()
chainable.token.return_value = chainable
chainable.base_url.return_value = chainable
chainable.base_file_url.return_value = chainable
chainable.local_mode.return_value = chainable
chainable.request.return_value = chainable
chainable.get_updates_request.return_value = chainable
chainable.build.side_effect = _StopConnect
builder_root = MagicMock()
builder_root.builder.return_value = chainable
monkeypatch.setattr(tg_adapter, "Application", builder_root)
try:
asyncio.run(adapter.connect())
except _StopConnect:
pass
except Exception:
# connect() wraps work in a try; if it swallows the sentinel and
# continues to real init, the recorded instances are still valid.
pass
return list(_RecordingHTTPXRequest.instances)
def _assert_keepalive_tight(instances):
assert instances, "connect() built no HTTPXRequest — test setup is wrong"
for inst in instances:
limits = inst.kwargs.get("httpx_kwargs", {}).get("limits")
assert isinstance(limits, httpx.Limits), (
"HTTPXRequest must receive httpx_kwargs['limits'] = httpx.Limits "
"wired from platform_httpx_limits() (#31599). Missing → PTB falls "
"back to default keepalive_expiry=5.0 and leaks CLOSE_WAIT fds."
)
# The whole point: keepalive must be tighter than httpx's 5.0 default.
assert limits.keepalive_expiry is not None
assert limits.keepalive_expiry < 5.0, (
"keepalive_expiry must be < httpx default 5.0 so idle/CLOSE_WAIT "
"sockets drain promptly behind a proxy (#31599)."
)
assert limits.max_keepalive_connections is not None
assert 1 <= limits.max_keepalive_connections <= 50
# PTB's connection_pool_size (max_connections) must be preserved.
assert limits.max_connections is not None and limits.max_connections > 0
def test_proxy_branch_general_pool_has_tight_keepalive(monkeypatch):
"""The proxy path the #31599 reporter hit must wire tuned limits."""
instances = _drive_connect(monkeypatch, proxy_url="http://127.0.0.1:9/")
# Both the general request pool and the get_updates pool are built here.
assert len(instances) >= 2
_assert_keepalive_tight(instances)
# Sanity: the proxy was actually threaded through (we're on the proxy branch).
assert any(inst.kwargs.get("proxy") == "http://127.0.0.1:9/" for inst in instances)
def test_plain_branch_general_pool_has_tight_keepalive(monkeypatch):
"""No proxy / no fallback IPs → plain branch must also wire tuned limits."""
instances = _drive_connect(monkeypatch, proxy_url=None)
assert len(instances) >= 2
_assert_keepalive_tight(instances)
def test_limits_keepalive_below_ptb_default_is_the_contract():
"""Document the invariant independent of adapter wiring: the shared
helper itself must tighten keepalive below httpx's 5.0 default."""
from gateway.platforms._http_client_limits import platform_httpx_limits
limits = platform_httpx_limits()
assert isinstance(limits, httpx.Limits)
assert limits.keepalive_expiry is not None and limits.keepalive_expiry < 5.0

View file

@ -0,0 +1,459 @@
"""Regression tests for #31501 — prune stale Telegram DM topic bindings.
When a Telegram user deletes a DM topic in the client, the Bot API
responds to the gateway's next send with ``Thread not found``. The
adapter falls back to a plain send (no ``message_thread_id``), but
prior to this fix it left the corresponding row in
``telegram_dm_topic_bindings`` untouched.
``gateway.run._recover_telegram_topic_thread_id`` then walked the
user's bindings newest-first on every later inbound message and
cheerfully redirected them back to the deleted topic tool
progress, approvals and replies all silently landed in the wrong
place until the operator manually ran ``DELETE`` on ``state.db``.
The fix has three pieces these tests pin all three:
1. ``SessionDB.delete_telegram_topic_binding`` the targeted
prune helper (new public API).
2. ``TelegramAdapter._prune_stale_dm_topic_binding`` the
adapter glue that calls the helper from a send-fallback hot
path without raising on cleanup failure.
3. The two "Thread not found" call sites in the streaming send
loop and the control-message helper now invoke (2) we pin
this with a source-level guard rather than spinning the full
send pipeline.
"""
from __future__ import annotations
import inspect
from types import SimpleNamespace
import pytest
from hermes_state import SessionDB
# ---------------------------------------------------------------------------
# SessionDB.delete_telegram_topic_binding
# ---------------------------------------------------------------------------
def _seed_binding(
db: SessionDB,
*,
chat_id: str = "5595856929",
thread_id: str = "15287",
user_id: str = "5595856929",
session_id: str = "sess-target",
) -> None:
db.create_session(
session_id=session_id,
source="telegram",
user_id=user_id,
)
db.bind_telegram_topic(
chat_id=chat_id,
thread_id=thread_id,
user_id=user_id,
session_key=f"agent:main:telegram:dm:{chat_id}:{thread_id}",
session_id=session_id,
)
class TestDeleteTelegramTopicBinding:
def test_removes_matching_row_and_returns_count(self, tmp_path):
db = SessionDB(db_path=tmp_path / "state.db")
_seed_binding(db, thread_id="15287")
# Sanity check — binding present before prune.
assert db.get_telegram_topic_binding(
chat_id="5595856929", thread_id="15287",
) is not None
removed = db.delete_telegram_topic_binding(
chat_id="5595856929", thread_id="15287",
)
assert removed == 1
assert db.get_telegram_topic_binding(
chat_id="5595856929", thread_id="15287",
) is None
db.close()
def test_does_not_touch_unrelated_bindings(self, tmp_path):
# Critical for the fix: a chat with multiple topics must
# only lose the one Telegram confirmed deleted, never the
# rest. Otherwise the user's healthy topics also vanish
# from recovery's view.
db = SessionDB(db_path=tmp_path / "state.db")
_seed_binding(db, thread_id="15287", session_id="sess-stale")
_seed_binding(db, thread_id="15418", session_id="sess-fresh")
removed = db.delete_telegram_topic_binding(
chat_id="5595856929", thread_id="15287",
)
assert removed == 1
# Stale binding is gone; the fresh one survives.
assert db.get_telegram_topic_binding(
chat_id="5595856929", thread_id="15287",
) is None
assert db.get_telegram_topic_binding(
chat_id="5595856929", thread_id="15418",
) is not None
db.close()
def test_missing_row_returns_zero_silently(self, tmp_path):
db = SessionDB(db_path=tmp_path / "state.db")
_seed_binding(db, thread_id="15287")
# Different thread_id — must not raise, just report 0.
removed = db.delete_telegram_topic_binding(
chat_id="5595856929", thread_id="99999",
)
assert removed == 0
# Original binding still intact.
assert db.get_telegram_topic_binding(
chat_id="5595856929", thread_id="15287",
) is not None
db.close()
def test_pristine_database_with_no_topic_tables_is_silent_noop(self, tmp_path):
# Fresh profile that has never run /topic — the topic-mode
# tables don't exist yet. The send-fallback hot path can
# still hit this code, so we must not crash.
db = SessionDB(db_path=tmp_path / "state.db")
# Confirm precondition: tables really aren't there.
tables = {
row[0]
for row in db._conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' "
"AND name LIKE 'telegram_dm%'"
).fetchall()
}
assert "telegram_dm_topic_bindings" not in tables
removed = db.delete_telegram_topic_binding(
chat_id="any", thread_id="any",
)
assert removed == 0
db.close()
def test_idempotent_under_repeated_calls(self, tmp_path):
db = SessionDB(db_path=tmp_path / "state.db")
_seed_binding(db, thread_id="15287")
first = db.delete_telegram_topic_binding(
chat_id="5595856929", thread_id="15287",
)
second = db.delete_telegram_topic_binding(
chat_id="5595856929", thread_id="15287",
)
assert first == 1
assert second == 0 # already gone, no spurious "1"
db.close()
class TestPruneClearsTopicModeWhenLastBindingGone:
"""Proactive cleanup (#31501 follow-up): pruning the chat's final
binding must also flip ``telegram_dm_topic_mode.enabled`` to 0 so
recovery fully stands down covers the user who disabled topics in
the Telegram client without ever running ``/topic off``."""
def test_clears_enabled_when_last_binding_pruned(self, tmp_path):
db = SessionDB(db_path=tmp_path / "state.db")
db.enable_telegram_topic_mode(
chat_id="5595856929", user_id="5595856929",
)
_seed_binding(db, thread_id="15287")
assert db.is_telegram_topic_mode_enabled(
chat_id="5595856929", user_id="5595856929",
) is True
removed = db.delete_telegram_topic_binding(
chat_id="5595856929", thread_id="15287",
)
assert removed == 1
assert db.is_telegram_topic_mode_enabled(
chat_id="5595856929", user_id="5595856929",
) is False
db.close()
def test_keeps_enabled_while_other_bindings_remain(self, tmp_path):
# Deleting one of several topics must NOT disable topic mode —
# the chat still has healthy lanes that recovery should serve.
db = SessionDB(db_path=tmp_path / "state.db")
db.enable_telegram_topic_mode(
chat_id="5595856929", user_id="5595856929",
)
_seed_binding(db, thread_id="15287", session_id="sess-stale")
_seed_binding(db, thread_id="15418", session_id="sess-fresh")
db.delete_telegram_topic_binding(
chat_id="5595856929", thread_id="15287",
)
assert db.is_telegram_topic_mode_enabled(
chat_id="5595856929", user_id="5595856929",
) is True
db.close()
def test_noop_prune_leaves_enabled_untouched(self, tmp_path):
# A prune that matches no row must not flip the flag — there's
# still a live binding the (wrong) thread_id didn't match.
db = SessionDB(db_path=tmp_path / "state.db")
db.enable_telegram_topic_mode(
chat_id="5595856929", user_id="5595856929",
)
_seed_binding(db, thread_id="15287")
removed = db.delete_telegram_topic_binding(
chat_id="5595856929", thread_id="99999",
)
assert removed == 0
assert db.is_telegram_topic_mode_enabled(
chat_id="5595856929", user_id="5595856929",
) is True
db.close()
# ---------------------------------------------------------------------------
# Adapter glue — _prune_stale_dm_topic_binding
# ---------------------------------------------------------------------------
def _bare_adapter(db: SessionDB | None = None):
# The adapter accesses the SessionDB via
# ``self._session_store._db`` (set by GatewayRunner via
# ``set_session_store``). Build a minimal stand-in with just
# the surface the prune helper touches; we don't need the
# python-telegram-bot import-graph here. ``name`` is a
# property that delegates to ``platform.value.title()``, so
# we set ``platform`` rather than poking ``name`` directly.
from gateway.config import Platform
from plugins.platforms.telegram.adapter import TelegramAdapter
adapter = object.__new__(TelegramAdapter)
adapter.platform = Platform.TELEGRAM
if db is not None:
adapter._session_store = SimpleNamespace(_db=db)
return adapter
class TestPruneStaleDmTopicBindingHelper:
def test_drops_binding_when_session_store_db_is_present(self, tmp_path):
db = SessionDB(db_path=tmp_path / "state.db")
_seed_binding(db, thread_id="15287")
adapter = _bare_adapter(db)
adapter._prune_stale_dm_topic_binding("5595856929", 15287)
assert db.get_telegram_topic_binding(
chat_id="5595856929", thread_id="15287",
) is None
db.close()
def test_silent_when_session_store_unavailable(self):
# No ``_session_store`` attribute — the helper must not
# explode (the streaming send path hits this in tests
# that bypass the gateway runner).
adapter = _bare_adapter()
adapter._prune_stale_dm_topic_binding("123", "456")
def test_silent_when_db_lacks_helper(self):
# Old SessionDB without the new method (e.g. running
# against an older state.db schema). Must be a no-op
# rather than AttributeError.
adapter = _bare_adapter()
adapter._session_store = SimpleNamespace(
_db=SimpleNamespace(), # no methods at all
)
adapter._prune_stale_dm_topic_binding("123", "456")
def test_swallows_db_exceptions_so_send_continues(self):
class ExplodingDb:
def delete_telegram_topic_binding(self, **_):
raise RuntimeError("disk full or whatever")
adapter = _bare_adapter()
adapter._session_store = SimpleNamespace(_db=ExplodingDb())
# The point of the helper is that a failed cleanup must
# NEVER turn into a failed user-facing send. No exception
# should escape.
adapter._prune_stale_dm_topic_binding("123", "456")
def test_skips_when_chat_or_thread_missing(self, tmp_path):
# Defensive — control-message paths sometimes call us
# with chat_id=None when kwargs lack the key. We must
# not produce a spurious DELETE that matches every row
# with a NULL chat_id.
db = SessionDB(db_path=tmp_path / "state.db")
_seed_binding(db, thread_id="15287")
adapter = _bare_adapter(db)
adapter._prune_stale_dm_topic_binding(None, "15287")
adapter._prune_stale_dm_topic_binding("5595856929", None)
# Still there — neither call generated a DELETE.
assert db.get_telegram_topic_binding(
chat_id="5595856929", thread_id="15287",
) is not None
db.close()
# ---------------------------------------------------------------------------
# Source-level wiring guards — both fallback sites must call the helper
# ---------------------------------------------------------------------------
class TestThreadNotFoundFallbackSitesPruneBinding:
"""Pin that the two ``Thread not found`` warning sites in the
Telegram adapter actually invoke ``_prune_stale_dm_topic_binding``.
These guards stop a future refactor from quietly losing the
cleanup wire re-opening #31501.
"""
def test_streaming_send_fallback_calls_prune(self):
from plugins.platforms.telegram import adapter as telegram_mod
src = inspect.getsource(telegram_mod.TelegramAdapter.send)
# Locate the second-failure branch (the one that flips
# ``used_thread_fallback``). It must invoke the prune
# helper before flipping the flag.
marker = "retrying without message_thread_id"
idx = src.find(marker)
assert idx != -1, (
"Streaming send must keep its 'thread not found' "
"fallback log line — the prune wiring is anchored "
"next to it."
)
# 600 char window is enough to cover the warning, the
# prune call, and the ``used_thread_fallback = True``
# assignment that follows.
window = src[idx:idx + 600]
assert "_prune_stale_dm_topic_binding" in window, (
"Streaming send 'Thread not found' fallback must call "
"_prune_stale_dm_topic_binding so the stale row in "
"telegram_dm_topic_bindings doesn't keep redirecting "
"future inbound messages to the deleted topic (#31501)."
)
def test_control_message_helper_calls_prune(self):
from plugins.platforms.telegram import adapter as telegram_mod
src = inspect.getsource(
telegram_mod.TelegramAdapter._send_message_with_thread_fallback
)
# The helper has a single retry path; the prune call
# must sit inside it, not in dead code outside the
# ``if message_thread_id is not None and …`` guard.
assert "_prune_stale_dm_topic_binding" in src, (
"_send_message_with_thread_fallback must call "
"_prune_stale_dm_topic_binding when Telegram returns "
"BadRequest('Thread not found') for a control message "
"(#31501)."
)
# Belt-and-braces: the call must precede the retry
# ``send_message`` so the prune happens whether or not
# the retry itself succeeds.
prune_idx = src.find("_prune_stale_dm_topic_binding")
retry_idx = src.find("send_message(**retry_kwargs)")
assert 0 <= prune_idx < retry_idx, (
"_prune_stale_dm_topic_binding must run before the "
"fallback send_message retry."
)
# ---------------------------------------------------------------------------
# End-to-end semantic — prune + recovery returns None for deleted topic
# ---------------------------------------------------------------------------
class TestRecoveryAfterPrune:
"""The whole point of the fix: once a topic is pruned, the
GatewayRunner's ``_recover_telegram_topic_thread_id`` must no
longer steer future inbound messages to it.
"""
def test_recovery_no_longer_returns_pruned_topic(self, tmp_path):
# Build the same fixture used elsewhere: two topic bindings
# for the same user, then prune the most-recent one.
# ``_recover_telegram_topic_thread_id`` walks bindings
# newest-first, so without the prune it would pick the
# one we just removed.
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.run import GatewayRunner
from gateway.session import SessionSource, build_session_key
db = SessionDB(db_path=tmp_path / "state.db")
db.enable_telegram_topic_mode(
chat_id="5595856929", user_id="5595856929",
)
for sid, thread in (("sess-A", "111"), ("sess-B", "222")):
db.create_session(
session_id=sid, source="telegram",
user_id="5595856929",
)
db.bind_telegram_topic(
chat_id="5595856929",
thread_id=thread,
user_id="5595856929",
session_key=build_session_key(SessionSource(
platform=Platform.TELEGRAM,
user_id="5595856929",
chat_id="5595856929",
user_name="tester",
chat_type="dm",
thread_id=thread,
)),
session_id=sid,
)
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(enabled=True, token="***"),
}
)
runner._session_db = db
runner._telegram_topic_mode_enabled = lambda _src: True
# Sanity: before the prune, recovery picks "222" (newest).
# Recovery only fires for a lobby-shaped inbound (omitted
# message_thread_id or General topic "1"); a non-lobby
# unknown thread is preserved as a brand-new topic. Use the
# General topic id so the recovery walk actually runs.
before = runner._recover_telegram_topic_thread_id(SessionSource(
platform=Platform.TELEGRAM,
user_id="5595856929",
chat_id="5595856929",
user_name="tester",
chat_type="dm",
thread_id="1", # General/stripped reply — triggers recovery
))
assert before == "222"
# User deletes topic 222 in Telegram → adapter prunes.
db.delete_telegram_topic_binding(
chat_id="5595856929", thread_id="222",
)
# Now recovery falls back to topic 111 (the surviving
# binding) instead of the dead one. This is the exact
# behaviour change the bug report asks for.
after = runner._recover_telegram_topic_thread_id(SessionSource(
platform=Platform.TELEGRAM,
user_id="5595856929",
chat_id="5595856929",
user_name="tester",
chat_type="dm",
thread_id="1",
))
assert after == "111"
db.close()

View file

@ -0,0 +1,66 @@
"""Regression test for TUI approval-prompt credential redaction (#48456).
Follow-up to #50767, which redacted the chat-platform and SSE/API approval
transports. The TUI JSON-RPC transport is the third egress: three
`register_gateway_notify` callbacks in `tui_gateway/server.py` emit the raw
`approval_data` (with an unredacted `command`) to the TUI client. They now
route through the module-level `_emit_approval_request` helper, which redacts
`payload["command"]` via the shared `gateway.run._redact_approval_command` seam
before emitting.
"""
import inspect
import pytest
class TestTuiApprovalEmitRedaction:
def test_emit_approval_request_redacts_command_in_payload(self, monkeypatch):
from tui_gateway import server as tui_server
emitted = {}
monkeypatch.setattr(
tui_server, "_emit",
lambda event, sid, payload=None: emitted.update(
{"event": event, "sid": sid, "payload": payload}
),
)
raw = "curl -H 'Authorization: token ghp_01...6789' https://api.github.com"
tui_server._emit_approval_request("sess-1", {"command": raw, "description": "x"})
assert emitted["event"] == "approval.request"
# credential removed, non-command field + command structure preserved
assert "ghp_01...6789" not in emitted["payload"]["command"]
assert emitted["payload"]["description"] == "x"
assert "github.com" in emitted["payload"]["command"]
def test_emit_approval_request_handles_missing_command(self, monkeypatch):
from tui_gateway import server as tui_server
emitted = {}
monkeypatch.setattr(
tui_server, "_emit",
lambda event, sid, payload=None: emitted.update({"payload": payload}),
)
tui_server._emit_approval_request("s", {"description": "no command here"})
assert emitted["payload"] == {"description": "no command here"}
tui_server._emit_approval_request("s", None)
assert emitted["payload"] == {}
def test_no_raw_command_emit_in_approval_registrations(self):
"""Every register_gateway_notify approval callback must route through the
redacting `_emit_approval_request` helper no registration may emit the
raw payload via `_emit("approval.request", ...)` directly. The ONLY
allowed raw emit is inside the helper itself."""
from tui_gateway import server as tui_server
src = inspect.getsource(tui_server)
raw_emits = src.count('_emit("approval.request"')
assert raw_emits == 1, (
f'expected exactly 1 raw _emit("approval.request") (inside the '
f"redacting helper), found {raw_emits} — a registration may be "
f"emitting the unredacted command"
)
assert "_emit_approval_request(sid, data)" in src, (
"registration lambdas must route through _emit_approval_request"
)

View file

@ -113,6 +113,33 @@ def test_active_session_registry_prunes_dead_pids(tmp_path, monkeypatch):
lease.release()
def test_transfer_active_session_reanchors_existing_lease(tmp_path, monkeypatch):
home = tmp_path / ".hermes"
monkeypatch.setenv("HERMES_HOME", str(home))
lease, message = active_sessions.try_acquire_active_session(
session_id="session-old",
surface="tui",
config={"max_concurrent_sessions": 1},
metadata={"live_session_id": "ui-1"},
)
assert message is None
assert lease is not None
assert active_sessions.transfer_active_session(
lease,
session_id="session-new",
metadata={"live_session_id": "ui-1"},
)
snapshot = active_sessions.active_session_registry_snapshot()
assert lease.session_id == "session-new"
assert len(snapshot) == 1
assert snapshot[0]["session_id"] == "session-new"
assert snapshot[0]["metadata"] == {"live_session_id": "ui-1"}
lease.release()
def test_pid_alive_uses_safe_pid_exists_without_signalling(monkeypatch):
checked: list[int] = []

View file

@ -190,7 +190,11 @@ def _arrange_startup_fallback(monkeypatch, tmp_path, running_pids):
def test_gateway_cmd_script_uses_pythonw_without_replace_or_start_churn(monkeypatch):
"""Scheduled Task wrapper should launch pythonw once and avoid replace loops."""
monkeypatch.setattr(gateway_windows, "_derive_venv_pythonw", lambda exe: exe.replace("python.exe", "pythonw.exe"))
monkeypatch.setattr(
gateway_windows,
"_resolve_detached_python",
lambda exe: (exe.replace("python.exe", "pythonw.exe"), r"C:\\Hermes\\hermes-agent\\venv", []),
)
content = gateway_windows._build_gateway_cmd_script(
r"C:\\Hermes\\hermes-agent\\venv\\Scripts\\python.exe",
@ -206,6 +210,41 @@ def test_gateway_cmd_script_uses_pythonw_without_replace_or_start_churn(monkeypa
assert "exit /b 0" in content
def test_gateway_cmd_script_uses_uv_safe_base_pythonw(monkeypatch, tmp_path):
"""Scheduled Task wrapper should share the detached uv-venv workaround."""
project = tmp_path / "project"
scripts = project / "venv" / "Scripts"
site_packages = project / "venv" / "Lib" / "site-packages"
hermes_home = tmp_path / "hermes-home"
base = tmp_path / "uv" / "python" / "cpython-3.11-windows-x86_64-none"
scripts.mkdir(parents=True)
site_packages.mkdir(parents=True)
hermes_home.mkdir()
base.mkdir(parents=True)
venv_python = scripts / "python.exe"
venv_pythonw = scripts / "pythonw.exe"
base_pythonw = base / "pythonw.exe"
for exe in (venv_python, venv_pythonw, base_pythonw):
exe.write_text("", encoding="utf-8")
(project / "venv" / "pyvenv.cfg").write_text(
f"home = {base}\nimplementation = CPython\nuv = 0.11.14\nversion_info = 3.11.15\n",
encoding="utf-8",
)
content = gateway_windows._build_gateway_cmd_script(
str(venv_python),
str(hermes_home),
str(hermes_home),
"",
)
assert str(base_pythonw) in content
assert f'set "VIRTUAL_ENV={project / "venv"}"' in content
assert str(site_packages) in content
assert str(venv_pythonw) not in content
def test_elevated_gateway_command_uses_pythonw_hidden_console(monkeypatch):
"""UAC handoff should not leave a second elevated cmd.exe window open."""
calls = []
@ -239,14 +278,18 @@ def test_install_scheduled_task_recreates_instead_of_change(monkeypatch, tmp_pat
"""Install must delete+create so stale minute-repeat task settings are not preserved."""
calls = []
script_path = tmp_path / "Hermes_Gateway_alice.cmd"
xml_seen = {}
monkeypatch.setattr(gateway_windows, "_assert_windows", lambda: None)
monkeypatch.setattr(gateway_windows, "_resolve_task_user", lambda: r"DOMAIN\\alice")
def fake_schtasks(args):
calls.append(tuple(args))
if args[0] == "/Delete":
return (0, "SUCCESS", "")
if args[0] == "/Create":
xml_path = Path(args[args.index("/XML") + 1])
xml_seen["text"] = xml_path.read_text(encoding="utf-16")
return (0, "SUCCESS", "")
raise AssertionError(f"unexpected schtasks args: {args}")
@ -257,8 +300,88 @@ def test_install_scheduled_task_recreates_instead_of_change(monkeypatch, tmp_pat
assert "/Change" not in [arg for call in calls for arg in call]
assert calls[0][:4] == ("/Delete", "/F", "/TN", "Hermes_Gateway_alice")
assert calls[1][0] == "/Create"
assert "/SC" in calls[1]
assert "ONLOGON" in calls[1]
assert "/XML" in calls[1]
assert "/SC" not in calls[1]
assert "<Delay>PT30S</Delay>" in xml_seen["text"]
assert "<StartWhenAvailable>true</StartWhenAvailable>" in xml_seen["text"]
assert "<StopOnIdleEnd>false</StopOnIdleEnd>" in xml_seen["text"]
assert "<DisallowStartIfOnBatteries>false</DisallowStartIfOnBatteries>" in xml_seen["text"]
assert "<StopIfGoingOnBatteries>false</StopIfGoingOnBatteries>" in xml_seen["text"]
assert "<ExecutionTimeLimit>PT0S</ExecutionTimeLimit>" in xml_seen["text"]
assert "<RestartOnFailure>" in xml_seen["text"]
assert "<Count>999</Count>" in xml_seen["text"]
# Scheduled Task launches the console-less .vbs via wscript.exe, never cmd.exe
# (issue #45599 fix A: no console -> no logon CTRL_CLOSE_EVENT / 0xC000013A).
assert "<Command>wscript.exe</Command>" in xml_seen["text"]
assert "//B //Nologo" in xml_seen["text"]
assert "Hermes_Gateway_alice.vbs" in xml_seen["text"]
assert "cmd.exe" not in xml_seen["text"]
def test_gateway_vbs_script_is_console_less(monkeypatch):
"""The .vbs launcher must avoid cmd.exe entirely and Run pythonw hidden
(issue #45599 fix A: no console -> no logon CTRL_CLOSE_EVENT / 0xC000013A)."""
monkeypatch.setattr(
gateway_windows,
"_resolve_detached_python",
lambda exe: (r"C:\venv\Scripts\pythonw.exe", Path(r"C:\venv"), []),
)
content = gateway_windows._build_gateway_vbs_script(
r"C:\venv\Scripts\python.exe",
r"C:\Hermes",
r"C:\Hermes",
"--profile work",
)
assert "cmd.exe" not in content.lower()
assert 'CreateObject("WScript.Shell")' in content
assert "pythonw.exe" in content
assert "hermes_cli.main" in content
assert "gateway run" in content
assert ", 0, False" in content # hidden window, detached/async
for var in ("HERMES_HOME", "PYTHONIOENCODING", "HERMES_GATEWAY_DETACHED", "VIRTUAL_ENV", "PYTHONPATH"):
assert var in content
assert "--profile" in content and "work" in content
assert content.endswith("\r\n")
def test_gateway_vbs_script_quotes_spaced_paths(monkeypatch):
"""Spaced exe/dir paths stay correctly quoted through the VBScript literal."""
monkeypatch.setattr(
gateway_windows,
"_resolve_detached_python",
lambda exe: (r"C:\Program Files\Py\pythonw.exe", Path(r"C:\v env"), []),
)
content = gateway_windows._build_gateway_vbs_script(
r"C:\Program Files\Py\python.exe",
r"C:\work dir",
r"C:\h home",
"",
)
# list2cmdline quotes the spaced exe; _quote_vbs_string doubles those quotes.
assert '""C:\\Program Files\\Py\\pythonw.exe""' in content
assert 'sh.CurrentDirectory = "C:\\work dir"' in content
def test_gateway_vbs_script_pythonpath_chains_runtime_value(monkeypatch):
"""PYTHONPATH chains onto the task env's existing value, like ;%PYTHONPATH%."""
monkeypatch.setattr(
gateway_windows,
"_resolve_detached_python",
lambda exe: (r"C:\v\pythonw.exe", Path(r"C:\v"), [r"C:\v\Lib\site-packages"]),
)
content = gateway_windows._build_gateway_vbs_script(
r"C:\v\python.exe", r"C:\w", r"C:\h", "",
)
assert 'existing_pp = env.Item("PYTHONPATH")' in content
assert "If Len(existing_pp) > 0 Then" in content
assert r"C:\v\Lib\site-packages" in content
def test_quote_vbs_string_doubles_quotes_and_rejects_newlines():
assert gateway_windows._quote_vbs_string("plain") == '"plain"'
assert gateway_windows._quote_vbs_string('a"b') == '"a""b"'
with pytest.raises(ValueError):
gateway_windows._quote_vbs_string("line1\nline2")
def test_install_scheduled_task_success_start_now_uses_direct_spawn_not_task_run(monkeypatch, tmp_path, capsys):

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,4 @@
"""Tests for ``install_cua_driver`` upgrade semantics and architecture pre-check.
"""Tests for ``install_cua_driver`` upgrade semantics.
The cua-driver upstream installer always pulls the latest release tag, so
re-running it is the canonical upgrade path. ``install_cua_driver(upgrade=True)``
@ -10,30 +10,34 @@ must:
fix for the "we only pulled cua-driver once on enable" complaint).
* Preserve original ``upgrade=False`` behaviour for the toolset-enable flow:
skip if installed, install otherwise, warn on non-macOS.
* Pre-check architecture compatibility before downloading to avoid raw 404
errors on Intel macOS when the upstream release lacks x86_64 assets.
The pre-install arch probe that used to live alongside this function was
deleted (see top-of-file comment in tools_config.py) the upstream
installer has CUA_DRIVER_RS_BAKED_VERSION baked in by CD and errors
cleanly on missing-arch assets, and the upgrade path uses
``cua_driver_update_check()`` (which shells `cua-driver check-update
--json` against the already-installed binary).
"""
from __future__ import annotations
import json
from unittest.mock import MagicMock, patch
from unittest.mock import patch
class TestInstallCuaDriverUpgrade:
def test_upgrade_on_non_macos_is_silent_noop(self):
def test_upgrade_on_unsupported_platform_is_silent_noop(self):
from hermes_cli import tools_config
with patch.object(tools_config, "_print_warning") as warn, \
patch("platform.system", return_value="Linux"):
patch("platform.system", return_value="FreeBSD"):
assert tools_config.install_cua_driver(upgrade=True) is False
warn.assert_not_called()
def test_non_upgrade_on_non_macos_warns(self):
def test_non_upgrade_on_unsupported_platform_warns(self):
from hermes_cli import tools_config
with patch.object(tools_config, "_print_warning") as warn, \
patch("platform.system", return_value="Linux"):
patch("platform.system", return_value="FreeBSD"):
assert tools_config.install_cua_driver(upgrade=False) is False
warn.assert_called()
@ -44,8 +48,6 @@ class TestInstallCuaDriverUpgrade:
patch.object(tools_config.shutil, "which",
side_effect=lambda n: "/usr/local/bin/" + n
if n in {"cua-driver", "curl"} else None), \
patch.object(tools_config, "_check_cua_driver_asset_for_arch",
return_value=True), \
patch.object(tools_config, "_run_cua_driver_installer",
return_value=True) as runner, \
patch("subprocess.run"):
@ -60,8 +62,6 @@ class TestInstallCuaDriverUpgrade:
with patch("platform.system", return_value="Darwin"), \
patch.object(tools_config.shutil, "which",
side_effect=lambda n: "/usr/bin/curl" if n == "curl" else None), \
patch.object(tools_config, "_check_cua_driver_asset_for_arch",
return_value=True), \
patch.object(tools_config, "_run_cua_driver_installer",
return_value=True) as runner:
assert tools_config.install_cua_driver(upgrade=True) is True
@ -85,128 +85,75 @@ class TestInstallCuaDriverUpgrade:
with patch("platform.system", return_value="Darwin"), \
patch.object(tools_config.shutil, "which",
side_effect=lambda n: "/usr/bin/curl" if n == "curl" else None), \
patch.object(tools_config, "_check_cua_driver_asset_for_arch",
return_value=True), \
patch.object(tools_config, "_run_cua_driver_installer",
return_value=True) as runner:
assert tools_config.install_cua_driver(upgrade=False) is True
runner.assert_called_once()
class TestCheckCuaDriverAssetForArch:
def test_arm64_always_returns_true(self):
class TestArchProbeRemoval:
"""Regression tests for the deletion of `_check_cua_driver_asset_for_arch`.
The old probe queried ``/releases/latest`` on trycua/cua and inspected
asset names. That was wrong in two ways:
1. cua-driver-rs releases are marked **prerelease** on every cut, so
``/releases/latest`` returns the Python ``cua-agent`` / ``cua-computer``
package instead a release with zero binary assets. The probe then
reported "no asset for $arch" on Linux x86_64, Windows, macOS Intel,
Linux arm64 every non-Apple-Silicon host.
2. Even with the right endpoint, it duplicated tag-resolution the upstream
installer already does correctly via ``CUA_DRIVER_RS_BAKED_VERSION``
(auto-baked by CD on every release).
The fix: stop probing. Trust the upstream installer for fresh installs
(it has the baked version + correct API fallback) and the
``cua-driver check-update --json`` MCP-binary native command for the
upgrade path.
"""
def test_probe_function_is_gone(self):
from hermes_cli import tools_config
assert not hasattr(tools_config, "_check_cua_driver_asset_for_arch")
assert not hasattr(tools_config, "_latest_cua_driver_rs_release")
with patch("platform.machine", return_value="arm64"):
assert tools_config._check_cua_driver_asset_for_arch() is True
def test_x86_64_with_asset_returns_true(self):
def test_fresh_install_does_not_call_github_api(self):
"""Pre-install no longer probes the GitHub API — the upstream
``install.sh`` resolves the tag from its baked CUA_DRIVER_RS_BAKED_VERSION
line. install.sh errors cleanly when the arch has no asset, so the
probe was duplicate gatekeeping.
"""
from hermes_cli import tools_config
release = {
"tag_name": "cua-driver-v0.1.6",
"assets": [
{"name": "cua-driver-0.1.6-darwin-arm64.tar.gz"},
{"name": "cua-driver-0.1.6-darwin-x86_64.tar.gz"},
],
}
mock_resp = MagicMock()
mock_resp.read.return_value = json.dumps(release).encode()
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
with patch("platform.machine", return_value="x86_64"), \
patch("urllib.request.urlopen", return_value=mock_resp):
assert tools_config._check_cua_driver_asset_for_arch() is True
def test_x86_64_without_asset_returns_false(self):
from hermes_cli import tools_config
release = {
"tag_name": "cua-driver-v0.1.6",
"assets": [
{"name": "cua-driver-0.1.6-darwin-arm64.tar.gz"},
{"name": "cua-driver.tar.gz"},
],
}
mock_resp = MagicMock()
mock_resp.read.return_value = json.dumps(release).encode()
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
with patch("platform.machine", return_value="x86_64"), \
patch("urllib.request.urlopen", return_value=mock_resp), \
patch.object(tools_config, "_print_warning") as warn, \
patch.object(tools_config, "_print_info"):
assert tools_config._check_cua_driver_asset_for_arch() is False
warn.assert_called_once()
assert "no Intel" in warn.call_args[0][0].lower() or "x86_64" in warn.call_args[0][0]
def test_x86_64_api_failure_returns_true(self):
"""Network failure should fail open — let the installer handle it."""
from hermes_cli import tools_config
with patch("platform.machine", return_value="x86_64"), \
patch("urllib.request.urlopen", side_effect=Exception("timeout")):
assert tools_config._check_cua_driver_asset_for_arch() is True
def test_fresh_install_x86_64_no_asset_skips_installer(self):
"""When the latest release has no Intel asset, skip the installer."""
from hermes_cli import tools_config
release = {
"tag_name": "cua-driver-v0.1.6",
"assets": [{"name": "cua-driver-0.1.6-darwin-arm64.tar.gz"}],
}
mock_resp = MagicMock()
mock_resp.read.return_value = json.dumps(release).encode()
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
with patch("platform.system", return_value="Darwin"), \
patch.object(tools_config.shutil, "which",
side_effect=lambda n: "/usr/bin/curl" if n == "curl" else None), \
patch("platform.machine", return_value="x86_64"), \
patch("urllib.request.urlopen", return_value=mock_resp), \
patch.object(tools_config, "_print_warning"), \
patch.object(tools_config, "_print_info"), \
patch.object(tools_config, "_run_cua_driver_installer") as runner:
assert tools_config.install_cua_driver(upgrade=False) is False
runner.assert_not_called()
patch("urllib.request.urlopen") as urlopen, \
patch.object(tools_config, "_run_cua_driver_installer",
return_value=True) as runner:
assert tools_config.install_cua_driver(upgrade=False) is True
runner.assert_called_once()
urlopen.assert_not_called()
def test_upgrade_x86_64_no_asset_returns_existing_status(self):
"""On upgrade with no Intel asset, return whether binary existed."""
def test_upgrade_with_binary_does_not_call_github_api_directly(self):
"""The upgrade path no longer hits GitHub from Python — it delegates
to the upstream ``install.sh`` (which has the baked release tag and
the proper API fallback). When cua-driver is already installed,
``cua_driver_update_check()`` (added in a separate change) further
short-circuits the network re-install via the binary's native
``check-update --json`` verb.
"""
from hermes_cli import tools_config
release = {
"tag_name": "cua-driver-v0.1.6",
"assets": [{"name": "cua-driver-0.1.6-darwin-arm64.tar.gz"}],
}
mock_resp = MagicMock()
mock_resp.read.return_value = json.dumps(release).encode()
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
# With binary installed — returns True (binary exists)
with patch("platform.system", return_value="Darwin"), \
patch.object(tools_config.shutil, "which",
side_effect=lambda n: "/usr/local/bin/" + n
if n in ("cua-driver", "curl") else None), \
patch("platform.machine", return_value="x86_64"), \
patch("urllib.request.urlopen", return_value=mock_resp), \
patch.object(tools_config, "_print_warning"), \
patch.object(tools_config, "_print_info"), \
patch.object(tools_config, "_run_cua_driver_installer") as runner:
patch("urllib.request.urlopen") as urlopen, \
patch("subprocess.run"), \
patch.object(tools_config, "_run_cua_driver_installer",
return_value=True) as runner:
assert tools_config.install_cua_driver(upgrade=True) is True
runner.assert_not_called()
# Without binary — returns False
with patch("platform.system", return_value="Darwin"), \
patch.object(tools_config.shutil, "which",
side_effect=lambda n: "/usr/bin/curl" if n == "curl" else None), \
patch("platform.machine", return_value="x86_64"), \
patch("urllib.request.urlopen", return_value=mock_resp), \
patch.object(tools_config, "_print_warning"), \
patch.object(tools_config, "_print_info"), \
patch.object(tools_config, "_run_cua_driver_installer") as runner:
assert tools_config.install_cua_driver(upgrade=True) is False
runner.assert_not_called()
runner.assert_called_once()
# Probe deleted — no direct GitHub API call from Python.
urlopen.assert_not_called()

View file

@ -639,6 +639,46 @@ def test_aggregator_dedup_does_not_empty_user_defined_custom_provider():
assert or_row["total_models"] == 1
def test_flat_namespace_reseller_keeps_first_party_models_overlapping_user_proxy():
"""opencode-go / opencode-zen are flagged ``is_aggregator=True`` (their
flat ``/v1/models`` returns bare IDs the model-switch resolver searches),
but they are NOT routing aggregators every model they list is a
first-party model under the user's subscription. When a user also runs a
custom proxy that happens to serve a same-named model, the picker dedup
must NOT strip the reseller's own catalog. Regression for #47077, where
opencode-go showed only 13 of 19 models because minimax-m3/m2.7/m2.5,
glm-5/5.1, and deepseek-v4-flash were deduped against an overlapping
custom provider.
"""
rows = [
_user_provider_row("custom:my-proxy", [
"minimax-m3", "minimax-m2.7", "glm-5", "deepseek-v4-flash",
]),
_aggregator_row("opencode-go", [
"kimi-k2.6", "minimax-m3", "minimax-m2.7", "glm-5",
"deepseek-v4-flash", "qwen3.7-max",
]),
_aggregator_row("openrouter", ["minimax-m3", "anthropic/claude-sonnet-4.6"]),
]
ctx = _empty_ctx()
with _list_auth_returning(rows):
payload = build_models_payload(ctx)
go_row = next(r for r in payload["providers"] if r["slug"] == "opencode-go")
or_row = next(r for r in payload["providers"] if r["slug"] == "openrouter")
# The reseller keeps ALL of its first-party models — nothing stripped.
assert go_row["models"] == [
"kimi-k2.6", "minimax-m3", "minimax-m2.7", "glm-5",
"deepseek-v4-flash", "qwen3.7-max",
]
assert go_row["total_models"] == 6
# A TRUE routing aggregator is still deduped against the user's models.
assert "minimax-m3" not in or_row["models"]
assert "anthropic/claude-sonnet-4.6" in or_row["models"]
def test_two_custom_providers_with_overlap_both_survive():
"""Two user-defined custom endpoints that happen to expose an
overlapping model must each keep their full catalog. Neither is the

View file

@ -179,9 +179,10 @@ def _patch_judge(monkeypatch, verdicts):
"""Make judge_goal return a scripted sequence of verdicts."""
seq = list(verdicts)
def _fake_judge(goal, response, subgoals=None):
def _fake_judge(goal, response, subgoals=None, background_processes=None, **_kw):
v = seq.pop(0) if seq else "done"
return v, f"scripted:{v}", False
# 4-tuple contract: (verdict, reason, parse_failed, wait_directive)
return v, f"scripted:{v}", False, None
monkeypatch.setattr(goals, "judge_goal", _fake_judge)

View file

@ -129,6 +129,23 @@ def test_is_aggregator_leaves_unknown_provider_non_aggregator():
assert providers_mod.is_aggregator("not-a-provider") is False
def test_is_routing_aggregator_excludes_flat_namespace_resellers():
"""opencode-go / opencode-zen stay ``is_aggregator=True`` (model-switch
relies on it to search their flat bare-name catalog), but they are NOT
routing aggregators their models are first-party, so the picker dedup
must not strip them. (#47077)"""
# Still aggregators for model-switch flat-catalog resolution.
assert providers_mod.is_aggregator("opencode-go") is True
assert providers_mod.is_aggregator("opencode-zen") is True
# But NOT routing aggregators for picker-dedup purposes.
assert providers_mod.is_routing_aggregator("opencode-go") is False
assert providers_mod.is_routing_aggregator("opencode-zen") is False
# True routers and custom proxies remain routing aggregators.
assert providers_mod.is_routing_aggregator("openrouter") is True
assert providers_mod.is_routing_aggregator("custom:litellm") is True
assert providers_mod.is_routing_aggregator("not-a-provider") is False
def test_switch_model_accepts_explicit_named_custom_provider(monkeypatch):
"""Shared /model switch pipeline should accept --provider for custom_providers."""
monkeypatch.setattr(

View file

@ -24,7 +24,7 @@ These tests pin each layer of the new defence:
* ``_safe_plugin_api_relpath`` rejects absolute paths, ``..``
traversal, and non-string / empty values.
* ``_mount_plugin_api_routes`` re-validates at import time and
refuses project-source plugins outright.
refuses user/project-source plugin backend code outright.
* End-to-end the original PoC manifest no longer triggers
``importlib`` for ``/tmp/payload.py``.
"""
@ -216,7 +216,7 @@ class TestDiscoveryScrubsApiField:
assert entry["_api_file"] is None
assert entry["has_api"] is False
def test_safe_api_path_survives(self, user_plugin_factory, tmp_path):
def test_user_safe_api_path_is_scrubbed(self, user_plugin_factory, tmp_path):
user_plugin_factory("safe", {
"name": "safe",
"label": "Safe",
@ -230,6 +230,86 @@ class TestDiscoveryScrubsApiField:
)
plugins = web_server._get_dashboard_plugins(force_rescan=True)
entry = next(p for p in plugins if p["name"] == "safe")
assert entry["_api_file"] is None
assert entry["has_api"] is False
def test_project_safe_api_path_is_scrubbed(self, tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
(tmp_path / "home").mkdir()
monkeypatch.setenv("HERMES_ENABLE_PROJECT_PLUGINS", "1")
cwd = tmp_path / "project"
cwd.mkdir()
monkeypatch.chdir(cwd)
dashboard = _write_plugin_manifest(
cwd / ".hermes" / "plugins",
"safe-project",
{
"name": "safe-project",
"label": "Safe Project",
"api": "api.py",
"entry": "dist/index.js",
},
)
(dashboard / "api.py").write_text("router = None\n")
plugins = web_server._get_dashboard_plugins(force_rescan=True)
entry = next(p for p in plugins if p["name"] == "safe-project")
assert entry["_api_file"] is None
assert entry["has_api"] is False
def test_bundled_safe_api_path_survives(self, tmp_path, monkeypatch):
hermes_home = tmp_path / "home"
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
hermes_home.mkdir()
monkeypatch.setenv("HERMES_BUNDLED_PLUGINS", str(tmp_path / "bundled"))
dashboard = _write_plugin_manifest(
tmp_path / "bundled",
"safe-bundled",
{
"name": "safe-bundled",
"label": "Safe Bundled",
"api": "api.py",
"entry": "dist/index.js",
},
)
(dashboard / "api.py").write_text("router = None\n")
plugins = web_server._get_dashboard_plugins(force_rescan=True)
entry = next(p for p in plugins if p["name"] == "safe-bundled")
assert entry["_api_file"] == "api.py"
assert entry["has_api"] is True
def test_user_plugin_does_not_shadow_bundled_backend(self, tmp_path, monkeypatch):
hermes_home = tmp_path / "home"
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
hermes_home.mkdir()
monkeypatch.setenv("HERMES_BUNDLED_PLUGINS", str(tmp_path / "bundled"))
bundled_dashboard = _write_plugin_manifest(
tmp_path / "bundled",
"shadowed",
{
"name": "shadowed",
"label": "Bundled Shadowed",
"api": "api.py",
"entry": "dist/index.js",
},
)
(bundled_dashboard / "api.py").write_text("router = None\n")
_write_plugin_manifest(
hermes_home / "plugins",
"shadowed",
{
"name": "shadowed",
"label": "User Shadowed",
"api": "api.py",
"entry": "dist/index.js",
},
)
plugins = web_server._get_dashboard_plugins(force_rescan=True)
entry = next(p for p in plugins if p["name"] == "shadowed")
assert entry["source"] == "bundled"
assert entry["_api_file"] == "api.py"
assert entry["has_api"] is True
@ -276,6 +356,16 @@ class TestMountApiRoutesRefusesUntrusted:
"GHSA-5qr3-c538-wm9j defence-in-depth regression"
)
def test_user_source_api_is_not_imported(self, tmp_path):
plugin = self._payload_plugin(tmp_path, source="user")
web_server._dashboard_plugins_cache = [plugin]
with patch("importlib.util.spec_from_file_location") as spec:
web_server._mount_plugin_api_routes()
assert spec.call_count == 0, (
"user-installed plugin api file was imported — "
"third-party dashboard plugin backend code must stay inert"
)
def test_bundled_source_api_imports_normally(self, tmp_path):
plugin = self._payload_plugin(tmp_path, source="bundled")
web_server._dashboard_plugins_cache = [plugin]

View file

@ -1,6 +1,30 @@
"""Tests for Slack CLI helpers."""
import argparse
from hermes_cli.slack_cli import _build_full_manifest
from hermes_cli.subcommands.slack import build_slack_parser
def _parse_slack_args(argv):
"""Build the real `hermes slack` parser and parse argv against it."""
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="command")
build_slack_parser(subparsers, cmd_slack=lambda _args: 0)
return parser.parse_args(argv)
class TestSlackManifestArgparse:
"""The `--no-assistant` flag wires through argparse to `no_assistant`."""
def test_no_assistant_flag_defaults_false(self):
args = _parse_slack_args(["slack", "manifest"])
assert getattr(args, "no_assistant", False) is False
def test_no_assistant_flag_sets_true(self):
args = _parse_slack_args(["slack", "manifest", "--no-assistant"])
assert args.no_assistant is True
class TestSlackFullManifest:
@ -28,3 +52,35 @@ class TestSlackFullManifest:
assert "assistant:write" in manifest["oauth_config"]["scopes"]["bot"]
bot_events = manifest["settings"]["event_subscriptions"]["bot_events"]
assert "assistant_thread_started" in bot_events
def test_no_assistant_omits_assistant_pieces(self):
manifest = _build_full_manifest(
"Hermes", "Your Hermes agent on Slack", include_assistant=False
)
# assistant_view feature is gone -> Slack renders a flat DM, not the
# Assistant thread pane (where bare slash commands don't dispatch).
assert "assistant_view" not in manifest["features"]
assert "assistant:write" not in manifest["oauth_config"]["scopes"]["bot"]
bot_events = manifest["settings"]["event_subscriptions"]["bot_events"]
assert "assistant_thread_started" not in bot_events
assert "assistant_thread_context_changed" not in bot_events
def test_no_assistant_preserves_core_surface(self):
"""Dropping assistant mode must NOT strip the regular messaging surface."""
manifest = _build_full_manifest(
"Hermes", "Your Hermes agent on Slack", include_assistant=False
)
# Flat DM still needs the Messages tab writable.
assert manifest["features"]["app_home"]["messages_tab_enabled"] is True
# Slash commands and Socket Mode are independent of assistant mode.
assert manifest["features"]["slash_commands"]
assert manifest["settings"]["socket_mode_enabled"] is True
# Channel + DM scopes/events survive so the bot still works everywhere.
bot_scopes = manifest["oauth_config"]["scopes"]["bot"]
for scope in ("commands", "channels:history", "groups:read", "im:history"):
assert scope in bot_scopes
bot_events = manifest["settings"]["event_subscriptions"]["bot_events"]
for event in ("message.im", "message.channels", "message.groups", "app_mention"):
assert event in bot_events

View file

@ -93,7 +93,8 @@ def test_check_for_updates_expired_cache(tmp_path, monkeypatch):
result = check_for_updates()
assert result == 5
assert mock_run.call_count == 3 # origin probe + git fetch + git rev-list
# origin probe + is-shallow probe + git fetch + git rev-list
assert mock_run.call_count == 4
def test_check_for_updates_official_ssh_origin_uses_https_probe(tmp_path):
@ -128,6 +129,99 @@ def test_check_for_updates_official_ssh_origin_uses_https_probe(tmp_path):
assert ["git", "fetch", "origin", "--quiet"] not in calls
def test_check_via_local_git_shallow_clone_behind_reports_no_count(tmp_path):
"""Shallow installer clones must report presence-only, never a bogus count.
On a ``git clone --depth 1`` checkout the history stops at one commit, so
counting ``HEAD..origin/main`` across the shallow boundary yields a huge
nonsense number (the "12492 commits behind" banner). The shallow path must
compare tip SHAs and return UPDATE_AVAILABLE_NO_COUNT instead, and must
never run ``git rev-list --count``.
"""
import hermes_cli.banner as banner
repo_dir = tmp_path / "hermes-agent"
repo_dir.mkdir()
(repo_dir / ".git").mkdir()
calls = []
def fake_run(cmd, **kwargs):
calls.append(cmd)
if cmd == ["git", "remote", "get-url", "origin"]:
return MagicMock(returncode=0, stdout="https://github.com/NousResearch/hermes-agent.git\n")
if cmd == ["git", "rev-parse", "--is-shallow-repository"]:
return MagicMock(returncode=0, stdout="true\n")
if cmd[:2] == ["git", "fetch"]:
return MagicMock(returncode=0, stdout="")
if cmd == ["git", "rev-parse", "HEAD"]:
return MagicMock(returncode=0, stdout="local-sha\n")
if cmd == ["git", "rev-parse", "FETCH_HEAD"]:
return MagicMock(returncode=0, stdout="upstream-sha\n")
if cmd[:3] == ["git", "rev-list", "--count"]:
raise AssertionError("shallow path must not count across the boundary")
raise AssertionError(f"unexpected git command: {cmd!r}")
with patch("hermes_cli.banner.subprocess.run", side_effect=fake_run):
result = banner._check_via_local_git(repo_dir)
assert result == banner.UPDATE_AVAILABLE_NO_COUNT
# The shallow fetch must preserve the boundary (--depth 1), not unshallow.
assert ["git", "fetch", "origin", "--depth", "1", "--quiet"] in calls
def test_check_via_local_git_shallow_clone_up_to_date(tmp_path):
"""Shallow clone whose tip matches upstream reports up-to-date (0)."""
import hermes_cli.banner as banner
repo_dir = tmp_path / "hermes-agent"
repo_dir.mkdir()
(repo_dir / ".git").mkdir()
def fake_run(cmd, **kwargs):
if cmd == ["git", "remote", "get-url", "origin"]:
return MagicMock(returncode=0, stdout="https://github.com/NousResearch/hermes-agent.git\n")
if cmd == ["git", "rev-parse", "--is-shallow-repository"]:
return MagicMock(returncode=0, stdout="true\n")
if cmd[:2] == ["git", "fetch"]:
return MagicMock(returncode=0, stdout="")
if cmd == ["git", "rev-parse", "HEAD"]:
return MagicMock(returncode=0, stdout="same-sha\n")
if cmd == ["git", "rev-parse", "FETCH_HEAD"]:
return MagicMock(returncode=0, stdout="same-sha\n")
raise AssertionError(f"unexpected git command: {cmd!r}")
with patch("hermes_cli.banner.subprocess.run", side_effect=fake_run):
result = banner._check_via_local_git(repo_dir)
assert result == 0
def test_check_via_local_git_full_clone_keeps_exact_count(tmp_path):
"""Full (non-shallow) clones keep the exact rev-list count path."""
import hermes_cli.banner as banner
repo_dir = tmp_path / "hermes-agent"
repo_dir.mkdir()
(repo_dir / ".git").mkdir()
def fake_run(cmd, **kwargs):
if cmd == ["git", "remote", "get-url", "origin"]:
return MagicMock(returncode=0, stdout="https://github.com/NousResearch/hermes-agent.git\n")
if cmd == ["git", "rev-parse", "--is-shallow-repository"]:
return MagicMock(returncode=0, stdout="false\n")
if cmd[:2] == ["git", "fetch"]:
return MagicMock(returncode=0, stdout="")
if cmd[:3] == ["git", "rev-list", "--count"]:
return MagicMock(returncode=0, stdout="7\n")
raise AssertionError(f"unexpected git command: {cmd!r}")
with patch("hermes_cli.banner.subprocess.run", side_effect=fake_run):
result = banner._check_via_local_git(repo_dir)
assert result == 7
def test_check_for_updates_no_git_dir(tmp_path, monkeypatch):
"""Falls back to PyPI check when .git directory doesn't exist anywhere."""
import hermes_cli.banner as banner

View file

@ -597,6 +597,120 @@ def test_resume_windows_gateways_after_update_respawns_unmapped_by_cmdline(
assert "Restarting 1 unmapped Windows gateway process(es)" in out
@patch.object(cli_main, "_is_windows", return_value=True)
def test_pause_returns_cold_start_token_when_installed_but_none_running(
_winp,
monkeypatch,
):
"""No gateway running + autostart entry installed → cold-start token.
A gateway that died between updates (spawning terminal/TUI closed) leaves
nothing for the resume path to relaunch, but the installed autostart entry
is an explicit "I want a gateway" signal. The pause step must return a
token that tells resume to cold-start one.
"""
import hermes_cli.gateway as gateway_mod
from hermes_cli import gateway_windows
monkeypatch.setattr(gateway_mod, "find_gateway_pids", lambda **_k: [])
monkeypatch.setattr(gateway_windows, "is_installed", lambda: True)
token = cli_main._pause_windows_gateways_for_update()
assert token == {
"resume_needed": True,
"profiles": {},
"unmapped_pids": [],
"unmapped": [],
"cold_start_if_installed": True,
}
@patch.object(cli_main, "_is_windows", return_value=True)
def test_pause_returns_none_when_nothing_running_and_not_installed(
_winp,
monkeypatch,
):
"""No gateway running + no autostart entry → no token (gateway-less user).
Users who deliberately run without a gateway must not get one forced on
them by an update.
"""
import hermes_cli.gateway as gateway_mod
from hermes_cli import gateway_windows
monkeypatch.setattr(gateway_mod, "find_gateway_pids", lambda **_k: [])
monkeypatch.setattr(gateway_windows, "is_installed", lambda: False)
assert cli_main._pause_windows_gateways_for_update() is None
@patch.object(cli_main, "_is_windows", return_value=True)
def test_resume_cold_starts_gateway_when_token_requests_it(
_winp,
monkeypatch,
capsys,
):
"""cold_start_if_installed token + nothing running → fresh detached spawn."""
import hermes_cli.gateway as gateway_mod
from hermes_cli import gateway_windows
monkeypatch.setattr(gateway_mod, "find_gateway_pids", lambda **_k: [])
spawned = []
monkeypatch.setattr(
gateway_windows,
"_spawn_detached",
lambda: spawned.append(True) or 4242,
)
token = {
"resume_needed": True,
"profiles": {},
"unmapped_pids": [],
"unmapped": [],
"cold_start_if_installed": True,
}
cli_main._resume_windows_gateways_after_update(token)
assert token["resume_needed"] is False
assert spawned == [True]
assert "Starting Windows gateway after update (PID 4242)" in capsys.readouterr().out
@patch.object(cli_main, "_is_windows", return_value=True)
def test_resume_cold_start_skips_when_gateway_already_running(
_winp,
monkeypatch,
capsys,
):
"""Don't double-start: if a gateway came up between pause and resume
(e.g. the autostart entry fired), the cold-start must no-op."""
import hermes_cli.gateway as gateway_mod
from hermes_cli import gateway_windows
monkeypatch.setattr(gateway_mod, "find_gateway_pids", lambda **_k: [9001])
spawned = []
monkeypatch.setattr(
gateway_windows,
"_spawn_detached",
lambda: spawned.append(True) or 4242,
)
token = {
"resume_needed": True,
"profiles": {},
"unmapped_pids": [],
"unmapped": [],
"cold_start_if_installed": True,
}
cli_main._resume_windows_gateways_after_update(token)
assert spawned == []
assert "Starting Windows gateway after update" not in capsys.readouterr().out
# ---------------------------------------------------------------------------
# cmd_update integration — concurrent-instance gate
# ---------------------------------------------------------------------------

View file

@ -263,6 +263,29 @@ class TestWebServerEndpoints:
import hermes_cli.web_server as web_server
monkeypatch.setattr(hermes_constants, "is_container", lambda: True)
# A docker install inside a container should be managed externally.
monkeypatch.setattr(web_server, "detect_install_method", lambda _root: "docker")
assert web_server._dashboard_local_update_managed_externally() is True
def test_dashboard_update_capability_allows_git_in_container(self, monkeypatch):
"""A git checkout inside a container (e.g. bind-mounted in hermes-webui)
should still offer dashboard updates the checkout is self-managed."""
import hermes_constants
import hermes_cli.web_server as web_server
monkeypatch.setattr(hermes_constants, "is_container", lambda: True)
monkeypatch.setattr(web_server, "detect_install_method", lambda _root: "git")
assert web_server._dashboard_local_update_managed_externally() is False
def test_dashboard_update_capability_blocks_pip_in_container(self, monkeypatch):
"""A pip install inside a container is still managed externally."""
import hermes_constants
import hermes_cli.web_server as web_server
monkeypatch.setattr(hermes_constants, "is_container", lambda: True)
monkeypatch.setattr(web_server, "detect_install_method", lambda _root: "pip")
assert web_server._dashboard_local_update_managed_externally() is True
@ -1011,6 +1034,8 @@ class TestWebServerEndpoints:
spawned = True
raise AssertionError("docker update guard should not spawn hermes update")
# Bypass the managed-externally gate so we reach the docker install check.
monkeypatch.setattr(web_server, "_dashboard_local_update_managed_externally", lambda: False)
monkeypatch.setattr(web_server, "detect_install_method", lambda _root: "docker")
monkeypatch.setattr(web_server, "_spawn_hermes_action", fail_spawn)
web_server._ACTION_PROCS.pop("hermes-update", None)
@ -5070,14 +5095,8 @@ class TestPluginAPIAuth:
"""Tests that plugin API routes require the session token (issue #19533)."""
@pytest.fixture(autouse=True)
def _setup_test_client(self, monkeypatch, _isolate_hermes_home, _install_example_plugin):
"""Create a TestClient without the session token header.
Pulls in ``_install_example_plugin`` so ``test_plugin_route_allows_auth``
has the ``/api/plugins/example/hello`` endpoint available the
example plugin is no longer a bundled plugin, so the fixture
installs it into the per-test ``HERMES_HOME``.
"""
def _setup_test_client(self, monkeypatch, _isolate_hermes_home):
"""Create TestClients with and without the session token header."""
try:
from starlette.testclient import TestClient
except ImportError:
@ -5102,19 +5121,15 @@ class TestPluginAPIAuth:
def test_plugin_route_allows_auth(self):
"""Plugin API routes should work with a valid session token.
Uses ``/api/plugins/example/hello`` from the example-dashboard
test fixture (installed into HERMES_HOME by the class-level
``_install_example_plugin`` fixture) a stable, side-effect-free
GET that's only loaded for tests. With a valid token the handler
should run (200); without one the middleware should 401 before
the handler is reached.
Uses a bundled plugin route so the test covers authenticated plugin
API access without relying on user-installed plugin backend imports.
"""
# Without auth: middleware blocks before reaching the handler.
resp = self.client.get("/api/plugins/example/hello")
resp = self.client.get("/api/plugins/kanban/board")
assert resp.status_code == 401
# With auth: handler runs.
resp = self.auth_client.get("/api/plugins/example/hello")
resp = self.auth_client.get("/api/plugins/kanban/board")
assert resp.status_code == 200
def test_plugin_post_requires_auth(self):

View file

@ -155,15 +155,31 @@ class TestResolveSessionNameTitle:
result = cfg.resolve_session_name("/some/dir", session_id=None)
assert result == "dir"
def test_title_beats_session_id(self):
def test_per_session_id_beats_title(self):
# per-session: the run's session_id is authoritative; an (auto-)generated
# title must NOT remap a live conversation onto a second Honcho session.
cfg = HonchoClientConfig(session_strategy="per-session")
result = cfg.resolve_session_name("/some/dir", session_title="my-title", session_id="20260309_175514_9797dd")
assert result == "20260309_175514_9797dd"
def test_per_session_id_beats_manual_map(self):
# per-session: session_id also wins over a stale cwd map entry (e.g. the
# desktop launching from a mapped home dir).
cfg = HonchoClientConfig(session_strategy="per-session", sessions={"/some/dir": "pinned"})
result = cfg.resolve_session_name("/some/dir", session_id="20260309_175514_9797dd")
assert result == "20260309_175514_9797dd"
def test_title_still_applies_for_non_per_session(self):
# Outside per-session, /title still names the Honcho session.
cfg = HonchoClientConfig(session_strategy="per-directory")
result = cfg.resolve_session_name("/some/dir", session_title="my-title", session_id="20260309_175514_9797dd")
assert result == "my-title"
def test_manual_beats_session_id(self):
cfg = HonchoClientConfig(session_strategy="per-session", sessions={"/some/dir": "pinned"})
result = cfg.resolve_session_name("/some/dir", session_id="20260309_175514_9797dd")
assert result == "pinned"
def test_gateway_key_beats_per_session_id(self):
# Gateways keep per-chat isolation even in per-session.
cfg = HonchoClientConfig(session_strategy="per-session")
result = cfg.resolve_session_name("/some/dir", gateway_session_key="agent:main:telegram:dm:42", session_id="20260309_175514_9797dd")
assert result == "agent-main-telegram-dm-42"
def test_global_strategy_returns_workspace(self):
cfg = HonchoClientConfig(session_strategy="global", workspace_id="my-workspace")

View file

@ -234,6 +234,66 @@ class TestCmdStatus:
assert "FAILED (Invalid API key)" in out
assert "Connection... OK" not in out
def test_auth_line_detects_oauth_grant(self, monkeypatch, capsys, tmp_path):
import plugins.memory.honcho.cli as honcho_cli
cfg_path = tmp_path / "honcho.json"
cfg_path.write_text("{}")
class FakeConfig:
enabled = True
api_key = "hch-at-deadbeef"
workspace_id = "claude-code"
host = "hermes"
base_url = None
ai_peer = "hermes"
peer_name = "eri"
recall_mode = "hybrid"
user_observe_me = True
user_observe_others = False
ai_observe_me = False
ai_observe_others = True
write_frequency = "async"
session_strategy = "per-session"
context_tokens = None
dialectic_reasoning_level = "low"
reasoning_level_cap = "high"
reasoning_heuristic = True
raw = {
"hosts": {
"hermes": {
"apiKey": "hch-at-deadbeef",
"oauth": {
"refreshToken": "hch-rt-x",
"clientId": "hermes-agent",
"tokenEndpoint": "https://api.honcho.dev/oauth/token",
"expiresAt": 9999999999,
},
}
}
}
def resolve_session_name(self):
return "hermes"
monkeypatch.setattr(honcho_cli, "_read_config", lambda: {})
monkeypatch.setattr(honcho_cli, "_config_path", lambda: cfg_path)
monkeypatch.setattr(honcho_cli, "_local_config_path", lambda: cfg_path)
monkeypatch.setattr(honcho_cli, "_active_profile_name", lambda: "default")
monkeypatch.setattr(
"plugins.memory.honcho.client.HonchoClientConfig.from_global_config",
lambda host=None: FakeConfig(),
)
monkeypatch.setattr("plugins.memory.honcho.client.get_honcho_client", lambda cfg: object())
monkeypatch.setattr(honcho_cli, "_show_peer_cards", lambda hcfg, client: None)
monkeypatch.setitem(__import__("sys").modules, "honcho", SimpleNamespace())
honcho_cli.cmd_status(SimpleNamespace(all=False))
out = capsys.readouterr().out
assert "Auth: OAuth (hermes-agent" in out
assert "API key:" not in out
class TestCloneHonchoForProfile:
"""Identity-key carryover during profile cloning.
@ -389,6 +449,9 @@ class TestSetupWizardDeploymentShape:
# Scripted _prompt: pop answers in order. Default-return for unconsumed prompts.
answer_iter = iter(answers)
def _scripted_prompt(label, default=None, secret=False):
# Auth-method prompt is orthogonal to shape; auto-answer apikey so the answer lists stay shape-only.
if "OAuth" in label:
return "apikey"
try:
return next(answer_iter)
except StopIteration:

View file

@ -711,15 +711,17 @@ class TestResolveSessionNameGatewayKey:
)
assert result == "agent-main-telegram-dm-8439114563"
def test_session_title_still_wins_over_gateway_key(self):
"""Explicit /title remap takes priority over gateway_session_key."""
def test_gateway_key_not_remapped_by_title(self):
"""A title never remaps a stable identifier — the gateway per-chat key
wins over the title so a generated title can't split a live conversation
onto a new Honcho session."""
config = HonchoClientConfig(session_strategy="per-session")
result = config.resolve_session_name(
session_title="my-custom-title",
session_id="20260412_171002_69bb38",
gateway_session_key="agent:main:telegram:dm:8439114563",
)
assert result == "my-custom-title"
assert result == "agent-main-telegram-dm-8439114563"
def test_per_session_fallback_without_gateway_key(self):
"""Without gateway_session_key, per-session returns session_id (CLI path)."""

View file

@ -0,0 +1,254 @@
"""Tests for plugins/memory/honcho/oauth.py — OAuth grant storage + refresh."""
import json
from pathlib import Path
import pytest
from plugins.memory.honcho import oauth
from plugins.memory.honcho.oauth import OAuthCredential
def _host_block(refresh="hch-rt-old", expires_at=10_000):
return {
"apiKey": "hch-at-old",
"oauth": {
"refreshToken": refresh,
"expiresAt": expires_at,
"clientId": "hermes-desktop",
"tokenEndpoint": "http://localhost:8000/oauth/token",
"scope": "write",
"tokenType": "Bearer",
},
}
def _write(path: Path, raw: dict) -> None:
path.write_text(json.dumps(raw), encoding="utf-8")
class TestTokenDetection:
def test_access_token_prefix(self):
assert oauth.is_oauth_access_token("hch-at-abc")
assert not oauth.is_oauth_access_token("hch-v3-abc")
assert not oauth.is_oauth_access_token("hch-rt-abc")
assert not oauth.is_oauth_access_token(None)
class TestCredentialModel:
def test_roundtrip(self):
cred = OAuthCredential.from_host_block(_host_block())
assert cred is not None
block = cred.oauth_block()
assert block["refreshToken"] == "hch-rt-old"
assert block["expiresAt"] == 10_000
assert block["clientId"] == "hermes-desktop"
def test_incomplete_block_returns_none(self):
# plain API key (no oauth sub-block)
assert OAuthCredential.from_host_block({"apiKey": "hch-v3-x"}) is None
# oauth block missing refreshToken
bad = _host_block()
del bad["oauth"]["refreshToken"]
assert OAuthCredential.from_host_block(bad) is None
def test_is_expired_respects_skew(self):
cred = OAuthCredential.from_host_block(_host_block(expires_at=1000))
assert not cred.is_expired(now=800, skew=120) # 1000-120=880 > 800
assert cred.is_expired(now=900, skew=120) # 900 >= 880
class TestEnsureFreshToken:
def test_no_oauth_credential_is_noop(self, tmp_path):
path = tmp_path / "honcho.json"
_write(path, {"hosts": {"hermes": {"apiKey": "hch-v3-static"}}})
token, refreshed = oauth.ensure_fresh_token(path, "hermes", now=0)
assert token is None and refreshed is False
def test_fresh_token_skips_refresh(self, tmp_path, monkeypatch):
path = tmp_path / "honcho.json"
_write(path, {"hosts": {"hermes": _host_block(expires_at=10_000)}})
monkeypatch.setattr(
oauth, "_http_post_form",
lambda *a, **k: pytest.fail("refresh must not be called when fresh"),
)
token, refreshed = oauth.ensure_fresh_token(path, "hermes", now=0)
assert token == "hch-at-old" and refreshed is False
def test_fresh_token_served_from_cache_without_disk(self, tmp_path, monkeypatch):
path = tmp_path / "honcho.json"
_write(path, {"hosts": {"hermes": _host_block(expires_at=10_000)}})
oauth._expiry_cache.clear()
# First call seeds the cache from disk.
oauth.ensure_fresh_token(path, "hermes", now=0)
# Second call must not touch disk while the token is well clear of expiry.
monkeypatch.setattr(
oauth, "_read_config",
lambda *a, **k: pytest.fail("disk must not be read while token is fresh"),
)
token, refreshed = oauth.ensure_fresh_token(path, "hermes", now=100)
assert token == "hch-at-old" and refreshed is False
def test_expired_token_refreshes_and_persists_rotation(self, tmp_path, monkeypatch):
path = tmp_path / "honcho.json"
_write(path, {"hosts": {"hermes": _host_block(expires_at=100)}})
def fake_post(url, data, timeout):
assert data["grant_type"] == "refresh_token"
assert data["refresh_token"] == "hch-rt-old"
assert data["client_id"] == "hermes-desktop"
return {
"access_token": "hch-at-new",
"refresh_token": "hch-rt-new",
"expires_in": 3600,
"scope": "write",
"token_type": "Bearer",
}
monkeypatch.setattr(oauth, "_http_post_form", fake_post)
token, refreshed = oauth.ensure_fresh_token(path, "hermes", now=1000)
assert token == "hch-at-new" and refreshed is True
# Rotated refresh token + new access token + absolute expiry persisted.
saved = json.loads(path.read_text())["hosts"]["hermes"]
assert saved["apiKey"] == "hch-at-new"
assert saved["oauth"]["refreshToken"] == "hch-rt-new"
assert saved["oauth"]["expiresAt"] == 1000 + 3600
def test_refresh_failure_fails_open(self, tmp_path, monkeypatch):
path = tmp_path / "honcho.json"
_write(path, {"hosts": {"hermes": _host_block(expires_at=100)}})
def boom(*a, **k):
raise RuntimeError("network down")
monkeypatch.setattr(oauth, "_http_post_form", boom)
token, refreshed = oauth.ensure_fresh_token(path, "hermes", now=1000)
# Stale token returned, no crash, file untouched.
assert token == "hch-at-old" and refreshed is False
assert json.loads(path.read_text())["hosts"]["hermes"]["apiKey"] == "hch-at-old"
def test_double_check_uses_disk_when_already_rotated(self, tmp_path, monkeypatch):
# Simulates a concurrent thread that rotated the token on disk after our
# stale in-memory snapshot: the locked re-read must skip the HTTP call.
path = tmp_path / "honcho.json"
_write(path, {"hosts": {"hermes": _host_block(refresh="hch-rt-fresh", expires_at=10_000)}})
stale_raw = {"hosts": {"hermes": _host_block(refresh="hch-rt-old", expires_at=100)}}
stale_raw["hosts"]["hermes"]["apiKey"] = "hch-at-stale"
monkeypatch.setattr(
oauth, "_http_post_form",
lambda *a, **k: pytest.fail("must not refresh; disk token is fresh"),
)
token, refreshed = oauth.ensure_fresh_token(path, "hermes", stale_raw, now=1000)
assert token == "hch-at-old" # the on-disk fresh credential's access token
def test_refresh_holds_cross_process_lock(self, tmp_path, monkeypatch):
# A second opener must not grab <config>.lock mid-refresh — proving the
# rotation is serialized machine-wide so peers can't replay the token.
fcntl = pytest.importorskip("fcntl")
path = tmp_path / "honcho.json"
_write(path, {"hosts": {"hermes": _host_block(expires_at=100)}})
seen = {}
def fake_post(url, data, timeout):
with open(f"{path}.lock", "a+b") as other:
try:
fcntl.flock(other.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
fcntl.flock(other.fileno(), fcntl.LOCK_UN)
seen["held"] = False
except OSError:
seen["held"] = True
return {"access_token": "hch-at-new", "refresh_token": "hch-rt-new",
"expires_in": 3600, "scope": "write", "token_type": "Bearer"}
monkeypatch.setattr(oauth, "_http_post_form", fake_post)
token, refreshed = oauth.ensure_fresh_token(path, "hermes", now=1000)
assert refreshed is True and seen.get("held") is True
# Released afterward: a non-blocking acquire now succeeds.
with open(f"{path}.lock", "a+b") as fh:
fcntl.flock(fh.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
fcntl.flock(fh.fileno(), fcntl.LOCK_UN)
def test_refresh_degrades_when_lock_unavailable(self, tmp_path, monkeypatch):
# No flock (unsupported FS/platform) must not block refresh — it falls
# back to in-process serialization only.
fcntl = pytest.importorskip("fcntl")
path = tmp_path / "honcho.json"
_write(path, {"hosts": {"hermes": _host_block(expires_at=100)}})
def no_flock(*a, **k):
raise OSError("flock unsupported")
monkeypatch.setattr(fcntl, "flock", no_flock)
monkeypatch.setattr(
oauth, "_http_post_form",
lambda *a, **k: {"access_token": "hch-at-new", "refresh_token": "hch-rt-new",
"expires_in": 3600, "scope": "write", "token_type": "Bearer"},
)
token, refreshed = oauth.ensure_fresh_token(path, "hermes", now=1000)
assert token == "hch-at-new" and refreshed is True
class TestInstallGrant:
def test_deep_merges_config_and_preserves_other_hosts(self, tmp_path):
path = tmp_path / "honcho.json"
_write(path, {
"apiKey": "hch-v3-root", # root static key preserved
"hosts": {
"obsidian": {"workspace": "obsidian"},
"hermes": {"workspace": "hermes", "saveMessages": False},
},
})
grant = {
"access_token": "hch-at-fresh",
"refresh_token": "hch-rt-fresh",
"expires_in": 3600,
"scope": "write",
"config": {
"environment": "production",
"hosts": {"hermes": {"saveMessages": True, "recallMode": "hybrid"}},
},
}
cred = oauth.install_grant(
path, "hermes", grant,
client_id="hermes-desktop",
token_endpoint="http://localhost:8000/oauth/token",
now=1000,
)
assert cred.expires_at == 1000 + 3600
saved = json.loads(path.read_text())
assert saved["apiKey"] == "hch-v3-root" # untouched
assert saved["hosts"]["obsidian"] == {"workspace": "obsidian"} # untouched
h = saved["hosts"]["hermes"]
assert h["apiKey"] == "hch-at-fresh"
assert h["oauth"]["refreshToken"] == "hch-rt-fresh"
assert h["saveMessages"] is True # grant config won the deep-merge
assert h["recallMode"] == "hybrid" # new key added
assert h["workspace"] == "hermes" # pre-existing key preserved
assert saved["environment"] == "production" # root key from grant
def test_rejects_grant_without_tokens(self, tmp_path):
path = tmp_path / "honcho.json"
_write(path, {})
with pytest.raises(ValueError):
oauth.install_grant(
path, "hermes", {"access_token": "hch-at-x"}, # no refresh_token
client_id="c", token_endpoint="e",
)
class TestApplyTokenToClient:
def test_mutates_live_bearer(self):
class FakeHttp:
api_key = "hch-at-old"
class FakeClient:
_http = FakeHttp()
client = FakeClient()
assert oauth.apply_token_to_client(client, "hch-at-new") is True
assert client._http.api_key == "hch-at-new"
def test_returns_false_when_shape_unknown(self):
assert oauth.apply_token_to_client(object(), "hch-at-new") is False

View file

@ -0,0 +1,347 @@
"""End-to-end test for the zero-CLI Honcho OAuth flow against a fake AS.
Stands up a real local authorization server (no network, no browser) and drives
the full path: begin /authorize 302 loopback :8765 callback token
exchange install_grant forced-expiry refresh with rotation. This is the
deterministic "real smoke test" for the consumer flow.
"""
import json
import threading
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from urllib.parse import parse_qs, urlparse
import httpx
import pytest
from plugins.memory.honcho import oauth, oauth_flow
class _FakeAS(BaseHTTPRequestHandler):
"""Minimal OAuth 2.1 AS: /authorize 302s to the callback; /oauth/token mints."""
# Rotation counter shared across requests so refresh returns a new token.
issued = {"n": 0}
def do_GET(self): # noqa: N802
parsed = urlparse(self.path)
if parsed.path != "/authorize":
self.send_response(404)
self.end_headers()
return
q = parse_qs(parsed.query)
redirect = q["redirect_uri"][0]
# The redirect must be the IP literal matching the bound host — a
# `localhost` redirect can resolve to ::1 and miss the IPv4 listener.
# Host must be the IP literal (port may fall back off :8765).
assert redirect.startswith("http://127.0.0.1:") and "/callback" in redirect, redirect
# Consent shows a home-relative display path — never an absolute path
# that would leak the username / home layout off the machine.
cp = q["config_path"][0]
assert cp.endswith("honcho.json"), q.get("config_path")
assert not cp.startswith("/"), cp
state = q["state"][0]
location = f"{redirect}?code=test-auth-code&state={state}"
self.send_response(302)
self.send_header("Location", location)
self.end_headers()
def do_POST(self): # noqa: N802
parsed = urlparse(self.path)
if parsed.path != "/oauth/token":
self.send_response(404)
self.end_headers()
return
length = int(self.headers.get("Content-Length", 0))
form = parse_qs(self.rfile.read(length).decode())
grant_type = form["grant_type"][0]
self.issued["n"] += 1
n = self.issued["n"]
body = {
"access_token": f"hch-at-{n}",
"refresh_token": f"hch-rt-{n}",
"token_type": "Bearer",
"expires_in": 3600,
"scope": "write",
}
if grant_type == "authorization_code":
body["config"] = {
"peerName": "lyra",
"environment": "production",
"hosts": {"hermes": {"saveMessages": True, "recallMode": "hybrid"}},
}
payload = json.dumps(body).encode()
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(payload)
def log_message(self, *args):
return
@pytest.fixture
def fake_as(monkeypatch):
_FakeAS.issued["n"] = 0
server = HTTPServer(("127.0.0.1", 0), _FakeAS)
port = server.server_address[1]
thread = threading.Thread(target=server.serve_forever, daemon=True)
thread.start()
base = f"http://127.0.0.1:{port}"
monkeypatch.setenv("HONCHO_OAUTH_AUTHORIZE_URL", f"{base}/authorize")
monkeypatch.setenv("HONCHO_OAUTH_TOKEN_URL", f"{base}/oauth/token")
monkeypatch.setenv("HONCHO_OAUTH_CLIENT_ID", "hermes-desktop")
try:
yield base
finally:
server.shutdown()
server.server_close()
def _browser_driver(authorize_url: str) -> None:
"""Stand in for the user's browser: follow /authorize's 302 into the callback.
Retries the callback GET so it can't lose the race to the loopback bind.
"""
resp = httpx.get(authorize_url, follow_redirects=False)
location = resp.headers["Location"]
for _ in range(50):
try:
httpx.get(location, timeout=2)
return
except httpx.ConnectError:
time.sleep(0.05)
raise RuntimeError("loopback callback never came up")
def test_full_loopback_flow_then_refresh(tmp_path, fake_as):
config_path = tmp_path / "honcho.json"
config_path.write_text(json.dumps({"hosts": {"obsidian": {"workspace": "obsidian"}}}))
cred = oauth_flow.authorize_via_loopback(
config_path=config_path,
host="hermes",
open_url=lambda url: _browser_driver(url),
timeout=10,
)
# Grant installed: token stored, config deep-merged, other host preserved.
assert cred.access_token == "hch-at-1"
saved = json.loads(config_path.read_text())
assert saved["hosts"]["hermes"]["apiKey"] == "hch-at-1"
assert saved["hosts"]["hermes"]["oauth"]["refreshToken"] == "hch-rt-1"
assert saved["hosts"]["hermes"]["recallMode"] == "hybrid"
assert saved["environment"] == "production"
assert saved["hosts"]["obsidian"] == {"workspace": "obsidian"}
# Force expiry; ensure_fresh_token refreshes against the same AS and rotates.
token, refreshed = oauth.ensure_fresh_token(
config_path, "hermes", now=saved["hosts"]["hermes"]["oauth"]["expiresAt"] + 10
)
assert refreshed is True
assert token == "hch-at-2"
rotated = json.loads(config_path.read_text())["hosts"]["hermes"]["oauth"]
assert rotated["refreshToken"] == "hch-rt-2"
def test_state_mismatch_is_rejected(fake_as, tmp_path):
endpoints = oauth_flow.resolve_endpoints()
_, state = oauth_flow.begin_authorization(endpoints)
with pytest.raises(ValueError, match="unknown or expired"):
oauth_flow.complete_authorization(
endpoints, "code", "not-the-real-state",
config_path=tmp_path / "honcho.json", host="hermes",
)
def test_source_tags_the_authorize_link(fake_as):
endpoints = oauth_flow.resolve_endpoints()
url, _ = oauth_flow.begin_authorization(endpoints, source="hermes-cli")
assert "source=hermes-cli" in url
untagged, _ = oauth_flow.begin_authorization(endpoints)
assert "source=" not in untagged
def test_client_id_defaults_to_hermes_agent(monkeypatch):
# One client for every surface; the env var overrides for unusual deployments.
monkeypatch.delenv("HONCHO_OAUTH_CLIENT_ID", raising=False)
common = {"environment": "production", "base_url": "https://api.honcho.dev"}
assert oauth_flow.resolve_endpoints(**common).client_id == "hermes-agent"
monkeypatch.setenv("HONCHO_OAUTH_CLIENT_ID", "custom-id")
assert oauth_flow.resolve_endpoints(**common).client_id == "custom-id"
def test_grant_persists_default_client_id(tmp_path, fake_as, monkeypatch):
# Drop the fixture's override so the default takes effect; the grant must
# store client_id=hermes-agent so refresh reuses the right client.
monkeypatch.delenv("HONCHO_OAUTH_CLIENT_ID", raising=False)
config_path = tmp_path / "honcho.json"
config_path.write_text(json.dumps({"hosts": {}}))
oauth_flow.authorize_via_loopback(
config_path=config_path,
host="hermes",
source="hermes-cli",
apply_config=False,
open_url=lambda url: _browser_driver(url),
timeout=10,
)
saved = json.loads(config_path.read_text())
assert saved["hosts"]["hermes"]["oauth"]["clientId"] == "hermes-agent"
def test_config_path_rides_the_authorize_link(fake_as):
endpoints = oauth_flow.resolve_endpoints()
url, _ = oauth_flow.begin_authorization(endpoints, config_path="~/.hermes/honcho.json")
q = parse_qs(urlparse(url).query)
assert q["config_path"][0] == "~/.hermes/honcho.json"
bare, _ = oauth_flow.begin_authorization(endpoints)
assert "config_path=" not in bare
def test_display_config_path_never_leaks_absolute_path():
from pathlib import Path
# Under home → collapsed to ~/…; outside home → bare filename only.
under_home = Path.home() / ".hermes" / "profiles" / "work" / "honcho.json"
assert oauth_flow._display_config_path(under_home) == "~/.hermes/profiles/work/honcho.json"
assert oauth_flow._display_config_path("/var/folders/tmp/honcho.json") == "honcho.json"
def test_cli_flow_stores_tokens_without_applying_config(tmp_path, fake_as):
# apply_config=False (the CLI path): grant config must NOT touch settings.
config_path = tmp_path / "honcho.json"
config_path.write_text(json.dumps({"hosts": {"hermes": {"saveMessages": False}}}))
cred = oauth_flow.authorize_via_loopback(
config_path=config_path,
host="hermes",
source="hermes-cli",
apply_config=False,
open_url=lambda url: _browser_driver(url),
timeout=10,
)
saved = json.loads(config_path.read_text())
host = saved["hosts"]["hermes"]
assert host["apiKey"] == cred.access_token
assert host["oauth"]["refreshToken"] == cred.refresh_token
# Wizard-owned setting untouched; grant config keys absent.
assert host["saveMessages"] is False
assert "recallMode" not in host
assert "environment" not in saved
# consent peer name still surfaced (seeds the CLI wizard prompt) despite no merge
assert cred.consent_peer_name == "lyra"
# ── Desktop "Connect" button path: background launcher, status, dispatch ──
@pytest.fixture
def reset_flow():
oauth_flow._status = oauth_flow.FlowStatus()
oauth_flow._flow_thread = None
yield
oauth_flow._status = oauth_flow.FlowStatus()
oauth_flow._flow_thread = None
def _wait_until(predicate, timeout=2.0):
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
if predicate():
return True
time.sleep(0.02)
return False
def test_launcher_runs_flow_in_background_and_reports_connected(monkeypatch, reset_flow):
seen = {}
gate = threading.Event()
def fake(**kwargs):
seen.update(kwargs) # captures source default + eagerly-resolved path/host
gate.wait(2) # hold the flow open so the launcher returns while pending
monkeypatch.setattr(oauth_flow, "authorize_via_loopback", fake)
monkeypatch.setattr(oauth_flow, "_detect_connection", lambda: (True, "oauth"))
st = oauth_flow.start_loopback_flow_background(config_path=Path("/t/honcho.json"), host="hermes")
assert st["state"] == "pending" # returns immediately, before the flow finishes
assert _wait_until(lambda: seen.get("source") == "hermes-desktop") # default source tag
assert seen["host"] == "hermes"
gate.set()
assert _wait_until(lambda: oauth_flow.get_flow_status()["state"] == "connected")
def test_launcher_reports_error_on_flow_failure(monkeypatch, reset_flow):
def boom(**kwargs):
raise RuntimeError("loopback bind failed")
monkeypatch.setattr(oauth_flow, "authorize_via_loopback", boom)
monkeypatch.setattr(oauth_flow, "_detect_connection", lambda: (False, None))
oauth_flow.start_loopback_flow_background(config_path=Path("/t/honcho.json"), host="hermes")
assert _wait_until(lambda: oauth_flow.get_flow_status()["state"] == "error")
assert "loopback bind failed" in oauth_flow.get_flow_status()["detail"]
def test_launcher_is_idempotent_while_pending(monkeypatch, reset_flow):
block = threading.Event()
calls = []
def fake(**kwargs):
calls.append(1)
block.wait(2)
monkeypatch.setattr(oauth_flow, "authorize_via_loopback", fake)
monkeypatch.setattr(oauth_flow, "_detect_connection", lambda: (False, None))
s1 = oauth_flow.start_loopback_flow_background(config_path=Path("/t/h.json"), host="hermes")
assert _wait_until(lambda: len(calls) == 1) # first flow is running
s2 = oauth_flow.start_loopback_flow_background(config_path=Path("/t/h.json"), host="hermes")
block.set()
assert s1["state"] == "pending" and s2["state"] == "pending"
assert _wait_until(lambda: oauth_flow.get_flow_status()["state"] == "connected")
assert calls == [1] # the second call did not spawn a second flow
def test_get_flow_status_reports_stored_connection(tmp_path, monkeypatch, reset_flow):
from plugins.memory.honcho import client as honcho_client
cfgfile = tmp_path / "honcho.json"
monkeypatch.setattr(honcho_client, "resolve_config_path", lambda: cfgfile)
monkeypatch.setattr(honcho_client, "resolve_active_host", lambda: "hermes")
monkeypatch.delenv("HONCHO_API_KEY", raising=False)
cfgfile.write_text(json.dumps({"hosts": {"hermes": {}}}))
assert oauth_flow.get_flow_status()["connected"] is False
cfgfile.write_text(json.dumps({"hosts": {"hermes": {"apiKey": "hch-v3-static"}}}))
s = oauth_flow.get_flow_status()
assert s["connected"] is True and s["auth"] == "apikey"
cfgfile.write_text(json.dumps({"hosts": {"hermes": {
"apiKey": "hch-at-tok",
"oauth": {"refreshToken": "hch-rt-x", "expiresAt": 9_999_999_999,
"clientId": "hermes-desktop", "tokenEndpoint": "http://x/oauth/token"},
}}}))
s = oauth_flow.get_flow_status()
assert s["connected"] is True and s["auth"] == "oauth"
def test_memory_oauth_router_dispatches_by_provider_convention():
# The generic seam behind the two routes: provider → plugins.memory.<p>.oauth_flow.
from fastapi import HTTPException
from hermes_cli.memory_oauth import _resolve_flow
mod = _resolve_flow("honcho")
assert hasattr(mod, "start_loopback_flow_background") and hasattr(mod, "get_flow_status")
for bad in ("builtin", "no-such-provider", "../etc"):
with pytest.raises(HTTPException) as exc:
_resolve_flow(bad)
assert exc.value.status_code == 404

View file

@ -0,0 +1,209 @@
"""Tests for Mem0Backend abstraction — PlatformBackend and OSSBackend."""
import pytest
from plugins.memory.mem0._backend import Mem0Backend, PlatformBackend, OSSBackend
class FakePlatformClient:
"""Fake MemoryClient for PlatformBackend tests."""
def __init__(self):
self.calls = []
def search(self, query, **kwargs):
self.calls.append(("search", query, kwargs))
return {"results": [{"id": "m1", "memory": "fact1", "score": 0.9}]}
def get_all(self, **kwargs):
self.calls.append(("get_all", kwargs))
return {"count": 1, "next": None, "results": [{"id": "m1", "memory": "fact1"}]}
def add(self, messages, **kwargs):
self.calls.append(("add", messages, kwargs))
return {"status": "PENDING", "event_id": "evt-1"}
def update(self, **kwargs):
self.calls.append(("update", kwargs))
return {"id": kwargs["memory_id"], "text": kwargs["text"]}
def delete(self, **kwargs):
self.calls.append(("delete", kwargs))
class TestPlatformBackend:
def _make(self):
client = FakePlatformClient()
backend = PlatformBackend.__new__(PlatformBackend)
backend._client = client
return backend, client
def test_search_forwards_params(self):
backend, client = self._make()
result = backend.search("test query", filters={"user_id": "u1"}, top_k=5)
assert client.calls[0][0] == "search"
assert client.calls[0][1] == "test query"
assert client.calls[0][2]["filters"] == {"user_id": "u1"}
assert client.calls[0][2]["top_k"] == 5
def test_search_forwards_rerank(self):
backend, client = self._make()
backend.search("q", filters={}, rerank=False)
assert client.calls[0][2]["rerank"] is False
def test_search_rerank_default_true(self):
backend, client = self._make()
backend.search("q", filters={})
assert client.calls[0][2]["rerank"] is True
def test_search_returns_list(self):
backend, _ = self._make()
result = backend.search("q", filters={})
assert isinstance(result, list)
assert result[0]["id"] == "m1"
def test_get_all_forwards_pagination(self):
backend, client = self._make()
result = backend.get_all(filters={"user_id": "u1"}, page=2, page_size=50)
assert client.calls[0][1]["page"] == 2
assert client.calls[0][1]["page_size"] == 50
assert "count" in result
def test_add_forwards_kwargs(self):
backend, client = self._make()
msgs = [{"role": "user", "content": "hi"}]
result = backend.add(msgs, user_id="u1", agent_id="hermes", infer=False)
call = client.calls[0]
assert call[2]["user_id"] == "u1"
assert call[2]["infer"] is False
# metadata kwarg should be omitted entirely when not provided so we
# don't surprise older mem0 client versions with an unknown kwarg.
assert "metadata" not in call[2]
def test_add_forwards_metadata_when_present(self):
backend, client = self._make()
msgs = [{"role": "user", "content": "hi"}]
backend.add(
msgs,
user_id="u1",
agent_id="hermes",
infer=False,
metadata={"channel": "telegram"},
)
assert client.calls[0][2]["metadata"] == {"channel": "telegram"}
def test_add_omits_empty_metadata(self):
backend, client = self._make()
msgs = [{"role": "user", "content": "hi"}]
backend.add(msgs, user_id="u1", agent_id="hermes", infer=False, metadata={})
assert "metadata" not in client.calls[0][2]
def test_update_forwards(self):
backend, client = self._make()
backend.update("m1", "new text")
assert client.calls[0][1] == {"memory_id": "m1", "text": "new text"}
def test_delete_forwards(self):
backend, client = self._make()
backend.delete("m1")
assert client.calls[0][1] == {"memory_id": "m1"}
class FakeOSSMemory:
"""Fake mem0.Memory for OSSBackend tests."""
def __init__(self):
self.calls = []
def search(self, query, **kwargs):
self.calls.append(("search", query, kwargs))
return {"results": [{"id": "m1", "memory": "fact1", "score": 0.8}]}
def get_all(self, **kwargs):
self.calls.append(("get_all", kwargs))
return {"results": [{"id": "m1", "memory": "fact1"}]}
def add(self, messages, **kwargs):
self.calls.append(("add", messages, kwargs))
return {"results": [{"id": "m1", "memory": "fact1", "event": "ADD"}]}
def update(self, memory_id, **kwargs):
self.calls.append(("update", memory_id, kwargs))
return {"message": "Memory updated successfully!"}
def delete(self, memory_id):
self.calls.append(("delete", memory_id))
return {"message": "Memory deleted successfully!"}
class TestOSSBackend:
def _make(self):
memory = FakeOSSMemory()
backend = OSSBackend.__new__(OSSBackend)
backend._memory = memory
return backend, memory
def test_search_returns_list(self):
backend, _ = self._make()
result = backend.search("test", filters={"user_id": "u1"})
assert isinstance(result, list)
assert result[0]["id"] == "m1"
def test_search_passes_filters(self):
backend, memory = self._make()
backend.search("q", filters={"user_id": "u1"}, top_k=3)
assert memory.calls[0][2]["filters"] == {"user_id": "u1"}
assert memory.calls[0][2]["top_k"] == 3
def test_search_ignores_rerank(self):
"""OSS backend accepts rerank param but does not forward it to Memory."""
backend, memory = self._make()
backend.search("q", filters={}, rerank=True)
assert "rerank" not in memory.calls[0][2]
def test_get_all_ignores_pagination(self):
"""OSSBackend accepts page/page_size but does NOT forward to Memory.get_all()."""
backend, memory = self._make()
result = backend.get_all(filters={"user_id": "u1"}, page=2, page_size=50)
call_kwargs = memory.calls[0][1]
assert "page" not in call_kwargs
assert "page_size" not in call_kwargs
assert result["count"] == 1
def test_get_all_returns_envelope(self):
backend, _ = self._make()
result = backend.get_all(filters={"user_id": "u1"})
assert "results" in result
assert "count" in result
def test_add_forwards_kwargs(self):
backend, memory = self._make()
msgs = [{"role": "user", "content": "hi"}]
backend.add(msgs, user_id="u1", agent_id="hermes", infer=False)
assert memory.calls[0][2]["user_id"] == "u1"
assert memory.calls[0][2]["infer"] is False
def test_update_maps_text_to_data(self):
"""OSS Memory.update uses `data=` param, not `text=`."""
backend, memory = self._make()
backend.update("m1", "new text")
assert memory.calls[0][0] == "update"
assert memory.calls[0][1] == "m1"
assert memory.calls[0][2] == {"data": "new text"}
def test_delete_positional_arg(self):
backend, memory = self._make()
backend.delete("m1")
assert memory.calls[0] == ("delete", "m1")
def test_update_normalizes_response(self):
backend, _ = self._make()
result = backend.update("m1", "text")
assert result == {"result": "Memory updated.", "memory_id": "m1"}
def test_delete_normalizes_response(self):
backend, _ = self._make()
result = backend.delete("m1")
assert result == {"result": "Memory deleted.", "memory_id": "m1"}

View file

@ -0,0 +1,107 @@
"""Tests for OSS provider definitions and validation."""
import pytest
from plugins.memory.mem0._oss_providers import (
LLM_PROVIDERS,
EMBEDDER_PROVIDERS,
VECTOR_PROVIDERS,
KNOWN_DIMS,
validate_oss_config,
)
class TestProviderDefinitions:
def test_llm_providers_have_required_keys(self):
for pid, p in LLM_PROVIDERS.items():
assert "label" in p
assert "needs_key" in p
assert "default_model" in p
def test_embedder_providers_have_required_keys(self):
for pid, p in EMBEDDER_PROVIDERS.items():
assert "label" in p
assert "needs_key" in p
assert "default_model" in p
assert "dims" in p
def test_embedder_provider_ids(self):
assert set(EMBEDDER_PROVIDERS.keys()) == {"openai", "ollama"}
def test_vector_providers_have_required_keys(self):
for pid, p in VECTOR_PROVIDERS.items():
assert "label" in p
assert "default_config" in p
def test_vector_provider_ids(self):
assert set(VECTOR_PROVIDERS.keys()) == {"qdrant", "pgvector"}
def test_known_dims_covers_defaults(self):
for pid, p in EMBEDDER_PROVIDERS.items():
assert p["default_model"] in KNOWN_DIMS
class TestValidation:
def test_valid_openai_config(self):
cfg = {
"llm": {"provider": "openai", "config": {"model": "gpt-4o-mini"}},
"embedder": {"provider": "openai", "config": {"model": "text-embedding-3-small"}},
"vector_store": {"provider": "qdrant", "config": {"path": "/tmp/test"}},
}
errors = validate_oss_config(cfg)
assert errors == []
def test_unknown_llm_provider(self):
cfg = {
"llm": {"provider": "gemini", "config": {}},
"embedder": {"provider": "openai", "config": {}},
"vector_store": {"provider": "qdrant", "config": {}},
}
errors = validate_oss_config(cfg)
assert any("llm" in e.lower() for e in errors)
def test_unknown_embedder_provider(self):
cfg = {
"llm": {"provider": "openai", "config": {}},
"embedder": {"provider": "cohere", "config": {}},
"vector_store": {"provider": "qdrant", "config": {}},
}
errors = validate_oss_config(cfg)
assert any("embedder" in e.lower() for e in errors)
def test_unknown_vector_provider(self):
cfg = {
"llm": {"provider": "openai", "config": {}},
"embedder": {"provider": "openai", "config": {}},
"vector_store": {"provider": "redis", "config": {}},
}
errors = validate_oss_config(cfg)
assert any("vector" in e.lower() for e in errors)
def test_missing_llm_section(self):
cfg = {
"embedder": {"provider": "openai", "config": {}},
"vector_store": {"provider": "qdrant", "config": {}},
}
errors = validate_oss_config(cfg)
assert any("llm" in e.lower() for e in errors)
def test_pgvector_needs_user(self):
cfg = {
"llm": {"provider": "openai", "config": {}},
"embedder": {"provider": "openai", "config": {}},
"vector_store": {"provider": "pgvector", "config": {"host": "localhost"}},
}
errors = validate_oss_config(cfg)
assert any("user" in e.lower() for e in errors)
def test_pgvector_with_user_valid(self):
cfg = {
"llm": {"provider": "openai", "config": {}},
"embedder": {"provider": "openai", "config": {}},
"vector_store": {"provider": "pgvector", "config": {"host": "localhost", "user": "pg"}},
}
errors = validate_oss_config(cfg)
assert errors == []

View file

@ -0,0 +1,251 @@
"""Tests for Mem0 setup wizard — flag parsing, config building, validation."""
import json
import sys
import types
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
from plugins.memory.mem0._setup import (
parse_flags,
build_oss_config,
_write_env,
post_setup,
_check_qdrant_path,
_check_ollama,
_check_pgvector,
)
def _inject_fake_hermes_cli(monkeypatch):
"""Inject fake hermes_cli modules so yaml/curses aren't required."""
fake_config_mod = types.ModuleType("hermes_cli.config")
fake_config_mod.save_config = lambda c: None
fake_setup_mod = types.ModuleType("hermes_cli.memory_setup")
fake_setup_mod._curses_select = lambda *a, **kw: 0
fake_setup_mod._prompt = lambda label, default=None, secret=False: default or ""
fake_hermes_cli = types.ModuleType("hermes_cli")
fake_hermes_cli.config = fake_config_mod
fake_hermes_cli.memory_setup = fake_setup_mod
monkeypatch.setitem(sys.modules, "hermes_cli", fake_hermes_cli)
monkeypatch.setitem(sys.modules, "hermes_cli.config", fake_config_mod)
monkeypatch.setitem(sys.modules, "hermes_cli.memory_setup", fake_setup_mod)
monkeypatch.setattr("plugins.memory.mem0._setup._curses_select", lambda *a, **kw: 0)
monkeypatch.setattr("plugins.memory.mem0._setup._prompt", lambda label, default=None, secret=False: default or "")
return fake_config_mod
class TestParseFlags:
def test_mode_platform(self):
flags = parse_flags(["--mode", "platform", "--api-key", "sk-test"])
assert flags["mode"] == "platform"
assert flags["api_key"] == "sk-test"
def test_mode_oss_defaults(self):
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"])
assert flags["mode"] == "oss"
assert flags["oss_llm"] == "openai"
assert flags["oss_embedder"] == "openai"
assert flags["oss_vector"] == "qdrant"
def test_mode_oss_all_flags(self):
flags = parse_flags([
"--mode", "oss",
"--oss-llm", "ollama",
"--oss-llm-model", "llama3:latest",
"--oss-embedder", "ollama",
"--oss-embedder-model", "nomic-embed-text",
"--oss-vector", "pgvector",
"--oss-vector-host", "db.local",
"--oss-vector-port", "5433",
"--oss-vector-user", "pguser",
"--oss-vector-password", "secret",
"--oss-vector-dbname", "memdb",
"--user-id", "my-user",
])
assert flags["oss_llm"] == "ollama"
assert flags["oss_llm_model"] == "llama3:latest"
assert flags["oss_vector"] == "pgvector"
assert flags["oss_vector_user"] == "pguser"
assert flags["user_id"] == "my-user"
def test_no_flags_returns_empty_mode(self):
flags = parse_flags([])
assert flags["mode"] == ""
def test_oss_vector_path_flag(self):
flags = parse_flags(["--mode", "oss", "--oss-vector-path", "/data/qdrant"])
assert flags["oss_vector_path"] == "/data/qdrant"
class TestBuildOSSConfig:
def test_openai_defaults(self):
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"])
oss, env_writes = build_oss_config(flags)
assert oss["llm"]["provider"] == "openai"
assert oss["llm"]["config"]["model"] == "gpt-5-mini"
assert oss["embedder"]["provider"] == "openai"
assert oss["embedder"]["config"]["model"] == "text-embedding-3-small"
assert oss["vector_store"]["provider"] == "qdrant"
assert env_writes["OPENAI_API_KEY"] == "sk-oai"
def test_ollama_no_key_needed(self):
flags = parse_flags(["--mode", "oss", "--oss-llm", "ollama", "--oss-embedder", "ollama"])
oss, env_writes = build_oss_config(flags)
assert oss["llm"]["provider"] == "ollama"
assert "model" in oss["llm"]["config"]
assert env_writes == {}
def test_embedder_reuses_llm_key(self):
"""When LLM and embedder share same provider, key written once."""
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"])
_, env_writes = build_oss_config(flags)
assert env_writes == {"OPENAI_API_KEY": "sk-oai"}
def test_different_embedder_needs_separate_key(self):
flags = parse_flags([
"--mode", "oss",
"--oss-llm", "ollama",
"--oss-embedder", "openai", "--oss-embedder-key", "sk-oai",
])
_, env_writes = build_oss_config(flags)
assert env_writes == {"OPENAI_API_KEY": "sk-oai"}
def test_pgvector_config(self):
flags = parse_flags([
"--mode", "oss", "--oss-llm-key", "sk-oai",
"--oss-vector", "pgvector",
"--oss-vector-host", "db.local", "--oss-vector-port", "5433",
"--oss-vector-user", "pg", "--oss-vector-dbname", "memdb",
])
oss, _ = build_oss_config(flags)
vs = oss["vector_store"]
assert vs["provider"] == "pgvector"
assert vs["config"]["host"] == "db.local"
assert vs["config"]["port"] == 5433
assert vs["config"]["user"] == "pg"
def test_known_dims_auto_set(self):
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"])
oss, _ = build_oss_config(flags)
dims = oss["embedder"]["config"].get("embedding_dims")
assert dims == 1536
def test_custom_qdrant_path(self):
flags = parse_flags([
"--mode", "oss", "--oss-llm-key", "sk-oai",
"--oss-vector-path", "/data/qdrant",
])
oss, _ = build_oss_config(flags)
assert oss["vector_store"]["config"]["path"] == "/data/qdrant"
class TestWriteEnv:
def test_write_new_vars(self, tmp_path):
env_path = tmp_path / ".env"
_write_env(env_path, {"OPENAI_API_KEY": "sk-test"})
content = env_path.read_text()
assert "OPENAI_API_KEY=sk-test" in content
def test_update_existing_var(self, tmp_path):
env_path = tmp_path / ".env"
env_path.write_text("OPENAI_API_KEY=old\nOTHER=keep\n")
_write_env(env_path, {"OPENAI_API_KEY": "new"})
content = env_path.read_text()
assert "OPENAI_API_KEY=new" in content
assert "OTHER=keep" in content
assert "old" not in content
class TestPostSetup:
def test_platform_flag_mode(self, tmp_path, monkeypatch):
monkeypatch.setattr("sys.argv", ["hermes", "--mode", "platform", "--api-key", "sk-test"])
monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path)
_inject_fake_hermes_cli(monkeypatch)
config = {"memory": {}}
post_setup(str(tmp_path), config)
assert config["memory"]["provider"] == "mem0"
env_content = (tmp_path / ".env").read_text()
assert "MEM0_API_KEY=sk-test" in env_content
mem0_json = json.loads((tmp_path / "mem0.json").read_text())
assert mem0_json["mode"] == "platform"
def test_oss_flag_mode(self, tmp_path, monkeypatch):
monkeypatch.setattr("sys.argv", [
"hermes", "--mode", "oss", "--oss-llm-key", "sk-oai",
])
monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path)
_inject_fake_hermes_cli(monkeypatch)
monkeypatch.setattr("plugins.memory.mem0._setup._install_provider_deps", lambda l, e, v: None)
config = {"memory": {}}
post_setup(str(tmp_path), config)
assert config["memory"]["provider"] == "mem0"
mem0_json = json.loads((tmp_path / "mem0.json").read_text())
assert mem0_json["mode"] == "oss"
assert mem0_json["oss"]["llm"]["provider"] == "openai"
class TestDryRun:
def test_dry_run_flag_parsed(self):
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai", "--dry-run"])
assert flags["dry_run"] is True
def test_dry_run_not_set_by_default(self):
flags = parse_flags(["--mode", "oss"])
assert flags["dry_run"] is False
def test_dry_run_platform_no_files(self, tmp_path, monkeypatch):
monkeypatch.setattr("sys.argv", ["hermes", "--mode", "platform", "--api-key", "sk-test", "--dry-run"])
monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path)
_inject_fake_hermes_cli(monkeypatch)
config = {"memory": {}}
post_setup(str(tmp_path), config)
assert not (tmp_path / ".env").exists()
assert not (tmp_path / "mem0.json").exists()
assert "provider" not in config["memory"]
def test_dry_run_oss_no_files(self, tmp_path, monkeypatch):
monkeypatch.setattr("sys.argv", [
"hermes", "--mode", "oss", "--oss-llm-key", "sk-oai", "--dry-run",
])
monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path)
_inject_fake_hermes_cli(monkeypatch)
monkeypatch.setattr("plugins.memory.mem0._setup._install_provider_deps", lambda l, e, v: None)
config = {"memory": {}}
post_setup(str(tmp_path), config)
assert not (tmp_path / ".env").exists()
assert not (tmp_path / "mem0.json").exists()
assert "provider" not in config["memory"]
class TestConnectivityChecks:
def test_qdrant_path_writable(self, tmp_path):
ok, msg = _check_qdrant_path(str(tmp_path / "qdrant"))
assert ok is True
def test_qdrant_path_not_writable(self, tmp_path, monkeypatch):
def _raise_oserror(*a, **kw):
raise OSError("Permission denied")
monkeypatch.setattr(Path, "mkdir", _raise_oserror)
ok, msg = _check_qdrant_path(str(tmp_path / "qdrant"))
assert ok is False
assert "Permission denied" in msg
def test_ollama_unreachable(self):
ok, msg = _check_ollama("http://localhost:1")
assert ok is False
def test_pgvector_unreachable(self):
ok, msg = _check_pgvector("localhost", 1)
assert ok is False

View file

@ -1,241 +0,0 @@
"""Tests for Mem0 API v2 compatibility — filters param and dict response unwrapping.
Salvaged from PRs #5301 (qaqcvc) and #5117 (vvvanguards).
"""
import json
import os
import stat
import pytest
from plugins.memory.mem0 import Mem0MemoryProvider
class FakeClientV2:
"""Fake Mem0 client that returns v2-style dict responses and captures call kwargs."""
def __init__(self, search_results=None, all_results=None):
self._search_results = search_results or {"results": []}
self._all_results = all_results or {"results": []}
self.captured_search = {}
self.captured_get_all = {}
self.captured_add = []
def search(self, **kwargs):
self.captured_search = kwargs
return self._search_results
def get_all(self, **kwargs):
self.captured_get_all = kwargs
return self._all_results
def add(self, messages, **kwargs):
self.captured_add.append({"messages": messages, **kwargs})
# ---------------------------------------------------------------------------
# Filter migration: bare user_id= -> filters={}
# ---------------------------------------------------------------------------
class TestMem0FiltersV2:
"""All API calls must use filters={} instead of bare user_id= kwargs."""
def _make_provider(self, monkeypatch, client):
provider = Mem0MemoryProvider()
provider.initialize("test-session")
provider._user_id = "u123"
provider._agent_id = "hermes"
monkeypatch.setattr(provider, "_get_client", lambda: client)
return provider
def test_search_uses_filters(self, monkeypatch):
client = FakeClientV2()
provider = self._make_provider(monkeypatch, client)
provider.handle_tool_call("mem0_search", {"query": "hello", "top_k": 3, "rerank": False})
assert client.captured_search["query"] == "hello"
assert client.captured_search["top_k"] == 3
assert client.captured_search["rerank"] is False
assert client.captured_search["filters"] == {"user_id": "u123"}
# Must NOT have bare user_id kwarg
assert "user_id" not in {k for k in client.captured_search if k != "filters"}
def test_profile_uses_filters(self, monkeypatch):
client = FakeClientV2()
provider = self._make_provider(monkeypatch, client)
provider.handle_tool_call("mem0_profile", {})
assert client.captured_get_all["filters"] == {"user_id": "u123"}
assert "user_id" not in {k for k in client.captured_get_all if k != "filters"}
def test_prefetch_uses_filters(self, monkeypatch):
client = FakeClientV2()
provider = self._make_provider(monkeypatch, client)
provider.queue_prefetch("hello")
provider._prefetch_thread.join(timeout=2)
assert client.captured_search["query"] == "hello"
assert client.captured_search["filters"] == {"user_id": "u123"}
assert "user_id" not in {k for k in client.captured_search if k != "filters"}
def test_sync_turn_uses_write_filters(self, monkeypatch):
client = FakeClientV2()
provider = self._make_provider(monkeypatch, client)
provider.sync_turn("user said this", "assistant replied", session_id="s1")
provider._sync_thread.join(timeout=2)
assert len(client.captured_add) == 1
call = client.captured_add[0]
assert call["user_id"] == "u123"
assert call["agent_id"] == "hermes"
def test_conclude_uses_write_filters(self, monkeypatch):
client = FakeClientV2()
provider = self._make_provider(monkeypatch, client)
provider.handle_tool_call("mem0_conclude", {"conclusion": "user likes dark mode"})
assert len(client.captured_add) == 1
call = client.captured_add[0]
assert call["user_id"] == "u123"
assert call["agent_id"] == "hermes"
assert call["infer"] is False
def test_read_filters_no_agent_id(self):
"""Read filters should use user_id only — cross-session recall across agents."""
provider = Mem0MemoryProvider()
provider._user_id = "u123"
provider._agent_id = "hermes"
assert provider._read_filters() == {"user_id": "u123"}
def test_write_filters_include_agent_id(self):
"""Write filters should include agent_id for attribution."""
provider = Mem0MemoryProvider()
provider._user_id = "u123"
provider._agent_id = "hermes"
assert provider._write_filters() == {"user_id": "u123", "agent_id": "hermes"}
# ---------------------------------------------------------------------------
# Dict response unwrapping (API v2 wraps in {"results": [...]})
# ---------------------------------------------------------------------------
class TestMem0ResponseUnwrapping:
"""API v2 returns {"results": [...]} dicts; we must extract the list."""
def _make_provider(self, monkeypatch, client):
provider = Mem0MemoryProvider()
provider.initialize("test-session")
monkeypatch.setattr(provider, "_get_client", lambda: client)
return provider
def test_profile_dict_response(self, monkeypatch):
client = FakeClientV2(all_results={"results": [{"memory": "alpha"}, {"memory": "beta"}]})
provider = self._make_provider(monkeypatch, client)
result = json.loads(provider.handle_tool_call("mem0_profile", {}))
assert result["count"] == 2
assert "alpha" in result["result"]
assert "beta" in result["result"]
def test_profile_list_response_backward_compat(self, monkeypatch):
"""Old API returned bare lists — still works."""
client = FakeClientV2(all_results=[{"memory": "gamma"}])
provider = self._make_provider(monkeypatch, client)
result = json.loads(provider.handle_tool_call("mem0_profile", {}))
assert result["count"] == 1
assert "gamma" in result["result"]
def test_search_dict_response(self, monkeypatch):
client = FakeClientV2(search_results={
"results": [{"memory": "foo", "score": 0.9}, {"memory": "bar", "score": 0.7}]
})
provider = self._make_provider(monkeypatch, client)
result = json.loads(provider.handle_tool_call(
"mem0_search", {"query": "test", "top_k": 5}
))
assert result["count"] == 2
assert result["results"][0]["memory"] == "foo"
def test_search_list_response_backward_compat(self, monkeypatch):
"""Old API returned bare lists — still works."""
client = FakeClientV2(search_results=[{"memory": "baz", "score": 0.8}])
provider = self._make_provider(monkeypatch, client)
result = json.loads(provider.handle_tool_call(
"mem0_search", {"query": "test"}
))
assert result["count"] == 1
def test_unwrap_results_edge_cases(self):
"""_unwrap_results handles all shapes gracefully."""
assert Mem0MemoryProvider._unwrap_results({"results": [1, 2]}) == [1, 2]
assert Mem0MemoryProvider._unwrap_results([3, 4]) == [3, 4]
assert Mem0MemoryProvider._unwrap_results({}) == []
assert Mem0MemoryProvider._unwrap_results(None) == []
assert Mem0MemoryProvider._unwrap_results("unexpected") == []
def test_prefetch_dict_response(self, monkeypatch):
client = FakeClientV2(search_results={
"results": [{"memory": "user prefers dark mode"}]
})
provider = Mem0MemoryProvider()
provider.initialize("test-session")
monkeypatch.setattr(provider, "_get_client", lambda: client)
provider.queue_prefetch("preferences")
provider._prefetch_thread.join(timeout=2)
result = provider.prefetch("preferences")
assert "dark mode" in result
# ---------------------------------------------------------------------------
# Default preservation
# ---------------------------------------------------------------------------
@pytest.mark.skipif(os.name == "nt", reason="POSIX mode bits not enforced on Windows")
def test_save_config_sets_owner_only_permissions(tmp_path):
"""mem0.json must be written with 0o600 so API key is not world-readable."""
provider = Mem0MemoryProvider()
provider.save_config({"api_key": "m0-test-key"}, str(tmp_path))
config_file = tmp_path / "mem0.json"
assert config_file.exists()
mode = stat.S_IMODE(config_file.stat().st_mode)
assert mode == 0o600, f"Expected 0o600 (owner-only), got {oct(mode)}"
class TestMem0Defaults:
"""Ensure we don't break existing users' defaults."""
def test_default_user_id_hermes_user(self, monkeypatch, tmp_path):
monkeypatch.setenv("MEM0_API_KEY", "test-key")
monkeypatch.delenv("MEM0_USER_ID", raising=False)
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
provider = Mem0MemoryProvider()
provider.initialize("test")
assert provider._user_id == "hermes-user"
def test_default_agent_id_hermes(self, monkeypatch, tmp_path):
monkeypatch.setenv("MEM0_API_KEY", "test-key")
monkeypatch.delenv("MEM0_AGENT_ID", raising=False)
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
provider = Mem0MemoryProvider()
provider.initialize("test")
assert provider._agent_id == "hermes"

View file

@ -0,0 +1,463 @@
"""Tests for Mem0 v3 API — new tool names, paginated responses, update/delete tools."""
import json
import pytest
from plugins.memory.mem0 import Mem0MemoryProvider
class FakeBackend:
"""Fake Mem0Backend for provider-level tests."""
def __init__(self, search_results=None, all_results=None):
self._search_results = search_results or []
self._all_results = all_results or {"results": [], "count": 0}
self.captured = []
def search(self, query, *, filters, top_k=10, rerank=True):
self.captured.append(("search", query, {"filters": filters, "top_k": top_k, "rerank": rerank}))
return self._search_results
def get_all(self, *, filters, page=1, page_size=100):
self.captured.append(("get_all", {"filters": filters, "page": page, "page_size": page_size}))
return self._all_results
def add(self, messages, *, user_id, agent_id, infer=False, metadata=None):
self.captured.append((
"add",
messages,
{"user_id": user_id, "agent_id": agent_id, "infer": infer, "metadata": metadata},
))
return {"status": "PENDING", "event_id": "evt-test-123"}
def update(self, memory_id, text):
self.captured.append(("update", memory_id, text))
return {"result": "Memory updated.", "memory_id": memory_id}
def delete(self, memory_id):
self.captured.append(("delete", memory_id))
return {"result": "Memory deleted.", "memory_id": memory_id}
class TestMem0V3Tools:
"""Test v3 tool names and response handling."""
def _make_provider(self, monkeypatch, backend):
provider = Mem0MemoryProvider()
provider.initialize("test-session")
provider._user_id = "u123"
provider._agent_id = "hermes"
provider._backend = backend
return provider
def test_list_returns_paginated_with_ids(self, monkeypatch):
backend = FakeBackend(all_results={
"count": 2,
"results": [
{"id": "mem-1", "memory": "alpha"},
{"id": "mem-2", "memory": "beta"},
]
})
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call("mem0_list", {}))
assert result["count"] == 2
assert result["results"][0]["id"] == "mem-1"
assert result["results"][0]["memory"] == "alpha"
def test_list_pagination_params(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
provider.handle_tool_call("mem0_list", {"page": 2, "page_size": 50})
assert backend.captured[0][1]["page"] == 2
assert backend.captured[0][1]["page_size"] == 50
def test_list_empty(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call("mem0_list", {}))
assert result["result"] == "No memories stored yet."
def test_search_returns_ids(self, monkeypatch):
backend = FakeBackend(search_results=[{"id": "mem-1", "memory": "foo", "score": 0.9}])
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call("mem0_search", {"query": "test"}))
assert result["results"][0]["id"] == "mem-1"
def test_search_uses_filters(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
provider.handle_tool_call("mem0_search", {"query": "hello", "top_k": 3})
assert backend.captured[0][2]["filters"] == {"user_id": "u123"}
assert backend.captured[0][2]["top_k"] == 3
def test_search_rerank_default_true(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
provider.handle_tool_call("mem0_search", {"query": "test"})
assert backend.captured[0][2]["rerank"] is True
def test_search_rerank_override_false(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
provider.handle_tool_call("mem0_search", {"query": "test", "rerank": False})
assert backend.captured[0][2]["rerank"] is False
def test_add_uses_content_param(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call("mem0_add", {"content": "user likes dark mode"}))
assert len(backend.captured) == 1
call = backend.captured[0]
assert call[2]["infer"] is False
assert call[2]["user_id"] == "u123"
assert call[2]["agent_id"] == "hermes"
assert "event_id" in result
def test_add_returns_event_id(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call("mem0_add", {"content": "test"}))
assert result["event_id"] == "evt-test-123"
def test_add_missing_content(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call("mem0_add", {}))
assert "error" in result
def test_old_tool_names_return_unknown(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call("mem0_profile", {}))
assert "error" in result
result = json.loads(provider.handle_tool_call("mem0_conclude", {}))
assert "error" in result
class TestMem0UpdateDelete:
def _make_provider(self, monkeypatch, backend):
provider = Mem0MemoryProvider()
provider.initialize("test-session")
provider._user_id = "u123"
provider._agent_id = "hermes"
provider._backend = backend
return provider
def test_update_calls_sdk(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call(
"mem0_update", {"memory_id": "mem-1", "text": "updated fact"}
))
assert backend.captured[0][1] == "mem-1"
assert backend.captured[0][2] == "updated fact"
assert result["result"] == "Memory updated."
assert result["memory_id"] == "mem-1"
def test_update_missing_memory_id(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call("mem0_update", {"text": "no id"}))
assert "error" in result
def test_update_missing_text(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call("mem0_update", {"memory_id": "mem-1"}))
assert "error" in result
def test_delete_calls_sdk(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call(
"mem0_delete", {"memory_id": "mem-1"}
))
assert backend.captured[0][1] == "mem-1"
assert result["result"] == "Memory deleted."
def test_delete_missing_memory_id(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call("mem0_delete", {}))
assert "error" in result
class TestMem0ErrorHandling:
def _make_provider(self, monkeypatch, backend):
provider = Mem0MemoryProvider()
provider.initialize("test-session")
provider._user_id = "u123"
provider._agent_id = "hermes"
provider._backend = backend
return provider
def test_update_404_no_circuit_breaker(self, monkeypatch):
backend = FakeBackend()
backend.update = lambda mid, text: (_ for _ in ()).throw(Exception("404 Not Found"))
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call(
"mem0_update", {"memory_id": "bad-id", "text": "x"}
))
assert "error" in result
assert provider._consecutive_failures == 0
def test_delete_404_no_circuit_breaker(self, monkeypatch):
backend = FakeBackend()
backend.delete = lambda mid: (_ for _ in ()).throw(Exception("404 not found"))
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call(
"mem0_delete", {"memory_id": "bad-id"}
))
assert "error" in result
assert provider._consecutive_failures == 0
def test_update_validation_error_no_circuit_breaker(self, monkeypatch):
"""ValidationError (bad UUID format) should not trip circuit breaker."""
class ValidationError(Exception):
pass
backend = FakeBackend()
backend.update = lambda mid, text: (_ for _ in ()).throw(
ValidationError('{"error":"memory_id should be a valid UUID"}')
)
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call(
"mem0_update", {"memory_id": "not-a-uuid", "text": "x"}
))
assert "error" in result
assert provider._consecutive_failures == 0
def test_delete_validation_error_no_circuit_breaker(self, monkeypatch):
class ValidationError(Exception):
pass
backend = FakeBackend()
backend.delete = lambda mid: (_ for _ in ()).throw(
ValidationError('{"error":"memory_id should be a valid UUID"}')
)
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call(
"mem0_delete", {"memory_id": "not-a-uuid"}
))
assert "error" in result
assert provider._consecutive_failures == 0
def test_update_5xx_trips_circuit_breaker(self, monkeypatch):
backend = FakeBackend()
backend.update = lambda mid, text: (_ for _ in ()).throw(Exception("500 Internal Server Error"))
provider = self._make_provider(monkeypatch, backend)
provider.handle_tool_call("mem0_update", {"memory_id": "mem-1", "text": "x"})
assert provider._consecutive_failures == 1
class TestMem0V3Internal:
def _make_provider(self, monkeypatch, backend):
provider = Mem0MemoryProvider()
provider.initialize("test-session")
provider._user_id = "u123"
provider._agent_id = "hermes"
provider._backend = backend
return provider
def test_sync_turn_explicit_kwargs(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
provider.sync_turn("user said", "assistant replied", session_id="s1")
provider._sync_thread.join(timeout=2)
assert len(backend.captured) == 1
call = backend.captured[0]
assert call[2]["user_id"] == "u123"
assert call[2]["agent_id"] == "hermes"
assert call[2]["infer"] is True
def test_old_tool_names_return_unknown(self, monkeypatch):
backend = FakeBackend()
provider = self._make_provider(monkeypatch, backend)
result = json.loads(provider.handle_tool_call("mem0_profile", {}))
assert "error" in result
result = json.loads(provider.handle_tool_call("mem0_conclude", {}))
assert "error" in result
class TestMem0V3Config:
def test_tool_schemas_five_tools(self):
provider = Mem0MemoryProvider()
schemas = provider.get_tool_schemas()
names = [s["name"] for s in schemas]
assert names == ["mem0_list", "mem0_search", "mem0_add", "mem0_update", "mem0_delete"]
def test_system_prompt_new_tool_names(self):
provider = Mem0MemoryProvider()
provider._user_id = "test"
block = provider.system_prompt_block()
assert "mem0_search" in block
assert "mem0_add" in block
assert "mem0_list" in block
assert "mem0_update" in block
assert "mem0_delete" in block
assert "mem0_profile" not in block
assert "mem0_conclude" not in block
def test_system_prompt_shows_platform_mode(self):
provider = Mem0MemoryProvider()
provider._user_id = "test"
provider._mode = "platform"
block = provider.system_prompt_block()
assert "platform" in block
assert "Rerank" in block
def test_system_prompt_shows_oss_mode(self):
provider = Mem0MemoryProvider()
provider._user_id = "test"
provider._mode = "oss"
block = provider.system_prompt_block()
assert "OSS" in block
assert "Rerank" not in block
def test_search_schema_has_rerank(self):
"""rerank property available in SEARCH_SCHEMA for platform mode."""
provider = Mem0MemoryProvider()
schemas = provider.get_tool_schemas()
search = next(s for s in schemas if s["name"] == "mem0_search")
assert "rerank" in search["parameters"]["properties"]
assert search["parameters"]["properties"]["rerank"]["type"] == "boolean"
class TestMem0ModeSwitch:
def test_default_mode_is_platform(self, monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setenv("MEM0_API_KEY", "test-key")
provider = Mem0MemoryProvider()
provider.initialize("test")
assert provider._mode == "platform"
def test_missing_mode_key_defaults_platform(self, monkeypatch, tmp_path):
"""Backward compat: old mem0.json without mode key works."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
config_path = tmp_path / "mem0.json"
config_path.write_text('{"user_id": "old-user"}')
monkeypatch.setenv("MEM0_API_KEY", "test-key")
provider = Mem0MemoryProvider()
provider.initialize("test")
assert provider._mode == "platform"
assert provider._user_id == "old-user"
def test_is_available_platform_needs_key(self, monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.delenv("MEM0_API_KEY", raising=False)
provider = Mem0MemoryProvider()
assert provider.is_available() is False
def test_is_available_oss_needs_vector(self, monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
config_path = tmp_path / "mem0.json"
config_path.write_text('{"mode": "oss", "oss": {"vector_store": {"provider": "qdrant"}}}')
provider = Mem0MemoryProvider()
assert provider.is_available() is True
def test_is_available_oss_no_vector(self, monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
config_path = tmp_path / "mem0.json"
config_path.write_text('{"mode": "oss", "oss": {}}')
provider = Mem0MemoryProvider()
assert provider.is_available() is False
def test_tool_schemas_unchanged(self):
provider = Mem0MemoryProvider()
schemas = provider.get_tool_schemas()
names = [s["name"] for s in schemas]
assert names == ["mem0_list", "mem0_search", "mem0_add", "mem0_update", "mem0_delete"]
def test_system_prompt_includes_mode(self):
provider = Mem0MemoryProvider()
provider._user_id = "test"
provider._mode = "oss"
block = provider.system_prompt_block()
assert "mem0_search" in block
assert "mem0_list" in block
assert "OSS" in block
class TestMem0UserIdResolution:
"""user_id resolution: configured override > gateway-native id > placeholder.
Same human across CLI / Telegram / Discord / Slack / etc. should map to
the same memory store when MEM0_USER_ID is set, and only fall back to the
gateway-native id when it isn't.
"""
def _provider(self, monkeypatch, tmp_path):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
monkeypatch.setenv("MEM0_API_KEY", "test-key")
provider = Mem0MemoryProvider()
# Skip backend instantiation — we only care about identity resolution.
provider._create_backend = lambda: None # type: ignore[method-assign]
return provider
def test_env_override_beats_gateway_native_id(self, monkeypatch, tmp_path):
monkeypatch.setenv("MEM0_USER_ID", "ryan@example.com")
provider = self._provider(monkeypatch, tmp_path)
provider.initialize("test", user_id="123456789", platform="telegram")
assert provider._user_id == "ryan@example.com"
def test_file_override_beats_gateway_native_id(self, monkeypatch, tmp_path):
monkeypatch.delenv("MEM0_USER_ID", raising=False)
(tmp_path / "mem0.json").write_text('{"user_id": "ryan@example.com"}')
provider = self._provider(monkeypatch, tmp_path)
provider.initialize("test", user_id="123456789", platform="telegram")
assert provider._user_id == "ryan@example.com"
def test_unset_falls_back_to_gateway_native_id(self, monkeypatch, tmp_path):
monkeypatch.delenv("MEM0_USER_ID", raising=False)
provider = self._provider(monkeypatch, tmp_path)
provider.initialize("test", user_id="123456789", platform="telegram")
assert provider._user_id == "123456789"
def test_unset_and_no_kwargs_falls_back_to_default(self, monkeypatch, tmp_path):
monkeypatch.delenv("MEM0_USER_ID", raising=False)
provider = self._provider(monkeypatch, tmp_path)
provider.initialize("test")
assert provider._user_id == "hermes-user"
def test_legacy_placeholder_in_config_does_not_override_kwargs(self, monkeypatch, tmp_path):
# Setup wizard historically wrote {"user_id": "hermes-user"} as the
# suggested default. Treat that placeholder as unset so users on
# gateways still get gateway-native ids — not silent collisions.
monkeypatch.delenv("MEM0_USER_ID", raising=False)
(tmp_path / "mem0.json").write_text('{"user_id": "hermes-user"}')
provider = self._provider(monkeypatch, tmp_path)
provider.initialize("test", user_id="123456789", platform="telegram")
assert provider._user_id == "123456789"
class TestMem0WriteMetadata:
"""Writes carry metadata.channel so per-channel filtered views are possible
without coupling identity to the channel.
"""
def _make_provider(self, channel: str = "cli"):
provider = Mem0MemoryProvider()
provider._user_id = "u123"
provider._agent_id = "hermes"
provider._channel = channel
provider._backend = FakeBackend()
return provider
def test_add_tool_passes_channel_metadata(self):
provider = self._make_provider("telegram")
provider.handle_tool_call("mem0_add", {"content": "user likes dark mode"})
call = provider._backend.captured[-1]
assert call[2]["metadata"] == {"channel": "telegram"}
def test_sync_turn_passes_channel_metadata(self):
provider = self._make_provider("discord")
provider.sync_turn("hi", "hello", session_id="s")
# sync_turn fires a daemon thread; wait for it.
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
adds = [c for c in provider._backend.captured if c[0] == "add"]
assert adds, "expected an add call from sync_turn"
assert adds[-1][2]["metadata"] == {"channel": "discord"}

View file

@ -1459,6 +1459,137 @@ def test_tool_add_resource_sends_git_remote_sources_as_path(url):
})
def test_get_tool_schemas_includes_narrow_forget_tool():
provider = OpenVikingMemoryProvider()
names = [schema["name"] for schema in provider.get_tool_schemas()]
assert "viking_forget" in names
def test_handle_tool_call_forget_deletes_exact_memory_file_uri():
uri = "viking://user/peers/hermes/memories/preferences/mem_abc123.md"
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.delete.return_value = {
"status": "ok",
"result": {"uri": uri, "estimated_deleted_count": 1},
}
result = json.loads(provider.handle_tool_call("viking_forget", {"uri": uri}))
provider._client.delete.assert_called_once_with(
"/api/v1/fs",
params={"uri": uri, "recursive": False},
)
assert result == {
"status": "deleted",
"uri": uri,
"estimated_deleted_count": 1,
}
def test_handle_tool_call_forget_deletes_exact_memory_file_under_memories_root():
uri = "viking://user/default/memories/profile.md"
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.delete.return_value = {
"status": "ok",
"result": {"uri": uri, "estimated_deleted_count": 1},
}
result = json.loads(provider.handle_tool_call("viking_forget", {"uri": uri}))
provider._client.delete.assert_called_once_with(
"/api/v1/fs",
params={"uri": uri, "recursive": False},
)
assert result == {
"status": "deleted",
"uri": uri,
"estimated_deleted_count": 1,
}
def test_handle_tool_call_forget_allows_non_generated_dot_md_memory_file():
uri = "viking://user/default/memories/preferences/.full.md"
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.delete.return_value = {
"status": "ok",
"result": {"uri": uri, "estimated_deleted_count": 1},
}
result = json.loads(provider.handle_tool_call("viking_forget", {"uri": uri}))
provider._client.delete.assert_called_once_with(
"/api/v1/fs",
params={"uri": uri, "recursive": False},
)
assert result == {
"status": "deleted",
"uri": uri,
"estimated_deleted_count": 1,
}
@pytest.mark.parametrize("uri", [
"",
"https://example.com/mem.md",
"viking:/user/memories/preferences/mem_abc123.md",
"viking://resources/project/doc.md",
"viking://resources/project/memories/mem_abc123.md",
"viking://memories/preferences/mem_abc123.md",
"viking://agent/hermes/memories/preferences/mem_abc123.md",
"viking://user/skills/example/SKILL.md",
"viking://user/sessions/session-1/messages.jsonl",
"viking://user/memories/preferences/",
"viking://user/memories/preferences/.overview.md",
"viking://user/memories/preferences/.abstract.md",
"viking://user/memories/preferences/mem_abc123.md?recursive=true",
])
def test_handle_tool_call_forget_rejects_non_memory_file_uris(uri):
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
result = json.loads(provider.handle_tool_call("viking_forget", {"uri": uri}))
assert "error" in result
provider._client.delete.assert_not_called()
def test_viking_client_delete_uses_identity_headers(monkeypatch):
client = _VikingClient(
"https://example.com",
api_key="test-key",
account="acct",
user="alice",
agent="hermes",
)
captured = {}
def capture_delete(url, **kwargs):
captured["url"] = url
captured["kwargs"] = kwargs
return SimpleNamespace(
status_code=200,
text="",
json=lambda: {"status": "ok", "result": {"uri": "viking://user/memories/x.md"}},
raise_for_status=lambda: None,
)
monkeypatch.setattr(client._httpx, "delete", capture_delete)
assert client.delete("/api/v1/fs", params={"uri": "viking://user/memories/x.md"}) == {
"status": "ok",
"result": {"uri": "viking://user/memories/x.md"},
}
assert captured["url"] == "https://example.com/api/v1/fs"
assert captured["kwargs"]["params"] == {"uri": "viking://user/memories/x.md"}
assert captured["kwargs"]["headers"]["Authorization"] == "Bearer test-key"
assert captured["kwargs"]["headers"]["X-OpenViking-Actor-Peer"] == "hermes"
def test_viking_client_upload_temp_file_uses_multipart_identity_headers(tmp_path, monkeypatch):
sample = tmp_path / "sample.md"
sample.write_text("# Local resource\n", encoding="utf-8")
@ -2637,6 +2768,94 @@ def test_on_memory_write_uses_content_write_independent_of_session_rotation():
)
def test_shutdown_waits_for_memory_write_worker(monkeypatch):
import threading
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
worker_started = threading.Event()
release_worker = threading.Event()
worker_finished = threading.Event()
shutdown_returned = threading.Event()
class StubClient:
def __init__(self, *a, **kw):
pass
def post(self, path, payload=None, **kwargs):
assert path == "/api/v1/content/write"
worker_started.set()
release_worker.wait(timeout=2.0)
worker_finished.set()
return {}
monkeypatch.setattr(openviking_module, "_VikingClient", StubClient)
provider.on_memory_write("add", "user", "remember this")
assert worker_started.wait(timeout=2.0), "worker never entered post()"
shutdown_thread = threading.Thread(
target=lambda: (provider.shutdown(), shutdown_returned.set()),
daemon=True,
)
shutdown_thread.start()
returned_before_worker_finished = shutdown_returned.wait(timeout=0.1)
release_worker.set()
assert shutdown_returned.wait(timeout=2.0), "shutdown did not return after worker finished"
shutdown_thread.join(timeout=2.0)
assert not returned_before_worker_finished
assert worker_finished.is_set()
assert provider._memory_write_threads == set()
@pytest.mark.parametrize(
("action", "content"),
[
("replace", "updated memory"),
("remove", ""),
("forget", ""),
("delete", ""),
],
)
def test_on_memory_write_ignores_non_add_actions(action, content, monkeypatch):
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
uri = "viking://user/peers/hermes/memories/preferences/mem_abc123.md"
spawned = []
class StubThread:
def __init__(self, *args, **kwargs):
spawned.append((args, kwargs))
def start(self):
raise AssertionError("non-URI remove should not spawn a mirror thread")
import plugins.memory.openviking as _mod
monkeypatch.setattr(_mod.threading, "Thread", StubThread)
provider.on_memory_write(
action,
"memory",
content,
metadata={"uri": uri, "old_text": "stale fact"},
)
assert spawned == []
# ---------------------------------------------------------------------------
# Prefetch staleness: a prefetch worker that finishes AFTER a session switch
# must drop its result instead of repopulating the new session with stale

View file

@ -0,0 +1,153 @@
"""Unit tests for the Ollama Cloud provider profile's reasoning-effort wiring.
Ollama Cloud's ``/v1/chat/completions`` endpoint supports top-level
``reasoning_effort`` with values ``none``, ``low``, ``medium``, ``high``,
and (undocumented but empirically confirmed) ``max``. The profile maps
Hermes's ``xhigh`` → ``max`` to unlock DeepSeek V4's "Max thinking" tier
and passes the standard levels through unchanged.
These tests pin the profile's wire-shape contract so Ollama Cloud
requests carry the correct ``reasoning_effort`` field.
"""
from __future__ import annotations
import pytest
@pytest.fixture
def ollama_cloud_profile():
"""Resolve the registered Ollama Cloud profile.
Going through ``providers.get_provider_profile`` keeps the test
honest if someone replaces the registered class with a plain
``ProviderProfile``, every assertion below collapses.
"""
# ``model_tools`` triggers plugin discovery on import, which is what
# registers the Ollama Cloud profile in the global provider registry.
import model_tools # noqa: F401
import providers
profile = providers.get_provider_profile("ollama-cloud")
assert profile is not None, "ollama-cloud provider profile must be registered"
return profile
class TestOllamaCloudReasoningEffort:
"""``build_api_kwargs_extras`` emits correct top-level ``reasoning_effort``."""
# ── xhigh / max → max ──────────────────────────────────────────
@pytest.mark.parametrize("effort", ["xhigh", "max", "MAX", " Max "])
def test_xhigh_and_max_normalize_to_max(self, ollama_cloud_profile, effort):
extra_body, top_level = ollama_cloud_profile.build_api_kwargs_extras(
reasoning_config={"enabled": True, "effort": effort},
)
assert extra_body == {}
assert top_level == {"reasoning_effort": "max"}
# ── low / medium / high pass through ───────────────────────────
@pytest.mark.parametrize("effort", ["low", "medium", "high"])
def test_standard_efforts_pass_through(self, ollama_cloud_profile, effort):
_, top_level = ollama_cloud_profile.build_api_kwargs_extras(
reasoning_config={"enabled": True, "effort": effort},
)
assert top_level == {"reasoning_effort": effort}
# ── disabled → no reasoning_effort emitted ─────────────────────
def test_explicitly_disabled_emits_nothing(self, ollama_cloud_profile):
extra_body, top_level = ollama_cloud_profile.build_api_kwargs_extras(
reasoning_config={"enabled": False},
)
assert extra_body == {}
assert top_level == {}
def test_disabled_ignores_effort_field(self, ollama_cloud_profile):
"""Effort silently dropped when thinking is off."""
_, top_level = ollama_cloud_profile.build_api_kwargs_extras(
reasoning_config={"enabled": False, "effort": "high"},
)
assert top_level == {}
# ── none effort → no reasoning_effort ──────────────────────────
def test_none_effort_emits_nothing(self, ollama_cloud_profile):
extra_body, top_level = ollama_cloud_profile.build_api_kwargs_extras(
reasoning_config={"enabled": True, "effort": "none"},
)
assert extra_body == {}
assert top_level == {}
# ── missing / empty effort → let model default ─────────────────
def test_no_reasoning_config_emits_nothing(self, ollama_cloud_profile):
extra_body, top_level = ollama_cloud_profile.build_api_kwargs_extras(
reasoning_config=None,
)
assert extra_body == {}
assert top_level == {}
def test_empty_effort_emits_nothing(self, ollama_cloud_profile):
_, top_level = ollama_cloud_profile.build_api_kwargs_extras(
reasoning_config={"enabled": True, "effort": ""},
)
assert top_level == {}
def test_no_effort_key_emits_nothing(self, ollama_cloud_profile):
"""When effort key is absent, let the model use its default."""
_, top_level = ollama_cloud_profile.build_api_kwargs_extras(
reasoning_config={"enabled": True},
)
assert top_level == {}
# ── unknown effort → forwarded as-is ───────────────────────────
def test_unknown_effort_forwarded(self, ollama_cloud_profile):
_, top_level = ollama_cloud_profile.build_api_kwargs_extras(
reasoning_config={"enabled": True, "effort": "ultra"},
)
assert top_level == {"reasoning_effort": "ultra"}
class TestOllamaCloudFullKwargsIntegration:
"""End-to-end: the transport's full kwargs include reasoning_effort."""
def test_full_kwargs_with_xhigh(self, ollama_cloud_profile):
from agent.transports.chat_completions import ChatCompletionsTransport
kwargs = ChatCompletionsTransport().build_kwargs(
model="deepseek-v4-pro:cloud",
messages=[{"role": "user", "content": "ping"}],
tools=None,
provider_profile=ollama_cloud_profile,
reasoning_config={"enabled": True, "effort": "xhigh"},
base_url="https://ollama.com/v1",
provider_name="ollama-cloud",
)
assert kwargs["model"] == "deepseek-v4-pro:cloud"
assert kwargs["reasoning_effort"] == "max"
# No extra_body — Ollama Cloud uses top-level reasoning_effort
assert "extra_body" not in kwargs or "reasoning" not in kwargs.get("extra_body", {})
def test_full_kwargs_with_disabled(self, ollama_cloud_profile):
from agent.transports.chat_completions import ChatCompletionsTransport
kwargs = ChatCompletionsTransport().build_kwargs(
model="deepseek-v4-pro:cloud",
messages=[{"role": "user", "content": "ping"}],
tools=None,
provider_profile=ollama_cloud_profile,
reasoning_config={"enabled": False},
base_url="https://ollama.com/v1",
provider_name="ollama-cloud",
)
assert "reasoning_effort" not in kwargs
class TestOllamaCloudAuxModel:
"""Ollama Cloud aux model is set on the profile."""
def test_profile_advertises_aux_model(self, ollama_cloud_profile):
assert ollama_cloud_profile.default_aux_model == "nemotron-3-nano:30b"

View file

@ -0,0 +1,138 @@
"""Unit coverage for the background-review aux-model selector + routed digest.
Covers the two behaviors this change adds:
_resolve_review_runtime auto/same-model not routed (main model, warm
cache); a configured different model routed with resolved credentials.
_digest_history compact replay used ONLY on the routed path (recent tail
verbatim + a digest of older turns), preserving role alternation.
Pure-function / config-driven; no live model calls.
"""
from unittest.mock import patch
from agent import background_review as br
def _msg(role, content, tool_calls=None):
m = {"role": role, "content": content}
if tool_calls:
m["tool_calls"] = tool_calls
return m
# ---------------------------------------------------------------------------
# _resolve_review_runtime — the aux-model selector
# ---------------------------------------------------------------------------
class _FakeAgent:
def __init__(self, provider="openai-codex", model="gpt-5.5"):
self.provider = provider
self.model = model
def _current_main_runtime(self):
return {
"api_key": "parent-key",
"base_url": "https://chatgpt.com/backend-api/codex",
"api_mode": "codex_app_server",
}
def test_routing_auto_inherits_parent_and_downgrades_codex_app_server():
agent = _FakeAgent()
cfg = {"auxiliary": {"background_review": {"provider": "auto", "model": ""}}}
with patch("hermes_cli.config.load_config", return_value=cfg):
rt = br._resolve_review_runtime(agent)
assert rt["routed"] is False
assert rt["provider"] == "openai-codex"
assert rt["model"] == "gpt-5.5"
assert rt["api_mode"] == "codex_responses" # downgraded so agent-loop tools dispatch
def test_routing_to_different_model_marks_routed_and_resolves_credentials():
agent = _FakeAgent()
cfg = {"auxiliary": {"background_review": {
"provider": "openrouter", "model": "google/gemini-3-flash-preview",
}}}
fake_rp = {
"provider": "openrouter", "api_key": "or-key",
"base_url": "https://openrouter.ai/api/v1", "api_mode": "chat_completions",
}
with patch("hermes_cli.config.load_config", return_value=cfg), \
patch("hermes_cli.runtime_provider.resolve_runtime_provider", return_value=fake_rp):
rt = br._resolve_review_runtime(agent)
assert rt["routed"] is True
assert rt["provider"] == "openrouter"
assert rt["model"] == "google/gemini-3-flash-preview"
assert rt["api_key"] == "or-key"
def test_routing_same_model_as_parent_is_not_routed():
agent = _FakeAgent(provider="openrouter", model="anthropic/claude-opus-4.8")
cfg = {"auxiliary": {"background_review": {
"provider": "openrouter", "model": "anthropic/claude-opus-4.8",
}}}
with patch("hermes_cli.config.load_config", return_value=cfg):
rt = br._resolve_review_runtime(agent)
assert rt["routed"] is False # same model/provider → keep full-replay path
def test_routing_resolution_failure_falls_back_to_parent():
agent = _FakeAgent()
cfg = {"auxiliary": {"background_review": {
"provider": "openrouter", "model": "google/gemini-3-flash-preview",
}}}
with patch("hermes_cli.config.load_config", return_value=cfg), \
patch("hermes_cli.runtime_provider.resolve_runtime_provider",
side_effect=RuntimeError("boom")):
rt = br._resolve_review_runtime(agent)
assert rt["routed"] is False
assert rt["provider"] == "openai-codex"
# ---------------------------------------------------------------------------
# _digest_history — routed-path compact replay
# ---------------------------------------------------------------------------
def test_digest_under_tail_returns_full():
msgs = [_msg("user", "hi"), _msg("assistant", "hello")]
assert br._digest_history(msgs, tail=24) == msgs
def test_digest_collapses_old_keeps_tail_verbatim():
msgs = []
for i in range(60):
msgs.append(_msg("user", f"u{i} " + "x" * 50))
msgs.append(_msg("assistant", f"a{i} " + "y" * 50))
out = br._digest_history(msgs, tail=10)
# First message is the synthetic digest (user role → alternation preserved).
assert out[0]["role"] == "user"
assert out[0]["content"].startswith("[Earlier conversation digest")
# Recent tail preserved verbatim.
assert out[-1] == msgs[-1]
assert len(out) == 11 # 1 digest + 10 tail
def test_digest_does_not_open_tail_on_a_tool_message():
msgs = []
for i in range(40):
msgs.append(_msg("user", "u" + "x" * 50))
msgs.append(_msg("assistant", "", tool_calls=[
{"function": {"name": "terminal", "arguments": "{}"}}]))
msgs.append({"role": "tool", "content": "result " + "w" * 50})
out = br._digest_history(msgs, tail=2)
# The verbatim tail (after the digest) must not begin on a bare tool message.
assert out[1]["role"] != "tool"
def test_digest_records_tool_names_in_arc():
old = [
_msg("user", "do the thing"),
_msg("assistant", "", tool_calls=[
{"function": {"name": "skill_view", "arguments": "{}"}},
{"function": {"name": "patch", "arguments": "{}"}}]),
]
msgs = old + [_msg("user", f"tail{i}") for i in range(30)]
out = br._digest_history(msgs, tail=10)
digest = out[0]["content"]
assert "USER: do the thing" in digest
assert "tools: skill_view, patch" in digest

View file

@ -260,6 +260,52 @@ class TestShrinkImagePartsHelper:
assert seen["max_dimension"] == 2000
assert msgs[0]["content"][0]["image_url"]["url"] == shrunk
def test_anthropic_base64_image_source_rewritten(self, monkeypatch):
"""Anthropic-native image blocks are shrinkable after adapter conversion."""
agent = _make_agent()
_install_fake_pillow(monkeypatch, (2501, 100), shrunk_size=(1500, 60))
original = _big_png_data_url(100)
_, _, original_data = original.partition(",")
shrunk = "data:image/jpeg;base64," + "N" * 1000
seen = {}
def _fake_resize(path, mime_type=None, max_base64_bytes=None, max_dimension=None):
seen["mime_type"] = mime_type
seen["max_dimension"] = max_dimension
return shrunk
monkeypatch.setattr(
"tools.vision_tools._resize_image_for_vision",
_fake_resize,
raising=False,
)
msgs = [{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": original_data,
},
},
],
}]
changed = agent._try_shrink_image_parts_in_messages(
msgs,
max_dimension=2000,
)
source = msgs[0]["content"][0]["source"]
assert changed is True
assert seen["mime_type"] == "image/png"
assert seen["max_dimension"] == 2000
assert source["type"] == "base64"
assert source["media_type"] == "image/jpeg"
assert source["data"] == "N" * 1000
def test_oversized_input_image_string_shape_rewritten(self, monkeypatch):
"""OpenAI Responses shape: {type: input_image, image_url: "data:..."}."""
agent = _make_agent()

View file

@ -23,6 +23,7 @@ from agent.codex_responses_adapter import _normalize_codex_response
import run_agent
from run_agent import AIAgent
from agent.error_classifier import FailoverReason
from agent.memory_manager import MemoryManager
from agent.prompt_builder import DEFAULT_AGENT_IDENTITY
@ -2082,6 +2083,41 @@ class TestExecuteToolCalls:
assert messages[0]["role"] == "tool"
assert "search result" in messages[0]["content"]
def test_sequential_memory_remove_notifies_provider_with_tool_result(self, agent):
old_text = "stale preference entry"
tc = _mock_tool_call(
name="memory",
arguments=json.dumps({
"action": "remove",
"target": "memory",
"old_text": old_text,
}),
call_id="mem-1",
)
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc])
messages = []
calls = []
class FakeMemoryManager(MemoryManager):
def has_tool(self, tool_name):
return False
def on_memory_write(self, action, target, content, metadata=None):
calls.append((action, target, content, metadata or {}))
agent._memory_manager = FakeMemoryManager()
agent._memory_store = object()
with patch("tools.memory_tool.memory_tool", return_value=json.dumps({"success": True})):
agent._execute_tool_calls_sequential(mock_msg, messages, "task-1")
assert len(calls) == 1
action, target, content, metadata = calls[0]
assert (action, target, content) == ("remove", "memory", "")
assert metadata["old_text"] == old_text
assert metadata["tool_call_id"] == "mem-1"
assert messages[-1]["tool_call_id"] == "mem-1"
def test_keyboard_interrupt_emits_cancelled_post_tool_hook(self, agent, monkeypatch):
tc = _mock_tool_call(name="web_search", arguments='{"q":"test"}', call_id="c1")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc])
@ -2457,6 +2493,35 @@ class TestConcurrentToolExecution:
assert messages[1]["tool_call_id"] == "c2"
assert "success" in messages[1]["content"]
def test_concurrent_submit_shutdown_error_returns_tool_errors(self, agent):
"""Submit-time interpreter shutdown should not escape the outer loop."""
class ShutdownExecutor:
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def submit(self, *args, **kwargs):
raise RuntimeError("cannot schedule new futures after interpreter shutdown")
tc1 = _mock_tool_call(name="web_search", arguments='{"q": "alpha"}', call_id="c1")
tc2 = _mock_tool_call(name="web_search", arguments='{"q": "beta"}', call_id="c2")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
messages = []
with patch("agent.tool_executor.concurrent.futures.ThreadPoolExecutor", ShutdownExecutor):
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
assert len(messages) == 2
assert messages[0]["tool_call_id"] == "c1"
assert messages[1]["tool_call_id"] == "c2"
assert all("Python interpreter is shutting down" in m["content"] for m in messages)
def test_concurrent_interrupt_before_start(self, agent):
"""If interrupt is requested before concurrent execution, all tools are skipped."""
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
@ -2797,6 +2862,68 @@ class TestConcurrentToolExecution:
assert json.loads(result) == {"error": "Blocked"}
assert agent._turns_since_memory == 5
def test_invoke_tool_memory_remove_notifies_provider_with_old_text(self, agent, monkeypatch):
monkeypatch.setattr(
"hermes_cli.plugins.get_pre_tool_call_block_message",
lambda *args, **kwargs: None,
)
calls = []
class FakeMemoryManager(MemoryManager):
def has_tool(self, tool_name):
return False
def on_memory_write(self, action, target, content, metadata=None):
calls.append((action, target, content, metadata or {}))
old_text = "stale preference entry"
agent._memory_manager = FakeMemoryManager()
agent._memory_store = object()
with patch("tools.memory_tool.memory_tool", return_value=json.dumps({"success": True})):
agent._invoke_tool(
"memory",
{"action": "remove", "target": "memory", "old_text": old_text},
"task-1",
tool_call_id="mem-1",
)
assert len(calls) == 1
action, target, content, metadata = calls[0]
assert (action, target, content) == ("remove", "memory", "")
assert metadata["old_text"] == old_text
assert metadata["tool_call_id"] == "mem-1"
def test_invoke_tool_memory_failed_remove_skips_provider_notification(self, agent, monkeypatch):
monkeypatch.setattr(
"hermes_cli.plugins.get_pre_tool_call_block_message",
lambda *args, **kwargs: None,
)
notify = MagicMock(side_effect=AssertionError("should not notify"))
class FakeMemoryManager(MemoryManager):
def has_tool(self, tool_name):
return False
on_memory_write = notify
manager = FakeMemoryManager()
agent._memory_manager = manager
agent._memory_store = object()
with patch(
"tools.memory_tool.memory_tool",
return_value=json.dumps({"success": False, "error": "No entry matched"}),
):
agent._invoke_tool(
"memory",
{"action": "remove", "target": "memory", "old_text": "missing"},
"task-1",
tool_call_id="mem-1",
)
notify.assert_not_called()
def test_concurrent_blocked_write_skips_checkpoint(self, agent, monkeypatch):
"""Concurrent path: blocked write_file should not trigger checkpoint."""
tc1 = _mock_tool_call(name="write_file",

View file

@ -0,0 +1,252 @@
"""Behavior contracts for incremental tool-call persistence (#49045).
A destructive or process-terminating tool that runs during tool execution
must not lose the just-executed assistant(tool_calls) block or the tool
results that were produced before it fired. These tests pin the contract:
1. run_conversation flushes the assistant tool-call turn to the session
DB BEFORE handing control to _execute_tool_calls (so a tool that
restarts/kills the process never orphans the tool-call block).
2. The SEQUENTIAL tool path flushes each tool result to the session DB
immediately after appending it BEFORE the next tool dispatches.
3. The CONCURRENT tool path flushes each tool result in append order.
These exercise the REAL production dispatch surfaces:
* sequential -> ``run_agent.handle_function_call`` (tool_executor ~1256/1298)
* concurrent -> ``agent._invoke_tool`` (tool_executor ~539)
Mocking the genuine dispatch surface keeps the tests deterministic (no real
``web_search`` / network) AND mutation-survivable: the ordering assertions
read snapshots captured at flush time, so removing any production flush call
makes the corresponding assertion fail.
"""
import copy
from types import SimpleNamespace
from pathlib import Path
import tempfile
from unittest.mock import MagicMock, patch
from agent.tool_dispatch_helpers import make_tool_result_message
from run_agent import AIAgent
def _make_tool_defs(*names: str) -> list:
return [
{
"type": "function",
"function": {
"name": name,
"description": f"{name} tool",
"parameters": {"type": "object", "properties": {}},
},
}
for name in names
]
def _make_agent():
hermes_home = Path(tempfile.mkdtemp(prefix="hermes-test-home-"))
(hermes_home / "logs").mkdir(parents=True, exist_ok=True)
with (
patch(
"run_agent.get_tool_definitions",
return_value=_make_tool_defs("web_search"),
),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
patch("run_agent._hermes_home", hermes_home),
patch("agent.model_metadata.fetch_model_metadata", return_value={}),
):
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
agent.client = MagicMock()
agent._cached_system_prompt = "You are helpful."
agent._use_prompt_caching = False
agent.tool_delay = 0
agent.compression_enabled = False
agent.save_trajectories = False
return agent
def _mock_tool_call(name="web_search", arguments="{}", call_id="call_1"):
return SimpleNamespace(
id=call_id,
type="function",
function=SimpleNamespace(name=name, arguments=arguments),
)
def _mock_response(content="Hello", finish_reason="stop", tool_calls=None):
msg = SimpleNamespace(content=content, tool_calls=tool_calls)
choice = SimpleNamespace(message=msg, finish_reason=finish_reason)
return SimpleNamespace(choices=[choice], model="test/model", usage=None)
# ---------------------------------------------------------------------------
# Contract 1: run_conversation persists the assistant tool-call block BEFORE
# tool execution begins.
# ---------------------------------------------------------------------------
def test_run_conversation_flushes_assistant_tool_call_before_execution():
agent = _make_agent()
tool_call = _mock_tool_call(call_id="c1")
agent.client.chat.completions.create.side_effect = [
_mock_response(content="", finish_reason="tool_calls", tool_calls=[tool_call]),
_mock_response(content="done", finish_reason="stop"),
]
# Record a deep snapshot of the message list at every flush so the
# assertion does not depend on later mutations.
flush_snapshots: list[list] = []
def _record_flush(messages, conversation_history=None):
flush_snapshots.append(copy.deepcopy(messages))
agent._flush_messages_to_session_db = MagicMock(side_effect=_record_flush)
# Capture observations at execute time into module-level lists rather than
# asserting inside _execute_tool_calls — run_conversation's outer loop
# swallows exceptions, so an in-callback assertion would never surface.
executed = {"count": 0}
snapshot_at_execute: list = []
def _fake_execute(assistant_message, messages, effective_task_id, api_call_count=0):
executed["count"] += 1
# Record the DB state observed at the moment tool execution begins.
snapshot_at_execute.append(
copy.deepcopy(flush_snapshots[-1]) if flush_snapshots else None
)
# Simulate the tool producing a result (as the real path would).
messages.append(make_tool_result_message("web_search", "search result", "c1"))
with (
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
patch.object(agent, "_execute_tool_calls", side_effect=_fake_execute),
):
result = agent.run_conversation("search something")
assert executed["count"] == 1, "_execute_tool_calls was never reached"
# The assistant tool-call block MUST have been flushed before execution.
last = snapshot_at_execute[0]
assert last is not None, "no flush occurred before tool execution"
assert last[-1]["role"] == "assistant"
assert last[-1]["tool_calls"][0]["id"] == "c1"
assert result["final_response"] == "done"
# ---------------------------------------------------------------------------
# Contract 2: the SEQUENTIAL path flushes each tool result immediately, BEFORE
# the next tool dispatches. Dispatch goes through run_agent.handle_function_call
# (the real production surface), which we mock for determinism.
# ---------------------------------------------------------------------------
def test_execute_tool_calls_sequential_flushes_each_tool_result_before_next_dispatch():
agent = _make_agent()
tool_calls = [
_mock_tool_call(name="web_search", call_id="c1"),
_mock_tool_call(name="web_search", call_id="c2"),
]
messages: list = []
assistant_message = SimpleNamespace(content="", tool_calls=tool_calls)
# Ordered event log interleaving real dispatches and DB flushes.
events: list = []
def _fake_dispatch(function_name, function_args, effective_task_id, **kwargs):
# The result for call N must have been flushed before call N+1 fires.
events.append(("dispatch", kwargs.get("tool_call_id")))
return f"result-{kwargs.get('tool_call_id')}"
def _record_flush(flush_messages, conversation_history=None):
# Snapshot the tail tool result that triggered this flush.
tail = flush_messages[-1]
events.append(("flush", tail.get("role"), tail.get("tool_call_id")))
agent._flush_messages_to_session_db = MagicMock(side_effect=_record_flush)
with (
patch("run_agent.handle_function_call", side_effect=_fake_dispatch) as disp,
patch(
"agent.tool_executor.maybe_persist_tool_result",
side_effect=lambda **kwargs: kwargs["content"],
),
):
agent._execute_tool_calls_sequential(assistant_message, messages, "task-1")
# The mock proves we exercised the REAL sequential dispatch surface.
assert disp.call_count == 2, "sequential path did not dispatch via handle_function_call"
# Both tool results landed, in order.
assert [m["role"] for m in messages] == ["tool", "tool"]
assert [m["tool_call_id"] for m in messages] == ["c1", "c2"]
# Ordering contract: each tool result is flushed AFTER its own dispatch
# and BEFORE the next dispatch. Expected interleaving:
# dispatch c1 -> flush c1 -> dispatch c2 -> flush c2
assert events == [
("dispatch", "c1"),
("flush", "tool", "c1"),
("dispatch", "c2"),
("flush", "tool", "c2"),
]
# ---------------------------------------------------------------------------
# Contract 3: the CONCURRENT path flushes each collected tool result in append
# order. Dispatch goes through agent._invoke_tool (the real concurrent
# surface), which we mock for determinism.
# ---------------------------------------------------------------------------
def test_execute_tool_calls_concurrent_flushes_each_tool_result_in_order():
agent = _make_agent()
tool_calls = [
_mock_tool_call(name="web_search", call_id="c1"),
_mock_tool_call(name="web_search", call_id="c2"),
]
messages: list = []
assistant_message = SimpleNamespace(content="", tool_calls=tool_calls)
invoked_ids: list = []
def _fake_invoke(function_name, function_args, effective_task_id, tool_call_id, **kwargs):
invoked_ids.append(tool_call_id)
return f"result-{tool_call_id}"
# Each flush must observe exactly one more tool result than the previous
# flush, in append order — i.e. the tail tool_call_id sequence is c1, c2.
flushed_tool_ids: list = []
flush_lengths: list = []
def _record_flush(flush_messages, conversation_history=None):
flushed_tool_ids.append(flush_messages[-1]["tool_call_id"])
flush_lengths.append(len([m for m in flush_messages if m.get("role") == "tool"]))
agent._flush_messages_to_session_db = MagicMock(side_effect=_record_flush)
with (
patch.object(agent, "_invoke_tool", side_effect=_fake_invoke) as inv,
patch(
"agent.tool_executor.maybe_persist_tool_result",
side_effect=lambda **kwargs: kwargs["content"],
),
):
agent._execute_tool_calls_concurrent(assistant_message, messages, "task-1")
# Proves the real concurrent dispatch surface was exercised.
assert inv.call_count == 2, "concurrent path did not dispatch via _invoke_tool"
assert sorted(invoked_ids) == ["c1", "c2"]
# Results appended in deterministic order.
assert [m["tool_call_id"] for m in messages] == ["c1", "c2"]
# Each tool result was flushed exactly once, in append order, with the
# running tool count growing by one each time (1 then 2). Removing either
# production flush call breaks one of these assertions.
assert flushed_tool_ids == ["c1", "c2"]
assert flush_lengths == [1, 2]

View file

@ -0,0 +1,164 @@
"""Tests for optional-skills/web-development/cloudflare-temporary-deploy/scripts/parse_deploy_output.py"""
import json
import sys
from pathlib import Path
from unittest import mock
import pytest
SCRIPTS_DIR = (
Path(__file__).resolve().parents[2]
/ "optional-skills"
/ "web-development"
/ "cloudflare-temporary-deploy"
/ "scripts"
)
sys.path.insert(0, str(SCRIPTS_DIR))
import parse_deploy_output as pdo
CREATED = """\
Continuing means you accept Cloudflare's Terms of Service and Privacy Policy.
Temporary account ready:
Account: swift-otter (created)
Claim within: 60 minutes
Claim URL: https://dash.cloudflare.com/claim-preview?claimToken=TOKEN_AAA
Uploaded my-worker
Deployed my-worker triggers
https://my-worker.swift-otter.workers.dev
"""
REUSED = """\
Temporary account ready:
Account: swift-otter (reused)
Claim within: 17 minutes
Claim URL: https://dash.cloudflare.com/claim-preview?claimToken=TOKEN_BBB
Deployed my-worker triggers
https://my-worker.swift-otter.workers.dev
"""
NOT_LOGGED_IN = """\
[ERROR] You are not logged in.
To continue without logging in, rerun this command with `--temporary`.
"""
AUTH_PRESENT_ERROR = """\
[ERROR] The --temporary flag cannot be used while Wrangler is authenticated.
Run `wrangler logout` first, or remove CLOUDFLARE_API_TOKEN.
"""
class TestParseCreated:
def test_live_url(self):
assert pdo.parse(CREATED)["live_url"] == "https://my-worker.swift-otter.workers.dev"
def test_claim_url(self):
assert (
pdo.parse(CREATED)["claim_url"]
== "https://dash.cloudflare.com/claim-preview?claimToken=TOKEN_AAA"
)
def test_account_and_state(self):
r = pdo.parse(CREATED)
assert r["account"] == "swift-otter"
assert r["account_state"] == "created"
def test_expiry_and_deployed(self):
r = pdo.parse(CREATED)
assert r["expires_minutes"] == 60
assert r["deployed"] is True
class TestParseReused:
def test_state_is_reused(self):
assert pdo.parse(REUSED)["account_state"] == "reused"
def test_expiry_window_can_shrink(self):
assert pdo.parse(REUSED)["expires_minutes"] == 17
def test_live_url_stable(self):
assert pdo.parse(REUSED)["live_url"] == "https://my-worker.swift-otter.workers.dev"
class TestNoDeploy:
def test_not_logged_in_has_no_urls(self):
r = pdo.parse(NOT_LOGGED_IN)
assert r["live_url"] is None
assert r["claim_url"] is None
assert r["account"] is None
assert r["deployed"] is False
def test_auth_present_error_has_no_urls(self):
r = pdo.parse(AUTH_PRESENT_ERROR)
assert r["live_url"] is None
assert r["claim_url"] is None
assert r["deployed"] is False
class TestRealWorldOutput:
"""Regression: real wrangler output uses tab-indent + multi-word account names."""
REAL = (
"⛅️ wrangler 4.103.0\n"
"Continuing means you accept Cloudflare's Terms of Service and Privacy Policy.\n"
"Solving proof-of-work challenge…\n"
"Temporary account ready:\n"
"\tAccount: Serene Temple (created)\n"
"\tClaim within: 60 minutes\n"
"\tClaim URL: https://dash.cloudflare.com/claim-preview?claimToken=fxLzyAD-vlTzMQmClpg\n"
"Total Upload: 0.19 KiB / gzip: 0.16 KiB\n"
"Uploaded hermes-temp-hello (0.74 sec)\n"
"Deployed hermes-temp-hello triggers (0.42 sec)\n"
" https://hermes-temp-hello.serene-temple.workers.dev\n"
)
def test_multiword_account_name(self):
r = pdo.parse(self.REAL)
assert r["account"] == "Serene Temple"
assert r["account_state"] == "created"
def test_all_fields_from_real_output(self):
r = pdo.parse(self.REAL)
assert r["live_url"] == "https://hermes-temp-hello.serene-temple.workers.dev"
assert r["claim_url"].endswith("claimToken=fxLzyAD-vlTzMQmClpg")
assert r["expires_minutes"] == 60
assert r["deployed"] is True
class TestUrlHygiene:
def test_trailing_punctuation_stripped(self):
text = "Deployed\n see https://w.acct.workers.dev. for details"
assert pdo.parse(text)["live_url"] == "https://w.acct.workers.dev"
def test_does_not_match_plain_cloudflare_com(self):
# A generic cloudflare.com link without a claimToken must not be taken as the claim URL.
text = "Privacy Policy: https://www.cloudflare.com/privacypolicy/\nDeployed x"
assert pdo.parse(text)["claim_url"] is None
class TestCli:
def test_selftest_exits_zero(self):
assert pdo.main(["--selftest"]) == 0
def test_main_prints_json_and_exit_zero_on_live(self, capsys):
with mock.patch.object(sys.stdin, "read", return_value=CREATED):
rc = pdo.main([])
out = json.loads(capsys.readouterr().out)
assert rc == 0
assert out["live_url"] == "https://my-worker.swift-otter.workers.dev"
def test_main_exit_one_when_no_live_url(self, capsys):
with mock.patch.object(sys.stdin, "read", return_value=NOT_LOGGED_IN):
rc = pdo.main([])
out = json.loads(capsys.readouterr().out)
assert rc == 1
assert out["live_url"] is None
if __name__ == "__main__":
raise SystemExit(pytest.main([__file__, "-q"]))

79
tests/test_code_skew.py Normal file
View file

@ -0,0 +1,79 @@
"""Tests for gateway code-skew detection (stale-checkout guard).
Companion to ``tests/test_stale_utils_module_import.py``: that test proves the
crash; these prove the guard that turns it into a clear "restart the gateway"
message before a model switch can hit it.
"""
import pytest
from gateway import code_skew
@pytest.fixture(autouse=True)
def _reset_boot_fingerprint(monkeypatch):
"""Each test starts with no recorded boot fingerprint."""
monkeypatch.setattr(code_skew, "_boot_fingerprint", None)
class TestDetectCodeSkew:
def test_no_boot_fingerprint_means_no_skew(self, monkeypatch):
# Nothing recorded (e.g. non-git install) -> never a false positive.
monkeypatch.setattr(code_skew, "_fingerprint", lambda: "git:refs/heads/main:def456")
assert code_skew.detect_code_skew() is None
def test_unchanged_checkout_is_not_skew(self, monkeypatch):
monkeypatch.setattr(code_skew, "_fingerprint", lambda: "git:refs/heads/main:abc1234567890")
code_skew.record_boot_fingerprint()
assert code_skew.detect_code_skew() is None
def test_drift_is_detected_with_short_revs(self, monkeypatch):
monkeypatch.setattr(code_skew, "_fingerprint", lambda: "git:refs/heads/main:abc1234567890")
code_skew.record_boot_fingerprint()
monkeypatch.setattr(code_skew, "_fingerprint", lambda: "git:refs/heads/main:def4567890123")
skew = code_skew.detect_code_skew()
assert skew == ("abc1234567", "def4567890")
def test_unreadable_current_rev_does_not_false_positive(self, monkeypatch):
monkeypatch.setattr(code_skew, "_fingerprint", lambda: "git:refs/heads/main:abc1234567890")
code_skew.record_boot_fingerprint()
monkeypatch.setattr(code_skew, "_fingerprint", lambda: None)
assert code_skew.detect_code_skew() is None
def test_record_is_idempotent(self, monkeypatch):
monkeypatch.setattr(code_skew, "_fingerprint", lambda: "git:refs/heads/main:first")
code_skew.record_boot_fingerprint()
monkeypatch.setattr(code_skew, "_fingerprint", lambda: "git:refs/heads/main:second")
code_skew.record_boot_fingerprint() # must not overwrite the boot snapshot
assert code_skew._boot_fingerprint == "git:refs/heads/main:first"
class TestShort:
def test_shortens_long_sha(self):
assert code_skew._short("git:refs/heads/main:abcdef0123456789") == "abcdef0123"
def test_keeps_unresolved_marker(self):
assert code_skew._short("git:refs/heads/main:unresolved") == "unresolved"
def test_passes_short_sha_through_untruncated(self):
assert code_skew._short("git:HEAD:abc1234") == "abc1234"
class TestModelSwitchSkewGuard:
def test_guard_returns_none_without_skew(self, monkeypatch):
from gateway import slash_commands
monkeypatch.setattr(code_skew, "detect_code_skew", lambda: None)
assert slash_commands._model_switch_skew_guard() is None
def test_guard_message_names_revs_and_restart(self, monkeypatch):
from gateway import slash_commands
monkeypatch.setattr(code_skew, "detect_code_skew", lambda: ("abc1234567", "def4567890"))
msg = slash_commands._model_switch_skew_guard()
assert msg is not None
assert "abc1234567" in msg
assert "def4567890" in msg
assert "hermes gateway restart" in msg

View file

@ -12,19 +12,47 @@ REPO_ROOT = Path(__file__).resolve().parent.parent
INSTALL_SH = REPO_ROOT / "scripts" / "install.sh"
def test_install_script_skips_playwright_download_when_system_browser_exists() -> None:
def test_install_script_does_not_autodetect_system_browser_on_path() -> None:
"""The installer must not scan PATH/well-known locations for a browser.
Auto-detection silently bound the install to whatever ``command -v
chromium`` resolved to most damagingly a Snap Chromium, whose sandbox
blocks agent-browser's control socket and hangs every browser_navigate. The
fallback was dropped in favor of always using the bundled Playwright
Chromium, so the old PATH-scan and "use the system browser" path are gone.
"""
text = INSTALL_SH.read_text()
assert "find_system_browser()" in text
assert "google-chrome google-chrome-stable chromium chromium-browser chrome" in text
assert "Skipping Playwright browser download; Hermes will use the system browser." in text
assert "google-chrome google-chrome-stable chromium chromium-browser chrome" not in text
assert "Skipping Playwright browser download; Hermes will use the system browser." not in text
def test_install_script_persists_system_browser_for_agent_browser() -> None:
def test_install_script_honors_explicit_browser_override_only() -> None:
"""find_system_browser consults only an explicit AGENT_BROWSER_EXECUTABLE_PATH."""
text = INSTALL_SH.read_text()
assert "configure_browser_env_from_system_browser()" in text
assert "AGENT_BROWSER_EXECUTABLE_PATH=$browser_path" in text
assert 'override="${AGENT_BROWSER_EXECUTABLE_PATH:-}"' in text
# An explicit override still skips the bundled download (override, not fallback).
assert "Skipping bundled Chromium download" in text
def test_install_script_strips_stale_snap_browser_override() -> None:
"""Already-affected installs must auto-recover.
A pre-existing AGENT_BROWSER_EXECUTABLE_PATH pointing at a Snap Chromium is
the exact value that hangs the browser tool, and the runtime reads it from
.env so the installer strips it (and a Snap override is rejected even when
set explicitly) so the bundled Chromium download runs on update.
"""
text = INSTALL_SH.read_text()
assert "strip_snap_browser_override()" in text
assert "^AGENT_BROWSER_EXECUTABLE_PATH=/snap/" in text
# Both install paths invoke the migration before resolving a browser.
assert text.count("strip_snap_browser_override") >= 3
# A snap path is rejected by find_system_browser itself.
assert "/snap/*) return 1 ;;" in text
def test_playwright_installs_are_timeout_guarded() -> None:

View file

@ -0,0 +1,90 @@
"""Regression for the stale-``utils``-module ImportError after a hot ``git pull``.
Real incident (gateway session 1518671026962174144)::
Sorry, I encountered an error (ImportError).
cannot import name 'env_float' from 'utils' (~/.hermes/hermes-agent/utils.py)
Mechanism:
1. A long-running gateway/agent process imported ``utils`` BEFORE ``env_float``
existed (added in 06ca1e99, 2026-06-20 14:00). The cached module object in
``sys.modules`` therefore has no ``env_float`` attribute.
2. ``hermes update`` ran ``git pull``, updating ``utils.py`` (now defining
``env_float``) and ~22 consumer modules (now doing ``from utils import
env_float``) on disk -- WITHOUT restarting the process.
3. Switching the live session's model (anthropic/opus -> opencode/glm) forced the
FIRST import of a consumer module on the new provider's code path. Its
top-level ``from utils import env_float`` resolved against the STALE cached
``utils`` -> ImportError. The path in parentheses is the consumer-reported
``utils.__file__`` on disk (which *does* define ``env_float``), which is why
the error is so confusing: the file on disk is fine, the in-memory module is not.
``hermes_cli/main.py`` (the ``hermes update`` flow, ~line 9326) already
acknowledges this exact hazard -- "source files on disk are newer than cached
Python modules in this process" -- and reloads ``hermes_constants`` after the
pull, but NOT ``utils``. Any ``utils`` consumer added in the same release stays
exposed until the process restarts.
The messaging client (Discord/Telegram/Feishu/...) is incidental: the trigger is
a fresh import on a stale process, not the platform. We assert that below by
reproducing the failure with the Discord adapter's exact import line.
"""
import sys
import types
import pytest
def _import_fresh_consumer(name: str, source: str) -> types.ModuleType:
"""Import a brand-new module whose body runs ``source`` -- mimicking a
consumer module being imported for the first time on the model-switch path."""
mod = types.ModuleType(name)
mod.__file__ = f"{name}.py"
sys.modules.pop(name, None)
exec(compile(source, mod.__file__, "exec"), mod.__dict__)
sys.modules[name] = mod
return mod
class TestStaleUtilsModuleImport:
def test_fresh_consumer_import_fails_against_stale_utils(self, monkeypatch):
"""The bug: stale in-memory ``utils`` + fresh ``from utils import env_float``."""
import utils
# Sanity: today's on-disk source is healthy.
assert hasattr(utils, "env_float")
# Simulate the pre-06-20 cached module (monkeypatch auto-restores after).
monkeypatch.delattr(utils, "env_float")
with pytest.raises(ImportError, match=r"cannot import name 'env_float' from 'utils'"):
_import_fresh_consumer("stale_switch_path_consumer", "from utils import env_float\n")
def test_client_is_incidental_discord_import_line_fails_identically(self, monkeypatch):
"""Same failure via the Discord adapter's exact import line -- the client
does not determine the bug, the stale process does."""
import utils
monkeypatch.delattr(utils, "env_float")
# plugins/platforms/discord/adapter.py:106
with pytest.raises(ImportError, match=r"cannot import name 'env_float' from 'utils'"):
_import_fresh_consumer(
"stale_discord_consumer",
"from utils import atomic_json_write, env_float\n",
)
def test_healthy_process_imports_consumer_fine(self):
"""Control: when the cached ``utils`` matches disk (env_float present),
the same consumer import succeeds -- proving the harness isolates the
staleness, not an unrelated import error."""
import utils
assert hasattr(utils, "env_float")
mod = _import_fresh_consumer(
"healthy_consumer",
"from utils import env_float\nVALUE = env_float('UNSET_FOR_TEST', 1.5)\n",
)
assert mod.VALUE == 1.5

View file

@ -7946,3 +7946,45 @@ def test_start_agent_build_passes_session_model_override(monkeypatch):
assert session["agent"].model == "claude-sonnet-4.6"
finally:
server._sessions.clear()
# ── _get_usage active_subagents (TUI status-bar ⛓ indicator) ──────────────
# Mirrors the classic CLI status bar: _get_usage embeds a live count of
# background/async subagents from tools.async_delegation.active_count() so the
# Ink status bar can render ⛓ N. Source of truth is the same registry the CLI
# reads; the field rides the existing per-update `usage` payload.
class _BareAgent:
"""Agent stub with no compressor — exercises the active_subagents path
independent of the `if comp:` context-percent block."""
model = "x"
def test_get_usage_includes_active_subagents(monkeypatch):
import tools.async_delegation as ad_mod
monkeypatch.setattr(ad_mod, "active_count", lambda: 4)
usage = server._get_usage(_BareAgent())
assert usage["active_subagents"] == 4
def test_get_usage_active_subagents_zero(monkeypatch):
import tools.async_delegation as ad_mod
monkeypatch.setattr(ad_mod, "active_count", lambda: 0)
usage = server._get_usage(_BareAgent())
assert usage["active_subagents"] == 0
def test_get_usage_safe_when_active_count_raises(monkeypatch):
"""A raising active_count() must not break the usage payload."""
import tools.async_delegation as ad_mod
def _boom():
raise RuntimeError("boom")
monkeypatch.setattr(ad_mod, "active_count", _boom)
usage = server._get_usage(_BareAgent())
# Field omitted, but the rest of the payload is intact.
assert "active_subagents" not in usage
assert usage["model"] == "x"

File diff suppressed because it is too large Load diff

View file

@ -204,7 +204,7 @@ class TestCaptureResponseRoutedToAuxVision:
args, _kwargs = fake_vat.call_args
path_arg, prompt_arg = args[0], args[1]
assert str(tmp_cache_dir) in path_arg
assert "macOS application screenshot" in prompt_arg
assert "desktop application screenshot" in prompt_arg
# AX summary is included so the aux model can ground its description
# against the same set-of-mark index the agent will see.
assert "Sign in" in prompt_arg
@ -298,15 +298,17 @@ class TestCaptureResponseRoutedToAuxVision:
new_callable=lambda: fake_vat):
resp = cu_tool._capture_response(cap)
# Aux failure → fall back to multimodal envelope (so the user still
# gets *something* useful even if vision is broken).
assert isinstance(resp, dict)
assert resp.get("_multimodal") is True
# Aux failure with routing requested degrades to the AX/SOM text
# payload. Falling through to a multimodal envelope can hand pixels to
# a text-only model and fail the provider request.
assert isinstance(resp, str)
body = json.loads(resp)
assert body.get("vision_unavailable") is True
# Temp file must still be cleaned up.
assert observed_path["path"]
assert not os.path.exists(observed_path["path"])
def test_empty_aux_analysis_falls_back_to_multimodal(self, tmp_cache_dir):
def test_empty_aux_analysis_degrades_to_text_payload(self, tmp_cache_dir):
from tools.computer_use import tool as cu_tool
cap = _make_capture(mode="som")
@ -323,12 +325,15 @@ class TestCaptureResponseRoutedToAuxVision:
new_callable=lambda: fake_vat):
resp = cu_tool._capture_response(cap)
# Empty analysis is treated as failure — we'd rather show pixels
# than embed an empty 'vision_analysis' string into the result.
assert isinstance(resp, dict)
assert resp.get("_multimodal") is True
# Empty analysis is treated as failure; with routing requested the
# capture degrades to the AX/SOM text payload (elements stay usable)
# rather than embedding an empty 'vision_analysis' string.
assert isinstance(resp, str)
body = json.loads(resp)
assert body.get("vision_unavailable") is True
assert body.get("elements") is not None
def test_invalid_aux_response_falls_back_to_multimodal(self, tmp_cache_dir):
def test_invalid_aux_response_degrades_to_text_payload(self, tmp_cache_dir):
from tools.computer_use import tool as cu_tool
cap = _make_capture(mode="som")
@ -345,8 +350,9 @@ class TestCaptureResponseRoutedToAuxVision:
new_callable=lambda: fake_vat):
resp = cu_tool._capture_response(cap)
assert isinstance(resp, dict)
assert resp.get("_multimodal") is True
assert isinstance(resp, str)
body = json.loads(resp)
assert body.get("vision_unavailable") is True
# ---------------------------------------------------------------------------

View file

@ -0,0 +1,109 @@
"""Regression tests for profile-aware tilde expansion in file tools.
The bug (#48552): in-process file tools (write_file, read_file, patch,
search_files) resolved ``~`` via ``os.path.expanduser()``, which reads the
gateway process's ``HOME``. In profile mode (Docker, systemd, s6) the gateway
``HOME`` differs from the profile ``HOME`` that interactive sessions use, so
``~`` expanded to the wrong directory and file operations failed with
"no such file or directory".
The fix adds ``_expand_tilde()`` which delegates to
``hermes_constants.get_subprocess_home()`` the same policy the terminal tool
uses for subprocess environments.
See: https://github.com/NousResearch/hermes-agent/issues/48552
"""
import os
from pathlib import Path
from unittest.mock import patch
import pytest
import tools.file_tools as ft
# ---------------------------------------------------------------------------
# _expand_tilde() unit tests
# ---------------------------------------------------------------------------
class TestExpandTilde:
"""Verify the _expand_tilde() helper resolves ~ to the profile home."""
def test_tilde_expands_to_profile_home(self):
"""When get_subprocess_home returns a value, ~/path uses it."""
with patch("hermes_constants.get_subprocess_home", return_value="/opt/data/profiles/coder/home"):
result = ft._expand_tilde("~/scratch/file.txt")
assert result == "/opt/data/profiles/coder/home/scratch/file.txt"
def test_bare_tilde_expands_to_profile_home(self):
"""Bare ~ expands to the profile home."""
with patch("hermes_constants.get_subprocess_home", return_value="/opt/data/profiles/coder/home"):
result = ft._expand_tilde("~")
assert result == "/opt/data/profiles/coder/home"
def test_falls_back_when_no_profile_home(self):
"""When get_subprocess_home returns None, use os.path.expanduser."""
with patch("hermes_constants.get_subprocess_home", return_value=None):
result = ft._expand_tilde("~/Documents")
assert result == os.path.expanduser("~/Documents")
def test_other_user_tilde_not_overridden(self):
"""~user/path must NOT use the profile home — it's a different user."""
with patch("hermes_constants.get_subprocess_home", return_value="/opt/data/profiles/coder/home"):
result = ft._expand_tilde("~root/file.txt")
# Should use os.path.expanduser, not the profile home
assert "/opt/data/profiles/coder/home" not in result
def test_no_tilde_unchanged(self):
"""Paths without ~ are returned unchanged (modulo expanduser)."""
with patch("hermes_constants.get_subprocess_home", return_value="/opt/data/profiles/coder/home"):
result = ft._expand_tilde("/etc/passwd")
assert result == "/etc/passwd"
def test_empty_path_unchanged(self):
"""Empty string returns empty."""
with patch("hermes_constants.get_subprocess_home", return_value="/opt/data/profiles/coder/home"):
assert ft._expand_tilde("") == ""
# ---------------------------------------------------------------------------
# Integration: _resolve_path_for_task uses profile home
# ---------------------------------------------------------------------------
class TestResolvePathUsesProfileHome:
"""Verify _resolve_path_for_task resolves ~ to the profile home."""
def test_relative_tilde_resolves_to_profile_home(self, tmp_path, monkeypatch):
"""A ~/path argument resolves under the profile home, not process HOME."""
profile_home = tmp_path / "profile_home"
profile_home.mkdir()
process_home = tmp_path / "process_home"
process_home.mkdir()
monkeypatch.setenv("HOME", str(process_home))
monkeypatch.setattr(ft, "_get_live_tracking_cwd", lambda task_id="default": None)
with patch("hermes_constants.get_subprocess_home", return_value=str(profile_home)):
resolved = ft._resolve_path_for_task("~/test_file.txt", task_id="test")
assert str(resolved).startswith(str(profile_home))
assert "process_home" not in str(resolved)
def test_absolute_tilde_in_workspace_root(self, tmp_path, monkeypatch):
"""A workspace root specified with ~ resolves to profile home."""
profile_home = tmp_path / "profile_home"
profile_home.mkdir()
process_home = tmp_path / "process_home"
process_home.mkdir()
monkeypatch.setenv("HOME", str(process_home))
monkeypatch.setattr(ft, "_get_live_tracking_cwd", lambda task_id="default": None)
with patch("hermes_constants.get_subprocess_home", return_value=str(profile_home)):
# _resolve_base_dir uses the workspace root from config; if it contains ~,
# it should resolve to profile home
resolved = ft._resolve_path_for_task("~/data/config.json", task_id="test")
assert str(profile_home) in str(resolved)
assert str(process_home) not in str(resolved)

View file

@ -260,6 +260,56 @@ class TestStdioPgroupReaping:
assert fake_pid not in _orphan_stdio_pids
assert fake_pid not in _stdio_pgids
def test_killpg_skipped_when_pgid_matches_gateway_own_pgroup(self, monkeypatch):
"""#47134: when a tracked MCP child shares the gateway's OWN process
group, killpg(pgid) would signal the gateway itself and crash it.
The guard must skip killpg for that pgid and fall through to per-pid
os.kill instead."""
from tools.mcp_tool import (
_kill_orphaned_mcp_children,
_orphan_stdio_pids,
_stdio_pgids,
_lock,
)
if not hasattr(os, "killpg") or not hasattr(os, "getpgrp"):
pytest.skip("os.killpg/os.getpgrp not available on this platform")
self._reset_state()
gateway_pgid = 424242
fake_pid = 717171 # a child pid that resolves to the gateway's pgid
other_pid = 818181 # a normal child in its OWN (non-gateway) group
other_pgid = 818181
with _lock:
_orphan_stdio_pids.add(fake_pid)
_stdio_pgids[fake_pid] = gateway_pgid # == gateway's own pgid
_orphan_stdio_pids.add(other_pid)
_stdio_pgids[other_pid] = other_pgid # distinct group → killpg OK
fake_sigkill = 9
monkeypatch.setattr(signal, "SIGKILL", fake_sigkill, raising=False)
with patch("tools.mcp_tool.os.getpgrp", return_value=gateway_pgid), \
patch("tools.mcp_tool.os.killpg") as mock_killpg, \
patch("tools.mcp_tool.os.kill") as mock_kill, \
patch("gateway.status._pid_exists", return_value=True), \
patch("time.sleep"):
_kill_orphaned_mcp_children()
# killpg must NEVER be called for the gateway's own pgid (would self-kill).
killpg_pgids = [call.args[0] for call in mock_killpg.call_args_list]
assert gateway_pgid not in killpg_pgids, (
"killpg was called with the gateway's own pgid — self-kill (#47134)"
)
# The shared-pgid child must be reaped via per-pid kill instead.
mock_kill.assert_any_call(fake_pid, signal.SIGTERM)
mock_kill.assert_any_call(fake_pid, fake_sigkill)
# NEGATIVE CONTROL: a child in a DISTINCT group must STILL use killpg —
# the guard must skip only the gateway's own group, not all pgids.
assert other_pgid in killpg_pgids, (
"killpg must still be used for a non-gateway pgid (guard too broad)"
)
def test_killpg_failure_falls_back_to_kill(self, monkeypatch):
"""If killpg raises ProcessLookupError (pgroup gone), try os.kill."""
from tools.mcp_tool import (

View file

@ -107,6 +107,63 @@ def test_memory_gate_on_then_apply(hermes_home):
assert "approved entry" in store.user_entries[0]
def test_cli_memory_approve_without_live_agent_uses_fresh_store(hermes_home, capsys):
"""#46783: ``/memory approve`` from a context with no live agent (e.g. the
Desktop GUI) passed ``memory_store=None`` into the shared handler, which
returned "memory store unavailable" and applied nothing. The CLI handler must
fall back to a freshly loaded on-disk store, like the gateway path does."""
import json
from tools.memory_tool import memory_tool, MemoryStore
from tools import write_approval as wa
from hermes_cli.cli_commands_mixin import CLICommandsMixin
_set_approval("memory", True)
staging = MemoryStore(); staging.load_from_disk()
r = json.loads(memory_tool("add", "memory", "remember the launch date", store=staging))
assert r.get("pending_id"), r
assert wa.pending_count("memory") == 1
# Bare CLI handler with no live agent → store resolves to None pre-fix.
handler = CLICommandsMixin.__new__(CLICommandsMixin)
handler.agent = None
handler._handle_memory_command("/memory approve all")
out = capsys.readouterr().out
assert "memory store unavailable" not in out, out
assert "Approved 1" in out, out
assert wa.pending_count("memory") == 0
# The approved write landed in a freshly loaded on-disk store (MEMORY.md).
reloaded = MemoryStore(); reloaded.load_from_disk()
assert any("remember the launch date" in e for e in reloaded.memory_entries)
def test_load_on_disk_store_honors_configured_char_limits(hermes_home, monkeypatch):
"""load_on_disk_store() must read memory.memory_char_limit /
user_char_limit from config so approvals applied without a live agent
enforce the SAME caps as the live agent (agent_init.py). Falls back to
defaults when config can't be loaded.
"""
from tools.memory_tool import load_on_disk_store
# Config override path: helper picks up the configured limits.
monkeypatch.setattr(
"hermes_cli.config.load_config",
lambda: {"memory": {"memory_char_limit": 999, "user_char_limit": 444}},
)
store = load_on_disk_store()
assert store.memory_char_limit == 999
assert store.user_char_limit == 444
# Failure path: config raises → defaults, never blows up.
def _boom():
raise RuntimeError("no config")
monkeypatch.setattr("hermes_cli.config.load_config", _boom)
fallback = load_on_disk_store()
assert fallback.memory_char_limit == 2200
assert fallback.user_char_limit == 1375
# ---------------------------------------------------------------------------
# Skill gate
# ---------------------------------------------------------------------------

View file

@ -734,6 +734,100 @@ def test_session_resume_reuses_existing_live_session(server, monkeypatch):
assert all(sid == winner for sid in server._sessions)
def test_session_resume_reuses_live_agent_after_compression_rotation(server, monkeypatch):
"""Resume must match the live agent's current session_id, not stale session_key."""
target = "20260409_020202_child"
stale_parent = "20260409_010101_parent"
sid = "live-rotated"
server._sessions[sid] = {
"agent": types.SimpleNamespace(model="test/model", session_id=target),
"created_at": 123.0,
"display_history_prefix": [],
"history": [{"role": "assistant", "content": "live child"}],
"history_lock": threading.RLock(),
"last_active": 123.0,
"running": False,
"session_key": stale_parent,
"transport": server._stdio_transport,
}
class _DB:
def get_session(self, _sid):
return {"id": target}
def get_session_by_title(self, _title):
return None
def resolve_resume_session_id(self, _target):
return target
monkeypatch.setattr(server, "_get_db", lambda: _DB())
monkeypatch.setattr(server, "_emit", lambda *_args, **_kwargs: None)
monkeypatch.setattr(
server,
"_session_info",
lambda _agent, _session=None: {"model": "test/model"},
)
result = server.handle_request(
{
"id": "r1",
"method": "session.resume",
"params": {"session_id": target, "cols": 100},
}
)
assert "error" not in result
assert result["result"]["session_id"] == sid
assert result["result"]["session_key"] == target
assert len(server._sessions) == 1
def test_sync_session_key_after_compress_reanchors_active_session_lease(
server, monkeypatch, tmp_path
):
home = tmp_path / ".hermes"
monkeypatch.setenv("HERMES_HOME", str(home))
from hermes_cli.active_sessions import (
active_session_registry_snapshot,
try_acquire_active_session,
)
lease, message = try_acquire_active_session(
session_id="session-old",
surface="tui",
config={"max_concurrent_sessions": 1},
metadata={"live_session_id": "ui-1"},
)
assert message is None
assert lease is not None
session = {
"active_session_lease": lease,
"agent": types.SimpleNamespace(session_id="session-new"),
"session_key": "session-old",
}
fake_approval = types.SimpleNamespace(
disable_session_yolo=lambda *_args, **_kwargs: None,
enable_session_yolo=lambda *_args, **_kwargs: None,
is_session_yolo_enabled=lambda *_args, **_kwargs: False,
register_gateway_notify=lambda *_args, **_kwargs: None,
unregister_gateway_notify=lambda *_args, **_kwargs: None,
)
monkeypatch.setattr(server, "_restart_slash_worker", lambda *_args, **_kwargs: None)
with patch.dict(sys.modules, {"tools.approval": fake_approval}):
server._sync_session_key_after_compress("ui-1", session)
snapshot = active_session_registry_snapshot()
assert session["session_key"] == "session-new"
assert lease.session_id == "session-new"
assert [entry["session_id"] for entry in snapshot] == ["session-new"]
lease.release()
def test_session_resume_live_payload_uses_current_history_with_ancestors(server, monkeypatch):
"""Live resume should not reuse a stale ancestor-inclusive snapshot."""