mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-04 02:21:47 +00:00
Merge remote-tracking branch 'origin/main' into hermes/hermes-1f7bfa9e
# Conflicts: # cron/scheduler.py # tools/send_message_tool.py
This commit is contained in:
commit
e7fc6450fc
99 changed files with 9609 additions and 1075 deletions
|
|
@ -576,11 +576,19 @@ class TestSummaryTargetRatio:
|
|||
assert c.summary_target_ratio == 0.80
|
||||
|
||||
def test_default_threshold_is_50_percent(self):
|
||||
"""Default compression threshold should be 50%."""
|
||||
"""Default compression threshold should be 50%, with a 64K floor."""
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100_000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True)
|
||||
assert c.threshold_percent == 0.50
|
||||
assert c.threshold_tokens == 50_000
|
||||
# 50% of 100K = 50K, but the floor is 64K
|
||||
assert c.threshold_tokens == 64_000
|
||||
|
||||
def test_threshold_floor_does_not_apply_above_128k(self):
|
||||
"""On large-context models the 50% percentage is used directly."""
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=200_000):
|
||||
c = ContextCompressor(model="test", quiet_mode=True)
|
||||
# 50% of 200K = 100K, which is above the 64K floor
|
||||
assert c.threshold_tokens == 100_000
|
||||
|
||||
def test_default_protect_last_n_is_20(self):
|
||||
"""Default protect_last_n should be 20."""
|
||||
|
|
|
|||
|
|
@ -22,6 +22,9 @@ class TestLocalStreamReadTimeout:
|
|||
"http://0.0.0.0:5000",
|
||||
"http://192.168.1.100:8000",
|
||||
"http://10.0.0.5:1234",
|
||||
"http://host.docker.internal:11434",
|
||||
"http://host.containers.internal:11434",
|
||||
"http://host.lima.internal:11434",
|
||||
])
|
||||
def test_local_endpoint_bumps_read_timeout(self, base_url):
|
||||
"""Local endpoint + default timeout -> bumps to base_timeout."""
|
||||
|
|
@ -68,3 +71,38 @@ class TestLocalStreamReadTimeout:
|
|||
if _stream_read_timeout == 120.0 and base_url and is_local_endpoint(base_url):
|
||||
_stream_read_timeout = _base_timeout
|
||||
assert _stream_read_timeout == 120.0
|
||||
|
||||
|
||||
class TestIsLocalEndpoint:
|
||||
"""Direct unit tests for is_local_endpoint."""
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"http://localhost:11434",
|
||||
"http://127.0.0.1:8080",
|
||||
"http://0.0.0.0:5000",
|
||||
"http://[::1]:11434",
|
||||
"http://192.168.1.100:8000",
|
||||
"http://10.0.0.5:1234",
|
||||
"http://172.17.0.1:11434",
|
||||
])
|
||||
def test_classic_local_addresses(self, url):
|
||||
assert is_local_endpoint(url) is True
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"http://host.docker.internal:11434",
|
||||
"http://host.docker.internal:8080/v1",
|
||||
"http://gateway.docker.internal:11434",
|
||||
"http://host.containers.internal:11434",
|
||||
"http://host.lima.internal:11434",
|
||||
])
|
||||
def test_container_dns_names(self, url):
|
||||
assert is_local_endpoint(url) is True
|
||||
|
||||
@pytest.mark.parametrize("url", [
|
||||
"https://api.openai.com",
|
||||
"https://openrouter.ai/api",
|
||||
"https://api.anthropic.com",
|
||||
"https://evil.docker.internal.example.com",
|
||||
])
|
||||
def test_remote_endpoints(self, url):
|
||||
assert is_local_endpoint(url) is False
|
||||
|
|
|
|||
|
|
@ -50,7 +50,8 @@ class TestEstimateTokensRough:
|
|||
assert estimate_tokens_rough("a" * 400) == 100
|
||||
|
||||
def test_short_text(self):
|
||||
assert estimate_tokens_rough("hello") == 1
|
||||
# "hello" = 5 chars → ceil(5/4) = 2
|
||||
assert estimate_tokens_rough("hello") == 2
|
||||
|
||||
def test_proportional(self):
|
||||
short = estimate_tokens_rough("hello world")
|
||||
|
|
@ -68,10 +69,11 @@ class TestEstimateMessagesTokensRough:
|
|||
assert estimate_messages_tokens_rough([]) == 0
|
||||
|
||||
def test_single_message_concrete_value(self):
|
||||
"""Verify against known str(msg) length."""
|
||||
"""Verify against known str(msg) length (ceiling division)."""
|
||||
msg = {"role": "user", "content": "a" * 400}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
expected = len(str(msg)) // 4
|
||||
n = len(str(msg))
|
||||
expected = (n + 3) // 4
|
||||
assert result == expected
|
||||
|
||||
def test_multiple_messages_additive(self):
|
||||
|
|
@ -80,7 +82,8 @@ class TestEstimateMessagesTokensRough:
|
|||
{"role": "assistant", "content": "Hi there, how can I help?"},
|
||||
]
|
||||
result = estimate_messages_tokens_rough(msgs)
|
||||
expected = sum(len(str(m)) for m in msgs) // 4
|
||||
n = sum(len(str(m)) for m in msgs)
|
||||
expected = (n + 3) // 4
|
||||
assert result == expected
|
||||
|
||||
def test_tool_call_message(self):
|
||||
|
|
@ -89,7 +92,7 @@ class TestEstimateMessagesTokensRough:
|
|||
"tool_calls": [{"id": "1", "function": {"name": "terminal", "arguments": "{}"}}]}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
assert result > 0
|
||||
assert result == len(str(msg)) // 4
|
||||
assert result == (len(str(msg)) + 3) // 4
|
||||
|
||||
def test_message_with_list_content(self):
|
||||
"""Vision messages with multimodal content arrays."""
|
||||
|
|
@ -98,7 +101,7 @@ class TestEstimateMessagesTokensRough:
|
|||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}
|
||||
]}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
assert result == len(str(msg)) // 4
|
||||
assert result == (len(str(msg)) + 3) // 4
|
||||
|
||||
|
||||
# =========================================================================
|
||||
|
|
|
|||
|
|
@ -1009,65 +1009,4 @@ class TestOpenAIModelExecutionGuidance:
|
|||
# =========================================================================
|
||||
|
||||
|
||||
class TestStripBudgetWarningsFromHistory:
|
||||
def test_strips_json_budget_warning_key(self):
|
||||
import json
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1", "content": json.dumps({
|
||||
"output": "hello",
|
||||
"exit_code": 0,
|
||||
"_budget_warning": "[BUDGET: Iteration 55/60. 5 iterations left. Start consolidating your work.]",
|
||||
})},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
parsed = json.loads(messages[0]["content"])
|
||||
assert "_budget_warning" not in parsed
|
||||
assert parsed["output"] == "hello"
|
||||
assert parsed["exit_code"] == 0
|
||||
|
||||
def test_strips_text_budget_warning(self):
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1",
|
||||
"content": "some result\n\n[BUDGET WARNING: Iteration 58/60. Only 2 iteration(s) left. Provide your final response NOW. No more tool calls unless absolutely critical.]"},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
assert messages[0]["content"] == "some result"
|
||||
|
||||
def test_leaves_non_tool_messages_unchanged(self):
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "assistant", "content": "[BUDGET WARNING: Iteration 58/60. Only 2 iteration(s) left. Provide your final response NOW. No more tool calls unless absolutely critical.]"},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
original_contents = [m["content"] for m in messages]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
assert [m["content"] for m in messages] == original_contents
|
||||
|
||||
def test_handles_empty_and_missing_content(self):
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1", "content": ""},
|
||||
{"role": "tool", "tool_call_id": "c2"},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
assert messages[0]["content"] == ""
|
||||
|
||||
def test_strips_caution_variant(self):
|
||||
import json
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1", "content": json.dumps({
|
||||
"output": "ok",
|
||||
"_budget_warning": "[BUDGET: Iteration 42/60. 18 iterations left. Start consolidating your work.]",
|
||||
})},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
parsed = json.loads(messages[0]["content"])
|
||||
assert "_budget_warning" not in parsed
|
||||
|
|
|
|||
|
|
@ -74,6 +74,26 @@ class FakeBot:
|
|||
return None
|
||||
|
||||
|
||||
class SlowSyncTree(FakeTree):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.started = asyncio.Event()
|
||||
self.allow_finish = asyncio.Event()
|
||||
|
||||
async def _slow_sync():
|
||||
self.started.set()
|
||||
await self.allow_finish.wait()
|
||||
return []
|
||||
|
||||
self.sync = AsyncMock(side_effect=_slow_sync)
|
||||
|
||||
|
||||
class SlowSyncBot(FakeBot):
|
||||
def __init__(self, *, intents, proxy=None):
|
||||
super().__init__(intents=intents, proxy=proxy)
|
||||
self.tree = SlowSyncTree()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("allowed_users", "expected_members_intent"),
|
||||
|
|
@ -138,3 +158,36 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch):
|
|||
assert ok is False
|
||||
assert released == [("discord-bot-token", "test-token")]
|
||||
assert adapter._platform_lock_identity is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_does_not_wait_for_slash_sync(monkeypatch):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="test-token"))
|
||||
|
||||
monkeypatch.setattr("gateway.status.acquire_scoped_lock", lambda scope, identity, metadata=None: (True, None))
|
||||
monkeypatch.setattr("gateway.status.release_scoped_lock", lambda scope, identity: None)
|
||||
|
||||
intents = SimpleNamespace(message_content=False, dm_messages=False, guild_messages=False, members=False, voice_states=False)
|
||||
monkeypatch.setattr(discord_platform.Intents, "default", lambda: intents)
|
||||
|
||||
created = {}
|
||||
|
||||
def fake_bot_factory(*, command_prefix, intents, proxy=None):
|
||||
bot = SlowSyncBot(intents=intents, proxy=proxy)
|
||||
created["bot"] = bot
|
||||
return bot
|
||||
|
||||
monkeypatch.setattr(discord_platform.commands, "Bot", fake_bot_factory)
|
||||
monkeypatch.setattr(adapter, "_resolve_allowed_usernames", AsyncMock())
|
||||
|
||||
ok = await asyncio.wait_for(adapter.connect(), timeout=1.0)
|
||||
|
||||
assert ok is True
|
||||
assert adapter._ready_event.is_set()
|
||||
|
||||
await asyncio.wait_for(created["bot"].tree.started.wait(), timeout=1.0)
|
||||
assert created["bot"].tree.sync.await_count == 1
|
||||
|
||||
created["bot"].tree.allow_finish.set()
|
||||
await asyncio.sleep(0)
|
||||
await adapter.disconnect()
|
||||
|
|
|
|||
355
tests/gateway/test_display_config.py
Normal file
355
tests/gateway/test_display_config.py
Normal file
|
|
@ -0,0 +1,355 @@
|
|||
"""Tests for gateway.display_config — per-platform display/verbosity resolver."""
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Resolver: resolution order
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResolveDisplaySetting:
|
||||
"""resolve_display_setting() resolves with correct priority."""
|
||||
|
||||
def test_explicit_platform_override_wins(self):
|
||||
"""display.platforms.<plat>.<key> takes top priority."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"platforms": {
|
||||
"telegram": {"tool_progress": "verbose"},
|
||||
},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "verbose"
|
||||
|
||||
def test_global_setting_when_no_platform_override(self):
|
||||
"""Falls back to display.<key> when no platform override exists."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "new",
|
||||
"platforms": {},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "new"
|
||||
|
||||
def test_platform_default_when_no_user_config(self):
|
||||
"""Falls back to built-in platform default."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
# Empty config — should get built-in defaults
|
||||
config = {}
|
||||
# Telegram defaults to tier_high → "all"
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
|
||||
# Email defaults to tier_minimal → "off"
|
||||
assert resolve_display_setting(config, "email", "tool_progress") == "off"
|
||||
|
||||
def test_global_default_for_unknown_platform(self):
|
||||
"""Unknown platforms get the global defaults."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {}
|
||||
# Unknown platform, no config → global default "all"
|
||||
assert resolve_display_setting(config, "unknown_platform", "tool_progress") == "all"
|
||||
|
||||
def test_fallback_parameter_used_last(self):
|
||||
"""Explicit fallback is used when nothing else matches."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {}
|
||||
# "nonexistent_key" isn't in any defaults
|
||||
result = resolve_display_setting(config, "telegram", "nonexistent_key", "my_fallback")
|
||||
assert result == "my_fallback"
|
||||
|
||||
def test_platform_override_only_affects_that_platform(self):
|
||||
"""Other platforms are unaffected by a specific platform override."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"platforms": {
|
||||
"slack": {"tool_progress": "off"},
|
||||
},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "slack", "tool_progress") == "off"
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backward compatibility: tool_progress_overrides
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBackwardCompat:
|
||||
"""Legacy tool_progress_overrides is still respected as a fallback."""
|
||||
|
||||
def test_legacy_overrides_read(self):
|
||||
"""tool_progress_overrides is read when no platforms entry exists."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"tool_progress_overrides": {
|
||||
"signal": "off",
|
||||
"telegram": "verbose",
|
||||
},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "signal", "tool_progress") == "off"
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "verbose"
|
||||
|
||||
def test_new_platforms_takes_precedence_over_legacy(self):
|
||||
"""display.platforms beats tool_progress_overrides."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "all",
|
||||
"tool_progress_overrides": {"telegram": "verbose"},
|
||||
"platforms": {"telegram": {"tool_progress": "new"}},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "new"
|
||||
|
||||
def test_legacy_overrides_only_for_tool_progress(self):
|
||||
"""Legacy overrides don't affect other settings."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress_overrides": {"telegram": "verbose"},
|
||||
}
|
||||
}
|
||||
# show_reasoning should NOT read from tool_progress_overrides
|
||||
assert resolve_display_setting(config, "telegram", "show_reasoning") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# YAML normalisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestYAMLNormalisation:
|
||||
"""YAML 1.1 quirks (bare off → False, on → True) are handled."""
|
||||
|
||||
def test_tool_progress_false_normalised_to_off(self):
|
||||
"""YAML's bare `off` parses as False — normalised to 'off' string."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"tool_progress": False}}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "off"
|
||||
|
||||
def test_tool_progress_true_normalised_to_all(self):
|
||||
"""YAML's bare `on` parses as True — normalised to 'all'."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"tool_progress": True}}
|
||||
assert resolve_display_setting(config, "telegram", "tool_progress") == "all"
|
||||
|
||||
def test_show_reasoning_string_true(self):
|
||||
"""String 'true' is normalised to bool True."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"platforms": {"telegram": {"show_reasoning": "true"}}}}
|
||||
assert resolve_display_setting(config, "telegram", "show_reasoning") is True
|
||||
|
||||
def test_tool_preview_length_string(self):
|
||||
"""String numbers are normalised to int."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"platforms": {"slack": {"tool_preview_length": "80"}}}}
|
||||
assert resolve_display_setting(config, "slack", "tool_preview_length") == 80
|
||||
|
||||
def test_platform_override_false_tool_progress(self):
|
||||
"""Per-platform bare off → normalised."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {"display": {"platforms": {"slack": {"tool_progress": False}}}}
|
||||
assert resolve_display_setting(config, "slack", "tool_progress") == "off"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Built-in platform defaults (tier system)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPlatformDefaults:
|
||||
"""Built-in defaults reflect platform capability tiers."""
|
||||
|
||||
def test_high_tier_platforms(self):
|
||||
"""Telegram and Discord default to 'all' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("telegram", "discord"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "all", plat
|
||||
|
||||
def test_medium_tier_platforms(self):
|
||||
"""Slack, Mattermost, Matrix default to 'new' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("slack", "mattermost", "matrix", "feishu"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "new", plat
|
||||
|
||||
def test_low_tier_platforms(self):
|
||||
"""Signal, WhatsApp, etc. default to 'off' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("signal", "whatsapp", "bluebubbles", "weixin", "wecom", "dingtalk"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "off", plat
|
||||
|
||||
def test_minimal_tier_platforms(self):
|
||||
"""Email, SMS, webhook default to 'off' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("email", "sms", "webhook", "homeassistant"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "off", plat
|
||||
|
||||
def test_low_tier_streaming_defaults_to_false(self):
|
||||
"""Low-tier platforms default streaming to False."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
assert resolve_display_setting({}, "signal", "streaming") is False
|
||||
assert resolve_display_setting({}, "email", "streaming") is False
|
||||
|
||||
def test_high_tier_streaming_defaults_to_none(self):
|
||||
"""High-tier platforms default streaming to None (follow global)."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
assert resolve_display_setting({}, "telegram", "streaming") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_effective_display / get_platform_defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHelpers:
|
||||
"""Helper functions return correct composite results."""
|
||||
|
||||
def test_get_effective_display_merges_correctly(self):
|
||||
from gateway.display_config import get_effective_display
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"tool_progress": "new",
|
||||
"show_reasoning": True,
|
||||
"platforms": {
|
||||
"telegram": {"tool_progress": "verbose"},
|
||||
},
|
||||
}
|
||||
}
|
||||
eff = get_effective_display(config, "telegram")
|
||||
assert eff["tool_progress"] == "verbose" # platform override
|
||||
assert eff["show_reasoning"] is True # global
|
||||
assert "tool_preview_length" in eff # default filled in
|
||||
|
||||
def test_get_platform_defaults_returns_dict(self):
|
||||
from gateway.display_config import get_platform_defaults
|
||||
|
||||
defaults = get_platform_defaults("telegram")
|
||||
assert "tool_progress" in defaults
|
||||
assert "show_reasoning" in defaults
|
||||
# Returns a new dict (not the shared tier dict)
|
||||
defaults["tool_progress"] = "changed"
|
||||
assert get_platform_defaults("telegram")["tool_progress"] != "changed"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config migration: tool_progress_overrides → display.platforms
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConfigMigration:
|
||||
"""Version 16 migration moves tool_progress_overrides into display.platforms."""
|
||||
|
||||
def test_migration_creates_platforms_entries(self, tmp_path, monkeypatch):
|
||||
"""Old overrides are migrated into display.platforms.<plat>.tool_progress."""
|
||||
import yaml
|
||||
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config = {
|
||||
"_config_version": 15,
|
||||
"display": {
|
||||
"tool_progress_overrides": {
|
||||
"signal": "off",
|
||||
"telegram": "all",
|
||||
},
|
||||
},
|
||||
}
|
||||
config_path.write_text(yaml.dump(config))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Re-import to pick up the new HERMES_HOME
|
||||
import importlib
|
||||
import hermes_cli.config as cfg_mod
|
||||
importlib.reload(cfg_mod)
|
||||
|
||||
result = cfg_mod.migrate_config(interactive=False, quiet=True)
|
||||
# Re-read config
|
||||
updated = yaml.safe_load(config_path.read_text())
|
||||
platforms = updated.get("display", {}).get("platforms", {})
|
||||
assert platforms.get("signal", {}).get("tool_progress") == "off"
|
||||
assert platforms.get("telegram", {}).get("tool_progress") == "all"
|
||||
|
||||
def test_migration_preserves_existing_platforms_entries(self, tmp_path, monkeypatch):
|
||||
"""Existing display.platforms entries are NOT overwritten by migration."""
|
||||
import yaml
|
||||
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config = {
|
||||
"_config_version": 15,
|
||||
"display": {
|
||||
"tool_progress_overrides": {"telegram": "off"},
|
||||
"platforms": {"telegram": {"tool_progress": "verbose"}},
|
||||
},
|
||||
}
|
||||
config_path.write_text(yaml.dump(config))
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
import importlib
|
||||
import hermes_cli.config as cfg_mod
|
||||
importlib.reload(cfg_mod)
|
||||
|
||||
cfg_mod.migrate_config(interactive=False, quiet=True)
|
||||
updated = yaml.safe_load(config_path.read_text())
|
||||
# Existing "verbose" should NOT be overwritten by legacy "off"
|
||||
assert updated["display"]["platforms"]["telegram"]["tool_progress"] == "verbose"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming per-platform (None = follow global)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStreamingPerPlatform:
|
||||
"""Streaming per-platform override semantics."""
|
||||
|
||||
def test_none_means_follow_global(self):
|
||||
"""When streaming is None, the caller should use global config."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {}
|
||||
# Telegram has no streaming override in defaults → None
|
||||
result = resolve_display_setting(config, "telegram", "streaming")
|
||||
assert result is None # caller should check global StreamingConfig
|
||||
|
||||
def test_explicit_false_disables(self):
|
||||
"""Explicit False disables streaming for that platform."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"platforms": {"telegram": {"streaming": False}},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "telegram", "streaming") is False
|
||||
|
||||
def test_explicit_true_enables(self):
|
||||
"""Explicit True enables streaming for that platform."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
config = {
|
||||
"display": {
|
||||
"platforms": {"email": {"streaming": True}},
|
||||
}
|
||||
}
|
||||
assert resolve_display_setting(config, "email", "streaming") is True
|
||||
|
|
@ -157,12 +157,44 @@ def _make_fake_mautrix():
|
|||
mautrix_crypto_store = types.ModuleType("mautrix.crypto.store")
|
||||
|
||||
class MemoryCryptoStore:
|
||||
def __init__(self, account_id="", pickle_key=""):
|
||||
def __init__(self, account_id="", pickle_key=""): # noqa: S301
|
||||
self.account_id = account_id
|
||||
self.pickle_key = pickle_key
|
||||
|
||||
mautrix_crypto_store.MemoryCryptoStore = MemoryCryptoStore
|
||||
|
||||
# --- mautrix.crypto.store.asyncpg ---
|
||||
mautrix_crypto_store_asyncpg = types.ModuleType("mautrix.crypto.store.asyncpg")
|
||||
|
||||
class PgCryptoStore:
|
||||
upgrade_table = MagicMock()
|
||||
|
||||
def __init__(self, account_id="", pickle_key="", db=None): # noqa: S301
|
||||
self.account_id = account_id
|
||||
self.pickle_key = pickle_key
|
||||
self.db = db
|
||||
|
||||
async def open(self):
|
||||
pass
|
||||
|
||||
mautrix_crypto_store_asyncpg.PgCryptoStore = PgCryptoStore
|
||||
|
||||
# --- mautrix.util ---
|
||||
mautrix_util = types.ModuleType("mautrix.util")
|
||||
|
||||
# --- mautrix.util.async_db ---
|
||||
mautrix_util_async_db = types.ModuleType("mautrix.util.async_db")
|
||||
|
||||
class Database:
|
||||
@classmethod
|
||||
def create(cls, url, upgrade_table=None):
|
||||
db = MagicMock()
|
||||
db.start = AsyncMock()
|
||||
db.stop = AsyncMock()
|
||||
return db
|
||||
|
||||
mautrix_util_async_db.Database = Database
|
||||
|
||||
return {
|
||||
"mautrix": mautrix,
|
||||
"mautrix.api": mautrix_api,
|
||||
|
|
@ -171,6 +203,9 @@ def _make_fake_mautrix():
|
|||
"mautrix.client.state_store": mautrix_client_state_store,
|
||||
"mautrix.crypto": mautrix_crypto,
|
||||
"mautrix.crypto.store": mautrix_crypto_store,
|
||||
"mautrix.crypto.store.asyncpg": mautrix_crypto_store_asyncpg,
|
||||
"mautrix.util": mautrix_util,
|
||||
"mautrix.util.async_db": mautrix_util_async_db,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -740,6 +775,12 @@ class TestMatrixAccessTokenAuth:
|
|||
mock_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123"))
|
||||
mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}})
|
||||
mock_client.add_event_handler = MagicMock()
|
||||
mock_client.handle_sync = MagicMock(return_value=[])
|
||||
mock_client.query_keys = AsyncMock(return_value={
|
||||
"device_keys": {"@bot:example.org": {"DEV123": {
|
||||
"keys": {"ed25519:DEV123": "fake_ed25519_key"},
|
||||
}}},
|
||||
})
|
||||
mock_client.api = MagicMock()
|
||||
mock_client.api.token = "syt_test_access_token"
|
||||
mock_client.api.session = MagicMock()
|
||||
|
|
@ -751,6 +792,8 @@ class TestMatrixAccessTokenAuth:
|
|||
mock_olm.share_keys = AsyncMock()
|
||||
mock_olm.share_keys_min_trust = None
|
||||
mock_olm.send_keys_min_trust = None
|
||||
mock_olm.account = MagicMock()
|
||||
mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"}
|
||||
|
||||
# Patch Client constructor to return our mock
|
||||
fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client)
|
||||
|
|
@ -924,6 +967,12 @@ class TestMatrixDeviceId:
|
|||
mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="WHOAMI_DEV"))
|
||||
mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}})
|
||||
mock_client.add_event_handler = MagicMock()
|
||||
mock_client.handle_sync = MagicMock(return_value=[])
|
||||
mock_client.query_keys = AsyncMock(return_value={
|
||||
"device_keys": {"@bot:example.org": {"MY_STABLE_DEVICE": {
|
||||
"keys": {"ed25519:MY_STABLE_DEVICE": "fake_ed25519_key"},
|
||||
}}},
|
||||
})
|
||||
mock_client.api = MagicMock()
|
||||
mock_client.api.token = "syt_test_access_token"
|
||||
mock_client.api.session = MagicMock()
|
||||
|
|
@ -934,6 +983,8 @@ class TestMatrixDeviceId:
|
|||
mock_olm.share_keys = AsyncMock()
|
||||
mock_olm.share_keys_min_trust = None
|
||||
mock_olm.send_keys_min_trust = None
|
||||
mock_olm.account = MagicMock()
|
||||
mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"}
|
||||
|
||||
fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client)
|
||||
fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(return_value=mock_olm)
|
||||
|
|
@ -1030,8 +1081,8 @@ class TestMatrixDeviceIdConfig:
|
|||
|
||||
class TestMatrixSyncLoop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_loop_shares_keys_when_encryption_enabled(self):
|
||||
"""_sync_loop should call crypto.share_keys() after each sync."""
|
||||
async def test_sync_loop_dispatches_events_and_stores_token(self):
|
||||
"""_sync_loop should call handle_sync() and persist next_batch."""
|
||||
adapter = _make_adapter()
|
||||
adapter._encryption = True
|
||||
adapter._closing = False
|
||||
|
|
@ -1046,7 +1097,6 @@ class TestMatrixSyncLoop:
|
|||
return {"rooms": {"join": {"!room:example.org": {}}}, "next_batch": "s1234"}
|
||||
|
||||
mock_crypto = MagicMock()
|
||||
mock_crypto.share_keys = AsyncMock()
|
||||
|
||||
mock_sync_store = MagicMock()
|
||||
mock_sync_store.get_next_batch = AsyncMock(return_value=None)
|
||||
|
|
@ -1062,7 +1112,6 @@ class TestMatrixSyncLoop:
|
|||
await adapter._sync_loop()
|
||||
|
||||
fake_client.sync.assert_awaited_once()
|
||||
mock_crypto.share_keys.assert_awaited_once()
|
||||
fake_client.handle_sync.assert_called_once()
|
||||
mock_sync_store.put_next_batch.assert_awaited_once_with("s1234")
|
||||
|
||||
|
|
@ -1248,6 +1297,12 @@ class TestMatrixEncryptedEventHandler:
|
|||
mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="DEV123"))
|
||||
mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}})
|
||||
mock_client.add_event_handler = MagicMock()
|
||||
mock_client.handle_sync = MagicMock(return_value=[])
|
||||
mock_client.query_keys = AsyncMock(return_value={
|
||||
"device_keys": {"@bot:example.org": {"DEV123": {
|
||||
"keys": {"ed25519:DEV123": "fake_ed25519_key"},
|
||||
}}},
|
||||
})
|
||||
mock_client.api = MagicMock()
|
||||
mock_client.api.token = "syt_test_token"
|
||||
mock_client.api.session = MagicMock()
|
||||
|
|
@ -1258,6 +1313,8 @@ class TestMatrixEncryptedEventHandler:
|
|||
mock_olm.share_keys = AsyncMock()
|
||||
mock_olm.share_keys_min_trust = None
|
||||
mock_olm.send_keys_min_trust = None
|
||||
mock_olm.account = MagicMock()
|
||||
mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"}
|
||||
|
||||
fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client)
|
||||
fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(return_value=mock_olm)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from gateway.run import _dequeue_pending_event
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
|
|
@ -79,6 +80,26 @@ class TestQueueMessageStorage:
|
|||
# Should be consumed (cleared)
|
||||
assert adapter.get_pending_message(session_key) is None
|
||||
|
||||
def test_dequeue_pending_event_preserves_voice_media_metadata(self):
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:voice"
|
||||
event = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.VOICE,
|
||||
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
|
||||
message_id="voice-q1",
|
||||
media_urls=["/tmp/voice.ogg"],
|
||||
media_types=["audio/ogg"],
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
retrieved = _dequeue_pending_event(adapter, session_key)
|
||||
|
||||
assert retrieved is event
|
||||
assert retrieved.media_urls == ["/tmp/voice.ogg"]
|
||||
assert retrieved.media_types == ["audio/ogg"]
|
||||
assert adapter.get_pending_message(session_key) is None
|
||||
|
||||
def test_queue_does_not_set_interrupt_event(self):
|
||||
"""The whole point of /queue — no interrupt signal."""
|
||||
adapter = _StubAdapter()
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from types import SimpleNamespace
|
|||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult
|
||||
from gateway.config import Platform, PlatformConfig, StreamingConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
|
|
@ -104,6 +104,11 @@ def _make_runner(adapter):
|
|||
runner._session_db = None
|
||||
runner._running_agents = {}
|
||||
runner.hooks = SimpleNamespace(loaded_hooks=False)
|
||||
runner.config = SimpleNamespace(
|
||||
thread_sessions_per_user=False,
|
||||
group_sessions_per_user=False,
|
||||
stt_enabled=False,
|
||||
)
|
||||
return runner
|
||||
|
||||
|
||||
|
|
@ -118,6 +123,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
|
|||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = FakeAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
import tools.terminal_tool # noqa: F401 - register terminal emoji for this fake-agent test
|
||||
|
||||
adapter = ProgressCaptureAdapter()
|
||||
runner = _make_runner(adapter)
|
||||
|
|
@ -144,7 +150,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
|
|||
assert adapter.sent == [
|
||||
{
|
||||
"chat_id": "-1001",
|
||||
"content": '⚙️ terminal: "pwd"',
|
||||
"content": '💻 terminal: "pwd"',
|
||||
"reply_to": None,
|
||||
"metadata": {"thread_id": "17585"},
|
||||
}
|
||||
|
|
@ -334,3 +340,238 @@ def test_all_mode_no_truncation_when_preview_fits(monkeypatch, tmp_path):
|
|||
content = adapter.sent[0]["content"]
|
||||
# With a 200-char cap, the 165-char command should NOT be truncated
|
||||
assert "..." not in content, f"Preview was truncated when it shouldn't be: {content}"
|
||||
|
||||
|
||||
class CommentaryAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.stream_delta_callback = kwargs.get("stream_delta_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
if self.interim_assistant_callback:
|
||||
self.interim_assistant_callback("I'll inspect the repo first.", already_streamed=False)
|
||||
time.sleep(0.1)
|
||||
if self.stream_delta_callback:
|
||||
self.stream_delta_callback("done")
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
class PreviewedResponseAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
if self.interim_assistant_callback:
|
||||
self.interim_assistant_callback("You're welcome.", already_streamed=False)
|
||||
return {
|
||||
"final_response": "You're welcome.",
|
||||
"response_previewed": True,
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
class QueuedCommentaryAgent:
|
||||
calls = 0
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
type(self).calls += 1
|
||||
if type(self).calls == 1 and self.interim_assistant_callback:
|
||||
self.interim_assistant_callback("I'll inspect the repo first.", already_streamed=False)
|
||||
return {
|
||||
"final_response": f"final response {type(self).calls}",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
async def _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
agent_cls,
|
||||
*,
|
||||
session_id,
|
||||
pending_text=None,
|
||||
config_data=None,
|
||||
):
|
||||
if config_data:
|
||||
import yaml
|
||||
|
||||
(tmp_path / "config.yaml").write_text(yaml.dump(config_data), encoding="utf-8")
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = agent_cls
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
adapter = ProgressCaptureAdapter()
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
if config_data and "streaming" in config_data:
|
||||
runner.config.streaming = StreamingConfig.from_dict(config_data["streaming"])
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
session_key = "agent:main:telegram:group:-1001:17585"
|
||||
if pending_text is not None:
|
||||
adapter._pending_messages[session_key] = MessageEvent(
|
||||
text=pending_text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id="queued-1",
|
||||
)
|
||||
|
||||
result = await runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id=session_id,
|
||||
session_key=session_key,
|
||||
)
|
||||
return adapter, result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_surfaces_real_interim_commentary(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary",
|
||||
config_data={"display": {"interim_assistant_messages": True}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_surfaces_interim_commentary_by_default(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-default-on",
|
||||
)
|
||||
|
||||
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_suppresses_interim_commentary_when_disabled(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-disabled",
|
||||
config_data={"display": {"interim_assistant_messages": False}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_tool_progress_does_not_control_interim_commentary(monkeypatch, tmp_path):
|
||||
"""tool_progress=all with interim_assistant_messages=false should not surface commentary."""
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-tool-progress",
|
||||
config_data={"display": {"tool_progress": "all", "interim_assistant_messages": False}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_streaming_does_not_enable_completed_interim_commentary(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
"""Streaming alone with interim_assistant_messages=false should not surface commentary."""
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-streaming",
|
||||
config_data={
|
||||
"display": {"tool_progress": "off", "interim_assistant_messages": False},
|
||||
"streaming": {"enabled": True},
|
||||
},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is True
|
||||
assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_interim_commentary_works_with_tool_progress_off(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
CommentaryAgent,
|
||||
session_id="sess-commentary-explicit-on",
|
||||
config_data={
|
||||
"display": {
|
||||
"tool_progress": "off",
|
||||
"interim_assistant_messages": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is not True
|
||||
assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_previewed_final_marks_already_sent(monkeypatch, tmp_path):
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
PreviewedResponseAgent,
|
||||
session_id="sess-previewed",
|
||||
config_data={"display": {"interim_assistant_messages": True}},
|
||||
)
|
||||
|
||||
assert result.get("already_sent") is True
|
||||
assert [call["content"] for call in adapter.sent] == ["You're welcome."]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_queued_message_does_not_treat_commentary_as_final(monkeypatch, tmp_path):
|
||||
QueuedCommentaryAgent.calls = 0
|
||||
adapter, result = await _run_with_agent(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
QueuedCommentaryAgent,
|
||||
session_id="sess-queued-commentary",
|
||||
pending_text="queued follow-up",
|
||||
config_data={"display": {"interim_assistant_messages": True}},
|
||||
)
|
||||
|
||||
sent_texts = [call["content"] for call in adapter.sent]
|
||||
assert result["final_response"] == "final response 2"
|
||||
assert "I'll inspect the repo first." in sent_texts
|
||||
assert "final response 1" in sent_texts
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
|
|
@ -45,6 +46,23 @@ class _DisabledAdapter(BasePlatformAdapter):
|
|||
return {"id": chat_id}
|
||||
|
||||
|
||||
class _SuccessfulAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.DISCORD)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
|
@ -65,7 +83,7 @@ async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch,
|
|||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "startup_failed"
|
||||
assert "temporary DNS resolution failure" in state["exit_reason"]
|
||||
assert state["platforms"]["telegram"]["state"] == "fatal"
|
||||
assert state["platforms"]["telegram"]["state"] == "retrying"
|
||||
assert state["platforms"]["telegram"]["error_code"] == "telegram_connect_error"
|
||||
|
||||
|
||||
|
|
@ -89,6 +107,31 @@ async def test_runner_allows_cron_only_mode_when_no_platforms_are_enabled(monkey
|
|||
assert state["gateway_state"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_records_connected_platform_state_on_success(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.DISCORD: PlatformConfig(enabled=True, token="***")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _SuccessfulAdapter())
|
||||
monkeypatch.setattr(runner.hooks, "discover_and_load", lambda: None)
|
||||
monkeypatch.setattr(runner.hooks, "emit", AsyncMock())
|
||||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is True
|
||||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "running"
|
||||
assert state["platforms"]["discord"]["state"] == "connected"
|
||||
assert state["platforms"]["discord"]["error_code"] is None
|
||||
assert state["platforms"]["discord"]["error_message"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import os
|
||||
|
||||
from gateway.config import Platform
|
||||
|
|
@ -130,3 +131,99 @@ def test_set_session_env_handles_missing_optional_fields():
|
|||
assert get_session_env("HERMES_SESSION_THREAD_ID") == ""
|
||||
|
||||
runner._clear_session_env(tokens)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SESSION_KEY contextvars tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_session_key_set_via_contextvars(monkeypatch):
|
||||
"""set_session_vars should set HERMES_SESSION_KEY via contextvars."""
|
||||
monkeypatch.delenv("HERMES_SESSION_KEY", raising=False)
|
||||
|
||||
tokens = set_session_vars(
|
||||
platform="telegram",
|
||||
chat_id="-1001",
|
||||
session_key="tg:-1001:17585",
|
||||
)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585"
|
||||
|
||||
clear_session_vars(tokens)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == ""
|
||||
|
||||
|
||||
def test_session_key_falls_back_to_os_environ(monkeypatch):
|
||||
"""get_session_env for SESSION_KEY should fall back to os.environ."""
|
||||
monkeypatch.setenv("HERMES_SESSION_KEY", "env-session-123")
|
||||
|
||||
# No contextvar set — should read from os.environ
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "env-session-123"
|
||||
|
||||
# Set contextvar — should prefer it
|
||||
tokens = set_session_vars(session_key="ctx-session-456")
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "ctx-session-456"
|
||||
|
||||
# Restore — should fall back to os.environ
|
||||
clear_session_vars(tokens)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "env-session-123"
|
||||
|
||||
|
||||
def test_set_session_env_includes_session_key():
|
||||
"""_set_session_env should propagate session_key from SessionContext."""
|
||||
runner = object.__new__(GatewayRunner)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1001",
|
||||
chat_name="Group",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
)
|
||||
context = SessionContext(
|
||||
source=source,
|
||||
connected_platforms=[],
|
||||
home_channels={},
|
||||
session_key="tg:-1001:17585",
|
||||
)
|
||||
|
||||
tokens = runner._set_session_env(context)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585"
|
||||
runner._clear_session_env(tokens)
|
||||
assert get_session_env("HERMES_SESSION_KEY") == ""
|
||||
|
||||
|
||||
def test_session_key_no_race_condition_with_contextvars(monkeypatch):
|
||||
"""Prove contextvars isolates SESSION_KEY across concurrent async tasks.
|
||||
|
||||
Two tasks set different session keys. With contextvars each task
|
||||
reads back its own value. With os.environ the second task would
|
||||
overwrite the first (the old bug).
|
||||
"""
|
||||
monkeypatch.delenv("HERMES_SESSION_KEY", raising=False)
|
||||
|
||||
results = {}
|
||||
|
||||
async def handler(key: str, delay: float):
|
||||
tokens = set_session_vars(session_key=key)
|
||||
try:
|
||||
await asyncio.sleep(delay)
|
||||
read_back = get_session_env("HERMES_SESSION_KEY")
|
||||
results[key] = read_back
|
||||
finally:
|
||||
clear_session_vars(tokens)
|
||||
|
||||
async def run():
|
||||
task_a = asyncio.create_task(handler("session-A", 0.15))
|
||||
await asyncio.sleep(0.05)
|
||||
task_b = asyncio.create_task(handler("session-B", 0.05))
|
||||
await asyncio.gather(task_a, task_b)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
# Both tasks must read back their own session key
|
||||
assert results["session-A"] == "session-A", (
|
||||
f"Session A got '{results['session-A']}' instead of 'session-A' — race condition!"
|
||||
)
|
||||
assert results["session-B"] == "session-B", (
|
||||
f"Session B got '{results['session-B']}' instead of 'session-B' — race condition!"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -104,6 +104,34 @@ class TestGatewayRuntimeStatus:
|
|||
assert payload["platforms"]["telegram"]["error_code"] == "telegram_polling_conflict"
|
||||
assert payload["platforms"]["telegram"]["error_message"] == "another poller is active"
|
||||
|
||||
def test_write_runtime_status_explicit_none_clears_stale_fields(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
status.write_runtime_status(
|
||||
gateway_state="startup_failed",
|
||||
exit_reason="stale error",
|
||||
platform="discord",
|
||||
platform_state="fatal",
|
||||
error_code="discord_timeout",
|
||||
error_message="stale platform error",
|
||||
)
|
||||
|
||||
status.write_runtime_status(
|
||||
gateway_state="running",
|
||||
exit_reason=None,
|
||||
platform="discord",
|
||||
platform_state="connected",
|
||||
error_code=None,
|
||||
error_message=None,
|
||||
)
|
||||
|
||||
payload = status.read_runtime_status()
|
||||
assert payload["gateway_state"] == "running"
|
||||
assert payload["exit_reason"] is None
|
||||
assert payload["platforms"]["discord"]["state"] == "connected"
|
||||
assert payload["platforms"]["discord"]["error_code"] is None
|
||||
assert payload["platforms"]["discord"]["error_message"] is None
|
||||
|
||||
|
||||
class TestTerminatePid:
|
||||
def test_force_uses_taskkill_on_windows(self, monkeypatch):
|
||||
|
|
|
|||
|
|
@ -505,3 +505,81 @@ class TestSegmentBreakOnToolBoundary:
|
|||
assert len(sent_texts) == 3
|
||||
assert sent_texts[0].startswith(prefix)
|
||||
assert sum(len(t) for t in sent_texts[1:]) == len(tail)
|
||||
|
||||
|
||||
class TestInterimCommentaryMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_commentary_message_stays_separate_from_final_stream(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(side_effect=[
|
||||
SimpleNamespace(success=True, message_id="msg_1"),
|
||||
SimpleNamespace(success=True, message_id="msg_2"),
|
||||
])
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter,
|
||||
"chat_123",
|
||||
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5),
|
||||
)
|
||||
|
||||
consumer.on_commentary("I'll inspect the repository first.")
|
||||
consumer.on_delta("Done.")
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert sent_texts == ["I'll inspect the repository first.", "Done."]
|
||||
assert consumer.final_response_sent is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_final_send_does_not_mark_final_response_sent(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(return_value=SimpleNamespace(success=False, message_id=None))
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter,
|
||||
"chat_123",
|
||||
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5),
|
||||
)
|
||||
|
||||
consumer.on_delta("Done.")
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
assert consumer.final_response_sent is False
|
||||
assert consumer.already_sent is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_without_message_id_marks_visible_and_sends_only_tail(self):
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(side_effect=[
|
||||
SimpleNamespace(success=True, message_id=None),
|
||||
SimpleNamespace(success=True, message_id=None),
|
||||
])
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter,
|
||||
"chat_123",
|
||||
StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉"),
|
||||
)
|
||||
|
||||
consumer.on_delta("Hello")
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(" world")
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.finish()
|
||||
await task
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert sent_texts == ["Hello ▉", "world"]
|
||||
assert consumer.already_sent is True
|
||||
assert consumer.final_response_sent is True
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
import yaml
|
||||
|
||||
from gateway.config import GatewayConfig, load_gateway_config
|
||||
from gateway.config import GatewayConfig, Platform, load_gateway_config
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def test_gateway_config_stt_disabled_from_dict_nested():
|
||||
|
|
@ -69,3 +71,46 @@ async def test_enrich_message_with_transcription_avoids_bogus_no_provider_messag
|
|||
assert "No STT provider is configured" not in result
|
||||
assert "trouble transcribing" in result
|
||||
assert "caption" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_inbound_message_text_transcribes_queued_voice_event():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(stt_enabled=True)
|
||||
runner.adapters = {}
|
||||
runner._model = "test-model"
|
||||
runner._base_url = ""
|
||||
runner._has_setup_skill = lambda: False
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="",
|
||||
message_type=MessageType.VOICE,
|
||||
source=source,
|
||||
media_urls=["/tmp/queued-voice.ogg"],
|
||||
media_types=["audio/ogg"],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"tools.transcription_tools.transcribe_audio",
|
||||
return_value={
|
||||
"success": True,
|
||||
"transcript": "queued voice transcript",
|
||||
"provider": "local_command",
|
||||
},
|
||||
):
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[],
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert "queued voice transcript" in result
|
||||
assert "voice message" in result.lower()
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class TestVerboseCommand:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enabled_cycles_mode(self, tmp_path, monkeypatch):
|
||||
"""When enabled, /verbose cycles tool_progress mode."""
|
||||
"""When enabled, /verbose cycles tool_progress mode per-platform."""
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
|
|
@ -79,10 +79,11 @@ class TestVerboseCommand:
|
|||
|
||||
# all -> verbose
|
||||
assert "VERBOSE" in result
|
||||
assert "telegram" in result.lower() # per-platform feedback
|
||||
|
||||
# Verify config was saved
|
||||
# Verify config was saved to display.platforms.telegram
|
||||
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["display"]["tool_progress"] == "verbose"
|
||||
assert saved["display"]["platforms"]["telegram"]["tool_progress"] == "verbose"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cycles_through_all_modes(self, tmp_path, monkeypatch):
|
||||
|
|
@ -103,8 +104,9 @@ class TestVerboseCommand:
|
|||
for mode in expected:
|
||||
result = await runner._handle_verbose_command(_make_event())
|
||||
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["display"]["tool_progress"] == mode, \
|
||||
f"Expected {mode}, got {saved['display']['tool_progress']}"
|
||||
actual = saved["display"]["platforms"]["telegram"]["tool_progress"]
|
||||
assert actual == mode, \
|
||||
f"Expected {mode}, got {actual}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_defaults_to_all_when_no_tool_progress_set(self, tmp_path, monkeypatch):
|
||||
|
|
@ -122,10 +124,45 @@ class TestVerboseCommand:
|
|||
runner = _make_runner()
|
||||
result = await runner._handle_verbose_command(_make_event())
|
||||
|
||||
# default "all" -> verbose
|
||||
# Telegram default is "all" (high tier) → cycles to verbose
|
||||
assert "VERBOSE" in result
|
||||
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["display"]["tool_progress"] == "verbose"
|
||||
assert saved["display"]["platforms"]["telegram"]["tool_progress"] == "verbose"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_per_platform_isolation(self, tmp_path, monkeypatch):
|
||||
"""Cycling /verbose on Telegram doesn't change Slack's setting.
|
||||
|
||||
Without a global tool_progress, each platform uses its built-in
|
||||
default: Telegram = 'all' (high tier), Slack = 'new' (medium tier).
|
||||
"""
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
# No global tool_progress → built-in platform defaults apply
|
||||
config_path.write_text(
|
||||
"display:\n tool_progress_command: true\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||
runner = _make_runner()
|
||||
|
||||
# Cycle on Telegram
|
||||
await runner._handle_verbose_command(
|
||||
_make_event(platform=Platform.TELEGRAM)
|
||||
)
|
||||
# Cycle on Slack
|
||||
await runner._handle_verbose_command(
|
||||
_make_event(platform=Platform.SLACK)
|
||||
)
|
||||
|
||||
saved = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
platforms = saved["display"]["platforms"]
|
||||
# Telegram: all -> verbose (high tier default = all)
|
||||
assert platforms["telegram"]["tool_progress"] == "verbose"
|
||||
# Slack: new -> all (medium tier default = new, cycle to all)
|
||||
assert platforms["slack"]["tool_progress"] == "all"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_config_file_returns_disabled(self, tmp_path, monkeypatch):
|
||||
|
|
|
|||
185
tests/gateway/test_wecom_callback.py
Normal file
185
tests/gateway/test_wecom_callback.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
"""Tests for the WeCom callback-mode adapter."""
|
||||
|
||||
import asyncio
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.wecom_callback import WecomCallbackAdapter
|
||||
from gateway.platforms.wecom_crypto import WXBizMsgCrypt
|
||||
|
||||
|
||||
def _app(name="test-app", corp_id="ww1234567890", agent_id="1000002"):
|
||||
return {
|
||||
"name": name,
|
||||
"corp_id": corp_id,
|
||||
"corp_secret": "test-secret",
|
||||
"agent_id": agent_id,
|
||||
"token": "test-callback-token",
|
||||
"encoding_aes_key": "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG",
|
||||
}
|
||||
|
||||
|
||||
def _config(apps=None):
|
||||
return PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"mode": "callback", "host": "127.0.0.1", "port": 0, "apps": apps or [_app()]},
|
||||
)
|
||||
|
||||
|
||||
class TestWecomCrypto:
|
||||
def test_roundtrip_encrypt_decrypt(self):
|
||||
app = _app()
|
||||
crypt = WXBizMsgCrypt(app["token"], app["encoding_aes_key"], app["corp_id"])
|
||||
encrypted_xml = crypt.encrypt(
|
||||
"<xml><Content>hello</Content></xml>", nonce="nonce123", timestamp="123456",
|
||||
)
|
||||
root = ET.fromstring(encrypted_xml)
|
||||
decrypted = crypt.decrypt(
|
||||
root.findtext("MsgSignature", default=""),
|
||||
root.findtext("TimeStamp", default=""),
|
||||
root.findtext("Nonce", default=""),
|
||||
root.findtext("Encrypt", default=""),
|
||||
)
|
||||
assert b"<Content>hello</Content>" in decrypted
|
||||
|
||||
def test_signature_mismatch_raises(self):
|
||||
app = _app()
|
||||
crypt = WXBizMsgCrypt(app["token"], app["encoding_aes_key"], app["corp_id"])
|
||||
encrypted_xml = crypt.encrypt("<xml/>", nonce="n", timestamp="1")
|
||||
root = ET.fromstring(encrypted_xml)
|
||||
from gateway.platforms.wecom_crypto import SignatureError
|
||||
with pytest.raises(SignatureError):
|
||||
crypt.decrypt("bad-sig", "1", "n", root.findtext("Encrypt", default=""))
|
||||
|
||||
|
||||
class TestWecomCallbackEventConstruction:
|
||||
def test_build_event_extracts_text_message(self):
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
xml_text = """
|
||||
<xml>
|
||||
<ToUserName>ww1234567890</ToUserName>
|
||||
<FromUserName>zhangsan</FromUserName>
|
||||
<CreateTime>1710000000</CreateTime>
|
||||
<MsgType>text</MsgType>
|
||||
<Content>\u4f60\u597d</Content>
|
||||
<MsgId>123456789</MsgId>
|
||||
</xml>
|
||||
"""
|
||||
event = adapter._build_event(_app(), xml_text)
|
||||
assert event is not None
|
||||
assert event.source is not None
|
||||
assert event.source.user_id == "zhangsan"
|
||||
assert event.source.chat_id == "ww1234567890:zhangsan"
|
||||
assert event.message_id == "123456789"
|
||||
assert event.text == "\u4f60\u597d"
|
||||
|
||||
def test_build_event_returns_none_for_subscribe(self):
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
xml_text = """
|
||||
<xml>
|
||||
<ToUserName>ww1234567890</ToUserName>
|
||||
<FromUserName>zhangsan</FromUserName>
|
||||
<CreateTime>1710000000</CreateTime>
|
||||
<MsgType>event</MsgType>
|
||||
<Event>subscribe</Event>
|
||||
</xml>
|
||||
"""
|
||||
event = adapter._build_event(_app(), xml_text)
|
||||
assert event is None
|
||||
|
||||
|
||||
class TestWecomCallbackRouting:
|
||||
def test_user_app_key_scopes_across_corps(self):
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
assert adapter._user_app_key("corpA", "alice") == "corpA:alice"
|
||||
assert adapter._user_app_key("corpB", "alice") == "corpB:alice"
|
||||
assert adapter._user_app_key("corpA", "alice") != adapter._user_app_key("corpB", "alice")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_selects_correct_app_for_scoped_chat_id(self):
|
||||
apps = [
|
||||
_app(name="corp-a", corp_id="corpA", agent_id="1001"),
|
||||
_app(name="corp-b", corp_id="corpB", agent_id="2002"),
|
||||
]
|
||||
adapter = WecomCallbackAdapter(_config(apps=apps))
|
||||
adapter._user_app_map["corpB:alice"] = "corp-b"
|
||||
adapter._access_tokens["corp-b"] = {"token": "tok-b", "expires_at": 9999999999}
|
||||
|
||||
calls = {}
|
||||
|
||||
class FakeResponse:
|
||||
def json(self):
|
||||
return {"errcode": 0, "msgid": "ok1"}
|
||||
|
||||
class FakeClient:
|
||||
async def post(self, url, json):
|
||||
calls["url"] = url
|
||||
calls["json"] = json
|
||||
return FakeResponse()
|
||||
|
||||
adapter._http_client = FakeClient()
|
||||
result = await adapter.send("corpB:alice", "hello")
|
||||
|
||||
assert result.success is True
|
||||
assert calls["json"]["touser"] == "alice"
|
||||
assert calls["json"]["agentid"] == 2002
|
||||
assert "tok-b" in calls["url"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_falls_back_from_bare_user_id_when_unique(self):
|
||||
apps = [_app(name="corp-a", corp_id="corpA", agent_id="1001")]
|
||||
adapter = WecomCallbackAdapter(_config(apps=apps))
|
||||
adapter._user_app_map["corpA:alice"] = "corp-a"
|
||||
adapter._access_tokens["corp-a"] = {"token": "tok-a", "expires_at": 9999999999}
|
||||
|
||||
calls = {}
|
||||
|
||||
class FakeResponse:
|
||||
def json(self):
|
||||
return {"errcode": 0, "msgid": "ok2"}
|
||||
|
||||
class FakeClient:
|
||||
async def post(self, url, json):
|
||||
calls["url"] = url
|
||||
calls["json"] = json
|
||||
return FakeResponse()
|
||||
|
||||
adapter._http_client = FakeClient()
|
||||
result = await adapter.send("alice", "hello")
|
||||
|
||||
assert result.success is True
|
||||
assert calls["json"]["agentid"] == 1001
|
||||
|
||||
|
||||
class TestWecomCallbackPollLoop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_loop_dispatches_handle_message(self, monkeypatch):
|
||||
adapter = WecomCallbackAdapter(_config())
|
||||
calls = []
|
||||
|
||||
async def fake_handle_message(event):
|
||||
calls.append(event.text)
|
||||
|
||||
monkeypatch.setattr(adapter, "handle_message", fake_handle_message)
|
||||
event = adapter._build_event(
|
||||
_app(),
|
||||
"""
|
||||
<xml>
|
||||
<ToUserName>ww1234567890</ToUserName>
|
||||
<FromUserName>lisi</FromUserName>
|
||||
<CreateTime>1710000000</CreateTime>
|
||||
<MsgType>text</MsgType>
|
||||
<Content>test</Content>
|
||||
<MsgId>m2</MsgId>
|
||||
</xml>
|
||||
""",
|
||||
)
|
||||
task = asyncio.create_task(adapter._poll_loop())
|
||||
await adapter._message_queue.put(event)
|
||||
await asyncio.sleep(0.05)
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
assert calls == ["test"]
|
||||
897
tests/hermes_cli/test_backup.py
Normal file
897
tests/hermes_cli/test_backup.py
Normal file
|
|
@ -0,0 +1,897 @@
|
|||
"""Tests for hermes backup and import commands."""
|
||||
|
||||
import os
|
||||
import zipfile
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_hermes_tree(root: Path) -> None:
|
||||
"""Create a realistic ~/.hermes directory structure for testing."""
|
||||
(root / "config.yaml").write_text("model:\n provider: openrouter\n")
|
||||
(root / ".env").write_text("OPENROUTER_API_KEY=sk-test-123\n")
|
||||
(root / "memory_store.db").write_bytes(b"fake-sqlite")
|
||||
(root / "hermes_state.db").write_bytes(b"fake-state")
|
||||
|
||||
# Sessions
|
||||
(root / "sessions").mkdir(exist_ok=True)
|
||||
(root / "sessions" / "abc123.json").write_text("{}")
|
||||
|
||||
# Skills
|
||||
(root / "skills").mkdir(exist_ok=True)
|
||||
(root / "skills" / "my-skill").mkdir()
|
||||
(root / "skills" / "my-skill" / "SKILL.md").write_text("# My Skill\n")
|
||||
|
||||
# Skins
|
||||
(root / "skins").mkdir(exist_ok=True)
|
||||
(root / "skins" / "cyber.yaml").write_text("name: cyber\n")
|
||||
|
||||
# Cron
|
||||
(root / "cron").mkdir(exist_ok=True)
|
||||
(root / "cron" / "jobs.json").write_text("[]")
|
||||
|
||||
# Memories
|
||||
(root / "memories").mkdir(exist_ok=True)
|
||||
(root / "memories" / "notes.json").write_text("{}")
|
||||
|
||||
# Profiles
|
||||
(root / "profiles").mkdir(exist_ok=True)
|
||||
(root / "profiles" / "coder").mkdir()
|
||||
(root / "profiles" / "coder" / "config.yaml").write_text("model:\n provider: anthropic\n")
|
||||
(root / "profiles" / "coder" / ".env").write_text("ANTHROPIC_API_KEY=sk-ant-123\n")
|
||||
|
||||
# hermes-agent repo (should be EXCLUDED)
|
||||
(root / "hermes-agent").mkdir(exist_ok=True)
|
||||
(root / "hermes-agent" / "run_agent.py").write_text("# big file\n")
|
||||
(root / "hermes-agent" / ".git").mkdir()
|
||||
(root / "hermes-agent" / ".git" / "HEAD").write_text("ref: refs/heads/main\n")
|
||||
|
||||
# __pycache__ (should be EXCLUDED)
|
||||
(root / "plugins").mkdir(exist_ok=True)
|
||||
(root / "plugins" / "__pycache__").mkdir()
|
||||
(root / "plugins" / "__pycache__" / "mod.cpython-312.pyc").write_bytes(b"\x00")
|
||||
|
||||
# PID files (should be EXCLUDED)
|
||||
(root / "gateway.pid").write_text("12345")
|
||||
|
||||
# Logs (should be included)
|
||||
(root / "logs").mkdir(exist_ok=True)
|
||||
(root / "logs" / "agent.log").write_text("log line\n")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _should_exclude tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestShouldExclude:
|
||||
def test_excludes_hermes_agent(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert _should_exclude(Path("hermes-agent/run_agent.py"))
|
||||
assert _should_exclude(Path("hermes-agent/.git/HEAD"))
|
||||
|
||||
def test_excludes_pycache(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert _should_exclude(Path("plugins/__pycache__/mod.cpython-312.pyc"))
|
||||
|
||||
def test_excludes_pyc_files(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert _should_exclude(Path("some/module.pyc"))
|
||||
|
||||
def test_excludes_pid_files(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert _should_exclude(Path("gateway.pid"))
|
||||
assert _should_exclude(Path("cron.pid"))
|
||||
|
||||
def test_includes_config(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert not _should_exclude(Path("config.yaml"))
|
||||
|
||||
def test_includes_env(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert not _should_exclude(Path(".env"))
|
||||
|
||||
def test_includes_skills(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert not _should_exclude(Path("skills/my-skill/SKILL.md"))
|
||||
|
||||
def test_includes_profiles(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert not _should_exclude(Path("profiles/coder/config.yaml"))
|
||||
|
||||
def test_includes_sessions(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert not _should_exclude(Path("sessions/abc.json"))
|
||||
|
||||
def test_includes_logs(self):
|
||||
from hermes_cli.backup import _should_exclude
|
||||
assert not _should_exclude(Path("logs/agent.log"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backup tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBackup:
|
||||
def test_creates_zip(self, tmp_path, monkeypatch):
|
||||
"""Backup creates a valid zip containing expected files."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
_make_hermes_tree(hermes_home)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
# get_default_hermes_root needs this
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out_zip = tmp_path / "backup.zip"
|
||||
args = Namespace(output=str(out_zip))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
assert out_zip.exists()
|
||||
with zipfile.ZipFile(out_zip, "r") as zf:
|
||||
names = zf.namelist()
|
||||
# Config should be present
|
||||
assert "config.yaml" in names
|
||||
assert ".env" in names
|
||||
# Skills
|
||||
assert "skills/my-skill/SKILL.md" in names
|
||||
# Profiles
|
||||
assert "profiles/coder/config.yaml" in names
|
||||
assert "profiles/coder/.env" in names
|
||||
# Sessions
|
||||
assert "sessions/abc123.json" in names
|
||||
# Logs
|
||||
assert "logs/agent.log" in names
|
||||
# Skins
|
||||
assert "skins/cyber.yaml" in names
|
||||
|
||||
def test_excludes_hermes_agent(self, tmp_path, monkeypatch):
|
||||
"""Backup does NOT include hermes-agent/ directory."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
_make_hermes_tree(hermes_home)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out_zip = tmp_path / "backup.zip"
|
||||
args = Namespace(output=str(out_zip))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
with zipfile.ZipFile(out_zip, "r") as zf:
|
||||
names = zf.namelist()
|
||||
agent_files = [n for n in names if "hermes-agent" in n]
|
||||
assert agent_files == [], f"hermes-agent files leaked into backup: {agent_files}"
|
||||
|
||||
def test_excludes_pycache(self, tmp_path, monkeypatch):
|
||||
"""Backup does NOT include __pycache__ dirs."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
_make_hermes_tree(hermes_home)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out_zip = tmp_path / "backup.zip"
|
||||
args = Namespace(output=str(out_zip))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
with zipfile.ZipFile(out_zip, "r") as zf:
|
||||
names = zf.namelist()
|
||||
pycache_files = [n for n in names if "__pycache__" in n]
|
||||
assert pycache_files == []
|
||||
|
||||
def test_excludes_pid_files(self, tmp_path, monkeypatch):
|
||||
"""Backup does NOT include PID files."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
_make_hermes_tree(hermes_home)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out_zip = tmp_path / "backup.zip"
|
||||
args = Namespace(output=str(out_zip))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
with zipfile.ZipFile(out_zip, "r") as zf:
|
||||
names = zf.namelist()
|
||||
pid_files = [n for n in names if n.endswith(".pid")]
|
||||
assert pid_files == []
|
||||
|
||||
def test_default_output_path(self, tmp_path, monkeypatch):
|
||||
"""When no output path given, zip goes to ~/hermes-backup-*.zip."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("model: test\n")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
args = Namespace(output=None)
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
# Should exist in home dir
|
||||
zips = list(tmp_path.glob("hermes-backup-*.zip"))
|
||||
assert len(zips) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestImport:
|
||||
def _make_backup_zip(self, zip_path: Path, files: dict[str, str | bytes]) -> None:
|
||||
"""Create a test zip with given files."""
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
for name, content in files.items():
|
||||
if isinstance(content, bytes):
|
||||
zf.writestr(name, content)
|
||||
else:
|
||||
zf.writestr(name, content)
|
||||
|
||||
def test_restores_files(self, tmp_path, monkeypatch):
|
||||
"""Import extracts files into hermes home."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model:\n provider: openrouter\n",
|
||||
".env": "OPENROUTER_API_KEY=sk-test\n",
|
||||
"skills/my-skill/SKILL.md": "# My Skill\n",
|
||||
"profiles/coder/config.yaml": "model:\n provider: anthropic\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
run_import(args)
|
||||
|
||||
assert (hermes_home / "config.yaml").read_text() == "model:\n provider: openrouter\n"
|
||||
assert (hermes_home / ".env").read_text() == "OPENROUTER_API_KEY=sk-test\n"
|
||||
assert (hermes_home / "skills" / "my-skill" / "SKILL.md").read_text() == "# My Skill\n"
|
||||
assert (hermes_home / "profiles" / "coder" / "config.yaml").exists()
|
||||
|
||||
def test_strips_hermes_prefix(self, tmp_path, monkeypatch):
|
||||
"""Import strips .hermes/ prefix if all entries share it."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
".hermes/config.yaml": "model: test\n",
|
||||
".hermes/skills/a/SKILL.md": "# A\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
run_import(args)
|
||||
|
||||
assert (hermes_home / "config.yaml").read_text() == "model: test\n"
|
||||
assert (hermes_home / "skills" / "a" / "SKILL.md").read_text() == "# A\n"
|
||||
|
||||
def test_rejects_empty_zip(self, tmp_path, monkeypatch):
|
||||
"""Import rejects an empty zip."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "empty.zip"
|
||||
with zipfile.ZipFile(zip_path, "w"):
|
||||
pass # empty
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with pytest.raises(SystemExit):
|
||||
run_import(args)
|
||||
|
||||
def test_rejects_non_hermes_zip(self, tmp_path, monkeypatch):
|
||||
"""Import rejects a zip that doesn't look like a hermes backup."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "random.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"some/random/file.txt": "hello",
|
||||
"another/thing.json": "{}",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with pytest.raises(SystemExit):
|
||||
run_import(args)
|
||||
|
||||
def test_blocks_path_traversal(self, tmp_path, monkeypatch):
|
||||
"""Import blocks zip entries with path traversal."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "evil.zip"
|
||||
# Include a marker file so validation passes
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model: test\n",
|
||||
"../../etc/passwd": "root:x:0:0\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
run_import(args)
|
||||
|
||||
# config.yaml should be restored
|
||||
assert (hermes_home / "config.yaml").exists()
|
||||
# traversal file should NOT exist outside hermes home
|
||||
assert not (tmp_path / "etc" / "passwd").exists()
|
||||
|
||||
def test_confirmation_prompt_abort(self, tmp_path, monkeypatch):
|
||||
"""Import aborts when user says no to confirmation."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
# Pre-existing config triggers the confirmation
|
||||
(hermes_home / "config.yaml").write_text("existing: true\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model: restored\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=False)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with patch("builtins.input", return_value="n"):
|
||||
run_import(args)
|
||||
|
||||
# Original config should be unchanged
|
||||
assert (hermes_home / "config.yaml").read_text() == "existing: true\n"
|
||||
|
||||
def test_force_skips_confirmation(self, tmp_path, monkeypatch):
|
||||
"""Import with --force skips confirmation and overwrites."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("existing: true\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model: restored\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
run_import(args)
|
||||
|
||||
assert (hermes_home / "config.yaml").read_text() == "model: restored\n"
|
||||
|
||||
def test_missing_file_exits(self, tmp_path, monkeypatch):
|
||||
"""Import exits with error for nonexistent file."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
args = Namespace(zipfile=str(tmp_path / "nonexistent.zip"), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with pytest.raises(SystemExit):
|
||||
run_import(args)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Round-trip test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRoundTrip:
|
||||
def test_backup_then_import(self, tmp_path, monkeypatch):
|
||||
"""Full round-trip: backup -> import to a new location -> verify."""
|
||||
# Source
|
||||
src_home = tmp_path / "source" / ".hermes"
|
||||
src_home.mkdir(parents=True)
|
||||
_make_hermes_tree(src_home)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(src_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path / "source")
|
||||
|
||||
# Backup
|
||||
out_zip = tmp_path / "roundtrip.zip"
|
||||
from hermes_cli.backup import run_backup, run_import
|
||||
|
||||
run_backup(Namespace(output=str(out_zip)))
|
||||
assert out_zip.exists()
|
||||
|
||||
# Import into a different location
|
||||
dst_home = tmp_path / "dest" / ".hermes"
|
||||
dst_home.mkdir(parents=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(dst_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path / "dest")
|
||||
|
||||
run_import(Namespace(zipfile=str(out_zip), force=True))
|
||||
|
||||
# Verify key files
|
||||
assert (dst_home / "config.yaml").read_text() == "model:\n provider: openrouter\n"
|
||||
assert (dst_home / ".env").read_text() == "OPENROUTER_API_KEY=sk-test-123\n"
|
||||
assert (dst_home / "skills" / "my-skill" / "SKILL.md").exists()
|
||||
assert (dst_home / "profiles" / "coder" / "config.yaml").exists()
|
||||
assert (dst_home / "sessions" / "abc123.json").exists()
|
||||
assert (dst_home / "logs" / "agent.log").exists()
|
||||
|
||||
# hermes-agent should NOT be present
|
||||
assert not (dst_home / "hermes-agent").exists()
|
||||
# __pycache__ should NOT be present
|
||||
assert not (dst_home / "plugins" / "__pycache__").exists()
|
||||
# PID files should NOT be present
|
||||
assert not (dst_home / "gateway.pid").exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validate / detect-prefix unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFormatSize:
|
||||
def test_bytes(self):
|
||||
from hermes_cli.backup import _format_size
|
||||
assert _format_size(512) == "512 B"
|
||||
|
||||
def test_kilobytes(self):
|
||||
from hermes_cli.backup import _format_size
|
||||
assert "KB" in _format_size(2048)
|
||||
|
||||
def test_megabytes(self):
|
||||
from hermes_cli.backup import _format_size
|
||||
assert "MB" in _format_size(5 * 1024 * 1024)
|
||||
|
||||
def test_gigabytes(self):
|
||||
from hermes_cli.backup import _format_size
|
||||
assert "GB" in _format_size(3 * 1024 ** 3)
|
||||
|
||||
def test_terabytes(self):
|
||||
from hermes_cli.backup import _format_size
|
||||
assert "TB" in _format_size(2 * 1024 ** 4)
|
||||
|
||||
|
||||
class TestValidation:
|
||||
def test_validate_with_config(self):
|
||||
"""Zip with config.yaml passes validation."""
|
||||
import io
|
||||
from hermes_cli.backup import _validate_backup_zip
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr("config.yaml", "test")
|
||||
buf.seek(0)
|
||||
with zipfile.ZipFile(buf, "r") as zf:
|
||||
ok, reason = _validate_backup_zip(zf)
|
||||
assert ok
|
||||
|
||||
def test_validate_with_env(self):
|
||||
"""Zip with .env passes validation."""
|
||||
import io
|
||||
from hermes_cli.backup import _validate_backup_zip
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr(".env", "KEY=val")
|
||||
buf.seek(0)
|
||||
with zipfile.ZipFile(buf, "r") as zf:
|
||||
ok, reason = _validate_backup_zip(zf)
|
||||
assert ok
|
||||
|
||||
def test_validate_rejects_random(self):
|
||||
"""Zip without hermes markers fails validation."""
|
||||
import io
|
||||
from hermes_cli.backup import _validate_backup_zip
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr("random/file.txt", "hello")
|
||||
buf.seek(0)
|
||||
with zipfile.ZipFile(buf, "r") as zf:
|
||||
ok, reason = _validate_backup_zip(zf)
|
||||
assert not ok
|
||||
|
||||
def test_detect_prefix_hermes(self):
|
||||
"""Detects .hermes/ prefix wrapping all entries."""
|
||||
import io
|
||||
from hermes_cli.backup import _detect_prefix
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr(".hermes/config.yaml", "test")
|
||||
zf.writestr(".hermes/skills/a/SKILL.md", "skill")
|
||||
buf.seek(0)
|
||||
with zipfile.ZipFile(buf, "r") as zf:
|
||||
assert _detect_prefix(zf) == ".hermes/"
|
||||
|
||||
def test_detect_prefix_none(self):
|
||||
"""No prefix when entries are at root."""
|
||||
import io
|
||||
from hermes_cli.backup import _detect_prefix
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr("config.yaml", "test")
|
||||
zf.writestr("skills/a/SKILL.md", "skill")
|
||||
buf.seek(0)
|
||||
with zipfile.ZipFile(buf, "r") as zf:
|
||||
assert _detect_prefix(zf) == ""
|
||||
|
||||
def test_detect_prefix_only_dirs(self):
|
||||
"""Prefix detection returns empty for zip with only directory entries."""
|
||||
import io
|
||||
from hermes_cli.backup import _detect_prefix
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
# Only directory entries (trailing slash)
|
||||
zf.writestr(".hermes/", "")
|
||||
zf.writestr(".hermes/skills/", "")
|
||||
buf.seek(0)
|
||||
with zipfile.ZipFile(buf, "r") as zf:
|
||||
assert _detect_prefix(zf) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge case tests for uncovered paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBackupEdgeCases:
|
||||
def test_nonexistent_hermes_home(self, tmp_path, monkeypatch):
|
||||
"""Backup exits when hermes home doesn't exist."""
|
||||
fake_home = tmp_path / "nonexistent" / ".hermes"
|
||||
monkeypatch.setenv("HERMES_HOME", str(fake_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path / "nonexistent")
|
||||
|
||||
args = Namespace(output=str(tmp_path / "out.zip"))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
with pytest.raises(SystemExit):
|
||||
run_backup(args)
|
||||
|
||||
def test_output_is_directory(self, tmp_path, monkeypatch):
|
||||
"""When output path is a directory, zip is created inside it."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("model: test\n")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out_dir = tmp_path / "backups"
|
||||
out_dir.mkdir()
|
||||
|
||||
args = Namespace(output=str(out_dir))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
zips = list(out_dir.glob("hermes-backup-*.zip"))
|
||||
assert len(zips) == 1
|
||||
|
||||
def test_output_without_zip_suffix(self, tmp_path, monkeypatch):
|
||||
"""Output path without .zip gets suffix appended."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("model: test\n")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out_path = tmp_path / "mybackup.tar"
|
||||
args = Namespace(output=str(out_path))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
# Should have .tar.zip suffix
|
||||
assert (tmp_path / "mybackup.tar.zip").exists()
|
||||
|
||||
def test_empty_hermes_home(self, tmp_path, monkeypatch):
|
||||
"""Backup handles empty hermes home (no files to back up)."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
# Only excluded dirs, no actual files
|
||||
(hermes_home / "__pycache__").mkdir()
|
||||
(hermes_home / "__pycache__" / "foo.pyc").write_bytes(b"\x00")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
args = Namespace(output=str(tmp_path / "out.zip"))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
# No zip should be created
|
||||
assert not (tmp_path / "out.zip").exists()
|
||||
|
||||
def test_permission_error_during_backup(self, tmp_path, monkeypatch):
|
||||
"""Backup handles permission errors gracefully."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("model: test\n")
|
||||
|
||||
# Create an unreadable file
|
||||
bad_file = hermes_home / "secret.db"
|
||||
bad_file.write_text("data")
|
||||
bad_file.chmod(0o000)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
out_zip = tmp_path / "out.zip"
|
||||
args = Namespace(output=str(out_zip))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
try:
|
||||
run_backup(args)
|
||||
finally:
|
||||
# Restore permissions for cleanup
|
||||
bad_file.chmod(0o644)
|
||||
|
||||
# Zip should still be created with the readable files
|
||||
assert out_zip.exists()
|
||||
|
||||
def test_skips_output_zip_inside_hermes(self, tmp_path, monkeypatch):
|
||||
"""Backup skips its own output zip if it's inside hermes root."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("model: test\n")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
# Output inside hermes home
|
||||
out_zip = hermes_home / "backup.zip"
|
||||
args = Namespace(output=str(out_zip))
|
||||
|
||||
from hermes_cli.backup import run_backup
|
||||
run_backup(args)
|
||||
|
||||
# The zip should exist but not contain itself
|
||||
assert out_zip.exists()
|
||||
with zipfile.ZipFile(out_zip, "r") as zf:
|
||||
assert "backup.zip" not in zf.namelist()
|
||||
|
||||
|
||||
class TestImportEdgeCases:
|
||||
def _make_backup_zip(self, zip_path: Path, files: dict[str, str | bytes]) -> None:
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
for name, content in files.items():
|
||||
zf.writestr(name, content)
|
||||
|
||||
def test_not_a_zip(self, tmp_path, monkeypatch):
|
||||
"""Import rejects a non-zip file."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
not_zip = tmp_path / "fake.zip"
|
||||
not_zip.write_text("this is not a zip")
|
||||
|
||||
args = Namespace(zipfile=str(not_zip), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with pytest.raises(SystemExit):
|
||||
run_import(args)
|
||||
|
||||
def test_eof_during_confirmation(self, tmp_path, monkeypatch):
|
||||
"""Import handles EOFError during confirmation prompt."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("existing\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {"config.yaml": "new\n"})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=False)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with patch("builtins.input", side_effect=EOFError):
|
||||
with pytest.raises(SystemExit):
|
||||
run_import(args)
|
||||
|
||||
def test_keyboard_interrupt_during_confirmation(self, tmp_path, monkeypatch):
|
||||
"""Import handles KeyboardInterrupt during confirmation prompt."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / ".env").write_text("KEY=val\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {"config.yaml": "new\n"})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=False)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with patch("builtins.input", side_effect=KeyboardInterrupt):
|
||||
with pytest.raises(SystemExit):
|
||||
run_import(args)
|
||||
|
||||
def test_permission_error_during_import(self, tmp_path, monkeypatch):
|
||||
"""Import handles permission errors during extraction."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
# Create a read-only directory so extraction fails
|
||||
locked_dir = hermes_home / "locked"
|
||||
locked_dir.mkdir()
|
||||
locked_dir.chmod(0o555)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model: test\n",
|
||||
"locked/secret.txt": "data",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
try:
|
||||
run_import(args)
|
||||
finally:
|
||||
locked_dir.chmod(0o755)
|
||||
|
||||
# config.yaml should still be restored despite the error
|
||||
assert (hermes_home / "config.yaml").exists()
|
||||
|
||||
def test_progress_with_many_files(self, tmp_path, monkeypatch):
|
||||
"""Import shows progress with 500+ files."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "big.zip"
|
||||
files = {"config.yaml": "model: test\n"}
|
||||
for i in range(600):
|
||||
files[f"sessions/s{i:04d}.json"] = "{}"
|
||||
|
||||
self._make_backup_zip(zip_path, files)
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
run_import(args)
|
||||
|
||||
assert (hermes_home / "config.yaml").exists()
|
||||
assert (hermes_home / "sessions" / "s0599.json").exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Profile restoration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestProfileRestoration:
|
||||
def _make_backup_zip(self, zip_path: Path, files: dict[str, str | bytes]) -> None:
|
||||
with zipfile.ZipFile(zip_path, "w") as zf:
|
||||
for name, content in files.items():
|
||||
zf.writestr(name, content)
|
||||
|
||||
def test_import_creates_profile_wrappers(self, tmp_path, monkeypatch):
|
||||
"""Import auto-creates wrapper scripts for restored profiles."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
# Mock the wrapper dir to be inside tmp_path
|
||||
wrapper_dir = tmp_path / ".local" / "bin"
|
||||
wrapper_dir.mkdir(parents=True)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model:\n provider: openrouter\n",
|
||||
"profiles/coder/config.yaml": "model:\n provider: anthropic\n",
|
||||
"profiles/coder/.env": "ANTHROPIC_API_KEY=sk-test\n",
|
||||
"profiles/researcher/config.yaml": "model:\n provider: deepseek\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
run_import(args)
|
||||
|
||||
# Profile directories should exist
|
||||
assert (hermes_home / "profiles" / "coder" / "config.yaml").exists()
|
||||
assert (hermes_home / "profiles" / "researcher" / "config.yaml").exists()
|
||||
|
||||
# Wrapper scripts should be created
|
||||
assert (wrapper_dir / "coder").exists()
|
||||
assert (wrapper_dir / "researcher").exists()
|
||||
|
||||
# Wrappers should contain the right content
|
||||
coder_wrapper = (wrapper_dir / "coder").read_text()
|
||||
assert "hermes -p coder" in coder_wrapper
|
||||
|
||||
def test_import_skips_profile_dirs_without_config(self, tmp_path, monkeypatch):
|
||||
"""Import doesn't create wrappers for profile dirs without config."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
wrapper_dir = tmp_path / ".local" / "bin"
|
||||
wrapper_dir.mkdir(parents=True)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model: test\n",
|
||||
"profiles/valid/config.yaml": "model: test\n",
|
||||
"profiles/empty/readme.txt": "nothing here\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
run_import(args)
|
||||
|
||||
# Only valid profile should get a wrapper
|
||||
assert (wrapper_dir / "valid").exists()
|
||||
assert not (wrapper_dir / "empty").exists()
|
||||
|
||||
def test_import_without_profiles_module(self, tmp_path, monkeypatch):
|
||||
"""Import gracefully handles missing profiles module (fresh install)."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
zip_path = tmp_path / "backup.zip"
|
||||
self._make_backup_zip(zip_path, {
|
||||
"config.yaml": "model: test\n",
|
||||
"profiles/coder/config.yaml": "model: test\n",
|
||||
})
|
||||
|
||||
args = Namespace(zipfile=str(zip_path), force=True)
|
||||
|
||||
# Simulate profiles module not being available
|
||||
import hermes_cli.backup as backup_mod
|
||||
original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
|
||||
|
||||
def fake_import(name, *a, **kw):
|
||||
if name == "hermes_cli.profiles":
|
||||
raise ImportError("no profiles module")
|
||||
return original_import(name, *a, **kw)
|
||||
|
||||
from hermes_cli.backup import run_import
|
||||
with patch("builtins.__import__", side_effect=fake_import):
|
||||
run_import(args)
|
||||
|
||||
# Files should still be restored even if wrappers can't be created
|
||||
assert (hermes_home / "profiles" / "coder" / "config.yaml").exists()
|
||||
254
tests/hermes_cli/test_cli_model_picker.py
Normal file
254
tests/hermes_cli/test_cli_model_picker.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
"""Tests for the interactive CLI /model picker (provider → model drill-down)."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class _FakeBuffer:
|
||||
def __init__(self, text="draft text"):
|
||||
self.text = text
|
||||
self.cursor_position = len(text)
|
||||
self.reset_calls = []
|
||||
|
||||
def reset(self, append_to_history=False):
|
||||
self.reset_calls.append(append_to_history)
|
||||
self.text = ""
|
||||
self.cursor_position = 0
|
||||
|
||||
|
||||
def _make_providers():
|
||||
return [
|
||||
{
|
||||
"slug": "openrouter",
|
||||
"name": "OpenRouter",
|
||||
"is_current": True,
|
||||
"is_user_defined": False,
|
||||
"models": ["anthropic/claude-opus-4.6", "openai/gpt-5.4"],
|
||||
"total_models": 2,
|
||||
"source": "built-in",
|
||||
},
|
||||
{
|
||||
"slug": "anthropic",
|
||||
"name": "Anthropic",
|
||||
"is_current": False,
|
||||
"is_user_defined": False,
|
||||
"models": ["claude-opus-4.6", "claude-sonnet-4.6"],
|
||||
"total_models": 2,
|
||||
"source": "built-in",
|
||||
},
|
||||
{
|
||||
"slug": "custom:my-ollama",
|
||||
"name": "My Ollama",
|
||||
"is_current": False,
|
||||
"is_user_defined": True,
|
||||
"models": ["llama3", "mistral"],
|
||||
"total_models": 2,
|
||||
"source": "user-config",
|
||||
"api_url": "http://localhost:11434/v1",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _make_picker_cli(picker_return_value):
|
||||
cli = MagicMock()
|
||||
cli._run_curses_picker = MagicMock(return_value=picker_return_value)
|
||||
cli._app = MagicMock()
|
||||
cli._status_bar_visible = True
|
||||
return cli
|
||||
|
||||
|
||||
def _make_modal_cli():
|
||||
from cli import HermesCLI
|
||||
|
||||
cli = HermesCLI.__new__(HermesCLI)
|
||||
cli.model = "gpt-5.4"
|
||||
cli.provider = "openrouter"
|
||||
cli.requested_provider = "openrouter"
|
||||
cli.base_url = ""
|
||||
cli.api_key = ""
|
||||
cli.api_mode = ""
|
||||
cli._explicit_api_key = ""
|
||||
cli._explicit_base_url = ""
|
||||
cli._pending_model_switch_note = None
|
||||
cli._model_picker_state = None
|
||||
cli._modal_input_snapshot = None
|
||||
cli._status_bar_visible = True
|
||||
cli._invalidate = MagicMock()
|
||||
cli.agent = None
|
||||
cli.config = {}
|
||||
cli.console = MagicMock()
|
||||
cli._app = SimpleNamespace(
|
||||
current_buffer=_FakeBuffer(),
|
||||
invalidate=MagicMock(),
|
||||
)
|
||||
return cli
|
||||
|
||||
|
||||
def test_provider_selection_returns_slug_on_choice():
|
||||
providers = _make_providers()
|
||||
cli = _make_picker_cli(1)
|
||||
from cli import HermesCLI
|
||||
|
||||
result = HermesCLI._interactive_provider_selection(cli, providers, "gpt-5.4", "OpenRouter")
|
||||
|
||||
assert result == "anthropic"
|
||||
cli._run_curses_picker.assert_called_once()
|
||||
|
||||
|
||||
def test_provider_selection_returns_none_on_cancel():
|
||||
providers = _make_providers()
|
||||
cli = _make_picker_cli(None)
|
||||
from cli import HermesCLI
|
||||
|
||||
result = HermesCLI._interactive_provider_selection(cli, providers, "gpt-5.4", "OpenRouter")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_provider_selection_default_is_current():
|
||||
providers = _make_providers()
|
||||
cli = _make_picker_cli(0)
|
||||
from cli import HermesCLI
|
||||
|
||||
HermesCLI._interactive_provider_selection(cli, providers, "gpt-5.4", "OpenRouter")
|
||||
|
||||
assert cli._run_curses_picker.call_args.kwargs["default_index"] == 0
|
||||
|
||||
|
||||
def test_model_selection_returns_model_on_choice():
|
||||
provider_data = _make_providers()[0]
|
||||
cli = _make_picker_cli(0)
|
||||
from cli import HermesCLI
|
||||
|
||||
result = HermesCLI._interactive_model_selection(cli, provider_data["models"], provider_data)
|
||||
|
||||
assert result == "anthropic/claude-opus-4.6"
|
||||
|
||||
|
||||
def test_model_selection_custom_entry_prompts_for_input():
|
||||
provider_data = _make_providers()[0]
|
||||
cli = _make_picker_cli(2)
|
||||
from cli import HermesCLI
|
||||
|
||||
cli._prompt_text_input = MagicMock(return_value="my-custom-model")
|
||||
result = HermesCLI._interactive_model_selection(cli, provider_data["models"], provider_data)
|
||||
|
||||
assert result == "my-custom-model"
|
||||
cli._prompt_text_input.assert_called_once_with(" Enter model name: ")
|
||||
|
||||
|
||||
def test_model_selection_empty_prompts_for_manual_input():
|
||||
provider_data = {
|
||||
"slug": "custom:empty",
|
||||
"name": "Empty Provider",
|
||||
"models": [],
|
||||
"total_models": 0,
|
||||
}
|
||||
cli = _make_picker_cli(None)
|
||||
from cli import HermesCLI
|
||||
|
||||
cli._prompt_text_input = MagicMock(return_value="my-model")
|
||||
result = HermesCLI._interactive_model_selection(cli, [], provider_data)
|
||||
|
||||
assert result == "my-model"
|
||||
cli._prompt_text_input.assert_called_once_with(" Enter model name manually (or Enter to cancel): ")
|
||||
|
||||
|
||||
def test_prompt_text_input_uses_run_in_terminal_when_app_active():
|
||||
from cli import HermesCLI
|
||||
|
||||
cli = _make_modal_cli()
|
||||
|
||||
with (
|
||||
patch("prompt_toolkit.application.run_in_terminal", side_effect=lambda fn: fn()) as run_mock,
|
||||
patch("builtins.input", return_value="manual-value"),
|
||||
):
|
||||
result = HermesCLI._prompt_text_input(cli, "Enter value: ")
|
||||
|
||||
assert result == "manual-value"
|
||||
run_mock.assert_called_once()
|
||||
assert cli._status_bar_visible is True
|
||||
|
||||
|
||||
def test_should_handle_model_command_inline_uses_command_name_resolution():
|
||||
from cli import HermesCLI
|
||||
|
||||
cli = _make_modal_cli()
|
||||
|
||||
with patch("hermes_cli.commands.resolve_command", return_value=SimpleNamespace(name="model")):
|
||||
assert HermesCLI._should_handle_model_command_inline(cli, "/model") is True
|
||||
|
||||
with patch("hermes_cli.commands.resolve_command", return_value=SimpleNamespace(name="help")):
|
||||
assert HermesCLI._should_handle_model_command_inline(cli, "/model") is False
|
||||
|
||||
assert HermesCLI._should_handle_model_command_inline(cli, "/model", has_images=True) is False
|
||||
|
||||
|
||||
def test_process_command_model_without_args_opens_modal_picker_and_captures_draft():
|
||||
from cli import HermesCLI
|
||||
|
||||
cli = _make_modal_cli()
|
||||
providers = _make_providers()
|
||||
|
||||
with (
|
||||
patch("hermes_cli.model_switch.list_authenticated_providers", return_value=providers),
|
||||
patch("cli._cprint"),
|
||||
):
|
||||
result = cli.process_command("/model")
|
||||
|
||||
assert result is True
|
||||
assert cli._model_picker_state is not None
|
||||
assert cli._model_picker_state["stage"] == "provider"
|
||||
assert cli._model_picker_state["selected"] == 0
|
||||
assert cli._modal_input_snapshot == {"text": "draft text", "cursor_position": len("draft text")}
|
||||
assert cli._app.current_buffer.text == ""
|
||||
|
||||
|
||||
def test_model_picker_provider_then_model_selection_applies_switch_result_and_restores_draft():
|
||||
from cli import HermesCLI
|
||||
|
||||
cli = _make_modal_cli()
|
||||
providers = _make_providers()
|
||||
|
||||
with (
|
||||
patch("hermes_cli.model_switch.list_authenticated_providers", return_value=providers),
|
||||
patch("cli._cprint"),
|
||||
):
|
||||
assert cli.process_command("/model") is True
|
||||
|
||||
cli._model_picker_state["selected"] = 1
|
||||
with patch("hermes_cli.models.provider_model_ids", return_value=["claude-opus-4.6", "claude-sonnet-4.6"]):
|
||||
HermesCLI._handle_model_picker_selection(cli)
|
||||
|
||||
assert cli._model_picker_state["stage"] == "model"
|
||||
assert cli._model_picker_state["provider_data"]["slug"] == "anthropic"
|
||||
assert cli._model_picker_state["model_list"] == ["claude-opus-4.6", "claude-sonnet-4.6"]
|
||||
|
||||
cli._model_picker_state["selected"] = 0
|
||||
switch_result = SimpleNamespace(
|
||||
success=True,
|
||||
error_message=None,
|
||||
new_model="claude-opus-4.6",
|
||||
target_provider="anthropic",
|
||||
api_key="",
|
||||
base_url="",
|
||||
api_mode="anthropic_messages",
|
||||
provider_label="Anthropic",
|
||||
model_info=None,
|
||||
warning_message=None,
|
||||
provider_changed=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("hermes_cli.model_switch.switch_model", return_value=switch_result) as switch_mock,
|
||||
patch("cli._cprint"),
|
||||
):
|
||||
HermesCLI._handle_model_picker_selection(cli)
|
||||
|
||||
assert cli._model_picker_state is None
|
||||
assert cli.model == "claude-opus-4.6"
|
||||
assert cli.provider == "anthropic"
|
||||
assert cli.requested_provider == "anthropic"
|
||||
assert cli._app.current_buffer.text == "draft text"
|
||||
switch_mock.assert_called_once()
|
||||
assert switch_mock.call_args.kwargs["explicit_provider"] == "anthropic"
|
||||
|
|
@ -68,6 +68,7 @@ class TestLoadConfigDefaults:
|
|||
assert "max_turns" not in config
|
||||
assert "terminal" in config
|
||||
assert config["terminal"]["backend"] == "local"
|
||||
assert config["display"]["interim_assistant_messages"] is True
|
||||
|
||||
def test_legacy_root_level_max_turns_migrates_to_agent_config(self, tmp_path):
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
|
|
@ -421,3 +422,25 @@ class TestAnthropicTokenMigration:
|
|||
}):
|
||||
migrate_config(interactive=False, quiet=True)
|
||||
assert load_env().get("ANTHROPIC_TOKEN") == "current-token"
|
||||
|
||||
|
||||
class TestInterimAssistantMessageConfig:
|
||||
"""Test the explicit gateway interim-message config gate."""
|
||||
|
||||
def test_default_config_enables_interim_assistant_messages(self):
|
||||
assert DEFAULT_CONFIG["display"]["interim_assistant_messages"] is True
|
||||
|
||||
def test_migrate_to_v15_adds_interim_assistant_message_gate(self, tmp_path):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(
|
||||
yaml.safe_dump({"_config_version": 14, "display": {"tool_progress": "off"}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
migrate_config(interactive=False, quiet=True)
|
||||
raw = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert raw["_config_version"] == 16
|
||||
assert raw["display"]["tool_progress"] == "off"
|
||||
assert raw["display"]["interim_assistant_messages"] is True
|
||||
|
|
|
|||
342
tests/hermes_cli/test_container_aware_cli.py
Normal file
342
tests/hermes_cli/test_container_aware_cli.py
Normal file
|
|
@ -0,0 +1,342 @@
|
|||
"""Tests for container-aware CLI routing (NixOS container mode).
|
||||
|
||||
When container.enable = true in the NixOS module, the activation script
|
||||
writes a .container-mode metadata file. The host CLI detects this and
|
||||
execs into the container instead of running locally.
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.config import (
|
||||
_is_inside_container,
|
||||
get_container_exec_info,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _is_inside_container
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_is_inside_container_dockerenv():
|
||||
"""Detects /.dockerenv marker file."""
|
||||
with patch("os.path.exists") as mock_exists:
|
||||
mock_exists.side_effect = lambda p: p == "/.dockerenv"
|
||||
assert _is_inside_container() is True
|
||||
|
||||
|
||||
def test_is_inside_container_containerenv():
|
||||
"""Detects Podman's /run/.containerenv marker."""
|
||||
with patch("os.path.exists") as mock_exists:
|
||||
mock_exists.side_effect = lambda p: p == "/run/.containerenv"
|
||||
assert _is_inside_container() is True
|
||||
|
||||
|
||||
def test_is_inside_container_cgroup_docker():
|
||||
"""Detects 'docker' in /proc/1/cgroup."""
|
||||
with patch("os.path.exists", return_value=False), \
|
||||
patch("builtins.open", create=True) as mock_open:
|
||||
mock_open.return_value.__enter__ = lambda s: s
|
||||
mock_open.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_open.return_value.read = MagicMock(
|
||||
return_value="12:memory:/docker/abc123\n"
|
||||
)
|
||||
assert _is_inside_container() is True
|
||||
|
||||
|
||||
def test_is_inside_container_false_on_host():
|
||||
"""Returns False when none of the container indicators are present."""
|
||||
with patch("os.path.exists", return_value=False), \
|
||||
patch("builtins.open", side_effect=OSError("no such file")):
|
||||
assert _is_inside_container() is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# get_container_exec_info
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def container_env(tmp_path, monkeypatch):
|
||||
"""Set up a fake HERMES_HOME with .container-mode file."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("HERMES_DEV", raising=False)
|
||||
|
||||
container_mode = hermes_home / ".container-mode"
|
||||
container_mode.write_text(
|
||||
"# Written by NixOS activation script. Do not edit manually.\n"
|
||||
"backend=podman\n"
|
||||
"container_name=hermes-agent\n"
|
||||
"exec_user=hermes\n"
|
||||
"hermes_bin=/data/current-package/bin/hermes\n"
|
||||
)
|
||||
return hermes_home
|
||||
|
||||
|
||||
def test_get_container_exec_info_returns_metadata(container_env):
|
||||
"""Reads .container-mode and returns all fields including exec_user."""
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is not None
|
||||
assert info["backend"] == "podman"
|
||||
assert info["container_name"] == "hermes-agent"
|
||||
assert info["exec_user"] == "hermes"
|
||||
assert info["hermes_bin"] == "/data/current-package/bin/hermes"
|
||||
|
||||
|
||||
def test_get_container_exec_info_none_inside_container(container_env):
|
||||
"""Returns None when we're already inside a container."""
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=True):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is None
|
||||
|
||||
|
||||
def test_get_container_exec_info_none_without_file(tmp_path, monkeypatch):
|
||||
"""Returns None when .container-mode doesn't exist (native mode)."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("HERMES_DEV", raising=False)
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is None
|
||||
|
||||
|
||||
def test_get_container_exec_info_skipped_when_hermes_dev(container_env, monkeypatch):
|
||||
"""Returns None when HERMES_DEV=1 is set (dev mode bypass)."""
|
||||
monkeypatch.setenv("HERMES_DEV", "1")
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is None
|
||||
|
||||
|
||||
def test_get_container_exec_info_not_skipped_when_hermes_dev_zero(container_env, monkeypatch):
|
||||
"""HERMES_DEV=0 does NOT trigger bypass — only '1' does."""
|
||||
monkeypatch.setenv("HERMES_DEV", "0")
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is not None
|
||||
|
||||
|
||||
def test_get_container_exec_info_defaults():
|
||||
"""Falls back to defaults for missing keys."""
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
hermes_home = Path(tmpdir) / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / ".container-mode").write_text(
|
||||
"# minimal file with no keys\n"
|
||||
)
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False), \
|
||||
patch("hermes_cli.config.get_hermes_home", return_value=hermes_home), \
|
||||
patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("HERMES_DEV", None)
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info is not None
|
||||
assert info["backend"] == "docker"
|
||||
assert info["container_name"] == "hermes-agent"
|
||||
assert info["exec_user"] == "hermes"
|
||||
assert info["hermes_bin"] == "/data/current-package/bin/hermes"
|
||||
|
||||
|
||||
def test_get_container_exec_info_docker_backend(container_env):
|
||||
"""Correctly reads docker backend with custom exec_user."""
|
||||
(container_env / ".container-mode").write_text(
|
||||
"backend=docker\n"
|
||||
"container_name=hermes-custom\n"
|
||||
"exec_user=myuser\n"
|
||||
"hermes_bin=/opt/hermes/bin/hermes\n"
|
||||
)
|
||||
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False):
|
||||
info = get_container_exec_info()
|
||||
|
||||
assert info["backend"] == "docker"
|
||||
assert info["container_name"] == "hermes-custom"
|
||||
assert info["exec_user"] == "myuser"
|
||||
assert info["hermes_bin"] == "/opt/hermes/bin/hermes"
|
||||
|
||||
|
||||
def test_get_container_exec_info_crashes_on_permission_error(container_env):
|
||||
"""PermissionError propagates instead of being silently swallowed."""
|
||||
with patch("hermes_cli.config._is_inside_container", return_value=False), \
|
||||
patch("builtins.open", side_effect=PermissionError("permission denied")):
|
||||
with pytest.raises(PermissionError):
|
||||
get_container_exec_info()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _exec_in_container
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def docker_container_info():
|
||||
return {
|
||||
"backend": "docker",
|
||||
"container_name": "hermes-agent",
|
||||
"exec_user": "hermes",
|
||||
"hermes_bin": "/data/current-package/bin/hermes",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def podman_container_info():
|
||||
return {
|
||||
"backend": "podman",
|
||||
"container_name": "hermes-agent",
|
||||
"exec_user": "hermes",
|
||||
"hermes_bin": "/data/current-package/bin/hermes",
|
||||
}
|
||||
|
||||
|
||||
def test_exec_in_container_calls_execvp(docker_container_info):
|
||||
"""Verifies os.execvp is called with correct args: runtime, tty flags,
|
||||
user, env vars, container name, binary, and CLI args."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
with patch("shutil.which", return_value="/usr/bin/docker"), \
|
||||
patch("subprocess.run") as mock_run, \
|
||||
patch("sys.stdin") as mock_stdin, \
|
||||
patch("os.execvp") as mock_execvp, \
|
||||
patch.dict(os.environ, {"TERM": "xterm-256color", "LANG": "en_US.UTF-8"},
|
||||
clear=False):
|
||||
mock_stdin.isatty.return_value = True
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
_exec_in_container(docker_container_info, ["chat", "-m", "opus"])
|
||||
|
||||
mock_execvp.assert_called_once()
|
||||
cmd = mock_execvp.call_args[0][1]
|
||||
assert cmd[0] == "/usr/bin/docker"
|
||||
assert cmd[1] == "exec"
|
||||
assert "-it" in cmd
|
||||
idx_u = cmd.index("-u")
|
||||
assert cmd[idx_u + 1] == "hermes"
|
||||
e_indices = [i for i, v in enumerate(cmd) if v == "-e"]
|
||||
e_values = [cmd[i + 1] for i in e_indices]
|
||||
assert "TERM=xterm-256color" in e_values
|
||||
assert "LANG=en_US.UTF-8" in e_values
|
||||
assert "hermes-agent" in cmd
|
||||
assert "/data/current-package/bin/hermes" in cmd
|
||||
assert "chat" in cmd
|
||||
|
||||
|
||||
def test_exec_in_container_non_tty_uses_i_only(docker_container_info):
|
||||
"""Non-TTY mode uses -i instead of -it."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
with patch("shutil.which", return_value="/usr/bin/docker"), \
|
||||
patch("subprocess.run") as mock_run, \
|
||||
patch("sys.stdin") as mock_stdin, \
|
||||
patch("os.execvp") as mock_execvp:
|
||||
mock_stdin.isatty.return_value = False
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
_exec_in_container(docker_container_info, ["sessions", "list"])
|
||||
|
||||
cmd = mock_execvp.call_args[0][1]
|
||||
assert "-i" in cmd
|
||||
assert "-it" not in cmd
|
||||
|
||||
|
||||
def test_exec_in_container_no_runtime_hard_fails(podman_container_info):
|
||||
"""Hard fails when runtime not found (no fallback)."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
with patch("shutil.which", return_value=None), \
|
||||
patch("subprocess.run") as mock_run, \
|
||||
patch("os.execvp") as mock_execvp, \
|
||||
pytest.raises(SystemExit) as exc_info:
|
||||
_exec_in_container(podman_container_info, ["chat"])
|
||||
|
||||
mock_run.assert_not_called()
|
||||
mock_execvp.assert_not_called()
|
||||
assert exc_info.value.code != 0
|
||||
|
||||
|
||||
def test_exec_in_container_sudo_probe_sets_prefix(podman_container_info):
|
||||
"""When first probe fails and sudo probe succeeds, execvp is called
|
||||
with sudo -n prefix."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
def which_side_effect(name):
|
||||
if name == "podman":
|
||||
return "/usr/bin/podman"
|
||||
if name == "sudo":
|
||||
return "/usr/bin/sudo"
|
||||
return None
|
||||
|
||||
with patch("shutil.which", side_effect=which_side_effect), \
|
||||
patch("subprocess.run") as mock_run, \
|
||||
patch("sys.stdin") as mock_stdin, \
|
||||
patch("os.execvp") as mock_execvp:
|
||||
mock_stdin.isatty.return_value = True
|
||||
mock_run.side_effect = [
|
||||
MagicMock(returncode=1), # direct probe fails
|
||||
MagicMock(returncode=0), # sudo probe succeeds
|
||||
]
|
||||
|
||||
_exec_in_container(podman_container_info, ["chat"])
|
||||
|
||||
mock_execvp.assert_called_once()
|
||||
cmd = mock_execvp.call_args[0][1]
|
||||
assert cmd[0] == "/usr/bin/sudo"
|
||||
assert cmd[1] == "-n"
|
||||
assert cmd[2] == "/usr/bin/podman"
|
||||
assert cmd[3] == "exec"
|
||||
|
||||
|
||||
def test_exec_in_container_probe_timeout_prints_message(docker_container_info):
|
||||
"""TimeoutExpired from probe produces a human-readable error, not a
|
||||
raw traceback."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
with patch("shutil.which", return_value="/usr/bin/docker"), \
|
||||
patch("subprocess.run", side_effect=subprocess.TimeoutExpired(
|
||||
cmd=["docker", "inspect"], timeout=15)), \
|
||||
patch("os.execvp") as mock_execvp, \
|
||||
pytest.raises(SystemExit) as exc_info:
|
||||
_exec_in_container(docker_container_info, ["chat"])
|
||||
|
||||
mock_execvp.assert_not_called()
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
|
||||
def test_exec_in_container_container_not_running_no_sudo(docker_container_info):
|
||||
"""When runtime exists but container not found and no sudo available,
|
||||
prints helpful error about root containers."""
|
||||
from hermes_cli.main import _exec_in_container
|
||||
|
||||
def which_side_effect(name):
|
||||
if name == "docker":
|
||||
return "/usr/bin/docker"
|
||||
return None
|
||||
|
||||
with patch("shutil.which", side_effect=which_side_effect), \
|
||||
patch("subprocess.run") as mock_run, \
|
||||
patch("os.execvp") as mock_execvp, \
|
||||
pytest.raises(SystemExit) as exc_info:
|
||||
mock_run.return_value = MagicMock(returncode=1)
|
||||
|
||||
_exec_in_container(docker_container_info, ["chat"])
|
||||
|
||||
mock_execvp.assert_not_called()
|
||||
assert exc_info.value.code == 1
|
||||
|
|
@ -260,7 +260,7 @@ class TestWaitForGatewayExit:
|
|||
def test_kill_gateway_processes_force_uses_helper(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(gateway, "find_gateway_pids", lambda exclude_pids=None: [11, 22])
|
||||
monkeypatch.setattr(gateway, "find_gateway_pids", lambda exclude_pids=None, all_profiles=False: [11, 22])
|
||||
monkeypatch.setattr(gateway, "terminate_pid", lambda pid, force=False: calls.append((pid, force)))
|
||||
|
||||
killed = gateway.kill_gateway_processes(force=True)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Tests for gateway service management helpers."""
|
||||
|
||||
import os
|
||||
import pwd
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
|
@ -129,7 +130,7 @@ class TestGatewayStopCleanup:
|
|||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"kill_gateway_processes",
|
||||
lambda force=False: kill_calls.append(force) or 2,
|
||||
lambda force=False, all_profiles=False: kill_calls.append(force) or 2,
|
||||
)
|
||||
|
||||
gateway_cli.gateway_command(SimpleNamespace(gateway_command="stop"))
|
||||
|
|
@ -155,7 +156,7 @@ class TestGatewayStopCleanup:
|
|||
monkeypatch.setattr(
|
||||
gateway_cli,
|
||||
"kill_gateway_processes",
|
||||
lambda force=False: kill_calls.append(force) or 2,
|
||||
lambda force=False, all_profiles=False: kill_calls.append(force) or 2,
|
||||
)
|
||||
|
||||
gateway_cli.gateway_command(SimpleNamespace(gateway_command="stop", **{"all": True}))
|
||||
|
|
@ -924,6 +925,23 @@ class TestProfileArg:
|
|||
assert "<string>--profile</string>" in plist
|
||||
assert "<string>mybot</string>" in plist
|
||||
|
||||
def test_launchd_plist_path_uses_real_user_home_not_profile_home(self, tmp_path, monkeypatch):
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "orcha"
|
||||
profile_dir.mkdir(parents=True)
|
||||
machine_home = tmp_path / "machine-home"
|
||||
machine_home.mkdir()
|
||||
profile_home = profile_dir / "home"
|
||||
profile_home.mkdir()
|
||||
|
||||
monkeypatch.setattr(Path, "home", lambda: profile_home)
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_dir))
|
||||
monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir)
|
||||
monkeypatch.setattr(pwd, "getpwuid", lambda uid: SimpleNamespace(pw_dir=str(machine_home)))
|
||||
|
||||
plist_path = gateway_cli.get_launchd_plist_path()
|
||||
|
||||
assert plist_path == machine_home / "Library" / "LaunchAgents" / "ai.hermes.gateway-orcha.plist"
|
||||
|
||||
|
||||
class TestRemapPathForUser:
|
||||
"""Unit tests for _remap_path_for_user()."""
|
||||
|
|
|
|||
|
|
@ -1,288 +1,255 @@
|
|||
"""Tests for hermes_cli/logs.py — log viewing and filtering."""
|
||||
"""Tests for hermes_cli.logs — log viewing and filtering."""
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
from datetime import datetime, timedelta
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.logs import (
|
||||
LOG_FILES,
|
||||
_extract_level,
|
||||
_extract_logger_name,
|
||||
_line_matches_component,
|
||||
_matches_filters,
|
||||
_parse_line_timestamp,
|
||||
_parse_since,
|
||||
_read_last_n_lines,
|
||||
list_logs,
|
||||
tail_log,
|
||||
_read_tail,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
def log_dir(tmp_path, monkeypatch):
|
||||
"""Create a fake HERMES_HOME with a logs/ directory."""
|
||||
home = Path(os.environ["HERMES_HOME"])
|
||||
logs = home / "logs"
|
||||
logs.mkdir(parents=True, exist_ok=True)
|
||||
return logs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_log(log_dir):
|
||||
"""Write a realistic agent.log with mixed levels and sessions."""
|
||||
lines = textwrap.dedent("""\
|
||||
2026-04-05 10:00:00,000 INFO run_agent: conversation turn: session=sess_aaa model=claude provider=openrouter platform=cli history=0 msg='hello'
|
||||
2026-04-05 10:00:01,000 INFO run_agent: tool terminal completed (0.50s, 200 chars)
|
||||
2026-04-05 10:00:02,000 INFO run_agent: API call #1: model=claude provider=openrouter in=1000 out=200 total=1200 latency=1.5s
|
||||
2026-04-05 10:00:03,000 WARNING run_agent: Tool web_search returned error (2.00s): timeout
|
||||
2026-04-05 10:00:04,000 INFO run_agent: conversation turn: session=sess_bbb model=gpt-5 provider=openai platform=telegram history=5 msg='fix bug'
|
||||
2026-04-05 10:00:05,000 ERROR run_agent: API call failed after 3 retries. rate limited
|
||||
2026-04-05 10:00:06,000 INFO run_agent: tool read_file completed (0.01s, 500 chars)
|
||||
2026-04-05 10:00:07,000 DEBUG run_agent: verbose internal detail
|
||||
2026-04-05 10:00:08,000 INFO credential_pool: credential pool: marking key-1 exhausted (status=429), rotating
|
||||
2026-04-05 10:00:09,000 INFO credential_pool: credential pool: rotated to key-2
|
||||
""")
|
||||
path = log_dir / "agent.log"
|
||||
path.write_text(lines)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_errors_log(log_dir):
|
||||
"""Write a small errors.log."""
|
||||
lines = textwrap.dedent("""\
|
||||
2026-04-05 10:00:03,000 WARNING run_agent: Tool web_search returned error (2.00s): timeout
|
||||
2026-04-05 10:00:05,000 ERROR run_agent: API call failed after 3 retries. rate limited
|
||||
""")
|
||||
path = log_dir / "errors.log"
|
||||
path.write_text(lines)
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_since
|
||||
# Timestamp parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseSince:
|
||||
def test_hours(self):
|
||||
cutoff = _parse_since("2h")
|
||||
assert cutoff is not None
|
||||
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(7200, abs=5)
|
||||
assert abs((datetime.now() - cutoff).total_seconds() - 7200) < 2
|
||||
|
||||
def test_minutes(self):
|
||||
cutoff = _parse_since("30m")
|
||||
assert cutoff is not None
|
||||
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(1800, abs=5)
|
||||
assert abs((datetime.now() - cutoff).total_seconds() - 1800) < 2
|
||||
|
||||
def test_days(self):
|
||||
cutoff = _parse_since("1d")
|
||||
assert cutoff is not None
|
||||
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(86400, abs=5)
|
||||
assert abs((datetime.now() - cutoff).total_seconds() - 86400) < 2
|
||||
|
||||
def test_seconds(self):
|
||||
cutoff = _parse_since("60s")
|
||||
cutoff = _parse_since("120s")
|
||||
assert cutoff is not None
|
||||
assert (datetime.now() - cutoff).total_seconds() == pytest.approx(60, abs=5)
|
||||
assert abs((datetime.now() - cutoff).total_seconds() - 120) < 2
|
||||
|
||||
def test_invalid_returns_none(self):
|
||||
assert _parse_since("abc") is None
|
||||
assert _parse_since("") is None
|
||||
assert _parse_since("10x") is None
|
||||
|
||||
def test_whitespace_handling(self):
|
||||
cutoff = _parse_since(" 1h ")
|
||||
def test_whitespace_tolerance(self):
|
||||
cutoff = _parse_since(" 5m ")
|
||||
assert cutoff is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_line_timestamp
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseLineTimestamp:
|
||||
def test_standard_format(self):
|
||||
ts = _parse_line_timestamp("2026-04-05 10:00:00,123 INFO something")
|
||||
assert ts is not None
|
||||
assert ts.year == 2026
|
||||
assert ts.hour == 10
|
||||
ts = _parse_line_timestamp("2026-04-11 10:23:45 INFO gateway.run: msg")
|
||||
assert ts == datetime(2026, 4, 11, 10, 23, 45)
|
||||
|
||||
def test_no_timestamp(self):
|
||||
assert _parse_line_timestamp("just some text") is None
|
||||
assert _parse_line_timestamp("no timestamp here") is None
|
||||
|
||||
def test_continuation_line(self):
|
||||
assert _parse_line_timestamp(" at module.function (line 42)") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_level
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractLevel:
|
||||
def test_info(self):
|
||||
assert _extract_level("2026-04-05 10:00:00 INFO run_agent: something") == "INFO"
|
||||
assert _extract_level("2026-01-01 00:00:00 INFO gateway.run: msg") == "INFO"
|
||||
|
||||
def test_warning(self):
|
||||
assert _extract_level("2026-04-05 10:00:00 WARNING run_agent: bad") == "WARNING"
|
||||
assert _extract_level("2026-01-01 00:00:00 WARNING tools.file: msg") == "WARNING"
|
||||
|
||||
def test_error(self):
|
||||
assert _extract_level("2026-04-05 10:00:00 ERROR run_agent: crash") == "ERROR"
|
||||
assert _extract_level("2026-01-01 00:00:00 ERROR run_agent: msg") == "ERROR"
|
||||
|
||||
def test_debug(self):
|
||||
assert _extract_level("2026-04-05 10:00:00 DEBUG run_agent: detail") == "DEBUG"
|
||||
assert _extract_level("2026-01-01 00:00:00 DEBUG agent.aux: msg") == "DEBUG"
|
||||
|
||||
def test_no_level(self):
|
||||
assert _extract_level("just a plain line") is None
|
||||
assert _extract_level("random text") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _matches_filters
|
||||
# Logger name extraction (new for component filtering)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractLoggerName:
|
||||
def test_standard_line(self):
|
||||
line = "2026-04-11 10:23:45 INFO gateway.run: Starting gateway"
|
||||
assert _extract_logger_name(line) == "gateway.run"
|
||||
|
||||
def test_nested_logger(self):
|
||||
line = "2026-04-11 10:23:45 INFO gateway.platforms.telegram: connected"
|
||||
assert _extract_logger_name(line) == "gateway.platforms.telegram"
|
||||
|
||||
def test_warning_level(self):
|
||||
line = "2026-04-11 10:23:45 WARNING tools.terminal_tool: timeout"
|
||||
assert _extract_logger_name(line) == "tools.terminal_tool"
|
||||
|
||||
def test_with_session_tag(self):
|
||||
line = "2026-04-11 10:23:45 INFO [abc123] tools.file_tools: reading file"
|
||||
assert _extract_logger_name(line) == "tools.file_tools"
|
||||
|
||||
def test_with_session_tag_and_error(self):
|
||||
line = "2026-04-11 10:23:45 ERROR [sess_xyz] agent.context_compressor: failed"
|
||||
assert _extract_logger_name(line) == "agent.context_compressor"
|
||||
|
||||
def test_top_level_module(self):
|
||||
line = "2026-04-11 10:23:45 INFO run_agent: starting conversation"
|
||||
assert _extract_logger_name(line) == "run_agent"
|
||||
|
||||
def test_no_match(self):
|
||||
assert _extract_logger_name("random text") is None
|
||||
|
||||
|
||||
class TestLineMatchesComponent:
|
||||
def test_gateway_component(self):
|
||||
line = "2026-04-11 10:23:45 INFO gateway.run: msg"
|
||||
assert _line_matches_component(line, ("gateway",))
|
||||
|
||||
def test_gateway_nested(self):
|
||||
line = "2026-04-11 10:23:45 INFO gateway.platforms.telegram: msg"
|
||||
assert _line_matches_component(line, ("gateway",))
|
||||
|
||||
def test_tools_component(self):
|
||||
line = "2026-04-11 10:23:45 INFO tools.terminal_tool: msg"
|
||||
assert _line_matches_component(line, ("tools",))
|
||||
|
||||
def test_agent_with_multiple_prefixes(self):
|
||||
prefixes = ("agent", "run_agent", "model_tools")
|
||||
assert _line_matches_component(
|
||||
"2026-04-11 10:23:45 INFO agent.context_compressor: msg", prefixes)
|
||||
assert _line_matches_component(
|
||||
"2026-04-11 10:23:45 INFO run_agent: msg", prefixes)
|
||||
assert _line_matches_component(
|
||||
"2026-04-11 10:23:45 INFO model_tools: msg", prefixes)
|
||||
|
||||
def test_no_match(self):
|
||||
line = "2026-04-11 10:23:45 INFO tools.browser: msg"
|
||||
assert not _line_matches_component(line, ("gateway",))
|
||||
|
||||
def test_with_session_tag(self):
|
||||
line = "2026-04-11 10:23:45 INFO [abc] gateway.run: msg"
|
||||
assert _line_matches_component(line, ("gateway",))
|
||||
|
||||
def test_unparseable_line(self):
|
||||
assert not _line_matches_component("random text", ("gateway",))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Combined filter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatchesFilters:
|
||||
def test_no_filters_always_matches(self):
|
||||
assert _matches_filters("any line") is True
|
||||
def test_no_filters_passes_everything(self):
|
||||
assert _matches_filters("any line")
|
||||
|
||||
def test_level_filter_passes(self):
|
||||
def test_level_filter(self):
|
||||
assert _matches_filters(
|
||||
"2026-04-05 10:00:00 WARNING something",
|
||||
min_level="WARNING",
|
||||
) is True
|
||||
"2026-01-01 00:00:00 WARNING x: msg", min_level="WARNING")
|
||||
assert not _matches_filters(
|
||||
"2026-01-01 00:00:00 INFO x: msg", min_level="WARNING")
|
||||
|
||||
def test_level_filter_rejects(self):
|
||||
def test_session_filter(self):
|
||||
assert _matches_filters(
|
||||
"2026-04-05 10:00:00 INFO something",
|
||||
min_level="WARNING",
|
||||
) is False
|
||||
"2026-01-01 00:00:00 INFO [abc123] x: msg", session_filter="abc123")
|
||||
assert not _matches_filters(
|
||||
"2026-01-01 00:00:00 INFO [xyz789] x: msg", session_filter="abc123")
|
||||
|
||||
def test_session_filter_passes(self):
|
||||
def test_component_filter(self):
|
||||
assert _matches_filters(
|
||||
"session=sess_aaa model=claude",
|
||||
session_filter="sess_aaa",
|
||||
) is True
|
||||
|
||||
def test_session_filter_rejects(self):
|
||||
assert _matches_filters(
|
||||
"session=sess_aaa model=claude",
|
||||
session_filter="sess_bbb",
|
||||
) is False
|
||||
|
||||
def test_since_filter_passes(self):
|
||||
# Line from the future should always pass
|
||||
assert _matches_filters(
|
||||
"2099-01-01 00:00:00 INFO future",
|
||||
since=datetime.now(),
|
||||
) is True
|
||||
|
||||
def test_since_filter_rejects(self):
|
||||
assert _matches_filters(
|
||||
"2020-01-01 00:00:00 INFO past",
|
||||
since=datetime.now(),
|
||||
) is False
|
||||
"2026-01-01 00:00:00 INFO gateway.run: msg",
|
||||
component_prefixes=("gateway",))
|
||||
assert not _matches_filters(
|
||||
"2026-01-01 00:00:00 INFO tools.file: msg",
|
||||
component_prefixes=("gateway",))
|
||||
|
||||
def test_combined_filters(self):
|
||||
line = "2099-01-01 00:00:00 WARNING run_agent: session=abc error"
|
||||
"""All filters must pass for a line to match."""
|
||||
line = "2026-04-11 10:00:00 WARNING [sess_1] gateway.run: connection lost"
|
||||
assert _matches_filters(
|
||||
line, min_level="WARNING", session_filter="abc",
|
||||
since=datetime.now(),
|
||||
) is True
|
||||
# Fails session filter
|
||||
line,
|
||||
min_level="WARNING",
|
||||
session_filter="sess_1",
|
||||
component_prefixes=("gateway",),
|
||||
)
|
||||
# Fails component filter
|
||||
assert not _matches_filters(
|
||||
line,
|
||||
min_level="WARNING",
|
||||
session_filter="sess_1",
|
||||
component_prefixes=("tools",),
|
||||
)
|
||||
|
||||
def test_since_filter(self):
|
||||
# Line with a very old timestamp should be filtered out
|
||||
assert not _matches_filters(
|
||||
"2020-01-01 00:00:00 INFO x: old msg",
|
||||
since=datetime.now() - timedelta(hours=1))
|
||||
# Line with a recent timestamp should pass
|
||||
recent = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
assert _matches_filters(
|
||||
line, min_level="WARNING", session_filter="xyz",
|
||||
) is False
|
||||
f"{recent} INFO x: recent msg",
|
||||
since=datetime.now() - timedelta(hours=1))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_last_n_lines
|
||||
# File reading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadLastNLines:
|
||||
def test_reads_correct_count(self, sample_agent_log):
|
||||
lines = _read_last_n_lines(sample_agent_log, 3)
|
||||
assert len(lines) == 3
|
||||
class TestReadTail:
|
||||
def test_read_small_file(self, tmp_path):
|
||||
log_file = tmp_path / "test.log"
|
||||
lines = [f"2026-01-01 00:00:0{i} INFO x: line {i}\n" for i in range(10)]
|
||||
log_file.write_text("".join(lines))
|
||||
|
||||
def test_reads_all_when_fewer(self, sample_agent_log):
|
||||
lines = _read_last_n_lines(sample_agent_log, 100)
|
||||
assert len(lines) == 10 # sample has 10 lines
|
||||
result = _read_last_n_lines(log_file, 5)
|
||||
assert len(result) == 5
|
||||
assert "line 9" in result[-1]
|
||||
|
||||
def test_empty_file(self, log_dir):
|
||||
empty = log_dir / "empty.log"
|
||||
empty.write_text("")
|
||||
lines = _read_last_n_lines(empty, 10)
|
||||
assert lines == []
|
||||
def test_read_with_component_filter(self, tmp_path):
|
||||
log_file = tmp_path / "test.log"
|
||||
lines = [
|
||||
"2026-01-01 00:00:00 INFO gateway.run: gw msg\n",
|
||||
"2026-01-01 00:00:01 INFO tools.file: tool msg\n",
|
||||
"2026-01-01 00:00:02 INFO gateway.session: session msg\n",
|
||||
"2026-01-01 00:00:03 INFO agent.compressor: agent msg\n",
|
||||
]
|
||||
log_file.write_text("".join(lines))
|
||||
|
||||
def test_last_line_content(self, sample_agent_log):
|
||||
lines = _read_last_n_lines(sample_agent_log, 1)
|
||||
assert "rotated to key-2" in lines[0]
|
||||
result = _read_tail(
|
||||
log_file, 50,
|
||||
has_filters=True,
|
||||
component_prefixes=("gateway",),
|
||||
)
|
||||
assert len(result) == 2
|
||||
assert "gw msg" in result[0]
|
||||
assert "session msg" in result[1]
|
||||
|
||||
def test_empty_file(self, tmp_path):
|
||||
log_file = tmp_path / "empty.log"
|
||||
log_file.write_text("")
|
||||
result = _read_last_n_lines(log_file, 10)
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tail_log
|
||||
# LOG_FILES registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTailLog:
|
||||
def test_basic_tail(self, sample_agent_log, capsys):
|
||||
tail_log("agent", num_lines=3)
|
||||
captured = capsys.readouterr()
|
||||
assert "agent.log" in captured.out
|
||||
# Should have the header + 3 lines
|
||||
lines = captured.out.strip().split("\n")
|
||||
assert len(lines) == 4 # 1 header + 3 content
|
||||
|
||||
def test_level_filter(self, sample_agent_log, capsys):
|
||||
tail_log("agent", num_lines=50, level="ERROR")
|
||||
captured = capsys.readouterr()
|
||||
assert "level>=ERROR" in captured.out
|
||||
# Only the ERROR line should appear
|
||||
content_lines = [l for l in captured.out.strip().split("\n") if not l.startswith("---")]
|
||||
assert len(content_lines) == 1
|
||||
assert "API call failed" in content_lines[0]
|
||||
|
||||
def test_session_filter(self, sample_agent_log, capsys):
|
||||
tail_log("agent", num_lines=50, session="sess_bbb")
|
||||
captured = capsys.readouterr()
|
||||
content_lines = [l for l in captured.out.strip().split("\n") if not l.startswith("---")]
|
||||
assert len(content_lines) == 1
|
||||
assert "sess_bbb" in content_lines[0]
|
||||
|
||||
def test_errors_log(self, sample_errors_log, capsys):
|
||||
tail_log("errors", num_lines=10)
|
||||
captured = capsys.readouterr()
|
||||
assert "errors.log" in captured.out
|
||||
assert "WARNING" in captured.out or "ERROR" in captured.out
|
||||
|
||||
def test_unknown_log_exits(self):
|
||||
with pytest.raises(SystemExit):
|
||||
tail_log("nonexistent")
|
||||
|
||||
def test_missing_file_exits(self, log_dir):
|
||||
with pytest.raises(SystemExit):
|
||||
tail_log("agent") # agent.log doesn't exist in clean log_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_logs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListLogs:
|
||||
def test_lists_files(self, sample_agent_log, sample_errors_log, capsys):
|
||||
list_logs()
|
||||
captured = capsys.readouterr()
|
||||
assert "agent.log" in captured.out
|
||||
assert "errors.log" in captured.out
|
||||
|
||||
def test_empty_dir(self, log_dir, capsys):
|
||||
list_logs()
|
||||
captured = capsys.readouterr()
|
||||
assert "no log files yet" in captured.out
|
||||
|
||||
def test_shows_sizes(self, sample_agent_log, capsys):
|
||||
list_logs()
|
||||
captured = capsys.readouterr()
|
||||
# File is small, should show as bytes or KB
|
||||
assert "B" in captured.out or "KB" in captured.out
|
||||
class TestLogFiles:
|
||||
def test_known_log_files(self):
|
||||
assert "agent" in LOG_FILES
|
||||
assert "errors" in LOG_FILES
|
||||
assert "gateway" in LOG_FILES
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ def _make_args(**kwargs):
|
|||
"command": None,
|
||||
"args": None,
|
||||
"auth": None,
|
||||
"preset": None,
|
||||
"env": None,
|
||||
"mcp_action": None,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
|
|
@ -269,6 +271,145 @@ class TestMcpAdd:
|
|||
config = load_config()
|
||||
assert config["mcp_servers"]["broken"]["enabled"] is False
|
||||
|
||||
def test_add_stdio_server_with_env(self, tmp_path, capsys, monkeypatch):
|
||||
"""Stdio servers can persist explicit environment variables."""
|
||||
fake_tools = [FakeTool("search", "Search repos")]
|
||||
|
||||
def mock_probe(name, config, **kw):
|
||||
assert config["env"] == {
|
||||
"MY_API_KEY": "secret123",
|
||||
"DEBUG": "true",
|
||||
}
|
||||
return [(t.name, t.description) for t in fake_tools]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._probe_single_server", mock_probe
|
||||
)
|
||||
monkeypatch.setattr("builtins.input", lambda _: "")
|
||||
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(
|
||||
name="github",
|
||||
command="npx",
|
||||
args=["@mcp/github"],
|
||||
env=["MY_API_KEY=secret123", "DEBUG=true"],
|
||||
))
|
||||
out = capsys.readouterr().out
|
||||
assert "Saved" in out
|
||||
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
srv = config["mcp_servers"]["github"]
|
||||
assert srv["env"] == {
|
||||
"MY_API_KEY": "secret123",
|
||||
"DEBUG": "true",
|
||||
}
|
||||
|
||||
def test_add_stdio_server_rejects_invalid_env_name(self, capsys):
|
||||
"""Invalid environment variable names are rejected up front."""
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(
|
||||
name="github",
|
||||
command="npx",
|
||||
args=["@mcp/github"],
|
||||
env=["BAD-NAME=value"],
|
||||
))
|
||||
out = capsys.readouterr().out
|
||||
assert "Invalid --env variable name" in out
|
||||
|
||||
def test_add_http_server_rejects_env_flag(self, capsys):
|
||||
"""The --env flag is only valid for stdio transports."""
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(
|
||||
name="ink",
|
||||
url="https://mcp.ml.ink/mcp",
|
||||
env=["DEBUG=true"],
|
||||
))
|
||||
out = capsys.readouterr().out
|
||||
assert "only supported for stdio MCP servers" in out
|
||||
|
||||
def test_add_preset_fills_transport(self, tmp_path, capsys, monkeypatch):
|
||||
"""A preset fills in command/args when no explicit transport given."""
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._MCP_PRESETS",
|
||||
{"testmcp": {"command": "npx", "args": ["-y", "test-mcp-server"], "display_name": "Test MCP"}},
|
||||
)
|
||||
fake_tools = [FakeTool("do_thing", "Does a thing")]
|
||||
|
||||
def mock_probe(name, config, **kw):
|
||||
assert name == "myserver"
|
||||
assert config["command"] == "npx"
|
||||
assert config["args"] == ["-y", "test-mcp-server"]
|
||||
assert "env" not in config
|
||||
return [(t.name, t.description) for t in fake_tools]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._probe_single_server", mock_probe
|
||||
)
|
||||
monkeypatch.setattr("builtins.input", lambda _: "")
|
||||
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
from hermes_cli.config import read_raw_config
|
||||
|
||||
cmd_mcp_add(_make_args(name="myserver", preset="testmcp"))
|
||||
out = capsys.readouterr().out
|
||||
assert "Saved" in out
|
||||
|
||||
config = read_raw_config()
|
||||
srv = config["mcp_servers"]["myserver"]
|
||||
assert srv["command"] == "npx"
|
||||
assert srv["args"] == ["-y", "test-mcp-server"]
|
||||
assert "env" not in srv
|
||||
|
||||
def test_preset_does_not_override_explicit_command(self, tmp_path, capsys, monkeypatch):
|
||||
"""Explicit transports win over presets."""
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._MCP_PRESETS",
|
||||
{"testmcp": {"command": "npx", "args": ["-y", "test-mcp-server"], "display_name": "Test MCP"}},
|
||||
)
|
||||
fake_tools = [FakeTool("search", "Search repos")]
|
||||
|
||||
def mock_probe(name, config, **kw):
|
||||
assert config["command"] == "uvx"
|
||||
assert config["args"] == ["custom-server"]
|
||||
assert "env" not in config
|
||||
return [(t.name, t.description) for t in fake_tools]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._probe_single_server", mock_probe
|
||||
)
|
||||
monkeypatch.setattr("builtins.input", lambda _: "")
|
||||
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
from hermes_cli.config import read_raw_config
|
||||
|
||||
cmd_mcp_add(_make_args(
|
||||
name="custom",
|
||||
preset="testmcp",
|
||||
command="uvx",
|
||||
args=["custom-server"],
|
||||
))
|
||||
out = capsys.readouterr().out
|
||||
assert "Saved" in out
|
||||
|
||||
config = read_raw_config()
|
||||
srv = config["mcp_servers"]["custom"]
|
||||
assert srv["command"] == "uvx"
|
||||
assert srv["args"] == ["custom-server"]
|
||||
assert "env" not in srv
|
||||
|
||||
def test_unknown_preset_rejected(self, capsys):
|
||||
"""An unknown preset name is rejected with a clear error."""
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(name="foo", preset="nonexistent"))
|
||||
out = capsys.readouterr().out
|
||||
assert "Unknown MCP preset" in out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: cmd_mcp_test
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from rich.console import Console
|
||||
|
||||
from cli import ChatConsole
|
||||
from hermes_cli.skills_hub import do_check, do_install, do_list, do_update, handle_skills_slash
|
||||
|
||||
|
||||
|
|
@ -179,6 +181,21 @@ def test_do_update_reinstalls_outdated_skills(monkeypatch):
|
|||
assert "Updated 1 skill" in output
|
||||
|
||||
|
||||
def test_handle_skills_slash_search_accepts_chatconsole_without_status_errors():
|
||||
results = [type("R", (), {
|
||||
"name": "kubernetes",
|
||||
"description": "Cluster orchestration",
|
||||
"source": "skills.sh",
|
||||
"trust_level": "community",
|
||||
"identifier": "skills-sh/example/kubernetes",
|
||||
})()]
|
||||
|
||||
with patch("tools.skills_hub.unified_search", return_value=results), \
|
||||
patch("tools.skills_hub.create_source_router", return_value={}), \
|
||||
patch("tools.skills_hub.GitHubAuth"):
|
||||
handle_skills_slash("/skills search kubernetes", console=ChatConsole())
|
||||
|
||||
|
||||
def test_do_install_scans_with_resolved_identifier(monkeypatch, tmp_path, hub_env):
|
||||
import tools.skills_guard as guard
|
||||
import tools.skills_hub as hub
|
||||
|
|
|
|||
|
|
@ -191,6 +191,19 @@ class TestLaunchdPlistPath:
|
|||
raise AssertionError("PATH key not found in plist")
|
||||
|
||||
|
||||
class TestLaunchdPlistCurrentness:
|
||||
def test_launchd_plist_is_current_ignores_path_drift(self, tmp_path, monkeypatch):
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path)
|
||||
|
||||
monkeypatch.setenv("PATH", "/custom/bin:/usr/bin:/bin")
|
||||
plist_path.write_text(gateway_cli.generate_launchd_plist(), encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("PATH", "/opt/homebrew/bin:/usr/local/bin:/usr/bin:/bin")
|
||||
|
||||
assert gateway_cli.launchd_plist_is_current() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cmd_update — macOS launchd detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -536,7 +549,7 @@ class TestServicePidExclusion:
|
|||
gateway_cli, "_get_service_pids", return_value={SERVICE_PID}
|
||||
), patch.object(
|
||||
gateway_cli, "find_gateway_pids",
|
||||
side_effect=lambda exclude_pids=None: (
|
||||
side_effect=lambda exclude_pids=None, all_profiles=False: (
|
||||
[SERVICE_PID] if not exclude_pids else
|
||||
[p for p in [SERVICE_PID] if p not in exclude_pids]
|
||||
),
|
||||
|
|
@ -579,7 +592,7 @@ class TestServicePidExclusion:
|
|||
gateway_cli, "_get_service_pids", return_value={SERVICE_PID}
|
||||
), patch.object(
|
||||
gateway_cli, "find_gateway_pids",
|
||||
side_effect=lambda exclude_pids=None: (
|
||||
side_effect=lambda exclude_pids=None, all_profiles=False: (
|
||||
[SERVICE_PID] if not exclude_pids else
|
||||
[p for p in [SERVICE_PID] if p not in exclude_pids]
|
||||
),
|
||||
|
|
@ -618,7 +631,7 @@ class TestServicePidExclusion:
|
|||
launchctl_loaded=True,
|
||||
)
|
||||
|
||||
def fake_find(exclude_pids=None):
|
||||
def fake_find(exclude_pids=None, all_profiles=False):
|
||||
_exclude = exclude_pids or set()
|
||||
return [p for p in [SERVICE_PID, MANUAL_PID] if p not in _exclude]
|
||||
|
||||
|
|
@ -760,3 +773,28 @@ class TestFindGatewayPidsExclude:
|
|||
pids = gateway_cli.find_gateway_pids()
|
||||
assert 100 in pids
|
||||
assert 200 in pids
|
||||
|
||||
def test_filters_to_current_profile(self, monkeypatch, tmp_path):
|
||||
profile_dir = tmp_path / ".hermes" / "profiles" / "orcha"
|
||||
profile_dir.mkdir(parents=True)
|
||||
monkeypatch.setattr(gateway_cli, "is_windows", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir)
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
return subprocess.CompletedProcess(
|
||||
cmd, 0,
|
||||
stdout=(
|
||||
"100 /Users/dgrieco/.hermes/hermes-agent/venv/bin/python -m hermes_cli.main --profile orcha gateway run --replace\n"
|
||||
"200 /Users/dgrieco/.hermes/hermes-agent/venv/bin/python -m hermes_cli.main --profile other gateway run --replace\n"
|
||||
),
|
||||
stderr="",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
|
||||
monkeypatch.setattr("os.getpid", lambda: 999)
|
||||
monkeypatch.setattr(gateway_cli, "_get_service_pids", lambda: set())
|
||||
monkeypatch.setattr(gateway_cli, "_profile_arg", lambda hermes_home=None: "--profile orcha")
|
||||
|
||||
pids = gateway_cli.find_gateway_pids()
|
||||
|
||||
assert pids == [100]
|
||||
|
|
|
|||
|
|
@ -2742,74 +2742,12 @@ class TestSystemPromptStability:
|
|||
assert "Hermes Agent" in agent._cached_system_prompt
|
||||
|
||||
class TestBudgetPressure:
|
||||
"""Budget pressure warning system (issue #414)."""
|
||||
"""Budget exhaustion grace call system."""
|
||||
|
||||
def test_no_warning_below_caution(self, agent):
|
||||
agent.max_iterations = 60
|
||||
assert agent._get_budget_warning(30) is None
|
||||
|
||||
def test_caution_at_70_percent(self, agent):
|
||||
agent.max_iterations = 60
|
||||
msg = agent._get_budget_warning(42)
|
||||
assert msg is not None
|
||||
assert "[BUDGET:" in msg
|
||||
assert "18 iterations left" in msg
|
||||
|
||||
def test_warning_at_90_percent(self, agent):
|
||||
agent.max_iterations = 60
|
||||
msg = agent._get_budget_warning(54)
|
||||
assert "[BUDGET WARNING:" in msg
|
||||
assert "Provide your final response NOW" in msg
|
||||
|
||||
def test_last_iteration(self, agent):
|
||||
agent.max_iterations = 60
|
||||
msg = agent._get_budget_warning(59)
|
||||
assert "1 iteration(s) left" in msg
|
||||
|
||||
def test_disabled(self, agent):
|
||||
agent.max_iterations = 60
|
||||
agent._budget_pressure_enabled = False
|
||||
assert agent._get_budget_warning(55) is None
|
||||
|
||||
def test_zero_max_iterations(self, agent):
|
||||
agent.max_iterations = 0
|
||||
assert agent._get_budget_warning(0) is None
|
||||
|
||||
def test_injects_into_json_tool_result(self, agent):
|
||||
"""Warning should be injected as _budget_warning field in JSON tool results."""
|
||||
import json
|
||||
agent.max_iterations = 10
|
||||
messages = [
|
||||
{"role": "tool", "content": json.dumps({"output": "done", "exit_code": 0}), "tool_call_id": "tc1"}
|
||||
]
|
||||
warning = agent._get_budget_warning(9)
|
||||
assert warning is not None
|
||||
# Simulate the injection logic
|
||||
last_content = messages[-1]["content"]
|
||||
parsed = json.loads(last_content)
|
||||
parsed["_budget_warning"] = warning
|
||||
messages[-1]["content"] = json.dumps(parsed, ensure_ascii=False)
|
||||
result = json.loads(messages[-1]["content"])
|
||||
assert "_budget_warning" in result
|
||||
assert "BUDGET WARNING" in result["_budget_warning"]
|
||||
assert result["output"] == "done" # original content preserved
|
||||
|
||||
def test_appends_to_non_json_tool_result(self, agent):
|
||||
"""Warning should be appended as text for non-JSON tool results."""
|
||||
agent.max_iterations = 10
|
||||
messages = [
|
||||
{"role": "tool", "content": "plain text result", "tool_call_id": "tc1"}
|
||||
]
|
||||
warning = agent._get_budget_warning(9)
|
||||
# Simulate injection logic for non-JSON
|
||||
last_content = messages[-1]["content"]
|
||||
try:
|
||||
import json
|
||||
json.loads(last_content)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
messages[-1]["content"] = last_content + f"\n\n{warning}"
|
||||
assert "plain text result" in messages[-1]["content"]
|
||||
assert "BUDGET WARNING" in messages[-1]["content"]
|
||||
def test_grace_call_flags_initialized(self, agent):
|
||||
"""Agent should have budget grace call flags."""
|
||||
assert agent._budget_exhausted_injected is False
|
||||
assert agent._budget_grace_call is False
|
||||
|
||||
|
||||
class TestSafeWriter:
|
||||
|
|
|
|||
|
|
@ -744,6 +744,44 @@ def test_normalize_codex_response_marks_commentary_only_message_as_incomplete(mo
|
|||
assert "inspect the repository" in (assistant_message.content or "")
|
||||
|
||||
|
||||
def test_interim_commentary_is_not_marked_already_streamed_without_callbacks(monkeypatch):
|
||||
agent = _build_agent(monkeypatch)
|
||||
observed = {}
|
||||
|
||||
agent._fire_stream_delta("short version: yes")
|
||||
agent.interim_assistant_callback = lambda text, *, already_streamed=False: observed.update(
|
||||
{"text": text, "already_streamed": already_streamed}
|
||||
)
|
||||
|
||||
agent._emit_interim_assistant_message({"role": "assistant", "content": "short version: yes"})
|
||||
|
||||
assert observed == {
|
||||
"text": "short version: yes",
|
||||
"already_streamed": False,
|
||||
}
|
||||
|
||||
|
||||
def test_interim_commentary_is_not_marked_already_streamed_when_stream_callback_fails(monkeypatch):
|
||||
agent = _build_agent(monkeypatch)
|
||||
observed = {}
|
||||
|
||||
def failing_callback(_text):
|
||||
raise RuntimeError("display failed")
|
||||
|
||||
agent.stream_delta_callback = failing_callback
|
||||
agent._fire_stream_delta("short version: yes")
|
||||
agent.interim_assistant_callback = lambda text, *, already_streamed=False: observed.update(
|
||||
{"text": text, "already_streamed": already_streamed}
|
||||
)
|
||||
|
||||
agent._emit_interim_assistant_message({"role": "assistant", "content": "short version: yes"})
|
||||
|
||||
assert observed == {
|
||||
"text": "short version: yes",
|
||||
"already_streamed": False,
|
||||
}
|
||||
|
||||
|
||||
def test_run_conversation_codex_continues_after_commentary_phase_message(monkeypatch):
|
||||
agent = _build_agent(monkeypatch)
|
||||
responses = [
|
||||
|
|
@ -1104,3 +1142,58 @@ def test_duplicate_detection_distinguishes_different_codex_reasoning(monkeypatch
|
|||
]
|
||||
assert "enc_first" in encrypted_contents
|
||||
assert "enc_second" in encrypted_contents
|
||||
|
||||
|
||||
def test_chat_messages_to_responses_input_deduplicates_reasoning_ids(monkeypatch):
|
||||
"""Duplicate reasoning item IDs across multi-turn incomplete responses
|
||||
must be deduplicated so the Responses API doesn't reject with HTTP 400."""
|
||||
agent = _build_agent(monkeypatch)
|
||||
messages = [
|
||||
{"role": "user", "content": "think hard"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"codex_reasoning_items": [
|
||||
{"type": "reasoning", "id": "rs_aaa", "encrypted_content": "enc_1"},
|
||||
{"type": "reasoning", "id": "rs_bbb", "encrypted_content": "enc_2"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "partial answer",
|
||||
"codex_reasoning_items": [
|
||||
# rs_aaa is duplicated from the previous turn
|
||||
{"type": "reasoning", "id": "rs_aaa", "encrypted_content": "enc_1"},
|
||||
{"type": "reasoning", "id": "rs_ccc", "encrypted_content": "enc_3"},
|
||||
],
|
||||
},
|
||||
]
|
||||
items = agent._chat_messages_to_responses_input(messages)
|
||||
|
||||
reasoning_ids = [it["id"] for it in items if it.get("type") == "reasoning"]
|
||||
# rs_aaa should appear only once (first occurrence kept)
|
||||
assert reasoning_ids.count("rs_aaa") == 1
|
||||
# rs_bbb and rs_ccc should each appear once
|
||||
assert reasoning_ids.count("rs_bbb") == 1
|
||||
assert reasoning_ids.count("rs_ccc") == 1
|
||||
assert len(reasoning_ids) == 3
|
||||
|
||||
|
||||
def test_preflight_codex_input_deduplicates_reasoning_ids(monkeypatch):
|
||||
"""_preflight_codex_input_items should also deduplicate reasoning items by ID."""
|
||||
agent = _build_agent(monkeypatch)
|
||||
raw_input = [
|
||||
{"role": "user", "content": [{"type": "input_text", "text": "hello"}]},
|
||||
{"type": "reasoning", "id": "rs_xyz", "encrypted_content": "enc_a"},
|
||||
{"role": "assistant", "content": "ok"},
|
||||
{"type": "reasoning", "id": "rs_xyz", "encrypted_content": "enc_a"},
|
||||
{"type": "reasoning", "id": "rs_zzz", "encrypted_content": "enc_b"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
]
|
||||
normalized = agent._preflight_codex_input_items(raw_input)
|
||||
|
||||
reasoning_items = [it for it in normalized if it.get("type") == "reasoning"]
|
||||
reasoning_ids = [it["id"] for it in reasoning_items]
|
||||
assert reasoning_ids.count("rs_xyz") == 1
|
||||
assert reasoning_ids.count("rs_zzz") == 1
|
||||
assert len(reasoning_items) == 2
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import logging
|
||||
import os
|
||||
import stat
|
||||
import threading
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
|
@ -34,6 +35,8 @@ def _reset_logging_state():
|
|||
h.close()
|
||||
else:
|
||||
pre_existing.append(h)
|
||||
# Ensure the record factory is installed (it's idempotent).
|
||||
hermes_logging._install_session_record_factory()
|
||||
yield
|
||||
# Restore — remove any handlers added during the test.
|
||||
for h in list(root.handlers):
|
||||
|
|
@ -41,6 +44,7 @@ def _reset_logging_state():
|
|||
root.removeHandler(h)
|
||||
h.close()
|
||||
hermes_logging._logging_initialized = False
|
||||
hermes_logging.clear_session_context()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -220,6 +224,294 @@ class TestSetupLogging:
|
|||
]
|
||||
assert agent_handlers[0].level == logging.WARNING
|
||||
|
||||
def test_record_factory_installed(self, hermes_home):
|
||||
"""The custom record factory injects session_tag on all records."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
factory = logging.getLogRecordFactory()
|
||||
assert getattr(factory, "_hermes_session_injector", False), (
|
||||
"Record factory should have _hermes_session_injector marker"
|
||||
)
|
||||
# Verify session_tag exists on a fresh record
|
||||
record = factory("test", logging.INFO, "", 0, "msg", (), None)
|
||||
assert hasattr(record, "session_tag")
|
||||
|
||||
|
||||
class TestGatewayMode:
|
||||
"""setup_logging(mode='gateway') creates a filtered gateway.log."""
|
||||
|
||||
def test_gateway_log_created(self, hermes_home):
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
root = logging.getLogger()
|
||||
|
||||
gw_handlers = [
|
||||
h for h in root.handlers
|
||||
if isinstance(h, RotatingFileHandler)
|
||||
and "gateway.log" in getattr(h, "baseFilename", "")
|
||||
]
|
||||
assert len(gw_handlers) == 1
|
||||
|
||||
def test_gateway_log_not_created_in_cli_mode(self, hermes_home):
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="cli")
|
||||
root = logging.getLogger()
|
||||
|
||||
gw_handlers = [
|
||||
h for h in root.handlers
|
||||
if isinstance(h, RotatingFileHandler)
|
||||
and "gateway.log" in getattr(h, "baseFilename", "")
|
||||
]
|
||||
assert len(gw_handlers) == 0
|
||||
|
||||
def test_gateway_log_receives_gateway_records(self, hermes_home):
|
||||
"""gateway.log captures records from gateway.* loggers."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
|
||||
gw_logger = logging.getLogger("gateway.platforms.telegram")
|
||||
gw_logger.info("telegram connected")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
gw_log = hermes_home / "logs" / "gateway.log"
|
||||
assert gw_log.exists()
|
||||
assert "telegram connected" in gw_log.read_text()
|
||||
|
||||
def test_gateway_log_rejects_non_gateway_records(self, hermes_home):
|
||||
"""gateway.log does NOT capture records from tools.*, agent.*, etc."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
|
||||
tool_logger = logging.getLogger("tools.terminal_tool")
|
||||
tool_logger.info("running command")
|
||||
|
||||
agent_logger = logging.getLogger("agent.context_compressor")
|
||||
agent_logger.info("compressing context")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
gw_log = hermes_home / "logs" / "gateway.log"
|
||||
if gw_log.exists():
|
||||
content = gw_log.read_text()
|
||||
assert "running command" not in content
|
||||
assert "compressing context" not in content
|
||||
|
||||
def test_agent_log_still_receives_all(self, hermes_home):
|
||||
"""agent.log (catch-all) still receives gateway AND tool records."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
|
||||
logging.getLogger("gateway.run").info("gateway msg")
|
||||
logging.getLogger("tools.file_tools").info("file msg")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
assert "gateway msg" in content
|
||||
assert "file msg" in content
|
||||
|
||||
|
||||
class TestSessionContext:
|
||||
"""set_session_context / clear_session_context + _SessionFilter."""
|
||||
|
||||
def test_session_tag_in_log_output(self, hermes_home):
|
||||
"""When session context is set, log lines include [session_id]."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
hermes_logging.set_session_context("abc123")
|
||||
|
||||
test_logger = logging.getLogger("test.session_tag")
|
||||
test_logger.info("tagged message")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
assert "[abc123]" in content
|
||||
assert "tagged message" in content
|
||||
|
||||
def test_no_session_tag_without_context(self, hermes_home):
|
||||
"""Without session context, log lines have no session tag."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
hermes_logging.clear_session_context()
|
||||
|
||||
test_logger = logging.getLogger("test.no_session")
|
||||
test_logger.info("untagged message")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
assert "untagged message" in content
|
||||
# Should not have any [xxx] session tag
|
||||
import re
|
||||
for line in content.splitlines():
|
||||
if "untagged message" in line:
|
||||
assert not re.search(r"\[.+?\]", line.split("INFO")[1].split("test.no_session")[0])
|
||||
|
||||
def test_clear_session_context(self, hermes_home):
|
||||
"""After clearing, session tag disappears."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
hermes_logging.set_session_context("xyz789")
|
||||
hermes_logging.clear_session_context()
|
||||
|
||||
test_logger = logging.getLogger("test.cleared")
|
||||
test_logger.info("after clear")
|
||||
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
assert "[xyz789]" not in content
|
||||
|
||||
def test_session_context_thread_isolated(self, hermes_home):
|
||||
"""Session context is per-thread — one thread's context doesn't leak."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home)
|
||||
|
||||
results = {}
|
||||
|
||||
def thread_a():
|
||||
hermes_logging.set_session_context("thread_a_session")
|
||||
logging.getLogger("test.thread_a").info("from thread A")
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
def thread_b():
|
||||
hermes_logging.set_session_context("thread_b_session")
|
||||
logging.getLogger("test.thread_b").info("from thread B")
|
||||
for h in logging.getLogger().handlers:
|
||||
h.flush()
|
||||
|
||||
ta = threading.Thread(target=thread_a)
|
||||
tb = threading.Thread(target=thread_b)
|
||||
ta.start()
|
||||
ta.join()
|
||||
tb.start()
|
||||
tb.join()
|
||||
|
||||
agent_log = hermes_home / "logs" / "agent.log"
|
||||
content = agent_log.read_text()
|
||||
|
||||
# Each thread's message should have its own session tag
|
||||
for line in content.splitlines():
|
||||
if "from thread A" in line:
|
||||
assert "[thread_a_session]" in line
|
||||
assert "[thread_b_session]" not in line
|
||||
if "from thread B" in line:
|
||||
assert "[thread_b_session]" in line
|
||||
assert "[thread_a_session]" not in line
|
||||
|
||||
|
||||
class TestRecordFactory:
|
||||
"""Unit tests for the custom LogRecord factory."""
|
||||
|
||||
def test_record_has_session_tag(self):
|
||||
"""Every record gets a session_tag attribute."""
|
||||
factory = logging.getLogRecordFactory()
|
||||
record = factory("test", logging.INFO, "", 0, "msg", (), None)
|
||||
assert hasattr(record, "session_tag")
|
||||
|
||||
def test_empty_tag_without_context(self):
|
||||
hermes_logging.clear_session_context()
|
||||
factory = logging.getLogRecordFactory()
|
||||
record = factory("test", logging.INFO, "", 0, "msg", (), None)
|
||||
assert record.session_tag == ""
|
||||
|
||||
def test_tag_with_context(self):
|
||||
hermes_logging.set_session_context("sess_42")
|
||||
factory = logging.getLogRecordFactory()
|
||||
record = factory("test", logging.INFO, "", 0, "msg", (), None)
|
||||
assert record.session_tag == " [sess_42]"
|
||||
|
||||
def test_idempotent_install(self):
|
||||
"""Calling _install_session_record_factory() twice doesn't double-wrap."""
|
||||
hermes_logging._install_session_record_factory()
|
||||
factory_a = logging.getLogRecordFactory()
|
||||
hermes_logging._install_session_record_factory()
|
||||
factory_b = logging.getLogRecordFactory()
|
||||
assert factory_a is factory_b
|
||||
|
||||
def test_works_with_any_handler(self):
|
||||
"""A handler using %(session_tag)s works even without _SessionFilter."""
|
||||
hermes_logging.set_session_context("any_handler_test")
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter("%(session_tag)s %(message)s"))
|
||||
|
||||
logger = logging.getLogger("_test_any_handler")
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
try:
|
||||
# Should not raise KeyError
|
||||
logger.info("hello")
|
||||
finally:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
|
||||
class TestComponentFilter:
|
||||
"""Unit tests for _ComponentFilter."""
|
||||
|
||||
def test_passes_matching_prefix(self):
|
||||
f = hermes_logging._ComponentFilter(("gateway",))
|
||||
record = logging.LogRecord(
|
||||
"gateway.run", logging.INFO, "", 0, "msg", (), None
|
||||
)
|
||||
assert f.filter(record) is True
|
||||
|
||||
def test_passes_nested_matching_prefix(self):
|
||||
f = hermes_logging._ComponentFilter(("gateway",))
|
||||
record = logging.LogRecord(
|
||||
"gateway.platforms.telegram", logging.INFO, "", 0, "msg", (), None
|
||||
)
|
||||
assert f.filter(record) is True
|
||||
|
||||
def test_blocks_non_matching(self):
|
||||
f = hermes_logging._ComponentFilter(("gateway",))
|
||||
record = logging.LogRecord(
|
||||
"tools.terminal_tool", logging.INFO, "", 0, "msg", (), None
|
||||
)
|
||||
assert f.filter(record) is False
|
||||
|
||||
def test_multiple_prefixes(self):
|
||||
f = hermes_logging._ComponentFilter(("agent", "run_agent", "model_tools"))
|
||||
assert f.filter(logging.LogRecord(
|
||||
"agent.compressor", logging.INFO, "", 0, "", (), None
|
||||
))
|
||||
assert f.filter(logging.LogRecord(
|
||||
"run_agent", logging.INFO, "", 0, "", (), None
|
||||
))
|
||||
assert f.filter(logging.LogRecord(
|
||||
"model_tools", logging.INFO, "", 0, "", (), None
|
||||
))
|
||||
assert not f.filter(logging.LogRecord(
|
||||
"tools.browser", logging.INFO, "", 0, "", (), None
|
||||
))
|
||||
|
||||
|
||||
class TestComponentPrefixes:
|
||||
"""COMPONENT_PREFIXES covers the expected components."""
|
||||
|
||||
def test_gateway_prefix(self):
|
||||
assert "gateway" in hermes_logging.COMPONENT_PREFIXES
|
||||
assert ("gateway",) == hermes_logging.COMPONENT_PREFIXES["gateway"]
|
||||
|
||||
def test_agent_prefix(self):
|
||||
prefixes = hermes_logging.COMPONENT_PREFIXES["agent"]
|
||||
assert "agent" in prefixes
|
||||
assert "run_agent" in prefixes
|
||||
assert "model_tools" in prefixes
|
||||
|
||||
def test_tools_prefix(self):
|
||||
assert ("tools",) == hermes_logging.COMPONENT_PREFIXES["tools"]
|
||||
|
||||
def test_cli_prefix(self):
|
||||
prefixes = hermes_logging.COMPONENT_PREFIXES["cli"]
|
||||
assert "hermes_cli" in prefixes
|
||||
assert "cli" in prefixes
|
||||
|
||||
def test_cron_prefix(self):
|
||||
assert ("cron",) == hermes_logging.COMPONENT_PREFIXES["cron"]
|
||||
|
||||
|
||||
class TestSetupVerboseLogging:
|
||||
"""setup_verbose_logging() adds a DEBUG-level console handler."""
|
||||
|
|
@ -301,6 +593,59 @@ class TestAddRotatingHandler:
|
|||
logger.removeHandler(h)
|
||||
h.close()
|
||||
|
||||
def test_log_filter_attached(self, tmp_path):
|
||||
"""Optional log_filter is attached to the handler."""
|
||||
log_path = tmp_path / "filtered.log"
|
||||
logger = logging.getLogger("_test_rotating_filter")
|
||||
formatter = logging.Formatter("%(message)s")
|
||||
component_filter = hermes_logging._ComponentFilter(("test",))
|
||||
|
||||
hermes_logging._add_rotating_handler(
|
||||
logger, log_path,
|
||||
level=logging.INFO, max_bytes=1024, backup_count=1,
|
||||
formatter=formatter,
|
||||
log_filter=component_filter,
|
||||
)
|
||||
|
||||
handlers = [h for h in logger.handlers if isinstance(h, RotatingFileHandler)]
|
||||
assert len(handlers) == 1
|
||||
assert component_filter in handlers[0].filters
|
||||
# Clean up
|
||||
for h in list(logger.handlers):
|
||||
if isinstance(h, RotatingFileHandler):
|
||||
logger.removeHandler(h)
|
||||
h.close()
|
||||
|
||||
def test_no_session_filter_on_handler(self, tmp_path):
|
||||
"""Handlers rely on record factory, not per-handler _SessionFilter."""
|
||||
log_path = tmp_path / "no_session_filter.log"
|
||||
logger = logging.getLogger("_test_no_session_filter")
|
||||
formatter = logging.Formatter("%(session_tag)s%(message)s")
|
||||
|
||||
hermes_logging._add_rotating_handler(
|
||||
logger, log_path,
|
||||
level=logging.INFO, max_bytes=1024, backup_count=1,
|
||||
formatter=formatter,
|
||||
)
|
||||
|
||||
handlers = [h for h in logger.handlers if isinstance(h, RotatingFileHandler)]
|
||||
assert len(handlers) == 1
|
||||
# No _SessionFilter on the handler — record factory handles it
|
||||
assert len(handlers[0].filters) == 0
|
||||
|
||||
# But session_tag still works (via record factory)
|
||||
hermes_logging.set_session_context("factory_test")
|
||||
logger.info("test msg")
|
||||
handlers[0].flush()
|
||||
content = log_path.read_text()
|
||||
assert "[factory_test]" in content
|
||||
|
||||
# Clean up
|
||||
for h in list(logger.handlers):
|
||||
if isinstance(h, RotatingFileHandler):
|
||||
logger.removeHandler(h)
|
||||
h.close()
|
||||
|
||||
def test_managed_mode_initial_open_sets_group_writable(self, tmp_path):
|
||||
log_path = tmp_path / "managed-open.log"
|
||||
logger = logging.getLogger("_test_rotating_managed_open")
|
||||
|
|
|
|||
|
|
@ -59,8 +59,9 @@ class TestCamofoxConfigDefaults:
|
|||
browser_cfg = DEFAULT_CONFIG["browser"]
|
||||
assert browser_cfg["camofox"]["managed_persistence"] is False
|
||||
|
||||
def test_config_version_unchanged(self):
|
||||
def test_config_version_matches_current_schema(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
|
||||
# managed_persistence is auto-merged by _deep_merge, no version bump needed
|
||||
assert DEFAULT_CONFIG["_config_version"] == 13
|
||||
# The current schema version is tracked globally; unrelated default
|
||||
# options may bump it after browser defaults are added.
|
||||
assert DEFAULT_CONFIG["_config_version"] == 15
|
||||
|
|
|
|||
|
|
@ -1,9 +1,6 @@
|
|||
"""Tests for tools/checkpoint_manager.py — CheckpointManager."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
|
@ -42,6 +39,19 @@ def checkpoint_base(tmp_path):
|
|||
return tmp_path / "checkpoints"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fake_home(tmp_path, monkeypatch):
|
||||
"""Set a deterministic fake home for expanduser/path-home behavior."""
|
||||
home = tmp_path / "home"
|
||||
home.mkdir()
|
||||
monkeypatch.setenv("HOME", str(home))
|
||||
monkeypatch.setenv("USERPROFILE", str(home))
|
||||
monkeypatch.delenv("HOMEDRIVE", raising=False)
|
||||
monkeypatch.delenv("HOMEPATH", raising=False)
|
||||
monkeypatch.setattr(Path, "home", classmethod(lambda cls: home))
|
||||
return home
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mgr(work_dir, checkpoint_base, monkeypatch):
|
||||
"""CheckpointManager with redirected checkpoint base."""
|
||||
|
|
@ -78,6 +88,16 @@ class TestShadowRepoPath:
|
|||
p = _shadow_repo_path(str(work_dir))
|
||||
assert str(p).startswith(str(checkpoint_base))
|
||||
|
||||
def test_tilde_and_expanded_home_share_shadow_repo(self, fake_home, checkpoint_base, monkeypatch):
|
||||
monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base)
|
||||
project = fake_home / "project"
|
||||
project.mkdir()
|
||||
|
||||
tilde_path = f"~/{project.name}"
|
||||
expanded_path = str(project)
|
||||
|
||||
assert _shadow_repo_path(tilde_path) == _shadow_repo_path(expanded_path)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Shadow repo init
|
||||
|
|
@ -221,6 +241,20 @@ class TestListCheckpoints:
|
|||
assert result[0]["reason"] == "third"
|
||||
assert result[2]["reason"] == "first"
|
||||
|
||||
def test_tilde_path_lists_same_checkpoints_as_expanded_path(self, checkpoint_base, fake_home, monkeypatch):
|
||||
monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base)
|
||||
mgr = CheckpointManager(enabled=True, max_snapshots=50)
|
||||
project = fake_home / "project"
|
||||
project.mkdir()
|
||||
(project / "main.py").write_text("v1\n")
|
||||
|
||||
tilde_path = f"~/{project.name}"
|
||||
assert mgr.ensure_checkpoint(tilde_path, "initial") is True
|
||||
|
||||
listed = mgr.list_checkpoints(str(project))
|
||||
assert len(listed) == 1
|
||||
assert listed[0]["reason"] == "initial"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# CheckpointManager — restoring
|
||||
|
|
@ -271,6 +305,28 @@ class TestRestore:
|
|||
assert len(all_cps) >= 2
|
||||
assert "pre-rollback" in all_cps[0]["reason"]
|
||||
|
||||
def test_tilde_path_supports_diff_and_restore_flow(self, checkpoint_base, fake_home, monkeypatch):
|
||||
monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base)
|
||||
mgr = CheckpointManager(enabled=True, max_snapshots=50)
|
||||
project = fake_home / "project"
|
||||
project.mkdir()
|
||||
file_path = project / "main.py"
|
||||
file_path.write_text("original\n")
|
||||
|
||||
tilde_path = f"~/{project.name}"
|
||||
assert mgr.ensure_checkpoint(tilde_path, "initial") is True
|
||||
mgr.new_turn()
|
||||
|
||||
file_path.write_text("changed\n")
|
||||
checkpoints = mgr.list_checkpoints(str(project))
|
||||
diff_result = mgr.diff(tilde_path, checkpoints[0]["hash"])
|
||||
assert diff_result["success"] is True
|
||||
assert "main.py" in diff_result["diff"]
|
||||
|
||||
restore_result = mgr.restore(tilde_path, checkpoints[0]["hash"])
|
||||
assert restore_result["success"] is True
|
||||
assert file_path.read_text() == "original\n"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# CheckpointManager — working dir resolution
|
||||
|
|
@ -310,6 +366,19 @@ class TestWorkingDirResolution:
|
|||
result = mgr.get_working_dir_for_path(str(filepath))
|
||||
assert result == str(filepath.parent)
|
||||
|
||||
def test_resolves_tilde_path_to_project_root(self, fake_home):
|
||||
mgr = CheckpointManager(enabled=True)
|
||||
project = fake_home / "myproject"
|
||||
project.mkdir()
|
||||
(project / "pyproject.toml").write_text("[project]\n")
|
||||
subdir = project / "src"
|
||||
subdir.mkdir()
|
||||
filepath = subdir / "main.py"
|
||||
filepath.write_text("x\n")
|
||||
|
||||
result = mgr.get_working_dir_for_path(f"~/{project.name}/src/main.py")
|
||||
assert result == str(project)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Git env isolation
|
||||
|
|
@ -333,6 +402,14 @@ class TestGitEnvIsolation:
|
|||
env = _git_env(shadow, str(tmp_path))
|
||||
assert "GIT_INDEX_FILE" not in env
|
||||
|
||||
def test_expands_tilde_in_work_tree(self, fake_home, tmp_path):
|
||||
shadow = tmp_path / "shadow"
|
||||
work = fake_home / "work"
|
||||
work.mkdir()
|
||||
|
||||
env = _git_env(shadow, f"~/{work.name}")
|
||||
assert env["GIT_WORK_TREE"] == str(work.resolve())
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# format_checkpoint_list
|
||||
|
|
@ -384,6 +461,8 @@ class TestErrorResilience:
|
|||
assert result is False
|
||||
|
||||
def test_run_git_allows_expected_nonzero_without_error_log(self, tmp_path, caplog):
|
||||
work = tmp_path / "work"
|
||||
work.mkdir()
|
||||
completed = subprocess.CompletedProcess(
|
||||
args=["git", "diff", "--cached", "--quiet"],
|
||||
returncode=1,
|
||||
|
|
@ -395,7 +474,7 @@ class TestErrorResilience:
|
|||
ok, stdout, stderr = _run_git(
|
||||
["diff", "--cached", "--quiet"],
|
||||
tmp_path / "shadow",
|
||||
str(tmp_path / "work"),
|
||||
str(work),
|
||||
allowed_returncodes={1},
|
||||
)
|
||||
assert ok is False
|
||||
|
|
@ -403,6 +482,38 @@ class TestErrorResilience:
|
|||
assert stderr == ""
|
||||
assert not caplog.records
|
||||
|
||||
def test_run_git_invalid_working_dir_reports_path_error(self, tmp_path, caplog):
|
||||
missing = tmp_path / "missing"
|
||||
with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"):
|
||||
ok, stdout, stderr = _run_git(
|
||||
["status"],
|
||||
tmp_path / "shadow",
|
||||
str(missing),
|
||||
)
|
||||
assert ok is False
|
||||
assert stdout == ""
|
||||
assert "working directory not found" in stderr
|
||||
assert not any("Git executable not found" in r.getMessage() for r in caplog.records)
|
||||
|
||||
def test_run_git_missing_git_reports_git_not_found(self, tmp_path, monkeypatch, caplog):
|
||||
work = tmp_path / "work"
|
||||
work.mkdir()
|
||||
|
||||
def raise_missing_git(*args, **kwargs):
|
||||
raise FileNotFoundError(2, "No such file or directory", "git")
|
||||
|
||||
monkeypatch.setattr("tools.checkpoint_manager.subprocess.run", raise_missing_git)
|
||||
with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"):
|
||||
ok, stdout, stderr = _run_git(
|
||||
["status"],
|
||||
tmp_path / "shadow",
|
||||
str(work),
|
||||
)
|
||||
assert ok is False
|
||||
assert stdout == ""
|
||||
assert stderr == "git not found"
|
||||
assert any("Git executable not found" in r.getMessage() for r in caplog.records)
|
||||
|
||||
def test_checkpoint_failure_does_not_raise(self, mgr, work_dir, monkeypatch):
|
||||
"""Checkpoint failures should never raise — they're silently logged."""
|
||||
def broken_run_git(*args, **kwargs):
|
||||
|
|
@ -411,3 +522,68 @@ class TestErrorResilience:
|
|||
# Should not raise
|
||||
result = mgr.ensure_checkpoint(str(work_dir), "test")
|
||||
assert result is False
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Security / Input validation
|
||||
# =========================================================================
|
||||
|
||||
class TestSecurity:
|
||||
def test_restore_rejects_argument_injection(self, mgr, work_dir):
|
||||
mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||
# Try to pass a git flag as a commit hash
|
||||
result = mgr.restore(str(work_dir), "--patch")
|
||||
assert result["success"] is False
|
||||
assert "Invalid commit hash" in result["error"]
|
||||
assert "must not start with '-'" in result["error"]
|
||||
|
||||
result = mgr.restore(str(work_dir), "-p")
|
||||
assert result["success"] is False
|
||||
assert "Invalid commit hash" in result["error"]
|
||||
|
||||
def test_restore_rejects_invalid_hex_chars(self, mgr, work_dir):
|
||||
mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||
# Git hashes should not contain characters like ;, &, |
|
||||
result = mgr.restore(str(work_dir), "abc; rm -rf /")
|
||||
assert result["success"] is False
|
||||
assert "expected 4-64 hex characters" in result["error"]
|
||||
|
||||
result = mgr.diff(str(work_dir), "abc&def")
|
||||
assert result["success"] is False
|
||||
assert "expected 4-64 hex characters" in result["error"]
|
||||
|
||||
def test_restore_rejects_path_traversal(self, mgr, work_dir):
|
||||
mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||
# Real commit hash but malicious path
|
||||
checkpoints = mgr.list_checkpoints(str(work_dir))
|
||||
target_hash = checkpoints[0]["hash"]
|
||||
|
||||
# Absolute path outside
|
||||
result = mgr.restore(str(work_dir), target_hash, file_path="/etc/passwd")
|
||||
assert result["success"] is False
|
||||
assert "got absolute path" in result["error"]
|
||||
|
||||
# Relative traversal outside path
|
||||
result = mgr.restore(str(work_dir), target_hash, file_path="../outside_file.txt")
|
||||
assert result["success"] is False
|
||||
assert "escapes the working directory" in result["error"]
|
||||
|
||||
def test_restore_accepts_valid_file_path(self, mgr, work_dir):
|
||||
mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||
checkpoints = mgr.list_checkpoints(str(work_dir))
|
||||
target_hash = checkpoints[0]["hash"]
|
||||
|
||||
# Valid path inside directory
|
||||
result = mgr.restore(str(work_dir), target_hash, file_path="main.py")
|
||||
assert result["success"] is True
|
||||
|
||||
# Another valid path with subdirectories
|
||||
(work_dir / "subdir").mkdir()
|
||||
(work_dir / "subdir" / "test.txt").write_text("hello")
|
||||
mgr.new_turn()
|
||||
mgr.ensure_checkpoint(str(work_dir), "second")
|
||||
checkpoints = mgr.list_checkpoints(str(work_dir))
|
||||
target_hash = checkpoints[0]["hash"]
|
||||
|
||||
result = mgr.restore(str(work_dir), target_hash, file_path="subdir/test.txt")
|
||||
assert result["success"] is True
|
||||
|
|
|
|||
|
|
@ -380,7 +380,7 @@ class TestStubSchemaDrift(unittest.TestCase):
|
|||
# Parameters that are internal (injected by the handler, not user-facing)
|
||||
_INTERNAL_PARAMS = {"task_id", "user_task"}
|
||||
# Parameters intentionally blocked in the sandbox
|
||||
_BLOCKED_TERMINAL_PARAMS = {"background", "check_interval", "pty", "notify_on_complete"}
|
||||
_BLOCKED_TERMINAL_PARAMS = {"background", "pty", "notify_on_complete"}
|
||||
|
||||
def test_stubs_cover_all_schema_params(self):
|
||||
"""Every user-facing parameter in the real schema must appear in the
|
||||
|
|
|
|||
295
tests/tools/test_modal_bulk_upload.py
Normal file
295
tests/tools/test_modal_bulk_upload.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
"""Tests for Modal bulk upload via tar/base64 archive."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import modal as modal_env
|
||||
|
||||
|
||||
def _make_mock_modal_env(monkeypatch, tmp_path):
|
||||
"""Create a minimal mock ModalEnvironment for testing upload methods.
|
||||
|
||||
Returns a ModalEnvironment-like object with _sandbox and _worker mocked.
|
||||
We don't call __init__ because it requires the Modal SDK.
|
||||
"""
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
env._sync_manager = None
|
||||
return env
|
||||
|
||||
|
||||
def _make_mock_stdin():
|
||||
"""Create a mock stdin that captures written data."""
|
||||
stdin = MagicMock()
|
||||
written_chunks = []
|
||||
|
||||
def mock_write(data):
|
||||
written_chunks.append(data)
|
||||
|
||||
stdin.write = mock_write
|
||||
stdin.write_eof = MagicMock()
|
||||
stdin.drain = MagicMock()
|
||||
stdin.drain.aio = AsyncMock()
|
||||
stdin._written_chunks = written_chunks
|
||||
return stdin
|
||||
|
||||
|
||||
def _wire_async_exec(env, exec_calls=None):
|
||||
"""Wire mock sandbox.exec.aio and a real run_coroutine on the env.
|
||||
|
||||
Optionally captures exec call args into *exec_calls* list.
|
||||
Returns (exec_calls, run_kwargs, stdin_mock).
|
||||
"""
|
||||
if exec_calls is None:
|
||||
exec_calls = []
|
||||
run_kwargs: dict = {}
|
||||
stdin_mock = _make_mock_stdin()
|
||||
|
||||
async def mock_exec_fn(*args, **kwargs):
|
||||
exec_calls.append(args)
|
||||
proc = MagicMock()
|
||||
proc.wait = MagicMock()
|
||||
proc.wait.aio = AsyncMock(return_value=0)
|
||||
proc.stdin = stdin_mock
|
||||
proc.stderr = MagicMock()
|
||||
proc.stderr.read = MagicMock()
|
||||
proc.stderr.read.aio = AsyncMock(return_value="")
|
||||
return proc
|
||||
|
||||
env._sandbox.exec = MagicMock()
|
||||
env._sandbox.exec.aio = mock_exec_fn
|
||||
|
||||
def real_run_coroutine(coro, **kwargs):
|
||||
run_kwargs.update(kwargs)
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
env._worker.run_coroutine = real_run_coroutine
|
||||
return exec_calls, run_kwargs, stdin_mock
|
||||
|
||||
|
||||
class TestModalBulkUpload:
|
||||
"""Test _modal_bulk_upload method."""
|
||||
|
||||
def test_empty_files_is_noop(self, monkeypatch, tmp_path):
|
||||
"""Empty file list should not call worker.run_coroutine."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
env._modal_bulk_upload([])
|
||||
env._worker.run_coroutine.assert_not_called()
|
||||
|
||||
def test_tar_archive_contains_all_files(self, monkeypatch, tmp_path):
|
||||
"""The tar archive sent via stdin should contain all files."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src_a = tmp_path / "a.json"
|
||||
src_b = tmp_path / "b.py"
|
||||
src_a.write_text("cred_content")
|
||||
src_b.write_text("skill_content")
|
||||
|
||||
files = [
|
||||
(str(src_a), "/root/.hermes/credentials/a.json"),
|
||||
(str(src_b), "/root/.hermes/skills/b.py"),
|
||||
]
|
||||
|
||||
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# Verify the command reads from stdin (no echo with embedded payload)
|
||||
assert len(exec_calls) == 1
|
||||
args = exec_calls[0]
|
||||
assert args[0] == "bash"
|
||||
assert args[1] == "-c"
|
||||
cmd = args[2]
|
||||
assert "mkdir -p" in cmd
|
||||
assert "base64 -d" in cmd
|
||||
assert "tar xzf" in cmd
|
||||
assert "-C /" in cmd
|
||||
|
||||
# Reassemble the base64 payload from stdin chunks and verify tar contents
|
||||
payload = "".join(stdin_mock._written_chunks)
|
||||
tar_data = base64.b64decode(payload)
|
||||
buf = io.BytesIO(tar_data)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
names = sorted(tar.getnames())
|
||||
assert "root/.hermes/credentials/a.json" in names
|
||||
assert "root/.hermes/skills/b.py" in names
|
||||
|
||||
# Verify content
|
||||
a_content = tar.extractfile("root/.hermes/credentials/a.json").read()
|
||||
assert a_content == b"cred_content"
|
||||
b_content = tar.extractfile("root/.hermes/skills/b.py").read()
|
||||
assert b_content == b"skill_content"
|
||||
|
||||
# Verify stdin was closed
|
||||
stdin_mock.write_eof.assert_called_once()
|
||||
|
||||
def test_mkdir_includes_all_parents(self, monkeypatch, tmp_path):
|
||||
"""Remote parent directories should be pre-created in the command."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("data")
|
||||
|
||||
files = [
|
||||
(str(src), "/root/.hermes/credentials/f.txt"),
|
||||
(str(src), "/root/.hermes/skills/deep/nested/f.txt"),
|
||||
]
|
||||
|
||||
exec_calls, _, _ = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
cmd = exec_calls[0][2]
|
||||
assert "/root/.hermes/credentials" in cmd
|
||||
assert "/root/.hermes/skills/deep/nested" in cmd
|
||||
|
||||
def test_single_exec_call(self, monkeypatch, tmp_path):
|
||||
"""Bulk upload should use exactly one exec call regardless of file count."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
files = []
|
||||
for i in range(20):
|
||||
src = tmp_path / f"file_{i}.txt"
|
||||
src.write_text(f"content_{i}")
|
||||
files.append((str(src), f"/root/.hermes/cache/file_{i}.txt"))
|
||||
|
||||
exec_calls, _, _ = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# Should be exactly 1 exec call, not 20
|
||||
assert len(exec_calls) == 1
|
||||
|
||||
def test_bulk_upload_wired_in_filesyncmanager(self, monkeypatch):
|
||||
"""Verify ModalEnvironment passes bulk_upload_fn to FileSyncManager."""
|
||||
captured_kwargs = {}
|
||||
|
||||
def capture_fsm(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return type("M", (), {"sync": lambda self, **k: None})()
|
||||
|
||||
monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm)
|
||||
|
||||
# Create a minimal env without full __init__
|
||||
env = object.__new__(modal_env.ModalEnvironment)
|
||||
env._sandbox = MagicMock()
|
||||
env._worker = MagicMock()
|
||||
env._persistent = False
|
||||
env._task_id = "test"
|
||||
|
||||
# Manually call the part of __init__ that wires FileSyncManager
|
||||
from tools.environments.file_sync import iter_sync_files
|
||||
env._sync_manager = modal_env.FileSyncManager(
|
||||
get_files_fn=lambda: iter_sync_files("/root/.hermes"),
|
||||
upload_fn=env._modal_upload,
|
||||
delete_fn=env._modal_delete,
|
||||
bulk_upload_fn=env._modal_bulk_upload,
|
||||
)
|
||||
|
||||
assert "bulk_upload_fn" in captured_kwargs
|
||||
assert captured_kwargs["bulk_upload_fn"] is not None
|
||||
assert callable(captured_kwargs["bulk_upload_fn"])
|
||||
|
||||
def test_timeout_set_to_120(self, monkeypatch, tmp_path):
|
||||
"""Bulk upload uses a 120s timeout (not the per-file 15s)."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("data")
|
||||
files = [(str(src), "/root/.hermes/f.txt")]
|
||||
|
||||
_, run_kwargs, _ = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
assert run_kwargs.get("timeout") == 120
|
||||
|
||||
def test_nonzero_exit_raises(self, monkeypatch, tmp_path):
|
||||
"""Non-zero exit code from remote exec should raise RuntimeError."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("data")
|
||||
files = [(str(src), "/root/.hermes/f.txt")]
|
||||
|
||||
stdin_mock = _make_mock_stdin()
|
||||
|
||||
async def mock_exec_fn(*args, **kwargs):
|
||||
proc = MagicMock()
|
||||
proc.wait = MagicMock()
|
||||
proc.wait.aio = AsyncMock(return_value=1) # non-zero exit
|
||||
proc.stdin = stdin_mock
|
||||
proc.stderr = MagicMock()
|
||||
proc.stderr.read = MagicMock()
|
||||
proc.stderr.read.aio = AsyncMock(return_value="tar: error")
|
||||
return proc
|
||||
|
||||
env._sandbox.exec = MagicMock()
|
||||
env._sandbox.exec.aio = mock_exec_fn
|
||||
|
||||
def real_run_coroutine(coro, **kwargs):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
env._worker.run_coroutine = real_run_coroutine
|
||||
|
||||
with pytest.raises(RuntimeError, match="Modal bulk upload failed"):
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
def test_payload_not_in_command_string(self, monkeypatch, tmp_path):
|
||||
"""The base64 payload must NOT appear in the bash -c argument.
|
||||
|
||||
This is the core ARG_MAX fix: the payload goes through stdin,
|
||||
not embedded in the command string.
|
||||
"""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
src = tmp_path / "f.txt"
|
||||
src.write_text("some data to upload")
|
||||
files = [(str(src), "/root/.hermes/f.txt")]
|
||||
|
||||
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# The command should NOT contain an echo with the payload
|
||||
cmd = exec_calls[0][2]
|
||||
assert "echo" not in cmd
|
||||
# The payload should go through stdin
|
||||
assert len(stdin_mock._written_chunks) > 0
|
||||
|
||||
def test_stdin_chunked_for_large_payloads(self, monkeypatch, tmp_path):
|
||||
"""Payloads larger than _STDIN_CHUNK_SIZE should be split into multiple writes."""
|
||||
env = _make_mock_modal_env(monkeypatch, tmp_path)
|
||||
|
||||
# Use random bytes so gzip cannot compress them -- ensures the
|
||||
# base64 payload exceeds one 1 MB chunk.
|
||||
import os as _os
|
||||
src = tmp_path / "large.bin"
|
||||
src.write_bytes(_os.urandom(1024 * 1024 + 512 * 1024))
|
||||
files = [(str(src), "/root/.hermes/large.bin")]
|
||||
|
||||
exec_calls, _, stdin_mock = _wire_async_exec(env)
|
||||
env._modal_bulk_upload(files)
|
||||
|
||||
# Should have multiple stdin write chunks
|
||||
assert len(stdin_mock._written_chunks) >= 2
|
||||
|
||||
# Reassembled payload should still decode to valid tar
|
||||
payload = "".join(stdin_mock._written_chunks)
|
||||
tar_data = base64.b64decode(payload)
|
||||
buf = io.BytesIO(tar_data)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
names = tar.getnames()
|
||||
assert "root/.hermes/large.bin" in names
|
||||
517
tests/tools/test_ssh_bulk_upload.py
Normal file
517
tests/tools/test_ssh_bulk_upload.py
Normal file
|
|
@ -0,0 +1,517 @@
|
|||
"""Tests for SSH bulk upload via tar pipe."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import ssh as ssh_env
|
||||
from tools.environments.file_sync import quoted_mkdir_command, unique_parent_dirs
|
||||
from tools.environments.ssh import SSHEnvironment
|
||||
|
||||
|
||||
def _mock_proc(*, returncode=0, poll_return=0, communicate_return=(b"", b""),
|
||||
stderr_read=b""):
|
||||
"""Create a MagicMock mimicking subprocess.Popen for tar/ssh pipes."""
|
||||
m = MagicMock()
|
||||
m.stdout = MagicMock()
|
||||
m.returncode = returncode
|
||||
m.poll.return_value = poll_return
|
||||
m.communicate.return_value = communicate_return
|
||||
m.stderr = MagicMock()
|
||||
m.stderr.read.return_value = stderr_read
|
||||
return m
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env(monkeypatch):
|
||||
"""Create an SSHEnvironment with mocked connection/sync."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
monkeypatch.setattr(
|
||||
ssh_env, "FileSyncManager",
|
||||
lambda **kw: type("M", (), {"sync": lambda self, **k: None})(),
|
||||
)
|
||||
return SSHEnvironment(host="example.com", user="testuser")
|
||||
|
||||
|
||||
class TestSSHBulkUpload:
|
||||
"""Unit tests for _ssh_bulk_upload — tar pipe mechanics."""
|
||||
|
||||
def test_empty_files_is_noop(self, mock_env):
|
||||
"""Empty file list should not spawn any subprocesses."""
|
||||
with patch.object(subprocess, "run") as mock_run, \
|
||||
patch.object(subprocess, "Popen") as mock_popen:
|
||||
mock_env._ssh_bulk_upload([])
|
||||
mock_run.assert_not_called()
|
||||
mock_popen.assert_not_called()
|
||||
|
||||
def test_mkdir_batched_into_single_call(self, mock_env, tmp_path):
|
||||
"""All parent directories should be created in one SSH call."""
|
||||
# Create test files
|
||||
f1 = tmp_path / "a.txt"
|
||||
f1.write_text("aaa")
|
||||
f2 = tmp_path / "b.txt"
|
||||
f2.write_text("bbb")
|
||||
|
||||
files = [
|
||||
(str(f1), "/home/testuser/.hermes/skills/a.txt"),
|
||||
(str(f2), "/home/testuser/.hermes/credentials/b.txt"),
|
||||
]
|
||||
|
||||
# Mock subprocess.run for mkdir and Popen for tar pipe
|
||||
mock_run = MagicMock(return_value=subprocess.CompletedProcess([], 0))
|
||||
|
||||
def make_proc(cmd, **kwargs):
|
||||
m = MagicMock()
|
||||
m.stdout = MagicMock()
|
||||
m.returncode = 0
|
||||
m.poll.return_value = 0
|
||||
m.communicate.return_value = (b"", b"")
|
||||
m.stderr = MagicMock()
|
||||
m.stderr.read.return_value = b""
|
||||
return m
|
||||
|
||||
with patch.object(subprocess, "run", mock_run), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
# Exactly one subprocess.run call for mkdir
|
||||
assert mock_run.call_count == 1
|
||||
mkdir_cmd = mock_run.call_args[0][0]
|
||||
# Should contain mkdir -p with both parent dirs
|
||||
mkdir_str = " ".join(mkdir_cmd)
|
||||
assert "mkdir -p" in mkdir_str
|
||||
assert "/home/testuser/.hermes/skills" in mkdir_str
|
||||
assert "/home/testuser/.hermes/credentials" in mkdir_str
|
||||
|
||||
def test_staging_symlinks_mirror_remote_layout(self, mock_env, tmp_path):
|
||||
"""Symlinks in staging dir should mirror the remote path structure."""
|
||||
f1 = tmp_path / "local_a.txt"
|
||||
f1.write_text("content a")
|
||||
|
||||
files = [
|
||||
(str(f1), "/home/testuser/.hermes/skills/my_skill.md"),
|
||||
]
|
||||
|
||||
staging_paths = []
|
||||
|
||||
def capture_tar_cmd(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
# Capture the staging dir from -C argument
|
||||
c_idx = cmd.index("-C")
|
||||
staging_dir = cmd[c_idx + 1]
|
||||
# Check the symlink exists
|
||||
expected = os.path.join(
|
||||
staging_dir, "home/testuser/.hermes/skills/my_skill.md"
|
||||
)
|
||||
staging_paths.append(expected)
|
||||
assert os.path.islink(expected), f"Expected symlink at {expected}"
|
||||
assert os.readlink(expected) == os.path.abspath(str(f1))
|
||||
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_tar_cmd):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
assert len(staging_paths) == 1, "tar command should have been called"
|
||||
|
||||
def test_tar_pipe_commands(self, mock_env, tmp_path):
|
||||
"""Verify tar and SSH commands are wired correctly."""
|
||||
f1 = tmp_path / "x.txt"
|
||||
f1.write_text("x")
|
||||
|
||||
files = [(str(f1), "/home/testuser/.hermes/cache/x.txt")]
|
||||
|
||||
popen_cmds = []
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
popen_cmds.append(cmd)
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
assert len(popen_cmds) == 2, "Should spawn tar + ssh processes"
|
||||
|
||||
tar_cmd = popen_cmds[0]
|
||||
ssh_cmd = popen_cmds[1]
|
||||
|
||||
# tar: create, dereference symlinks, to stdout
|
||||
assert tar_cmd[0] == "tar"
|
||||
assert "-chf" in tar_cmd
|
||||
assert "-" in tar_cmd # stdout
|
||||
assert "-C" in tar_cmd
|
||||
|
||||
# ssh: extract from stdin at /
|
||||
ssh_str = " ".join(ssh_cmd)
|
||||
assert "ssh" in ssh_str
|
||||
assert "tar xf - -C /" in ssh_str
|
||||
assert "testuser@example.com" in ssh_str
|
||||
|
||||
def test_mkdir_failure_raises(self, mock_env, tmp_path):
|
||||
"""mkdir failure should raise RuntimeError before tar pipe."""
|
||||
f1 = tmp_path / "y.txt"
|
||||
f1.write_text("y")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/y.txt")]
|
||||
|
||||
failed_run = subprocess.CompletedProcess([], 1, stderr="Permission denied")
|
||||
with patch.object(subprocess, "run", return_value=failed_run):
|
||||
with pytest.raises(RuntimeError, match="remote mkdir failed"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
def test_tar_create_failure_raises(self, mock_env, tmp_path):
|
||||
"""tar create failure should raise RuntimeError."""
|
||||
f1 = tmp_path / "z.txt"
|
||||
f1.write_text("z")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/z.txt")]
|
||||
|
||||
mock_tar = MagicMock()
|
||||
mock_tar.stdout = MagicMock()
|
||||
mock_tar.returncode = 1
|
||||
mock_tar.poll.return_value = 1
|
||||
mock_tar.communicate.return_value = (b"tar: error", b"")
|
||||
mock_tar.stderr = MagicMock()
|
||||
mock_tar.stderr.read.return_value = b"tar: error"
|
||||
|
||||
mock_ssh = MagicMock()
|
||||
mock_ssh.communicate.return_value = (b"", b"")
|
||||
mock_ssh.returncode = 0
|
||||
|
||||
def popen_side_effect(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
return mock_tar
|
||||
return mock_ssh
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=popen_side_effect):
|
||||
with pytest.raises(RuntimeError, match="tar create failed"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
def test_ssh_extract_failure_raises(self, mock_env, tmp_path):
|
||||
"""SSH tar extract failure should raise RuntimeError."""
|
||||
f1 = tmp_path / "w.txt"
|
||||
f1.write_text("w")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/w.txt")]
|
||||
|
||||
mock_tar = MagicMock()
|
||||
mock_tar.stdout = MagicMock()
|
||||
mock_tar.returncode = 0
|
||||
mock_tar.poll.return_value = 0
|
||||
mock_tar.communicate.return_value = (b"", b"")
|
||||
mock_tar.stderr = MagicMock()
|
||||
mock_tar.stderr.read.return_value = b""
|
||||
|
||||
mock_ssh = MagicMock()
|
||||
mock_ssh.communicate.return_value = (b"", b"Permission denied")
|
||||
mock_ssh.returncode = 1
|
||||
|
||||
def popen_side_effect(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
return mock_tar
|
||||
return mock_ssh
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=popen_side_effect):
|
||||
with pytest.raises(RuntimeError, match="tar extract over SSH failed"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
def test_ssh_command_uses_control_socket(self, mock_env, tmp_path):
|
||||
"""SSH command for tar extract should reuse ControlMaster socket."""
|
||||
f1 = tmp_path / "c.txt"
|
||||
f1.write_text("c")
|
||||
files = [(str(f1), "/home/testuser/.hermes/cache/c.txt")]
|
||||
|
||||
popen_cmds = []
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
popen_cmds.append(cmd)
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
# The SSH command (second Popen call) should include ControlPath
|
||||
ssh_cmd = popen_cmds[1]
|
||||
assert f"ControlPath={mock_env.control_socket}" in " ".join(ssh_cmd)
|
||||
|
||||
def test_custom_port_and_key_in_ssh_command(self, monkeypatch, tmp_path):
|
||||
"""Bulk upload SSH command should include custom port and key."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
monkeypatch.setattr(
|
||||
ssh_env, "FileSyncManager",
|
||||
lambda **kw: type("M", (), {"sync": lambda self, **k: None})(),
|
||||
)
|
||||
env = SSHEnvironment(host="h", user="u", port=2222, key_path="/my/key")
|
||||
|
||||
f1 = tmp_path / "d.txt"
|
||||
f1.write_text("d")
|
||||
files = [(str(f1), "/home/u/.hermes/skills/d.txt")]
|
||||
|
||||
run_cmds = []
|
||||
popen_cmds = []
|
||||
|
||||
def capture_run(cmd, **kwargs):
|
||||
run_cmds.append(cmd)
|
||||
return subprocess.CompletedProcess([], 0)
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
popen_cmds.append(cmd)
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run", side_effect=capture_run), \
|
||||
patch.object(subprocess, "Popen", side_effect=capture_popen):
|
||||
env._ssh_bulk_upload(files)
|
||||
|
||||
# Check mkdir SSH call includes port and key
|
||||
assert len(run_cmds) == 1
|
||||
mkdir_cmd = run_cmds[0]
|
||||
assert "-p" in mkdir_cmd and "2222" in mkdir_cmd
|
||||
assert "-i" in mkdir_cmd and "/my/key" in mkdir_cmd
|
||||
|
||||
# Check tar extract SSH call includes port and key
|
||||
ssh_cmd = popen_cmds[1]
|
||||
assert "-p" in ssh_cmd and "2222" in ssh_cmd
|
||||
assert "-i" in ssh_cmd and "/my/key" in ssh_cmd
|
||||
|
||||
def test_parent_dirs_deduplicated(self, mock_env, tmp_path):
|
||||
"""Multiple files in the same dir should produce one mkdir entry."""
|
||||
f1 = tmp_path / "a.txt"
|
||||
f1.write_text("a")
|
||||
f2 = tmp_path / "b.txt"
|
||||
f2.write_text("b")
|
||||
f3 = tmp_path / "c.txt"
|
||||
f3.write_text("c")
|
||||
|
||||
files = [
|
||||
(str(f1), "/home/testuser/.hermes/skills/a.txt"),
|
||||
(str(f2), "/home/testuser/.hermes/skills/b.txt"),
|
||||
(str(f3), "/home/testuser/.hermes/credentials/c.txt"),
|
||||
]
|
||||
|
||||
run_cmds = []
|
||||
|
||||
def capture_run(cmd, **kwargs):
|
||||
run_cmds.append(cmd)
|
||||
return subprocess.CompletedProcess([], 0)
|
||||
|
||||
def make_mock_proc(cmd, **kwargs):
|
||||
mock = MagicMock()
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run", side_effect=capture_run), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_mock_proc):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
# Only one mkdir call
|
||||
assert len(run_cmds) == 1
|
||||
mkdir_str = " ".join(run_cmds[0])
|
||||
# skills dir should appear exactly once despite two files
|
||||
assert mkdir_str.count("/home/testuser/.hermes/skills") == 1
|
||||
assert "/home/testuser/.hermes/credentials" in mkdir_str
|
||||
|
||||
def test_tar_stdout_closed_for_sigpipe(self, mock_env, tmp_path):
|
||||
"""tar_proc.stdout must be closed so SIGPIPE propagates correctly."""
|
||||
f1 = tmp_path / "s.txt"
|
||||
f1.write_text("s")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/s.txt")]
|
||||
|
||||
mock_tar_stdout = MagicMock()
|
||||
|
||||
def make_proc(cmd, **kwargs):
|
||||
mock = MagicMock()
|
||||
if cmd[0] == "tar":
|
||||
mock.stdout = mock_tar_stdout
|
||||
else:
|
||||
mock.stdout = MagicMock()
|
||||
mock.returncode = 0
|
||||
mock.poll.return_value = 0
|
||||
mock.communicate.return_value = (b"", b"")
|
||||
mock.stderr = MagicMock()
|
||||
mock.stderr.read.return_value = b""
|
||||
return mock
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
mock_tar_stdout.close.assert_called_once()
|
||||
|
||||
def test_timeout_kills_both_processes(self, mock_env, tmp_path):
|
||||
"""TimeoutExpired during communicate should kill both processes."""
|
||||
f1 = tmp_path / "t.txt"
|
||||
f1.write_text("t")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/t.txt")]
|
||||
|
||||
mock_tar = MagicMock()
|
||||
mock_tar.stdout = MagicMock()
|
||||
mock_tar.returncode = None
|
||||
mock_tar.poll.return_value = None
|
||||
|
||||
mock_ssh = MagicMock()
|
||||
mock_ssh.communicate.side_effect = subprocess.TimeoutExpired("ssh", 120)
|
||||
mock_ssh.returncode = None
|
||||
|
||||
def make_proc(cmd, **kwargs):
|
||||
if cmd[0] == "tar":
|
||||
return mock_tar
|
||||
return mock_ssh
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=make_proc):
|
||||
with pytest.raises(RuntimeError, match="SSH bulk upload timed out"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
mock_tar.kill.assert_called_once()
|
||||
mock_ssh.kill.assert_called_once()
|
||||
|
||||
|
||||
class TestSSHBulkUploadWiring:
|
||||
"""Verify bulk_upload_fn is wired into FileSyncManager."""
|
||||
|
||||
def test_filesyncmanager_receives_bulk_upload_fn(self, monkeypatch):
|
||||
"""SSHEnvironment should pass _ssh_bulk_upload to FileSyncManager."""
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root")
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None)
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None)
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
class FakeSyncManager:
|
||||
def __init__(self, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
|
||||
def sync(self, **kw):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(ssh_env, "FileSyncManager", FakeSyncManager)
|
||||
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
|
||||
assert "bulk_upload_fn" in captured_kwargs
|
||||
assert captured_kwargs["bulk_upload_fn"] is not None
|
||||
# Should be the bound method
|
||||
assert callable(captured_kwargs["bulk_upload_fn"])
|
||||
|
||||
|
||||
class TestSharedHelpers:
|
||||
"""Direct unit tests for file_sync.py helpers."""
|
||||
|
||||
def test_quoted_mkdir_command_basic(self):
|
||||
result = quoted_mkdir_command(["/a", "/b/c"])
|
||||
assert result == "mkdir -p /a /b/c"
|
||||
|
||||
def test_quoted_mkdir_command_quotes_special_chars(self):
|
||||
result = quoted_mkdir_command(["/path/with spaces", "/path/'quotes'"])
|
||||
assert "mkdir -p" in result
|
||||
# shlex.quote wraps in single quotes
|
||||
assert "'/path/with spaces'" in result
|
||||
|
||||
def test_quoted_mkdir_command_empty(self):
|
||||
result = quoted_mkdir_command([])
|
||||
assert result == "mkdir -p "
|
||||
|
||||
def test_unique_parent_dirs_deduplicates(self):
|
||||
files = [
|
||||
("/local/a.txt", "/remote/dir/a.txt"),
|
||||
("/local/b.txt", "/remote/dir/b.txt"),
|
||||
("/local/c.txt", "/remote/other/c.txt"),
|
||||
]
|
||||
result = unique_parent_dirs(files)
|
||||
assert result == ["/remote/dir", "/remote/other"]
|
||||
|
||||
def test_unique_parent_dirs_sorted(self):
|
||||
files = [
|
||||
("/local/z.txt", "/z/file.txt"),
|
||||
("/local/a.txt", "/a/file.txt"),
|
||||
]
|
||||
result = unique_parent_dirs(files)
|
||||
assert result == ["/a", "/z"]
|
||||
|
||||
def test_unique_parent_dirs_empty(self):
|
||||
assert unique_parent_dirs([]) == []
|
||||
|
||||
|
||||
class TestSSHBulkUploadEdgeCases:
|
||||
"""Edge cases for _ssh_bulk_upload."""
|
||||
|
||||
def test_ssh_popen_failure_kills_tar(self, mock_env, tmp_path):
|
||||
"""If SSH Popen raises, tar process must be killed and cleaned up."""
|
||||
f1 = tmp_path / "e.txt"
|
||||
f1.write_text("e")
|
||||
files = [(str(f1), "/home/testuser/.hermes/skills/e.txt")]
|
||||
|
||||
mock_tar = _mock_proc()
|
||||
|
||||
call_count = 0
|
||||
|
||||
def failing_ssh_popen(cmd, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return mock_tar # tar Popen succeeds
|
||||
raise OSError("SSH binary not found")
|
||||
|
||||
with patch.object(subprocess, "run",
|
||||
return_value=subprocess.CompletedProcess([], 0)), \
|
||||
patch.object(subprocess, "Popen", side_effect=failing_ssh_popen):
|
||||
with pytest.raises(OSError, match="SSH binary not found"):
|
||||
mock_env._ssh_bulk_upload(files)
|
||||
|
||||
mock_tar.kill.assert_called_once()
|
||||
mock_tar.wait.assert_called_once()
|
||||
|
|
@ -24,6 +24,18 @@ class TestWriteAndRead:
|
|||
items[0]["content"] = "MUTATED"
|
||||
assert store.read()[0]["content"] == "Task"
|
||||
|
||||
def test_write_deduplicates_duplicate_ids(self):
|
||||
store = TodoStore()
|
||||
result = store.write([
|
||||
{"id": "1", "content": "First version", "status": "pending"},
|
||||
{"id": "2", "content": "Other task", "status": "pending"},
|
||||
{"id": "1", "content": "Latest version", "status": "in_progress"},
|
||||
])
|
||||
assert result == [
|
||||
{"id": "2", "content": "Other task", "status": "pending"},
|
||||
{"id": "1", "content": "Latest version", "status": "in_progress"},
|
||||
]
|
||||
|
||||
|
||||
class TestHasItems:
|
||||
def test_empty_store(self):
|
||||
|
|
|
|||
|
|
@ -124,6 +124,34 @@ class TestWriteToSandbox:
|
|||
cmd = env.execute.call_args[0][0]
|
||||
assert "mkdir -p /data/data/com.termux/files/usr/tmp/hermes-results" in cmd
|
||||
|
||||
def test_path_with_spaces_is_quoted(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
remote_path = "/tmp/hermes results/abc file.txt"
|
||||
_write_to_sandbox("content", remote_path, env)
|
||||
cmd = env.execute.call_args[0][0]
|
||||
assert "'/tmp/hermes results'" in cmd
|
||||
assert "'/tmp/hermes results/abc file.txt'" in cmd
|
||||
|
||||
def test_shell_metacharacters_neutralized(self):
|
||||
"""Paths with shell metacharacters must be quoted to prevent injection."""
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
malicious_path = "/tmp/hermes-results/$(whoami).txt"
|
||||
_write_to_sandbox("content", malicious_path, env)
|
||||
cmd = env.execute.call_args[0][0]
|
||||
# The $() must not appear unquoted — shlex.quote wraps it
|
||||
assert "'/tmp/hermes-results/$(whoami).txt'" in cmd
|
||||
|
||||
def test_semicolon_injection_neutralized(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
malicious_path = "/tmp/x; rm -rf /; echo .txt"
|
||||
_write_to_sandbox("content", malicious_path, env)
|
||||
cmd = env.execute.call_args[0][0]
|
||||
# The semicolons must be inside quotes, not acting as command separators
|
||||
assert "'/tmp/x; rm -rf /; echo .txt'" in cmd
|
||||
|
||||
|
||||
class TestResolveStorageDir:
|
||||
def test_defaults_to_storage_dir_without_env(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue