mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-29 01:31:41 +00:00
fix(delegate): resolve merge conflict with upstream thread capture
This commit is contained in:
commit
fda1325e79
522 changed files with 82059 additions and 4541 deletions
|
|
@ -434,6 +434,76 @@ class TestSensitiveRedirectPattern:
|
|||
assert dangerous is False
|
||||
assert key is None
|
||||
|
||||
def test_redirect_to_local_dotenv_requires_approval(self):
|
||||
dangerous, key, desc = detect_dangerous_command("echo TOKEN=x > .env")
|
||||
assert dangerous is True
|
||||
assert key is not None
|
||||
assert "project env/config" in desc.lower()
|
||||
|
||||
def test_redirect_to_nested_config_yaml_requires_approval(self):
|
||||
dangerous, key, desc = detect_dangerous_command("echo mode: prod > deploy/config.yaml")
|
||||
assert dangerous is True
|
||||
assert key is not None
|
||||
assert "project env/config" in desc.lower()
|
||||
|
||||
def test_redirect_from_local_dotenv_source_is_safe(self):
|
||||
dangerous, key, desc = detect_dangerous_command("cat .env > backup.txt")
|
||||
assert dangerous is False
|
||||
assert key is None
|
||||
assert desc is None
|
||||
|
||||
|
||||
class TestProjectSensitiveCopyPattern:
|
||||
def test_cp_to_local_dotenv_requires_approval(self):
|
||||
dangerous, key, desc = detect_dangerous_command("cp .env.local .env")
|
||||
assert dangerous is True
|
||||
assert key is not None
|
||||
assert "project env/config" in desc.lower()
|
||||
|
||||
def test_cp_absolute_path_to_dotenv_requires_approval(self):
|
||||
# Regression: the real-world bug report was `cp /opt/data/.env.local /opt/data/.env`.
|
||||
# The regex must cover absolute paths, not just `./` / bare relative paths.
|
||||
dangerous, key, desc = detect_dangerous_command(
|
||||
"cp /opt/data/.env.local /opt/data/.env"
|
||||
)
|
||||
assert dangerous is True
|
||||
assert key is not None
|
||||
assert "project env/config" in desc.lower()
|
||||
|
||||
def test_redirect_absolute_path_to_dotenv_requires_approval(self):
|
||||
dangerous, key, desc = detect_dangerous_command(
|
||||
"cat /opt/data/.env.local > /opt/data/.env"
|
||||
)
|
||||
assert dangerous is True
|
||||
assert key is not None
|
||||
assert "project env/config" in desc.lower()
|
||||
|
||||
def test_mv_to_nested_config_yaml_requires_approval(self):
|
||||
dangerous, key, desc = detect_dangerous_command("mv tmp/generated.yaml config/config.yaml")
|
||||
assert dangerous is True
|
||||
assert key is not None
|
||||
assert "project env/config" in desc.lower()
|
||||
|
||||
def test_install_to_dotenv_requires_approval(self):
|
||||
dangerous, key, desc = detect_dangerous_command("install -m 600 template.env .env.production")
|
||||
assert dangerous is True
|
||||
assert key is not None
|
||||
assert "project env/config" in desc.lower()
|
||||
|
||||
def test_cp_from_config_yaml_source_is_safe(self):
|
||||
dangerous, key, desc = detect_dangerous_command("cp config.yaml backup.yaml")
|
||||
assert dangerous is False
|
||||
assert key is None
|
||||
assert desc is None
|
||||
|
||||
|
||||
class TestProjectSensitiveTeePattern:
|
||||
def test_tee_to_local_dotenv_requires_approval(self):
|
||||
dangerous, key, desc = detect_dangerous_command("printenv | tee .env.local")
|
||||
assert dangerous is True
|
||||
assert key is not None
|
||||
assert "project env/config" in desc.lower()
|
||||
|
||||
|
||||
class TestPatternKeyUniqueness:
|
||||
"""Bug: pattern_key is derived by splitting on \\b and taking [1], so
|
||||
|
|
@ -836,4 +906,3 @@ class TestChmodExecuteCombo:
|
|||
cmd = "chmod +x script.sh"
|
||||
dangerous, _, _ = detect_dangerous_command(cmd)
|
||||
assert dangerous is False
|
||||
|
||||
|
|
|
|||
|
|
@ -60,6 +60,22 @@ class TestWrapCommand:
|
|||
assert "cd ~" in wrapped
|
||||
assert "cd '~'" not in wrapped
|
||||
|
||||
def test_tilde_subpath_with_spaces_uses_home_and_quotes_suffix(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
wrapped = env._wrap_command("ls", "~/my repo")
|
||||
|
||||
assert "cd $HOME/'my repo'" in wrapped
|
||||
assert "cd ~/my repo" not in wrapped
|
||||
|
||||
def test_tilde_slash_maps_to_home(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
wrapped = env._wrap_command("ls", "~/")
|
||||
|
||||
assert "cd $HOME" in wrapped
|
||||
assert "cd ~/" not in wrapped
|
||||
|
||||
def test_cd_failure_exit_126(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
|
|
|
|||
|
|
@ -283,7 +283,7 @@ class TestCamofoxVisionConfig:
|
|||
with (
|
||||
patch("tools.browser_camofox.open", create=True) as mock_open,
|
||||
patch("agent.auxiliary_client.call_llm", return_value=mock_response) as mock_llm,
|
||||
patch("hermes_cli.config.load_config", return_value={"auxiliary": {"vision": {"temperature": 1, "timeout": 45}}}),
|
||||
patch("tools.browser_camofox.load_config", return_value={"auxiliary": {"vision": {"temperature": 1, "timeout": 45}}}),
|
||||
):
|
||||
mock_open.return_value.__enter__.return_value.read.return_value = b"fakepng"
|
||||
result = json.loads(camofox_vision("what is on the page?", annotate=True, task_id="t11"))
|
||||
|
|
@ -315,7 +315,7 @@ class TestCamofoxVisionConfig:
|
|||
with (
|
||||
patch("tools.browser_camofox.open", create=True) as mock_open,
|
||||
patch("agent.auxiliary_client.call_llm", return_value=mock_response) as mock_llm,
|
||||
patch("hermes_cli.config.load_config", return_value={"auxiliary": {"vision": {}}}),
|
||||
patch("tools.browser_camofox.load_config", return_value={"auxiliary": {"vision": {}}}),
|
||||
):
|
||||
mock_open.return_value.__enter__.return_value.read.return_value = b"fakepng"
|
||||
result = json.loads(camofox_vision("what is on the page?", annotate=True, task_id="t12"))
|
||||
|
|
|
|||
|
|
@ -351,7 +351,10 @@ def test_registered_in_browser_toolset():
|
|||
|
||||
entry = registry.get_entry("browser_cdp")
|
||||
assert entry is not None
|
||||
assert entry.toolset == "browser"
|
||||
# browser_cdp lives in its own toolset so its stricter check_fn
|
||||
# (requires reachable CDP endpoint) doesn't gate the whole browser
|
||||
# toolset — see commit 96b0f3700.
|
||||
assert entry.toolset == "browser-cdp"
|
||||
assert entry.schema["name"] == "browser_cdp"
|
||||
assert entry.schema["parameters"]["required"] == ["method"]
|
||||
assert "Chrome DevTools Protocol" in entry.schema["description"]
|
||||
|
|
|
|||
563
tests/tools/test_browser_supervisor.py
Normal file
563
tests/tools/test_browser_supervisor.py
Normal file
|
|
@ -0,0 +1,563 @@
|
|||
"""Integration tests for tools.browser_supervisor.
|
||||
|
||||
Exercises the supervisor end-to-end against a real local Chrome
|
||||
(``--remote-debugging-port``). Skipped when Chrome is not installed
|
||||
— these are the tests that actually verify the CDP wire protocol
|
||||
works, since mock-CDP unit tests can only prove the happy paths we
|
||||
thought to model.
|
||||
|
||||
Run manually:
|
||||
scripts/run_tests.sh tests/tools/test_browser_supervisor.py
|
||||
|
||||
Automated: skipped in CI unless ``HERMES_E2E_BROWSER=1`` is set.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not shutil.which("google-chrome") and not shutil.which("chromium"),
|
||||
reason="Chrome/Chromium not installed",
|
||||
)
|
||||
|
||||
|
||||
def _find_chrome() -> str:
|
||||
for candidate in ("google-chrome", "chromium", "chromium-browser"):
|
||||
path = shutil.which(candidate)
|
||||
if path:
|
||||
return path
|
||||
pytest.skip("no Chrome binary found")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chrome_cdp(worker_id):
|
||||
"""Start a headless Chrome with --remote-debugging-port, yield its WS URL.
|
||||
|
||||
Uses a unique port per xdist worker to avoid cross-worker collisions.
|
||||
Always launches with ``--site-per-process`` so cross-origin iframes
|
||||
become real OOPIFs (needed by the iframe interaction tests).
|
||||
"""
|
||||
import socket
|
||||
|
||||
# xdist worker_id is "master" in single-process mode or "gw0".."gwN" otherwise.
|
||||
if worker_id == "master":
|
||||
port_offset = 0
|
||||
else:
|
||||
port_offset = int(worker_id.lstrip("gw"))
|
||||
port = 9225 + port_offset
|
||||
profile = tempfile.mkdtemp(prefix="hermes-supervisor-test-")
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
_find_chrome(),
|
||||
f"--remote-debugging-port={port}",
|
||||
f"--user-data-dir={profile}",
|
||||
"--no-first-run",
|
||||
"--no-default-browser-check",
|
||||
"--headless=new",
|
||||
"--disable-gpu",
|
||||
"--site-per-process", # force OOPIFs for cross-origin iframes
|
||||
],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
ws_url = None
|
||||
deadline = time.monotonic() + 15
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
import urllib.request
|
||||
with urllib.request.urlopen(
|
||||
f"http://127.0.0.1:{port}/json/version", timeout=1
|
||||
) as r:
|
||||
info = json.loads(r.read().decode())
|
||||
ws_url = info["webSocketDebuggerUrl"]
|
||||
break
|
||||
except Exception:
|
||||
time.sleep(0.25)
|
||||
if ws_url is None:
|
||||
proc.terminate()
|
||||
proc.wait(timeout=5)
|
||||
shutil.rmtree(profile, ignore_errors=True)
|
||||
pytest.skip("Chrome didn't expose CDP in time")
|
||||
|
||||
yield ws_url, port
|
||||
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=3)
|
||||
except Exception:
|
||||
proc.kill()
|
||||
shutil.rmtree(profile, ignore_errors=True)
|
||||
|
||||
|
||||
def _test_page_url() -> str:
|
||||
html = """<!doctype html>
|
||||
<html><head><title>Supervisor pytest</title></head><body>
|
||||
<h1>Supervisor pytest</h1>
|
||||
<iframe id="inner" srcdoc="<body><h2>frame-marker</h2></body>" width="400" height="100"></iframe>
|
||||
</body></html>"""
|
||||
return "data:text/html;base64," + base64.b64encode(html.encode()).decode()
|
||||
|
||||
|
||||
def _fire_on_page(cdp_url: str, expression: str) -> None:
|
||||
"""Navigate the first page target to a data URL and fire `expression`."""
|
||||
import asyncio
|
||||
import websockets as _ws_mod
|
||||
|
||||
async def run():
|
||||
async with _ws_mod.connect(cdp_url, max_size=50 * 1024 * 1024) as ws:
|
||||
next_id = [1]
|
||||
|
||||
async def call(method, params=None, session_id=None):
|
||||
cid = next_id[0]
|
||||
next_id[0] += 1
|
||||
p = {"id": cid, "method": method}
|
||||
if params:
|
||||
p["params"] = params
|
||||
if session_id:
|
||||
p["sessionId"] = session_id
|
||||
await ws.send(json.dumps(p))
|
||||
async for raw in ws:
|
||||
m = json.loads(raw)
|
||||
if m.get("id") == cid:
|
||||
return m
|
||||
|
||||
targets = (await call("Target.getTargets"))["result"]["targetInfos"]
|
||||
page = next(t for t in targets if t.get("type") == "page")
|
||||
attach = await call(
|
||||
"Target.attachToTarget", {"targetId": page["targetId"], "flatten": True}
|
||||
)
|
||||
sid = attach["result"]["sessionId"]
|
||||
await call("Page.navigate", {"url": _test_page_url()}, session_id=sid)
|
||||
await asyncio.sleep(1.5) # let the page load
|
||||
await call(
|
||||
"Runtime.evaluate",
|
||||
{"expression": expression, "returnByValue": True},
|
||||
session_id=sid,
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def supervisor_registry():
|
||||
"""Yield the global registry and tear down any supervisors after the test."""
|
||||
from tools.browser_supervisor import SUPERVISOR_REGISTRY
|
||||
|
||||
yield SUPERVISOR_REGISTRY
|
||||
SUPERVISOR_REGISTRY.stop_all()
|
||||
|
||||
|
||||
def _wait_for_dialog(supervisor, timeout: float = 5.0):
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
snap = supervisor.snapshot()
|
||||
if snap.pending_dialogs:
|
||||
return snap.pending_dialogs
|
||||
time.sleep(0.1)
|
||||
return ()
|
||||
|
||||
|
||||
def test_supervisor_start_and_snapshot(chrome_cdp, supervisor_registry):
|
||||
"""Supervisor attaches, exposes an active snapshot with a top frame."""
|
||||
cdp_url, _port = chrome_cdp
|
||||
supervisor = supervisor_registry.get_or_start(task_id="pytest-1", cdp_url=cdp_url)
|
||||
|
||||
# Navigate so the frame tree populates.
|
||||
_fire_on_page(cdp_url, "/* no dialog */ void 0")
|
||||
|
||||
# Give a moment for frame events to propagate
|
||||
time.sleep(1.0)
|
||||
snap = supervisor.snapshot()
|
||||
assert snap.active is True
|
||||
assert snap.task_id == "pytest-1"
|
||||
assert snap.pending_dialogs == ()
|
||||
# At minimum a top frame should exist after the navigate.
|
||||
assert snap.frame_tree.get("top") is not None
|
||||
|
||||
|
||||
def test_main_frame_alert_detection_and_dismiss(chrome_cdp, supervisor_registry):
|
||||
"""alert() in the main frame surfaces and can be dismissed via the sync API."""
|
||||
cdp_url, _port = chrome_cdp
|
||||
supervisor = supervisor_registry.get_or_start(task_id="pytest-2", cdp_url=cdp_url)
|
||||
|
||||
_fire_on_page(cdp_url, "setTimeout(() => alert('PYTEST-MAIN-ALERT'), 50)")
|
||||
dialogs = _wait_for_dialog(supervisor)
|
||||
assert dialogs, "no dialog detected"
|
||||
d = dialogs[0]
|
||||
assert d.type == "alert"
|
||||
assert "PYTEST-MAIN-ALERT" in d.message
|
||||
|
||||
result = supervisor.respond_to_dialog("dismiss")
|
||||
assert result["ok"] is True
|
||||
# State cleared after dismiss
|
||||
time.sleep(0.3)
|
||||
assert supervisor.snapshot().pending_dialogs == ()
|
||||
|
||||
|
||||
def test_iframe_contentwindow_alert(chrome_cdp, supervisor_registry):
|
||||
"""alert() fired from inside a same-origin iframe surfaces too."""
|
||||
cdp_url, _port = chrome_cdp
|
||||
supervisor = supervisor_registry.get_or_start(task_id="pytest-3", cdp_url=cdp_url)
|
||||
|
||||
_fire_on_page(
|
||||
cdp_url,
|
||||
"setTimeout(() => document.querySelector('#inner').contentWindow.alert('PYTEST-IFRAME'), 50)",
|
||||
)
|
||||
dialogs = _wait_for_dialog(supervisor)
|
||||
assert dialogs, "no iframe dialog detected"
|
||||
assert any("PYTEST-IFRAME" in d.message for d in dialogs)
|
||||
|
||||
result = supervisor.respond_to_dialog("accept")
|
||||
assert result["ok"] is True
|
||||
|
||||
|
||||
def test_prompt_dialog_with_response_text(chrome_cdp, supervisor_registry):
|
||||
"""prompt() gets our prompt_text back inside the page."""
|
||||
cdp_url, _port = chrome_cdp
|
||||
supervisor = supervisor_registry.get_or_start(task_id="pytest-4", cdp_url=cdp_url)
|
||||
|
||||
# Fire a prompt and stash the answer on window
|
||||
_fire_on_page(
|
||||
cdp_url,
|
||||
"setTimeout(() => { window.__promptResult = prompt('give me a token', 'default-x'); }, 50)",
|
||||
)
|
||||
dialogs = _wait_for_dialog(supervisor)
|
||||
assert dialogs
|
||||
d = dialogs[0]
|
||||
assert d.type == "prompt"
|
||||
assert d.default_prompt == "default-x"
|
||||
|
||||
result = supervisor.respond_to_dialog("accept", prompt_text="PYTEST-PROMPT-REPLY")
|
||||
assert result["ok"] is True
|
||||
|
||||
|
||||
def test_respond_with_no_pending_dialog_errors_cleanly(chrome_cdp, supervisor_registry):
|
||||
"""Calling respond_to_dialog when nothing is pending returns a clean error, not an exception."""
|
||||
cdp_url, _port = chrome_cdp
|
||||
supervisor = supervisor_registry.get_or_start(task_id="pytest-5", cdp_url=cdp_url)
|
||||
|
||||
result = supervisor.respond_to_dialog("accept")
|
||||
assert result["ok"] is False
|
||||
assert "no dialog" in result["error"].lower()
|
||||
|
||||
|
||||
def test_auto_dismiss_policy(chrome_cdp, supervisor_registry):
|
||||
"""auto_dismiss policy clears dialogs without the agent responding."""
|
||||
from tools.browser_supervisor import DIALOG_POLICY_AUTO_DISMISS
|
||||
|
||||
cdp_url, _port = chrome_cdp
|
||||
supervisor = supervisor_registry.get_or_start(
|
||||
task_id="pytest-6",
|
||||
cdp_url=cdp_url,
|
||||
dialog_policy=DIALOG_POLICY_AUTO_DISMISS,
|
||||
)
|
||||
|
||||
_fire_on_page(cdp_url, "setTimeout(() => alert('PYTEST-AUTO-DISMISS'), 50)")
|
||||
# Give the supervisor a moment to see + auto-dismiss
|
||||
time.sleep(2.0)
|
||||
snap = supervisor.snapshot()
|
||||
# Nothing pending because auto-dismiss cleared it immediately
|
||||
assert snap.pending_dialogs == ()
|
||||
|
||||
|
||||
def test_registry_idempotent_get_or_start(chrome_cdp, supervisor_registry):
|
||||
"""Calling get_or_start twice with the same (task, url) returns the same instance."""
|
||||
cdp_url, _port = chrome_cdp
|
||||
a = supervisor_registry.get_or_start(task_id="pytest-idem", cdp_url=cdp_url)
|
||||
b = supervisor_registry.get_or_start(task_id="pytest-idem", cdp_url=cdp_url)
|
||||
assert a is b
|
||||
|
||||
|
||||
def test_registry_stop(chrome_cdp, supervisor_registry):
|
||||
"""stop() tears down the supervisor and snapshot reports inactive."""
|
||||
cdp_url, _port = chrome_cdp
|
||||
supervisor = supervisor_registry.get_or_start(task_id="pytest-stop", cdp_url=cdp_url)
|
||||
assert supervisor.snapshot().active is True
|
||||
supervisor_registry.stop("pytest-stop")
|
||||
# Post-stop snapshot reports inactive; supervisor obj may still exist
|
||||
assert supervisor.snapshot().active is False
|
||||
|
||||
|
||||
def test_browser_dialog_tool_no_supervisor():
|
||||
"""browser_dialog returns a clear error when no supervisor is attached."""
|
||||
from tools.browser_dialog_tool import browser_dialog
|
||||
|
||||
r = json.loads(browser_dialog(action="accept", task_id="nonexistent-task"))
|
||||
assert r["success"] is False
|
||||
assert "No CDP supervisor" in r["error"]
|
||||
|
||||
|
||||
def test_browser_dialog_invalid_action(chrome_cdp, supervisor_registry):
|
||||
"""browser_dialog rejects actions that aren't accept/dismiss."""
|
||||
from tools.browser_dialog_tool import browser_dialog
|
||||
|
||||
cdp_url, _port = chrome_cdp
|
||||
supervisor_registry.get_or_start(task_id="pytest-bad-action", cdp_url=cdp_url)
|
||||
|
||||
r = json.loads(browser_dialog(action="eat", task_id="pytest-bad-action"))
|
||||
assert r["success"] is False
|
||||
assert "accept" in r["error"] and "dismiss" in r["error"]
|
||||
|
||||
|
||||
def test_recent_dialogs_ring_buffer(chrome_cdp, supervisor_registry):
|
||||
"""Closed dialogs show up in recent_dialogs with a closed_by tag."""
|
||||
from tools.browser_supervisor import DIALOG_POLICY_AUTO_DISMISS
|
||||
|
||||
cdp_url, _port = chrome_cdp
|
||||
sv = supervisor_registry.get_or_start(
|
||||
task_id="pytest-recent",
|
||||
cdp_url=cdp_url,
|
||||
dialog_policy=DIALOG_POLICY_AUTO_DISMISS,
|
||||
)
|
||||
|
||||
_fire_on_page(cdp_url, "setTimeout(() => alert('PYTEST-RECENT'), 50)")
|
||||
# Wait for auto-dismiss to cycle the dialog through
|
||||
deadline = time.time() + 5
|
||||
while time.time() < deadline:
|
||||
recent = sv.snapshot().recent_dialogs
|
||||
if recent and any("PYTEST-RECENT" in r.message for r in recent):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
recent = sv.snapshot().recent_dialogs
|
||||
assert recent, "recent_dialogs should contain the auto-dismissed dialog"
|
||||
match = next((r for r in recent if "PYTEST-RECENT" in r.message), None)
|
||||
assert match is not None
|
||||
assert match.type == "alert"
|
||||
assert match.closed_by == "auto_policy"
|
||||
assert match.closed_at >= match.opened_at
|
||||
|
||||
|
||||
def test_browser_dialog_tool_end_to_end(chrome_cdp, supervisor_registry):
|
||||
"""Full agent-path check: fire an alert, call the tool handler directly."""
|
||||
from tools.browser_dialog_tool import browser_dialog
|
||||
|
||||
cdp_url, _port = chrome_cdp
|
||||
supervisor = supervisor_registry.get_or_start(task_id="pytest-tool", cdp_url=cdp_url)
|
||||
|
||||
_fire_on_page(cdp_url, "setTimeout(() => alert('PYTEST-TOOL-END2END'), 50)")
|
||||
assert _wait_for_dialog(supervisor), "no dialog detected via wait_for_dialog"
|
||||
|
||||
r = json.loads(browser_dialog(action="dismiss", task_id="pytest-tool"))
|
||||
assert r["success"] is True
|
||||
assert r["action"] == "dismiss"
|
||||
assert "PYTEST-TOOL-END2END" in r["dialog"]["message"]
|
||||
|
||||
|
||||
def test_browser_cdp_frame_id_routes_via_supervisor(chrome_cdp, supervisor_registry, monkeypatch):
|
||||
"""browser_cdp(frame_id=...) routes Runtime.evaluate through supervisor.
|
||||
|
||||
Mocks the supervisor with a known frame and verifies browser_cdp sends
|
||||
the call via the supervisor's loop rather than opening a stateless
|
||||
WebSocket. This is the path that makes cross-origin iframe eval work
|
||||
on Browserbase.
|
||||
"""
|
||||
cdp_url, _port = chrome_cdp
|
||||
sv = supervisor_registry.get_or_start(task_id="frame-id-test", cdp_url=cdp_url)
|
||||
assert sv.snapshot().active
|
||||
|
||||
# Inject a fake OOPIF frame pointing at the SUPERVISOR's own page session
|
||||
# so we can verify routing. We fake is_oopif=True so the code path
|
||||
# treats it as an OOPIF child.
|
||||
import tools.browser_supervisor as _bs
|
||||
with sv._state_lock:
|
||||
fake_frame_id = "FAKE-FRAME-001"
|
||||
sv._frames[fake_frame_id] = _bs.FrameInfo(
|
||||
frame_id=fake_frame_id,
|
||||
url="fake://",
|
||||
origin="",
|
||||
parent_frame_id=None,
|
||||
is_oopif=True,
|
||||
cdp_session_id=sv._page_session_id, # route at page scope
|
||||
)
|
||||
|
||||
# Route the tool through the supervisor. Should succeed and return
|
||||
# something that clearly came from CDP.
|
||||
from tools.browser_cdp_tool import browser_cdp
|
||||
result = browser_cdp(
|
||||
method="Runtime.evaluate",
|
||||
params={"expression": "1 + 1", "returnByValue": True},
|
||||
frame_id=fake_frame_id,
|
||||
task_id="frame-id-test",
|
||||
)
|
||||
r = json.loads(result)
|
||||
assert r.get("success") is True, f"expected success, got: {r}"
|
||||
assert r.get("frame_id") == fake_frame_id
|
||||
assert r.get("session_id") == sv._page_session_id
|
||||
value = r.get("result", {}).get("result", {}).get("value")
|
||||
assert value == 2, f"expected 2, got {value!r}"
|
||||
|
||||
|
||||
def test_browser_cdp_frame_id_real_oopif_smoke_documented():
|
||||
"""Document that real-OOPIF E2E was manually verified — see PR #14540.
|
||||
|
||||
A pytest version of this hits an asyncio version-quirk in the venv
|
||||
(3.11) that doesn't show up in standalone scripts (3.13 + system
|
||||
websockets). The mechanism IS verified end-to-end by two separate
|
||||
smoke scripts in /tmp/dialog-iframe-test/:
|
||||
|
||||
* smoke_local_oopif.py — local Chrome + 2 http servers on
|
||||
different hostnames + --site-per-process. Outer page on
|
||||
localhost:18905, iframe src=http://127.0.0.1:18906. Calls
|
||||
browser_cdp(method='Runtime.evaluate', frame_id=<OOPIF>) and
|
||||
verifies inner page's title comes back from the OOPIF session.
|
||||
PASSED on 2026-04-23: iframe document.title = 'INNER-FRAME-XYZ'
|
||||
|
||||
* smoke_bb_iframe_agent_path.py — Browserbase + real cross-origin
|
||||
iframe (src=https://example.com/). Same browser_cdp(frame_id=)
|
||||
path. PASSED on 2026-04-23: iframe document.title =
|
||||
'Example Domain'
|
||||
|
||||
The test_browser_cdp_frame_id_routes_via_supervisor pytest covers
|
||||
the supervisor-routing plumbing with a fake injected OOPIF.
|
||||
"""
|
||||
pytest.skip(
|
||||
"Real-OOPIF E2E verified manually with smoke_local_oopif.py and "
|
||||
"smoke_bb_iframe_agent_path.py — pytest version hits an asyncio "
|
||||
"version quirk between venv (3.11) and standalone (3.13). "
|
||||
"Smoke logs preserved in /tmp/dialog-iframe-test/."
|
||||
)
|
||||
|
||||
|
||||
def test_browser_cdp_frame_id_missing_supervisor():
|
||||
"""browser_cdp(frame_id=...) errors cleanly when no supervisor is attached."""
|
||||
from tools.browser_cdp_tool import browser_cdp
|
||||
result = browser_cdp(
|
||||
method="Runtime.evaluate",
|
||||
params={"expression": "1"},
|
||||
frame_id="any-frame-id",
|
||||
task_id="no-such-task",
|
||||
)
|
||||
r = json.loads(result)
|
||||
assert r.get("success") is not True
|
||||
assert "supervisor" in (r.get("error") or "").lower()
|
||||
|
||||
|
||||
def test_browser_cdp_frame_id_not_in_frame_tree(chrome_cdp, supervisor_registry):
|
||||
"""browser_cdp(frame_id=...) errors when the frame_id isn't known."""
|
||||
cdp_url, _port = chrome_cdp
|
||||
sv = supervisor_registry.get_or_start(task_id="bad-frame-test", cdp_url=cdp_url)
|
||||
assert sv.snapshot().active
|
||||
|
||||
from tools.browser_cdp_tool import browser_cdp
|
||||
result = browser_cdp(
|
||||
method="Runtime.evaluate",
|
||||
params={"expression": "1"},
|
||||
frame_id="nonexistent-frame",
|
||||
task_id="bad-frame-test",
|
||||
)
|
||||
r = json.loads(result)
|
||||
assert r.get("success") is not True
|
||||
assert "not found" in (r.get("error") or "").lower()
|
||||
|
||||
|
||||
def test_bridge_captures_prompt_and_returns_reply_text(chrome_cdp, supervisor_registry):
|
||||
"""End-to-end: agent's prompt_text round-trips INTO the page's JS.
|
||||
|
||||
Proves the bridge isn't just catching dialogs — it's properly round-
|
||||
tripping our reply back into the page via Fetch.fulfillRequest, so
|
||||
``prompt()`` actually returns the agent-supplied string to the page.
|
||||
"""
|
||||
import base64 as _b64
|
||||
|
||||
cdp_url, _port = chrome_cdp
|
||||
sv = supervisor_registry.get_or_start(task_id="pytest-bridge-prompt", cdp_url=cdp_url)
|
||||
|
||||
# Page fires prompt and stashes the return value on window.
|
||||
html = """<!doctype html><html><body><script>
|
||||
window.__ret = null;
|
||||
setTimeout(() => { window.__ret = prompt('PROMPT-MSG', 'default'); }, 50);
|
||||
</script></body></html>"""
|
||||
url = "data:text/html;base64," + _b64.b64encode(html.encode()).decode()
|
||||
|
||||
import asyncio as _asyncio
|
||||
import websockets as _ws_mod
|
||||
|
||||
async def nav_and_read():
|
||||
async with _ws_mod.connect(cdp_url, max_size=50 * 1024 * 1024) as ws:
|
||||
nid = [1]
|
||||
pending: dict = {}
|
||||
|
||||
async def reader_fn():
|
||||
try:
|
||||
async for raw in ws:
|
||||
m = json.loads(raw)
|
||||
if "id" in m:
|
||||
fut = pending.pop(m["id"], None)
|
||||
if fut and not fut.done():
|
||||
fut.set_result(m)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
rd = _asyncio.create_task(reader_fn())
|
||||
|
||||
async def call(method, params=None, sid=None):
|
||||
c = nid[0]; nid[0] += 1
|
||||
p = {"id": c, "method": method}
|
||||
if params: p["params"] = params
|
||||
if sid: p["sessionId"] = sid
|
||||
fut = _asyncio.get_event_loop().create_future()
|
||||
pending[c] = fut
|
||||
await ws.send(json.dumps(p))
|
||||
return await _asyncio.wait_for(fut, timeout=20)
|
||||
|
||||
try:
|
||||
t = (await call("Target.getTargets"))["result"]["targetInfos"]
|
||||
pg = next(x for x in t if x.get("type") == "page")
|
||||
a = await call("Target.attachToTarget", {"targetId": pg["targetId"], "flatten": True})
|
||||
sid = a["result"]["sessionId"]
|
||||
|
||||
# Fire navigate but don't await — prompt() blocks the page
|
||||
nav_id = nid[0]; nid[0] += 1
|
||||
nav_fut = _asyncio.get_event_loop().create_future()
|
||||
pending[nav_id] = nav_fut
|
||||
await ws.send(json.dumps({"id": nav_id, "method": "Page.navigate", "params": {"url": url}, "sessionId": sid}))
|
||||
|
||||
# Wait for supervisor to see the prompt
|
||||
deadline = time.monotonic() + 10
|
||||
dialog = None
|
||||
while time.monotonic() < deadline:
|
||||
snap = sv.snapshot()
|
||||
if snap.pending_dialogs:
|
||||
dialog = snap.pending_dialogs[0]
|
||||
break
|
||||
await _asyncio.sleep(0.05)
|
||||
assert dialog is not None, "no dialog captured"
|
||||
assert dialog.bridge_request_id is not None, "expected bridge path"
|
||||
assert dialog.type == "prompt"
|
||||
|
||||
# Agent responds
|
||||
resp = sv.respond_to_dialog("accept", prompt_text="AGENT-SUPPLIED-REPLY")
|
||||
assert resp["ok"] is True
|
||||
|
||||
# Wait for nav to complete + read back
|
||||
try:
|
||||
await _asyncio.wait_for(nav_fut, timeout=10)
|
||||
except Exception:
|
||||
pass
|
||||
await _asyncio.sleep(0.5)
|
||||
r = await call(
|
||||
"Runtime.evaluate",
|
||||
{"expression": "window.__ret", "returnByValue": True},
|
||||
sid=sid,
|
||||
)
|
||||
return r.get("result", {}).get("result", {}).get("value")
|
||||
finally:
|
||||
rd.cancel()
|
||||
try: await rd
|
||||
except BaseException: pass
|
||||
|
||||
value = asyncio.run(nav_and_read())
|
||||
assert value == "AGENT-SUPPLIED-REPLY", f"expected AGENT-SUPPLIED-REPLY, got {value!r}"
|
||||
|
|
@ -357,12 +357,33 @@ class TestWorkingDirResolution:
|
|||
result = mgr.get_working_dir_for_path(str(subdir / "file.py"))
|
||||
assert result == str(project)
|
||||
|
||||
def test_falls_back_to_parent(self, tmp_path):
|
||||
def test_falls_back_to_parent(self, tmp_path, monkeypatch):
|
||||
mgr = CheckpointManager(enabled=True)
|
||||
filepath = tmp_path / "random" / "file.py"
|
||||
filepath.parent.mkdir(parents=True)
|
||||
filepath.write_text("x\\n")
|
||||
|
||||
# The walk-up scan for project markers (.git, pyproject.toml, etc.)
|
||||
# stops at tmp_path — otherwise stray markers in ``/tmp`` (e.g.
|
||||
# ``/tmp/pyproject.toml`` left by other tools on the host) get
|
||||
# picked up as the project root and this test flakes on shared CI.
|
||||
import pathlib as _pl
|
||||
_real_exists = _pl.Path.exists
|
||||
|
||||
def _guarded_exists(self):
|
||||
s = str(self)
|
||||
stop = str(tmp_path)
|
||||
if not s.startswith(stop) and any(
|
||||
s.endswith("/" + m) or s == "/" + m
|
||||
for m in (".git", "pyproject.toml", "package.json",
|
||||
"Cargo.toml", "go.mod", "Makefile", "pom.xml",
|
||||
".hg", "Gemfile")
|
||||
):
|
||||
return False
|
||||
return _real_exists(self)
|
||||
|
||||
monkeypatch.setattr(_pl.Path, "exists", _guarded_exists)
|
||||
|
||||
result = mgr.get_working_dir_for_path(str(filepath))
|
||||
assert result == str(filepath.parent)
|
||||
|
||||
|
|
|
|||
|
|
@ -69,7 +69,10 @@ class TestDelegateRequirements(unittest.TestCase):
|
|||
self.assertIn("tasks", props)
|
||||
self.assertIn("context", props)
|
||||
self.assertIn("toolsets", props)
|
||||
self.assertIn("max_iterations", props)
|
||||
# max_iterations is intentionally NOT exposed to the model — it's
|
||||
# config-authoritative via delegation.max_iterations so users get
|
||||
# predictable budgets.
|
||||
self.assertNotIn("max_iterations", props)
|
||||
self.assertNotIn("maxItems", props["tasks"]) # removed — limit is now runtime-configurable
|
||||
|
||||
|
||||
|
|
@ -1316,6 +1319,112 @@ class TestDelegateHeartbeat(unittest.TestCase):
|
|||
any("API call #5 completed" in desc for desc in touch_calls),
|
||||
f"Heartbeat should include last_activity_desc: {touch_calls}")
|
||||
|
||||
def test_heartbeat_does_not_trip_idle_stale_while_inside_tool(self):
|
||||
"""A long-running tool (no iteration advance, but current_tool set)
|
||||
must not be flagged stale at the idle threshold.
|
||||
|
||||
Bug #13041: when a child is legitimately busy inside a slow tool
|
||||
(terminal command, browser fetch), api_call_count does not advance.
|
||||
The previous stale check treated this as idle and stopped the
|
||||
heartbeat after 5 cycles (~150s), letting the gateway kill the
|
||||
session. The fix uses a much higher in-tool threshold and only
|
||||
applies the tight idle threshold when current_tool is None.
|
||||
"""
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
parent = _make_mock_parent()
|
||||
touch_calls = []
|
||||
parent._touch_activity = lambda desc: touch_calls.append(desc)
|
||||
|
||||
child = MagicMock()
|
||||
# Child is stuck inside a single terminal call for the whole run.
|
||||
# api_call_count never advances, current_tool is always set.
|
||||
child.get_activity_summary.return_value = {
|
||||
"current_tool": "terminal",
|
||||
"api_call_count": 1,
|
||||
"max_iterations": 50,
|
||||
"last_activity_desc": "executing tool: terminal",
|
||||
}
|
||||
|
||||
def slow_run(**kwargs):
|
||||
# Long enough to exceed the OLD idle threshold (5 cycles) at
|
||||
# the patched interval, but shorter than the new in-tool
|
||||
# threshold.
|
||||
time.sleep(0.4)
|
||||
return {"final_response": "done", "completed": True, "api_calls": 1}
|
||||
|
||||
child.run_conversation.side_effect = slow_run
|
||||
|
||||
# Patch both the interval AND the idle ceiling so the test proves
|
||||
# the in-tool branch takes effect: with a 0.05s interval and the
|
||||
# default _HEARTBEAT_STALE_CYCLES_IDLE=5, the old behavior would
|
||||
# trip after 0.25s and stop firing. We should see heartbeats
|
||||
# continuing through the full 0.4s run.
|
||||
with patch("tools.delegate_tool._HEARTBEAT_INTERVAL", 0.05):
|
||||
_run_single_child(
|
||||
task_index=0,
|
||||
goal="Test long-running tool",
|
||||
child=child,
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
# With the old idle threshold (5 cycles = 0.25s), touch_calls
|
||||
# would cap at ~5. With the in-tool threshold (20 cycles = 1.0s),
|
||||
# we should see substantially more heartbeats over 0.4s.
|
||||
self.assertGreater(
|
||||
len(touch_calls), 6,
|
||||
f"Heartbeat stopped too early while child was inside a tool; "
|
||||
f"got {len(touch_calls)} touches over 0.4s at 0.05s interval",
|
||||
)
|
||||
|
||||
def test_heartbeat_still_trips_idle_stale_when_no_tool(self):
|
||||
"""A wedged child with no current_tool still trips the idle threshold.
|
||||
|
||||
Regression guard: the fix for #13041 must not disable stale
|
||||
detection entirely. A child that's hung between turns (no tool
|
||||
running, no iteration progress) must still stop touching the
|
||||
parent so the gateway timeout can fire.
|
||||
"""
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
parent = _make_mock_parent()
|
||||
touch_calls = []
|
||||
parent._touch_activity = lambda desc: touch_calls.append(desc)
|
||||
|
||||
child = MagicMock()
|
||||
# Wedged child: no tool running, iteration frozen.
|
||||
child.get_activity_summary.return_value = {
|
||||
"current_tool": None,
|
||||
"api_call_count": 3,
|
||||
"max_iterations": 50,
|
||||
"last_activity_desc": "waiting for API response",
|
||||
}
|
||||
|
||||
def slow_run(**kwargs):
|
||||
time.sleep(0.6)
|
||||
return {"final_response": "done", "completed": True, "api_calls": 3}
|
||||
|
||||
child.run_conversation.side_effect = slow_run
|
||||
|
||||
# At interval 0.05s, idle threshold (5 cycles) trips at ~0.25s.
|
||||
# We should see the heartbeat stop firing well before 0.6s.
|
||||
with patch("tools.delegate_tool._HEARTBEAT_INTERVAL", 0.05):
|
||||
_run_single_child(
|
||||
task_index=0,
|
||||
goal="Test wedged child",
|
||||
child=child,
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
# With idle threshold=5 + interval=0.05s, touches should cap
|
||||
# around 5. Bound loosely to avoid timing flakes.
|
||||
self.assertLess(
|
||||
len(touch_calls), 9,
|
||||
f"Idle stale detection did not fire: got {len(touch_calls)} "
|
||||
f"touches over 0.6s — expected heartbeat to stop after "
|
||||
f"~5 stale cycles",
|
||||
)
|
||||
|
||||
|
||||
class TestDelegationReasoningEffort(unittest.TestCase):
|
||||
"""Tests for delegation.reasoning_effort config override."""
|
||||
|
|
|
|||
286
tests/tools/test_delegate_subagent_timeout_diagnostic.py
Normal file
286
tests/tools/test_delegate_subagent_timeout_diagnostic.py
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
"""Regression tests for subagent timeout diagnostic dump (issue #14726).
|
||||
|
||||
When delegate_task's child subagent times out without having made any API
|
||||
call, a structured diagnostic file is written under
|
||||
``~/.hermes/logs/subagent-timeout-<sid>-<ts>.log``. This gives users a
|
||||
concrete artifact to inspect (worker thread stack, system prompt size,
|
||||
tool schema bytes, credential pool state, etc.) instead of the previous
|
||||
opaque "subagent timed out" error.
|
||||
|
||||
These tests pin:
|
||||
- the diagnostic writer's output format and content
|
||||
- the timeout branch in _run_single_child only dumps when api_calls == 0
|
||||
- the error message surfaces the diagnostic path
|
||||
- api_calls > 0 timeouts do NOT write a dump (the old "stuck on slow API
|
||||
call" explanation still applies)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hermes_home(tmp_path, monkeypatch):
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
return home
|
||||
|
||||
|
||||
class _StubChild:
|
||||
"""Minimal stand-in for an AIAgent subagent."""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_call_count: int = 0,
|
||||
hang_seconds: float = 5.0,
|
||||
subagent_id: str = "sa-0-stubabc",
|
||||
tool_schema=None,
|
||||
):
|
||||
self._subagent_id = subagent_id
|
||||
self._delegate_depth = 1
|
||||
self._delegate_role = "leaf"
|
||||
self.model = "test/model"
|
||||
self.provider = "testprov"
|
||||
self.api_mode = "chat_completions"
|
||||
self.base_url = "https://example.test/v1"
|
||||
self.max_iterations = 30
|
||||
self.quiet_mode = True
|
||||
self.skip_memory = True
|
||||
self.skip_context_files = True
|
||||
self.platform = "cli"
|
||||
self.ephemeral_system_prompt = "sys prompt"
|
||||
self.enabled_toolsets = ["web", "terminal"]
|
||||
self.valid_tool_names = {"web_search", "terminal"}
|
||||
self.tools = tool_schema if tool_schema is not None else [
|
||||
{"name": "web_search", "description": "search"},
|
||||
{"name": "terminal", "description": "shell"},
|
||||
]
|
||||
self._api_call_count = api_call_count
|
||||
self._hang = threading.Event()
|
||||
self._hang_seconds = hang_seconds
|
||||
|
||||
def get_activity_summary(self):
|
||||
return {
|
||||
"api_call_count": self._api_call_count,
|
||||
"max_iterations": self.max_iterations,
|
||||
"current_tool": None,
|
||||
"seconds_since_activity": 60,
|
||||
}
|
||||
|
||||
def run_conversation(self, user_message, task_id=None):
|
||||
self._hang.wait(self._hang_seconds)
|
||||
return {"final_response": "", "completed": False, "api_calls": self._api_call_count}
|
||||
|
||||
def interrupt(self):
|
||||
self._hang.set()
|
||||
|
||||
|
||||
# ── _dump_subagent_timeout_diagnostic ──────────────────────────────────
|
||||
|
||||
class TestDumpSubagentTimeoutDiagnostic:
|
||||
|
||||
def test_writes_log_with_expected_sections(self, hermes_home):
|
||||
from tools.delegate_tool import _dump_subagent_timeout_diagnostic
|
||||
child = _StubChild(subagent_id="sa-7-abc123")
|
||||
|
||||
worker = threading.Thread(
|
||||
target=lambda: child.run_conversation("test"),
|
||||
daemon=True,
|
||||
)
|
||||
worker.start()
|
||||
time.sleep(0.1)
|
||||
try:
|
||||
path = _dump_subagent_timeout_diagnostic(
|
||||
child=child,
|
||||
task_index=7,
|
||||
timeout_seconds=300.0,
|
||||
duration_seconds=300.01,
|
||||
worker_thread=worker,
|
||||
goal="Research something long",
|
||||
)
|
||||
finally:
|
||||
child.interrupt()
|
||||
worker.join(timeout=2.0)
|
||||
|
||||
assert path is not None
|
||||
p = Path(path)
|
||||
assert p.is_file()
|
||||
# File lives under HERMES_HOME/logs/
|
||||
assert p.parent == hermes_home / "logs"
|
||||
assert p.name.startswith("subagent-timeout-sa-7-abc123-")
|
||||
assert p.suffix == ".log"
|
||||
|
||||
content = p.read_text()
|
||||
# Header references the issue for future grep-ability
|
||||
assert "issue #14726" in content
|
||||
# Timeout facts
|
||||
assert "task_index: 7" in content
|
||||
assert "subagent_id: sa-7-abc123" in content
|
||||
assert "configured_timeout: 300.0s" in content
|
||||
assert "actual_duration: 300.01s" in content
|
||||
# Goal
|
||||
assert "Research something long" in content
|
||||
# Child config
|
||||
assert "model: 'test/model'" in content
|
||||
assert "provider: 'testprov'" in content
|
||||
assert "base_url: 'https://example.test/v1'" in content
|
||||
assert "max_iterations: 30" in content
|
||||
# Toolsets
|
||||
assert "enabled_toolsets: ['web', 'terminal']" in content
|
||||
assert "loaded tool count: 2" in content
|
||||
# Prompt / schema sizes
|
||||
assert "system_prompt_bytes:" in content
|
||||
assert "tool_schema_count: 2" in content
|
||||
assert "tool_schema_bytes:" in content
|
||||
# Activity summary
|
||||
assert "api_call_count: 0" in content
|
||||
# Worker stack
|
||||
assert "Worker thread stack at timeout" in content
|
||||
# The thread is parked inside _hang.wait → cond.wait → waiter.acquire
|
||||
assert "acquire" in content or "wait" in content
|
||||
|
||||
def test_truncates_very_long_goal(self, hermes_home):
|
||||
from tools.delegate_tool import _dump_subagent_timeout_diagnostic
|
||||
child = _StubChild()
|
||||
huge_goal = "x" * 5000
|
||||
|
||||
path = _dump_subagent_timeout_diagnostic(
|
||||
child=child,
|
||||
task_index=0,
|
||||
timeout_seconds=300.0,
|
||||
duration_seconds=300.0,
|
||||
worker_thread=None,
|
||||
goal=huge_goal,
|
||||
)
|
||||
child.interrupt()
|
||||
|
||||
content = Path(path).read_text()
|
||||
assert "[truncated]" in content
|
||||
# Goal section trimmed to 1000 chars + suffix
|
||||
goal_block = content.split("## Goal", 1)[1].split("## Child config", 1)[0]
|
||||
assert len(goal_block) < 1200
|
||||
|
||||
def test_missing_worker_thread_is_handled(self, hermes_home):
|
||||
from tools.delegate_tool import _dump_subagent_timeout_diagnostic
|
||||
child = _StubChild()
|
||||
path = _dump_subagent_timeout_diagnostic(
|
||||
child=child,
|
||||
task_index=0,
|
||||
timeout_seconds=300.0,
|
||||
duration_seconds=300.0,
|
||||
worker_thread=None,
|
||||
goal="x",
|
||||
)
|
||||
child.interrupt()
|
||||
content = Path(path).read_text()
|
||||
assert "<no worker thread handle>" in content
|
||||
|
||||
def test_exited_worker_thread_is_handled(self, hermes_home):
|
||||
from tools.delegate_tool import _dump_subagent_timeout_diagnostic
|
||||
child = _StubChild()
|
||||
# A thread that has already finished
|
||||
t = threading.Thread(target=lambda: None)
|
||||
t.start()
|
||||
t.join()
|
||||
assert not t.is_alive()
|
||||
path = _dump_subagent_timeout_diagnostic(
|
||||
child=child,
|
||||
task_index=0,
|
||||
timeout_seconds=300.0,
|
||||
duration_seconds=300.0,
|
||||
worker_thread=t,
|
||||
goal="x",
|
||||
)
|
||||
child.interrupt()
|
||||
content = Path(path).read_text()
|
||||
assert "<worker thread already exited>" in content
|
||||
|
||||
def test_returns_none_on_unwritable_logs_dir(self, tmp_path, monkeypatch):
|
||||
# Point HERMES_HOME at an unwritable path so logs/ can't be created
|
||||
# (simulates permission-denied). Helper must not raise.
|
||||
from tools.delegate_tool import _dump_subagent_timeout_diagnostic
|
||||
bogus = tmp_path / "does-not-exist" / ".hermes"
|
||||
monkeypatch.setenv("HERMES_HOME", str(bogus))
|
||||
child = _StubChild()
|
||||
|
||||
# Make the logs dir itself unwritable by creating it as a FILE
|
||||
# so mkdir(exist_ok=True) → NotADirectoryError and we fall through.
|
||||
bogus.parent.mkdir(parents=True, exist_ok=True)
|
||||
bogus.mkdir()
|
||||
(bogus / "logs").write_text("not a dir")
|
||||
result = _dump_subagent_timeout_diagnostic(
|
||||
child=child,
|
||||
task_index=0,
|
||||
timeout_seconds=300.0,
|
||||
duration_seconds=300.0,
|
||||
worker_thread=None,
|
||||
goal="x",
|
||||
)
|
||||
child.interrupt()
|
||||
# Either None (mkdir failed) or a real path; must never raise.
|
||||
# We assert no exception propagates — the return value is advisory.
|
||||
assert result is None or Path(result).exists()
|
||||
|
||||
|
||||
# ── _run_single_child timeout branch wiring ───────────────────────────
|
||||
|
||||
class TestRunSingleChildTimeoutDump:
|
||||
"""The timeout branch in _run_single_child must emit the diagnostic
|
||||
dump when api_calls == 0, and must NOT emit it when api_calls > 0."""
|
||||
|
||||
def _invoke_with_short_timeout(self, child, monkeypatch):
|
||||
"""Run _run_single_child with a tiny timeout to force the timeout branch."""
|
||||
from tools import delegate_tool
|
||||
# Force a 0.3s timeout so the test is fast
|
||||
monkeypatch.setattr(delegate_tool, "_get_child_timeout", lambda: 0.3)
|
||||
|
||||
parent = MagicMock()
|
||||
parent._touch_activity = MagicMock()
|
||||
parent._current_task_id = None
|
||||
return delegate_tool._run_single_child(
|
||||
task_index=0,
|
||||
goal="test goal",
|
||||
child=child,
|
||||
parent_agent=parent,
|
||||
)
|
||||
|
||||
def test_zero_api_calls_writes_dump_and_surfaces_path(self, hermes_home, monkeypatch):
|
||||
child = _StubChild(api_call_count=0, hang_seconds=10.0)
|
||||
result = self._invoke_with_short_timeout(child, monkeypatch)
|
||||
|
||||
assert result["status"] == "timeout"
|
||||
assert result["api_calls"] == 0
|
||||
assert result["diagnostic_path"] is not None
|
||||
dump_path = Path(result["diagnostic_path"])
|
||||
assert dump_path.is_file()
|
||||
assert dump_path.parent == hermes_home / "logs"
|
||||
|
||||
# Error message surfaces the path and the "no API call" phrasing
|
||||
assert "without making any API call" in result["error"]
|
||||
assert "Diagnostic:" in result["error"]
|
||||
assert str(dump_path) in result["error"]
|
||||
|
||||
def test_nonzero_api_calls_skips_dump_and_uses_old_message(self, hermes_home, monkeypatch):
|
||||
child = _StubChild(api_call_count=5, hang_seconds=10.0)
|
||||
result = self._invoke_with_short_timeout(child, monkeypatch)
|
||||
|
||||
assert result["status"] == "timeout"
|
||||
assert result["api_calls"] == 5
|
||||
# No diagnostic file should be written for timeouts that made
|
||||
# actual API calls — the old generic "stuck on slow call" message
|
||||
# still applies.
|
||||
assert result.get("diagnostic_path") is None
|
||||
assert "stuck on a slow API call" in result["error"]
|
||||
# And no subagent-timeout-* file should exist under logs/
|
||||
logs_dir = hermes_home / "logs"
|
||||
if logs_dir.is_dir():
|
||||
dumps = list(logs_dir.glob("subagent-timeout-*.log"))
|
||||
assert dumps == []
|
||||
78
tests/tools/test_dockerfile_pid1_reaping.py
Normal file
78
tests/tools/test_dockerfile_pid1_reaping.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
"""Contract tests for the container Dockerfile.
|
||||
|
||||
These tests assert invariants about how the Dockerfile composes its runtime —
|
||||
they deliberately avoid snapshotting specific package versions, line numbers,
|
||||
or exact flag choices. What they DO assert is that the Dockerfile maintains
|
||||
the properties required for correct production behaviour:
|
||||
|
||||
- A PID-1 init (tini) is installed and wraps the entrypoint, so that orphaned
|
||||
subprocesses (MCP stdio servers, git, bun, browser daemons) get reaped
|
||||
instead of accumulating as zombies (#15012).
|
||||
- Signal forwarding runs through the init so ``docker stop`` triggers
|
||||
hermes's own graceful-shutdown path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
DOCKERFILE = REPO_ROOT / "Dockerfile"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dockerfile_text() -> str:
|
||||
if not DOCKERFILE.exists():
|
||||
pytest.skip("Dockerfile not present in this checkout")
|
||||
return DOCKERFILE.read_text()
|
||||
|
||||
|
||||
def test_dockerfile_installs_an_init_for_zombie_reaping(dockerfile_text):
|
||||
"""Some init (tini, dumb-init, catatonit) must be installed.
|
||||
|
||||
Without a PID-1 init that handles SIGCHLD, hermes accumulates zombie
|
||||
processes from MCP stdio subprocesses, git operations, browser
|
||||
daemons, etc. In long-running Docker deployments this eventually
|
||||
exhausts the PID table.
|
||||
"""
|
||||
# Accept any of the common reapers. The contract is behavioural:
|
||||
# something must be installed that reaps orphans.
|
||||
known_inits = ("tini", "dumb-init", "catatonit")
|
||||
installed = any(name in dockerfile_text for name in known_inits)
|
||||
assert installed, (
|
||||
"No PID-1 init detected in Dockerfile (looked for: "
|
||||
f"{', '.join(known_inits)}). Without an init process to reap "
|
||||
"orphaned subprocesses, hermes accumulates zombies in Docker "
|
||||
"deployments. See issue #15012."
|
||||
)
|
||||
|
||||
|
||||
def test_dockerfile_entrypoint_routes_through_the_init(dockerfile_text):
|
||||
"""The ENTRYPOINT must invoke the init, not the entrypoint script directly.
|
||||
|
||||
Installing tini is only half the fix — the container must actually run
|
||||
with tini as PID 1. If the ENTRYPOINT executes the shell script
|
||||
directly, the shell becomes PID 1 and will ``exec`` into hermes,
|
||||
which then runs as PID 1 without any zombie reaping.
|
||||
"""
|
||||
# Find the last uncommented ENTRYPOINT line — Docker honours the final one.
|
||||
entrypoint_line = None
|
||||
for raw_line in dockerfile_text.splitlines():
|
||||
line = raw_line.strip()
|
||||
if line.startswith("#"):
|
||||
continue
|
||||
if line.startswith("ENTRYPOINT"):
|
||||
entrypoint_line = line
|
||||
|
||||
assert entrypoint_line is not None, "Dockerfile is missing an ENTRYPOINT directive"
|
||||
|
||||
known_inits = ("tini", "dumb-init", "catatonit")
|
||||
routes_through_init = any(name in entrypoint_line for name in known_inits)
|
||||
assert routes_through_init, (
|
||||
f"ENTRYPOINT does not route through an init: {entrypoint_line!r}. "
|
||||
"If tini is only installed but not wired into ENTRYPOINT, hermes "
|
||||
"still runs as PID 1 and zombies will accumulate (#15012)."
|
||||
)
|
||||
|
|
@ -13,8 +13,10 @@ import os
|
|||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from tools import file_state
|
||||
from tools.file_tools import (
|
||||
read_file_tool,
|
||||
write_file_tool,
|
||||
|
|
@ -76,6 +78,7 @@ class TestStalenessCheck(unittest.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
_read_tracker.clear()
|
||||
file_state.get_registry().clear()
|
||||
self._tmpdir = tempfile.mkdtemp()
|
||||
self._tmpfile = os.path.join(self._tmpdir, "stale_test.txt")
|
||||
with open(self._tmpfile, "w") as f:
|
||||
|
|
@ -83,6 +86,7 @@ class TestStalenessCheck(unittest.TestCase):
|
|||
|
||||
def tearDown(self):
|
||||
_read_tracker.clear()
|
||||
file_state.get_registry().clear()
|
||||
try:
|
||||
os.unlink(self._tmpfile)
|
||||
os.rmdir(self._tmpdir)
|
||||
|
|
@ -145,6 +149,53 @@ class TestStalenessCheck(unittest.TestCase):
|
|||
result = json.loads(write_file_tool(self._tmpfile, "new", task_id="task_b"))
|
||||
self.assertNotIn("_warning", result)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_relative_path_uses_live_cwd_for_staleness_tracking(self, mock_ops):
|
||||
"""Relative-path stale tracking must follow the live terminal cwd."""
|
||||
start_dir = os.path.join(self._tmpdir, "start")
|
||||
live_dir = os.path.join(self._tmpdir, "worktree")
|
||||
os.makedirs(start_dir, exist_ok=True)
|
||||
os.makedirs(live_dir, exist_ok=True)
|
||||
|
||||
start_file = os.path.join(start_dir, "shared.txt")
|
||||
live_file = os.path.join(live_dir, "shared.txt")
|
||||
with open(start_file, "w") as f:
|
||||
f.write("start copy\n")
|
||||
with open(live_file, "w") as f:
|
||||
f.write("live copy\n")
|
||||
|
||||
fake_ops = _make_fake_ops("live copy\n", 10)
|
||||
fake_ops.env = SimpleNamespace(cwd=live_dir)
|
||||
fake_ops.cwd = start_dir
|
||||
mock_ops.return_value = fake_ops
|
||||
|
||||
from tools import file_tools
|
||||
|
||||
with file_tools._file_ops_lock:
|
||||
previous = file_tools._file_ops_cache.get("live_task")
|
||||
file_tools._file_ops_cache["live_task"] = fake_ops
|
||||
|
||||
try:
|
||||
with patch.dict(os.environ, {"TERMINAL_CWD": start_dir}, clear=False):
|
||||
read_file_tool("shared.txt", task_id="live_task")
|
||||
|
||||
time.sleep(0.05)
|
||||
with open(live_file, "w") as f:
|
||||
f.write("live copy modified elsewhere\n")
|
||||
|
||||
result = json.loads(
|
||||
write_file_tool("shared.txt", "replacement", task_id="live_task")
|
||||
)
|
||||
finally:
|
||||
with file_tools._file_ops_lock:
|
||||
if previous is None:
|
||||
file_tools._file_ops_cache.pop("live_task", None)
|
||||
else:
|
||||
file_tools._file_ops_cache["live_task"] = previous
|
||||
|
||||
self.assertIn("_warning", result)
|
||||
self.assertIn("modified since you last read", result["_warning"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Staleness in patch
|
||||
|
|
@ -154,6 +205,7 @@ class TestPatchStaleness(unittest.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
_read_tracker.clear()
|
||||
file_state.get_registry().clear()
|
||||
self._tmpdir = tempfile.mkdtemp()
|
||||
self._tmpfile = os.path.join(self._tmpdir, "patch_test.txt")
|
||||
with open(self._tmpfile, "w") as f:
|
||||
|
|
@ -161,6 +213,7 @@ class TestPatchStaleness(unittest.TestCase):
|
|||
|
||||
def tearDown(self):
|
||||
_read_tracker.clear()
|
||||
file_state.get_registry().clear()
|
||||
try:
|
||||
os.unlink(self._tmpfile)
|
||||
os.rmdir(self._tmpdir)
|
||||
|
|
@ -207,9 +260,11 @@ class TestCheckFileStalenessHelper(unittest.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
_read_tracker.clear()
|
||||
file_state.get_registry().clear()
|
||||
|
||||
def tearDown(self):
|
||||
_read_tracker.clear()
|
||||
file_state.get_registry().clear()
|
||||
|
||||
def test_returns_none_for_unknown_task(self):
|
||||
self.assertIsNone(_check_file_staleness("/tmp/x.py", "nonexistent"))
|
||||
|
|
|
|||
|
|
@ -247,7 +247,9 @@ class TestPatchHints:
|
|||
|
||||
from tools.file_tools import patch_tool
|
||||
raw = patch_tool(mode="replace", path="foo.py", old_string="x", new_string="y")
|
||||
assert "[Hint:" in raw
|
||||
# patch_tool surfaces the hint as a structured "_hint" field on the
|
||||
# JSON error payload (not an inline "[Hint: ..." tail).
|
||||
assert "_hint" in raw
|
||||
assert "read_file" in raw
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
|
|
@ -260,7 +262,7 @@ class TestPatchHints:
|
|||
|
||||
from tools.file_tools import patch_tool
|
||||
raw = patch_tool(mode="replace", path="foo.py", old_string="x", new_string="y")
|
||||
assert "[Hint:" not in raw
|
||||
assert "_hint" not in raw
|
||||
|
||||
|
||||
class TestSearchHints:
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class TestStdioPidTracking:
|
|||
from tools.mcp_tool import _stdio_pids, _lock
|
||||
with _lock:
|
||||
# Might have residual state from other tests, just check type
|
||||
assert isinstance(_stdio_pids, set)
|
||||
assert isinstance(_stdio_pids, dict)
|
||||
|
||||
def test_kill_orphaned_noop_when_empty(self):
|
||||
"""_kill_orphaned_mcp_children does nothing when no PIDs tracked."""
|
||||
|
|
@ -96,7 +96,7 @@ class TestStdioPidTracking:
|
|||
# Use a PID that definitely doesn't exist
|
||||
fake_pid = 999999999
|
||||
with _lock:
|
||||
_stdio_pids.add(fake_pid)
|
||||
_stdio_pids[fake_pid] = "test"
|
||||
|
||||
# Should not raise (ProcessLookupError is caught)
|
||||
_kill_orphaned_mcp_children()
|
||||
|
|
@ -105,40 +105,49 @@ class TestStdioPidTracking:
|
|||
assert fake_pid not in _stdio_pids
|
||||
|
||||
def test_kill_orphaned_uses_sigkill_when_available(self, monkeypatch):
|
||||
"""Unix-like platforms should keep using SIGKILL for orphan cleanup."""
|
||||
"""SIGTERM-first then SIGKILL after 2s for orphan cleanup."""
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||
|
||||
fake_pid = 424242
|
||||
with _lock:
|
||||
_stdio_pids.clear()
|
||||
_stdio_pids.add(fake_pid)
|
||||
_stdio_pids[fake_pid] = "test"
|
||||
|
||||
fake_sigkill = 9
|
||||
monkeypatch.setattr(signal, "SIGKILL", fake_sigkill, raising=False)
|
||||
|
||||
with patch("tools.mcp_tool.os.kill") as mock_kill:
|
||||
with patch("tools.mcp_tool.os.kill") as mock_kill, \
|
||||
patch("time.sleep") as mock_sleep:
|
||||
_kill_orphaned_mcp_children()
|
||||
|
||||
mock_kill.assert_called_once_with(fake_pid, fake_sigkill)
|
||||
# SIGTERM, then alive-check (signal 0), then SIGKILL
|
||||
mock_kill.assert_any_call(fake_pid, signal.SIGTERM)
|
||||
mock_kill.assert_any_call(fake_pid, 0) # alive check
|
||||
mock_kill.assert_any_call(fake_pid, fake_sigkill)
|
||||
assert mock_kill.call_count == 3
|
||||
mock_sleep.assert_called_once_with(2)
|
||||
|
||||
with _lock:
|
||||
assert fake_pid not in _stdio_pids
|
||||
|
||||
def test_kill_orphaned_falls_back_without_sigkill(self, monkeypatch):
|
||||
"""Windows-like signal modules without SIGKILL should fall back to SIGTERM."""
|
||||
"""Without SIGKILL, SIGTERM is used for both phases."""
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||
|
||||
fake_pid = 434343
|
||||
with _lock:
|
||||
_stdio_pids.clear()
|
||||
_stdio_pids.add(fake_pid)
|
||||
_stdio_pids[fake_pid] = "test"
|
||||
|
||||
monkeypatch.delattr(signal, "SIGKILL", raising=False)
|
||||
|
||||
with patch("tools.mcp_tool.os.kill") as mock_kill:
|
||||
with patch("tools.mcp_tool.os.kill") as mock_kill, \
|
||||
patch("time.sleep") as mock_sleep:
|
||||
_kill_orphaned_mcp_children()
|
||||
|
||||
mock_kill.assert_called_once_with(fake_pid, signal.SIGTERM)
|
||||
# SIGTERM phase, alive check raises (process gone), no escalation
|
||||
mock_kill.assert_any_call(fake_pid, signal.SIGTERM)
|
||||
assert mock_sleep.called
|
||||
|
||||
with _lock:
|
||||
assert fake_pid not in _stdio_pids
|
||||
|
|
|
|||
|
|
@ -120,6 +120,177 @@ class TestSchemaConversion:
|
|||
|
||||
assert schema["parameters"] == {"type": "object", "properties": {}}
|
||||
|
||||
def test_definitions_refs_are_rewritten_to_defs(self):
|
||||
from tools.mcp_tool import _convert_mcp_schema
|
||||
|
||||
mcp_tool = _make_mcp_tool(
|
||||
name="submit",
|
||||
description="Submit a payload",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {"$ref": "#/definitions/Payload"},
|
||||
},
|
||||
"required": ["input"],
|
||||
"definitions": {
|
||||
"Payload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
schema = _convert_mcp_schema("forms", mcp_tool)
|
||||
|
||||
assert schema["parameters"]["properties"]["input"]["$ref"] == "#/$defs/Payload"
|
||||
assert "$defs" in schema["parameters"]
|
||||
assert "definitions" not in schema["parameters"]
|
||||
|
||||
def test_nested_definition_refs_are_rewritten_recursively(self):
|
||||
from tools.mcp_tool import _convert_mcp_schema
|
||||
|
||||
mcp_tool = _make_mcp_tool(
|
||||
name="nested",
|
||||
description="Nested schema",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/definitions/Entry"},
|
||||
},
|
||||
},
|
||||
"definitions": {
|
||||
"Entry": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"child": {"$ref": "#/definitions/Child"},
|
||||
},
|
||||
},
|
||||
"Child": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
schema = _convert_mcp_schema("forms", mcp_tool)
|
||||
|
||||
assert schema["parameters"]["properties"]["items"]["items"]["$ref"] == "#/$defs/Entry"
|
||||
assert schema["parameters"]["$defs"]["Entry"]["properties"]["child"]["$ref"] == "#/$defs/Child"
|
||||
|
||||
def test_missing_type_on_object_is_coerced(self):
|
||||
"""Schemas that describe an object but omit ``type`` get type='object'."""
|
||||
from tools.mcp_tool import _normalize_mcp_input_schema
|
||||
|
||||
schema = _normalize_mcp_input_schema({
|
||||
"properties": {"q": {"type": "string"}},
|
||||
"required": ["q"],
|
||||
})
|
||||
|
||||
assert schema["type"] == "object"
|
||||
assert schema["properties"]["q"]["type"] == "string"
|
||||
assert schema["required"] == ["q"]
|
||||
|
||||
def test_null_type_on_object_is_coerced(self):
|
||||
"""type: None should be treated like missing type (common MCP server bug)."""
|
||||
from tools.mcp_tool import _normalize_mcp_input_schema
|
||||
|
||||
schema = _normalize_mcp_input_schema({
|
||||
"type": None,
|
||||
"properties": {"x": {"type": "integer"}},
|
||||
})
|
||||
|
||||
assert schema["type"] == "object"
|
||||
|
||||
def test_required_pruned_when_property_missing(self):
|
||||
"""Gemini 400s on required names that don't exist in properties."""
|
||||
from tools.mcp_tool import _normalize_mcp_input_schema
|
||||
|
||||
schema = _normalize_mcp_input_schema({
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "string"}},
|
||||
"required": ["a", "ghost", "phantom"],
|
||||
})
|
||||
|
||||
assert schema["required"] == ["a"]
|
||||
|
||||
def test_required_removed_when_all_names_dangle(self):
|
||||
from tools.mcp_tool import _normalize_mcp_input_schema
|
||||
|
||||
schema = _normalize_mcp_input_schema({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": ["ghost"],
|
||||
})
|
||||
|
||||
assert "required" not in schema
|
||||
|
||||
def test_required_pruning_applies_recursively_inside_nested_objects(self):
|
||||
"""Nested object schemas also get required pruning."""
|
||||
from tools.mcp_tool import _normalize_mcp_input_schema
|
||||
|
||||
schema = _normalize_mcp_input_schema({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filter": {
|
||||
"type": "object",
|
||||
"properties": {"field": {"type": "string"}},
|
||||
"required": ["field", "missing"],
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
assert schema["properties"]["filter"]["required"] == ["field"]
|
||||
|
||||
def test_object_in_array_items_gets_properties_filled(self):
|
||||
"""Array-item object schemas without properties get an empty dict."""
|
||||
from tools.mcp_tool import _normalize_mcp_input_schema
|
||||
|
||||
schema = _normalize_mcp_input_schema({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
assert schema["properties"]["items"]["items"]["properties"] == {}
|
||||
|
||||
def test_convert_mcp_schema_survives_missing_inputschema_attribute(self):
|
||||
"""A Tool object without .inputSchema must not crash registration."""
|
||||
import types
|
||||
|
||||
from tools.mcp_tool import _convert_mcp_schema
|
||||
|
||||
bare_tool = types.SimpleNamespace(name="probe", description="Probe")
|
||||
schema = _convert_mcp_schema("srv", bare_tool)
|
||||
|
||||
assert schema["name"] == "mcp_srv_probe"
|
||||
assert schema["parameters"] == {"type": "object", "properties": {}}
|
||||
|
||||
def test_convert_mcp_schema_with_none_inputschema(self):
|
||||
"""Tool with inputSchema=None produces a valid empty object schema."""
|
||||
import types
|
||||
|
||||
from tools.mcp_tool import _convert_mcp_schema
|
||||
|
||||
# Note: _make_mcp_tool(input_schema=None) falls back to a default —
|
||||
# build the namespace directly so .inputSchema really is None.
|
||||
mcp_tool = types.SimpleNamespace(name="probe", description="Probe", inputSchema=None)
|
||||
schema = _convert_mcp_schema("srv", mcp_tool)
|
||||
|
||||
assert schema["parameters"] == {"type": "object", "properties": {}}
|
||||
|
||||
def test_tool_name_prefix_format(self):
|
||||
from tools.mcp_tool import _convert_mcp_schema
|
||||
|
||||
|
|
@ -1029,6 +1200,92 @@ class TestHTTPConfig:
|
|||
|
||||
asyncio.run(_test())
|
||||
|
||||
def test_http_seeds_initial_protocol_header(self):
|
||||
from tools.mcp_tool import LATEST_PROTOCOL_VERSION, MCPServerTask
|
||||
|
||||
server = MCPServerTask("remote")
|
||||
captured = {}
|
||||
|
||||
class DummyAsyncClient:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class DummyTransportCtx:
|
||||
async def __aenter__(self):
|
||||
return MagicMock(), MagicMock(), (lambda: None)
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class DummySession:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def initialize(self):
|
||||
return None
|
||||
|
||||
class DummyLegacyTransportCtx:
|
||||
def __init__(self, **kwargs):
|
||||
captured["legacy_headers"] = kwargs.get("headers")
|
||||
|
||||
async def __aenter__(self):
|
||||
return MagicMock(), MagicMock(), (lambda: None)
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def _discover_tools(self):
|
||||
self._shutdown_event.set()
|
||||
|
||||
async def _run(config, *, new_http):
|
||||
captured.clear()
|
||||
with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
|
||||
patch("tools.mcp_tool._MCP_NEW_HTTP", new_http), \
|
||||
patch("httpx.AsyncClient", DummyAsyncClient), \
|
||||
patch("tools.mcp_tool.streamable_http_client", return_value=DummyTransportCtx()), \
|
||||
patch("tools.mcp_tool.streamablehttp_client", side_effect=lambda url, **kwargs: DummyLegacyTransportCtx(**kwargs)), \
|
||||
patch("tools.mcp_tool.ClientSession", DummySession), \
|
||||
patch.object(MCPServerTask, "_discover_tools", _discover_tools):
|
||||
await server._run_http(config)
|
||||
|
||||
asyncio.run(_run({"url": "https://example.com/mcp"}, new_http=True))
|
||||
assert captured["headers"]["mcp-protocol-version"] == LATEST_PROTOCOL_VERSION
|
||||
|
||||
asyncio.run(_run({
|
||||
"url": "https://example.com/mcp",
|
||||
"headers": {"mcp-protocol-version": "custom-version"},
|
||||
}, new_http=True))
|
||||
assert captured["headers"]["mcp-protocol-version"] == "custom-version"
|
||||
|
||||
asyncio.run(_run({
|
||||
"url": "https://example.com/mcp",
|
||||
"headers": {"MCP-Protocol-Version": "custom-version"},
|
||||
}, new_http=True))
|
||||
assert captured["headers"]["MCP-Protocol-Version"] == "custom-version"
|
||||
assert "mcp-protocol-version" not in captured["headers"]
|
||||
|
||||
asyncio.run(_run({"url": "https://example.com/mcp"}, new_http=False))
|
||||
assert captured["legacy_headers"]["mcp-protocol-version"] == LATEST_PROTOCOL_VERSION
|
||||
|
||||
asyncio.run(_run({
|
||||
"url": "https://example.com/mcp",
|
||||
"headers": {"MCP-Protocol-Version": "custom-version"},
|
||||
}, new_http=False))
|
||||
assert captured["legacy_headers"]["MCP-Protocol-Version"] == "custom-version"
|
||||
assert "mcp-protocol-version" not in captured["legacy_headers"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reconnection logic
|
||||
|
|
|
|||
359
tests/tools/test_mcp_tool_session_expired.py
Normal file
359
tests/tools/test_mcp_tool_session_expired.py
Normal file
|
|
@ -0,0 +1,359 @@
|
|||
"""Tests for MCP tool-handler transport-session auto-reconnect.
|
||||
|
||||
When a Streamable HTTP MCP server garbage-collects its server-side
|
||||
session (idle TTL, server restart, pod rotation, …) it rejects
|
||||
subsequent requests with a JSON-RPC error containing phrases like
|
||||
``"Invalid or expired session"``. The OAuth token remains valid —
|
||||
only the transport session state needs rebuilding.
|
||||
|
||||
Before the #13383 fix, this class of failure fell through as a plain
|
||||
tool error with no recovery path, so every subsequent call on the
|
||||
affected MCP server failed until the gateway was manually restarted.
|
||||
"""
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_session_expired_error — unit coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_session_expired_detects_invalid_or_expired_session():
|
||||
"""Reporter's exact wpcom-mcp error message (#13383)."""
|
||||
from tools.mcp_tool import _is_session_expired_error
|
||||
exc = RuntimeError("Invalid params: Invalid or expired session")
|
||||
assert _is_session_expired_error(exc) is True
|
||||
|
||||
|
||||
def test_is_session_expired_detects_expired_session_variant():
|
||||
"""Generic ``session expired`` / ``expired session`` phrasings used
|
||||
by other SDK servers."""
|
||||
from tools.mcp_tool import _is_session_expired_error
|
||||
assert _is_session_expired_error(RuntimeError("Session expired")) is True
|
||||
assert _is_session_expired_error(RuntimeError("expired session: abc")) is True
|
||||
|
||||
|
||||
def test_is_session_expired_detects_session_not_found():
|
||||
"""Server-side GC produces ``session not found`` / ``unknown session``
|
||||
on some implementations."""
|
||||
from tools.mcp_tool import _is_session_expired_error
|
||||
assert _is_session_expired_error(RuntimeError("session not found")) is True
|
||||
assert _is_session_expired_error(RuntimeError("Unknown session: abc123")) is True
|
||||
|
||||
|
||||
def test_is_session_expired_is_case_insensitive():
|
||||
"""Match uses lower-cased comparison so servers that emit the
|
||||
message in different cases (SDK formatter quirks) still trigger."""
|
||||
from tools.mcp_tool import _is_session_expired_error
|
||||
assert _is_session_expired_error(RuntimeError("INVALID OR EXPIRED SESSION")) is True
|
||||
assert _is_session_expired_error(RuntimeError("Session Expired")) is True
|
||||
|
||||
|
||||
def test_is_session_expired_rejects_unrelated_errors():
|
||||
"""Narrow scope: only the specific session-expired markers trigger.
|
||||
A regular RuntimeError / ValueError does not."""
|
||||
from tools.mcp_tool import _is_session_expired_error
|
||||
assert _is_session_expired_error(RuntimeError("Tool failed to execute")) is False
|
||||
assert _is_session_expired_error(ValueError("Missing parameter")) is False
|
||||
assert _is_session_expired_error(Exception("Connection refused")) is False
|
||||
# 401 is handled by the sibling _is_auth_error path, not here.
|
||||
assert _is_session_expired_error(RuntimeError("401 Unauthorized")) is False
|
||||
|
||||
|
||||
def test_is_session_expired_rejects_interrupted_error():
|
||||
"""InterruptedError is the user-cancel signal — must never route
|
||||
through the session-reconnect path."""
|
||||
from tools.mcp_tool import _is_session_expired_error
|
||||
assert _is_session_expired_error(InterruptedError()) is False
|
||||
assert _is_session_expired_error(InterruptedError("Invalid or expired session")) is False
|
||||
|
||||
|
||||
def test_is_session_expired_rejects_empty_message():
|
||||
"""Bare exceptions with no message shouldn't match."""
|
||||
from tools.mcp_tool import _is_session_expired_error
|
||||
assert _is_session_expired_error(RuntimeError("")) is False
|
||||
assert _is_session_expired_error(Exception()) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Handler integration — verify the recovery plumbing wires end-to-end
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _install_stub_server(name: str = "wpcom"):
|
||||
"""Register a minimal server stub that _handle_session_expired_and_retry
|
||||
can signal via _reconnect_event, and that reports ready+session after
|
||||
the event fires."""
|
||||
from tools import mcp_tool
|
||||
|
||||
mcp_tool._ensure_mcp_loop()
|
||||
|
||||
server = MagicMock()
|
||||
server.name = name
|
||||
# _reconnect_event is called via loop.call_soon_threadsafe(…set); use
|
||||
# a threading-safe substitute.
|
||||
reconnect_flag = threading.Event()
|
||||
|
||||
class _EventAdapter:
|
||||
def set(self):
|
||||
reconnect_flag.set()
|
||||
|
||||
server._reconnect_event = _EventAdapter()
|
||||
|
||||
# Immediately "ready" — simulates a fast reconnect (_ready.is_set()
|
||||
# is polled by _handle_session_expired_and_retry until the timeout).
|
||||
ready_flag = threading.Event()
|
||||
ready_flag.set()
|
||||
server._ready = MagicMock()
|
||||
server._ready.is_set = ready_flag.is_set
|
||||
|
||||
# session attr must be truthy for the handler's initial check
|
||||
# (``if not server or not server.session``) and for the post-
|
||||
# reconnect readiness probe (``srv.session is not None``).
|
||||
server.session = MagicMock()
|
||||
return server, reconnect_flag
|
||||
|
||||
|
||||
def test_call_tool_handler_reconnects_on_session_expired(monkeypatch, tmp_path):
|
||||
"""Reporter's exact repro: call_tool raises "Invalid or expired
|
||||
session", handler triggers reconnect, retries once, and returns
|
||||
the retry's successful JSON (not the generic error)."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools import mcp_tool
|
||||
from tools.mcp_tool import _make_tool_handler
|
||||
|
||||
server, reconnect_flag = _install_stub_server("wpcom")
|
||||
mcp_tool._servers["wpcom"] = server
|
||||
mcp_tool._server_error_counts.pop("wpcom", None)
|
||||
|
||||
# First call raises session-expired; second call (post-reconnect)
|
||||
# returns a proper MCP tool result.
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def _call_sequence(*a, **kw):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
raise RuntimeError("Invalid params: Invalid or expired session")
|
||||
# Second call: mimic the MCP SDK's structured success response.
|
||||
result = MagicMock()
|
||||
result.isError = False
|
||||
result.content = [MagicMock(type="text", text="tool completed")]
|
||||
result.structuredContent = None
|
||||
return result
|
||||
|
||||
server.session.call_tool = _call_sequence
|
||||
|
||||
try:
|
||||
handler = _make_tool_handler("wpcom", "wpcom-mcp-content-authoring", 10.0)
|
||||
out = handler({"slug": "hello"})
|
||||
parsed = json.loads(out)
|
||||
# Retry succeeded — no error surfaced to caller.
|
||||
assert "error" not in parsed, (
|
||||
f"Expected retry to succeed after reconnect; got: {parsed}"
|
||||
)
|
||||
# _reconnect_event was signalled exactly once.
|
||||
assert reconnect_flag.is_set(), (
|
||||
"Handler did not trigger transport reconnect on session-expired "
|
||||
"error — the reconnect flow is the whole point of this fix."
|
||||
)
|
||||
# Exactly 2 call attempts (original + one retry).
|
||||
assert call_count["n"] == 2, (
|
||||
f"Expected 1 original + 1 retry = 2 calls; got {call_count['n']}"
|
||||
)
|
||||
finally:
|
||||
mcp_tool._servers.pop("wpcom", None)
|
||||
mcp_tool._server_error_counts.pop("wpcom", None)
|
||||
|
||||
|
||||
def test_call_tool_handler_non_session_expired_error_falls_through(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
"""Preserved-behaviour canary: a non-session-expired exception must
|
||||
NOT trigger reconnect — it must fall through to the generic error
|
||||
path so the caller sees the real failure."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools import mcp_tool
|
||||
from tools.mcp_tool import _make_tool_handler
|
||||
|
||||
server, reconnect_flag = _install_stub_server("srv")
|
||||
mcp_tool._servers["srv"] = server
|
||||
mcp_tool._server_error_counts.pop("srv", None)
|
||||
|
||||
async def _raises(*a, **kw):
|
||||
raise RuntimeError("Tool execution failed — unrelated error")
|
||||
|
||||
server.session.call_tool = _raises
|
||||
|
||||
try:
|
||||
handler = _make_tool_handler("srv", "mytool", 10.0)
|
||||
out = handler({"arg": "v"})
|
||||
parsed = json.loads(out)
|
||||
# Generic error path surfaced the failure.
|
||||
assert "MCP call failed" in parsed.get("error", "")
|
||||
# Reconnect was NOT triggered for this unrelated failure.
|
||||
assert not reconnect_flag.is_set(), (
|
||||
"Reconnect must not fire for non-session-expired errors — "
|
||||
"this would cause spurious transport churn on every tool "
|
||||
"failure."
|
||||
)
|
||||
finally:
|
||||
mcp_tool._servers.pop("srv", None)
|
||||
mcp_tool._server_error_counts.pop("srv", None)
|
||||
|
||||
|
||||
def test_session_expired_handler_returns_none_without_loop(monkeypatch):
|
||||
"""Defensive: if the MCP loop isn't running (cold start / shutdown
|
||||
race), the handler must fall through cleanly instead of hanging
|
||||
or raising."""
|
||||
from tools import mcp_tool
|
||||
from tools.mcp_tool import _handle_session_expired_and_retry
|
||||
|
||||
# Install a server stub but make the event loop unavailable.
|
||||
server = MagicMock()
|
||||
server._reconnect_event = MagicMock()
|
||||
server._ready = MagicMock()
|
||||
server._ready.is_set = MagicMock(return_value=True)
|
||||
server.session = MagicMock()
|
||||
mcp_tool._servers["srv-noloop"] = server
|
||||
|
||||
monkeypatch.setattr(mcp_tool, "_mcp_loop", None)
|
||||
|
||||
try:
|
||||
out = _handle_session_expired_and_retry(
|
||||
"srv-noloop",
|
||||
RuntimeError("Invalid or expired session"),
|
||||
lambda: '{"ok": true}',
|
||||
"tools/call",
|
||||
)
|
||||
assert out is None, (
|
||||
"Without an event loop, session-expired handler must fall "
|
||||
"through to caller's generic error path — not hang or raise."
|
||||
)
|
||||
finally:
|
||||
mcp_tool._servers.pop("srv-noloop", None)
|
||||
|
||||
|
||||
def test_session_expired_handler_returns_none_without_server_record():
|
||||
"""If the server has been torn down / isn't in _servers, fall
|
||||
through cleanly — nothing to reconnect to."""
|
||||
from tools.mcp_tool import _handle_session_expired_and_retry
|
||||
out = _handle_session_expired_and_retry(
|
||||
"does-not-exist",
|
||||
RuntimeError("Invalid or expired session"),
|
||||
lambda: '{"ok": true}',
|
||||
"tools/call",
|
||||
)
|
||||
assert out is None
|
||||
|
||||
|
||||
def test_session_expired_handler_returns_none_when_retry_also_fails(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
"""If the retry after reconnect also raises, fall through to the
|
||||
generic error path (don't loop forever, don't mask the second
|
||||
failure)."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools import mcp_tool
|
||||
from tools.mcp_tool import _handle_session_expired_and_retry
|
||||
|
||||
server, _ = _install_stub_server("srv-retry-fail")
|
||||
mcp_tool._servers["srv-retry-fail"] = server
|
||||
|
||||
def _retry_raises():
|
||||
raise RuntimeError("retry blew up too")
|
||||
|
||||
try:
|
||||
out = _handle_session_expired_and_retry(
|
||||
"srv-retry-fail",
|
||||
RuntimeError("Invalid or expired session"),
|
||||
_retry_raises,
|
||||
"tools/call",
|
||||
)
|
||||
assert out is None, (
|
||||
"When the retry itself fails, the handler must return None "
|
||||
"so the caller's generic error path runs — no retry loop."
|
||||
)
|
||||
finally:
|
||||
mcp_tool._servers.pop("srv-retry-fail", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parallel coverage for resources/list, resources/read, prompts/list,
|
||||
# prompts/get — all four handlers share the same exception path.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"handler_factory, handler_kwargs, session_method, op_label",
|
||||
[
|
||||
("_make_list_resources_handler", {"tool_timeout": 10.0}, "list_resources", "list_resources"),
|
||||
("_make_read_resource_handler", {"tool_timeout": 10.0}, "read_resource", "read_resource"),
|
||||
("_make_list_prompts_handler", {"tool_timeout": 10.0}, "list_prompts", "list_prompts"),
|
||||
("_make_get_prompt_handler", {"tool_timeout": 10.0}, "get_prompt", "get_prompt"),
|
||||
],
|
||||
)
|
||||
def test_non_tool_handlers_also_reconnect_on_session_expired(
|
||||
monkeypatch, tmp_path, handler_factory, handler_kwargs, session_method, op_label
|
||||
):
|
||||
"""All four non-``tools/call`` MCP handlers share the recovery
|
||||
pattern and must reconnect the same way on session-expired."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools import mcp_tool
|
||||
|
||||
server, reconnect_flag = _install_stub_server(f"srv-{op_label}")
|
||||
mcp_tool._servers[f"srv-{op_label}"] = server
|
||||
mcp_tool._server_error_counts.pop(f"srv-{op_label}", None)
|
||||
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def _sequence(*a, **kw):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
raise RuntimeError("Invalid or expired session")
|
||||
# Return something with the shapes each handler expects.
|
||||
# Explicitly set primitive attrs — MagicMock's default auto-attr
|
||||
# behaviour surfaces ``MagicMock`` values for optional fields
|
||||
# like ``description``, which break ``json.dumps`` downstream.
|
||||
result = MagicMock()
|
||||
result.resources = []
|
||||
result.prompts = []
|
||||
result.contents = []
|
||||
result.messages = [] # get_prompt
|
||||
result.description = None # get_prompt optional field
|
||||
return result
|
||||
|
||||
setattr(server.session, session_method, _sequence)
|
||||
|
||||
factory = getattr(mcp_tool, handler_factory)
|
||||
# list_resources / list_prompts take (server_name, timeout).
|
||||
# read_resource / get_prompt take the same signature.
|
||||
try:
|
||||
handler = factory(f"srv-{op_label}", **handler_kwargs)
|
||||
if op_label == "read_resource":
|
||||
out = handler({"uri": "file://foo"})
|
||||
elif op_label == "get_prompt":
|
||||
out = handler({"name": "p1"})
|
||||
else:
|
||||
out = handler({})
|
||||
parsed = json.loads(out)
|
||||
assert "error" not in parsed, (
|
||||
f"{op_label}: expected retry success, got {parsed}"
|
||||
)
|
||||
assert reconnect_flag.is_set(), (
|
||||
f"{op_label}: reconnect should fire for session-expired"
|
||||
)
|
||||
assert call_count["n"] == 2, (
|
||||
f"{op_label}: expected 1 original + 1 retry"
|
||||
)
|
||||
finally:
|
||||
mcp_tool._servers.pop(f"srv-{op_label}", None)
|
||||
mcp_tool._server_error_counts.pop(f"srv-{op_label}", None)
|
||||
|
|
@ -8,14 +8,17 @@ import pytest
|
|||
moa = importlib.import_module("tools.mixture_of_agents_tool")
|
||||
|
||||
|
||||
def test_moa_defaults_track_current_openrouter_frontier_models():
|
||||
assert moa.REFERENCE_MODELS == [
|
||||
"anthropic/claude-opus-4.6",
|
||||
"google/gemini-3-pro-preview",
|
||||
"openai/gpt-5.4-pro",
|
||||
"deepseek/deepseek-v3.2",
|
||||
]
|
||||
assert moa.AGGREGATOR_MODEL == "anthropic/claude-opus-4.6"
|
||||
def test_moa_defaults_are_well_formed():
|
||||
# Invariants, not a catalog snapshot: the exact model list churns with
|
||||
# OpenRouter availability (see PR #6636 where gemini-3-pro-preview was
|
||||
# removed upstream). What we care about is that the defaults are present
|
||||
# and valid vendor/model slugs.
|
||||
assert isinstance(moa.REFERENCE_MODELS, list)
|
||||
assert len(moa.REFERENCE_MODELS) >= 1
|
||||
for m in moa.REFERENCE_MODELS:
|
||||
assert isinstance(m, str) and "/" in m and not m.startswith("/")
|
||||
assert isinstance(moa.AGGREGATOR_MODEL, str)
|
||||
assert "/" in moa.AGGREGATOR_MODEL
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -292,6 +292,7 @@ class TestBuiltinDiscovery:
|
|||
def test_matches_previous_manual_builtin_tool_set(self):
|
||||
expected = {
|
||||
"tools.browser_cdp_tool",
|
||||
"tools.browser_dialog_tool",
|
||||
"tools.browser_tool",
|
||||
"tools.clarify_tool",
|
||||
"tools.code_execution_tool",
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -22,8 +23,9 @@ class TestResolvePath:
|
|||
monkeypatch.setenv("TERMINAL_CWD", str(tmp_path))
|
||||
from tools.file_tools import _resolve_path
|
||||
|
||||
result = _resolve_path("/etc/hosts")
|
||||
assert result == Path("/etc/hosts")
|
||||
absolute = (tmp_path / "already-absolute.txt").resolve()
|
||||
result = _resolve_path(str(absolute))
|
||||
assert result == absolute
|
||||
|
||||
def test_falls_back_to_cwd_without_terminal_cwd(self, monkeypatch):
|
||||
"""Without TERMINAL_CWD, falls back to os.getcwd()."""
|
||||
|
|
@ -50,3 +52,34 @@ class TestResolvePath:
|
|||
result = _resolve_path("a/../b/file.txt")
|
||||
assert ".." not in str(result)
|
||||
assert result == (tmp_path / "b" / "file.txt")
|
||||
|
||||
def test_relative_path_prefers_live_file_ops_cwd(self, monkeypatch, tmp_path):
|
||||
"""Live env.cwd must win after the terminal session changes directory."""
|
||||
start_dir = tmp_path / "start"
|
||||
live_dir = tmp_path / "worktree"
|
||||
start_dir.mkdir()
|
||||
live_dir.mkdir()
|
||||
monkeypatch.setenv("TERMINAL_CWD", str(start_dir))
|
||||
|
||||
from tools import file_tools
|
||||
|
||||
task_id = "live-cwd"
|
||||
fake_ops = SimpleNamespace(
|
||||
env=SimpleNamespace(cwd=str(live_dir)),
|
||||
cwd=str(start_dir),
|
||||
)
|
||||
|
||||
with file_tools._file_ops_lock:
|
||||
previous = file_tools._file_ops_cache.get(task_id)
|
||||
file_tools._file_ops_cache[task_id] = fake_ops
|
||||
|
||||
try:
|
||||
result = file_tools._resolve_path("nested/file.txt", task_id=task_id)
|
||||
finally:
|
||||
with file_tools._file_ops_lock:
|
||||
if previous is None:
|
||||
file_tools._file_ops_cache.pop(task_id, None)
|
||||
else:
|
||||
file_tools._file_ops_cache[task_id] = previous
|
||||
|
||||
assert result == live_dir / "nested" / "file.txt"
|
||||
|
|
|
|||
205
tests/tools/test_schema_sanitizer.py
Normal file
205
tests/tools/test_schema_sanitizer.py
Normal file
|
|
@ -0,0 +1,205 @@
|
|||
"""Tests for tools/schema_sanitizer.py.
|
||||
|
||||
Targets the known llama.cpp ``json-schema-to-grammar`` failure modes that
|
||||
cause ``HTTP 400: Unable to generate parser for this template. ...
|
||||
Unrecognized schema: "object"`` errors on local inference backends.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
|
||||
from tools.schema_sanitizer import sanitize_tool_schemas
|
||||
|
||||
|
||||
def _tool(name: str, parameters: dict) -> dict:
|
||||
return {"type": "function", "function": {"name": name, "parameters": parameters}}
|
||||
|
||||
|
||||
def test_object_without_properties_gets_empty_properties():
|
||||
tools = [_tool("t", {"type": "object"})]
|
||||
out = sanitize_tool_schemas(tools)
|
||||
assert out[0]["function"]["parameters"] == {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
def test_nested_object_without_properties_gets_empty_properties():
|
||||
tools = [_tool("t", {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"arguments": {"type": "object", "description": "free-form"},
|
||||
},
|
||||
"required": ["name"],
|
||||
})]
|
||||
out = sanitize_tool_schemas(tools)
|
||||
args = out[0]["function"]["parameters"]["properties"]["arguments"]
|
||||
assert args["type"] == "object"
|
||||
assert args["properties"] == {}
|
||||
assert args["description"] == "free-form"
|
||||
|
||||
|
||||
def test_bare_string_object_value_replaced_with_schema_dict():
|
||||
# Malformed: a property's schema value is the bare string "object".
|
||||
# This is the exact shape llama.cpp reports as `Unrecognized schema: "object"`.
|
||||
tools = [_tool("t", {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"payload": "object", # <-- invalid, should be {"type": "object"}
|
||||
},
|
||||
})]
|
||||
out = sanitize_tool_schemas(tools)
|
||||
payload = out[0]["function"]["parameters"]["properties"]["payload"]
|
||||
assert isinstance(payload, dict)
|
||||
assert payload["type"] == "object"
|
||||
assert payload["properties"] == {}
|
||||
|
||||
|
||||
def test_bare_string_primitive_value_replaced_with_schema_dict():
|
||||
tools = [_tool("t", {
|
||||
"type": "object",
|
||||
"properties": {"name": "string"},
|
||||
})]
|
||||
out = sanitize_tool_schemas(tools)
|
||||
assert out[0]["function"]["parameters"]["properties"]["name"] == {"type": "string"}
|
||||
|
||||
|
||||
def test_nullable_type_array_collapsed_to_single_string():
|
||||
tools = [_tool("t", {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"maybe_name": {"type": ["string", "null"]},
|
||||
},
|
||||
})]
|
||||
out = sanitize_tool_schemas(tools)
|
||||
prop = out[0]["function"]["parameters"]["properties"]["maybe_name"]
|
||||
assert prop["type"] == "string"
|
||||
assert prop.get("nullable") is True
|
||||
|
||||
|
||||
def test_anyof_nested_objects_sanitized():
|
||||
tools = [_tool("t", {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"opt": {
|
||||
"anyOf": [
|
||||
{"type": "object"}, # bare object
|
||||
{"type": "string"},
|
||||
],
|
||||
},
|
||||
},
|
||||
})]
|
||||
out = sanitize_tool_schemas(tools)
|
||||
variants = out[0]["function"]["parameters"]["properties"]["opt"]["anyOf"]
|
||||
assert variants[0] == {"type": "object", "properties": {}}
|
||||
assert variants[1] == {"type": "string"}
|
||||
|
||||
|
||||
def test_missing_parameters_gets_default_object_schema():
|
||||
tools = [{"type": "function", "function": {"name": "t"}}]
|
||||
out = sanitize_tool_schemas(tools)
|
||||
assert out[0]["function"]["parameters"] == {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
def test_non_dict_parameters_gets_default_object_schema():
|
||||
tools = [_tool("t", "object")] # pathological
|
||||
out = sanitize_tool_schemas(tools)
|
||||
assert out[0]["function"]["parameters"] == {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
def test_required_pruned_to_existing_properties():
|
||||
tools = [_tool("t", {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": ["name", "missing_field"],
|
||||
})]
|
||||
out = sanitize_tool_schemas(tools)
|
||||
assert out[0]["function"]["parameters"]["required"] == ["name"]
|
||||
|
||||
|
||||
def test_required_all_missing_is_dropped():
|
||||
tools = [_tool("t", {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": ["x", "y"],
|
||||
})]
|
||||
out = sanitize_tool_schemas(tools)
|
||||
assert "required" not in out[0]["function"]["parameters"]
|
||||
|
||||
|
||||
def test_well_formed_schema_unchanged():
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "File path"},
|
||||
"offset": {"type": "integer", "minimum": 1},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
tools = [_tool("read_file", copy.deepcopy(schema))]
|
||||
out = sanitize_tool_schemas(tools)
|
||||
assert out[0]["function"]["parameters"] == schema
|
||||
|
||||
|
||||
def test_additional_properties_bool_preserved():
|
||||
tools = [_tool("t", {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"payload": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": True,
|
||||
},
|
||||
},
|
||||
})]
|
||||
out = sanitize_tool_schemas(tools)
|
||||
payload = out[0]["function"]["parameters"]["properties"]["payload"]
|
||||
assert payload["additionalProperties"] is True
|
||||
|
||||
|
||||
def test_additional_properties_schema_sanitized():
|
||||
tools = [_tool("t", {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"dict_field": {
|
||||
"type": "object",
|
||||
"additionalProperties": {"type": "object"}, # bare object schema
|
||||
},
|
||||
},
|
||||
})]
|
||||
out = sanitize_tool_schemas(tools)
|
||||
field = out[0]["function"]["parameters"]["properties"]["dict_field"]
|
||||
assert field["additionalProperties"] == {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
def test_deepcopy_does_not_mutate_input():
|
||||
original = {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "object"}},
|
||||
}
|
||||
tools = [_tool("t", original)]
|
||||
_ = sanitize_tool_schemas(tools)
|
||||
# Original should still lack properties on the nested object
|
||||
assert "properties" not in original["properties"]["x"]
|
||||
|
||||
|
||||
def test_items_sanitized_in_array_schema():
|
||||
tools = [_tool("t", {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"bag": {
|
||||
"type": "array",
|
||||
"items": {"type": "object"}, # bare object items
|
||||
},
|
||||
},
|
||||
})]
|
||||
out = sanitize_tool_schemas(tools)
|
||||
items = out[0]["function"]["parameters"]["properties"]["bag"]["items"]
|
||||
assert items == {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
def test_empty_tools_list_returns_empty():
|
||||
assert sanitize_tool_schemas([]) == []
|
||||
|
||||
|
||||
def test_none_tools_returns_none():
|
||||
assert sanitize_tool_schemas(None) is None
|
||||
|
|
@ -484,3 +484,85 @@ class TestSkillManageDispatcher:
|
|||
raw = skill_manage(action="create", name="test-skill", content=VALID_SKILL_CONTENT)
|
||||
result = json.loads(raw)
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
class TestSecurityScanGate:
|
||||
"""_security_scan_skill is gated by skills.guard_agent_created config flag."""
|
||||
|
||||
def test_scan_noop_when_flag_off(self, tmp_path):
|
||||
"""Default config (flag off) short-circuits before running scan_skill."""
|
||||
from tools.skill_manager_tool import _security_scan_skill
|
||||
|
||||
with patch("tools.skill_manager_tool._guard_agent_created_enabled", return_value=False), \
|
||||
patch("tools.skill_manager_tool.scan_skill") as mock_scan:
|
||||
result = _security_scan_skill(tmp_path)
|
||||
|
||||
assert result is None
|
||||
mock_scan.assert_not_called() # scan never ran
|
||||
|
||||
def test_scan_runs_when_flag_on(self, tmp_path):
|
||||
"""When flag is on, scan_skill is invoked and its verdict is honored."""
|
||||
from tools.skill_manager_tool import _security_scan_skill
|
||||
from tools.skills_guard import ScanResult
|
||||
|
||||
# Fake a safe scan result — caller should return None (allow)
|
||||
fake_result = ScanResult(
|
||||
skill_name="test",
|
||||
source="agent-created",
|
||||
trust_level="agent-created",
|
||||
verdict="safe",
|
||||
findings=[],
|
||||
summary="ok",
|
||||
)
|
||||
with patch("tools.skill_manager_tool._guard_agent_created_enabled", return_value=True), \
|
||||
patch("tools.skill_manager_tool.scan_skill", return_value=fake_result) as mock_scan:
|
||||
result = _security_scan_skill(tmp_path)
|
||||
|
||||
assert result is None
|
||||
mock_scan.assert_called_once()
|
||||
|
||||
def test_scan_blocks_dangerous_when_flag_on(self, tmp_path):
|
||||
"""Dangerous verdict + flag on → returns an error string for the agent."""
|
||||
from tools.skill_manager_tool import _security_scan_skill
|
||||
from tools.skills_guard import ScanResult, Finding
|
||||
|
||||
finding = Finding(
|
||||
pattern_id="test", severity="critical", category="exfiltration",
|
||||
file="SKILL.md", line=1, match="curl $TOKEN", description="test",
|
||||
)
|
||||
fake_result = ScanResult(
|
||||
skill_name="test",
|
||||
source="agent-created",
|
||||
trust_level="agent-created",
|
||||
verdict="dangerous",
|
||||
findings=[finding],
|
||||
summary="dangerous",
|
||||
)
|
||||
with patch("tools.skill_manager_tool._guard_agent_created_enabled", return_value=True), \
|
||||
patch("tools.skill_manager_tool.scan_skill", return_value=fake_result):
|
||||
result = _security_scan_skill(tmp_path)
|
||||
|
||||
assert result is not None
|
||||
assert "Security scan blocked" in result
|
||||
|
||||
def test_guard_flag_reads_config_default_false(self):
|
||||
"""_guard_agent_created_enabled returns False when config doesn't set it."""
|
||||
from tools.skill_manager_tool import _guard_agent_created_enabled
|
||||
|
||||
with patch("hermes_cli.config.load_config", return_value={"skills": {}}):
|
||||
assert _guard_agent_created_enabled() is False
|
||||
|
||||
def test_guard_flag_reads_config_when_set(self):
|
||||
"""_guard_agent_created_enabled returns True when user explicitly enables."""
|
||||
from tools.skill_manager_tool import _guard_agent_created_enabled
|
||||
|
||||
with patch("hermes_cli.config.load_config",
|
||||
return_value={"skills": {"guard_agent_created": True}}):
|
||||
assert _guard_agent_created_enabled() is True
|
||||
|
||||
def test_guard_flag_handles_config_error(self):
|
||||
"""If load_config raises, _guard_agent_created_enabled defaults to False (fail-safe off)."""
|
||||
from tools.skill_manager_tool import _guard_agent_created_enabled
|
||||
|
||||
with patch("hermes_cli.config.load_config", side_effect=RuntimeError("boom")):
|
||||
assert _guard_agent_created_enabled() is False
|
||||
|
|
|
|||
|
|
@ -174,27 +174,24 @@ class TestShouldAllowInstall:
|
|||
assert allowed is True
|
||||
assert "agent-created" in reason
|
||||
|
||||
def test_dangerous_agent_created_allowed(self):
|
||||
"""Agent-created skills bypass verdict gating — agent can already
|
||||
execute the same code via terminal(), so skill_manage allows all
|
||||
verdicts. This prevents friction when the agent writes skills that
|
||||
mention risky keywords in prose (e.g. describing cache-busting or
|
||||
persistence semantics in a PR-review skill)."""
|
||||
def test_dangerous_agent_created_asks(self):
|
||||
"""Agent-created skills with dangerous verdict return None (ask for confirmation)
|
||||
when the scan runs. The caller (_security_scan_skill) surfaces this as an error
|
||||
to the agent, who can retry without the flagged content.
|
||||
|
||||
This gate only runs when skills.guard_agent_created is enabled (off by default)."""
|
||||
f = [Finding("env_exfil_curl", "critical", "exfiltration", "SKILL.md", 1, "curl $TOKEN", "exfiltration")]
|
||||
allowed, reason = should_allow_install(self._result("agent-created", "dangerous", f))
|
||||
assert allowed is True
|
||||
assert "agent-created" in reason
|
||||
assert allowed is None
|
||||
assert "Requires confirmation" in reason
|
||||
|
||||
def test_force_noop_for_agent_created_dangerous(self):
|
||||
"""With agent-created dangerous mapped to 'allow', force becomes a
|
||||
no-op — the allow branch returns first. Force still works for any
|
||||
trust level that maps to block (community/trusted)."""
|
||||
def test_force_overrides_dangerous_for_agent_created(self):
|
||||
f = [Finding("x", "critical", "c", "f", 1, "m", "d")]
|
||||
allowed, reason = should_allow_install(
|
||||
self._result("agent-created", "dangerous", f), force=True
|
||||
)
|
||||
assert allowed is True
|
||||
assert "agent-created" in reason
|
||||
assert "Force-installed" in reason
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -44,6 +44,18 @@ description: Description for {name}.
|
|||
return skill_dir
|
||||
|
||||
|
||||
def _symlink_category(skills_dir: Path, linked_root: Path, category: str) -> Path:
|
||||
"""Create a category symlink under skills_dir pointing outside the tree."""
|
||||
external_category = linked_root / category
|
||||
external_category.mkdir(parents=True, exist_ok=True)
|
||||
symlink_path = skills_dir / category
|
||||
try:
|
||||
symlink_path.symlink_to(external_category, target_is_directory=True)
|
||||
except (OSError, NotImplementedError) as exc:
|
||||
pytest.skip(f"symlinks unavailable in test environment: {exc}")
|
||||
return external_category
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_frontmatter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -255,6 +267,20 @@ class TestFindAllSkills:
|
|||
assert len(skills) == 1
|
||||
assert skills[0]["name"] == "real-skill"
|
||||
|
||||
def test_finds_skills_in_symlinked_category_dir(self, tmp_path):
|
||||
external_root = tmp_path / "repo"
|
||||
skills_root = tmp_path / "skills"
|
||||
skills_root.mkdir()
|
||||
|
||||
external_category = _symlink_category(skills_root, external_root, "linked")
|
||||
_make_skill(external_category.parent, "knowledge-brain", category="linked")
|
||||
|
||||
with patch("tools.skills_tool.SKILLS_DIR", skills_root):
|
||||
skills = _find_all_skills()
|
||||
|
||||
assert [s["name"] for s in skills] == ["knowledge-brain"]
|
||||
assert skills[0]["category"] == "linked"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# skills_list
|
||||
|
|
@ -288,6 +314,23 @@ class TestSkillsList:
|
|||
assert result["count"] == 1
|
||||
assert result["skills"][0]["name"] == "skill-a"
|
||||
|
||||
def test_category_filter_finds_symlinked_category(self, tmp_path):
|
||||
external_root = tmp_path / "repo"
|
||||
skills_root = tmp_path / "skills"
|
||||
skills_root.mkdir()
|
||||
|
||||
external_category = _symlink_category(skills_root, external_root, "linked")
|
||||
_make_skill(external_category.parent, "knowledge-brain", category="linked")
|
||||
|
||||
with patch("tools.skills_tool.SKILLS_DIR", skills_root):
|
||||
raw = skills_list(category="linked")
|
||||
|
||||
result = json.loads(raw)
|
||||
assert result["success"] is True
|
||||
assert result["count"] == 1
|
||||
assert result["categories"] == ["linked"]
|
||||
assert result["skills"][0]["name"] == "knowledge-brain"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# skill_view
|
||||
|
|
@ -304,6 +347,70 @@ class TestSkillView:
|
|||
assert result["name"] == "my-skill"
|
||||
assert "Step 1" in result["content"]
|
||||
|
||||
def test_skill_view_applies_template_vars(self, tmp_path):
|
||||
with (
|
||||
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
|
||||
patch(
|
||||
"agent.skill_preprocessing.load_skills_config",
|
||||
return_value={"template_vars": True, "inline_shell": False},
|
||||
),
|
||||
):
|
||||
skill_dir = _make_skill(
|
||||
tmp_path,
|
||||
"templated",
|
||||
body="Run ${HERMES_SKILL_DIR}/scripts/do.sh in ${HERMES_SESSION_ID}",
|
||||
)
|
||||
raw = skill_view("templated", task_id="session-123")
|
||||
|
||||
result = json.loads(raw)
|
||||
assert result["success"] is True
|
||||
assert f"Run {skill_dir}/scripts/do.sh in session-123" in result["content"]
|
||||
assert "${HERMES_SKILL_DIR}" not in result["content"]
|
||||
|
||||
def test_skill_view_applies_inline_shell_when_enabled(self, tmp_path):
|
||||
with (
|
||||
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
|
||||
patch(
|
||||
"agent.skill_preprocessing.load_skills_config",
|
||||
return_value={
|
||||
"template_vars": True,
|
||||
"inline_shell": True,
|
||||
"inline_shell_timeout": 5,
|
||||
},
|
||||
),
|
||||
):
|
||||
_make_skill(
|
||||
tmp_path,
|
||||
"dynamic",
|
||||
body="Current date: !`printf 2026-04-24`",
|
||||
)
|
||||
raw = skill_view("dynamic")
|
||||
|
||||
result = json.loads(raw)
|
||||
assert result["success"] is True
|
||||
assert "Current date: 2026-04-24" in result["content"]
|
||||
assert "!`printf 2026-04-24`" not in result["content"]
|
||||
|
||||
def test_skill_view_leaves_inline_shell_literal_when_disabled(self, tmp_path):
|
||||
with (
|
||||
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
|
||||
patch(
|
||||
"agent.skill_preprocessing.load_skills_config",
|
||||
return_value={"template_vars": True, "inline_shell": False},
|
||||
),
|
||||
):
|
||||
_make_skill(
|
||||
tmp_path,
|
||||
"static",
|
||||
body="Current date: !`printf SHOULD_NOT_RUN`",
|
||||
)
|
||||
raw = skill_view("static")
|
||||
|
||||
result = json.loads(raw)
|
||||
assert result["success"] is True
|
||||
assert "Current date: !`printf SHOULD_NOT_RUN`" in result["content"]
|
||||
assert "Current date: SHOULD_NOT_RUN" not in result["content"]
|
||||
|
||||
def test_view_nonexistent_skill(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_skill(tmp_path, "other-skill")
|
||||
|
|
@ -389,6 +496,35 @@ class TestSkillView:
|
|||
result = json.loads(raw)
|
||||
assert result["success"] is True
|
||||
|
||||
def test_view_finds_skill_in_symlinked_category_dir(self, tmp_path):
|
||||
external_root = tmp_path / "repo"
|
||||
skills_root = tmp_path / "skills"
|
||||
skills_root.mkdir()
|
||||
|
||||
external_category = _symlink_category(skills_root, external_root, "linked")
|
||||
_make_skill(external_category.parent, "knowledge-brain", category="linked")
|
||||
|
||||
with patch("tools.skills_tool.SKILLS_DIR", skills_root):
|
||||
raw = skill_view("knowledge-brain")
|
||||
|
||||
result = json.loads(raw)
|
||||
assert result["success"] is True
|
||||
assert result["name"] == "knowledge-brain"
|
||||
|
||||
def test_not_found_hint_uses_same_order_as_skills_list(self, tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
_make_skill(tmp_path, "zeta", category="z-cat")
|
||||
_make_skill(tmp_path, "alpha", category="a-cat")
|
||||
_make_skill(tmp_path, "beta", category="a-cat")
|
||||
|
||||
list_result = json.loads(skills_list())
|
||||
view_result = json.loads(skill_view("missing-skill"))
|
||||
|
||||
assert view_result["success"] is False
|
||||
assert view_result["available_skills"] == [
|
||||
skill["name"] for skill in list_result["skills"]
|
||||
]
|
||||
|
||||
|
||||
class TestSkillViewSecureSetupOnLoad:
|
||||
def test_requests_missing_required_env_and_continues(self, tmp_path, monkeypatch):
|
||||
|
|
|
|||
299
tests/tools/test_spotify_client.py
Normal file
299
tests/tools/test_spotify_client.py
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from plugins.spotify import client as spotify_mod
|
||||
from plugins.spotify import tools as spotify_tool
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, status_code: int, payload: dict | None = None, *, text: str = "", headers: dict | None = None):
|
||||
self.status_code = status_code
|
||||
self._payload = payload
|
||||
self.text = text or (json.dumps(payload) if payload is not None else "")
|
||||
self.headers = headers or {"content-type": "application/json"}
|
||||
self.content = self.text.encode("utf-8") if self.text else b""
|
||||
|
||||
def json(self):
|
||||
if self._payload is None:
|
||||
raise ValueError("no json")
|
||||
return self._payload
|
||||
|
||||
|
||||
class _StubSpotifyClient:
|
||||
def __init__(self, payload):
|
||||
self.payload = payload
|
||||
|
||||
def get_currently_playing(self, *, market=None):
|
||||
return self.payload
|
||||
|
||||
|
||||
def test_spotify_client_retries_once_after_401(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls: list[str] = []
|
||||
tokens = iter([
|
||||
{
|
||||
"access_token": "token-1",
|
||||
"base_url": "https://api.spotify.com/v1",
|
||||
},
|
||||
{
|
||||
"access_token": "token-2",
|
||||
"base_url": "https://api.spotify.com/v1",
|
||||
},
|
||||
])
|
||||
|
||||
monkeypatch.setattr(
|
||||
spotify_mod,
|
||||
"resolve_spotify_runtime_credentials",
|
||||
lambda **kwargs: next(tokens),
|
||||
)
|
||||
|
||||
def fake_request(method, url, headers=None, params=None, json=None, timeout=None):
|
||||
calls.append(headers["Authorization"])
|
||||
if len(calls) == 1:
|
||||
return _FakeResponse(401, {"error": {"message": "expired token"}})
|
||||
return _FakeResponse(200, {"devices": [{"id": "dev-1"}]})
|
||||
|
||||
monkeypatch.setattr(spotify_mod.httpx, "request", fake_request)
|
||||
|
||||
client = spotify_mod.SpotifyClient()
|
||||
payload = client.get_devices()
|
||||
|
||||
assert payload["devices"][0]["id"] == "dev-1"
|
||||
assert calls == ["Bearer token-1", "Bearer token-2"]
|
||||
|
||||
|
||||
def test_normalize_spotify_uri_accepts_urls() -> None:
|
||||
uri = spotify_mod.normalize_spotify_uri(
|
||||
"https://open.spotify.com/track/7ouMYWpwJ422jRcDASZB7P",
|
||||
"track",
|
||||
)
|
||||
assert uri == "spotify:track:7ouMYWpwJ422jRcDASZB7P"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("status_code", "path", "payload", "expected"),
|
||||
[
|
||||
(
|
||||
403,
|
||||
"/me/player/play",
|
||||
{"error": {"message": "Premium required"}},
|
||||
"Spotify rejected this playback request. Playback control usually requires a Spotify Premium account and an active Spotify Connect device.",
|
||||
),
|
||||
(
|
||||
404,
|
||||
"/me/player",
|
||||
{"error": {"message": "Device not found"}},
|
||||
"Spotify could not find an active playback device or player session for this request.",
|
||||
),
|
||||
(
|
||||
429,
|
||||
"/search",
|
||||
{"error": {"message": "rate limit"}},
|
||||
"Spotify rate limit exceeded. Retry after 7 seconds.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_spotify_client_formats_friendly_api_errors(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
status_code: int,
|
||||
path: str,
|
||||
payload: dict,
|
||||
expected: str,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
spotify_mod,
|
||||
"resolve_spotify_runtime_credentials",
|
||||
lambda **kwargs: {
|
||||
"access_token": "token-1",
|
||||
"base_url": "https://api.spotify.com/v1",
|
||||
},
|
||||
)
|
||||
|
||||
def fake_request(method, url, headers=None, params=None, json=None, timeout=None):
|
||||
return _FakeResponse(status_code, payload, headers={"content-type": "application/json", "Retry-After": "7"})
|
||||
|
||||
monkeypatch.setattr(spotify_mod.httpx, "request", fake_request)
|
||||
|
||||
client = spotify_mod.SpotifyClient()
|
||||
with pytest.raises(spotify_mod.SpotifyAPIError) as exc:
|
||||
client.request("GET", path)
|
||||
|
||||
assert str(exc.value) == expected
|
||||
|
||||
|
||||
def test_get_currently_playing_returns_explanatory_empty_payload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
spotify_mod,
|
||||
"resolve_spotify_runtime_credentials",
|
||||
lambda **kwargs: {
|
||||
"access_token": "token-1",
|
||||
"base_url": "https://api.spotify.com/v1",
|
||||
},
|
||||
)
|
||||
|
||||
def fake_request(method, url, headers=None, params=None, json=None, timeout=None):
|
||||
return _FakeResponse(204, None, text="", headers={"content-type": "application/json"})
|
||||
|
||||
monkeypatch.setattr(spotify_mod.httpx, "request", fake_request)
|
||||
|
||||
client = spotify_mod.SpotifyClient()
|
||||
payload = client.get_currently_playing()
|
||||
|
||||
assert payload == {
|
||||
"status_code": 204,
|
||||
"empty": True,
|
||||
"message": "Spotify is not currently playing anything. Start playback in Spotify and try again.",
|
||||
}
|
||||
|
||||
|
||||
def test_spotify_playback_get_currently_playing_returns_explanatory_empty_result(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
spotify_tool,
|
||||
"_spotify_client",
|
||||
lambda: _StubSpotifyClient({
|
||||
"status_code": 204,
|
||||
"empty": True,
|
||||
"message": "Spotify is not currently playing anything. Start playback in Spotify and try again.",
|
||||
}),
|
||||
)
|
||||
|
||||
payload = json.loads(spotify_tool._handle_spotify_playback({"action": "get_currently_playing"}))
|
||||
|
||||
assert payload == {
|
||||
"success": True,
|
||||
"action": "get_currently_playing",
|
||||
"is_playing": False,
|
||||
"status_code": 204,
|
||||
"message": "Spotify is not currently playing anything. Start playback in Spotify and try again.",
|
||||
}
|
||||
|
||||
|
||||
def test_library_contains_uses_generic_library_endpoint(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
seen: list[tuple[str, str, dict | None]] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
spotify_mod,
|
||||
"resolve_spotify_runtime_credentials",
|
||||
lambda **kwargs: {
|
||||
"access_token": "token-1",
|
||||
"base_url": "https://api.spotify.com/v1",
|
||||
},
|
||||
)
|
||||
|
||||
def fake_request(method, url, headers=None, params=None, json=None, timeout=None):
|
||||
seen.append((method, url, params))
|
||||
return _FakeResponse(200, [True])
|
||||
|
||||
monkeypatch.setattr(spotify_mod.httpx, "request", fake_request)
|
||||
|
||||
client = spotify_mod.SpotifyClient()
|
||||
payload = client.library_contains(uris=["spotify:album:abc", "spotify:track:def"])
|
||||
|
||||
assert payload == [True]
|
||||
assert seen == [
|
||||
(
|
||||
"GET",
|
||||
"https://api.spotify.com/v1/me/library/contains",
|
||||
{"uris": "spotify:album:abc,spotify:track:def"},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("method_name", "item_key", "item_value", "expected_uris"),
|
||||
[
|
||||
("remove_saved_tracks", "track_ids", ["track-a", "track-b"], ["spotify:track:track-a", "spotify:track:track-b"]),
|
||||
("remove_saved_albums", "album_ids", ["album-a"], ["spotify:album:album-a"]),
|
||||
],
|
||||
)
|
||||
def test_library_remove_uses_generic_library_endpoint(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
method_name: str,
|
||||
item_key: str,
|
||||
item_value: list[str],
|
||||
expected_uris: list[str],
|
||||
) -> None:
|
||||
seen: list[tuple[str, str, dict | None]] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
spotify_mod,
|
||||
"resolve_spotify_runtime_credentials",
|
||||
lambda **kwargs: {
|
||||
"access_token": "token-1",
|
||||
"base_url": "https://api.spotify.com/v1",
|
||||
},
|
||||
)
|
||||
|
||||
def fake_request(method, url, headers=None, params=None, json=None, timeout=None):
|
||||
seen.append((method, url, params))
|
||||
return _FakeResponse(200, {})
|
||||
|
||||
monkeypatch.setattr(spotify_mod.httpx, "request", fake_request)
|
||||
|
||||
client = spotify_mod.SpotifyClient()
|
||||
getattr(client, method_name)(**{item_key: item_value})
|
||||
|
||||
assert seen == [
|
||||
(
|
||||
"DELETE",
|
||||
"https://api.spotify.com/v1/me/library",
|
||||
{"uris": ",".join(expected_uris)},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
|
||||
def test_spotify_library_tracks_list_routes_to_saved_tracks(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
seen: list[str] = []
|
||||
|
||||
class _LibStub:
|
||||
def get_saved_tracks(self, **kw):
|
||||
seen.append("tracks")
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
def get_saved_albums(self, **kw):
|
||||
seen.append("albums")
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
monkeypatch.setattr(spotify_tool, "_spotify_client", lambda: _LibStub())
|
||||
json.loads(spotify_tool._handle_spotify_library({"kind": "tracks", "action": "list"}))
|
||||
assert seen == ["tracks"]
|
||||
|
||||
|
||||
def test_spotify_library_albums_list_routes_to_saved_albums(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
seen: list[str] = []
|
||||
|
||||
class _LibStub:
|
||||
def get_saved_tracks(self, **kw):
|
||||
seen.append("tracks")
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
def get_saved_albums(self, **kw):
|
||||
seen.append("albums")
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
monkeypatch.setattr(spotify_tool, "_spotify_client", lambda: _LibStub())
|
||||
json.loads(spotify_tool._handle_spotify_library({"kind": "albums", "action": "list"}))
|
||||
assert seen == ["albums"]
|
||||
|
||||
|
||||
def test_spotify_library_rejects_missing_kind() -> None:
|
||||
payload = json.loads(spotify_tool._handle_spotify_library({"action": "list"}))
|
||||
assert "kind" in (payload.get("error") or "").lower()
|
||||
|
||||
|
||||
def test_spotify_playback_recently_played_action(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""recently_played is now an action on spotify_playback (folded from spotify_activity)."""
|
||||
seen: list[dict] = []
|
||||
|
||||
class _RecentStub:
|
||||
def get_recently_played(self, **kw):
|
||||
seen.append(kw)
|
||||
return {"items": [{"track": {"name": "x"}}]}
|
||||
|
||||
monkeypatch.setattr(spotify_tool, "_spotify_client", lambda: _RecentStub())
|
||||
payload = json.loads(spotify_tool._handle_spotify_playback({"action": "recently_played", "limit": 5}))
|
||||
assert seen and seen[0]["limit"] == 5
|
||||
assert isinstance(payload, dict)
|
||||
152
tests/tools/test_tool_output_limits.py
Normal file
152
tests/tools/test_tool_output_limits.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
"""Tests for tools.tool_output_limits.
|
||||
|
||||
Covers:
|
||||
1. Default values when no config is provided.
|
||||
2. Config override picks up user-supplied max_bytes / max_lines /
|
||||
max_line_length.
|
||||
3. Malformed values (None, negative, wrong type) fall back to defaults
|
||||
rather than raising.
|
||||
4. Integration: the helpers return what the terminal_tool and
|
||||
file_operations call paths will actually consume.
|
||||
|
||||
Port-tracking: anomalyco/opencode PR #23770
|
||||
(feat(truncate): allow configuring tool output truncation limits).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools import tool_output_limits as tol
|
||||
|
||||
|
||||
class TestDefaults:
|
||||
def test_defaults_match_previous_hardcoded_values(self):
|
||||
assert tol.DEFAULT_MAX_BYTES == 50_000
|
||||
assert tol.DEFAULT_MAX_LINES == 2000
|
||||
assert tol.DEFAULT_MAX_LINE_LENGTH == 2000
|
||||
|
||||
def test_get_limits_returns_defaults_when_config_missing(self):
|
||||
with patch("hermes_cli.config.load_config", return_value={}):
|
||||
limits = tol.get_tool_output_limits()
|
||||
assert limits == {
|
||||
"max_bytes": tol.DEFAULT_MAX_BYTES,
|
||||
"max_lines": tol.DEFAULT_MAX_LINES,
|
||||
"max_line_length": tol.DEFAULT_MAX_LINE_LENGTH,
|
||||
}
|
||||
|
||||
def test_get_limits_returns_defaults_when_config_not_a_dict(self):
|
||||
# load_config should always return a dict but be defensive anyway.
|
||||
with patch("hermes_cli.config.load_config", return_value="not a dict"):
|
||||
limits = tol.get_tool_output_limits()
|
||||
assert limits["max_bytes"] == tol.DEFAULT_MAX_BYTES
|
||||
|
||||
def test_get_limits_returns_defaults_when_load_config_raises(self):
|
||||
def _boom():
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with patch("hermes_cli.config.load_config", side_effect=_boom):
|
||||
limits = tol.get_tool_output_limits()
|
||||
assert limits["max_lines"] == tol.DEFAULT_MAX_LINES
|
||||
|
||||
|
||||
class TestOverrides:
|
||||
def test_user_config_overrides_all_three(self):
|
||||
cfg = {
|
||||
"tool_output": {
|
||||
"max_bytes": 100_000,
|
||||
"max_lines": 5000,
|
||||
"max_line_length": 4096,
|
||||
}
|
||||
}
|
||||
with patch("hermes_cli.config.load_config", return_value=cfg):
|
||||
limits = tol.get_tool_output_limits()
|
||||
assert limits == {
|
||||
"max_bytes": 100_000,
|
||||
"max_lines": 5000,
|
||||
"max_line_length": 4096,
|
||||
}
|
||||
|
||||
def test_partial_override_preserves_other_defaults(self):
|
||||
cfg = {"tool_output": {"max_bytes": 200_000}}
|
||||
with patch("hermes_cli.config.load_config", return_value=cfg):
|
||||
limits = tol.get_tool_output_limits()
|
||||
assert limits["max_bytes"] == 200_000
|
||||
assert limits["max_lines"] == tol.DEFAULT_MAX_LINES
|
||||
assert limits["max_line_length"] == tol.DEFAULT_MAX_LINE_LENGTH
|
||||
|
||||
def test_section_not_a_dict_falls_back(self):
|
||||
cfg = {"tool_output": "nonsense"}
|
||||
with patch("hermes_cli.config.load_config", return_value=cfg):
|
||||
limits = tol.get_tool_output_limits()
|
||||
assert limits["max_bytes"] == tol.DEFAULT_MAX_BYTES
|
||||
|
||||
|
||||
class TestCoercion:
|
||||
@pytest.mark.parametrize("bad", [None, "not a number", -1, 0, [], {}])
|
||||
def test_invalid_values_fall_back_to_defaults(self, bad):
|
||||
cfg = {"tool_output": {"max_bytes": bad, "max_lines": bad, "max_line_length": bad}}
|
||||
with patch("hermes_cli.config.load_config", return_value=cfg):
|
||||
limits = tol.get_tool_output_limits()
|
||||
assert limits["max_bytes"] == tol.DEFAULT_MAX_BYTES
|
||||
assert limits["max_lines"] == tol.DEFAULT_MAX_LINES
|
||||
assert limits["max_line_length"] == tol.DEFAULT_MAX_LINE_LENGTH
|
||||
|
||||
def test_string_integer_is_coerced(self):
|
||||
cfg = {"tool_output": {"max_bytes": "75000"}}
|
||||
with patch("hermes_cli.config.load_config", return_value=cfg):
|
||||
limits = tol.get_tool_output_limits()
|
||||
assert limits["max_bytes"] == 75_000
|
||||
|
||||
|
||||
class TestShortcuts:
|
||||
def test_individual_accessors_delegate_to_get_tool_output_limits(self):
|
||||
cfg = {
|
||||
"tool_output": {
|
||||
"max_bytes": 111,
|
||||
"max_lines": 222,
|
||||
"max_line_length": 333,
|
||||
}
|
||||
}
|
||||
with patch("hermes_cli.config.load_config", return_value=cfg):
|
||||
assert tol.get_max_bytes() == 111
|
||||
assert tol.get_max_lines() == 222
|
||||
assert tol.get_max_line_length() == 333
|
||||
|
||||
|
||||
class TestDefaultConfigHasSection:
|
||||
"""The DEFAULT_CONFIG in hermes_cli.config must expose tool_output so
|
||||
that ``hermes setup`` and default installs stay in sync with the
|
||||
helpers here."""
|
||||
|
||||
def test_default_config_contains_tool_output_section(self):
|
||||
from hermes_cli.config import DEFAULT_CONFIG
|
||||
assert "tool_output" in DEFAULT_CONFIG
|
||||
section = DEFAULT_CONFIG["tool_output"]
|
||||
assert isinstance(section, dict)
|
||||
assert section["max_bytes"] == tol.DEFAULT_MAX_BYTES
|
||||
assert section["max_lines"] == tol.DEFAULT_MAX_LINES
|
||||
assert section["max_line_length"] == tol.DEFAULT_MAX_LINE_LENGTH
|
||||
|
||||
|
||||
class TestIntegrationReadPagination:
|
||||
"""normalize_read_pagination uses get_max_lines() — verify the plumbing."""
|
||||
|
||||
def test_pagination_limit_clamped_by_config_value(self):
|
||||
from tools.file_operations import normalize_read_pagination
|
||||
cfg = {"tool_output": {"max_lines": 50}}
|
||||
with patch("hermes_cli.config.load_config", return_value=cfg):
|
||||
offset, limit = normalize_read_pagination(offset=1, limit=1000)
|
||||
# limit should have been clamped to 50 (the configured max_lines)
|
||||
assert limit == 50
|
||||
assert offset == 1
|
||||
|
||||
def test_pagination_default_when_config_missing(self):
|
||||
from tools.file_operations import normalize_read_pagination
|
||||
with patch("hermes_cli.config.load_config", return_value={}):
|
||||
offset, limit = normalize_read_pagination(offset=10, limit=100000)
|
||||
# Clamped to default MAX_LINES (2000).
|
||||
assert limit == tol.DEFAULT_MAX_LINES
|
||||
assert offset == 10
|
||||
|
|
@ -505,6 +505,101 @@ class TestTranscribeLocalExtended:
|
|||
assert result["success"] is True
|
||||
assert result["transcript"] == "Hello world"
|
||||
|
||||
def test_load_time_cuda_lib_failure_falls_back_to_cpu(self, tmp_path):
|
||||
"""Missing libcublas at load time → reload on CPU, succeed."""
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
seg = MagicMock()
|
||||
seg.text = "hi"
|
||||
info = MagicMock()
|
||||
info.language = "en"
|
||||
info.duration = 1.0
|
||||
|
||||
cpu_model = MagicMock()
|
||||
cpu_model.transcribe.return_value = ([seg], info)
|
||||
|
||||
call_args = []
|
||||
|
||||
def fake_whisper(model_name, device, compute_type):
|
||||
call_args.append((device, compute_type))
|
||||
if device == "auto":
|
||||
raise RuntimeError("Library libcublas.so.12 is not found or cannot be loaded")
|
||||
return cpu_model
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("faster_whisper.WhisperModel", side_effect=fake_whisper), \
|
||||
patch("tools.transcription_tools._local_model", None), \
|
||||
patch("tools.transcription_tools._local_model_name", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
result = _transcribe_local(str(audio), "base")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "hi"
|
||||
assert call_args == [("auto", "auto"), ("cpu", "int8")]
|
||||
|
||||
def test_runtime_cuda_lib_failure_evicts_cache_and_retries_on_cpu(self, tmp_path):
|
||||
"""libcublas dlopen fails at transcribe() → evict cache, reload CPU, retry."""
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
seg = MagicMock()
|
||||
seg.text = "recovered"
|
||||
info = MagicMock()
|
||||
info.language = "en"
|
||||
info.duration = 1.0
|
||||
|
||||
# First model loads fine (auto), but transcribe() blows up on dlopen
|
||||
gpu_model = MagicMock()
|
||||
gpu_model.transcribe.side_effect = RuntimeError(
|
||||
"Library libcublas.so.12 is not found or cannot be loaded"
|
||||
)
|
||||
# Second model (forced CPU) works
|
||||
cpu_model = MagicMock()
|
||||
cpu_model.transcribe.return_value = ([seg], info)
|
||||
|
||||
models = [gpu_model, cpu_model]
|
||||
call_args = []
|
||||
|
||||
def fake_whisper(model_name, device, compute_type):
|
||||
call_args.append((device, compute_type))
|
||||
return models.pop(0)
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("faster_whisper.WhisperModel", side_effect=fake_whisper), \
|
||||
patch("tools.transcription_tools._local_model", None), \
|
||||
patch("tools.transcription_tools._local_model_name", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
result = _transcribe_local(str(audio), "base")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "recovered"
|
||||
# First load is auto, retry forces CPU.
|
||||
assert call_args == [("auto", "auto"), ("cpu", "int8")]
|
||||
# Cached-bad-model eviction: the broken GPU model was called once,
|
||||
# then discarded; the CPU model served the retry.
|
||||
assert gpu_model.transcribe.call_count == 1
|
||||
assert cpu_model.transcribe.call_count == 1
|
||||
|
||||
def test_cuda_out_of_memory_does_not_trigger_cpu_fallback(self, tmp_path):
|
||||
"""'CUDA out of memory' is a real error, not a missing lib — surface it."""
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
mock_whisper_cls = MagicMock(side_effect=RuntimeError("CUDA out of memory"))
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("faster_whisper.WhisperModel", mock_whisper_cls), \
|
||||
patch("tools.transcription_tools._local_model", None), \
|
||||
patch("tools.transcription_tools._local_model_name", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
result = _transcribe_local(str(audio), "base")
|
||||
|
||||
# Single call — no CPU retry, because OOM isn't a missing-lib symptom.
|
||||
assert mock_whisper_cls.call_count == 1
|
||||
assert result["success"] is False
|
||||
assert "CUDA out of memory" in result["error"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Model auto-correction
|
||||
|
|
|
|||
|
|
@ -33,7 +33,12 @@ class TestWriteDenyExactPaths:
|
|||
assert _is_write_denied(path) is True
|
||||
|
||||
def test_hermes_env(self):
|
||||
path = os.path.join(str(Path.home()), ".hermes", ".env")
|
||||
# ``.env`` under the active HERMES_HOME (profile-aware, not just
|
||||
# ``~/.hermes``) must be write-denied. The hermetic test conftest
|
||||
# points HERMES_HOME at a tempdir — resolve via get_hermes_home()
|
||||
# to match the denylist.
|
||||
from hermes_constants import get_hermes_home
|
||||
path = str(get_hermes_home() / ".env")
|
||||
assert _is_write_denied(path) is True
|
||||
|
||||
def test_shell_profiles(self):
|
||||
|
|
|
|||
|
|
@ -110,8 +110,8 @@ class TestAgentCloseMethod:
|
|||
agent.client = None
|
||||
|
||||
with patch("tools.process_registry.process_registry") as mock_registry, \
|
||||
patch("tools.terminal_tool.cleanup_vm") as mock_cleanup_vm, \
|
||||
patch("tools.browser_tool.cleanup_browser") as mock_cleanup_browser:
|
||||
patch("run_agent.cleanup_vm") as mock_cleanup_vm, \
|
||||
patch("run_agent.cleanup_browser") as mock_cleanup_browser:
|
||||
agent.close()
|
||||
|
||||
mock_registry.kill_all.assert_called_once_with(
|
||||
|
|
@ -172,9 +172,9 @@ class TestAgentCloseMethod:
|
|||
with patch(
|
||||
"tools.process_registry.process_registry"
|
||||
) as mock_reg, patch(
|
||||
"tools.terminal_tool.cleanup_vm"
|
||||
"run_agent.cleanup_vm"
|
||||
) as mock_vm, patch(
|
||||
"tools.browser_tool.cleanup_browser"
|
||||
"run_agent.cleanup_browser"
|
||||
) as mock_browser:
|
||||
mock_reg.kill_all.side_effect = RuntimeError("boom")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue