mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-28 11:32:22 +00:00
Merge remote-tracking branch 'origin/main' into bb/pets-merge
# Conflicts: # hermes_cli/commands.py # tui_gateway/server.py
This commit is contained in:
commit
e495b33bf1
251 changed files with 23395 additions and 2720 deletions
|
|
@ -331,6 +331,131 @@ class TestResolveAnthropicToken:
|
|||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
assert resolve_anthropic_token() == "cc-auto-token"
|
||||
|
||||
def test_falls_back_to_anthropic_credential_pool_oauth(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
# Isolate source #4 (credential_pool): ensure source #3 (Claude Code
|
||||
# creds, incl. the macOS keychain read which Path.home does not cover)
|
||||
# returns nothing, mirroring a Hermes-PKCE-only setup.
|
||||
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
|
||||
|
||||
pool_entry = SimpleNamespace(
|
||||
auth_type="oauth",
|
||||
access_token="pool-oauth-token",
|
||||
)
|
||||
pool = SimpleNamespace(
|
||||
_available_entries=lambda **_kwargs: [pool_entry],
|
||||
)
|
||||
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
|
||||
|
||||
assert resolve_anthropic_token() == "pool-oauth-token"
|
||||
|
||||
def test_prefers_anthropic_credential_pool_oauth_over_api_key(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant...ykey")
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
# Pool (source #4) must win over ANTHROPIC_API_KEY (source #5); also
|
||||
# isolate source #3 so a machine-local Claude Code creds / keychain
|
||||
# entry can't short-circuit before the pool.
|
||||
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
|
||||
|
||||
pool_entry = SimpleNamespace(
|
||||
auth_type="oauth",
|
||||
access_token="pool-oauth-token",
|
||||
)
|
||||
pool = SimpleNamespace(
|
||||
_available_entries=lambda **_kwargs: [pool_entry],
|
||||
)
|
||||
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
|
||||
|
||||
assert resolve_anthropic_token() == "pool-oauth-token"
|
||||
|
||||
def test_pool_entry_with_null_access_token_does_not_crash(self, monkeypatch, tmp_path):
|
||||
"""A persisted OAuth entry with access_token=None must not crash the
|
||||
resolver (None.strip() would escape the helper's try/excepts and take
|
||||
down the whole resolver incl. the ANTHROPIC_API_KEY fallback). It should
|
||||
be skipped and the api-key fallback (source #5) should win."""
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant...ykey")
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
|
||||
|
||||
broken_entry = SimpleNamespace(auth_type="oauth", access_token=None)
|
||||
pool = SimpleNamespace(
|
||||
_available_entries=lambda **_kwargs: [broken_entry],
|
||||
)
|
||||
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
|
||||
|
||||
# Must fall through to source #5 (ANTHROPIC_API_KEY), not raise.
|
||||
assert resolve_anthropic_token() == "sk-ant...ykey"
|
||||
|
||||
def test_pool_api_key_only_entry_is_not_returned_as_token(self, monkeypatch, tmp_path):
|
||||
"""resolve_anthropic_token() returns an OAuth bearer token; a pool entry
|
||||
whose auth_type is api_key (not oauth) must NOT be returned from the pool
|
||||
path — those are consumed via the aux client's _pool_runtime_api_key
|
||||
lane, a different resolution concern."""
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
|
||||
|
||||
api_key_entry = SimpleNamespace(auth_type="api_key", access_token="sk-pool-apikey")
|
||||
pool = SimpleNamespace(
|
||||
_available_entries=lambda **_kwargs: [api_key_entry],
|
||||
)
|
||||
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
|
||||
|
||||
# No OAuth entry and no other source → None (the api_key entry is ignored here).
|
||||
assert resolve_anthropic_token() is None
|
||||
|
||||
def test_pool_is_not_consulted_when_env_token_present(self, monkeypatch, tmp_path):
|
||||
"""Source #1 (ANTHROPIC_TOKEN) must short-circuit before the pool: when
|
||||
it is set, load_pool must never be called (ordering contract #1 → #4)."""
|
||||
monkeypatch.setenv("ANTHROPIC_TOKEN", "env-token")
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
|
||||
|
||||
pool_calls = []
|
||||
|
||||
def _tracking_load_pool(provider):
|
||||
pool_calls.append(provider)
|
||||
raise AssertionError("load_pool must not be called when source #1 wins")
|
||||
|
||||
monkeypatch.setattr("agent.credential_pool.load_pool", _tracking_load_pool)
|
||||
|
||||
assert resolve_anthropic_token() == "env-token"
|
||||
assert pool_calls == []
|
||||
|
||||
def test_pool_resolution_is_read_only(self, monkeypatch, tmp_path):
|
||||
"""The resolver must enumerate the pool read-only — clear_expired and
|
||||
refresh must both be False so a bare resolve never writes auth.json or
|
||||
triggers a network refresh from diagnostic call sites (#50108 MED)."""
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
|
||||
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
|
||||
|
||||
captured = {}
|
||||
pool_entry = SimpleNamespace(auth_type="oauth", access_token="pool-oauth-token")
|
||||
|
||||
def _available_entries(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return [pool_entry]
|
||||
|
||||
pool = SimpleNamespace(_available_entries=_available_entries)
|
||||
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
|
||||
|
||||
assert resolve_anthropic_token() == "pool-oauth-token"
|
||||
assert captured == {"clear_expired": False, "refresh": False}
|
||||
|
||||
def test_prefers_refreshable_claude_code_credentials_over_static_anthropic_token(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.setenv("ANTHROPIC_TOKEN", "sk-ant-oat01-static-token")
|
||||
|
|
|
|||
|
|
@ -206,6 +206,35 @@ class TestProjectFacts:
|
|||
assert "Project: package.json" in block
|
||||
assert "Verify:" not in block
|
||||
|
||||
def test_detect_project_facts_structured(self, tmp_path):
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps({"scripts": {"test": "vitest", "dev": "vite"}})
|
||||
)
|
||||
(tmp_path / "pnpm-lock.yaml").write_text("")
|
||||
facts = cc.detect_project_facts(tmp_path)
|
||||
assert facts.manifests == ["package.json"]
|
||||
assert facts.package_managers == ["pnpm"]
|
||||
assert facts.verify_commands == ["pnpm run test"] # dev excluded
|
||||
assert facts.context_files == []
|
||||
|
||||
def test_project_facts_for_matches_prompt_block(self, tmp_path):
|
||||
# Invariant: the structured facts the UI consumes must not drift from the
|
||||
# commands the prompt snapshot renders — one detector feeds both.
|
||||
_git_init(tmp_path)
|
||||
(tmp_path / "package.json").write_text(
|
||||
json.dumps({"scripts": {"test": "vitest", "lint": "eslint ."}})
|
||||
)
|
||||
(tmp_path / "pnpm-lock.yaml").write_text("")
|
||||
facts = cc.project_facts_for(tmp_path)
|
||||
assert facts is not None
|
||||
verify_line = cc.build_coding_workspace_block(tmp_path).split("Verify:")[1].splitlines()[0]
|
||||
assert facts["verifyCommands"]
|
||||
for cmd in facts["verifyCommands"]:
|
||||
assert cmd in verify_line
|
||||
|
||||
def test_project_facts_for_none_outside_workspace(self, tmp_path):
|
||||
assert cc.project_facts_for(tmp_path) is None
|
||||
|
||||
|
||||
# ── $HOME dotfiles guard ────────────────────────────────────────────────────
|
||||
|
||||
|
|
|
|||
86
tests/agent/test_compression_progress.py
Normal file
86
tests/agent/test_compression_progress.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""Regression: detect compression progress by tokens, not just rows.
|
||||
|
||||
Issue #39548: preflight compression in the turn prologue was checking
|
||||
``len(messages) >= _orig_len`` to decide "Cannot compress further". This
|
||||
false-positives when a pass summarises message contents — reducing the
|
||||
estimated request token count without removing any rows — and surfaces a
|
||||
spurious ``Context length exceeded`` failure followed by an auto-reset of
|
||||
an otherwise healthy session.
|
||||
|
||||
These tests pin the contract of ``_compression_made_progress``: a
|
||||
row-count reduction OR a *material* (>5%) token-count reduction counts as
|
||||
progress.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from agent.turn_context import _compression_made_progress
|
||||
|
||||
|
||||
class TestCompressionMadeProgress:
|
||||
def test_rows_reduced_counts_as_progress(self):
|
||||
"""Removing message rows is the obvious progress signal."""
|
||||
assert _compression_made_progress(
|
||||
orig_len=10, new_len=5, orig_tokens=1000, new_tokens=1000
|
||||
) is True
|
||||
|
||||
def test_tokens_reduced_without_row_change_counts_as_progress(self):
|
||||
"""Issue #39548: 220 → 220 rows, 288k → 183k tokens IS progress."""
|
||||
assert _compression_made_progress(
|
||||
orig_len=220, new_len=220, orig_tokens=288_028, new_tokens=183_180
|
||||
) is True
|
||||
|
||||
def test_both_reduced_counts_as_progress(self):
|
||||
"""Common case: summarising drops some rows and shrinks the rest."""
|
||||
assert _compression_made_progress(
|
||||
orig_len=220, new_len=180, orig_tokens=288_028, new_tokens=150_000
|
||||
) is True
|
||||
|
||||
def test_neither_moved_means_no_progress(self):
|
||||
"""The genuine "stuck" case — same rows, same tokens, give up."""
|
||||
assert _compression_made_progress(
|
||||
orig_len=10, new_len=10, orig_tokens=1000, new_tokens=1000
|
||||
) is False
|
||||
|
||||
def test_rows_grew_and_tokens_grew_means_no_progress(self):
|
||||
"""Pathological: the pass made the request larger — definitely stuck."""
|
||||
assert _compression_made_progress(
|
||||
orig_len=10, new_len=12, orig_tokens=1000, new_tokens=1200
|
||||
) is False
|
||||
|
||||
def test_rows_grew_but_tokens_dropped_is_progress(self):
|
||||
"""Edge: summary rows may expand the row count while shrinking tokens.
|
||||
|
||||
Token reduction alone is sufficient to keep the loop going.
|
||||
"""
|
||||
assert _compression_made_progress(
|
||||
orig_len=10, new_len=11, orig_tokens=1000, new_tokens=600
|
||||
) is True
|
||||
|
||||
def test_tokens_grew_but_rows_dropped_is_progress(self):
|
||||
"""Edge: row reduction alone is sufficient even if tokens nominally
|
||||
creep up (e.g. summary verbosity). Row-count reduction is a hard
|
||||
signal that the transcript actually shrank.
|
||||
"""
|
||||
assert _compression_made_progress(
|
||||
orig_len=10, new_len=5, orig_tokens=1000, new_tokens=1100
|
||||
) is True
|
||||
|
||||
def test_sub_5pct_token_drop_is_not_progress(self):
|
||||
"""A token reduction below the 5% material floor does NOT count as
|
||||
progress — matching the overflow-handler retry path (#39550) so a
|
||||
marginal wobble can't keep the multi-pass loop spinning."""
|
||||
# 1000 -> 970 is a 3% drop, below the 5% floor.
|
||||
assert _compression_made_progress(
|
||||
orig_len=10, new_len=10, orig_tokens=1000, new_tokens=970
|
||||
) is False
|
||||
# 1000 -> 940 is a 6% drop, above the floor.
|
||||
assert _compression_made_progress(
|
||||
orig_len=10, new_len=10, orig_tokens=1000, new_tokens=940
|
||||
) is True
|
||||
|
||||
def test_zero_orig_tokens_is_not_progress(self):
|
||||
"""Degenerate estimate (0 tokens) must not be read as a token win."""
|
||||
assert _compression_made_progress(
|
||||
orig_len=10, new_len=10, orig_tokens=0, new_tokens=0
|
||||
) is False
|
||||
107
tests/agent/test_compressor_tool_call_budget.py
Normal file
107
tests/agent/test_compressor_tool_call_budget.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Regression tests for tool_call envelope accounting in the compression
|
||||
tail-protection budget walks (issue #28053).
|
||||
|
||||
The budget walks used to estimate an assistant message's tokens from
|
||||
content + ``function.arguments`` only, dropping each ``tool_call``'s ``id``,
|
||||
``type`` and ``function.name`` (plus JSON structure). For assistant turns
|
||||
that fan out into parallel tool calls this undercounted by 2-15x, so the
|
||||
protected tail overshot ``tail_token_budget`` and compression became
|
||||
ineffective. The fix routes all three walks through
|
||||
``_estimate_msg_budget_tokens``, which counts the full envelope.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent.context_compressor import (
|
||||
ContextCompressor,
|
||||
_CHARS_PER_TOKEN,
|
||||
_estimate_msg_budget_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _assistant_with_tool_calls(n_calls: int, *, args: str = '{"path":"a"}') -> dict:
|
||||
"""An assistant turn fanning into ``n_calls`` parallel tool calls with
|
||||
realistic id/name overhead but a small arguments string."""
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": f"call_{i:02d}_{'a' * 24}", # ~32 chars, UUID-ish id
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": args},
|
||||
}
|
||||
for i in range(n_calls)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _args_only_estimate(msg: dict) -> int:
|
||||
"""Reproduce the OLD (buggy) arguments-only walk for comparison."""
|
||||
content = msg.get("content") or ""
|
||||
tokens = len(content) // _CHARS_PER_TOKEN + 10
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict):
|
||||
tokens += len(tc.get("function", {}).get("arguments", "")) // _CHARS_PER_TOKEN
|
||||
return tokens
|
||||
|
||||
|
||||
class TestToolCallEnvelopeEstimate:
|
||||
def test_envelope_counted_not_just_arguments(self):
|
||||
msg = _assistant_with_tool_calls(4)
|
||||
new = _estimate_msg_budget_tokens(msg)
|
||||
old = _args_only_estimate(msg)
|
||||
# id/type/name + JSON structure dwarf the tiny arguments string.
|
||||
assert new > old * 3, (new, old)
|
||||
# The estimate covers the full serialized tool_call envelope.
|
||||
envelope = sum(len(str(tc)) for tc in msg["tool_calls"]) // _CHARS_PER_TOKEN
|
||||
assert new >= envelope
|
||||
|
||||
def test_scales_with_number_of_parallel_calls(self):
|
||||
one = _estimate_msg_budget_tokens(_assistant_with_tool_calls(1))
|
||||
five = _estimate_msg_budget_tokens(_assistant_with_tool_calls(5))
|
||||
assert five > one * 3
|
||||
|
||||
def test_no_tool_calls_matches_content_estimate(self):
|
||||
msg = {"role": "user", "content": "x" * 400}
|
||||
# Plain message: content//4 + 10 overhead, behavior unchanged.
|
||||
assert _estimate_msg_budget_tokens(msg) == 400 // _CHARS_PER_TOKEN + 10
|
||||
|
||||
def test_non_dict_tool_calls_do_not_crash(self):
|
||||
msg = {"role": "assistant", "content": "hi", "tool_calls": ["weird", None]}
|
||||
# Non-dict entries are ignored (as before) without raising.
|
||||
assert _estimate_msg_budget_tokens(msg) == len("hi") // _CHARS_PER_TOKEN + 10
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def compressor():
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
return ContextCompressor(
|
||||
model="test/model",
|
||||
threshold_percent=0.85,
|
||||
protect_first_n=2,
|
||||
protect_last_n=2,
|
||||
quiet_mode=True,
|
||||
)
|
||||
|
||||
|
||||
class TestTailCutAccountsForToolCalls:
|
||||
def test_tail_cut_stops_on_tool_call_heavy_tail(self, compressor):
|
||||
# 20 assistant turns, each fanning into 5 short-arg tool calls.
|
||||
heavy = [_assistant_with_tool_calls(5) for _ in range(20)]
|
||||
messages = [{"role": "user", "content": "start"}] + heavy
|
||||
|
||||
per_msg = _estimate_msg_budget_tokens(messages[-1])
|
||||
assert per_msg > 30 # sanity: a heavy turn is non-trivial once the envelope counts
|
||||
|
||||
# Budget sized so ~6 heavy turns fit under the 1.5x soft ceiling.
|
||||
token_budget = int(per_msg * 6 / 1.5)
|
||||
cut = compressor._find_tail_cut_by_tokens(messages, head_end=1, token_budget=token_budget)
|
||||
protected = len(messages) - cut
|
||||
|
||||
# With the envelope counted, the walk stops well short of protecting all
|
||||
# 20 turns. The old arguments-only estimate (~25 tokens/turn) never
|
||||
# reaches the ceiling and would protect the entire transcript.
|
||||
assert protected < len(heavy)
|
||||
assert 3 <= protected <= 12
|
||||
|
|
@ -86,6 +86,28 @@ class TestPreflightDeferral:
|
|||
|
||||
assert compressor.should_defer_preflight_to_real_usage(93_000) is False
|
||||
|
||||
def test_defers_immediately_after_compaction_with_stale_real_prompt(self, compressor):
|
||||
"""#36718: right after a compaction, last_real_prompt_tokens still holds
|
||||
the stale pre-compression value (above threshold). The awaiting flag
|
||||
must force deferral so preflight doesn't fire a SECOND compaction before
|
||||
real post-compaction usage arrives."""
|
||||
compressor.threshold_tokens = 85_000
|
||||
# Stale pre-compression value — would hit the `>= threshold => False`
|
||||
# short-circuit and defeat deferral without the flag guard.
|
||||
compressor.last_real_prompt_tokens = 120_000
|
||||
compressor.awaiting_real_usage_after_compression = True
|
||||
assert compressor.should_defer_preflight_to_real_usage(95_000) is True
|
||||
|
||||
def test_resumes_normal_deferral_after_flag_cleared(self, compressor):
|
||||
"""Once update_from_response() clears the flag, the normal baseline/
|
||||
growth deferral logic governs again (no permanent deferral)."""
|
||||
compressor.threshold_tokens = 85_000
|
||||
compressor.last_real_prompt_tokens = 120_000
|
||||
compressor.awaiting_real_usage_after_compression = False
|
||||
# Stale-high real prompt with the flag cleared => the >= threshold
|
||||
# short-circuit applies => no deferral.
|
||||
assert compressor.should_defer_preflight_to_real_usage(95_000) is False
|
||||
|
||||
|
||||
|
||||
class TestCompress:
|
||||
|
|
@ -242,6 +264,59 @@ class TestCompress:
|
|||
assert c.should_compress(55000) is True
|
||||
assert c.should_compress(40000) is False
|
||||
|
||||
def test_max_tokens_reservation_lowers_threshold(self):
|
||||
"""#43547: the provider reserves max_tokens out of the window, so the
|
||||
threshold must be based on (context_length - max_tokens), not the full
|
||||
window. A 200K model reserving 65536 output tokens has a ~134K input
|
||||
budget; at 50% that's ~67K, NOT 100K."""
|
||||
# No reservation (provider default) → full-window behavior, unchanged.
|
||||
assert ContextCompressor._compute_threshold_tokens(200000, 0.50) == 100000
|
||||
assert ContextCompressor._compute_threshold_tokens(200000, 0.50, None) == 100000
|
||||
# 65536 reserved → effective input budget 134464; 50% = 67232.
|
||||
assert ContextCompressor._compute_threshold_tokens(200000, 0.50, 65536) == 67232
|
||||
|
||||
def test_max_tokens_reservation_with_small_window_floors(self):
|
||||
"""With a large reservation on a smaller window the effective budget
|
||||
can drop near/below the minimum floor — the degenerate-window guard
|
||||
then triggers at 85% of the EFFECTIVE budget, never the raw window."""
|
||||
# 128K window, 65536 reserved → effective 62464 (< MINIMUM 64000).
|
||||
# Floor (64000) >= effective window (62464) → 85% of effective.
|
||||
t = ContextCompressor._compute_threshold_tokens(128000, 0.50, 65536)
|
||||
assert t == int(62464 * 0.85) # 53094
|
||||
assert t < 62464
|
||||
|
||||
def test_max_tokens_exceeding_window_falls_back_to_full(self):
|
||||
"""Pathological: max_tokens >= context_length would make the effective
|
||||
budget <= 0; fall back to the full window rather than produce a
|
||||
non-positive threshold."""
|
||||
t = ContextCompressor._compute_threshold_tokens(64000, 0.50, 70000)
|
||||
# effective_window <= 0 → fall back to full context (64000) → 85% guard.
|
||||
assert t == 54400 # 85% of 64000, same as no-reservation small-ctx case
|
||||
assert t > 0
|
||||
|
||||
def test_max_tokens_coercion_treats_non_int_as_no_reservation(self):
|
||||
"""A non-int / non-positive max_tokens must coerce safely so the
|
||||
threshold arithmetic never raises. Guards the path where a mocked
|
||||
parent agent forwards a MagicMock max_tokens into a child
|
||||
ContextCompressor (regression for the delegate-test TypeError:
|
||||
'<=' not supported between MagicMock and int)."""
|
||||
from unittest.mock import MagicMock
|
||||
assert ContextCompressor._coerce_max_tokens(None) is None
|
||||
assert ContextCompressor._coerce_max_tokens(0) is None
|
||||
assert ContextCompressor._coerce_max_tokens(-5) is None
|
||||
assert ContextCompressor._coerce_max_tokens("nope") is None
|
||||
assert ContextCompressor._coerce_max_tokens(65536) == 65536
|
||||
# The actual regression: building a compressor with a MagicMock
|
||||
# max_tokens must NOT raise (the unmocked code did `ctx - MagicMock`
|
||||
# then `MagicMock <= 0`). int(MagicMock()) returns 1, so coercion
|
||||
# yields a harmless positive int rather than crashing — the threshold
|
||||
# is computed cleanly with a 1-token reservation.
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=200000):
|
||||
c = ContextCompressor(model="m", quiet_mode=True, max_tokens=MagicMock())
|
||||
assert isinstance(c.max_tokens, int)
|
||||
assert isinstance(c.threshold_tokens, int)
|
||||
assert c.threshold_tokens > 0 # no crash, sane value
|
||||
|
||||
def test_compression_increments_count(self, compressor):
|
||||
msgs = self._make_messages(10)
|
||||
# Default config (abort_on_summary_failure=False) — fallback path
|
||||
|
|
|
|||
73
tests/agent/test_learn_prompt.py
Normal file
73
tests/agent/test_learn_prompt.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""Tests for /learn — open-ended skill distillation.
|
||||
|
||||
Covers the shared prompt builder (agent.learn_prompt.build_learn_prompt) and
|
||||
the slash-command registry wiring. /learn has no engine and no model tool: it
|
||||
builds a standards-guided prompt that the live agent runs as a normal turn, so
|
||||
these are the load-bearing behavior contracts.
|
||||
"""
|
||||
|
||||
from agent.learn_prompt import build_learn_prompt, _AUTHORING_STANDARDS
|
||||
|
||||
|
||||
class TestBuildLearnPrompt:
|
||||
def test_embeds_the_user_request_verbatim(self):
|
||||
req = "the REST client in ~/projects/acme-sdk, focus on auth"
|
||||
prompt = build_learn_prompt(req)
|
||||
assert req in prompt
|
||||
|
||||
def test_always_includes_the_authoring_standards(self):
|
||||
# The standards are what make distilled skills match house style;
|
||||
# they must travel with every prompt regardless of input.
|
||||
for req in ["", "a url https://x/y", "what we just did"]:
|
||||
assert _AUTHORING_STANDARDS in build_learn_prompt(req)
|
||||
|
||||
def test_instructs_saving_via_skill_manage_not_a_raw_file(self):
|
||||
prompt = build_learn_prompt("learn the thing")
|
||||
assert "skill_manage" in prompt
|
||||
|
||||
def test_references_gather_tools_for_open_ended_sourcing(self):
|
||||
# Open-ended sourcing relies on the agent's own tools, named so it
|
||||
# knows dirs/URLs/conversation/paste all route through existing tools.
|
||||
prompt = build_learn_prompt("learn from somewhere")
|
||||
for tool in ("read_file", "search_files", "web_extract"):
|
||||
assert tool in prompt
|
||||
|
||||
def test_empty_request_falls_back_to_the_conversation(self):
|
||||
# Bare /learn should distill "what we just did", not error.
|
||||
prompt = build_learn_prompt("")
|
||||
assert "conversation" in prompt.lower()
|
||||
# And still carries the standards + save instruction.
|
||||
assert "skill_manage" in prompt
|
||||
|
||||
def test_whitespace_only_request_is_treated_as_empty(self):
|
||||
assert build_learn_prompt(" \n ") == build_learn_prompt("")
|
||||
|
||||
def test_description_length_rule_is_in_the_standards(self):
|
||||
# The single most-violated rule must be explicit in the prompt.
|
||||
assert "60" in _AUTHORING_STANDARDS
|
||||
|
||||
|
||||
class TestLearnRegistryWiring:
|
||||
def test_learn_is_registered_and_resolves(self):
|
||||
from hermes_cli.commands import resolve_command
|
||||
|
||||
cmd = resolve_command("learn")
|
||||
assert cmd is not None
|
||||
assert cmd.name == "learn"
|
||||
|
||||
def test_learn_is_in_tools_and_skills_category(self):
|
||||
from hermes_cli.commands import resolve_command
|
||||
|
||||
assert resolve_command("learn").category == "Tools & Skills"
|
||||
|
||||
def test_learn_works_on_the_gateway(self):
|
||||
# /learn must reach the gateway runner (it's a both-surfaces command),
|
||||
# not be CLI-only.
|
||||
from hermes_cli.commands import GATEWAY_KNOWN_COMMANDS
|
||||
|
||||
assert "learn" in GATEWAY_KNOWN_COMMANDS
|
||||
|
||||
def test_learn_is_not_cli_only(self):
|
||||
from hermes_cli.commands import resolve_command
|
||||
|
||||
assert not resolve_command("learn").cli_only
|
||||
|
|
@ -1172,16 +1172,12 @@ class TestOnMemoryWriteBridge:
|
|||
mgr.on_memory_write("replace", "user", "updated pref")
|
||||
assert p.memory_writes == [("replace", "user", "updated pref")]
|
||||
|
||||
def test_on_memory_write_remove_not_bridged(self):
|
||||
"""The bridge intentionally skips 'remove' — only add/replace notify."""
|
||||
# This tests the contract that run_agent.py checks:
|
||||
# function_args.get("action") in ("add", "replace")
|
||||
def test_on_memory_write_remove_supported_by_manager(self):
|
||||
"""The manager forwards remove actions when a caller elects to bridge them."""
|
||||
mgr = MemoryManager()
|
||||
p = FakeMemoryProvider("ext")
|
||||
mgr.add_provider(p)
|
||||
|
||||
# Manager itself doesn't filter — run_agent.py does.
|
||||
# But providers should handle remove gracefully.
|
||||
mgr.on_memory_write("remove", "memory", "old fact")
|
||||
assert p.memory_writes == [("remove", "memory", "old fact")]
|
||||
|
||||
|
|
|
|||
145
tests/agent/test_memory_write_bridge.py
Normal file
145
tests/agent/test_memory_write_bridge.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""Behavior tests for the built-in memory → external provider bridge.
|
||||
|
||||
The bridge lives behind the MemoryManager interface
|
||||
(``MemoryManager.notify_memory_tool_write``): the agent loop hands over the raw
|
||||
built-in memory tool result + args, and the manager decides whether/what to
|
||||
mirror to external providers. These tests drive that method with a fake
|
||||
external provider and assert which ``on_memory_write`` calls land.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.memory_manager import MemoryManager
|
||||
from agent.memory_provider import MemoryProvider
|
||||
|
||||
|
||||
class _RecordingProvider(MemoryProvider):
|
||||
"""Minimal external provider that records on_memory_write calls."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls = []
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "recording"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def get_tool_schemas(self):
|
||||
return []
|
||||
|
||||
def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
def on_memory_write(self, action, target, content, metadata=None):
|
||||
self.calls.append({
|
||||
"action": action,
|
||||
"target": target,
|
||||
"content": content,
|
||||
"metadata": dict(metadata or {}),
|
||||
})
|
||||
|
||||
|
||||
def _manager_with_provider():
|
||||
mgr = MemoryManager()
|
||||
provider = _RecordingProvider()
|
||||
mgr.add_provider(provider)
|
||||
return mgr, provider
|
||||
|
||||
|
||||
def test_notifies_remove_with_old_text_after_success():
|
||||
mgr, provider = _manager_with_provider()
|
||||
mgr.notify_memory_tool_write(
|
||||
json.dumps({"success": True}),
|
||||
{"action": "remove", "target": "memory", "old_text": "stale preference entry"},
|
||||
)
|
||||
assert provider.calls == [
|
||||
{
|
||||
"action": "remove",
|
||||
"target": "memory",
|
||||
"content": "",
|
||||
"metadata": {"old_text": "stale preference entry"},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_skips_failed_memory_write():
|
||||
mgr, provider = _manager_with_provider()
|
||||
mgr.notify_memory_tool_write(
|
||||
json.dumps({"success": False, "error": "No entry matched"}),
|
||||
{"action": "remove", "target": "memory", "old_text": "stale preference entry"},
|
||||
)
|
||||
assert provider.calls == []
|
||||
|
||||
|
||||
def test_skips_staged_memory_write():
|
||||
mgr, provider = _manager_with_provider()
|
||||
mgr.notify_memory_tool_write(
|
||||
json.dumps({"success": True, "staged": True, "pending_id": "abc123"}),
|
||||
{"action": "remove", "target": "memory", "old_text": "stale preference entry"},
|
||||
)
|
||||
assert provider.calls == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tool_result", [None, [], object(), "not-json"])
|
||||
def test_skips_unrecognized_tool_result_shape(tool_result):
|
||||
mgr, provider = _manager_with_provider()
|
||||
mgr.notify_memory_tool_write(
|
||||
tool_result,
|
||||
{"action": "add", "target": "memory", "content": "new fact"},
|
||||
)
|
||||
assert provider.calls == []
|
||||
|
||||
|
||||
def test_preserves_old_text_for_replace_and_remove_batch():
|
||||
mgr, provider = _manager_with_provider()
|
||||
mgr.notify_memory_tool_write(
|
||||
json.dumps({"success": True}),
|
||||
{
|
||||
"target": "user",
|
||||
"operations": [
|
||||
{"action": "replace", "old_text": "old preference", "content": "updated"},
|
||||
{"action": "remove", "old_text": "obsolete preference"},
|
||||
{"action": "add", "content": "new fact"},
|
||||
],
|
||||
},
|
||||
)
|
||||
assert provider.calls == [
|
||||
{"action": "replace", "target": "user", "content": "updated",
|
||||
"metadata": {"old_text": "old preference"}},
|
||||
{"action": "remove", "target": "user", "content": "",
|
||||
"metadata": {"old_text": "obsolete preference"}},
|
||||
{"action": "add", "target": "user", "content": "new fact", "metadata": {}},
|
||||
]
|
||||
|
||||
|
||||
def test_non_mutating_actions_are_not_mirrored():
|
||||
mgr, provider = _manager_with_provider()
|
||||
mgr.notify_memory_tool_write(
|
||||
json.dumps({"success": True}),
|
||||
{"action": "read", "target": "memory"},
|
||||
)
|
||||
assert provider.calls == []
|
||||
|
||||
|
||||
def test_build_metadata_callback_is_merged_per_op():
|
||||
mgr, provider = _manager_with_provider()
|
||||
mgr.notify_memory_tool_write(
|
||||
json.dumps({"success": True}),
|
||||
{"action": "add", "target": "memory", "content": "fact"},
|
||||
build_metadata=lambda: {"session_id": "s1", "tool_name": "memory"},
|
||||
)
|
||||
assert provider.calls == [
|
||||
{
|
||||
"action": "add",
|
||||
"target": "memory",
|
||||
"content": "fact",
|
||||
"metadata": {"session_id": "s1", "tool_name": "memory"},
|
||||
}
|
||||
]
|
||||
110
tests/agent/test_oneshot.py
Normal file
110
tests/agent/test_oneshot.py
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
"""Tests for agent.oneshot — shared one-off (stateless) LLM requests."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.oneshot import (
|
||||
PROMPT_TEMPLATES,
|
||||
render_template,
|
||||
run_oneshot,
|
||||
_strip_code_fence,
|
||||
_truncate,
|
||||
)
|
||||
|
||||
|
||||
class TestRenderTemplate:
|
||||
def test_unknown_template_raises(self):
|
||||
with pytest.raises(KeyError):
|
||||
render_template("does-not-exist", {})
|
||||
|
||||
def test_commit_message_template_is_registered(self):
|
||||
assert "commit_message" in PROMPT_TEMPLATES
|
||||
|
||||
def test_commit_message_includes_diff_and_recent(self):
|
||||
instructions, user = render_template(
|
||||
"commit_message",
|
||||
{"diff": "diff --git a/x b/x\n+new", "recent_commits": "feat: a\nfix: b"},
|
||||
)
|
||||
# Instructions describe the contract (conventional commits), not a snapshot.
|
||||
assert "Conventional Commits" in instructions
|
||||
assert "diff --git a/x b/x" in user
|
||||
assert "feat: a" in user
|
||||
|
||||
def test_commit_message_diff_with_braces_passes_through(self):
|
||||
# Templates must not use str.format — code payloads carry literal { }.
|
||||
_, user = render_template("commit_message", {"diff": "x = {a: 1}"})
|
||||
assert "x = {a: 1}" in user
|
||||
|
||||
def test_commit_message_handles_missing_variables(self):
|
||||
instructions, user = render_template("commit_message", {})
|
||||
assert instructions
|
||||
assert "no textual diff available" in user
|
||||
|
||||
def test_commit_message_avoid_forces_new_message(self):
|
||||
# Passing the previous message must instruct the model not to repeat it,
|
||||
# so "regenerate" yields a different result even on greedy models.
|
||||
_, plain = render_template("commit_message", {"diff": "d"})
|
||||
_, regen = render_template("commit_message", {"diff": "d", "avoid": "feat: prior"})
|
||||
assert "feat: prior" in regen
|
||||
assert "do not repeat" in regen
|
||||
assert "feat: prior" not in plain
|
||||
|
||||
|
||||
class TestRunOneshot:
|
||||
def _mock_response(self, content):
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = content
|
||||
resp.choices[0].message.reasoning = None
|
||||
resp.choices[0].message.reasoning_content = None
|
||||
resp.choices[0].message.reasoning_details = None
|
||||
return resp
|
||||
|
||||
def test_template_path_calls_llm_with_rendered_prompt(self):
|
||||
with patch(
|
||||
"agent.oneshot.call_llm",
|
||||
return_value=self._mock_response("feat: add thing"),
|
||||
) as llm:
|
||||
out = run_oneshot(template="commit_message", variables={"diff": "d"})
|
||||
|
||||
assert out == "feat: add thing"
|
||||
messages = llm.call_args.kwargs["messages"]
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[1]["role"] == "user"
|
||||
|
||||
def test_explicit_instructions_path(self):
|
||||
with patch(
|
||||
"agent.oneshot.call_llm",
|
||||
return_value=self._mock_response("hello"),
|
||||
) as llm:
|
||||
out = run_oneshot(instructions="be brief", user_input="say hi")
|
||||
|
||||
assert out == "hello"
|
||||
messages = llm.call_args.kwargs["messages"]
|
||||
assert messages[0]["content"] == "be brief"
|
||||
assert messages[1]["content"] == "say hi"
|
||||
|
||||
def test_requires_template_or_prompt(self):
|
||||
with pytest.raises(ValueError):
|
||||
run_oneshot()
|
||||
|
||||
def test_strips_wrapping_code_fence(self):
|
||||
with patch(
|
||||
"agent.oneshot.call_llm",
|
||||
return_value=self._mock_response("```\nfix: bug\n```"),
|
||||
):
|
||||
assert run_oneshot(instructions="x", user_input="y") == "fix: bug"
|
||||
|
||||
|
||||
class TestHelpers:
|
||||
def test_truncate_under_limit_unchanged(self):
|
||||
assert _truncate("short", 100) == "short"
|
||||
|
||||
def test_truncate_over_limit_marks_truncation(self):
|
||||
out = _truncate("x" * 200, 50)
|
||||
assert out.endswith("…(truncated)")
|
||||
assert len(out) < 200
|
||||
|
||||
def test_strip_code_fence_without_fence_is_noop(self):
|
||||
assert _strip_code_fence("plain text") == "plain text"
|
||||
|
|
@ -100,7 +100,13 @@ class _StubAgent:
|
|||
pass
|
||||
|
||||
|
||||
def _run(agent):
|
||||
def _run(
|
||||
agent,
|
||||
*,
|
||||
final_response=None,
|
||||
api_call_count=3,
|
||||
turn_exit_reason="unknown",
|
||||
):
|
||||
messages = [
|
||||
{"role": "user", "content": "do a thing"},
|
||||
{
|
||||
|
|
@ -114,8 +120,8 @@ def _run(agent):
|
|||
]
|
||||
return finalize_turn(
|
||||
agent,
|
||||
final_response=None, # forces the max-iterations summary path
|
||||
api_call_count=3,
|
||||
final_response=final_response,
|
||||
api_call_count=api_call_count,
|
||||
interrupted=False,
|
||||
failed=False,
|
||||
messages=messages,
|
||||
|
|
@ -125,7 +131,7 @@ def _run(agent):
|
|||
user_message="do a thing",
|
||||
original_user_message="do a thing",
|
||||
_should_review_memory=False,
|
||||
_turn_exit_reason="unknown",
|
||||
_turn_exit_reason=turn_exit_reason,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -162,4 +168,17 @@ def test_clean_turn_has_no_cleanup_errors_key():
|
|||
agent = _StubAgent(raise_in=())
|
||||
result = _run(agent)
|
||||
assert result["final_response"] == "PARTIAL SUMMARY FROM MODEL"
|
||||
assert result["completed"] is False
|
||||
assert "cleanup_errors" not in result
|
||||
|
||||
|
||||
def test_text_response_on_last_allowed_call_is_completed():
|
||||
agent = _StubAgent(raise_in=())
|
||||
result = _run(
|
||||
agent,
|
||||
final_response="final report",
|
||||
api_call_count=agent.max_iterations,
|
||||
turn_exit_reason="text_response(finish_reason=stop)",
|
||||
)
|
||||
assert result["final_response"] == "final report"
|
||||
assert result["completed"] is True
|
||||
|
|
|
|||
85
tests/ci/test_classify_changes.py
Normal file
85
tests/ci/test_classify_changes.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
"""Tests for scripts/ci/classify_changes.py.
|
||||
|
||||
Check some common patterns of file modifications and the CI lanes they should run.
|
||||
We should always fail open. We may run a lane we didn't need, never skip one a
|
||||
change could have broken.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
_PATH = Path(__file__).resolve().parents[2] / "scripts" / "ci" / "classify_changes.py"
|
||||
_spec = importlib.util.spec_from_file_location("classify_changes", _PATH)
|
||||
if _spec is None or _spec.loader is None:
|
||||
raise ImportError("Failed to load classify_changes.py")
|
||||
_mod = importlib.util.module_from_spec(_spec)
|
||||
_spec.loader.exec_module(_mod)
|
||||
classify = _mod.classify
|
||||
|
||||
DEFAULT = {
|
||||
"python": True,
|
||||
"frontend": True,
|
||||
"docker_meta": True,
|
||||
"site": True,
|
||||
"scan": True,
|
||||
"deps": True,
|
||||
"mcp_catalog": False,
|
||||
}
|
||||
|
||||
|
||||
def _lanes(python=False, frontend=False, site=False, scan=False, deps=False, mcp_catalog=False, docker_meta=False) -> dict[str, bool]:
|
||||
return {
|
||||
"python": python,
|
||||
"frontend": frontend,
|
||||
"docker_meta": docker_meta,
|
||||
"site": site,
|
||||
"scan": scan,
|
||||
"deps": deps,
|
||||
"mcp_catalog": mcp_catalog,
|
||||
}
|
||||
|
||||
|
||||
CASES = {
|
||||
"docs-only → nothing heavy": (["README.md", "docs/guide.md"], _lanes()),
|
||||
"python source → python": (["run_agent.py"], _lanes(python=True, scan=True)),
|
||||
"dep manifest → python": (["pyproject.toml"], _lanes(python=True, scan=True, deps=True)),
|
||||
"uv.lock → python": (["uv.lock"], _lanes(python=True)),
|
||||
"ts package → frontend": (["apps/desktop/src/app.tsx"], _lanes(frontend=True)),
|
||||
"ui-tui → frontend": (["ui-tui/src/entry.ts"], _lanes(frontend=True)),
|
||||
# Lockfile bump shifts every TS package's tree, but not the Python suite.
|
||||
"root lockfile → frontend, not python": (["package-lock.json"], _lanes(frontend=True)),
|
||||
"website → site": (["website/docs/intro.md"], _lanes(site=True)),
|
||||
# SKILL.md reads like docs, but the skill-doc tests read skills/, so a
|
||||
# skill edit must still run Python.
|
||||
"skill md → python + site": (["skills/github/SKILL.md"], _lanes(python=True, site=True)),
|
||||
"dockerfile → docker meta": (["Dockerfile"], _lanes(docker_meta=True)),
|
||||
# Unknown top-level file keeps Python on rather than risk a silent skip.
|
||||
"unknown toplevel → python": (["Makefile"], _lanes(python=True)),
|
||||
"mixed docs+python → python": (["README.md", "agent/x.py"], _lanes(python=True, scan=True)),
|
||||
"mixed docs+frontend → frontend": (["README.md", "apps/x.tsx"], _lanes(frontend=True)),
|
||||
# Supply-chain lanes
|
||||
".pth file → scan": (["evil.pth"], _lanes(python=True, scan=True)),
|
||||
"setup.py → scan": (["setup.py"], _lanes(python=True, scan=True)),
|
||||
"mcp catalog manifest → mcp_catalog": (
|
||||
["optional-mcps/foo/manifest.yaml"],
|
||||
_lanes(python=True, mcp_catalog=True),
|
||||
),
|
||||
"mcp_catalog.py → mcp_catalog": (
|
||||
["hermes_cli/mcp_catalog.py"],
|
||||
_lanes(python=True, scan=True, mcp_catalog=True),
|
||||
),
|
||||
# Fail open: CI-config / empty / blank diffs run everything.
|
||||
".github change → all": ([".github/workflows/tests.yml"], DEFAULT),
|
||||
"action change → all": ([".github/actions/detect-changes/action.yml"], DEFAULT),
|
||||
"empty diff → all": ([], DEFAULT),
|
||||
"blank lines → all": (["", " "], DEFAULT),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("files,expected", CASES.values(), ids=CASES.keys())
|
||||
def test_classify(files, expected):
|
||||
assert classify(files) == expected
|
||||
|
|
@ -189,3 +189,82 @@ def test_indicators_independent_agents_and_processes(monkeypatch):
|
|||
rendered = "".join(text for _style, text in frags)
|
||||
assert "▶ 1" in rendered
|
||||
assert "⚙ 2" in rendered
|
||||
|
||||
|
||||
# ── Background/async subagent indicator (⛓ N) ─────────────────────────────
|
||||
# Source of truth is tools.async_delegation.active_count() — the count of
|
||||
# delegate_task delegations (batch + background single) still in the
|
||||
# "running" state. Distinct from ▶ (/background agent threads) and ⚙ (shell
|
||||
# processes); all three can be active at once.
|
||||
|
||||
|
||||
def _patch_async_active(monkeypatch, count: int) -> None:
|
||||
import tools.async_delegation as ad_mod
|
||||
monkeypatch.setattr(ad_mod, "active_count", lambda: count)
|
||||
|
||||
|
||||
def test_snapshot_reports_zero_when_no_background_subagents(monkeypatch):
|
||||
cli_obj = _make_cli()
|
||||
_patch_async_active(monkeypatch, 0)
|
||||
snap = cli_obj._get_status_bar_snapshot()
|
||||
assert snap["active_background_subagents"] == 0
|
||||
|
||||
|
||||
def test_snapshot_counts_live_background_subagents(monkeypatch):
|
||||
cli_obj = _make_cli()
|
||||
_patch_async_active(monkeypatch, 4)
|
||||
snap = cli_obj._get_status_bar_snapshot()
|
||||
assert snap["active_background_subagents"] == 4
|
||||
|
||||
|
||||
def test_snapshot_safe_when_async_active_count_raises(monkeypatch):
|
||||
"""If active_count() raises the snapshot stays at 0; no propagate."""
|
||||
cli_obj = _make_cli()
|
||||
import tools.async_delegation as ad_mod
|
||||
|
||||
def _boom():
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(ad_mod, "active_count", _boom)
|
||||
snap = cli_obj._get_status_bar_snapshot()
|
||||
assert snap["active_background_subagents"] == 0
|
||||
|
||||
|
||||
def test_plain_text_status_shows_subagent_indicator_when_active(monkeypatch):
|
||||
cli_obj = _make_cli()
|
||||
_patch_async_active(monkeypatch, 3)
|
||||
text = cli_obj._build_status_bar_text(width=80)
|
||||
assert "⛓ 3" in text
|
||||
|
||||
|
||||
def test_plain_text_status_omits_subagent_indicator_when_idle(monkeypatch):
|
||||
cli_obj = _make_cli()
|
||||
_patch_async_active(monkeypatch, 0)
|
||||
text = cli_obj._build_status_bar_text(width=80)
|
||||
assert "⛓" not in text
|
||||
|
||||
|
||||
def test_fragments_include_subagent_segment_when_active(monkeypatch):
|
||||
cli_obj = _make_cli()
|
||||
_patch_async_active(monkeypatch, 2)
|
||||
cli_obj._status_bar_visible = True
|
||||
cli_obj._get_tui_terminal_width = lambda: 120 # type: ignore[method-assign]
|
||||
frags = cli_obj._get_status_bar_fragments()
|
||||
rendered = "".join(text for _style, text in frags)
|
||||
assert "⛓ 2" in rendered
|
||||
|
||||
|
||||
def test_all_three_background_indicators_independent(monkeypatch):
|
||||
"""▶ (agent tasks), ⚙ (shell processes), ⛓ (subagents) all coexist."""
|
||||
cli_obj = _make_cli()
|
||||
cli_obj._background_tasks = {"bg_a": _stub_thread()}
|
||||
_patch_process_registry(monkeypatch, 2)
|
||||
_patch_async_active(monkeypatch, 5)
|
||||
cli_obj._status_bar_visible = True
|
||||
cli_obj._get_tui_terminal_width = lambda: 120 # type: ignore[method-assign]
|
||||
frags = cli_obj._get_status_bar_fragments()
|
||||
rendered = "".join(text for _style, text in frags)
|
||||
assert "▶ 1" in rendered
|
||||
assert "⚙ 2" in rendered
|
||||
assert "⛓ 5" in rendered
|
||||
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ class TestHealthyTurnStillRuns:
|
|||
# Force the judge to say "continue" without touching the network.
|
||||
with patch(
|
||||
"hermes_cli.goals.judge_goal",
|
||||
return_value=("continue", "needs more steps", False),
|
||||
return_value=("continue", "needs more steps", False, None),
|
||||
):
|
||||
cli._maybe_continue_goal_after_turn()
|
||||
|
||||
|
|
@ -189,7 +189,7 @@ class TestHealthyTurnStillRuns:
|
|||
|
||||
with patch(
|
||||
"hermes_cli.goals.judge_goal",
|
||||
return_value=("done", "goal satisfied", False),
|
||||
return_value=("done", "goal satisfied", False, None),
|
||||
):
|
||||
cli._maybe_continue_goal_after_turn()
|
||||
|
||||
|
|
|
|||
80
tests/computer_use/test_cua_telemetry.py
Normal file
80
tests/computer_use/test_cua_telemetry.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
"""Tests for the cua-driver telemetry opt-in policy.
|
||||
|
||||
cua-driver ships anonymous PostHog telemetry ENABLED by default upstream.
|
||||
Hermes disables it unless the user opts in via
|
||||
``computer_use.cua_telemetry: true``. The policy is applied by injecting
|
||||
``CUA_DRIVER_RS_TELEMETRY_ENABLED=0`` into every cua-driver child env.
|
||||
|
||||
These assert the behavior contract (default disables, opt-in leaves the var
|
||||
untouched, config failure fails safe toward disabled), not specific config
|
||||
snapshots.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from tools.computer_use import cua_backend
|
||||
|
||||
|
||||
_VAR = "CUA_DRIVER_RS_TELEMETRY_ENABLED"
|
||||
|
||||
|
||||
class TestTelemetryDisabledFlag:
|
||||
def test_default_config_disables(self):
|
||||
# cua_telemetry absent / False => telemetry disabled.
|
||||
with patch("hermes_cli.config.load_config", return_value={}):
|
||||
assert cua_backend._cua_telemetry_disabled() is True
|
||||
|
||||
def test_explicit_false_disables(self):
|
||||
with patch("hermes_cli.config.load_config",
|
||||
return_value={"computer_use": {"cua_telemetry": False}}):
|
||||
assert cua_backend._cua_telemetry_disabled() is True
|
||||
|
||||
def test_opt_in_true_does_not_disable(self):
|
||||
with patch("hermes_cli.config.load_config",
|
||||
return_value={"computer_use": {"cua_telemetry": True}}):
|
||||
assert cua_backend._cua_telemetry_disabled() is False
|
||||
|
||||
def test_config_load_failure_fails_safe(self):
|
||||
# Unreadable config => default to disabling telemetry (privacy-safe).
|
||||
with patch("hermes_cli.config.load_config", side_effect=RuntimeError("boom")):
|
||||
assert cua_backend._cua_telemetry_disabled() is True
|
||||
|
||||
def test_missing_section_disables(self):
|
||||
with patch("hermes_cli.config.load_config", return_value={"other": {}}):
|
||||
assert cua_backend._cua_telemetry_disabled() is True
|
||||
|
||||
|
||||
class TestChildEnv:
|
||||
def test_disabled_injects_var_zero(self):
|
||||
with patch.object(cua_backend, "_cua_telemetry_disabled", return_value=True):
|
||||
env = cua_backend.cua_driver_child_env({"PATH": "/usr/bin"})
|
||||
assert env[_VAR] == "0"
|
||||
# base env is preserved
|
||||
assert env["PATH"] == "/usr/bin"
|
||||
|
||||
def test_opt_in_leaves_var_untouched(self):
|
||||
# When the user opts in, we must NOT set the var — the driver uses its
|
||||
# own default. If the base env already has a value, it is preserved.
|
||||
with patch.object(cua_backend, "_cua_telemetry_disabled", return_value=False):
|
||||
env = cua_backend.cua_driver_child_env({"PATH": "/usr/bin"})
|
||||
assert _VAR not in env
|
||||
|
||||
def test_opt_in_preserves_user_set_var(self):
|
||||
with patch.object(cua_backend, "_cua_telemetry_disabled", return_value=False):
|
||||
env = cua_backend.cua_driver_child_env({_VAR: "1", "PATH": "/usr/bin"})
|
||||
# user opted in and explicitly set it — don't clobber.
|
||||
assert env[_VAR] == "1"
|
||||
|
||||
def test_disabled_overrides_inherited_enabled(self):
|
||||
# Even if the parent process had telemetry enabled, the default policy
|
||||
# forces it off in the child.
|
||||
with patch.object(cua_backend, "_cua_telemetry_disabled", return_value=True):
|
||||
env = cua_backend.cua_driver_child_env({_VAR: "1"})
|
||||
assert env[_VAR] == "0"
|
||||
|
||||
def test_defaults_to_os_environ_when_no_base(self):
|
||||
with patch.object(cua_backend, "_cua_telemetry_disabled", return_value=True), \
|
||||
patch.dict("os.environ", {"SOME_MARKER": "yes"}, clear=False):
|
||||
env = cua_backend.cua_driver_child_env()
|
||||
assert env.get("SOME_MARKER") == "yes"
|
||||
assert env[_VAR] == "0"
|
||||
325
tests/computer_use/test_doctor.py
Normal file
325
tests/computer_use/test_doctor.py
Normal file
|
|
@ -0,0 +1,325 @@
|
|||
"""Tests for ``tools.computer_use.doctor``.
|
||||
|
||||
The doctor module drives cua-driver's stable ``health_report`` MCP tool over
|
||||
stdio JSON-RPC and renders the structured response. Most of the surface is
|
||||
about parsing what cua-driver hands back, plus the exit-code contract
|
||||
downstream consumers (CI / `hermes update`) rely on:
|
||||
|
||||
* Exit 0 when overall == "ok"
|
||||
* Exit 1 when overall in ("degraded", "failed") — at least one check
|
||||
failed but the tool itself ran successfully
|
||||
* Exit 2 when the cua-driver binary is missing or the protocol breaks
|
||||
|
||||
We do NOT spin up a real cua-driver — that lives in the cua-driver
|
||||
integration test suite (libs/cua-driver/rust/tests/integration/
|
||||
test_health_report_mcp.py). Here we mock the subprocess and assert the
|
||||
Hermes-side adapter behaves correctly against the documented response
|
||||
shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from io import StringIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
# ── helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _fake_proc_with_responses(*responses: dict) -> MagicMock:
|
||||
"""Build a MagicMock subprocess.Popen handle that yields one JSON-RPC
|
||||
response per `readline()` call, then returns "" (EOF)."""
|
||||
lines = [json.dumps(r) + "\n" for r in responses] + [""]
|
||||
proc = MagicMock()
|
||||
proc.stdin = MagicMock()
|
||||
proc.stdout = MagicMock()
|
||||
proc.stdout.readline = MagicMock(side_effect=lines)
|
||||
proc.stderr = MagicMock()
|
||||
proc.stderr.read = MagicMock(return_value="")
|
||||
proc.wait = MagicMock(return_value=0)
|
||||
proc.kill = MagicMock()
|
||||
return proc
|
||||
|
||||
|
||||
def _ok_report() -> dict:
|
||||
"""Minimal well-formed health_report response."""
|
||||
return {
|
||||
"schema_version": "1",
|
||||
"platform": "darwin",
|
||||
"driver_version": "0.5.8",
|
||||
"overall": "ok",
|
||||
"checks": [
|
||||
{"name": "binary_version", "status": "pass", "message": "cua-driver 0.5.8"},
|
||||
{"name": "tcc_accessibility", "status": "pass", "message": "Accessibility is granted."},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _degraded_report() -> dict:
|
||||
"""Report with one failing check — overall=degraded."""
|
||||
return {
|
||||
"schema_version": "1",
|
||||
"platform": "darwin",
|
||||
"driver_version": "0.5.8",
|
||||
"overall": "degraded",
|
||||
"checks": [
|
||||
{"name": "binary_version", "status": "pass", "message": "cua-driver 0.5.8"},
|
||||
{
|
||||
"name": "bundle_identity",
|
||||
"status": "fail",
|
||||
"message": "Process has no CFBundleIdentifier.",
|
||||
"hint": "Run inside CuaDriver.app",
|
||||
"data": {"executable_path": "/tmp/cua-driver"},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ── exit codes ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDoctorExitCodes:
|
||||
def test_ok_exits_0(self):
|
||||
from tools.computer_use import doctor
|
||||
|
||||
proc = _fake_proc_with_responses(
|
||||
{"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
|
||||
)
|
||||
with patch("shutil.which", return_value="/fake/cua-driver"), \
|
||||
patch("subprocess.Popen", return_value=proc), \
|
||||
patch("sys.stdout", new_callable=StringIO):
|
||||
code = doctor.run_doctor()
|
||||
assert code == 0
|
||||
|
||||
def test_degraded_exits_1(self):
|
||||
from tools.computer_use import doctor
|
||||
|
||||
proc = _fake_proc_with_responses(
|
||||
{"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _degraded_report()}},
|
||||
)
|
||||
with patch("shutil.which", return_value="/fake/cua-driver"), \
|
||||
patch("subprocess.Popen", return_value=proc), \
|
||||
patch("sys.stdout", new_callable=StringIO):
|
||||
code = doctor.run_doctor()
|
||||
assert code == 1
|
||||
|
||||
def test_failed_overall_exits_1(self):
|
||||
"""`failed` overall (every check failed) is also exit 1, not 2 —
|
||||
the tool ran successfully; the diagnosis was bad."""
|
||||
from tools.computer_use import doctor
|
||||
|
||||
report = _degraded_report()
|
||||
report["overall"] = "failed"
|
||||
proc = _fake_proc_with_responses(
|
||||
{"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": report}},
|
||||
)
|
||||
with patch("shutil.which", return_value="/fake/cua-driver"), \
|
||||
patch("subprocess.Popen", return_value=proc), \
|
||||
patch("sys.stdout", new_callable=StringIO):
|
||||
code = doctor.run_doctor()
|
||||
assert code == 1
|
||||
|
||||
def test_missing_binary_exits_2(self):
|
||||
from tools.computer_use import doctor
|
||||
|
||||
with patch("shutil.which", return_value=None), \
|
||||
patch("sys.stdout", new_callable=StringIO):
|
||||
code = doctor.run_doctor()
|
||||
assert code == 2
|
||||
|
||||
def test_protocol_error_exits_2(self, capsys):
|
||||
"""An empty stdout response (driver crashed during handshake) is a
|
||||
protocol failure → exit 2."""
|
||||
from tools.computer_use import doctor
|
||||
|
||||
proc = MagicMock()
|
||||
proc.stdin = MagicMock()
|
||||
proc.stdout = MagicMock()
|
||||
proc.stdout.readline = MagicMock(return_value="") # EOF on initialize
|
||||
proc.stderr = MagicMock()
|
||||
proc.stderr.read = MagicMock(return_value="boom\n")
|
||||
proc.wait = MagicMock(return_value=0)
|
||||
proc.kill = MagicMock()
|
||||
|
||||
with patch("shutil.which", return_value="/fake/cua-driver"), \
|
||||
patch("subprocess.Popen", return_value=proc):
|
||||
code = doctor.run_doctor()
|
||||
assert code == 2
|
||||
# stderr should mention the failure
|
||||
captured = capsys.readouterr()
|
||||
assert "cua-driver" in captured.err.lower() or "health_report" in captured.err.lower()
|
||||
|
||||
|
||||
# ── response-shape parsing ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestResponseShapeParsing:
|
||||
def test_prefers_structuredContent(self):
|
||||
from tools.computer_use import doctor
|
||||
|
||||
proc = _fake_proc_with_responses(
|
||||
{"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
|
||||
)
|
||||
with patch("shutil.which", return_value="/fake/cua-driver"), \
|
||||
patch("subprocess.Popen", return_value=proc), \
|
||||
patch("sys.stdout", new_callable=StringIO) as out:
|
||||
doctor.run_doctor()
|
||||
# Header line includes driver version + platform + overall.
|
||||
text = out.getvalue()
|
||||
assert "darwin" in text
|
||||
assert "ok" in text
|
||||
|
||||
def test_falls_back_to_text_content_when_structuredContent_absent(self):
|
||||
"""Older cua-driver builds may emit health_report as a text content
|
||||
item carrying the JSON — the doctor should still parse it."""
|
||||
from tools.computer_use import doctor
|
||||
|
||||
proc = _fake_proc_with_responses(
|
||||
{"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
{
|
||||
"jsonrpc": "2.0", "id": 2,
|
||||
"result": {
|
||||
"content": [
|
||||
{"type": "text", "text": json.dumps(_ok_report())},
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
with patch("shutil.which", return_value="/fake/cua-driver"), \
|
||||
patch("subprocess.Popen", return_value=proc), \
|
||||
patch("sys.stdout", new_callable=StringIO) as out:
|
||||
code = doctor.run_doctor()
|
||||
assert code == 0
|
||||
assert "ok" in out.getvalue()
|
||||
|
||||
def test_jsonrpc_error_response_exits_2(self, capsys):
|
||||
from tools.computer_use import doctor
|
||||
|
||||
proc = _fake_proc_with_responses(
|
||||
{"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
{"jsonrpc": "2.0", "id": 2, "error": {"code": -32601, "message": "method not found"}},
|
||||
)
|
||||
with patch("shutil.which", return_value="/fake/cua-driver"), \
|
||||
patch("subprocess.Popen", return_value=proc):
|
||||
code = doctor.run_doctor()
|
||||
assert code == 2
|
||||
assert "method not found" in capsys.readouterr().err
|
||||
|
||||
|
||||
# ── args / arg passthrough ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestArgPassthrough:
|
||||
def test_include_passed_through_to_tools_call(self):
|
||||
from tools.computer_use import doctor
|
||||
|
||||
proc = _fake_proc_with_responses(
|
||||
{"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
|
||||
)
|
||||
with patch("shutil.which", return_value="/fake/cua-driver"), \
|
||||
patch("subprocess.Popen", return_value=proc), \
|
||||
patch("sys.stdout", new_callable=StringIO):
|
||||
doctor.run_doctor(include=["binary_version", "tcc_accessibility"])
|
||||
|
||||
# Inspect the second write to stdin — the tools/call payload.
|
||||
writes = [call.args[0] for call in proc.stdin.write.call_args_list]
|
||||
call_payload = next(json.loads(w) for w in writes if "tools/call" in w)
|
||||
assert call_payload["params"]["arguments"]["include"] == [
|
||||
"binary_version", "tcc_accessibility",
|
||||
]
|
||||
|
||||
def test_skip_passed_through(self):
|
||||
from tools.computer_use import doctor
|
||||
|
||||
proc = _fake_proc_with_responses(
|
||||
{"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
|
||||
)
|
||||
with patch("shutil.which", return_value="/fake/cua-driver"), \
|
||||
patch("subprocess.Popen", return_value=proc), \
|
||||
patch("sys.stdout", new_callable=StringIO):
|
||||
doctor.run_doctor(skip=["bundle_identity"])
|
||||
writes = [call.args[0] for call in proc.stdin.write.call_args_list]
|
||||
call_payload = next(json.loads(w) for w in writes if "tools/call" in w)
|
||||
assert call_payload["params"]["arguments"]["skip"] == ["bundle_identity"]
|
||||
|
||||
def test_no_filters_sends_empty_arguments(self):
|
||||
"""When neither include nor skip is given, the arguments object is
|
||||
empty — not present-but-null — so the driver's default 'run every
|
||||
check' branch fires."""
|
||||
from tools.computer_use import doctor
|
||||
|
||||
proc = _fake_proc_with_responses(
|
||||
{"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
|
||||
)
|
||||
with patch("shutil.which", return_value="/fake/cua-driver"), \
|
||||
patch("subprocess.Popen", return_value=proc), \
|
||||
patch("sys.stdout", new_callable=StringIO):
|
||||
doctor.run_doctor()
|
||||
writes = [call.args[0] for call in proc.stdin.write.call_args_list]
|
||||
call_payload = next(json.loads(w) for w in writes if "tools/call" in w)
|
||||
assert call_payload["params"]["arguments"] == {}
|
||||
|
||||
|
||||
# ── json output ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestJsonOutput:
|
||||
def test_json_output_is_parseable_round_trip(self):
|
||||
from tools.computer_use import doctor
|
||||
|
||||
proc = _fake_proc_with_responses(
|
||||
{"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
|
||||
)
|
||||
with patch("shutil.which", return_value="/fake/cua-driver"), \
|
||||
patch("subprocess.Popen", return_value=proc), \
|
||||
patch("sys.stdout", new_callable=StringIO) as out:
|
||||
doctor.run_doctor(json_output=True)
|
||||
# Verify the captured text round-trips through json.loads and matches
|
||||
# the input report (the contract: --json passes the structured payload
|
||||
# through unchanged so downstream tooling can consume it directly).
|
||||
parsed = json.loads(out.getvalue())
|
||||
assert parsed == _ok_report()
|
||||
|
||||
|
||||
# ── HERMES_CUA_DRIVER_CMD resolution ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestDriverCmdResolution:
|
||||
def test_explicit_driver_cmd_arg_wins(self):
|
||||
from tools.computer_use import doctor
|
||||
|
||||
proc = _fake_proc_with_responses(
|
||||
{"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
|
||||
)
|
||||
with patch("shutil.which", return_value="/fake/explicit-binary") as which_mock, \
|
||||
patch("subprocess.Popen", return_value=proc), \
|
||||
patch("sys.stdout", new_callable=StringIO):
|
||||
doctor.run_doctor(driver_cmd="/custom/path/cua-driver")
|
||||
# shutil.which should have been called with the explicit arg, not
|
||||
# the env-var / default resolver.
|
||||
which_mock.assert_called_with("/custom/path/cua-driver")
|
||||
|
||||
def test_env_var_used_when_no_arg_given(self, monkeypatch):
|
||||
from tools.computer_use import doctor
|
||||
|
||||
monkeypatch.setenv("HERMES_CUA_DRIVER_CMD", "/env/path/cua-driver")
|
||||
proc = _fake_proc_with_responses(
|
||||
{"jsonrpc": "2.0", "id": 1, "result": {}},
|
||||
{"jsonrpc": "2.0", "id": 2, "result": {"structuredContent": _ok_report()}},
|
||||
)
|
||||
with patch("shutil.which", return_value="/env/path/cua-driver") as which_mock, \
|
||||
patch("subprocess.Popen", return_value=proc), \
|
||||
patch("sys.stdout", new_callable=StringIO):
|
||||
doctor.run_doctor()
|
||||
# First (and only) which call should have used the env var.
|
||||
which_mock.assert_called_with("/env/path/cua-driver")
|
||||
|
|
@ -14,10 +14,7 @@ import pytest
|
|||
def temp_home(tmp_path, monkeypatch):
|
||||
"""Isolated HERMES_HOME so jobs.json doesn't touch the real store."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# NOTE: cron.jobs resolves its store paths (JOBS_FILE, CRON_DIR) from
|
||||
# get_default_hermes_root() at IMPORT time, so setting HERMES_HOME here does
|
||||
# not re-point an already-imported module's store. These tests exercise the
|
||||
# claim logic on in-memory job dicts and don't depend on the on-disk path.
|
||||
# cron.jobs caches no home at import; get_hermes_home() reads the env live.
|
||||
yield tmp_path
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,105 +0,0 @@
|
|||
"""Regression tests for #32091 — profile-scoped cron jobs orphaned.
|
||||
|
||||
Cron storage (CRON_DIR/JOBS_FILE) must anchor at the *default root* Hermes
|
||||
home, not the active profile's home. Otherwise a job created from a
|
||||
profile-scoped agent session writes to ~/.hermes/profiles/<p>/cron/jobs.json,
|
||||
while the profile-less gateway reads only ~/.hermes/cron/jobs.json — the job
|
||||
is silently orphaned (looks healthy in `list`, never fires).
|
||||
"""
|
||||
import importlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_cron_storage_anchors_at_root_under_profile(tmp_path, monkeypatch):
|
||||
"""Under a profile HERMES_HOME (<root>/profiles/<name>), the cron store
|
||||
resolves to <root>/cron, NOT <root>/profiles/<name>/cron."""
|
||||
root = tmp_path / "hermes_home"
|
||||
profile_home = root / "profiles" / "myprofile"
|
||||
profile_home.mkdir(parents=True)
|
||||
|
||||
# Pretend the platform default root IS our tmp root, and the active
|
||||
# HERMES_HOME is a profile under it (the #32091 scenario).
|
||||
import hermes_constants
|
||||
monkeypatch.setattr(hermes_constants, "_get_platform_default_hermes_home",
|
||||
lambda: root)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_home))
|
||||
|
||||
# get_default_hermes_root must return the ROOT, not the profile dir.
|
||||
assert hermes_constants.get_default_hermes_root().resolve() == root.resolve()
|
||||
# ...while get_hermes_home (used elsewhere) follows the profile override.
|
||||
assert hermes_constants.get_hermes_home().resolve() == profile_home.resolve()
|
||||
|
||||
# cron/jobs.py computes HERMES_DIR from get_default_hermes_root at import,
|
||||
# so a fresh import under this env anchors the store at <root>/cron.
|
||||
import cron.jobs as jobs
|
||||
importlib.reload(jobs)
|
||||
try:
|
||||
assert jobs.HERMES_DIR.resolve() == root.resolve()
|
||||
assert jobs.JOBS_FILE.resolve() == (root / "cron" / "jobs.json").resolve()
|
||||
# The orphan path (<profile>/cron/jobs.json) must NOT be the store.
|
||||
assert jobs.JOBS_FILE.resolve() != (profile_home / "cron" / "jobs.json").resolve()
|
||||
finally:
|
||||
# Restore module state for other tests (reload under the real env).
|
||||
monkeypatch.undo()
|
||||
importlib.reload(jobs)
|
||||
|
||||
|
||||
def test_cron_storage_unaffected_when_no_profile(tmp_path, monkeypatch):
|
||||
"""With no profile (HERMES_HOME == root), behavior is unchanged: store at
|
||||
<root>/cron."""
|
||||
root = tmp_path / "hermes_home"
|
||||
root.mkdir(parents=True)
|
||||
import hermes_constants
|
||||
monkeypatch.setattr(hermes_constants, "_get_platform_default_hermes_home",
|
||||
lambda: root)
|
||||
monkeypatch.setenv("HERMES_HOME", str(root))
|
||||
|
||||
import cron.jobs as jobs
|
||||
importlib.reload(jobs)
|
||||
try:
|
||||
assert jobs.JOBS_FILE.resolve() == (root / "cron" / "jobs.json").resolve()
|
||||
finally:
|
||||
monkeypatch.undo()
|
||||
importlib.reload(jobs)
|
||||
|
||||
|
||||
def test_tick_lock_anchors_at_root_under_profile(tmp_path, monkeypatch):
|
||||
"""The cron tick lock must live at <root>/cron/.tick.lock, NOT the profile
|
||||
dir — otherwise tickers under different profiles grab different locks and
|
||||
double-fire the (now root-anchored) jobs store (#32091)."""
|
||||
import importlib
|
||||
root = tmp_path / "hermes_home"
|
||||
profile_home = root / "profiles" / "p"
|
||||
profile_home.mkdir(parents=True)
|
||||
import hermes_constants
|
||||
monkeypatch.setattr(hermes_constants, "_get_platform_default_hermes_home", lambda: root)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_home))
|
||||
import cron.scheduler as sched
|
||||
importlib.reload(sched)
|
||||
try:
|
||||
# _hermes_home override is None -> uses get_default_hermes_root()
|
||||
sched._hermes_home = None
|
||||
lock_dir, lock_file = sched._get_lock_paths()
|
||||
assert lock_dir.resolve() == (root / "cron").resolve()
|
||||
assert lock_file.resolve() == (root / "cron" / ".tick.lock").resolve()
|
||||
assert lock_dir.resolve() != (profile_home / "cron").resolve()
|
||||
finally:
|
||||
monkeypatch.undo()
|
||||
importlib.reload(sched)
|
||||
|
||||
|
||||
def test_get_default_hermes_root_docker_layouts(tmp_path, monkeypatch):
|
||||
"""get_default_hermes_root resolves the root for Docker/custom HERMES_HOME
|
||||
(outside ~/.hermes), so cron storage works in containers."""
|
||||
import hermes_constants
|
||||
native = tmp_path / "native_home"
|
||||
monkeypatch.setattr(hermes_constants, "_get_platform_default_hermes_home", lambda: native)
|
||||
|
||||
# Docker custom root (outside native): HERMES_HOME itself IS the root.
|
||||
monkeypatch.setenv("HERMES_HOME", "/opt/data")
|
||||
assert hermes_constants.get_default_hermes_root() == Path("/opt/data")
|
||||
|
||||
# Docker profile layout: <custom>/profiles/<name> -> <custom>.
|
||||
monkeypatch.setenv("HERMES_HOME", "/opt/data/profiles/coder")
|
||||
assert hermes_constants.get_default_hermes_root() == Path("/opt/data")
|
||||
|
|
@ -7,11 +7,75 @@ from unittest.mock import AsyncMock, patch, MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from cron.scheduler import _resolve_origin, _resolve_delivery_target, _deliver_result, _send_media_via_adapter, run_job, SILENT_MARKER, _build_job_prompt
|
||||
from cron.scheduler import _resolve_origin, _resolve_delivery_target, _deliver_result, _send_media_via_adapter, run_job, SILENT_MARKER, _build_job_prompt, _resolve_cron_enabled_toolsets, _merge_mcp_into_per_job_toolsets
|
||||
from tools.env_passthrough import clear_env_passthrough
|
||||
from tools.credential_files import clear_credential_files
|
||||
|
||||
|
||||
class TestPerJobToolsetMcpMerge:
|
||||
"""A per-job enabled_toolsets allowlist must not silently drop MCP servers."""
|
||||
|
||||
CFG = {
|
||||
"mcp_servers": {
|
||||
"finnhub": {"enabled": True},
|
||||
"playwright": {"enabled": True},
|
||||
"disabled_one": {"enabled": False},
|
||||
"string_enabled": {"enabled": "true"},
|
||||
"not_a_dict": "ignored",
|
||||
}
|
||||
}
|
||||
|
||||
def _enabled_names(self):
|
||||
return {"finnhub", "playwright", "string_enabled"}
|
||||
|
||||
def test_native_only_list_gets_all_enabled_mcp_servers(self):
|
||||
result = _merge_mcp_into_per_job_toolsets(["web", "terminal"], self.CFG)
|
||||
assert result[:2] == ["web", "terminal"]
|
||||
assert set(result) == {"web", "terminal"} | self._enabled_names()
|
||||
|
||||
def test_disabled_servers_are_not_added(self):
|
||||
result = _merge_mcp_into_per_job_toolsets(["web"], self.CFG)
|
||||
assert "disabled_one" not in result
|
||||
|
||||
def test_explicit_mcp_name_is_treated_as_allowlist(self):
|
||||
# User named one server -> add nothing further.
|
||||
result = _merge_mcp_into_per_job_toolsets(["web", "finnhub"], self.CFG)
|
||||
assert result == ["web", "finnhub"]
|
||||
assert "playwright" not in result
|
||||
|
||||
def test_no_mcp_sentinel_opts_out_and_is_stripped(self):
|
||||
result = _merge_mcp_into_per_job_toolsets(["web", "no_mcp"], self.CFG)
|
||||
assert result == ["web"]
|
||||
assert not (set(result) & self._enabled_names())
|
||||
|
||||
def test_no_mcp_config_adds_nothing(self):
|
||||
result = _merge_mcp_into_per_job_toolsets(["web"], {})
|
||||
assert result == ["web"]
|
||||
|
||||
def test_no_duplicate_when_listed_name_also_globally_enabled(self):
|
||||
result = _merge_mcp_into_per_job_toolsets(["finnhub", "finnhub"], self.CFG)
|
||||
assert result.count("finnhub") == 2 # input dups preserved, none added
|
||||
|
||||
def test_resolver_uses_merge_for_per_job_lists(self):
|
||||
job = {"enabled_toolsets": ["web", "terminal"]}
|
||||
result = _resolve_cron_enabled_toolsets(job, self.CFG)
|
||||
assert set(result) == {"web", "terminal"} | self._enabled_names()
|
||||
|
||||
def test_resolver_empty_per_job_falls_through_to_platform(self):
|
||||
# No per-job list -> must delegate to _get_platform_tools (the platform
|
||||
# fallback), NOT the per-job merge. Stub the platform resolver and assert
|
||||
# it is the path taken and its result is returned.
|
||||
job = {"enabled_toolsets": None}
|
||||
sentinel = ["web", "finnhub"]
|
||||
with patch("hermes_cli.tools_config._get_platform_tools",
|
||||
return_value=set(sentinel)) as m_platform:
|
||||
result = _resolve_cron_enabled_toolsets(job, self.CFG)
|
||||
m_platform.assert_called_once()
|
||||
# _get_platform_tools args: (cfg, "cron")
|
||||
assert m_platform.call_args[0][1] == "cron"
|
||||
assert set(result) == set(sentinel)
|
||||
|
||||
|
||||
class TestResolveOrigin:
|
||||
def test_full_origin(self):
|
||||
job = {
|
||||
|
|
@ -1330,6 +1394,52 @@ class TestRunJobSessionPersistence:
|
|||
assert error is None
|
||||
assert final_response == "all good"
|
||||
|
||||
def test_run_job_delivers_max_iteration_fallback_summary(self, tmp_path):
|
||||
"""Cron should deliver a usable max-iteration fallback summary.
|
||||
|
||||
A cron run can exhaust the iteration budget, get a final text summary
|
||||
from the no-tools fallback call, and still have ``completed=False`` in
|
||||
the generic agent result. That should not make cron raise the report
|
||||
text as a RuntimeError.
|
||||
"""
|
||||
job = {
|
||||
"id": "summary-job",
|
||||
"name": "summary",
|
||||
"prompt": "finish the report",
|
||||
}
|
||||
fake_db = MagicMock()
|
||||
|
||||
with patch("cron.scheduler._hermes_home", tmp_path), \
|
||||
patch("cron.scheduler._resolve_origin", return_value=None), \
|
||||
patch("dotenv.load_dotenv"), \
|
||||
patch("hermes_state.SessionDB", return_value=fake_db), \
|
||||
patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
return_value={
|
||||
"api_key": "***",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
},
|
||||
), \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.run_conversation.return_value = {
|
||||
"final_response": "final fallback report",
|
||||
"completed": False,
|
||||
"failed": False,
|
||||
"turn_exit_reason": "max_iterations_reached(60/60)",
|
||||
}
|
||||
mock_agent_cls.return_value = mock_agent
|
||||
|
||||
success, output, final_response, error = run_job(job)
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert final_response == "final fallback report"
|
||||
assert "final fallback report" in output
|
||||
assert "(FAILED)" not in output
|
||||
|
||||
def test_tick_marks_empty_response_as_error(self, tmp_path):
|
||||
"""When run_job returns success=True but final_response is empty,
|
||||
tick() should mark the job as error so last_status != 'ok'.
|
||||
|
|
|
|||
243
tests/gateway/relay/test_relay_going_idle.py
Normal file
243
tests/gateway/relay/test_relay_going_idle.py
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
"""Phase 5 §5.3 — going-idle / buffered-flip primitive (gateway side).
|
||||
|
||||
Exercises the WebSocketRelayTransport's going_idle/ack handshake, the
|
||||
buffered-inbound ack (a bufferId-carrying inbound is acked after the handler
|
||||
runs), the NET-NEW reconnect loop (re-dial + re-handshake after an unexpected
|
||||
close), and the RelayAdapter emitting going_idle from its existing drain
|
||||
(disconnect) transition. All against a real in-process websockets server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from gateway.relay.ws_transport import WebSocketRelayTransport, WEBSOCKETS_AVAILABLE
|
||||
|
||||
pytestmark = pytest.mark.skipif(not WEBSOCKETS_AVAILABLE, reason="websockets not installed")
|
||||
|
||||
if WEBSOCKETS_AVAILABLE:
|
||||
import websockets
|
||||
|
||||
|
||||
DESCRIPTOR = {
|
||||
"contract_version": 1,
|
||||
"platform": "discord",
|
||||
"label": "Discord",
|
||||
"max_message_length": 2000,
|
||||
"supports_draft_streaming": False,
|
||||
"supports_edit": True,
|
||||
"supports_threads": True,
|
||||
"markdown_dialect": "discord",
|
||||
"len_unit": "chars",
|
||||
}
|
||||
|
||||
|
||||
class _IdleAwareServer:
|
||||
"""Connector stub: descriptor on hello, acks going_idle, records inbound_acks,
|
||||
and can push buffered inbound frames (with bufferId) after handshake."""
|
||||
|
||||
def __init__(self):
|
||||
self.received: list[dict] = []
|
||||
self.inbound_acks: list[str] = []
|
||||
self.going_idle_count = 0
|
||||
self._server = None
|
||||
self.url = ""
|
||||
# Frames to push right after each handshake (e.g. buffered backlog replay).
|
||||
self._to_push: list[dict] = []
|
||||
self.connections = 0
|
||||
|
||||
async def start(self):
|
||||
self._server = await websockets.serve(self._handle, "127.0.0.1", 0)
|
||||
sock = next(iter(self._server.sockets))
|
||||
self.url = f"ws://127.0.0.1:{sock.getsockname()[1]}"
|
||||
|
||||
async def stop(self):
|
||||
if self._server is not None:
|
||||
self._server.close()
|
||||
await self._server.wait_closed()
|
||||
|
||||
async def _handle(self, ws):
|
||||
self.connections += 1
|
||||
try:
|
||||
async for raw in ws:
|
||||
for line in str(raw).split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
frame = json.loads(line)
|
||||
self.received.append(frame)
|
||||
await self._on_frame(ws, frame)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _on_frame(self, ws, frame):
|
||||
ftype = frame.get("type")
|
||||
if ftype == "hello":
|
||||
await ws.send(json.dumps({"type": "descriptor", "descriptor": DESCRIPTOR}) + "\n")
|
||||
for f in self._to_push:
|
||||
await ws.send(json.dumps(f) + "\n")
|
||||
elif ftype == "going_idle":
|
||||
self.going_idle_count += 1
|
||||
await ws.send(json.dumps({"type": "going_idle_ack"}) + "\n")
|
||||
elif ftype == "inbound_ack":
|
||||
self.inbound_acks.append(frame.get("bufferId"))
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def server():
|
||||
srv = _IdleAwareServer()
|
||||
await srv.start()
|
||||
yield srv
|
||||
await srv.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_go_idle_awaits_ack(server):
|
||||
t = WebSocketRelayTransport(server.url, "discord", "appShared")
|
||||
await t.connect()
|
||||
try:
|
||||
await t.handshake()
|
||||
acked = await t.go_idle(timeout_s=2)
|
||||
assert acked is True
|
||||
assert server.going_idle_count == 1
|
||||
assert any(f["type"] == "going_idle" for f in server.received)
|
||||
finally:
|
||||
await t.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_go_idle_returns_false_on_timeout(server):
|
||||
# A server that never acks going_idle -> go_idle returns False (caller closes anyway).
|
||||
async def no_ack(ws, frame):
|
||||
if frame.get("type") == "hello":
|
||||
await ws.send(json.dumps({"type": "descriptor", "descriptor": DESCRIPTOR}) + "\n")
|
||||
# deliberately ignore going_idle
|
||||
|
||||
server._on_frame = no_ack # type: ignore[assignment]
|
||||
t = WebSocketRelayTransport(server.url, "discord", "appShared")
|
||||
await t.connect()
|
||||
try:
|
||||
await t.handshake()
|
||||
acked = await t.go_idle(timeout_s=0.3)
|
||||
assert acked is False
|
||||
finally:
|
||||
await t.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffered_inbound_is_acked_after_handler(server):
|
||||
# A buffered delivery (bufferId present) is acked AFTER the handler runs; a
|
||||
# live delivery (no bufferId) is not acked.
|
||||
server._to_push = [
|
||||
{
|
||||
"type": "inbound",
|
||||
"event": {
|
||||
"text": "buffered",
|
||||
"message_type": "text",
|
||||
"source": {"platform": "discord", "chat_id": "c1", "chat_type": "dm"},
|
||||
},
|
||||
"bufferId": "buf-42",
|
||||
},
|
||||
{
|
||||
"type": "inbound",
|
||||
"event": {
|
||||
"text": "live",
|
||||
"message_type": "text",
|
||||
"source": {"platform": "discord", "chat_id": "c1", "chat_type": "dm"},
|
||||
},
|
||||
},
|
||||
]
|
||||
seen = []
|
||||
|
||||
async def handler(ev):
|
||||
seen.append(ev.text)
|
||||
|
||||
t = WebSocketRelayTransport(server.url, "discord", "appShared")
|
||||
t.set_inbound_handler(handler)
|
||||
await t.connect()
|
||||
try:
|
||||
await t.handshake()
|
||||
await asyncio.sleep(0.1)
|
||||
assert "buffered" in seen and "live" in seen
|
||||
# Only the buffered (bufferId) delivery was acked.
|
||||
assert server.inbound_acks == ["buf-42"]
|
||||
finally:
|
||||
await t.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_redials_after_unexpected_close():
|
||||
# A server that drops the FIRST connection right after handshake; the
|
||||
# transport with reconnect=True re-dials and handshakes again.
|
||||
drops = {"n": 0}
|
||||
srv = _IdleAwareServer()
|
||||
|
||||
async def handle(ws):
|
||||
srv.connections += 1
|
||||
async for raw in ws:
|
||||
for line in str(raw).split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
frame = json.loads(line)
|
||||
if frame.get("type") == "hello":
|
||||
await ws.send(json.dumps({"type": "descriptor", "descriptor": DESCRIPTOR}) + "\n")
|
||||
if drops["n"] == 0:
|
||||
drops["n"] += 1
|
||||
await ws.close() # force an unexpected close on the first connection
|
||||
return
|
||||
|
||||
srv._server = await websockets.serve(handle, "127.0.0.1", 0)
|
||||
sock = next(iter(srv._server.sockets))
|
||||
srv.url = f"ws://127.0.0.1:{sock.getsockname()[1]}"
|
||||
t = WebSocketRelayTransport(srv.url, "discord", "appShared", reconnect=True, reconnect_backoff_s=0.05)
|
||||
try:
|
||||
await t.connect()
|
||||
await t.handshake()
|
||||
# First connection is dropped server-side; the reconnect loop re-dials.
|
||||
await asyncio.sleep(0.5)
|
||||
assert srv.connections >= 2
|
||||
finally:
|
||||
await t.disconnect()
|
||||
srv._server.close()
|
||||
await srv._server.wait_closed()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_reconnect_after_deliberate_disconnect(server):
|
||||
t = WebSocketRelayTransport(server.url, "discord", "appShared", reconnect=True, reconnect_backoff_s=0.05)
|
||||
await t.connect()
|
||||
await t.handshake()
|
||||
before = server.connections
|
||||
await t.disconnect()
|
||||
await asyncio.sleep(0.3)
|
||||
# A deliberate disconnect must NOT trigger the reconnect loop.
|
||||
assert server.connections == before
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapter_emits_going_idle_on_disconnect(server):
|
||||
# The RelayAdapter emits going_idle as part of its existing disconnect (drain)
|
||||
# transition, then tears down the transport.
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.relay.adapter import RelayAdapter
|
||||
from gateway.relay.descriptor import CONTRACT_VERSION, CapabilityDescriptor
|
||||
|
||||
placeholder = CapabilityDescriptor(
|
||||
contract_version=CONTRACT_VERSION,
|
||||
platform="discord",
|
||||
label="Relay",
|
||||
max_message_length=4096,
|
||||
supports_draft_streaming=False,
|
||||
supports_edit=True,
|
||||
supports_threads=False,
|
||||
markdown_dialect="plain",
|
||||
len_unit="chars",
|
||||
)
|
||||
transport = WebSocketRelayTransport(server.url, "discord", "appShared")
|
||||
adapter = RelayAdapter(PlatformConfig(), placeholder, transport=transport)
|
||||
await adapter.connect()
|
||||
await adapter.disconnect()
|
||||
assert server.going_idle_count == 1
|
||||
192
tests/gateway/relay/test_relay_policy_send.py
Normal file
192
tests/gateway/relay/test_relay_policy_send.py
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
"""Unit tests for the gateway-side relay relevance-policy declaration (Phase 6 ζ).
|
||||
|
||||
Covers gateway.relay.relay_relevance_policy() (the projection of the agent's
|
||||
mention-gating / free-response / allow-bots config into the connector's generic
|
||||
vocabulary) and send_relay_policy() (the boot-time POST to /relay/policy). The
|
||||
connector HTTP POST is monkeypatched; the cross-repo E2E (connector repo,
|
||||
gateway_policy_driver.py) exercises the real route. These prove the PROJECTION
|
||||
mapping, the auth/skip logic, and the fail-soft boot behaviour.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import gateway.relay as relay
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_env(monkeypatch):
|
||||
for k in (
|
||||
"GATEWAY_RELAY_URL",
|
||||
"GATEWAY_RELAY_ID",
|
||||
"GATEWAY_RELAY_SECRET",
|
||||
"GATEWAY_RELAY_PLATFORM",
|
||||
"GATEWAY_RELAY_BOT_ID",
|
||||
"DISCORD_ALLOW_BOTS",
|
||||
):
|
||||
monkeypatch.delenv(k, raising=False)
|
||||
monkeypatch.setattr("gateway.run._load_gateway_config", lambda: {}, raising=False)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# relay_relevance_policy() — the projection
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
def test_projection_maps_require_mention_and_free_response(monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_RELAY_PLATFORM", "discord")
|
||||
monkeypatch.setattr(
|
||||
"gateway.run._load_gateway_config",
|
||||
lambda: {"discord": {"require_mention": True, "free_response_channels": ["c-support", "c-help"]}},
|
||||
raising=False,
|
||||
)
|
||||
pol = relay.relay_relevance_policy()
|
||||
assert pol == {
|
||||
"platform": "discord",
|
||||
"requireAddress": True,
|
||||
"freeResponseScopes": ["c-support", "c-help"],
|
||||
"allowOtherBots": False,
|
||||
}
|
||||
|
||||
|
||||
def test_projection_allow_other_bots_from_env(monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_RELAY_PLATFORM", "discord")
|
||||
monkeypatch.setenv("DISCORD_ALLOW_BOTS", "all")
|
||||
monkeypatch.setattr(
|
||||
"gateway.run._load_gateway_config",
|
||||
lambda: {"discord": {"require_mention": True}},
|
||||
raising=False,
|
||||
)
|
||||
pol = relay.relay_relevance_policy()
|
||||
assert pol is not None and pol["allowOtherBots"] is True
|
||||
|
||||
|
||||
def test_projection_comma_string_free_response(monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_RELAY_PLATFORM", "discord")
|
||||
monkeypatch.setattr(
|
||||
"gateway.run._load_gateway_config",
|
||||
lambda: {"discord": {"free_response_channels": "c1, c2 ,c3"}},
|
||||
raising=False,
|
||||
)
|
||||
pol = relay.relay_relevance_policy()
|
||||
assert pol is not None and pol["freeResponseScopes"] == ["c1", "c2", "c3"]
|
||||
|
||||
|
||||
def test_projection_falls_back_to_top_level_require_mention(monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_RELAY_PLATFORM", "discord")
|
||||
monkeypatch.setattr(
|
||||
"gateway.run._load_gateway_config",
|
||||
lambda: {"require_mention": True}, # top-level, no discord: block
|
||||
raising=False,
|
||||
)
|
||||
pol = relay.relay_relevance_policy()
|
||||
assert pol is not None and pol["requireAddress"] is True
|
||||
|
||||
|
||||
def test_projection_none_when_all_default(monkeypatch):
|
||||
# No require_mention, no free-response, no allow-bots ⇒ nothing to declare
|
||||
# (the connector's quiet default already matches).
|
||||
monkeypatch.setenv("GATEWAY_RELAY_PLATFORM", "discord")
|
||||
monkeypatch.setattr("gateway.run._load_gateway_config", lambda: {"discord": {}}, raising=False)
|
||||
assert relay.relay_relevance_policy() is None
|
||||
|
||||
|
||||
def test_projection_none_when_platform_unresolved(monkeypatch):
|
||||
# Default platform "relay" ⇒ no concrete fronted platform ⇒ nothing to project.
|
||||
monkeypatch.setattr(
|
||||
"gateway.run._load_gateway_config",
|
||||
lambda: {"discord": {"require_mention": True}},
|
||||
raising=False,
|
||||
)
|
||||
assert relay.relay_relevance_policy() is None
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# send_relay_policy() — the boot-time declaration
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
def _arm(monkeypatch, *, url="wss://connector.example/relay"):
|
||||
monkeypatch.setenv("GATEWAY_RELAY_URL", url)
|
||||
monkeypatch.setenv("GATEWAY_RELAY_ID", "gw-x")
|
||||
monkeypatch.setenv("GATEWAY_RELAY_SECRET", "s" * 48)
|
||||
monkeypatch.setenv("GATEWAY_RELAY_PLATFORM", "discord")
|
||||
|
||||
|
||||
def test_send_posts_projected_policy_with_token(monkeypatch):
|
||||
_arm(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
"gateway.run._load_gateway_config",
|
||||
lambda: {"discord": {"require_mention": True, "free_response_channels": ["c-support"]}},
|
||||
raising=False,
|
||||
)
|
||||
captured = {}
|
||||
|
||||
def _fake_post(*, policy_url, token, policy, timeout=15.0):
|
||||
captured["policy_url"] = policy_url
|
||||
captured["token"] = token
|
||||
captured["policy"] = policy
|
||||
return 200
|
||||
|
||||
monkeypatch.setattr(relay, "_post_policy", _fake_post)
|
||||
assert relay.send_relay_policy() is True
|
||||
assert captured["policy_url"] == "https://connector.example/relay/policy"
|
||||
assert captured["token"] # a real upgrade token was minted
|
||||
assert captured["policy"]["requireAddress"] is True
|
||||
assert captured["policy"]["freeResponseScopes"] == ["c-support"]
|
||||
|
||||
|
||||
def test_send_skips_when_no_secret(monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_RELAY_URL", "wss://connector.example/relay")
|
||||
monkeypatch.setenv("GATEWAY_RELAY_PLATFORM", "discord")
|
||||
# no GATEWAY_RELAY_ID / SECRET
|
||||
monkeypatch.setattr(
|
||||
"gateway.run._load_gateway_config",
|
||||
lambda: {"discord": {"require_mention": True}},
|
||||
raising=False,
|
||||
)
|
||||
called = {"n": 0}
|
||||
monkeypatch.setattr(relay, "_post_policy", lambda **k: called.__setitem__("n", called["n"] + 1) or 200)
|
||||
assert relay.send_relay_policy() is False
|
||||
assert called["n"] == 0 # never attempted without a secret to auth with
|
||||
|
||||
|
||||
def test_send_skips_when_nothing_to_declare(monkeypatch):
|
||||
_arm(monkeypatch)
|
||||
monkeypatch.setattr("gateway.run._load_gateway_config", lambda: {"discord": {}}, raising=False)
|
||||
called = {"n": 0}
|
||||
monkeypatch.setattr(relay, "_post_policy", lambda **k: called.__setitem__("n", called["n"] + 1) or 200)
|
||||
assert relay.send_relay_policy() is False
|
||||
assert called["n"] == 0 # no redundant write of the default
|
||||
|
||||
|
||||
def test_send_fail_soft_on_transport_error(monkeypatch):
|
||||
_arm(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
"gateway.run._load_gateway_config",
|
||||
lambda: {"discord": {"require_mention": True}},
|
||||
raising=False,
|
||||
)
|
||||
|
||||
def _boom(**kwargs):
|
||||
raise RuntimeError("connector unreachable")
|
||||
|
||||
monkeypatch.setattr(relay, "_post_policy", _boom)
|
||||
# Never raises; returns False so boot proceeds.
|
||||
assert relay.send_relay_policy() is False
|
||||
|
||||
|
||||
def test_send_fail_soft_on_non_200(monkeypatch):
|
||||
_arm(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
"gateway.run._load_gateway_config",
|
||||
lambda: {"discord": {"require_mention": True}},
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(relay, "_post_policy", lambda **k: 401)
|
||||
assert relay.send_relay_policy() is False
|
||||
|
||||
|
||||
def test_send_skips_when_relay_unconfigured(monkeypatch):
|
||||
# No GATEWAY_RELAY_URL ⇒ relay not configured ⇒ no-op.
|
||||
monkeypatch.setattr(relay, "_post_policy", lambda **k: 200)
|
||||
assert relay.send_relay_policy() is False
|
||||
|
|
@ -30,6 +30,7 @@ def _clean_env(monkeypatch):
|
|||
"GATEWAY_RELAY_ROUTE_KEYS",
|
||||
"GATEWAY_RELAY_PLATFORM",
|
||||
"GATEWAY_RELAY_BOT_ID",
|
||||
"GATEWAY_RELAY_INSTANCE_ID",
|
||||
):
|
||||
monkeypatch.delenv(k, raising=False)
|
||||
# Never read config.yaml off disk in these tests.
|
||||
|
|
@ -83,6 +84,24 @@ def test_relay_route_keys_empty():
|
|||
assert relay.relay_route_keys() == []
|
||||
|
||||
|
||||
def test_relay_instance_id_from_env(monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_RELAY_INSTANCE_ID", " inst-abc ")
|
||||
assert relay.relay_instance_id() == "inst-abc"
|
||||
|
||||
|
||||
def test_relay_instance_id_absent_is_none():
|
||||
assert relay.relay_instance_id() is None
|
||||
|
||||
|
||||
def test_relay_instance_id_from_config(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"gateway.run._load_gateway_config",
|
||||
lambda: {"gateway": {"relay_instance_id": "inst-from-config"}},
|
||||
raising=False,
|
||||
)
|
||||
assert relay.relay_instance_id() == "inst-from-config"
|
||||
|
||||
|
||||
def test_provision_url_maps_ws_to_http():
|
||||
assert relay._provision_url("wss://c.example/relay") == "https://c.example/relay/provision"
|
||||
assert relay._provision_url("ws://c.example/relay") == "http://c.example/relay/provision"
|
||||
|
|
@ -161,6 +180,81 @@ def test_outbound_only_when_no_endpoint(monkeypatch):
|
|||
assert relay.relay_connection_auth()[1] == "a" * 64
|
||||
|
||||
|
||||
# ─────────────────── instance-id forwarding (Phase 6 Unit α) ───────────────────
|
||||
|
||||
def test_forwards_instance_id_to_provision(monkeypatch):
|
||||
"""A managed agent stamped with GATEWAY_RELAY_INSTANCE_ID forwards it to the
|
||||
connector so it can bind gatewayId -> instanceId (per-instance routing)."""
|
||||
_arm(monkeypatch)
|
||||
monkeypatch.setenv("GATEWAY_RELAY_INSTANCE_ID", "inst-abc")
|
||||
captured: dict = {}
|
||||
monkeypatch.setattr(relay, "_post_provision", _stub_post(captured))
|
||||
|
||||
assert relay.self_provision_relay() is True
|
||||
assert captured["instance_id"] == "inst-abc"
|
||||
|
||||
|
||||
def test_instance_id_absent_forwards_none(monkeypatch):
|
||||
"""No stamp (self-hosted / pre-Phase-6) -> instance_id None; the connector
|
||||
stores null and per-instance routing simply has no binding yet."""
|
||||
_arm(monkeypatch)
|
||||
captured: dict = {}
|
||||
monkeypatch.setattr(relay, "_post_provision", _stub_post(captured))
|
||||
|
||||
assert relay.self_provision_relay() is True
|
||||
assert captured["instance_id"] is None
|
||||
|
||||
|
||||
def test_post_provision_body_includes_instanceId_only_when_set(monkeypatch):
|
||||
"""The real _post_provision adds `instanceId` to the JSON body ONLY when a
|
||||
value is supplied — omitting it lets the connector store null (back-compat),
|
||||
rather than binding an empty string."""
|
||||
import json
|
||||
|
||||
sent: dict = {}
|
||||
|
||||
class _Resp:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *a):
|
||||
return False
|
||||
|
||||
def read(self):
|
||||
return json.dumps({"secret": "a" * 64, "deliveryKey": "b" * 64, "tenant": "t", "gatewayId": "gw-1"}).encode()
|
||||
|
||||
def _fake_urlopen(req, timeout=None): # noqa: ANN001
|
||||
sent["body"] = json.loads(req.data.decode())
|
||||
return _Resp()
|
||||
|
||||
monkeypatch.setattr("urllib.request.urlopen", _fake_urlopen)
|
||||
|
||||
# With an instance id -> present in the body.
|
||||
relay._post_provision(
|
||||
provision_url="https://c.example/relay/provision",
|
||||
access_token="tok",
|
||||
gateway_id="gw-1",
|
||||
platform="discord",
|
||||
bot_id="app",
|
||||
gateway_endpoint=None,
|
||||
route_keys=[],
|
||||
instance_id="inst-abc",
|
||||
)
|
||||
assert sent["body"]["instanceId"] == "inst-abc"
|
||||
|
||||
# Without one -> the key is absent entirely (not "" ).
|
||||
relay._post_provision(
|
||||
provision_url="https://c.example/relay/provision",
|
||||
access_token="tok",
|
||||
gateway_id="gw-1",
|
||||
platform="discord",
|
||||
bot_id="app",
|
||||
gateway_endpoint=None,
|
||||
route_keys=[],
|
||||
)
|
||||
assert "instanceId" not in sent["body"]
|
||||
|
||||
|
||||
# ─────────────────────────── fail-soft ───────────────────────────
|
||||
|
||||
def test_no_nas_token_is_non_fatal(monkeypatch):
|
||||
|
|
|
|||
128
tests/gateway/test_approval_prompt_redaction.py
Normal file
128
tests/gateway/test_approval_prompt_redaction.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
"""Regression test for approval prompt credential redaction (issue #48456).
|
||||
|
||||
When Tirith flags a command for containing a credential-shaped pattern, the
|
||||
gateway approval prompt must redact the credential from the command text
|
||||
before sending it to the chat platform. Without this fix, the raw command
|
||||
(with the credential in plaintext) is sent verbatim to Telegram/Discord/etc.,
|
||||
undoing Tirith's redaction one layer up.
|
||||
|
||||
The redaction is wired through the module-level ``_redact_approval_command``
|
||||
seam. These tests bind that seam -- the production wiring -- not just the
|
||||
underlying ``redact_sensitive_text`` helper, so they fail if the redaction
|
||||
call is removed from either approval path.
|
||||
|
||||
Credential fixtures are built at runtime from a benign prefix + a run of
|
||||
``X`` characters (the same trick tests/agent/test_redact.py uses): they match
|
||||
the redactor regexes so the assertions stay meaningful, but contain no real
|
||||
or real-looking key, so secret scanners do not flag this file.
|
||||
"""
|
||||
|
||||
from gateway.run import _redact_approval_command
|
||||
|
||||
# Synthetic, scanner-safe credential fixtures. Each matches its redactor
|
||||
# regex (ghp_/sk-/JWT) but is unmistakably fake -- a run of X's, never a
|
||||
# real or real-format key.
|
||||
_FAKE_GHP = "ghp_" + "X" * 36
|
||||
_FAKE_OPENAI = "sk-proj-" + "X" * 40
|
||||
_FAKE_JWT = "eyJ" + "X" * 20 + "." + "eyJ" + "X" * 24 + "." + "X" * 30
|
||||
|
||||
|
||||
class TestRedactApprovalCommand:
|
||||
"""Contract for the approval-prompt redaction seam used by the gateway."""
|
||||
|
||||
def test_redacts_github_pat(self):
|
||||
raw = "curl -H 'Authorization: token " + _FAKE_GHP + "' https://api.github.com/user"
|
||||
out = _redact_approval_command(raw)
|
||||
assert _FAKE_GHP not in out
|
||||
# command structure preserved so the operator can still judge the action
|
||||
assert "curl" in out
|
||||
assert "github.com" in out
|
||||
|
||||
def test_redacts_openai_key(self):
|
||||
raw = "export OPENAI_API_KEY=" + _FAKE_OPENAI + " && python s.py"
|
||||
out = _redact_approval_command(raw)
|
||||
assert _FAKE_OPENAI not in out
|
||||
assert "python s.py" in out
|
||||
|
||||
def test_redacts_bearer_token(self):
|
||||
raw = "curl -H 'Authorization: Bearer " + _FAKE_JWT + "' https://api.example.com"
|
||||
out = _redact_approval_command(raw)
|
||||
assert _FAKE_JWT not in out
|
||||
|
||||
def test_clean_command_passes_through_unchanged(self):
|
||||
raw = "ls -la /tmp && echo hello"
|
||||
assert _redact_approval_command(raw) == raw
|
||||
|
||||
def test_forces_redaction_even_when_disabled(self, monkeypatch):
|
||||
"""force=True must redact even if security.redact_secrets is off -- the
|
||||
approval prompt is a hard secret-egress boundary regardless of config."""
|
||||
raw = "curl -H 'Authorization: token " + _FAKE_GHP + "' https://api.github.com"
|
||||
# With redaction globally disabled, the seam must STILL redact (force=True).
|
||||
monkeypatch.setattr("agent.redact._REDACT_ENABLED", False, raising=False)
|
||||
out = _redact_approval_command(raw)
|
||||
assert _FAKE_GHP not in out
|
||||
|
||||
def test_handles_none_and_empty(self):
|
||||
assert _redact_approval_command("") == ""
|
||||
assert _redact_approval_command(None) == ""
|
||||
|
||||
|
||||
class TestApprovalCommandWiring:
|
||||
"""Guard the production wiring on BOTH approval-notify transports:
|
||||
1. the chat-platform path (_approval_notify_sync in gateway/run.py), and
|
||||
2. the SSE/API path (_approval_notify in gateway/platforms/api_server.py),
|
||||
each of which must route the command through _redact_approval_command and
|
||||
REASSIGN the redacted value before any send/enqueue (so the raw command
|
||||
cannot reach a client). Uses AST (not char-offset string slicing) so a
|
||||
benign refactor doesn't cause a false failure, and so a discarded-result
|
||||
call (`_redact(cmd); send(cmd)`) does NOT pass."""
|
||||
|
||||
def _assert_redacts_then_uses(self, module, func_name: str, sink_substr: str):
|
||||
"""Parse `module`'s full AST, locate the (possibly nested) function
|
||||
`func_name`, and assert it contains an assignment
|
||||
`<x> = _redact_approval_command(...)` whose result is then used by a
|
||||
statement matching `sink_substr` on a LATER line. Walking the real AST
|
||||
(not a source slice) is refactor-robust and rejects discarded-result
|
||||
calls (the call must be an assignment, not a bare expression)."""
|
||||
import ast
|
||||
import inspect
|
||||
|
||||
source = inspect.getsource(module)
|
||||
tree = ast.parse(source)
|
||||
target_fn = None
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == func_name:
|
||||
target_fn = node
|
||||
break
|
||||
assert target_fn is not None, f"function {func_name} not found in {module.__name__}"
|
||||
|
||||
redact_line = None
|
||||
for node in ast.walk(target_fn):
|
||||
if isinstance(node, ast.Assign) and isinstance(node.value, ast.Call):
|
||||
fn = node.value.func
|
||||
if isinstance(fn, ast.Name) and fn.id == "_redact_approval_command":
|
||||
redact_line = node.lineno
|
||||
assert redact_line is not None, (
|
||||
f"{func_name} must assign the result of _redact_approval_command(...) "
|
||||
"(a discarded-result call would still leak the raw command)"
|
||||
)
|
||||
|
||||
sink_line = None
|
||||
for node in ast.walk(target_fn):
|
||||
seg = ast.get_source_segment(source, node)
|
||||
if seg and sink_substr in seg and getattr(node, "lineno", 0) > redact_line:
|
||||
sink_line = node.lineno
|
||||
break
|
||||
assert sink_line is not None, (
|
||||
f"`{sink_substr}` sink not found after the redaction in {func_name}"
|
||||
)
|
||||
|
||||
def test_chat_platform_path_redacts_before_send(self):
|
||||
import gateway.run as run
|
||||
|
||||
self._assert_redacts_then_uses(run, "_approval_notify_sync", "send_exec_approval")
|
||||
|
||||
def test_sse_api_path_redacts_before_enqueue(self):
|
||||
from gateway.platforms import api_server
|
||||
|
||||
self._assert_redacts_then_uses(api_server, "_approval_notify", "put_nowait")
|
||||
|
|
@ -281,3 +281,143 @@ async def test_platform_send_failure_raises_for_delivery_result(tmp_path, monkey
|
|||
|
||||
with pytest.raises(RuntimeError, match="route failed"):
|
||||
await router._deliver_to_platform(target, "hello", metadata={"telegram_reply_to_message_id": "9001"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cron output truncation / adapter-aware chunking (issue #50126)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class ChunkingAdapter:
|
||||
"""Adapter that declares splits_long_messages=True (like Discord/Telegram)."""
|
||||
splits_long_messages = True
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def send(self, chat_id, content, metadata=None):
|
||||
self.calls.append({"chat_id": chat_id, "content": content, "metadata": metadata})
|
||||
return {"success": True}
|
||||
|
||||
|
||||
class NonChunkingAdapter:
|
||||
"""Adapter without splits_long_messages (default False — legacy behavior)."""
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def send(self, chat_id, content, metadata=None):
|
||||
self.calls.append({"chat_id": chat_id, "content": content, "metadata": metadata})
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_output_truncated_for_non_chunking_adapter(tmp_path, monkeypatch):
|
||||
"""Non-chunking adapters receive truncated content with a footer + file save."""
|
||||
monkeypatch.setattr("gateway.delivery.get_hermes_home", lambda: tmp_path)
|
||||
adapter = NonChunkingAdapter()
|
||||
router = DeliveryRouter(GatewayConfig(), adapters={Platform.DISCORD: adapter})
|
||||
target = DeliveryTarget.parse("discord:123")
|
||||
|
||||
long_content = "x" * 5000
|
||||
await router._deliver_to_platform(target, long_content, metadata={"job_id": "job1"})
|
||||
|
||||
delivered = adapter.calls[0]["content"]
|
||||
assert len(delivered) < 5000 # was truncated
|
||||
assert "truncated" in delivered.lower()
|
||||
assert "full output saved to" in delivered
|
||||
# Full output was saved to disk
|
||||
saved_files = list(tmp_path.glob("cron/output/job1_*.txt"))
|
||||
assert len(saved_files) == 1
|
||||
assert saved_files[0].read_text() == long_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_output_preserved_for_chunking_adapter(tmp_path, monkeypatch):
|
||||
"""Chunking adapters (splits_long_messages=True) receive the FULL content."""
|
||||
monkeypatch.setattr("gateway.delivery.get_hermes_home", lambda: tmp_path)
|
||||
adapter = ChunkingAdapter()
|
||||
router = DeliveryRouter(GatewayConfig(), adapters={Platform.DISCORD: adapter})
|
||||
target = DeliveryTarget.parse("discord:123")
|
||||
|
||||
long_content = "x" * 5000
|
||||
await router._deliver_to_platform(target, long_content, metadata={"job_id": "job2"})
|
||||
|
||||
delivered = adapter.calls[0]["content"]
|
||||
assert delivered == long_content # NOT truncated — adapter handles chunking
|
||||
assert "truncated" not in delivered.lower()
|
||||
# Full output still saved to disk as audit trail
|
||||
saved_files = list(tmp_path.glob("cron/output/job2_*.txt"))
|
||||
assert len(saved_files) == 1
|
||||
assert saved_files[0].read_text() == long_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_short_output_never_truncated(tmp_path, monkeypatch):
|
||||
"""Output under the limit passes through untouched for any adapter."""
|
||||
monkeypatch.setattr("gateway.delivery.get_hermes_home", lambda: tmp_path)
|
||||
adapter = NonChunkingAdapter()
|
||||
router = DeliveryRouter(GatewayConfig(), adapters={Platform.DISCORD: adapter})
|
||||
target = DeliveryTarget.parse("discord:123")
|
||||
|
||||
short_content = "x" * 100
|
||||
await router._deliver_to_platform(target, short_content, metadata={"job_id": "job3"})
|
||||
|
||||
assert adapter.calls[0]["content"] == short_content
|
||||
# Nothing saved to disk
|
||||
assert not list(tmp_path.glob("cron/output/*.txt"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_save_failure_does_not_break_chunking_delivery(tmp_path, monkeypatch):
|
||||
"""If the audit save fails (disk full, permissions), chunking adapters
|
||||
still receive the full content — the save is best-effort."""
|
||||
monkeypatch.setattr("gateway.delivery.get_hermes_home", lambda: tmp_path)
|
||||
|
||||
adapter = ChunkingAdapter()
|
||||
router = DeliveryRouter(GatewayConfig(), adapters={Platform.DISCORD: adapter})
|
||||
target = DeliveryTarget.parse("discord:123")
|
||||
|
||||
long_content = "x" * 5000
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
def failing_save(content, job_id):
|
||||
call_count["n"] += 1
|
||||
raise OSError("No space left on device")
|
||||
|
||||
monkeypatch.setattr(router, "_save_full_output", failing_save)
|
||||
|
||||
# Should NOT raise — audit failure is caught for chunking adapters
|
||||
await router._deliver_to_platform(target, long_content, metadata={"job_id": "job6"})
|
||||
|
||||
# Adapter still got the full content
|
||||
assert adapter.calls[0]["content"] == long_content
|
||||
# Save was attempted (best-effort, swallowed)
|
||||
assert call_count["n"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_failure_during_truncation_raises_for_non_chunking_adapter(tmp_path, monkeypatch):
|
||||
"""For a non-chunking adapter, the truncation footer needs a valid saved
|
||||
path. If the save fails there, that is a real delivery problem and the
|
||||
error propagates (not swallowed like the chunking best-effort save)."""
|
||||
monkeypatch.setattr("gateway.delivery.get_hermes_home", lambda: tmp_path)
|
||||
|
||||
adapter = NonChunkingAdapter()
|
||||
router = DeliveryRouter(GatewayConfig(), adapters={Platform.DISCORD: adapter})
|
||||
target = DeliveryTarget.parse("discord:123")
|
||||
|
||||
long_content = "x" * 5000
|
||||
|
||||
def failing_save(content, job_id):
|
||||
raise OSError("No space left on device")
|
||||
|
||||
monkeypatch.setattr(router, "_save_full_output", failing_save)
|
||||
|
||||
# Non-chunking adapter must truncate → needs a valid saved path → the
|
||||
# Step 1 best-effort catch swallows the first attempt, but the Step 2
|
||||
# retry (footer needs the path) re-raises.
|
||||
with pytest.raises(OSError, match="No space left on device"):
|
||||
await router._deliver_to_platform(target, long_content, metadata={"job_id": "job7"})
|
||||
|
||||
|
||||
|
|
|
|||
516
tests/gateway/test_discord_double_dispatch.py
Normal file
516
tests/gateway/test_discord_double_dispatch.py
Normal file
|
|
@ -0,0 +1,516 @@
|
|||
"""Tests for Discord double-dispatch prevention (#51057).
|
||||
|
||||
When _auto_create_thread() creates a thread from a user message via
|
||||
message.create_thread(), Discord fires a second MESSAGE_CREATE event for
|
||||
the "thread starter message". That starter message carries
|
||||
``message.id == thread.id`` and may arrive with ``type=default``
|
||||
(instead of ``type=21 / thread_starter_message``), so the type filter
|
||||
does NOT catch it — resulting in two agent runs and two responses.
|
||||
|
||||
Fix: after _auto_create_thread succeeds, pre-seed the dedup cache with
|
||||
``str(thread.id)`` so the duplicate starter-message event is dropped.
|
||||
|
||||
Two sub-scenarios are tested:
|
||||
1. Thread-starter as a duplicate MESSAGE_CREATE (the primary bug).
|
||||
2. When text_batch_delay=0 the dispatch path is direct (no batching).
|
||||
The same dedup pre-seed must still protect against the duplicate.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord mock setup
|
||||
# The tests/gateway/conftest.py already installs a comprehensive discord
|
||||
# mock at collection time. We import the adapter AFTER that is done.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import plugins.platforms.discord.adapter as discord_platform # noqa: E402
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake channel/thread helpers
|
||||
#
|
||||
# IMPORTANT: FakeTextChannel must NOT be the same class as discord.DMChannel
|
||||
# or discord.Thread (those are set up by conftest). We give it a neutral name
|
||||
# and do NOT monkeypatch discord.DMChannel to it.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _TextChannel:
|
||||
"""Fake Discord text channel (not a DM, not a Thread)."""
|
||||
|
||||
def __init__(self, channel_id: int = 100, name: str = "general",
|
||||
guild_name: str = "Test Server"):
|
||||
self.id = channel_id
|
||||
self.name = name
|
||||
self.guild = SimpleNamespace(name=guild_name, id=1)
|
||||
self.topic = None
|
||||
|
||||
def history(self, *, limit, before, after=None, oldest_first=None):
|
||||
async def _empty():
|
||||
return
|
||||
yield
|
||||
return _empty()
|
||||
|
||||
|
||||
class _Thread:
|
||||
"""Fake Discord thread (not a DM, not a top-level channel)."""
|
||||
|
||||
def __init__(self, thread_id: int, name: str = "thread",
|
||||
parent=None, guild_name: str = "Test Server"):
|
||||
self.id = thread_id
|
||||
self.name = name
|
||||
self.parent = parent
|
||||
self.parent_id = getattr(parent, "id", None)
|
||||
self.guild = getattr(parent, "guild", None) or SimpleNamespace(
|
||||
name=guild_name, id=1
|
||||
)
|
||||
self.topic = None
|
||||
|
||||
def history(self, *, limit, before, after=None, oldest_first=None):
|
||||
async def _empty():
|
||||
return
|
||||
yield
|
||||
return _empty()
|
||||
|
||||
|
||||
def _make_message(
|
||||
*,
|
||||
msg_id: int = 42,
|
||||
channel,
|
||||
content: str = "hello",
|
||||
mentions=None,
|
||||
author=None,
|
||||
msg_type=None,
|
||||
attachments=None,
|
||||
reference=None,
|
||||
message_snapshots=None,
|
||||
):
|
||||
if author is None:
|
||||
author = SimpleNamespace(id=7, display_name="Alice", name="Alice", bot=False)
|
||||
return SimpleNamespace(
|
||||
id=msg_id,
|
||||
content=content,
|
||||
mentions=list(mentions or []),
|
||||
attachments=list(attachments or []),
|
||||
reference=reference,
|
||||
message_snapshots=message_snapshots,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
channel=channel,
|
||||
author=author,
|
||||
type=(
|
||||
msg_type
|
||||
if msg_type is not None
|
||||
else discord_platform.discord.MessageType.default
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(monkeypatch):
|
||||
# Clear relevant env vars so tests are hermetic
|
||||
for var in (
|
||||
"DISCORD_REQUIRE_MENTION",
|
||||
"DISCORD_AUTO_THREAD",
|
||||
"DISCORD_NO_THREAD_CHANNELS",
|
||||
"DISCORD_FREE_RESPONSE_CHANNELS",
|
||||
"DISCORD_ALLOWED_CHANNELS",
|
||||
"DISCORD_IGNORED_CHANNELS",
|
||||
"DISCORD_HISTORY_BACKFILL",
|
||||
"DISCORD_ALLOW_BOTS",
|
||||
"DISCORD_IGNORE_NO_MENTION",
|
||||
):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
config = PlatformConfig(enabled=True, token="***")
|
||||
a = DiscordAdapter(config)
|
||||
a._client = SimpleNamespace(user=SimpleNamespace(id=999, bot=True))
|
||||
a._text_batch_delay_seconds = 0 # disable batching so dispatch is synchronous
|
||||
a.handle_message = AsyncMock()
|
||||
return a
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 1 — thread-starter message duplicate via on_message (the main bug)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestThreadStarterDedup:
|
||||
"""Pre-seeding dedup with thread.id prevents a second dispatch when the
|
||||
thread-starter message arrives as a duplicate MESSAGE_CREATE event."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_starter_duplicate_dropped(self, adapter, monkeypatch):
|
||||
"""After _auto_create_thread the thread.id is pre-seeded in dedup.
|
||||
|
||||
Simulates the exact Discord bug: after thread creation, Discord
|
||||
fires MESSAGE_CREATE again with message.id == thread.id. The
|
||||
adapter's on_message guard calls _dedup.is_duplicate(str(message.id))
|
||||
before dispatching. With the fix the duplicate is dropped; without
|
||||
it there would be two agent runs.
|
||||
"""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
|
||||
|
||||
channel = _TextChannel(channel_id=100)
|
||||
thread_id = 55555 # thread.id == starter-message.id on Discord
|
||||
fake_thread = _Thread(thread_id=thread_id, parent=channel)
|
||||
|
||||
async def fake_auto_create_thread(message):
|
||||
return fake_thread
|
||||
|
||||
monkeypatch.setattr(adapter, "_auto_create_thread", fake_auto_create_thread)
|
||||
|
||||
# 1) Original user message arrives → triggers thread creation + dispatch
|
||||
user_msg = _make_message(msg_id=42, channel=channel, content="hello bot")
|
||||
await adapter._handle_message(user_msg)
|
||||
|
||||
# One dispatch for the user message
|
||||
assert adapter.handle_message.call_count == 1, (
|
||||
"Expected handle_message to be called exactly once for the user message"
|
||||
)
|
||||
|
||||
# 2) Discord fires a second MESSAGE_CREATE for the thread starter.
|
||||
# Its message.id == thread.id (this is the Discord quirk).
|
||||
# Simulate what on_message does: check _dedup.is_duplicate first.
|
||||
#
|
||||
# The fix pre-seeded thread.id via _dedup.is_duplicate(str(thread.id))
|
||||
# inside _handle_message. That call already marked thread.id as seen.
|
||||
# So this second call with the same id returns True → drop the duplicate.
|
||||
starter_msg_id = str(thread_id)
|
||||
is_dup = adapter._dedup.is_duplicate(starter_msg_id)
|
||||
assert is_dup is True, (
|
||||
"Thread starter message (id == thread.id) should be in dedup cache "
|
||||
"after _auto_create_thread returns, so the duplicate event is dropped"
|
||||
)
|
||||
|
||||
# Confirm: handle_message was only called once total
|
||||
assert adapter.handle_message.call_count == 1, (
|
||||
"handle_message should only be called once — duplicate starter dropped"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_id_pre_seeded_in_dedup_cache(self, adapter, monkeypatch):
|
||||
"""After _handle_message with auto-thread, thread.id is in _dedup._seen."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
|
||||
|
||||
channel = _TextChannel(channel_id=100)
|
||||
thread_id = 55555
|
||||
fake_thread = _Thread(thread_id=thread_id, parent=channel)
|
||||
|
||||
async def fake_auto_create_thread(message):
|
||||
return fake_thread
|
||||
|
||||
monkeypatch.setattr(adapter, "_auto_create_thread", fake_auto_create_thread)
|
||||
|
||||
user_msg = _make_message(msg_id=42, channel=channel, content="hello")
|
||||
await adapter._handle_message(user_msg)
|
||||
|
||||
# Thread id must be in the dedup internal cache
|
||||
assert str(thread_id) in adapter._dedup._seen, (
|
||||
f"thread.id={thread_id} should be pre-seeded in _dedup._seen "
|
||||
"after _auto_create_thread returns a thread"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_dedup_seed_when_thread_creation_fails(self, adapter, monkeypatch):
|
||||
"""When _auto_create_thread returns None, no pre-seeding occurs."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
|
||||
|
||||
channel = _TextChannel(channel_id=100)
|
||||
phantom_thread_id = 55555
|
||||
|
||||
async def fake_auto_create_thread_fail(message):
|
||||
return None # thread creation failed
|
||||
|
||||
monkeypatch.setattr(
|
||||
adapter, "_auto_create_thread", fake_auto_create_thread_fail
|
||||
)
|
||||
|
||||
user_msg = _make_message(msg_id=42, channel=channel, content="hello")
|
||||
await adapter._handle_message(user_msg)
|
||||
|
||||
# The message was still dispatched (no thread, but message goes through)
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
||||
# The phantom thread id should NOT be in the dedup cache
|
||||
assert str(phantom_thread_id) not in adapter._dedup._seen, (
|
||||
"thread.id should NOT be pre-seeded when thread creation fails"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_dedup_seed_when_auto_thread_disabled(self, adapter, monkeypatch):
|
||||
"""When DISCORD_AUTO_THREAD=false, no thread is created and no pre-seeding."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
channel = _TextChannel(channel_id=100)
|
||||
auto_create_called = []
|
||||
|
||||
async def fake_auto_create_thread(message):
|
||||
auto_create_called.append(True)
|
||||
return _Thread(thread_id=55555, parent=channel)
|
||||
|
||||
monkeypatch.setattr(adapter, "_auto_create_thread", fake_auto_create_thread)
|
||||
|
||||
user_msg = _make_message(msg_id=42, channel=channel, content="hello")
|
||||
await adapter._handle_message(user_msg)
|
||||
|
||||
# _auto_create_thread should NOT have been called
|
||||
assert not auto_create_called, "_auto_create_thread should not run when disabled"
|
||||
# thread.id should NOT be pre-seeded
|
||||
assert "55555" not in adapter._dedup._seen, (
|
||||
"thread.id should not be in dedup when auto-threading is disabled"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedup_seed_with_text_batch_delay_zero(self, adapter, monkeypatch):
|
||||
"""With text_batch_delay=0 (direct dispatch path), pre-seeding still works."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
|
||||
|
||||
# text_batch_delay_seconds is already 0 in the fixture
|
||||
assert adapter._text_batch_delay_seconds == 0
|
||||
|
||||
channel = _TextChannel(channel_id=100)
|
||||
thread_id = 77777
|
||||
fake_thread = _Thread(thread_id=thread_id, parent=channel)
|
||||
|
||||
async def fake_auto_create_thread(message):
|
||||
return fake_thread
|
||||
|
||||
monkeypatch.setattr(adapter, "_auto_create_thread", fake_auto_create_thread)
|
||||
|
||||
user_msg = _make_message(msg_id=42, channel=channel, content="hello")
|
||||
await adapter._handle_message(user_msg)
|
||||
|
||||
# Dispatched once
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
||||
# Thread id IS pre-seeded even with direct dispatch path
|
||||
assert str(thread_id) in adapter._dedup._seen, (
|
||||
"thread.id must be pre-seeded regardless of text_batch_delay setting"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_id_different_from_message_id_both_tracked(
|
||||
self, adapter, monkeypatch
|
||||
):
|
||||
"""Verify thread.id is tracked independently when it differs from message.id."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
|
||||
|
||||
channel = _TextChannel(channel_id=100)
|
||||
user_msg_id = 12345
|
||||
thread_id = 99999 # always different in practice
|
||||
fake_thread = _Thread(thread_id=thread_id, parent=channel)
|
||||
|
||||
async def fake_auto_create_thread(message):
|
||||
return fake_thread
|
||||
|
||||
monkeypatch.setattr(adapter, "_auto_create_thread", fake_auto_create_thread)
|
||||
|
||||
user_msg = _make_message(msg_id=user_msg_id, channel=channel, content="hello")
|
||||
await adapter._handle_message(user_msg)
|
||||
|
||||
# The thread.id (99999) is pre-seeded
|
||||
assert str(thread_id) in adapter._dedup._seen, (
|
||||
f"thread.id={thread_id} must be pre-seeded after auto-thread creation"
|
||||
)
|
||||
|
||||
# A second MESSAGE_CREATE with message.id=thread.id is caught as duplicate
|
||||
assert adapter._dedup.is_duplicate(str(thread_id)) is True, (
|
||||
"Subsequent is_duplicate(thread.id) must return True"
|
||||
)
|
||||
|
||||
# A hypothetical NEW message with a different id is not a duplicate
|
||||
assert adapter._dedup.is_duplicate("11111") is False, (
|
||||
"An unrelated new message id must not be blocked"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 2 — direct double-call to _handle_message with same message id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDirectDoubleDispatch:
|
||||
"""on_message dedup (checked before _handle_message) prevents double dispatch.
|
||||
|
||||
While the on_message guard calls _dedup.is_duplicate before _handle_message,
|
||||
these tests verify that the adapter's own _dedup correctly marks IDs as seen
|
||||
so that hypothetical double-delivery of the same MESSAGE_CREATE is dropped.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_message_id_not_dispatched_twice_via_dedup(
|
||||
self, adapter, monkeypatch
|
||||
):
|
||||
"""Calling on_message dedup check twice with the same id only dispatches once."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
channel = _TextChannel(channel_id=100)
|
||||
msg = _make_message(msg_id=42, channel=channel, content="hello")
|
||||
|
||||
# Simulate on_message dedup check + dispatch for first delivery
|
||||
is_dup_1 = adapter._dedup.is_duplicate(str(msg.id))
|
||||
assert is_dup_1 is False
|
||||
await adapter._handle_message(msg)
|
||||
assert adapter.handle_message.call_count == 1
|
||||
|
||||
# Simulate on_message dedup check for second delivery (RESUME replay)
|
||||
is_dup_2 = adapter._dedup.is_duplicate(str(msg.id))
|
||||
assert is_dup_2 is True
|
||||
# on_message would return early here — do NOT call _handle_message again
|
||||
|
||||
assert adapter.handle_message.call_count == 1, (
|
||||
"Second delivery with same message.id must be dropped by dedup"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_message_ids_both_dispatched(self, adapter, monkeypatch):
|
||||
"""Two distinct messages with different IDs both reach the agent."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
channel = _TextChannel(channel_id=100)
|
||||
msg1 = _make_message(msg_id=1, channel=channel, content="first")
|
||||
msg2 = _make_message(msg_id=2, channel=channel, content="second")
|
||||
|
||||
assert adapter._dedup.is_duplicate(str(msg1.id)) is False
|
||||
await adapter._handle_message(msg1)
|
||||
assert adapter._dedup.is_duplicate(str(msg2.id)) is False
|
||||
await adapter._handle_message(msg2)
|
||||
|
||||
assert adapter.handle_message.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 3 — message_type=thread_starter filtered by type guard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestThreadStarterTypeFilter:
|
||||
"""Discord sometimes sends thread starter messages with the correct
|
||||
type=21 (thread_starter_message). Verify the type filter in on_message
|
||||
blocks those correctly, separate from the dedup path.
|
||||
"""
|
||||
|
||||
def test_thread_starter_message_type_not_in_allowed_set(self):
|
||||
"""MessageType.thread_starter_message (21) is not in the allowed set."""
|
||||
discord_mod = sys.modules["discord"]
|
||||
|
||||
# The adapter's on_message guard uses:
|
||||
# if message.type not in {discord.MessageType.default, discord.MessageType.reply}
|
||||
# Verify that thread_starter_message (if it has a numeric value of 21)
|
||||
# would be excluded.
|
||||
allowed = {
|
||||
discord_mod.MessageType.default,
|
||||
discord_mod.MessageType.reply,
|
||||
}
|
||||
# In real discord.py, thread_starter_message has value 21.
|
||||
# In our mock, MessageType is a MagicMock so attribute access returns
|
||||
# a new unique Mock each time — which is NOT in the allowed set.
|
||||
thread_starter = discord_mod.MessageType.thread_starter_message
|
||||
assert thread_starter not in allowed, (
|
||||
"thread_starter_message type should not be in the allowed types set"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_type_default_passes_type_filter(self, adapter, monkeypatch):
|
||||
"""MessageType.default messages pass the type filter (they reach _handle_message)."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "false")
|
||||
|
||||
channel = _TextChannel(channel_id=100)
|
||||
msg = _make_message(
|
||||
msg_id=42,
|
||||
channel=channel,
|
||||
content="hello",
|
||||
msg_type=discord_platform.discord.MessageType.default,
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 4 — dedup cache integrity after thread pre-seeding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDedupCacheIntegrity:
|
||||
"""Verify the dedup cache state is correct after pre-seeding."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preseed_does_not_block_legitimate_new_messages(
|
||||
self, adapter, monkeypatch
|
||||
):
|
||||
"""Pre-seeding thread.id does NOT interfere with other unrelated messages."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
|
||||
|
||||
channel = _TextChannel(channel_id=100)
|
||||
thread_id = 22222
|
||||
fake_thread = _Thread(thread_id=thread_id, parent=channel)
|
||||
|
||||
async def fake_auto_create_thread(message):
|
||||
return fake_thread
|
||||
|
||||
monkeypatch.setattr(adapter, "_auto_create_thread", fake_auto_create_thread)
|
||||
|
||||
# First message — creates thread, pre-seeds dedup
|
||||
msg1 = _make_message(msg_id=10, channel=channel, content="first")
|
||||
await adapter._handle_message(msg1)
|
||||
assert adapter.handle_message.call_count == 1
|
||||
|
||||
# A new message ID that is unrelated to the thread
|
||||
msg2_id = 20
|
||||
assert str(msg2_id) != str(thread_id) # sanity check
|
||||
assert adapter._dedup.is_duplicate(str(msg2_id)) is False, (
|
||||
"A new message with a different ID should not be blocked"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_thread_creations_each_preseeded(
|
||||
self, adapter, monkeypatch
|
||||
):
|
||||
"""Each thread creation pre-seeds its own thread.id independently."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "false")
|
||||
monkeypatch.setenv("DISCORD_AUTO_THREAD", "true")
|
||||
|
||||
channel = _TextChannel(channel_id=100)
|
||||
thread_ids = [33333, 44444, 55555]
|
||||
thread_idx = [0]
|
||||
|
||||
async def fake_auto_create_thread(message):
|
||||
tid = thread_ids[thread_idx[0] % len(thread_ids)]
|
||||
thread_idx[0] += 1
|
||||
return _Thread(thread_id=tid, parent=channel)
|
||||
|
||||
monkeypatch.setattr(adapter, "_auto_create_thread", fake_auto_create_thread)
|
||||
|
||||
for i, tid in enumerate(thread_ids):
|
||||
msg = _make_message(msg_id=100 + i, channel=channel, content=f"msg {i}")
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
# All three thread ids should be pre-seeded
|
||||
for tid in thread_ids:
|
||||
assert str(tid) in adapter._dedup._seen, (
|
||||
f"thread.id={tid} should be pre-seeded in _dedup._seen "
|
||||
"after its thread was created"
|
||||
)
|
||||
# And they should be detected as duplicates now
|
||||
assert adapter._dedup.is_duplicate(str(tid)) is True, (
|
||||
f"thread.id={tid} should be treated as duplicate"
|
||||
)
|
||||
140
tests/gateway/test_discord_sync_limit.py
Normal file
140
tests/gateway/test_discord_sync_limit.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
"""Test Discord slash command sync respects the 100-command hard limit."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
if sys.modules.get("discord") is None:
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
sys.modules["discord"] = discord_mod
|
||||
sys.modules["discord.ext"] = MagicMock()
|
||||
sys.modules["discord.ext.commands"] = MagicMock()
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from plugins.platforms.discord.adapter import DiscordAdapter
|
||||
|
||||
|
||||
class _FakeTreeCommand:
|
||||
"""Minimal command stub matching discord.py tree command API."""
|
||||
|
||||
def __init__(self, name: str, command_type: int = 1):
|
||||
self.name = name
|
||||
self.type = command_type
|
||||
|
||||
def to_dict(self, _tree):
|
||||
return {"name": self.name, "type": self.type}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter():
|
||||
"""Create a Discord adapter with mocked Discord client."""
|
||||
_ensure_discord_mock()
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
adapter = DiscordAdapter(config)
|
||||
|
||||
# Mock the Discord client and tree
|
||||
adapter._client = MagicMock()
|
||||
adapter._client.tree = MagicMock()
|
||||
adapter._client.http = AsyncMock()
|
||||
adapter._client.application_id = "test_app_id"
|
||||
|
||||
adapter._sleep_between_command_sync_mutations = AsyncMock()
|
||||
adapter._existing_command_to_payload = MagicMock(side_effect=lambda cmd: {"name": cmd.name})
|
||||
adapter._canonicalize_app_command_payload = MagicMock(side_effect=lambda p: p)
|
||||
adapter._patchable_app_command_payload = MagicMock(side_effect=lambda p: p)
|
||||
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_safe_sync_deletes_before_creating():
|
||||
"""Sync must delete obsolete commands BEFORE creating new ones.
|
||||
|
||||
Discord's 100-command limit is enforced when trying to upsert. If we
|
||||
have 100 commands on Discord, try to add 1 new one, and haven't deleted
|
||||
any yet, Discord rejects with error 30032.
|
||||
|
||||
The fix: identify and delete obsolete commands first, then create/update.
|
||||
This ensures we never temporarily exceed 100 during the sync operation.
|
||||
|
||||
This is a regression guard for the samuraiheart bug where sync would fail
|
||||
with error 30032 even though the registration code properly capped at 100.
|
||||
"""
|
||||
_ensure_discord_mock()
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
adapter = DiscordAdapter(config)
|
||||
|
||||
adapter._client = MagicMock()
|
||||
adapter._client.tree = MagicMock()
|
||||
adapter._client.http = AsyncMock()
|
||||
adapter._client.application_id = "test_app_id"
|
||||
adapter._sleep_between_command_sync_mutations = AsyncMock()
|
||||
adapter._existing_command_to_payload = MagicMock(side_effect=lambda cmd: {"name": cmd.name})
|
||||
adapter._canonicalize_app_command_payload = MagicMock(side_effect=lambda p: p)
|
||||
adapter._patchable_app_command_payload = MagicMock(side_effect=lambda p: p)
|
||||
|
||||
# Simulate having 100 commands on Discord, with 1 that's no longer desired
|
||||
# and 1 new command that should be created.
|
||||
# Existing on Discord: cmd_0, cmd_1, ..., cmd_99 (100 total)
|
||||
# Desired locally: cmd_1, cmd_2, ..., cmd_99, cmd_new (100 total)
|
||||
# So: delete cmd_0 (1 deletion), create cmd_new (1 creation)
|
||||
|
||||
existing_commands = [
|
||||
SimpleNamespace(id=f"id_{i}", name=f"cmd_{i}", type=1)
|
||||
for i in range(100)
|
||||
]
|
||||
adapter._client.tree.fetch_commands = AsyncMock(return_value=existing_commands)
|
||||
|
||||
adapter._client.tree.get_commands = MagicMock(
|
||||
return_value=[
|
||||
_FakeTreeCommand(name=f"cmd_{i}", command_type=1)
|
||||
for i in range(1, 100)
|
||||
] + [_FakeTreeCommand(name="cmd_new", command_type=1)]
|
||||
)
|
||||
|
||||
# Track the order of mutations
|
||||
mutation_log = []
|
||||
|
||||
async def mock_delete(*args):
|
||||
mutation_log.append(("delete", args[-1]))
|
||||
|
||||
async def mock_upsert(*args):
|
||||
mutation_log.append(("create", args[-1].get("name")))
|
||||
|
||||
adapter._client.http.delete_global_command = mock_delete
|
||||
adapter._client.http.upsert_global_command = mock_upsert
|
||||
adapter._client.http.edit_global_command = AsyncMock()
|
||||
|
||||
# Call sync
|
||||
await adapter._safe_sync_slash_commands()
|
||||
|
||||
# Verify that:
|
||||
# 1. A deletion happened (cmd_0)
|
||||
# 2. It happened BEFORE any creation
|
||||
# 3. The creation of cmd_new happened AFTER deletion
|
||||
deletes = [m for m in mutation_log if m[0] == "delete"]
|
||||
creates = [m for m in mutation_log if m[0] == "create"]
|
||||
|
||||
assert len(deletes) >= 1, "At least one command should be deleted"
|
||||
assert len(creates) >= 1, "At least one command should be created"
|
||||
|
||||
# The key assertion: all deletions should come before all creations.
|
||||
# Find the index of the last delete and the first create.
|
||||
last_delete_idx = max(i for i, m in enumerate(mutation_log) if m[0] == "delete")
|
||||
first_create_idx = min(i for i, m in enumerate(mutation_log) if m[0] == "create")
|
||||
|
||||
assert last_delete_idx < first_create_idx, (
|
||||
f"Deletions must happen before creations to avoid exceeding 100-command limit. "
|
||||
f"Last delete at index {last_delete_idx}, first create at index {first_create_idx}"
|
||||
)
|
||||
|
|
@ -510,3 +510,48 @@ class TestToolProgressGrouping:
|
|||
resolve_display_setting(config, "telegram", "tool_progress_grouping")
|
||||
== "separate"
|
||||
)
|
||||
|
||||
|
||||
class TestReasoningStyle:
|
||||
"""Per-platform reasoning render style (code | blockquote | subtext)."""
|
||||
|
||||
def test_discord_defaults_to_subtext(self):
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
assert resolve_display_setting({}, "discord", "reasoning_style") == "subtext"
|
||||
|
||||
def test_other_platforms_default_to_code(self):
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("telegram", "slack", "matrix", "api_server"):
|
||||
assert (
|
||||
resolve_display_setting({}, plat, "reasoning_style") == "code"
|
||||
), plat
|
||||
|
||||
def test_platform_override_wins(self):
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"platforms": {"discord": {"reasoning_style": "blockquote"}}}}
|
||||
assert (
|
||||
resolve_display_setting(config, "discord", "reasoning_style") == "blockquote"
|
||||
)
|
||||
|
||||
def test_global_override(self):
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"reasoning_style": "subtext"}}
|
||||
assert (
|
||||
resolve_display_setting(config, "telegram", "reasoning_style") == "subtext"
|
||||
)
|
||||
|
||||
def test_invalid_value_falls_back_to_code(self):
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"reasoning_style": "bogus"}}
|
||||
assert resolve_display_setting(config, "telegram", "reasoning_style") == "code"
|
||||
|
||||
def test_case_insensitive(self):
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"reasoning_style": "SUBTEXT"}}
|
||||
assert resolve_display_setting(config, "telegram", "reasoning_style") == "subtext"
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ async def test_goal_verdict_done_sent_via_adapter_send(hermes_home):
|
|||
mgr = GoalManager(session_entry.session_id)
|
||||
mgr.set("ship the feature")
|
||||
|
||||
with patch("hermes_cli.goals.judge_goal", return_value=("done", "the feature shipped", False)):
|
||||
with patch("hermes_cli.goals.judge_goal", return_value=("done", "the feature shipped", False, None)):
|
||||
await runner._post_turn_goal_continuation(
|
||||
session_entry=session_entry,
|
||||
source=src,
|
||||
|
|
@ -136,7 +136,7 @@ async def test_goal_verdict_continue_enqueues_continuation(hermes_home):
|
|||
mgr = GoalManager(session_entry.session_id)
|
||||
mgr.set("polish the docs")
|
||||
|
||||
with patch("hermes_cli.goals.judge_goal", return_value=("continue", "still needs work", False)):
|
||||
with patch("hermes_cli.goals.judge_goal", return_value=("continue", "still needs work", False, None)):
|
||||
await runner._post_turn_goal_continuation(
|
||||
session_entry=session_entry,
|
||||
source=src,
|
||||
|
|
@ -164,7 +164,7 @@ async def test_goal_verdict_budget_exhausted_sends_pause(hermes_home):
|
|||
state.turns_used = 2
|
||||
save_goal(session_entry.session_id, state)
|
||||
|
||||
with patch("hermes_cli.goals.judge_goal", return_value=("continue", "keep going", False)):
|
||||
with patch("hermes_cli.goals.judge_goal", return_value=("continue", "keep going", False, None)):
|
||||
await runner._post_turn_goal_continuation(
|
||||
session_entry=session_entry,
|
||||
source=src,
|
||||
|
|
@ -211,7 +211,7 @@ async def test_goal_verdict_survives_adapter_without_send(hermes_home):
|
|||
|
||||
runner.adapters[Platform.TELEGRAM] = _NoSendAdapter()
|
||||
|
||||
with patch("hermes_cli.goals.judge_goal", return_value=("done", "ok", False)):
|
||||
with patch("hermes_cli.goals.judge_goal", return_value=("done", "ok", False, None)):
|
||||
# must not raise
|
||||
await runner._post_turn_goal_continuation(
|
||||
session_entry=session_entry,
|
||||
|
|
|
|||
|
|
@ -967,6 +967,105 @@ class TestMediaDeliveryDefaultMode:
|
|||
|
||||
assert BasePlatformAdapter.validate_media_delivery_path(str(config_file)) is None
|
||||
|
||||
def test_denylist_blocks_google_token_default_mode(self, tmp_path, monkeypatch):
|
||||
"""Integration credentials at the HERMES_HOME root (google_token.json)
|
||||
must never be deliverable, even though they aren't the historically
|
||||
enumerated .env/auth.json/config.yaml files. Regression for a
|
||||
refreshed google_token.json being auto-attached to a Slack reply
|
||||
(#50912).
|
||||
"""
|
||||
self._patch_roots(monkeypatch)
|
||||
|
||||
fake_home = tmp_path / "home"
|
||||
hermes_dir = fake_home / ".hermes"
|
||||
hermes_dir.mkdir(parents=True)
|
||||
token = hermes_dir / "google_token.json"
|
||||
token.write_text('{"access_token": "***", "refresh_token": "***"}')
|
||||
monkeypatch.setenv("HOME", str(fake_home))
|
||||
monkeypatch.setattr("gateway.platforms.base._HERMES_HOME", hermes_dir)
|
||||
monkeypatch.setattr("gateway.platforms.base._HERMES_ROOT", hermes_dir)
|
||||
|
||||
assert BasePlatformAdapter.validate_media_delivery_path(str(token)) is None
|
||||
|
||||
def test_denylist_blocks_google_token_even_when_freshly_refreshed(self, tmp_path, monkeypatch):
|
||||
"""The exploit was that the Google integration rewrites
|
||||
google_token.json every turn, bumping its mtime to ~now, so the
|
||||
strict-mode recency window (trust_recent_files) kept re-trusting it
|
||||
and it re-sent on every reply. An explicit denylist entry must win
|
||||
over recency trust.
|
||||
"""
|
||||
self._patch_roots(monkeypatch) # zero cache allowlist, strict mode on
|
||||
monkeypatch.setenv("HERMES_MEDIA_TRUST_RECENT_FILES", "1")
|
||||
monkeypatch.setenv("HERMES_MEDIA_TRUST_RECENT_SECONDS", "600")
|
||||
|
||||
fake_home = tmp_path / "home"
|
||||
hermes_dir = fake_home / ".hermes"
|
||||
hermes_dir.mkdir(parents=True)
|
||||
token = hermes_dir / "google_token.json"
|
||||
token.write_text('{"access_token": "***"}') # mtime = now → "recent"
|
||||
monkeypatch.setenv("HOME", str(fake_home))
|
||||
monkeypatch.setattr("gateway.platforms.base._HERMES_HOME", hermes_dir)
|
||||
monkeypatch.setattr("gateway.platforms.base._HERMES_ROOT", hermes_dir)
|
||||
|
||||
assert BasePlatformAdapter.validate_media_delivery_path(str(token)) is None
|
||||
|
||||
def test_denylist_blocks_pairing_directory_contents(self, tmp_path, monkeypatch):
|
||||
"""Files under ~/.hermes/pairing/ (platform pairing tokens) are
|
||||
credential material and must not be deliverable.
|
||||
"""
|
||||
self._patch_roots(monkeypatch)
|
||||
|
||||
fake_home = tmp_path / "home"
|
||||
hermes_dir = fake_home / ".hermes"
|
||||
pairing = hermes_dir / "pairing"
|
||||
pairing.mkdir(parents=True)
|
||||
token = pairing / "telegram-approved.json"
|
||||
token.write_text('{"approved": ["123"]}')
|
||||
monkeypatch.setenv("HOME", str(fake_home))
|
||||
monkeypatch.setattr("gateway.platforms.base._HERMES_HOME", hermes_dir)
|
||||
monkeypatch.setattr("gateway.platforms.base._HERMES_ROOT", hermes_dir)
|
||||
|
||||
assert BasePlatformAdapter.validate_media_delivery_path(str(token)) is None
|
||||
|
||||
def test_hermes_cache_still_delivers_under_denied_home(self, tmp_path, monkeypatch):
|
||||
"""The targeted credential denylist must not break legitimate cache
|
||||
deliveries: a generated artifact under the allowlisted cache root is
|
||||
matched before the denylist and still delivers.
|
||||
"""
|
||||
fake_home = tmp_path / "home"
|
||||
hermes_dir = fake_home / ".hermes"
|
||||
cache_dir = hermes_dir / "cache" / "documents"
|
||||
cache_dir.mkdir(parents=True)
|
||||
artifact = cache_dir / "report.pdf"
|
||||
artifact.write_bytes(b"%PDF-1.4")
|
||||
self._patch_roots(monkeypatch, cache_dir)
|
||||
monkeypatch.setenv("HOME", str(fake_home))
|
||||
monkeypatch.setattr("gateway.platforms.base._HERMES_HOME", hermes_dir)
|
||||
monkeypatch.setattr("gateway.platforms.base._HERMES_ROOT", hermes_dir)
|
||||
|
||||
assert BasePlatformAdapter.validate_media_delivery_path(str(artifact)) == str(artifact.resolve())
|
||||
|
||||
def test_denylist_blocks_non_cache_file_under_hermes_home(self, tmp_path, monkeypatch):
|
||||
"""A non-credential file the agent wrote directly under ~/.hermes
|
||||
(not in a cache subdir) is still deliverable via recency trust — we
|
||||
did NOT blanket-deny the tree (per #32090/#34425). This guards against
|
||||
accidentally re-introducing the rejected whole-tree deny.
|
||||
"""
|
||||
self._patch_roots(monkeypatch) # strict mode on
|
||||
monkeypatch.setenv("HERMES_MEDIA_TRUST_RECENT_FILES", "1")
|
||||
monkeypatch.setenv("HERMES_MEDIA_TRUST_RECENT_SECONDS", "600")
|
||||
|
||||
fake_home = tmp_path / "home"
|
||||
hermes_dir = fake_home / ".hermes"
|
||||
hermes_dir.mkdir(parents=True)
|
||||
artifact = hermes_dir / "adhoc_report.pdf"
|
||||
artifact.write_bytes(b"%PDF-1.4") # fresh mtime
|
||||
monkeypatch.setenv("HOME", str(fake_home))
|
||||
monkeypatch.setattr("gateway.platforms.base._HERMES_HOME", hermes_dir)
|
||||
monkeypatch.setattr("gateway.platforms.base._HERMES_ROOT", hermes_dir)
|
||||
|
||||
assert BasePlatformAdapter.validate_media_delivery_path(str(artifact)) == str(artifact.resolve())
|
||||
|
||||
def test_strict_mode_envvar_restores_legacy_behavior(self, tmp_path, monkeypatch):
|
||||
"""Setting HERMES_MEDIA_DELIVERY_STRICT=1 reactivates the older
|
||||
allowlist+recency logic. A stale file outside the allowlist is
|
||||
|
|
|
|||
|
|
@ -299,6 +299,78 @@ class TestStaleSessionLockSelfHeal:
|
|||
assert sk in adapter._active_sessions
|
||||
assert sk in adapter._session_tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guard_mismatch_preserves_session_task_for_stale_detection(self):
|
||||
"""When guard mismatch skips _release_session_guard, _session_tasks is preserved.
|
||||
|
||||
This is the core of the production split-brain fix: the finally block
|
||||
only deletes _session_tasks[key] if _active_sessions[key] was actually
|
||||
released. If the guard was swapped (e.g., by a reset command), the
|
||||
_session_tasks entry remains so _session_task_is_stale can detect the
|
||||
done task and heal the lock on the next inbound message.
|
||||
"""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
|
||||
# Simulate: task recorded with guard=event_a
|
||||
event_a = asyncio.Event()
|
||||
async def _done():
|
||||
return None
|
||||
|
||||
done_task = asyncio.create_task(_done())
|
||||
await done_task
|
||||
|
||||
adapter._active_sessions[sk] = event_a
|
||||
adapter._session_tasks[sk] = done_task
|
||||
|
||||
# Simulate guard swap (as reset/new command would do)
|
||||
event_b = asyncio.Event()
|
||||
adapter._active_sessions[sk] = event_b
|
||||
|
||||
# Drive the REAL finally-block cleanup helper (not a copy of its logic):
|
||||
# _release_session_guard sees event_b != event_a → skips releasing, so
|
||||
# _session_tasks must be preserved for stale detection.
|
||||
adapter._cleanup_finished_session_task(sk, event_a)
|
||||
|
||||
# _session_tasks preserved because guard mismatch kept _active_sessions
|
||||
assert sk in adapter._session_tasks, (
|
||||
"_session_tasks entry must survive guard mismatch so stale detection works"
|
||||
)
|
||||
assert adapter._session_tasks[sk] is done_task
|
||||
|
||||
# Stale detection now works: task is done, guard is stale
|
||||
assert adapter._session_task_is_stale(sk) is True
|
||||
|
||||
# Heal clears both
|
||||
assert adapter._heal_stale_session_lock(sk) is True
|
||||
assert sk not in adapter._active_sessions
|
||||
assert sk not in adapter._session_tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_releases_and_deletes_when_guard_matches(self):
|
||||
"""Positive path for #48300: when the guard still matches (normal
|
||||
completion), the helper releases the guard AND drops the task entry —
|
||||
the release-then-conditional-delete must not strand a healthy session."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
|
||||
event_a = asyncio.Event()
|
||||
|
||||
async def _done():
|
||||
return None
|
||||
|
||||
done_task = asyncio.create_task(_done())
|
||||
await done_task
|
||||
|
||||
adapter._active_sessions[sk] = event_a
|
||||
adapter._session_tasks[sk] = done_task
|
||||
|
||||
# No guard swap → _release_session_guard matches event_a and releases.
|
||||
adapter._cleanup_finished_session_task(sk, event_a)
|
||||
|
||||
assert sk not in adapter._active_sessions, "guard must be released on match"
|
||||
assert sk not in adapter._session_tasks, "task entry must be dropped after release"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Layer 3: Runner-side generation guard on slot promotion + release
|
||||
|
|
|
|||
|
|
@ -1754,6 +1754,193 @@ class TestIncomingDocumentHandling:
|
|||
assert "> /deploy now" in msg_event.text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestIncomingAudioHandling — Slack voice messages (regression)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSlackAudioExtResolution:
|
||||
"""Unit coverage for the inbound-audio extension resolver.
|
||||
|
||||
Regression for: Slack in-app voice messages are MP4/AAC containers
|
||||
(``audio/mp4``, filename ``audio_message*.mp4``) that the old code cached
|
||||
as ``.ogg`` (the catch-all fallback), so OpenAI STT — which sniffs the
|
||||
container from the filename extension — rejected them. WhatsApp ``.ogg``
|
||||
and uploaded ``.m4a`` worked because their extension happened to match.
|
||||
"""
|
||||
|
||||
def test_slack_voice_message_mp4_keeps_real_extension(self):
|
||||
"""The core bug: audio/mp4 voice message must NOT become .ogg."""
|
||||
f = {"name": "audio_message.mp4", "mimetype": "audio/mp4"}
|
||||
ext = _slack_mod._resolve_slack_audio_ext(f, f["mimetype"])
|
||||
assert ext != ".ogg", "regression: MP4 voice message mislabeled as .ogg"
|
||||
assert ext in {".mp4", ".m4a"}
|
||||
assert ext in _slack_mod._SLACK_STT_SUPPORTED_EXTS
|
||||
|
||||
def test_whatsapp_ogg_preserved(self):
|
||||
f = {"name": "voice.ogg", "mimetype": "audio/ogg"}
|
||||
assert _slack_mod._resolve_slack_audio_ext(f, f["mimetype"]) == ".ogg"
|
||||
|
||||
def test_m4a_upload_preserved(self):
|
||||
f = {"name": "clip.m4a", "mimetype": "audio/x-m4a"}
|
||||
assert _slack_mod._resolve_slack_audio_ext(f, f["mimetype"]) == ".m4a"
|
||||
|
||||
def test_mp3_upload_preserved(self):
|
||||
f = {"name": "song.mp3", "mimetype": "audio/mpeg"}
|
||||
assert _slack_mod._resolve_slack_audio_ext(f, f["mimetype"]) == ".mp3"
|
||||
|
||||
def test_mimetype_used_when_filename_extension_missing(self):
|
||||
"""No usable filename ext → fall back to the mime map, not .ogg."""
|
||||
f = {"name": "", "mimetype": "audio/mp4"}
|
||||
assert _slack_mod._resolve_slack_audio_ext(f, f["mimetype"]) == ".m4a"
|
||||
|
||||
def test_unknown_audio_defaults_to_m4a_not_ogg(self):
|
||||
"""A truly unknown audio type defaults to the broadly-decodable .m4a."""
|
||||
f = {"name": "weird", "mimetype": "audio/x-some-future-codec"}
|
||||
ext = _slack_mod._resolve_slack_audio_ext(f, f["mimetype"])
|
||||
assert ext == ".m4a"
|
||||
assert ext != ".ogg"
|
||||
|
||||
|
||||
class TestSlackVoiceClipDetection:
|
||||
"""Unit coverage for the video/mp4-mislabeled voice-clip detector."""
|
||||
|
||||
def test_audio_message_filename_detected(self):
|
||||
assert _slack_mod._is_slack_voice_clip(
|
||||
{"name": "audio_message.mp4", "mimetype": "video/mp4"}
|
||||
)
|
||||
|
||||
def test_slack_audio_subtype_detected(self):
|
||||
assert _slack_mod._is_slack_voice_clip(
|
||||
{"name": "clip.mp4", "subtype": "slack_audio", "mimetype": "video/mp4"}
|
||||
)
|
||||
|
||||
def test_real_video_not_detected(self):
|
||||
"""A genuine uploaded video must NOT be hijacked into the audio path."""
|
||||
assert not _slack_mod._is_slack_voice_clip(
|
||||
{"name": "vacation.mp4", "mimetype": "video/mp4"}
|
||||
)
|
||||
|
||||
def test_slack_video_clip_not_detected(self):
|
||||
"""slack_video clips carry a real video track — leave them as video."""
|
||||
assert not _slack_mod._is_slack_voice_clip(
|
||||
{"name": "screen_recording.mp4", "subtype": "slack_video"}
|
||||
)
|
||||
|
||||
|
||||
class TestIncomingAudioHandling:
|
||||
def _make_event(self, files=None, text="hello"):
|
||||
return {
|
||||
"text": text,
|
||||
"user": "U_USER",
|
||||
"channel": "D123",
|
||||
"channel_type": "im",
|
||||
"ts": "1234567890.000001",
|
||||
"files": files or [],
|
||||
"blocks": [],
|
||||
"attachments": [],
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_message_cached_with_correct_extension(self, adapter, tmp_path):
|
||||
"""audio/mp4 voice message is cached with an STT-acceptable extension,
|
||||
not the old .ogg fallback, and routed as audio."""
|
||||
captured = {}
|
||||
|
||||
async def _fake_download(url, ext, audio=False, team_id=""):
|
||||
captured["ext"] = ext
|
||||
captured["audio"] = audio
|
||||
path = tmp_path / f"cached{ext}"
|
||||
path.write_bytes(b"\x00\x00\x00\x18ftypmp42fake mp4 bytes")
|
||||
return str(path)
|
||||
|
||||
with patch.object(adapter, "_download_slack_file", side_effect=_fake_download):
|
||||
event = self._make_event(
|
||||
files=[
|
||||
{
|
||||
"mimetype": "audio/mp4",
|
||||
"name": "audio_message.mp4",
|
||||
"subtype": "slack_audio",
|
||||
"url_private_download": "https://files.slack.com/audio_message.mp4",
|
||||
"size": 2048,
|
||||
}
|
||||
]
|
||||
)
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
assert captured.get("audio") is True
|
||||
assert captured["ext"] != ".ogg", "regression: voice message cached as .ogg"
|
||||
assert captured["ext"] in {".mp4", ".m4a"}
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert len(msg_event.media_urls) == 1
|
||||
# media_type stays audio/* so the gateway routes it to STT
|
||||
assert msg_event.media_types[0].startswith("audio/")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_video_mp4_voice_clip_rerouted_to_audio(self, adapter, tmp_path):
|
||||
"""A voice clip mislabeled video/mp4 is rerouted to the audio path
|
||||
(cached as audio, reported as audio/*) instead of video understanding."""
|
||||
captured = {}
|
||||
|
||||
async def _fake_download(url, ext, audio=False, team_id=""):
|
||||
captured["ext"] = ext
|
||||
captured["audio"] = audio
|
||||
path = tmp_path / f"cached{ext}"
|
||||
path.write_bytes(b"\x00\x00\x00\x18ftypmp42fake mp4 bytes")
|
||||
return str(path)
|
||||
|
||||
with patch.object(adapter, "_download_slack_file", side_effect=_fake_download):
|
||||
event = self._make_event(
|
||||
files=[
|
||||
{
|
||||
"mimetype": "video/mp4",
|
||||
"name": "audio_message.mp4",
|
||||
"subtype": "slack_audio",
|
||||
"url_private_download": "https://files.slack.com/audio_message.mp4",
|
||||
"size": 2048,
|
||||
}
|
||||
]
|
||||
)
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
assert captured.get("audio") is True
|
||||
assert captured["ext"] in {".mp4", ".m4a"}
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert len(msg_event.media_urls) == 1
|
||||
assert msg_event.media_types[0].startswith("audio/"), (
|
||||
"voice clip should route to STT, not video understanding"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_video_still_routed_as_video(self, adapter, tmp_path):
|
||||
"""A genuine uploaded video must remain on the video path."""
|
||||
|
||||
async def _fake_download_bytes(url, team_id=""):
|
||||
return b"\x00\x00\x00\x18ftypisomfake real video"
|
||||
|
||||
with patch.object(
|
||||
adapter, "_download_slack_file_bytes", side_effect=_fake_download_bytes
|
||||
):
|
||||
event = self._make_event(
|
||||
files=[
|
||||
{
|
||||
"mimetype": "video/mp4",
|
||||
"name": "vacation.mp4",
|
||||
"url_private_download": "https://files.slack.com/vacation.mp4",
|
||||
"size": 4096,
|
||||
}
|
||||
]
|
||||
)
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert len(msg_event.media_urls) == 1
|
||||
assert msg_event.media_types[0].startswith("video/"), (
|
||||
"a real video must not be hijacked into the audio path"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMessageRouting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -55,7 +55,8 @@ CHANNEL_ID = "C0AQWDLHY9M"
|
|||
OTHER_CHANNEL_ID = "C9999999999"
|
||||
|
||||
|
||||
def _make_adapter(require_mention=None, strict_mention=None, free_response_channels=None, allowed_channels=None):
|
||||
def _make_adapter(require_mention=None, strict_mention=None, free_response_channels=None,
|
||||
allowed_channels=None, mention_patterns=None):
|
||||
extra = {}
|
||||
if require_mention is not None:
|
||||
extra["require_mention"] = require_mention
|
||||
|
|
@ -65,6 +66,8 @@ def _make_adapter(require_mention=None, strict_mention=None, free_response_chann
|
|||
extra["free_response_channels"] = free_response_channels
|
||||
if allowed_channels is not None:
|
||||
extra["allowed_channels"] = allowed_channels
|
||||
if mention_patterns is not None:
|
||||
extra["mention_patterns"] = mention_patterns
|
||||
|
||||
adapter = object.__new__(SlackAdapter)
|
||||
adapter.platform = Platform.SLACK
|
||||
|
|
@ -249,7 +252,10 @@ def _would_process(adapter, *, is_dm=False, channel_id=CHANNEL_ID,
|
|||
bot_uid = adapter._team_bot_user_ids.get("T1", adapter._bot_user_id)
|
||||
if mentioned:
|
||||
text = f"<@{bot_uid}> {text}"
|
||||
is_mentioned = bot_uid and f"<@{bot_uid}>" in text
|
||||
is_mentioned = bool(
|
||||
(bot_uid and f"<@{bot_uid}>" in text)
|
||||
or adapter._slack_message_matches_mention_patterns(text)
|
||||
)
|
||||
|
||||
if not is_dm and bot_uid:
|
||||
# allowed_channels check (whitelist — must pass before other gating)
|
||||
|
|
@ -687,3 +693,61 @@ def test_config_bridges_slack_allowed_channels_env_takes_precedence(monkeypatch,
|
|||
import os as _os
|
||||
# env var must not be overwritten by config.yaml
|
||||
assert _os.environ["SLACK_ALLOWED_CHANNELS"] == OTHER_CHANNEL_ID
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: mention_patterns (wake words) — parity with other adapters (#50732)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_mention_patterns_default_no_match(monkeypatch):
|
||||
monkeypatch.delenv("SLACK_MENTION_PATTERNS", raising=False)
|
||||
adapter = _make_adapter()
|
||||
assert adapter._slack_mention_patterns() == []
|
||||
assert adapter._slack_message_matches_mention_patterns("hello there") is False
|
||||
|
||||
|
||||
def test_mention_patterns_list_matches():
|
||||
adapter = _make_adapter(mention_patterns=["hey hermes", "hermes,"])
|
||||
assert adapter._slack_message_matches_mention_patterns("hey hermes, you there?") is True
|
||||
assert adapter._slack_message_matches_mention_patterns("just chatting") is False
|
||||
|
||||
|
||||
def test_mention_patterns_case_insensitive():
|
||||
adapter = _make_adapter(mention_patterns=["hey hermes"])
|
||||
assert adapter._slack_message_matches_mention_patterns("HEY HERMES!") is True
|
||||
|
||||
|
||||
def test_mention_patterns_single_string():
|
||||
adapter = _make_adapter(mention_patterns="^hermes")
|
||||
assert adapter._slack_message_matches_mention_patterns("hermes do this") is True
|
||||
assert adapter._slack_message_matches_mention_patterns("ok hermes") is False
|
||||
|
||||
|
||||
def test_mention_patterns_invalid_regex_skipped_without_crash():
|
||||
# An invalid pattern is dropped; valid siblings still work.
|
||||
adapter = _make_adapter(mention_patterns=["(unclosed", "hey hermes"])
|
||||
assert adapter._slack_message_matches_mention_patterns("hey hermes") is True
|
||||
|
||||
|
||||
def test_mention_patterns_env_var_fallback(monkeypatch):
|
||||
monkeypatch.setenv("SLACK_MENTION_PATTERNS", '["hey hermes", "hermes,"]')
|
||||
adapter = _make_adapter() # no config value -> falls back to env
|
||||
assert adapter._slack_message_matches_mention_patterns("hey hermes") is True
|
||||
|
||||
|
||||
def test_mention_patterns_env_var_csv_fallback_splits_patterns(monkeypatch):
|
||||
monkeypatch.setenv("SLACK_MENTION_PATTERNS", "hey hermes,hermes,")
|
||||
adapter = _make_adapter() # no config value -> falls back to env
|
||||
|
||||
patterns = adapter._slack_mention_patterns()
|
||||
|
||||
assert [pattern.pattern for pattern in patterns] == ["hey hermes", "hermes"]
|
||||
assert adapter._slack_message_matches_mention_patterns("hey hermes") is True
|
||||
|
||||
|
||||
def test_mention_patterns_trigger_in_channel_without_literal_mention():
|
||||
"""A wake word triggers the bot in a channel even with require_mention on."""
|
||||
adapter = _make_adapter(require_mention=True, mention_patterns=["hey hermes"])
|
||||
assert _would_process(adapter, text="hey hermes what's the status") is True
|
||||
# Unrelated channel chatter is still ignored.
|
||||
assert _would_process(adapter, text="lunch anyone?") is False
|
||||
|
|
|
|||
177
tests/gateway/test_telegram_closewait_limits_31599.py
Normal file
177
tests/gateway/test_telegram_closewait_limits_31599.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
"""Regression test for #31599 — Telegram general-pool CLOSE_WAIT fd leak.
|
||||
|
||||
Background
|
||||
----------
|
||||
PTB's ``telegram.request.HTTPXRequest`` builds the underlying
|
||||
``httpx.AsyncClient`` with ``limits = httpx.Limits(max_connections=...)``
|
||||
and *no* keepalive tuning, so httpx's default ``keepalive_expiry=5.0``
|
||||
applies. Behind an HTTP proxy (Cloudflare Warp etc.) a peer-initiated
|
||||
FIN can sit in ``CLOSE_WAIT`` longer than that, leaking fds in the
|
||||
general request pool (``_request[1]`` — the pool that routes
|
||||
``bot.send_message`` / ``set_my_commands``), which
|
||||
``_drain_polling_connections`` never resets.
|
||||
|
||||
The fix wires the shared ``gateway.platforms._http_client_limits``
|
||||
``platform_httpx_limits()`` helper into *every* HTTPXRequest the adapter
|
||||
builds — the fallback-transport branch, the proxy branch, and the plain
|
||||
branch — so idle keepalive sockets drain aggressively.
|
||||
|
||||
Contract asserted here (mutation-survivable)
|
||||
---------------------------------------------
|
||||
Every ``HTTPXRequest`` constructed by ``TelegramAdapter.connect()`` must
|
||||
receive ``httpx_kwargs["limits"]`` that is an ``httpx.Limits`` with a
|
||||
``keepalive_expiry`` strictly below httpx's 5.0 default and a positive,
|
||||
bounded ``max_keepalive_connections``. Reverting the limits wiring (so
|
||||
HTTPXRequest falls back to PTB's default 5.0s keepalive) fails this test.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _ensure_telegram_mock():
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
return
|
||||
telegram_mod = MagicMock()
|
||||
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
telegram_mod.constants.ChatType.GROUP = "group"
|
||||
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
from plugins.platforms.telegram import adapter as tg_adapter # noqa: E402
|
||||
from plugins.platforms.telegram.adapter import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
class _StopConnect(Exception):
|
||||
"""Sentinel raised to abort connect() once requests are built."""
|
||||
|
||||
|
||||
class _RecordingHTTPXRequest:
|
||||
"""Stand-in for PTB's HTTPXRequest that records constructor kwargs."""
|
||||
|
||||
instances: list = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
_RecordingHTTPXRequest.instances.append(self)
|
||||
|
||||
|
||||
def _make_adapter() -> TelegramAdapter:
|
||||
return TelegramAdapter(PlatformConfig(enabled=True, token="test-token"))
|
||||
|
||||
|
||||
def _drive_connect(monkeypatch, *, proxy_url):
|
||||
"""Run connect() far enough to build the HTTPXRequests, then abort.
|
||||
|
||||
Returns the list of recorded _RecordingHTTPXRequest instances.
|
||||
"""
|
||||
_RecordingHTTPXRequest.instances = []
|
||||
|
||||
# No DoH auto-discovery → exercise the proxy / plain branches, not fallback.
|
||||
async def _no_fallback():
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(tg_adapter, "discover_fallback_ips", _no_fallback)
|
||||
monkeypatch.setattr(
|
||||
tg_adapter, "resolve_proxy_url", lambda *a, **k: proxy_url
|
||||
)
|
||||
# Replace the real HTTPXRequest with our recorder.
|
||||
monkeypatch.setattr(tg_adapter, "HTTPXRequest", _RecordingHTTPXRequest)
|
||||
|
||||
adapter = _make_adapter()
|
||||
# Skip the cross-process token lock.
|
||||
monkeypatch.setattr(adapter, "_acquire_platform_lock", lambda *a, **k: True)
|
||||
# Ensure the adapter reports no statically-configured fallback IPs.
|
||||
monkeypatch.setattr(adapter, "_fallback_ips", lambda: [])
|
||||
|
||||
# builder.request(...).get_updates_request(...).build() must be harmless;
|
||||
# make build() raise our sentinel so connect() stops right after the
|
||||
# HTTPXRequests are constructed (before any real network/init).
|
||||
fake_built_app = MagicMock()
|
||||
fake_built_app.initialize = MagicMock(side_effect=_StopConnect)
|
||||
|
||||
chainable = MagicMock()
|
||||
chainable.token.return_value = chainable
|
||||
chainable.base_url.return_value = chainable
|
||||
chainable.base_file_url.return_value = chainable
|
||||
chainable.local_mode.return_value = chainable
|
||||
chainable.request.return_value = chainable
|
||||
chainable.get_updates_request.return_value = chainable
|
||||
chainable.build.side_effect = _StopConnect
|
||||
|
||||
builder_root = MagicMock()
|
||||
builder_root.builder.return_value = chainable
|
||||
monkeypatch.setattr(tg_adapter, "Application", builder_root)
|
||||
|
||||
try:
|
||||
asyncio.run(adapter.connect())
|
||||
except _StopConnect:
|
||||
pass
|
||||
except Exception:
|
||||
# connect() wraps work in a try; if it swallows the sentinel and
|
||||
# continues to real init, the recorded instances are still valid.
|
||||
pass
|
||||
|
||||
return list(_RecordingHTTPXRequest.instances)
|
||||
|
||||
|
||||
def _assert_keepalive_tight(instances):
|
||||
assert instances, "connect() built no HTTPXRequest — test setup is wrong"
|
||||
for inst in instances:
|
||||
limits = inst.kwargs.get("httpx_kwargs", {}).get("limits")
|
||||
assert isinstance(limits, httpx.Limits), (
|
||||
"HTTPXRequest must receive httpx_kwargs['limits'] = httpx.Limits "
|
||||
"wired from platform_httpx_limits() (#31599). Missing → PTB falls "
|
||||
"back to default keepalive_expiry=5.0 and leaks CLOSE_WAIT fds."
|
||||
)
|
||||
# The whole point: keepalive must be tighter than httpx's 5.0 default.
|
||||
assert limits.keepalive_expiry is not None
|
||||
assert limits.keepalive_expiry < 5.0, (
|
||||
"keepalive_expiry must be < httpx default 5.0 so idle/CLOSE_WAIT "
|
||||
"sockets drain promptly behind a proxy (#31599)."
|
||||
)
|
||||
assert limits.max_keepalive_connections is not None
|
||||
assert 1 <= limits.max_keepalive_connections <= 50
|
||||
# PTB's connection_pool_size (max_connections) must be preserved.
|
||||
assert limits.max_connections is not None and limits.max_connections > 0
|
||||
|
||||
|
||||
def test_proxy_branch_general_pool_has_tight_keepalive(monkeypatch):
|
||||
"""The proxy path the #31599 reporter hit must wire tuned limits."""
|
||||
instances = _drive_connect(monkeypatch, proxy_url="http://127.0.0.1:9/")
|
||||
# Both the general request pool and the get_updates pool are built here.
|
||||
assert len(instances) >= 2
|
||||
_assert_keepalive_tight(instances)
|
||||
# Sanity: the proxy was actually threaded through (we're on the proxy branch).
|
||||
assert any(inst.kwargs.get("proxy") == "http://127.0.0.1:9/" for inst in instances)
|
||||
|
||||
|
||||
def test_plain_branch_general_pool_has_tight_keepalive(monkeypatch):
|
||||
"""No proxy / no fallback IPs → plain branch must also wire tuned limits."""
|
||||
instances = _drive_connect(monkeypatch, proxy_url=None)
|
||||
assert len(instances) >= 2
|
||||
_assert_keepalive_tight(instances)
|
||||
|
||||
|
||||
def test_limits_keepalive_below_ptb_default_is_the_contract():
|
||||
"""Document the invariant independent of adapter wiring: the shared
|
||||
helper itself must tighten keepalive below httpx's 5.0 default."""
|
||||
from gateway.platforms._http_client_limits import platform_httpx_limits
|
||||
|
||||
limits = platform_httpx_limits()
|
||||
assert isinstance(limits, httpx.Limits)
|
||||
assert limits.keepalive_expiry is not None and limits.keepalive_expiry < 5.0
|
||||
459
tests/gateway/test_telegram_prune_stale_topic_binding_31501.py
Normal file
459
tests/gateway/test_telegram_prune_stale_topic_binding_31501.py
Normal file
|
|
@ -0,0 +1,459 @@
|
|||
"""Regression tests for #31501 — prune stale Telegram DM topic bindings.
|
||||
|
||||
When a Telegram user deletes a DM topic in the client, the Bot API
|
||||
responds to the gateway's next send with ``Thread not found``. The
|
||||
adapter falls back to a plain send (no ``message_thread_id``), but
|
||||
prior to this fix it left the corresponding row in
|
||||
``telegram_dm_topic_bindings`` untouched.
|
||||
``gateway.run._recover_telegram_topic_thread_id`` then walked the
|
||||
user's bindings newest-first on every later inbound message and
|
||||
cheerfully redirected them back to the deleted topic — tool
|
||||
progress, approvals and replies all silently landed in the wrong
|
||||
place until the operator manually ran ``DELETE`` on ``state.db``.
|
||||
|
||||
The fix has three pieces — these tests pin all three:
|
||||
|
||||
1. ``SessionDB.delete_telegram_topic_binding`` — the targeted
|
||||
prune helper (new public API).
|
||||
2. ``TelegramAdapter._prune_stale_dm_topic_binding`` — the
|
||||
adapter glue that calls the helper from a send-fallback hot
|
||||
path without raising on cleanup failure.
|
||||
3. The two "Thread not found" call sites in the streaming send
|
||||
loop and the control-message helper now invoke (2) — we pin
|
||||
this with a source-level guard rather than spinning the full
|
||||
send pipeline.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_state import SessionDB
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionDB.delete_telegram_topic_binding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _seed_binding(
|
||||
db: SessionDB,
|
||||
*,
|
||||
chat_id: str = "5595856929",
|
||||
thread_id: str = "15287",
|
||||
user_id: str = "5595856929",
|
||||
session_id: str = "sess-target",
|
||||
) -> None:
|
||||
db.create_session(
|
||||
session_id=session_id,
|
||||
source="telegram",
|
||||
user_id=user_id,
|
||||
)
|
||||
db.bind_telegram_topic(
|
||||
chat_id=chat_id,
|
||||
thread_id=thread_id,
|
||||
user_id=user_id,
|
||||
session_key=f"agent:main:telegram:dm:{chat_id}:{thread_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class TestDeleteTelegramTopicBinding:
|
||||
def test_removes_matching_row_and_returns_count(self, tmp_path):
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
_seed_binding(db, thread_id="15287")
|
||||
# Sanity check — binding present before prune.
|
||||
assert db.get_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="15287",
|
||||
) is not None
|
||||
|
||||
removed = db.delete_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="15287",
|
||||
)
|
||||
|
||||
assert removed == 1
|
||||
assert db.get_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="15287",
|
||||
) is None
|
||||
db.close()
|
||||
|
||||
def test_does_not_touch_unrelated_bindings(self, tmp_path):
|
||||
# Critical for the fix: a chat with multiple topics must
|
||||
# only lose the one Telegram confirmed deleted, never the
|
||||
# rest. Otherwise the user's healthy topics also vanish
|
||||
# from recovery's view.
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
_seed_binding(db, thread_id="15287", session_id="sess-stale")
|
||||
_seed_binding(db, thread_id="15418", session_id="sess-fresh")
|
||||
|
||||
removed = db.delete_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="15287",
|
||||
)
|
||||
assert removed == 1
|
||||
|
||||
# Stale binding is gone; the fresh one survives.
|
||||
assert db.get_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="15287",
|
||||
) is None
|
||||
assert db.get_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="15418",
|
||||
) is not None
|
||||
db.close()
|
||||
|
||||
def test_missing_row_returns_zero_silently(self, tmp_path):
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
_seed_binding(db, thread_id="15287")
|
||||
|
||||
# Different thread_id — must not raise, just report 0.
|
||||
removed = db.delete_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="99999",
|
||||
)
|
||||
assert removed == 0
|
||||
# Original binding still intact.
|
||||
assert db.get_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="15287",
|
||||
) is not None
|
||||
db.close()
|
||||
|
||||
def test_pristine_database_with_no_topic_tables_is_silent_noop(self, tmp_path):
|
||||
# Fresh profile that has never run /topic — the topic-mode
|
||||
# tables don't exist yet. The send-fallback hot path can
|
||||
# still hit this code, so we must not crash.
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
# Confirm precondition: tables really aren't there.
|
||||
tables = {
|
||||
row[0]
|
||||
for row in db._conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' "
|
||||
"AND name LIKE 'telegram_dm%'"
|
||||
).fetchall()
|
||||
}
|
||||
assert "telegram_dm_topic_bindings" not in tables
|
||||
|
||||
removed = db.delete_telegram_topic_binding(
|
||||
chat_id="any", thread_id="any",
|
||||
)
|
||||
assert removed == 0
|
||||
db.close()
|
||||
|
||||
def test_idempotent_under_repeated_calls(self, tmp_path):
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
_seed_binding(db, thread_id="15287")
|
||||
|
||||
first = db.delete_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="15287",
|
||||
)
|
||||
second = db.delete_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="15287",
|
||||
)
|
||||
|
||||
assert first == 1
|
||||
assert second == 0 # already gone, no spurious "1"
|
||||
db.close()
|
||||
|
||||
|
||||
class TestPruneClearsTopicModeWhenLastBindingGone:
|
||||
"""Proactive cleanup (#31501 follow-up): pruning the chat's final
|
||||
binding must also flip ``telegram_dm_topic_mode.enabled`` to 0 so
|
||||
recovery fully stands down — covers the user who disabled topics in
|
||||
the Telegram client without ever running ``/topic off``."""
|
||||
|
||||
def test_clears_enabled_when_last_binding_pruned(self, tmp_path):
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.enable_telegram_topic_mode(
|
||||
chat_id="5595856929", user_id="5595856929",
|
||||
)
|
||||
_seed_binding(db, thread_id="15287")
|
||||
assert db.is_telegram_topic_mode_enabled(
|
||||
chat_id="5595856929", user_id="5595856929",
|
||||
) is True
|
||||
|
||||
removed = db.delete_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="15287",
|
||||
)
|
||||
|
||||
assert removed == 1
|
||||
assert db.is_telegram_topic_mode_enabled(
|
||||
chat_id="5595856929", user_id="5595856929",
|
||||
) is False
|
||||
db.close()
|
||||
|
||||
def test_keeps_enabled_while_other_bindings_remain(self, tmp_path):
|
||||
# Deleting one of several topics must NOT disable topic mode —
|
||||
# the chat still has healthy lanes that recovery should serve.
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.enable_telegram_topic_mode(
|
||||
chat_id="5595856929", user_id="5595856929",
|
||||
)
|
||||
_seed_binding(db, thread_id="15287", session_id="sess-stale")
|
||||
_seed_binding(db, thread_id="15418", session_id="sess-fresh")
|
||||
|
||||
db.delete_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="15287",
|
||||
)
|
||||
|
||||
assert db.is_telegram_topic_mode_enabled(
|
||||
chat_id="5595856929", user_id="5595856929",
|
||||
) is True
|
||||
db.close()
|
||||
|
||||
def test_noop_prune_leaves_enabled_untouched(self, tmp_path):
|
||||
# A prune that matches no row must not flip the flag — there's
|
||||
# still a live binding the (wrong) thread_id didn't match.
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.enable_telegram_topic_mode(
|
||||
chat_id="5595856929", user_id="5595856929",
|
||||
)
|
||||
_seed_binding(db, thread_id="15287")
|
||||
|
||||
removed = db.delete_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="99999",
|
||||
)
|
||||
|
||||
assert removed == 0
|
||||
assert db.is_telegram_topic_mode_enabled(
|
||||
chat_id="5595856929", user_id="5595856929",
|
||||
) is True
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter glue — _prune_stale_dm_topic_binding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _bare_adapter(db: SessionDB | None = None):
|
||||
# The adapter accesses the SessionDB via
|
||||
# ``self._session_store._db`` (set by GatewayRunner via
|
||||
# ``set_session_store``). Build a minimal stand-in with just
|
||||
# the surface the prune helper touches; we don't need the
|
||||
# python-telegram-bot import-graph here. ``name`` is a
|
||||
# property that delegates to ``platform.value.title()``, so
|
||||
# we set ``platform`` rather than poking ``name`` directly.
|
||||
from gateway.config import Platform
|
||||
from plugins.platforms.telegram.adapter import TelegramAdapter
|
||||
|
||||
adapter = object.__new__(TelegramAdapter)
|
||||
adapter.platform = Platform.TELEGRAM
|
||||
if db is not None:
|
||||
adapter._session_store = SimpleNamespace(_db=db)
|
||||
return adapter
|
||||
|
||||
|
||||
class TestPruneStaleDmTopicBindingHelper:
|
||||
def test_drops_binding_when_session_store_db_is_present(self, tmp_path):
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
_seed_binding(db, thread_id="15287")
|
||||
|
||||
adapter = _bare_adapter(db)
|
||||
adapter._prune_stale_dm_topic_binding("5595856929", 15287)
|
||||
|
||||
assert db.get_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="15287",
|
||||
) is None
|
||||
db.close()
|
||||
|
||||
def test_silent_when_session_store_unavailable(self):
|
||||
# No ``_session_store`` attribute — the helper must not
|
||||
# explode (the streaming send path hits this in tests
|
||||
# that bypass the gateway runner).
|
||||
adapter = _bare_adapter()
|
||||
adapter._prune_stale_dm_topic_binding("123", "456")
|
||||
|
||||
def test_silent_when_db_lacks_helper(self):
|
||||
# Old SessionDB without the new method (e.g. running
|
||||
# against an older state.db schema). Must be a no-op
|
||||
# rather than AttributeError.
|
||||
adapter = _bare_adapter()
|
||||
adapter._session_store = SimpleNamespace(
|
||||
_db=SimpleNamespace(), # no methods at all
|
||||
)
|
||||
adapter._prune_stale_dm_topic_binding("123", "456")
|
||||
|
||||
def test_swallows_db_exceptions_so_send_continues(self):
|
||||
class ExplodingDb:
|
||||
def delete_telegram_topic_binding(self, **_):
|
||||
raise RuntimeError("disk full or whatever")
|
||||
|
||||
adapter = _bare_adapter()
|
||||
adapter._session_store = SimpleNamespace(_db=ExplodingDb())
|
||||
|
||||
# The point of the helper is that a failed cleanup must
|
||||
# NEVER turn into a failed user-facing send. No exception
|
||||
# should escape.
|
||||
adapter._prune_stale_dm_topic_binding("123", "456")
|
||||
|
||||
def test_skips_when_chat_or_thread_missing(self, tmp_path):
|
||||
# Defensive — control-message paths sometimes call us
|
||||
# with chat_id=None when kwargs lack the key. We must
|
||||
# not produce a spurious DELETE that matches every row
|
||||
# with a NULL chat_id.
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
_seed_binding(db, thread_id="15287")
|
||||
|
||||
adapter = _bare_adapter(db)
|
||||
|
||||
adapter._prune_stale_dm_topic_binding(None, "15287")
|
||||
adapter._prune_stale_dm_topic_binding("5595856929", None)
|
||||
|
||||
# Still there — neither call generated a DELETE.
|
||||
assert db.get_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="15287",
|
||||
) is not None
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Source-level wiring guards — both fallback sites must call the helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestThreadNotFoundFallbackSitesPruneBinding:
|
||||
"""Pin that the two ``Thread not found`` warning sites in the
|
||||
Telegram adapter actually invoke ``_prune_stale_dm_topic_binding``.
|
||||
These guards stop a future refactor from quietly losing the
|
||||
cleanup wire — re-opening #31501.
|
||||
"""
|
||||
|
||||
def test_streaming_send_fallback_calls_prune(self):
|
||||
from plugins.platforms.telegram import adapter as telegram_mod
|
||||
|
||||
src = inspect.getsource(telegram_mod.TelegramAdapter.send)
|
||||
# Locate the second-failure branch (the one that flips
|
||||
# ``used_thread_fallback``). It must invoke the prune
|
||||
# helper before flipping the flag.
|
||||
marker = "retrying without message_thread_id"
|
||||
idx = src.find(marker)
|
||||
assert idx != -1, (
|
||||
"Streaming send must keep its 'thread not found' "
|
||||
"fallback log line — the prune wiring is anchored "
|
||||
"next to it."
|
||||
)
|
||||
# 600 char window is enough to cover the warning, the
|
||||
# prune call, and the ``used_thread_fallback = True``
|
||||
# assignment that follows.
|
||||
window = src[idx:idx + 600]
|
||||
assert "_prune_stale_dm_topic_binding" in window, (
|
||||
"Streaming send 'Thread not found' fallback must call "
|
||||
"_prune_stale_dm_topic_binding so the stale row in "
|
||||
"telegram_dm_topic_bindings doesn't keep redirecting "
|
||||
"future inbound messages to the deleted topic (#31501)."
|
||||
)
|
||||
|
||||
def test_control_message_helper_calls_prune(self):
|
||||
from plugins.platforms.telegram import adapter as telegram_mod
|
||||
|
||||
src = inspect.getsource(
|
||||
telegram_mod.TelegramAdapter._send_message_with_thread_fallback
|
||||
)
|
||||
# The helper has a single retry path; the prune call
|
||||
# must sit inside it, not in dead code outside the
|
||||
# ``if message_thread_id is not None and …`` guard.
|
||||
assert "_prune_stale_dm_topic_binding" in src, (
|
||||
"_send_message_with_thread_fallback must call "
|
||||
"_prune_stale_dm_topic_binding when Telegram returns "
|
||||
"BadRequest('Thread not found') for a control message "
|
||||
"(#31501)."
|
||||
)
|
||||
# Belt-and-braces: the call must precede the retry
|
||||
# ``send_message`` so the prune happens whether or not
|
||||
# the retry itself succeeds.
|
||||
prune_idx = src.find("_prune_stale_dm_topic_binding")
|
||||
retry_idx = src.find("send_message(**retry_kwargs)")
|
||||
assert 0 <= prune_idx < retry_idx, (
|
||||
"_prune_stale_dm_topic_binding must run before the "
|
||||
"fallback send_message retry."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-end semantic — prune + recovery returns None for deleted topic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecoveryAfterPrune:
|
||||
"""The whole point of the fix: once a topic is pruned, the
|
||||
GatewayRunner's ``_recover_telegram_topic_thread_id`` must no
|
||||
longer steer future inbound messages to it.
|
||||
"""
|
||||
|
||||
def test_recovery_no_longer_returns_pruned_topic(self, tmp_path):
|
||||
# Build the same fixture used elsewhere: two topic bindings
|
||||
# for the same user, then prune the most-recent one.
|
||||
# ``_recover_telegram_topic_thread_id`` walks bindings
|
||||
# newest-first, so without the prune it would pick the
|
||||
# one we just removed.
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.enable_telegram_topic_mode(
|
||||
chat_id="5595856929", user_id="5595856929",
|
||||
)
|
||||
|
||||
for sid, thread in (("sess-A", "111"), ("sess-B", "222")):
|
||||
db.create_session(
|
||||
session_id=sid, source="telegram",
|
||||
user_id="5595856929",
|
||||
)
|
||||
db.bind_telegram_topic(
|
||||
chat_id="5595856929",
|
||||
thread_id=thread,
|
||||
user_id="5595856929",
|
||||
session_key=build_session_key(SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="5595856929",
|
||||
chat_id="5595856929",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
thread_id=thread,
|
||||
)),
|
||||
session_id=sid,
|
||||
)
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token="***"),
|
||||
}
|
||||
)
|
||||
runner._session_db = db
|
||||
runner._telegram_topic_mode_enabled = lambda _src: True
|
||||
|
||||
# Sanity: before the prune, recovery picks "222" (newest).
|
||||
# Recovery only fires for a lobby-shaped inbound (omitted
|
||||
# message_thread_id or General topic "1"); a non-lobby
|
||||
# unknown thread is preserved as a brand-new topic. Use the
|
||||
# General topic id so the recovery walk actually runs.
|
||||
before = runner._recover_telegram_topic_thread_id(SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="5595856929",
|
||||
chat_id="5595856929",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
thread_id="1", # General/stripped reply — triggers recovery
|
||||
))
|
||||
assert before == "222"
|
||||
|
||||
# User deletes topic 222 in Telegram → adapter prunes.
|
||||
db.delete_telegram_topic_binding(
|
||||
chat_id="5595856929", thread_id="222",
|
||||
)
|
||||
|
||||
# Now recovery falls back to topic 111 (the surviving
|
||||
# binding) instead of the dead one. This is the exact
|
||||
# behaviour change the bug report asks for.
|
||||
after = runner._recover_telegram_topic_thread_id(SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="5595856929",
|
||||
chat_id="5595856929",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
thread_id="1",
|
||||
))
|
||||
assert after == "111"
|
||||
db.close()
|
||||
66
tests/gateway/test_tui_approval_redaction.py
Normal file
66
tests/gateway/test_tui_approval_redaction.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""Regression test for TUI approval-prompt credential redaction (#48456).
|
||||
|
||||
Follow-up to #50767, which redacted the chat-platform and SSE/API approval
|
||||
transports. The TUI JSON-RPC transport is the third egress: three
|
||||
`register_gateway_notify` callbacks in `tui_gateway/server.py` emit the raw
|
||||
`approval_data` (with an unredacted `command`) to the TUI client. They now
|
||||
route through the module-level `_emit_approval_request` helper, which redacts
|
||||
`payload["command"]` via the shared `gateway.run._redact_approval_command` seam
|
||||
before emitting.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestTuiApprovalEmitRedaction:
|
||||
def test_emit_approval_request_redacts_command_in_payload(self, monkeypatch):
|
||||
from tui_gateway import server as tui_server
|
||||
|
||||
emitted = {}
|
||||
monkeypatch.setattr(
|
||||
tui_server, "_emit",
|
||||
lambda event, sid, payload=None: emitted.update(
|
||||
{"event": event, "sid": sid, "payload": payload}
|
||||
),
|
||||
)
|
||||
raw = "curl -H 'Authorization: token ghp_01...6789' https://api.github.com"
|
||||
tui_server._emit_approval_request("sess-1", {"command": raw, "description": "x"})
|
||||
|
||||
assert emitted["event"] == "approval.request"
|
||||
# credential removed, non-command field + command structure preserved
|
||||
assert "ghp_01...6789" not in emitted["payload"]["command"]
|
||||
assert emitted["payload"]["description"] == "x"
|
||||
assert "github.com" in emitted["payload"]["command"]
|
||||
|
||||
def test_emit_approval_request_handles_missing_command(self, monkeypatch):
|
||||
from tui_gateway import server as tui_server
|
||||
|
||||
emitted = {}
|
||||
monkeypatch.setattr(
|
||||
tui_server, "_emit",
|
||||
lambda event, sid, payload=None: emitted.update({"payload": payload}),
|
||||
)
|
||||
tui_server._emit_approval_request("s", {"description": "no command here"})
|
||||
assert emitted["payload"] == {"description": "no command here"}
|
||||
tui_server._emit_approval_request("s", None)
|
||||
assert emitted["payload"] == {}
|
||||
|
||||
def test_no_raw_command_emit_in_approval_registrations(self):
|
||||
"""Every register_gateway_notify approval callback must route through the
|
||||
redacting `_emit_approval_request` helper — no registration may emit the
|
||||
raw payload via `_emit("approval.request", ...)` directly. The ONLY
|
||||
allowed raw emit is inside the helper itself."""
|
||||
from tui_gateway import server as tui_server
|
||||
|
||||
src = inspect.getsource(tui_server)
|
||||
raw_emits = src.count('_emit("approval.request"')
|
||||
assert raw_emits == 1, (
|
||||
f'expected exactly 1 raw _emit("approval.request") (inside the '
|
||||
f"redacting helper), found {raw_emits} — a registration may be "
|
||||
f"emitting the unredacted command"
|
||||
)
|
||||
assert "_emit_approval_request(sid, data)" in src, (
|
||||
"registration lambdas must route through _emit_approval_request"
|
||||
)
|
||||
|
|
@ -113,6 +113,33 @@ def test_active_session_registry_prunes_dead_pids(tmp_path, monkeypatch):
|
|||
lease.release()
|
||||
|
||||
|
||||
def test_transfer_active_session_reanchors_existing_lease(tmp_path, monkeypatch):
|
||||
home = tmp_path / ".hermes"
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
lease, message = active_sessions.try_acquire_active_session(
|
||||
session_id="session-old",
|
||||
surface="tui",
|
||||
config={"max_concurrent_sessions": 1},
|
||||
metadata={"live_session_id": "ui-1"},
|
||||
)
|
||||
|
||||
assert message is None
|
||||
assert lease is not None
|
||||
assert active_sessions.transfer_active_session(
|
||||
lease,
|
||||
session_id="session-new",
|
||||
metadata={"live_session_id": "ui-1"},
|
||||
)
|
||||
|
||||
snapshot = active_sessions.active_session_registry_snapshot()
|
||||
assert lease.session_id == "session-new"
|
||||
assert len(snapshot) == 1
|
||||
assert snapshot[0]["session_id"] == "session-new"
|
||||
assert snapshot[0]["metadata"] == {"live_session_id": "ui-1"}
|
||||
lease.release()
|
||||
|
||||
|
||||
def test_pid_alive_uses_safe_pid_exists_without_signalling(monkeypatch):
|
||||
checked: list[int] = []
|
||||
|
||||
|
|
|
|||
|
|
@ -190,7 +190,11 @@ def _arrange_startup_fallback(monkeypatch, tmp_path, running_pids):
|
|||
|
||||
def test_gateway_cmd_script_uses_pythonw_without_replace_or_start_churn(monkeypatch):
|
||||
"""Scheduled Task wrapper should launch pythonw once and avoid replace loops."""
|
||||
monkeypatch.setattr(gateway_windows, "_derive_venv_pythonw", lambda exe: exe.replace("python.exe", "pythonw.exe"))
|
||||
monkeypatch.setattr(
|
||||
gateway_windows,
|
||||
"_resolve_detached_python",
|
||||
lambda exe: (exe.replace("python.exe", "pythonw.exe"), r"C:\\Hermes\\hermes-agent\\venv", []),
|
||||
)
|
||||
|
||||
content = gateway_windows._build_gateway_cmd_script(
|
||||
r"C:\\Hermes\\hermes-agent\\venv\\Scripts\\python.exe",
|
||||
|
|
@ -206,6 +210,41 @@ def test_gateway_cmd_script_uses_pythonw_without_replace_or_start_churn(monkeypa
|
|||
assert "exit /b 0" in content
|
||||
|
||||
|
||||
def test_gateway_cmd_script_uses_uv_safe_base_pythonw(monkeypatch, tmp_path):
|
||||
"""Scheduled Task wrapper should share the detached uv-venv workaround."""
|
||||
project = tmp_path / "project"
|
||||
scripts = project / "venv" / "Scripts"
|
||||
site_packages = project / "venv" / "Lib" / "site-packages"
|
||||
hermes_home = tmp_path / "hermes-home"
|
||||
base = tmp_path / "uv" / "python" / "cpython-3.11-windows-x86_64-none"
|
||||
scripts.mkdir(parents=True)
|
||||
site_packages.mkdir(parents=True)
|
||||
hermes_home.mkdir()
|
||||
base.mkdir(parents=True)
|
||||
|
||||
venv_python = scripts / "python.exe"
|
||||
venv_pythonw = scripts / "pythonw.exe"
|
||||
base_pythonw = base / "pythonw.exe"
|
||||
for exe in (venv_python, venv_pythonw, base_pythonw):
|
||||
exe.write_text("", encoding="utf-8")
|
||||
(project / "venv" / "pyvenv.cfg").write_text(
|
||||
f"home = {base}\nimplementation = CPython\nuv = 0.11.14\nversion_info = 3.11.15\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
content = gateway_windows._build_gateway_cmd_script(
|
||||
str(venv_python),
|
||||
str(hermes_home),
|
||||
str(hermes_home),
|
||||
"",
|
||||
)
|
||||
|
||||
assert str(base_pythonw) in content
|
||||
assert f'set "VIRTUAL_ENV={project / "venv"}"' in content
|
||||
assert str(site_packages) in content
|
||||
assert str(venv_pythonw) not in content
|
||||
|
||||
|
||||
def test_elevated_gateway_command_uses_pythonw_hidden_console(monkeypatch):
|
||||
"""UAC handoff should not leave a second elevated cmd.exe window open."""
|
||||
calls = []
|
||||
|
|
@ -239,14 +278,18 @@ def test_install_scheduled_task_recreates_instead_of_change(monkeypatch, tmp_pat
|
|||
"""Install must delete+create so stale minute-repeat task settings are not preserved."""
|
||||
calls = []
|
||||
script_path = tmp_path / "Hermes_Gateway_alice.cmd"
|
||||
xml_seen = {}
|
||||
|
||||
monkeypatch.setattr(gateway_windows, "_assert_windows", lambda: None)
|
||||
monkeypatch.setattr(gateway_windows, "_resolve_task_user", lambda: r"DOMAIN\\alice")
|
||||
|
||||
def fake_schtasks(args):
|
||||
calls.append(tuple(args))
|
||||
if args[0] == "/Delete":
|
||||
return (0, "SUCCESS", "")
|
||||
if args[0] == "/Create":
|
||||
xml_path = Path(args[args.index("/XML") + 1])
|
||||
xml_seen["text"] = xml_path.read_text(encoding="utf-16")
|
||||
return (0, "SUCCESS", "")
|
||||
raise AssertionError(f"unexpected schtasks args: {args}")
|
||||
|
||||
|
|
@ -257,8 +300,88 @@ def test_install_scheduled_task_recreates_instead_of_change(monkeypatch, tmp_pat
|
|||
assert "/Change" not in [arg for call in calls for arg in call]
|
||||
assert calls[0][:4] == ("/Delete", "/F", "/TN", "Hermes_Gateway_alice")
|
||||
assert calls[1][0] == "/Create"
|
||||
assert "/SC" in calls[1]
|
||||
assert "ONLOGON" in calls[1]
|
||||
assert "/XML" in calls[1]
|
||||
assert "/SC" not in calls[1]
|
||||
assert "<Delay>PT30S</Delay>" in xml_seen["text"]
|
||||
assert "<StartWhenAvailable>true</StartWhenAvailable>" in xml_seen["text"]
|
||||
assert "<StopOnIdleEnd>false</StopOnIdleEnd>" in xml_seen["text"]
|
||||
assert "<DisallowStartIfOnBatteries>false</DisallowStartIfOnBatteries>" in xml_seen["text"]
|
||||
assert "<StopIfGoingOnBatteries>false</StopIfGoingOnBatteries>" in xml_seen["text"]
|
||||
assert "<ExecutionTimeLimit>PT0S</ExecutionTimeLimit>" in xml_seen["text"]
|
||||
assert "<RestartOnFailure>" in xml_seen["text"]
|
||||
assert "<Count>999</Count>" in xml_seen["text"]
|
||||
# Scheduled Task launches the console-less .vbs via wscript.exe, never cmd.exe
|
||||
# (issue #45599 fix A: no console -> no logon CTRL_CLOSE_EVENT / 0xC000013A).
|
||||
assert "<Command>wscript.exe</Command>" in xml_seen["text"]
|
||||
assert "//B //Nologo" in xml_seen["text"]
|
||||
assert "Hermes_Gateway_alice.vbs" in xml_seen["text"]
|
||||
assert "cmd.exe" not in xml_seen["text"]
|
||||
|
||||
|
||||
def test_gateway_vbs_script_is_console_less(monkeypatch):
|
||||
"""The .vbs launcher must avoid cmd.exe entirely and Run pythonw hidden
|
||||
(issue #45599 fix A: no console -> no logon CTRL_CLOSE_EVENT / 0xC000013A)."""
|
||||
monkeypatch.setattr(
|
||||
gateway_windows,
|
||||
"_resolve_detached_python",
|
||||
lambda exe: (r"C:\venv\Scripts\pythonw.exe", Path(r"C:\venv"), []),
|
||||
)
|
||||
content = gateway_windows._build_gateway_vbs_script(
|
||||
r"C:\venv\Scripts\python.exe",
|
||||
r"C:\Hermes",
|
||||
r"C:\Hermes",
|
||||
"--profile work",
|
||||
)
|
||||
assert "cmd.exe" not in content.lower()
|
||||
assert 'CreateObject("WScript.Shell")' in content
|
||||
assert "pythonw.exe" in content
|
||||
assert "hermes_cli.main" in content
|
||||
assert "gateway run" in content
|
||||
assert ", 0, False" in content # hidden window, detached/async
|
||||
for var in ("HERMES_HOME", "PYTHONIOENCODING", "HERMES_GATEWAY_DETACHED", "VIRTUAL_ENV", "PYTHONPATH"):
|
||||
assert var in content
|
||||
assert "--profile" in content and "work" in content
|
||||
assert content.endswith("\r\n")
|
||||
|
||||
|
||||
def test_gateway_vbs_script_quotes_spaced_paths(monkeypatch):
|
||||
"""Spaced exe/dir paths stay correctly quoted through the VBScript literal."""
|
||||
monkeypatch.setattr(
|
||||
gateway_windows,
|
||||
"_resolve_detached_python",
|
||||
lambda exe: (r"C:\Program Files\Py\pythonw.exe", Path(r"C:\v env"), []),
|
||||
)
|
||||
content = gateway_windows._build_gateway_vbs_script(
|
||||
r"C:\Program Files\Py\python.exe",
|
||||
r"C:\work dir",
|
||||
r"C:\h home",
|
||||
"",
|
||||
)
|
||||
# list2cmdline quotes the spaced exe; _quote_vbs_string doubles those quotes.
|
||||
assert '""C:\\Program Files\\Py\\pythonw.exe""' in content
|
||||
assert 'sh.CurrentDirectory = "C:\\work dir"' in content
|
||||
|
||||
|
||||
def test_gateway_vbs_script_pythonpath_chains_runtime_value(monkeypatch):
|
||||
"""PYTHONPATH chains onto the task env's existing value, like ;%PYTHONPATH%."""
|
||||
monkeypatch.setattr(
|
||||
gateway_windows,
|
||||
"_resolve_detached_python",
|
||||
lambda exe: (r"C:\v\pythonw.exe", Path(r"C:\v"), [r"C:\v\Lib\site-packages"]),
|
||||
)
|
||||
content = gateway_windows._build_gateway_vbs_script(
|
||||
r"C:\v\python.exe", r"C:\w", r"C:\h", "",
|
||||
)
|
||||
assert 'existing_pp = env.Item("PYTHONPATH")' in content
|
||||
assert "If Len(existing_pp) > 0 Then" in content
|
||||
assert r"C:\v\Lib\site-packages" in content
|
||||
|
||||
|
||||
def test_quote_vbs_string_doubles_quotes_and_rejects_newlines():
|
||||
assert gateway_windows._quote_vbs_string("plain") == '"plain"'
|
||||
assert gateway_windows._quote_vbs_string('a"b') == '"a""b"'
|
||||
with pytest.raises(ValueError):
|
||||
gateway_windows._quote_vbs_string("line1\nline2")
|
||||
|
||||
|
||||
def test_install_scheduled_task_success_start_now_uses_direct_spawn_not_task_run(monkeypatch, tmp_path, capsys):
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,4 +1,4 @@
|
|||
"""Tests for ``install_cua_driver`` upgrade semantics and architecture pre-check.
|
||||
"""Tests for ``install_cua_driver`` upgrade semantics.
|
||||
|
||||
The cua-driver upstream installer always pulls the latest release tag, so
|
||||
re-running it is the canonical upgrade path. ``install_cua_driver(upgrade=True)``
|
||||
|
|
@ -10,30 +10,34 @@ must:
|
|||
fix for the "we only pulled cua-driver once on enable" complaint).
|
||||
* Preserve original ``upgrade=False`` behaviour for the toolset-enable flow:
|
||||
skip if installed, install otherwise, warn on non-macOS.
|
||||
* Pre-check architecture compatibility before downloading to avoid raw 404
|
||||
errors on Intel macOS when the upstream release lacks x86_64 assets.
|
||||
|
||||
The pre-install arch probe that used to live alongside this function was
|
||||
deleted (see top-of-file comment in tools_config.py) — the upstream
|
||||
installer has CUA_DRIVER_RS_BAKED_VERSION baked in by CD and errors
|
||||
cleanly on missing-arch assets, and the upgrade path uses
|
||||
``cua_driver_update_check()`` (which shells `cua-driver check-update
|
||||
--json` against the already-installed binary).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
class TestInstallCuaDriverUpgrade:
|
||||
def test_upgrade_on_non_macos_is_silent_noop(self):
|
||||
def test_upgrade_on_unsupported_platform_is_silent_noop(self):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
with patch.object(tools_config, "_print_warning") as warn, \
|
||||
patch("platform.system", return_value="Linux"):
|
||||
patch("platform.system", return_value="FreeBSD"):
|
||||
assert tools_config.install_cua_driver(upgrade=True) is False
|
||||
warn.assert_not_called()
|
||||
|
||||
def test_non_upgrade_on_non_macos_warns(self):
|
||||
def test_non_upgrade_on_unsupported_platform_warns(self):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
with patch.object(tools_config, "_print_warning") as warn, \
|
||||
patch("platform.system", return_value="Linux"):
|
||||
patch("platform.system", return_value="FreeBSD"):
|
||||
assert tools_config.install_cua_driver(upgrade=False) is False
|
||||
warn.assert_called()
|
||||
|
||||
|
|
@ -44,8 +48,6 @@ class TestInstallCuaDriverUpgrade:
|
|||
patch.object(tools_config.shutil, "which",
|
||||
side_effect=lambda n: "/usr/local/bin/" + n
|
||||
if n in {"cua-driver", "curl"} else None), \
|
||||
patch.object(tools_config, "_check_cua_driver_asset_for_arch",
|
||||
return_value=True), \
|
||||
patch.object(tools_config, "_run_cua_driver_installer",
|
||||
return_value=True) as runner, \
|
||||
patch("subprocess.run"):
|
||||
|
|
@ -60,8 +62,6 @@ class TestInstallCuaDriverUpgrade:
|
|||
with patch("platform.system", return_value="Darwin"), \
|
||||
patch.object(tools_config.shutil, "which",
|
||||
side_effect=lambda n: "/usr/bin/curl" if n == "curl" else None), \
|
||||
patch.object(tools_config, "_check_cua_driver_asset_for_arch",
|
||||
return_value=True), \
|
||||
patch.object(tools_config, "_run_cua_driver_installer",
|
||||
return_value=True) as runner:
|
||||
assert tools_config.install_cua_driver(upgrade=True) is True
|
||||
|
|
@ -85,128 +85,75 @@ class TestInstallCuaDriverUpgrade:
|
|||
with patch("platform.system", return_value="Darwin"), \
|
||||
patch.object(tools_config.shutil, "which",
|
||||
side_effect=lambda n: "/usr/bin/curl" if n == "curl" else None), \
|
||||
patch.object(tools_config, "_check_cua_driver_asset_for_arch",
|
||||
return_value=True), \
|
||||
patch.object(tools_config, "_run_cua_driver_installer",
|
||||
return_value=True) as runner:
|
||||
assert tools_config.install_cua_driver(upgrade=False) is True
|
||||
runner.assert_called_once()
|
||||
|
||||
|
||||
class TestCheckCuaDriverAssetForArch:
|
||||
def test_arm64_always_returns_true(self):
|
||||
class TestArchProbeRemoval:
|
||||
"""Regression tests for the deletion of `_check_cua_driver_asset_for_arch`.
|
||||
|
||||
The old probe queried ``/releases/latest`` on trycua/cua and inspected
|
||||
asset names. That was wrong in two ways:
|
||||
|
||||
1. cua-driver-rs releases are marked **prerelease** on every cut, so
|
||||
``/releases/latest`` returns the Python ``cua-agent`` / ``cua-computer``
|
||||
package instead — a release with zero binary assets. The probe then
|
||||
reported "no asset for $arch" on Linux x86_64, Windows, macOS Intel,
|
||||
Linux arm64 — every non-Apple-Silicon host.
|
||||
2. Even with the right endpoint, it duplicated tag-resolution the upstream
|
||||
installer already does correctly via ``CUA_DRIVER_RS_BAKED_VERSION``
|
||||
(auto-baked by CD on every release).
|
||||
|
||||
The fix: stop probing. Trust the upstream installer for fresh installs
|
||||
(it has the baked version + correct API fallback) and the
|
||||
``cua-driver check-update --json`` MCP-binary native command for the
|
||||
upgrade path.
|
||||
"""
|
||||
|
||||
def test_probe_function_is_gone(self):
|
||||
from hermes_cli import tools_config
|
||||
assert not hasattr(tools_config, "_check_cua_driver_asset_for_arch")
|
||||
assert not hasattr(tools_config, "_latest_cua_driver_rs_release")
|
||||
|
||||
with patch("platform.machine", return_value="arm64"):
|
||||
assert tools_config._check_cua_driver_asset_for_arch() is True
|
||||
|
||||
def test_x86_64_with_asset_returns_true(self):
|
||||
def test_fresh_install_does_not_call_github_api(self):
|
||||
"""Pre-install no longer probes the GitHub API — the upstream
|
||||
``install.sh`` resolves the tag from its baked CUA_DRIVER_RS_BAKED_VERSION
|
||||
line. install.sh errors cleanly when the arch has no asset, so the
|
||||
probe was duplicate gatekeeping.
|
||||
"""
|
||||
from hermes_cli import tools_config
|
||||
|
||||
release = {
|
||||
"tag_name": "cua-driver-v0.1.6",
|
||||
"assets": [
|
||||
{"name": "cua-driver-0.1.6-darwin-arm64.tar.gz"},
|
||||
{"name": "cua-driver-0.1.6-darwin-x86_64.tar.gz"},
|
||||
],
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps(release).encode()
|
||||
mock_resp.__enter__ = lambda s: s
|
||||
mock_resp.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("platform.machine", return_value="x86_64"), \
|
||||
patch("urllib.request.urlopen", return_value=mock_resp):
|
||||
assert tools_config._check_cua_driver_asset_for_arch() is True
|
||||
|
||||
def test_x86_64_without_asset_returns_false(self):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
release = {
|
||||
"tag_name": "cua-driver-v0.1.6",
|
||||
"assets": [
|
||||
{"name": "cua-driver-0.1.6-darwin-arm64.tar.gz"},
|
||||
{"name": "cua-driver.tar.gz"},
|
||||
],
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps(release).encode()
|
||||
mock_resp.__enter__ = lambda s: s
|
||||
mock_resp.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("platform.machine", return_value="x86_64"), \
|
||||
patch("urllib.request.urlopen", return_value=mock_resp), \
|
||||
patch.object(tools_config, "_print_warning") as warn, \
|
||||
patch.object(tools_config, "_print_info"):
|
||||
assert tools_config._check_cua_driver_asset_for_arch() is False
|
||||
warn.assert_called_once()
|
||||
assert "no Intel" in warn.call_args[0][0].lower() or "x86_64" in warn.call_args[0][0]
|
||||
|
||||
def test_x86_64_api_failure_returns_true(self):
|
||||
"""Network failure should fail open — let the installer handle it."""
|
||||
from hermes_cli import tools_config
|
||||
|
||||
with patch("platform.machine", return_value="x86_64"), \
|
||||
patch("urllib.request.urlopen", side_effect=Exception("timeout")):
|
||||
assert tools_config._check_cua_driver_asset_for_arch() is True
|
||||
|
||||
def test_fresh_install_x86_64_no_asset_skips_installer(self):
|
||||
"""When the latest release has no Intel asset, skip the installer."""
|
||||
from hermes_cli import tools_config
|
||||
|
||||
release = {
|
||||
"tag_name": "cua-driver-v0.1.6",
|
||||
"assets": [{"name": "cua-driver-0.1.6-darwin-arm64.tar.gz"}],
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps(release).encode()
|
||||
mock_resp.__enter__ = lambda s: s
|
||||
mock_resp.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("platform.system", return_value="Darwin"), \
|
||||
patch.object(tools_config.shutil, "which",
|
||||
side_effect=lambda n: "/usr/bin/curl" if n == "curl" else None), \
|
||||
patch("platform.machine", return_value="x86_64"), \
|
||||
patch("urllib.request.urlopen", return_value=mock_resp), \
|
||||
patch.object(tools_config, "_print_warning"), \
|
||||
patch.object(tools_config, "_print_info"), \
|
||||
patch.object(tools_config, "_run_cua_driver_installer") as runner:
|
||||
assert tools_config.install_cua_driver(upgrade=False) is False
|
||||
runner.assert_not_called()
|
||||
patch("urllib.request.urlopen") as urlopen, \
|
||||
patch.object(tools_config, "_run_cua_driver_installer",
|
||||
return_value=True) as runner:
|
||||
assert tools_config.install_cua_driver(upgrade=False) is True
|
||||
runner.assert_called_once()
|
||||
urlopen.assert_not_called()
|
||||
|
||||
def test_upgrade_x86_64_no_asset_returns_existing_status(self):
|
||||
"""On upgrade with no Intel asset, return whether binary existed."""
|
||||
def test_upgrade_with_binary_does_not_call_github_api_directly(self):
|
||||
"""The upgrade path no longer hits GitHub from Python — it delegates
|
||||
to the upstream ``install.sh`` (which has the baked release tag and
|
||||
the proper API fallback). When cua-driver is already installed,
|
||||
``cua_driver_update_check()`` (added in a separate change) further
|
||||
short-circuits the network re-install via the binary's native
|
||||
``check-update --json`` verb.
|
||||
"""
|
||||
from hermes_cli import tools_config
|
||||
|
||||
release = {
|
||||
"tag_name": "cua-driver-v0.1.6",
|
||||
"assets": [{"name": "cua-driver-0.1.6-darwin-arm64.tar.gz"}],
|
||||
}
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.read.return_value = json.dumps(release).encode()
|
||||
mock_resp.__enter__ = lambda s: s
|
||||
mock_resp.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# With binary installed — returns True (binary exists)
|
||||
with patch("platform.system", return_value="Darwin"), \
|
||||
patch.object(tools_config.shutil, "which",
|
||||
side_effect=lambda n: "/usr/local/bin/" + n
|
||||
if n in ("cua-driver", "curl") else None), \
|
||||
patch("platform.machine", return_value="x86_64"), \
|
||||
patch("urllib.request.urlopen", return_value=mock_resp), \
|
||||
patch.object(tools_config, "_print_warning"), \
|
||||
patch.object(tools_config, "_print_info"), \
|
||||
patch.object(tools_config, "_run_cua_driver_installer") as runner:
|
||||
patch("urllib.request.urlopen") as urlopen, \
|
||||
patch("subprocess.run"), \
|
||||
patch.object(tools_config, "_run_cua_driver_installer",
|
||||
return_value=True) as runner:
|
||||
assert tools_config.install_cua_driver(upgrade=True) is True
|
||||
runner.assert_not_called()
|
||||
|
||||
# Without binary — returns False
|
||||
with patch("platform.system", return_value="Darwin"), \
|
||||
patch.object(tools_config.shutil, "which",
|
||||
side_effect=lambda n: "/usr/bin/curl" if n == "curl" else None), \
|
||||
patch("platform.machine", return_value="x86_64"), \
|
||||
patch("urllib.request.urlopen", return_value=mock_resp), \
|
||||
patch.object(tools_config, "_print_warning"), \
|
||||
patch.object(tools_config, "_print_info"), \
|
||||
patch.object(tools_config, "_run_cua_driver_installer") as runner:
|
||||
assert tools_config.install_cua_driver(upgrade=True) is False
|
||||
runner.assert_not_called()
|
||||
runner.assert_called_once()
|
||||
# Probe deleted — no direct GitHub API call from Python.
|
||||
urlopen.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -639,6 +639,46 @@ def test_aggregator_dedup_does_not_empty_user_defined_custom_provider():
|
|||
assert or_row["total_models"] == 1
|
||||
|
||||
|
||||
def test_flat_namespace_reseller_keeps_first_party_models_overlapping_user_proxy():
|
||||
"""opencode-go / opencode-zen are flagged ``is_aggregator=True`` (their
|
||||
flat ``/v1/models`` returns bare IDs the model-switch resolver searches),
|
||||
but they are NOT routing aggregators — every model they list is a
|
||||
first-party model under the user's subscription. When a user also runs a
|
||||
custom proxy that happens to serve a same-named model, the picker dedup
|
||||
must NOT strip the reseller's own catalog. Regression for #47077, where
|
||||
opencode-go showed only 13 of 19 models because minimax-m3/m2.7/m2.5,
|
||||
glm-5/5.1, and deepseek-v4-flash were deduped against an overlapping
|
||||
custom provider.
|
||||
"""
|
||||
rows = [
|
||||
_user_provider_row("custom:my-proxy", [
|
||||
"minimax-m3", "minimax-m2.7", "glm-5", "deepseek-v4-flash",
|
||||
]),
|
||||
_aggregator_row("opencode-go", [
|
||||
"kimi-k2.6", "minimax-m3", "minimax-m2.7", "glm-5",
|
||||
"deepseek-v4-flash", "qwen3.7-max",
|
||||
]),
|
||||
_aggregator_row("openrouter", ["minimax-m3", "anthropic/claude-sonnet-4.6"]),
|
||||
]
|
||||
ctx = _empty_ctx()
|
||||
with _list_auth_returning(rows):
|
||||
payload = build_models_payload(ctx)
|
||||
|
||||
go_row = next(r for r in payload["providers"] if r["slug"] == "opencode-go")
|
||||
or_row = next(r for r in payload["providers"] if r["slug"] == "openrouter")
|
||||
|
||||
# The reseller keeps ALL of its first-party models — nothing stripped.
|
||||
assert go_row["models"] == [
|
||||
"kimi-k2.6", "minimax-m3", "minimax-m2.7", "glm-5",
|
||||
"deepseek-v4-flash", "qwen3.7-max",
|
||||
]
|
||||
assert go_row["total_models"] == 6
|
||||
|
||||
# A TRUE routing aggregator is still deduped against the user's models.
|
||||
assert "minimax-m3" not in or_row["models"]
|
||||
assert "anthropic/claude-sonnet-4.6" in or_row["models"]
|
||||
|
||||
|
||||
def test_two_custom_providers_with_overlap_both_survive():
|
||||
"""Two user-defined custom endpoints that happen to expose an
|
||||
overlapping model must each keep their full catalog. Neither is the
|
||||
|
|
|
|||
|
|
@ -179,9 +179,10 @@ def _patch_judge(monkeypatch, verdicts):
|
|||
"""Make judge_goal return a scripted sequence of verdicts."""
|
||||
seq = list(verdicts)
|
||||
|
||||
def _fake_judge(goal, response, subgoals=None):
|
||||
def _fake_judge(goal, response, subgoals=None, background_processes=None, **_kw):
|
||||
v = seq.pop(0) if seq else "done"
|
||||
return v, f"scripted:{v}", False
|
||||
# 4-tuple contract: (verdict, reason, parse_failed, wait_directive)
|
||||
return v, f"scripted:{v}", False, None
|
||||
|
||||
monkeypatch.setattr(goals, "judge_goal", _fake_judge)
|
||||
|
||||
|
|
|
|||
|
|
@ -129,6 +129,23 @@ def test_is_aggregator_leaves_unknown_provider_non_aggregator():
|
|||
assert providers_mod.is_aggregator("not-a-provider") is False
|
||||
|
||||
|
||||
def test_is_routing_aggregator_excludes_flat_namespace_resellers():
|
||||
"""opencode-go / opencode-zen stay ``is_aggregator=True`` (model-switch
|
||||
relies on it to search their flat bare-name catalog), but they are NOT
|
||||
routing aggregators — their models are first-party, so the picker dedup
|
||||
must not strip them. (#47077)"""
|
||||
# Still aggregators for model-switch flat-catalog resolution.
|
||||
assert providers_mod.is_aggregator("opencode-go") is True
|
||||
assert providers_mod.is_aggregator("opencode-zen") is True
|
||||
# But NOT routing aggregators for picker-dedup purposes.
|
||||
assert providers_mod.is_routing_aggregator("opencode-go") is False
|
||||
assert providers_mod.is_routing_aggregator("opencode-zen") is False
|
||||
# True routers and custom proxies remain routing aggregators.
|
||||
assert providers_mod.is_routing_aggregator("openrouter") is True
|
||||
assert providers_mod.is_routing_aggregator("custom:litellm") is True
|
||||
assert providers_mod.is_routing_aggregator("not-a-provider") is False
|
||||
|
||||
|
||||
def test_switch_model_accepts_explicit_named_custom_provider(monkeypatch):
|
||||
"""Shared /model switch pipeline should accept --provider for custom_providers."""
|
||||
monkeypatch.setattr(
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ These tests pin each layer of the new defence:
|
|||
* ``_safe_plugin_api_relpath`` rejects absolute paths, ``..``
|
||||
traversal, and non-string / empty values.
|
||||
* ``_mount_plugin_api_routes`` re-validates at import time and
|
||||
refuses project-source plugins outright.
|
||||
refuses user/project-source plugin backend code outright.
|
||||
* End-to-end the original PoC manifest no longer triggers
|
||||
``importlib`` for ``/tmp/payload.py``.
|
||||
"""
|
||||
|
|
@ -216,7 +216,7 @@ class TestDiscoveryScrubsApiField:
|
|||
assert entry["_api_file"] is None
|
||||
assert entry["has_api"] is False
|
||||
|
||||
def test_safe_api_path_survives(self, user_plugin_factory, tmp_path):
|
||||
def test_user_safe_api_path_is_scrubbed(self, user_plugin_factory, tmp_path):
|
||||
user_plugin_factory("safe", {
|
||||
"name": "safe",
|
||||
"label": "Safe",
|
||||
|
|
@ -230,6 +230,86 @@ class TestDiscoveryScrubsApiField:
|
|||
)
|
||||
plugins = web_server._get_dashboard_plugins(force_rescan=True)
|
||||
entry = next(p for p in plugins if p["name"] == "safe")
|
||||
assert entry["_api_file"] is None
|
||||
assert entry["has_api"] is False
|
||||
|
||||
def test_project_safe_api_path_is_scrubbed(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "home"))
|
||||
(tmp_path / "home").mkdir()
|
||||
monkeypatch.setenv("HERMES_ENABLE_PROJECT_PLUGINS", "1")
|
||||
cwd = tmp_path / "project"
|
||||
cwd.mkdir()
|
||||
monkeypatch.chdir(cwd)
|
||||
dashboard = _write_plugin_manifest(
|
||||
cwd / ".hermes" / "plugins",
|
||||
"safe-project",
|
||||
{
|
||||
"name": "safe-project",
|
||||
"label": "Safe Project",
|
||||
"api": "api.py",
|
||||
"entry": "dist/index.js",
|
||||
},
|
||||
)
|
||||
(dashboard / "api.py").write_text("router = None\n")
|
||||
|
||||
plugins = web_server._get_dashboard_plugins(force_rescan=True)
|
||||
entry = next(p for p in plugins if p["name"] == "safe-project")
|
||||
assert entry["_api_file"] is None
|
||||
assert entry["has_api"] is False
|
||||
|
||||
def test_bundled_safe_api_path_survives(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "home"
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_BUNDLED_PLUGINS", str(tmp_path / "bundled"))
|
||||
dashboard = _write_plugin_manifest(
|
||||
tmp_path / "bundled",
|
||||
"safe-bundled",
|
||||
{
|
||||
"name": "safe-bundled",
|
||||
"label": "Safe Bundled",
|
||||
"api": "api.py",
|
||||
"entry": "dist/index.js",
|
||||
},
|
||||
)
|
||||
(dashboard / "api.py").write_text("router = None\n")
|
||||
|
||||
plugins = web_server._get_dashboard_plugins(force_rescan=True)
|
||||
entry = next(p for p in plugins if p["name"] == "safe-bundled")
|
||||
assert entry["_api_file"] == "api.py"
|
||||
assert entry["has_api"] is True
|
||||
|
||||
def test_user_plugin_does_not_shadow_bundled_backend(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "home"
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_BUNDLED_PLUGINS", str(tmp_path / "bundled"))
|
||||
|
||||
bundled_dashboard = _write_plugin_manifest(
|
||||
tmp_path / "bundled",
|
||||
"shadowed",
|
||||
{
|
||||
"name": "shadowed",
|
||||
"label": "Bundled Shadowed",
|
||||
"api": "api.py",
|
||||
"entry": "dist/index.js",
|
||||
},
|
||||
)
|
||||
(bundled_dashboard / "api.py").write_text("router = None\n")
|
||||
_write_plugin_manifest(
|
||||
hermes_home / "plugins",
|
||||
"shadowed",
|
||||
{
|
||||
"name": "shadowed",
|
||||
"label": "User Shadowed",
|
||||
"api": "api.py",
|
||||
"entry": "dist/index.js",
|
||||
},
|
||||
)
|
||||
|
||||
plugins = web_server._get_dashboard_plugins(force_rescan=True)
|
||||
entry = next(p for p in plugins if p["name"] == "shadowed")
|
||||
assert entry["source"] == "bundled"
|
||||
assert entry["_api_file"] == "api.py"
|
||||
assert entry["has_api"] is True
|
||||
|
||||
|
|
@ -276,6 +356,16 @@ class TestMountApiRoutesRefusesUntrusted:
|
|||
"GHSA-5qr3-c538-wm9j defence-in-depth regression"
|
||||
)
|
||||
|
||||
def test_user_source_api_is_not_imported(self, tmp_path):
|
||||
plugin = self._payload_plugin(tmp_path, source="user")
|
||||
web_server._dashboard_plugins_cache = [plugin]
|
||||
with patch("importlib.util.spec_from_file_location") as spec:
|
||||
web_server._mount_plugin_api_routes()
|
||||
assert spec.call_count == 0, (
|
||||
"user-installed plugin api file was imported — "
|
||||
"third-party dashboard plugin backend code must stay inert"
|
||||
)
|
||||
|
||||
def test_bundled_source_api_imports_normally(self, tmp_path):
|
||||
plugin = self._payload_plugin(tmp_path, source="bundled")
|
||||
web_server._dashboard_plugins_cache = [plugin]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,30 @@
|
|||
"""Tests for Slack CLI helpers."""
|
||||
|
||||
import argparse
|
||||
|
||||
from hermes_cli.slack_cli import _build_full_manifest
|
||||
from hermes_cli.subcommands.slack import build_slack_parser
|
||||
|
||||
|
||||
def _parse_slack_args(argv):
|
||||
"""Build the real `hermes slack` parser and parse argv against it."""
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="command")
|
||||
build_slack_parser(subparsers, cmd_slack=lambda _args: 0)
|
||||
return parser.parse_args(argv)
|
||||
|
||||
|
||||
class TestSlackManifestArgparse:
|
||||
"""The `--no-assistant` flag wires through argparse to `no_assistant`."""
|
||||
|
||||
def test_no_assistant_flag_defaults_false(self):
|
||||
args = _parse_slack_args(["slack", "manifest"])
|
||||
assert getattr(args, "no_assistant", False) is False
|
||||
|
||||
def test_no_assistant_flag_sets_true(self):
|
||||
args = _parse_slack_args(["slack", "manifest", "--no-assistant"])
|
||||
assert args.no_assistant is True
|
||||
|
||||
|
||||
|
||||
class TestSlackFullManifest:
|
||||
|
|
@ -28,3 +52,35 @@ class TestSlackFullManifest:
|
|||
assert "assistant:write" in manifest["oauth_config"]["scopes"]["bot"]
|
||||
bot_events = manifest["settings"]["event_subscriptions"]["bot_events"]
|
||||
assert "assistant_thread_started" in bot_events
|
||||
|
||||
def test_no_assistant_omits_assistant_pieces(self):
|
||||
manifest = _build_full_manifest(
|
||||
"Hermes", "Your Hermes agent on Slack", include_assistant=False
|
||||
)
|
||||
|
||||
# assistant_view feature is gone -> Slack renders a flat DM, not the
|
||||
# Assistant thread pane (where bare slash commands don't dispatch).
|
||||
assert "assistant_view" not in manifest["features"]
|
||||
assert "assistant:write" not in manifest["oauth_config"]["scopes"]["bot"]
|
||||
bot_events = manifest["settings"]["event_subscriptions"]["bot_events"]
|
||||
assert "assistant_thread_started" not in bot_events
|
||||
assert "assistant_thread_context_changed" not in bot_events
|
||||
|
||||
def test_no_assistant_preserves_core_surface(self):
|
||||
"""Dropping assistant mode must NOT strip the regular messaging surface."""
|
||||
manifest = _build_full_manifest(
|
||||
"Hermes", "Your Hermes agent on Slack", include_assistant=False
|
||||
)
|
||||
|
||||
# Flat DM still needs the Messages tab writable.
|
||||
assert manifest["features"]["app_home"]["messages_tab_enabled"] is True
|
||||
# Slash commands and Socket Mode are independent of assistant mode.
|
||||
assert manifest["features"]["slash_commands"]
|
||||
assert manifest["settings"]["socket_mode_enabled"] is True
|
||||
# Channel + DM scopes/events survive so the bot still works everywhere.
|
||||
bot_scopes = manifest["oauth_config"]["scopes"]["bot"]
|
||||
for scope in ("commands", "channels:history", "groups:read", "im:history"):
|
||||
assert scope in bot_scopes
|
||||
bot_events = manifest["settings"]["event_subscriptions"]["bot_events"]
|
||||
for event in ("message.im", "message.channels", "message.groups", "app_mention"):
|
||||
assert event in bot_events
|
||||
|
|
|
|||
|
|
@ -93,7 +93,8 @@ def test_check_for_updates_expired_cache(tmp_path, monkeypatch):
|
|||
result = check_for_updates()
|
||||
|
||||
assert result == 5
|
||||
assert mock_run.call_count == 3 # origin probe + git fetch + git rev-list
|
||||
# origin probe + is-shallow probe + git fetch + git rev-list
|
||||
assert mock_run.call_count == 4
|
||||
|
||||
|
||||
def test_check_for_updates_official_ssh_origin_uses_https_probe(tmp_path):
|
||||
|
|
@ -128,6 +129,99 @@ def test_check_for_updates_official_ssh_origin_uses_https_probe(tmp_path):
|
|||
assert ["git", "fetch", "origin", "--quiet"] not in calls
|
||||
|
||||
|
||||
def test_check_via_local_git_shallow_clone_behind_reports_no_count(tmp_path):
|
||||
"""Shallow installer clones must report presence-only, never a bogus count.
|
||||
|
||||
On a ``git clone --depth 1`` checkout the history stops at one commit, so
|
||||
counting ``HEAD..origin/main`` across the shallow boundary yields a huge
|
||||
nonsense number (the "12492 commits behind" banner). The shallow path must
|
||||
compare tip SHAs and return UPDATE_AVAILABLE_NO_COUNT instead, and must
|
||||
never run ``git rev-list --count``.
|
||||
"""
|
||||
import hermes_cli.banner as banner
|
||||
|
||||
repo_dir = tmp_path / "hermes-agent"
|
||||
repo_dir.mkdir()
|
||||
(repo_dir / ".git").mkdir()
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
calls.append(cmd)
|
||||
if cmd == ["git", "remote", "get-url", "origin"]:
|
||||
return MagicMock(returncode=0, stdout="https://github.com/NousResearch/hermes-agent.git\n")
|
||||
if cmd == ["git", "rev-parse", "--is-shallow-repository"]:
|
||||
return MagicMock(returncode=0, stdout="true\n")
|
||||
if cmd[:2] == ["git", "fetch"]:
|
||||
return MagicMock(returncode=0, stdout="")
|
||||
if cmd == ["git", "rev-parse", "HEAD"]:
|
||||
return MagicMock(returncode=0, stdout="local-sha\n")
|
||||
if cmd == ["git", "rev-parse", "FETCH_HEAD"]:
|
||||
return MagicMock(returncode=0, stdout="upstream-sha\n")
|
||||
if cmd[:3] == ["git", "rev-list", "--count"]:
|
||||
raise AssertionError("shallow path must not count across the boundary")
|
||||
raise AssertionError(f"unexpected git command: {cmd!r}")
|
||||
|
||||
with patch("hermes_cli.banner.subprocess.run", side_effect=fake_run):
|
||||
result = banner._check_via_local_git(repo_dir)
|
||||
|
||||
assert result == banner.UPDATE_AVAILABLE_NO_COUNT
|
||||
# The shallow fetch must preserve the boundary (--depth 1), not unshallow.
|
||||
assert ["git", "fetch", "origin", "--depth", "1", "--quiet"] in calls
|
||||
|
||||
|
||||
def test_check_via_local_git_shallow_clone_up_to_date(tmp_path):
|
||||
"""Shallow clone whose tip matches upstream reports up-to-date (0)."""
|
||||
import hermes_cli.banner as banner
|
||||
|
||||
repo_dir = tmp_path / "hermes-agent"
|
||||
repo_dir.mkdir()
|
||||
(repo_dir / ".git").mkdir()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
if cmd == ["git", "remote", "get-url", "origin"]:
|
||||
return MagicMock(returncode=0, stdout="https://github.com/NousResearch/hermes-agent.git\n")
|
||||
if cmd == ["git", "rev-parse", "--is-shallow-repository"]:
|
||||
return MagicMock(returncode=0, stdout="true\n")
|
||||
if cmd[:2] == ["git", "fetch"]:
|
||||
return MagicMock(returncode=0, stdout="")
|
||||
if cmd == ["git", "rev-parse", "HEAD"]:
|
||||
return MagicMock(returncode=0, stdout="same-sha\n")
|
||||
if cmd == ["git", "rev-parse", "FETCH_HEAD"]:
|
||||
return MagicMock(returncode=0, stdout="same-sha\n")
|
||||
raise AssertionError(f"unexpected git command: {cmd!r}")
|
||||
|
||||
with patch("hermes_cli.banner.subprocess.run", side_effect=fake_run):
|
||||
result = banner._check_via_local_git(repo_dir)
|
||||
|
||||
assert result == 0
|
||||
|
||||
|
||||
def test_check_via_local_git_full_clone_keeps_exact_count(tmp_path):
|
||||
"""Full (non-shallow) clones keep the exact rev-list count path."""
|
||||
import hermes_cli.banner as banner
|
||||
|
||||
repo_dir = tmp_path / "hermes-agent"
|
||||
repo_dir.mkdir()
|
||||
(repo_dir / ".git").mkdir()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
if cmd == ["git", "remote", "get-url", "origin"]:
|
||||
return MagicMock(returncode=0, stdout="https://github.com/NousResearch/hermes-agent.git\n")
|
||||
if cmd == ["git", "rev-parse", "--is-shallow-repository"]:
|
||||
return MagicMock(returncode=0, stdout="false\n")
|
||||
if cmd[:2] == ["git", "fetch"]:
|
||||
return MagicMock(returncode=0, stdout="")
|
||||
if cmd[:3] == ["git", "rev-list", "--count"]:
|
||||
return MagicMock(returncode=0, stdout="7\n")
|
||||
raise AssertionError(f"unexpected git command: {cmd!r}")
|
||||
|
||||
with patch("hermes_cli.banner.subprocess.run", side_effect=fake_run):
|
||||
result = banner._check_via_local_git(repo_dir)
|
||||
|
||||
assert result == 7
|
||||
|
||||
|
||||
def test_check_for_updates_no_git_dir(tmp_path, monkeypatch):
|
||||
"""Falls back to PyPI check when .git directory doesn't exist anywhere."""
|
||||
import hermes_cli.banner as banner
|
||||
|
|
|
|||
|
|
@ -597,6 +597,120 @@ def test_resume_windows_gateways_after_update_respawns_unmapped_by_cmdline(
|
|||
assert "Restarting 1 unmapped Windows gateway process(es)" in out
|
||||
|
||||
|
||||
@patch.object(cli_main, "_is_windows", return_value=True)
|
||||
def test_pause_returns_cold_start_token_when_installed_but_none_running(
|
||||
_winp,
|
||||
monkeypatch,
|
||||
):
|
||||
"""No gateway running + autostart entry installed → cold-start token.
|
||||
|
||||
A gateway that died between updates (spawning terminal/TUI closed) leaves
|
||||
nothing for the resume path to relaunch, but the installed autostart entry
|
||||
is an explicit "I want a gateway" signal. The pause step must return a
|
||||
token that tells resume to cold-start one.
|
||||
"""
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
from hermes_cli import gateway_windows
|
||||
|
||||
monkeypatch.setattr(gateway_mod, "find_gateway_pids", lambda **_k: [])
|
||||
monkeypatch.setattr(gateway_windows, "is_installed", lambda: True)
|
||||
|
||||
token = cli_main._pause_windows_gateways_for_update()
|
||||
|
||||
assert token == {
|
||||
"resume_needed": True,
|
||||
"profiles": {},
|
||||
"unmapped_pids": [],
|
||||
"unmapped": [],
|
||||
"cold_start_if_installed": True,
|
||||
}
|
||||
|
||||
|
||||
@patch.object(cli_main, "_is_windows", return_value=True)
|
||||
def test_pause_returns_none_when_nothing_running_and_not_installed(
|
||||
_winp,
|
||||
monkeypatch,
|
||||
):
|
||||
"""No gateway running + no autostart entry → no token (gateway-less user).
|
||||
|
||||
Users who deliberately run without a gateway must not get one forced on
|
||||
them by an update.
|
||||
"""
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
from hermes_cli import gateway_windows
|
||||
|
||||
monkeypatch.setattr(gateway_mod, "find_gateway_pids", lambda **_k: [])
|
||||
monkeypatch.setattr(gateway_windows, "is_installed", lambda: False)
|
||||
|
||||
assert cli_main._pause_windows_gateways_for_update() is None
|
||||
|
||||
|
||||
@patch.object(cli_main, "_is_windows", return_value=True)
|
||||
def test_resume_cold_starts_gateway_when_token_requests_it(
|
||||
_winp,
|
||||
monkeypatch,
|
||||
capsys,
|
||||
):
|
||||
"""cold_start_if_installed token + nothing running → fresh detached spawn."""
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
from hermes_cli import gateway_windows
|
||||
|
||||
monkeypatch.setattr(gateway_mod, "find_gateway_pids", lambda **_k: [])
|
||||
spawned = []
|
||||
monkeypatch.setattr(
|
||||
gateway_windows,
|
||||
"_spawn_detached",
|
||||
lambda: spawned.append(True) or 4242,
|
||||
)
|
||||
|
||||
token = {
|
||||
"resume_needed": True,
|
||||
"profiles": {},
|
||||
"unmapped_pids": [],
|
||||
"unmapped": [],
|
||||
"cold_start_if_installed": True,
|
||||
}
|
||||
|
||||
cli_main._resume_windows_gateways_after_update(token)
|
||||
|
||||
assert token["resume_needed"] is False
|
||||
assert spawned == [True]
|
||||
assert "Starting Windows gateway after update (PID 4242)" in capsys.readouterr().out
|
||||
|
||||
|
||||
@patch.object(cli_main, "_is_windows", return_value=True)
|
||||
def test_resume_cold_start_skips_when_gateway_already_running(
|
||||
_winp,
|
||||
monkeypatch,
|
||||
capsys,
|
||||
):
|
||||
"""Don't double-start: if a gateway came up between pause and resume
|
||||
(e.g. the autostart entry fired), the cold-start must no-op."""
|
||||
import hermes_cli.gateway as gateway_mod
|
||||
from hermes_cli import gateway_windows
|
||||
|
||||
monkeypatch.setattr(gateway_mod, "find_gateway_pids", lambda **_k: [9001])
|
||||
spawned = []
|
||||
monkeypatch.setattr(
|
||||
gateway_windows,
|
||||
"_spawn_detached",
|
||||
lambda: spawned.append(True) or 4242,
|
||||
)
|
||||
|
||||
token = {
|
||||
"resume_needed": True,
|
||||
"profiles": {},
|
||||
"unmapped_pids": [],
|
||||
"unmapped": [],
|
||||
"cold_start_if_installed": True,
|
||||
}
|
||||
|
||||
cli_main._resume_windows_gateways_after_update(token)
|
||||
|
||||
assert spawned == []
|
||||
assert "Starting Windows gateway after update" not in capsys.readouterr().out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cmd_update integration — concurrent-instance gate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -263,6 +263,29 @@ class TestWebServerEndpoints:
|
|||
import hermes_cli.web_server as web_server
|
||||
|
||||
monkeypatch.setattr(hermes_constants, "is_container", lambda: True)
|
||||
# A docker install inside a container should be managed externally.
|
||||
monkeypatch.setattr(web_server, "detect_install_method", lambda _root: "docker")
|
||||
|
||||
assert web_server._dashboard_local_update_managed_externally() is True
|
||||
|
||||
def test_dashboard_update_capability_allows_git_in_container(self, monkeypatch):
|
||||
"""A git checkout inside a container (e.g. bind-mounted in hermes-webui)
|
||||
should still offer dashboard updates — the checkout is self-managed."""
|
||||
import hermes_constants
|
||||
import hermes_cli.web_server as web_server
|
||||
|
||||
monkeypatch.setattr(hermes_constants, "is_container", lambda: True)
|
||||
monkeypatch.setattr(web_server, "detect_install_method", lambda _root: "git")
|
||||
|
||||
assert web_server._dashboard_local_update_managed_externally() is False
|
||||
|
||||
def test_dashboard_update_capability_blocks_pip_in_container(self, monkeypatch):
|
||||
"""A pip install inside a container is still managed externally."""
|
||||
import hermes_constants
|
||||
import hermes_cli.web_server as web_server
|
||||
|
||||
monkeypatch.setattr(hermes_constants, "is_container", lambda: True)
|
||||
monkeypatch.setattr(web_server, "detect_install_method", lambda _root: "pip")
|
||||
|
||||
assert web_server._dashboard_local_update_managed_externally() is True
|
||||
|
||||
|
|
@ -1011,6 +1034,8 @@ class TestWebServerEndpoints:
|
|||
spawned = True
|
||||
raise AssertionError("docker update guard should not spawn hermes update")
|
||||
|
||||
# Bypass the managed-externally gate so we reach the docker install check.
|
||||
monkeypatch.setattr(web_server, "_dashboard_local_update_managed_externally", lambda: False)
|
||||
monkeypatch.setattr(web_server, "detect_install_method", lambda _root: "docker")
|
||||
monkeypatch.setattr(web_server, "_spawn_hermes_action", fail_spawn)
|
||||
web_server._ACTION_PROCS.pop("hermes-update", None)
|
||||
|
|
@ -5070,14 +5095,8 @@ class TestPluginAPIAuth:
|
|||
"""Tests that plugin API routes require the session token (issue #19533)."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_test_client(self, monkeypatch, _isolate_hermes_home, _install_example_plugin):
|
||||
"""Create a TestClient without the session token header.
|
||||
|
||||
Pulls in ``_install_example_plugin`` so ``test_plugin_route_allows_auth``
|
||||
has the ``/api/plugins/example/hello`` endpoint available — the
|
||||
example plugin is no longer a bundled plugin, so the fixture
|
||||
installs it into the per-test ``HERMES_HOME``.
|
||||
"""
|
||||
def _setup_test_client(self, monkeypatch, _isolate_hermes_home):
|
||||
"""Create TestClients with and without the session token header."""
|
||||
try:
|
||||
from starlette.testclient import TestClient
|
||||
except ImportError:
|
||||
|
|
@ -5102,19 +5121,15 @@ class TestPluginAPIAuth:
|
|||
def test_plugin_route_allows_auth(self):
|
||||
"""Plugin API routes should work with a valid session token.
|
||||
|
||||
Uses ``/api/plugins/example/hello`` from the example-dashboard
|
||||
test fixture (installed into HERMES_HOME by the class-level
|
||||
``_install_example_plugin`` fixture) — a stable, side-effect-free
|
||||
GET that's only loaded for tests. With a valid token the handler
|
||||
should run (200); without one the middleware should 401 before
|
||||
the handler is reached.
|
||||
Uses a bundled plugin route so the test covers authenticated plugin
|
||||
API access without relying on user-installed plugin backend imports.
|
||||
"""
|
||||
# Without auth: middleware blocks before reaching the handler.
|
||||
resp = self.client.get("/api/plugins/example/hello")
|
||||
resp = self.client.get("/api/plugins/kanban/board")
|
||||
assert resp.status_code == 401
|
||||
|
||||
# With auth: handler runs.
|
||||
resp = self.auth_client.get("/api/plugins/example/hello")
|
||||
resp = self.auth_client.get("/api/plugins/kanban/board")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_plugin_post_requires_auth(self):
|
||||
|
|
|
|||
|
|
@ -155,15 +155,31 @@ class TestResolveSessionNameTitle:
|
|||
result = cfg.resolve_session_name("/some/dir", session_id=None)
|
||||
assert result == "dir"
|
||||
|
||||
def test_title_beats_session_id(self):
|
||||
def test_per_session_id_beats_title(self):
|
||||
# per-session: the run's session_id is authoritative; an (auto-)generated
|
||||
# title must NOT remap a live conversation onto a second Honcho session.
|
||||
cfg = HonchoClientConfig(session_strategy="per-session")
|
||||
result = cfg.resolve_session_name("/some/dir", session_title="my-title", session_id="20260309_175514_9797dd")
|
||||
assert result == "20260309_175514_9797dd"
|
||||
|
||||
def test_per_session_id_beats_manual_map(self):
|
||||
# per-session: session_id also wins over a stale cwd map entry (e.g. the
|
||||
# desktop launching from a mapped home dir).
|
||||
cfg = HonchoClientConfig(session_strategy="per-session", sessions={"/some/dir": "pinned"})
|
||||
result = cfg.resolve_session_name("/some/dir", session_id="20260309_175514_9797dd")
|
||||
assert result == "20260309_175514_9797dd"
|
||||
|
||||
def test_title_still_applies_for_non_per_session(self):
|
||||
# Outside per-session, /title still names the Honcho session.
|
||||
cfg = HonchoClientConfig(session_strategy="per-directory")
|
||||
result = cfg.resolve_session_name("/some/dir", session_title="my-title", session_id="20260309_175514_9797dd")
|
||||
assert result == "my-title"
|
||||
|
||||
def test_manual_beats_session_id(self):
|
||||
cfg = HonchoClientConfig(session_strategy="per-session", sessions={"/some/dir": "pinned"})
|
||||
result = cfg.resolve_session_name("/some/dir", session_id="20260309_175514_9797dd")
|
||||
assert result == "pinned"
|
||||
def test_gateway_key_beats_per_session_id(self):
|
||||
# Gateways keep per-chat isolation even in per-session.
|
||||
cfg = HonchoClientConfig(session_strategy="per-session")
|
||||
result = cfg.resolve_session_name("/some/dir", gateway_session_key="agent:main:telegram:dm:42", session_id="20260309_175514_9797dd")
|
||||
assert result == "agent-main-telegram-dm-42"
|
||||
|
||||
def test_global_strategy_returns_workspace(self):
|
||||
cfg = HonchoClientConfig(session_strategy="global", workspace_id="my-workspace")
|
||||
|
|
|
|||
|
|
@ -234,6 +234,66 @@ class TestCmdStatus:
|
|||
assert "FAILED (Invalid API key)" in out
|
||||
assert "Connection... OK" not in out
|
||||
|
||||
def test_auth_line_detects_oauth_grant(self, monkeypatch, capsys, tmp_path):
|
||||
import plugins.memory.honcho.cli as honcho_cli
|
||||
|
||||
cfg_path = tmp_path / "honcho.json"
|
||||
cfg_path.write_text("{}")
|
||||
|
||||
class FakeConfig:
|
||||
enabled = True
|
||||
api_key = "hch-at-deadbeef"
|
||||
workspace_id = "claude-code"
|
||||
host = "hermes"
|
||||
base_url = None
|
||||
ai_peer = "hermes"
|
||||
peer_name = "eri"
|
||||
recall_mode = "hybrid"
|
||||
user_observe_me = True
|
||||
user_observe_others = False
|
||||
ai_observe_me = False
|
||||
ai_observe_others = True
|
||||
write_frequency = "async"
|
||||
session_strategy = "per-session"
|
||||
context_tokens = None
|
||||
dialectic_reasoning_level = "low"
|
||||
reasoning_level_cap = "high"
|
||||
reasoning_heuristic = True
|
||||
raw = {
|
||||
"hosts": {
|
||||
"hermes": {
|
||||
"apiKey": "hch-at-deadbeef",
|
||||
"oauth": {
|
||||
"refreshToken": "hch-rt-x",
|
||||
"clientId": "hermes-agent",
|
||||
"tokenEndpoint": "https://api.honcho.dev/oauth/token",
|
||||
"expiresAt": 9999999999,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def resolve_session_name(self):
|
||||
return "hermes"
|
||||
|
||||
monkeypatch.setattr(honcho_cli, "_read_config", lambda: {})
|
||||
monkeypatch.setattr(honcho_cli, "_config_path", lambda: cfg_path)
|
||||
monkeypatch.setattr(honcho_cli, "_local_config_path", lambda: cfg_path)
|
||||
monkeypatch.setattr(honcho_cli, "_active_profile_name", lambda: "default")
|
||||
monkeypatch.setattr(
|
||||
"plugins.memory.honcho.client.HonchoClientConfig.from_global_config",
|
||||
lambda host=None: FakeConfig(),
|
||||
)
|
||||
monkeypatch.setattr("plugins.memory.honcho.client.get_honcho_client", lambda cfg: object())
|
||||
monkeypatch.setattr(honcho_cli, "_show_peer_cards", lambda hcfg, client: None)
|
||||
monkeypatch.setitem(__import__("sys").modules, "honcho", SimpleNamespace())
|
||||
|
||||
honcho_cli.cmd_status(SimpleNamespace(all=False))
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Auth: OAuth (hermes-agent" in out
|
||||
assert "API key:" not in out
|
||||
|
||||
|
||||
class TestCloneHonchoForProfile:
|
||||
"""Identity-key carryover during profile cloning.
|
||||
|
|
@ -389,6 +449,9 @@ class TestSetupWizardDeploymentShape:
|
|||
# Scripted _prompt: pop answers in order. Default-return for unconsumed prompts.
|
||||
answer_iter = iter(answers)
|
||||
def _scripted_prompt(label, default=None, secret=False):
|
||||
# Auth-method prompt is orthogonal to shape; auto-answer apikey so the answer lists stay shape-only.
|
||||
if "OAuth" in label:
|
||||
return "apikey"
|
||||
try:
|
||||
return next(answer_iter)
|
||||
except StopIteration:
|
||||
|
|
|
|||
|
|
@ -711,15 +711,17 @@ class TestResolveSessionNameGatewayKey:
|
|||
)
|
||||
assert result == "agent-main-telegram-dm-8439114563"
|
||||
|
||||
def test_session_title_still_wins_over_gateway_key(self):
|
||||
"""Explicit /title remap takes priority over gateway_session_key."""
|
||||
def test_gateway_key_not_remapped_by_title(self):
|
||||
"""A title never remaps a stable identifier — the gateway per-chat key
|
||||
wins over the title so a generated title can't split a live conversation
|
||||
onto a new Honcho session."""
|
||||
config = HonchoClientConfig(session_strategy="per-session")
|
||||
result = config.resolve_session_name(
|
||||
session_title="my-custom-title",
|
||||
session_id="20260412_171002_69bb38",
|
||||
gateway_session_key="agent:main:telegram:dm:8439114563",
|
||||
)
|
||||
assert result == "my-custom-title"
|
||||
assert result == "agent-main-telegram-dm-8439114563"
|
||||
|
||||
def test_per_session_fallback_without_gateway_key(self):
|
||||
"""Without gateway_session_key, per-session returns session_id (CLI path)."""
|
||||
|
|
|
|||
254
tests/honcho_plugin/test_oauth.py
Normal file
254
tests/honcho_plugin/test_oauth.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
"""Tests for plugins/memory/honcho/oauth.py — OAuth grant storage + refresh."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from plugins.memory.honcho import oauth
|
||||
from plugins.memory.honcho.oauth import OAuthCredential
|
||||
|
||||
|
||||
def _host_block(refresh="hch-rt-old", expires_at=10_000):
|
||||
return {
|
||||
"apiKey": "hch-at-old",
|
||||
"oauth": {
|
||||
"refreshToken": refresh,
|
||||
"expiresAt": expires_at,
|
||||
"clientId": "hermes-desktop",
|
||||
"tokenEndpoint": "http://localhost:8000/oauth/token",
|
||||
"scope": "write",
|
||||
"tokenType": "Bearer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _write(path: Path, raw: dict) -> None:
|
||||
path.write_text(json.dumps(raw), encoding="utf-8")
|
||||
|
||||
|
||||
class TestTokenDetection:
|
||||
def test_access_token_prefix(self):
|
||||
assert oauth.is_oauth_access_token("hch-at-abc")
|
||||
assert not oauth.is_oauth_access_token("hch-v3-abc")
|
||||
assert not oauth.is_oauth_access_token("hch-rt-abc")
|
||||
assert not oauth.is_oauth_access_token(None)
|
||||
|
||||
|
||||
class TestCredentialModel:
|
||||
def test_roundtrip(self):
|
||||
cred = OAuthCredential.from_host_block(_host_block())
|
||||
assert cred is not None
|
||||
block = cred.oauth_block()
|
||||
assert block["refreshToken"] == "hch-rt-old"
|
||||
assert block["expiresAt"] == 10_000
|
||||
assert block["clientId"] == "hermes-desktop"
|
||||
|
||||
def test_incomplete_block_returns_none(self):
|
||||
# plain API key (no oauth sub-block)
|
||||
assert OAuthCredential.from_host_block({"apiKey": "hch-v3-x"}) is None
|
||||
# oauth block missing refreshToken
|
||||
bad = _host_block()
|
||||
del bad["oauth"]["refreshToken"]
|
||||
assert OAuthCredential.from_host_block(bad) is None
|
||||
|
||||
def test_is_expired_respects_skew(self):
|
||||
cred = OAuthCredential.from_host_block(_host_block(expires_at=1000))
|
||||
assert not cred.is_expired(now=800, skew=120) # 1000-120=880 > 800
|
||||
assert cred.is_expired(now=900, skew=120) # 900 >= 880
|
||||
|
||||
|
||||
class TestEnsureFreshToken:
|
||||
def test_no_oauth_credential_is_noop(self, tmp_path):
|
||||
path = tmp_path / "honcho.json"
|
||||
_write(path, {"hosts": {"hermes": {"apiKey": "hch-v3-static"}}})
|
||||
token, refreshed = oauth.ensure_fresh_token(path, "hermes", now=0)
|
||||
assert token is None and refreshed is False
|
||||
|
||||
def test_fresh_token_skips_refresh(self, tmp_path, monkeypatch):
|
||||
path = tmp_path / "honcho.json"
|
||||
_write(path, {"hosts": {"hermes": _host_block(expires_at=10_000)}})
|
||||
monkeypatch.setattr(
|
||||
oauth, "_http_post_form",
|
||||
lambda *a, **k: pytest.fail("refresh must not be called when fresh"),
|
||||
)
|
||||
token, refreshed = oauth.ensure_fresh_token(path, "hermes", now=0)
|
||||
assert token == "hch-at-old" and refreshed is False
|
||||
|
||||
def test_fresh_token_served_from_cache_without_disk(self, tmp_path, monkeypatch):
|
||||
path = tmp_path / "honcho.json"
|
||||
_write(path, {"hosts": {"hermes": _host_block(expires_at=10_000)}})
|
||||
oauth._expiry_cache.clear()
|
||||
# First call seeds the cache from disk.
|
||||
oauth.ensure_fresh_token(path, "hermes", now=0)
|
||||
# Second call must not touch disk while the token is well clear of expiry.
|
||||
monkeypatch.setattr(
|
||||
oauth, "_read_config",
|
||||
lambda *a, **k: pytest.fail("disk must not be read while token is fresh"),
|
||||
)
|
||||
token, refreshed = oauth.ensure_fresh_token(path, "hermes", now=100)
|
||||
assert token == "hch-at-old" and refreshed is False
|
||||
|
||||
def test_expired_token_refreshes_and_persists_rotation(self, tmp_path, monkeypatch):
|
||||
path = tmp_path / "honcho.json"
|
||||
_write(path, {"hosts": {"hermes": _host_block(expires_at=100)}})
|
||||
|
||||
def fake_post(url, data, timeout):
|
||||
assert data["grant_type"] == "refresh_token"
|
||||
assert data["refresh_token"] == "hch-rt-old"
|
||||
assert data["client_id"] == "hermes-desktop"
|
||||
return {
|
||||
"access_token": "hch-at-new",
|
||||
"refresh_token": "hch-rt-new",
|
||||
"expires_in": 3600,
|
||||
"scope": "write",
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
|
||||
monkeypatch.setattr(oauth, "_http_post_form", fake_post)
|
||||
token, refreshed = oauth.ensure_fresh_token(path, "hermes", now=1000)
|
||||
assert token == "hch-at-new" and refreshed is True
|
||||
|
||||
# Rotated refresh token + new access token + absolute expiry persisted.
|
||||
saved = json.loads(path.read_text())["hosts"]["hermes"]
|
||||
assert saved["apiKey"] == "hch-at-new"
|
||||
assert saved["oauth"]["refreshToken"] == "hch-rt-new"
|
||||
assert saved["oauth"]["expiresAt"] == 1000 + 3600
|
||||
|
||||
def test_refresh_failure_fails_open(self, tmp_path, monkeypatch):
|
||||
path = tmp_path / "honcho.json"
|
||||
_write(path, {"hosts": {"hermes": _host_block(expires_at=100)}})
|
||||
|
||||
def boom(*a, **k):
|
||||
raise RuntimeError("network down")
|
||||
|
||||
monkeypatch.setattr(oauth, "_http_post_form", boom)
|
||||
token, refreshed = oauth.ensure_fresh_token(path, "hermes", now=1000)
|
||||
# Stale token returned, no crash, file untouched.
|
||||
assert token == "hch-at-old" and refreshed is False
|
||||
assert json.loads(path.read_text())["hosts"]["hermes"]["apiKey"] == "hch-at-old"
|
||||
|
||||
def test_double_check_uses_disk_when_already_rotated(self, tmp_path, monkeypatch):
|
||||
# Simulates a concurrent thread that rotated the token on disk after our
|
||||
# stale in-memory snapshot: the locked re-read must skip the HTTP call.
|
||||
path = tmp_path / "honcho.json"
|
||||
_write(path, {"hosts": {"hermes": _host_block(refresh="hch-rt-fresh", expires_at=10_000)}})
|
||||
stale_raw = {"hosts": {"hermes": _host_block(refresh="hch-rt-old", expires_at=100)}}
|
||||
stale_raw["hosts"]["hermes"]["apiKey"] = "hch-at-stale"
|
||||
monkeypatch.setattr(
|
||||
oauth, "_http_post_form",
|
||||
lambda *a, **k: pytest.fail("must not refresh; disk token is fresh"),
|
||||
)
|
||||
token, refreshed = oauth.ensure_fresh_token(path, "hermes", stale_raw, now=1000)
|
||||
assert token == "hch-at-old" # the on-disk fresh credential's access token
|
||||
|
||||
def test_refresh_holds_cross_process_lock(self, tmp_path, monkeypatch):
|
||||
# A second opener must not grab <config>.lock mid-refresh — proving the
|
||||
# rotation is serialized machine-wide so peers can't replay the token.
|
||||
fcntl = pytest.importorskip("fcntl")
|
||||
path = tmp_path / "honcho.json"
|
||||
_write(path, {"hosts": {"hermes": _host_block(expires_at=100)}})
|
||||
seen = {}
|
||||
|
||||
def fake_post(url, data, timeout):
|
||||
with open(f"{path}.lock", "a+b") as other:
|
||||
try:
|
||||
fcntl.flock(other.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
fcntl.flock(other.fileno(), fcntl.LOCK_UN)
|
||||
seen["held"] = False
|
||||
except OSError:
|
||||
seen["held"] = True
|
||||
return {"access_token": "hch-at-new", "refresh_token": "hch-rt-new",
|
||||
"expires_in": 3600, "scope": "write", "token_type": "Bearer"}
|
||||
|
||||
monkeypatch.setattr(oauth, "_http_post_form", fake_post)
|
||||
token, refreshed = oauth.ensure_fresh_token(path, "hermes", now=1000)
|
||||
assert refreshed is True and seen.get("held") is True
|
||||
# Released afterward: a non-blocking acquire now succeeds.
|
||||
with open(f"{path}.lock", "a+b") as fh:
|
||||
fcntl.flock(fh.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
fcntl.flock(fh.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
def test_refresh_degrades_when_lock_unavailable(self, tmp_path, monkeypatch):
|
||||
# No flock (unsupported FS/platform) must not block refresh — it falls
|
||||
# back to in-process serialization only.
|
||||
fcntl = pytest.importorskip("fcntl")
|
||||
path = tmp_path / "honcho.json"
|
||||
_write(path, {"hosts": {"hermes": _host_block(expires_at=100)}})
|
||||
|
||||
def no_flock(*a, **k):
|
||||
raise OSError("flock unsupported")
|
||||
|
||||
monkeypatch.setattr(fcntl, "flock", no_flock)
|
||||
monkeypatch.setattr(
|
||||
oauth, "_http_post_form",
|
||||
lambda *a, **k: {"access_token": "hch-at-new", "refresh_token": "hch-rt-new",
|
||||
"expires_in": 3600, "scope": "write", "token_type": "Bearer"},
|
||||
)
|
||||
token, refreshed = oauth.ensure_fresh_token(path, "hermes", now=1000)
|
||||
assert token == "hch-at-new" and refreshed is True
|
||||
|
||||
|
||||
class TestInstallGrant:
|
||||
def test_deep_merges_config_and_preserves_other_hosts(self, tmp_path):
|
||||
path = tmp_path / "honcho.json"
|
||||
_write(path, {
|
||||
"apiKey": "hch-v3-root", # root static key preserved
|
||||
"hosts": {
|
||||
"obsidian": {"workspace": "obsidian"},
|
||||
"hermes": {"workspace": "hermes", "saveMessages": False},
|
||||
},
|
||||
})
|
||||
grant = {
|
||||
"access_token": "hch-at-fresh",
|
||||
"refresh_token": "hch-rt-fresh",
|
||||
"expires_in": 3600,
|
||||
"scope": "write",
|
||||
"config": {
|
||||
"environment": "production",
|
||||
"hosts": {"hermes": {"saveMessages": True, "recallMode": "hybrid"}},
|
||||
},
|
||||
}
|
||||
cred = oauth.install_grant(
|
||||
path, "hermes", grant,
|
||||
client_id="hermes-desktop",
|
||||
token_endpoint="http://localhost:8000/oauth/token",
|
||||
now=1000,
|
||||
)
|
||||
assert cred.expires_at == 1000 + 3600
|
||||
|
||||
saved = json.loads(path.read_text())
|
||||
assert saved["apiKey"] == "hch-v3-root" # untouched
|
||||
assert saved["hosts"]["obsidian"] == {"workspace": "obsidian"} # untouched
|
||||
h = saved["hosts"]["hermes"]
|
||||
assert h["apiKey"] == "hch-at-fresh"
|
||||
assert h["oauth"]["refreshToken"] == "hch-rt-fresh"
|
||||
assert h["saveMessages"] is True # grant config won the deep-merge
|
||||
assert h["recallMode"] == "hybrid" # new key added
|
||||
assert h["workspace"] == "hermes" # pre-existing key preserved
|
||||
assert saved["environment"] == "production" # root key from grant
|
||||
|
||||
def test_rejects_grant_without_tokens(self, tmp_path):
|
||||
path = tmp_path / "honcho.json"
|
||||
_write(path, {})
|
||||
with pytest.raises(ValueError):
|
||||
oauth.install_grant(
|
||||
path, "hermes", {"access_token": "hch-at-x"}, # no refresh_token
|
||||
client_id="c", token_endpoint="e",
|
||||
)
|
||||
|
||||
|
||||
class TestApplyTokenToClient:
|
||||
def test_mutates_live_bearer(self):
|
||||
class FakeHttp:
|
||||
api_key = "hch-at-old"
|
||||
|
||||
class FakeClient:
|
||||
_http = FakeHttp()
|
||||
|
||||
client = FakeClient()
|
||||
assert oauth.apply_token_to_client(client, "hch-at-new") is True
|
||||
assert client._http.api_key == "hch-at-new"
|
||||
|
||||
def test_returns_false_when_shape_unknown(self):
|
||||
assert oauth.apply_token_to_client(object(), "hch-at-new") is False
|
||||
347
tests/honcho_plugin/test_oauth_flow.py
Normal file
347
tests/honcho_plugin/test_oauth_flow.py
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
"""End-to-end test for the zero-CLI Honcho OAuth flow against a fake AS.
|
||||
|
||||
Stands up a real local authorization server (no network, no browser) and drives
|
||||
the full path: begin → /authorize 302 → loopback :8765 callback → token
|
||||
exchange → install_grant → forced-expiry refresh with rotation. This is the
|
||||
deterministic "real smoke test" for the consumer flow.
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from plugins.memory.honcho import oauth, oauth_flow
|
||||
|
||||
|
||||
class _FakeAS(BaseHTTPRequestHandler):
|
||||
"""Minimal OAuth 2.1 AS: /authorize 302s to the callback; /oauth/token mints."""
|
||||
|
||||
# Rotation counter shared across requests so refresh returns a new token.
|
||||
issued = {"n": 0}
|
||||
|
||||
def do_GET(self): # noqa: N802
|
||||
parsed = urlparse(self.path)
|
||||
if parsed.path != "/authorize":
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
return
|
||||
q = parse_qs(parsed.query)
|
||||
redirect = q["redirect_uri"][0]
|
||||
# The redirect must be the IP literal matching the bound host — a
|
||||
# `localhost` redirect can resolve to ::1 and miss the IPv4 listener.
|
||||
# Host must be the IP literal (port may fall back off :8765).
|
||||
assert redirect.startswith("http://127.0.0.1:") and "/callback" in redirect, redirect
|
||||
# Consent shows a home-relative display path — never an absolute path
|
||||
# that would leak the username / home layout off the machine.
|
||||
cp = q["config_path"][0]
|
||||
assert cp.endswith("honcho.json"), q.get("config_path")
|
||||
assert not cp.startswith("/"), cp
|
||||
state = q["state"][0]
|
||||
location = f"{redirect}?code=test-auth-code&state={state}"
|
||||
self.send_response(302)
|
||||
self.send_header("Location", location)
|
||||
self.end_headers()
|
||||
|
||||
def do_POST(self): # noqa: N802
|
||||
parsed = urlparse(self.path)
|
||||
if parsed.path != "/oauth/token":
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
return
|
||||
length = int(self.headers.get("Content-Length", 0))
|
||||
form = parse_qs(self.rfile.read(length).decode())
|
||||
grant_type = form["grant_type"][0]
|
||||
self.issued["n"] += 1
|
||||
n = self.issued["n"]
|
||||
body = {
|
||||
"access_token": f"hch-at-{n}",
|
||||
"refresh_token": f"hch-rt-{n}",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"scope": "write",
|
||||
}
|
||||
if grant_type == "authorization_code":
|
||||
body["config"] = {
|
||||
"peerName": "lyra",
|
||||
"environment": "production",
|
||||
"hosts": {"hermes": {"saveMessages": True, "recallMode": "hybrid"}},
|
||||
}
|
||||
payload = json.dumps(body).encode()
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(payload)
|
||||
|
||||
def log_message(self, *args):
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_as(monkeypatch):
|
||||
_FakeAS.issued["n"] = 0
|
||||
server = HTTPServer(("127.0.0.1", 0), _FakeAS)
|
||||
port = server.server_address[1]
|
||||
thread = threading.Thread(target=server.serve_forever, daemon=True)
|
||||
thread.start()
|
||||
base = f"http://127.0.0.1:{port}"
|
||||
monkeypatch.setenv("HONCHO_OAUTH_AUTHORIZE_URL", f"{base}/authorize")
|
||||
monkeypatch.setenv("HONCHO_OAUTH_TOKEN_URL", f"{base}/oauth/token")
|
||||
monkeypatch.setenv("HONCHO_OAUTH_CLIENT_ID", "hermes-desktop")
|
||||
try:
|
||||
yield base
|
||||
finally:
|
||||
server.shutdown()
|
||||
server.server_close()
|
||||
|
||||
|
||||
def _browser_driver(authorize_url: str) -> None:
|
||||
"""Stand in for the user's browser: follow /authorize's 302 into the callback.
|
||||
|
||||
Retries the callback GET so it can't lose the race to the loopback bind.
|
||||
"""
|
||||
resp = httpx.get(authorize_url, follow_redirects=False)
|
||||
location = resp.headers["Location"]
|
||||
for _ in range(50):
|
||||
try:
|
||||
httpx.get(location, timeout=2)
|
||||
return
|
||||
except httpx.ConnectError:
|
||||
time.sleep(0.05)
|
||||
raise RuntimeError("loopback callback never came up")
|
||||
|
||||
|
||||
def test_full_loopback_flow_then_refresh(tmp_path, fake_as):
|
||||
config_path = tmp_path / "honcho.json"
|
||||
config_path.write_text(json.dumps({"hosts": {"obsidian": {"workspace": "obsidian"}}}))
|
||||
|
||||
cred = oauth_flow.authorize_via_loopback(
|
||||
config_path=config_path,
|
||||
host="hermes",
|
||||
open_url=lambda url: _browser_driver(url),
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Grant installed: token stored, config deep-merged, other host preserved.
|
||||
assert cred.access_token == "hch-at-1"
|
||||
saved = json.loads(config_path.read_text())
|
||||
assert saved["hosts"]["hermes"]["apiKey"] == "hch-at-1"
|
||||
assert saved["hosts"]["hermes"]["oauth"]["refreshToken"] == "hch-rt-1"
|
||||
assert saved["hosts"]["hermes"]["recallMode"] == "hybrid"
|
||||
assert saved["environment"] == "production"
|
||||
assert saved["hosts"]["obsidian"] == {"workspace": "obsidian"}
|
||||
|
||||
# Force expiry; ensure_fresh_token refreshes against the same AS and rotates.
|
||||
token, refreshed = oauth.ensure_fresh_token(
|
||||
config_path, "hermes", now=saved["hosts"]["hermes"]["oauth"]["expiresAt"] + 10
|
||||
)
|
||||
assert refreshed is True
|
||||
assert token == "hch-at-2"
|
||||
rotated = json.loads(config_path.read_text())["hosts"]["hermes"]["oauth"]
|
||||
assert rotated["refreshToken"] == "hch-rt-2"
|
||||
|
||||
|
||||
def test_state_mismatch_is_rejected(fake_as, tmp_path):
|
||||
endpoints = oauth_flow.resolve_endpoints()
|
||||
_, state = oauth_flow.begin_authorization(endpoints)
|
||||
with pytest.raises(ValueError, match="unknown or expired"):
|
||||
oauth_flow.complete_authorization(
|
||||
endpoints, "code", "not-the-real-state",
|
||||
config_path=tmp_path / "honcho.json", host="hermes",
|
||||
)
|
||||
|
||||
|
||||
def test_source_tags_the_authorize_link(fake_as):
|
||||
endpoints = oauth_flow.resolve_endpoints()
|
||||
url, _ = oauth_flow.begin_authorization(endpoints, source="hermes-cli")
|
||||
assert "source=hermes-cli" in url
|
||||
untagged, _ = oauth_flow.begin_authorization(endpoints)
|
||||
assert "source=" not in untagged
|
||||
|
||||
|
||||
def test_client_id_defaults_to_hermes_agent(monkeypatch):
|
||||
# One client for every surface; the env var overrides for unusual deployments.
|
||||
monkeypatch.delenv("HONCHO_OAUTH_CLIENT_ID", raising=False)
|
||||
common = {"environment": "production", "base_url": "https://api.honcho.dev"}
|
||||
assert oauth_flow.resolve_endpoints(**common).client_id == "hermes-agent"
|
||||
monkeypatch.setenv("HONCHO_OAUTH_CLIENT_ID", "custom-id")
|
||||
assert oauth_flow.resolve_endpoints(**common).client_id == "custom-id"
|
||||
|
||||
|
||||
def test_grant_persists_default_client_id(tmp_path, fake_as, monkeypatch):
|
||||
# Drop the fixture's override so the default takes effect; the grant must
|
||||
# store client_id=hermes-agent so refresh reuses the right client.
|
||||
monkeypatch.delenv("HONCHO_OAUTH_CLIENT_ID", raising=False)
|
||||
config_path = tmp_path / "honcho.json"
|
||||
config_path.write_text(json.dumps({"hosts": {}}))
|
||||
|
||||
oauth_flow.authorize_via_loopback(
|
||||
config_path=config_path,
|
||||
host="hermes",
|
||||
source="hermes-cli",
|
||||
apply_config=False,
|
||||
open_url=lambda url: _browser_driver(url),
|
||||
timeout=10,
|
||||
)
|
||||
saved = json.loads(config_path.read_text())
|
||||
assert saved["hosts"]["hermes"]["oauth"]["clientId"] == "hermes-agent"
|
||||
|
||||
|
||||
def test_config_path_rides_the_authorize_link(fake_as):
|
||||
endpoints = oauth_flow.resolve_endpoints()
|
||||
url, _ = oauth_flow.begin_authorization(endpoints, config_path="~/.hermes/honcho.json")
|
||||
q = parse_qs(urlparse(url).query)
|
||||
assert q["config_path"][0] == "~/.hermes/honcho.json"
|
||||
bare, _ = oauth_flow.begin_authorization(endpoints)
|
||||
assert "config_path=" not in bare
|
||||
|
||||
|
||||
def test_display_config_path_never_leaks_absolute_path():
|
||||
from pathlib import Path
|
||||
|
||||
# Under home → collapsed to ~/…; outside home → bare filename only.
|
||||
under_home = Path.home() / ".hermes" / "profiles" / "work" / "honcho.json"
|
||||
assert oauth_flow._display_config_path(under_home) == "~/.hermes/profiles/work/honcho.json"
|
||||
assert oauth_flow._display_config_path("/var/folders/tmp/honcho.json") == "honcho.json"
|
||||
|
||||
|
||||
def test_cli_flow_stores_tokens_without_applying_config(tmp_path, fake_as):
|
||||
# apply_config=False (the CLI path): grant config must NOT touch settings.
|
||||
config_path = tmp_path / "honcho.json"
|
||||
config_path.write_text(json.dumps({"hosts": {"hermes": {"saveMessages": False}}}))
|
||||
|
||||
cred = oauth_flow.authorize_via_loopback(
|
||||
config_path=config_path,
|
||||
host="hermes",
|
||||
source="hermes-cli",
|
||||
apply_config=False,
|
||||
open_url=lambda url: _browser_driver(url),
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
saved = json.loads(config_path.read_text())
|
||||
host = saved["hosts"]["hermes"]
|
||||
assert host["apiKey"] == cred.access_token
|
||||
assert host["oauth"]["refreshToken"] == cred.refresh_token
|
||||
# Wizard-owned setting untouched; grant config keys absent.
|
||||
assert host["saveMessages"] is False
|
||||
assert "recallMode" not in host
|
||||
assert "environment" not in saved
|
||||
# consent peer name still surfaced (seeds the CLI wizard prompt) despite no merge
|
||||
assert cred.consent_peer_name == "lyra"
|
||||
|
||||
|
||||
# ── Desktop "Connect" button path: background launcher, status, dispatch ──
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_flow():
|
||||
oauth_flow._status = oauth_flow.FlowStatus()
|
||||
oauth_flow._flow_thread = None
|
||||
yield
|
||||
oauth_flow._status = oauth_flow.FlowStatus()
|
||||
oauth_flow._flow_thread = None
|
||||
|
||||
|
||||
def _wait_until(predicate, timeout=2.0):
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
if predicate():
|
||||
return True
|
||||
time.sleep(0.02)
|
||||
return False
|
||||
|
||||
|
||||
def test_launcher_runs_flow_in_background_and_reports_connected(monkeypatch, reset_flow):
|
||||
seen = {}
|
||||
gate = threading.Event()
|
||||
|
||||
def fake(**kwargs):
|
||||
seen.update(kwargs) # captures source default + eagerly-resolved path/host
|
||||
gate.wait(2) # hold the flow open so the launcher returns while pending
|
||||
|
||||
monkeypatch.setattr(oauth_flow, "authorize_via_loopback", fake)
|
||||
monkeypatch.setattr(oauth_flow, "_detect_connection", lambda: (True, "oauth"))
|
||||
|
||||
st = oauth_flow.start_loopback_flow_background(config_path=Path("/t/honcho.json"), host="hermes")
|
||||
assert st["state"] == "pending" # returns immediately, before the flow finishes
|
||||
assert _wait_until(lambda: seen.get("source") == "hermes-desktop") # default source tag
|
||||
assert seen["host"] == "hermes"
|
||||
gate.set()
|
||||
assert _wait_until(lambda: oauth_flow.get_flow_status()["state"] == "connected")
|
||||
|
||||
|
||||
def test_launcher_reports_error_on_flow_failure(monkeypatch, reset_flow):
|
||||
def boom(**kwargs):
|
||||
raise RuntimeError("loopback bind failed")
|
||||
|
||||
monkeypatch.setattr(oauth_flow, "authorize_via_loopback", boom)
|
||||
monkeypatch.setattr(oauth_flow, "_detect_connection", lambda: (False, None))
|
||||
|
||||
oauth_flow.start_loopback_flow_background(config_path=Path("/t/honcho.json"), host="hermes")
|
||||
assert _wait_until(lambda: oauth_flow.get_flow_status()["state"] == "error")
|
||||
assert "loopback bind failed" in oauth_flow.get_flow_status()["detail"]
|
||||
|
||||
|
||||
def test_launcher_is_idempotent_while_pending(monkeypatch, reset_flow):
|
||||
block = threading.Event()
|
||||
calls = []
|
||||
|
||||
def fake(**kwargs):
|
||||
calls.append(1)
|
||||
block.wait(2)
|
||||
|
||||
monkeypatch.setattr(oauth_flow, "authorize_via_loopback", fake)
|
||||
monkeypatch.setattr(oauth_flow, "_detect_connection", lambda: (False, None))
|
||||
|
||||
s1 = oauth_flow.start_loopback_flow_background(config_path=Path("/t/h.json"), host="hermes")
|
||||
assert _wait_until(lambda: len(calls) == 1) # first flow is running
|
||||
s2 = oauth_flow.start_loopback_flow_background(config_path=Path("/t/h.json"), host="hermes")
|
||||
block.set()
|
||||
assert s1["state"] == "pending" and s2["state"] == "pending"
|
||||
assert _wait_until(lambda: oauth_flow.get_flow_status()["state"] == "connected")
|
||||
assert calls == [1] # the second call did not spawn a second flow
|
||||
|
||||
|
||||
def test_get_flow_status_reports_stored_connection(tmp_path, monkeypatch, reset_flow):
|
||||
from plugins.memory.honcho import client as honcho_client
|
||||
|
||||
cfgfile = tmp_path / "honcho.json"
|
||||
monkeypatch.setattr(honcho_client, "resolve_config_path", lambda: cfgfile)
|
||||
monkeypatch.setattr(honcho_client, "resolve_active_host", lambda: "hermes")
|
||||
monkeypatch.delenv("HONCHO_API_KEY", raising=False)
|
||||
|
||||
cfgfile.write_text(json.dumps({"hosts": {"hermes": {}}}))
|
||||
assert oauth_flow.get_flow_status()["connected"] is False
|
||||
|
||||
cfgfile.write_text(json.dumps({"hosts": {"hermes": {"apiKey": "hch-v3-static"}}}))
|
||||
s = oauth_flow.get_flow_status()
|
||||
assert s["connected"] is True and s["auth"] == "apikey"
|
||||
|
||||
cfgfile.write_text(json.dumps({"hosts": {"hermes": {
|
||||
"apiKey": "hch-at-tok",
|
||||
"oauth": {"refreshToken": "hch-rt-x", "expiresAt": 9_999_999_999,
|
||||
"clientId": "hermes-desktop", "tokenEndpoint": "http://x/oauth/token"},
|
||||
}}}))
|
||||
s = oauth_flow.get_flow_status()
|
||||
assert s["connected"] is True and s["auth"] == "oauth"
|
||||
|
||||
|
||||
def test_memory_oauth_router_dispatches_by_provider_convention():
|
||||
# The generic seam behind the two routes: provider → plugins.memory.<p>.oauth_flow.
|
||||
from fastapi import HTTPException
|
||||
|
||||
from hermes_cli.memory_oauth import _resolve_flow
|
||||
|
||||
mod = _resolve_flow("honcho")
|
||||
assert hasattr(mod, "start_loopback_flow_background") and hasattr(mod, "get_flow_status")
|
||||
|
||||
for bad in ("builtin", "no-such-provider", "../etc"):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
_resolve_flow(bad)
|
||||
assert exc.value.status_code == 404
|
||||
209
tests/plugins/memory/test_mem0_backend.py
Normal file
209
tests/plugins/memory/test_mem0_backend.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
"""Tests for Mem0Backend abstraction — PlatformBackend and OSSBackend."""
|
||||
|
||||
import pytest
|
||||
|
||||
from plugins.memory.mem0._backend import Mem0Backend, PlatformBackend, OSSBackend
|
||||
|
||||
|
||||
class FakePlatformClient:
|
||||
"""Fake MemoryClient for PlatformBackend tests."""
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def search(self, query, **kwargs):
|
||||
self.calls.append(("search", query, kwargs))
|
||||
return {"results": [{"id": "m1", "memory": "fact1", "score": 0.9}]}
|
||||
|
||||
def get_all(self, **kwargs):
|
||||
self.calls.append(("get_all", kwargs))
|
||||
return {"count": 1, "next": None, "results": [{"id": "m1", "memory": "fact1"}]}
|
||||
|
||||
def add(self, messages, **kwargs):
|
||||
self.calls.append(("add", messages, kwargs))
|
||||
return {"status": "PENDING", "event_id": "evt-1"}
|
||||
|
||||
def update(self, **kwargs):
|
||||
self.calls.append(("update", kwargs))
|
||||
return {"id": kwargs["memory_id"], "text": kwargs["text"]}
|
||||
|
||||
def delete(self, **kwargs):
|
||||
self.calls.append(("delete", kwargs))
|
||||
|
||||
|
||||
class TestPlatformBackend:
|
||||
|
||||
def _make(self):
|
||||
client = FakePlatformClient()
|
||||
backend = PlatformBackend.__new__(PlatformBackend)
|
||||
backend._client = client
|
||||
return backend, client
|
||||
|
||||
def test_search_forwards_params(self):
|
||||
backend, client = self._make()
|
||||
result = backend.search("test query", filters={"user_id": "u1"}, top_k=5)
|
||||
assert client.calls[0][0] == "search"
|
||||
assert client.calls[0][1] == "test query"
|
||||
assert client.calls[0][2]["filters"] == {"user_id": "u1"}
|
||||
assert client.calls[0][2]["top_k"] == 5
|
||||
|
||||
def test_search_forwards_rerank(self):
|
||||
backend, client = self._make()
|
||||
backend.search("q", filters={}, rerank=False)
|
||||
assert client.calls[0][2]["rerank"] is False
|
||||
|
||||
def test_search_rerank_default_true(self):
|
||||
backend, client = self._make()
|
||||
backend.search("q", filters={})
|
||||
assert client.calls[0][2]["rerank"] is True
|
||||
|
||||
def test_search_returns_list(self):
|
||||
backend, _ = self._make()
|
||||
result = backend.search("q", filters={})
|
||||
assert isinstance(result, list)
|
||||
assert result[0]["id"] == "m1"
|
||||
|
||||
def test_get_all_forwards_pagination(self):
|
||||
backend, client = self._make()
|
||||
result = backend.get_all(filters={"user_id": "u1"}, page=2, page_size=50)
|
||||
assert client.calls[0][1]["page"] == 2
|
||||
assert client.calls[0][1]["page_size"] == 50
|
||||
assert "count" in result
|
||||
|
||||
def test_add_forwards_kwargs(self):
|
||||
backend, client = self._make()
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
result = backend.add(msgs, user_id="u1", agent_id="hermes", infer=False)
|
||||
call = client.calls[0]
|
||||
assert call[2]["user_id"] == "u1"
|
||||
assert call[2]["infer"] is False
|
||||
# metadata kwarg should be omitted entirely when not provided so we
|
||||
# don't surprise older mem0 client versions with an unknown kwarg.
|
||||
assert "metadata" not in call[2]
|
||||
|
||||
def test_add_forwards_metadata_when_present(self):
|
||||
backend, client = self._make()
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
backend.add(
|
||||
msgs,
|
||||
user_id="u1",
|
||||
agent_id="hermes",
|
||||
infer=False,
|
||||
metadata={"channel": "telegram"},
|
||||
)
|
||||
assert client.calls[0][2]["metadata"] == {"channel": "telegram"}
|
||||
|
||||
def test_add_omits_empty_metadata(self):
|
||||
backend, client = self._make()
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
backend.add(msgs, user_id="u1", agent_id="hermes", infer=False, metadata={})
|
||||
assert "metadata" not in client.calls[0][2]
|
||||
|
||||
def test_update_forwards(self):
|
||||
backend, client = self._make()
|
||||
backend.update("m1", "new text")
|
||||
assert client.calls[0][1] == {"memory_id": "m1", "text": "new text"}
|
||||
|
||||
def test_delete_forwards(self):
|
||||
backend, client = self._make()
|
||||
backend.delete("m1")
|
||||
assert client.calls[0][1] == {"memory_id": "m1"}
|
||||
|
||||
|
||||
class FakeOSSMemory:
|
||||
"""Fake mem0.Memory for OSSBackend tests."""
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def search(self, query, **kwargs):
|
||||
self.calls.append(("search", query, kwargs))
|
||||
return {"results": [{"id": "m1", "memory": "fact1", "score": 0.8}]}
|
||||
|
||||
def get_all(self, **kwargs):
|
||||
self.calls.append(("get_all", kwargs))
|
||||
return {"results": [{"id": "m1", "memory": "fact1"}]}
|
||||
|
||||
def add(self, messages, **kwargs):
|
||||
self.calls.append(("add", messages, kwargs))
|
||||
return {"results": [{"id": "m1", "memory": "fact1", "event": "ADD"}]}
|
||||
|
||||
def update(self, memory_id, **kwargs):
|
||||
self.calls.append(("update", memory_id, kwargs))
|
||||
return {"message": "Memory updated successfully!"}
|
||||
|
||||
def delete(self, memory_id):
|
||||
self.calls.append(("delete", memory_id))
|
||||
return {"message": "Memory deleted successfully!"}
|
||||
|
||||
|
||||
class TestOSSBackend:
|
||||
|
||||
def _make(self):
|
||||
memory = FakeOSSMemory()
|
||||
backend = OSSBackend.__new__(OSSBackend)
|
||||
backend._memory = memory
|
||||
return backend, memory
|
||||
|
||||
def test_search_returns_list(self):
|
||||
backend, _ = self._make()
|
||||
result = backend.search("test", filters={"user_id": "u1"})
|
||||
assert isinstance(result, list)
|
||||
assert result[0]["id"] == "m1"
|
||||
|
||||
def test_search_passes_filters(self):
|
||||
backend, memory = self._make()
|
||||
backend.search("q", filters={"user_id": "u1"}, top_k=3)
|
||||
assert memory.calls[0][2]["filters"] == {"user_id": "u1"}
|
||||
assert memory.calls[0][2]["top_k"] == 3
|
||||
|
||||
def test_search_ignores_rerank(self):
|
||||
"""OSS backend accepts rerank param but does not forward it to Memory."""
|
||||
backend, memory = self._make()
|
||||
backend.search("q", filters={}, rerank=True)
|
||||
assert "rerank" not in memory.calls[0][2]
|
||||
|
||||
def test_get_all_ignores_pagination(self):
|
||||
"""OSSBackend accepts page/page_size but does NOT forward to Memory.get_all()."""
|
||||
backend, memory = self._make()
|
||||
result = backend.get_all(filters={"user_id": "u1"}, page=2, page_size=50)
|
||||
call_kwargs = memory.calls[0][1]
|
||||
assert "page" not in call_kwargs
|
||||
assert "page_size" not in call_kwargs
|
||||
assert result["count"] == 1
|
||||
|
||||
def test_get_all_returns_envelope(self):
|
||||
backend, _ = self._make()
|
||||
result = backend.get_all(filters={"user_id": "u1"})
|
||||
assert "results" in result
|
||||
assert "count" in result
|
||||
|
||||
def test_add_forwards_kwargs(self):
|
||||
backend, memory = self._make()
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
backend.add(msgs, user_id="u1", agent_id="hermes", infer=False)
|
||||
assert memory.calls[0][2]["user_id"] == "u1"
|
||||
assert memory.calls[0][2]["infer"] is False
|
||||
|
||||
def test_update_maps_text_to_data(self):
|
||||
"""OSS Memory.update uses `data=` param, not `text=`."""
|
||||
backend, memory = self._make()
|
||||
backend.update("m1", "new text")
|
||||
assert memory.calls[0][0] == "update"
|
||||
assert memory.calls[0][1] == "m1"
|
||||
assert memory.calls[0][2] == {"data": "new text"}
|
||||
|
||||
def test_delete_positional_arg(self):
|
||||
backend, memory = self._make()
|
||||
backend.delete("m1")
|
||||
assert memory.calls[0] == ("delete", "m1")
|
||||
|
||||
def test_update_normalizes_response(self):
|
||||
backend, _ = self._make()
|
||||
result = backend.update("m1", "text")
|
||||
assert result == {"result": "Memory updated.", "memory_id": "m1"}
|
||||
|
||||
def test_delete_normalizes_response(self):
|
||||
backend, _ = self._make()
|
||||
result = backend.delete("m1")
|
||||
assert result == {"result": "Memory deleted.", "memory_id": "m1"}
|
||||
107
tests/plugins/memory/test_mem0_providers.py
Normal file
107
tests/plugins/memory/test_mem0_providers.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Tests for OSS provider definitions and validation."""
|
||||
|
||||
import pytest
|
||||
|
||||
from plugins.memory.mem0._oss_providers import (
|
||||
LLM_PROVIDERS,
|
||||
EMBEDDER_PROVIDERS,
|
||||
VECTOR_PROVIDERS,
|
||||
KNOWN_DIMS,
|
||||
validate_oss_config,
|
||||
)
|
||||
|
||||
|
||||
class TestProviderDefinitions:
|
||||
|
||||
def test_llm_providers_have_required_keys(self):
|
||||
for pid, p in LLM_PROVIDERS.items():
|
||||
assert "label" in p
|
||||
assert "needs_key" in p
|
||||
assert "default_model" in p
|
||||
|
||||
def test_embedder_providers_have_required_keys(self):
|
||||
for pid, p in EMBEDDER_PROVIDERS.items():
|
||||
assert "label" in p
|
||||
assert "needs_key" in p
|
||||
assert "default_model" in p
|
||||
assert "dims" in p
|
||||
|
||||
def test_embedder_provider_ids(self):
|
||||
assert set(EMBEDDER_PROVIDERS.keys()) == {"openai", "ollama"}
|
||||
|
||||
def test_vector_providers_have_required_keys(self):
|
||||
for pid, p in VECTOR_PROVIDERS.items():
|
||||
assert "label" in p
|
||||
assert "default_config" in p
|
||||
|
||||
def test_vector_provider_ids(self):
|
||||
assert set(VECTOR_PROVIDERS.keys()) == {"qdrant", "pgvector"}
|
||||
|
||||
def test_known_dims_covers_defaults(self):
|
||||
for pid, p in EMBEDDER_PROVIDERS.items():
|
||||
assert p["default_model"] in KNOWN_DIMS
|
||||
|
||||
|
||||
class TestValidation:
|
||||
|
||||
def test_valid_openai_config(self):
|
||||
cfg = {
|
||||
"llm": {"provider": "openai", "config": {"model": "gpt-4o-mini"}},
|
||||
"embedder": {"provider": "openai", "config": {"model": "text-embedding-3-small"}},
|
||||
"vector_store": {"provider": "qdrant", "config": {"path": "/tmp/test"}},
|
||||
}
|
||||
errors = validate_oss_config(cfg)
|
||||
assert errors == []
|
||||
|
||||
def test_unknown_llm_provider(self):
|
||||
cfg = {
|
||||
"llm": {"provider": "gemini", "config": {}},
|
||||
"embedder": {"provider": "openai", "config": {}},
|
||||
"vector_store": {"provider": "qdrant", "config": {}},
|
||||
}
|
||||
errors = validate_oss_config(cfg)
|
||||
assert any("llm" in e.lower() for e in errors)
|
||||
|
||||
def test_unknown_embedder_provider(self):
|
||||
cfg = {
|
||||
"llm": {"provider": "openai", "config": {}},
|
||||
"embedder": {"provider": "cohere", "config": {}},
|
||||
"vector_store": {"provider": "qdrant", "config": {}},
|
||||
}
|
||||
errors = validate_oss_config(cfg)
|
||||
assert any("embedder" in e.lower() for e in errors)
|
||||
|
||||
def test_unknown_vector_provider(self):
|
||||
cfg = {
|
||||
"llm": {"provider": "openai", "config": {}},
|
||||
"embedder": {"provider": "openai", "config": {}},
|
||||
"vector_store": {"provider": "redis", "config": {}},
|
||||
}
|
||||
errors = validate_oss_config(cfg)
|
||||
assert any("vector" in e.lower() for e in errors)
|
||||
|
||||
def test_missing_llm_section(self):
|
||||
cfg = {
|
||||
"embedder": {"provider": "openai", "config": {}},
|
||||
"vector_store": {"provider": "qdrant", "config": {}},
|
||||
}
|
||||
errors = validate_oss_config(cfg)
|
||||
assert any("llm" in e.lower() for e in errors)
|
||||
|
||||
def test_pgvector_needs_user(self):
|
||||
cfg = {
|
||||
"llm": {"provider": "openai", "config": {}},
|
||||
"embedder": {"provider": "openai", "config": {}},
|
||||
"vector_store": {"provider": "pgvector", "config": {"host": "localhost"}},
|
||||
}
|
||||
errors = validate_oss_config(cfg)
|
||||
assert any("user" in e.lower() for e in errors)
|
||||
|
||||
def test_pgvector_with_user_valid(self):
|
||||
cfg = {
|
||||
"llm": {"provider": "openai", "config": {}},
|
||||
"embedder": {"provider": "openai", "config": {}},
|
||||
"vector_store": {"provider": "pgvector", "config": {"host": "localhost", "user": "pg"}},
|
||||
}
|
||||
errors = validate_oss_config(cfg)
|
||||
assert errors == []
|
||||
251
tests/plugins/memory/test_mem0_setup.py
Normal file
251
tests/plugins/memory/test_mem0_setup.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
"""Tests for Mem0 setup wizard — flag parsing, config building, validation."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from plugins.memory.mem0._setup import (
|
||||
parse_flags,
|
||||
build_oss_config,
|
||||
_write_env,
|
||||
post_setup,
|
||||
_check_qdrant_path,
|
||||
_check_ollama,
|
||||
_check_pgvector,
|
||||
)
|
||||
|
||||
|
||||
def _inject_fake_hermes_cli(monkeypatch):
|
||||
"""Inject fake hermes_cli modules so yaml/curses aren't required."""
|
||||
fake_config_mod = types.ModuleType("hermes_cli.config")
|
||||
fake_config_mod.save_config = lambda c: None
|
||||
|
||||
fake_setup_mod = types.ModuleType("hermes_cli.memory_setup")
|
||||
fake_setup_mod._curses_select = lambda *a, **kw: 0
|
||||
fake_setup_mod._prompt = lambda label, default=None, secret=False: default or ""
|
||||
|
||||
fake_hermes_cli = types.ModuleType("hermes_cli")
|
||||
fake_hermes_cli.config = fake_config_mod
|
||||
fake_hermes_cli.memory_setup = fake_setup_mod
|
||||
|
||||
monkeypatch.setitem(sys.modules, "hermes_cli", fake_hermes_cli)
|
||||
monkeypatch.setitem(sys.modules, "hermes_cli.config", fake_config_mod)
|
||||
monkeypatch.setitem(sys.modules, "hermes_cli.memory_setup", fake_setup_mod)
|
||||
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup._curses_select", lambda *a, **kw: 0)
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup._prompt", lambda label, default=None, secret=False: default or "")
|
||||
return fake_config_mod
|
||||
|
||||
|
||||
class TestParseFlags:
|
||||
|
||||
def test_mode_platform(self):
|
||||
flags = parse_flags(["--mode", "platform", "--api-key", "sk-test"])
|
||||
assert flags["mode"] == "platform"
|
||||
assert flags["api_key"] == "sk-test"
|
||||
|
||||
def test_mode_oss_defaults(self):
|
||||
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"])
|
||||
assert flags["mode"] == "oss"
|
||||
assert flags["oss_llm"] == "openai"
|
||||
assert flags["oss_embedder"] == "openai"
|
||||
assert flags["oss_vector"] == "qdrant"
|
||||
|
||||
def test_mode_oss_all_flags(self):
|
||||
flags = parse_flags([
|
||||
"--mode", "oss",
|
||||
"--oss-llm", "ollama",
|
||||
"--oss-llm-model", "llama3:latest",
|
||||
"--oss-embedder", "ollama",
|
||||
"--oss-embedder-model", "nomic-embed-text",
|
||||
"--oss-vector", "pgvector",
|
||||
"--oss-vector-host", "db.local",
|
||||
"--oss-vector-port", "5433",
|
||||
"--oss-vector-user", "pguser",
|
||||
"--oss-vector-password", "secret",
|
||||
"--oss-vector-dbname", "memdb",
|
||||
"--user-id", "my-user",
|
||||
])
|
||||
assert flags["oss_llm"] == "ollama"
|
||||
assert flags["oss_llm_model"] == "llama3:latest"
|
||||
assert flags["oss_vector"] == "pgvector"
|
||||
assert flags["oss_vector_user"] == "pguser"
|
||||
assert flags["user_id"] == "my-user"
|
||||
|
||||
def test_no_flags_returns_empty_mode(self):
|
||||
flags = parse_flags([])
|
||||
assert flags["mode"] == ""
|
||||
|
||||
def test_oss_vector_path_flag(self):
|
||||
flags = parse_flags(["--mode", "oss", "--oss-vector-path", "/data/qdrant"])
|
||||
assert flags["oss_vector_path"] == "/data/qdrant"
|
||||
|
||||
|
||||
class TestBuildOSSConfig:
|
||||
|
||||
def test_openai_defaults(self):
|
||||
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"])
|
||||
oss, env_writes = build_oss_config(flags)
|
||||
assert oss["llm"]["provider"] == "openai"
|
||||
assert oss["llm"]["config"]["model"] == "gpt-5-mini"
|
||||
assert oss["embedder"]["provider"] == "openai"
|
||||
assert oss["embedder"]["config"]["model"] == "text-embedding-3-small"
|
||||
assert oss["vector_store"]["provider"] == "qdrant"
|
||||
assert env_writes["OPENAI_API_KEY"] == "sk-oai"
|
||||
|
||||
def test_ollama_no_key_needed(self):
|
||||
flags = parse_flags(["--mode", "oss", "--oss-llm", "ollama", "--oss-embedder", "ollama"])
|
||||
oss, env_writes = build_oss_config(flags)
|
||||
assert oss["llm"]["provider"] == "ollama"
|
||||
assert "model" in oss["llm"]["config"]
|
||||
assert env_writes == {}
|
||||
|
||||
def test_embedder_reuses_llm_key(self):
|
||||
"""When LLM and embedder share same provider, key written once."""
|
||||
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"])
|
||||
_, env_writes = build_oss_config(flags)
|
||||
assert env_writes == {"OPENAI_API_KEY": "sk-oai"}
|
||||
|
||||
def test_different_embedder_needs_separate_key(self):
|
||||
flags = parse_flags([
|
||||
"--mode", "oss",
|
||||
"--oss-llm", "ollama",
|
||||
"--oss-embedder", "openai", "--oss-embedder-key", "sk-oai",
|
||||
])
|
||||
_, env_writes = build_oss_config(flags)
|
||||
assert env_writes == {"OPENAI_API_KEY": "sk-oai"}
|
||||
|
||||
def test_pgvector_config(self):
|
||||
flags = parse_flags([
|
||||
"--mode", "oss", "--oss-llm-key", "sk-oai",
|
||||
"--oss-vector", "pgvector",
|
||||
"--oss-vector-host", "db.local", "--oss-vector-port", "5433",
|
||||
"--oss-vector-user", "pg", "--oss-vector-dbname", "memdb",
|
||||
])
|
||||
oss, _ = build_oss_config(flags)
|
||||
vs = oss["vector_store"]
|
||||
assert vs["provider"] == "pgvector"
|
||||
assert vs["config"]["host"] == "db.local"
|
||||
assert vs["config"]["port"] == 5433
|
||||
assert vs["config"]["user"] == "pg"
|
||||
|
||||
def test_known_dims_auto_set(self):
|
||||
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"])
|
||||
oss, _ = build_oss_config(flags)
|
||||
dims = oss["embedder"]["config"].get("embedding_dims")
|
||||
assert dims == 1536
|
||||
|
||||
def test_custom_qdrant_path(self):
|
||||
flags = parse_flags([
|
||||
"--mode", "oss", "--oss-llm-key", "sk-oai",
|
||||
"--oss-vector-path", "/data/qdrant",
|
||||
])
|
||||
oss, _ = build_oss_config(flags)
|
||||
assert oss["vector_store"]["config"]["path"] == "/data/qdrant"
|
||||
|
||||
|
||||
class TestWriteEnv:
|
||||
|
||||
def test_write_new_vars(self, tmp_path):
|
||||
env_path = tmp_path / ".env"
|
||||
_write_env(env_path, {"OPENAI_API_KEY": "sk-test"})
|
||||
content = env_path.read_text()
|
||||
assert "OPENAI_API_KEY=sk-test" in content
|
||||
|
||||
def test_update_existing_var(self, tmp_path):
|
||||
env_path = tmp_path / ".env"
|
||||
env_path.write_text("OPENAI_API_KEY=old\nOTHER=keep\n")
|
||||
_write_env(env_path, {"OPENAI_API_KEY": "new"})
|
||||
content = env_path.read_text()
|
||||
assert "OPENAI_API_KEY=new" in content
|
||||
assert "OTHER=keep" in content
|
||||
assert "old" not in content
|
||||
|
||||
|
||||
class TestPostSetup:
|
||||
|
||||
def test_platform_flag_mode(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("sys.argv", ["hermes", "--mode", "platform", "--api-key", "sk-test"])
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path)
|
||||
_inject_fake_hermes_cli(monkeypatch)
|
||||
config = {"memory": {}}
|
||||
post_setup(str(tmp_path), config)
|
||||
assert config["memory"]["provider"] == "mem0"
|
||||
env_content = (tmp_path / ".env").read_text()
|
||||
assert "MEM0_API_KEY=sk-test" in env_content
|
||||
mem0_json = json.loads((tmp_path / "mem0.json").read_text())
|
||||
assert mem0_json["mode"] == "platform"
|
||||
|
||||
def test_oss_flag_mode(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("sys.argv", [
|
||||
"hermes", "--mode", "oss", "--oss-llm-key", "sk-oai",
|
||||
])
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path)
|
||||
_inject_fake_hermes_cli(monkeypatch)
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup._install_provider_deps", lambda l, e, v: None)
|
||||
config = {"memory": {}}
|
||||
post_setup(str(tmp_path), config)
|
||||
assert config["memory"]["provider"] == "mem0"
|
||||
mem0_json = json.loads((tmp_path / "mem0.json").read_text())
|
||||
assert mem0_json["mode"] == "oss"
|
||||
assert mem0_json["oss"]["llm"]["provider"] == "openai"
|
||||
|
||||
|
||||
class TestDryRun:
|
||||
|
||||
def test_dry_run_flag_parsed(self):
|
||||
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai", "--dry-run"])
|
||||
assert flags["dry_run"] is True
|
||||
|
||||
def test_dry_run_not_set_by_default(self):
|
||||
flags = parse_flags(["--mode", "oss"])
|
||||
assert flags["dry_run"] is False
|
||||
|
||||
def test_dry_run_platform_no_files(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("sys.argv", ["hermes", "--mode", "platform", "--api-key", "sk-test", "--dry-run"])
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path)
|
||||
_inject_fake_hermes_cli(monkeypatch)
|
||||
config = {"memory": {}}
|
||||
post_setup(str(tmp_path), config)
|
||||
assert not (tmp_path / ".env").exists()
|
||||
assert not (tmp_path / "mem0.json").exists()
|
||||
assert "provider" not in config["memory"]
|
||||
|
||||
def test_dry_run_oss_no_files(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("sys.argv", [
|
||||
"hermes", "--mode", "oss", "--oss-llm-key", "sk-oai", "--dry-run",
|
||||
])
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path)
|
||||
_inject_fake_hermes_cli(monkeypatch)
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup._install_provider_deps", lambda l, e, v: None)
|
||||
config = {"memory": {}}
|
||||
post_setup(str(tmp_path), config)
|
||||
assert not (tmp_path / ".env").exists()
|
||||
assert not (tmp_path / "mem0.json").exists()
|
||||
assert "provider" not in config["memory"]
|
||||
|
||||
|
||||
class TestConnectivityChecks:
|
||||
|
||||
def test_qdrant_path_writable(self, tmp_path):
|
||||
ok, msg = _check_qdrant_path(str(tmp_path / "qdrant"))
|
||||
assert ok is True
|
||||
|
||||
def test_qdrant_path_not_writable(self, tmp_path, monkeypatch):
|
||||
def _raise_oserror(*a, **kw):
|
||||
raise OSError("Permission denied")
|
||||
monkeypatch.setattr(Path, "mkdir", _raise_oserror)
|
||||
ok, msg = _check_qdrant_path(str(tmp_path / "qdrant"))
|
||||
assert ok is False
|
||||
assert "Permission denied" in msg
|
||||
|
||||
def test_ollama_unreachable(self):
|
||||
ok, msg = _check_ollama("http://localhost:1")
|
||||
assert ok is False
|
||||
|
||||
def test_pgvector_unreachable(self):
|
||||
ok, msg = _check_pgvector("localhost", 1)
|
||||
assert ok is False
|
||||
|
|
@ -1,241 +0,0 @@
|
|||
"""Tests for Mem0 API v2 compatibility — filters param and dict response unwrapping.
|
||||
|
||||
Salvaged from PRs #5301 (qaqcvc) and #5117 (vvvanguards).
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
|
||||
import pytest
|
||||
|
||||
from plugins.memory.mem0 import Mem0MemoryProvider
|
||||
|
||||
|
||||
class FakeClientV2:
|
||||
"""Fake Mem0 client that returns v2-style dict responses and captures call kwargs."""
|
||||
|
||||
def __init__(self, search_results=None, all_results=None):
|
||||
self._search_results = search_results or {"results": []}
|
||||
self._all_results = all_results or {"results": []}
|
||||
self.captured_search = {}
|
||||
self.captured_get_all = {}
|
||||
self.captured_add = []
|
||||
|
||||
def search(self, **kwargs):
|
||||
self.captured_search = kwargs
|
||||
return self._search_results
|
||||
|
||||
def get_all(self, **kwargs):
|
||||
self.captured_get_all = kwargs
|
||||
return self._all_results
|
||||
|
||||
def add(self, messages, **kwargs):
|
||||
self.captured_add.append({"messages": messages, **kwargs})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filter migration: bare user_id= -> filters={}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMem0FiltersV2:
|
||||
"""All API calls must use filters={} instead of bare user_id= kwargs."""
|
||||
|
||||
def _make_provider(self, monkeypatch, client):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
monkeypatch.setattr(provider, "_get_client", lambda: client)
|
||||
return provider
|
||||
|
||||
def test_search_uses_filters(self, monkeypatch):
|
||||
client = FakeClientV2()
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
provider.handle_tool_call("mem0_search", {"query": "hello", "top_k": 3, "rerank": False})
|
||||
|
||||
assert client.captured_search["query"] == "hello"
|
||||
assert client.captured_search["top_k"] == 3
|
||||
assert client.captured_search["rerank"] is False
|
||||
assert client.captured_search["filters"] == {"user_id": "u123"}
|
||||
# Must NOT have bare user_id kwarg
|
||||
assert "user_id" not in {k for k in client.captured_search if k != "filters"}
|
||||
|
||||
def test_profile_uses_filters(self, monkeypatch):
|
||||
client = FakeClientV2()
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
provider.handle_tool_call("mem0_profile", {})
|
||||
|
||||
assert client.captured_get_all["filters"] == {"user_id": "u123"}
|
||||
assert "user_id" not in {k for k in client.captured_get_all if k != "filters"}
|
||||
|
||||
def test_prefetch_uses_filters(self, monkeypatch):
|
||||
client = FakeClientV2()
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
provider.queue_prefetch("hello")
|
||||
provider._prefetch_thread.join(timeout=2)
|
||||
|
||||
assert client.captured_search["query"] == "hello"
|
||||
assert client.captured_search["filters"] == {"user_id": "u123"}
|
||||
assert "user_id" not in {k for k in client.captured_search if k != "filters"}
|
||||
|
||||
def test_sync_turn_uses_write_filters(self, monkeypatch):
|
||||
client = FakeClientV2()
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
provider.sync_turn("user said this", "assistant replied", session_id="s1")
|
||||
provider._sync_thread.join(timeout=2)
|
||||
|
||||
assert len(client.captured_add) == 1
|
||||
call = client.captured_add[0]
|
||||
assert call["user_id"] == "u123"
|
||||
assert call["agent_id"] == "hermes"
|
||||
|
||||
def test_conclude_uses_write_filters(self, monkeypatch):
|
||||
client = FakeClientV2()
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
provider.handle_tool_call("mem0_conclude", {"conclusion": "user likes dark mode"})
|
||||
|
||||
assert len(client.captured_add) == 1
|
||||
call = client.captured_add[0]
|
||||
assert call["user_id"] == "u123"
|
||||
assert call["agent_id"] == "hermes"
|
||||
assert call["infer"] is False
|
||||
|
||||
def test_read_filters_no_agent_id(self):
|
||||
"""Read filters should use user_id only — cross-session recall across agents."""
|
||||
provider = Mem0MemoryProvider()
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
assert provider._read_filters() == {"user_id": "u123"}
|
||||
|
||||
def test_write_filters_include_agent_id(self):
|
||||
"""Write filters should include agent_id for attribution."""
|
||||
provider = Mem0MemoryProvider()
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
assert provider._write_filters() == {"user_id": "u123", "agent_id": "hermes"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dict response unwrapping (API v2 wraps in {"results": [...]})
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMem0ResponseUnwrapping:
|
||||
"""API v2 returns {"results": [...]} dicts; we must extract the list."""
|
||||
|
||||
def _make_provider(self, monkeypatch, client):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
monkeypatch.setattr(provider, "_get_client", lambda: client)
|
||||
return provider
|
||||
|
||||
def test_profile_dict_response(self, monkeypatch):
|
||||
client = FakeClientV2(all_results={"results": [{"memory": "alpha"}, {"memory": "beta"}]})
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
result = json.loads(provider.handle_tool_call("mem0_profile", {}))
|
||||
|
||||
assert result["count"] == 2
|
||||
assert "alpha" in result["result"]
|
||||
assert "beta" in result["result"]
|
||||
|
||||
def test_profile_list_response_backward_compat(self, monkeypatch):
|
||||
"""Old API returned bare lists — still works."""
|
||||
client = FakeClientV2(all_results=[{"memory": "gamma"}])
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
result = json.loads(provider.handle_tool_call("mem0_profile", {}))
|
||||
assert result["count"] == 1
|
||||
assert "gamma" in result["result"]
|
||||
|
||||
def test_search_dict_response(self, monkeypatch):
|
||||
client = FakeClientV2(search_results={
|
||||
"results": [{"memory": "foo", "score": 0.9}, {"memory": "bar", "score": 0.7}]
|
||||
})
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_search", {"query": "test", "top_k": 5}
|
||||
))
|
||||
|
||||
assert result["count"] == 2
|
||||
assert result["results"][0]["memory"] == "foo"
|
||||
|
||||
def test_search_list_response_backward_compat(self, monkeypatch):
|
||||
"""Old API returned bare lists — still works."""
|
||||
client = FakeClientV2(search_results=[{"memory": "baz", "score": 0.8}])
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_search", {"query": "test"}
|
||||
))
|
||||
assert result["count"] == 1
|
||||
|
||||
def test_unwrap_results_edge_cases(self):
|
||||
"""_unwrap_results handles all shapes gracefully."""
|
||||
assert Mem0MemoryProvider._unwrap_results({"results": [1, 2]}) == [1, 2]
|
||||
assert Mem0MemoryProvider._unwrap_results([3, 4]) == [3, 4]
|
||||
assert Mem0MemoryProvider._unwrap_results({}) == []
|
||||
assert Mem0MemoryProvider._unwrap_results(None) == []
|
||||
assert Mem0MemoryProvider._unwrap_results("unexpected") == []
|
||||
|
||||
def test_prefetch_dict_response(self, monkeypatch):
|
||||
client = FakeClientV2(search_results={
|
||||
"results": [{"memory": "user prefers dark mode"}]
|
||||
})
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
monkeypatch.setattr(provider, "_get_client", lambda: client)
|
||||
|
||||
provider.queue_prefetch("preferences")
|
||||
provider._prefetch_thread.join(timeout=2)
|
||||
result = provider.prefetch("preferences")
|
||||
|
||||
assert "dark mode" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.name == "nt", reason="POSIX mode bits not enforced on Windows")
|
||||
def test_save_config_sets_owner_only_permissions(tmp_path):
|
||||
"""mem0.json must be written with 0o600 so API key is not world-readable."""
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.save_config({"api_key": "m0-test-key"}, str(tmp_path))
|
||||
config_file = tmp_path / "mem0.json"
|
||||
assert config_file.exists()
|
||||
mode = stat.S_IMODE(config_file.stat().st_mode)
|
||||
assert mode == 0o600, f"Expected 0o600 (owner-only), got {oct(mode)}"
|
||||
|
||||
|
||||
class TestMem0Defaults:
|
||||
"""Ensure we don't break existing users' defaults."""
|
||||
|
||||
def test_default_user_id_hermes_user(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("MEM0_API_KEY", "test-key")
|
||||
monkeypatch.delenv("MEM0_USER_ID", raising=False)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test")
|
||||
|
||||
assert provider._user_id == "hermes-user"
|
||||
|
||||
def test_default_agent_id_hermes(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("MEM0_API_KEY", "test-key")
|
||||
monkeypatch.delenv("MEM0_AGENT_ID", raising=False)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test")
|
||||
|
||||
assert provider._agent_id == "hermes"
|
||||
463
tests/plugins/memory/test_mem0_v3.py
Normal file
463
tests/plugins/memory/test_mem0_v3.py
Normal file
|
|
@ -0,0 +1,463 @@
|
|||
"""Tests for Mem0 v3 API — new tool names, paginated responses, update/delete tools."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from plugins.memory.mem0 import Mem0MemoryProvider
|
||||
|
||||
|
||||
class FakeBackend:
|
||||
"""Fake Mem0Backend for provider-level tests."""
|
||||
|
||||
def __init__(self, search_results=None, all_results=None):
|
||||
self._search_results = search_results or []
|
||||
self._all_results = all_results or {"results": [], "count": 0}
|
||||
self.captured = []
|
||||
|
||||
def search(self, query, *, filters, top_k=10, rerank=True):
|
||||
self.captured.append(("search", query, {"filters": filters, "top_k": top_k, "rerank": rerank}))
|
||||
return self._search_results
|
||||
|
||||
def get_all(self, *, filters, page=1, page_size=100):
|
||||
self.captured.append(("get_all", {"filters": filters, "page": page, "page_size": page_size}))
|
||||
return self._all_results
|
||||
|
||||
def add(self, messages, *, user_id, agent_id, infer=False, metadata=None):
|
||||
self.captured.append((
|
||||
"add",
|
||||
messages,
|
||||
{"user_id": user_id, "agent_id": agent_id, "infer": infer, "metadata": metadata},
|
||||
))
|
||||
return {"status": "PENDING", "event_id": "evt-test-123"}
|
||||
|
||||
def update(self, memory_id, text):
|
||||
self.captured.append(("update", memory_id, text))
|
||||
return {"result": "Memory updated.", "memory_id": memory_id}
|
||||
|
||||
def delete(self, memory_id):
|
||||
self.captured.append(("delete", memory_id))
|
||||
return {"result": "Memory deleted.", "memory_id": memory_id}
|
||||
|
||||
|
||||
class TestMem0V3Tools:
|
||||
"""Test v3 tool names and response handling."""
|
||||
|
||||
def _make_provider(self, monkeypatch, backend):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
provider._backend = backend
|
||||
return provider
|
||||
|
||||
def test_list_returns_paginated_with_ids(self, monkeypatch):
|
||||
backend = FakeBackend(all_results={
|
||||
"count": 2,
|
||||
"results": [
|
||||
{"id": "mem-1", "memory": "alpha"},
|
||||
{"id": "mem-2", "memory": "beta"},
|
||||
]
|
||||
})
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_list", {}))
|
||||
assert result["count"] == 2
|
||||
assert result["results"][0]["id"] == "mem-1"
|
||||
assert result["results"][0]["memory"] == "alpha"
|
||||
|
||||
def test_list_pagination_params(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
provider.handle_tool_call("mem0_list", {"page": 2, "page_size": 50})
|
||||
assert backend.captured[0][1]["page"] == 2
|
||||
assert backend.captured[0][1]["page_size"] == 50
|
||||
|
||||
def test_list_empty(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_list", {}))
|
||||
assert result["result"] == "No memories stored yet."
|
||||
|
||||
def test_search_returns_ids(self, monkeypatch):
|
||||
backend = FakeBackend(search_results=[{"id": "mem-1", "memory": "foo", "score": 0.9}])
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_search", {"query": "test"}))
|
||||
assert result["results"][0]["id"] == "mem-1"
|
||||
|
||||
def test_search_uses_filters(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
provider.handle_tool_call("mem0_search", {"query": "hello", "top_k": 3})
|
||||
assert backend.captured[0][2]["filters"] == {"user_id": "u123"}
|
||||
assert backend.captured[0][2]["top_k"] == 3
|
||||
|
||||
def test_search_rerank_default_true(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
provider.handle_tool_call("mem0_search", {"query": "test"})
|
||||
assert backend.captured[0][2]["rerank"] is True
|
||||
|
||||
def test_search_rerank_override_false(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
provider.handle_tool_call("mem0_search", {"query": "test", "rerank": False})
|
||||
assert backend.captured[0][2]["rerank"] is False
|
||||
|
||||
def test_add_uses_content_param(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_add", {"content": "user likes dark mode"}))
|
||||
assert len(backend.captured) == 1
|
||||
call = backend.captured[0]
|
||||
assert call[2]["infer"] is False
|
||||
assert call[2]["user_id"] == "u123"
|
||||
assert call[2]["agent_id"] == "hermes"
|
||||
assert "event_id" in result
|
||||
|
||||
def test_add_returns_event_id(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_add", {"content": "test"}))
|
||||
assert result["event_id"] == "evt-test-123"
|
||||
|
||||
def test_add_missing_content(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_add", {}))
|
||||
assert "error" in result
|
||||
|
||||
def test_old_tool_names_return_unknown(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_profile", {}))
|
||||
assert "error" in result
|
||||
result = json.loads(provider.handle_tool_call("mem0_conclude", {}))
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestMem0UpdateDelete:
|
||||
|
||||
def _make_provider(self, monkeypatch, backend):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
provider._backend = backend
|
||||
return provider
|
||||
|
||||
def test_update_calls_sdk(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_update", {"memory_id": "mem-1", "text": "updated fact"}
|
||||
))
|
||||
assert backend.captured[0][1] == "mem-1"
|
||||
assert backend.captured[0][2] == "updated fact"
|
||||
assert result["result"] == "Memory updated."
|
||||
assert result["memory_id"] == "mem-1"
|
||||
|
||||
def test_update_missing_memory_id(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_update", {"text": "no id"}))
|
||||
assert "error" in result
|
||||
|
||||
def test_update_missing_text(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_update", {"memory_id": "mem-1"}))
|
||||
assert "error" in result
|
||||
|
||||
def test_delete_calls_sdk(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_delete", {"memory_id": "mem-1"}
|
||||
))
|
||||
assert backend.captured[0][1] == "mem-1"
|
||||
assert result["result"] == "Memory deleted."
|
||||
|
||||
def test_delete_missing_memory_id(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_delete", {}))
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestMem0ErrorHandling:
|
||||
|
||||
def _make_provider(self, monkeypatch, backend):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
provider._backend = backend
|
||||
return provider
|
||||
|
||||
def test_update_404_no_circuit_breaker(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
backend.update = lambda mid, text: (_ for _ in ()).throw(Exception("404 Not Found"))
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_update", {"memory_id": "bad-id", "text": "x"}
|
||||
))
|
||||
assert "error" in result
|
||||
assert provider._consecutive_failures == 0
|
||||
|
||||
def test_delete_404_no_circuit_breaker(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
backend.delete = lambda mid: (_ for _ in ()).throw(Exception("404 not found"))
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_delete", {"memory_id": "bad-id"}
|
||||
))
|
||||
assert "error" in result
|
||||
assert provider._consecutive_failures == 0
|
||||
|
||||
def test_update_validation_error_no_circuit_breaker(self, monkeypatch):
|
||||
"""ValidationError (bad UUID format) should not trip circuit breaker."""
|
||||
class ValidationError(Exception):
|
||||
pass
|
||||
backend = FakeBackend()
|
||||
backend.update = lambda mid, text: (_ for _ in ()).throw(
|
||||
ValidationError('{"error":"memory_id should be a valid UUID"}')
|
||||
)
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_update", {"memory_id": "not-a-uuid", "text": "x"}
|
||||
))
|
||||
assert "error" in result
|
||||
assert provider._consecutive_failures == 0
|
||||
|
||||
def test_delete_validation_error_no_circuit_breaker(self, monkeypatch):
|
||||
class ValidationError(Exception):
|
||||
pass
|
||||
backend = FakeBackend()
|
||||
backend.delete = lambda mid: (_ for _ in ()).throw(
|
||||
ValidationError('{"error":"memory_id should be a valid UUID"}')
|
||||
)
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_delete", {"memory_id": "not-a-uuid"}
|
||||
))
|
||||
assert "error" in result
|
||||
assert provider._consecutive_failures == 0
|
||||
|
||||
def test_update_5xx_trips_circuit_breaker(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
backend.update = lambda mid, text: (_ for _ in ()).throw(Exception("500 Internal Server Error"))
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
provider.handle_tool_call("mem0_update", {"memory_id": "mem-1", "text": "x"})
|
||||
assert provider._consecutive_failures == 1
|
||||
|
||||
|
||||
class TestMem0V3Internal:
|
||||
|
||||
def _make_provider(self, monkeypatch, backend):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
provider._backend = backend
|
||||
return provider
|
||||
|
||||
def test_sync_turn_explicit_kwargs(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
provider.sync_turn("user said", "assistant replied", session_id="s1")
|
||||
provider._sync_thread.join(timeout=2)
|
||||
assert len(backend.captured) == 1
|
||||
call = backend.captured[0]
|
||||
assert call[2]["user_id"] == "u123"
|
||||
assert call[2]["agent_id"] == "hermes"
|
||||
assert call[2]["infer"] is True
|
||||
|
||||
def test_old_tool_names_return_unknown(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_profile", {}))
|
||||
assert "error" in result
|
||||
result = json.loads(provider.handle_tool_call("mem0_conclude", {}))
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestMem0V3Config:
|
||||
|
||||
def test_tool_schemas_five_tools(self):
|
||||
provider = Mem0MemoryProvider()
|
||||
schemas = provider.get_tool_schemas()
|
||||
names = [s["name"] for s in schemas]
|
||||
assert names == ["mem0_list", "mem0_search", "mem0_add", "mem0_update", "mem0_delete"]
|
||||
|
||||
def test_system_prompt_new_tool_names(self):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider._user_id = "test"
|
||||
block = provider.system_prompt_block()
|
||||
assert "mem0_search" in block
|
||||
assert "mem0_add" in block
|
||||
assert "mem0_list" in block
|
||||
assert "mem0_update" in block
|
||||
assert "mem0_delete" in block
|
||||
assert "mem0_profile" not in block
|
||||
assert "mem0_conclude" not in block
|
||||
|
||||
def test_system_prompt_shows_platform_mode(self):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider._user_id = "test"
|
||||
provider._mode = "platform"
|
||||
block = provider.system_prompt_block()
|
||||
assert "platform" in block
|
||||
assert "Rerank" in block
|
||||
|
||||
def test_system_prompt_shows_oss_mode(self):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider._user_id = "test"
|
||||
provider._mode = "oss"
|
||||
block = provider.system_prompt_block()
|
||||
assert "OSS" in block
|
||||
assert "Rerank" not in block
|
||||
|
||||
def test_search_schema_has_rerank(self):
|
||||
"""rerank property available in SEARCH_SCHEMA for platform mode."""
|
||||
provider = Mem0MemoryProvider()
|
||||
schemas = provider.get_tool_schemas()
|
||||
search = next(s for s in schemas if s["name"] == "mem0_search")
|
||||
assert "rerank" in search["parameters"]["properties"]
|
||||
assert search["parameters"]["properties"]["rerank"]["type"] == "boolean"
|
||||
|
||||
|
||||
class TestMem0ModeSwitch:
|
||||
|
||||
def test_default_mode_is_platform(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("MEM0_API_KEY", "test-key")
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test")
|
||||
assert provider._mode == "platform"
|
||||
|
||||
def test_missing_mode_key_defaults_platform(self, monkeypatch, tmp_path):
|
||||
"""Backward compat: old mem0.json without mode key works."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config_path = tmp_path / "mem0.json"
|
||||
config_path.write_text('{"user_id": "old-user"}')
|
||||
monkeypatch.setenv("MEM0_API_KEY", "test-key")
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test")
|
||||
assert provider._mode == "platform"
|
||||
assert provider._user_id == "old-user"
|
||||
|
||||
def test_is_available_platform_needs_key(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.delenv("MEM0_API_KEY", raising=False)
|
||||
provider = Mem0MemoryProvider()
|
||||
assert provider.is_available() is False
|
||||
|
||||
def test_is_available_oss_needs_vector(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config_path = tmp_path / "mem0.json"
|
||||
config_path.write_text('{"mode": "oss", "oss": {"vector_store": {"provider": "qdrant"}}}')
|
||||
provider = Mem0MemoryProvider()
|
||||
assert provider.is_available() is True
|
||||
|
||||
def test_is_available_oss_no_vector(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config_path = tmp_path / "mem0.json"
|
||||
config_path.write_text('{"mode": "oss", "oss": {}}')
|
||||
provider = Mem0MemoryProvider()
|
||||
assert provider.is_available() is False
|
||||
|
||||
def test_tool_schemas_unchanged(self):
|
||||
provider = Mem0MemoryProvider()
|
||||
schemas = provider.get_tool_schemas()
|
||||
names = [s["name"] for s in schemas]
|
||||
assert names == ["mem0_list", "mem0_search", "mem0_add", "mem0_update", "mem0_delete"]
|
||||
|
||||
def test_system_prompt_includes_mode(self):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider._user_id = "test"
|
||||
provider._mode = "oss"
|
||||
block = provider.system_prompt_block()
|
||||
assert "mem0_search" in block
|
||||
assert "mem0_list" in block
|
||||
assert "OSS" in block
|
||||
|
||||
|
||||
class TestMem0UserIdResolution:
|
||||
"""user_id resolution: configured override > gateway-native id > placeholder.
|
||||
|
||||
Same human across CLI / Telegram / Discord / Slack / etc. should map to
|
||||
the same memory store when MEM0_USER_ID is set, and only fall back to the
|
||||
gateway-native id when it isn't.
|
||||
"""
|
||||
|
||||
def _provider(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("MEM0_API_KEY", "test-key")
|
||||
provider = Mem0MemoryProvider()
|
||||
# Skip backend instantiation — we only care about identity resolution.
|
||||
provider._create_backend = lambda: None # type: ignore[method-assign]
|
||||
return provider
|
||||
|
||||
def test_env_override_beats_gateway_native_id(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("MEM0_USER_ID", "ryan@example.com")
|
||||
provider = self._provider(monkeypatch, tmp_path)
|
||||
provider.initialize("test", user_id="123456789", platform="telegram")
|
||||
assert provider._user_id == "ryan@example.com"
|
||||
|
||||
def test_file_override_beats_gateway_native_id(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("MEM0_USER_ID", raising=False)
|
||||
(tmp_path / "mem0.json").write_text('{"user_id": "ryan@example.com"}')
|
||||
provider = self._provider(monkeypatch, tmp_path)
|
||||
provider.initialize("test", user_id="123456789", platform="telegram")
|
||||
assert provider._user_id == "ryan@example.com"
|
||||
|
||||
def test_unset_falls_back_to_gateway_native_id(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("MEM0_USER_ID", raising=False)
|
||||
provider = self._provider(monkeypatch, tmp_path)
|
||||
provider.initialize("test", user_id="123456789", platform="telegram")
|
||||
assert provider._user_id == "123456789"
|
||||
|
||||
def test_unset_and_no_kwargs_falls_back_to_default(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("MEM0_USER_ID", raising=False)
|
||||
provider = self._provider(monkeypatch, tmp_path)
|
||||
provider.initialize("test")
|
||||
assert provider._user_id == "hermes-user"
|
||||
|
||||
def test_legacy_placeholder_in_config_does_not_override_kwargs(self, monkeypatch, tmp_path):
|
||||
# Setup wizard historically wrote {"user_id": "hermes-user"} as the
|
||||
# suggested default. Treat that placeholder as unset so users on
|
||||
# gateways still get gateway-native ids — not silent collisions.
|
||||
monkeypatch.delenv("MEM0_USER_ID", raising=False)
|
||||
(tmp_path / "mem0.json").write_text('{"user_id": "hermes-user"}')
|
||||
provider = self._provider(monkeypatch, tmp_path)
|
||||
provider.initialize("test", user_id="123456789", platform="telegram")
|
||||
assert provider._user_id == "123456789"
|
||||
|
||||
|
||||
class TestMem0WriteMetadata:
|
||||
"""Writes carry metadata.channel so per-channel filtered views are possible
|
||||
without coupling identity to the channel.
|
||||
"""
|
||||
|
||||
def _make_provider(self, channel: str = "cli"):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
provider._channel = channel
|
||||
provider._backend = FakeBackend()
|
||||
return provider
|
||||
|
||||
def test_add_tool_passes_channel_metadata(self):
|
||||
provider = self._make_provider("telegram")
|
||||
provider.handle_tool_call("mem0_add", {"content": "user likes dark mode"})
|
||||
call = provider._backend.captured[-1]
|
||||
assert call[2]["metadata"] == {"channel": "telegram"}
|
||||
|
||||
def test_sync_turn_passes_channel_metadata(self):
|
||||
provider = self._make_provider("discord")
|
||||
provider.sync_turn("hi", "hello", session_id="s")
|
||||
# sync_turn fires a daemon thread; wait for it.
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
adds = [c for c in provider._backend.captured if c[0] == "add"]
|
||||
assert adds, "expected an add call from sync_turn"
|
||||
assert adds[-1][2]["metadata"] == {"channel": "discord"}
|
||||
|
|
@ -1459,6 +1459,137 @@ def test_tool_add_resource_sends_git_remote_sources_as_path(url):
|
|||
})
|
||||
|
||||
|
||||
def test_get_tool_schemas_includes_narrow_forget_tool():
|
||||
provider = OpenVikingMemoryProvider()
|
||||
|
||||
names = [schema["name"] for schema in provider.get_tool_schemas()]
|
||||
|
||||
assert "viking_forget" in names
|
||||
|
||||
|
||||
def test_handle_tool_call_forget_deletes_exact_memory_file_uri():
|
||||
uri = "viking://user/peers/hermes/memories/preferences/mem_abc123.md"
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = MagicMock()
|
||||
provider._client.delete.return_value = {
|
||||
"status": "ok",
|
||||
"result": {"uri": uri, "estimated_deleted_count": 1},
|
||||
}
|
||||
|
||||
result = json.loads(provider.handle_tool_call("viking_forget", {"uri": uri}))
|
||||
|
||||
provider._client.delete.assert_called_once_with(
|
||||
"/api/v1/fs",
|
||||
params={"uri": uri, "recursive": False},
|
||||
)
|
||||
assert result == {
|
||||
"status": "deleted",
|
||||
"uri": uri,
|
||||
"estimated_deleted_count": 1,
|
||||
}
|
||||
|
||||
|
||||
def test_handle_tool_call_forget_deletes_exact_memory_file_under_memories_root():
|
||||
uri = "viking://user/default/memories/profile.md"
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = MagicMock()
|
||||
provider._client.delete.return_value = {
|
||||
"status": "ok",
|
||||
"result": {"uri": uri, "estimated_deleted_count": 1},
|
||||
}
|
||||
|
||||
result = json.loads(provider.handle_tool_call("viking_forget", {"uri": uri}))
|
||||
|
||||
provider._client.delete.assert_called_once_with(
|
||||
"/api/v1/fs",
|
||||
params={"uri": uri, "recursive": False},
|
||||
)
|
||||
assert result == {
|
||||
"status": "deleted",
|
||||
"uri": uri,
|
||||
"estimated_deleted_count": 1,
|
||||
}
|
||||
|
||||
|
||||
def test_handle_tool_call_forget_allows_non_generated_dot_md_memory_file():
|
||||
uri = "viking://user/default/memories/preferences/.full.md"
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = MagicMock()
|
||||
provider._client.delete.return_value = {
|
||||
"status": "ok",
|
||||
"result": {"uri": uri, "estimated_deleted_count": 1},
|
||||
}
|
||||
|
||||
result = json.loads(provider.handle_tool_call("viking_forget", {"uri": uri}))
|
||||
|
||||
provider._client.delete.assert_called_once_with(
|
||||
"/api/v1/fs",
|
||||
params={"uri": uri, "recursive": False},
|
||||
)
|
||||
assert result == {
|
||||
"status": "deleted",
|
||||
"uri": uri,
|
||||
"estimated_deleted_count": 1,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("uri", [
|
||||
"",
|
||||
"https://example.com/mem.md",
|
||||
"viking:/user/memories/preferences/mem_abc123.md",
|
||||
"viking://resources/project/doc.md",
|
||||
"viking://resources/project/memories/mem_abc123.md",
|
||||
"viking://memories/preferences/mem_abc123.md",
|
||||
"viking://agent/hermes/memories/preferences/mem_abc123.md",
|
||||
"viking://user/skills/example/SKILL.md",
|
||||
"viking://user/sessions/session-1/messages.jsonl",
|
||||
"viking://user/memories/preferences/",
|
||||
"viking://user/memories/preferences/.overview.md",
|
||||
"viking://user/memories/preferences/.abstract.md",
|
||||
"viking://user/memories/preferences/mem_abc123.md?recursive=true",
|
||||
])
|
||||
def test_handle_tool_call_forget_rejects_non_memory_file_uris(uri):
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = MagicMock()
|
||||
|
||||
result = json.loads(provider.handle_tool_call("viking_forget", {"uri": uri}))
|
||||
|
||||
assert "error" in result
|
||||
provider._client.delete.assert_not_called()
|
||||
|
||||
|
||||
def test_viking_client_delete_uses_identity_headers(monkeypatch):
|
||||
client = _VikingClient(
|
||||
"https://example.com",
|
||||
api_key="test-key",
|
||||
account="acct",
|
||||
user="alice",
|
||||
agent="hermes",
|
||||
)
|
||||
captured = {}
|
||||
|
||||
def capture_delete(url, **kwargs):
|
||||
captured["url"] = url
|
||||
captured["kwargs"] = kwargs
|
||||
return SimpleNamespace(
|
||||
status_code=200,
|
||||
text="",
|
||||
json=lambda: {"status": "ok", "result": {"uri": "viking://user/memories/x.md"}},
|
||||
raise_for_status=lambda: None,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(client._httpx, "delete", capture_delete)
|
||||
|
||||
assert client.delete("/api/v1/fs", params={"uri": "viking://user/memories/x.md"}) == {
|
||||
"status": "ok",
|
||||
"result": {"uri": "viking://user/memories/x.md"},
|
||||
}
|
||||
assert captured["url"] == "https://example.com/api/v1/fs"
|
||||
assert captured["kwargs"]["params"] == {"uri": "viking://user/memories/x.md"}
|
||||
assert captured["kwargs"]["headers"]["Authorization"] == "Bearer test-key"
|
||||
assert captured["kwargs"]["headers"]["X-OpenViking-Actor-Peer"] == "hermes"
|
||||
|
||||
|
||||
def test_viking_client_upload_temp_file_uses_multipart_identity_headers(tmp_path, monkeypatch):
|
||||
sample = tmp_path / "sample.md"
|
||||
sample.write_text("# Local resource\n", encoding="utf-8")
|
||||
|
|
@ -2637,6 +2768,94 @@ def test_on_memory_write_uses_content_write_independent_of_session_rotation():
|
|||
)
|
||||
|
||||
|
||||
def test_shutdown_waits_for_memory_write_worker(monkeypatch):
|
||||
import threading
|
||||
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = MagicMock()
|
||||
provider._endpoint = "http://test"
|
||||
provider._api_key = ""
|
||||
provider._account = "acct"
|
||||
provider._user = "usr"
|
||||
provider._agent = "hermes"
|
||||
|
||||
worker_started = threading.Event()
|
||||
release_worker = threading.Event()
|
||||
worker_finished = threading.Event()
|
||||
shutdown_returned = threading.Event()
|
||||
|
||||
class StubClient:
|
||||
def __init__(self, *a, **kw):
|
||||
pass
|
||||
|
||||
def post(self, path, payload=None, **kwargs):
|
||||
assert path == "/api/v1/content/write"
|
||||
worker_started.set()
|
||||
release_worker.wait(timeout=2.0)
|
||||
worker_finished.set()
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(openviking_module, "_VikingClient", StubClient)
|
||||
|
||||
provider.on_memory_write("add", "user", "remember this")
|
||||
assert worker_started.wait(timeout=2.0), "worker never entered post()"
|
||||
|
||||
shutdown_thread = threading.Thread(
|
||||
target=lambda: (provider.shutdown(), shutdown_returned.set()),
|
||||
daemon=True,
|
||||
)
|
||||
shutdown_thread.start()
|
||||
|
||||
returned_before_worker_finished = shutdown_returned.wait(timeout=0.1)
|
||||
release_worker.set()
|
||||
assert shutdown_returned.wait(timeout=2.0), "shutdown did not return after worker finished"
|
||||
shutdown_thread.join(timeout=2.0)
|
||||
|
||||
assert not returned_before_worker_finished
|
||||
assert worker_finished.is_set()
|
||||
assert provider._memory_write_threads == set()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("action", "content"),
|
||||
[
|
||||
("replace", "updated memory"),
|
||||
("remove", ""),
|
||||
("forget", ""),
|
||||
("delete", ""),
|
||||
],
|
||||
)
|
||||
def test_on_memory_write_ignores_non_add_actions(action, content, monkeypatch):
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = MagicMock()
|
||||
provider._endpoint = "http://test"
|
||||
provider._api_key = ""
|
||||
provider._account = "acct"
|
||||
provider._user = "usr"
|
||||
provider._agent = "hermes"
|
||||
uri = "viking://user/peers/hermes/memories/preferences/mem_abc123.md"
|
||||
spawned = []
|
||||
|
||||
class StubThread:
|
||||
def __init__(self, *args, **kwargs):
|
||||
spawned.append((args, kwargs))
|
||||
|
||||
def start(self):
|
||||
raise AssertionError("non-URI remove should not spawn a mirror thread")
|
||||
|
||||
import plugins.memory.openviking as _mod
|
||||
monkeypatch.setattr(_mod.threading, "Thread", StubThread)
|
||||
|
||||
provider.on_memory_write(
|
||||
action,
|
||||
"memory",
|
||||
content,
|
||||
metadata={"uri": uri, "old_text": "stale fact"},
|
||||
)
|
||||
|
||||
assert spawned == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prefetch staleness: a prefetch worker that finishes AFTER a session switch
|
||||
# must drop its result instead of repopulating the new session with stale
|
||||
|
|
|
|||
153
tests/plugins/model_providers/test_ollama_cloud_profile.py
Normal file
153
tests/plugins/model_providers/test_ollama_cloud_profile.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
"""Unit tests for the Ollama Cloud provider profile's reasoning-effort wiring.
|
||||
|
||||
Ollama Cloud's ``/v1/chat/completions`` endpoint supports top-level
|
||||
``reasoning_effort`` with values ``none``, ``low``, ``medium``, ``high``,
|
||||
and (undocumented but empirically confirmed) ``max``. The profile maps
|
||||
Hermes's ``xhigh`` → ``max`` to unlock DeepSeek V4's "Max thinking" tier
|
||||
and passes the standard levels through unchanged.
|
||||
|
||||
These tests pin the profile's wire-shape contract so Ollama Cloud
|
||||
requests carry the correct ``reasoning_effort`` field.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_cloud_profile():
|
||||
"""Resolve the registered Ollama Cloud profile.
|
||||
|
||||
Going through ``providers.get_provider_profile`` keeps the test
|
||||
honest — if someone replaces the registered class with a plain
|
||||
``ProviderProfile``, every assertion below collapses.
|
||||
"""
|
||||
# ``model_tools`` triggers plugin discovery on import, which is what
|
||||
# registers the Ollama Cloud profile in the global provider registry.
|
||||
import model_tools # noqa: F401
|
||||
import providers
|
||||
|
||||
profile = providers.get_provider_profile("ollama-cloud")
|
||||
assert profile is not None, "ollama-cloud provider profile must be registered"
|
||||
return profile
|
||||
|
||||
|
||||
class TestOllamaCloudReasoningEffort:
|
||||
"""``build_api_kwargs_extras`` emits correct top-level ``reasoning_effort``."""
|
||||
|
||||
# ── xhigh / max → max ──────────────────────────────────────────
|
||||
|
||||
@pytest.mark.parametrize("effort", ["xhigh", "max", "MAX", " Max "])
|
||||
def test_xhigh_and_max_normalize_to_max(self, ollama_cloud_profile, effort):
|
||||
extra_body, top_level = ollama_cloud_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": True, "effort": effort},
|
||||
)
|
||||
assert extra_body == {}
|
||||
assert top_level == {"reasoning_effort": "max"}
|
||||
|
||||
# ── low / medium / high pass through ───────────────────────────
|
||||
|
||||
@pytest.mark.parametrize("effort", ["low", "medium", "high"])
|
||||
def test_standard_efforts_pass_through(self, ollama_cloud_profile, effort):
|
||||
_, top_level = ollama_cloud_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": True, "effort": effort},
|
||||
)
|
||||
assert top_level == {"reasoning_effort": effort}
|
||||
|
||||
# ── disabled → no reasoning_effort emitted ─────────────────────
|
||||
|
||||
def test_explicitly_disabled_emits_nothing(self, ollama_cloud_profile):
|
||||
extra_body, top_level = ollama_cloud_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": False},
|
||||
)
|
||||
assert extra_body == {}
|
||||
assert top_level == {}
|
||||
|
||||
def test_disabled_ignores_effort_field(self, ollama_cloud_profile):
|
||||
"""Effort silently dropped when thinking is off."""
|
||||
_, top_level = ollama_cloud_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": False, "effort": "high"},
|
||||
)
|
||||
assert top_level == {}
|
||||
|
||||
# ── none effort → no reasoning_effort ──────────────────────────
|
||||
|
||||
def test_none_effort_emits_nothing(self, ollama_cloud_profile):
|
||||
extra_body, top_level = ollama_cloud_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": True, "effort": "none"},
|
||||
)
|
||||
assert extra_body == {}
|
||||
assert top_level == {}
|
||||
|
||||
# ── missing / empty effort → let model default ─────────────────
|
||||
|
||||
def test_no_reasoning_config_emits_nothing(self, ollama_cloud_profile):
|
||||
extra_body, top_level = ollama_cloud_profile.build_api_kwargs_extras(
|
||||
reasoning_config=None,
|
||||
)
|
||||
assert extra_body == {}
|
||||
assert top_level == {}
|
||||
|
||||
def test_empty_effort_emits_nothing(self, ollama_cloud_profile):
|
||||
_, top_level = ollama_cloud_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": True, "effort": ""},
|
||||
)
|
||||
assert top_level == {}
|
||||
|
||||
def test_no_effort_key_emits_nothing(self, ollama_cloud_profile):
|
||||
"""When effort key is absent, let the model use its default."""
|
||||
_, top_level = ollama_cloud_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": True},
|
||||
)
|
||||
assert top_level == {}
|
||||
|
||||
# ── unknown effort → forwarded as-is ───────────────────────────
|
||||
|
||||
def test_unknown_effort_forwarded(self, ollama_cloud_profile):
|
||||
_, top_level = ollama_cloud_profile.build_api_kwargs_extras(
|
||||
reasoning_config={"enabled": True, "effort": "ultra"},
|
||||
)
|
||||
assert top_level == {"reasoning_effort": "ultra"}
|
||||
|
||||
|
||||
class TestOllamaCloudFullKwargsIntegration:
|
||||
"""End-to-end: the transport's full kwargs include reasoning_effort."""
|
||||
|
||||
def test_full_kwargs_with_xhigh(self, ollama_cloud_profile):
|
||||
from agent.transports.chat_completions import ChatCompletionsTransport
|
||||
|
||||
kwargs = ChatCompletionsTransport().build_kwargs(
|
||||
model="deepseek-v4-pro:cloud",
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
tools=None,
|
||||
provider_profile=ollama_cloud_profile,
|
||||
reasoning_config={"enabled": True, "effort": "xhigh"},
|
||||
base_url="https://ollama.com/v1",
|
||||
provider_name="ollama-cloud",
|
||||
)
|
||||
assert kwargs["model"] == "deepseek-v4-pro:cloud"
|
||||
assert kwargs["reasoning_effort"] == "max"
|
||||
# No extra_body — Ollama Cloud uses top-level reasoning_effort
|
||||
assert "extra_body" not in kwargs or "reasoning" not in kwargs.get("extra_body", {})
|
||||
|
||||
def test_full_kwargs_with_disabled(self, ollama_cloud_profile):
|
||||
from agent.transports.chat_completions import ChatCompletionsTransport
|
||||
|
||||
kwargs = ChatCompletionsTransport().build_kwargs(
|
||||
model="deepseek-v4-pro:cloud",
|
||||
messages=[{"role": "user", "content": "ping"}],
|
||||
tools=None,
|
||||
provider_profile=ollama_cloud_profile,
|
||||
reasoning_config={"enabled": False},
|
||||
base_url="https://ollama.com/v1",
|
||||
provider_name="ollama-cloud",
|
||||
)
|
||||
assert "reasoning_effort" not in kwargs
|
||||
|
||||
|
||||
class TestOllamaCloudAuxModel:
|
||||
"""Ollama Cloud aux model is set on the profile."""
|
||||
|
||||
def test_profile_advertises_aux_model(self, ollama_cloud_profile):
|
||||
assert ollama_cloud_profile.default_aux_model == "nemotron-3-nano:30b"
|
||||
138
tests/run_agent/test_background_review_cost_controls.py
Normal file
138
tests/run_agent/test_background_review_cost_controls.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""Unit coverage for the background-review aux-model selector + routed digest.
|
||||
|
||||
Covers the two behaviors this change adds:
|
||||
• _resolve_review_runtime — auto/same-model → not routed (main model, warm
|
||||
cache); a configured different model → routed with resolved credentials.
|
||||
• _digest_history — compact replay used ONLY on the routed path (recent tail
|
||||
verbatim + a digest of older turns), preserving role alternation.
|
||||
|
||||
Pure-function / config-driven; no live model calls.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent import background_review as br
|
||||
|
||||
|
||||
def _msg(role, content, tool_calls=None):
|
||||
m = {"role": role, "content": content}
|
||||
if tool_calls:
|
||||
m["tool_calls"] = tool_calls
|
||||
return m
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_review_runtime — the aux-model selector
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _FakeAgent:
|
||||
def __init__(self, provider="openai-codex", model="gpt-5.5"):
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
|
||||
def _current_main_runtime(self):
|
||||
return {
|
||||
"api_key": "parent-key",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_mode": "codex_app_server",
|
||||
}
|
||||
|
||||
|
||||
def test_routing_auto_inherits_parent_and_downgrades_codex_app_server():
|
||||
agent = _FakeAgent()
|
||||
cfg = {"auxiliary": {"background_review": {"provider": "auto", "model": ""}}}
|
||||
with patch("hermes_cli.config.load_config", return_value=cfg):
|
||||
rt = br._resolve_review_runtime(agent)
|
||||
assert rt["routed"] is False
|
||||
assert rt["provider"] == "openai-codex"
|
||||
assert rt["model"] == "gpt-5.5"
|
||||
assert rt["api_mode"] == "codex_responses" # downgraded so agent-loop tools dispatch
|
||||
|
||||
|
||||
def test_routing_to_different_model_marks_routed_and_resolves_credentials():
|
||||
agent = _FakeAgent()
|
||||
cfg = {"auxiliary": {"background_review": {
|
||||
"provider": "openrouter", "model": "google/gemini-3-flash-preview",
|
||||
}}}
|
||||
fake_rp = {
|
||||
"provider": "openrouter", "api_key": "or-key",
|
||||
"base_url": "https://openrouter.ai/api/v1", "api_mode": "chat_completions",
|
||||
}
|
||||
with patch("hermes_cli.config.load_config", return_value=cfg), \
|
||||
patch("hermes_cli.runtime_provider.resolve_runtime_provider", return_value=fake_rp):
|
||||
rt = br._resolve_review_runtime(agent)
|
||||
assert rt["routed"] is True
|
||||
assert rt["provider"] == "openrouter"
|
||||
assert rt["model"] == "google/gemini-3-flash-preview"
|
||||
assert rt["api_key"] == "or-key"
|
||||
|
||||
|
||||
def test_routing_same_model_as_parent_is_not_routed():
|
||||
agent = _FakeAgent(provider="openrouter", model="anthropic/claude-opus-4.8")
|
||||
cfg = {"auxiliary": {"background_review": {
|
||||
"provider": "openrouter", "model": "anthropic/claude-opus-4.8",
|
||||
}}}
|
||||
with patch("hermes_cli.config.load_config", return_value=cfg):
|
||||
rt = br._resolve_review_runtime(agent)
|
||||
assert rt["routed"] is False # same model/provider → keep full-replay path
|
||||
|
||||
|
||||
def test_routing_resolution_failure_falls_back_to_parent():
|
||||
agent = _FakeAgent()
|
||||
cfg = {"auxiliary": {"background_review": {
|
||||
"provider": "openrouter", "model": "google/gemini-3-flash-preview",
|
||||
}}}
|
||||
with patch("hermes_cli.config.load_config", return_value=cfg), \
|
||||
patch("hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
side_effect=RuntimeError("boom")):
|
||||
rt = br._resolve_review_runtime(agent)
|
||||
assert rt["routed"] is False
|
||||
assert rt["provider"] == "openai-codex"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _digest_history — routed-path compact replay
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_digest_under_tail_returns_full():
|
||||
msgs = [_msg("user", "hi"), _msg("assistant", "hello")]
|
||||
assert br._digest_history(msgs, tail=24) == msgs
|
||||
|
||||
|
||||
def test_digest_collapses_old_keeps_tail_verbatim():
|
||||
msgs = []
|
||||
for i in range(60):
|
||||
msgs.append(_msg("user", f"u{i} " + "x" * 50))
|
||||
msgs.append(_msg("assistant", f"a{i} " + "y" * 50))
|
||||
out = br._digest_history(msgs, tail=10)
|
||||
# First message is the synthetic digest (user role → alternation preserved).
|
||||
assert out[0]["role"] == "user"
|
||||
assert out[0]["content"].startswith("[Earlier conversation digest")
|
||||
# Recent tail preserved verbatim.
|
||||
assert out[-1] == msgs[-1]
|
||||
assert len(out) == 11 # 1 digest + 10 tail
|
||||
|
||||
|
||||
def test_digest_does_not_open_tail_on_a_tool_message():
|
||||
msgs = []
|
||||
for i in range(40):
|
||||
msgs.append(_msg("user", "u" + "x" * 50))
|
||||
msgs.append(_msg("assistant", "", tool_calls=[
|
||||
{"function": {"name": "terminal", "arguments": "{}"}}]))
|
||||
msgs.append({"role": "tool", "content": "result " + "w" * 50})
|
||||
out = br._digest_history(msgs, tail=2)
|
||||
# The verbatim tail (after the digest) must not begin on a bare tool message.
|
||||
assert out[1]["role"] != "tool"
|
||||
|
||||
|
||||
def test_digest_records_tool_names_in_arc():
|
||||
old = [
|
||||
_msg("user", "do the thing"),
|
||||
_msg("assistant", "", tool_calls=[
|
||||
{"function": {"name": "skill_view", "arguments": "{}"}},
|
||||
{"function": {"name": "patch", "arguments": "{}"}}]),
|
||||
]
|
||||
msgs = old + [_msg("user", f"tail{i}") for i in range(30)]
|
||||
out = br._digest_history(msgs, tail=10)
|
||||
digest = out[0]["content"]
|
||||
assert "USER: do the thing" in digest
|
||||
assert "tools: skill_view, patch" in digest
|
||||
|
|
@ -260,6 +260,52 @@ class TestShrinkImagePartsHelper:
|
|||
assert seen["max_dimension"] == 2000
|
||||
assert msgs[0]["content"][0]["image_url"]["url"] == shrunk
|
||||
|
||||
def test_anthropic_base64_image_source_rewritten(self, monkeypatch):
|
||||
"""Anthropic-native image blocks are shrinkable after adapter conversion."""
|
||||
agent = _make_agent()
|
||||
_install_fake_pillow(monkeypatch, (2501, 100), shrunk_size=(1500, 60))
|
||||
original = _big_png_data_url(100)
|
||||
_, _, original_data = original.partition(",")
|
||||
shrunk = "data:image/jpeg;base64," + "N" * 1000
|
||||
seen = {}
|
||||
|
||||
def _fake_resize(path, mime_type=None, max_base64_bytes=None, max_dimension=None):
|
||||
seen["mime_type"] = mime_type
|
||||
seen["max_dimension"] = max_dimension
|
||||
return shrunk
|
||||
|
||||
monkeypatch.setattr(
|
||||
"tools.vision_tools._resize_image_for_vision",
|
||||
_fake_resize,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
msgs = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": original_data,
|
||||
},
|
||||
},
|
||||
],
|
||||
}]
|
||||
changed = agent._try_shrink_image_parts_in_messages(
|
||||
msgs,
|
||||
max_dimension=2000,
|
||||
)
|
||||
source = msgs[0]["content"][0]["source"]
|
||||
|
||||
assert changed is True
|
||||
assert seen["mime_type"] == "image/png"
|
||||
assert seen["max_dimension"] == 2000
|
||||
assert source["type"] == "base64"
|
||||
assert source["media_type"] == "image/jpeg"
|
||||
assert source["data"] == "N" * 1000
|
||||
|
||||
def test_oversized_input_image_string_shape_rewritten(self, monkeypatch):
|
||||
"""OpenAI Responses shape: {type: input_image, image_url: "data:..."}."""
|
||||
agent = _make_agent()
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from agent.codex_responses_adapter import _normalize_codex_response
|
|||
import run_agent
|
||||
from run_agent import AIAgent
|
||||
from agent.error_classifier import FailoverReason
|
||||
from agent.memory_manager import MemoryManager
|
||||
from agent.prompt_builder import DEFAULT_AGENT_IDENTITY
|
||||
|
||||
|
||||
|
|
@ -2082,6 +2083,41 @@ class TestExecuteToolCalls:
|
|||
assert messages[0]["role"] == "tool"
|
||||
assert "search result" in messages[0]["content"]
|
||||
|
||||
def test_sequential_memory_remove_notifies_provider_with_tool_result(self, agent):
|
||||
old_text = "stale preference entry"
|
||||
tc = _mock_tool_call(
|
||||
name="memory",
|
||||
arguments=json.dumps({
|
||||
"action": "remove",
|
||||
"target": "memory",
|
||||
"old_text": old_text,
|
||||
}),
|
||||
call_id="mem-1",
|
||||
)
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc])
|
||||
messages = []
|
||||
calls = []
|
||||
|
||||
class FakeMemoryManager(MemoryManager):
|
||||
def has_tool(self, tool_name):
|
||||
return False
|
||||
|
||||
def on_memory_write(self, action, target, content, metadata=None):
|
||||
calls.append((action, target, content, metadata or {}))
|
||||
|
||||
agent._memory_manager = FakeMemoryManager()
|
||||
agent._memory_store = object()
|
||||
|
||||
with patch("tools.memory_tool.memory_tool", return_value=json.dumps({"success": True})):
|
||||
agent._execute_tool_calls_sequential(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(calls) == 1
|
||||
action, target, content, metadata = calls[0]
|
||||
assert (action, target, content) == ("remove", "memory", "")
|
||||
assert metadata["old_text"] == old_text
|
||||
assert metadata["tool_call_id"] == "mem-1"
|
||||
assert messages[-1]["tool_call_id"] == "mem-1"
|
||||
|
||||
def test_keyboard_interrupt_emits_cancelled_post_tool_hook(self, agent, monkeypatch):
|
||||
tc = _mock_tool_call(name="web_search", arguments='{"q":"test"}', call_id="c1")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc])
|
||||
|
|
@ -2457,6 +2493,35 @@ class TestConcurrentToolExecution:
|
|||
assert messages[1]["tool_call_id"] == "c2"
|
||||
assert "success" in messages[1]["content"]
|
||||
|
||||
def test_concurrent_submit_shutdown_error_returns_tool_errors(self, agent):
|
||||
"""Submit-time interpreter shutdown should not escape the outer loop."""
|
||||
|
||||
class ShutdownExecutor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def submit(self, *args, **kwargs):
|
||||
raise RuntimeError("cannot schedule new futures after interpreter shutdown")
|
||||
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{"q": "alpha"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='{"q": "beta"}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
with patch("agent.tool_executor.concurrent.futures.ThreadPoolExecutor", ShutdownExecutor):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["tool_call_id"] == "c1"
|
||||
assert messages[1]["tool_call_id"] == "c2"
|
||||
assert all("Python interpreter is shutting down" in m["content"] for m in messages)
|
||||
|
||||
def test_concurrent_interrupt_before_start(self, agent):
|
||||
"""If interrupt is requested before concurrent execution, all tools are skipped."""
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{}', call_id="c1")
|
||||
|
|
@ -2797,6 +2862,68 @@ class TestConcurrentToolExecution:
|
|||
assert json.loads(result) == {"error": "Blocked"}
|
||||
assert agent._turns_since_memory == 5
|
||||
|
||||
def test_invoke_tool_memory_remove_notifies_provider_with_old_text(self, agent, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.get_pre_tool_call_block_message",
|
||||
lambda *args, **kwargs: None,
|
||||
)
|
||||
calls = []
|
||||
|
||||
class FakeMemoryManager(MemoryManager):
|
||||
def has_tool(self, tool_name):
|
||||
return False
|
||||
|
||||
def on_memory_write(self, action, target, content, metadata=None):
|
||||
calls.append((action, target, content, metadata or {}))
|
||||
|
||||
old_text = "stale preference entry"
|
||||
agent._memory_manager = FakeMemoryManager()
|
||||
agent._memory_store = object()
|
||||
|
||||
with patch("tools.memory_tool.memory_tool", return_value=json.dumps({"success": True})):
|
||||
agent._invoke_tool(
|
||||
"memory",
|
||||
{"action": "remove", "target": "memory", "old_text": old_text},
|
||||
"task-1",
|
||||
tool_call_id="mem-1",
|
||||
)
|
||||
|
||||
assert len(calls) == 1
|
||||
action, target, content, metadata = calls[0]
|
||||
assert (action, target, content) == ("remove", "memory", "")
|
||||
assert metadata["old_text"] == old_text
|
||||
assert metadata["tool_call_id"] == "mem-1"
|
||||
|
||||
def test_invoke_tool_memory_failed_remove_skips_provider_notification(self, agent, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.get_pre_tool_call_block_message",
|
||||
lambda *args, **kwargs: None,
|
||||
)
|
||||
notify = MagicMock(side_effect=AssertionError("should not notify"))
|
||||
|
||||
class FakeMemoryManager(MemoryManager):
|
||||
def has_tool(self, tool_name):
|
||||
return False
|
||||
|
||||
on_memory_write = notify
|
||||
|
||||
manager = FakeMemoryManager()
|
||||
agent._memory_manager = manager
|
||||
agent._memory_store = object()
|
||||
|
||||
with patch(
|
||||
"tools.memory_tool.memory_tool",
|
||||
return_value=json.dumps({"success": False, "error": "No entry matched"}),
|
||||
):
|
||||
agent._invoke_tool(
|
||||
"memory",
|
||||
{"action": "remove", "target": "memory", "old_text": "missing"},
|
||||
"task-1",
|
||||
tool_call_id="mem-1",
|
||||
)
|
||||
|
||||
notify.assert_not_called()
|
||||
|
||||
def test_concurrent_blocked_write_skips_checkpoint(self, agent, monkeypatch):
|
||||
"""Concurrent path: blocked write_file should not trigger checkpoint."""
|
||||
tc1 = _mock_tool_call(name="write_file",
|
||||
|
|
|
|||
252
tests/run_agent/test_tool_call_incremental_persistence.py
Normal file
252
tests/run_agent/test_tool_call_incremental_persistence.py
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
"""Behavior contracts for incremental tool-call persistence (#49045).
|
||||
|
||||
A destructive or process-terminating tool that runs during tool execution
|
||||
must not lose the just-executed assistant(tool_calls) block or the tool
|
||||
results that were produced before it fired. These tests pin the contract:
|
||||
|
||||
1. run_conversation flushes the assistant tool-call turn to the session
|
||||
DB BEFORE handing control to _execute_tool_calls (so a tool that
|
||||
restarts/kills the process never orphans the tool-call block).
|
||||
2. The SEQUENTIAL tool path flushes each tool result to the session DB
|
||||
immediately after appending it — BEFORE the next tool dispatches.
|
||||
3. The CONCURRENT tool path flushes each tool result in append order.
|
||||
|
||||
These exercise the REAL production dispatch surfaces:
|
||||
|
||||
* sequential -> ``run_agent.handle_function_call`` (tool_executor ~1256/1298)
|
||||
* concurrent -> ``agent._invoke_tool`` (tool_executor ~539)
|
||||
|
||||
Mocking the genuine dispatch surface keeps the tests deterministic (no real
|
||||
``web_search`` / network) AND mutation-survivable: the ordering assertions
|
||||
read snapshots captured at flush time, so removing any production flush call
|
||||
makes the corresponding assertion fail.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from types import SimpleNamespace
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from agent.tool_dispatch_helpers import make_tool_result_message
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _make_tool_defs(*names: str) -> list:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": f"{name} tool",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
for name in names
|
||||
]
|
||||
|
||||
|
||||
def _make_agent():
|
||||
hermes_home = Path(tempfile.mkdtemp(prefix="hermes-test-home-"))
|
||||
(hermes_home / "logs").mkdir(parents=True, exist_ok=True)
|
||||
with (
|
||||
patch(
|
||||
"run_agent.get_tool_definitions",
|
||||
return_value=_make_tool_defs("web_search"),
|
||||
),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
patch("run_agent._hermes_home", hermes_home),
|
||||
patch("agent.model_metadata.fetch_model_metadata", return_value={}),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.client = MagicMock()
|
||||
agent._cached_system_prompt = "You are helpful."
|
||||
agent._use_prompt_caching = False
|
||||
agent.tool_delay = 0
|
||||
agent.compression_enabled = False
|
||||
agent.save_trajectories = False
|
||||
return agent
|
||||
|
||||
|
||||
def _mock_tool_call(name="web_search", arguments="{}", call_id="call_1"):
|
||||
return SimpleNamespace(
|
||||
id=call_id,
|
||||
type="function",
|
||||
function=SimpleNamespace(name=name, arguments=arguments),
|
||||
)
|
||||
|
||||
|
||||
def _mock_response(content="Hello", finish_reason="stop", tool_calls=None):
|
||||
msg = SimpleNamespace(content=content, tool_calls=tool_calls)
|
||||
choice = SimpleNamespace(message=msg, finish_reason=finish_reason)
|
||||
return SimpleNamespace(choices=[choice], model="test/model", usage=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Contract 1: run_conversation persists the assistant tool-call block BEFORE
|
||||
# tool execution begins.
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_run_conversation_flushes_assistant_tool_call_before_execution():
|
||||
agent = _make_agent()
|
||||
tool_call = _mock_tool_call(call_id="c1")
|
||||
agent.client.chat.completions.create.side_effect = [
|
||||
_mock_response(content="", finish_reason="tool_calls", tool_calls=[tool_call]),
|
||||
_mock_response(content="done", finish_reason="stop"),
|
||||
]
|
||||
|
||||
# Record a deep snapshot of the message list at every flush so the
|
||||
# assertion does not depend on later mutations.
|
||||
flush_snapshots: list[list] = []
|
||||
|
||||
def _record_flush(messages, conversation_history=None):
|
||||
flush_snapshots.append(copy.deepcopy(messages))
|
||||
|
||||
agent._flush_messages_to_session_db = MagicMock(side_effect=_record_flush)
|
||||
|
||||
# Capture observations at execute time into module-level lists rather than
|
||||
# asserting inside _execute_tool_calls — run_conversation's outer loop
|
||||
# swallows exceptions, so an in-callback assertion would never surface.
|
||||
executed = {"count": 0}
|
||||
snapshot_at_execute: list = []
|
||||
|
||||
def _fake_execute(assistant_message, messages, effective_task_id, api_call_count=0):
|
||||
executed["count"] += 1
|
||||
# Record the DB state observed at the moment tool execution begins.
|
||||
snapshot_at_execute.append(
|
||||
copy.deepcopy(flush_snapshots[-1]) if flush_snapshots else None
|
||||
)
|
||||
# Simulate the tool producing a result (as the real path would).
|
||||
messages.append(make_tool_result_message("web_search", "search result", "c1"))
|
||||
|
||||
with (
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
patch.object(agent, "_execute_tool_calls", side_effect=_fake_execute),
|
||||
):
|
||||
result = agent.run_conversation("search something")
|
||||
|
||||
assert executed["count"] == 1, "_execute_tool_calls was never reached"
|
||||
# The assistant tool-call block MUST have been flushed before execution.
|
||||
last = snapshot_at_execute[0]
|
||||
assert last is not None, "no flush occurred before tool execution"
|
||||
assert last[-1]["role"] == "assistant"
|
||||
assert last[-1]["tool_calls"][0]["id"] == "c1"
|
||||
assert result["final_response"] == "done"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Contract 2: the SEQUENTIAL path flushes each tool result immediately, BEFORE
|
||||
# the next tool dispatches. Dispatch goes through run_agent.handle_function_call
|
||||
# (the real production surface), which we mock for determinism.
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_execute_tool_calls_sequential_flushes_each_tool_result_before_next_dispatch():
|
||||
agent = _make_agent()
|
||||
tool_calls = [
|
||||
_mock_tool_call(name="web_search", call_id="c1"),
|
||||
_mock_tool_call(name="web_search", call_id="c2"),
|
||||
]
|
||||
messages: list = []
|
||||
assistant_message = SimpleNamespace(content="", tool_calls=tool_calls)
|
||||
|
||||
# Ordered event log interleaving real dispatches and DB flushes.
|
||||
events: list = []
|
||||
|
||||
def _fake_dispatch(function_name, function_args, effective_task_id, **kwargs):
|
||||
# The result for call N must have been flushed before call N+1 fires.
|
||||
events.append(("dispatch", kwargs.get("tool_call_id")))
|
||||
return f"result-{kwargs.get('tool_call_id')}"
|
||||
|
||||
def _record_flush(flush_messages, conversation_history=None):
|
||||
# Snapshot the tail tool result that triggered this flush.
|
||||
tail = flush_messages[-1]
|
||||
events.append(("flush", tail.get("role"), tail.get("tool_call_id")))
|
||||
|
||||
agent._flush_messages_to_session_db = MagicMock(side_effect=_record_flush)
|
||||
|
||||
with (
|
||||
patch("run_agent.handle_function_call", side_effect=_fake_dispatch) as disp,
|
||||
patch(
|
||||
"agent.tool_executor.maybe_persist_tool_result",
|
||||
side_effect=lambda **kwargs: kwargs["content"],
|
||||
),
|
||||
):
|
||||
agent._execute_tool_calls_sequential(assistant_message, messages, "task-1")
|
||||
|
||||
# The mock proves we exercised the REAL sequential dispatch surface.
|
||||
assert disp.call_count == 2, "sequential path did not dispatch via handle_function_call"
|
||||
|
||||
# Both tool results landed, in order.
|
||||
assert [m["role"] for m in messages] == ["tool", "tool"]
|
||||
assert [m["tool_call_id"] for m in messages] == ["c1", "c2"]
|
||||
|
||||
# Ordering contract: each tool result is flushed AFTER its own dispatch
|
||||
# and BEFORE the next dispatch. Expected interleaving:
|
||||
# dispatch c1 -> flush c1 -> dispatch c2 -> flush c2
|
||||
assert events == [
|
||||
("dispatch", "c1"),
|
||||
("flush", "tool", "c1"),
|
||||
("dispatch", "c2"),
|
||||
("flush", "tool", "c2"),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Contract 3: the CONCURRENT path flushes each collected tool result in append
|
||||
# order. Dispatch goes through agent._invoke_tool (the real concurrent
|
||||
# surface), which we mock for determinism.
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_execute_tool_calls_concurrent_flushes_each_tool_result_in_order():
|
||||
agent = _make_agent()
|
||||
tool_calls = [
|
||||
_mock_tool_call(name="web_search", call_id="c1"),
|
||||
_mock_tool_call(name="web_search", call_id="c2"),
|
||||
]
|
||||
messages: list = []
|
||||
assistant_message = SimpleNamespace(content="", tool_calls=tool_calls)
|
||||
|
||||
invoked_ids: list = []
|
||||
|
||||
def _fake_invoke(function_name, function_args, effective_task_id, tool_call_id, **kwargs):
|
||||
invoked_ids.append(tool_call_id)
|
||||
return f"result-{tool_call_id}"
|
||||
|
||||
# Each flush must observe exactly one more tool result than the previous
|
||||
# flush, in append order — i.e. the tail tool_call_id sequence is c1, c2.
|
||||
flushed_tool_ids: list = []
|
||||
flush_lengths: list = []
|
||||
|
||||
def _record_flush(flush_messages, conversation_history=None):
|
||||
flushed_tool_ids.append(flush_messages[-1]["tool_call_id"])
|
||||
flush_lengths.append(len([m for m in flush_messages if m.get("role") == "tool"]))
|
||||
|
||||
agent._flush_messages_to_session_db = MagicMock(side_effect=_record_flush)
|
||||
|
||||
with (
|
||||
patch.object(agent, "_invoke_tool", side_effect=_fake_invoke) as inv,
|
||||
patch(
|
||||
"agent.tool_executor.maybe_persist_tool_result",
|
||||
side_effect=lambda **kwargs: kwargs["content"],
|
||||
),
|
||||
):
|
||||
agent._execute_tool_calls_concurrent(assistant_message, messages, "task-1")
|
||||
|
||||
# Proves the real concurrent dispatch surface was exercised.
|
||||
assert inv.call_count == 2, "concurrent path did not dispatch via _invoke_tool"
|
||||
assert sorted(invoked_ids) == ["c1", "c2"]
|
||||
|
||||
# Results appended in deterministic order.
|
||||
assert [m["tool_call_id"] for m in messages] == ["c1", "c2"]
|
||||
|
||||
# Each tool result was flushed exactly once, in append order, with the
|
||||
# running tool count growing by one each time (1 then 2). Removing either
|
||||
# production flush call breaks one of these assertions.
|
||||
assert flushed_tool_ids == ["c1", "c2"]
|
||||
assert flush_lengths == [1, 2]
|
||||
164
tests/skills/test_cloudflare_temporary_deploy_skill.py
Normal file
164
tests/skills/test_cloudflare_temporary_deploy_skill.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"""Tests for optional-skills/web-development/cloudflare-temporary-deploy/scripts/parse_deploy_output.py"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
SCRIPTS_DIR = (
|
||||
Path(__file__).resolve().parents[2]
|
||||
/ "optional-skills"
|
||||
/ "web-development"
|
||||
/ "cloudflare-temporary-deploy"
|
||||
/ "scripts"
|
||||
)
|
||||
sys.path.insert(0, str(SCRIPTS_DIR))
|
||||
|
||||
import parse_deploy_output as pdo
|
||||
|
||||
|
||||
CREATED = """\
|
||||
Continuing means you accept Cloudflare's Terms of Service and Privacy Policy.
|
||||
|
||||
Temporary account ready:
|
||||
Account: swift-otter (created)
|
||||
Claim within: 60 minutes
|
||||
Claim URL: https://dash.cloudflare.com/claim-preview?claimToken=TOKEN_AAA
|
||||
|
||||
Uploaded my-worker
|
||||
Deployed my-worker triggers
|
||||
https://my-worker.swift-otter.workers.dev
|
||||
"""
|
||||
|
||||
REUSED = """\
|
||||
Temporary account ready:
|
||||
Account: swift-otter (reused)
|
||||
Claim within: 17 minutes
|
||||
Claim URL: https://dash.cloudflare.com/claim-preview?claimToken=TOKEN_BBB
|
||||
Deployed my-worker triggers
|
||||
https://my-worker.swift-otter.workers.dev
|
||||
"""
|
||||
|
||||
NOT_LOGGED_IN = """\
|
||||
✘ [ERROR] You are not logged in.
|
||||
|
||||
To continue without logging in, rerun this command with `--temporary`.
|
||||
"""
|
||||
|
||||
AUTH_PRESENT_ERROR = """\
|
||||
✘ [ERROR] The --temporary flag cannot be used while Wrangler is authenticated.
|
||||
Run `wrangler logout` first, or remove CLOUDFLARE_API_TOKEN.
|
||||
"""
|
||||
|
||||
|
||||
class TestParseCreated:
|
||||
def test_live_url(self):
|
||||
assert pdo.parse(CREATED)["live_url"] == "https://my-worker.swift-otter.workers.dev"
|
||||
|
||||
def test_claim_url(self):
|
||||
assert (
|
||||
pdo.parse(CREATED)["claim_url"]
|
||||
== "https://dash.cloudflare.com/claim-preview?claimToken=TOKEN_AAA"
|
||||
)
|
||||
|
||||
def test_account_and_state(self):
|
||||
r = pdo.parse(CREATED)
|
||||
assert r["account"] == "swift-otter"
|
||||
assert r["account_state"] == "created"
|
||||
|
||||
def test_expiry_and_deployed(self):
|
||||
r = pdo.parse(CREATED)
|
||||
assert r["expires_minutes"] == 60
|
||||
assert r["deployed"] is True
|
||||
|
||||
|
||||
class TestParseReused:
|
||||
def test_state_is_reused(self):
|
||||
assert pdo.parse(REUSED)["account_state"] == "reused"
|
||||
|
||||
def test_expiry_window_can_shrink(self):
|
||||
assert pdo.parse(REUSED)["expires_minutes"] == 17
|
||||
|
||||
def test_live_url_stable(self):
|
||||
assert pdo.parse(REUSED)["live_url"] == "https://my-worker.swift-otter.workers.dev"
|
||||
|
||||
|
||||
class TestNoDeploy:
|
||||
def test_not_logged_in_has_no_urls(self):
|
||||
r = pdo.parse(NOT_LOGGED_IN)
|
||||
assert r["live_url"] is None
|
||||
assert r["claim_url"] is None
|
||||
assert r["account"] is None
|
||||
assert r["deployed"] is False
|
||||
|
||||
def test_auth_present_error_has_no_urls(self):
|
||||
r = pdo.parse(AUTH_PRESENT_ERROR)
|
||||
assert r["live_url"] is None
|
||||
assert r["claim_url"] is None
|
||||
assert r["deployed"] is False
|
||||
|
||||
|
||||
class TestRealWorldOutput:
|
||||
"""Regression: real wrangler output uses tab-indent + multi-word account names."""
|
||||
|
||||
REAL = (
|
||||
"⛅️ wrangler 4.103.0\n"
|
||||
"Continuing means you accept Cloudflare's Terms of Service and Privacy Policy.\n"
|
||||
"Solving proof-of-work challenge…\n"
|
||||
"Temporary account ready:\n"
|
||||
"\tAccount: Serene Temple (created)\n"
|
||||
"\tClaim within: 60 minutes\n"
|
||||
"\tClaim URL: https://dash.cloudflare.com/claim-preview?claimToken=fxLzyAD-vlTzMQmClpg\n"
|
||||
"Total Upload: 0.19 KiB / gzip: 0.16 KiB\n"
|
||||
"Uploaded hermes-temp-hello (0.74 sec)\n"
|
||||
"Deployed hermes-temp-hello triggers (0.42 sec)\n"
|
||||
" https://hermes-temp-hello.serene-temple.workers.dev\n"
|
||||
)
|
||||
|
||||
def test_multiword_account_name(self):
|
||||
r = pdo.parse(self.REAL)
|
||||
assert r["account"] == "Serene Temple"
|
||||
assert r["account_state"] == "created"
|
||||
|
||||
def test_all_fields_from_real_output(self):
|
||||
r = pdo.parse(self.REAL)
|
||||
assert r["live_url"] == "https://hermes-temp-hello.serene-temple.workers.dev"
|
||||
assert r["claim_url"].endswith("claimToken=fxLzyAD-vlTzMQmClpg")
|
||||
assert r["expires_minutes"] == 60
|
||||
assert r["deployed"] is True
|
||||
|
||||
|
||||
class TestUrlHygiene:
|
||||
def test_trailing_punctuation_stripped(self):
|
||||
text = "Deployed\n see https://w.acct.workers.dev. for details"
|
||||
assert pdo.parse(text)["live_url"] == "https://w.acct.workers.dev"
|
||||
|
||||
def test_does_not_match_plain_cloudflare_com(self):
|
||||
# A generic cloudflare.com link without a claimToken must not be taken as the claim URL.
|
||||
text = "Privacy Policy: https://www.cloudflare.com/privacypolicy/\nDeployed x"
|
||||
assert pdo.parse(text)["claim_url"] is None
|
||||
|
||||
|
||||
class TestCli:
|
||||
def test_selftest_exits_zero(self):
|
||||
assert pdo.main(["--selftest"]) == 0
|
||||
|
||||
def test_main_prints_json_and_exit_zero_on_live(self, capsys):
|
||||
with mock.patch.object(sys.stdin, "read", return_value=CREATED):
|
||||
rc = pdo.main([])
|
||||
out = json.loads(capsys.readouterr().out)
|
||||
assert rc == 0
|
||||
assert out["live_url"] == "https://my-worker.swift-otter.workers.dev"
|
||||
|
||||
def test_main_exit_one_when_no_live_url(self, capsys):
|
||||
with mock.patch.object(sys.stdin, "read", return_value=NOT_LOGGED_IN):
|
||||
rc = pdo.main([])
|
||||
out = json.loads(capsys.readouterr().out)
|
||||
assert rc == 1
|
||||
assert out["live_url"] is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(pytest.main([__file__, "-q"]))
|
||||
79
tests/test_code_skew.py
Normal file
79
tests/test_code_skew.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""Tests for gateway code-skew detection (stale-checkout guard).
|
||||
|
||||
Companion to ``tests/test_stale_utils_module_import.py``: that test proves the
|
||||
crash; these prove the guard that turns it into a clear "restart the gateway"
|
||||
message before a model switch can hit it.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway import code_skew
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_boot_fingerprint(monkeypatch):
|
||||
"""Each test starts with no recorded boot fingerprint."""
|
||||
monkeypatch.setattr(code_skew, "_boot_fingerprint", None)
|
||||
|
||||
|
||||
class TestDetectCodeSkew:
|
||||
def test_no_boot_fingerprint_means_no_skew(self, monkeypatch):
|
||||
# Nothing recorded (e.g. non-git install) -> never a false positive.
|
||||
monkeypatch.setattr(code_skew, "_fingerprint", lambda: "git:refs/heads/main:def456")
|
||||
assert code_skew.detect_code_skew() is None
|
||||
|
||||
def test_unchanged_checkout_is_not_skew(self, monkeypatch):
|
||||
monkeypatch.setattr(code_skew, "_fingerprint", lambda: "git:refs/heads/main:abc1234567890")
|
||||
code_skew.record_boot_fingerprint()
|
||||
assert code_skew.detect_code_skew() is None
|
||||
|
||||
def test_drift_is_detected_with_short_revs(self, monkeypatch):
|
||||
monkeypatch.setattr(code_skew, "_fingerprint", lambda: "git:refs/heads/main:abc1234567890")
|
||||
code_skew.record_boot_fingerprint()
|
||||
|
||||
monkeypatch.setattr(code_skew, "_fingerprint", lambda: "git:refs/heads/main:def4567890123")
|
||||
skew = code_skew.detect_code_skew()
|
||||
assert skew == ("abc1234567", "def4567890")
|
||||
|
||||
def test_unreadable_current_rev_does_not_false_positive(self, monkeypatch):
|
||||
monkeypatch.setattr(code_skew, "_fingerprint", lambda: "git:refs/heads/main:abc1234567890")
|
||||
code_skew.record_boot_fingerprint()
|
||||
|
||||
monkeypatch.setattr(code_skew, "_fingerprint", lambda: None)
|
||||
assert code_skew.detect_code_skew() is None
|
||||
|
||||
def test_record_is_idempotent(self, monkeypatch):
|
||||
monkeypatch.setattr(code_skew, "_fingerprint", lambda: "git:refs/heads/main:first")
|
||||
code_skew.record_boot_fingerprint()
|
||||
monkeypatch.setattr(code_skew, "_fingerprint", lambda: "git:refs/heads/main:second")
|
||||
code_skew.record_boot_fingerprint() # must not overwrite the boot snapshot
|
||||
assert code_skew._boot_fingerprint == "git:refs/heads/main:first"
|
||||
|
||||
|
||||
class TestShort:
|
||||
def test_shortens_long_sha(self):
|
||||
assert code_skew._short("git:refs/heads/main:abcdef0123456789") == "abcdef0123"
|
||||
|
||||
def test_keeps_unresolved_marker(self):
|
||||
assert code_skew._short("git:refs/heads/main:unresolved") == "unresolved"
|
||||
|
||||
def test_passes_short_sha_through_untruncated(self):
|
||||
assert code_skew._short("git:HEAD:abc1234") == "abc1234"
|
||||
|
||||
|
||||
class TestModelSwitchSkewGuard:
|
||||
def test_guard_returns_none_without_skew(self, monkeypatch):
|
||||
from gateway import slash_commands
|
||||
|
||||
monkeypatch.setattr(code_skew, "detect_code_skew", lambda: None)
|
||||
assert slash_commands._model_switch_skew_guard() is None
|
||||
|
||||
def test_guard_message_names_revs_and_restart(self, monkeypatch):
|
||||
from gateway import slash_commands
|
||||
|
||||
monkeypatch.setattr(code_skew, "detect_code_skew", lambda: ("abc1234567", "def4567890"))
|
||||
msg = slash_commands._model_switch_skew_guard()
|
||||
assert msg is not None
|
||||
assert "abc1234567" in msg
|
||||
assert "def4567890" in msg
|
||||
assert "hermes gateway restart" in msg
|
||||
|
|
@ -12,19 +12,47 @@ REPO_ROOT = Path(__file__).resolve().parent.parent
|
|||
INSTALL_SH = REPO_ROOT / "scripts" / "install.sh"
|
||||
|
||||
|
||||
def test_install_script_skips_playwright_download_when_system_browser_exists() -> None:
|
||||
def test_install_script_does_not_autodetect_system_browser_on_path() -> None:
|
||||
"""The installer must not scan PATH/well-known locations for a browser.
|
||||
|
||||
Auto-detection silently bound the install to whatever ``command -v
|
||||
chromium`` resolved to — most damagingly a Snap Chromium, whose sandbox
|
||||
blocks agent-browser's control socket and hangs every browser_navigate. The
|
||||
fallback was dropped in favor of always using the bundled Playwright
|
||||
Chromium, so the old PATH-scan and "use the system browser" path are gone.
|
||||
"""
|
||||
text = INSTALL_SH.read_text()
|
||||
|
||||
assert "find_system_browser()" in text
|
||||
assert "google-chrome google-chrome-stable chromium chromium-browser chrome" in text
|
||||
assert "Skipping Playwright browser download; Hermes will use the system browser." in text
|
||||
assert "google-chrome google-chrome-stable chromium chromium-browser chrome" not in text
|
||||
assert "Skipping Playwright browser download; Hermes will use the system browser." not in text
|
||||
|
||||
|
||||
def test_install_script_persists_system_browser_for_agent_browser() -> None:
|
||||
def test_install_script_honors_explicit_browser_override_only() -> None:
|
||||
"""find_system_browser consults only an explicit AGENT_BROWSER_EXECUTABLE_PATH."""
|
||||
text = INSTALL_SH.read_text()
|
||||
|
||||
assert "configure_browser_env_from_system_browser()" in text
|
||||
assert "AGENT_BROWSER_EXECUTABLE_PATH=$browser_path" in text
|
||||
assert 'override="${AGENT_BROWSER_EXECUTABLE_PATH:-}"' in text
|
||||
# An explicit override still skips the bundled download (override, not fallback).
|
||||
assert "Skipping bundled Chromium download" in text
|
||||
|
||||
|
||||
def test_install_script_strips_stale_snap_browser_override() -> None:
|
||||
"""Already-affected installs must auto-recover.
|
||||
|
||||
A pre-existing AGENT_BROWSER_EXECUTABLE_PATH pointing at a Snap Chromium is
|
||||
the exact value that hangs the browser tool, and the runtime reads it from
|
||||
.env — so the installer strips it (and a Snap override is rejected even when
|
||||
set explicitly) so the bundled Chromium download runs on update.
|
||||
"""
|
||||
text = INSTALL_SH.read_text()
|
||||
|
||||
assert "strip_snap_browser_override()" in text
|
||||
assert "^AGENT_BROWSER_EXECUTABLE_PATH=/snap/" in text
|
||||
# Both install paths invoke the migration before resolving a browser.
|
||||
assert text.count("strip_snap_browser_override") >= 3
|
||||
# A snap path is rejected by find_system_browser itself.
|
||||
assert "/snap/*) return 1 ;;" in text
|
||||
|
||||
|
||||
def test_playwright_installs_are_timeout_guarded() -> None:
|
||||
|
|
|
|||
90
tests/test_stale_utils_module_import.py
Normal file
90
tests/test_stale_utils_module_import.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
"""Regression for the stale-``utils``-module ImportError after a hot ``git pull``.
|
||||
|
||||
Real incident (gateway session 1518671026962174144)::
|
||||
|
||||
Sorry, I encountered an error (ImportError).
|
||||
cannot import name 'env_float' from 'utils' (~/.hermes/hermes-agent/utils.py)
|
||||
|
||||
Mechanism:
|
||||
|
||||
1. A long-running gateway/agent process imported ``utils`` BEFORE ``env_float``
|
||||
existed (added in 06ca1e99, 2026-06-20 14:00). The cached module object in
|
||||
``sys.modules`` therefore has no ``env_float`` attribute.
|
||||
2. ``hermes update`` ran ``git pull``, updating ``utils.py`` (now defining
|
||||
``env_float``) and ~22 consumer modules (now doing ``from utils import
|
||||
env_float``) on disk -- WITHOUT restarting the process.
|
||||
3. Switching the live session's model (anthropic/opus -> opencode/glm) forced the
|
||||
FIRST import of a consumer module on the new provider's code path. Its
|
||||
top-level ``from utils import env_float`` resolved against the STALE cached
|
||||
``utils`` -> ImportError. The path in parentheses is the consumer-reported
|
||||
``utils.__file__`` on disk (which *does* define ``env_float``), which is why
|
||||
the error is so confusing: the file on disk is fine, the in-memory module is not.
|
||||
|
||||
``hermes_cli/main.py`` (the ``hermes update`` flow, ~line 9326) already
|
||||
acknowledges this exact hazard -- "source files on disk are newer than cached
|
||||
Python modules in this process" -- and reloads ``hermes_constants`` after the
|
||||
pull, but NOT ``utils``. Any ``utils`` consumer added in the same release stays
|
||||
exposed until the process restarts.
|
||||
|
||||
The messaging client (Discord/Telegram/Feishu/...) is incidental: the trigger is
|
||||
a fresh import on a stale process, not the platform. We assert that below by
|
||||
reproducing the failure with the Discord adapter's exact import line.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _import_fresh_consumer(name: str, source: str) -> types.ModuleType:
|
||||
"""Import a brand-new module whose body runs ``source`` -- mimicking a
|
||||
consumer module being imported for the first time on the model-switch path."""
|
||||
mod = types.ModuleType(name)
|
||||
mod.__file__ = f"{name}.py"
|
||||
sys.modules.pop(name, None)
|
||||
exec(compile(source, mod.__file__, "exec"), mod.__dict__)
|
||||
sys.modules[name] = mod
|
||||
return mod
|
||||
|
||||
|
||||
class TestStaleUtilsModuleImport:
|
||||
def test_fresh_consumer_import_fails_against_stale_utils(self, monkeypatch):
|
||||
"""The bug: stale in-memory ``utils`` + fresh ``from utils import env_float``."""
|
||||
import utils
|
||||
|
||||
# Sanity: today's on-disk source is healthy.
|
||||
assert hasattr(utils, "env_float")
|
||||
|
||||
# Simulate the pre-06-20 cached module (monkeypatch auto-restores after).
|
||||
monkeypatch.delattr(utils, "env_float")
|
||||
|
||||
with pytest.raises(ImportError, match=r"cannot import name 'env_float' from 'utils'"):
|
||||
_import_fresh_consumer("stale_switch_path_consumer", "from utils import env_float\n")
|
||||
|
||||
def test_client_is_incidental_discord_import_line_fails_identically(self, monkeypatch):
|
||||
"""Same failure via the Discord adapter's exact import line -- the client
|
||||
does not determine the bug, the stale process does."""
|
||||
import utils
|
||||
|
||||
monkeypatch.delattr(utils, "env_float")
|
||||
|
||||
# plugins/platforms/discord/adapter.py:106
|
||||
with pytest.raises(ImportError, match=r"cannot import name 'env_float' from 'utils'"):
|
||||
_import_fresh_consumer(
|
||||
"stale_discord_consumer",
|
||||
"from utils import atomic_json_write, env_float\n",
|
||||
)
|
||||
|
||||
def test_healthy_process_imports_consumer_fine(self):
|
||||
"""Control: when the cached ``utils`` matches disk (env_float present),
|
||||
the same consumer import succeeds -- proving the harness isolates the
|
||||
staleness, not an unrelated import error."""
|
||||
import utils
|
||||
|
||||
assert hasattr(utils, "env_float")
|
||||
mod = _import_fresh_consumer(
|
||||
"healthy_consumer",
|
||||
"from utils import env_float\nVALUE = env_float('UNSET_FOR_TEST', 1.5)\n",
|
||||
)
|
||||
assert mod.VALUE == 1.5
|
||||
|
|
@ -7946,3 +7946,45 @@ def test_start_agent_build_passes_session_model_override(monkeypatch):
|
|||
assert session["agent"].model == "claude-sonnet-4.6"
|
||||
finally:
|
||||
server._sessions.clear()
|
||||
|
||||
|
||||
# ── _get_usage active_subagents (TUI status-bar ⛓ indicator) ──────────────
|
||||
# Mirrors the classic CLI status bar: _get_usage embeds a live count of
|
||||
# background/async subagents from tools.async_delegation.active_count() so the
|
||||
# Ink status bar can render ⛓ N. Source of truth is the same registry the CLI
|
||||
# reads; the field rides the existing per-update `usage` payload.
|
||||
|
||||
|
||||
class _BareAgent:
|
||||
"""Agent stub with no compressor — exercises the active_subagents path
|
||||
independent of the `if comp:` context-percent block."""
|
||||
|
||||
model = "x"
|
||||
|
||||
|
||||
def test_get_usage_includes_active_subagents(monkeypatch):
|
||||
import tools.async_delegation as ad_mod
|
||||
monkeypatch.setattr(ad_mod, "active_count", lambda: 4)
|
||||
usage = server._get_usage(_BareAgent())
|
||||
assert usage["active_subagents"] == 4
|
||||
|
||||
|
||||
def test_get_usage_active_subagents_zero(monkeypatch):
|
||||
import tools.async_delegation as ad_mod
|
||||
monkeypatch.setattr(ad_mod, "active_count", lambda: 0)
|
||||
usage = server._get_usage(_BareAgent())
|
||||
assert usage["active_subagents"] == 0
|
||||
|
||||
|
||||
def test_get_usage_safe_when_active_count_raises(monkeypatch):
|
||||
"""A raising active_count() must not break the usage payload."""
|
||||
import tools.async_delegation as ad_mod
|
||||
|
||||
def _boom():
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(ad_mod, "active_count", _boom)
|
||||
usage = server._get_usage(_BareAgent())
|
||||
# Field omitted, but the rest of the payload is intact.
|
||||
assert "active_subagents" not in usage
|
||||
assert usage["model"] == "x"
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -204,7 +204,7 @@ class TestCaptureResponseRoutedToAuxVision:
|
|||
args, _kwargs = fake_vat.call_args
|
||||
path_arg, prompt_arg = args[0], args[1]
|
||||
assert str(tmp_cache_dir) in path_arg
|
||||
assert "macOS application screenshot" in prompt_arg
|
||||
assert "desktop application screenshot" in prompt_arg
|
||||
# AX summary is included so the aux model can ground its description
|
||||
# against the same set-of-mark index the agent will see.
|
||||
assert "Sign in" in prompt_arg
|
||||
|
|
@ -298,15 +298,17 @@ class TestCaptureResponseRoutedToAuxVision:
|
|||
new_callable=lambda: fake_vat):
|
||||
resp = cu_tool._capture_response(cap)
|
||||
|
||||
# Aux failure → fall back to multimodal envelope (so the user still
|
||||
# gets *something* useful even if vision is broken).
|
||||
assert isinstance(resp, dict)
|
||||
assert resp.get("_multimodal") is True
|
||||
# Aux failure with routing requested degrades to the AX/SOM text
|
||||
# payload. Falling through to a multimodal envelope can hand pixels to
|
||||
# a text-only model and fail the provider request.
|
||||
assert isinstance(resp, str)
|
||||
body = json.loads(resp)
|
||||
assert body.get("vision_unavailable") is True
|
||||
# Temp file must still be cleaned up.
|
||||
assert observed_path["path"]
|
||||
assert not os.path.exists(observed_path["path"])
|
||||
|
||||
def test_empty_aux_analysis_falls_back_to_multimodal(self, tmp_cache_dir):
|
||||
def test_empty_aux_analysis_degrades_to_text_payload(self, tmp_cache_dir):
|
||||
from tools.computer_use import tool as cu_tool
|
||||
|
||||
cap = _make_capture(mode="som")
|
||||
|
|
@ -323,12 +325,15 @@ class TestCaptureResponseRoutedToAuxVision:
|
|||
new_callable=lambda: fake_vat):
|
||||
resp = cu_tool._capture_response(cap)
|
||||
|
||||
# Empty analysis is treated as failure — we'd rather show pixels
|
||||
# than embed an empty 'vision_analysis' string into the result.
|
||||
assert isinstance(resp, dict)
|
||||
assert resp.get("_multimodal") is True
|
||||
# Empty analysis is treated as failure; with routing requested the
|
||||
# capture degrades to the AX/SOM text payload (elements stay usable)
|
||||
# rather than embedding an empty 'vision_analysis' string.
|
||||
assert isinstance(resp, str)
|
||||
body = json.loads(resp)
|
||||
assert body.get("vision_unavailable") is True
|
||||
assert body.get("elements") is not None
|
||||
|
||||
def test_invalid_aux_response_falls_back_to_multimodal(self, tmp_cache_dir):
|
||||
def test_invalid_aux_response_degrades_to_text_payload(self, tmp_cache_dir):
|
||||
from tools.computer_use import tool as cu_tool
|
||||
|
||||
cap = _make_capture(mode="som")
|
||||
|
|
@ -345,8 +350,9 @@ class TestCaptureResponseRoutedToAuxVision:
|
|||
new_callable=lambda: fake_vat):
|
||||
resp = cu_tool._capture_response(cap)
|
||||
|
||||
assert isinstance(resp, dict)
|
||||
assert resp.get("_multimodal") is True
|
||||
assert isinstance(resp, str)
|
||||
body = json.loads(resp)
|
||||
assert body.get("vision_unavailable") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
109
tests/tools/test_file_tools_tilde_profile.py
Normal file
109
tests/tools/test_file_tools_tilde_profile.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
"""Regression tests for profile-aware tilde expansion in file tools.
|
||||
|
||||
The bug (#48552): in-process file tools (write_file, read_file, patch,
|
||||
search_files) resolved ``~`` via ``os.path.expanduser()``, which reads the
|
||||
gateway process's ``HOME``. In profile mode (Docker, systemd, s6) the gateway
|
||||
``HOME`` differs from the profile ``HOME`` that interactive sessions use, so
|
||||
``~`` expanded to the wrong directory and file operations failed with
|
||||
"no such file or directory".
|
||||
|
||||
The fix adds ``_expand_tilde()`` which delegates to
|
||||
``hermes_constants.get_subprocess_home()`` — the same policy the terminal tool
|
||||
uses for subprocess environments.
|
||||
|
||||
See: https://github.com/NousResearch/hermes-agent/issues/48552
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import tools.file_tools as ft
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _expand_tilde() unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExpandTilde:
|
||||
"""Verify the _expand_tilde() helper resolves ~ to the profile home."""
|
||||
|
||||
def test_tilde_expands_to_profile_home(self):
|
||||
"""When get_subprocess_home returns a value, ~/path uses it."""
|
||||
with patch("hermes_constants.get_subprocess_home", return_value="/opt/data/profiles/coder/home"):
|
||||
result = ft._expand_tilde("~/scratch/file.txt")
|
||||
assert result == "/opt/data/profiles/coder/home/scratch/file.txt"
|
||||
|
||||
def test_bare_tilde_expands_to_profile_home(self):
|
||||
"""Bare ~ expands to the profile home."""
|
||||
with patch("hermes_constants.get_subprocess_home", return_value="/opt/data/profiles/coder/home"):
|
||||
result = ft._expand_tilde("~")
|
||||
assert result == "/opt/data/profiles/coder/home"
|
||||
|
||||
def test_falls_back_when_no_profile_home(self):
|
||||
"""When get_subprocess_home returns None, use os.path.expanduser."""
|
||||
with patch("hermes_constants.get_subprocess_home", return_value=None):
|
||||
result = ft._expand_tilde("~/Documents")
|
||||
assert result == os.path.expanduser("~/Documents")
|
||||
|
||||
def test_other_user_tilde_not_overridden(self):
|
||||
"""~user/path must NOT use the profile home — it's a different user."""
|
||||
with patch("hermes_constants.get_subprocess_home", return_value="/opt/data/profiles/coder/home"):
|
||||
result = ft._expand_tilde("~root/file.txt")
|
||||
# Should use os.path.expanduser, not the profile home
|
||||
assert "/opt/data/profiles/coder/home" not in result
|
||||
|
||||
def test_no_tilde_unchanged(self):
|
||||
"""Paths without ~ are returned unchanged (modulo expanduser)."""
|
||||
with patch("hermes_constants.get_subprocess_home", return_value="/opt/data/profiles/coder/home"):
|
||||
result = ft._expand_tilde("/etc/passwd")
|
||||
assert result == "/etc/passwd"
|
||||
|
||||
def test_empty_path_unchanged(self):
|
||||
"""Empty string returns empty."""
|
||||
with patch("hermes_constants.get_subprocess_home", return_value="/opt/data/profiles/coder/home"):
|
||||
assert ft._expand_tilde("") == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: _resolve_path_for_task uses profile home
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResolvePathUsesProfileHome:
|
||||
"""Verify _resolve_path_for_task resolves ~ to the profile home."""
|
||||
|
||||
def test_relative_tilde_resolves_to_profile_home(self, tmp_path, monkeypatch):
|
||||
"""A ~/path argument resolves under the profile home, not process HOME."""
|
||||
profile_home = tmp_path / "profile_home"
|
||||
profile_home.mkdir()
|
||||
process_home = tmp_path / "process_home"
|
||||
process_home.mkdir()
|
||||
|
||||
monkeypatch.setenv("HOME", str(process_home))
|
||||
monkeypatch.setattr(ft, "_get_live_tracking_cwd", lambda task_id="default": None)
|
||||
|
||||
with patch("hermes_constants.get_subprocess_home", return_value=str(profile_home)):
|
||||
resolved = ft._resolve_path_for_task("~/test_file.txt", task_id="test")
|
||||
|
||||
assert str(resolved).startswith(str(profile_home))
|
||||
assert "process_home" not in str(resolved)
|
||||
|
||||
def test_absolute_tilde_in_workspace_root(self, tmp_path, monkeypatch):
|
||||
"""A workspace root specified with ~ resolves to profile home."""
|
||||
profile_home = tmp_path / "profile_home"
|
||||
profile_home.mkdir()
|
||||
process_home = tmp_path / "process_home"
|
||||
process_home.mkdir()
|
||||
|
||||
monkeypatch.setenv("HOME", str(process_home))
|
||||
monkeypatch.setattr(ft, "_get_live_tracking_cwd", lambda task_id="default": None)
|
||||
|
||||
with patch("hermes_constants.get_subprocess_home", return_value=str(profile_home)):
|
||||
# _resolve_base_dir uses the workspace root from config; if it contains ~,
|
||||
# it should resolve to profile home
|
||||
resolved = ft._resolve_path_for_task("~/data/config.json", task_id="test")
|
||||
|
||||
assert str(profile_home) in str(resolved)
|
||||
assert str(process_home) not in str(resolved)
|
||||
|
|
@ -260,6 +260,56 @@ class TestStdioPgroupReaping:
|
|||
assert fake_pid not in _orphan_stdio_pids
|
||||
assert fake_pid not in _stdio_pgids
|
||||
|
||||
def test_killpg_skipped_when_pgid_matches_gateway_own_pgroup(self, monkeypatch):
|
||||
"""#47134: when a tracked MCP child shares the gateway's OWN process
|
||||
group, killpg(pgid) would signal the gateway itself and crash it.
|
||||
The guard must skip killpg for that pgid and fall through to per-pid
|
||||
os.kill instead."""
|
||||
from tools.mcp_tool import (
|
||||
_kill_orphaned_mcp_children,
|
||||
_orphan_stdio_pids,
|
||||
_stdio_pgids,
|
||||
_lock,
|
||||
)
|
||||
|
||||
if not hasattr(os, "killpg") or not hasattr(os, "getpgrp"):
|
||||
pytest.skip("os.killpg/os.getpgrp not available on this platform")
|
||||
|
||||
self._reset_state()
|
||||
gateway_pgid = 424242
|
||||
fake_pid = 717171 # a child pid that resolves to the gateway's pgid
|
||||
other_pid = 818181 # a normal child in its OWN (non-gateway) group
|
||||
other_pgid = 818181
|
||||
with _lock:
|
||||
_orphan_stdio_pids.add(fake_pid)
|
||||
_stdio_pgids[fake_pid] = gateway_pgid # == gateway's own pgid
|
||||
_orphan_stdio_pids.add(other_pid)
|
||||
_stdio_pgids[other_pid] = other_pgid # distinct group → killpg OK
|
||||
|
||||
fake_sigkill = 9
|
||||
monkeypatch.setattr(signal, "SIGKILL", fake_sigkill, raising=False)
|
||||
|
||||
with patch("tools.mcp_tool.os.getpgrp", return_value=gateway_pgid), \
|
||||
patch("tools.mcp_tool.os.killpg") as mock_killpg, \
|
||||
patch("tools.mcp_tool.os.kill") as mock_kill, \
|
||||
patch("gateway.status._pid_exists", return_value=True), \
|
||||
patch("time.sleep"):
|
||||
_kill_orphaned_mcp_children()
|
||||
|
||||
# killpg must NEVER be called for the gateway's own pgid (would self-kill).
|
||||
killpg_pgids = [call.args[0] for call in mock_killpg.call_args_list]
|
||||
assert gateway_pgid not in killpg_pgids, (
|
||||
"killpg was called with the gateway's own pgid — self-kill (#47134)"
|
||||
)
|
||||
# The shared-pgid child must be reaped via per-pid kill instead.
|
||||
mock_kill.assert_any_call(fake_pid, signal.SIGTERM)
|
||||
mock_kill.assert_any_call(fake_pid, fake_sigkill)
|
||||
# NEGATIVE CONTROL: a child in a DISTINCT group must STILL use killpg —
|
||||
# the guard must skip only the gateway's own group, not all pgids.
|
||||
assert other_pgid in killpg_pgids, (
|
||||
"killpg must still be used for a non-gateway pgid (guard too broad)"
|
||||
)
|
||||
|
||||
def test_killpg_failure_falls_back_to_kill(self, monkeypatch):
|
||||
"""If killpg raises ProcessLookupError (pgroup gone), try os.kill."""
|
||||
from tools.mcp_tool import (
|
||||
|
|
|
|||
|
|
@ -107,6 +107,63 @@ def test_memory_gate_on_then_apply(hermes_home):
|
|||
assert "approved entry" in store.user_entries[0]
|
||||
|
||||
|
||||
def test_cli_memory_approve_without_live_agent_uses_fresh_store(hermes_home, capsys):
|
||||
"""#46783: ``/memory approve`` from a context with no live agent (e.g. the
|
||||
Desktop GUI) passed ``memory_store=None`` into the shared handler, which
|
||||
returned "memory store unavailable" and applied nothing. The CLI handler must
|
||||
fall back to a freshly loaded on-disk store, like the gateway path does."""
|
||||
import json
|
||||
from tools.memory_tool import memory_tool, MemoryStore
|
||||
from tools import write_approval as wa
|
||||
from hermes_cli.cli_commands_mixin import CLICommandsMixin
|
||||
|
||||
_set_approval("memory", True)
|
||||
staging = MemoryStore(); staging.load_from_disk()
|
||||
r = json.loads(memory_tool("add", "memory", "remember the launch date", store=staging))
|
||||
assert r.get("pending_id"), r
|
||||
assert wa.pending_count("memory") == 1
|
||||
|
||||
# Bare CLI handler with no live agent → store resolves to None pre-fix.
|
||||
handler = CLICommandsMixin.__new__(CLICommandsMixin)
|
||||
handler.agent = None
|
||||
handler._handle_memory_command("/memory approve all")
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "memory store unavailable" not in out, out
|
||||
assert "Approved 1" in out, out
|
||||
assert wa.pending_count("memory") == 0
|
||||
# The approved write landed in a freshly loaded on-disk store (MEMORY.md).
|
||||
reloaded = MemoryStore(); reloaded.load_from_disk()
|
||||
assert any("remember the launch date" in e for e in reloaded.memory_entries)
|
||||
|
||||
|
||||
def test_load_on_disk_store_honors_configured_char_limits(hermes_home, monkeypatch):
|
||||
"""load_on_disk_store() must read memory.memory_char_limit /
|
||||
user_char_limit from config so approvals applied without a live agent
|
||||
enforce the SAME caps as the live agent (agent_init.py). Falls back to
|
||||
defaults when config can't be loaded.
|
||||
"""
|
||||
from tools.memory_tool import load_on_disk_store
|
||||
|
||||
# Config override path: helper picks up the configured limits.
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.load_config",
|
||||
lambda: {"memory": {"memory_char_limit": 999, "user_char_limit": 444}},
|
||||
)
|
||||
store = load_on_disk_store()
|
||||
assert store.memory_char_limit == 999
|
||||
assert store.user_char_limit == 444
|
||||
|
||||
# Failure path: config raises → defaults, never blows up.
|
||||
def _boom():
|
||||
raise RuntimeError("no config")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", _boom)
|
||||
fallback = load_on_disk_store()
|
||||
assert fallback.memory_char_limit == 2200
|
||||
assert fallback.user_char_limit == 1375
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skill gate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -734,6 +734,100 @@ def test_session_resume_reuses_existing_live_session(server, monkeypatch):
|
|||
assert all(sid == winner for sid in server._sessions)
|
||||
|
||||
|
||||
def test_session_resume_reuses_live_agent_after_compression_rotation(server, monkeypatch):
|
||||
"""Resume must match the live agent's current session_id, not stale session_key."""
|
||||
|
||||
target = "20260409_020202_child"
|
||||
stale_parent = "20260409_010101_parent"
|
||||
sid = "live-rotated"
|
||||
server._sessions[sid] = {
|
||||
"agent": types.SimpleNamespace(model="test/model", session_id=target),
|
||||
"created_at": 123.0,
|
||||
"display_history_prefix": [],
|
||||
"history": [{"role": "assistant", "content": "live child"}],
|
||||
"history_lock": threading.RLock(),
|
||||
"last_active": 123.0,
|
||||
"running": False,
|
||||
"session_key": stale_parent,
|
||||
"transport": server._stdio_transport,
|
||||
}
|
||||
|
||||
class _DB:
|
||||
def get_session(self, _sid):
|
||||
return {"id": target}
|
||||
|
||||
def get_session_by_title(self, _title):
|
||||
return None
|
||||
|
||||
def resolve_resume_session_id(self, _target):
|
||||
return target
|
||||
|
||||
monkeypatch.setattr(server, "_get_db", lambda: _DB())
|
||||
monkeypatch.setattr(server, "_emit", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
server,
|
||||
"_session_info",
|
||||
lambda _agent, _session=None: {"model": "test/model"},
|
||||
)
|
||||
|
||||
result = server.handle_request(
|
||||
{
|
||||
"id": "r1",
|
||||
"method": "session.resume",
|
||||
"params": {"session_id": target, "cols": 100},
|
||||
}
|
||||
)
|
||||
|
||||
assert "error" not in result
|
||||
assert result["result"]["session_id"] == sid
|
||||
assert result["result"]["session_key"] == target
|
||||
assert len(server._sessions) == 1
|
||||
|
||||
|
||||
def test_sync_session_key_after_compress_reanchors_active_session_lease(
|
||||
server, monkeypatch, tmp_path
|
||||
):
|
||||
home = tmp_path / ".hermes"
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
from hermes_cli.active_sessions import (
|
||||
active_session_registry_snapshot,
|
||||
try_acquire_active_session,
|
||||
)
|
||||
|
||||
lease, message = try_acquire_active_session(
|
||||
session_id="session-old",
|
||||
surface="tui",
|
||||
config={"max_concurrent_sessions": 1},
|
||||
metadata={"live_session_id": "ui-1"},
|
||||
)
|
||||
assert message is None
|
||||
assert lease is not None
|
||||
|
||||
session = {
|
||||
"active_session_lease": lease,
|
||||
"agent": types.SimpleNamespace(session_id="session-new"),
|
||||
"session_key": "session-old",
|
||||
}
|
||||
fake_approval = types.SimpleNamespace(
|
||||
disable_session_yolo=lambda *_args, **_kwargs: None,
|
||||
enable_session_yolo=lambda *_args, **_kwargs: None,
|
||||
is_session_yolo_enabled=lambda *_args, **_kwargs: False,
|
||||
register_gateway_notify=lambda *_args, **_kwargs: None,
|
||||
unregister_gateway_notify=lambda *_args, **_kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(server, "_restart_slash_worker", lambda *_args, **_kwargs: None)
|
||||
|
||||
with patch.dict(sys.modules, {"tools.approval": fake_approval}):
|
||||
server._sync_session_key_after_compress("ui-1", session)
|
||||
|
||||
snapshot = active_session_registry_snapshot()
|
||||
assert session["session_key"] == "session-new"
|
||||
assert lease.session_id == "session-new"
|
||||
assert [entry["session_id"] for entry in snapshot] == ["session-new"]
|
||||
lease.release()
|
||||
|
||||
|
||||
def test_session_resume_live_payload_uses_current_history_with_ancestors(server, monkeypatch):
|
||||
"""Live resume should not reuse a stale ancestor-inclusive snapshot."""
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue