mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-08 03:01:47 +00:00
Merge remote-tracking branch 'origin/main' into sid/types-and-lints
# Conflicts: # gateway/run.py # tools/delegate_tool.py
This commit is contained in:
commit
847ffca715
171 changed files with 15125 additions and 1675 deletions
|
|
@ -20,11 +20,14 @@ from unittest.mock import MagicMock, patch
|
|||
from tools.delegate_tool import (
|
||||
DELEGATE_BLOCKED_TOOLS,
|
||||
DELEGATE_TASK_SCHEMA,
|
||||
DelegateEvent,
|
||||
_get_max_concurrent_children,
|
||||
_LEGACY_EVENT_MAP,
|
||||
MAX_DEPTH,
|
||||
check_delegate_requirements,
|
||||
delegate_task,
|
||||
_build_child_agent,
|
||||
_build_child_progress_callback,
|
||||
_build_child_system_prompt,
|
||||
_strip_blocked_tools,
|
||||
_resolve_child_credential_pool,
|
||||
|
|
@ -387,7 +390,7 @@ class TestToolNamePreservation(unittest.TestCase):
|
|||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = MagicMock()
|
||||
|
||||
def capture_and_return(user_message):
|
||||
def capture_and_return(user_message, task_id=None):
|
||||
captured["saved"] = list(mock_child._delegate_saved_tool_names)
|
||||
return {"final_response": "ok", "completed": True, "api_calls": 1}
|
||||
|
||||
|
|
@ -568,8 +571,16 @@ class TestBlockedTools(unittest.TestCase):
|
|||
self.assertIn(tool, DELEGATE_BLOCKED_TOOLS)
|
||||
|
||||
def test_constants(self):
|
||||
from tools.delegate_tool import (
|
||||
_get_max_spawn_depth, _get_orchestrator_enabled,
|
||||
_MIN_SPAWN_DEPTH, _MAX_SPAWN_DEPTH_CAP,
|
||||
)
|
||||
self.assertEqual(_get_max_concurrent_children(), 3)
|
||||
self.assertEqual(MAX_DEPTH, 2)
|
||||
self.assertEqual(MAX_DEPTH, 1)
|
||||
self.assertEqual(_get_max_spawn_depth(), 1) # default: flat
|
||||
self.assertTrue(_get_orchestrator_enabled()) # default
|
||||
self.assertEqual(_MIN_SPAWN_DEPTH, 1)
|
||||
self.assertEqual(_MAX_SPAWN_DEPTH_CAP, 3)
|
||||
|
||||
|
||||
class TestDelegationCredentialResolution(unittest.TestCase):
|
||||
|
|
@ -1325,5 +1336,635 @@ class TestDelegationReasoningEffort(unittest.TestCase):
|
|||
self.assertEqual(call_kwargs["reasoning_config"], {"enabled": True, "effort": "medium"})
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Dispatch helper, progress events, concurrency
|
||||
# =========================================================================
|
||||
|
||||
class TestDispatchDelegateTask(unittest.TestCase):
|
||||
"""Tests for the _dispatch_delegate_task helper and full param forwarding."""
|
||||
|
||||
@patch("tools.delegate_tool._load_config", return_value={})
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
def test_acp_args_forwarded(self, mock_creds, mock_cfg):
|
||||
"""Both acp_command and acp_args reach delegate_task via the helper."""
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
with patch("tools.delegate_tool._build_child_agent") as mock_build:
|
||||
mock_child = MagicMock()
|
||||
mock_child.run_conversation.return_value = {
|
||||
"final_response": "done", "completed": True,
|
||||
"api_calls": 1, "messages": [],
|
||||
}
|
||||
mock_child._delegate_saved_tool_names = []
|
||||
mock_child._credential_pool = None
|
||||
mock_child.session_prompt_tokens = 0
|
||||
mock_child.session_completion_tokens = 0
|
||||
mock_child.model = "test"
|
||||
mock_build.return_value = mock_child
|
||||
|
||||
delegate_task(
|
||||
goal="test",
|
||||
acp_command="claude",
|
||||
acp_args=["--acp", "--stdio"],
|
||||
parent_agent=parent,
|
||||
)
|
||||
_, kwargs = mock_build.call_args
|
||||
self.assertEqual(kwargs["override_acp_command"], "claude")
|
||||
self.assertEqual(kwargs["override_acp_args"], ["--acp", "--stdio"])
|
||||
|
||||
class TestDelegateEventEnum(unittest.TestCase):
|
||||
"""Tests for DelegateEvent enum and back-compat aliases."""
|
||||
|
||||
def test_enum_values_are_strings(self):
|
||||
for event in DelegateEvent:
|
||||
self.assertIsInstance(event.value, str)
|
||||
self.assertTrue(event.value.startswith("delegate."))
|
||||
|
||||
def test_legacy_map_covers_all_old_names(self):
|
||||
expected_legacy = {"_thinking", "reasoning.available",
|
||||
"tool.started", "tool.completed", "subagent_progress"}
|
||||
self.assertEqual(set(_LEGACY_EVENT_MAP.keys()), expected_legacy)
|
||||
|
||||
def test_legacy_map_values_are_delegate_events(self):
|
||||
for old_name, event in _LEGACY_EVENT_MAP.items():
|
||||
self.assertIsInstance(event, DelegateEvent)
|
||||
|
||||
def test_progress_callback_normalises_tool_started(self):
|
||||
"""_build_child_progress_callback handles tool.started via enum."""
|
||||
parent = _make_mock_parent()
|
||||
parent._delegate_spinner = MagicMock()
|
||||
parent.tool_progress_callback = MagicMock()
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
self.assertIsNotNone(cb)
|
||||
|
||||
cb("tool.started", tool_name="terminal", preview="ls")
|
||||
parent._delegate_spinner.print_above.assert_called()
|
||||
|
||||
def test_progress_callback_normalises_thinking(self):
|
||||
"""Both _thinking and reasoning.available route to TASK_THINKING."""
|
||||
parent = _make_mock_parent()
|
||||
parent._delegate_spinner = MagicMock()
|
||||
parent.tool_progress_callback = None
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
|
||||
cb("_thinking", tool_name=None, preview="pondering...")
|
||||
assert any("💭" in str(c) for c in parent._delegate_spinner.print_above.call_args_list)
|
||||
|
||||
parent._delegate_spinner.print_above.reset_mock()
|
||||
cb("reasoning.available", tool_name=None, preview="hmm")
|
||||
assert any("💭" in str(c) for c in parent._delegate_spinner.print_above.call_args_list)
|
||||
|
||||
def test_progress_callback_tool_completed_is_noop(self):
|
||||
"""tool.completed is normalised but produces no display output."""
|
||||
parent = _make_mock_parent()
|
||||
parent._delegate_spinner = MagicMock()
|
||||
parent.tool_progress_callback = None
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
cb("tool.completed", tool_name="terminal")
|
||||
parent._delegate_spinner.print_above.assert_not_called()
|
||||
|
||||
def test_progress_callback_ignores_unknown_events(self):
|
||||
"""Unknown event types are silently ignored."""
|
||||
parent = _make_mock_parent()
|
||||
parent._delegate_spinner = MagicMock()
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
# Should not raise
|
||||
cb("some.unknown.event", tool_name="x")
|
||||
parent._delegate_spinner.print_above.assert_not_called()
|
||||
|
||||
def test_progress_callback_accepts_enum_value_directly(self):
|
||||
"""cb(DelegateEvent.TASK_THINKING, ...) must route to the thinking
|
||||
branch. Pre-fix the callback only handled legacy strings via
|
||||
_LEGACY_EVENT_MAP.get and silently dropped enum-typed callers."""
|
||||
parent = _make_mock_parent()
|
||||
parent._delegate_spinner = MagicMock()
|
||||
parent.tool_progress_callback = None
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
cb(DelegateEvent.TASK_THINKING, preview="pondering")
|
||||
# If the enum was accepted, the thinking emoji got printed.
|
||||
assert any(
|
||||
"💭" in str(c)
|
||||
for c in parent._delegate_spinner.print_above.call_args_list
|
||||
)
|
||||
|
||||
def test_progress_callback_accepts_new_style_string(self):
|
||||
"""cb('delegate.task_thinking', ...) — the string form of the
|
||||
enum value — must route to the thinking branch too, so new-style
|
||||
emitters don't have to import DelegateEvent."""
|
||||
parent = _make_mock_parent()
|
||||
parent._delegate_spinner = MagicMock()
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
cb("delegate.task_thinking", preview="hmm")
|
||||
assert any(
|
||||
"💭" in str(c)
|
||||
for c in parent._delegate_spinner.print_above.call_args_list
|
||||
)
|
||||
|
||||
def test_progress_callback_task_progress_not_misrendered(self):
|
||||
"""'subagent_progress' (legacy name for TASK_PROGRESS) carries a
|
||||
pre-batched summary in the tool_name slot. Before the fix, this
|
||||
fell through to the TASK_TOOL_STARTED rendering path, treating
|
||||
the summary string as a tool name. After the fix: distinct
|
||||
render (no tool-start emoji lookup) and pass-through relay
|
||||
upward (no re-batching).
|
||||
|
||||
Regression path only reachable once nested orchestration is
|
||||
enabled: nested orchestrators relay subagent_progress from
|
||||
grandchildren upward through this callback.
|
||||
"""
|
||||
parent = _make_mock_parent()
|
||||
parent._delegate_spinner = MagicMock()
|
||||
parent.tool_progress_callback = MagicMock()
|
||||
|
||||
cb = _build_child_progress_callback(0, "test goal", parent, task_count=1)
|
||||
cb("subagent_progress", tool_name="🔀 [1] terminal, file")
|
||||
|
||||
# Spinner gets a distinct 🔀-prefixed line, NOT a tool emoji
|
||||
# followed by the summary string as if it were a tool name.
|
||||
calls = parent._delegate_spinner.print_above.call_args_list
|
||||
self.assertTrue(any("🔀 🔀 [1] terminal, file" in str(c) for c in calls))
|
||||
# Parent callback receives the relay (pass-through, no re-batching).
|
||||
parent.tool_progress_callback.assert_called_once()
|
||||
# No '⚡' tool-start emoji should appear — that's the pre-fix bug.
|
||||
self.assertFalse(any("⚡" in str(c) for c in calls))
|
||||
|
||||
|
||||
class TestConcurrencyDefaults(unittest.TestCase):
|
||||
"""Tests for the concurrency default and no hard ceiling."""
|
||||
|
||||
@patch("tools.delegate_tool._load_config", return_value={})
|
||||
def test_default_is_three(self, mock_cfg):
|
||||
# Clear env var if set
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
self.assertEqual(_get_max_concurrent_children(), 3)
|
||||
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_concurrent_children": 10})
|
||||
def test_no_upper_ceiling(self, mock_cfg):
|
||||
"""Users can raise concurrency as high as they want — no hard cap."""
|
||||
self.assertEqual(_get_max_concurrent_children(), 10)
|
||||
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_concurrent_children": 100})
|
||||
def test_very_high_values_honored(self, mock_cfg):
|
||||
self.assertEqual(_get_max_concurrent_children(), 100)
|
||||
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_concurrent_children": 0})
|
||||
def test_zero_clamped_to_one(self, mock_cfg):
|
||||
"""Floor of 1 is enforced; zero or negative values raise to 1."""
|
||||
self.assertEqual(_get_max_concurrent_children(), 1)
|
||||
|
||||
@patch("tools.delegate_tool._load_config", return_value={})
|
||||
def test_env_var_honored_uncapped(self, mock_cfg):
|
||||
with patch.dict(os.environ, {"DELEGATION_MAX_CONCURRENT_CHILDREN": "12"}):
|
||||
self.assertEqual(_get_max_concurrent_children(), 12)
|
||||
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_concurrent_children": 6})
|
||||
def test_configured_value_returned(self, mock_cfg):
|
||||
self.assertEqual(_get_max_concurrent_children(), 6)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# max_spawn_depth clamping
|
||||
# =========================================================================
|
||||
|
||||
class TestMaxSpawnDepth(unittest.TestCase):
|
||||
"""Tests for _get_max_spawn_depth clamping and fallback behavior."""
|
||||
|
||||
@patch("tools.delegate_tool._load_config", return_value={})
|
||||
def test_max_spawn_depth_defaults_to_1(self, mock_cfg):
|
||||
from tools.delegate_tool import _get_max_spawn_depth
|
||||
self.assertEqual(_get_max_spawn_depth(), 1)
|
||||
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 0})
|
||||
def test_max_spawn_depth_clamped_below_one(self, mock_cfg):
|
||||
import logging
|
||||
from tools.delegate_tool import _get_max_spawn_depth
|
||||
with self.assertLogs("tools.delegate_tool", level=logging.WARNING) as cm:
|
||||
result = _get_max_spawn_depth()
|
||||
self.assertEqual(result, 1)
|
||||
self.assertTrue(any("clamping to 1" in m for m in cm.output))
|
||||
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 99})
|
||||
def test_max_spawn_depth_clamped_above_three(self, mock_cfg):
|
||||
import logging
|
||||
from tools.delegate_tool import _get_max_spawn_depth
|
||||
with self.assertLogs("tools.delegate_tool", level=logging.WARNING) as cm:
|
||||
result = _get_max_spawn_depth()
|
||||
self.assertEqual(result, 3)
|
||||
self.assertTrue(any("clamping to 3" in m for m in cm.output))
|
||||
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": "not-a-number"})
|
||||
def test_max_spawn_depth_invalid_falls_back_to_default(self, mock_cfg):
|
||||
from tools.delegate_tool import _get_max_spawn_depth
|
||||
self.assertEqual(_get_max_spawn_depth(), 1)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# role param plumbing
|
||||
# =========================================================================
|
||||
#
|
||||
# These tests cover the schema + signature + stash plumbing of the role
|
||||
# param. The full role-honoring behavior (toolset re-add, role-aware
|
||||
# prompt) lives in TestOrchestratorRoleBehavior below; these tests only
|
||||
# assert on _delegate_role stashing and on the schema shape.
|
||||
|
||||
|
||||
class TestOrchestratorRoleSchema(unittest.TestCase):
|
||||
"""Tests that the role param reaches the child via dispatch."""
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 2})
|
||||
def _run_with_mock_child(self, role_arg, mock_cfg, mock_creds):
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = MagicMock()
|
||||
mock_child.run_conversation.return_value = {
|
||||
"final_response": "done", "completed": True,
|
||||
"api_calls": 1, "messages": [],
|
||||
}
|
||||
mock_child._delegate_saved_tool_names = []
|
||||
mock_child._credential_pool = None
|
||||
mock_child.session_prompt_tokens = 0
|
||||
mock_child.session_completion_tokens = 0
|
||||
mock_child.model = "test"
|
||||
MockAgent.return_value = mock_child
|
||||
kwargs = {"goal": "test", "parent_agent": parent}
|
||||
if role_arg is not _SENTINEL:
|
||||
kwargs["role"] = role_arg
|
||||
delegate_task(**kwargs)
|
||||
return mock_child
|
||||
|
||||
def test_default_role_is_leaf(self):
|
||||
child = self._run_with_mock_child(_SENTINEL)
|
||||
self.assertEqual(child._delegate_role, "leaf")
|
||||
|
||||
def test_explicit_orchestrator_role_stashed(self):
|
||||
"""role='orchestrator' reaches _build_child_agent and is stashed.
|
||||
Full behavior (toolset re-add) lands in commit 3; commit 2 only
|
||||
verifies the plumbing."""
|
||||
child = self._run_with_mock_child("orchestrator")
|
||||
self.assertEqual(child._delegate_role, "orchestrator")
|
||||
|
||||
def test_unknown_role_coerces_to_leaf(self):
|
||||
"""role='nonsense' → _normalize_role warns and returns 'leaf'."""
|
||||
import logging
|
||||
with self.assertLogs("tools.delegate_tool", level=logging.WARNING) as cm:
|
||||
child = self._run_with_mock_child("nonsense")
|
||||
self.assertEqual(child._delegate_role, "leaf")
|
||||
self.assertTrue(any("coercing" in m.lower() for m in cm.output))
|
||||
|
||||
def test_schema_has_role_top_level_and_per_task(self):
|
||||
from tools.delegate_tool import DELEGATE_TASK_SCHEMA
|
||||
props = DELEGATE_TASK_SCHEMA["parameters"]["properties"]
|
||||
self.assertIn("role", props)
|
||||
self.assertEqual(props["role"]["enum"], ["leaf", "orchestrator"])
|
||||
task_props = props["tasks"]["items"]["properties"]
|
||||
self.assertIn("role", task_props)
|
||||
self.assertEqual(task_props["role"]["enum"], ["leaf", "orchestrator"])
|
||||
|
||||
|
||||
# Sentinel used to distinguish "role kwarg omitted" from "role=None".
|
||||
_SENTINEL = object()
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# role-honoring behavior
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _make_role_mock_child():
|
||||
"""Helper: mock child with minimal fields for delegate_task to process."""
|
||||
mock_child = MagicMock()
|
||||
mock_child.run_conversation.return_value = {
|
||||
"final_response": "done", "completed": True,
|
||||
"api_calls": 1, "messages": [],
|
||||
}
|
||||
mock_child._delegate_saved_tool_names = []
|
||||
mock_child._credential_pool = None
|
||||
mock_child.session_prompt_tokens = 0
|
||||
mock_child.session_completion_tokens = 0
|
||||
mock_child.model = "test"
|
||||
return mock_child
|
||||
|
||||
|
||||
class TestOrchestratorRoleBehavior(unittest.TestCase):
|
||||
"""Tests that role='orchestrator' actually changes toolset + prompt."""
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 2})
|
||||
def test_orchestrator_role_keeps_delegation_at_depth_1(
|
||||
self, mock_cfg, mock_creds
|
||||
):
|
||||
"""role='orchestrator' + depth-0 parent with max_spawn_depth=2 →
|
||||
child at depth 1 gets 'delegation' in enabled_toolsets (can
|
||||
further delegate). Requires max_spawn_depth>=2 since the new
|
||||
default is 1 (flat)."""
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
parent.enabled_toolsets = ["terminal", "file"]
|
||||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = _make_role_mock_child()
|
||||
MockAgent.return_value = mock_child
|
||||
delegate_task(goal="test", role="orchestrator", parent_agent=parent)
|
||||
kwargs = MockAgent.call_args[1]
|
||||
self.assertIn("delegation", kwargs["enabled_toolsets"])
|
||||
self.assertEqual(mock_child._delegate_role, "orchestrator")
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 2})
|
||||
def test_orchestrator_blocked_at_max_spawn_depth(
|
||||
self, mock_cfg, mock_creds
|
||||
):
|
||||
"""Parent at depth 1 with max_spawn_depth=2 spawns child
|
||||
at depth 2 (the floor); role='orchestrator' degrades to leaf."""
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=1)
|
||||
parent.enabled_toolsets = ["terminal", "delegation"]
|
||||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = _make_role_mock_child()
|
||||
MockAgent.return_value = mock_child
|
||||
delegate_task(goal="test", role="orchestrator", parent_agent=parent)
|
||||
kwargs = MockAgent.call_args[1]
|
||||
self.assertNotIn("delegation", kwargs["enabled_toolsets"])
|
||||
self.assertEqual(mock_child._delegate_role, "leaf")
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
@patch("tools.delegate_tool._load_config", return_value={})
|
||||
def test_orchestrator_blocked_at_default_flat_depth(
|
||||
self, mock_cfg, mock_creds
|
||||
):
|
||||
"""With default max_spawn_depth=1 (flat), role='orchestrator'
|
||||
on a depth-0 parent produces a depth-1 child that is already at
|
||||
the floor — the role degrades to 'leaf' and the delegation
|
||||
toolset is stripped. This is the new default posture."""
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
parent.enabled_toolsets = ["terminal", "file", "delegation"]
|
||||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = _make_role_mock_child()
|
||||
MockAgent.return_value = mock_child
|
||||
delegate_task(goal="test", role="orchestrator", parent_agent=parent)
|
||||
kwargs = MockAgent.call_args[1]
|
||||
self.assertNotIn("delegation", kwargs["enabled_toolsets"])
|
||||
self.assertEqual(mock_child._delegate_role, "leaf")
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
def test_orchestrator_enabled_false_forces_leaf(self, mock_creds):
|
||||
"""Kill switch delegation.orchestrator_enabled=false overrides
|
||||
role='orchestrator'."""
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
parent.enabled_toolsets = ["terminal", "delegation"]
|
||||
with patch("tools.delegate_tool._load_config",
|
||||
return_value={"orchestrator_enabled": False}):
|
||||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = _make_role_mock_child()
|
||||
MockAgent.return_value = mock_child
|
||||
delegate_task(goal="test", role="orchestrator",
|
||||
parent_agent=parent)
|
||||
kwargs = MockAgent.call_args[1]
|
||||
self.assertNotIn("delegation", kwargs["enabled_toolsets"])
|
||||
self.assertEqual(mock_child._delegate_role, "leaf")
|
||||
|
||||
# ── Role-aware system prompt ────────────────────────────────────────
|
||||
|
||||
def test_leaf_prompt_does_not_mention_delegation(self):
|
||||
prompt = _build_child_system_prompt(
|
||||
"Fix tests", role="leaf",
|
||||
max_spawn_depth=2, child_depth=1,
|
||||
)
|
||||
self.assertNotIn("delegate_task", prompt)
|
||||
self.assertNotIn("Orchestrator Role", prompt)
|
||||
|
||||
def test_orchestrator_prompt_mentions_delegation_capability(self):
|
||||
prompt = _build_child_system_prompt(
|
||||
"Survey approaches", role="orchestrator",
|
||||
max_spawn_depth=2, child_depth=1,
|
||||
)
|
||||
self.assertIn("delegate_task", prompt)
|
||||
self.assertIn("Orchestrator Role", prompt)
|
||||
# Depth/max-depth note present and literal:
|
||||
self.assertIn("depth 1", prompt)
|
||||
self.assertIn("max_spawn_depth=2", prompt)
|
||||
|
||||
def test_orchestrator_prompt_at_depth_floor_says_children_are_leaves(self):
|
||||
"""With max_spawn_depth=2 and child_depth=1, the orchestrator's
|
||||
own children would be at depth 2 (the floor) → must be leaves."""
|
||||
prompt = _build_child_system_prompt(
|
||||
"Survey", role="orchestrator",
|
||||
max_spawn_depth=2, child_depth=1,
|
||||
)
|
||||
self.assertIn("MUST be leaves", prompt)
|
||||
|
||||
def test_orchestrator_prompt_below_floor_allows_more_nesting(self):
|
||||
"""With max_spawn_depth=3 and child_depth=1, the orchestrator's
|
||||
own children can themselves be orchestrators (depth 2 < 3)."""
|
||||
prompt = _build_child_system_prompt(
|
||||
"Deep work", role="orchestrator",
|
||||
max_spawn_depth=3, child_depth=1,
|
||||
)
|
||||
self.assertIn("can themselves be orchestrators", prompt)
|
||||
|
||||
# ── Batch mode and intersection ─────────────────────────────────────
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 2})
|
||||
def test_batch_mode_per_task_role_override(self, mock_cfg, mock_creds):
|
||||
"""Per-task role beats top-level; no top-level role → "leaf".
|
||||
|
||||
tasks=[{role:'orchestrator'},{role:'leaf'},{}] → first gets
|
||||
delegation, second and third don't. Requires max_spawn_depth>=2
|
||||
(raised explicitly here) since the new default is 1 (flat).
|
||||
"""
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
parent.enabled_toolsets = ["terminal", "file", "delegation"]
|
||||
built_toolsets = []
|
||||
|
||||
def _factory(*a, **kw):
|
||||
m = _make_role_mock_child()
|
||||
built_toolsets.append(kw.get("enabled_toolsets"))
|
||||
return m
|
||||
|
||||
with patch("run_agent.AIAgent", side_effect=_factory):
|
||||
delegate_task(
|
||||
tasks=[
|
||||
{"goal": "A", "role": "orchestrator"},
|
||||
{"goal": "B", "role": "leaf"},
|
||||
{"goal": "C"}, # no role → falls back to top_role (leaf)
|
||||
],
|
||||
parent_agent=parent,
|
||||
)
|
||||
self.assertIn("delegation", built_toolsets[0])
|
||||
self.assertNotIn("delegation", built_toolsets[1])
|
||||
self.assertNotIn("delegation", built_toolsets[2])
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 2})
|
||||
def test_intersection_preserves_delegation_bound(
|
||||
self, mock_cfg, mock_creds
|
||||
):
|
||||
"""Design decision: orchestrator capability is granted by role,
|
||||
NOT inherited from the parent's toolset. A parent without
|
||||
'delegation' in its enabled_toolsets can still spawn an
|
||||
orchestrator child — the re-add in _build_child_agent runs
|
||||
unconditionally for orchestrators (when max_spawn_depth allows).
|
||||
|
||||
If you want to change to "parent must have delegation too",
|
||||
update _build_child_agent to check parent_toolsets before the
|
||||
re-add and update this test to match.
|
||||
"""
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
parent.enabled_toolsets = ["terminal", "file"] # no delegation
|
||||
with patch("run_agent.AIAgent") as MockAgent:
|
||||
mock_child = _make_role_mock_child()
|
||||
MockAgent.return_value = mock_child
|
||||
delegate_task(goal="test", role="orchestrator",
|
||||
parent_agent=parent)
|
||||
self.assertIn("delegation", MockAgent.call_args[1]["enabled_toolsets"])
|
||||
|
||||
|
||||
class TestOrchestratorEndToEnd(unittest.TestCase):
|
||||
"""End-to-end: parent -> orchestrator -> two-leaf nested orchestration.
|
||||
|
||||
Covers the acceptance gate: parent delegates to an orchestrator
|
||||
child; the orchestrator delegates to two leaf grandchildren; the
|
||||
role/toolset/depth chain all resolve correctly.
|
||||
|
||||
Mock strategy: a single AIAgent patch with a side_effect factory
|
||||
that keys on the child's ephemeral_system_prompt — orchestrator
|
||||
prompts contain the string "Orchestrator Role" (see
|
||||
_build_child_system_prompt), leaves don't. The orchestrator
|
||||
mock's run_conversation recursively calls delegate_task with
|
||||
tasks=[{goal:...},{goal:...}] to spawn two leaves. This keeps
|
||||
the test in one patch context and avoids depth-indexed nesting.
|
||||
"""
|
||||
|
||||
@patch("tools.delegate_tool._resolve_delegation_credentials")
|
||||
@patch("tools.delegate_tool._load_config",
|
||||
return_value={"max_spawn_depth": 2})
|
||||
def test_end_to_end_nested_orchestration(self, mock_cfg, mock_creds):
|
||||
mock_creds.return_value = {
|
||||
"provider": None, "base_url": None,
|
||||
"api_key": None, "api_mode": None, "model": None,
|
||||
}
|
||||
parent = _make_mock_parent(depth=0)
|
||||
parent.enabled_toolsets = ["terminal", "file", "delegation"]
|
||||
|
||||
# (enabled_toolsets, _delegate_role) for each agent built
|
||||
built_agents: list = []
|
||||
# Keep the orchestrator mock around so the re-entrant delegate_task
|
||||
# can reach it via closure.
|
||||
orch_mock = {}
|
||||
|
||||
def _factory(*a, **kw):
|
||||
prompt = kw.get("ephemeral_system_prompt", "") or ""
|
||||
is_orchestrator = "Orchestrator Role" in prompt
|
||||
m = _make_role_mock_child()
|
||||
built_agents.append({
|
||||
"enabled_toolsets": list(kw.get("enabled_toolsets") or []),
|
||||
"is_orchestrator_prompt": is_orchestrator,
|
||||
})
|
||||
|
||||
if is_orchestrator:
|
||||
# Prepare the orchestrator mock as a parent-capable object
|
||||
# so the nested delegate_task call succeeds.
|
||||
m._delegate_depth = 1
|
||||
m._delegate_role = "orchestrator"
|
||||
m._active_children = []
|
||||
m._active_children_lock = threading.Lock()
|
||||
m._session_db = None
|
||||
m.platform = "cli"
|
||||
m.enabled_toolsets = ["terminal", "file", "delegation"]
|
||||
m.api_key = "***"
|
||||
m.base_url = ""
|
||||
m.provider = None
|
||||
m.api_mode = None
|
||||
m.providers_allowed = None
|
||||
m.providers_ignored = None
|
||||
m.providers_order = None
|
||||
m.provider_sort = None
|
||||
m._print_fn = None
|
||||
m.tool_progress_callback = None
|
||||
m.thinking_callback = None
|
||||
orch_mock["agent"] = m
|
||||
|
||||
def _orchestrator_run(user_message=None, task_id=None):
|
||||
# Re-entrant: orchestrator spawns two leaves
|
||||
delegate_task(
|
||||
tasks=[{"goal": "leaf-A"}, {"goal": "leaf-B"}],
|
||||
parent_agent=m,
|
||||
)
|
||||
return {
|
||||
"final_response": "orchestrated 2 workers",
|
||||
"completed": True, "api_calls": 1,
|
||||
"messages": [],
|
||||
}
|
||||
m.run_conversation.side_effect = _orchestrator_run
|
||||
|
||||
return m
|
||||
|
||||
with patch("run_agent.AIAgent", side_effect=_factory) as MockAgent:
|
||||
delegate_task(
|
||||
goal="top-level orchestration",
|
||||
role="orchestrator",
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
# 1 orchestrator + 2 leaf grandchildren = 3 agents
|
||||
self.assertEqual(MockAgent.call_count, 3)
|
||||
# First built = the orchestrator (parent's direct child)
|
||||
self.assertIn("delegation", built_agents[0]["enabled_toolsets"])
|
||||
self.assertTrue(built_agents[0]["is_orchestrator_prompt"])
|
||||
# Next two = leaves (grandchildren)
|
||||
self.assertNotIn("delegation", built_agents[1]["enabled_toolsets"])
|
||||
self.assertFalse(built_agents[1]["is_orchestrator_prompt"])
|
||||
self.assertNotIn("delegation", built_agents[2]["enabled_toolsets"])
|
||||
self.assertFalse(built_agents[2]["is_orchestrator_prompt"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
287
tests/tools/test_file_state_registry.py
Normal file
287
tests/tools/test_file_state_registry.py
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Tests for the cross-agent FileStateRegistry (tools/file_state.py).
|
||||
|
||||
Covers the three layers added for safe concurrent subagent file edits:
|
||||
|
||||
1. Cross-agent staleness detection via ``check_stale``
|
||||
2. Per-path serialization via ``lock_path``
|
||||
3. Delegate-completion reminder via ``writes_since``
|
||||
|
||||
Plus integration through the real ``read_file_tool`` / ``write_file_tool``
|
||||
/ ``patch_tool`` handlers so the full hook wiring is exercised.
|
||||
|
||||
Run:
|
||||
python -m pytest tests/tools/test_file_state_registry.py -v
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from tools import file_state
|
||||
from tools.file_tools import (
|
||||
read_file_tool,
|
||||
write_file_tool,
|
||||
patch_tool,
|
||||
)
|
||||
|
||||
|
||||
def _tmp_file(content: str = "initial\n") -> str:
|
||||
fd, path = tempfile.mkstemp(prefix="hermes_file_state_test_", suffix=".txt")
|
||||
with os.fdopen(fd, "w") as f:
|
||||
f.write(content)
|
||||
return path
|
||||
|
||||
|
||||
class FileStateRegistryUnitTests(unittest.TestCase):
|
||||
"""Direct unit tests on the registry singleton."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
file_state.get_registry().clear()
|
||||
self._tmpfiles: list[str] = []
|
||||
|
||||
def tearDown(self) -> None:
|
||||
for p in self._tmpfiles:
|
||||
try:
|
||||
os.unlink(p)
|
||||
except OSError:
|
||||
pass
|
||||
file_state.get_registry().clear()
|
||||
|
||||
def _mk(self, content: str = "x\n") -> str:
|
||||
p = _tmp_file(content)
|
||||
self._tmpfiles.append(p)
|
||||
return p
|
||||
|
||||
def test_record_read_then_check_stale_returns_none(self):
|
||||
p = self._mk()
|
||||
file_state.record_read("A", p)
|
||||
self.assertIsNone(file_state.check_stale("A", p))
|
||||
|
||||
def test_sibling_write_flags_other_agent_as_stale(self):
|
||||
p = self._mk()
|
||||
file_state.record_read("A", p)
|
||||
# Simulate sibling writing this file later
|
||||
time.sleep(0.01) # ensure ts ordering across resolution
|
||||
file_state.note_write("B", p)
|
||||
warn = file_state.check_stale("A", p)
|
||||
self.assertIsNotNone(warn)
|
||||
self.assertIn("B", warn)
|
||||
self.assertIn("sibling", warn.lower())
|
||||
|
||||
def test_write_without_read_flagged(self):
|
||||
p = self._mk()
|
||||
# Agent A never read this file.
|
||||
file_state.note_write("B", p) # another agent touched it
|
||||
warn = file_state.check_stale("A", p)
|
||||
self.assertIsNotNone(warn)
|
||||
|
||||
def test_partial_read_flagged_on_write(self):
|
||||
p = self._mk()
|
||||
file_state.record_read("A", p, partial=True)
|
||||
warn = file_state.check_stale("A", p)
|
||||
self.assertIsNotNone(warn)
|
||||
self.assertIn("partial", warn.lower())
|
||||
|
||||
def test_external_mtime_drift_flagged(self):
|
||||
p = self._mk()
|
||||
file_state.record_read("A", p)
|
||||
# Bump the on-disk mtime without going through the registry.
|
||||
time.sleep(0.01)
|
||||
os.utime(p, None)
|
||||
with open(p, "w") as f:
|
||||
f.write("externally modified\n")
|
||||
warn = file_state.check_stale("A", p)
|
||||
self.assertIsNotNone(warn)
|
||||
self.assertIn("modified since you last read", warn)
|
||||
|
||||
def test_own_write_updates_stamp_so_next_write_is_clean(self):
|
||||
p = self._mk()
|
||||
file_state.record_read("A", p)
|
||||
file_state.note_write("A", p)
|
||||
# Second write by the same agent — should not be flagged.
|
||||
self.assertIsNone(file_state.check_stale("A", p))
|
||||
|
||||
def test_different_paths_dont_interfere(self):
|
||||
a = self._mk()
|
||||
b = self._mk()
|
||||
file_state.record_read("A", a)
|
||||
file_state.note_write("B", b)
|
||||
# A reads only `a`; B writes `b`. A writing `a` is NOT stale.
|
||||
self.assertIsNone(file_state.check_stale("A", a))
|
||||
|
||||
def test_lock_path_serializes_same_path(self):
|
||||
p = self._mk()
|
||||
events: list[tuple[str, int]] = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def worker(i: int) -> None:
|
||||
with file_state.lock_path(p):
|
||||
with lock:
|
||||
events.append(("enter", i))
|
||||
time.sleep(0.01)
|
||||
with lock:
|
||||
events.append(("exit", i))
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(4)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Every enter must be immediately followed by its matching exit.
|
||||
self.assertEqual(len(events), 8)
|
||||
for i in range(0, 8, 2):
|
||||
self.assertEqual(events[i][0], "enter")
|
||||
self.assertEqual(events[i + 1][0], "exit")
|
||||
self.assertEqual(events[i][1], events[i + 1][1])
|
||||
|
||||
def test_lock_path_is_per_path_not_global(self):
|
||||
a = self._mk()
|
||||
b = self._mk()
|
||||
b_entered = threading.Event()
|
||||
|
||||
def hold_a() -> None:
|
||||
with file_state.lock_path(a):
|
||||
b_entered.wait(timeout=2.0)
|
||||
|
||||
def enter_b() -> None:
|
||||
time.sleep(0.02) # let A grab its lock
|
||||
with file_state.lock_path(b):
|
||||
b_entered.set()
|
||||
|
||||
ta = threading.Thread(target=hold_a)
|
||||
tb = threading.Thread(target=enter_b)
|
||||
ta.start()
|
||||
tb.start()
|
||||
self.assertTrue(b_entered.wait(timeout=3.0))
|
||||
ta.join(timeout=3.0)
|
||||
tb.join(timeout=3.0)
|
||||
|
||||
def test_writes_since_filters_by_parent_read_set(self):
|
||||
foo = self._mk()
|
||||
bar = self._mk()
|
||||
baz = self._mk()
|
||||
file_state.record_read("parent", foo)
|
||||
file_state.record_read("parent", bar)
|
||||
since = time.time()
|
||||
time.sleep(0.01)
|
||||
file_state.note_write("child", foo) # parent read this — report
|
||||
file_state.note_write("child", baz) # parent never saw — skip
|
||||
|
||||
# Caller passes only paths the parent actually read (this is what
|
||||
# delegate_tool does via ``known_reads(parent_task_id)``).
|
||||
parent_reads = file_state.known_reads("parent")
|
||||
out = file_state.writes_since("parent", since, parent_reads)
|
||||
self.assertIn("child", out)
|
||||
self.assertIn(foo, out["child"])
|
||||
self.assertNotIn(baz, out["child"])
|
||||
|
||||
def test_writes_since_excludes_the_target_agent(self):
|
||||
p = self._mk()
|
||||
file_state.record_read("parent", p)
|
||||
since = time.time()
|
||||
time.sleep(0.01)
|
||||
file_state.note_write("parent", p) # parent's own write
|
||||
out = file_state.writes_since("parent", since, [p])
|
||||
self.assertEqual(out, {})
|
||||
|
||||
def test_kill_switch_env_var(self):
|
||||
p = self._mk()
|
||||
os.environ["HERMES_DISABLE_FILE_STATE_GUARD"] = "1"
|
||||
try:
|
||||
file_state.record_read("A", p)
|
||||
file_state.note_write("B", p)
|
||||
self.assertIsNone(file_state.check_stale("A", p))
|
||||
self.assertEqual(file_state.known_reads("A"), [])
|
||||
self.assertEqual(
|
||||
file_state.writes_since("A", 0.0, [p]),
|
||||
{},
|
||||
)
|
||||
finally:
|
||||
del os.environ["HERMES_DISABLE_FILE_STATE_GUARD"]
|
||||
|
||||
|
||||
class FileToolsIntegrationTests(unittest.TestCase):
|
||||
"""Integration through the real file_tools handlers.
|
||||
|
||||
These exercise the wiring: read_file_tool → registry.record_read,
|
||||
write_file_tool / patch_tool → check_stale + lock_path + note_write.
|
||||
"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
file_state.get_registry().clear()
|
||||
self._tmpdir = tempfile.mkdtemp(prefix="hermes_file_state_int_")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
import shutil
|
||||
shutil.rmtree(self._tmpdir, ignore_errors=True)
|
||||
file_state.get_registry().clear()
|
||||
|
||||
def _write_seed(self, name: str, content: str = "seed\n") -> str:
|
||||
p = os.path.join(self._tmpdir, name)
|
||||
with open(p, "w") as f:
|
||||
f.write(content)
|
||||
return p
|
||||
|
||||
def test_sibling_agent_write_surfaces_warning_through_handler(self):
|
||||
p = self._write_seed("shared.txt")
|
||||
r = json.loads(read_file_tool(path=p, task_id="agentA"))
|
||||
self.assertNotIn("error", r)
|
||||
|
||||
w_b = json.loads(write_file_tool(path=p, content="B wrote\n", task_id="agentB"))
|
||||
self.assertNotIn("error", w_b)
|
||||
|
||||
w_a = json.loads(write_file_tool(path=p, content="A stale\n", task_id="agentA"))
|
||||
warn = w_a.get("_warning", "")
|
||||
self.assertTrue(warn, f"expected warning, got: {w_a}")
|
||||
# The cross-agent message names the sibling task_id.
|
||||
self.assertIn("agentB", warn)
|
||||
self.assertIn("sibling", warn.lower())
|
||||
|
||||
def test_same_agent_consecutive_writes_no_false_warning(self):
|
||||
p = self._write_seed("own.txt")
|
||||
json.loads(read_file_tool(path=p, task_id="agentC"))
|
||||
w1 = json.loads(write_file_tool(path=p, content="one\n", task_id="agentC"))
|
||||
self.assertFalse(w1.get("_warning"))
|
||||
w2 = json.loads(write_file_tool(path=p, content="two\n", task_id="agentC"))
|
||||
self.assertFalse(w2.get("_warning"))
|
||||
|
||||
def test_patch_tool_also_surfaces_sibling_warning(self):
|
||||
p = self._write_seed("p.txt", "hello world\n")
|
||||
json.loads(read_file_tool(path=p, task_id="agentA"))
|
||||
json.loads(write_file_tool(path=p, content="hello planet\n", task_id="agentB"))
|
||||
r = json.loads(
|
||||
patch_tool(
|
||||
mode="replace",
|
||||
path=p,
|
||||
old_string="hello",
|
||||
new_string="HI",
|
||||
task_id="agentA",
|
||||
)
|
||||
)
|
||||
warn = r.get("_warning", "")
|
||||
# Patch may fail (sibling changed the content so old_string may not
|
||||
# match) or succeed — either way, the cross-agent warning should be
|
||||
# present when old_string still happens to match. What matters is
|
||||
# that if the patch succeeded or the warning was reported, it names
|
||||
# the sibling. When old_string doesn't match, the patch itself
|
||||
# returns an error but the warning is still set from the pre-check.
|
||||
if warn:
|
||||
self.assertIn("agentB", warn)
|
||||
|
||||
def test_net_new_file_no_warning(self):
|
||||
p = os.path.join(self._tmpdir, "brand_new.txt")
|
||||
# Nobody has read or written this before.
|
||||
w = json.loads(write_file_tool(path=p, content="hi\n", task_id="agentX"))
|
||||
self.assertFalse(w.get("_warning"))
|
||||
self.assertNotIn("error", w)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -136,6 +136,49 @@ class TestGptLiteralFamily:
|
|||
assert p["image_size"] == "1024x1536"
|
||||
|
||||
|
||||
class TestGptImage2Presets:
|
||||
"""GPT Image 2 uses preset enum sizes (not literal strings like 1.5).
|
||||
Mapped to 4:3 variants so we stay above the 655,360 min-pixel floor
|
||||
(16:9 presets at 1024x576 = 589,824 would be rejected)."""
|
||||
|
||||
def test_gpt2_landscape_uses_4_3_preset(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-2", "hello", "landscape")
|
||||
assert p["image_size"] == "landscape_4_3"
|
||||
|
||||
def test_gpt2_square_uses_square_hd(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-2", "hello", "square")
|
||||
assert p["image_size"] == "square_hd"
|
||||
|
||||
def test_gpt2_portrait_uses_4_3_preset(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-2", "hello", "portrait")
|
||||
assert p["image_size"] == "portrait_4_3"
|
||||
|
||||
def test_gpt2_quality_pinned_to_medium(self, image_tool):
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-2", "hi", "square")
|
||||
assert p["quality"] == "medium"
|
||||
|
||||
def test_gpt2_strips_byok_and_unsupported_overrides(self, image_tool):
|
||||
"""openai_api_key (BYOK) is deliberately not in supports — all users
|
||||
route through shared FAL billing. guidance_scale/num_inference_steps
|
||||
aren't in the model's API surface either."""
|
||||
p = image_tool._build_fal_payload(
|
||||
"fal-ai/gpt-image-2", "hi", "square",
|
||||
overrides={
|
||||
"openai_api_key": "sk-...",
|
||||
"guidance_scale": 7.5,
|
||||
"num_inference_steps": 50,
|
||||
},
|
||||
)
|
||||
assert "openai_api_key" not in p
|
||||
assert "guidance_scale" not in p
|
||||
assert "num_inference_steps" not in p
|
||||
|
||||
def test_gpt2_strips_seed_even_if_passed(self, image_tool):
|
||||
# seed isn't in the GPT Image 2 API surface either.
|
||||
p = image_tool._build_fal_payload("fal-ai/gpt-image-2", "hi", "square", seed=42)
|
||||
assert "seed" not in p
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Supports whitelist — the main safety property
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -231,10 +274,11 @@ class TestGptQualityPinnedToMedium:
|
|||
assert p["quality"] == "medium"
|
||||
|
||||
def test_non_gpt_model_never_gets_quality(self, image_tool):
|
||||
"""quality is only meaningful for gpt-image-1.5 — other models should
|
||||
never have it in their payload."""
|
||||
"""quality is only meaningful for GPT-Image models (1.5, 2) — other
|
||||
models should never have it in their payload."""
|
||||
gpt_models = {"fal-ai/gpt-image-1.5", "fal-ai/gpt-image-2"}
|
||||
for mid in image_tool.FAL_MODELS:
|
||||
if mid == "fal-ai/gpt-image-1.5":
|
||||
if mid in gpt_models:
|
||||
continue
|
||||
p = image_tool._build_fal_payload(mid, "hi", "square")
|
||||
assert "quality" not in p, f"{mid} unexpectedly has 'quality' in payload"
|
||||
|
|
|
|||
197
tests/tools/test_tts_max_text_length.py
Normal file
197
tests/tools/test_tts_max_text_length.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
"""Tests for per-provider TTS input-character limits.
|
||||
|
||||
Replaces the old global ``MAX_TEXT_LENGTH = 4000`` cap that truncated every
|
||||
provider at 4000 chars even though OpenAI allows 4096, xAI allows 15000,
|
||||
MiniMax allows 10000, and ElevenLabs allows 5000-40000 depending on model.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.tts_tool import (
|
||||
ELEVENLABS_MODEL_MAX_TEXT_LENGTH,
|
||||
FALLBACK_MAX_TEXT_LENGTH,
|
||||
PROVIDER_MAX_TEXT_LENGTH,
|
||||
_resolve_max_text_length,
|
||||
)
|
||||
|
||||
|
||||
class TestResolveMaxTextLength:
|
||||
def test_edge_default(self):
|
||||
assert _resolve_max_text_length("edge", {}) == PROVIDER_MAX_TEXT_LENGTH["edge"]
|
||||
|
||||
def test_openai_default_is_4096(self):
|
||||
assert _resolve_max_text_length("openai", {}) == 4096
|
||||
|
||||
def test_xai_default_is_15000(self):
|
||||
assert _resolve_max_text_length("xai", {}) == 15000
|
||||
|
||||
def test_minimax_default_is_10000(self):
|
||||
assert _resolve_max_text_length("minimax", {}) == 10000
|
||||
|
||||
def test_mistral_default(self):
|
||||
assert _resolve_max_text_length("mistral", {}) == PROVIDER_MAX_TEXT_LENGTH["mistral"]
|
||||
|
||||
def test_gemini_default(self):
|
||||
assert _resolve_max_text_length("gemini", {}) == PROVIDER_MAX_TEXT_LENGTH["gemini"]
|
||||
|
||||
def test_unknown_provider_falls_back(self):
|
||||
assert _resolve_max_text_length("does-not-exist", {}) == FALLBACK_MAX_TEXT_LENGTH
|
||||
|
||||
def test_empty_provider_falls_back(self):
|
||||
assert _resolve_max_text_length("", {}) == FALLBACK_MAX_TEXT_LENGTH
|
||||
assert _resolve_max_text_length(None, {}) == FALLBACK_MAX_TEXT_LENGTH
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _resolve_max_text_length("OpenAI", {}) == 4096
|
||||
assert _resolve_max_text_length(" XAI ", {}) == 15000
|
||||
|
||||
# --- Overrides ---
|
||||
|
||||
def test_override_wins(self):
|
||||
cfg = {"openai": {"max_text_length": 9999}}
|
||||
assert _resolve_max_text_length("openai", cfg) == 9999
|
||||
|
||||
def test_override_zero_falls_through(self):
|
||||
# A broken/zero override must not disable truncation
|
||||
cfg = {"openai": {"max_text_length": 0}}
|
||||
assert _resolve_max_text_length("openai", cfg) == 4096
|
||||
|
||||
def test_override_negative_falls_through(self):
|
||||
cfg = {"xai": {"max_text_length": -1}}
|
||||
assert _resolve_max_text_length("xai", cfg) == 15000
|
||||
|
||||
def test_override_non_int_falls_through(self):
|
||||
cfg = {"minimax": {"max_text_length": "lots"}}
|
||||
assert _resolve_max_text_length("minimax", cfg) == 10000
|
||||
|
||||
def test_override_bool_falls_through(self):
|
||||
# bool is technically an int; make sure we don't treat True as 1 char
|
||||
cfg = {"openai": {"max_text_length": True}}
|
||||
assert _resolve_max_text_length("openai", cfg) == 4096
|
||||
|
||||
def test_missing_provider_section_uses_default(self):
|
||||
cfg = {"provider": "openai"} # no "openai" key
|
||||
assert _resolve_max_text_length("openai", cfg) == 4096
|
||||
|
||||
# --- ElevenLabs model-aware ---
|
||||
|
||||
def test_elevenlabs_default_model_multilingual_v2(self):
|
||||
cfg = {"elevenlabs": {"model_id": "eleven_multilingual_v2"}}
|
||||
assert _resolve_max_text_length("elevenlabs", cfg) == 10000
|
||||
|
||||
def test_elevenlabs_flash_v2_5_gets_40k(self):
|
||||
cfg = {"elevenlabs": {"model_id": "eleven_flash_v2_5"}}
|
||||
assert _resolve_max_text_length("elevenlabs", cfg) == 40000
|
||||
|
||||
def test_elevenlabs_flash_v2_gets_30k(self):
|
||||
cfg = {"elevenlabs": {"model_id": "eleven_flash_v2"}}
|
||||
assert _resolve_max_text_length("elevenlabs", cfg) == 30000
|
||||
|
||||
def test_elevenlabs_v3_gets_5k(self):
|
||||
cfg = {"elevenlabs": {"model_id": "eleven_v3"}}
|
||||
assert _resolve_max_text_length("elevenlabs", cfg) == 5000
|
||||
|
||||
def test_elevenlabs_unknown_model_falls_back_to_provider_default(self):
|
||||
cfg = {"elevenlabs": {"model_id": "eleven_experimental_xyz"}}
|
||||
assert _resolve_max_text_length("elevenlabs", cfg) == PROVIDER_MAX_TEXT_LENGTH["elevenlabs"]
|
||||
|
||||
def test_elevenlabs_override_beats_model_lookup(self):
|
||||
cfg = {"elevenlabs": {"model_id": "eleven_flash_v2_5", "max_text_length": 1000}}
|
||||
assert _resolve_max_text_length("elevenlabs", cfg) == 1000
|
||||
|
||||
def test_elevenlabs_no_model_id_uses_default_model_mapping(self):
|
||||
# Falls back to DEFAULT_ELEVENLABS_MODEL_ID = eleven_multilingual_v2 -> 10000
|
||||
assert _resolve_max_text_length("elevenlabs", {}) == 10000
|
||||
|
||||
def test_provider_config_not_a_dict(self):
|
||||
cfg = {"openai": "not-a-dict"}
|
||||
assert _resolve_max_text_length("openai", cfg) == 4096
|
||||
|
||||
# --- Sanity: the table covers every provider listed in the schema ---
|
||||
|
||||
def test_all_documented_providers_have_defaults(self):
|
||||
expected = {"edge", "openai", "xai", "minimax", "mistral",
|
||||
"gemini", "elevenlabs", "neutts", "kittentts"}
|
||||
assert expected.issubset(PROVIDER_MAX_TEXT_LENGTH.keys())
|
||||
|
||||
|
||||
class TestTextToSpeechToolTruncation:
|
||||
"""End-to-end: verify the resolver actually drives the text_to_speech_tool
|
||||
truncation path rather than the old 4000-char global."""
|
||||
|
||||
def test_openai_truncates_at_4096_not_4000(self, tmp_path, monkeypatch, caplog):
|
||||
import logging
|
||||
caplog.set_level(logging.WARNING, logger="tools.tts_tool")
|
||||
|
||||
# 5000 chars -- over OpenAI's 4096 limit but under xAI's 15k
|
||||
text = "A" * 5000
|
||||
captured_text = {}
|
||||
|
||||
def fake_openai(t, out, cfg):
|
||||
captured_text["text"] = t
|
||||
with open(out, "wb") as f:
|
||||
f.write(b"\x00")
|
||||
return out
|
||||
|
||||
monkeypatch.setattr("tools.tts_tool._generate_openai_tts", fake_openai)
|
||||
monkeypatch.setattr("tools.tts_tool._load_tts_config",
|
||||
lambda: {"provider": "openai"})
|
||||
|
||||
from tools.tts_tool import text_to_speech_tool
|
||||
out = str(tmp_path / "out.mp3")
|
||||
result = json.loads(text_to_speech_tool(text=text, output_path=out))
|
||||
|
||||
assert result["success"] is True
|
||||
# Should be truncated to 4096, not the old 4000
|
||||
assert len(captured_text["text"]) == 4096
|
||||
# And the warning should mention the provider
|
||||
assert any("openai" in rec.message.lower() for rec in caplog.records)
|
||||
|
||||
def test_xai_accepts_much_longer_input(self, tmp_path, monkeypatch):
|
||||
# 12000 chars -- over old global 4000, under xAI's 15000
|
||||
text = "B" * 12000
|
||||
captured_text = {}
|
||||
|
||||
def fake_xai(t, out, cfg):
|
||||
captured_text["text"] = t
|
||||
with open(out, "wb") as f:
|
||||
f.write(b"\x00")
|
||||
return out
|
||||
|
||||
monkeypatch.setattr("tools.tts_tool._generate_xai_tts", fake_xai)
|
||||
monkeypatch.setattr("tools.tts_tool._load_tts_config",
|
||||
lambda: {"provider": "xai"})
|
||||
|
||||
from tools.tts_tool import text_to_speech_tool
|
||||
out = str(tmp_path / "out.mp3")
|
||||
result = json.loads(text_to_speech_tool(text=text, output_path=out))
|
||||
|
||||
assert result["success"] is True
|
||||
# xAI should accept the full 12000 chars
|
||||
assert len(captured_text["text"]) == 12000
|
||||
|
||||
def test_user_override_is_respected(self, tmp_path, monkeypatch):
|
||||
# User says "cap openai at 100 chars" -- we must honor it
|
||||
text = "C" * 500
|
||||
captured_text = {}
|
||||
|
||||
def fake_openai(t, out, cfg):
|
||||
captured_text["text"] = t
|
||||
with open(out, "wb") as f:
|
||||
f.write(b"\x00")
|
||||
return out
|
||||
|
||||
monkeypatch.setattr("tools.tts_tool._generate_openai_tts", fake_openai)
|
||||
monkeypatch.setattr("tools.tts_tool._load_tts_config",
|
||||
lambda: {"provider": "openai",
|
||||
"openai": {"max_text_length": 100}})
|
||||
|
||||
from tools.tts_tool import text_to_speech_tool
|
||||
out = str(tmp_path / "out.mp3")
|
||||
result = json.loads(text_to_speech_tool(text=text, output_path=out))
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(captured_text["text"]) == 100
|
||||
Loading…
Add table
Add a link
Reference in a new issue