mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-08 03:01:47 +00:00
Merge remote-tracking branch 'origin/main' into sid/types-and-lints
# Conflicts: # gateway/run.py # tools/delegate_tool.py
This commit is contained in:
commit
847ffca715
171 changed files with 15125 additions and 1675 deletions
|
|
@ -476,6 +476,133 @@ class TestGetTextAuxiliaryClient:
|
|||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
|
||||
class TestNousAuxiliaryRefresh:
|
||||
def test_try_nous_prefers_runtime_credentials(self):
|
||||
fresh_base = "https://inference-api.nousresearch.com/v1"
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value={"access_token": "stale-token"}),
|
||||
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", fresh_base)),
|
||||
patch("hermes_cli.models.get_nous_recommended_aux_model", return_value=None),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
from agent.auxiliary_client import _try_nous
|
||||
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = _try_nous()
|
||||
|
||||
assert client is not None
|
||||
# No Portal recommendation → falls back to the hardcoded default.
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "fresh-agent-key"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == fresh_base
|
||||
|
||||
def test_try_nous_uses_portal_recommendation_for_text(self):
|
||||
"""When the Portal recommends a compaction model, _try_nous honors it."""
|
||||
fresh_base = "https://inference-api.nousresearch.com/v1"
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value={"access_token": "***"}),
|
||||
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", fresh_base)),
|
||||
patch("hermes_cli.models.get_nous_recommended_aux_model", return_value="minimax/minimax-m2.7") as mock_rec,
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
from agent.auxiliary_client import _try_nous
|
||||
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = _try_nous(vision=False)
|
||||
|
||||
assert client is not None
|
||||
assert model == "minimax/minimax-m2.7"
|
||||
assert mock_rec.call_args.kwargs["vision"] is False
|
||||
|
||||
def test_try_nous_uses_portal_recommendation_for_vision(self):
|
||||
"""Vision tasks should ask for the vision-specific recommendation."""
|
||||
fresh_base = "https://inference-api.nousresearch.com/v1"
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value={"access_token": "***"}),
|
||||
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", fresh_base)),
|
||||
patch("hermes_cli.models.get_nous_recommended_aux_model", return_value="google/gemini-3-flash-preview") as mock_rec,
|
||||
patch("agent.auxiliary_client.OpenAI"),
|
||||
):
|
||||
from agent.auxiliary_client import _try_nous
|
||||
client, model = _try_nous(vision=True)
|
||||
|
||||
assert client is not None
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert mock_rec.call_args.kwargs["vision"] is True
|
||||
|
||||
def test_try_nous_falls_back_when_recommendation_lookup_raises(self):
|
||||
"""If the Portal lookup throws, we must still return a usable model."""
|
||||
fresh_base = "https://inference-api.nousresearch.com/v1"
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value={"access_token": "***"}),
|
||||
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", fresh_base)),
|
||||
patch("hermes_cli.models.get_nous_recommended_aux_model", side_effect=RuntimeError("portal down")),
|
||||
patch("agent.auxiliary_client.OpenAI"),
|
||||
):
|
||||
from agent.auxiliary_client import _try_nous
|
||||
client, model = _try_nous()
|
||||
|
||||
assert client is not None
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_call_llm_retries_nous_after_401(self):
|
||||
class _Auth401(Exception):
|
||||
status_code = 401
|
||||
|
||||
stale_client = MagicMock()
|
||||
stale_client.base_url = "https://inference-api.nousresearch.com/v1"
|
||||
stale_client.chat.completions.create.side_effect = _Auth401("stale nous key")
|
||||
|
||||
fresh_client = MagicMock()
|
||||
fresh_client.base_url = "https://inference-api.nousresearch.com/v1"
|
||||
fresh_client.chat.completions.create.return_value = {"ok": True}
|
||||
|
||||
with (
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model", return_value=("nous", "nous-model", None, None, None)),
|
||||
patch("agent.auxiliary_client._get_cached_client", return_value=(stale_client, "nous-model")),
|
||||
patch("agent.auxiliary_client.OpenAI", return_value=fresh_client),
|
||||
patch("agent.auxiliary_client._validate_llm_response", side_effect=lambda resp, _task: resp),
|
||||
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", "https://inference-api.nousresearch.com/v1")),
|
||||
):
|
||||
result = call_llm(
|
||||
task="compression",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
assert stale_client.chat.completions.create.call_count == 1
|
||||
assert fresh_client.chat.completions.create.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_call_llm_retries_nous_after_401(self):
|
||||
class _Auth401(Exception):
|
||||
status_code = 401
|
||||
|
||||
stale_client = MagicMock()
|
||||
stale_client.base_url = "https://inference-api.nousresearch.com/v1"
|
||||
stale_client.chat.completions.create = AsyncMock(side_effect=_Auth401("stale nous key"))
|
||||
|
||||
fresh_async_client = MagicMock()
|
||||
fresh_async_client.base_url = "https://inference-api.nousresearch.com/v1"
|
||||
fresh_async_client.chat.completions.create = AsyncMock(return_value={"ok": True})
|
||||
|
||||
with (
|
||||
patch("agent.auxiliary_client._resolve_task_provider_model", return_value=("nous", "nous-model", None, None, None)),
|
||||
patch("agent.auxiliary_client._get_cached_client", return_value=(stale_client, "nous-model")),
|
||||
patch("agent.auxiliary_client._to_async_client", return_value=(fresh_async_client, "nous-model")),
|
||||
patch("agent.auxiliary_client._validate_llm_response", side_effect=lambda resp, _task: resp),
|
||||
patch("agent.auxiliary_client._resolve_nous_runtime_api", return_value=("fresh-agent-key", "https://inference-api.nousresearch.com/v1")),
|
||||
):
|
||||
result = await async_call_llm(
|
||||
task="session_search",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
)
|
||||
|
||||
assert result == {"ok": True}
|
||||
assert stale_client.chat.completions.create.await_count == 1
|
||||
assert fresh_async_client.chat.completions.create.await_count == 1
|
||||
|
||||
# ── Payment / credit exhaustion fallback ─────────────────────────────────
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ class TestResolveAutoMainFirst:
|
|||
|
||||
|
||||
class TestResolveVisionMainFirst:
|
||||
"""Vision auto-detection prefers main provider + main model first."""
|
||||
"""Vision auto-detection prefers the main provider first."""
|
||||
|
||||
def test_openrouter_main_vision_uses_main_model(self, monkeypatch):
|
||||
"""OpenRouter main with vision-capable model → aux vision uses main model."""
|
||||
|
|
@ -200,28 +200,49 @@ class TestResolveVisionMainFirst:
|
|||
assert mock_resolve.call_args.args[0] == "openrouter"
|
||||
assert mock_resolve.call_args.args[1] == "anthropic/claude-sonnet-4.6"
|
||||
|
||||
def test_nous_main_vision_uses_main_model(self):
|
||||
"""Nous Portal main → aux vision uses main model, not free-tier MiMo-V2-Omni."""
|
||||
def test_nous_main_vision_uses_paid_nous_vision_backend(self):
|
||||
"""Paid Nous main → aux vision uses the dedicated Nous vision backend."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="nous",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="openai/gpt-5",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve, patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_strict_vision_backend",
|
||||
return_value=(MagicMock(), "google/gemini-3-flash-preview"),
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "openai/gpt-5")
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "nous"
|
||||
assert model == "openai/gpt-5"
|
||||
assert client is not None
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_nous_main_vision_uses_free_tier_nous_vision_backend(self):
|
||||
"""Free-tier Nous main → aux vision uses MiMo omni, not the text main model."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="nous",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="xiaomi/mimo-v2-pro",
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_strict_vision_backend",
|
||||
return_value=(MagicMock(), "xiaomi/mimo-v2-omni"),
|
||||
):
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "nous"
|
||||
assert client is not None
|
||||
assert model == "xiaomi/mimo-v2-omni"
|
||||
|
||||
def test_exotic_provider_with_vision_override_preserved(self):
|
||||
"""xiaomi → mimo-v2-omni override still wins over main_model."""
|
||||
|
|
|
|||
111
tests/agent/test_image_gen_registry.py
Normal file
111
tests/agent/test_image_gen_registry.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
"""Tests for agent/image_gen_registry.py — provider registration & active lookup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agent import image_gen_registry
|
||||
from agent.image_gen_provider import ImageGenProvider
|
||||
|
||||
|
||||
class _FakeProvider(ImageGenProvider):
|
||||
def __init__(self, name: str, available: bool = True):
|
||||
self._name = name
|
||||
self._available = available
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._available
|
||||
|
||||
def generate(self, prompt, aspect_ratio="landscape", **kw):
|
||||
return {"success": True, "image": f"{self._name}://{prompt}"}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
image_gen_registry._reset_for_tests()
|
||||
yield
|
||||
image_gen_registry._reset_for_tests()
|
||||
|
||||
|
||||
class TestRegisterProvider:
|
||||
def test_register_and_lookup(self):
|
||||
provider = _FakeProvider("fake")
|
||||
image_gen_registry.register_provider(provider)
|
||||
assert image_gen_registry.get_provider("fake") is provider
|
||||
|
||||
def test_rejects_non_provider(self):
|
||||
with pytest.raises(TypeError):
|
||||
image_gen_registry.register_provider("not a provider") # type: ignore[arg-type]
|
||||
|
||||
def test_rejects_empty_name(self):
|
||||
class Empty(ImageGenProvider):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return ""
|
||||
|
||||
def generate(self, prompt, aspect_ratio="landscape", **kw):
|
||||
return {}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
image_gen_registry.register_provider(Empty())
|
||||
|
||||
def test_reregister_overwrites(self):
|
||||
a = _FakeProvider("same")
|
||||
b = _FakeProvider("same")
|
||||
image_gen_registry.register_provider(a)
|
||||
image_gen_registry.register_provider(b)
|
||||
assert image_gen_registry.get_provider("same") is b
|
||||
|
||||
def test_list_is_sorted(self):
|
||||
image_gen_registry.register_provider(_FakeProvider("zeta"))
|
||||
image_gen_registry.register_provider(_FakeProvider("alpha"))
|
||||
names = [p.name for p in image_gen_registry.list_providers()]
|
||||
assert names == ["alpha", "zeta"]
|
||||
|
||||
|
||||
class TestGetActiveProvider:
|
||||
def test_single_provider_autoresolves(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
image_gen_registry.register_provider(_FakeProvider("solo"))
|
||||
active = image_gen_registry.get_active_provider()
|
||||
assert active is not None and active.name == "solo"
|
||||
|
||||
def test_fal_preferred_on_multi_without_config(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
image_gen_registry.register_provider(_FakeProvider("fal"))
|
||||
image_gen_registry.register_provider(_FakeProvider("openai"))
|
||||
active = image_gen_registry.get_active_provider()
|
||||
assert active is not None and active.name == "fal"
|
||||
|
||||
def test_explicit_config_wins(self, tmp_path, monkeypatch):
|
||||
import yaml
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
yaml.safe_dump({"image_gen": {"provider": "openai"}})
|
||||
)
|
||||
image_gen_registry.register_provider(_FakeProvider("fal"))
|
||||
image_gen_registry.register_provider(_FakeProvider("openai"))
|
||||
active = image_gen_registry.get_active_provider()
|
||||
assert active is not None and active.name == "openai"
|
||||
|
||||
def test_missing_configured_provider_falls_back(self, tmp_path, monkeypatch):
|
||||
import yaml
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
yaml.safe_dump({"image_gen": {"provider": "replicate"}})
|
||||
)
|
||||
# Only FAL is registered — configured provider doesn't exist
|
||||
image_gen_registry.register_provider(_FakeProvider("fal"))
|
||||
active = image_gen_registry.get_active_provider()
|
||||
# Falls back to FAL preference (legacy default) rather than None
|
||||
assert active is not None and active.name == "fal"
|
||||
|
||||
def test_none_when_empty(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
assert image_gen_registry.get_active_provider() is None
|
||||
115
tests/agent/test_kimi_coding_anthropic_thinking.py
Normal file
115
tests/agent/test_kimi_coding_anthropic_thinking.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""Regression guard: don't send Anthropic ``thinking`` to Kimi's /coding endpoint.
|
||||
|
||||
Kimi's ``api.kimi.com/coding`` endpoint speaks the Anthropic Messages protocol
|
||||
but has its own thinking semantics. When ``thinking.enabled`` is present in
|
||||
the request, Kimi validates the message history and requires every prior
|
||||
assistant tool-call message to carry OpenAI-style ``reasoning_content``.
|
||||
|
||||
The Anthropic path never populates that field, and
|
||||
``convert_messages_to_anthropic`` strips Anthropic thinking blocks on
|
||||
third-party endpoints — so after one turn with tool calls the next request
|
||||
fails with HTTP 400::
|
||||
|
||||
thinking is enabled but reasoning_content is missing in assistant
|
||||
tool call message at index N
|
||||
|
||||
Kimi on the chat_completions route handles ``thinking`` via ``extra_body`` in
|
||||
``ChatCompletionsTransport`` (#13503). On the Anthropic route the right
|
||||
thing to do is drop the parameter entirely and let Kimi drive reasoning
|
||||
server-side.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestKimiCodingSkipsAnthropicThinking:
|
||||
"""build_anthropic_kwargs must not inject ``thinking`` for Kimi /coding."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"base_url",
|
||||
[
|
||||
"https://api.kimi.com/coding",
|
||||
"https://api.kimi.com/coding/v1",
|
||||
"https://api.kimi.com/coding/anthropic",
|
||||
"https://api.kimi.com/coding/",
|
||||
],
|
||||
)
|
||||
def test_kimi_coding_endpoint_omits_thinking(self, base_url: str) -> None:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="kimi-k2.5",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
base_url=base_url,
|
||||
)
|
||||
assert "thinking" not in kwargs, (
|
||||
"Anthropic thinking must not be sent to Kimi /coding — "
|
||||
"endpoint requires reasoning_content on history we don't preserve."
|
||||
)
|
||||
assert "output_config" not in kwargs
|
||||
|
||||
def test_kimi_coding_with_explicit_disabled_also_omits(self) -> None:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="kimi-k2.5",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": False},
|
||||
base_url="https://api.kimi.com/coding",
|
||||
)
|
||||
assert "thinking" not in kwargs
|
||||
|
||||
def test_non_kimi_third_party_still_gets_thinking(self) -> None:
|
||||
"""MiniMax and other third-party Anthropic endpoints must retain thinking."""
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="MiniMax-M2.7",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
base_url="https://api.minimax.io/anthropic",
|
||||
)
|
||||
assert "thinking" in kwargs
|
||||
assert kwargs["thinking"]["type"] == "enabled"
|
||||
|
||||
def test_native_anthropic_still_gets_thinking(self) -> None:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-sonnet-4-20250514",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
base_url=None,
|
||||
)
|
||||
assert "thinking" in kwargs
|
||||
|
||||
def test_kimi_root_endpoint_unaffected(self) -> None:
|
||||
"""Only the /coding route is special-cased — plain api.kimi.com is not.
|
||||
|
||||
``api.kimi.com`` without ``/coding`` uses the chat_completions transport
|
||||
(see runtime_provider._detect_api_mode_for_url); build_anthropic_kwargs
|
||||
should never see it, but if it somehow does we should not suppress
|
||||
thinking there — that path has different semantics.
|
||||
"""
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="kimi-k2.5",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
base_url="https://api.kimi.com/v1",
|
||||
)
|
||||
assert "thinking" in kwargs
|
||||
|
|
@ -789,6 +789,24 @@ class TestPromptBuilderConstants:
|
|||
assert "cron" in PLATFORM_HINTS
|
||||
assert "cli" in PLATFORM_HINTS
|
||||
|
||||
def test_cli_hint_does_not_suggest_media_tags(self):
|
||||
# Regression: MEDIA:/path tags are intercepted only by messaging
|
||||
# gateway platforms. On the CLI they render as literal text and
|
||||
# confuse users. The CLI hint must steer the agent away from them.
|
||||
cli_hint = PLATFORM_HINTS["cli"]
|
||||
assert "MEDIA:" in cli_hint, (
|
||||
"CLI hint should mention MEDIA: in order to tell the agent "
|
||||
"NOT to use it (negative guidance)."
|
||||
)
|
||||
# Must contain explicit "don't" language near the MEDIA reference.
|
||||
assert any(
|
||||
marker in cli_hint.lower()
|
||||
for marker in ("do not emit media", "not intercepted", "do not", "don't")
|
||||
), "CLI hint should explicitly discourage MEDIA: tags."
|
||||
# Messaging hints should still advertise MEDIA: positively (sanity
|
||||
# check that this test is calibrated correctly).
|
||||
assert "include MEDIA:" in PLATFORM_HINTS["telegram"]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Environment hints
|
||||
|
|
|
|||
|
|
@ -193,7 +193,7 @@ class TestBuildChildProgressCallback:
|
|||
|
||||
# task_index=0 in a batch of 3 → prefix "[1]"
|
||||
cb0 = _build_child_progress_callback(0, "test goal", parent, task_count=3)
|
||||
cb0("web_search", "test")
|
||||
cb0("tool.started", "web_search", "test", {})
|
||||
output = buf.getvalue()
|
||||
assert "[1]" in output
|
||||
|
||||
|
|
@ -201,7 +201,7 @@ class TestBuildChildProgressCallback:
|
|||
buf.truncate(0)
|
||||
buf.seek(0)
|
||||
cb2 = _build_child_progress_callback(2, "test goal", parent, task_count=3)
|
||||
cb2("web_search", "test")
|
||||
cb2("tool.started", "web_search", "test", {})
|
||||
output = buf.getvalue()
|
||||
assert "[3]" in output
|
||||
|
||||
|
|
|
|||
164
tests/agent/transports/test_bedrock_transport.py
Normal file
164
tests/agent/transports/test_bedrock_transport.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"""Tests for the BedrockTransport."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from agent.transports import get_transport
|
||||
from agent.transports.types import NormalizedResponse, ToolCall
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transport():
|
||||
import agent.transports.bedrock # noqa: F401
|
||||
return get_transport("bedrock_converse")
|
||||
|
||||
|
||||
class TestBedrockBasic:
|
||||
|
||||
def test_api_mode(self, transport):
|
||||
assert transport.api_mode == "bedrock_converse"
|
||||
|
||||
def test_registered(self, transport):
|
||||
assert transport is not None
|
||||
|
||||
|
||||
class TestBedrockBuildKwargs:
|
||||
|
||||
def test_basic_kwargs(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hello"}]
|
||||
kw = transport.build_kwargs(model="anthropic.claude-3-5-sonnet-20241022-v2:0", messages=msgs)
|
||||
assert kw["modelId"] == "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
assert kw["__bedrock_converse__"] is True
|
||||
assert kw["__bedrock_region__"] == "us-east-1"
|
||||
assert "messages" in kw
|
||||
|
||||
def test_custom_region(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
messages=msgs,
|
||||
region="eu-west-1",
|
||||
)
|
||||
assert kw["__bedrock_region__"] == "eu-west-1"
|
||||
|
||||
def test_max_tokens(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
messages=msgs,
|
||||
max_tokens=8192,
|
||||
)
|
||||
assert kw["inferenceConfig"]["maxTokens"] == 8192
|
||||
|
||||
|
||||
class TestBedrockConvertTools:
|
||||
|
||||
def test_convert_tools(self, transport):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "terminal",
|
||||
"description": "Run commands",
|
||||
"parameters": {"type": "object", "properties": {"command": {"type": "string"}}},
|
||||
}
|
||||
}]
|
||||
result = transport.convert_tools(tools)
|
||||
assert len(result) == 1
|
||||
assert result[0]["toolSpec"]["name"] == "terminal"
|
||||
|
||||
|
||||
class TestBedrockValidate:
|
||||
|
||||
def test_none(self, transport):
|
||||
assert transport.validate_response(None) is False
|
||||
|
||||
def test_raw_dict_valid(self, transport):
|
||||
assert transport.validate_response({"output": {"message": {}}}) is True
|
||||
|
||||
def test_raw_dict_invalid(self, transport):
|
||||
assert transport.validate_response({"error": "fail"}) is False
|
||||
|
||||
def test_normalized_valid(self, transport):
|
||||
r = SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="hi"))])
|
||||
assert transport.validate_response(r) is True
|
||||
|
||||
|
||||
class TestBedrockMapFinishReason:
|
||||
|
||||
def test_end_turn(self, transport):
|
||||
assert transport.map_finish_reason("end_turn") == "stop"
|
||||
|
||||
def test_tool_use(self, transport):
|
||||
assert transport.map_finish_reason("tool_use") == "tool_calls"
|
||||
|
||||
def test_max_tokens(self, transport):
|
||||
assert transport.map_finish_reason("max_tokens") == "length"
|
||||
|
||||
def test_guardrail(self, transport):
|
||||
assert transport.map_finish_reason("guardrail_intervened") == "content_filter"
|
||||
|
||||
def test_unknown(self, transport):
|
||||
assert transport.map_finish_reason("unknown") == "stop"
|
||||
|
||||
|
||||
class TestBedrockNormalize:
|
||||
|
||||
def _make_bedrock_response(self, text="Hello", tool_calls=None, stop_reason="end_turn"):
|
||||
"""Build a raw Bedrock converse response dict."""
|
||||
content = []
|
||||
if text:
|
||||
content.append({"text": text})
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
content.append({
|
||||
"toolUse": {
|
||||
"toolUseId": tc["id"],
|
||||
"name": tc["name"],
|
||||
"input": tc["input"],
|
||||
}
|
||||
})
|
||||
return {
|
||||
"output": {"message": {"role": "assistant", "content": content}},
|
||||
"stopReason": stop_reason,
|
||||
"usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15},
|
||||
}
|
||||
|
||||
def test_text_response(self, transport):
|
||||
raw = self._make_bedrock_response(text="Hello world")
|
||||
nr = transport.normalize_response(raw)
|
||||
assert isinstance(nr, NormalizedResponse)
|
||||
assert nr.content == "Hello world"
|
||||
assert nr.finish_reason == "stop"
|
||||
|
||||
def test_tool_call_response(self, transport):
|
||||
raw = self._make_bedrock_response(
|
||||
text=None,
|
||||
tool_calls=[{"id": "tool_1", "name": "terminal", "input": {"command": "ls"}}],
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
nr = transport.normalize_response(raw)
|
||||
assert nr.finish_reason == "tool_calls"
|
||||
assert len(nr.tool_calls) == 1
|
||||
assert nr.tool_calls[0].name == "terminal"
|
||||
|
||||
def test_already_normalized_response(self, transport):
|
||||
"""Test normalize_response handles already-normalized SimpleNamespace (from dispatch site)."""
|
||||
pre_normalized = SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(
|
||||
content="Hello from Bedrock",
|
||||
tool_calls=None,
|
||||
reasoning=None,
|
||||
reasoning_content=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
)
|
||||
nr = transport.normalize_response(pre_normalized)
|
||||
assert isinstance(nr, NormalizedResponse)
|
||||
assert nr.content == "Hello from Bedrock"
|
||||
assert nr.finish_reason == "stop"
|
||||
assert nr.usage is not None
|
||||
assert nr.usage.prompt_tokens == 10
|
||||
349
tests/agent/transports/test_chat_completions.py
Normal file
349
tests/agent/transports/test_chat_completions.py
Normal file
|
|
@ -0,0 +1,349 @@
|
|||
"""Tests for the ChatCompletionsTransport."""
|
||||
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from agent.transports import get_transport
|
||||
from agent.transports.types import NormalizedResponse, ToolCall
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transport():
|
||||
import agent.transports.chat_completions # noqa: F401
|
||||
return get_transport("chat_completions")
|
||||
|
||||
|
||||
class TestChatCompletionsBasic:
|
||||
|
||||
def test_api_mode(self, transport):
|
||||
assert transport.api_mode == "chat_completions"
|
||||
|
||||
def test_registered(self, transport):
|
||||
assert transport is not None
|
||||
|
||||
def test_convert_tools_identity(self, transport):
|
||||
tools = [{"type": "function", "function": {"name": "test", "parameters": {}}}]
|
||||
assert transport.convert_tools(tools) is tools
|
||||
|
||||
def test_convert_messages_no_codex_leaks(self, transport):
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
result = transport.convert_messages(msgs)
|
||||
assert result is msgs # no copy needed
|
||||
|
||||
def test_convert_messages_strips_codex_fields(self, transport):
|
||||
msgs = [
|
||||
{"role": "assistant", "content": "ok", "codex_reasoning_items": [{"id": "rs_1"}],
|
||||
"tool_calls": [{"id": "call_1", "call_id": "call_1", "response_item_id": "fc_1",
|
||||
"type": "function", "function": {"name": "t", "arguments": "{}"}}]},
|
||||
]
|
||||
result = transport.convert_messages(msgs)
|
||||
assert "codex_reasoning_items" not in result[0]
|
||||
assert "call_id" not in result[0]["tool_calls"][0]
|
||||
assert "response_item_id" not in result[0]["tool_calls"][0]
|
||||
# Original list untouched (deepcopy-on-demand)
|
||||
assert "codex_reasoning_items" in msgs[0]
|
||||
|
||||
|
||||
class TestChatCompletionsBuildKwargs:
|
||||
|
||||
def test_basic_kwargs(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hello"}]
|
||||
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, timeout=30.0)
|
||||
assert kw["model"] == "gpt-4o"
|
||||
assert kw["messages"][0]["content"] == "Hello"
|
||||
assert kw["timeout"] == 30.0
|
||||
|
||||
def test_developer_role_swap(self, transport):
|
||||
msgs = [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(model="gpt-5.4", messages=msgs, model_lower="gpt-5.4")
|
||||
assert kw["messages"][0]["role"] == "developer"
|
||||
|
||||
def test_no_developer_swap_for_non_gpt5(self, transport):
|
||||
msgs = [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(model="claude-sonnet-4", messages=msgs, model_lower="claude-sonnet-4")
|
||||
assert kw["messages"][0]["role"] == "system"
|
||||
|
||||
def test_tools_included(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
tools = [{"type": "function", "function": {"name": "test", "parameters": {}}}]
|
||||
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, tools=tools)
|
||||
assert kw["tools"] == tools
|
||||
|
||||
def test_openrouter_provider_prefs(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-4o", messages=msgs,
|
||||
is_openrouter=True,
|
||||
provider_preferences={"only": ["openai"]},
|
||||
)
|
||||
assert kw["extra_body"]["provider"] == {"only": ["openai"]}
|
||||
|
||||
def test_nous_tags(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, is_nous=True)
|
||||
assert kw["extra_body"]["tags"] == ["product=hermes-agent"]
|
||||
|
||||
def test_reasoning_default(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-4o", messages=msgs,
|
||||
supports_reasoning=True,
|
||||
)
|
||||
assert kw["extra_body"]["reasoning"] == {"enabled": True, "effort": "medium"}
|
||||
|
||||
def test_nous_omits_disabled_reasoning(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-4o", messages=msgs,
|
||||
supports_reasoning=True,
|
||||
is_nous=True,
|
||||
reasoning_config={"enabled": False},
|
||||
)
|
||||
# Nous rejects enabled=false; reasoning omitted entirely
|
||||
assert "reasoning" not in kw.get("extra_body", {})
|
||||
|
||||
def test_ollama_num_ctx(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="llama3", messages=msgs,
|
||||
ollama_num_ctx=32768,
|
||||
)
|
||||
assert kw["extra_body"]["options"]["num_ctx"] == 32768
|
||||
|
||||
def test_custom_think_false(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="qwen3", messages=msgs,
|
||||
is_custom_provider=True,
|
||||
reasoning_config={"effort": "none"},
|
||||
)
|
||||
assert kw["extra_body"]["think"] is False
|
||||
|
||||
def test_max_tokens_with_fn(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-4o", messages=msgs,
|
||||
max_tokens=4096,
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
assert kw["max_tokens"] == 4096
|
||||
|
||||
def test_ephemeral_overrides_max_tokens(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-4o", messages=msgs,
|
||||
max_tokens=4096,
|
||||
ephemeral_max_output_tokens=2048,
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
assert kw["max_tokens"] == 2048
|
||||
|
||||
def test_nvidia_default_max_tokens(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="glm-4.7", messages=msgs,
|
||||
is_nvidia_nim=True,
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
# NVIDIA default: 16384
|
||||
assert kw["max_tokens"] == 16384
|
||||
|
||||
def test_qwen_default_max_tokens(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="qwen3-coder-plus", messages=msgs,
|
||||
is_qwen_portal=True,
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
# Qwen default: 65536
|
||||
assert kw["max_tokens"] == 65536
|
||||
|
||||
def test_anthropic_max_output_for_claude_on_aggregator(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="anthropic/claude-sonnet-4.6", messages=msgs,
|
||||
is_openrouter=True,
|
||||
anthropic_max_output=64000,
|
||||
)
|
||||
# Set as plain max_tokens (not via fn) because the aggregator proxies to
|
||||
# Anthropic Messages API which requires the field.
|
||||
assert kw["max_tokens"] == 64000
|
||||
|
||||
def test_request_overrides_last(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-4o", messages=msgs,
|
||||
request_overrides={"service_tier": "priority"},
|
||||
)
|
||||
assert kw["service_tier"] == "priority"
|
||||
|
||||
def test_fixed_temperature(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, fixed_temperature=0.6)
|
||||
assert kw["temperature"] == 0.6
|
||||
|
||||
def test_omit_temperature(self, transport):
|
||||
msgs = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(model="gpt-4o", messages=msgs, omit_temperature=True, fixed_temperature=0.5)
|
||||
# omit wins
|
||||
assert "temperature" not in kw
|
||||
|
||||
|
||||
class TestChatCompletionsKimi:
|
||||
"""Regression tests for the Kimi/Moonshot quirks migrated into the transport."""
|
||||
|
||||
def test_kimi_max_tokens_default(self, transport):
|
||||
kw = transport.build_kwargs(
|
||||
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
|
||||
is_kimi=True,
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
# Kimi CLI default: 32000
|
||||
assert kw["max_tokens"] == 32000
|
||||
|
||||
def test_kimi_reasoning_effort_top_level(self, transport):
|
||||
kw = transport.build_kwargs(
|
||||
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
|
||||
is_kimi=True,
|
||||
reasoning_config={"effort": "high"},
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
# Kimi requires reasoning_effort as a top-level parameter
|
||||
assert kw["reasoning_effort"] == "high"
|
||||
|
||||
def test_kimi_reasoning_effort_omitted_when_thinking_disabled(self, transport):
|
||||
kw = transport.build_kwargs(
|
||||
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
|
||||
is_kimi=True,
|
||||
reasoning_config={"enabled": False},
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
# Mirror Kimi CLI: omit reasoning_effort entirely when thinking off
|
||||
assert "reasoning_effort" not in kw
|
||||
|
||||
def test_kimi_thinking_enabled_extra_body(self, transport):
|
||||
kw = transport.build_kwargs(
|
||||
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
|
||||
is_kimi=True,
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
assert kw["extra_body"]["thinking"] == {"type": "enabled"}
|
||||
|
||||
def test_kimi_thinking_disabled_extra_body(self, transport):
|
||||
kw = transport.build_kwargs(
|
||||
model="kimi-k2", messages=[{"role": "user", "content": "Hi"}],
|
||||
is_kimi=True,
|
||||
reasoning_config={"enabled": False},
|
||||
max_tokens_param_fn=lambda n: {"max_tokens": n},
|
||||
)
|
||||
assert kw["extra_body"]["thinking"] == {"type": "disabled"}
|
||||
|
||||
|
||||
class TestChatCompletionsValidate:
|
||||
|
||||
def test_none(self, transport):
|
||||
assert transport.validate_response(None) is False
|
||||
|
||||
def test_no_choices(self, transport):
|
||||
r = SimpleNamespace(choices=None)
|
||||
assert transport.validate_response(r) is False
|
||||
|
||||
def test_empty_choices(self, transport):
|
||||
r = SimpleNamespace(choices=[])
|
||||
assert transport.validate_response(r) is False
|
||||
|
||||
def test_valid(self, transport):
|
||||
r = SimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content="hi"))])
|
||||
assert transport.validate_response(r) is True
|
||||
|
||||
|
||||
class TestChatCompletionsNormalize:
|
||||
|
||||
def test_text_response(self, transport):
|
||||
r = SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(content="Hello", tool_calls=None, reasoning_content=None),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert isinstance(nr, NormalizedResponse)
|
||||
assert nr.content == "Hello"
|
||||
assert nr.finish_reason == "stop"
|
||||
assert nr.tool_calls is None
|
||||
|
||||
def test_tool_call_response(self, transport):
|
||||
tc = SimpleNamespace(
|
||||
id="call_123",
|
||||
function=SimpleNamespace(name="terminal", arguments='{"command": "ls"}'),
|
||||
)
|
||||
r = SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(content=None, tool_calls=[tc], reasoning_content=None),
|
||||
finish_reason="tool_calls",
|
||||
)],
|
||||
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=20, total_tokens=30),
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert len(nr.tool_calls) == 1
|
||||
assert nr.tool_calls[0].name == "terminal"
|
||||
assert nr.tool_calls[0].id == "call_123"
|
||||
|
||||
def test_tool_call_extra_content_preserved(self, transport):
|
||||
"""Gemini 3 thinking models attach extra_content with thought_signature
|
||||
on tool_calls. Without this replay on the next turn, the API rejects
|
||||
the request with 400. The transport MUST surface extra_content so the
|
||||
agent loop can write it back into the assistant message."""
|
||||
tc = SimpleNamespace(
|
||||
id="call_gem",
|
||||
function=SimpleNamespace(name="terminal", arguments='{"command": "ls"}'),
|
||||
extra_content={"google": {"thought_signature": "SIG_ABC123"}},
|
||||
)
|
||||
r = SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(content=None, tool_calls=[tc], reasoning_content=None),
|
||||
finish_reason="tool_calls",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert nr.tool_calls[0].provider_data == {
|
||||
"extra_content": {"google": {"thought_signature": "SIG_ABC123"}}
|
||||
}
|
||||
|
||||
def test_reasoning_content_preserved_separately(self, transport):
|
||||
"""DeepSeek/Moonshot use reasoning_content distinct from reasoning.
|
||||
Don't merge them — the thinking-prefill retry check reads each field
|
||||
separately."""
|
||||
r = SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(
|
||||
content=None, tool_calls=None,
|
||||
reasoning="summary text",
|
||||
reasoning_content="detailed scratchpad",
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert nr.reasoning == "summary text"
|
||||
assert nr.provider_data == {"reasoning_content": "detailed scratchpad"}
|
||||
|
||||
|
||||
class TestChatCompletionsCacheStats:
|
||||
|
||||
def test_no_usage(self, transport):
|
||||
r = SimpleNamespace(usage=None)
|
||||
assert transport.extract_cache_stats(r) is None
|
||||
|
||||
def test_no_details(self, transport):
|
||||
r = SimpleNamespace(usage=SimpleNamespace(prompt_tokens_details=None))
|
||||
assert transport.extract_cache_stats(r) is None
|
||||
|
||||
def test_with_cache(self, transport):
|
||||
details = SimpleNamespace(cached_tokens=500, cache_write_tokens=100)
|
||||
r = SimpleNamespace(usage=SimpleNamespace(prompt_tokens_details=details))
|
||||
result = transport.extract_cache_stats(r)
|
||||
assert result == {"cached_tokens": 500, "creation_tokens": 100}
|
||||
220
tests/agent/transports/test_codex_transport.py
Normal file
220
tests/agent/transports/test_codex_transport.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
"""Tests for the ResponsesApiTransport (Codex)."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from agent.transports import get_transport
|
||||
from agent.transports.types import NormalizedResponse, ToolCall
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transport():
|
||||
import agent.transports.codex # noqa: F401
|
||||
return get_transport("codex_responses")
|
||||
|
||||
|
||||
class TestCodexTransportBasic:
|
||||
|
||||
def test_api_mode(self, transport):
|
||||
assert transport.api_mode == "codex_responses"
|
||||
|
||||
def test_registered_on_import(self, transport):
|
||||
assert transport is not None
|
||||
|
||||
def test_convert_tools(self, transport):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "terminal",
|
||||
"description": "Run a command",
|
||||
"parameters": {"type": "object", "properties": {"command": {"type": "string"}}},
|
||||
}
|
||||
}]
|
||||
result = transport.convert_tools(tools)
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "function"
|
||||
assert result[0]["name"] == "terminal"
|
||||
|
||||
|
||||
class TestCodexBuildKwargs:
|
||||
|
||||
def test_basic_kwargs(self, transport):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4",
|
||||
messages=messages,
|
||||
tools=[],
|
||||
)
|
||||
assert kw["model"] == "gpt-5.4"
|
||||
assert kw["instructions"] == "You are helpful."
|
||||
assert "input" in kw
|
||||
assert kw["store"] is False
|
||||
|
||||
def test_system_extracted_from_messages(self, transport):
|
||||
messages = [
|
||||
{"role": "system", "content": "Custom system prompt"},
|
||||
{"role": "user", "content": "Hi"},
|
||||
]
|
||||
kw = transport.build_kwargs(model="gpt-5.4", messages=messages, tools=[])
|
||||
assert kw["instructions"] == "Custom system prompt"
|
||||
|
||||
def test_no_system_uses_default(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(model="gpt-5.4", messages=messages, tools=[])
|
||||
assert kw["instructions"] # should be non-empty default
|
||||
|
||||
def test_reasoning_config(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4", messages=messages, tools=[],
|
||||
reasoning_config={"effort": "high"},
|
||||
)
|
||||
assert kw.get("reasoning", {}).get("effort") == "high"
|
||||
|
||||
def test_reasoning_disabled(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4", messages=messages, tools=[],
|
||||
reasoning_config={"enabled": False},
|
||||
)
|
||||
assert "reasoning" not in kw or kw.get("include") == []
|
||||
|
||||
def test_session_id_sets_cache_key(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4", messages=messages, tools=[],
|
||||
session_id="test-session-123",
|
||||
)
|
||||
assert kw.get("prompt_cache_key") == "test-session-123"
|
||||
|
||||
def test_github_responses_no_cache_key(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4", messages=messages, tools=[],
|
||||
session_id="test-session",
|
||||
is_github_responses=True,
|
||||
)
|
||||
assert "prompt_cache_key" not in kw
|
||||
|
||||
def test_max_tokens(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4", messages=messages, tools=[],
|
||||
max_tokens=4096,
|
||||
)
|
||||
assert kw.get("max_output_tokens") == 4096
|
||||
|
||||
def test_codex_backend_no_max_output_tokens(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4", messages=messages, tools=[],
|
||||
max_tokens=4096,
|
||||
is_codex_backend=True,
|
||||
)
|
||||
assert "max_output_tokens" not in kw
|
||||
|
||||
def test_xai_headers(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="grok-3", messages=messages, tools=[],
|
||||
session_id="conv-123",
|
||||
is_xai_responses=True,
|
||||
)
|
||||
assert kw.get("extra_headers", {}).get("x-grok-conv-id") == "conv-123"
|
||||
|
||||
def test_minimal_effort_clamped(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="gpt-5.4", messages=messages, tools=[],
|
||||
reasoning_config={"effort": "minimal"},
|
||||
)
|
||||
# "minimal" should be clamped to "low"
|
||||
assert kw.get("reasoning", {}).get("effort") == "low"
|
||||
|
||||
|
||||
class TestCodexValidateResponse:
|
||||
|
||||
def test_none_response(self, transport):
|
||||
assert transport.validate_response(None) is False
|
||||
|
||||
def test_empty_output(self, transport):
|
||||
r = SimpleNamespace(output=[], output_text=None)
|
||||
assert transport.validate_response(r) is False
|
||||
|
||||
def test_valid_output(self, transport):
|
||||
r = SimpleNamespace(output=[{"type": "message", "content": []}])
|
||||
assert transport.validate_response(r) is True
|
||||
|
||||
def test_output_text_fallback_not_valid(self, transport):
|
||||
"""validate_response is strict — output_text doesn't make it valid.
|
||||
The caller handles output_text fallback with diagnostic logging."""
|
||||
r = SimpleNamespace(output=None, output_text="Some text")
|
||||
assert transport.validate_response(r) is False
|
||||
|
||||
|
||||
class TestCodexMapFinishReason:
|
||||
|
||||
def test_completed(self, transport):
|
||||
assert transport.map_finish_reason("completed") == "stop"
|
||||
|
||||
def test_incomplete(self, transport):
|
||||
assert transport.map_finish_reason("incomplete") == "length"
|
||||
|
||||
def test_failed(self, transport):
|
||||
assert transport.map_finish_reason("failed") == "stop"
|
||||
|
||||
def test_unknown(self, transport):
|
||||
assert transport.map_finish_reason("unknown_status") == "stop"
|
||||
|
||||
|
||||
class TestCodexNormalizeResponse:
|
||||
|
||||
def test_text_response(self, transport):
|
||||
"""Normalize a simple text Codex response."""
|
||||
r = SimpleNamespace(
|
||||
output=[
|
||||
SimpleNamespace(
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[SimpleNamespace(type="output_text", text="Hello world")],
|
||||
status="completed",
|
||||
),
|
||||
],
|
||||
status="completed",
|
||||
incomplete_details=None,
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=5,
|
||||
input_tokens_details=None, output_tokens_details=None),
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert isinstance(nr, NormalizedResponse)
|
||||
assert nr.content == "Hello world"
|
||||
assert nr.finish_reason == "stop"
|
||||
|
||||
def test_tool_call_response(self, transport):
|
||||
"""Normalize a Codex response with tool calls."""
|
||||
r = SimpleNamespace(
|
||||
output=[
|
||||
SimpleNamespace(
|
||||
type="function_call",
|
||||
call_id="call_abc123",
|
||||
name="terminal",
|
||||
arguments=json.dumps({"command": "ls"}),
|
||||
id="fc_abc123",
|
||||
status="completed",
|
||||
),
|
||||
],
|
||||
status="completed",
|
||||
incomplete_details=None,
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=20,
|
||||
input_tokens_details=None, output_tokens_details=None),
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert nr.finish_reason == "tool_calls"
|
||||
assert len(nr.tool_calls) == 1
|
||||
tc = nr.tool_calls[0]
|
||||
assert tc.name == "terminal"
|
||||
assert '"command"' in tc.arguments
|
||||
|
|
@ -254,3 +254,88 @@ class TestCliApprovalUi:
|
|||
|
||||
# Command got truncated with a marker.
|
||||
assert "(command truncated" in rendered
|
||||
|
||||
|
||||
class TestApprovalCallbackThreadLocalWiring:
|
||||
"""Regression guard for the thread-local callback freeze (#13617 / #13618).
|
||||
|
||||
After 62348cff made _approval_callback / _sudo_password_callback thread-local
|
||||
(ACP GHSA-qg5c-hvr5-hjgr), the CLI agent thread could no longer see callbacks
|
||||
registered in the main thread — the dangerous-command prompt silently fell
|
||||
back to stdin input() and deadlocked against prompt_toolkit. The fix is to
|
||||
register the callbacks INSIDE the agent worker thread (matching the ACP
|
||||
pattern). These tests lock in that invariant.
|
||||
"""
|
||||
|
||||
def test_main_thread_registration_is_invisible_to_child_thread(self):
|
||||
"""Confirms the underlying threading.local semantics that drove the bug.
|
||||
|
||||
If this ever starts passing as "visible", the thread-local isolation
|
||||
is gone and the ACP race GHSA-qg5c-hvr5-hjgr may be back.
|
||||
"""
|
||||
from tools.terminal_tool import (
|
||||
set_approval_callback,
|
||||
_get_approval_callback,
|
||||
)
|
||||
|
||||
def main_cb(_cmd, _desc):
|
||||
return "once"
|
||||
|
||||
set_approval_callback(main_cb)
|
||||
try:
|
||||
seen = {}
|
||||
|
||||
def _child():
|
||||
seen["value"] = _get_approval_callback()
|
||||
|
||||
t = threading.Thread(target=_child, daemon=True)
|
||||
t.start()
|
||||
t.join(timeout=2)
|
||||
assert seen["value"] is None
|
||||
finally:
|
||||
set_approval_callback(None)
|
||||
|
||||
def test_child_thread_registration_is_visible_and_cleared_in_finally(self):
|
||||
"""The fix pattern: register INSIDE the worker thread, clear in finally.
|
||||
|
||||
This is exactly what cli.py's run_agent() closure does. If this test
|
||||
fails, the CLI approval prompt freeze (#13617) has regressed.
|
||||
"""
|
||||
from tools.terminal_tool import (
|
||||
set_approval_callback,
|
||||
set_sudo_password_callback,
|
||||
_get_approval_callback,
|
||||
_get_sudo_password_callback,
|
||||
)
|
||||
|
||||
def approval_cb(_cmd, _desc):
|
||||
return "once"
|
||||
|
||||
def sudo_cb():
|
||||
return "hunter2"
|
||||
|
||||
seen = {}
|
||||
|
||||
def _worker():
|
||||
# Mimic cli.py's run_agent() thread target.
|
||||
set_approval_callback(approval_cb)
|
||||
set_sudo_password_callback(sudo_cb)
|
||||
try:
|
||||
seen["approval"] = _get_approval_callback()
|
||||
seen["sudo"] = _get_sudo_password_callback()
|
||||
finally:
|
||||
set_approval_callback(None)
|
||||
set_sudo_password_callback(None)
|
||||
seen["approval_after"] = _get_approval_callback()
|
||||
seen["sudo_after"] = _get_sudo_password_callback()
|
||||
|
||||
t = threading.Thread(target=_worker, daemon=True)
|
||||
t.start()
|
||||
t.join(timeout=2)
|
||||
|
||||
assert seen["approval"] is approval_cb
|
||||
assert seen["sudo"] is sudo_cb
|
||||
# Finally block must clear both slots — otherwise a reused thread
|
||||
# would hold a stale reference to a disposed CLI instance.
|
||||
assert seen["approval_after"] is None
|
||||
assert seen["sudo_after"] is None
|
||||
|
|
|
|||
|
|
@ -147,6 +147,37 @@ class TestEscapedSpaces:
|
|||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["remainder"] == "what is this?"
|
||||
|
||||
def test_unquoted_spaces_in_path(self, tmp_image_with_spaces):
|
||||
result = _detect_file_drop(str(tmp_image_with_spaces))
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["is_image"] is True
|
||||
assert result["remainder"] == ""
|
||||
|
||||
def test_unquoted_spaces_with_trailing_text(self, tmp_image_with_spaces):
|
||||
user_input = f"{tmp_image_with_spaces} what is this?"
|
||||
result = _detect_file_drop(user_input)
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["remainder"] == "what is this?"
|
||||
|
||||
def test_mixed_escaped_and_literal_spaces_in_path(self, tmp_path):
|
||||
img = tmp_path / "Screenshot 2026-04-21 at 1.04.43 PM.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n")
|
||||
mixed = str(img).replace("Screenshot ", "Screenshot\\ ").replace("2026-04-21 ", "2026-04-21\\ ").replace("at ", "at\\ ")
|
||||
result = _detect_file_drop(mixed)
|
||||
assert result is not None
|
||||
assert result["path"] == img
|
||||
assert result["is_image"] is True
|
||||
assert result["remainder"] == ""
|
||||
|
||||
def test_file_uri_image_path(self, tmp_image_with_spaces):
|
||||
uri = tmp_image_with_spaces.as_uri()
|
||||
result = _detect_file_drop(uri)
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["is_image"] is True
|
||||
|
||||
def test_tilde_prefixed_path(self, tmp_path, monkeypatch):
|
||||
home = tmp_path / "home"
|
||||
img = home / "storage" / "shared" / "Pictures" / "cat.png"
|
||||
|
|
|
|||
|
|
@ -115,3 +115,27 @@ def test_final_assistant_content_can_leave_markdown_raw():
|
|||
|
||||
output = _render_to_text(renderable)
|
||||
assert "***Bold italic***" in output
|
||||
|
||||
|
||||
def test_strip_mode_preserves_intraword_underscores_in_snake_case_identifiers():
|
||||
renderable = _render_final_assistant_content(
|
||||
"Let me look at test_case_with_underscores and SOME_CONST "
|
||||
"then /tmp/snake_case_dir/file_with_name.py",
|
||||
mode="strip",
|
||||
)
|
||||
|
||||
output = _render_to_text(renderable)
|
||||
assert "test_case_with_underscores" in output
|
||||
assert "SOME_CONST" in output
|
||||
assert "snake_case_dir" in output
|
||||
assert "file_with_name" in output
|
||||
|
||||
|
||||
def test_strip_mode_still_strips_boundary_underscore_emphasis():
|
||||
renderable = _render_final_assistant_content(
|
||||
"say _hi_ and __bold__ now",
|
||||
mode="strip",
|
||||
)
|
||||
|
||||
output = _render_to_text(renderable)
|
||||
assert "say hi and bold now" in output
|
||||
|
|
|
|||
91
tests/gateway/test_complete_path_at_filter.py
Normal file
91
tests/gateway/test_complete_path_at_filter.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
"""Regression tests for the TUI gateway's `complete.path` handler.
|
||||
|
||||
Reported during the TUI v2 blitz retest: typing `@folder:` (and `@folder`
|
||||
with no colon yet) still surfaced files alongside directories in the
|
||||
TUI composer, because the gateway-side completion lives in
|
||||
`tui_gateway/server.py` and was never touched by the earlier fix to
|
||||
`hermes_cli/commands.py`.
|
||||
|
||||
Covers:
|
||||
- `@folder:` only yields directories
|
||||
- `@file:` only yields regular files
|
||||
- Bare `@folder` / `@file` (no colon) lists cwd directly
|
||||
- Explicit prefix is preserved in the completion text
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from tui_gateway import server
|
||||
|
||||
|
||||
def _fixture(tmp_path: Path):
|
||||
(tmp_path / "readme.md").write_text("x")
|
||||
(tmp_path / ".env").write_text("x")
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "docs").mkdir()
|
||||
|
||||
|
||||
def _items(word: str):
|
||||
resp = server.handle_request({"id": "1", "method": "complete.path", "params": {"word": word}})
|
||||
|
||||
return [(it["text"], it["display"], it.get("meta", "")) for it in resp["result"]["items"]]
|
||||
|
||||
|
||||
def test_at_folder_colon_only_dirs(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
_fixture(tmp_path)
|
||||
|
||||
texts = [t for t, _, _ in _items("@folder:")]
|
||||
|
||||
assert all(t.startswith("@folder:") for t in texts), texts
|
||||
assert any(t == "@folder:src/" for t in texts)
|
||||
assert any(t == "@folder:docs/" for t in texts)
|
||||
assert not any(t == "@folder:readme.md" for t in texts)
|
||||
assert not any(t == "@folder:.env" for t in texts)
|
||||
|
||||
|
||||
def test_at_file_colon_only_files(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
_fixture(tmp_path)
|
||||
|
||||
texts = [t for t, _, _ in _items("@file:")]
|
||||
|
||||
assert all(t.startswith("@file:") for t in texts), texts
|
||||
assert any(t == "@file:readme.md" for t in texts)
|
||||
assert not any(t == "@file:src/" for t in texts)
|
||||
assert not any(t == "@file:docs/" for t in texts)
|
||||
|
||||
|
||||
def test_at_folder_bare_without_colon_lists_dirs(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
_fixture(tmp_path)
|
||||
|
||||
texts = [t for t, _, _ in _items("@folder")]
|
||||
|
||||
assert any(t == "@folder:src/" for t in texts), texts
|
||||
assert any(t == "@folder:docs/" for t in texts), texts
|
||||
assert not any(t == "@folder:readme.md" for t in texts)
|
||||
|
||||
|
||||
def test_at_file_bare_without_colon_lists_files(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
_fixture(tmp_path)
|
||||
|
||||
texts = [t for t, _, _ in _items("@file")]
|
||||
|
||||
assert any(t == "@file:readme.md" for t in texts), texts
|
||||
assert not any(t == "@file:src/" for t in texts)
|
||||
|
||||
|
||||
def test_bare_at_still_shows_static_refs(tmp_path, monkeypatch):
|
||||
"""`@` alone should list the static references so users discover the
|
||||
available prefixes. (Unchanged behaviour; regression guard.)
|
||||
"""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
texts = [t for t, _, _ in _items("@")]
|
||||
|
||||
for expected in ("@diff", "@staged", "@file:", "@folder:", "@url:", "@git:"):
|
||||
assert expected in texts, f"missing static ref {expected!r} in {texts!r}"
|
||||
159
tests/gateway/test_reply_to_injection.py
Normal file
159
tests/gateway/test_reply_to_injection.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
"""Tests for reply-to pointer injection in _prepare_inbound_message_text.
|
||||
|
||||
The `[Replying to: "..."]` prefix is a *disambiguation pointer*, not
|
||||
deduplication. It must always be injected when the user explicitly replies
|
||||
to a prior message — even when the quoted text already exists somewhere
|
||||
in the conversation history. History can contain the same or similar text
|
||||
multiple times, and without an explicit pointer the agent has to guess
|
||||
which prior message the user is referencing.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_runner() -> GatewayRunner:
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake")},
|
||||
)
|
||||
runner.adapters = {}
|
||||
runner._model = "openai/gpt-4.1-mini"
|
||||
runner._base_url = None
|
||||
return runner
|
||||
|
||||
|
||||
def _source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_name="DM",
|
||||
chat_type="private",
|
||||
user_name="Alice",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_prefix_injected_when_text_absent_from_history():
|
||||
runner = _make_runner()
|
||||
source = _source()
|
||||
event = MessageEvent(
|
||||
text="What's the best time to go?",
|
||||
source=source,
|
||||
reply_to_message_id="42",
|
||||
reply_to_text="Japan is great for culture, food, and efficiency.",
|
||||
)
|
||||
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[{"role": "user", "content": "unrelated"}],
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.startswith(
|
||||
'[Replying to: "Japan is great for culture, food, and efficiency."]'
|
||||
)
|
||||
assert result.endswith("What's the best time to go?")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_prefix_still_injected_when_text_in_history():
|
||||
"""Regression test: the pointer must survive even when the quoted text
|
||||
already appears in history. Previously a `found_in_history` guard
|
||||
silently dropped the prefix, leaving the agent to guess which prior
|
||||
message the user was referencing."""
|
||||
runner = _make_runner()
|
||||
source = _source()
|
||||
quoted = "Japan is great for culture, food, and efficiency."
|
||||
event = MessageEvent(
|
||||
text="What's the best time to go?",
|
||||
source=source,
|
||||
reply_to_message_id="42",
|
||||
reply_to_text=quoted,
|
||||
)
|
||||
|
||||
history = [
|
||||
{"role": "user", "content": "I'm thinking of going to Japan or Italy."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": (
|
||||
f"{quoted} Italy is better if you prefer a relaxed pace."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": "How long should I stay?"},
|
||||
{"role": "assistant", "content": "For Japan, 10-14 days is ideal."},
|
||||
]
|
||||
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=history,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.startswith(f'[Replying to: "{quoted}"]')
|
||||
assert result.endswith("What's the best time to go?")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_prefix_without_reply_context():
|
||||
runner = _make_runner()
|
||||
source = _source()
|
||||
event = MessageEvent(text="hello", source=source)
|
||||
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[],
|
||||
)
|
||||
|
||||
assert result == "hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_prefix_when_reply_to_text_is_empty():
|
||||
"""reply_to_message_id alone without text (e.g. a reply to a media-only
|
||||
message) should not produce an empty `[Replying to: ""]` prefix."""
|
||||
runner = _make_runner()
|
||||
source = _source()
|
||||
event = MessageEvent(
|
||||
text="hi",
|
||||
source=source,
|
||||
reply_to_message_id="42",
|
||||
reply_to_text=None,
|
||||
)
|
||||
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[],
|
||||
)
|
||||
|
||||
assert result == "hi"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_snippet_truncated_to_500_chars():
|
||||
runner = _make_runner()
|
||||
source = _source()
|
||||
long_text = "x" * 800
|
||||
event = MessageEvent(
|
||||
text="follow-up",
|
||||
source=source,
|
||||
reply_to_message_id="42",
|
||||
reply_to_text=long_text,
|
||||
)
|
||||
|
||||
result = await runner._prepare_inbound_message_text(
|
||||
event=event,
|
||||
source=source,
|
||||
history=[],
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.startswith('[Replying to: "' + "x" * 500 + '"]')
|
||||
assert "x" * 501 not in result
|
||||
76
tests/gateway/test_session_list_allowed_sources.py
Normal file
76
tests/gateway/test_session_list_allowed_sources.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
"""Regression tests for the TUI gateway's ``session.list`` handler.
|
||||
|
||||
Reported during TUI v2 blitz retest: the ``/resume`` modal inside a TUI
|
||||
session only surfaced ``tui``/``cli`` rows, hiding telegram sessions users
|
||||
could still resume directly via ``hermes --tui --resume <id>``.
|
||||
|
||||
The fix widens the picker to a curated allowlist of user-facing sources
|
||||
(tui/cli + chat adapters) while still filtering internal/system sources.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from tui_gateway import server
|
||||
|
||||
|
||||
class _StubDB:
|
||||
def __init__(self, rows):
|
||||
self.rows = rows
|
||||
self.calls: list[dict] = []
|
||||
|
||||
def list_sessions_rich(self, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
return list(self.rows)
|
||||
|
||||
|
||||
def _call(limit: int = 20):
|
||||
return server.handle_request({
|
||||
"id": "1",
|
||||
"method": "session.list",
|
||||
"params": {"limit": limit},
|
||||
})
|
||||
|
||||
|
||||
def test_session_list_includes_telegram_but_filters_internal_sources(monkeypatch):
|
||||
rows = [
|
||||
{"id": "tui-1", "source": "tui", "started_at": 9},
|
||||
{"id": "tool-1", "source": "tool", "started_at": 8},
|
||||
{"id": "tg-1", "source": "telegram", "started_at": 7},
|
||||
{"id": "acp-1", "source": "acp", "started_at": 6},
|
||||
{"id": "cli-1", "source": "cli", "started_at": 5},
|
||||
]
|
||||
db = _StubDB(rows)
|
||||
monkeypatch.setattr(server, "_get_db", lambda: db)
|
||||
|
||||
resp = _call(limit=10)
|
||||
sessions = resp["result"]["sessions"]
|
||||
ids = [s["id"] for s in sessions]
|
||||
|
||||
assert "tg-1" in ids and "tui-1" in ids and "cli-1" in ids, ids
|
||||
assert "tool-1" not in ids and "acp-1" not in ids, ids
|
||||
|
||||
|
||||
def test_session_list_fetches_wider_window_before_filtering(monkeypatch):
|
||||
db = _StubDB([{"id": "x", "source": "cli", "started_at": 1}])
|
||||
monkeypatch.setattr(server, "_get_db", lambda: db)
|
||||
|
||||
_call(limit=10)
|
||||
|
||||
assert len(db.calls) == 1
|
||||
assert db.calls[0].get("source") is None, db.calls[0]
|
||||
assert db.calls[0].get("limit") == 100, db.calls[0]
|
||||
|
||||
|
||||
def test_session_list_preserves_ordering_after_filter(monkeypatch):
|
||||
rows = [
|
||||
{"id": "newest", "source": "telegram", "started_at": 5},
|
||||
{"id": "internal", "source": "tool", "started_at": 4},
|
||||
{"id": "middle", "source": "tui", "started_at": 3},
|
||||
{"id": "oldest", "source": "discord", "started_at": 1},
|
||||
]
|
||||
monkeypatch.setattr(server, "_get_db", lambda: _StubDB(rows))
|
||||
|
||||
resp = _call()
|
||||
ids = [s["id"] for s in resp["result"]["sessions"]]
|
||||
|
||||
assert ids == ["newest", "middle", "oldest"]
|
||||
|
|
@ -71,7 +71,11 @@ class TestProviderRegistry:
|
|||
|
||||
def test_kimi_env_vars(self):
|
||||
pconfig = PROVIDER_REGISTRY["kimi-coding"]
|
||||
assert pconfig.api_key_env_vars == ("KIMI_API_KEY",)
|
||||
# KIMI_API_KEY is the primary env var; KIMI_CODING_API_KEY is a
|
||||
# secondary fallback for Kimi Code sk-kimi- keys so users don't
|
||||
# have to overload the same variable.
|
||||
assert "KIMI_API_KEY" in pconfig.api_key_env_vars
|
||||
assert "KIMI_CODING_API_KEY" in pconfig.api_key_env_vars
|
||||
assert pconfig.base_url_env_var == "KIMI_BASE_URL"
|
||||
|
||||
def test_minimax_env_vars(self):
|
||||
|
|
|
|||
90
tests/hermes_cli/test_at_context_completion_filter.py
Normal file
90
tests/hermes_cli/test_at_context_completion_filter.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
"""Regression test: `@folder:` completion must only surface directories and
|
||||
`@file:` must only surface regular files.
|
||||
|
||||
Reported during TUI v2 blitz testing: typing `@folder:` showed .dockerignore,
|
||||
.env, .gitignore, etc. alongside the actual directories because the path-
|
||||
completion branch yielded every entry regardless of the explicit prefix, and
|
||||
auto-switched the completion kind based on `is_dir`. That defeated the user's
|
||||
explicit choice and rendered the `@folder:` / `@file:` prefixes useless for
|
||||
filtering.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from hermes_cli.commands import SlashCommandCompleter
|
||||
|
||||
|
||||
def _run(tmp_path: Path, word: str) -> list[tuple[str, str]]:
|
||||
(tmp_path / "readme.md").write_text("x")
|
||||
(tmp_path / ".env").write_text("x")
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "docs").mkdir()
|
||||
|
||||
completer = SlashCommandCompleter.__new__(SlashCommandCompleter)
|
||||
completions: Iterable = completer._context_completions(word)
|
||||
|
||||
return [(c.text, c.display_meta) for c in completions if c.text.startswith(("@file:", "@folder:"))]
|
||||
|
||||
|
||||
def test_at_folder_only_yields_directories(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
texts = [t for t, _ in _run(tmp_path, "@folder:")]
|
||||
|
||||
assert all(t.startswith("@folder:") for t in texts), texts
|
||||
assert any(t == "@folder:src/" for t in texts)
|
||||
assert any(t == "@folder:docs/" for t in texts)
|
||||
assert not any(t == "@folder:readme.md" for t in texts)
|
||||
assert not any(t == "@folder:.env" for t in texts)
|
||||
|
||||
|
||||
def test_at_file_only_yields_files(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
texts = [t for t, _ in _run(tmp_path, "@file:")]
|
||||
|
||||
assert all(t.startswith("@file:") for t in texts), texts
|
||||
assert any(t == "@file:readme.md" for t in texts)
|
||||
assert any(t == "@file:.env" for t in texts)
|
||||
assert not any(t == "@file:src/" for t in texts)
|
||||
assert not any(t == "@file:docs/" for t in texts)
|
||||
|
||||
|
||||
def test_at_folder_preserves_prefix_on_empty_match(tmp_path, monkeypatch):
|
||||
"""User typed `@folder:` (no partial) — completion text must keep the
|
||||
`@folder:` prefix even though the previous implementation auto-rewrote
|
||||
it to `@file:` for non-dir entries.
|
||||
"""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
texts = [t for t, _ in _run(tmp_path, "@folder:")]
|
||||
|
||||
assert texts, "expected at least one directory completion"
|
||||
for t in texts:
|
||||
assert t.startswith("@folder:"), f"prefix leaked: {t}"
|
||||
|
||||
|
||||
def test_at_folder_bare_without_colon_lists_directories(tmp_path, monkeypatch):
|
||||
"""Typing `@folder` alone (no colon yet) should surface directories so
|
||||
users don't need to first accept the static `@folder:` hint before
|
||||
seeing what they're picking from.
|
||||
"""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
texts = [t for t, _ in _run(tmp_path, "@folder")]
|
||||
|
||||
assert any(t == "@folder:src/" for t in texts), texts
|
||||
assert any(t == "@folder:docs/" for t in texts), texts
|
||||
assert not any(t == "@folder:readme.md" for t in texts)
|
||||
|
||||
|
||||
def test_at_file_bare_without_colon_lists_files(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
texts = [t for t, _ in _run(tmp_path, "@file")]
|
||||
|
||||
assert any(t == "@file:readme.md" for t in texts), texts
|
||||
assert not any(t == "@file:src/" for t in texts)
|
||||
|
|
@ -376,7 +376,6 @@ class TestLoginNousSkipKeepsCurrent:
|
|||
lambda *a, **kw: prompt_returns,
|
||||
)
|
||||
monkeypatch.setattr(models_mod, "get_pricing_for_provider", lambda p: {})
|
||||
monkeypatch.setattr(models_mod, "filter_nous_free_models", lambda ids, p: ids)
|
||||
monkeypatch.setattr(models_mod, "check_nous_free_tier", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
models_mod, "partition_nous_models_by_tier",
|
||||
|
|
|
|||
36
tests/hermes_cli/test_config_drift.py
Normal file
36
tests/hermes_cli/test_config_drift.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
"""Regression tests for removed dead config keys.
|
||||
|
||||
This file guards against accidental re-introduction of config keys that were
|
||||
documented or declared at some point but never actually wired up to read code.
|
||||
Future dead-config regressions can accumulate here.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
|
||||
|
||||
def test_delegation_default_toolsets_removed_from_cli_config():
|
||||
"""delegation.default_toolsets was dead config — never read by
|
||||
_load_config() or anywhere else. Removed.
|
||||
|
||||
Guards against accidental re-introduction in cli.py's CLI_CONFIG default
|
||||
dict. If this test fails, someone re-added the key without wiring it up
|
||||
to _load_config() in tools/delegate_tool.py.
|
||||
|
||||
We inspect the source of load_cli_config() instead of asserting on the
|
||||
runtime CLI_CONFIG dict because CLI_CONFIG is populated by deep-merging
|
||||
the user's ~/.hermes/config.yaml over the defaults (cli.py:359-366).
|
||||
A contributor who still has the legacy key set in their own config
|
||||
would cause a false failure, and HERMES_HOME patching via conftest
|
||||
doesn't help because cli._hermes_home is frozen at module import time
|
||||
(cli.py:76) — before any autouse fixture can fire. Source inspection
|
||||
sidesteps all of that: it tests the defaults literal directly.
|
||||
"""
|
||||
from cli import load_cli_config
|
||||
|
||||
source = inspect.getsource(load_cli_config)
|
||||
assert '"default_toolsets"' not in source, (
|
||||
"delegation.default_toolsets was removed because it was never read. "
|
||||
"Do not re-add it to cli.py's CLI_CONFIG default dict; "
|
||||
"use tools/delegate_tool.py's DEFAULT_TOOLSETS module constant or "
|
||||
"wire a new config key through _load_config()."
|
||||
)
|
||||
174
tests/hermes_cli/test_image_gen_picker.py
Normal file
174
tests/hermes_cli/test_image_gen_picker.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
"""Tests for plugin image_gen providers injecting themselves into the picker.
|
||||
|
||||
Covers `_plugin_image_gen_providers`, `_visible_providers`, and
|
||||
`_toolset_needs_configuration_prompt` handling of plugin providers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agent import image_gen_registry
|
||||
from agent.image_gen_provider import ImageGenProvider
|
||||
|
||||
|
||||
class _FakeProvider(ImageGenProvider):
|
||||
def __init__(self, name: str, available: bool = True, schema=None, models=None):
|
||||
self._name = name
|
||||
self._available = available
|
||||
self._schema = schema or {
|
||||
"name": name.title(),
|
||||
"badge": "test",
|
||||
"tag": f"{name} test tag",
|
||||
"env_vars": [{"key": f"{name.upper()}_API_KEY", "prompt": f"{name} key"}],
|
||||
}
|
||||
self._models = models or [
|
||||
{"id": f"{name}-model-v1", "display": f"{name} v1",
|
||||
"speed": "~5s", "strengths": "test", "price": "$"},
|
||||
]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._available
|
||||
|
||||
def list_models(self):
|
||||
return list(self._models)
|
||||
|
||||
def default_model(self):
|
||||
return self._models[0]["id"] if self._models else None
|
||||
|
||||
def get_setup_schema(self):
|
||||
return dict(self._schema)
|
||||
|
||||
def generate(self, prompt, aspect_ratio="landscape", **kw):
|
||||
return {"success": True, "image": f"{self._name}://{prompt}"}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
image_gen_registry._reset_for_tests()
|
||||
yield
|
||||
image_gen_registry._reset_for_tests()
|
||||
|
||||
|
||||
class TestPluginPickerInjection:
|
||||
def test_plugin_providers_returns_registered(self, monkeypatch):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
image_gen_registry.register_provider(_FakeProvider("myimg"))
|
||||
|
||||
rows = tools_config._plugin_image_gen_providers()
|
||||
names = [r["name"] for r in rows]
|
||||
plugin_names = [r.get("image_gen_plugin_name") for r in rows]
|
||||
|
||||
assert "Myimg" in names
|
||||
assert "myimg" in plugin_names
|
||||
|
||||
def test_fal_skipped_to_avoid_duplicate(self, monkeypatch):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
# Simulate a FAL plugin being registered — the picker already has
|
||||
# hardcoded FAL rows in TOOL_CATEGORIES, so plugin-FAL must be
|
||||
# skipped to avoid showing FAL twice.
|
||||
image_gen_registry.register_provider(_FakeProvider("fal"))
|
||||
image_gen_registry.register_provider(_FakeProvider("openai"))
|
||||
|
||||
rows = tools_config._plugin_image_gen_providers()
|
||||
names = [r.get("image_gen_plugin_name") for r in rows]
|
||||
assert "fal" not in names
|
||||
assert "openai" in names
|
||||
|
||||
def test_visible_providers_includes_plugins_for_image_gen(self, monkeypatch):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
image_gen_registry.register_provider(_FakeProvider("someimg"))
|
||||
|
||||
cat = tools_config.TOOL_CATEGORIES["image_gen"]
|
||||
visible = tools_config._visible_providers(cat, {})
|
||||
plugin_names = [p.get("image_gen_plugin_name") for p in visible if p.get("image_gen_plugin_name")]
|
||||
assert "someimg" in plugin_names
|
||||
|
||||
def test_visible_providers_does_not_inject_into_other_categories(self, monkeypatch):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
image_gen_registry.register_provider(_FakeProvider("someimg"))
|
||||
|
||||
# Browser category must NOT see image_gen plugins.
|
||||
browser = tools_config.TOOL_CATEGORIES["browser"]
|
||||
visible = tools_config._visible_providers(browser, {})
|
||||
assert all(p.get("image_gen_plugin_name") is None for p in visible)
|
||||
|
||||
|
||||
class TestPluginCatalog:
|
||||
def test_plugin_catalog_returns_models(self):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
image_gen_registry.register_provider(_FakeProvider("catimg"))
|
||||
|
||||
catalog, default = tools_config._plugin_image_gen_catalog("catimg")
|
||||
assert "catimg-model-v1" in catalog
|
||||
assert default == "catimg-model-v1"
|
||||
|
||||
def test_plugin_catalog_empty_for_unknown(self):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
catalog, default = tools_config._plugin_image_gen_catalog("does-not-exist")
|
||||
assert catalog == {}
|
||||
assert default is None
|
||||
|
||||
|
||||
class TestConfigPrompt:
|
||||
def test_image_gen_satisfied_by_plugin_provider(self, monkeypatch, tmp_path):
|
||||
"""When a plugin provider reports is_available(), the picker should
|
||||
not force a setup prompt on the user."""
|
||||
from hermes_cli import tools_config
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.delenv("FAL_KEY", raising=False)
|
||||
|
||||
image_gen_registry.register_provider(_FakeProvider("avail-img", available=True))
|
||||
|
||||
assert tools_config._toolset_needs_configuration_prompt("image_gen", {}) is False
|
||||
|
||||
def test_image_gen_still_prompts_when_nothing_available(self, monkeypatch, tmp_path):
|
||||
from hermes_cli import tools_config
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.delenv("FAL_KEY", raising=False)
|
||||
|
||||
image_gen_registry.register_provider(_FakeProvider("unavail-img", available=False))
|
||||
|
||||
assert tools_config._toolset_needs_configuration_prompt("image_gen", {}) is True
|
||||
|
||||
|
||||
class TestConfigWriting:
|
||||
def test_picking_plugin_provider_writes_provider_and_model(self, monkeypatch, tmp_path):
|
||||
"""When a user picks a plugin-backed image_gen provider with no
|
||||
env vars needed, ``_configure_provider`` should write both
|
||||
``image_gen.provider`` and ``image_gen.model``."""
|
||||
from hermes_cli import tools_config
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
image_gen_registry.register_provider(_FakeProvider("noenv", schema={
|
||||
"name": "NoEnv",
|
||||
"badge": "free",
|
||||
"tag": "",
|
||||
"env_vars": [],
|
||||
}))
|
||||
|
||||
# Stub out the interactive model picker — no TTY in tests.
|
||||
monkeypatch.setattr(tools_config, "_prompt_choice", lambda *a, **kw: 0)
|
||||
|
||||
config: dict = {}
|
||||
provider_row = {
|
||||
"name": "NoEnv",
|
||||
"env_vars": [],
|
||||
"image_gen_plugin_name": "noenv",
|
||||
}
|
||||
tools_config._configure_provider(provider_row, config)
|
||||
|
||||
assert config["image_gen"]["provider"] == "noenv"
|
||||
assert config["image_gen"]["model"] == "noenv-model-v1"
|
||||
|
|
@ -4,7 +4,6 @@ from unittest.mock import patch, MagicMock
|
|||
|
||||
from hermes_cli.models import (
|
||||
OPENROUTER_MODELS, fetch_openrouter_models, model_ids, detect_provider_for_model,
|
||||
filter_nous_free_models, _NOUS_ALLOWED_FREE_MODELS,
|
||||
is_nous_free_tier, partition_nous_models_by_tier,
|
||||
check_nous_free_tier, _FREE_TIER_CACHE_TTL,
|
||||
)
|
||||
|
|
@ -293,89 +292,6 @@ class TestDetectProviderForModel:
|
|||
assert result[0] not in ("nous",) # nous has claude models but shouldn't be suggested
|
||||
|
||||
|
||||
class TestFilterNousFreeModels:
|
||||
"""Tests for filter_nous_free_models — Nous Portal free-model policy."""
|
||||
|
||||
_PAID = {"prompt": "0.000003", "completion": "0.000015"}
|
||||
_FREE = {"prompt": "0", "completion": "0"}
|
||||
|
||||
def test_paid_models_kept(self):
|
||||
"""Regular paid models pass through unchanged."""
|
||||
models = ["anthropic/claude-opus-4.6", "openai/gpt-5.4"]
|
||||
pricing = {m: self._PAID for m in models}
|
||||
assert filter_nous_free_models(models, pricing) == models
|
||||
|
||||
def test_free_non_allowlist_models_removed(self):
|
||||
"""Free models NOT in the allowlist are filtered out."""
|
||||
models = ["anthropic/claude-opus-4.6", "arcee-ai/trinity-large-preview:free"]
|
||||
pricing = {
|
||||
"anthropic/claude-opus-4.6": self._PAID,
|
||||
"arcee-ai/trinity-large-preview:free": self._FREE,
|
||||
}
|
||||
result = filter_nous_free_models(models, pricing)
|
||||
assert result == ["anthropic/claude-opus-4.6"]
|
||||
|
||||
def test_allowlist_model_kept_when_free(self):
|
||||
"""Allowlist models are kept when they report as free."""
|
||||
models = ["anthropic/claude-opus-4.6", "xiaomi/mimo-v2-pro"]
|
||||
pricing = {
|
||||
"anthropic/claude-opus-4.6": self._PAID,
|
||||
"xiaomi/mimo-v2-pro": self._FREE,
|
||||
}
|
||||
result = filter_nous_free_models(models, pricing)
|
||||
assert result == ["anthropic/claude-opus-4.6", "xiaomi/mimo-v2-pro"]
|
||||
|
||||
def test_allowlist_model_removed_when_paid(self):
|
||||
"""Allowlist models are removed when they are NOT free."""
|
||||
models = ["anthropic/claude-opus-4.6", "xiaomi/mimo-v2-pro"]
|
||||
pricing = {
|
||||
"anthropic/claude-opus-4.6": self._PAID,
|
||||
"xiaomi/mimo-v2-pro": self._PAID,
|
||||
}
|
||||
result = filter_nous_free_models(models, pricing)
|
||||
assert result == ["anthropic/claude-opus-4.6"]
|
||||
|
||||
def test_no_pricing_returns_all(self):
|
||||
"""When pricing data is unavailable, all models pass through."""
|
||||
models = ["anthropic/claude-opus-4.6", "nvidia/nemotron-3-super-120b-a12b:free"]
|
||||
assert filter_nous_free_models(models, {}) == models
|
||||
|
||||
def test_model_with_no_pricing_entry_treated_as_paid(self):
|
||||
"""A model missing from the pricing dict is kept (assumed paid)."""
|
||||
models = ["anthropic/claude-opus-4.6", "openai/gpt-5.4"]
|
||||
pricing = {"anthropic/claude-opus-4.6": self._PAID} # gpt-5.4 not in pricing
|
||||
result = filter_nous_free_models(models, pricing)
|
||||
assert result == models
|
||||
|
||||
def test_mixed_scenario(self):
|
||||
"""End-to-end: mix of paid, free-allowed, free-disallowed, allowlist-not-free."""
|
||||
models = [
|
||||
"anthropic/claude-opus-4.6", # paid, not allowlist → keep
|
||||
"nvidia/nemotron-3-super-120b-a12b:free", # free, not allowlist → drop
|
||||
"xiaomi/mimo-v2-pro", # free, allowlist → keep
|
||||
"xiaomi/mimo-v2-omni", # paid, allowlist → drop
|
||||
"openai/gpt-5.4", # paid, not allowlist → keep
|
||||
]
|
||||
pricing = {
|
||||
"anthropic/claude-opus-4.6": self._PAID,
|
||||
"nvidia/nemotron-3-super-120b-a12b:free": self._FREE,
|
||||
"xiaomi/mimo-v2-pro": self._FREE,
|
||||
"xiaomi/mimo-v2-omni": self._PAID,
|
||||
"openai/gpt-5.4": self._PAID,
|
||||
}
|
||||
result = filter_nous_free_models(models, pricing)
|
||||
assert result == [
|
||||
"anthropic/claude-opus-4.6",
|
||||
"xiaomi/mimo-v2-pro",
|
||||
"openai/gpt-5.4",
|
||||
]
|
||||
|
||||
def test_allowlist_contains_expected_models(self):
|
||||
"""Sanity: the allowlist has the models we expect."""
|
||||
assert "xiaomi/mimo-v2-pro" in _NOUS_ALLOWED_FREE_MODELS
|
||||
assert "xiaomi/mimo-v2-omni" in _NOUS_ALLOWED_FREE_MODELS
|
||||
|
||||
|
||||
class TestIsNousFreeTier:
|
||||
"""Tests for is_nous_free_tier — account tier detection."""
|
||||
|
||||
|
|
@ -501,3 +417,190 @@ class TestCheckNousFreeTierCache:
|
|||
def test_cache_ttl_is_short(self):
|
||||
"""TTL should be short enough to catch upgrades quickly (<=5 min)."""
|
||||
assert _FREE_TIER_CACHE_TTL <= 300
|
||||
|
||||
|
||||
class TestNousRecommendedModels:
|
||||
"""Tests for fetch_nous_recommended_models + get_nous_recommended_aux_model."""
|
||||
|
||||
_SAMPLE_PAYLOAD = {
|
||||
"paidRecommendedModels": [],
|
||||
"freeRecommendedModels": [],
|
||||
"paidRecommendedCompactionModel": None,
|
||||
"paidRecommendedVisionModel": None,
|
||||
"freeRecommendedCompactionModel": {
|
||||
"modelName": "google/gemini-3-flash-preview",
|
||||
"displayName": "Google: Gemini 3 Flash Preview",
|
||||
},
|
||||
"freeRecommendedVisionModel": {
|
||||
"modelName": "google/gemini-3-flash-preview",
|
||||
"displayName": "Google: Gemini 3 Flash Preview",
|
||||
},
|
||||
}
|
||||
|
||||
def setup_method(self):
|
||||
_models_mod._nous_recommended_cache.clear()
|
||||
|
||||
def teardown_method(self):
|
||||
_models_mod._nous_recommended_cache.clear()
|
||||
|
||||
def _mock_urlopen(self, payload):
|
||||
"""Return a context-manager mock mimicking urllib.request.urlopen()."""
|
||||
import json as _json
|
||||
response = MagicMock()
|
||||
response.read.return_value = _json.dumps(payload).encode()
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = response
|
||||
cm.__exit__.return_value = False
|
||||
return cm
|
||||
|
||||
def test_fetch_caches_per_portal_url(self):
|
||||
from hermes_cli.models import fetch_nous_recommended_models
|
||||
mock_cm = self._mock_urlopen(self._SAMPLE_PAYLOAD)
|
||||
with patch("urllib.request.urlopen", return_value=mock_cm) as mock_urlopen:
|
||||
a = fetch_nous_recommended_models("https://portal.example.com")
|
||||
b = fetch_nous_recommended_models("https://portal.example.com")
|
||||
assert a == self._SAMPLE_PAYLOAD
|
||||
assert b == self._SAMPLE_PAYLOAD
|
||||
assert mock_urlopen.call_count == 1 # second call served from cache
|
||||
|
||||
def test_fetch_cache_is_keyed_per_portal(self):
|
||||
from hermes_cli.models import fetch_nous_recommended_models
|
||||
mock_cm = self._mock_urlopen(self._SAMPLE_PAYLOAD)
|
||||
with patch("urllib.request.urlopen", return_value=mock_cm) as mock_urlopen:
|
||||
fetch_nous_recommended_models("https://portal.example.com")
|
||||
fetch_nous_recommended_models("https://portal.staging-nousresearch.com")
|
||||
assert mock_urlopen.call_count == 2 # different portals → separate fetches
|
||||
|
||||
def test_fetch_returns_empty_on_network_failure(self):
|
||||
from hermes_cli.models import fetch_nous_recommended_models
|
||||
with patch("urllib.request.urlopen", side_effect=OSError("boom")):
|
||||
result = fetch_nous_recommended_models("https://portal.example.com")
|
||||
assert result == {}
|
||||
|
||||
def test_fetch_force_refresh_bypasses_cache(self):
|
||||
from hermes_cli.models import fetch_nous_recommended_models
|
||||
mock_cm = self._mock_urlopen(self._SAMPLE_PAYLOAD)
|
||||
with patch("urllib.request.urlopen", return_value=mock_cm) as mock_urlopen:
|
||||
fetch_nous_recommended_models("https://portal.example.com")
|
||||
fetch_nous_recommended_models("https://portal.example.com", force_refresh=True)
|
||||
assert mock_urlopen.call_count == 2
|
||||
|
||||
def test_get_aux_model_returns_vision_recommendation(self):
|
||||
from hermes_cli.models import get_nous_recommended_aux_model
|
||||
with patch(
|
||||
"hermes_cli.models.fetch_nous_recommended_models",
|
||||
return_value=self._SAMPLE_PAYLOAD,
|
||||
):
|
||||
# Free tier → free vision recommendation.
|
||||
model = get_nous_recommended_aux_model(vision=True, free_tier=True)
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_get_aux_model_returns_compaction_recommendation(self):
|
||||
from hermes_cli.models import get_nous_recommended_aux_model
|
||||
payload = dict(self._SAMPLE_PAYLOAD)
|
||||
payload["freeRecommendedCompactionModel"] = {"modelName": "minimax/minimax-m2.7"}
|
||||
with patch(
|
||||
"hermes_cli.models.fetch_nous_recommended_models",
|
||||
return_value=payload,
|
||||
):
|
||||
model = get_nous_recommended_aux_model(vision=False, free_tier=True)
|
||||
assert model == "minimax/minimax-m2.7"
|
||||
|
||||
def test_get_aux_model_returns_none_when_field_null(self):
|
||||
from hermes_cli.models import get_nous_recommended_aux_model
|
||||
payload = dict(self._SAMPLE_PAYLOAD)
|
||||
payload["freeRecommendedCompactionModel"] = None
|
||||
with patch(
|
||||
"hermes_cli.models.fetch_nous_recommended_models",
|
||||
return_value=payload,
|
||||
):
|
||||
model = get_nous_recommended_aux_model(vision=False, free_tier=True)
|
||||
assert model is None
|
||||
|
||||
def test_get_aux_model_returns_none_on_empty_payload(self):
|
||||
from hermes_cli.models import get_nous_recommended_aux_model
|
||||
with patch("hermes_cli.models.fetch_nous_recommended_models", return_value={}):
|
||||
assert get_nous_recommended_aux_model(vision=False, free_tier=True) is None
|
||||
assert get_nous_recommended_aux_model(vision=True, free_tier=False) is None
|
||||
|
||||
def test_get_aux_model_returns_none_when_modelname_blank(self):
|
||||
from hermes_cli.models import get_nous_recommended_aux_model
|
||||
payload = {"freeRecommendedCompactionModel": {"modelName": " "}}
|
||||
with patch(
|
||||
"hermes_cli.models.fetch_nous_recommended_models",
|
||||
return_value=payload,
|
||||
):
|
||||
assert get_nous_recommended_aux_model(vision=False, free_tier=True) is None
|
||||
|
||||
def test_paid_tier_prefers_paid_recommendation(self):
|
||||
"""Paid-tier users should get the paid model when it's populated."""
|
||||
from hermes_cli.models import get_nous_recommended_aux_model
|
||||
payload = {
|
||||
"paidRecommendedCompactionModel": {"modelName": "anthropic/claude-opus-4.7"},
|
||||
"freeRecommendedCompactionModel": {"modelName": "google/gemini-3-flash-preview"},
|
||||
"paidRecommendedVisionModel": {"modelName": "openai/gpt-5.4"},
|
||||
"freeRecommendedVisionModel": {"modelName": "google/gemini-3-flash-preview"},
|
||||
}
|
||||
with patch("hermes_cli.models.fetch_nous_recommended_models", return_value=payload):
|
||||
text = get_nous_recommended_aux_model(vision=False, free_tier=False)
|
||||
vision = get_nous_recommended_aux_model(vision=True, free_tier=False)
|
||||
assert text == "anthropic/claude-opus-4.7"
|
||||
assert vision == "openai/gpt-5.4"
|
||||
|
||||
def test_paid_tier_falls_back_to_free_when_paid_is_null(self):
|
||||
"""If the Portal returns null for the paid field, fall back to free."""
|
||||
from hermes_cli.models import get_nous_recommended_aux_model
|
||||
payload = {
|
||||
"paidRecommendedCompactionModel": None,
|
||||
"freeRecommendedCompactionModel": {"modelName": "google/gemini-3-flash-preview"},
|
||||
"paidRecommendedVisionModel": None,
|
||||
"freeRecommendedVisionModel": {"modelName": "google/gemini-3-flash-preview"},
|
||||
}
|
||||
with patch("hermes_cli.models.fetch_nous_recommended_models", return_value=payload):
|
||||
text = get_nous_recommended_aux_model(vision=False, free_tier=False)
|
||||
vision = get_nous_recommended_aux_model(vision=True, free_tier=False)
|
||||
assert text == "google/gemini-3-flash-preview"
|
||||
assert vision == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_free_tier_never_uses_paid_recommendation(self):
|
||||
"""Free-tier users must not get paid-only recommendations."""
|
||||
from hermes_cli.models import get_nous_recommended_aux_model
|
||||
payload = {
|
||||
"paidRecommendedCompactionModel": {"modelName": "anthropic/claude-opus-4.7"},
|
||||
"freeRecommendedCompactionModel": None, # no free recommendation
|
||||
}
|
||||
with patch("hermes_cli.models.fetch_nous_recommended_models", return_value=payload):
|
||||
model = get_nous_recommended_aux_model(vision=False, free_tier=True)
|
||||
# Free tier must return None — never leak the paid model.
|
||||
assert model is None
|
||||
|
||||
def test_auto_detects_tier_when_not_supplied(self):
|
||||
"""Default behaviour: call check_nous_free_tier() to pick the tier."""
|
||||
from hermes_cli.models import get_nous_recommended_aux_model
|
||||
payload = {
|
||||
"paidRecommendedCompactionModel": {"modelName": "paid-model"},
|
||||
"freeRecommendedCompactionModel": {"modelName": "free-model"},
|
||||
}
|
||||
with (
|
||||
patch("hermes_cli.models.fetch_nous_recommended_models", return_value=payload),
|
||||
patch("hermes_cli.models.check_nous_free_tier", return_value=True),
|
||||
):
|
||||
assert get_nous_recommended_aux_model(vision=False) == "free-model"
|
||||
with (
|
||||
patch("hermes_cli.models.fetch_nous_recommended_models", return_value=payload),
|
||||
patch("hermes_cli.models.check_nous_free_tier", return_value=False),
|
||||
):
|
||||
assert get_nous_recommended_aux_model(vision=False) == "paid-model"
|
||||
|
||||
def test_tier_detection_error_defaults_to_paid(self):
|
||||
"""If tier detection raises, assume paid so we don't downgrade silently."""
|
||||
from hermes_cli.models import get_nous_recommended_aux_model
|
||||
payload = {
|
||||
"paidRecommendedCompactionModel": {"modelName": "paid-model"},
|
||||
"freeRecommendedCompactionModel": {"modelName": "free-model"},
|
||||
}
|
||||
with (
|
||||
patch("hermes_cli.models.fetch_nous_recommended_models", return_value=payload),
|
||||
patch("hermes_cli.models.check_nous_free_tier", side_effect=RuntimeError("boom")),
|
||||
):
|
||||
assert get_nous_recommended_aux_model(vision=False) == "paid-model"
|
||||
|
|
|
|||
357
tests/hermes_cli/test_plugin_scanner_recursion.py
Normal file
357
tests/hermes_cli/test_plugin_scanner_recursion.py
Normal file
|
|
@ -0,0 +1,357 @@
|
|||
"""Tests for PR1 pluggable image gen: scanner recursion, kinds, path keys.
|
||||
|
||||
Covers ``_scan_directory`` recursion into category namespaces
|
||||
(``plugins/image_gen/openai/``), ``kind`` parsing, path-derived registry
|
||||
keys, and the new gate logic (bundled backends auto-load; user backends
|
||||
still opt-in; exclusive kind skipped; unknown kinds → standalone warning).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from hermes_cli.plugins import PluginManager, PluginManifest
|
||||
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _write_plugin(
|
||||
root: Path,
|
||||
segments: list[str],
|
||||
*,
|
||||
manifest_extra: Dict[str, Any] | None = None,
|
||||
register_body: str = "pass",
|
||||
) -> Path:
|
||||
"""Create a plugin dir at ``root/<segments...>/`` with plugin.yaml + __init__.py.
|
||||
|
||||
``segments`` lets tests build both flat (``["my-plugin"]``) and
|
||||
category-namespaced (``["image_gen", "openai"]``) layouts.
|
||||
"""
|
||||
plugin_dir = root
|
||||
for seg in segments:
|
||||
plugin_dir = plugin_dir / seg
|
||||
plugin_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
manifest = {
|
||||
"name": segments[-1],
|
||||
"version": "0.1.0",
|
||||
"description": f"Test plugin {'/'.join(segments)}",
|
||||
}
|
||||
if manifest_extra:
|
||||
manifest.update(manifest_extra)
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump(manifest))
|
||||
(plugin_dir / "__init__.py").write_text(
|
||||
f"def register(ctx):\n {register_body}\n"
|
||||
)
|
||||
return plugin_dir
|
||||
|
||||
|
||||
def _enable(hermes_home: Path, name: str) -> None:
|
||||
"""Append ``name`` to ``plugins.enabled`` in ``<hermes_home>/config.yaml``."""
|
||||
cfg_path = hermes_home / "config.yaml"
|
||||
cfg: dict = {}
|
||||
if cfg_path.exists():
|
||||
try:
|
||||
cfg = yaml.safe_load(cfg_path.read_text()) or {}
|
||||
except Exception:
|
||||
cfg = {}
|
||||
plugins_cfg = cfg.setdefault("plugins", {})
|
||||
enabled = plugins_cfg.setdefault("enabled", [])
|
||||
if isinstance(enabled, list) and name not in enabled:
|
||||
enabled.append(name)
|
||||
cfg_path.write_text(yaml.safe_dump(cfg))
|
||||
|
||||
|
||||
# ── Scanner recursion ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCategoryNamespaceRecursion:
|
||||
def test_category_namespace_discovered(self, tmp_path, monkeypatch):
|
||||
"""``<root>/image_gen/openai/plugin.yaml`` is discovered with key
|
||||
``image_gen/openai`` when the ``image_gen`` parent has no manifest."""
|
||||
import os
|
||||
hermes_home = Path(os.environ["HERMES_HOME"]) # set by hermetic conftest fixture
|
||||
user_plugins = hermes_home / "plugins"
|
||||
|
||||
_write_plugin(user_plugins, ["image_gen", "openai"])
|
||||
_enable(hermes_home, "image_gen/openai")
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "image_gen/openai" in mgr._plugins
|
||||
loaded = mgr._plugins["image_gen/openai"]
|
||||
assert loaded.manifest.key == "image_gen/openai"
|
||||
assert loaded.manifest.name == "openai"
|
||||
assert loaded.enabled is True
|
||||
|
||||
def test_flat_plugin_key_matches_name(self, tmp_path, monkeypatch):
|
||||
"""Flat plugins keep their bare name as the key (back-compat)."""
|
||||
import os
|
||||
hermes_home = Path(os.environ["HERMES_HOME"]) # set by hermetic conftest fixture
|
||||
user_plugins = hermes_home / "plugins"
|
||||
|
||||
_write_plugin(user_plugins, ["my-plugin"])
|
||||
_enable(hermes_home, "my-plugin")
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "my-plugin" in mgr._plugins
|
||||
assert mgr._plugins["my-plugin"].manifest.key == "my-plugin"
|
||||
|
||||
def test_depth_cap_two(self, tmp_path, monkeypatch):
|
||||
"""Plugins nested three levels deep are not discovered.
|
||||
|
||||
``<root>/a/b/c/plugin.yaml`` should NOT be picked up — cap is
|
||||
two segments.
|
||||
"""
|
||||
import os
|
||||
hermes_home = Path(os.environ["HERMES_HOME"]) # set by hermetic conftest fixture
|
||||
user_plugins = hermes_home / "plugins"
|
||||
|
||||
_write_plugin(user_plugins, ["a", "b", "c"])
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
non_bundled = [
|
||||
k for k, p in mgr._plugins.items()
|
||||
if p.manifest.source != "bundled"
|
||||
]
|
||||
assert non_bundled == []
|
||||
|
||||
def test_category_dir_with_manifest_is_leaf(self, tmp_path, monkeypatch):
|
||||
"""If ``image_gen/plugin.yaml`` exists, ``image_gen`` itself IS the
|
||||
plugin and its children are ignored."""
|
||||
import os
|
||||
hermes_home = Path(os.environ["HERMES_HOME"]) # set by hermetic conftest fixture
|
||||
user_plugins = hermes_home / "plugins"
|
||||
|
||||
# parent has a manifest → stop recursing
|
||||
_write_plugin(user_plugins, ["image_gen"])
|
||||
# child also has a manifest — should NOT be found because we stop
|
||||
# at the parent.
|
||||
_write_plugin(user_plugins, ["image_gen", "openai"])
|
||||
_enable(hermes_home, "image_gen")
|
||||
_enable(hermes_home, "image_gen/openai")
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
# The bundled plugins/image_gen/openai/ exists in the repo — filter
|
||||
# it out so we're only asserting on the user-dir layout.
|
||||
user_plugins_in_registry = {
|
||||
k for k, p in mgr._plugins.items() if p.manifest.source != "bundled"
|
||||
}
|
||||
assert "image_gen" in user_plugins_in_registry
|
||||
assert "image_gen/openai" not in user_plugins_in_registry
|
||||
|
||||
|
||||
# ── Kind parsing ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestKindField:
|
||||
def test_default_kind_is_standalone(self, tmp_path, monkeypatch):
|
||||
import os
|
||||
hermes_home = Path(os.environ["HERMES_HOME"]) # set by hermetic conftest fixture
|
||||
_write_plugin(hermes_home / "plugins", ["p1"])
|
||||
_enable(hermes_home, "p1")
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert mgr._plugins["p1"].manifest.kind == "standalone"
|
||||
|
||||
@pytest.mark.parametrize("kind", ["backend", "exclusive", "standalone"])
|
||||
def test_valid_kinds_parsed(self, kind, tmp_path, monkeypatch):
|
||||
import os
|
||||
hermes_home = Path(os.environ["HERMES_HOME"]) # set by hermetic conftest fixture
|
||||
_write_plugin(
|
||||
hermes_home / "plugins",
|
||||
["p1"],
|
||||
manifest_extra={"kind": kind},
|
||||
)
|
||||
# Not all kinds auto-load, but manifest should parse.
|
||||
_enable(hermes_home, "p1")
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "p1" in mgr._plugins
|
||||
assert mgr._plugins["p1"].manifest.kind == kind
|
||||
|
||||
def test_unknown_kind_falls_back_to_standalone(self, tmp_path, monkeypatch, caplog):
|
||||
import os
|
||||
hermes_home = Path(os.environ["HERMES_HOME"]) # set by hermetic conftest fixture
|
||||
_write_plugin(
|
||||
hermes_home / "plugins",
|
||||
["p1"],
|
||||
manifest_extra={"kind": "bogus"},
|
||||
)
|
||||
_enable(hermes_home, "p1")
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert mgr._plugins["p1"].manifest.kind == "standalone"
|
||||
assert any(
|
||||
"unknown kind" in rec.getMessage() for rec in caplog.records
|
||||
)
|
||||
|
||||
|
||||
# ── Gate logic ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBackendGate:
|
||||
def test_user_backend_still_gated_by_enabled(self, tmp_path, monkeypatch):
|
||||
"""User-installed ``kind: backend`` plugins still require opt-in —
|
||||
they're not trusted by default."""
|
||||
import os
|
||||
hermes_home = Path(os.environ["HERMES_HOME"]) # set by hermetic conftest fixture
|
||||
user_plugins = hermes_home / "plugins"
|
||||
|
||||
_write_plugin(
|
||||
user_plugins,
|
||||
["image_gen", "fancy"],
|
||||
manifest_extra={"kind": "backend"},
|
||||
)
|
||||
# Do NOT opt in.
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
loaded = mgr._plugins["image_gen/fancy"]
|
||||
assert loaded.enabled is False
|
||||
assert "not enabled" in (loaded.error or "")
|
||||
|
||||
def test_user_backend_loads_when_enabled(self, tmp_path, monkeypatch):
|
||||
import os
|
||||
hermes_home = Path(os.environ["HERMES_HOME"]) # set by hermetic conftest fixture
|
||||
user_plugins = hermes_home / "plugins"
|
||||
|
||||
_write_plugin(
|
||||
user_plugins,
|
||||
["image_gen", "fancy"],
|
||||
manifest_extra={"kind": "backend"},
|
||||
)
|
||||
_enable(hermes_home, "image_gen/fancy")
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert mgr._plugins["image_gen/fancy"].enabled is True
|
||||
|
||||
def test_exclusive_kind_skipped(self, tmp_path, monkeypatch):
|
||||
"""``kind: exclusive`` plugins are recorded but not loaded — the
|
||||
category's own discovery system handles them (memory today)."""
|
||||
import os
|
||||
hermes_home = Path(os.environ["HERMES_HOME"]) # set by hermetic conftest fixture
|
||||
_write_plugin(
|
||||
hermes_home / "plugins",
|
||||
["some-backend"],
|
||||
manifest_extra={"kind": "exclusive"},
|
||||
)
|
||||
_enable(hermes_home, "some-backend")
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
loaded = mgr._plugins["some-backend"]
|
||||
assert loaded.enabled is False
|
||||
assert "exclusive" in (loaded.error or "")
|
||||
|
||||
|
||||
# ── Bundled backend auto-load (integration with real bundled plugin) ────────
|
||||
|
||||
|
||||
class TestBundledBackendAutoLoad:
|
||||
def test_bundled_image_gen_openai_autoloads(self, tmp_path, monkeypatch):
|
||||
"""The bundled ``plugins/image_gen/openai/`` plugin loads without
|
||||
any opt-in — it's ``kind: backend`` and shipped in-repo."""
|
||||
import os
|
||||
hermes_home = Path(os.environ["HERMES_HOME"]) # set by hermetic conftest fixture
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "image_gen/openai" in mgr._plugins
|
||||
loaded = mgr._plugins["image_gen/openai"]
|
||||
assert loaded.manifest.source == "bundled"
|
||||
assert loaded.manifest.kind == "backend"
|
||||
assert loaded.enabled is True, f"error: {loaded.error}"
|
||||
|
||||
|
||||
# ── PluginContext.register_image_gen_provider ───────────────────────────────
|
||||
|
||||
|
||||
class TestRegisterImageGenProvider:
|
||||
def test_accepts_valid_provider(self, tmp_path, monkeypatch):
|
||||
from agent import image_gen_registry
|
||||
from agent.image_gen_provider import ImageGenProvider
|
||||
|
||||
image_gen_registry._reset_for_tests()
|
||||
|
||||
class FakeProvider(ImageGenProvider):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "fake-test"
|
||||
|
||||
def generate(self, prompt, aspect_ratio="landscape", **kw):
|
||||
return {"success": True, "image": "test://fake"}
|
||||
|
||||
import os
|
||||
hermes_home = Path(os.environ["HERMES_HOME"]) # set by hermetic conftest fixture
|
||||
plugin_dir = _write_plugin(
|
||||
hermes_home / "plugins",
|
||||
["my-img-plugin"],
|
||||
register_body=(
|
||||
"from agent.image_gen_provider import ImageGenProvider\n"
|
||||
" class P(ImageGenProvider):\n"
|
||||
" @property\n"
|
||||
" def name(self): return 'fake-ctx'\n"
|
||||
" def generate(self, prompt, aspect_ratio='landscape', **kw):\n"
|
||||
" return {'success': True, 'image': 'x://y'}\n"
|
||||
" ctx.register_image_gen_provider(P())"
|
||||
),
|
||||
)
|
||||
_enable(hermes_home, "my-img-plugin")
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert mgr._plugins["my-img-plugin"].enabled is True
|
||||
assert image_gen_registry.get_provider("fake-ctx") is not None
|
||||
|
||||
image_gen_registry._reset_for_tests()
|
||||
|
||||
def test_rejects_non_provider(self, tmp_path, monkeypatch, caplog):
|
||||
from agent import image_gen_registry
|
||||
|
||||
image_gen_registry._reset_for_tests()
|
||||
|
||||
import os
|
||||
hermes_home = Path(os.environ["HERMES_HOME"]) # set by hermetic conftest fixture
|
||||
_write_plugin(
|
||||
hermes_home / "plugins",
|
||||
["bad-img-plugin"],
|
||||
register_body="ctx.register_image_gen_provider('not a provider')",
|
||||
)
|
||||
_enable(hermes_home, "bad-img-plugin")
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
# Plugin loaded (register returned normally) but nothing was
|
||||
# registered in the provider registry.
|
||||
assert mgr._plugins["bad-img-plugin"].enabled is True
|
||||
assert image_gen_registry.get_provider("not a provider") is None
|
||||
|
||||
image_gen_registry._reset_for_tests()
|
||||
0
tests/plugins/image_gen/__init__.py
Normal file
0
tests/plugins/image_gen/__init__.py
Normal file
243
tests/plugins/image_gen/test_openai_provider.py
Normal file
243
tests/plugins/image_gen/test_openai_provider.py
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
"""Tests for the bundled OpenAI image_gen plugin (gpt-image-2, three tiers)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import plugins.image_gen.openai as openai_plugin
|
||||
|
||||
|
||||
# 1×1 transparent PNG — valid bytes for save_b64_image()
|
||||
_PNG_HEX = (
|
||||
"89504e470d0a1a0a0000000d49484452000000010000000108060000001f15c4"
|
||||
"890000000d49444154789c6300010000000500010d0a2db40000000049454e44"
|
||||
"ae426082"
|
||||
)
|
||||
|
||||
|
||||
def _b64_png() -> str:
|
||||
import base64
|
||||
return base64.b64encode(bytes.fromhex(_PNG_HEX)).decode()
|
||||
|
||||
|
||||
def _fake_response(*, b64=None, url=None, revised_prompt=None):
|
||||
item = SimpleNamespace(b64_json=b64, url=url, revised_prompt=revised_prompt)
|
||||
return SimpleNamespace(data=[item])
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _tmp_hermes_home(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
yield tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider(monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
|
||||
return openai_plugin.OpenAIImageGenProvider()
|
||||
|
||||
|
||||
def _patched_openai(fake_client: MagicMock):
|
||||
fake_openai = MagicMock()
|
||||
fake_openai.OpenAI.return_value = fake_client
|
||||
return patch.dict("sys.modules", {"openai": fake_openai})
|
||||
|
||||
|
||||
# ── Metadata ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestMetadata:
|
||||
def test_name(self, provider):
|
||||
assert provider.name == "openai"
|
||||
|
||||
def test_default_model(self, provider):
|
||||
assert provider.default_model() == "gpt-image-2-medium"
|
||||
|
||||
def test_list_models_three_tiers(self, provider):
|
||||
ids = [m["id"] for m in provider.list_models()]
|
||||
assert ids == ["gpt-image-2-low", "gpt-image-2-medium", "gpt-image-2-high"]
|
||||
|
||||
def test_catalog_entries_have_display_speed_strengths(self, provider):
|
||||
for entry in provider.list_models():
|
||||
assert entry["display"].startswith("GPT Image 2")
|
||||
assert entry["speed"]
|
||||
assert entry["strengths"]
|
||||
|
||||
|
||||
# ── Availability ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAvailability:
|
||||
def test_no_api_key_unavailable(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
assert openai_plugin.OpenAIImageGenProvider().is_available() is False
|
||||
|
||||
def test_api_key_set_available(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test")
|
||||
assert openai_plugin.OpenAIImageGenProvider().is_available() is True
|
||||
|
||||
|
||||
# ── Model resolution ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestModelResolution:
|
||||
def test_default_is_medium(self):
|
||||
model_id, meta = openai_plugin._resolve_model()
|
||||
assert model_id == "gpt-image-2-medium"
|
||||
assert meta["quality"] == "medium"
|
||||
|
||||
def test_env_var_override(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_IMAGE_MODEL", "gpt-image-2-high")
|
||||
model_id, meta = openai_plugin._resolve_model()
|
||||
assert model_id == "gpt-image-2-high"
|
||||
assert meta["quality"] == "high"
|
||||
|
||||
def test_env_var_unknown_falls_back(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_IMAGE_MODEL", "bogus-tier")
|
||||
model_id, _ = openai_plugin._resolve_model()
|
||||
assert model_id == openai_plugin.DEFAULT_MODEL
|
||||
|
||||
def test_config_openai_model(self, tmp_path):
|
||||
import yaml
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
yaml.safe_dump({"image_gen": {"openai": {"model": "gpt-image-2-low"}}})
|
||||
)
|
||||
model_id, meta = openai_plugin._resolve_model()
|
||||
assert model_id == "gpt-image-2-low"
|
||||
assert meta["quality"] == "low"
|
||||
|
||||
def test_config_top_level_model(self, tmp_path):
|
||||
"""``image_gen.model: gpt-image-2-high`` also works (top-level)."""
|
||||
import yaml
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
yaml.safe_dump({"image_gen": {"model": "gpt-image-2-high"}})
|
||||
)
|
||||
model_id, meta = openai_plugin._resolve_model()
|
||||
assert model_id == "gpt-image-2-high"
|
||||
assert meta["quality"] == "high"
|
||||
|
||||
|
||||
# ── Generate ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGenerate:
|
||||
def test_empty_prompt_rejected(self, provider):
|
||||
result = provider.generate("", aspect_ratio="square")
|
||||
assert result["success"] is False
|
||||
assert result["error_type"] == "invalid_argument"
|
||||
|
||||
def test_missing_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
result = openai_plugin.OpenAIImageGenProvider().generate("a cat")
|
||||
assert result["success"] is False
|
||||
assert result["error_type"] == "auth_required"
|
||||
|
||||
def test_b64_saves_to_cache(self, provider, tmp_path):
|
||||
import base64
|
||||
png_bytes = bytes.fromhex(_PNG_HEX)
|
||||
fake_client = MagicMock()
|
||||
fake_client.images.generate.return_value = _fake_response(b64=_b64_png())
|
||||
|
||||
with _patched_openai(fake_client):
|
||||
result = provider.generate("a cat", aspect_ratio="landscape")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["model"] == "gpt-image-2-medium"
|
||||
assert result["aspect_ratio"] == "landscape"
|
||||
assert result["provider"] == "openai"
|
||||
assert result["quality"] == "medium"
|
||||
|
||||
saved = Path(result["image"])
|
||||
assert saved.exists()
|
||||
assert saved.parent == tmp_path / "cache" / "images"
|
||||
assert saved.read_bytes() == png_bytes
|
||||
|
||||
call_kwargs = fake_client.images.generate.call_args.kwargs
|
||||
# All tiers hit the single underlying API model.
|
||||
assert call_kwargs["model"] == "gpt-image-2"
|
||||
assert call_kwargs["quality"] == "medium"
|
||||
assert call_kwargs["size"] == "1536x1024"
|
||||
# gpt-image-2 rejects response_format — we must NOT send it.
|
||||
assert "response_format" not in call_kwargs
|
||||
|
||||
@pytest.mark.parametrize("tier,expected_quality", [
|
||||
("gpt-image-2-low", "low"),
|
||||
("gpt-image-2-medium", "medium"),
|
||||
("gpt-image-2-high", "high"),
|
||||
])
|
||||
def test_tier_maps_to_quality(self, provider, monkeypatch, tier, expected_quality):
|
||||
monkeypatch.setenv("OPENAI_IMAGE_MODEL", tier)
|
||||
fake_client = MagicMock()
|
||||
fake_client.images.generate.return_value = _fake_response(b64=_b64_png())
|
||||
|
||||
with _patched_openai(fake_client):
|
||||
result = provider.generate("a cat")
|
||||
|
||||
assert result["model"] == tier
|
||||
assert result["quality"] == expected_quality
|
||||
assert fake_client.images.generate.call_args.kwargs["quality"] == expected_quality
|
||||
# Always the same underlying API model regardless of tier.
|
||||
assert fake_client.images.generate.call_args.kwargs["model"] == "gpt-image-2"
|
||||
|
||||
@pytest.mark.parametrize("aspect,expected_size", [
|
||||
("landscape", "1536x1024"),
|
||||
("square", "1024x1024"),
|
||||
("portrait", "1024x1536"),
|
||||
])
|
||||
def test_aspect_ratio_mapping(self, provider, aspect, expected_size):
|
||||
fake_client = MagicMock()
|
||||
fake_client.images.generate.return_value = _fake_response(b64=_b64_png())
|
||||
|
||||
with _patched_openai(fake_client):
|
||||
provider.generate("a cat", aspect_ratio=aspect)
|
||||
|
||||
assert fake_client.images.generate.call_args.kwargs["size"] == expected_size
|
||||
|
||||
def test_revised_prompt_passed_through(self, provider):
|
||||
fake_client = MagicMock()
|
||||
fake_client.images.generate.return_value = _fake_response(
|
||||
b64=_b64_png(), revised_prompt="A photo of a cat",
|
||||
)
|
||||
|
||||
with _patched_openai(fake_client):
|
||||
result = provider.generate("a cat")
|
||||
|
||||
assert result["revised_prompt"] == "A photo of a cat"
|
||||
|
||||
def test_api_error_returns_error_response(self, provider):
|
||||
fake_client = MagicMock()
|
||||
fake_client.images.generate.side_effect = RuntimeError("boom")
|
||||
|
||||
with _patched_openai(fake_client):
|
||||
result = provider.generate("a cat")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["error_type"] == "api_error"
|
||||
assert "boom" in result["error"]
|
||||
|
||||
def test_empty_response_data(self, provider):
|
||||
fake_client = MagicMock()
|
||||
fake_client.images.generate.return_value = SimpleNamespace(data=[])
|
||||
|
||||
with _patched_openai(fake_client):
|
||||
result = provider.generate("a cat")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["error_type"] == "empty_response"
|
||||
|
||||
def test_url_fallback_if_api_changes(self, provider):
|
||||
"""Defensive: if OpenAI ever returns URL instead of b64, pass through."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.images.generate.return_value = _fake_response(
|
||||
b64=None, url="https://example.com/img.png",
|
||||
)
|
||||
|
||||
with _patched_openai(fake_client):
|
||||
result = provider.generate("a cat")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["image"] == "https://example.com/img.png"
|
||||
|
|
@ -12,6 +12,7 @@ from types import SimpleNamespace
|
|||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from agent.codex_responses_adapter import _chat_messages_to_responses_input, _normalize_codex_response, _preflight_codex_input_items
|
||||
|
||||
sys.modules.setdefault("fire", types.SimpleNamespace(Fire=lambda *a, **k: None))
|
||||
sys.modules.setdefault("firecrawl", types.SimpleNamespace(Firecrawl=object))
|
||||
|
|
@ -446,7 +447,7 @@ class TestChatMessagesToResponsesInput:
|
|||
agent = _make_agent(monkeypatch, "openai-codex", api_mode="codex_responses",
|
||||
base_url="https://chatgpt.com/backend-api/codex")
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
items = agent._chat_messages_to_responses_input(messages)
|
||||
items = _chat_messages_to_responses_input(messages)
|
||||
assert items == [{"role": "user", "content": "hello"}]
|
||||
|
||||
def test_system_messages_filtered(self, monkeypatch):
|
||||
|
|
@ -456,7 +457,7 @@ class TestChatMessagesToResponsesInput:
|
|||
{"role": "system", "content": "be helpful"},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
items = agent._chat_messages_to_responses_input(messages)
|
||||
items = _chat_messages_to_responses_input(messages)
|
||||
assert len(items) == 1
|
||||
assert items[0]["role"] == "user"
|
||||
|
||||
|
|
@ -472,7 +473,7 @@ class TestChatMessagesToResponsesInput:
|
|||
"function": {"name": "web_search", "arguments": '{"query": "test"}'},
|
||||
}],
|
||||
}]
|
||||
items = agent._chat_messages_to_responses_input(messages)
|
||||
items = _chat_messages_to_responses_input(messages)
|
||||
fc_items = [i for i in items if i.get("type") == "function_call"]
|
||||
assert len(fc_items) == 1
|
||||
assert fc_items[0]["name"] == "web_search"
|
||||
|
|
@ -482,7 +483,7 @@ class TestChatMessagesToResponsesInput:
|
|||
agent = _make_agent(monkeypatch, "openai-codex", api_mode="codex_responses",
|
||||
base_url="https://chatgpt.com/backend-api/codex")
|
||||
messages = [{"role": "tool", "tool_call_id": "call_abc", "content": "result here"}]
|
||||
items = agent._chat_messages_to_responses_input(messages)
|
||||
items = _chat_messages_to_responses_input(messages)
|
||||
assert items[0]["type"] == "function_call_output"
|
||||
assert items[0]["call_id"] == "call_abc"
|
||||
assert items[0]["output"] == "result here"
|
||||
|
|
@ -502,7 +503,7 @@ class TestChatMessagesToResponsesInput:
|
|||
},
|
||||
{"role": "user", "content": "continue"},
|
||||
]
|
||||
items = agent._chat_messages_to_responses_input(messages)
|
||||
items = _chat_messages_to_responses_input(messages)
|
||||
reasoning_items = [i for i in items if i.get("type") == "reasoning"]
|
||||
assert len(reasoning_items) == 1
|
||||
assert reasoning_items[0]["encrypted_content"] == "gAAAA_test_blob"
|
||||
|
|
@ -515,7 +516,7 @@ class TestChatMessagesToResponsesInput:
|
|||
{"role": "assistant", "content": "hi"},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
items = agent._chat_messages_to_responses_input(messages)
|
||||
items = _chat_messages_to_responses_input(messages)
|
||||
reasoning_items = [i for i in items if i.get("type") == "reasoning"]
|
||||
assert len(reasoning_items) == 0
|
||||
|
||||
|
|
@ -539,7 +540,7 @@ class TestNormalizeCodexResponse:
|
|||
],
|
||||
status="completed",
|
||||
)
|
||||
msg, reason = agent._normalize_codex_response(response)
|
||||
msg, reason = _normalize_codex_response(response)
|
||||
assert msg.content == "Hello!"
|
||||
assert reason == "stop"
|
||||
|
||||
|
|
@ -557,7 +558,7 @@ class TestNormalizeCodexResponse:
|
|||
],
|
||||
status="completed",
|
||||
)
|
||||
msg, reason = agent._normalize_codex_response(response)
|
||||
msg, reason = _normalize_codex_response(response)
|
||||
assert msg.content == "42"
|
||||
assert "math" in msg.reasoning
|
||||
assert reason == "stop"
|
||||
|
|
@ -576,7 +577,7 @@ class TestNormalizeCodexResponse:
|
|||
],
|
||||
status="completed",
|
||||
)
|
||||
msg, reason = agent._normalize_codex_response(response)
|
||||
msg, reason = _normalize_codex_response(response)
|
||||
assert msg.codex_reasoning_items is not None
|
||||
assert len(msg.codex_reasoning_items) == 1
|
||||
assert msg.codex_reasoning_items[0]["encrypted_content"] == "gAAAA_secret_blob_123"
|
||||
|
|
@ -592,7 +593,7 @@ class TestNormalizeCodexResponse:
|
|||
],
|
||||
status="completed",
|
||||
)
|
||||
msg, reason = agent._normalize_codex_response(response)
|
||||
msg, reason = _normalize_codex_response(response)
|
||||
assert msg.codex_reasoning_items is None
|
||||
|
||||
def test_tool_calls_extracted(self, monkeypatch):
|
||||
|
|
@ -605,7 +606,7 @@ class TestNormalizeCodexResponse:
|
|||
],
|
||||
status="completed",
|
||||
)
|
||||
msg, reason = agent._normalize_codex_response(response)
|
||||
msg, reason = _normalize_codex_response(response)
|
||||
assert reason == "tool_calls"
|
||||
assert len(msg.tool_calls) == 1
|
||||
assert msg.tool_calls[0].function.name == "web_search"
|
||||
|
|
@ -821,7 +822,7 @@ class TestCodexReasoningPreflight:
|
|||
"summary": [{"type": "summary_text", "text": "Thinking about it"}]},
|
||||
{"role": "assistant", "content": "hi there"},
|
||||
]
|
||||
normalized = agent._preflight_codex_input_items(raw_input)
|
||||
normalized = _preflight_codex_input_items(raw_input)
|
||||
reasoning_items = [i for i in normalized if i.get("type") == "reasoning"]
|
||||
assert len(reasoning_items) == 1
|
||||
assert reasoning_items[0]["encrypted_content"] == "abc123encrypted"
|
||||
|
|
@ -837,7 +838,7 @@ class TestCodexReasoningPreflight:
|
|||
raw_input = [
|
||||
{"type": "reasoning", "encrypted_content": "abc123"},
|
||||
]
|
||||
normalized = agent._preflight_codex_input_items(raw_input)
|
||||
normalized = _preflight_codex_input_items(raw_input)
|
||||
assert len(normalized) == 1
|
||||
assert "id" not in normalized[0]
|
||||
assert normalized[0]["summary"] == [] # default empty summary
|
||||
|
|
@ -849,7 +850,7 @@ class TestCodexReasoningPreflight:
|
|||
{"type": "reasoning", "encrypted_content": ""},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
normalized = agent._preflight_codex_input_items(raw_input)
|
||||
normalized = _preflight_codex_input_items(raw_input)
|
||||
reasoning_items = [i for i in normalized if i.get("type") == "reasoning"]
|
||||
assert len(reasoning_items) == 0
|
||||
|
||||
|
|
@ -868,7 +869,7 @@ class TestCodexReasoningPreflight:
|
|||
},
|
||||
{"role": "user", "content": "follow up"},
|
||||
]
|
||||
items = agent._chat_messages_to_responses_input(messages)
|
||||
items = _chat_messages_to_responses_input(messages)
|
||||
reasoning_items = [i for i in items if isinstance(i, dict) and i.get("type") == "reasoning"]
|
||||
assert len(reasoning_items) == 1
|
||||
assert reasoning_items[0]["encrypted_content"] == "enc123"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from types import SimpleNamespace
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent.codex_responses_adapter import _chat_messages_to_responses_input, _normalize_codex_response, _preflight_codex_input_items
|
||||
|
||||
import run_agent
|
||||
from run_agent import AIAgent
|
||||
|
|
@ -4248,7 +4249,7 @@ class TestNormalizeCodexDictArguments:
|
|||
json.dumps, not str(), so downstream json.loads() succeeds."""
|
||||
args_dict = {"query": "weather in NYC", "units": "celsius"}
|
||||
response = self._make_codex_response("function_call", args_dict)
|
||||
msg, _ = agent._normalize_codex_response(response)
|
||||
msg, _ = _normalize_codex_response(response)
|
||||
tc = msg.tool_calls[0]
|
||||
parsed = json.loads(tc.function.arguments)
|
||||
assert parsed == args_dict
|
||||
|
|
@ -4257,7 +4258,7 @@ class TestNormalizeCodexDictArguments:
|
|||
"""dict arguments from custom_tool_call must also use json.dumps."""
|
||||
args_dict = {"path": "/tmp/test.txt", "content": "hello"}
|
||||
response = self._make_codex_response("custom_tool_call", args_dict)
|
||||
msg, _ = agent._normalize_codex_response(response)
|
||||
msg, _ = _normalize_codex_response(response)
|
||||
tc = msg.tool_calls[0]
|
||||
parsed = json.loads(tc.function.arguments)
|
||||
assert parsed == args_dict
|
||||
|
|
@ -4266,7 +4267,7 @@ class TestNormalizeCodexDictArguments:
|
|||
"""String arguments must pass through without modification."""
|
||||
args_str = '{"query": "test"}'
|
||||
response = self._make_codex_response("function_call", args_str)
|
||||
msg, _ = agent._normalize_codex_response(response)
|
||||
msg, _ = _normalize_codex_response(response)
|
||||
tc = msg.tool_calls[0]
|
||||
assert tc.function.arguments == args_str
|
||||
|
||||
|
|
|
|||
|
|
@ -640,7 +640,8 @@ def test_run_conversation_codex_tool_round_trip(monkeypatch):
|
|||
|
||||
def test_chat_messages_to_responses_input_uses_call_id_for_function_call(monkeypatch):
|
||||
agent = _build_agent(monkeypatch)
|
||||
items = agent._chat_messages_to_responses_input(
|
||||
from agent.codex_responses_adapter import _chat_messages_to_responses_input
|
||||
items = _chat_messages_to_responses_input(
|
||||
[
|
||||
{"role": "user", "content": "Run terminal"},
|
||||
{
|
||||
|
|
@ -668,7 +669,8 @@ def test_chat_messages_to_responses_input_uses_call_id_for_function_call(monkeyp
|
|||
|
||||
def test_chat_messages_to_responses_input_accepts_call_pipe_fc_ids(monkeypatch):
|
||||
agent = _build_agent(monkeypatch)
|
||||
items = agent._chat_messages_to_responses_input(
|
||||
from agent.codex_responses_adapter import _chat_messages_to_responses_input
|
||||
items = _chat_messages_to_responses_input(
|
||||
[
|
||||
{"role": "user", "content": "Run terminal"},
|
||||
{
|
||||
|
|
@ -696,7 +698,8 @@ def test_chat_messages_to_responses_input_accepts_call_pipe_fc_ids(monkeypatch):
|
|||
|
||||
def test_preflight_codex_api_kwargs_strips_optional_function_call_id(monkeypatch):
|
||||
agent = _build_agent(monkeypatch)
|
||||
preflight = agent._preflight_codex_api_kwargs(
|
||||
from agent.codex_responses_adapter import _preflight_codex_api_kwargs
|
||||
preflight = _preflight_codex_api_kwargs(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"instructions": "You are Hermes.",
|
||||
|
|
@ -724,7 +727,8 @@ def test_preflight_codex_api_kwargs_rejects_function_call_output_without_call_id
|
|||
agent = _build_agent(monkeypatch)
|
||||
|
||||
with pytest.raises(ValueError, match="function_call_output is missing call_id"):
|
||||
agent._preflight_codex_api_kwargs(
|
||||
from agent.codex_responses_adapter import _preflight_codex_api_kwargs
|
||||
_preflight_codex_api_kwargs(
|
||||
{
|
||||
"model": "gpt-5-codex",
|
||||
"instructions": "You are Hermes.",
|
||||
|
|
@ -741,7 +745,8 @@ def test_preflight_codex_api_kwargs_rejects_unsupported_request_fields(monkeypat
|
|||
kwargs["some_unknown_field"] = "value"
|
||||
|
||||
with pytest.raises(ValueError, match="unsupported field"):
|
||||
agent._preflight_codex_api_kwargs(kwargs)
|
||||
from agent.codex_responses_adapter import _preflight_codex_api_kwargs
|
||||
_preflight_codex_api_kwargs(kwargs)
|
||||
|
||||
|
||||
def test_preflight_codex_api_kwargs_allows_reasoning_and_temperature(monkeypatch):
|
||||
|
|
@ -752,7 +757,8 @@ def test_preflight_codex_api_kwargs_allows_reasoning_and_temperature(monkeypatch
|
|||
kwargs["temperature"] = 0.7
|
||||
kwargs["max_output_tokens"] = 4096
|
||||
|
||||
result = agent._preflight_codex_api_kwargs(kwargs)
|
||||
from agent.codex_responses_adapter import _preflight_codex_api_kwargs
|
||||
result = _preflight_codex_api_kwargs(kwargs)
|
||||
assert result["reasoning"] == {"effort": "high", "summary": "auto"}
|
||||
assert result["include"] == ["reasoning.encrypted_content"]
|
||||
assert result["temperature"] == 0.7
|
||||
|
|
@ -764,7 +770,8 @@ def test_preflight_codex_api_kwargs_allows_service_tier(monkeypatch):
|
|||
kwargs = _codex_request_kwargs()
|
||||
kwargs["service_tier"] = "priority"
|
||||
|
||||
result = agent._preflight_codex_api_kwargs(kwargs)
|
||||
from agent.codex_responses_adapter import _preflight_codex_api_kwargs
|
||||
result = _preflight_codex_api_kwargs(kwargs)
|
||||
assert result["service_tier"] == "priority"
|
||||
|
||||
|
||||
|
|
@ -841,7 +848,8 @@ def test_run_conversation_codex_continues_after_incomplete_interim_message(monke
|
|||
|
||||
def test_normalize_codex_response_marks_commentary_only_message_as_incomplete(monkeypatch):
|
||||
agent = _build_agent(monkeypatch)
|
||||
assistant_message, finish_reason = agent._normalize_codex_response(
|
||||
from agent.codex_responses_adapter import _normalize_codex_response
|
||||
assistant_message, finish_reason = _normalize_codex_response(
|
||||
_codex_commentary_message_response("I'll inspect the repository first.")
|
||||
)
|
||||
|
||||
|
|
@ -1068,7 +1076,8 @@ def test_normalize_codex_response_marks_reasoning_only_as_incomplete(monkeypatch
|
|||
sends them into the empty-content retry loop (3 retries then failure).
|
||||
"""
|
||||
agent = _build_agent(monkeypatch)
|
||||
assistant_message, finish_reason = agent._normalize_codex_response(
|
||||
from agent.codex_responses_adapter import _normalize_codex_response
|
||||
assistant_message, finish_reason = _normalize_codex_response(
|
||||
_codex_reasoning_only_response()
|
||||
)
|
||||
|
||||
|
|
@ -1101,7 +1110,8 @@ def test_normalize_codex_response_reasoning_with_content_is_stop(monkeypatch):
|
|||
status="completed",
|
||||
model="gpt-5-codex",
|
||||
)
|
||||
assistant_message, finish_reason = agent._normalize_codex_response(response)
|
||||
from agent.codex_responses_adapter import _normalize_codex_response
|
||||
assistant_message, finish_reason = _normalize_codex_response(response)
|
||||
|
||||
assert finish_reason == "stop"
|
||||
assert "Here is the answer" in assistant_message.content
|
||||
|
|
@ -1186,7 +1196,8 @@ def test_chat_messages_to_responses_input_reasoning_only_has_following_item(monk
|
|||
],
|
||||
},
|
||||
]
|
||||
items = agent._chat_messages_to_responses_input(messages)
|
||||
from agent.codex_responses_adapter import _chat_messages_to_responses_input
|
||||
items = _chat_messages_to_responses_input(messages)
|
||||
|
||||
# Find the reasoning item
|
||||
reasoning_indices = [i for i, it in enumerate(items) if it.get("type") == "reasoning"]
|
||||
|
|
@ -1273,7 +1284,8 @@ def test_chat_messages_to_responses_input_deduplicates_reasoning_ids(monkeypatch
|
|||
],
|
||||
},
|
||||
]
|
||||
items = agent._chat_messages_to_responses_input(messages)
|
||||
from agent.codex_responses_adapter import _chat_messages_to_responses_input
|
||||
items = _chat_messages_to_responses_input(messages)
|
||||
|
||||
reasoning_items = [it for it in items if it.get("type") == "reasoning"]
|
||||
# Dedup: rs_aaa appears in both turns but should only be emitted once.
|
||||
|
|
@ -1299,7 +1311,8 @@ def test_preflight_codex_input_deduplicates_reasoning_ids(monkeypatch):
|
|||
{"type": "reasoning", "id": "rs_zzz", "encrypted_content": "enc_b"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
]
|
||||
normalized = agent._preflight_codex_input_items(raw_input)
|
||||
from agent.codex_responses_adapter import _preflight_codex_input_items
|
||||
normalized = _preflight_codex_input_items(raw_input)
|
||||
|
||||
reasoning_items = [it for it in normalized if it.get("type") == "reasoning"]
|
||||
# rs_xyz duplicate should be collapsed to one item; rs_zzz kept.
|
||||
|
|
|
|||
93
tests/run_agent/test_switch_model_fallback_prune.py
Normal file
93
tests/run_agent/test_switch_model_fallback_prune.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""Regression test for TUI v2 blitz bug: explicit /model --provider switch
|
||||
silently fell back to the old primary provider on the next turn because the
|
||||
fallback chain — seeded from config at agent __init__ — kept entries for the
|
||||
provider the user just moved away from.
|
||||
|
||||
Reported: "switched from openrouter provider to anthropic api key via hermes
|
||||
model and the tui keeps trying openrouter".
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _make_agent(chain):
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
|
||||
agent.provider = "openrouter"
|
||||
agent.model = "x-ai/grok-4"
|
||||
agent.base_url = "https://openrouter.ai/api/v1"
|
||||
agent.api_key = "or-key"
|
||||
agent.api_mode = "chat_completions"
|
||||
agent.client = MagicMock()
|
||||
agent._client_kwargs = {"api_key": "or-key", "base_url": "https://openrouter.ai/api/v1"}
|
||||
agent.context_compressor = None
|
||||
agent._anthropic_api_key = ""
|
||||
agent._anthropic_base_url = None
|
||||
agent._anthropic_client = None
|
||||
agent._is_anthropic_oauth = False
|
||||
agent._cached_system_prompt = "cached"
|
||||
agent._primary_runtime = {}
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_index = 0
|
||||
agent._fallback_chain = list(chain)
|
||||
agent._fallback_model = chain[0] if chain else None
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def _switch_to_anthropic(agent):
|
||||
with (
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-xyz"),
|
||||
patch("agent.anthropic_adapter._is_oauth_token", return_value=False),
|
||||
patch("hermes_cli.timeouts.get_provider_request_timeout", return_value=None),
|
||||
):
|
||||
agent.switch_model(
|
||||
new_model="claude-sonnet-4-5",
|
||||
new_provider="anthropic",
|
||||
api_key="sk-ant-xyz",
|
||||
base_url="https://api.anthropic.com",
|
||||
api_mode="anthropic_messages",
|
||||
)
|
||||
|
||||
|
||||
def test_switch_drops_old_primary_from_fallback_chain():
|
||||
agent = _make_agent([
|
||||
{"provider": "openrouter", "model": "x-ai/grok-4"},
|
||||
{"provider": "nous", "model": "hermes-4"},
|
||||
])
|
||||
|
||||
_switch_to_anthropic(agent)
|
||||
|
||||
providers = [entry["provider"] for entry in agent._fallback_chain]
|
||||
|
||||
assert "openrouter" not in providers, "old primary must be pruned"
|
||||
assert "anthropic" not in providers, "new primary is redundant in the chain"
|
||||
assert providers == ["nous"]
|
||||
assert agent._fallback_model == {"provider": "nous", "model": "hermes-4"}
|
||||
|
||||
|
||||
def test_switch_with_empty_chain_stays_empty():
|
||||
agent = _make_agent([])
|
||||
|
||||
_switch_to_anthropic(agent)
|
||||
|
||||
assert agent._fallback_chain == []
|
||||
assert agent._fallback_model is None
|
||||
|
||||
|
||||
def test_switch_within_same_provider_preserves_chain():
|
||||
chain = [{"provider": "openrouter", "model": "x-ai/grok-4"}]
|
||||
agent = _make_agent(chain)
|
||||
|
||||
with patch("hermes_cli.timeouts.get_provider_request_timeout", return_value=None):
|
||||
agent.switch_model(
|
||||
new_model="openai/gpt-5",
|
||||
new_provider="openrouter",
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
|
||||
assert agent._fallback_chain == chain
|
||||
|
|
@ -147,6 +147,27 @@ class TestEscapedSpaces:
|
|||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["remainder"] == "what is this?"
|
||||
|
||||
def test_unquoted_spaces_in_path(self, tmp_image_with_spaces):
|
||||
result = _detect_file_drop(str(tmp_image_with_spaces))
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["is_image"] is True
|
||||
assert result["remainder"] == ""
|
||||
|
||||
def test_unquoted_spaces_with_trailing_text(self, tmp_image_with_spaces):
|
||||
user_input = f"{tmp_image_with_spaces} what is this?"
|
||||
result = _detect_file_drop(user_input)
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["remainder"] == "what is this?"
|
||||
|
||||
def test_file_uri_image_path(self, tmp_image_with_spaces):
|
||||
uri = tmp_image_with_spaces.as_uri()
|
||||
result = _detect_file_drop(uri)
|
||||
assert result is not None
|
||||
assert result["path"] == tmp_image_with_spaces
|
||||
assert result["is_image"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: edge cases
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
|
@ -230,6 +231,48 @@ def test_config_set_model_global_persists(monkeypatch):
|
|||
assert saved["model"]["base_url"] == "https://api.anthropic.com"
|
||||
|
||||
|
||||
def test_config_set_model_syncs_inference_provider_env(monkeypatch):
|
||||
"""After an explicit provider switch, HERMES_INFERENCE_PROVIDER must
|
||||
reflect the user's choice so ambient re-resolution (credential pool
|
||||
refresh, aux clients) picks up the new provider instead of the original
|
||||
one persisted in config or shell env.
|
||||
|
||||
Regression: a TUI user switched openrouter → anthropic and the TUI kept
|
||||
trying openrouter because the env-var-backed resolvers still saw the old
|
||||
provider.
|
||||
"""
|
||||
class _Agent:
|
||||
provider = "openrouter"
|
||||
model = "old/model"
|
||||
base_url = ""
|
||||
api_key = "sk-or"
|
||||
|
||||
def switch_model(self, **_kwargs):
|
||||
return None
|
||||
|
||||
result = types.SimpleNamespace(
|
||||
success=True,
|
||||
new_model="claude-sonnet-4.6",
|
||||
target_provider="anthropic",
|
||||
api_key="sk-ant",
|
||||
base_url="https://api.anthropic.com",
|
||||
api_mode="anthropic_messages",
|
||||
warning_message="",
|
||||
)
|
||||
|
||||
server._sessions["sid"] = _session(agent=_Agent())
|
||||
monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "openrouter")
|
||||
monkeypatch.setattr("hermes_cli.model_switch.switch_model", lambda **_kwargs: result)
|
||||
monkeypatch.setattr(server, "_restart_slash_worker", lambda session: None)
|
||||
monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None)
|
||||
|
||||
server.handle_request(
|
||||
{"id": "1", "method": "config.set", "params": {"session_id": "sid", "key": "model", "value": "claude-sonnet-4.6 --provider anthropic"}}
|
||||
)
|
||||
|
||||
assert os.environ["HERMES_INFERENCE_PROVIDER"] == "anthropic"
|
||||
|
||||
|
||||
def test_config_set_personality_rejects_unknown_name(monkeypatch):
|
||||
monkeypatch.setattr(server, "_available_personalities", lambda cfg=None: {"helpful": "You are helpful."})
|
||||
resp = server.handle_request(
|
||||
|
|
@ -350,6 +393,11 @@ def test_prompt_submit_expands_context_refs(monkeypatch):
|
|||
def test_image_attach_appends_local_image(monkeypatch):
|
||||
fake_cli = types.ModuleType("cli")
|
||||
fake_cli._IMAGE_EXTENSIONS = {".png"}
|
||||
fake_cli._detect_file_drop = lambda raw: {
|
||||
"path": Path("/tmp/cat.png"),
|
||||
"is_image": True,
|
||||
"remainder": "",
|
||||
}
|
||||
fake_cli._split_path_input = lambda raw: (raw, "")
|
||||
fake_cli._resolve_attachment_path = lambda raw: Path("/tmp/cat.png")
|
||||
|
||||
|
|
@ -363,6 +411,31 @@ def test_image_attach_appends_local_image(monkeypatch):
|
|||
assert len(server._sessions["sid"]["attached_images"]) == 1
|
||||
|
||||
|
||||
def test_image_attach_accepts_unquoted_screenshot_path_with_spaces(monkeypatch):
|
||||
screenshot = Path("/tmp/Screenshot 2026-04-21 at 1.04.43 PM.png")
|
||||
fake_cli = types.ModuleType("cli")
|
||||
fake_cli._IMAGE_EXTENSIONS = {".png"}
|
||||
fake_cli._detect_file_drop = lambda raw: {
|
||||
"path": screenshot,
|
||||
"is_image": True,
|
||||
"remainder": "",
|
||||
}
|
||||
fake_cli._split_path_input = lambda raw: ("/tmp/Screenshot", "2026-04-21 at 1.04.43 PM.png")
|
||||
fake_cli._resolve_attachment_path = lambda raw: None
|
||||
|
||||
server._sessions["sid"] = _session()
|
||||
monkeypatch.setitem(sys.modules, "cli", fake_cli)
|
||||
|
||||
resp = server.handle_request(
|
||||
{"id": "1", "method": "image.attach", "params": {"session_id": "sid", "path": str(screenshot)}}
|
||||
)
|
||||
|
||||
assert resp["result"]["attached"] is True
|
||||
assert resp["result"]["path"] == str(screenshot)
|
||||
assert resp["result"]["remainder"] == ""
|
||||
assert len(server._sessions["sid"]["attached_images"]) == 1
|
||||
|
||||
|
||||
def test_commands_catalog_surfaces_quick_commands(monkeypatch):
|
||||
monkeypatch.setattr(server, "_load_cfg", lambda: {"quick_commands": {
|
||||
"build": {"type": "exec", "command": "npm run build"},
|
||||
|
|
|
|||
|
|
@ -20,11 +20,14 @@ from unittest.mock import MagicMock, patch
|
|||
from tools.delegate_tool import (
|
||||
DELEGATE_BLOCKED_TOOLS,
|
||||
DELEGATE_TASK_SCHEMA,
|
||||
DelegateEvent,
|
||||
_get_max_concurrent_children,
|
||||
_LEGACY_EVENT_MAP,
|
||||
MAX_DEPTH,
|
||||
check_delegate_requirements,
|
||||
delegate_task,
|
||||
_build_child_agent,
|
||||
_build_child_progress_callback,
|
||||
_build_child_system_prompt,
|
||||
_strip_blocked_tools,
|
||||
_resolve_child_credential_pool,
|
||||
|
|
@ -387,7 +390,7 @@ class TestToolNamePreservation(unittest.TestCase):
|
|||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = MagicMock()
|
||||
|
||||
def capture_and_return(user_message):
|
||||
def capture_and_return(user_message, task_id=None):
|
||||
captured["saved"] = list(mock_child._delegate_saved_tool_names)
|
||||
return {"final_response": "ok", "completed": True, "api_calls": 1}
|
||||
|
||||
|
|
@ -568,8 +571,16 @@ class TestBlockedTools(unittest.TestCase):
|
|||
self.assertIn(tool, DELEGATE_BLOCKED_TOOLS)
|
||||
|
||||
def test_constants(self):
|
||||
from tools.delegate_tool import (
|
||||
_get_max_spawn_depth, _get_orchestrator_enabled,
|
||||
_MIN_SPAWN_DEPTH, _MAX_SPAWN_DEPTH_CAP,
|
||||
)
|
||||
self.assertEqual(_get_max_concurrent_children(), 3)
|
||||
self.assertEqual(MAX_DEPTH, 2)
|
||||
self.assertEqual(MAX_DEPTH, 1)
|
||||
self.assertEqual(_get_max_spawn_depth(), 1) # default: flat
|
||||
self.assertTrue(_get_orchestrator_enabled()) # default
|
||||
self.assertEqual(_MIN_SPAWN_DEPTH, 1)
|
||||
self.assertEqual(_MAX_SPAWN_DEPTH_CAP, 3)
|
||||
|
||||
|
||||
class TestDelegationCredentialResolution(unittest.TestCase):
|
||||
|
|
@ -1325,5 +1336,635 @@ class TestDelegationReasoningEffort(unittest.TestCase):
|
|||
self.assertEqual(call_kwargs["reasoning_config"], {"enabled": True, "effort": "medium"})
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Dispatch helper, progress events, concurrency
|
||||
# =========================================================================
|
||||
|
||||
class TestDispatchDelegateTask(unittest.TestCase):
|
||||
"""Tests for the _dispatch_delegate_task helper and full param forwarding."""
|
||||
|
||||
@patch("tools.delegate_tool._load_config", return_value={})
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
def test_acp_args_forwarded(self, mock_creds, mock_cfg):
|
||||
"""Both acp_command and acp_args reach delegate_task via the helper."""
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
with patch("tools.delegate_tool._build_child_agent") as mock_build:
|
||||
mock_child = MagicMock()
|
||||
mock_child.run_conversation.return_value = {
|
||||
"final_response": "done", "completed": True,
|
||||
"api_calls": 1, "messages": [],
|
||||
}
|
||||
mock_child._delegate_saved_tool_names = []
|
||||
mock_child._credential_pool = None
|
||||
mock_child.session_prompt_tokens = 0
|
||||
mock_child.session_completion_tokens = 0
|
||||
mock_child.model = "test"
|
||||
mock_build.return_value = mock_child
|
||||
|
||||
delegate_task(
|
||||
goal="test",
|
||||
acp_command="claude",
|
||||
acp_args=["--acp", "--stdio"],
|
||||
parent_agent=parent,
|
||||
)
|
||||
_, kwargs = mock_build.call_args
|
||||
self.assertEqual(kwargs["override_acp_command"], "claude")
|
||||
self.assertEqual(kwargs["override_acp_args"], ["--acp", "--stdio"])
|
||||
|
||||
class TestDelegateEventEnum(unittest.TestCase):
|
||||
"""Tests for DelegateEvent enum and back-compat aliases."""
|
||||
|
||||
def test_enum_values_are_strings(self):
|
||||
for event in DelegateEvent:
|
||||
self.assertIsInstance(event.value, str)
|
||||
self.assertTrue(event.value.startswith("delegate."))
|
||||
|
||||
def test_legacy_map_covers_all_old_names(self):
|
||||
expected_legacy = {"_thinking", "reasoning.available",
|
||||
"tool.started", "tool.completed", "subagent_progress"}
|
||||
self.assertEqual(set(_LEGACY_EVENT_MAP.keys()), expected_legacy)
|
||||
|
||||
def test_legacy_map_values_are_delegate_events(self):
|
||||
for old_name, event in _LEGACY_EVENT_MAP.items():
|
||||
self.assertIsInstance(event, DelegateEvent)
|
||||
|
||||
def test_progress_callback_normalises_tool_started(self):
|
||||
"""_build_child_progress_callback handles tool.started via enum."""
|
||||
parent = _make_mock_parent()
|
||||
parent._delegate_spinner = MagicMock()
|
||||
parent.tool_progress_callback = MagicMock()
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
self.assertIsNotNone(cb)
|
||||
|
||||
cb("tool.started", tool_name="terminal", preview="ls")
|
||||
parent._delegate_spinner.print_above.assert_called()
|
||||
|
||||
def test_progress_callback_normalises_thinking(self):
|
||||
"""Both _thinking and reasoning.available route to TASK_THINKING."""
|
||||
parent = _make_mock_parent()
|
||||
parent._delegate_spinner = MagicMock()
|
||||
parent.tool_progress_callback = None
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
|
||||
cb("_thinking", tool_name=None, preview="pondering...")
|
||||
assert any("💭" in str(c) for c in parent._delegate_spinner.print_above.call_args_list)
|
||||
|
||||
parent._delegate_spinner.print_above.reset_mock()
|
||||
cb("reasoning.available", tool_name=None, preview="hmm")
|
||||
assert any("💭" in str(c) for c in parent._delegate_spinner.print_above.call_args_list)
|
||||
|
||||
def test_progress_callback_tool_completed_is_noop(self):
|
||||
"""tool.completed is normalised but produces no display output."""
|
||||
parent = _make_mock_parent()
|
||||
parent._delegate_spinner = MagicMock()
|
||||
parent.tool_progress_callback = None
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
cb("tool.completed", tool_name="terminal")
|
||||
parent._delegate_spinner.print_above.assert_not_called()
|
||||
|
||||
def test_progress_callback_ignores_unknown_events(self):
|
||||
"""Unknown event types are silently ignored."""
|
||||
parent = _make_mock_parent()
|
||||
parent._delegate_spinner = MagicMock()
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
# Should not raise
|
||||
cb("some.unknown.event", tool_name="x")
|
||||
parent._delegate_spinner.print_above.assert_not_called()
|
||||
|
||||
def test_progress_callback_accepts_enum_value_directly(self):
|
||||
"""cb(DelegateEvent.TASK_THINKING, ...) must route to the thinking
|
||||
branch. Pre-fix the callback only handled legacy strings via
|
||||
_LEGACY_EVENT_MAP.get and silently dropped enum-typed callers."""
|
||||
parent = _make_mock_parent()
|
||||
parent._delegate_spinner = MagicMock()
|
||||
parent.tool_progress_callback = None
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
cb(DelegateEvent.TASK_THINKING, preview="pondering")
|
||||
# If the enum was accepted, the thinking emoji got printed.
|
||||
assert any(
|
||||
"💭" in str(c)
|
||||
for c in parent._delegate_spinner.print_above.call_args_list
|
||||
)
|
||||
|
||||
def test_progress_callback_accepts_new_style_string(self):
|
||||
"""cb('delegate.task_thinking', ...) — the string form of the
|
||||
enum value — must route to the thinking branch too, so new-style
|
||||
emitters don't have to import DelegateEvent."""
|
||||
parent = _make_mock_parent()
|
||||
parent._delegate_spinner = MagicMock()
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
cb("delegate.task_thinking", preview="hmm")
|
||||
assert any(
|
||||
"💭" in str(c)
|
||||
for c in parent._delegate_spinner.print_above.call_args_list
|
||||
)
|
||||
|
||||
def test_progress_callback_task_progress_not_misrendered(self):
|
||||
"""'subagent_progress' (legacy name for TASK_PROGRESS) carries a
|
||||
pre-batched summary in the tool_name slot. Before the fix, this
|
||||
fell through to the TASK_TOOL_STARTED rendering path, treating
|
||||
the summary string as a tool name. After the fix: distinct
|
||||
render (no tool-start emoji lookup) and pass-through relay
|
||||
upward (no re-batching).
|
||||
|
||||
Regression path only reachable once nested orchestration is
|
||||
enabled: nested orchestrators relay subagent_progress from
|
||||
grandchildren upward through this callback.
|
||||
"""
|
||||
parent = _make_mock_parent()
|
||||
parent._delegate_spinner = MagicMock()
|
||||
parent.tool_progress_callback = MagicMock()
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
cb("subagent_progress", tool_name="🔀 [1] terminal, file")
|
||||
|
||||
# Spinner gets a distinct 🔀-prefixed line, NOT a tool emoji
|
||||
# followed by the summary string as if it were a tool name.
|
||||
calls = parent._delegate_spinner.print_above.call_args_list
|
||||
self.assertTrue(any("🔀 🔀 [1] terminal, file" in str(c) for c in calls))
|
||||
# Parent callback receives the relay (pass-through, no re-batching).
|
||||
parent.tool_progress_callback.assert_called_once()
|
||||
# No '⚡' tool-start emoji should appear — that's the pre-fix bug.
|
||||
self.assertFalse(any("⚡" in str(c) for c in calls))
|
||||
|
||||
|
||||
class TestConcurrencyDefaults(unittest.TestCase):
|
||||
"""Tests for the concurrency default and no hard ceiling."""
|
||||
|
||||
@patch("tools.delegate_tool._load_config", return_value={})
|
||||
def test_default_is_three(self, mock_cfg):
|
||||
# Clear env var if set
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
self.assertEqual(_get_max_concurrent_children(), 3)
|
||||
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_concurrent_children": 10})
|
||||
def test_no_upper_ceiling(self, mock_cfg):
|
||||
"""Users can raise concurrency as high as they want — no hard cap."""
|
||||
self.assertEqual(_get_max_concurrent_children(), 10)
|
||||
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_concurrent_children": 100})
|
||||
def test_very_high_values_honored(self, mock_cfg):
|
||||
self.assertEqual(_get_max_concurrent_children(), 100)
|
||||
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_concurrent_children": 0})
|
||||
def test_zero_clamped_to_one(self, mock_cfg):
|
||||
"""Floor of 1 is enforced; zero or negative values raise to 1."""
|
||||
self.assertEqual(_get_max_concurrent_children(), 1)
|
||||
|
||||
@patch("tools.delegate_tool._load_config", return_value={})
|
||||
def test_env_var_honored_uncapped(self, mock_cfg):
|
||||
with patch.dict(os.environ, {"DELEGATION_MAX_CONCURRENT_CHILDREN": "12"}):
|
||||
self.assertEqual(_get_max_concurrent_children(), 12)
|
||||
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_concurrent_children": 6})
|
||||
def test_configured_value_returned(self, mock_cfg):
|
||||
self.assertEqual(_get_max_concurrent_children(), 6)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# max_spawn_depth clamping
|
||||
# =========================================================================
|
||||
|
||||
class TestMaxSpawnDepth(unittest.TestCase):
|
||||
"""Tests for _get_max_spawn_depth clamping and fallback behavior."""
|
||||
|
||||
@patch("tools.delegate_tool._load_config", return_value={})
|
||||
def test_max_spawn_depth_defaults_to_1(self, mock_cfg):
|
||||
from tools.delegate_tool import _get_max_spawn_depth
|
||||
self.assertEqual(_get_max_spawn_depth(), 1)
|
||||
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 0})
|
||||
def test_max_spawn_depth_clamped_below_one(self, mock_cfg):
|
||||
import logging
|
||||
from tools.delegate_tool import _get_max_spawn_depth
|
||||
with self.assertLogs("tools.delegate_tool", level=logging.WARNING) as cm:
|
||||
result = _get_max_spawn_depth()
|
||||
self.assertEqual(result, 1)
|
||||
self.assertTrue(any("clamping to 1" in m for m in cm.output))
|
||||
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 99})
|
||||
def test_max_spawn_depth_clamped_above_three(self, mock_cfg):
|
||||
import logging
|
||||
from tools.delegate_tool import _get_max_spawn_depth
|
||||
with self.assertLogs("tools.delegate_tool", level=logging.WARNING) as cm:
|
||||
result = _get_max_spawn_depth()
|
||||
self.assertEqual(result, 3)
|
||||
self.assertTrue(any("clamping to 3" in m for m in cm.output))
|
||||
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": "not-a-number"})
|
||||
def test_max_spawn_depth_invalid_falls_back_to_default(self, mock_cfg):
|
||||
from tools.delegate_tool import _get_max_spawn_depth
|
||||
self.assertEqual(_get_max_spawn_depth(), 1)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# role param plumbing
|
||||
# =========================================================================
|
||||
#
|
||||
# These tests cover the schema + signature + stash plumbing of the role
|
||||
# param. The full role-honoring behavior (toolset re-add, role-aware
|
||||
# prompt) lives in TestOrchestratorRoleBehavior below; these tests only
|
||||
# assert on _delegate_role stashing and on the schema shape.
|
||||
|
||||
|
||||
class TestOrchestratorRoleSchema(unittest.TestCase):
|
||||
"""Tests that the role param reaches the child via dispatch."""
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 2})
|
||||
def _run_with_mock_child(self, role_arg, mock_cfg, mock_creds):
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = MagicMock()
|
||||
mock_child.run_conversation.return_value = {
|
||||
"final_response": "done", "completed": True,
|
||||
"api_calls": 1, "messages": [],
|
||||
}
|
||||
mock_child._delegate_saved_tool_names = []
|
||||
mock_child._credential_pool = None
|
||||
mock_child.session_prompt_tokens = 0
|
||||
mock_child.session_completion_tokens = 0
|
||||
mock_child.model = "test"
|
||||
MockAgent.return_value = mock_child
|
||||
kwargs = {"goal": "test", "parent_agent": parent}
|
||||
if role_arg is not _SENTINEL:
|
||||
kwargs["role"] = role_arg
|
||||
delegate_task(**kwargs)
|
||||
return mock_child
|
||||
|
||||
def test_default_role_is_leaf(self):
|
||||
child = self._run_with_mock_child(_SENTINEL)
|
||||
self.assertEqual(child._delegate_role, "leaf")
|
||||
|
||||
def test_explicit_orchestrator_role_stashed(self):
|
||||
"""role='orchestrator' reaches _build_child_agent and is stashed.
|
||||
Full behavior (toolset re-add) lands in commit 3; commit 2 only
|
||||
verifies the plumbing."""
|
||||
child = self._run_with_mock_child("orchestrator")
|
||||
self.assertEqual(child._delegate_role, "orchestrator")
|
||||
|
||||
def test_unknown_role_coerces_to_leaf(self):
|
||||
"""role='nonsense' → _normalize_role warns and returns 'leaf'."""
|
||||
import logging
|
||||
with self.assertLogs("tools.delegate_tool", level=logging.WARNING) as cm:
|
||||
child = self._run_with_mock_child("nonsense")
|
||||
self.assertEqual(child._delegate_role, "leaf")
|
||||
self.assertTrue(any("coercing" in m.lower() for m in cm.output))
|
||||
|
||||
def test_schema_has_role_top_level_and_per_task(self):
|
||||
from tools.delegate_tool import DELEGATE_TASK_SCHEMA
|
||||
props = DELEGATE_TASK_SCHEMA["parameters"]["properties"]
|
||||
self.assertIn("role", props)
|
||||
self.assertEqual(props["role"]["enum"], ["leaf", "orchestrator"])
|
||||
task_props = props["tasks"]["items"]["properties"]
|
||||
self.assertIn("role", task_props)
|
||||
self.assertEqual(task_props["role"]["enum"], ["leaf", "orchestrator"])
|
||||
|
||||
|
||||
# Sentinel used to distinguish "role kwarg omitted" from "role=None".
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# role-honoring behavior
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _make_role_mock_child():
|
||||
"""Helper: mock child with minimal fields for delegate_task to process."""
|
||||
mock_child = MagicMock()
|
||||
mock_child.run_conversation.return_value = {
|
||||
"final_response": "done", "completed": True,
|
||||
"api_calls": 1, "messages": [],
|
||||
}
|
||||
mock_child._delegate_saved_tool_names = []
|
||||
mock_child._credential_pool = None
|
||||
mock_child.session_prompt_tokens = 0
|
||||
mock_child.session_completion_tokens = 0
|
||||
mock_child.model = "test"
|
||||
return mock_child
|
||||
|
||||
|
||||
class TestOrchestratorRoleBehavior(unittest.TestCase):
|
||||
"""Tests that role='orchestrator' actually changes toolset + prompt."""
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 2})
|
||||
def test_orchestrator_role_keeps_delegation_at_depth_1(
|
||||
self, mock_cfg, mock_creds
|
||||
):
|
||||
"""role='orchestrator' + depth-0 parent with max_spawn_depth=2 →
|
||||
child at depth 1 gets 'delegation' in enabled_toolsets (can
|
||||
further delegate). Requires max_spawn_depth>=2 since the new
|
||||
default is 1 (flat)."""
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
parent.enabled_toolsets = ["terminal", "file"]
|
||||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = _make_role_mock_child()
|
||||
MockAgent.return_value = mock_child
|
||||
delegate_task(goal="test", role="orchestrator", parent_agent=parent)
|
||||
kwargs = MockAgent.call_args[1]
|
||||
self.assertIn("delegation", kwargs["enabled_toolsets"])
|
||||
self.assertEqual(mock_child._delegate_role, "orchestrator")
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 2})
|
||||
def test_orchestrator_blocked_at_max_spawn_depth(
|
||||
self, mock_cfg, mock_creds
|
||||
):
|
||||
"""Parent at depth 1 with max_spawn_depth=2 spawns child
|
||||
at depth 2 (the floor); role='orchestrator' degrades to leaf."""
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=1)
|
||||
parent.enabled_toolsets = ["terminal", "delegation"]
|
||||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = _make_role_mock_child()
|
||||
MockAgent.return_value = mock_child
|
||||
delegate_task(goal="test", role="orchestrator", parent_agent=parent)
|
||||
kwargs = MockAgent.call_args[1]
|
||||
self.assertNotIn("delegation", kwargs["enabled_toolsets"])
|
||||
self.assertEqual(mock_child._delegate_role, "leaf")
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
@patch("tools.delegate_tool._load_config", return_value={})
|
||||
def test_orchestrator_blocked_at_default_flat_depth(
|
||||
self, mock_cfg, mock_creds
|
||||
):
|
||||
"""With default max_spawn_depth=1 (flat), role='orchestrator'
|
||||
on a depth-0 parent produces a depth-1 child that is already at
|
||||
the floor — the role degrades to 'leaf' and the delegation
|
||||
toolset is stripped. This is the new default posture."""
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
parent.enabled_toolsets = ["terminal", "file", "delegation"]
|
||||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = _make_role_mock_child()
|
||||
MockAgent.return_value = mock_child
|
||||
delegate_task(goal="test", role="orchestrator", parent_agent=parent)
|
||||
kwargs = MockAgent.call_args[1]
|
||||
self.assertNotIn("delegation", kwargs["enabled_toolsets"])
|
||||
self.assertEqual(mock_child._delegate_role, "leaf")
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
def test_orchestrator_enabled_false_forces_leaf(self, mock_creds):
|
||||
"""Kill switch delegation.orchestrator_enabled=false overrides
|
||||
role='orchestrator'."""
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
parent.enabled_toolsets = ["terminal", "delegation"]
|
||||
with patch("tools.delegate_tool._load_config",
|
||||
return_value={"orchestrator_enabled": False}):
|
||||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = _make_role_mock_child()
|
||||
MockAgent.return_value = mock_child
|
||||
delegate_task(goal="test", role="orchestrator",
|
||||
parent_agent=parent)
|
||||
kwargs = MockAgent.call_args[1]
|
||||
self.assertNotIn("delegation", kwargs["enabled_toolsets"])
|
||||
self.assertEqual(mock_child._delegate_role, "leaf")
|
||||
|
||||
# ── Role-aware system prompt ────────────────────────────────────────
|
||||
|
||||
def test_leaf_prompt_does_not_mention_delegation(self):
|
||||
prompt = _build_child_system_prompt(
|
||||
"Fix tests", role="leaf",
|
||||
max_spawn_depth=2, child_depth=1,
|
||||
)
|
||||
self.assertNotIn("delegate_task", prompt)
|
||||
self.assertNotIn("Orchestrator Role", prompt)
|
||||
|
||||
def test_orchestrator_prompt_mentions_delegation_capability(self):
|
||||
prompt = _build_child_system_prompt(
|
||||
"Survey approaches", role="orchestrator",
|
||||
max_spawn_depth=2, child_depth=1,
|
||||
)
|
||||
self.assertIn("delegate_task", prompt)
|
||||
self.assertIn("Orchestrator Role", prompt)
|
||||
# Depth/max-depth note present and literal:
|
||||
self.assertIn("depth 1", prompt)
|
||||
self.assertIn("max_spawn_depth=2", prompt)
|
||||
|
||||
def test_orchestrator_prompt_at_depth_floor_says_children_are_leaves(self):
|
||||
"""With max_spawn_depth=2 and child_depth=1, the orchestrator's
|
||||
own children would be at depth 2 (the floor) → must be leaves."""
|
||||
prompt = _build_child_system_prompt(
|
||||
"Survey", role="orchestrator",
|
||||
max_spawn_depth=2, child_depth=1,
|
||||
)
|
||||
self.assertIn("MUST be leaves", prompt)
|
||||
|
||||
def test_orchestrator_prompt_below_floor_allows_more_nesting(self):
|
||||
"""With max_spawn_depth=3 and child_depth=1, the orchestrator's
|
||||
own children can themselves be orchestrators (depth 2 < 3)."""
|
||||
prompt = _build_child_system_prompt(
|
||||
"Deep work", role="orchestrator",
|
||||
max_spawn_depth=3, child_depth=1,
|
||||
)
|
||||
self.assertIn("can themselves be orchestrators", prompt)
|
||||
|
||||
# ── Batch mode and intersection ─────────────────────────────────────
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 2})
|
||||
def test_batch_mode_per_task_role_override(self, mock_cfg, mock_creds):
|
||||
"""Per-task role beats top-level; no top-level role → "leaf".
|
||||
|
||||
tasks=[{role:'orchestrator'},{role:'leaf'},{}] → first gets
|
||||
delegation, second and third don't. Requires max_spawn_depth>=2
|
||||
(raised explicitly here) since the new default is 1 (flat).
|
||||
"""
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
parent.enabled_toolsets = ["terminal", "file", "delegation"]
|
||||
built_toolsets = []
|
||||
|
||||
def _factory(*a, **kw):
|
||||
m = _make_role_mock_child()
|
||||
built_toolsets.append(kw.get("enabled_toolsets"))
|
||||
return m
|
||||
|
||||
with patch("run_agent.AIAgent", side_effect=_factory):
|
||||
delegate_task(
|
||||
tasks=[
|
||||
{"goal": "A", "role": "orchestrator"},
|
||||
{"goal": "B", "role": "leaf"},
|
||||
{"goal": "C"}, # no role → falls back to top_role (leaf)
|
||||
],
|
||||
parent_agent=parent,
|
||||
)
|
||||
self.assertIn("delegation", built_toolsets[0])
|
||||
self.assertNotIn("delegation", built_toolsets[1])
|
||||
self.assertNotIn("delegation", built_toolsets[2])
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 2})
|
||||
def test_intersection_preserves_delegation_bound(
|
||||
self, mock_cfg, mock_creds
|
||||
):
|
||||
"""Design decision: orchestrator capability is granted by role,
|
||||
NOT inherited from the parent's toolset. A parent without
|
||||
'delegation' in its enabled_toolsets can still spawn an
|
||||
orchestrator child — the re-add in _build_child_agent runs
|
||||
unconditionally for orchestrators (when max_spawn_depth allows).
|
||||
|
||||
If you want to change to "parent must have delegation too",
|
||||
update _build_child_agent to check parent_toolsets before the
|
||||
re-add and update this test to match.
|
||||
"""
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
parent.enabled_toolsets = ["terminal", "file"] # no delegation
|
||||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = _make_role_mock_child()
|
||||
MockAgent.return_value = mock_child
|
||||
delegate_task(goal="test", role="orchestrator",
|
||||
parent_agent=parent)
|
||||
self.assertIn("delegation", MockAgent.call_args[1]["enabled_toolsets"])
|
||||
|
||||
|
||||
class TestOrchestratorEndToEnd(unittest.TestCase):
|
||||
"""End-to-end: parent -> orchestrator -> two-leaf nested orchestration.
|
||||
|
||||
Covers the acceptance gate: parent delegates to an orchestrator
|
||||
child; the orchestrator delegates to two leaf grandchildren; the
|
||||
role/toolset/depth chain all resolve correctly.
|
||||
|
||||
Mock strategy: a single AIAgent patch with a side_effect factory
|
||||
that keys on the child's ephemeral_system_prompt — orchestrator
|
||||
prompts contain the string "Orchestrator Role" (see
|
||||
_build_child_system_prompt), leaves don't. The orchestrator
|
||||
mock's run_conversation recursively calls delegate_task with
|
||||
tasks=[{goal:...},{goal:...}] to spawn two leaves. This keeps
|
||||
the test in one patch context and avoids depth-indexed nesting.
|
||||
"""
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 2})
|
||||
def test_end_to_end_nested_orchestration(self, mock_cfg, mock_creds):
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
parent.enabled_toolsets = ["terminal", "file", "delegation"]
|
||||
|
||||
# (enabled_toolsets, _delegate_role) for each agent built
|
||||
built_agents: list = []
|
||||
# Keep the orchestrator mock around so the re-entrant delegate_task
|
||||
# can reach it via closure.
|
||||
orch_mock = {}
|
||||
|
||||
def _factory(*a, **kw):
|
||||
prompt = kw.get("ephemeral_system_prompt", "") or ""
|
||||
is_orchestrator = "Orchestrator Role" in prompt
|
||||
m = _make_role_mock_child()
|
||||
built_agents.append({
|
||||
"enabled_toolsets": list(kw.get("enabled_toolsets") or []),
|
||||
"is_orchestrator_prompt": is_orchestrator,
|
||||
})
|
||||
|
||||
if is_orchestrator:
|
||||
# Prepare the orchestrator mock as a parent-capable object
|
||||
# so the nested delegate_task call succeeds.
|
||||
m._delegate_depth = 1
|
||||
m._delegate_role = "orchestrator"
|
||||
m._active_children = []
|
||||
m._active_children_lock = threading.Lock()
|
||||
m._session_db = None
|
||||
m.platform = "cli"
|
||||
m.enabled_toolsets = ["terminal", "file", "delegation"]
|
||||
m.api_key = "***"
|
||||
m.base_url = ""
|
||||
m.provider = None
|
||||
m.api_mode = None
|
||||
m.providers_allowed = None
|
||||
m.providers_ignored = None
|
||||
m.providers_order = None
|
||||
m.provider_sort = None
|
||||
m._print_fn = None
|
||||
m.tool_progress_callback = None
|
||||
m.thinking_callback = None
|
||||
orch_mock["agent"] = m
|
||||
|
||||
def _orchestrator_run(user_message=None, task_id=None):
|
||||
# Re-entrant: orchestrator spawns two leaves
|
||||
delegate_task(
|
||||
tasks=[{"goal": "leaf-A"}, {"goal": "leaf-B"}],
|
||||
parent_agent=m,
|
||||
)
|
||||
return {
|
||||
"final_response": "orchestrated 2 workers",
|
||||
"completed": True, "api_calls": 1,
|
||||
"messages": [],
|
||||
}
|
||||
m.run_conversation.side_effect = _orchestrator_run
|
||||
|
||||
return m
|
||||
|
||||
with patch("run_agent.AIAgent", side_effect=_factory) as MockAgent:
|
||||
delegate_task(
|
||||
goal="top-level orchestration",
|
||||
role="orchestrator",
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
# 1 orchestrator + 2 leaf grandchildren = 3 agents
|
||||
self.assertEqual(MockAgent.call_count, 3)
|
||||
# First built = the orchestrator (parent's direct child)
|
||||
self.assertIn("delegation", built_agents[0]["enabled_toolsets"])
|
||||
self.assertTrue(built_agents[0]["is_orchestrator_prompt"])
|
||||
# Next two = leaves (grandchildren)
|
||||
self.assertNotIn("delegation", built_agents[1]["enabled_toolsets"])
|
||||
self.assertFalse(built_agents[1]["is_orchestrator_prompt"])
|
||||
self.assertNotIn("delegation", built_agents[2]["enabled_toolsets"])
|
||||
self.assertFalse(built_agents[2]["is_orchestrator_prompt"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
287
tests/tools/test_file_state_registry.py
Normal file
287
tests/tools/test_file_state_registry.py
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Tests for the cross-agent FileStateRegistry (tools/file_state.py).
|
||||
|
||||
Covers the three layers added for safe concurrent subagent file edits:
|
||||
|
||||
1. Cross-agent staleness detection via ``check_stale``
|
||||
2. Per-path serialization via ``lock_path``
|
||||
3. Delegate-completion reminder via ``writes_since``
|
||||
|
||||
Plus integration through the real ``read_file_tool`` / ``write_file_tool``
|
||||
/ ``patch_tool`` handlers so the full hook wiring is exercised.
|
||||
|
||||
Run:
|
||||
python -m pytest tests/tools/test_file_state_registry.py -v
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from tools import file_state
|
||||
from tools.file_tools import (
|
||||
read_file_tool,
|
||||
write_file_tool,
|
||||
patch_tool,
|
||||
)
|
||||
|
||||
|
||||
def _tmp_file(content: str = "initial\n") -> str:
|
||||
fd, path = tempfile.mkstemp(prefix="hermes_file_state_test_", suffix=".txt")
|
||||
with os.fdopen(fd, "w") as f:
|
||||
f.write(content)
|
||||
return path
|
||||
|
||||
|
||||
class FileStateRegistryUnitTests(unittest.TestCase):
|
||||
"""Direct unit tests on the registry singleton."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
file_state.get_registry().clear()
|
||||
self._tmpfiles: list[str] = []
|
||||
|
||||
def tearDown(self) -> None:
|
||||
for p in self._tmpfiles:
|
||||
try:
|
||||
os.unlink(p)
|
||||
except OSError:
|
||||
pass
|
||||
file_state.get_registry().clear()
|
||||
|
||||
def _mk(self, content: str = "x\n") -> str:
|
||||
p = _tmp_file(content)
|
||||
self._tmpfiles.append(p)
|
||||
return p
|
||||
|
||||
def test_record_read_then_check_stale_returns_none(self):
|
||||
p = self._mk()
|
||||
file_state.record_read("A", p)
|
||||
self.assertIsNone(file_state.check_stale("A", p))
|
||||
|
||||
def test_sibling_write_flags_other_agent_as_stale(self):
|
||||
p = self._mk()
|
||||
file_state.record_read("A", p)
|
||||
# Simulate sibling writing this file later
|
||||
time.sleep(0.01) # ensure ts ordering across resolution
|
||||
file_state.note_write("B", p)
|
||||
warn = file_state.check_stale("A", p)
|
||||
self.assertIsNotNone(warn)
|
||||
self.assertIn("B", warn)
|
||||
self.assertIn("sibling", warn.lower())
|
||||
|
||||
def test_write_without_read_flagged(self):
|
||||
p = self._mk()
|
||||
# Agent A never read this file.
|
||||
file_state.note_write("B", p) # another agent touched it
|
||||
warn = file_state.check_stale("A", p)
|
||||
self.assertIsNotNone(warn)
|
||||
|
||||
def test_partial_read_flagged_on_write(self):
|
||||
p = self._mk()
|
||||
file_state.record_read("A", p, partial=True)
|
||||
warn = file_state.check_stale("A", p)
|
||||
self.assertIsNotNone(warn)
|
||||
self.assertIn("partial", warn.lower())
|
||||
|
||||
def test_external_mtime_drift_flagged(self):
|
||||
p = self._mk()
|
||||
file_state.record_read("A", p)
|
||||
# Bump the on-disk mtime without going through the registry.
|
||||
time.sleep(0.01)
|
||||
os.utime(p, None)
|
||||
with open(p, "w") as f:
|
||||
f.write("externally modified\n")
|
||||
warn = file_state.check_stale("A", p)
|
||||
self.assertIsNotNone(warn)
|
||||
self.assertIn("modified since you last read", warn)
|
||||
|
||||
def test_own_write_updates_stamp_so_next_write_is_clean(self):
|
||||
p = self._mk()
|
||||
file_state.record_read("A", p)
|
||||
file_state.note_write("A", p)
|
||||
# Second write by the same agent — should not be flagged.
|
||||
self.assertIsNone(file_state.check_stale("A", p))
|
||||
|
||||
def test_different_paths_dont_interfere(self):
|
||||
a = self._mk()
|
||||
b = self._mk()
|
||||
file_state.record_read("A", a)
|
||||
file_state.note_write("B", b)
|
||||
# A reads only `a`; B writes `b`. A writing `a` is NOT stale.
|
||||
self.assertIsNone(file_state.check_stale("A", a))
|
||||
|
||||
def test_lock_path_serializes_same_path(self):
|
||||
p = self._mk()
|
||||
events: list[tuple[str, int]] = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def worker(i: int) -> None:
|
||||
with file_state.lock_path(p):
|
||||
with lock:
|
||||
events.append(("enter", i))
|
||||
time.sleep(0.01)
|
||||
with lock:
|
||||
events.append(("exit", i))
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(4)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Every enter must be immediately followed by its matching exit.
|
||||
self.assertEqual(len(events), 8)
|
||||
for i in range(0, 8, 2):
|
||||
self.assertEqual(events[i][0], "enter")
|
||||
self.assertEqual(events[i + 1][0], "exit")
|
||||
self.assertEqual(events[i][1], events[i + 1][1])
|
||||
|
||||
def test_lock_path_is_per_path_not_global(self):
|
||||
a = self._mk()
|
||||
b = self._mk()
|
||||
b_entered = threading.Event()
|
||||
|
||||
def hold_a() -> None:
|
||||
with file_state.lock_path(a):
|
||||
b_entered.wait(timeout=2.0)
|
||||
|
||||
def enter_b() -> None:
|
||||
time.sleep(0.02) # let A grab its lock
|
||||
with file_state.lock_path(b):
|
||||
b_entered.set()
|
||||
|
||||
ta = threading.Thread(target=hold_a)
|
||||
tb = threading.Thread(target=enter_b)
|
||||
ta.start()
|
||||
tb.start()
|
||||
self.assertTrue(b_entered.wait(timeout=3.0))
|
||||
ta.join(timeout=3.0)
|
||||
tb.join(timeout=3.0)
|
||||
|
||||
def test_writes_since_filters_by_parent_read_set(self):
|
||||
foo = self._mk()
|
||||
bar = self._mk()
|
||||
baz = self._mk()
|
||||
file_state.record_read("parent", foo)
|
||||
file_state.record_read("parent", bar)
|
||||
since = time.time()
|
||||
time.sleep(0.01)
|
||||
file_state.note_write("child", foo) # parent read this — report
|
||||
file_state.note_write("child", baz) # parent never saw — skip
|
||||
|
||||
# Caller passes only paths the parent actually read (this is what
|
||||
# delegate_tool does via ``known_reads(parent_task_id)``).
|
||||
parent_reads = file_state.known_reads("parent")
|
||||
out = file_state.writes_since("parent", since, parent_reads)
|
||||
self.assertIn("child", out)
|
||||
self.assertIn(foo, out["child"])
|
||||
self.assertNotIn(baz, out["child"])
|
||||
|
||||
def test_writes_since_excludes_the_target_agent(self):
|
||||
p = self._mk()
|
||||
file_state.record_read("parent", p)
|
||||
since = time.time()
|
||||
time.sleep(0.01)
|
||||
file_state.note_write("parent", p) # parent's own write
|
||||
out = file_state.writes_since("parent", since, [p])
|
||||
self.assertEqual(out, {})
|
||||
|
||||
def test_kill_switch_env_var(self):
|
||||
p = self._mk()
|
||||
os.environ["HERMES_DISABLE_FILE_STATE_GUARD"] = "1"
|
||||
try:
|
||||
file_state.record_read("A", p)
|
||||
file_state.note_write("B", p)
|
||||
self.assertIsNone(file_state.check_stale("A", p))
|
||||
self.assertEqual(file_state.known_reads("A"), [])
|
||||
self.assertEqual(
|
||||
file_state.writes_since("A", 0.0, [p]),
|
||||
{},
|
||||
)
|
||||
finally:
|
||||
del os.environ["HERMES_DISABLE_FILE_STATE_GUARD"]
|
||||
|
||||
|
||||
class FileToolsIntegrationTests(unittest.TestCase):
|
||||
"""Integration through the real file_tools handlers.
|
||||
|
||||
These exercise the wiring: read_file_tool → registry.record_read,
|
||||
write_file_tool / patch_tool → check_stale + lock_path + note_write.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
file_state.get_registry().clear()
|
||||
self._tmpdir = tempfile.mkdtemp(prefix="hermes_file_state_int_")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
import shutil
|
||||
shutil.rmtree(self._tmpdir, ignore_errors=True)
|
||||
file_state.get_registry().clear()
|
||||
|
||||
def _write_seed(self, name: str, content: str = "seed\n") -> str:
|
||||
p = os.path.join(self._tmpdir, name)
|
||||
with open(p, "w") as f:
|
||||
f.write(content)
|
||||
return p
|
||||
|
||||
def test_sibling_agent_write_surfaces_warning_through_handler(self):
|
||||
p = self._write_seed("shared.txt")
|
||||
r = json.loads(read_file_tool(path=p, task_id="agentA"))
|
||||
self.assertNotIn("error", r)
|
||||
|
||||
w_b = json.loads(write_file_tool(path=p, content="B wrote\n", task_id="agentB"))
|
||||
self.assertNotIn("error", w_b)
|
||||
|
||||
w_a = json.loads(write_file_tool(path=p, content="A stale\n", task_id="agentA"))
|
||||
warn = w_a.get("_warning", "")
|
||||
self.assertTrue(warn, f"expected warning, got: {w_a}")
|
||||
# The cross-agent message names the sibling task_id.
|
||||
self.assertIn("agentB", warn)
|
||||
self.assertIn("sibling", warn.lower())
|
||||
|
||||
def test_same_agent_consecutive_writes_no_false_warning(self):
|
||||
p = self._write_seed("own.txt")
|
||||
json.loads(read_file_tool(path=p, task_id="agentC"))
|
||||
w1 = json.loads(write_file_tool(path=p, content="one\n", task_id="agentC"))
|
||||
self.assertFalse(w1.get("_warning"))
|
||||
w2 = json.loads(write_file_tool(path=p, content="two\n", task_id="agentC"))
|
||||
self.assertFalse(w2.get("_warning"))
|
||||
|
||||
def test_patch_tool_also_surfaces_sibling_warning(self):
|
||||
p = self._write_seed("p.txt", "hello world\n")
|
||||
json.loads(read_file_tool(path=p, task_id="agentA"))
|
||||
json.loads(write_file_tool(path=p, content="hello planet\n", task_id="agentB"))
|
||||
r = json.loads(
|
||||
patch_tool(
|
||||
mode="replace",
|
||||
path=p,
|
||||
old_string="hello",
|
||||
new_string="HI",
|
||||
task_id="agentA",
|
||||
)
|
||||
)
|
||||
warn = r.get("_warning", "")
|
||||
# Patch may fail (sibling changed the content so old_string may not
|
||||
# match) or succeed — either way, the cross-agent warning should be
|
||||
# present when old_string still happens to match. What matters is
|
||||
# that if the patch succeeded or the warning was reported, it names
|
||||
# the sibling. When old_string doesn't match, the patch itself
|
||||
# returns an error but the warning is still set from the pre-check.
|
||||
if warn:
|
||||
self.assertIn("agentB", warn)
|
||||
|
||||
def test_net_new_file_no_warning(self):
|
||||
p = os.path.join(self._tmpdir, "brand_new.txt")
|
||||
# Nobody has read or written this before.
|
||||
w = json.loads(write_file_tool(path=p, content="hi\n", task_id="agentX"))
|
||||
self.assertFalse(w.get("_warning"))
|
||||
self.assertNotIn("error", w)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -136,6 +136,49 @@ class TestGptLiteralFamily:
|
|||
assert p["image_size"] == "1024x1536"
|
||||
|
||||
|
||||
class TestGptImage2Presets:
|
||||
"""GPT Image 2 uses preset enum sizes (not literal strings like 1.5).
|
||||
Mapped to 4:3 variants so we stay above the 655,360 min-pixel floor
|
||||
(16:9 presets at 1024x576 = 589,824 would be rejected)."""
|
||||
|
||||
def test_gpt2_landscape_uses_4_3_preset(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-2", "hello", "landscape")
|
||||
assert p["image_size"] == "landscape_4_3"
|
||||
|
||||
def test_gpt2_square_uses_square_hd(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-2", "hello", "square")
|
||||
assert p["image_size"] == "square_hd"
|
||||
|
||||
def test_gpt2_portrait_uses_4_3_preset(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-2", "hello", "portrait")
|
||||
assert p["image_size"] == "portrait_4_3"
|
||||
|
||||
def test_gpt2_quality_pinned_to_medium(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-2", "hi", "square")
|
||||
assert p["quality"] == "medium"
|
||||
|
||||
def test_gpt2_strips_byok_and_unsupported_overrides(self, image_tool):
|
||||
"""openai_api_key (BYOK) is deliberately not in supports — all users
|
||||
route through shared FAL billing. guidance_scale/num_inference_steps
|
||||
aren't in the model's API surface either."""
|
||||
p = image_tool._build_fal_payload(
|
||||
"fal-ai/gpt-image-2", "hi", "square",
|
||||
overrides={
|
||||
"openai_api_key": "sk-...",
|
||||
"guidance_scale": 7.5,
|
||||
"num_inference_steps": 50,
|
||||
},
|
||||
)
|
||||
assert "openai_api_key" not in p
|
||||
assert "guidance_scale" not in p
|
||||
assert "num_inference_steps" not in p
|
||||
|
||||
def test_gpt2_strips_seed_even_if_passed(self, image_tool):
|
||||
# seed isn't in the GPT Image 2 API surface either.
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-2", "hi", "square", seed=42)
|
||||
assert "seed" not in p
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Supports whitelist — the main safety property
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -231,10 +274,11 @@ class TestGptQualityPinnedToMedium:
|
|||
assert p["quality"] == "medium"
|
||||
|
||||
def test_non_gpt_model_never_gets_quality(self, image_tool):
|
||||
"""quality is only meaningful for gpt-image-1.5 — other models should
|
||||
never have it in their payload."""
|
||||
"""quality is only meaningful for GPT-Image models (1.5, 2) — other
|
||||
models should never have it in their payload."""
|
||||
gpt_models = {"fal-ai/gpt-image-1.5", "fal-ai/gpt-image-2"}
|
||||
for mid in image_tool.FAL_MODELS:
|
||||
if mid == "fal-ai/gpt-image-1.5":
|
||||
if mid in gpt_models:
|
||||
continue
|
||||
p = image_tool._build_fal_payload(mid, "hi", "square")
|
||||
assert "quality" not in p, f"{mid} unexpectedly has 'quality' in payload"
|
||||
|
|
|
|||
197
tests/tools/test_tts_max_text_length.py
Normal file
197
tests/tools/test_tts_max_text_length.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
"""Tests for per-provider TTS input-character limits.
|
||||
|
||||
Replaces the old global ``MAX_TEXT_LENGTH = 4000`` cap that truncated every
|
||||
provider at 4000 chars even though OpenAI allows 4096, xAI allows 15000,
|
||||
MiniMax allows 10000, and ElevenLabs allows 5000-40000 depending on model.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.tts_tool import (
|
||||
ELEVENLABS_MODEL_MAX_TEXT_LENGTH,
|
||||
FALLBACK_MAX_TEXT_LENGTH,
|
||||
PROVIDER_MAX_TEXT_LENGTH,
|
||||
_resolve_max_text_length,
|
||||
)
|
||||
|
||||
|
||||
class TestResolveMaxTextLength:
|
||||
def test_edge_default(self):
|
||||
assert _resolve_max_text_length("edge", {}) == PROVIDER_MAX_TEXT_LENGTH["edge"]
|
||||
|
||||
def test_openai_default_is_4096(self):
|
||||
assert _resolve_max_text_length("openai", {}) == 4096
|
||||
|
||||
def test_xai_default_is_15000(self):
|
||||
assert _resolve_max_text_length("xai", {}) == 15000
|
||||
|
||||
def test_minimax_default_is_10000(self):
|
||||
assert _resolve_max_text_length("minimax", {}) == 10000
|
||||
|
||||
def test_mistral_default(self):
|
||||
assert _resolve_max_text_length("mistral", {}) == PROVIDER_MAX_TEXT_LENGTH["mistral"]
|
||||
|
||||
def test_gemini_default(self):
|
||||
assert _resolve_max_text_length("gemini", {}) == PROVIDER_MAX_TEXT_LENGTH["gemini"]
|
||||
|
||||
def test_unknown_provider_falls_back(self):
|
||||
assert _resolve_max_text_length("does-not-exist", {}) == FALLBACK_MAX_TEXT_LENGTH
|
||||
|
||||
def test_empty_provider_falls_back(self):
|
||||
assert _resolve_max_text_length("", {}) == FALLBACK_MAX_TEXT_LENGTH
|
||||
assert _resolve_max_text_length(None, {}) == FALLBACK_MAX_TEXT_LENGTH
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _resolve_max_text_length("OpenAI", {}) == 4096
|
||||
assert _resolve_max_text_length(" XAI ", {}) == 15000
|
||||
|
||||
# --- Overrides ---
|
||||
|
||||
def test_override_wins(self):
|
||||
cfg = {"openai": {"max_text_length": 9999}}
|
||||
assert _resolve_max_text_length("openai", cfg) == 9999
|
||||
|
||||
def test_override_zero_falls_through(self):
|
||||
# A broken/zero override must not disable truncation
|
||||
cfg = {"openai": {"max_text_length": 0}}
|
||||
assert _resolve_max_text_length("openai", cfg) == 4096
|
||||
|
||||
def test_override_negative_falls_through(self):
|
||||
cfg = {"xai": {"max_text_length": -1}}
|
||||
assert _resolve_max_text_length("xai", cfg) == 15000
|
||||
|
||||
def test_override_non_int_falls_through(self):
|
||||
cfg = {"minimax": {"max_text_length": "lots"}}
|
||||
assert _resolve_max_text_length("minimax", cfg) == 10000
|
||||
|
||||
def test_override_bool_falls_through(self):
|
||||
# bool is technically an int; make sure we don't treat True as 1 char
|
||||
cfg = {"openai": {"max_text_length": True}}
|
||||
assert _resolve_max_text_length("openai", cfg) == 4096
|
||||
|
||||
def test_missing_provider_section_uses_default(self):
|
||||
cfg = {"provider": "openai"} # no "openai" key
|
||||
assert _resolve_max_text_length("openai", cfg) == 4096
|
||||
|
||||
# --- ElevenLabs model-aware ---
|
||||
|
||||
def test_elevenlabs_default_model_multilingual_v2(self):
|
||||
cfg = {"elevenlabs": {"model_id": "eleven_multilingual_v2"}}
|
||||
assert _resolve_max_text_length("elevenlabs", cfg) == 10000
|
||||
|
||||
def test_elevenlabs_flash_v2_5_gets_40k(self):
|
||||
cfg = {"elevenlabs": {"model_id": "eleven_flash_v2_5"}}
|
||||
assert _resolve_max_text_length("elevenlabs", cfg) == 40000
|
||||
|
||||
def test_elevenlabs_flash_v2_gets_30k(self):
|
||||
cfg = {"elevenlabs": {"model_id": "eleven_flash_v2"}}
|
||||
assert _resolve_max_text_length("elevenlabs", cfg) == 30000
|
||||
|
||||
def test_elevenlabs_v3_gets_5k(self):
|
||||
cfg = {"elevenlabs": {"model_id": "eleven_v3"}}
|
||||
assert _resolve_max_text_length("elevenlabs", cfg) == 5000
|
||||
|
||||
def test_elevenlabs_unknown_model_falls_back_to_provider_default(self):
|
||||
cfg = {"elevenlabs": {"model_id": "eleven_experimental_xyz"}}
|
||||
assert _resolve_max_text_length("elevenlabs", cfg) == PROVIDER_MAX_TEXT_LENGTH["elevenlabs"]
|
||||
|
||||
def test_elevenlabs_override_beats_model_lookup(self):
|
||||
cfg = {"elevenlabs": {"model_id": "eleven_flash_v2_5", "max_text_length": 1000}}
|
||||
assert _resolve_max_text_length("elevenlabs", cfg) == 1000
|
||||
|
||||
def test_elevenlabs_no_model_id_uses_default_model_mapping(self):
|
||||
# Falls back to DEFAULT_ELEVENLABS_MODEL_ID = eleven_multilingual_v2 -> 10000
|
||||
assert _resolve_max_text_length("elevenlabs", {}) == 10000
|
||||
|
||||
def test_provider_config_not_a_dict(self):
|
||||
cfg = {"openai": "not-a-dict"}
|
||||
assert _resolve_max_text_length("openai", cfg) == 4096
|
||||
|
||||
# --- Sanity: the table covers every provider listed in the schema ---
|
||||
|
||||
def test_all_documented_providers_have_defaults(self):
|
||||
expected = {"edge", "openai", "xai", "minimax", "mistral",
|
||||
"gemini", "elevenlabs", "neutts", "kittentts"}
|
||||
assert expected.issubset(PROVIDER_MAX_TEXT_LENGTH.keys())
|
||||
|
||||
|
||||
class TestTextToSpeechToolTruncation:
|
||||
"""End-to-end: verify the resolver actually drives the text_to_speech_tool
|
||||
truncation path rather than the old 4000-char global."""
|
||||
|
||||
def test_openai_truncates_at_4096_not_4000(self, tmp_path, monkeypatch, caplog):
|
||||
import logging
|
||||
caplog.set_level(logging.WARNING, logger="tools.tts_tool")
|
||||
|
||||
# 5000 chars -- over OpenAI's 4096 limit but under xAI's 15k
|
||||
text = "A" * 5000
|
||||
captured_text = {}
|
||||
|
||||
def fake_openai(t, out, cfg):
|
||||
captured_text["text"] = t
|
||||
with open(out, "wb") as f:
|
||||
f.write(b"\x00")
|
||||
return out
|
||||
|
||||
monkeypatch.setattr("tools.tts_tool._generate_openai_tts", fake_openai)
|
||||
monkeypatch.setattr("tools.tts_tool._load_tts_config",
|
||||
lambda: {"provider": "openai"})
|
||||
|
||||
from tools.tts_tool import text_to_speech_tool
|
||||
out = str(tmp_path / "out.mp3")
|
||||
result = json.loads(text_to_speech_tool(text=text, output_path=out))
|
||||
|
||||
assert result["success"] is True
|
||||
# Should be truncated to 4096, not the old 4000
|
||||
assert len(captured_text["text"]) == 4096
|
||||
# And the warning should mention the provider
|
||||
assert any("openai" in rec.message.lower() for rec in caplog.records)
|
||||
|
||||
def test_xai_accepts_much_longer_input(self, tmp_path, monkeypatch):
|
||||
# 12000 chars -- over old global 4000, under xAI's 15000
|
||||
text = "B" * 12000
|
||||
captured_text = {}
|
||||
|
||||
def fake_xai(t, out, cfg):
|
||||
captured_text["text"] = t
|
||||
with open(out, "wb") as f:
|
||||
f.write(b"\x00")
|
||||
return out
|
||||
|
||||
monkeypatch.setattr("tools.tts_tool._generate_xai_tts", fake_xai)
|
||||
monkeypatch.setattr("tools.tts_tool._load_tts_config",
|
||||
lambda: {"provider": "xai"})
|
||||
|
||||
from tools.tts_tool import text_to_speech_tool
|
||||
out = str(tmp_path / "out.mp3")
|
||||
result = json.loads(text_to_speech_tool(text=text, output_path=out))
|
||||
|
||||
assert result["success"] is True
|
||||
# xAI should accept the full 12000 chars
|
||||
assert len(captured_text["text"]) == 12000
|
||||
|
||||
def test_user_override_is_respected(self, tmp_path, monkeypatch):
|
||||
# User says "cap openai at 100 chars" -- we must honor it
|
||||
text = "C" * 500
|
||||
captured_text = {}
|
||||
|
||||
def fake_openai(t, out, cfg):
|
||||
captured_text["text"] = t
|
||||
with open(out, "wb") as f:
|
||||
f.write(b"\x00")
|
||||
return out
|
||||
|
||||
monkeypatch.setattr("tools.tts_tool._generate_openai_tts", fake_openai)
|
||||
monkeypatch.setattr("tools.tts_tool._load_tts_config",
|
||||
lambda: {"provider": "openai",
|
||||
"openai": {"max_text_length": 100}})
|
||||
|
||||
from tools.tts_tool import text_to_speech_tool
|
||||
out = str(tmp_path / "out.mp3")
|
||||
result = json.loads(text_to_speech_tool(text=text, output_path=out))
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(captured_text["text"]) == 100
|
||||
Loading…
Add table
Add a link
Reference in a new issue