Merge remote-tracking branch 'origin/main' into bb/pets-merge

# Conflicts:
#	hermes_cli/commands.py
#	tui_gateway/server.py
This commit is contained in:
Brooklyn Nicholson 2026-06-23 19:05:22 -05:00
commit e495b33bf1
251 changed files with 23395 additions and 2720 deletions

View file

@ -331,6 +331,131 @@ class TestResolveAnthropicToken:
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
assert resolve_anthropic_token() == "cc-auto-token"
def test_falls_back_to_anthropic_credential_pool_oauth(self, monkeypatch, tmp_path):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
# Isolate source #4 (credential_pool): ensure source #3 (Claude Code
# creds, incl. the macOS keychain read which Path.home does not cover)
# returns nothing, mirroring a Hermes-PKCE-only setup.
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
pool_entry = SimpleNamespace(
auth_type="oauth",
access_token="pool-oauth-token",
)
pool = SimpleNamespace(
_available_entries=lambda **_kwargs: [pool_entry],
)
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
assert resolve_anthropic_token() == "pool-oauth-token"
def test_prefers_anthropic_credential_pool_oauth_over_api_key(self, monkeypatch, tmp_path):
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant...ykey")
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
# Pool (source #4) must win over ANTHROPIC_API_KEY (source #5); also
# isolate source #3 so a machine-local Claude Code creds / keychain
# entry can't short-circuit before the pool.
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
pool_entry = SimpleNamespace(
auth_type="oauth",
access_token="pool-oauth-token",
)
pool = SimpleNamespace(
_available_entries=lambda **_kwargs: [pool_entry],
)
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
assert resolve_anthropic_token() == "pool-oauth-token"
def test_pool_entry_with_null_access_token_does_not_crash(self, monkeypatch, tmp_path):
"""A persisted OAuth entry with access_token=None must not crash the
resolver (None.strip() would escape the helper's try/excepts and take
down the whole resolver incl. the ANTHROPIC_API_KEY fallback). It should
be skipped and the api-key fallback (source #5) should win."""
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant...ykey")
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
broken_entry = SimpleNamespace(auth_type="oauth", access_token=None)
pool = SimpleNamespace(
_available_entries=lambda **_kwargs: [broken_entry],
)
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
# Must fall through to source #5 (ANTHROPIC_API_KEY), not raise.
assert resolve_anthropic_token() == "sk-ant...ykey"
def test_pool_api_key_only_entry_is_not_returned_as_token(self, monkeypatch, tmp_path):
"""resolve_anthropic_token() returns an OAuth bearer token; a pool entry
whose auth_type is api_key (not oauth) must NOT be returned from the pool
path those are consumed via the aux client's _pool_runtime_api_key
lane, a different resolution concern."""
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
api_key_entry = SimpleNamespace(auth_type="api_key", access_token="sk-pool-apikey")
pool = SimpleNamespace(
_available_entries=lambda **_kwargs: [api_key_entry],
)
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
# No OAuth entry and no other source → None (the api_key entry is ignored here).
assert resolve_anthropic_token() is None
def test_pool_is_not_consulted_when_env_token_present(self, monkeypatch, tmp_path):
"""Source #1 (ANTHROPIC_TOKEN) must short-circuit before the pool: when
it is set, load_pool must never be called (ordering contract #1 → #4)."""
monkeypatch.setenv("ANTHROPIC_TOKEN", "env-token")
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
pool_calls = []
def _tracking_load_pool(provider):
pool_calls.append(provider)
raise AssertionError("load_pool must not be called when source #1 wins")
monkeypatch.setattr("agent.credential_pool.load_pool", _tracking_load_pool)
assert resolve_anthropic_token() == "env-token"
assert pool_calls == []
def test_pool_resolution_is_read_only(self, monkeypatch, tmp_path):
"""The resolver must enumerate the pool read-only — clear_expired and
refresh must both be False so a bare resolve never writes auth.json or
triggers a network refresh from diagnostic call sites (#50108 MED)."""
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
monkeypatch.setattr("agent.anthropic_adapter.Path.home", lambda: tmp_path)
monkeypatch.setattr("agent.anthropic_adapter.read_claude_code_credentials", lambda: None)
captured = {}
pool_entry = SimpleNamespace(auth_type="oauth", access_token="pool-oauth-token")
def _available_entries(**kwargs):
captured.update(kwargs)
return [pool_entry]
pool = SimpleNamespace(_available_entries=_available_entries)
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: pool)
assert resolve_anthropic_token() == "pool-oauth-token"
assert captured == {"clear_expired": False, "refresh": False}
def test_prefers_refreshable_claude_code_credentials_over_static_anthropic_token(self, monkeypatch, tmp_path):
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
monkeypatch.setenv("ANTHROPIC_TOKEN", "sk-ant-oat01-static-token")

View file

@ -206,6 +206,35 @@ class TestProjectFacts:
assert "Project: package.json" in block
assert "Verify:" not in block
def test_detect_project_facts_structured(self, tmp_path):
(tmp_path / "package.json").write_text(
json.dumps({"scripts": {"test": "vitest", "dev": "vite"}})
)
(tmp_path / "pnpm-lock.yaml").write_text("")
facts = cc.detect_project_facts(tmp_path)
assert facts.manifests == ["package.json"]
assert facts.package_managers == ["pnpm"]
assert facts.verify_commands == ["pnpm run test"] # dev excluded
assert facts.context_files == []
def test_project_facts_for_matches_prompt_block(self, tmp_path):
# Invariant: the structured facts the UI consumes must not drift from the
# commands the prompt snapshot renders — one detector feeds both.
_git_init(tmp_path)
(tmp_path / "package.json").write_text(
json.dumps({"scripts": {"test": "vitest", "lint": "eslint ."}})
)
(tmp_path / "pnpm-lock.yaml").write_text("")
facts = cc.project_facts_for(tmp_path)
assert facts is not None
verify_line = cc.build_coding_workspace_block(tmp_path).split("Verify:")[1].splitlines()[0]
assert facts["verifyCommands"]
for cmd in facts["verifyCommands"]:
assert cmd in verify_line
def test_project_facts_for_none_outside_workspace(self, tmp_path):
assert cc.project_facts_for(tmp_path) is None
# ── $HOME dotfiles guard ────────────────────────────────────────────────────

View file

@ -0,0 +1,86 @@
"""Regression: detect compression progress by tokens, not just rows.
Issue #39548: preflight compression in the turn prologue was checking
``len(messages) >= _orig_len`` to decide "Cannot compress further". This
false-positives when a pass summarises message contents reducing the
estimated request token count without removing any rows and surfaces a
spurious ``Context length exceeded`` failure followed by an auto-reset of
an otherwise healthy session.
These tests pin the contract of ``_compression_made_progress``: a
row-count reduction OR a *material* (>5%) token-count reduction counts as
progress.
"""
from __future__ import annotations
from agent.turn_context import _compression_made_progress
class TestCompressionMadeProgress:
def test_rows_reduced_counts_as_progress(self):
"""Removing message rows is the obvious progress signal."""
assert _compression_made_progress(
orig_len=10, new_len=5, orig_tokens=1000, new_tokens=1000
) is True
def test_tokens_reduced_without_row_change_counts_as_progress(self):
"""Issue #39548: 220 → 220 rows, 288k → 183k tokens IS progress."""
assert _compression_made_progress(
orig_len=220, new_len=220, orig_tokens=288_028, new_tokens=183_180
) is True
def test_both_reduced_counts_as_progress(self):
"""Common case: summarising drops some rows and shrinks the rest."""
assert _compression_made_progress(
orig_len=220, new_len=180, orig_tokens=288_028, new_tokens=150_000
) is True
def test_neither_moved_means_no_progress(self):
"""The genuine "stuck" case — same rows, same tokens, give up."""
assert _compression_made_progress(
orig_len=10, new_len=10, orig_tokens=1000, new_tokens=1000
) is False
def test_rows_grew_and_tokens_grew_means_no_progress(self):
"""Pathological: the pass made the request larger — definitely stuck."""
assert _compression_made_progress(
orig_len=10, new_len=12, orig_tokens=1000, new_tokens=1200
) is False
def test_rows_grew_but_tokens_dropped_is_progress(self):
"""Edge: summary rows may expand the row count while shrinking tokens.
Token reduction alone is sufficient to keep the loop going.
"""
assert _compression_made_progress(
orig_len=10, new_len=11, orig_tokens=1000, new_tokens=600
) is True
def test_tokens_grew_but_rows_dropped_is_progress(self):
"""Edge: row reduction alone is sufficient even if tokens nominally
creep up (e.g. summary verbosity). Row-count reduction is a hard
signal that the transcript actually shrank.
"""
assert _compression_made_progress(
orig_len=10, new_len=5, orig_tokens=1000, new_tokens=1100
) is True
def test_sub_5pct_token_drop_is_not_progress(self):
"""A token reduction below the 5% material floor does NOT count as
progress matching the overflow-handler retry path (#39550) so a
marginal wobble can't keep the multi-pass loop spinning."""
# 1000 -> 970 is a 3% drop, below the 5% floor.
assert _compression_made_progress(
orig_len=10, new_len=10, orig_tokens=1000, new_tokens=970
) is False
# 1000 -> 940 is a 6% drop, above the floor.
assert _compression_made_progress(
orig_len=10, new_len=10, orig_tokens=1000, new_tokens=940
) is True
def test_zero_orig_tokens_is_not_progress(self):
"""Degenerate estimate (0 tokens) must not be read as a token win."""
assert _compression_made_progress(
orig_len=10, new_len=10, orig_tokens=0, new_tokens=0
) is False

View file

@ -0,0 +1,107 @@
"""Regression tests for tool_call envelope accounting in the compression
tail-protection budget walks (issue #28053).
The budget walks used to estimate an assistant message's tokens from
content + ``function.arguments`` only, dropping each ``tool_call``'s ``id``,
``type`` and ``function.name`` (plus JSON structure). For assistant turns
that fan out into parallel tool calls this undercounted by 2-15x, so the
protected tail overshot ``tail_token_budget`` and compression became
ineffective. The fix routes all three walks through
``_estimate_msg_budget_tokens``, which counts the full envelope.
"""
import pytest
from unittest.mock import patch
from agent.context_compressor import (
ContextCompressor,
_CHARS_PER_TOKEN,
_estimate_msg_budget_tokens,
)
def _assistant_with_tool_calls(n_calls: int, *, args: str = '{"path":"a"}') -> dict:
"""An assistant turn fanning into ``n_calls`` parallel tool calls with
realistic id/name overhead but a small arguments string."""
return {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": f"call_{i:02d}_{'a' * 24}", # ~32 chars, UUID-ish id
"type": "function",
"function": {"name": "read_file", "arguments": args},
}
for i in range(n_calls)
],
}
def _args_only_estimate(msg: dict) -> int:
"""Reproduce the OLD (buggy) arguments-only walk for comparison."""
content = msg.get("content") or ""
tokens = len(content) // _CHARS_PER_TOKEN + 10
for tc in msg.get("tool_calls") or []:
if isinstance(tc, dict):
tokens += len(tc.get("function", {}).get("arguments", "")) // _CHARS_PER_TOKEN
return tokens
class TestToolCallEnvelopeEstimate:
def test_envelope_counted_not_just_arguments(self):
msg = _assistant_with_tool_calls(4)
new = _estimate_msg_budget_tokens(msg)
old = _args_only_estimate(msg)
# id/type/name + JSON structure dwarf the tiny arguments string.
assert new > old * 3, (new, old)
# The estimate covers the full serialized tool_call envelope.
envelope = sum(len(str(tc)) for tc in msg["tool_calls"]) // _CHARS_PER_TOKEN
assert new >= envelope
def test_scales_with_number_of_parallel_calls(self):
one = _estimate_msg_budget_tokens(_assistant_with_tool_calls(1))
five = _estimate_msg_budget_tokens(_assistant_with_tool_calls(5))
assert five > one * 3
def test_no_tool_calls_matches_content_estimate(self):
msg = {"role": "user", "content": "x" * 400}
# Plain message: content//4 + 10 overhead, behavior unchanged.
assert _estimate_msg_budget_tokens(msg) == 400 // _CHARS_PER_TOKEN + 10
def test_non_dict_tool_calls_do_not_crash(self):
msg = {"role": "assistant", "content": "hi", "tool_calls": ["weird", None]}
# Non-dict entries are ignored (as before) without raising.
assert _estimate_msg_budget_tokens(msg) == len("hi") // _CHARS_PER_TOKEN + 10
@pytest.fixture()
def compressor():
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
return ContextCompressor(
model="test/model",
threshold_percent=0.85,
protect_first_n=2,
protect_last_n=2,
quiet_mode=True,
)
class TestTailCutAccountsForToolCalls:
def test_tail_cut_stops_on_tool_call_heavy_tail(self, compressor):
# 20 assistant turns, each fanning into 5 short-arg tool calls.
heavy = [_assistant_with_tool_calls(5) for _ in range(20)]
messages = [{"role": "user", "content": "start"}] + heavy
per_msg = _estimate_msg_budget_tokens(messages[-1])
assert per_msg > 30 # sanity: a heavy turn is non-trivial once the envelope counts
# Budget sized so ~6 heavy turns fit under the 1.5x soft ceiling.
token_budget = int(per_msg * 6 / 1.5)
cut = compressor._find_tail_cut_by_tokens(messages, head_end=1, token_budget=token_budget)
protected = len(messages) - cut
# With the envelope counted, the walk stops well short of protecting all
# 20 turns. The old arguments-only estimate (~25 tokens/turn) never
# reaches the ceiling and would protect the entire transcript.
assert protected < len(heavy)
assert 3 <= protected <= 12

View file

@ -86,6 +86,28 @@ class TestPreflightDeferral:
assert compressor.should_defer_preflight_to_real_usage(93_000) is False
def test_defers_immediately_after_compaction_with_stale_real_prompt(self, compressor):
"""#36718: right after a compaction, last_real_prompt_tokens still holds
the stale pre-compression value (above threshold). The awaiting flag
must force deferral so preflight doesn't fire a SECOND compaction before
real post-compaction usage arrives."""
compressor.threshold_tokens = 85_000
# Stale pre-compression value — would hit the `>= threshold => False`
# short-circuit and defeat deferral without the flag guard.
compressor.last_real_prompt_tokens = 120_000
compressor.awaiting_real_usage_after_compression = True
assert compressor.should_defer_preflight_to_real_usage(95_000) is True
def test_resumes_normal_deferral_after_flag_cleared(self, compressor):
"""Once update_from_response() clears the flag, the normal baseline/
growth deferral logic governs again (no permanent deferral)."""
compressor.threshold_tokens = 85_000
compressor.last_real_prompt_tokens = 120_000
compressor.awaiting_real_usage_after_compression = False
# Stale-high real prompt with the flag cleared => the >= threshold
# short-circuit applies => no deferral.
assert compressor.should_defer_preflight_to_real_usage(95_000) is False
class TestCompress:
@ -242,6 +264,59 @@ class TestCompress:
assert c.should_compress(55000) is True
assert c.should_compress(40000) is False
def test_max_tokens_reservation_lowers_threshold(self):
"""#43547: the provider reserves max_tokens out of the window, so the
threshold must be based on (context_length - max_tokens), not the full
window. A 200K model reserving 65536 output tokens has a ~134K input
budget; at 50% that's ~67K, NOT 100K."""
# No reservation (provider default) → full-window behavior, unchanged.
assert ContextCompressor._compute_threshold_tokens(200000, 0.50) == 100000
assert ContextCompressor._compute_threshold_tokens(200000, 0.50, None) == 100000
# 65536 reserved → effective input budget 134464; 50% = 67232.
assert ContextCompressor._compute_threshold_tokens(200000, 0.50, 65536) == 67232
def test_max_tokens_reservation_with_small_window_floors(self):
"""With a large reservation on a smaller window the effective budget
can drop near/below the minimum floor the degenerate-window guard
then triggers at 85% of the EFFECTIVE budget, never the raw window."""
# 128K window, 65536 reserved → effective 62464 (< MINIMUM 64000).
# Floor (64000) >= effective window (62464) → 85% of effective.
t = ContextCompressor._compute_threshold_tokens(128000, 0.50, 65536)
assert t == int(62464 * 0.85) # 53094
assert t < 62464
def test_max_tokens_exceeding_window_falls_back_to_full(self):
"""Pathological: max_tokens >= context_length would make the effective
budget <= 0; fall back to the full window rather than produce a
non-positive threshold."""
t = ContextCompressor._compute_threshold_tokens(64000, 0.50, 70000)
# effective_window <= 0 → fall back to full context (64000) → 85% guard.
assert t == 54400 # 85% of 64000, same as no-reservation small-ctx case
assert t > 0
def test_max_tokens_coercion_treats_non_int_as_no_reservation(self):
"""A non-int / non-positive max_tokens must coerce safely so the
threshold arithmetic never raises. Guards the path where a mocked
parent agent forwards a MagicMock max_tokens into a child
ContextCompressor (regression for the delegate-test TypeError:
'<=' not supported between MagicMock and int)."""
from unittest.mock import MagicMock
assert ContextCompressor._coerce_max_tokens(None) is None
assert ContextCompressor._coerce_max_tokens(0) is None
assert ContextCompressor._coerce_max_tokens(-5) is None
assert ContextCompressor._coerce_max_tokens("nope") is None
assert ContextCompressor._coerce_max_tokens(65536) == 65536
# The actual regression: building a compressor with a MagicMock
# max_tokens must NOT raise (the unmocked code did `ctx - MagicMock`
# then `MagicMock <= 0`). int(MagicMock()) returns 1, so coercion
# yields a harmless positive int rather than crashing — the threshold
# is computed cleanly with a 1-token reservation.
with patch("agent.context_compressor.get_model_context_length", return_value=200000):
c = ContextCompressor(model="m", quiet_mode=True, max_tokens=MagicMock())
assert isinstance(c.max_tokens, int)
assert isinstance(c.threshold_tokens, int)
assert c.threshold_tokens > 0 # no crash, sane value
def test_compression_increments_count(self, compressor):
msgs = self._make_messages(10)
# Default config (abort_on_summary_failure=False) — fallback path

View file

@ -0,0 +1,73 @@
"""Tests for /learn — open-ended skill distillation.
Covers the shared prompt builder (agent.learn_prompt.build_learn_prompt) and
the slash-command registry wiring. /learn has no engine and no model tool: it
builds a standards-guided prompt that the live agent runs as a normal turn, so
these are the load-bearing behavior contracts.
"""
from agent.learn_prompt import build_learn_prompt, _AUTHORING_STANDARDS
class TestBuildLearnPrompt:
def test_embeds_the_user_request_verbatim(self):
req = "the REST client in ~/projects/acme-sdk, focus on auth"
prompt = build_learn_prompt(req)
assert req in prompt
def test_always_includes_the_authoring_standards(self):
# The standards are what make distilled skills match house style;
# they must travel with every prompt regardless of input.
for req in ["", "a url https://x/y", "what we just did"]:
assert _AUTHORING_STANDARDS in build_learn_prompt(req)
def test_instructs_saving_via_skill_manage_not_a_raw_file(self):
prompt = build_learn_prompt("learn the thing")
assert "skill_manage" in prompt
def test_references_gather_tools_for_open_ended_sourcing(self):
# Open-ended sourcing relies on the agent's own tools, named so it
# knows dirs/URLs/conversation/paste all route through existing tools.
prompt = build_learn_prompt("learn from somewhere")
for tool in ("read_file", "search_files", "web_extract"):
assert tool in prompt
def test_empty_request_falls_back_to_the_conversation(self):
# Bare /learn should distill "what we just did", not error.
prompt = build_learn_prompt("")
assert "conversation" in prompt.lower()
# And still carries the standards + save instruction.
assert "skill_manage" in prompt
def test_whitespace_only_request_is_treated_as_empty(self):
assert build_learn_prompt(" \n ") == build_learn_prompt("")
def test_description_length_rule_is_in_the_standards(self):
# The single most-violated rule must be explicit in the prompt.
assert "60" in _AUTHORING_STANDARDS
class TestLearnRegistryWiring:
def test_learn_is_registered_and_resolves(self):
from hermes_cli.commands import resolve_command
cmd = resolve_command("learn")
assert cmd is not None
assert cmd.name == "learn"
def test_learn_is_in_tools_and_skills_category(self):
from hermes_cli.commands import resolve_command
assert resolve_command("learn").category == "Tools & Skills"
def test_learn_works_on_the_gateway(self):
# /learn must reach the gateway runner (it's a both-surfaces command),
# not be CLI-only.
from hermes_cli.commands import GATEWAY_KNOWN_COMMANDS
assert "learn" in GATEWAY_KNOWN_COMMANDS
def test_learn_is_not_cli_only(self):
from hermes_cli.commands import resolve_command
assert not resolve_command("learn").cli_only

View file

@ -1172,16 +1172,12 @@ class TestOnMemoryWriteBridge:
mgr.on_memory_write("replace", "user", "updated pref")
assert p.memory_writes == [("replace", "user", "updated pref")]
def test_on_memory_write_remove_not_bridged(self):
"""The bridge intentionally skips 'remove' — only add/replace notify."""
# This tests the contract that run_agent.py checks:
# function_args.get("action") in ("add", "replace")
def test_on_memory_write_remove_supported_by_manager(self):
"""The manager forwards remove actions when a caller elects to bridge them."""
mgr = MemoryManager()
p = FakeMemoryProvider("ext")
mgr.add_provider(p)
# Manager itself doesn't filter — run_agent.py does.
# But providers should handle remove gracefully.
mgr.on_memory_write("remove", "memory", "old fact")
assert p.memory_writes == [("remove", "memory", "old fact")]

View file

@ -0,0 +1,145 @@
"""Behavior tests for the built-in memory → external provider bridge.
The bridge lives behind the MemoryManager interface
(``MemoryManager.notify_memory_tool_write``): the agent loop hands over the raw
built-in memory tool result + args, and the manager decides whether/what to
mirror to external providers. These tests drive that method with a fake
external provider and assert which ``on_memory_write`` calls land.
"""
import json
import pytest
from agent.memory_manager import MemoryManager
from agent.memory_provider import MemoryProvider
class _RecordingProvider(MemoryProvider):
"""Minimal external provider that records on_memory_write calls."""
def __init__(self) -> None:
self.calls = []
@property
def name(self) -> str:
return "recording"
def is_available(self) -> bool:
return True
def initialize(self, session_id: str, **kwargs) -> None:
pass
def get_tool_schemas(self):
return []
def shutdown(self) -> None:
pass
def on_memory_write(self, action, target, content, metadata=None):
self.calls.append({
"action": action,
"target": target,
"content": content,
"metadata": dict(metadata or {}),
})
def _manager_with_provider():
mgr = MemoryManager()
provider = _RecordingProvider()
mgr.add_provider(provider)
return mgr, provider
def test_notifies_remove_with_old_text_after_success():
mgr, provider = _manager_with_provider()
mgr.notify_memory_tool_write(
json.dumps({"success": True}),
{"action": "remove", "target": "memory", "old_text": "stale preference entry"},
)
assert provider.calls == [
{
"action": "remove",
"target": "memory",
"content": "",
"metadata": {"old_text": "stale preference entry"},
}
]
def test_skips_failed_memory_write():
mgr, provider = _manager_with_provider()
mgr.notify_memory_tool_write(
json.dumps({"success": False, "error": "No entry matched"}),
{"action": "remove", "target": "memory", "old_text": "stale preference entry"},
)
assert provider.calls == []
def test_skips_staged_memory_write():
mgr, provider = _manager_with_provider()
mgr.notify_memory_tool_write(
json.dumps({"success": True, "staged": True, "pending_id": "abc123"}),
{"action": "remove", "target": "memory", "old_text": "stale preference entry"},
)
assert provider.calls == []
@pytest.mark.parametrize("tool_result", [None, [], object(), "not-json"])
def test_skips_unrecognized_tool_result_shape(tool_result):
mgr, provider = _manager_with_provider()
mgr.notify_memory_tool_write(
tool_result,
{"action": "add", "target": "memory", "content": "new fact"},
)
assert provider.calls == []
def test_preserves_old_text_for_replace_and_remove_batch():
mgr, provider = _manager_with_provider()
mgr.notify_memory_tool_write(
json.dumps({"success": True}),
{
"target": "user",
"operations": [
{"action": "replace", "old_text": "old preference", "content": "updated"},
{"action": "remove", "old_text": "obsolete preference"},
{"action": "add", "content": "new fact"},
],
},
)
assert provider.calls == [
{"action": "replace", "target": "user", "content": "updated",
"metadata": {"old_text": "old preference"}},
{"action": "remove", "target": "user", "content": "",
"metadata": {"old_text": "obsolete preference"}},
{"action": "add", "target": "user", "content": "new fact", "metadata": {}},
]
def test_non_mutating_actions_are_not_mirrored():
mgr, provider = _manager_with_provider()
mgr.notify_memory_tool_write(
json.dumps({"success": True}),
{"action": "read", "target": "memory"},
)
assert provider.calls == []
def test_build_metadata_callback_is_merged_per_op():
mgr, provider = _manager_with_provider()
mgr.notify_memory_tool_write(
json.dumps({"success": True}),
{"action": "add", "target": "memory", "content": "fact"},
build_metadata=lambda: {"session_id": "s1", "tool_name": "memory"},
)
assert provider.calls == [
{
"action": "add",
"target": "memory",
"content": "fact",
"metadata": {"session_id": "s1", "tool_name": "memory"},
}
]

110
tests/agent/test_oneshot.py Normal file
View file

@ -0,0 +1,110 @@
"""Tests for agent.oneshot — shared one-off (stateless) LLM requests."""
from unittest.mock import MagicMock, patch
import pytest
from agent.oneshot import (
PROMPT_TEMPLATES,
render_template,
run_oneshot,
_strip_code_fence,
_truncate,
)
class TestRenderTemplate:
def test_unknown_template_raises(self):
with pytest.raises(KeyError):
render_template("does-not-exist", {})
def test_commit_message_template_is_registered(self):
assert "commit_message" in PROMPT_TEMPLATES
def test_commit_message_includes_diff_and_recent(self):
instructions, user = render_template(
"commit_message",
{"diff": "diff --git a/x b/x\n+new", "recent_commits": "feat: a\nfix: b"},
)
# Instructions describe the contract (conventional commits), not a snapshot.
assert "Conventional Commits" in instructions
assert "diff --git a/x b/x" in user
assert "feat: a" in user
def test_commit_message_diff_with_braces_passes_through(self):
# Templates must not use str.format — code payloads carry literal { }.
_, user = render_template("commit_message", {"diff": "x = {a: 1}"})
assert "x = {a: 1}" in user
def test_commit_message_handles_missing_variables(self):
instructions, user = render_template("commit_message", {})
assert instructions
assert "no textual diff available" in user
def test_commit_message_avoid_forces_new_message(self):
# Passing the previous message must instruct the model not to repeat it,
# so "regenerate" yields a different result even on greedy models.
_, plain = render_template("commit_message", {"diff": "d"})
_, regen = render_template("commit_message", {"diff": "d", "avoid": "feat: prior"})
assert "feat: prior" in regen
assert "do not repeat" in regen
assert "feat: prior" not in plain
class TestRunOneshot:
def _mock_response(self, content):
resp = MagicMock()
resp.choices = [MagicMock()]
resp.choices[0].message.content = content
resp.choices[0].message.reasoning = None
resp.choices[0].message.reasoning_content = None
resp.choices[0].message.reasoning_details = None
return resp
def test_template_path_calls_llm_with_rendered_prompt(self):
with patch(
"agent.oneshot.call_llm",
return_value=self._mock_response("feat: add thing"),
) as llm:
out = run_oneshot(template="commit_message", variables={"diff": "d"})
assert out == "feat: add thing"
messages = llm.call_args.kwargs["messages"]
assert messages[0]["role"] == "system"
assert messages[1]["role"] == "user"
def test_explicit_instructions_path(self):
with patch(
"agent.oneshot.call_llm",
return_value=self._mock_response("hello"),
) as llm:
out = run_oneshot(instructions="be brief", user_input="say hi")
assert out == "hello"
messages = llm.call_args.kwargs["messages"]
assert messages[0]["content"] == "be brief"
assert messages[1]["content"] == "say hi"
def test_requires_template_or_prompt(self):
with pytest.raises(ValueError):
run_oneshot()
def test_strips_wrapping_code_fence(self):
with patch(
"agent.oneshot.call_llm",
return_value=self._mock_response("```\nfix: bug\n```"),
):
assert run_oneshot(instructions="x", user_input="y") == "fix: bug"
class TestHelpers:
def test_truncate_under_limit_unchanged(self):
assert _truncate("short", 100) == "short"
def test_truncate_over_limit_marks_truncation(self):
out = _truncate("x" * 200, 50)
assert out.endswith("…(truncated)")
assert len(out) < 200
def test_strip_code_fence_without_fence_is_noop(self):
assert _strip_code_fence("plain text") == "plain text"

View file

@ -100,7 +100,13 @@ class _StubAgent:
pass
def _run(agent):
def _run(
agent,
*,
final_response=None,
api_call_count=3,
turn_exit_reason="unknown",
):
messages = [
{"role": "user", "content": "do a thing"},
{
@ -114,8 +120,8 @@ def _run(agent):
]
return finalize_turn(
agent,
final_response=None, # forces the max-iterations summary path
api_call_count=3,
final_response=final_response,
api_call_count=api_call_count,
interrupted=False,
failed=False,
messages=messages,
@ -125,7 +131,7 @@ def _run(agent):
user_message="do a thing",
original_user_message="do a thing",
_should_review_memory=False,
_turn_exit_reason="unknown",
_turn_exit_reason=turn_exit_reason,
)
@ -162,4 +168,17 @@ def test_clean_turn_has_no_cleanup_errors_key():
agent = _StubAgent(raise_in=())
result = _run(agent)
assert result["final_response"] == "PARTIAL SUMMARY FROM MODEL"
assert result["completed"] is False
assert "cleanup_errors" not in result
def test_text_response_on_last_allowed_call_is_completed():
agent = _StubAgent(raise_in=())
result = _run(
agent,
final_response="final report",
api_call_count=agent.max_iterations,
turn_exit_reason="text_response(finish_reason=stop)",
)
assert result["final_response"] == "final report"
assert result["completed"] is True