mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
Merge branch 'main' of github.com:NousResearch/hermes-agent into bb/gui
This commit is contained in:
commit
085c33ed70
105 changed files with 5022 additions and 1714 deletions
51
.github/workflows/tests.yml
vendored
51
.github/workflows/tests.yml
vendored
|
|
@ -23,13 +23,24 @@ concurrency:
|
|||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
timeout-minutes: 60
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install system dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y ripgrep
|
||||
- name: Install ripgrep (prebuilt binary)
|
||||
run: |
|
||||
set -euo pipefail
|
||||
RG_VERSION=15.1.0
|
||||
RG_SHA256=1c9297be4a084eea7ecaedf93eb03d058d6faae29bbc57ecdaf5063921491599
|
||||
RG_TARBALL=ripgrep-${RG_VERSION}-x86_64-unknown-linux-musl.tar.gz
|
||||
curl -sSfL -o "$RG_TARBALL" \
|
||||
"https://github.com/BurntSushi/ripgrep/releases/download/${RG_VERSION}/${RG_TARBALL}"
|
||||
echo "${RG_SHA256} ${RG_TARBALL}" | sha256sum -c -
|
||||
tar -xzf "$RG_TARBALL"
|
||||
sudo mv "ripgrep-${RG_VERSION}-x86_64-unknown-linux-musl/rg" /usr/local/bin/rg
|
||||
rm -rf "$RG_TARBALL" "ripgrep-${RG_VERSION}-x86_64-unknown-linux-musl"
|
||||
rg --version
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5
|
||||
|
|
@ -44,9 +55,26 @@ jobs:
|
|||
uv pip install -e ".[all,dev]"
|
||||
|
||||
- name: Run tests
|
||||
# Per-file isolation via scripts/run_tests_parallel.py: discovers
|
||||
# every test_*.py file under tests/ (excluding integration/ + e2e/),
|
||||
# then runs `python -m pytest <file>` in a freshly-spawned subprocess
|
||||
# with bounded parallelism. No xdist, no shared workers, no
|
||||
# module-level state leakage between files.
|
||||
#
|
||||
# Why per-file (not per-test): per-test spawn cost (~250ms × 17k
|
||||
# tests = 70min CPU minimum) blew the wall-clock budget. Per-file
|
||||
# spawn (~250ms × ~850 files = ~3.5min) fits while still giving
|
||||
# every file a fresh interpreter — the only isolation boundary
|
||||
# that matters in practice (cross-file leakage was the original
|
||||
# flake source; intra-file is the test author's responsibility).
|
||||
#
|
||||
# Why drop xdist entirely: xdist's persistent workers accumulate
|
||||
# state across files, which is exactly the leakage we wanted to
|
||||
# fix. ThreadPoolExecutor + subprocess.run is ~60 lines and does
|
||||
# the job with cleaner semantics.
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
python -m pytest tests/ -q --ignore=tests/integration --ignore=tests/e2e --tb=short -n auto --timeout=30 --timeout-method=signal
|
||||
python scripts/run_tests_parallel.py
|
||||
env:
|
||||
# Ensure tests don't accidentally call real APIs
|
||||
OPENROUTER_API_KEY: ""
|
||||
|
|
@ -60,8 +88,19 @@ jobs:
|
|||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Install system dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y ripgrep
|
||||
- name: Install ripgrep (prebuilt binary)
|
||||
run: |
|
||||
set -euo pipefail
|
||||
RG_VERSION=15.1.0
|
||||
RG_SHA256=1c9297be4a084eea7ecaedf93eb03d058d6faae29bbc57ecdaf5063921491599
|
||||
RG_TARBALL=ripgrep-${RG_VERSION}-x86_64-unknown-linux-musl.tar.gz
|
||||
curl -sSfL -o "$RG_TARBALL" \
|
||||
"https://github.com/BurntSushi/ripgrep/releases/download/${RG_VERSION}/${RG_TARBALL}"
|
||||
echo "${RG_SHA256} ${RG_TARBALL}" | sha256sum -c -
|
||||
tar -xzf "$RG_TARBALL"
|
||||
sudo mv "ripgrep-${RG_VERSION}-x86_64-unknown-linux-musl/rg" /usr/local/bin/rg
|
||||
rm -rf "$RG_TARBALL" "ripgrep-${RG_VERSION}-x86_64-unknown-linux-musl"
|
||||
rg --version
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -18,6 +18,7 @@ __pycache__/web_tools.cpython-310.pyc
|
|||
logs/
|
||||
data/
|
||||
.pytest_cache/
|
||||
.pytest-cache/
|
||||
tmp/
|
||||
temp_vision_images/
|
||||
hermes-*/*
|
||||
|
|
|
|||
44
AGENTS.md
44
AGENTS.md
|
|
@ -1038,17 +1038,39 @@ def profile_env(tmp_path, monkeypatch):
|
|||
|
||||
**ALWAYS use `scripts/run_tests.sh`** — do not call `pytest` directly. The script enforces
|
||||
hermetic environment parity with CI (unset credential vars, TZ=UTC, LANG=C.UTF-8,
|
||||
4 xdist workers matching GHA ubuntu-latest). Direct `pytest` on a 16+ core
|
||||
developer machine with API keys set diverges from CI in ways that have caused
|
||||
multiple "works locally, fails in CI" incidents (and the reverse).
|
||||
`-n auto` xdist workers, in-tree subprocess-isolation plugin). Direct `pytest`
|
||||
on a 16+ core developer machine with API keys set diverges from CI in ways
|
||||
that have caused multiple "works locally, fails in CI" incidents (and the reverse).
|
||||
|
||||
```bash
|
||||
scripts/run_tests.sh # full suite, CI-parity
|
||||
scripts/run_tests.sh tests/gateway/ # one directory
|
||||
scripts/run_tests.sh tests/agent/test_foo.py::test_x # one test
|
||||
scripts/run_tests.sh -v --tb=long # pass-through pytest flags
|
||||
scripts/run_tests.sh --no-isolate tests/foo/ # disable subprocess isolation (faster, for debugging)
|
||||
```
|
||||
|
||||
### Subprocess-per-test isolation
|
||||
|
||||
Every test runs in a freshly-spawned Python subprocess via the in-tree plugin
|
||||
at `tests/_isolate_plugin.py`. This means module-level dicts/sets and
|
||||
ContextVars from one test cannot leak into the next — the historic
|
||||
`_reset_module_state` autouse fixture is gone.
|
||||
|
||||
Implementation notes:
|
||||
|
||||
- The plugin uses `multiprocessing.get_context("spawn")`, which works on
|
||||
Linux, macOS, and Windows alike (POSIX `fork` is not used).
|
||||
- Per-test overhead is ~0.5–1.0s (Python startup + pytest collection). xdist
|
||||
parallelism amortizes this across cores; on a 20-core box the full suite
|
||||
finishes in roughly the same wall time as before, but flake-free.
|
||||
- `isolate_timeout` (configured in `pyproject.toml`) caps each test at 30s.
|
||||
Hangs are killed and surfaced as a failure report.
|
||||
- Pass `--no-isolate` to disable isolation — useful when debugging a single
|
||||
test interactively, or when you specifically want to verify state leakage.
|
||||
- The plugin disables itself in child processes (sentinel envvar
|
||||
`HERMES_ISOLATE_CHILD=1`), so there's no fork-bomb risk.
|
||||
|
||||
### Why the wrapper (and why the old "just call pytest" doesn't work)
|
||||
|
||||
Five real sources of local-vs-CI drift the script closes:
|
||||
|
|
@ -1059,7 +1081,7 @@ Five real sources of local-vs-CI drift the script closes:
|
|||
| HOME / `~/.hermes/` | Your real config+auth.json | Temp dir per test |
|
||||
| Timezone | Local TZ (PDT etc.) | UTC |
|
||||
| Locale | Whatever is set | C.UTF-8 |
|
||||
| xdist workers | `-n auto` = all cores (20+ on a workstation) | `-n 4` matching CI |
|
||||
| xdist workers | `-n auto` = all cores | `-n auto` (safe — subprocess isolation prevents cross-worker flakes) |
|
||||
|
||||
`tests/conftest.py` also enforces points 1-4 as an autouse fixture so ANY pytest
|
||||
invocation (including IDE integrations) gets hermetic behavior — but the wrapper
|
||||
|
|
@ -1067,15 +1089,21 @@ is belt-and-suspenders.
|
|||
|
||||
### Running without the wrapper (only if you must)
|
||||
|
||||
If you can't use the wrapper (e.g. on Windows or inside an IDE that shells
|
||||
pytest directly), at minimum activate the venv and pass `-n 4`:
|
||||
If you can't use the wrapper (e.g. inside an IDE that shells pytest directly),
|
||||
at minimum activate the venv. The isolation plugin loads automatically from
|
||||
`addopts` in `pyproject.toml`, so you get the same per-test process isolation
|
||||
either way.
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate # or: source venv/bin/activate
|
||||
python -m pytest tests/ -q -n 4
|
||||
python -m pytest tests/ -q
|
||||
```
|
||||
|
||||
Worker count above 4 will surface test-ordering flakes that CI never sees.
|
||||
If you need to bypass isolation for fast feedback while debugging:
|
||||
|
||||
```bash
|
||||
python -m pytest tests/agent/test_foo.py -q --no-isolate
|
||||
```
|
||||
|
||||
Always run the full suite before pushing changes.
|
||||
|
||||
|
|
|
|||
|
|
@ -71,6 +71,71 @@ def _ra():
|
|||
return run_agent
|
||||
|
||||
|
||||
def _normalized_custom_base_url(value: Any) -> str:
|
||||
if not isinstance(value, str):
|
||||
return ""
|
||||
return value.strip().rstrip("/")
|
||||
|
||||
|
||||
def _custom_provider_model_matches(agent_model: str, entry: Dict[str, Any]) -> bool:
|
||||
provider_model = str(entry.get("model", "") or "").strip().lower()
|
||||
if not provider_model:
|
||||
return True
|
||||
return provider_model == str(agent_model or "").strip().lower()
|
||||
|
||||
|
||||
def _custom_provider_extra_body_for_agent(
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
base_url: str,
|
||||
custom_providers: List[Dict[str, Any]],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if (provider or "").strip().lower() != "custom":
|
||||
return None
|
||||
|
||||
target_url = _normalized_custom_base_url(base_url)
|
||||
if not target_url:
|
||||
return None
|
||||
|
||||
fallback: Optional[Dict[str, Any]] = None
|
||||
for entry in custom_providers or []:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if _normalized_custom_base_url(entry.get("base_url")) != target_url:
|
||||
continue
|
||||
extra_body = entry.get("extra_body")
|
||||
if not isinstance(extra_body, dict) or not extra_body:
|
||||
continue
|
||||
provider_model = str(entry.get("model", "") or "").strip()
|
||||
if provider_model:
|
||||
if _custom_provider_model_matches(model, entry):
|
||||
return dict(extra_body)
|
||||
elif fallback is None:
|
||||
fallback = dict(extra_body)
|
||||
|
||||
return fallback
|
||||
|
||||
|
||||
def _merge_custom_provider_extra_body(agent, custom_providers: List[Dict[str, Any]]) -> None:
|
||||
extra_body = _custom_provider_extra_body_for_agent(
|
||||
provider=agent.provider,
|
||||
model=agent.model,
|
||||
base_url=agent.base_url,
|
||||
custom_providers=custom_providers,
|
||||
)
|
||||
if not extra_body:
|
||||
return
|
||||
|
||||
overrides = dict(getattr(agent, "request_overrides", {}) or {})
|
||||
merged_extra_body = dict(extra_body)
|
||||
existing_extra_body = overrides.get("extra_body")
|
||||
if isinstance(existing_extra_body, dict):
|
||||
merged_extra_body.update(existing_extra_body)
|
||||
overrides["extra_body"] = merged_extra_body
|
||||
agent.request_overrides = overrides
|
||||
|
||||
|
||||
def init_agent(
|
||||
agent,
|
||||
base_url: str = None,
|
||||
|
|
@ -1213,6 +1278,7 @@ def init_agent(
|
|||
# Store for reuse by _check_compression_model_feasibility (auxiliary
|
||||
# compression model context-length detection needs the same list).
|
||||
agent._custom_providers = _custom_providers
|
||||
_merge_custom_provider_extra_body(agent, _custom_providers)
|
||||
|
||||
# Check custom_providers per-model context_length
|
||||
if _config_context_length is None and _custom_providers:
|
||||
|
|
|
|||
|
|
@ -1869,6 +1869,77 @@ def copy_reasoning_content_for_api(agent, source_msg: dict, api_msg: dict) -> No
|
|||
|
||||
|
||||
|
||||
def _iter_pool_sockets(client: Any):
|
||||
"""Yield raw sockets reachable from an OpenAI/httpx client pool.
|
||||
|
||||
httpcore 1.x stores the concrete HTTP11/HTTP2 connection under
|
||||
``conn._connection``; older versions exposed stream attributes directly
|
||||
on the pool entry. Keep the traversal defensive because these are private
|
||||
transport internals and vary across httpx/httpcore releases.
|
||||
"""
|
||||
try:
|
||||
http_client = getattr(client, "_client", None)
|
||||
if http_client is None:
|
||||
return
|
||||
transport = getattr(http_client, "_transport", None)
|
||||
if transport is None:
|
||||
return
|
||||
pool = getattr(transport, "_pool", None)
|
||||
if pool is None:
|
||||
return
|
||||
connections = (
|
||||
getattr(pool, "_connections", None)
|
||||
or getattr(pool, "_pool", None)
|
||||
or []
|
||||
)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
seen: set[int] = set()
|
||||
for conn in list(connections):
|
||||
candidates = [conn]
|
||||
inner = getattr(conn, "_connection", None)
|
||||
if inner is not None:
|
||||
candidates.append(inner)
|
||||
for candidate in candidates:
|
||||
stream = (
|
||||
getattr(candidate, "_network_stream", None)
|
||||
or getattr(candidate, "_stream", None)
|
||||
)
|
||||
if stream is None:
|
||||
continue
|
||||
sock = getattr(stream, "_sock", None)
|
||||
if sock is None:
|
||||
get_extra_info = getattr(stream, "get_extra_info", None)
|
||||
if callable(get_extra_info):
|
||||
try:
|
||||
sock = get_extra_info("socket")
|
||||
except Exception:
|
||||
sock = None
|
||||
if sock is None:
|
||||
wrapped = getattr(stream, "stream", None)
|
||||
if wrapped is not None:
|
||||
sock = getattr(wrapped, "_sock", None)
|
||||
if sock is None:
|
||||
# anyio-backed streams expose the raw socket through
|
||||
# SocketAttribute.raw_socket when available.
|
||||
wrapped = getattr(stream, "_stream", None)
|
||||
extra = getattr(wrapped, "extra", None)
|
||||
if callable(extra):
|
||||
try:
|
||||
from anyio.abc import SocketAttribute
|
||||
sock = extra(SocketAttribute.raw_socket)
|
||||
except Exception:
|
||||
sock = None
|
||||
if sock is None:
|
||||
continue
|
||||
marker = id(sock)
|
||||
if marker in seen:
|
||||
continue
|
||||
seen.add(marker)
|
||||
yield sock
|
||||
|
||||
|
||||
def cleanup_dead_connections(agent) -> bool:
|
||||
"""Detect and clean up dead TCP connections on the primary client.
|
||||
|
||||
|
|
@ -1882,36 +1953,8 @@ def cleanup_dead_connections(agent) -> bool:
|
|||
if client is None:
|
||||
return False
|
||||
try:
|
||||
http_client = getattr(client, "_client", None)
|
||||
if http_client is None:
|
||||
return False
|
||||
transport = getattr(http_client, "_transport", None)
|
||||
if transport is None:
|
||||
return False
|
||||
pool = getattr(transport, "_pool", None)
|
||||
if pool is None:
|
||||
return False
|
||||
connections = (
|
||||
getattr(pool, "_connections", None)
|
||||
or getattr(pool, "_pool", None)
|
||||
or []
|
||||
)
|
||||
dead_count = 0
|
||||
for conn in list(connections):
|
||||
# Check for connections that are idle but have closed sockets
|
||||
stream = (
|
||||
getattr(conn, "_network_stream", None)
|
||||
or getattr(conn, "_stream", None)
|
||||
)
|
||||
if stream is None:
|
||||
continue
|
||||
sock = getattr(stream, "_sock", None)
|
||||
if sock is None:
|
||||
sock = getattr(stream, "stream", None)
|
||||
if sock is not None:
|
||||
sock = getattr(sock, "_sock", None)
|
||||
if sock is None:
|
||||
continue
|
||||
for sock in _iter_pool_sockets(client):
|
||||
# Probe socket health with a non-blocking recv peek
|
||||
import socket as _socket
|
||||
try:
|
||||
|
|
@ -2087,36 +2130,7 @@ def force_close_tcp_sockets(client: Any) -> int:
|
|||
|
||||
closed = 0
|
||||
try:
|
||||
http_client = getattr(client, "_client", None)
|
||||
if http_client is None:
|
||||
return 0
|
||||
transport = getattr(http_client, "_transport", None)
|
||||
if transport is None:
|
||||
return 0
|
||||
pool = getattr(transport, "_pool", None)
|
||||
if pool is None:
|
||||
return 0
|
||||
# httpx uses httpcore connection pools; connections live in
|
||||
# _connections (list) or _pool (list) depending on version.
|
||||
connections = (
|
||||
getattr(pool, "_connections", None)
|
||||
or getattr(pool, "_pool", None)
|
||||
or []
|
||||
)
|
||||
for conn in list(connections):
|
||||
stream = (
|
||||
getattr(conn, "_network_stream", None)
|
||||
or getattr(conn, "_stream", None)
|
||||
)
|
||||
if stream is None:
|
||||
continue
|
||||
sock = getattr(stream, "_sock", None)
|
||||
if sock is None:
|
||||
sock = getattr(stream, "stream", None)
|
||||
if sock is not None:
|
||||
sock = getattr(sock, "_sock", None)
|
||||
if sock is None:
|
||||
continue
|
||||
for sock in _iter_pool_sockets(client):
|
||||
try:
|
||||
sock.shutdown(_socket.SHUT_RDWR)
|
||||
except OSError:
|
||||
|
|
@ -2154,5 +2168,6 @@ __all__ = [
|
|||
"cleanup_dead_connections",
|
||||
"extract_api_error_context",
|
||||
"apply_pending_steer_to_tool_results",
|
||||
"_iter_pool_sockets",
|
||||
"force_close_tcp_sockets",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -390,6 +390,9 @@ def _run_review_in_thread(
|
|||
# parent below so memory(action="add") writes from
|
||||
# the review still land on disk; the review just
|
||||
# has zero side effects on external providers.
|
||||
# Match parent's toolset config so ``tools[]`` is byte-identical
|
||||
# in the request body — Anthropic's cache key includes it.
|
||||
# (The runtime whitelist below still restricts dispatch.)
|
||||
review_agent = AIAgent(
|
||||
model=agent.model,
|
||||
max_iterations=16,
|
||||
|
|
@ -401,6 +404,8 @@ def _run_review_in_thread(
|
|||
api_key=_parent_runtime.get("api_key") or None,
|
||||
credential_pool=getattr(agent, "_credential_pool", None),
|
||||
parent_session_id=agent.session_id,
|
||||
enabled_toolsets=getattr(agent, "enabled_toolsets", None),
|
||||
disabled_toolsets=getattr(agent, "disabled_toolsets", None),
|
||||
skip_memory=True,
|
||||
)
|
||||
review_agent._memory_write_origin = "background_review"
|
||||
|
|
|
|||
|
|
@ -92,17 +92,36 @@ def interruptible_api_call(agent, api_kwargs: dict):
|
|||
"""
|
||||
result = {"response": None, "error": None}
|
||||
request_client_holder = {"client": None}
|
||||
request_client_lock = threading.Lock()
|
||||
|
||||
def _set_request_client(client):
|
||||
with request_client_lock:
|
||||
request_client_holder["client"] = client
|
||||
return client
|
||||
|
||||
def _take_request_client():
|
||||
with request_client_lock:
|
||||
client = request_client_holder.get("client")
|
||||
request_client_holder["client"] = None
|
||||
return client
|
||||
|
||||
def _close_request_client_once(reason: str) -> None:
|
||||
request_client = _take_request_client()
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason=reason)
|
||||
|
||||
def _call():
|
||||
try:
|
||||
if agent.api_mode == "codex_responses":
|
||||
request_client_holder["client"] = agent._create_request_openai_client(
|
||||
reason="codex_stream_request",
|
||||
api_kwargs=api_kwargs,
|
||||
request_client = _set_request_client(
|
||||
agent._create_request_openai_client(
|
||||
reason="codex_stream_request",
|
||||
api_kwargs=api_kwargs,
|
||||
)
|
||||
)
|
||||
result["response"] = agent._run_codex_stream(
|
||||
api_kwargs,
|
||||
client=request_client_holder["client"],
|
||||
client=request_client,
|
||||
on_first_delta=getattr(agent, "_codex_on_first_delta", None),
|
||||
)
|
||||
elif agent.api_mode == "anthropic_messages":
|
||||
|
|
@ -131,17 +150,17 @@ def interruptible_api_call(agent, api_kwargs: dict):
|
|||
raise
|
||||
result["response"] = normalize_converse_response(raw_response)
|
||||
else:
|
||||
request_client_holder["client"] = agent._create_request_openai_client(
|
||||
reason="chat_completion_request",
|
||||
api_kwargs=api_kwargs,
|
||||
request_client = _set_request_client(
|
||||
agent._create_request_openai_client(
|
||||
reason="chat_completion_request",
|
||||
api_kwargs=api_kwargs,
|
||||
)
|
||||
)
|
||||
result["response"] = request_client_holder["client"].chat.completions.create(**api_kwargs)
|
||||
result["response"] = request_client.chat.completions.create(**api_kwargs)
|
||||
except Exception as e:
|
||||
result["error"] = e
|
||||
finally:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason="request_complete")
|
||||
_close_request_client_once("request_complete")
|
||||
|
||||
# ── Stale-call timeout (mirrors streaming stale detector) ────────
|
||||
# Non-streaming calls return nothing until the full response is
|
||||
|
|
@ -192,9 +211,7 @@ def interruptible_api_call(agent, api_kwargs: dict):
|
|||
agent._anthropic_client.close()
|
||||
agent._rebuild_anthropic_client()
|
||||
else:
|
||||
rc = request_client_holder.get("client")
|
||||
if rc is not None:
|
||||
agent._close_request_openai_client(rc, reason="stale_call_kill")
|
||||
_close_request_client_once("stale_call_kill")
|
||||
except Exception:
|
||||
pass
|
||||
agent._touch_activity(
|
||||
|
|
@ -218,9 +235,7 @@ def interruptible_api_call(agent, api_kwargs: dict):
|
|||
agent._anthropic_client.close()
|
||||
agent._rebuild_anthropic_client()
|
||||
else:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason="interrupt_abort")
|
||||
_close_request_client_once("interrupt_abort")
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during API call")
|
||||
|
|
@ -1257,6 +1272,24 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
|
||||
result = {"response": None, "error": None, "partial_tool_names": []}
|
||||
request_client_holder = {"client": None, "diag": None}
|
||||
request_client_lock = threading.Lock()
|
||||
|
||||
def _set_request_client(client):
|
||||
with request_client_lock:
|
||||
request_client_holder["client"] = client
|
||||
return client
|
||||
|
||||
def _take_request_client():
|
||||
with request_client_lock:
|
||||
client = request_client_holder.get("client")
|
||||
request_client_holder["client"] = None
|
||||
return client
|
||||
|
||||
def _close_request_client_once(reason: str) -> None:
|
||||
request_client = _take_request_client()
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason=reason)
|
||||
|
||||
first_delta_fired = {"done": False}
|
||||
deltas_were_sent = {"yes": False} # Track if any deltas were fired (for fallback)
|
||||
# Wall-clock timestamp of the last real streaming chunk. The outer
|
||||
|
|
@ -1313,9 +1346,11 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
pool=_conn_cap,
|
||||
),
|
||||
}
|
||||
request_client_holder["client"] = agent._create_request_openai_client(
|
||||
reason="chat_completion_stream_request",
|
||||
api_kwargs=stream_kwargs,
|
||||
request_client = _set_request_client(
|
||||
agent._create_request_openai_client(
|
||||
reason="chat_completion_stream_request",
|
||||
api_kwargs=stream_kwargs,
|
||||
)
|
||||
)
|
||||
# Reset stale-stream timer so the detector measures from this
|
||||
# attempt's start, not a previous attempt's last chunk.
|
||||
|
|
@ -1326,7 +1361,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
# ``request_client_holder["diag"]`` for closure access.
|
||||
_diag = agent._stream_diag_init()
|
||||
request_client_holder["diag"] = _diag
|
||||
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
|
||||
stream = request_client.chat.completions.create(**stream_kwargs)
|
||||
|
||||
# Capture rate limit headers from the initial HTTP response.
|
||||
# The OpenAI SDK Stream object exposes the underlying httpx
|
||||
|
|
@ -1765,12 +1800,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
mid_tool_call=True,
|
||||
diag=request_client_holder.get("diag"),
|
||||
)
|
||||
stale = request_client_holder.get("client")
|
||||
if stale is not None:
|
||||
agent._close_request_openai_client(
|
||||
stale, reason="stream_mid_tool_retry_cleanup"
|
||||
)
|
||||
request_client_holder["client"] = None
|
||||
_close_request_client_once("stream_mid_tool_retry_cleanup")
|
||||
try:
|
||||
agent._replace_primary_openai_client(
|
||||
reason="stream_mid_tool_retry_pool_cleanup"
|
||||
|
|
@ -1821,12 +1851,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
diag=request_client_holder.get("diag"),
|
||||
)
|
||||
# Close the stale request client before retry
|
||||
stale = request_client_holder.get("client")
|
||||
if stale is not None:
|
||||
agent._close_request_openai_client(
|
||||
stale, reason="stream_retry_cleanup"
|
||||
)
|
||||
request_client_holder["client"] = None
|
||||
_close_request_client_once("stream_retry_cleanup")
|
||||
# Also rebuild the primary client to purge
|
||||
# any dead connections from the pool.
|
||||
try:
|
||||
|
|
@ -1894,9 +1919,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
result["error"] = e
|
||||
return
|
||||
finally:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason="stream_request_complete")
|
||||
_close_request_client_once("stream_request_complete")
|
||||
|
||||
# Provider-configured stale timeout takes priority over env default.
|
||||
_cfg_stale = get_provider_stale_timeout(agent.provider, agent.model)
|
||||
|
|
@ -1966,9 +1989,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
f"Reconnecting..."
|
||||
)
|
||||
try:
|
||||
rc = request_client_holder.get("client")
|
||||
if rc is not None:
|
||||
agent._close_request_openai_client(rc, reason="stale_stream_kill")
|
||||
_close_request_client_once("stale_stream_kill")
|
||||
except Exception:
|
||||
pass
|
||||
# Rebuild the primary client too — its connection pool
|
||||
|
|
@ -1990,9 +2011,7 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta=
|
|||
agent._anthropic_client.close()
|
||||
agent._rebuild_anthropic_client()
|
||||
else:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
agent._close_request_openai_client(request_client, reason="stream_interrupt_abort")
|
||||
_close_request_client_once("stream_interrupt_abort")
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during streaming API call")
|
||||
|
|
|
|||
|
|
@ -251,13 +251,16 @@ def _chat_messages_to_responses_input(
|
|||
) -> List[Dict[str, Any]]:
|
||||
"""Convert internal chat-style messages to Responses input items.
|
||||
|
||||
``is_xai_responses=True`` strips ``encrypted_content`` from replayed
|
||||
reasoning items. xAI's OAuth/SuperGrok ``/v1/responses`` surface
|
||||
rejects encrypted reasoning blobs minted by prior turns: the request
|
||||
streams an ``error`` SSE frame before ``response.created`` and the
|
||||
OpenAI SDK collapses it into a generic stream-ordering error. Native
|
||||
Codex (chatgpt.com backend-api) DOES accept replayed encrypted_content
|
||||
— keep the default off.
|
||||
``is_xai_responses`` is kept for transport signature compatibility but
|
||||
no longer suppresses encrypted reasoning replay. Earlier (PR #26644,
|
||||
May 2026) we believed xAI's OAuth/SuperGrok ``/v1/responses`` surface
|
||||
rejected replayed ``encrypted_content`` reasoning items minted by
|
||||
prior turns, and we stripped them. That decision was wrong — xAI
|
||||
explicitly relies on Hermes threading encrypted reasoning back across
|
||||
turns for cross-turn coherence (the whole point of their partnership
|
||||
integration). We now replay encrypted reasoning on every Responses
|
||||
transport (xAI, native Codex, custom relays) and let xAI tell us
|
||||
explicitly if a specific surface ever rejects a payload.
|
||||
"""
|
||||
items: List[Dict[str, Any]] = []
|
||||
seen_item_ids: set = set()
|
||||
|
|
@ -284,17 +287,12 @@ def _chat_messages_to_responses_input(
|
|||
if role == "assistant":
|
||||
# Replay encrypted reasoning items from previous turns
|
||||
# so the API can maintain coherent reasoning chains.
|
||||
#
|
||||
# xAI OAuth (SuperGrok/Premium) rejects replayed
|
||||
# ``encrypted_content`` reasoning items minted by prior
|
||||
# turns — see _chat_messages_to_responses_input docstring.
|
||||
# When ``is_xai_responses`` is set we drop the replay
|
||||
# entirely; Grok still reasons on each turn server-side,
|
||||
# we just don't try to thread the prior turn's encrypted
|
||||
# blob back in.
|
||||
# This applies to every Responses transport including
|
||||
# xAI — see _chat_messages_to_responses_input docstring
|
||||
# for the May 2026 reversal of the earlier xAI gate.
|
||||
codex_reasoning = msg.get("codex_reasoning_items")
|
||||
has_codex_reasoning = False
|
||||
if isinstance(codex_reasoning, list) and not is_xai_responses:
|
||||
if isinstance(codex_reasoning, list):
|
||||
for ri in codex_reasoning:
|
||||
if isinstance(ri, dict) and ri.get("encrypted_content"):
|
||||
item_id = ri.get("id")
|
||||
|
|
|
|||
|
|
@ -16,9 +16,19 @@ def _hermes_home_path() -> Path:
|
|||
return Path(os.path.expanduser("~/.hermes"))
|
||||
|
||||
|
||||
def _hermes_root_path() -> Path:
|
||||
"""Resolve the Hermes root dir (always the parent of any profile, never per-profile)."""
|
||||
try:
|
||||
from hermes_constants import get_default_hermes_root # local import to avoid cycles
|
||||
return get_default_hermes_root()
|
||||
except Exception:
|
||||
return Path(os.path.expanduser("~/.hermes"))
|
||||
|
||||
|
||||
def build_write_denied_paths(home: str) -> set[str]:
|
||||
"""Return exact sensitive paths that must never be written."""
|
||||
hermes_home = _hermes_home_path()
|
||||
hermes_root = _hermes_root_path()
|
||||
return {
|
||||
os.path.realpath(p)
|
||||
for p in [
|
||||
|
|
@ -26,7 +36,11 @@ def build_write_denied_paths(home: str) -> set[str]:
|
|||
os.path.join(home, ".ssh", "id_rsa"),
|
||||
os.path.join(home, ".ssh", "id_ed25519"),
|
||||
os.path.join(home, ".ssh", "config"),
|
||||
# Active profile .env (or top-level .env when not in profile mode).
|
||||
str(hermes_home / ".env"),
|
||||
# Top-level .env, even when running under a profile — overwriting it
|
||||
# leaks credentials across every profile that inherits from root (#15981).
|
||||
str(hermes_root / ".env"),
|
||||
os.path.join(home, ".bashrc"),
|
||||
os.path.join(home, ".zshrc"),
|
||||
os.path.join(home, ".profile"),
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ from dataclasses import dataclass
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from hermes_constants import get_hermes_home, secure_parent_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -491,10 +491,8 @@ def save_credentials(creds: GoogleCredentials) -> Path:
|
|||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Tighten parent dir to 0o700 so siblings can't traverse to the creds file.
|
||||
# On Windows this is a no-op (POSIX mode bits aren't enforced); ignore failures.
|
||||
try:
|
||||
os.chmod(path.parent, 0o700)
|
||||
except OSError:
|
||||
pass
|
||||
# secure_parent_dir refuses to chmod / or top-level dirs (#25821).
|
||||
secure_parent_dir(path)
|
||||
payload = json.dumps(creds.to_dict(), indent=2, sort_keys=True) + "\n"
|
||||
|
||||
with _credentials_lock():
|
||||
|
|
|
|||
|
|
@ -46,6 +46,84 @@ logger = logging.getLogger(__name__)
|
|||
_VALID_MODES = frozenset({"auto", "native", "text"})
|
||||
|
||||
|
||||
# Strict YAML/JSON boolean coercion for capability overrides.
|
||||
#
|
||||
# ``bool("false")`` is True in Python because non-empty strings are truthy, so
|
||||
# a user writing ``supports_vision: "false"`` (quoted — a common YAML mistake)
|
||||
# would silently enable native vision routing on a model that can't actually
|
||||
# handle it. Accept only the values YAML 1.1 / 1.2 treat as booleans, plus
|
||||
# real ``bool`` and integer 0/1. Anything else returns None so the caller
|
||||
# falls through to models.dev rather than honouring garbage.
|
||||
_TRUE_TOKENS = frozenset({"true", "yes", "on", "1"})
|
||||
_FALSE_TOKENS = frozenset({"false", "no", "off", "0"})
|
||||
|
||||
|
||||
def _coerce_capability_bool(raw: Any) -> Optional[bool]:
|
||||
"""Return True/False for recognised boolean values, None otherwise."""
|
||||
if isinstance(raw, bool):
|
||||
return raw
|
||||
if isinstance(raw, int):
|
||||
if raw in (0, 1):
|
||||
return bool(raw)
|
||||
return None
|
||||
if isinstance(raw, str):
|
||||
s = raw.strip().lower()
|
||||
if s in _TRUE_TOKENS:
|
||||
return True
|
||||
if s in _FALSE_TOKENS:
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
def _supports_vision_override(
|
||||
cfg: Optional[Dict[str, Any]],
|
||||
provider: str,
|
||||
model: str,
|
||||
) -> Optional[bool]:
|
||||
"""Resolve user-declared vision capability from config.yaml.
|
||||
|
||||
Resolution order, first hit wins:
|
||||
1. ``model.supports_vision`` (top-level shortcut for the active model)
|
||||
2. ``providers.<provider>.models.<model>.supports_vision``
|
||||
(named custom providers — ``provider`` may be the runtime-resolved
|
||||
value ``"custom"`` and/or the user-declared name under
|
||||
``model.provider``; both are tried)
|
||||
|
||||
Returns None when no override is set, so the caller falls through to
|
||||
models.dev. Returns False explicitly only when the user wrote a
|
||||
recognised boolean false token.
|
||||
"""
|
||||
if not isinstance(cfg, dict):
|
||||
return None
|
||||
|
||||
# 1. Top-level shortcut
|
||||
model_cfg_raw = cfg.get("model")
|
||||
model_cfg: Dict[str, Any] = model_cfg_raw if isinstance(model_cfg_raw, dict) else {}
|
||||
top = _coerce_capability_bool(model_cfg.get("supports_vision"))
|
||||
if top is not None:
|
||||
return top
|
||||
|
||||
# 2. Per-provider, per-model. Named custom providers (e.g. "my-vllm")
|
||||
# get rewritten to provider="custom" at runtime
|
||||
# (hermes_cli/runtime_provider.py:_resolve_named_custom_runtime), so the
|
||||
# config still holds the user-declared name under model.provider. Try
|
||||
# both as candidate provider keys.
|
||||
config_provider = str(model_cfg.get("provider") or "").strip()
|
||||
providers_raw = cfg.get("providers")
|
||||
providers_cfg: Dict[str, Any] = providers_raw if isinstance(providers_raw, dict) else {}
|
||||
for p in dict.fromkeys(filter(None, (provider, config_provider))):
|
||||
entry_raw = providers_cfg.get(p)
|
||||
entry: Dict[str, Any] = entry_raw if isinstance(entry_raw, dict) else {}
|
||||
models_raw = entry.get("models")
|
||||
models_cfg: Dict[str, Any] = models_raw if isinstance(models_raw, dict) else {}
|
||||
per_model_raw = models_cfg.get(model)
|
||||
per_model: Dict[str, Any] = per_model_raw if isinstance(per_model_raw, dict) else {}
|
||||
coerced = _coerce_capability_bool(per_model.get("supports_vision"))
|
||||
if coerced is not None:
|
||||
return coerced
|
||||
return None
|
||||
|
||||
|
||||
def _coerce_mode(raw: Any) -> str:
|
||||
"""Normalize a config value into one of the valid modes."""
|
||||
if not isinstance(raw, str):
|
||||
|
|
@ -81,8 +159,20 @@ def _explicit_aux_vision_override(cfg: Optional[Dict[str, Any]]) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def _lookup_supports_vision(provider: str, model: str) -> Optional[bool]:
|
||||
"""Return True/False if we can resolve caps, None if unknown."""
|
||||
def _lookup_supports_vision(
|
||||
provider: str,
|
||||
model: str,
|
||||
cfg: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[bool]:
|
||||
"""Return True/False if we can resolve caps, None if unknown.
|
||||
|
||||
Consults the user's ``supports_vision`` override in config.yaml first
|
||||
(so custom/local models declared as vision-capable don't fall through to
|
||||
text routing in ``auto`` mode), then falls back to models.dev.
|
||||
"""
|
||||
override = _supports_vision_override(cfg, provider, model)
|
||||
if override is not None:
|
||||
return override
|
||||
if not provider or not model:
|
||||
return None
|
||||
try:
|
||||
|
|
@ -123,7 +213,7 @@ def decide_image_input_mode(
|
|||
if _explicit_aux_vision_override(cfg):
|
||||
return "text"
|
||||
|
||||
supports = _lookup_supports_vision(provider, model)
|
||||
supports = _lookup_supports_vision(provider, model, cfg)
|
||||
if supports is True:
|
||||
return "native"
|
||||
return "text"
|
||||
|
|
|
|||
|
|
@ -116,14 +116,11 @@ class ResponsesApiTransport(ProviderTransport):
|
|||
if reasoning_enabled and is_xai_responses:
|
||||
from agent.model_metadata import grok_supports_reasoning_effort
|
||||
|
||||
# NOTE: Hermes does NOT ask xAI to return ``reasoning.encrypted_content``
|
||||
# any more. xAI's OAuth/SuperGrok ``/v1/responses`` surface rejects
|
||||
# replayed encrypted reasoning items on turn 2+ — see
|
||||
# _chat_messages_to_responses_input docstring. Requesting the field
|
||||
# back would just have us cache something we then must strip. Grok
|
||||
# still reasons natively each turn; coherence across turns rides on
|
||||
# the visible message text alone.
|
||||
kwargs["include"] = []
|
||||
# Ask xAI to echo back encrypted reasoning items so we can
|
||||
# replay them on subsequent turns for cross-turn coherence.
|
||||
# See agent/codex_responses_adapter._chat_messages_to_responses_input
|
||||
# for the May 2026 reversal of the earlier suppression gate.
|
||||
kwargs["include"] = ["reasoning.encrypted_content"]
|
||||
# xAI rejects `reasoning.effort` on grok-4 / grok-4-fast / grok-3
|
||||
# / grok-code-fast / grok-4.20-0309-* with HTTP 400 even though
|
||||
# those models reason natively. Only send the effort dial when
|
||||
|
|
|
|||
53
cli.py
53
cli.py
|
|
@ -14380,13 +14380,54 @@ def main(
|
|||
# Only print the final response and parseable session info.
|
||||
cli.tool_progress_mode = "off"
|
||||
if cli._ensure_runtime_credentials():
|
||||
effective_query = query
|
||||
effective_query: Any = query
|
||||
if single_query_images:
|
||||
effective_query = cli._preprocess_images_with_vision(
|
||||
query,
|
||||
single_query_images,
|
||||
announce=False,
|
||||
)
|
||||
# Honour the same image-routing decision used by the
|
||||
# interactive path. With a vision-capable model (incl.
|
||||
# custom-provider models declared via
|
||||
# `model.supports_vision: true`), attach images natively
|
||||
# as image_url content parts. Otherwise fall back to the
|
||||
# text-pipeline (vision_analyze pre-description).
|
||||
_img_mode = "text"
|
||||
_build_parts = None
|
||||
try:
|
||||
from agent.image_routing import (
|
||||
build_native_content_parts as _build_parts, # noqa: F811
|
||||
)
|
||||
from agent.image_routing import decide_image_input_mode
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
_img_mode = decide_image_input_mode(
|
||||
(cli.provider or "").strip(),
|
||||
(cli.model or "").strip(),
|
||||
load_config(),
|
||||
)
|
||||
except Exception:
|
||||
_img_mode = "text"
|
||||
|
||||
if _img_mode == "native" and _build_parts is not None:
|
||||
try:
|
||||
_parts, _skipped = _build_parts(
|
||||
query if isinstance(query, str) else "",
|
||||
[str(p) for p in single_query_images],
|
||||
)
|
||||
if any(p.get("type") == "image_url" for p in _parts):
|
||||
effective_query = _parts
|
||||
else:
|
||||
# All images unreadable — text fallback.
|
||||
effective_query = cli._preprocess_images_with_vision(
|
||||
query, single_query_images, announce=False,
|
||||
)
|
||||
except Exception:
|
||||
effective_query = cli._preprocess_images_with_vision(
|
||||
query, single_query_images, announce=False,
|
||||
)
|
||||
else:
|
||||
effective_query = cli._preprocess_images_with_vision(
|
||||
query,
|
||||
single_query_images,
|
||||
announce=False,
|
||||
)
|
||||
turn_route = cli._resolve_turn_agent_config(effective_query)
|
||||
if turn_route["signature"] != cli._active_agent_route_signature:
|
||||
cli.agent = None
|
||||
|
|
|
|||
|
|
@ -830,6 +830,8 @@ def load_gateway_config() -> GatewayConfig:
|
|||
bridged["require_mention"] = platform_cfg["require_mention"]
|
||||
if plat == Platform.TELEGRAM and "allowed_chats" in platform_cfg:
|
||||
bridged["allowed_chats"] = platform_cfg["allowed_chats"]
|
||||
if plat == Platform.TELEGRAM and "group_allowed_chats" in platform_cfg:
|
||||
bridged["group_allowed_chats"] = platform_cfg["group_allowed_chats"]
|
||||
if plat == Platform.TELEGRAM and "allowed_topics" in platform_cfg:
|
||||
bridged["allowed_topics"] = platform_cfg["allowed_topics"]
|
||||
if "free_response_channels" in platform_cfg:
|
||||
|
|
@ -838,6 +840,8 @@ def load_gateway_config() -> GatewayConfig:
|
|||
bridged["mention_patterns"] = platform_cfg["mention_patterns"]
|
||||
if "exclusive_bot_mentions" in platform_cfg:
|
||||
bridged["exclusive_bot_mentions"] = platform_cfg["exclusive_bot_mentions"]
|
||||
if plat == Platform.TELEGRAM and "observe_unmentioned_group_messages" in platform_cfg:
|
||||
bridged["observe_unmentioned_group_messages"] = platform_cfg["observe_unmentioned_group_messages"]
|
||||
if "dm_policy" in platform_cfg:
|
||||
bridged["dm_policy"] = platform_cfg["dm_policy"]
|
||||
if "allow_from" in platform_cfg:
|
||||
|
|
@ -1024,6 +1028,8 @@ def load_gateway_config() -> GatewayConfig:
|
|||
os.environ["TELEGRAM_EXCLUSIVE_BOT_MENTIONS"] = str(telegram_cfg["exclusive_bot_mentions"]).lower()
|
||||
if "guest_mode" in telegram_cfg and not os.getenv("TELEGRAM_GUEST_MODE"):
|
||||
os.environ["TELEGRAM_GUEST_MODE"] = str(telegram_cfg["guest_mode"]).lower()
|
||||
if "observe_unmentioned_group_messages" in telegram_cfg and not os.getenv("TELEGRAM_OBSERVE_UNMENTIONED_GROUP_MESSAGES"):
|
||||
os.environ["TELEGRAM_OBSERVE_UNMENTIONED_GROUP_MESSAGES"] = str(telegram_cfg["observe_unmentioned_group_messages"]).lower()
|
||||
frc = telegram_cfg.get("free_response_chats")
|
||||
if frc is not None and not os.getenv("TELEGRAM_FREE_RESPONSE_CHATS"):
|
||||
if isinstance(frc, list):
|
||||
|
|
@ -1074,7 +1080,7 @@ def load_gateway_config() -> GatewayConfig:
|
|||
if isinstance(group_allowed_chats, list):
|
||||
group_allowed_chats = ",".join(str(v) for v in group_allowed_chats)
|
||||
os.environ["TELEGRAM_GROUP_ALLOWED_CHATS"] = str(group_allowed_chats)
|
||||
for _telegram_extra_key in ("guest_mode", "disable_link_previews"):
|
||||
for _telegram_extra_key in ("guest_mode", "disable_link_previews", "observe_unmentioned_group_messages"):
|
||||
if _telegram_extra_key in telegram_cfg:
|
||||
plat_data = platforms_data.setdefault(Platform.TELEGRAM.value, {})
|
||||
if not isinstance(plat_data, dict):
|
||||
|
|
|
|||
|
|
@ -64,7 +64,6 @@ def mirror_to_session(
|
|||
"mirror_source": source_label,
|
||||
}
|
||||
|
||||
_append_to_jsonl(session_id, mirror_msg)
|
||||
_append_to_sqlite(session_id, mirror_msg)
|
||||
|
||||
logger.debug("Mirror: wrote to session %s (from %s)", session_id, source_label)
|
||||
|
|
@ -150,15 +149,6 @@ def _find_session_id(
|
|||
return best_entry.get("session_id")
|
||||
|
||||
|
||||
def _append_to_jsonl(session_id: str, message: dict) -> None:
|
||||
"""Append a message to the JSONL transcript file."""
|
||||
transcript_path = _SESSIONS_DIR / f"{session_id}.jsonl"
|
||||
try:
|
||||
with open(transcript_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
except Exception as e:
|
||||
logger.debug("Mirror JSONL write failed: %s", e)
|
||||
|
||||
|
||||
def _append_to_sqlite(session_id: str, message: dict) -> None:
|
||||
"""Append a message to the SQLite session database."""
|
||||
|
|
|
|||
|
|
@ -2706,8 +2706,13 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
Discord's TYPING_START gateway event is unreliable in DMs for bots.
|
||||
Instead, start a background loop that hits the typing endpoint every
|
||||
8 seconds (typing indicator lasts ~10s). The loop is cancelled when
|
||||
12 seconds (typing indicator lasts ~10s). The loop is cancelled when
|
||||
stop_typing() is called (after the response is sent).
|
||||
|
||||
Rate-limit handling: if a 429 is encountered, the loop logs a
|
||||
warning, sleeps for the ``retry_after`` duration (or a sensible
|
||||
default), and continues — it does NOT die on a single rate-limit
|
||||
hit. Only CancelledError (from stop_typing) stops the loop.
|
||||
"""
|
||||
if not self._client:
|
||||
return
|
||||
|
|
@ -2727,9 +2732,22 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Discord typing indicator failed for %s: %s", chat_id, e)
|
||||
return
|
||||
await asyncio.sleep(8)
|
||||
# Don't die on 429 — backoff and continue
|
||||
retry_after = self._extract_discord_retry_after(e)
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
"Typing indicator rate-limited for %s; retrying in %.1fs",
|
||||
chat_id, retry_after,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Discord typing indicator failed for %s: %s",
|
||||
chat_id, e,
|
||||
)
|
||||
return
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
await asyncio.sleep(12)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -8,12 +8,14 @@ Uses python-telegram-bot library for:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import html as _html
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -4178,6 +4180,23 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
return bool(configured)
|
||||
return os.getenv("TELEGRAM_REQUIRE_MENTION", "false").lower() in {"true", "1", "yes", "on"}
|
||||
|
||||
def _telegram_observe_unmentioned_group_messages(self) -> bool:
|
||||
"""Return whether skipped unmentioned group messages are stored as context.
|
||||
|
||||
When enabled with ``require_mention``, Telegram matches the Yuanbao /
|
||||
OpenClaw-style group UX: observe ordinary group chatter in the session
|
||||
transcript, but only dispatch the agent when the bot is explicitly
|
||||
addressed.
|
||||
"""
|
||||
configured = self.config.extra.get("observe_unmentioned_group_messages")
|
||||
if configured is None:
|
||||
configured = self.config.extra.get("ingest_unmentioned_group_messages")
|
||||
if configured is not None:
|
||||
if isinstance(configured, str):
|
||||
return configured.lower() in {"true", "1", "yes", "on"}
|
||||
return bool(configured)
|
||||
return os.getenv("TELEGRAM_OBSERVE_UNMENTIONED_GROUP_MESSAGES", "false").lower() in {"true", "1", "yes", "on"}
|
||||
|
||||
def _telegram_guest_mode(self) -> bool:
|
||||
"""Return whether non-allowlisted groups may trigger via direct @mention."""
|
||||
configured = self.config.extra.get("guest_mode")
|
||||
|
|
@ -4219,6 +4238,30 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
return {str(part).strip() for part in raw if str(part).strip()}
|
||||
return {part.strip() for part in str(raw).split(",") if part.strip()}
|
||||
|
||||
def _telegram_group_allowed_chats(self) -> set[str]:
|
||||
"""Return Telegram chats authorized at group scope."""
|
||||
raw = self.config.extra.get("group_allowed_chats")
|
||||
if raw is None:
|
||||
raw = os.getenv("TELEGRAM_GROUP_ALLOWED_CHATS", "")
|
||||
if isinstance(raw, list):
|
||||
return {str(part).strip() for part in raw if str(part).strip()}
|
||||
return {part.strip() for part in str(raw).split(",") if part.strip()}
|
||||
|
||||
def _telegram_observe_allowed_chats(self) -> set[str]:
|
||||
"""Chats where observed group context may use a shared source.
|
||||
|
||||
``group_allowed_chats`` is the gateway authorization allowlist for
|
||||
user-less group sources. ``allowed_chats`` remains an optional response
|
||||
gate; when set, observed context must satisfy both lists.
|
||||
"""
|
||||
group_allowed = self._telegram_group_allowed_chats()
|
||||
if not group_allowed:
|
||||
return set()
|
||||
response_allowed = self._telegram_allowed_chats()
|
||||
if response_allowed:
|
||||
return group_allowed & response_allowed
|
||||
return group_allowed
|
||||
|
||||
def _telegram_allowed_topics(self) -> set[str]:
|
||||
"""Return the whitelist of Telegram forum topic IDs this bot handles.
|
||||
|
||||
|
|
@ -4466,6 +4509,126 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
cleaned = re.sub(rf"(?i)@{username}\b[,:\-]*\s*", "", text).strip()
|
||||
return cleaned or text
|
||||
|
||||
def _should_observe_unmentioned_group_message(self, message: Message) -> bool:
|
||||
"""Return True when a group message should be stored but not dispatched."""
|
||||
if not self._telegram_observe_unmentioned_group_messages():
|
||||
return False
|
||||
if not self._is_group_chat(message):
|
||||
return False
|
||||
|
||||
thread_id = getattr(message, "message_thread_id", None)
|
||||
allowed_topics = self._telegram_allowed_topics()
|
||||
if allowed_topics:
|
||||
topic_id = str(thread_id) if thread_id is not None else self._GENERAL_TOPIC_THREAD_ID
|
||||
if topic_id not in allowed_topics:
|
||||
return False
|
||||
|
||||
if thread_id is not None:
|
||||
try:
|
||||
if int(thread_id) in self._telegram_ignored_threads():
|
||||
return False
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
chat_id_str = str(getattr(getattr(message, "chat", None), "id", ""))
|
||||
if self._telegram_exclusive_bot_mentions() and self._explicit_bot_mentions_exclude_self(message):
|
||||
return False
|
||||
|
||||
allowed = self._telegram_observe_allowed_chats()
|
||||
# Observed context is shared at chat/topic scope so a later trigger from
|
||||
# another user can see it. Require an explicit chat allowlist; that
|
||||
# keeps shared observed history limited to operator-approved groups and
|
||||
# lets gateway authorization pass even after the shared session source
|
||||
# drops the per-sender user_id.
|
||||
if not allowed or chat_id_str not in allowed:
|
||||
return False
|
||||
|
||||
# Only observe messages skipped by the require_mention gate. If the
|
||||
# message would be processed normally, let the dispatcher handle it;
|
||||
# if require_mention is disabled, every group message is a request.
|
||||
if chat_id_str in self._telegram_free_response_chats():
|
||||
return False
|
||||
if not self._telegram_require_mention():
|
||||
return False
|
||||
if self._is_reply_to_bot(message):
|
||||
return False
|
||||
if self._message_mentions_bot(message):
|
||||
return False
|
||||
if self._message_matches_mention_patterns(message):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _telegram_group_observe_shared_source(self, source):
|
||||
"""Return a chat/topic-scoped source for observed Telegram group context."""
|
||||
return dataclasses.replace(source, user_id=None, user_name=None, user_id_alt=None)
|
||||
|
||||
def _telegram_group_observe_attributed_text(self, event: MessageEvent) -> str:
|
||||
user_id = event.source.user_id or "unknown"
|
||||
sender = event.source.user_name or user_id
|
||||
return f"[{sender}|{user_id}]\n{event.text or ''}"
|
||||
|
||||
def _telegram_group_observe_channel_prompt(self) -> str:
|
||||
username = getattr(getattr(self, "_bot", None), "username", None) or "unknown"
|
||||
bot_id = getattr(getattr(self, "_bot", None), "id", None) or "unknown"
|
||||
return (
|
||||
"You are handling a Telegram group chat message.\n"
|
||||
f"- Your identity: user_id={bot_id}, @-mention name in this group=@{username}\n"
|
||||
"- Lines in history prefixed with `[nickname|user_id]` are observed Telegram group context "
|
||||
"and are not necessarily addressed to you.\n"
|
||||
"- Treat only the current new message as a request explicitly directed at you, "
|
||||
"and answer it directly."
|
||||
)
|
||||
|
||||
def _apply_telegram_group_observe_attribution(self, event: MessageEvent) -> MessageEvent:
|
||||
"""Align triggered group turns with observed-history attribution."""
|
||||
if not self._telegram_observe_unmentioned_group_messages():
|
||||
return event
|
||||
raw_message = getattr(event, "raw_message", None)
|
||||
if not raw_message or not self._is_group_chat(raw_message):
|
||||
return event
|
||||
chat_id_str = str(getattr(getattr(raw_message, "chat", None), "id", ""))
|
||||
allowed = self._telegram_observe_allowed_chats()
|
||||
if not allowed or chat_id_str not in allowed:
|
||||
return event
|
||||
shared_source = self._telegram_group_observe_shared_source(event.source)
|
||||
observe_prompt = self._telegram_group_observe_channel_prompt()
|
||||
channel_prompt = f"{event.channel_prompt}\n\n{observe_prompt}" if event.channel_prompt else observe_prompt
|
||||
return dataclasses.replace(
|
||||
event,
|
||||
text=self._telegram_group_observe_attributed_text(event),
|
||||
source=shared_source,
|
||||
channel_prompt=channel_prompt,
|
||||
)
|
||||
|
||||
def _observe_unmentioned_group_message(self, message: Message, msg_type: MessageType, update_id: Optional[int] = None) -> None:
|
||||
"""Append skipped group chatter to the target session without dispatching."""
|
||||
store = getattr(self, "_session_store", None)
|
||||
if not store:
|
||||
return
|
||||
try:
|
||||
event = self._build_message_event(message, msg_type, update_id=update_id)
|
||||
shared_source = self._telegram_group_observe_shared_source(event.source)
|
||||
session_entry = store.get_or_create_session(shared_source)
|
||||
entry = {
|
||||
"role": "user",
|
||||
"content": self._telegram_group_observe_attributed_text(event),
|
||||
"timestamp": datetime.now(tz=timezone.utc).isoformat(),
|
||||
"observed": True,
|
||||
}
|
||||
if event.message_id:
|
||||
entry["message_id"] = str(event.message_id)
|
||||
store.append_to_transcript(session_entry.session_id, entry)
|
||||
adapter_name = getattr(self, "name", "telegram")
|
||||
logger.info(
|
||||
"[%s] Telegram group message observed (no bot trigger): chat=%s from=%s",
|
||||
adapter_name,
|
||||
getattr(getattr(message, "chat", None), "id", "unknown"),
|
||||
event.source.user_id or "unknown",
|
||||
)
|
||||
except Exception as exc:
|
||||
adapter_name = getattr(self, "name", "telegram")
|
||||
logger.warning("[%s] Failed to observe Telegram group message: %s", adapter_name, exc)
|
||||
|
||||
def _should_process_message(self, message: Message, *, is_command: bool = False) -> bool:
|
||||
"""Apply Telegram group trigger rules.
|
||||
|
||||
|
|
@ -4590,11 +4753,14 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
if not msg or not msg.text:
|
||||
return
|
||||
if not self._should_process_message(msg):
|
||||
if self._should_observe_unmentioned_group_message(msg):
|
||||
self._observe_unmentioned_group_message(msg, MessageType.TEXT, update_id=update.update_id)
|
||||
return
|
||||
await self._ensure_forum_commands(update.message)
|
||||
|
||||
event = self._build_message_event(msg, MessageType.TEXT, update_id=update.update_id)
|
||||
event.text = self._clean_bot_trigger_text(event.text)
|
||||
event = self._apply_telegram_group_observe_attribution(event)
|
||||
self._enqueue_text_event(event)
|
||||
|
||||
async def _handle_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
|
|
@ -4607,6 +4773,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
await self._ensure_forum_commands(msg)
|
||||
|
||||
event = self._build_message_event(msg, MessageType.COMMAND, update_id=update.update_id)
|
||||
event.text = self._clean_bot_trigger_text(event.text)
|
||||
event = self._apply_telegram_group_observe_attribution(event)
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _handle_location_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
|
|
@ -4615,6 +4783,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
if not msg:
|
||||
return
|
||||
if not self._should_process_message(msg):
|
||||
if self._should_observe_unmentioned_group_message(msg):
|
||||
self._observe_unmentioned_group_message(msg, MessageType.LOCATION, update_id=update.update_id)
|
||||
return
|
||||
|
||||
venue = getattr(msg, "venue", None)
|
||||
|
|
@ -4644,6 +4814,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
|
||||
event = self._build_message_event(msg, MessageType.LOCATION, update_id=update.update_id)
|
||||
event.text = "\n".join(parts)
|
||||
event = self._apply_telegram_group_observe_attribution(event)
|
||||
await self.handle_message(event)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -4788,8 +4959,23 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
if not update.message:
|
||||
return
|
||||
if not self._should_process_message(update.message):
|
||||
if self._should_observe_unmentioned_group_message(update.message):
|
||||
_m = update.message
|
||||
if _m.sticker:
|
||||
_observe_type = MessageType.STICKER
|
||||
elif _m.photo:
|
||||
_observe_type = MessageType.PHOTO
|
||||
elif _m.video:
|
||||
_observe_type = MessageType.VIDEO
|
||||
elif _m.audio:
|
||||
_observe_type = MessageType.AUDIO
|
||||
elif _m.voice:
|
||||
_observe_type = MessageType.VOICE
|
||||
else:
|
||||
_observe_type = MessageType.DOCUMENT
|
||||
self._observe_unmentioned_group_message(_m, _observe_type, update_id=update.update_id)
|
||||
return
|
||||
|
||||
|
||||
msg = update.message
|
||||
|
||||
# Determine media type
|
||||
|
|
@ -4817,9 +5003,14 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
# Handle stickers: describe via vision tool with caching
|
||||
if msg.sticker:
|
||||
await self._handle_sticker(msg, event)
|
||||
event = self._apply_telegram_group_observe_attribution(event)
|
||||
await self.handle_message(event)
|
||||
return
|
||||
|
||||
|
||||
# Apply observe attribution after caption is set; sticker is handled above
|
||||
# because _handle_sticker overwrites event.text with its vision description.
|
||||
event = self._apply_telegram_group_observe_attribution(event)
|
||||
|
||||
# Download photo to local image cache so the vision tool can access it
|
||||
# even after Telegram's ephemeral file URLs expire (~1 hour).
|
||||
if msg.photo:
|
||||
|
|
|
|||
|
|
@ -1410,41 +1410,43 @@ class RecallGuardMiddleware(InboundMiddleware):
|
|||
logger.warning("[%s] Recall: failed to resolve session: %s", adapter.name, exc)
|
||||
return
|
||||
|
||||
# Read JSONL directly — SQLite doesn't preserve message_id field.
|
||||
transcript: list = []
|
||||
# Load transcript from canonical store (state.db). Since PR #29278
|
||||
# added a ``platform_message_id`` column to the messages table and
|
||||
# ``append_to_transcript`` wires the incoming dict's ``message_id``
|
||||
# into it, ``load_transcript`` returns rows with ``message_id`` set
|
||||
# for any message that was observed with one — Branch A1 (exact id
|
||||
# match) is the canonical path again.
|
||||
try:
|
||||
path = store.get_transcript_path(sid)
|
||||
if path.exists():
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
transcript.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
transcript = store.load_transcript(sid)
|
||||
except Exception as exc:
|
||||
logger.warning("[%s] Recall: failed to load transcript: %s", adapter.name, exc)
|
||||
return
|
||||
|
||||
# Branch A: redact — try message_id first, then content fallback.
|
||||
# Observed messages have message_id; agent-processed @bot messages
|
||||
# only have content (run.py doesn't write message_id to transcript).
|
||||
# Branch A1: exact platform message_id match. Authoritative when the
|
||||
# row was persisted with a platform_message_id (observed group
|
||||
# messages and any inbound message whose adapter carried a msg_id).
|
||||
target = None
|
||||
branch_label = ""
|
||||
for entry in transcript:
|
||||
if entry.get("message_id") == recalled_id:
|
||||
target = entry
|
||||
branch_label = "branch A1: id match"
|
||||
break
|
||||
# Branch A2: content-match fallback for messages that lack an exact
|
||||
# platform id on the row — e.g. agent-processed @bot messages
|
||||
# (run.py doesn't carry msg_id through) or older rows persisted
|
||||
# before the platform_message_id column existed.
|
||||
if target is None and recalled_content:
|
||||
for entry in transcript:
|
||||
if entry.get("role") == "user" and entry.get("content") == recalled_content:
|
||||
target = entry
|
||||
branch_label = "branch A2: content match"
|
||||
break
|
||||
if target is not None:
|
||||
target["content"] = cls._REDACTED
|
||||
try:
|
||||
store.rewrite_transcript(sid, transcript)
|
||||
logger.info("[%s] Recall: redacted msg_id=%s (branch A)", adapter.name, recalled_id)
|
||||
logger.info("[%s] Recall: redacted msg_id=%s (%s)", adapter.name, recalled_id, branch_label)
|
||||
except Exception as exc:
|
||||
logger.warning("[%s] Recall: rewrite_transcript failed: %s", adapter.name, exc)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1248,20 +1248,15 @@ class SessionStore:
|
|||
|
||||
return entries
|
||||
|
||||
def get_transcript_path(self, session_id: str) -> Path:
|
||||
"""Get the path to a session's legacy transcript file."""
|
||||
return self.sessions_dir / f"{session_id}.jsonl"
|
||||
|
||||
def append_to_transcript(self, session_id: str, message: Dict[str, Any], skip_db: bool = False) -> None:
|
||||
"""Append a message to a session's transcript (SQLite + legacy JSONL).
|
||||
"""Append a message to a session's transcript (SQLite).
|
||||
|
||||
Args:
|
||||
skip_db: When True, only write to JSONL and skip the SQLite write.
|
||||
Used when the agent already persisted messages to SQLite
|
||||
via its own _flush_messages_to_session_db(), preventing
|
||||
the duplicate-write bug (#860).
|
||||
skip_db: When True, skip the SQLite write. Used when the agent
|
||||
already persisted messages to SQLite via its own
|
||||
_flush_messages_to_session_db(), preventing the
|
||||
duplicate-write bug (#860).
|
||||
"""
|
||||
# Write to SQLite (unless the agent already handled it)
|
||||
if self._db and not skip_db:
|
||||
try:
|
||||
self._db.append_message(
|
||||
|
|
@ -1276,94 +1271,42 @@ class SessionStore:
|
|||
reasoning_details=message.get("reasoning_details") if message.get("role") == "assistant" else None,
|
||||
codex_reasoning_items=message.get("codex_reasoning_items") if message.get("role") == "assistant" else None,
|
||||
codex_message_items=message.get("codex_message_items") if message.get("role") == "assistant" else None,
|
||||
# Platform-side message id (yuanbao msg_id, telegram update_id, …).
|
||||
# Accept either explicit ``platform_message_id`` or the legacy
|
||||
# ``message_id`` key the JSONL transcript used.
|
||||
platform_message_id=(
|
||||
message.get("platform_message_id") or message.get("message_id")
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
# Also write legacy JSONL (keeps existing tooling working during transition)
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
try:
|
||||
with self._lock:
|
||||
with open(transcript_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
except OSError as e:
|
||||
# Disk full / read-only fs / permission errors must not crash the
|
||||
# message handler — the SQLite write above is the primary store.
|
||||
logger.debug("Failed to write JSONL transcript for %s: %s", session_id, e)
|
||||
|
||||
def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Replace the entire transcript for a session with new messages.
|
||||
|
||||
Used by /retry, /undo, and /compress to persist modified conversation history.
|
||||
Rewrites both SQLite and legacy JSONL storage.
|
||||
|
||||
Used by /retry, /undo, and /compress to persist modified conversation
|
||||
history. state.db is the canonical store.
|
||||
"""
|
||||
# SQLite: replace atomically so a mid-rewrite failure doesn't leave
|
||||
# the session half-empty in the DB while JSONL still has history.
|
||||
if self._db:
|
||||
try:
|
||||
self._db.replace_messages(session_id, messages)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to rewrite transcript in DB: %s", e)
|
||||
|
||||
# JSONL: overwrite the file
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
with open(transcript_path, "w", encoding="utf-8") as f:
|
||||
for msg in messages:
|
||||
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
||||
|
||||
def load_transcript(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
"""Load all messages from a session's transcript."""
|
||||
db_messages = []
|
||||
# Try SQLite first
|
||||
if self._db:
|
||||
try:
|
||||
db_messages = self._db.get_messages_as_conversation(session_id)
|
||||
except Exception as e:
|
||||
logger.debug("Could not load messages from DB: %s", e)
|
||||
"""Load all messages from a session's transcript.
|
||||
|
||||
# Load legacy JSONL transcript (may contain more history than SQLite
|
||||
# for sessions created before the DB layer was introduced).
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
jsonl_messages = []
|
||||
if transcript_path.exists():
|
||||
try:
|
||||
with open(transcript_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
jsonl_messages.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Skipping corrupt line in transcript %s: %s",
|
||||
session_id, line[:120],
|
||||
)
|
||||
except OSError as e:
|
||||
# JSONL is the legacy compatibility store. If it becomes
|
||||
# unreadable, keep gateway recovery working by falling back to
|
||||
# SQLite rows loaded above (or [] when no DB exists).
|
||||
logger.debug("Failed to read JSONL transcript for %s: %s", session_id, e)
|
||||
|
||||
# Prefer whichever source has more messages.
|
||||
#
|
||||
# Background: when a session pre-dates SQLite storage (or when the DB
|
||||
# layer was added while a long-lived session was already active), the
|
||||
# first post-migration turn writes only the *new* messages to SQLite
|
||||
# (because _flush_messages_to_session_db skips messages already in
|
||||
# conversation_history, assuming they're persisted). On the *next*
|
||||
# turn load_transcript returns those few SQLite rows and ignores the
|
||||
# full JSONL history — the model sees a context of 1-4 messages instead
|
||||
# of hundreds. Using the longer source prevents this silent truncation.
|
||||
if len(jsonl_messages) > len(db_messages):
|
||||
if db_messages:
|
||||
logger.debug(
|
||||
"Session %s: JSONL has %d messages vs SQLite %d — "
|
||||
"using JSONL (legacy session not yet fully migrated)",
|
||||
session_id, len(jsonl_messages), len(db_messages),
|
||||
)
|
||||
return jsonl_messages
|
||||
|
||||
return db_messages
|
||||
state.db is the canonical store. The legacy JSONL fallback was removed
|
||||
in spec 002 — pre-DB sessions on existing disks have already been
|
||||
migrated (their DB row holds the full message history).
|
||||
"""
|
||||
if not self._db:
|
||||
return []
|
||||
try:
|
||||
return self._db.get_messages_as_conversation(session_id)
|
||||
except Exception as e:
|
||||
logger.debug("Could not load messages from DB: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
def build_session_context(
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ import httpx
|
|||
import yaml
|
||||
|
||||
from hermes_cli.config import get_hermes_home, get_config_path, read_raw_config
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
from hermes_constants import OPENROUTER_BASE_URL, secure_parent_dir
|
||||
from utils import atomic_replace, atomic_yaml_write, is_truthy_value
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -1030,10 +1030,8 @@ def _save_auth_store(auth_store: Dict[str, Any]) -> Path:
|
|||
auth_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Tighten parent dir to 0o700 so siblings can't traverse to creds.
|
||||
# No-op on Windows (POSIX mode bits not enforced); ignore failures.
|
||||
try:
|
||||
os.chmod(auth_file.parent, 0o700)
|
||||
except OSError:
|
||||
pass
|
||||
# secure_parent_dir refuses to chmod / or top-level dirs (#25821).
|
||||
secure_parent_dir(auth_file)
|
||||
auth_store["version"] = AUTH_STORE_VERSION
|
||||
auth_store["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
payload = json.dumps(auth_store, indent=2) + "\n"
|
||||
|
|
@ -1863,10 +1861,8 @@ def _read_qwen_cli_tokens() -> Dict[str, Any]:
|
|||
def _save_qwen_cli_tokens(tokens: Dict[str, Any]) -> Path:
|
||||
auth_path = _qwen_cli_auth_path()
|
||||
auth_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
os.chmod(auth_path.parent, 0o700)
|
||||
except OSError:
|
||||
pass
|
||||
# secure_parent_dir refuses to chmod / or top-level dirs (#25821).
|
||||
secure_parent_dir(auth_path)
|
||||
# Per-process random temp suffix avoids collisions between concurrent
|
||||
# writers and stale leftovers from a crashed prior write.
|
||||
tmp_path = auth_path.with_name(f"{auth_path.name}.tmp.{os.getpid()}.{uuid.uuid4().hex}")
|
||||
|
|
@ -4168,10 +4164,8 @@ def _write_shared_nous_state(state: Dict[str, Any]) -> None:
|
|||
with _nous_shared_store_lock():
|
||||
path = _nous_shared_store_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
os.chmod(path.parent, 0o700)
|
||||
except OSError:
|
||||
pass
|
||||
# secure_parent_dir refuses to chmod / or top-level dirs (#25821).
|
||||
secure_parent_dir(path)
|
||||
tmp = path.with_name(f"{path.name}.tmp.{os.getpid()}.{uuid.uuid4().hex}")
|
||||
# Create with 0o600 atomically via os.open(O_EXCL) — closes the TOCTOU
|
||||
# window where write_text() + post-write chmod briefly exposed Nous
|
||||
|
|
|
|||
|
|
@ -508,6 +508,68 @@ def telegram_bot_commands() -> list[tuple[str, str]]:
|
|||
return result
|
||||
|
||||
|
||||
_TELEGRAM_MENU_PRIORITY = (
|
||||
# Most-typed everyday commands first.
|
||||
"help",
|
||||
"new",
|
||||
"stop",
|
||||
"status",
|
||||
"resume",
|
||||
"sessions",
|
||||
"model",
|
||||
# Maintenance / diagnostics — the ones that prompted this priority list.
|
||||
"debug",
|
||||
"restart",
|
||||
"update",
|
||||
"verbose",
|
||||
"commands",
|
||||
# Mid-turn session control.
|
||||
"approve",
|
||||
"deny",
|
||||
"queue",
|
||||
"steer",
|
||||
"background",
|
||||
# Lower-priority but still useful operational built-ins.
|
||||
"reasoning",
|
||||
"usage",
|
||||
"platforms",
|
||||
"platform",
|
||||
"profile",
|
||||
"whoami",
|
||||
)
|
||||
"""Built-in commands that should stay visible in Telegram's capped menu.
|
||||
|
||||
Telegram only displays a small BotCommand menu in practice. The full Hermes
|
||||
registry is still dispatchable when typed manually, but operational commands
|
||||
need to survive the visible menu cap ahead of lower-priority built-ins.
|
||||
"""
|
||||
|
||||
|
||||
def _prioritize_telegram_menu_commands(
|
||||
commands: list[tuple[str, str]],
|
||||
) -> list[tuple[str, str]]:
|
||||
priority = {
|
||||
_sanitize_telegram_name(name): index
|
||||
for index, name in enumerate(_TELEGRAM_MENU_PRIORITY)
|
||||
}
|
||||
return [
|
||||
command
|
||||
for _index, command in sorted(
|
||||
enumerate(commands),
|
||||
key=lambda item: (
|
||||
0,
|
||||
priority[item[1][0]],
|
||||
item[0],
|
||||
)
|
||||
if item[1][0] in priority
|
||||
else (
|
||||
1,
|
||||
item[0],
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
_CMD_NAME_LIMIT = 32
|
||||
"""Max command name length shared by Telegram and Discord."""
|
||||
|
||||
|
|
@ -721,11 +783,12 @@ def telegram_menu_commands(max_commands: int = 100) -> tuple[list[tuple[str, str
|
|||
|
||||
Returns:
|
||||
(menu_commands, hidden_count) where hidden_count is the number of
|
||||
skill commands omitted due to the cap.
|
||||
commands omitted due to the cap.
|
||||
"""
|
||||
core_commands = list(telegram_bot_commands())
|
||||
core_commands = _prioritize_telegram_menu_commands(list(telegram_bot_commands()))
|
||||
reserved_names = {n for n, _ in core_commands}
|
||||
all_commands = list(core_commands)
|
||||
hidden_core_count = max(0, len(all_commands) - max_commands)
|
||||
|
||||
remaining_slots = max(0, max_commands - len(all_commands))
|
||||
entries, hidden_count = _collect_gateway_skill_entries(
|
||||
|
|
@ -737,7 +800,7 @@ def telegram_menu_commands(max_commands: int = 100) -> tuple[list[tuple[str, str
|
|||
)
|
||||
# Drop the cmd_key — Telegram only needs (name, desc) pairs.
|
||||
all_commands.extend((n, d) for n, d, _k in entries)
|
||||
return all_commands[:max_commands], hidden_count
|
||||
return all_commands[:max_commands], hidden_count + hidden_core_count
|
||||
|
||||
|
||||
def discord_skill_commands(
|
||||
|
|
|
|||
|
|
@ -3016,7 +3016,7 @@ def _normalize_custom_provider_entry(
|
|||
"api_mode", "transport", "model", "default_model", "models",
|
||||
"context_length", "rate_limit_delay",
|
||||
"request_timeout_seconds", "stale_timeout_seconds",
|
||||
"discover_models",
|
||||
"discover_models", "extra_body",
|
||||
}
|
||||
for camel, snake in _CAMEL_ALIASES.items():
|
||||
if camel in entry and snake not in entry:
|
||||
|
|
@ -3111,6 +3111,10 @@ def _normalize_custom_provider_entry(
|
|||
if isinstance(discover_models, bool):
|
||||
normalized["discover_models"] = discover_models
|
||||
|
||||
extra_body = entry.get("extra_body")
|
||||
if isinstance(extra_body, dict):
|
||||
normalized["extra_body"] = dict(extra_body)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
|
|
@ -3271,7 +3275,7 @@ _KNOWN_ROOT_KEYS = {
|
|||
# Valid fields inside a custom_providers list entry
|
||||
_VALID_CUSTOM_PROVIDER_FIELDS = {
|
||||
"name", "base_url", "api_key", "api_mode", "model", "models",
|
||||
"context_length", "rate_limit_delay",
|
||||
"context_length", "rate_limit_delay", "extra_body",
|
||||
# key_env is read at runtime by runtime_provider.py and auxiliary_client.py
|
||||
# — include it here so the set accurately describes the supported schema.
|
||||
"key_env",
|
||||
|
|
|
|||
|
|
@ -951,6 +951,58 @@ CREATE INDEX IF NOT EXISTS idx_notify_task ON kanban_notify_subs(task_
|
|||
|
||||
_INITIALIZED_PATHS: set[str] = set()
|
||||
_INIT_LOCK = threading.RLock()
|
||||
_SQLITE_HEADER = b"SQLite format 3\x00"
|
||||
|
||||
|
||||
def _looks_like_tls_record_at(data: bytes, offset: int) -> bool:
|
||||
"""Return True for a TLS record header at ``data[offset:]``."""
|
||||
if len(data) < offset + 5:
|
||||
return False
|
||||
content_type = data[offset]
|
||||
major = data[offset + 1]
|
||||
minor = data[offset + 2]
|
||||
length = int.from_bytes(data[offset + 3:offset + 5], "big")
|
||||
return (
|
||||
content_type in {0x14, 0x15, 0x16, 0x17}
|
||||
and major == 0x03
|
||||
and minor in {0x00, 0x01, 0x02, 0x03, 0x04}
|
||||
and 0 < length <= 18432
|
||||
)
|
||||
|
||||
|
||||
def _validate_sqlite_header(path: Path) -> None:
|
||||
"""Fail early with an actionable error for non-SQLite Kanban DB files.
|
||||
|
||||
``sqlite3.connect()`` creates missing and zero-byte files, so those are
|
||||
allowed. Existing non-empty files must have the SQLite header before we
|
||||
hand them to SQLite/WAL setup. This keeps corrupted page-0 failures from
|
||||
being collapsed into a generic PRAGMA error and lets the gateway's corrupt
|
||||
board handling identify the board by fingerprint.
|
||||
"""
|
||||
try:
|
||||
stat = path.stat()
|
||||
except FileNotFoundError:
|
||||
return
|
||||
except OSError:
|
||||
return
|
||||
if stat.st_size == 0:
|
||||
return
|
||||
try:
|
||||
with path.open("rb") as handle:
|
||||
head = handle.read(64)
|
||||
except OSError:
|
||||
return
|
||||
if head.startswith(_SQLITE_HEADER):
|
||||
return
|
||||
signature = ""
|
||||
if head.startswith(b"SQLit") and _looks_like_tls_record_at(head, 5):
|
||||
signature = " (TLS record header detected at byte offset 5)"
|
||||
elif _looks_like_tls_record_at(head, 0):
|
||||
signature = " (TLS record header detected at byte offset 0)"
|
||||
raise sqlite3.DatabaseError(
|
||||
"file is not a database: invalid SQLite header for "
|
||||
f"{path}{signature}; first_32={head[:32].hex(' ')}"
|
||||
)
|
||||
|
||||
|
||||
def connect(
|
||||
|
|
@ -981,6 +1033,7 @@ def connect(
|
|||
else:
|
||||
path = kanban_db_path(board=board)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
_validate_sqlite_header(path)
|
||||
resolved = str(path.resolve())
|
||||
conn = sqlite3.connect(str(path), isolation_level=None, timeout=30)
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -6112,24 +6112,36 @@ def _validate_critical_files_syntax(root) -> tuple[bool, str | None, str | None]
|
|||
them after a successful ``git pull`` so we can auto-roll-back instead of
|
||||
leaving the user with a bricked install.
|
||||
|
||||
The compiled ``.pyc`` is written to a temp directory rather than the
|
||||
source tree's ``__pycache__/`` so we don't race with concurrent test
|
||||
workers that walk the same dir, and so we don't leave a stale pyc
|
||||
behind in production if the next interpreter run picks a different
|
||||
Python version. The pyc is discarded on function return either way —
|
||||
we only care about the compile-or-not signal.
|
||||
|
||||
Returns ``(ok, failing_path, error_message)``. ``ok=True`` means every
|
||||
file parsed cleanly.
|
||||
"""
|
||||
import py_compile
|
||||
import tempfile
|
||||
|
||||
root = Path(root)
|
||||
for relpath in _UPDATE_CRITICAL_FILES:
|
||||
path = root / relpath
|
||||
if not path.exists():
|
||||
# Missing file is suspicious but not necessarily fatal — a future
|
||||
# refactor may legitimately remove one of these. Skip and move on.
|
||||
continue
|
||||
try:
|
||||
py_compile.compile(str(path), doraise=True)
|
||||
except py_compile.PyCompileError as exc:
|
||||
return False, str(path), str(exc)
|
||||
except OSError as exc:
|
||||
return False, str(path), f"could not read: {exc}"
|
||||
with tempfile.TemporaryDirectory(prefix="hermes-syntax-check-") as tmpdir:
|
||||
for relpath in _UPDATE_CRITICAL_FILES:
|
||||
path = root / relpath
|
||||
if not path.exists():
|
||||
# Missing file is suspicious but not necessarily fatal — a future
|
||||
# refactor may legitimately remove one of these. Skip and move on.
|
||||
continue
|
||||
# Mirror the relative path under the tmpdir so two different
|
||||
# files with the same basename don't collide on the cfile name.
|
||||
cfile = Path(tmpdir) / (relpath.replace("/", "__") + "c")
|
||||
try:
|
||||
py_compile.compile(str(path), cfile=str(cfile), doraise=True)
|
||||
except py_compile.PyCompileError as exc:
|
||||
return False, str(path), str(exc)
|
||||
except OSError as exc:
|
||||
return False, str(path), f"could not read: {exc}"
|
||||
return True, None, None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -902,7 +902,49 @@ def delete_profile(name: str, yes: bool = False) -> Path:
|
|||
|
||||
# 4. Remove profile directory
|
||||
try:
|
||||
shutil.rmtree(profile_dir)
|
||||
def _make_writable(func, path, exc):
|
||||
"""onexc/onerror handler: add +w on PermissionError so rmtree can proceed.
|
||||
|
||||
Handles two cases on NixOS (and other systems with read-only
|
||||
copies from immutable stores):
|
||||
1. The path itself isn't writable (e.g. a file with mode 0444)
|
||||
2. The *parent* directory isn't writable (e.g. mode 0555)
|
||||
|
||||
Compatible with both the ``onexc`` API (3.12+, receives an
|
||||
exception instance) and the ``onerror`` API (3.11-, receives
|
||||
``sys.exc_info()`` tuple).
|
||||
"""
|
||||
import stat as _stat
|
||||
import sys as _sys
|
||||
|
||||
# Normalise the two callback signatures:
|
||||
# onexc(func, path, exc_instance) — 3.12+
|
||||
# onerror(func, path, exc_info_tuple) — 3.11
|
||||
if isinstance(exc, tuple):
|
||||
exc = exc[1] # exc_info → actual exception object
|
||||
|
||||
if isinstance(exc, PermissionError):
|
||||
# Make the path writable
|
||||
try:
|
||||
os.chmod(path, os.stat(path).st_mode | _stat.S_IWUSR)
|
||||
except OSError:
|
||||
pass
|
||||
# Also make the parent writable (needed for unlink/rmdir)
|
||||
parent = os.path.dirname(path)
|
||||
if parent:
|
||||
try:
|
||||
os.chmod(parent, os.stat(parent).st_mode | _stat.S_IWUSR)
|
||||
except OSError:
|
||||
pass
|
||||
func(path)
|
||||
else:
|
||||
raise
|
||||
|
||||
# ``onexc`` was added in 3.12; fall back to ``onerror`` on 3.11.
|
||||
try:
|
||||
shutil.rmtree(profile_dir, onexc=_make_writable)
|
||||
except TypeError:
|
||||
shutil.rmtree(profile_dir, onerror=_make_writable)
|
||||
print(f"✓ Removed {profile_dir}")
|
||||
except Exception as e:
|
||||
print(f"⚠ Could not remove {profile_dir}: {e}")
|
||||
|
|
|
|||
|
|
@ -100,6 +100,63 @@ def _detect_api_mode_for_url(base_url: str) -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
def _host_derived_api_key(base_url: str) -> str:
|
||||
"""Look up `<VENDOR>_API_KEY` in the env, derived from the base URL host.
|
||||
|
||||
Examples:
|
||||
https://api.deepseek.com/v1 → DEEPSEEK_API_KEY
|
||||
https://api.groq.com/openai/v1 → GROQ_API_KEY
|
||||
https://api.mistral.ai/v1 → MISTRAL_API_KEY
|
||||
https://generativelanguage.googleapis.com/v1beta/openai/ → GOOGLEAPIS_API_KEY
|
||||
|
||||
Returns the env value (stripped) or "". Never returns env vars whose names
|
||||
are already explicitly checked elsewhere — those are handled by their own
|
||||
host-gated paths (OPENAI/OPENROUTER/OLLAMA).
|
||||
|
||||
The vendor label is the *registrable* portion of the hostname: strip
|
||||
``api.`` / ``www.`` prefixes, then take the second-to-last label
|
||||
(``api.deepseek.com`` → ``deepseek``). Falls back to "" for hostnames
|
||||
that don't yield a usable vendor label (IPs, loopback, single-label
|
||||
hosts).
|
||||
"""
|
||||
hostname = base_url_hostname(base_url)
|
||||
if not hostname:
|
||||
return ""
|
||||
# Reject IPv4 / IPv6 / loopback — no meaningful vendor label.
|
||||
if any(ch.isdigit() for ch in hostname.split(".")[-1]):
|
||||
# Last label starts with a digit → likely IP. (TLDs are never numeric.)
|
||||
return ""
|
||||
if hostname in ("localhost",) or ":" in hostname:
|
||||
return ""
|
||||
labels = [lbl for lbl in hostname.split(".") if lbl]
|
||||
# Strip common API/CDN prefixes.
|
||||
while labels and labels[0] in ("api", "www"):
|
||||
labels.pop(0)
|
||||
if len(labels) < 2:
|
||||
return ""
|
||||
# Take the *registrable* label (second-to-last). For typical provider
|
||||
# hosts this is what users intuitively call "the vendor":
|
||||
# deepseek.com → labels[-2] = "deepseek" ✓
|
||||
# api.groq.com → groq.com → labels[-2] = "groq" ✓
|
||||
# api.mistral.ai → labels[-2] = "mistral" ✓
|
||||
# Crucially, lookalike hosts pick the ATTACKER's label, not the spoofed
|
||||
# vendor:
|
||||
# api.deepseek.com.attacker.test → labels[-2] = "attacker"
|
||||
# so DEEPSEEK_API_KEY stays put and the chain falls through to
|
||||
# no-key-required. This mirrors how `base_url_host_matches` resists the
|
||||
# same lookalike attack for explicit hosts.
|
||||
vendor = labels[-2]
|
||||
# Sanitize to env var charset: A-Z, 0-9, underscore.
|
||||
sanitized = "".join(ch if ch.isalnum() else "_" for ch in vendor).upper()
|
||||
if not sanitized or not sanitized[0].isalpha():
|
||||
return ""
|
||||
# Don't re-derive env vars already handled by explicit host-gated paths.
|
||||
if sanitized in ("OPENAI", "OPENROUTER", "OLLAMA"):
|
||||
return ""
|
||||
env_name = f"{sanitized}_API_KEY"
|
||||
return (os.getenv(env_name, "") or "").strip()
|
||||
|
||||
|
||||
def _auto_detect_local_model(base_url: str) -> str:
|
||||
"""Query a local server for its model name when only one model is loaded."""
|
||||
if not base_url:
|
||||
|
|
@ -471,6 +528,9 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
|
|||
"api_key": resolved_api_key,
|
||||
"model": entry.get("default_model", ""),
|
||||
}
|
||||
extra_body = entry.get("extra_body")
|
||||
if isinstance(extra_body, dict):
|
||||
result["extra_body"] = dict(extra_body)
|
||||
# The v11→v12 migration writes the API mode under the new
|
||||
# ``transport`` field, but hand-edited configs may still
|
||||
# use the legacy ``api_mode`` spelling. Accept both —
|
||||
|
|
@ -496,6 +556,9 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
|
|||
"api_key": resolved_api_key,
|
||||
"model": entry.get("default_model", ""),
|
||||
}
|
||||
extra_body = entry.get("extra_body")
|
||||
if isinstance(extra_body, dict):
|
||||
result["extra_body"] = dict(extra_body)
|
||||
api_mode = _parse_api_mode(entry.get("api_mode") or entry.get("transport"))
|
||||
if api_mode:
|
||||
result["api_mode"] = api_mode
|
||||
|
|
@ -539,6 +602,9 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
|
|||
result["key_env"] = key_env
|
||||
if provider_key:
|
||||
result["provider_key"] = provider_key
|
||||
extra_body = entry.get("extra_body")
|
||||
if isinstance(extra_body, dict):
|
||||
result["extra_body"] = dict(extra_body)
|
||||
api_mode = _parse_api_mode(entry.get("api_mode"))
|
||||
if api_mode:
|
||||
result["api_mode"] = api_mode
|
||||
|
|
@ -550,6 +616,13 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
|
|||
return None
|
||||
|
||||
|
||||
def _custom_provider_request_overrides(custom_provider: Dict[str, Any]) -> Dict[str, Any]:
|
||||
extra_body = custom_provider.get("extra_body")
|
||||
if not isinstance(extra_body, dict) or not extra_body:
|
||||
return {}
|
||||
return {"extra_body": dict(extra_body)}
|
||||
|
||||
|
||||
def _resolve_named_custom_runtime(
|
||||
*,
|
||||
requested_provider: str,
|
||||
|
|
@ -582,10 +655,17 @@ def _resolve_named_custom_runtime(
|
|||
if pool_result:
|
||||
pool_result["source"] = "direct-alias"
|
||||
return pool_result
|
||||
_da_is_openai_url = base_url_host_matches(base_url, "openai.com") or base_url_host_matches(base_url, "openai.azure.com")
|
||||
_da_is_openrouter = base_url_host_matches(base_url, "openrouter.ai")
|
||||
api_key_candidates = [
|
||||
(explicit_api_key or "").strip(),
|
||||
os.getenv("OPENAI_API_KEY", "").strip(),
|
||||
os.getenv("OPENROUTER_API_KEY", "").strip(),
|
||||
# Gate env key fallbacks on authoritative hosts (#28660)
|
||||
(os.getenv("OPENAI_API_KEY", "").strip() if _da_is_openai_url else ""),
|
||||
(os.getenv("OPENROUTER_API_KEY", "").strip() if _da_is_openrouter else ""),
|
||||
# Bonus (#28660): derive `<VENDOR>_API_KEY` from the host so users
|
||||
# who set DEEPSEEK_API_KEY / GROQ_API_KEY / MISTRAL_API_KEY get the
|
||||
# intuitive match without configuring `custom_providers` first.
|
||||
_host_derived_api_key(base_url),
|
||||
]
|
||||
api_key = next(
|
||||
(c for c in api_key_candidates if has_usable_secret(c)),
|
||||
|
|
@ -619,14 +699,27 @@ def _resolve_named_custom_runtime(
|
|||
model_name = custom_provider.get("model")
|
||||
if model_name:
|
||||
pool_result["model"] = model_name
|
||||
request_overrides = _custom_provider_request_overrides(custom_provider)
|
||||
if request_overrides:
|
||||
pool_result["request_overrides"] = {
|
||||
**dict(pool_result.get("request_overrides") or {}),
|
||||
**request_overrides,
|
||||
}
|
||||
return pool_result
|
||||
|
||||
_cp_is_openai_url = base_url_host_matches(base_url, "openai.com") or base_url_host_matches(base_url, "openai.azure.com")
|
||||
_cp_is_openrouter = base_url_host_matches(base_url, "openrouter.ai")
|
||||
api_key_candidates = [
|
||||
(explicit_api_key or "").strip(),
|
||||
str(custom_provider.get("api_key", "") or "").strip(),
|
||||
os.getenv(str(custom_provider.get("key_env", "") or "").strip(), "").strip(),
|
||||
os.getenv("OPENAI_API_KEY", "").strip(),
|
||||
os.getenv("OPENROUTER_API_KEY", "").strip(),
|
||||
# Gate provider env keys on their authoritative hosts — sending
|
||||
# OPENAI_API_KEY to a local-llm endpoint leaks credentials (#28660).
|
||||
(os.getenv("OPENAI_API_KEY", "").strip() if _cp_is_openai_url else ""),
|
||||
(os.getenv("OPENROUTER_API_KEY", "").strip() if _cp_is_openrouter else ""),
|
||||
# Bonus (#28660): derive `<VENDOR>_API_KEY` from the host as a final
|
||||
# fallback when key_env wasn't set explicitly.
|
||||
_host_derived_api_key(base_url),
|
||||
]
|
||||
api_key = next((candidate for candidate in api_key_candidates if has_usable_secret(candidate)), "")
|
||||
|
||||
|
|
@ -643,6 +736,9 @@ def _resolve_named_custom_runtime(
|
|||
# provider name differs from the actual model string the API expects.
|
||||
if custom_provider.get("model"):
|
||||
result["model"] = custom_provider["model"]
|
||||
request_overrides = _custom_provider_request_overrides(custom_provider)
|
||||
if request_overrides:
|
||||
result["request_overrides"] = request_overrides
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -707,7 +803,15 @@ def _resolve_openrouter_runtime(
|
|||
# OPENAI_API_KEY so the OpenRouter key doesn't leak to an unrelated
|
||||
# provider (issues #420, #560).
|
||||
_is_openrouter_url = base_url_host_matches(base_url, "openrouter.ai")
|
||||
if _is_openrouter_url:
|
||||
# Also treat explicitly-configured OpenRouter mirrors/proxies as OpenRouter
|
||||
# for key selection — if the user set OPENROUTER_BASE_URL or requested
|
||||
# provider=openrouter explicitly, OPENROUTER_API_KEY should still be used.
|
||||
_is_openrouter_context = _is_openrouter_url or (
|
||||
requested_norm == "openrouter"
|
||||
and (env_openrouter_base_url or base_url == env_openrouter_base_url)
|
||||
and base_url == (env_openrouter_base_url or "").rstrip("/")
|
||||
)
|
||||
if _is_openrouter_context:
|
||||
api_key_candidates = [
|
||||
explicit_api_key,
|
||||
os.getenv("OPENROUTER_API_KEY"),
|
||||
|
|
@ -721,13 +825,24 @@ def _resolve_openrouter_runtime(
|
|||
# "ollama.com" (e.g. http://127.0.0.1/ollama.com/v1) or whose
|
||||
# hostname is a look-alike (ollama.com.attacker.test) must not
|
||||
# receive the Ollama credential. See GHSA-76xc-57q6-vm5m.
|
||||
_is_ollama_url = base_url_host_matches(base_url, "ollama.com")
|
||||
_is_ollama_url = base_url_host_matches(base_url, "ollama.com")
|
||||
_is_openai_url = base_url_host_matches(base_url, "openai.com")
|
||||
_is_openai_azure = base_url_host_matches(base_url, "openai.azure.com")
|
||||
# Gate each provider key on its own host — sending OPENAI_API_KEY or
|
||||
# OPENROUTER_API_KEY to an unrelated custom endpoint (DeepSeek, Groq,
|
||||
# Mistral, …) leaks credentials and causes 401s (issue #28660).
|
||||
# Mirrors the OLLAMA_API_KEY host-gate added in GHSA-76xc-57q6-vm5m.
|
||||
api_key_candidates = [
|
||||
explicit_api_key,
|
||||
(cfg_api_key if use_config_base_url else ""),
|
||||
(os.getenv("OLLAMA_API_KEY") if _is_ollama_url else ""),
|
||||
os.getenv("OPENAI_API_KEY"),
|
||||
os.getenv("OPENROUTER_API_KEY"),
|
||||
(os.getenv("OLLAMA_API_KEY") if _is_ollama_url else ""),
|
||||
(os.getenv("OPENAI_API_KEY") if (_is_openai_url or _is_openai_azure) else ""),
|
||||
(os.getenv("OPENROUTER_API_KEY") if _is_openrouter_url else ""),
|
||||
# Bonus (#28660): derive `<VENDOR>_API_KEY` from the host so users
|
||||
# who set DEEPSEEK_API_KEY / GROQ_API_KEY / MISTRAL_API_KEY get the
|
||||
# intuitive match. Helper returns "" for IPs/loopback and for env
|
||||
# vars already handled by the explicit host-gated paths above.
|
||||
_host_derived_api_key(base_url),
|
||||
]
|
||||
api_key = next(
|
||||
(str(candidate or "").strip() for candidate in api_key_candidates if has_usable_secret(candidate)),
|
||||
|
|
|
|||
|
|
@ -319,12 +319,14 @@ def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
|
|||
c.print("[dim]No skills found in the Skills Hub.[/]\n")
|
||||
return
|
||||
|
||||
# Deduplicate by name, preferring higher trust
|
||||
# Deduplicate by identifier, preferring higher trust.
|
||||
# identifier is always unique per skill; name is not (browse-sh skills from different
|
||||
# sites can share the same task name, e.g. "search-listings" on Airbnb and Booking.com).
|
||||
seen: dict = {}
|
||||
for r in all_results:
|
||||
rank = _TRUST_RANK.get(r.trust_level, 0)
|
||||
if r.name not in seen or rank > _TRUST_RANK.get(seen[r.name].trust_level, 0):
|
||||
seen[r.name] = r
|
||||
if r.identifier not in seen or rank > _TRUST_RANK.get(seen[r.identifier].trust_level, 0):
|
||||
seen[r.identifier] = r
|
||||
deduped = list(seen.values())
|
||||
|
||||
# Sort: official first, then by trust level (desc), then alphabetically
|
||||
|
|
@ -702,8 +704,8 @@ def browse_skills(page: int = 1, page_size: int = 20, source: str = "all") -> di
|
|||
seen: dict = {}
|
||||
for r in all_results:
|
||||
rank = _TRUST_RANK.get(r.trust_level, 0)
|
||||
if r.name not in seen or rank > _TRUST_RANK.get(seen[r.name].trust_level, 0):
|
||||
seen[r.name] = r
|
||||
if r.identifier not in seen or rank > _TRUST_RANK.get(seen[r.identifier].trust_level, 0):
|
||||
seen[r.identifier] = r
|
||||
deduped = list(seen.values())
|
||||
deduped.sort(key=lambda r: (-_TRUST_RANK.get(r.trust_level, 0), r.source != "official", r.name.lower()))
|
||||
total = len(deduped)
|
||||
|
|
|
|||
|
|
@ -235,6 +235,26 @@ def display_hermes_home() -> str:
|
|||
return str(home)
|
||||
|
||||
|
||||
def secure_parent_dir(path: Path) -> None:
|
||||
"""Chmod ``0o700`` on the parent directory of *path*, but only if safe.
|
||||
|
||||
Refuses to chmod ``/`` or any top-level directory (resolved parent with
|
||||
fewer than 3 parts, i.e. ``/`` or any direct child like ``/usr``) to
|
||||
prevent catastrophic host bricking when ``HERMES_HOME`` or other path
|
||||
env vars resolve to an unexpected location.
|
||||
|
||||
See https://github.com/NousResearch/hermes-agent/issues/25821.
|
||||
"""
|
||||
parent = path.parent.resolve()
|
||||
# Refuse root and its direct children (/usr, /home, /var, /tmp, …).
|
||||
if parent == Path("/") or len(parent.parts) < 3:
|
||||
return
|
||||
try:
|
||||
os.chmod(parent, 0o700)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def get_subprocess_home() -> str | None:
|
||||
"""Return a per-profile HOME directory for subprocesses, or None.
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ T = TypeVar("T")
|
|||
|
||||
DEFAULT_DB_PATH = get_hermes_home() / "state.db"
|
||||
|
||||
SCHEMA_VERSION = 11
|
||||
SCHEMA_VERSION = 12
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WAL-compatibility fallback
|
||||
|
|
@ -237,7 +237,8 @@ CREATE TABLE IF NOT EXISTS messages (
|
|||
reasoning_content TEXT,
|
||||
reasoning_details TEXT,
|
||||
codex_reasoning_items TEXT,
|
||||
codex_message_items TEXT
|
||||
codex_message_items TEXT,
|
||||
platform_message_id TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS state_meta (
|
||||
|
|
@ -572,6 +573,19 @@ class SessionDB:
|
|||
# column gets created here.
|
||||
self._reconcile_columns(cursor)
|
||||
|
||||
# Indexes that reference reconciler-added columns must be created
|
||||
# AFTER _reconcile_columns runs — declaring them in SCHEMA_SQL
|
||||
# makes the initial executescript fail on legacy DBs (the index's
|
||||
# WHERE clause references a column that doesn't exist yet).
|
||||
try:
|
||||
cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_messages_platform_msg_id "
|
||||
"ON messages(session_id, platform_message_id) "
|
||||
"WHERE platform_message_id IS NOT NULL"
|
||||
)
|
||||
except sqlite3.OperationalError as exc:
|
||||
logger.debug("idx_messages_platform_msg_id create skipped: %s", exc)
|
||||
|
||||
# ── Schema version bookkeeping ─────────────────────────────────
|
||||
# Bump to current so future data migrations (if any) can gate on
|
||||
# version. No version-gated column additions remain.
|
||||
|
|
@ -1462,12 +1476,19 @@ class SessionDB:
|
|||
reasoning_details: Any = None,
|
||||
codex_reasoning_items: Any = None,
|
||||
codex_message_items: Any = None,
|
||||
platform_message_id: str = None,
|
||||
) -> int:
|
||||
"""
|
||||
Append a message to a session. Returns the message row ID.
|
||||
|
||||
Also increments the session's message_count (and tool_call_count
|
||||
if role is 'tool' or tool_calls is present).
|
||||
|
||||
``platform_message_id`` is the external messaging platform's own
|
||||
message ID (e.g. Telegram update_id, Yuanbao msg_id). It is
|
||||
independent of the SQLite autoincrement primary key and is used by
|
||||
platform-specific flows like yuanbao's recall guard to redact a
|
||||
message by its platform-side identifier.
|
||||
"""
|
||||
# Serialize structured fields to JSON before entering the write txn
|
||||
reasoning_details_json = (
|
||||
|
|
@ -1497,8 +1518,8 @@ class SessionDB:
|
|||
"""INSERT INTO messages (session_id, role, content, tool_call_id,
|
||||
tool_calls, tool_name, timestamp, token_count, finish_reason,
|
||||
reasoning, reasoning_content, reasoning_details, codex_reasoning_items,
|
||||
codex_message_items)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
codex_message_items, platform_message_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
session_id,
|
||||
role,
|
||||
|
|
@ -1514,6 +1535,7 @@ class SessionDB:
|
|||
reasoning_details_json,
|
||||
codex_items_json,
|
||||
codex_message_items_json,
|
||||
platform_message_id,
|
||||
),
|
||||
)
|
||||
msg_id = cursor.lastrowid
|
||||
|
|
@ -1575,13 +1597,18 @@ class SessionDB:
|
|||
json.dumps(codex_message_items) if codex_message_items else None
|
||||
)
|
||||
tool_calls_json = json.dumps(tool_calls) if tool_calls else None
|
||||
# Accept either `platform_message_id` (new explicit name) or
|
||||
# `message_id` (yuanbao's existing convention on message dicts).
|
||||
platform_msg_id = (
|
||||
msg.get("platform_message_id") or msg.get("message_id")
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, tool_call_id,
|
||||
tool_calls, tool_name, timestamp, token_count, finish_reason,
|
||||
reasoning, reasoning_content, reasoning_details, codex_reasoning_items,
|
||||
codex_message_items)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
codex_message_items, platform_message_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
session_id,
|
||||
role,
|
||||
|
|
@ -1597,6 +1624,7 @@ class SessionDB:
|
|||
reasoning_details_json,
|
||||
codex_items_json,
|
||||
codex_message_items_json,
|
||||
platform_msg_id,
|
||||
),
|
||||
)
|
||||
total_messages += 1
|
||||
|
|
@ -1914,7 +1942,7 @@ class SessionDB:
|
|||
rows = self._conn.execute(
|
||||
"SELECT role, content, tool_call_id, tool_calls, tool_name, "
|
||||
"finish_reason, reasoning, reasoning_content, reasoning_details, "
|
||||
"codex_reasoning_items, codex_message_items "
|
||||
"codex_reasoning_items, codex_message_items, platform_message_id "
|
||||
f"FROM messages WHERE session_id IN ({placeholders}) ORDER BY id",
|
||||
tuple(session_ids),
|
||||
).fetchall()
|
||||
|
|
@ -1935,6 +1963,13 @@ class SessionDB:
|
|||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning("Failed to deserialize tool_calls in conversation replay, falling back to []")
|
||||
msg["tool_calls"] = []
|
||||
# Surface the platform-side message id (e.g. yuanbao msg_id,
|
||||
# telegram update_id) so platform-specific flows like recall
|
||||
# can match by external identifier instead of having to fall
|
||||
# back to content-match heuristics. Exposed as ``message_id``
|
||||
# for backward compatibility with the JSONL transcript shape.
|
||||
if row["platform_message_id"]:
|
||||
msg["message_id"] = row["platform_message_id"]
|
||||
# Restore reasoning fields on assistant messages so providers
|
||||
# that replay reasoning (OpenRouter, OpenAI, Nous) receive
|
||||
# coherent multi-turn reasoning context.
|
||||
|
|
|
|||
|
|
@ -17,6 +17,11 @@
|
|||
openssh,
|
||||
ffmpeg,
|
||||
tirith,
|
||||
|
||||
# linux-only deps
|
||||
wl-clipboard,
|
||||
xclip,
|
||||
|
||||
# Flake inputs — passed explicitly by packages.nix and overlays.nix
|
||||
uv2nix,
|
||||
pyproject-nix,
|
||||
|
|
@ -69,6 +74,10 @@ let
|
|||
openssh
|
||||
ffmpeg
|
||||
tirith
|
||||
]
|
||||
++ lib.optionals stdenv.isLinux [
|
||||
wl-clipboard
|
||||
xclip
|
||||
];
|
||||
|
||||
runtimePath = lib.makeBinPath runtimeDeps;
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ let
|
|||
src = ../ui-tui;
|
||||
npmDeps = pkgs.fetchNpmDeps {
|
||||
inherit src;
|
||||
hash = "sha256-dNL/J4tyQQ7Ji3xfIE5b5Jdi6rQyCFjqYpzLYftJVdc=";
|
||||
hash = "sha256-F6/MzZOWc0zhW9mIfnaY+PrllPvJcsA/OdFdEM+NpLY=";
|
||||
};
|
||||
|
||||
npm = hermesNpmLib.mkNpmPassthru { folder = "ui-tui"; attr = "tui"; pname = "hermes-tui"; };
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ modal = ["modal==1.3.4"]
|
|||
daytona = ["daytona==0.155.0"]
|
||||
vercel = ["vercel==0.5.7"]
|
||||
hindsight = ["hindsight-client==0.6.1"]
|
||||
dev = ["debugpy==1.8.20", "pytest==9.0.2", "pytest-asyncio==1.3.0", "pytest-xdist==3.8.0", "pytest-split==0.11.0", "pytest-timeout==2.4.0", "mcp==1.26.0", "ty==0.0.21", "ruff==0.15.10"]
|
||||
dev = ["debugpy==1.8.20", "pytest==9.0.2", "pytest-asyncio==1.3.0", "pytest-timeout==2.4.0", "mcp==1.26.0", "ty==0.0.21", "ruff==0.15.10"]
|
||||
messaging = ["python-telegram-bot[webhooks]==22.6", "discord.py[voice]==2.7.1", "aiohttp==3.13.3", "brotlicffi==1.2.0.1", "slack-bolt==1.27.0", "slack-sdk==3.40.1", "qrcode==7.4.2"]
|
||||
cron = [] # croniter is now a core dependency; this extra kept for back-compat
|
||||
slack = ["slack-bolt==1.27.0", "slack-sdk==3.40.1", "aiohttp==3.13.3"]
|
||||
|
|
@ -238,16 +238,12 @@ markers = [
|
|||
"integration: marks tests requiring external services (API keys, Modal, etc.)",
|
||||
"real_concurrent_gate: opt out of the autouse stub that disables _detect_concurrent_hermes_instances",
|
||||
]
|
||||
# pytest-timeout: per-test 60s hard cap with thread method.
|
||||
# Discovered May 2026: the suite reliably hangs at ~96% on full runs even
|
||||
# though every individual test completes in <30s. Root cause is leaked
|
||||
# threads / atexit handlers accumulating across thousands of tests until
|
||||
# something deadlocks at session teardown. Adding pytest-timeout (with
|
||||
# thread method, which forces an interrupt into the test thread) breaks
|
||||
# the deadlock — the suite then completes cleanly. The 60s cap is large
|
||||
# enough that no legitimate test trips it; if a test exceeds it that's a
|
||||
# real bug worth surfacing as a Timeout failure.
|
||||
addopts = "-m 'not integration' -n auto --timeout=30 --timeout-method=signal"
|
||||
# pytest-timeout: per-test 30s hard cap with signal method.
|
||||
# This is the fallback inside each per-file pytest subprocess (see
|
||||
# scripts/run_tests_parallel.py). Per-file isolation gives every test
|
||||
# file a fresh Python interpreter; pytest-timeout catches Python-level
|
||||
# hangs within a file.
|
||||
addopts = "-m 'not integration' --timeout=30 --timeout-method=signal"
|
||||
|
||||
[tool.ty.environment]
|
||||
python-version = "3.13"
|
||||
|
|
|
|||
18
run_agent.py
18
run_agent.py
|
|
@ -3218,17 +3218,21 @@ class AIAgent:
|
|||
Used to decide whether to strip image content parts from API-bound
|
||||
messages (for non-vision models) or let the provider adapter handle
|
||||
them natively (for vision-capable models).
|
||||
|
||||
Resolution order (see ``agent.image_routing._supports_vision_override``):
|
||||
1. ``model.supports_vision`` (top-level, single-model shortcut)
|
||||
2. ``providers.<provider>.models.<model>.supports_vision``
|
||||
3. models.dev capability lookup
|
||||
Custom/local models absent from models.dev would otherwise be
|
||||
misclassified as non-vision and have their images stripped.
|
||||
"""
|
||||
try:
|
||||
from agent.models_dev import get_model_capabilities
|
||||
from hermes_cli.config import load_config
|
||||
from agent.image_routing import _lookup_supports_vision
|
||||
cfg = load_config()
|
||||
provider = (getattr(self, "provider", "") or "").strip()
|
||||
model = (getattr(self, "model", "") or "").strip()
|
||||
if not provider or not model:
|
||||
return False
|
||||
caps = get_model_capabilities(provider, model)
|
||||
if caps is None:
|
||||
return False
|
||||
return bool(caps.supports_vision)
|
||||
return _lookup_supports_vision(provider, model, cfg) is True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ ACP_REGISTRY_MANIFEST = REPO_ROOT / "acp_registry" / "agent.json"
|
|||
AUTHOR_MAP = {
|
||||
# teknium (multiple emails)
|
||||
"teknium1@gmail.com": "teknium1",
|
||||
"me@promplate.dev": "CNSeniorious000",
|
||||
"erhanyasarx@gmail.com": "erhnysr",
|
||||
"30366221+WorldWriter@users.noreply.github.com": "WorldWriter",
|
||||
"dafeng@DafengdeMacBook-Pro.local": "WorldWriter",
|
||||
"anadi.jaggia@gmail.com": "Jaggia",
|
||||
|
|
@ -58,6 +60,7 @@ AUTHOR_MAP = {
|
|||
"altriatree@gmail.com": "TruaShamu",
|
||||
"m@mobrienv.dev": "mikeyobrien",
|
||||
"saeed919@pm.me": "falasi",
|
||||
"omar@techdeveloper.site": "nycomar",
|
||||
"qiyin.zuo@pcitc.com": "qiyin-code",
|
||||
"mr.aashiz@gmail.com": "aashizpoudel",
|
||||
"70629228+shaun0927@users.noreply.github.com": "shaun0927",
|
||||
|
|
@ -74,6 +77,7 @@ AUTHOR_MAP = {
|
|||
"108427749+buntingszn@users.noreply.github.com": "buntingszn",
|
||||
"yanglongwei06@gmail.com": "Alex-yang00",
|
||||
"teknium@nousresearch.com": "teknium1",
|
||||
"markuscontasul@gmail.com": "Glucksberg",
|
||||
"piyushvp1@gmail.com": "thelumiereguy",
|
||||
"dskwelmcy@163.com": "dskwe",
|
||||
"421774554@qq.com": "wuli666",
|
||||
|
|
@ -714,6 +718,7 @@ AUTHOR_MAP = {
|
|||
"9219265+cresslank@users.noreply.github.com": "cresslank",
|
||||
"trevmanthony@gmail.com": "trevthefoolish",
|
||||
"ziliangpeng@users.noreply.github.com": "ziliangpeng",
|
||||
"ziliangdotme@gmail.com": "ziliangpeng",
|
||||
"centripetal-star@users.noreply.github.com": "centripetal-star",
|
||||
"LeonSGP43@users.noreply.github.com": "LeonSGP43",
|
||||
"154585401+LeonSGP43@users.noreply.github.com": "LeonSGP43",
|
||||
|
|
|
|||
|
|
@ -3,29 +3,36 @@
|
|||
# `pytest` directly to guarantee your local run matches CI behavior.
|
||||
#
|
||||
# What this script enforces:
|
||||
# * -n 4 xdist workers (CI has 4 cores; -n auto diverges locally)
|
||||
# * Per-file isolation via scripts/run_tests_parallel.py — each test
|
||||
# file runs in its own freshly-spawned `python -m pytest <file>`
|
||||
# subprocess. No xdist, no shared workers, no module-level leakage
|
||||
# between files.
|
||||
# * TZ=UTC, LANG=C.UTF-8, PYTHONHASHSEED=0 (deterministic)
|
||||
# * Credential env vars blanked (conftest.py also does this, but this
|
||||
# is belt-and-suspenders for anyone running `pytest` outside of
|
||||
# our conftest path — e.g. calling pytest on a single file)
|
||||
# * Proper venv activation
|
||||
# * Env vars blanked (conftest.py also does this, but this
|
||||
# is belt-and-suspenders for anyone running pytest outside our
|
||||
# conftest path — e.g. on a single file)
|
||||
# * Proper venv activation (probes .venv, venv, then ~/.hermes/...)
|
||||
#
|
||||
# Usage:
|
||||
# scripts/run_tests.sh # full suite
|
||||
# scripts/run_tests.sh tests/agent/ # one directory
|
||||
# scripts/run_tests.sh tests/agent/test_foo.py::TestClass::test_method
|
||||
# scripts/run_tests.sh --tb=long -v # pass-through pytest args
|
||||
# scripts/run_tests.sh # full suite
|
||||
# scripts/run_tests.sh -j 4 # cap parallelism
|
||||
# scripts/run_tests.sh tests/agent/ # discover only here
|
||||
# scripts/run_tests.sh tests/agent/ tests/acp/ # multiple roots
|
||||
# scripts/run_tests.sh tests/foo.py # single file
|
||||
# scripts/run_tests.sh tests/foo.py -- --tb=long # path + pytest args
|
||||
# scripts/run_tests.sh -- -v --tb=long # pytest args only
|
||||
#
|
||||
# Everything after a literal '--' is passed through to each per-file
|
||||
# pytest invocation. Positional path arguments before '--' override
|
||||
# the default discovery root (tests/).
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# ── Locate repo root ────────────────────────────────────────────────────────
|
||||
# Works whether this is the main checkout or a worktree.
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
|
||||
# ── Activate venv ───────────────────────────────────────────────────────────
|
||||
# Prefer a .venv in the current tree, fall back to the main checkout's venv
|
||||
# (useful for worktrees where we don't always duplicate the venv).
|
||||
VENV=""
|
||||
for candidate in "$REPO_ROOT/.venv" "$REPO_ROOT/venv" "$HOME/.hermes/hermes-agent/venv"; do
|
||||
if [ -f "$candidate/bin/activate" ]; then
|
||||
|
|
@ -41,94 +48,31 @@ fi
|
|||
|
||||
PYTHON="$VENV/bin/python"
|
||||
|
||||
# ── Ensure pytest-split is installed (required for shard-equivalent runs) ──
|
||||
if ! "$PYTHON" -c "import pytest_split" 2>/dev/null; then
|
||||
echo "→ installing pytest-split into $VENV"
|
||||
if command -v uv >/dev/null 2>&1; then
|
||||
uv pip install --python "$PYTHON" --quiet "pytest-split>=0.9,<1"
|
||||
elif "$PYTHON" -m pip --version >/dev/null 2>&1; then
|
||||
"$PYTHON" -m pip install --quiet "pytest-split>=0.9,<1"
|
||||
else
|
||||
echo "error: neither uv nor pip is available in $VENV — pytest-split is missing" >&2
|
||||
echo " fix: run uv pip install -e \".[dev]\" from $REPO_ROOT" >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# ── Hermetic environment ────────────────────────────────────────────────────
|
||||
# Mirror what CI does in .github/workflows/tests.yml + what conftest.py does.
|
||||
# Unset every credential-shaped var currently in the environment.
|
||||
while IFS='=' read -r name _; do
|
||||
case "$name" in
|
||||
*_API_KEY|*_TOKEN|*_SECRET|*_PASSWORD|*_CREDENTIALS|*_ACCESS_KEY| \
|
||||
*_SECRET_ACCESS_KEY|*_PRIVATE_KEY|*_OAUTH_TOKEN|*_WEBHOOK_SECRET| \
|
||||
*_ENCRYPT_KEY|*_APP_SECRET|*_CLIENT_SECRET|*_CORP_SECRET|*_AES_KEY| \
|
||||
AWS_ACCESS_KEY_ID|AWS_SECRET_ACCESS_KEY|AWS_SESSION_TOKEN|FAL_KEY| \
|
||||
GH_TOKEN|GITHUB_TOKEN)
|
||||
unset "$name"
|
||||
;;
|
||||
esac
|
||||
done < <(env)
|
||||
|
||||
# Unset HERMES_* behavioral vars too.
|
||||
unset HERMES_YOLO_MODE HERMES_INTERACTIVE HERMES_QUIET HERMES_TOOL_PROGRESS \
|
||||
HERMES_TOOL_PROGRESS_MODE HERMES_MAX_ITERATIONS HERMES_SESSION_PLATFORM \
|
||||
HERMES_SESSION_CHAT_ID HERMES_SESSION_CHAT_NAME HERMES_SESSION_THREAD_ID \
|
||||
HERMES_SESSION_SOURCE HERMES_SESSION_KEY HERMES_GATEWAY_SESSION \
|
||||
HERMES_CRON_SESSION \
|
||||
HERMES_PLATFORM HERMES_INFERENCE_PROVIDER HERMES_MANAGED HERMES_DEV \
|
||||
HERMES_CONTAINER HERMES_EPHEMERAL_SYSTEM_PROMPT HERMES_TIMEZONE \
|
||||
HERMES_REDACT_SECRETS HERMES_BACKGROUND_NOTIFICATIONS HERMES_EXEC_ASK \
|
||||
HERMES_HOME_MODE 2>/dev/null || true
|
||||
|
||||
# Pin deterministic runtime.
|
||||
export TZ=UTC
|
||||
export LANG=C.UTF-8
|
||||
export LC_ALL=C.UTF-8
|
||||
export PYTHONHASHSEED=0
|
||||
|
||||
# ── Live-gateway test guard (developer machines) ────────────────────────────
|
||||
# If a system-wide hermes pytest_live_guard plugin is installed at
|
||||
# $HOME/.hermes/pytest_live_guard.py, force-load it here so every test run
|
||||
# from this script gets the protection regardless of which worktree is
|
||||
# checked out (in-tree tests/conftest.py guard may be missing on stale
|
||||
# branches). Harmless on CI / fresh machines that don't have the file.
|
||||
# ── Live-gateway plugin (computed before we drop env) ───────────────────────
|
||||
EXTRA_PYTHONPATH=""
|
||||
EXTRA_PYTEST_PLUGINS=""
|
||||
if [ -f "$HOME/.hermes/pytest_live_guard.py" ]; then
|
||||
case ":${PYTHONPATH:-}:" in
|
||||
*":$HOME/.hermes:"*) ;;
|
||||
*) export PYTHONPATH="${PYTHONPATH:+$PYTHONPATH:}$HOME/.hermes" ;;
|
||||
esac
|
||||
if [[ ",${PYTEST_PLUGINS:-}," != *,pytest_live_guard,* ]]; then
|
||||
export PYTEST_PLUGINS="${PYTEST_PLUGINS:+$PYTEST_PLUGINS,}pytest_live_guard"
|
||||
fi
|
||||
EXTRA_PYTHONPATH="$HOME/.hermes"
|
||||
EXTRA_PYTEST_PLUGINS="pytest_live_guard"
|
||||
fi
|
||||
|
||||
# ── Worker count ────────────────────────────────────────────────────────────
|
||||
# CI uses `-n auto` on ubuntu-latest which gives 4 workers. A 20-core
|
||||
# workstation with `-n auto` gets 20 workers and exposes test-ordering
|
||||
# flakes that CI will never see. Pin to 4 so local matches CI.
|
||||
WORKERS="${HERMES_TEST_WORKERS:-4}"
|
||||
|
||||
# ── Run pytest ──────────────────────────────────────────────────────────────
|
||||
# ── Run in hermetic env ──────────────────────────────────────────────────────
|
||||
# env -i: start with empty environment, opt-in only what we need.
|
||||
# No credential var can leak — you'd have to explicitly add it here.
|
||||
echo "▶ running per-file parallel test suite via run_tests_parallel.py"
|
||||
echo " (TZ=UTC LANG=C.UTF-8 PYTHONHASHSEED=0; clean env)"
|
||||
|
||||
cd "$REPO_ROOT"
|
||||
|
||||
# If the first argument starts with `-` treat all args as pytest flags;
|
||||
# otherwise treat them as test paths.
|
||||
ARGS=("$@")
|
||||
|
||||
echo "▶ running pytest with $WORKERS workers, hermetic env, in $REPO_ROOT"
|
||||
echo " (TZ=UTC LANG=C.UTF-8 PYTHONHASHSEED=0; all credential env vars unset)"
|
||||
|
||||
# -o "addopts=" clears pyproject.toml's `-n auto` so our -n wins.
|
||||
# We re-add --timeout/--timeout-method here because pyproject.toml's
|
||||
# addopts is wiped above. The 60s cap is essential: see pyproject.toml
|
||||
# for why (suite deadlocks at session teardown without it).
|
||||
exec "$PYTHON" -m pytest \
|
||||
-o "addopts=" \
|
||||
-n "$WORKERS" \
|
||||
--timeout=30 \
|
||||
--timeout-method=signal \
|
||||
--ignore=tests/integration \
|
||||
--ignore=tests/e2e \
|
||||
-m "not integration" \
|
||||
"${ARGS[@]}"
|
||||
exec env -i \
|
||||
PATH="$PATH" \
|
||||
HOME="$HOME" \
|
||||
TZ=UTC \
|
||||
LANG=C.UTF-8 \
|
||||
LC_ALL=C.UTF-8 \
|
||||
PYTHONHASHSEED=0 \
|
||||
${EXTRA_PYTHONPATH:+PYTHONPATH="$EXTRA_PYTHONPATH"} \
|
||||
${EXTRA_PYTEST_PLUGINS:+PYTEST_PLUGINS="$EXTRA_PYTEST_PLUGINS"} \
|
||||
"$PYTHON" "$SCRIPT_DIR/run_tests_parallel.py" "$@"
|
||||
|
|
|
|||
650
scripts/run_tests_parallel.py
Executable file
650
scripts/run_tests_parallel.py
Executable file
|
|
@ -0,0 +1,650 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Per-file parallel test runner.
|
||||
|
||||
The minimum-viable replacement for pytest-xdist + a subprocess-isolation
|
||||
plugin. Discovers test files under ``tests/`` (excluding integration/e2e
|
||||
unless explicitly requested), then runs one ``python -m pytest <file>``
|
||||
subprocess per file, with bounded parallelism (default: ``os.cpu_count()``).
|
||||
|
||||
Why per-file rather than per-test?
|
||||
Per-test spawn overhead (~250ms × 17k tests = 70min CPU minimum)
|
||||
swamped the actual work. Per-file spawn (~250ms × ~850 files = ~3.5min)
|
||||
fits in the budget while still giving every file a fresh Python
|
||||
interpreter — the only isolation boundary that actually matters
|
||||
(cross-file module-level state leakage was the original flake source;
|
||||
intra-file state is the test author's responsibility).
|
||||
|
||||
Why drop xdist entirely?
|
||||
xdist's persistent workers accumulate state across files, which is
|
||||
exactly the leakage we wanted to fix. xdist also adds complexity
|
||||
(loadfile vs loadscope, --max-worker-restart, internal control plane)
|
||||
that we don't need when the unit of work is "run pytest on one file".
|
||||
A subprocess.Popen pool gated by a semaphore is ~60 lines and does
|
||||
the job.
|
||||
|
||||
Usage:
|
||||
python scripts/run_tests_parallel.py [pytest_args...]
|
||||
|
||||
Common pytest args pass through (e.g. ``-v``, ``-x``, ``--tb=long``,
|
||||
``-k 'pattern'``, ``--lf``).
|
||||
|
||||
Environment:
|
||||
HERMES_TEST_WORKERS Override worker count (default: os.cpu_count())
|
||||
HERMES_TEST_PATHS Override discovery roots (colon-sep, default: 'tests')
|
||||
|
||||
Exit code: 0 if every file's pytest exited 0; 1 otherwise.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, Future
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
# Default test discovery roots.
|
||||
_DEFAULT_ROOTS = ["tests"]
|
||||
|
||||
# Directories to skip during discovery — the e2e + integration suites
|
||||
# require real services and are run separately. Match exactly the
|
||||
# ``--ignore=`` flags the previous CI command used.
|
||||
_SKIP_PARTS = {"integration", "e2e"}
|
||||
|
||||
# Per-file wall-clock cap. Generous default — pytest-timeout still
|
||||
# enforces per-test caps inside each subprocess; this is just an outer
|
||||
# safety net so a single hung file can't stall the whole suite. Override
|
||||
# via --file-timeout or HERMES_TEST_FILE_TIMEOUT.
|
||||
_DEFAULT_FILE_TIMEOUT_SECONDS = 600.0 # 10 minutes
|
||||
|
||||
|
||||
def _count_tests(
|
||||
files: List[Path], repo_root: Path, pytest_passthrough: List[str]
|
||||
) -> dict[Path, int]:
|
||||
"""Run ``pytest --co -q`` once to count individual tests per file.
|
||||
|
||||
Returns a mapping ``{file_path: test_count}``. Files with zero
|
||||
collected tests are omitted from the dict (not an error — e.g. the
|
||||
file only defines fixtures / conftest helpers).
|
||||
|
||||
This is a single subprocess call (~2-5s for ~1k files) that gives
|
||||
us the total test count for the discovery announcement and
|
||||
per-file counts for the progress lines.
|
||||
|
||||
``--ignore`` flags for directories in ``_SKIP_PARTS`` are added
|
||||
automatically so that pytest's own collection machinery (conftest
|
||||
walking, directory traversal) doesn't pull in tests we intend to
|
||||
skip — matching what the per-file runs will actually execute.
|
||||
"""
|
||||
# Build --ignore flags for skipped dirs so the --co collection
|
||||
# mirrors what we'll actually run (not what pytest might find via
|
||||
# conftest walking or directory traversal).
|
||||
ignore_args: List[str] = []
|
||||
for root in [repo_root / p for p in _DEFAULT_ROOTS]:
|
||||
for part in _SKIP_PARTS:
|
||||
d = root / part
|
||||
if d.is_dir():
|
||||
ignore_args.extend(["--ignore", str(d)])
|
||||
|
||||
cmd = [
|
||||
sys.executable, "-m", "pytest",
|
||||
"--co", "-q",
|
||||
*ignore_args,
|
||||
*[str(f) for f in files],
|
||||
*pytest_passthrough,
|
||||
]
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=repo_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, OSError):
|
||||
return {}
|
||||
|
||||
counts: dict[Path, int] = {}
|
||||
for line in result.stdout.splitlines():
|
||||
# Lines look like: tests/acp/test_auth.py::TestClass::test_name
|
||||
if "::" not in line:
|
||||
continue
|
||||
file_part = line.split("::", 1)[0]
|
||||
key = repo_root / file_part
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
def _discover_files(roots: List[Path]) -> List[Path]:
|
||||
"""Return every ``test_*.py`` under the given roots (sorted).
|
||||
|
||||
Roots may be directories (recursed for ``test_*.py``) or explicit
|
||||
``.py`` files (included as-is, even if they don't match the
|
||||
``test_*`` prefix — caller knows what they want).
|
||||
|
||||
Exclude any file whose path contains a component in ``_SKIP_PARTS``,
|
||||
UNLESS the user explicitly named it as a root (in which case the
|
||||
user's intent overrides the skip filter).
|
||||
"""
|
||||
seen: set[Path] = set()
|
||||
out: List[Path] = []
|
||||
for root in roots:
|
||||
if not root.exists():
|
||||
continue
|
||||
if root.is_file():
|
||||
# Explicit file: include it as-is, skip the _SKIP_PARTS filter
|
||||
# since the user named it directly.
|
||||
real = root.resolve()
|
||||
if real not in seen:
|
||||
seen.add(real)
|
||||
out.append(root)
|
||||
continue
|
||||
for path in root.rglob("test_*.py"):
|
||||
if any(part in _SKIP_PARTS for part in path.parts):
|
||||
continue
|
||||
real = path.resolve()
|
||||
if real in seen:
|
||||
continue
|
||||
seen.add(real)
|
||||
out.append(path)
|
||||
return sorted(out)
|
||||
|
||||
|
||||
def _kill_tree(proc: "subprocess.Popen", pgid: int | None = None) -> None:
|
||||
"""Kill the pytest subprocess and every descendant it spawned.
|
||||
|
||||
A test run can spin up uvicorn servers, async runtimes, or other
|
||||
long-running grandchildren that survive the pytest subprocess exit
|
||||
if we don't kill the whole tree. ``subprocess.Popen.kill()`` only
|
||||
targets the immediate child; grandchildren reparent to PID 1
|
||||
(Linux) / get adopted by services.exe (Windows) and leak.
|
||||
|
||||
POSIX: the caller must pass ``pgid`` — the process group id captured
|
||||
immediately after Popen (via ``os.getpgid(proc.pid)``). We can't
|
||||
look it up here in the happy path because by the time we get
|
||||
called the leader process has already been reaped and its pid is
|
||||
gone from the kernel's process table, even though descendants in
|
||||
the group are still alive. SIGKILL'ing the captured pgid takes out
|
||||
everything in that group atomically.
|
||||
|
||||
Windows: ``taskkill /F /T /PID`` walks the recorded ppid chain and
|
||||
terminates the whole tree, even when the root has already exited.
|
||||
|
||||
Why not psutil: psutil walks the parent-child tree, but in the
|
||||
happy path the root has already been reaped so ``psutil.Process(pid)``
|
||||
can't find it; grandchildren reparented to PID 1 are also
|
||||
unreachable by tree walk at that point. The platform-native
|
||||
primitives (process groups / taskkill) handle both cases correctly
|
||||
without an extra abstraction layer.
|
||||
"""
|
||||
if proc.pid is None:
|
||||
return
|
||||
|
||||
if sys.platform == "win32":
|
||||
try:
|
||||
|
||||
subprocess.run(
|
||||
["taskkill", "/F", "/T", "/PID", str(proc.pid)],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
timeout=10,
|
||||
) # windows-footgun: ok
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
pass
|
||||
else:
|
||||
# POSIX: kill the captured pgid. Local-import signal so the
|
||||
# SIGKILL attribute is never referenced on Windows.
|
||||
if pgid is not None:
|
||||
try:
|
||||
import signal as _signal
|
||||
os.killpg(pgid, _signal.SIGKILL) # windows-footgun: ok
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
pass
|
||||
|
||||
# Belt-and-suspenders: ensure subprocess.communicate() sees the exit.
|
||||
try:
|
||||
proc.kill()
|
||||
except (ProcessLookupError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
def _run_one_file(
|
||||
file: Path,
|
||||
pytest_args: List[str],
|
||||
repo_root: Path,
|
||||
file_timeout: float,
|
||||
) -> Tuple[Path, int, str, dict[str, int]]:
|
||||
"""Run ``python -m pytest <file> <pytest_args>`` in a fresh subprocess.
|
||||
|
||||
Returns (file, returncode, captured_combined_output, summary_counts).
|
||||
|
||||
``summary_counts`` is the result of ``_parse_pytest_summary(output)`` —
|
||||
|
||||
pytest exit codes (https://docs.pytest.org/en/stable/reference/exit-codes.html):
|
||||
0 = all tests passed
|
||||
1 = some tests failed
|
||||
2 = test execution interrupted
|
||||
3 = internal error
|
||||
4 = pytest CLI usage error
|
||||
5 = no tests collected
|
||||
|
||||
We treat exit 5 as a pass: it just means every test in the file was
|
||||
skipped or filtered by a marker (e.g. ``-m 'not integration'`` skips
|
||||
files where every test is marked integration). That's intentional and
|
||||
not a failure mode.
|
||||
|
||||
On per-file timeout (``file_timeout`` seconds) or any other exception
|
||||
during ``communicate()``, we kill the whole process group / process
|
||||
tree so grandchildren (uvicorn servers, async runtimes, etc.) do not
|
||||
orphan onto PID 1. The pytest-timeout plugin enforces per-test
|
||||
timeouts inside the subprocess; this outer timeout exists only to
|
||||
bound a pathologically slow or hung file as a whole.
|
||||
"""
|
||||
cmd = [sys.executable, "-m", "pytest", str(file), *pytest_args]
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=repo_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
# POSIX: place the child at the head of its own process group so
|
||||
# _kill_tree can SIGKILL the group atomically.
|
||||
# Windows: this maps to CREATE_NEW_PROCESS_GROUP in CPython 3.12+;
|
||||
# _kill_tree handles the Windows path via taskkill /F /T.
|
||||
start_new_session=True,
|
||||
)
|
||||
|
||||
# Capture the pgid NOW, before the leader can exit and be reaped.
|
||||
# Once the leader is reaped, os.getpgid(proc.pid) raises
|
||||
# ProcessLookupError even though grandchildren in that group are
|
||||
# still alive — defeating the whole cleanup. None on Windows where
|
||||
# the pgid concept doesn't apply (taskkill walks ppid chain instead).
|
||||
pgid: int | None = None
|
||||
if sys.platform != "win32":
|
||||
try:
|
||||
pgid = os.getpgid(proc.pid)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
# Astonishingly fast child? Already dead. _kill_tree's
|
||||
# fallback will handle this case as a no-op.
|
||||
pgid = None
|
||||
|
||||
try:
|
||||
output, _ = proc.communicate(timeout=file_timeout)
|
||||
rc = proc.returncode
|
||||
except subprocess.TimeoutExpired:
|
||||
_kill_tree(proc, pgid=pgid)
|
||||
# Drain whatever the child wrote before we killed it so we have
|
||||
# something to surface in the failure dump.
|
||||
try:
|
||||
output, _ = proc.communicate(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
output = "(file timeout exceeded; output unavailable)"
|
||||
rc = 124 # de facto convention for "killed by timeout".
|
||||
output = (
|
||||
f"(per-file timeout: {file_timeout:.0f}s exceeded; "
|
||||
f"process tree SIGKILL'd)\n{output}"
|
||||
)
|
||||
except BaseException:
|
||||
# KeyboardInterrupt / runner crash — make sure no zombie
|
||||
# grandchildren outlive us.
|
||||
_kill_tree(proc, pgid=pgid)
|
||||
raise
|
||||
else:
|
||||
# Happy path: pytest exited on its own. The child process already
|
||||
# cleaned up its grandchildren if it's well-behaved, but
|
||||
# well-behaved is not universal — kill the group anyway. Already-
|
||||
# dead processes are a no-op.
|
||||
_kill_tree(proc, pgid=pgid)
|
||||
|
||||
if rc == 5:
|
||||
# No tests collected — every test in the file was filtered out.
|
||||
# Treat as a pass; surface info in a slightly distinct status
|
||||
# so the operator can spot it.
|
||||
rc = 0
|
||||
summary = _parse_pytest_summary(output)
|
||||
return file, rc, output, summary
|
||||
|
||||
|
||||
def _parse_pytest_summary(output: str) -> dict[str, int]:
|
||||
"""Extract per-file test pass/fail/skip counts from pytest output.
|
||||
|
||||
pytest prints a summary line like ``12 passed, 3 skipped, 1 failed in 2.1s``
|
||||
as the last non-empty line before the short test summary. We scrape that
|
||||
line for the individual counts so the progress display can show test-level
|
||||
granularity instead of just file-level pass/fail.
|
||||
|
||||
Returns a dict with keys ``passed``, ``failed``, ``skipped``, ``errors``,
|
||||
``xfailed``, ``xpassed`` (only keys found in the output are present).
|
||||
"""
|
||||
import re
|
||||
|
||||
result: dict[str, int] = {}
|
||||
# Walk backwards from the end — the summary line is always near the tail.
|
||||
for line in reversed(output.splitlines()):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Match "N passed", "N failed", "N skipped", "N errors", "N xfailed", "N xpassed"
|
||||
for m in re.finditer(r"(\d+)\s+(passed|failed|skipped|errors|xfailed|xpassed)", line):
|
||||
result[m.group(2)] = int(m.group(1))
|
||||
# Also match "N error" (singular — pytest uses this sometimes).
|
||||
for m in re.finditer(r"(\d+)\s+error\b", line):
|
||||
result.setdefault("errors", result.get("errors", 0) + int(m.group(1)))
|
||||
if result:
|
||||
# Found the counts line — done.
|
||||
break
|
||||
# Stop at the short test summary header (if any) — everything above
|
||||
# that is individual failure details, not the counts line.
|
||||
if line.startswith("FAILED") or line.startswith("SHORT TEST SUMMARY"):
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
def _format_file(file: Path, repo_root: Path) -> str:
|
||||
"""Render a test-file path for display: strip the repo-root prefix
|
||||
when possible so output reads ``tests/acp/test_auth.py`` instead of
|
||||
``/home/runner/work/hermes-agent/hermes-agent/tests/acp/test_auth.py``.
|
||||
|
||||
Falls back to the absolute path for anything outside the repo root.
|
||||
"""
|
||||
try:
|
||||
return str(file.resolve().relative_to(repo_root.resolve()))
|
||||
except ValueError:
|
||||
return str(file)
|
||||
|
||||
|
||||
def _print_progress(
|
||||
tests_done: int,
|
||||
total_tests: int,
|
||||
file: Path,
|
||||
rc: int,
|
||||
dur: float,
|
||||
repo_root: Path,
|
||||
tests_passed: int,
|
||||
tests_failed: int,
|
||||
test_counts: dict[Path, int],
|
||||
file_summary: dict[str, int] | None = None,
|
||||
) -> None:
|
||||
"""Single-line live progress.
|
||||
|
||||
When ``file_summary`` is provided (parsed from pytest output), the
|
||||
per-file parenthetical shows individual test pass/fail counts instead
|
||||
of just the total test count.
|
||||
"""
|
||||
status = "✓" if rc == 0 else "✗"
|
||||
pct = (tests_done / total_tests * 100) if total_tests else 0
|
||||
# Digit width for left-side counter padding (derived from total file count).
|
||||
fw = len(str(tests_passed + tests_failed))
|
||||
# Build per-file test count string.
|
||||
if file_summary:
|
||||
parts = []
|
||||
p = file_summary.get("passed", 0)
|
||||
f = file_summary.get("failed", 0)
|
||||
s = file_summary.get("skipped", 0)
|
||||
e = file_summary.get("errors", 0)
|
||||
if p:
|
||||
parts.append(f"{p}✓")
|
||||
if f:
|
||||
parts.append(f"{f}✗")
|
||||
if s:
|
||||
parts.append(f"{s}s")
|
||||
if e:
|
||||
parts.append(f"{e}e")
|
||||
# xfailed/xpassed are rare; include if present.
|
||||
xf = file_summary.get("xfailed", 0)
|
||||
xp = file_summary.get("xpassed", 0)
|
||||
if xf:
|
||||
parts.append(f"{xf}xf")
|
||||
if xp:
|
||||
parts.append(f"{xp}xp")
|
||||
test_str = " ".join(parts) + ", " if parts else ""
|
||||
else:
|
||||
n_tests = test_counts.get(file, 0)
|
||||
test_str = f"{n_tests} tests, " if n_tests else ""
|
||||
msg = (
|
||||
f"[{pct:5.1f}% | {tests_done:>5}/{total_tests}"
|
||||
f" | ✓{tests_passed:>{fw}} | ✗{tests_failed:>{fw}}] "
|
||||
f"{status} {_format_file(file, repo_root)} ({test_str}{dur:.1f}s)"
|
||||
)
|
||||
# Truncate to terminal width if available (no clobbering ANSI lines).
|
||||
try:
|
||||
cols = os.get_terminal_size().columns
|
||||
if len(msg) > cols:
|
||||
msg = msg[: cols - 1] + "…"
|
||||
except OSError:
|
||||
pass
|
||||
print(msg, flush=True)
|
||||
|
||||
|
||||
def _print_inline_failure(
|
||||
file: Path, output: str, repo_root: Path, pytest_passthrough: List[str]
|
||||
) -> None:
|
||||
"""Print a compact failure summary immediately when a file fails.
|
||||
|
||||
Shows the tail of the pytest output (the failure section with stack
|
||||
traces) and a ready-to-run repro command, so the developer doesn't
|
||||
have to wait for the full run to finish before seeing what broke.
|
||||
"""
|
||||
rel = _format_file(file, repo_root)
|
||||
# Build a repro command the developer can copy-paste.
|
||||
passthrough_str = " ".join(pytest_passthrough) if pytest_passthrough else ""
|
||||
repro = f"python -m pytest {rel}"
|
||||
if passthrough_str:
|
||||
repro += f" {passthrough_str}"
|
||||
|
||||
# Grab just the failure lines (last ~30 lines of pytest output —
|
||||
# typically the FAILED summary + short test info).
|
||||
lines = output.rstrip().splitlines()
|
||||
tail = "\n".join(lines[-30:])
|
||||
|
||||
print(flush=True)
|
||||
print(f" ╔╍ Failed: {rel} ╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍", flush=True)
|
||||
for line in tail.splitlines():
|
||||
print(f" ║ {line}", flush=True)
|
||||
print(f" ║", flush=True)
|
||||
print(f" ║ Repro: {repro}", flush=True)
|
||||
print(f" ╚╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍╍", flush=True)
|
||||
print(flush=True)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--jobs",
|
||||
type=int,
|
||||
default=int(os.environ.get("HERMES_TEST_WORKERS") or (os.cpu_count() or 4) * 2),
|
||||
help="Parallel worker count (default: $HERMES_TEST_WORKERS or cpu_count*2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--paths",
|
||||
default=os.environ.get("HERMES_TEST_PATHS", ":".join(_DEFAULT_ROOTS)),
|
||||
help="Colon-separated discovery roots (default: 'tests')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include-integration",
|
||||
action="store_true",
|
||||
help="Don't skip integration/ e2e/ during discovery",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file-timeout",
|
||||
type=float,
|
||||
default=float(
|
||||
os.environ.get("HERMES_TEST_FILE_TIMEOUT", _DEFAULT_FILE_TIMEOUT_SECONDS)
|
||||
),
|
||||
help=(
|
||||
"Per-file wall-clock cap in seconds. On timeout, the pytest "
|
||||
"subprocess and its full process tree are SIGKILL'd. "
|
||||
"Default: 600 (10 min), env: HERMES_TEST_FILE_TIMEOUT."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"paths_positional",
|
||||
nargs="*",
|
||||
metavar="PATH",
|
||||
help=(
|
||||
"Restrict discovery to these paths (directories or .py files). "
|
||||
"Mutually exclusive with --paths. Anything after a literal '--' "
|
||||
"separator is passed through to each per-file pytest invocation."
|
||||
),
|
||||
)
|
||||
# Manually split argv on '--' so positional paths and pytest passthrough
|
||||
# args don't fight over each other. argparse's nargs="*" positional is
|
||||
# greedy and will swallow everything after '--' including the pytest
|
||||
# flags, defeating the convention.
|
||||
argv = sys.argv[1:]
|
||||
if "--" in argv:
|
||||
sep = argv.index("--")
|
||||
our_args, pytest_passthrough = argv[:sep], argv[sep + 1 :]
|
||||
else:
|
||||
our_args, pytest_passthrough = argv, []
|
||||
args = parser.parse_args(our_args)
|
||||
|
||||
repo_root = Path(__file__).resolve().parent.parent
|
||||
|
||||
# Resolve discovery roots: positional path args override --paths if any
|
||||
# were supplied, otherwise --paths (which itself defaults to 'tests').
|
||||
if args.paths_positional:
|
||||
# Positionals can be directories OR explicit .py files. Either is
|
||||
# fine — _discover_files handles both via rglob('test_*.py') for
|
||||
# dirs and direct inclusion for files.
|
||||
roots = [repo_root / p for p in args.paths_positional]
|
||||
else:
|
||||
roots = [repo_root / p for p in args.paths.split(":") if p]
|
||||
|
||||
if args.include_integration:
|
||||
# Caller takes responsibility — typically used via explicit -k filter.
|
||||
global _SKIP_PARTS # noqa: PLW0603 — config knob
|
||||
_SKIP_PARTS = set()
|
||||
|
||||
files = _discover_files(roots)
|
||||
if not files:
|
||||
print(f"No test files discovered under {[str(r) for r in roots]}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
# Count individual tests per file via a single pytest --co pass.
|
||||
test_counts = _count_tests(files, repo_root, pytest_passthrough)
|
||||
total_tests = sum(test_counts.values())
|
||||
|
||||
print(
|
||||
f"Discovered {len(files)} test files ({total_tests} tests) under "
|
||||
f"{[str(r.relative_to(repo_root)) if r.is_relative_to(repo_root) else str(r) for r in roots]}; "
|
||||
f"running with -j {args.jobs}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Capture and print on completion (out-of-order is fine — keeps the
|
||||
# terminal clean rather than interleaving N parallel pytest outputs).
|
||||
failures: List[Tuple[Path, str, Dict[str, int]]] = []
|
||||
started = time.monotonic()
|
||||
files_done = 0
|
||||
tests_done = 0
|
||||
pass_count = 0
|
||||
fail_count = 0
|
||||
tests_passed = 0
|
||||
tests_failed = 0
|
||||
lock = threading.Lock()
|
||||
|
||||
def _on_done(file: Path, started_at: float, fut: "Future[Tuple[Path, int, str, dict[str, int]]]") -> None:
|
||||
nonlocal files_done, tests_done, pass_count, fail_count, tests_passed, tests_failed
|
||||
n_tests = test_counts.get(file, 0)
|
||||
try:
|
||||
fpath, rc, output, summary = fut.result()
|
||||
except Exception as exc: # noqa: BLE001 — must always advance counter
|
||||
with lock:
|
||||
files_done += 1
|
||||
tests_done += n_tests
|
||||
fail_count += 1
|
||||
failures.append((file, f"runner crashed: {exc!r}", {}))
|
||||
_print_progress(
|
||||
tests_done, total_tests, file, 1,
|
||||
time.monotonic() - started_at,
|
||||
repo_root, tests_passed, tests_failed,
|
||||
test_counts,
|
||||
)
|
||||
return
|
||||
with lock:
|
||||
files_done += 1
|
||||
tests_done += n_tests
|
||||
# Accumulate test-level counts from parsed summary.
|
||||
tests_passed += summary.get("passed", 0)
|
||||
tests_failed += summary.get("failed", 0)
|
||||
if rc == 0:
|
||||
pass_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
failures.append((fpath, output, summary))
|
||||
_print_progress(
|
||||
tests_done, total_tests, fpath, rc,
|
||||
time.monotonic() - started_at,
|
||||
repo_root, tests_passed, tests_failed,
|
||||
test_counts,
|
||||
file_summary=summary,
|
||||
)
|
||||
if rc != 0:
|
||||
_print_inline_failure(fpath, output, repo_root, pytest_passthrough)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=args.jobs) as pool:
|
||||
futures: List[Future] = []
|
||||
for file in files:
|
||||
t0 = time.monotonic()
|
||||
fut = pool.submit(
|
||||
_run_one_file, file, pytest_passthrough, repo_root, args.file_timeout
|
||||
)
|
||||
fut.add_done_callback(lambda f, file=file, t0=t0: _on_done(file, t0, f))
|
||||
futures.append(fut)
|
||||
# Block until everything's done. ThreadPoolExecutor.__exit__ waits
|
||||
# for all submitted work, but doing it explicitly here makes the
|
||||
# control flow obvious.
|
||||
for fut in futures:
|
||||
fut.result() if fut.exception() is None else None
|
||||
|
||||
elapsed = time.monotonic() - started
|
||||
print()
|
||||
pct = (tests_done / total_tests * 100) if total_tests else 0
|
||||
print(f"=== Summary: {len(files)} files, {tests_passed} tests passed, {tests_failed} failed ({pct:.0f}% complete) in {elapsed:.1f}s ({args.jobs} workers) ===")
|
||||
|
||||
if failures:
|
||||
print()
|
||||
print("=== Failure output ===")
|
||||
for file, output, _summary in failures:
|
||||
print()
|
||||
print(f"--- {_format_file(file, repo_root)} ---")
|
||||
print(output.rstrip())
|
||||
print()
|
||||
# Split: files with actual test failures vs non-zero exit for other reasons
|
||||
test_fail_files = [(f, s) for f, _o, s in failures if s.get("failed", 0) > 0]
|
||||
all_passed_but_nonzero = [(f, s) for f, _o, s in failures
|
||||
if s.get("failed", 0) == 0 and s.get("passed", 0) > 0]
|
||||
no_tests_ran = [(f, s) for f, _o, s in failures
|
||||
if s.get("failed", 0) == 0 and s.get("passed", 0) == 0]
|
||||
if test_fail_files:
|
||||
total_tf = sum(s.get("failed", 0) for _, s in test_fail_files)
|
||||
print(f"=== {len(test_fail_files)} file{'s' if len(test_fail_files) != 1 else ''} with test failures ({total_tf} test{'s' if total_tf != 1 else ''} failed) ===")
|
||||
for file, s in test_fail_files:
|
||||
nf = s.get("failed", 0)
|
||||
print(f" {_format_file(file, repo_root)} ({nf} test{'s' if nf != 1 else ''} failed)")
|
||||
if all_passed_but_nonzero:
|
||||
print(f"=== {len(all_passed_but_nonzero)} file{'s' if len(all_passed_but_nonzero) != 1 else ''} where all tests passed but pytest exited non-zero (warnings-as-errors, hook failures, etc.) ===")
|
||||
for file, s in all_passed_but_nonzero:
|
||||
print(f" {_format_file(file, repo_root)} ({s.get('passed', 0)} passed)")
|
||||
if no_tests_ran:
|
||||
print(f"=== {len(no_tests_ran)} file{'s' if len(no_tests_ran) != 1 else ''} where no tests ran (collection/import error, timeout before collection, etc.) ===")
|
||||
for file, s in no_tests_ran:
|
||||
print(f" {_format_file(file, repo_root)}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
|
@ -40,6 +40,16 @@ def _clean_env(monkeypatch):
|
|||
"ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
# Module-level unhealthy cache (10-min TTL) leaks between tests;
|
||||
# earlier tests that call _mark_provider_unhealthy() poison the
|
||||
# cache for later ones, causing _resolve_auto to skip providers
|
||||
# that the test patched to return valid clients.
|
||||
import agent.auxiliary_client as _aux_mod
|
||||
_aux_mod._aux_unhealthy_until.clear()
|
||||
_aux_mod._aux_unhealthy_logged_at.clear()
|
||||
yield
|
||||
_aux_mod._aux_unhealthy_until.clear()
|
||||
_aux_mod._aux_unhealthy_logged_at.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -461,6 +471,17 @@ class TestExpiredCodexFallback:
|
|||
import base64
|
||||
import time as _time
|
||||
|
||||
# Belt-and-suspenders: _try_openrouter marks openrouter unhealthy
|
||||
# when OPENROUTER_API_KEY is absent (which the preceding test in
|
||||
# this class exercises). The file-level _clean_env autouse fixture
|
||||
# clears the cache, but fixture ordering with the conftest
|
||||
# _hermetic_environment autouse can leave a narrow window where
|
||||
# the mark reappears. Explicitly clear here so this test is
|
||||
# independent of run order.
|
||||
import agent.auxiliary_client as _aux_mod
|
||||
_aux_mod._aux_unhealthy_until.clear()
|
||||
_aux_mod._aux_unhealthy_logged_at.clear()
|
||||
|
||||
header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode()
|
||||
payload_data = json.dumps({"exp": int(_time.time()) - 3600}).encode()
|
||||
payload = base64.urlsafe_b64encode(payload_data).rstrip(b"=").decode()
|
||||
|
|
@ -1047,6 +1068,20 @@ class TestGetProviderChain:
|
|||
class TestTryPaymentFallback:
|
||||
"""_try_payment_fallback skips the failed provider and tries alternatives."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_unhealthy_cache(self):
|
||||
"""Earlier tests in this file call _mark_provider_unhealthy() which
|
||||
pollutes the module-level ``_aux_unhealthy_until`` dict (10-min TTL).
|
||||
Without this cleanup the fallback chain skips providers we've patched
|
||||
to return valid clients — the patched function is never called.
|
||||
"""
|
||||
from agent.auxiliary_client import _aux_unhealthy_until, _aux_unhealthy_logged_at
|
||||
_aux_unhealthy_until.clear()
|
||||
_aux_unhealthy_logged_at.clear()
|
||||
yield
|
||||
_aux_unhealthy_until.clear()
|
||||
_aux_unhealthy_logged_at.clear()
|
||||
|
||||
def test_skips_failed_provider(self):
|
||||
mock_client = MagicMock()
|
||||
with patch("agent.auxiliary_client._try_openrouter", return_value=(None, None)), \
|
||||
|
|
|
|||
93
tests/agent/test_custom_provider_extra_body.py
Normal file
93
tests/agent/test_custom_provider_extra_body.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
from types import SimpleNamespace
|
||||
|
||||
from agent.agent_init import _merge_custom_provider_extra_body
|
||||
|
||||
|
||||
def test_custom_provider_extra_body_merges_into_request_overrides():
|
||||
agent = SimpleNamespace(
|
||||
provider="custom",
|
||||
model="google/gemma-4-31b-it",
|
||||
base_url="https://example.test/v1",
|
||||
request_overrides={"service_tier": "priority"},
|
||||
)
|
||||
|
||||
_merge_custom_provider_extra_body(
|
||||
agent,
|
||||
[
|
||||
{
|
||||
"name": "gemma",
|
||||
"base_url": "https://example.test/v1/",
|
||||
"model": "google/gemma-4-31b-it",
|
||||
"extra_body": {
|
||||
"enable_thinking": True,
|
||||
"reasoning_effort": "high",
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
assert agent.request_overrides == {
|
||||
"service_tier": "priority",
|
||||
"extra_body": {
|
||||
"enable_thinking": True,
|
||||
"reasoning_effort": "high",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_custom_provider_extra_body_preserves_caller_override():
|
||||
agent = SimpleNamespace(
|
||||
provider="custom",
|
||||
model="google/gemma-4-31b-it",
|
||||
base_url="https://example.test/v1",
|
||||
request_overrides={
|
||||
"extra_body": {
|
||||
"reasoning_effort": "low",
|
||||
"caller_only": True,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
_merge_custom_provider_extra_body(
|
||||
agent,
|
||||
[
|
||||
{
|
||||
"name": "gemma",
|
||||
"base_url": "https://example.test/v1",
|
||||
"model": "google/gemma-4-31b-it",
|
||||
"extra_body": {
|
||||
"enable_thinking": True,
|
||||
"reasoning_effort": "high",
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
assert agent.request_overrides["extra_body"] == {
|
||||
"enable_thinking": True,
|
||||
"reasoning_effort": "low",
|
||||
"caller_only": True,
|
||||
}
|
||||
|
||||
|
||||
def test_custom_provider_extra_body_ignores_other_custom_models():
|
||||
agent = SimpleNamespace(
|
||||
provider="custom",
|
||||
model="other-model",
|
||||
base_url="https://example.test/v1",
|
||||
request_overrides={},
|
||||
)
|
||||
|
||||
_merge_custom_provider_extra_body(
|
||||
agent,
|
||||
[
|
||||
{
|
||||
"name": "gemma",
|
||||
"base_url": "https://example.test/v1",
|
||||
"model": "google/gemma-4-31b-it",
|
||||
"extra_body": {"enable_thinking": True},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
assert agent.request_overrides == {}
|
||||
|
|
@ -9,8 +9,11 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
|
||||
from agent.image_routing import (
|
||||
_coerce_capability_bool,
|
||||
_coerce_mode,
|
||||
_explicit_aux_vision_override,
|
||||
_lookup_supports_vision,
|
||||
_supports_vision_override,
|
||||
build_native_content_parts,
|
||||
decide_image_input_mode,
|
||||
)
|
||||
|
|
@ -125,6 +128,168 @@ class TestDecideImageInputMode:
|
|||
assert decide_image_input_mode("xiaomi", "mimo-v2.5-pro", {}) == "text"
|
||||
|
||||
|
||||
# ─── _coerce_capability_bool ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCoerceCapabilityBool:
|
||||
def test_real_bool_passes_through(self):
|
||||
assert _coerce_capability_bool(True) is True
|
||||
assert _coerce_capability_bool(False) is False
|
||||
|
||||
def test_int_0_and_1(self):
|
||||
assert _coerce_capability_bool(1) is True
|
||||
assert _coerce_capability_bool(0) is False
|
||||
|
||||
def test_other_ints_return_none(self):
|
||||
assert _coerce_capability_bool(2) is None
|
||||
assert _coerce_capability_bool(-1) is None
|
||||
|
||||
def test_yaml_true_tokens(self):
|
||||
for s in ("true", "TRUE", "True", "yes", "on", "1", " true "):
|
||||
assert _coerce_capability_bool(s) is True
|
||||
|
||||
def test_yaml_false_tokens(self):
|
||||
for s in ("false", "FALSE", "False", "no", "off", "0", " false "):
|
||||
assert _coerce_capability_bool(s) is False
|
||||
|
||||
def test_quoted_false_does_not_silently_become_true(self):
|
||||
# Regression: bool("false") is True in Python. A user writing
|
||||
# supports_vision: "false" must NOT enable native vision routing.
|
||||
assert _coerce_capability_bool("false") is False
|
||||
|
||||
def test_unrecognised_strings_return_none(self):
|
||||
# None == fall through to models.dev, not a silent truthy.
|
||||
assert _coerce_capability_bool("maybe") is None
|
||||
assert _coerce_capability_bool("") is None
|
||||
assert _coerce_capability_bool("definitely") is None
|
||||
|
||||
def test_other_types_return_none(self):
|
||||
assert _coerce_capability_bool(None) is None
|
||||
assert _coerce_capability_bool([]) is None
|
||||
assert _coerce_capability_bool({}) is None
|
||||
assert _coerce_capability_bool(1.5) is None
|
||||
|
||||
|
||||
# ─── _supports_vision_override ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSupportsVisionOverride:
|
||||
def test_no_cfg_returns_none(self):
|
||||
assert _supports_vision_override(None, "custom", "my-llava") is None
|
||||
assert _supports_vision_override({}, "custom", "my-llava") is None
|
||||
|
||||
def test_top_level_shortcut_wins(self):
|
||||
cfg = {"model": {"supports_vision": True}}
|
||||
assert _supports_vision_override(cfg, "custom", "my-llava") is True
|
||||
|
||||
def test_top_level_false_propagates(self):
|
||||
cfg = {"model": {"supports_vision": False}}
|
||||
assert _supports_vision_override(cfg, "custom", "my-llava") is False
|
||||
|
||||
def test_per_provider_per_model_via_runtime_name(self):
|
||||
cfg = {
|
||||
"providers": {
|
||||
"custom": {"models": {"my-llava": {"supports_vision": True}}},
|
||||
},
|
||||
}
|
||||
assert _supports_vision_override(cfg, "custom", "my-llava") is True
|
||||
|
||||
def test_per_provider_per_model_via_config_name(self):
|
||||
# Named custom provider — runtime self.provider == "custom", config
|
||||
# holds the original name under model.provider.
|
||||
cfg = {
|
||||
"model": {"provider": "my-vllm"},
|
||||
"providers": {
|
||||
"my-vllm": {"models": {"my-llava": {"supports_vision": True}}},
|
||||
},
|
||||
}
|
||||
assert _supports_vision_override(cfg, "custom", "my-llava") is True
|
||||
|
||||
def test_quoted_false_string_in_yaml_does_not_enable(self):
|
||||
# Real-world: user writes supports_vision: "false" (quoted).
|
||||
cfg = {"model": {"supports_vision": "false"}}
|
||||
assert _supports_vision_override(cfg, "custom", "my-llava") is False
|
||||
|
||||
def test_unrecognised_value_falls_through(self):
|
||||
cfg = {"model": {"supports_vision": "maybe"}}
|
||||
assert _supports_vision_override(cfg, "custom", "my-llava") is None
|
||||
|
||||
def test_no_override_returns_none(self):
|
||||
cfg = {"model": {"default": "my-llava"}}
|
||||
assert _supports_vision_override(cfg, "custom", "my-llava") is None
|
||||
|
||||
def test_malformed_sections_are_ignored(self):
|
||||
# User accidentally wrote a string where a section was expected —
|
||||
# don't blow up, just fall through.
|
||||
cfg = {"model": "some-string", "providers": ["not-a-dict"]}
|
||||
assert _supports_vision_override(cfg, "custom", "my-llava") is None
|
||||
|
||||
|
||||
# ─── _lookup_supports_vision (override-aware) ────────────────────────────────
|
||||
|
||||
|
||||
class TestLookupSupportsVisionOverride:
|
||||
def test_config_override_short_circuits_models_dev(self):
|
||||
# Config says True, models.dev says None — config wins.
|
||||
cfg = {"model": {"supports_vision": True}}
|
||||
with patch("agent.models_dev.get_model_capabilities", return_value=None):
|
||||
assert _lookup_supports_vision("custom", "my-llava", cfg) is True
|
||||
|
||||
def test_config_override_false_beats_vision_capable_models_dev(self):
|
||||
# User explicitly disables vision on a models.dev-vision-capable model.
|
||||
fake_caps = type("Caps", (), {"supports_vision": True})()
|
||||
cfg = {"model": {"supports_vision": False}}
|
||||
with patch("agent.models_dev.get_model_capabilities", return_value=fake_caps):
|
||||
assert _lookup_supports_vision("anthropic", "claude-sonnet-4", cfg) is False
|
||||
|
||||
def test_no_override_falls_back_to_models_dev(self):
|
||||
fake_caps = type("Caps", (), {"supports_vision": True})()
|
||||
with patch("agent.models_dev.get_model_capabilities", return_value=fake_caps):
|
||||
assert _lookup_supports_vision("anthropic", "claude-sonnet-4", {}) is True
|
||||
|
||||
def test_no_override_no_models_dev_entry_returns_none(self):
|
||||
with patch("agent.models_dev.get_model_capabilities", return_value=None):
|
||||
assert _lookup_supports_vision("custom", "my-llava", {}) is None
|
||||
|
||||
def test_cfg_none_falls_back_to_models_dev(self):
|
||||
# Caller didn't pass cfg at all — old call sites must still work.
|
||||
with patch("agent.models_dev.get_model_capabilities", return_value=None):
|
||||
assert _lookup_supports_vision("openrouter", "x", None) is None
|
||||
|
||||
|
||||
# ─── decide_image_input_mode with auto + override ────────────────────────────
|
||||
|
||||
|
||||
class TestAutoModeRespectsOverride:
|
||||
def test_auto_native_for_custom_with_supports_vision_true(self):
|
||||
# The motivating bug: Qwen3.6 on local llama.cpp via provider=custom.
|
||||
# Without the override, auto falls back to text. With it, auto picks
|
||||
# native — no need to also set agent.image_input_mode: native.
|
||||
cfg = {"model": {"supports_vision": True}}
|
||||
with patch("agent.models_dev.get_model_capabilities", return_value=None):
|
||||
assert decide_image_input_mode("custom", "qwen3.6-35b", cfg) == "native"
|
||||
|
||||
def test_auto_text_for_custom_with_supports_vision_false(self):
|
||||
cfg = {"model": {"supports_vision": False}}
|
||||
with patch("agent.models_dev.get_model_capabilities", return_value=None):
|
||||
assert decide_image_input_mode("custom", "some-text-only", cfg) == "text"
|
||||
|
||||
def test_auto_text_for_custom_with_no_override(self):
|
||||
# Unchanged baseline: unknown custom model → text.
|
||||
with patch("agent.models_dev.get_model_capabilities", return_value=None):
|
||||
assert decide_image_input_mode("custom", "unknown", {}) == "text"
|
||||
|
||||
def test_explicit_aux_vision_override_still_wins(self):
|
||||
# If the user has configured a dedicated vision aux backend, respect
|
||||
# it even when supports_vision: true is also set.
|
||||
cfg = {
|
||||
"model": {"supports_vision": True},
|
||||
"auxiliary": {"vision": {"provider": "openrouter", "model": "gemini-2.5-pro"}},
|
||||
}
|
||||
with patch("agent.models_dev.get_model_capabilities", return_value=None):
|
||||
assert decide_image_input_mode("custom", "qwen3.6-35b", cfg) == "text"
|
||||
|
||||
|
||||
# ─── build_native_content_parts ──────────────────────────────────────────────
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -556,10 +556,11 @@ Generate some audio.
|
|||
raising=False,
|
||||
)
|
||||
|
||||
with patch.dict(
|
||||
os.environ, {"HERMES_SESSION_PLATFORM": "telegram"}, clear=False
|
||||
):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
with patch("tools.skills_tool.SKILLS_DIR", tmp_path):
|
||||
from gateway.session_context import clear_session_vars, set_session_vars
|
||||
|
||||
tokens = set_session_vars(platform="telegram")
|
||||
try:
|
||||
_make_skill(
|
||||
tmp_path,
|
||||
"test-skill",
|
||||
|
|
@ -571,6 +572,8 @@ Generate some audio.
|
|||
)
|
||||
scan_skill_commands()
|
||||
msg = build_skill_invocation_message("/test-skill", "do stuff")
|
||||
finally:
|
||||
clear_session_vars(tokens)
|
||||
|
||||
assert msg is not None
|
||||
assert "local cli" in msg.lower()
|
||||
|
|
|
|||
|
|
@ -196,14 +196,13 @@ class TestCodexBuildKwargs:
|
|||
)
|
||||
# xAI Responses receives reasoning.effort on the allowlisted models.
|
||||
assert kw.get("reasoning") == {"effort": "high"}
|
||||
# As of May 2026 we deliberately do NOT request
|
||||
# reasoning.encrypted_content back from xAI — the OAuth/SuperGrok
|
||||
# surface rejects replayed encrypted reasoning items on turn 2+
|
||||
# (the multi-turn "Expected to have received response.created
|
||||
# before error" failure). Grok still reasons natively each turn;
|
||||
# we just don't try to thread the prior turn's encrypted blob back
|
||||
# in. See tests/run_agent/test_codex_xai_oauth_recovery.py.
|
||||
assert "reasoning.encrypted_content" not in kw.get("include", [])
|
||||
# As of May 2026 (post-revert of PR #26644) we DO request
|
||||
# reasoning.encrypted_content back from xAI so we can replay it
|
||||
# across turns for cross-turn coherence — xAI explicitly relies
|
||||
# on this for their partnership integration. See
|
||||
# tests/run_agent/test_codex_xai_oauth_recovery.py for the
|
||||
# full history.
|
||||
assert "reasoning.encrypted_content" in kw.get("include", [])
|
||||
|
||||
def test_xai_reasoning_disabled_no_reasoning_key(self, transport):
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
|
|
@ -229,9 +228,9 @@ class TestCodexBuildKwargs:
|
|||
# api.x.ai 400s with "Model X does not support parameter reasoningEffort"
|
||||
# on grok-4 / grok-4-fast / grok-3 / grok-code-fast / grok-4.20-0309-*.
|
||||
# Those models reason natively but don't expose the dial. The transport
|
||||
# must omit the `reasoning` key for them. As of May 2026 we also no
|
||||
# longer request ``reasoning.encrypted_content`` back from xAI on ANY
|
||||
# model — see test_xai_reasoning_effort_passed for the rationale.
|
||||
# must omit the `reasoning` key for them. As of May 2026 we DO request
|
||||
# ``reasoning.encrypted_content`` back from xAI on every model —
|
||||
# see test_xai_reasoning_effort_passed for the rationale.
|
||||
|
||||
def test_xai_grok_4_omits_reasoning_effort(self, transport):
|
||||
"""grok-4 / grok-4-0709 reject reasoning.effort with HTTP 400."""
|
||||
|
|
@ -245,9 +244,9 @@ class TestCodexBuildKwargs:
|
|||
assert "reasoning" not in kw, (
|
||||
f"{model} must not receive a reasoning key (xAI rejects it)"
|
||||
)
|
||||
# We no longer ask xAI for encrypted_content back (see comment
|
||||
# above) — verify the include list is empty.
|
||||
assert "reasoning.encrypted_content" not in kw.get("include", [])
|
||||
# Even without the effort dial we still ask xAI to echo back
|
||||
# encrypted reasoning content so it can be replayed next turn.
|
||||
assert "reasoning.encrypted_content" in kw.get("include", [])
|
||||
|
||||
def test_xai_grok_4_fast_omits_reasoning_effort(self, transport):
|
||||
"""grok-4-fast and grok-4-1-fast variants reject reasoning.effort."""
|
||||
|
|
|
|||
|
|
@ -20,12 +20,9 @@ test runner at ``scripts/run_tests.sh``.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -37,6 +34,22 @@ if str(PROJECT_ROOT) not in sys.path:
|
|||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
|
||||
# ── Per-file process isolation ──────────────────────────────────────────────
|
||||
# Tests run via ``scripts/run_tests_parallel.py``, which spawns a fresh
|
||||
# ``python -m pytest <file>`` subprocess per test file. Cross-file state
|
||||
# leakage (module-level dicts, ContextVars, caches) is impossible: each
|
||||
# file gets a clean Python interpreter. Intra-file ordering is the test
|
||||
# author's responsibility — if test A in foo.py mutates state that test B
|
||||
# in foo.py reads, that's a real bug to fix in the file (it would also
|
||||
# bite anyone running ``pytest tests/foo.py`` directly).
|
||||
#
|
||||
# This replaces the historic _reset_module_state autouse fixture (manual
|
||||
# state clearing) and the brief experiment with subprocess-per-test
|
||||
# isolation (too slow at ~17k tests).
|
||||
#
|
||||
# See ``scripts/run_tests_parallel.py`` for the runner.
|
||||
|
||||
|
||||
# ── Credential env-var filter ──────────────────────────────────────────────
|
||||
#
|
||||
# Any env var in the current process matching ONE of these patterns is
|
||||
|
|
@ -279,7 +292,7 @@ _HERMES_BEHAVIORAL_VARS = frozenset({
|
|||
"WECOM_HOME_CHANNEL_NAME",
|
||||
# Platform gating — set by load_gateway_config() as a side effect when
|
||||
# a config.yaml is present, so individual test bodies that call the
|
||||
# loader leak these values into later tests on the same xdist worker.
|
||||
# loader leak these values into later tests in the same process.
|
||||
# Force-clear on every test setup so the leak can't happen.
|
||||
"SLACK_REQUIRE_MENTION",
|
||||
"SLACK_STRICT_MENTION",
|
||||
|
|
@ -368,144 +381,21 @@ def _isolate_hermes_home(_hermetic_environment):
|
|||
return None
|
||||
|
||||
|
||||
# ── Module-level state reset ───────────────────────────────────────────────
|
||||
# ── Module-level state reset — replaced by per-file process isolation ──────
|
||||
#
|
||||
# Python modules are singletons per process, and pytest-xdist workers are
|
||||
# long-lived. Module-level dicts/sets (tool registries, approval state,
|
||||
# interrupt flags) and ContextVars persist across tests in the same worker,
|
||||
# causing tests that pass alone to fail when run with siblings.
|
||||
# Each test FILE runs in a freshly-spawned ``python -m pytest <file>``
|
||||
# subprocess via ``scripts/run_tests_parallel.py``, so module-level dicts /
|
||||
# sets / ContextVars from tests in one file cannot leak into tests in
|
||||
# another file. No manual per-module clearing needed.
|
||||
#
|
||||
# Each entry in this fixture clears state that belongs to a specific module.
|
||||
# New state buckets go here too — this is the single gate that prevents
|
||||
# "works alone, flakes in CI" bugs from state leakage.
|
||||
# Within a single file, ordering is the author's responsibility. If your
|
||||
# tests in the same file share mutable state, either reset it explicitly
|
||||
# in a fixture or split them across files.
|
||||
#
|
||||
# The skill `test-suite-cascade-diagnosis` documents the concrete patterns
|
||||
# this closes; the running example was `test_command_guards` failing 12/15
|
||||
# CI runs because ``tools.approval._session_approved`` carried approvals
|
||||
# from one test's session into another's.
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_module_state():
|
||||
"""Clear module-level mutable state and ContextVars between tests.
|
||||
|
||||
Keeps state from leaking across tests on the same xdist worker. Modules
|
||||
that don't exist yet (test collection before production import) are
|
||||
skipped silently — production import later creates fresh empty state.
|
||||
"""
|
||||
# --- logging — quiet/one-shot paths mutate process-global logger state ---
|
||||
logging.disable(logging.NOTSET)
|
||||
for _logger_name in ("tools", "run_agent", "trajectory_compressor", "cron", "hermes_cli"):
|
||||
_logger = logging.getLogger(_logger_name)
|
||||
_logger.disabled = False
|
||||
_logger.setLevel(logging.NOTSET)
|
||||
_logger.propagate = True
|
||||
|
||||
# --- tools.approval — the single biggest source of cross-test pollution ---
|
||||
try:
|
||||
from tools import approval as _approval_mod
|
||||
_approval_mod._session_approved.clear()
|
||||
_approval_mod._session_yolo.clear()
|
||||
_approval_mod._permanent_approved.clear()
|
||||
_approval_mod._pending.clear()
|
||||
_approval_mod._gateway_queues.clear()
|
||||
_approval_mod._gateway_notify_cbs.clear()
|
||||
# ContextVar: reset to empty string so get_current_session_key()
|
||||
# falls through to the env var / default path, matching a fresh
|
||||
# process.
|
||||
_approval_mod._approval_session_key.set("")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- tools.interrupt — per-thread interrupt flag set ---
|
||||
try:
|
||||
from tools import interrupt as _interrupt_mod
|
||||
with _interrupt_mod._lock:
|
||||
_interrupt_mod._interrupted_threads.clear()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- gateway.session_context — 9 ContextVars that represent
|
||||
# the active gateway session. If set in one test and not reset,
|
||||
# the next test's get_session_env() reads stale values.
|
||||
try:
|
||||
from gateway import session_context as _sc_mod
|
||||
for _cv in (
|
||||
_sc_mod._SESSION_PLATFORM,
|
||||
_sc_mod._SESSION_CHAT_ID,
|
||||
_sc_mod._SESSION_CHAT_NAME,
|
||||
_sc_mod._SESSION_THREAD_ID,
|
||||
_sc_mod._SESSION_USER_ID,
|
||||
_sc_mod._SESSION_USER_NAME,
|
||||
_sc_mod._SESSION_KEY,
|
||||
_sc_mod._CRON_AUTO_DELIVER_PLATFORM,
|
||||
_sc_mod._CRON_AUTO_DELIVER_CHAT_ID,
|
||||
_sc_mod._CRON_AUTO_DELIVER_THREAD_ID,
|
||||
):
|
||||
_cv.set(_sc_mod._UNSET)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- tools.env_passthrough — ContextVar<set[str]> with no default ---
|
||||
# LookupError is normal if the test never set it. Setting it to an
|
||||
# empty set unconditionally normalizes the starting state.
|
||||
try:
|
||||
from tools import env_passthrough as _envp_mod
|
||||
_envp_mod._allowed_env_vars_var.set(set())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- tools.terminal_tool — active environment/cwd cache ---
|
||||
# File tools prefer a live terminal cwd when one is cached for the task.
|
||||
# Clear terminal environments between tests so a prior terminal call can't
|
||||
# override TERMINAL_CWD in path-resolution tests.
|
||||
try:
|
||||
from tools import terminal_tool as _term_mod
|
||||
_envs_to_cleanup = []
|
||||
with _term_mod._env_lock:
|
||||
_envs_to_cleanup = list(_term_mod._active_environments.values())
|
||||
_term_mod._active_environments.clear()
|
||||
_term_mod._last_activity.clear()
|
||||
_term_mod._creation_locks.clear()
|
||||
for _env in _envs_to_cleanup:
|
||||
try:
|
||||
_env.cleanup()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- tools.credential_files — ContextVar<dict> ---
|
||||
try:
|
||||
from tools import credential_files as _credf_mod
|
||||
_credf_mod._registered_files_var.set({})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- agent.auxiliary_client — runtime main provider/model override and
|
||||
# payment-error health cache. Both are process-global in production;
|
||||
# reset them per test so one worker's fallback/402 test does not make
|
||||
# later auxiliary-client tests skip otherwise-available providers.
|
||||
try:
|
||||
from agent import auxiliary_client as _aux_mod
|
||||
_aux_mod.clear_runtime_main()
|
||||
_aux_mod._reset_aux_unhealthy_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# --- tools.file_tools — per-task read history + file-ops cache ---
|
||||
# _read_tracker accumulates per-task_id read history for loop detection,
|
||||
# capped by _READ_HISTORY_CAP. If entries from a prior test persist, the
|
||||
# cap is hit faster than expected and capacity-related tests flake.
|
||||
try:
|
||||
from tools import file_tools as _ft_mod
|
||||
with _ft_mod._read_tracker_lock:
|
||||
_ft_mod._read_tracker.clear()
|
||||
with _ft_mod._file_ops_lock:
|
||||
_ft_mod._file_ops_cache.clear()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield
|
||||
# The skill ``test-suite-cascade-diagnosis`` documents the cascade patterns
|
||||
# this replaces; the running example was ``test_command_guards`` failing
|
||||
# 12/15 CI runs because ``tools.approval._session_approved`` carried
|
||||
# approvals from one test's session into another's.
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
@ -532,13 +422,12 @@ def mock_config():
|
|||
}
|
||||
|
||||
|
||||
# ── Global test timeout ─────────────────────────────────────────────────────
|
||||
# Kill any individual test that takes longer than 30 seconds.
|
||||
# Prevents hanging tests (subprocess spawns, blocking I/O) from stalling the
|
||||
# entire test suite.
|
||||
# ── Per-test timeout — handled by the isolation plugin ─────────────────────
|
||||
#
|
||||
# The subprocess-per-test plugin enforces the configured ``isolate_timeout``
|
||||
# ini key by terminating the child if it overruns. The old SIGALRM-based
|
||||
# fixture (POSIX-only, didn't work on Windows) is gone.
|
||||
|
||||
def _timeout_handler(signum, frame):
|
||||
raise TimeoutError("Test exceeded 30 second timeout")
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _ensure_current_event_loop(request):
|
||||
|
|
@ -584,45 +473,6 @@ def _ensure_current_event_loop(request):
|
|||
asyncio.set_event_loop(None)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enforce_test_timeout():
|
||||
"""Kill any individual test that takes longer than 30 seconds.
|
||||
SIGALRM is Unix-only; skip on Windows."""
|
||||
if sys.platform == "win32":
|
||||
yield
|
||||
return
|
||||
old = signal.signal(signal.SIGALRM, _timeout_handler)
|
||||
signal.alarm(30)
|
||||
yield
|
||||
signal.alarm(0)
|
||||
signal.signal(signal.SIGALRM, old)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_tool_registry_caches():
|
||||
"""Clear tool-registry-level caches between tests.
|
||||
|
||||
The production registry caches ``check_fn()`` results for 30 s
|
||||
(see tools/registry.py) and :func:`get_tool_definitions` memoizes
|
||||
its result (see model_tools.py). Both are keyed on state that tests
|
||||
routinely mutate (env vars, registry._generation, config.yaml mtime)
|
||||
— but a stale result from test A can still be served to test B
|
||||
because 30 s covers the entire suite, and xdist worker reuse means
|
||||
one test's cache lands in another's process. Clearing before every
|
||||
test keeps hermetic behavior.
|
||||
"""
|
||||
try:
|
||||
from tools.registry import invalidate_check_fn_cache
|
||||
invalidate_check_fn_cache()
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from model_tools import _clear_tool_defs_cache
|
||||
_clear_tool_defs_cache()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# ── Live-system guard ──────────────────────────────────────────────────────
|
||||
#
|
||||
# Several test files exercise the gateway-restart / kill code paths
|
||||
|
|
|
|||
|
|
@ -313,19 +313,30 @@ def _scan_for_plugin_adapter_antipattern(source: str) -> list[str]:
|
|||
return offenses
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Reject plugin-adapter tests that use the sys.path anti-pattern.
|
||||
def _fingerprint_gateway_tests() -> str:
|
||||
"""Return a short fingerprint that changes when any gateway test file changes.
|
||||
|
||||
Runs once per pytest session on the controller, BEFORE any xdist
|
||||
worker is spawned. If any file under ``tests/gateway/`` matches the
|
||||
anti-pattern, we fail the whole session with a clear message —
|
||||
before a polluted ``sys.path`` can cascade across workers.
|
||||
Uses (mtime, size) pairs instead of content hashing — fast to compute
|
||||
(stat-only, no reads) and sufficient for cache invalidation across
|
||||
per-file subprocess runs.
|
||||
"""
|
||||
# Only run on the xdist controller (or in non-xdist runs). Skip on
|
||||
# worker subprocesses so we don't scan the filesystem N times.
|
||||
if hasattr(config, "workerinput"):
|
||||
return
|
||||
import hashlib
|
||||
|
||||
h = hashlib.sha256()
|
||||
for path in sorted(_GATEWAY_DIR.rglob("test_*.py")):
|
||||
try:
|
||||
st = path.stat()
|
||||
h.update(f"{path.name}:{st.st_mtime_ns}:{st.st_size}".encode())
|
||||
except OSError:
|
||||
h.update(f"{path.name}:missing".encode())
|
||||
return h.hexdigest()[:16]
|
||||
|
||||
|
||||
def _run_adapter_antipattern_scan() -> list[str]:
|
||||
"""Scan gateway test files for the plugin-adapter anti-pattern.
|
||||
|
||||
Returns a list of violation strings (empty if clean).
|
||||
"""
|
||||
violations: list[str] = []
|
||||
for path in _GATEWAY_DIR.rglob("test_*.py"):
|
||||
if path.name in {"_plugin_adapter_loader.py", "conftest.py"}:
|
||||
|
|
@ -334,20 +345,108 @@ def pytest_configure(config):
|
|||
source = path.read_text(encoding="utf-8")
|
||||
except OSError:
|
||||
continue
|
||||
# Fast string pre-filter: skip files that can't possibly violate.
|
||||
# A violating file MUST contain both (a) an adapter/plugins/platforms
|
||||
# reference AND (b) either sys.path manipulation or a bare adapter import.
|
||||
if "adapter" not in source and "plugins/platforms" not in source:
|
||||
continue
|
||||
if not (
|
||||
"sys.path" in source
|
||||
or "import adapter" in source
|
||||
or "from adapter import" in source
|
||||
):
|
||||
continue
|
||||
offenses = _scan_for_plugin_adapter_antipattern(source)
|
||||
if offenses:
|
||||
violations.append(
|
||||
f" {path.relative_to(_GATEWAY_DIR.parent.parent)}:\n "
|
||||
+ "\n ".join(offenses)
|
||||
)
|
||||
return violations
|
||||
|
||||
if violations:
|
||||
raise pytest.UsageError(
|
||||
"Plugin-adapter-import anti-pattern detected in gateway tests:\n"
|
||||
+ "\n".join(violations)
|
||||
+ "\n\n"
|
||||
+ _GUARD_HINT
|
||||
)
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Reject plugin-adapter tests that use the sys.path anti-pattern.
|
||||
|
||||
Runs once per pytest session on the controller, BEFORE any xdist
|
||||
worker is spawned. If any file under ``tests/gateway/`` matches the
|
||||
anti-pattern, we fail the whole session with a clear message —
|
||||
before a polluted ``sys.path`` can cascade across workers.
|
||||
|
||||
**Performance**: in the per-file subprocess isolation model (no xdist),
|
||||
every subprocess is a "controller" — so the naive scan would run 257
|
||||
times, each costing ~1s of AST walking. We avoid this with two
|
||||
strategies:
|
||||
|
||||
1. **Tight string pre-filter**: a file can only violate if it contains
|
||||
*both* an adapter/plugins/platforms reference *and* a sys.path
|
||||
manipulation or bare ``import adapter``. This drops ~95% of files
|
||||
from needing AST parsing.
|
||||
2. **File-locked cache**: the scan result is cached in
|
||||
``.pytest-cache/gw-adapter-guard-<fingerprint>`` keyed on a
|
||||
fingerprint of the gateway test file mtimes/sizes. Concurrent
|
||||
subprocesses acquire a lock; only the first performs the scan;
|
||||
the rest wait and read the cached result.
|
||||
"""
|
||||
# Only run on the xdist controller (or in non-xdist runs). Skip on
|
||||
# worker subprocesses so we don't scan the filesystem N times.
|
||||
if hasattr(config, "workerinput"):
|
||||
return
|
||||
|
||||
fp = _fingerprint_gateway_tests()
|
||||
cache_dir = Path.cwd() / ".pytest-cache"
|
||||
cache_file = cache_dir / f"gw-adapter-guard-{fp}"
|
||||
lock_file = cache_dir / f".gw-adapter-guard-{fp}.lock"
|
||||
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Evict stale cache entries from previous fingerprints (best-effort).
|
||||
try:
|
||||
for old in cache_dir.glob("gw-adapter-guard-*"):
|
||||
if old.name != f"gw-adapter-guard-{fp}":
|
||||
old.unlink(missing_ok=True)
|
||||
for old in cache_dir.glob(".gw-adapter-guard-*.lock"):
|
||||
if old.name != f".gw-adapter-guard-{fp}.lock":
|
||||
old.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass # Non-critical; old files are harmless.
|
||||
|
||||
# Use filelock to ensure only one process scans at a time.
|
||||
# Concurrent subprocesses all hit pytest_configure simultaneously;
|
||||
# without a lock they'd all find no cache and all run the scan.
|
||||
try:
|
||||
from filelock import FileLock
|
||||
lock = FileLock(str(lock_file), timeout=120)
|
||||
except ImportError:
|
||||
# Fallback: no locking (still correct, just slower under contention).
|
||||
import contextlib
|
||||
|
||||
class _NoLock:
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, *a):
|
||||
pass
|
||||
lock = _NoLock()
|
||||
|
||||
with lock:
|
||||
if cache_file.exists():
|
||||
cached = cache_file.read_text(encoding="utf-8")
|
||||
if cached == "clean":
|
||||
return
|
||||
raise pytest.UsageError(cached)
|
||||
|
||||
# Slow path: this process is the first to acquire the lock.
|
||||
violations = _run_adapter_antipattern_scan()
|
||||
|
||||
if violations:
|
||||
msg = (
|
||||
"Plugin-adapter-import anti-pattern detected in gateway tests:\n"
|
||||
+ "\n".join(violations)
|
||||
+ "\n\n"
|
||||
+ _GUARD_HINT
|
||||
)
|
||||
cache_file.write_text(msg, encoding="utf-8")
|
||||
raise pytest.UsageError(msg)
|
||||
else:
|
||||
cache_file.write_text("clean", encoding="utf-8")
|
||||
|
||||
|
|
|
|||
0
tests/gateway/platforms/__init__.py
Normal file
0
tests/gateway/platforms/__init__.py
Normal file
88
tests/gateway/platforms/test_yuanbao_recall_db_only.py
Normal file
88
tests/gateway/platforms/test_yuanbao_recall_db_only.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
"""Yuanbao recall: branch A1 (exact id) and A2 (content-match) against DB-only transcripts.
|
||||
|
||||
state.db persists the platform-side ``message_id`` via the
|
||||
``platform_message_id`` column (added in the salvage of PR #29211) and
|
||||
``load_transcript`` surfaces it back on each message dict as ``message_id``
|
||||
— so the recall guard's exact-id match path stays canonical even with the
|
||||
JSONL file gone. When a row has no platform id (e.g. agent-processed
|
||||
@bot messages whose adapter didn't carry a msg_id, or pre-column legacy
|
||||
rows), recall falls through to content-match.
|
||||
"""
|
||||
from gateway.session import SessionStore
|
||||
from gateway.config import GatewayConfig
|
||||
|
||||
|
||||
def _pin_db(monkeypatch, tmp_path):
|
||||
"""Force SessionDB() to write into tmp_path instead of the real ~/.hermes."""
|
||||
import hermes_state
|
||||
monkeypatch.setattr(hermes_state, "DEFAULT_DB_PATH", tmp_path / "state.db")
|
||||
|
||||
|
||||
def test_recall_branch_a1_exact_id_match_round_trips_through_db(tmp_path, monkeypatch):
|
||||
"""A user message persisted with ``message_id`` must round-trip through
|
||||
state.db so recall can find and redact it by exact id (branch A1)."""
|
||||
_pin_db(monkeypatch, tmp_path)
|
||||
|
||||
config = GatewayConfig()
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
|
||||
sid = "test-yuanbao-recall-a1"
|
||||
store._db.create_session(session_id=sid, source="yuanbao:group:G")
|
||||
store.append_to_transcript(sid, {
|
||||
"role": "user",
|
||||
"content": "sensitive content",
|
||||
"timestamp": 1.0,
|
||||
"message_id": "platform-msg-abc",
|
||||
})
|
||||
store.append_to_transcript(sid, {
|
||||
"role": "assistant",
|
||||
"content": "ack",
|
||||
"timestamp": 2.0,
|
||||
})
|
||||
|
||||
history = store.load_transcript(sid)
|
||||
# The user row must carry its platform id back so the recall guard can
|
||||
# match by exact id; the assistant row had no platform id so it should
|
||||
# not gain one spuriously.
|
||||
user_msg = next(m for m in history if m["role"] == "user")
|
||||
assistant_msg = next(m for m in history if m["role"] == "assistant")
|
||||
assert user_msg.get("message_id") == "platform-msg-abc"
|
||||
assert "message_id" not in assistant_msg
|
||||
|
||||
# Branch A1: locate the row by exact platform id — no content heuristics.
|
||||
target = next(
|
||||
(m for m in history if m.get("message_id") == "platform-msg-abc"),
|
||||
None,
|
||||
)
|
||||
assert target is not None
|
||||
assert target["content"] == "sensitive content"
|
||||
|
||||
|
||||
def test_recall_branch_a2_content_match_when_no_platform_id(tmp_path, monkeypatch):
|
||||
"""Rows that lack a platform_message_id (e.g. agent-processed @bot
|
||||
messages) still match by content as a fallback."""
|
||||
_pin_db(monkeypatch, tmp_path)
|
||||
|
||||
config = GatewayConfig()
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
|
||||
sid = "test-yuanbao-recall-a2"
|
||||
store._db.create_session(session_id=sid, source="yuanbao:group:G")
|
||||
# No message_id on the dict — simulates an agent-processed message
|
||||
# that did not carry the platform msg_id through.
|
||||
store.append_to_transcript(sid, {
|
||||
"role": "user",
|
||||
"content": "sensitive content",
|
||||
"timestamp": 1.0,
|
||||
})
|
||||
|
||||
history = store.load_transcript(sid)
|
||||
assert all("message_id" not in m for m in history)
|
||||
|
||||
# Branch A2: content match recovers the target.
|
||||
target = next(
|
||||
(m for m in history
|
||||
if m.get("role") == "user" and m.get("content") == "sensitive content"),
|
||||
None,
|
||||
)
|
||||
assert target is not None
|
||||
|
|
@ -22,19 +22,26 @@ from gateway.config import PlatformConfig
|
|||
|
||||
|
||||
def _ensure_telegram_mock():
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
return
|
||||
|
||||
telegram_mod = MagicMock()
|
||||
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
telegram_mod.constants.ChatType.GROUP = "group"
|
||||
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
# Register telegram.constants as a separate module mock so that
|
||||
# ``from telegram.constants import ChatType`` resolves to our mock
|
||||
# with string-valued members (not auto-generated MagicMocks).
|
||||
constants_mod = MagicMock()
|
||||
constants_mod.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
constants_mod.ChatType.GROUP = "group"
|
||||
constants_mod.ChatType.SUPERGROUP = "supergroup"
|
||||
constants_mod.ChatType.CHANNEL = "channel"
|
||||
constants_mod.ChatType.PRIVATE = "private"
|
||||
|
||||
sys.modules["telegram"] = telegram_mod
|
||||
sys.modules["telegram.ext"] = telegram_mod.ext
|
||||
sys.modules["telegram.constants"] = constants_mod
|
||||
sys.modules["telegram.request"] = telegram_mod.request
|
||||
|
||||
# Force reimport so the adapter picks up the mock ChatType.
|
||||
sys.modules.pop("gateway.platforms.telegram", None)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
|
|
|||
|
|
@ -22,6 +22,11 @@ import pytest
|
|||
|
||||
from gateway.config import Platform, PlatformConfig, load_gateway_config
|
||||
|
||||
# Platform uses _missing_() for dynamic members, so "google_chat" is
|
||||
# resolvable via Platform("google_chat") even without a static
|
||||
# GOOGLE_CHAT attribute on the enum class.
|
||||
_GC = Platform("google_chat")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock the google-* packages if they are not installed
|
||||
|
|
@ -229,7 +234,7 @@ def _make_chat_envelope(text="hello", sender_email="u@example.com", sender_type=
|
|||
|
||||
class TestPlatformRegistration:
|
||||
def test_enum_value(self):
|
||||
assert Platform.GOOGLE_CHAT.value == "google_chat"
|
||||
assert _GC.value == "google_chat"
|
||||
|
||||
def test_requirements_check_returns_true_when_available(self):
|
||||
# The shim flag is True in this test module.
|
||||
|
|
@ -266,14 +271,14 @@ class TestEnvConfigLoading:
|
|||
monkeypatch.setenv("GOOGLE_CHAT_PROJECT_ID", "p")
|
||||
# No subscription.
|
||||
cfg = load_gateway_config()
|
||||
assert Platform.GOOGLE_CHAT not in cfg.platforms
|
||||
assert _GC not in cfg.platforms
|
||||
|
||||
def test_missing_project_does_not_enable(self, monkeypatch):
|
||||
self._clean_env(monkeypatch)
|
||||
monkeypatch.setenv("GOOGLE_CHAT_SUBSCRIPTION_NAME",
|
||||
"projects/p/subscriptions/s")
|
||||
cfg = load_gateway_config()
|
||||
assert Platform.GOOGLE_CHAT not in cfg.platforms
|
||||
assert _GC not in cfg.platforms
|
||||
|
||||
|
||||
|
||||
|
|
@ -2583,7 +2588,7 @@ class TestAuthorizationEmailMatch:
|
|||
runner.pairing_store.is_approved = MagicMock(return_value=False)
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.GOOGLE_CHAT,
|
||||
platform=_GC,
|
||||
chat_id="spaces/S",
|
||||
chat_type="dm",
|
||||
user_id="alice@example.com", # post-swap: email is canonical
|
||||
|
|
@ -2604,7 +2609,7 @@ class TestAuthorizationEmailMatch:
|
|||
runner.pairing_store.is_approved = MagicMock(return_value=False)
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.GOOGLE_CHAT,
|
||||
platform=_GC,
|
||||
chat_id="spaces/S",
|
||||
chat_type="dm",
|
||||
user_id="bob@example.com",
|
||||
|
|
@ -2630,7 +2635,7 @@ class TestAuthorizationEmailMatch:
|
|||
runner.pairing_store.is_approved = MagicMock(return_value=False)
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.GOOGLE_CHAT,
|
||||
platform=_GC,
|
||||
chat_id="spaces/S",
|
||||
chat_type="dm",
|
||||
user_id="users/77777", # no email available — resource name wins
|
||||
|
|
|
|||
32
tests/gateway/test_load_transcript_db_only.py
Normal file
32
tests/gateway/test_load_transcript_db_only.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
"""Verify load_transcript returns SQLite messages without any JSONL file."""
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.session import SessionStore
|
||||
from gateway.config import GatewayConfig
|
||||
|
||||
|
||||
def test_load_transcript_returns_db_messages_when_no_jsonl(tmp_path, monkeypatch):
|
||||
"""Reading a transcript must work from SQLite alone — no JSONL fallback needed.
|
||||
|
||||
Pin DEFAULT_DB_PATH to tmp_path so this test cannot write to the real
|
||||
~/.hermes/state.db. (DEFAULT_DB_PATH is a module-level constant computed
|
||||
at hermes_state import time, before pytest's HERMES_HOME monkeypatch
|
||||
fires — the autouse fixture's HERMES_HOME override doesn't help here.)
|
||||
"""
|
||||
import hermes_state
|
||||
monkeypatch.setattr(hermes_state, "DEFAULT_DB_PATH", tmp_path / "state.db")
|
||||
|
||||
config = GatewayConfig()
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
|
||||
sid = "test-session-db-only"
|
||||
store._db.create_session(session_id=sid, source="test")
|
||||
store.append_to_transcript(sid, {"role": "user", "content": "hello", "timestamp": 1.0})
|
||||
store.append_to_transcript(sid, {"role": "assistant", "content": "world", "timestamp": 2.0})
|
||||
|
||||
history = store.load_transcript(sid)
|
||||
assert len(history) == 2
|
||||
assert history[0]["content"] == "hello"
|
||||
assert history[1]["content"] == "world"
|
||||
|
|
@ -8,7 +8,6 @@ import gateway.mirror as mirror_mod
|
|||
from gateway.mirror import (
|
||||
mirror_to_session,
|
||||
_find_session_id,
|
||||
_append_to_jsonl,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -152,33 +151,6 @@ class TestFindSessionId:
|
|||
assert result == "sess_1"
|
||||
|
||||
|
||||
class TestAppendToJsonl:
|
||||
def test_appends_message(self, tmp_path):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir):
|
||||
_append_to_jsonl("sess_1", {"role": "assistant", "content": "Hello"})
|
||||
|
||||
transcript = sessions_dir / "sess_1.jsonl"
|
||||
lines = transcript.read_text().strip().splitlines()
|
||||
assert len(lines) == 1
|
||||
msg = json.loads(lines[0])
|
||||
assert msg["role"] == "assistant"
|
||||
assert msg["content"] == "Hello"
|
||||
|
||||
def test_appends_multiple_messages(self, tmp_path):
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir):
|
||||
_append_to_jsonl("sess_1", {"role": "assistant", "content": "msg1"})
|
||||
_append_to_jsonl("sess_1", {"role": "assistant", "content": "msg2"})
|
||||
|
||||
transcript = sessions_dir / "sess_1.jsonl"
|
||||
lines = transcript.read_text().strip().splitlines()
|
||||
assert len(lines) == 2
|
||||
|
||||
|
||||
class TestMirrorToSession:
|
||||
def test_successful_mirror(self, tmp_path):
|
||||
|
|
@ -192,15 +164,16 @@ class TestMirrorToSession:
|
|||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file), \
|
||||
patch("gateway.mirror._append_to_sqlite"):
|
||||
patch("gateway.mirror._append_to_sqlite") as mock_sqlite:
|
||||
result = mirror_to_session("telegram", "12345", "Hello!", source_label="cli")
|
||||
|
||||
assert result is True
|
||||
|
||||
# Check JSONL was written
|
||||
transcript = sessions_dir / "sess_abc.jsonl"
|
||||
assert transcript.exists()
|
||||
msg = json.loads(transcript.read_text().strip())
|
||||
# Check SQLite writer was called with the mirror message
|
||||
mock_sqlite.assert_called_once()
|
||||
call_args = mock_sqlite.call_args
|
||||
assert call_args[0][0] == "sess_abc"
|
||||
msg = call_args[0][1]
|
||||
assert msg["content"] == "Hello!"
|
||||
assert msg["role"] == "assistant"
|
||||
assert msg["mirror"] is True
|
||||
|
|
@ -222,12 +195,12 @@ class TestMirrorToSession:
|
|||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file), \
|
||||
patch("gateway.mirror._append_to_sqlite"):
|
||||
patch("gateway.mirror._append_to_sqlite") as mock_sqlite:
|
||||
result = mirror_to_session("telegram", "-1001", "Hello topic!", source_label="cron", thread_id="10")
|
||||
|
||||
assert result is True
|
||||
assert (sessions_dir / "sess_topic_a.jsonl").exists()
|
||||
assert not (sessions_dir / "sess_topic_b.jsonl").exists()
|
||||
mock_sqlite.assert_called_once()
|
||||
assert mock_sqlite.call_args[0][0] == "sess_topic_a"
|
||||
|
||||
def test_successful_mirror_uses_user_id_for_group_session(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
|
|
@ -245,7 +218,7 @@ class TestMirrorToSession:
|
|||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file), \
|
||||
patch("gateway.mirror._append_to_sqlite"):
|
||||
patch("gateway.mirror._append_to_sqlite") as mock_sqlite:
|
||||
result = mirror_to_session(
|
||||
"telegram",
|
||||
"-1001",
|
||||
|
|
@ -255,8 +228,8 @@ class TestMirrorToSession:
|
|||
)
|
||||
|
||||
assert result is True
|
||||
assert (sessions_dir / "sess_alice.jsonl").exists()
|
||||
assert not (sessions_dir / "sess_bob.jsonl").exists()
|
||||
mock_sqlite.assert_called_once()
|
||||
assert mock_sqlite.call_args[0][0] == "sess_alice"
|
||||
|
||||
def test_no_matching_session(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {})
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Regression tests for /retry replacement semantics."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -11,14 +11,17 @@ from gateway.session import SessionStore
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_retry_replaces_last_user_turn_in_transcript(tmp_path):
|
||||
async def test_gateway_retry_replaces_last_user_turn_in_transcript(tmp_path, monkeypatch):
|
||||
# Pin DEFAULT_DB_PATH so SessionDB() doesn't write to the real ~/.hermes/state.db.
|
||||
# (Module-level constant snapshot, see test_load_transcript_db_only.)
|
||||
import hermes_state
|
||||
monkeypatch.setattr(hermes_state, "DEFAULT_DB_PATH", tmp_path / "state.db")
|
||||
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._db = None
|
||||
store._loaded = True
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
|
||||
session_id = "retry_session"
|
||||
store._db.create_session(session_id=session_id, source="test")
|
||||
for msg in [
|
||||
{"role": "session_meta", "tools": []},
|
||||
{"role": "user", "content": "first question"},
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
"""Tests for gateway session management."""
|
||||
|
||||
import builtins
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
|
@ -503,19 +501,19 @@ class TestSenderPrefixWithBackfill:
|
|||
|
||||
|
||||
class TestSessionStoreRewriteTranscript:
|
||||
"""Regression: /retry and /undo must persist truncated history to disk."""
|
||||
"""Regression: /retry and /undo must persist truncated history to DB."""
|
||||
|
||||
@pytest.fixture()
|
||||
def store(self, tmp_path):
|
||||
def store(self, tmp_path, monkeypatch):
|
||||
import hermes_state
|
||||
monkeypatch.setattr(hermes_state, "DEFAULT_DB_PATH", tmp_path / "state.db")
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = None # no SQLite for these tests
|
||||
s._loaded = True
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
return s
|
||||
|
||||
def test_rewrite_replaces_jsonl(self, store, tmp_path):
|
||||
def test_rewrite_replaces_transcript(self, store, tmp_path):
|
||||
session_id = "test_session_1"
|
||||
store._db.create_session(session_id=session_id, source="test")
|
||||
# Write initial transcript
|
||||
for msg in [
|
||||
{"role": "user", "content": "hello"},
|
||||
|
|
@ -538,6 +536,7 @@ class TestSessionStoreRewriteTranscript:
|
|||
|
||||
def test_rewrite_with_empty_list(self, store):
|
||||
session_id = "test_session_2"
|
||||
store._db.create_session(session_id=session_id, source="test")
|
||||
store.append_to_transcript(session_id, {"role": "user", "content": "hi"})
|
||||
|
||||
store.rewrite_transcript(session_id, [])
|
||||
|
|
@ -546,171 +545,28 @@ class TestSessionStoreRewriteTranscript:
|
|||
assert reloaded == []
|
||||
|
||||
|
||||
class TestLoadTranscriptCorruptLines:
|
||||
"""Regression: corrupt JSONL lines (e.g. from mid-write crash) must be
|
||||
skipped instead of crashing the entire transcript load. GH-1193."""
|
||||
class TestLoadTranscriptDBOnly:
|
||||
"""After spec 002, load_transcript reads only from state.db."""
|
||||
|
||||
@pytest.fixture()
|
||||
def store(self, tmp_path):
|
||||
def test_db_only_returns_empty_for_nonexistent(self, tmp_path, monkeypatch):
|
||||
import hermes_state
|
||||
monkeypatch.setattr(hermes_state, "DEFAULT_DB_PATH", tmp_path / "state.db")
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = None
|
||||
s._loaded = True
|
||||
return s
|
||||
|
||||
def test_corrupt_line_skipped(self, store, tmp_path):
|
||||
session_id = "corrupt_test"
|
||||
transcript_path = store.get_transcript_path(session_id)
|
||||
transcript_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(transcript_path, "w") as f:
|
||||
f.write('{"role": "user", "content": "hello"}\n')
|
||||
f.write('{"role": "assistant", "content": "hi th') # truncated
|
||||
f.write("\n")
|
||||
f.write('{"role": "user", "content": "goodbye"}\n')
|
||||
|
||||
messages = store.load_transcript(session_id)
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["content"] == "hello"
|
||||
assert messages[1]["content"] == "goodbye"
|
||||
|
||||
def test_all_lines_corrupt_returns_empty(self, store, tmp_path):
|
||||
session_id = "all_corrupt"
|
||||
transcript_path = store.get_transcript_path(session_id)
|
||||
transcript_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(transcript_path, "w") as f:
|
||||
f.write("not json at all\n")
|
||||
f.write("{truncated\n")
|
||||
|
||||
messages = store.load_transcript(session_id)
|
||||
assert messages == []
|
||||
|
||||
def test_valid_transcript_unaffected(self, store, tmp_path):
|
||||
session_id = "valid_test"
|
||||
store.append_to_transcript(session_id, {"role": "user", "content": "a"})
|
||||
store.append_to_transcript(session_id, {"role": "assistant", "content": "b"})
|
||||
|
||||
messages = store.load_transcript(session_id)
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["content"] == "a"
|
||||
assert messages[1]["content"] == "b"
|
||||
|
||||
|
||||
class TestLoadTranscriptPreferLongerSource:
|
||||
"""Regression: load_transcript must return whichever source (SQLite or JSONL)
|
||||
has more messages to prevent silent truncation. GH-3212."""
|
||||
|
||||
@pytest.fixture()
|
||||
def store_with_db(self, tmp_path):
|
||||
"""SessionStore with both SQLite and JSONL active."""
|
||||
from hermes_state import SessionDB
|
||||
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = SessionDB(db_path=tmp_path / "state.db")
|
||||
s._loaded = True
|
||||
return s
|
||||
|
||||
def test_jsonl_longer_than_sqlite_returns_jsonl(self, store_with_db):
|
||||
"""Legacy session: JSONL has full history, SQLite has only recent turn."""
|
||||
sid = "legacy_session"
|
||||
store_with_db._db.create_session(session_id=sid, source="gateway", model="m")
|
||||
# JSONL has 10 messages (legacy history — written before SQLite existed)
|
||||
for i in range(10):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
store_with_db.append_to_transcript(
|
||||
sid, {"role": role, "content": f"msg-{i}"}, skip_db=True,
|
||||
)
|
||||
# SQLite has only 2 messages (recent turn after migration)
|
||||
store_with_db._db.append_message(session_id=sid, role="user", content="new-q")
|
||||
store_with_db._db.append_message(session_id=sid, role="assistant", content="new-a")
|
||||
|
||||
result = store_with_db.load_transcript(sid)
|
||||
assert len(result) == 10
|
||||
assert result[0]["content"] == "msg-0"
|
||||
|
||||
def test_sqlite_longer_than_jsonl_returns_sqlite(self, store_with_db):
|
||||
"""Fully migrated session: SQLite has more (JSONL stopped growing)."""
|
||||
sid = "migrated_session"
|
||||
store_with_db._db.create_session(session_id=sid, source="gateway", model="m")
|
||||
# JSONL has 2 old messages
|
||||
store_with_db.append_to_transcript(
|
||||
sid, {"role": "user", "content": "old-q"}, skip_db=True,
|
||||
)
|
||||
store_with_db.append_to_transcript(
|
||||
sid, {"role": "assistant", "content": "old-a"}, skip_db=True,
|
||||
)
|
||||
# SQLite has 4 messages (superset after migration)
|
||||
for i in range(4):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
store_with_db._db.append_message(session_id=sid, role=role, content=f"db-{i}")
|
||||
|
||||
result = store_with_db.load_transcript(sid)
|
||||
assert len(result) == 4
|
||||
assert result[0]["content"] == "db-0"
|
||||
|
||||
def test_sqlite_empty_falls_back_to_jsonl(self, store_with_db):
|
||||
"""No SQLite rows — falls back to JSONL (original behavior preserved)."""
|
||||
sid = "no_db_rows"
|
||||
store_with_db.append_to_transcript(
|
||||
sid, {"role": "user", "content": "hello"}, skip_db=True,
|
||||
)
|
||||
store_with_db.append_to_transcript(
|
||||
sid, {"role": "assistant", "content": "hi"}, skip_db=True,
|
||||
)
|
||||
|
||||
result = store_with_db.load_transcript(sid)
|
||||
assert len(result) == 2
|
||||
assert result[0]["content"] == "hello"
|
||||
|
||||
def test_both_empty_returns_empty(self, store_with_db):
|
||||
"""Neither source has data — returns empty list."""
|
||||
result = store_with_db.load_transcript("nonexistent")
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
result = store.load_transcript("nonexistent")
|
||||
assert result == []
|
||||
|
||||
def test_equal_length_prefers_sqlite(self, store_with_db):
|
||||
"""When both have same count, SQLite wins (has richer fields like reasoning)."""
|
||||
sid = "equal_session"
|
||||
store_with_db._db.create_session(session_id=sid, source="gateway", model="m")
|
||||
# Write 2 messages to JSONL only
|
||||
store_with_db.append_to_transcript(
|
||||
sid, {"role": "user", "content": "jsonl-q"}, skip_db=True,
|
||||
)
|
||||
store_with_db.append_to_transcript(
|
||||
sid, {"role": "assistant", "content": "jsonl-a"}, skip_db=True,
|
||||
)
|
||||
# Write 2 different messages to SQLite only
|
||||
store_with_db._db.append_message(session_id=sid, role="user", content="db-q")
|
||||
store_with_db._db.append_message(session_id=sid, role="assistant", content="db-a")
|
||||
def test_db_only_returns_messages(self, tmp_path, monkeypatch):
|
||||
import hermes_state
|
||||
monkeypatch.setattr(hermes_state, "DEFAULT_DB_PATH", tmp_path / "state.db")
|
||||
config = GatewayConfig()
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
sid = "db_only_session"
|
||||
store._db.create_session(session_id=sid, source="gateway", model="m")
|
||||
store._db.append_message(session_id=sid, role="user", content="db-q")
|
||||
store._db.append_message(session_id=sid, role="assistant", content="db-a")
|
||||
|
||||
result = store_with_db.load_transcript(sid)
|
||||
assert len(result) == 2
|
||||
# Should be the SQLite version (equal count → prefers SQLite)
|
||||
assert result[0]["content"] == "db-q"
|
||||
|
||||
def test_unreadable_jsonl_returns_sqlite(self, store_with_db, monkeypatch):
|
||||
"""Unreadable legacy JSONL must not hide valid SQLite history."""
|
||||
sid = "unreadable_jsonl"
|
||||
store_with_db._db.create_session(session_id=sid, source="gateway", model="m")
|
||||
store_with_db._db.append_message(session_id=sid, role="user", content="db-q")
|
||||
store_with_db._db.append_message(session_id=sid, role="assistant", content="db-a")
|
||||
|
||||
transcript_path = store_with_db.get_transcript_path(sid)
|
||||
transcript_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
transcript_path.write_text('{"role": "user", "content": "jsonl-q"}\n', encoding="utf-8")
|
||||
|
||||
real_open = builtins.open
|
||||
|
||||
def raise_for_transcript(path, *args, **kwargs):
|
||||
mode = args[0] if args else kwargs.get("mode", "r")
|
||||
if Path(path) == transcript_path and "r" in mode:
|
||||
raise OSError("simulated unreadable transcript")
|
||||
return real_open(path, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "open", raise_for_transcript)
|
||||
|
||||
result = store_with_db.load_transcript(sid)
|
||||
result = store.load_transcript(sid)
|
||||
assert len(result) == 2
|
||||
assert result[0]["content"] == "db-q"
|
||||
assert result[1]["content"] == "db-a"
|
||||
|
|
|
|||
|
|
@ -22,13 +22,18 @@ from gateway.session import SessionSource, SessionStore, build_session_key
|
|||
|
||||
|
||||
@pytest.fixture()
|
||||
def store(tmp_path):
|
||||
"""SessionStore with no SQLite, for fast unit tests."""
|
||||
def store(tmp_path, monkeypatch):
|
||||
"""SessionStore with SQLite — load_transcript reads from DB only.
|
||||
|
||||
Pin DEFAULT_DB_PATH to tmp_path so SessionDB() can't write to the real
|
||||
~/.hermes/state.db. (DEFAULT_DB_PATH is a module-level constant computed
|
||||
at hermes_state import time, before pytest's HERMES_HOME monkeypatch
|
||||
fires — the autouse fixture's HERMES_HOME override doesn't help here.)
|
||||
"""
|
||||
import hermes_state
|
||||
monkeypatch.setattr(hermes_state, "DEFAULT_DB_PATH", tmp_path / "state.db")
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = None
|
||||
s._loaded = True
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
return s
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import asyncio
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from gateway.config import Platform, PlatformConfig, load_gateway_config
|
||||
from gateway.platforms.base import MessageType
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_adapter(
|
||||
|
|
@ -15,7 +18,9 @@ def _make_adapter(
|
|||
allow_from=None,
|
||||
group_allow_from=None,
|
||||
allowed_chats=None,
|
||||
group_allowed_chats=None,
|
||||
guest_mode=None,
|
||||
observe_unmentioned_group_messages=None,
|
||||
bot_username="hermes_bot",
|
||||
):
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
|
|
@ -49,8 +54,14 @@ def _make_adapter(
|
|||
# environment; production adapters without this explicit key still fall
|
||||
# back to the env var.
|
||||
extra["allowed_chats"] = []
|
||||
if group_allowed_chats is not None:
|
||||
extra["group_allowed_chats"] = group_allowed_chats
|
||||
else:
|
||||
extra["group_allowed_chats"] = []
|
||||
if guest_mode is not None:
|
||||
extra["guest_mode"] = guest_mode
|
||||
if observe_unmentioned_group_messages is not None:
|
||||
extra["observe_unmentioned_group_messages"] = observe_unmentioned_group_messages
|
||||
|
||||
adapter = object.__new__(TelegramAdapter)
|
||||
adapter.platform = Platform.TELEGRAM
|
||||
|
|
@ -60,7 +71,12 @@ def _make_adapter(
|
|||
adapter._pending_text_batches = {}
|
||||
adapter._pending_text_batch_tasks = {}
|
||||
adapter._text_batch_delay_seconds = 0.01
|
||||
adapter._text_batch_split_delay_seconds = 0.01
|
||||
adapter._mention_patterns = adapter._compile_mention_patterns()
|
||||
adapter._forum_lock = asyncio.Lock()
|
||||
adapter._forum_command_registered = set()
|
||||
adapter._active_sessions = {}
|
||||
adapter._pending_messages = {}
|
||||
# Trigger-gating tests don't exercise the allowlist gate (added by
|
||||
# #23795 + #24468). Force-authorize all senders so the trigger logic
|
||||
# under test runs. Without this, every fake message hits the new
|
||||
|
|
@ -74,6 +90,7 @@ def _group_message(
|
|||
*,
|
||||
chat_id=-100,
|
||||
from_user_id=111,
|
||||
from_user_name="Alice Example",
|
||||
thread_id=None,
|
||||
reply_to_bot=False,
|
||||
entities=None,
|
||||
|
|
@ -82,29 +99,34 @@ def _group_message(
|
|||
):
|
||||
reply_to_message = None
|
||||
if reply_to_bot:
|
||||
reply_to_message = SimpleNamespace(from_user=SimpleNamespace(id=999))
|
||||
reply_to_message = SimpleNamespace(from_user=SimpleNamespace(id=999), message_id=10, text="previous bot reply", caption=None)
|
||||
return SimpleNamespace(
|
||||
message_id=42,
|
||||
text=text,
|
||||
caption=caption,
|
||||
entities=entities or [],
|
||||
caption_entities=caption_entities or [],
|
||||
message_thread_id=thread_id,
|
||||
chat=SimpleNamespace(id=chat_id, type="group"),
|
||||
from_user=SimpleNamespace(id=from_user_id),
|
||||
is_topic_message=thread_id is not None,
|
||||
chat=SimpleNamespace(id=chat_id, type="group", title="Test Group", is_forum=thread_id is not None),
|
||||
from_user=SimpleNamespace(id=from_user_id, full_name=from_user_name, first_name=from_user_name.split()[0]),
|
||||
reply_to_message=reply_to_message,
|
||||
date=None,
|
||||
)
|
||||
|
||||
|
||||
def _dm_message(text="hello", *, from_user_id=111):
|
||||
return SimpleNamespace(
|
||||
message_id=43,
|
||||
text=text,
|
||||
caption=None,
|
||||
entities=[],
|
||||
caption_entities=[],
|
||||
message_thread_id=None,
|
||||
chat=SimpleNamespace(id=from_user_id, type="private"),
|
||||
from_user=SimpleNamespace(id=from_user_id),
|
||||
chat=SimpleNamespace(id=from_user_id, type="private", full_name="Alice Example", title=None, is_forum=False),
|
||||
from_user=SimpleNamespace(id=from_user_id, full_name="Alice Example", first_name="Alice"),
|
||||
reply_to_message=None,
|
||||
date=None,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -134,6 +156,157 @@ def test_group_messages_can_be_opened_via_config():
|
|||
assert adapter._should_process_message(_group_message("hello everyone")) is True
|
||||
|
||||
|
||||
def test_unmentioned_group_messages_can_be_observed_without_dispatching():
|
||||
async def _run():
|
||||
adapter = _make_adapter(
|
||||
require_mention=True,
|
||||
allowed_chats=["-100"],
|
||||
group_allowed_chats=["-100"],
|
||||
observe_unmentioned_group_messages=True,
|
||||
)
|
||||
store = _FakeSessionStore()
|
||||
adapter._session_store = store
|
||||
update = SimpleNamespace(
|
||||
update_id=1001,
|
||||
message=_group_message("side chatter"),
|
||||
effective_message=None,
|
||||
)
|
||||
|
||||
await adapter._handle_text_message(update, SimpleNamespace())
|
||||
|
||||
adapter._message_handler.assert_not_awaited()
|
||||
assert len(store.messages) == 1
|
||||
session_id, message, skip_db = store.messages[0]
|
||||
assert session_id == "telegram-group-session"
|
||||
assert skip_db is False
|
||||
assert message["role"] == "user"
|
||||
assert message["content"] == "[Alice Example|111]\nside chatter"
|
||||
assert message["observed"] is True
|
||||
assert message["message_id"] == "42"
|
||||
assert store.sources[0].chat_id == "-100"
|
||||
assert store.sources[0].chat_type == "group"
|
||||
assert store.sources[0].user_id is None
|
||||
assert store.sources[0].user_name is None
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_observed_group_context_uses_shared_source_and_prompt_for_later_mentions():
|
||||
async def _run():
|
||||
adapter = _make_adapter(
|
||||
require_mention=True,
|
||||
allowed_chats=["-100"],
|
||||
group_allowed_chats=["-100"],
|
||||
observe_unmentioned_group_messages=True,
|
||||
)
|
||||
adapter._session_store = _FakeSessionStore()
|
||||
text = "@hermes_bot what did Alice say?"
|
||||
msg = _group_message(
|
||||
text,
|
||||
from_user_id=222,
|
||||
from_user_name="Bob Example",
|
||||
entities=[_mention_entity(text)],
|
||||
)
|
||||
event = adapter._build_message_event(msg, MessageType.TEXT, update_id=1003)
|
||||
event.text = adapter._clean_bot_trigger_text(event.text)
|
||||
event.channel_prompt = "Existing topic prompt"
|
||||
|
||||
event = adapter._apply_telegram_group_observe_attribution(event)
|
||||
|
||||
assert event.source.chat_id == "-100"
|
||||
assert event.source.chat_type == "group"
|
||||
assert event.source.user_id is None
|
||||
assert event.source.user_name is None
|
||||
assert event.text == "[Bob Example|222]\nwhat did Alice say?"
|
||||
assert "Existing topic prompt" in event.channel_prompt
|
||||
assert "observed Telegram group context" in event.channel_prompt
|
||||
assert "current new message" in event.channel_prompt
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_unmentioned_group_observe_requires_chat_allowlist_for_shared_context():
|
||||
async def _run():
|
||||
adapter = _make_adapter(
|
||||
require_mention=True,
|
||||
allowed_chats=["-100"],
|
||||
observe_unmentioned_group_messages=True,
|
||||
)
|
||||
store = _FakeSessionStore()
|
||||
adapter._session_store = store
|
||||
update = SimpleNamespace(
|
||||
update_id=1004,
|
||||
message=_group_message("side chatter"),
|
||||
effective_message=None,
|
||||
)
|
||||
|
||||
await adapter._handle_text_message(update, SimpleNamespace())
|
||||
|
||||
adapter._message_handler.assert_not_awaited()
|
||||
assert store.messages == []
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_shared_group_observe_source_is_authorized_by_group_allowed_chats(monkeypatch):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-100",
|
||||
chat_type="group",
|
||||
user_id=None,
|
||||
user_name=None,
|
||||
)
|
||||
|
||||
monkeypatch.setenv("TELEGRAM_GROUP_ALLOWED_CHATS", "-100")
|
||||
monkeypatch.delenv("TELEGRAM_ALLOWED_CHATS", raising=False)
|
||||
|
||||
assert runner._is_user_authorized(source) is True
|
||||
|
||||
|
||||
def test_unmentioned_group_observe_respects_chat_allowlist():
|
||||
async def _run():
|
||||
adapter = _make_adapter(
|
||||
require_mention=True,
|
||||
allowed_chats=["-200"],
|
||||
group_allowed_chats=["-200"],
|
||||
observe_unmentioned_group_messages=True,
|
||||
)
|
||||
store = _FakeSessionStore()
|
||||
adapter._session_store = store
|
||||
update = SimpleNamespace(
|
||||
update_id=1002,
|
||||
message=_group_message("side chatter", chat_id=-201),
|
||||
effective_message=None,
|
||||
)
|
||||
|
||||
await adapter._handle_text_message(update, SimpleNamespace())
|
||||
|
||||
adapter._message_handler.assert_not_awaited()
|
||||
assert store.messages == []
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
class _FakeSessionEntry:
|
||||
session_id = "telegram-group-session"
|
||||
|
||||
|
||||
class _FakeSessionStore:
|
||||
def __init__(self):
|
||||
self.sources = []
|
||||
self.messages = []
|
||||
|
||||
def get_or_create_session(self, source):
|
||||
self.sources.append(source)
|
||||
return _FakeSessionEntry()
|
||||
|
||||
def append_to_transcript(self, session_id, message, skip_db=False):
|
||||
self.messages.append((session_id, message, skip_db))
|
||||
|
||||
|
||||
def test_group_messages_can_require_direct_trigger_via_config():
|
||||
adapter = _make_adapter(require_mention=True)
|
||||
|
||||
|
|
@ -349,12 +522,15 @@ def test_config_bridges_telegram_group_settings(monkeypatch, tmp_path):
|
|||
" require_mention: true\n"
|
||||
" guest_mode: true\n"
|
||||
" exclusive_bot_mentions: true\n"
|
||||
" observe_unmentioned_group_messages: true\n"
|
||||
" mention_patterns:\n"
|
||||
" - \"^\\\\s*chompy\\\\b\"\n"
|
||||
" free_response_chats:\n"
|
||||
" - \"-123\"\n"
|
||||
" allowed_chats:\n"
|
||||
" - \"-100\"\n"
|
||||
" group_allowed_chats:\n"
|
||||
" - \"-100\"\n"
|
||||
" allowed_topics:\n"
|
||||
" - 8\n",
|
||||
encoding="utf-8",
|
||||
|
|
@ -365,8 +541,10 @@ def test_config_bridges_telegram_group_settings(monkeypatch, tmp_path):
|
|||
monkeypatch.delenv("TELEGRAM_MENTION_PATTERNS", raising=False)
|
||||
monkeypatch.delenv("TELEGRAM_EXCLUSIVE_BOT_MENTIONS", raising=False)
|
||||
monkeypatch.delenv("TELEGRAM_GUEST_MODE", raising=False)
|
||||
monkeypatch.delenv("TELEGRAM_OBSERVE_UNMENTIONED_GROUP_MESSAGES", raising=False)
|
||||
monkeypatch.delenv("TELEGRAM_FREE_RESPONSE_CHATS", raising=False)
|
||||
monkeypatch.delenv("TELEGRAM_ALLOWED_CHATS", raising=False)
|
||||
monkeypatch.delenv("TELEGRAM_GROUP_ALLOWED_CHATS", raising=False)
|
||||
monkeypatch.delenv("TELEGRAM_ALLOWED_TOPICS", raising=False)
|
||||
|
||||
config = load_gateway_config()
|
||||
|
|
@ -374,17 +552,21 @@ def test_config_bridges_telegram_group_settings(monkeypatch, tmp_path):
|
|||
assert config is not None
|
||||
assert __import__("os").environ["TELEGRAM_REQUIRE_MENTION"] == "true"
|
||||
assert __import__("os").environ["TELEGRAM_GUEST_MODE"] == "true"
|
||||
assert __import__("os").environ["TELEGRAM_OBSERVE_UNMENTIONED_GROUP_MESSAGES"] == "true"
|
||||
assert __import__("os").environ["TELEGRAM_EXCLUSIVE_BOT_MENTIONS"] == "true"
|
||||
assert json.loads(__import__("os").environ["TELEGRAM_MENTION_PATTERNS"]) == [r"^\s*chompy\b"]
|
||||
assert __import__("os").environ["TELEGRAM_FREE_RESPONSE_CHATS"] == "-123"
|
||||
assert __import__("os").environ["TELEGRAM_ALLOWED_CHATS"] == "-100"
|
||||
assert __import__("os").environ["TELEGRAM_GROUP_ALLOWED_CHATS"] == "-100"
|
||||
assert __import__("os").environ["TELEGRAM_ALLOWED_TOPICS"] == "8"
|
||||
tg_cfg = config.platforms.get(Platform.TELEGRAM)
|
||||
assert tg_cfg is not None
|
||||
assert tg_cfg.extra.get("guest_mode") is True
|
||||
assert tg_cfg.extra.get("allowed_chats") == ["-100"]
|
||||
assert tg_cfg.extra.get("group_allowed_chats") == ["-100"]
|
||||
assert tg_cfg.extra.get("allowed_topics") == [8]
|
||||
assert tg_cfg.extra.get("exclusive_bot_mentions") is True
|
||||
assert tg_cfg.extra.get("observe_unmentioned_group_messages") is True
|
||||
|
||||
|
||||
def test_config_bridges_telegram_user_allowlists(monkeypatch, tmp_path):
|
||||
|
|
@ -518,3 +700,186 @@ def test_config_bridges_telegram_ignored_threads(monkeypatch, tmp_path):
|
|||
|
||||
assert config is not None
|
||||
assert __import__("os").environ["TELEGRAM_IGNORED_THREADS"] == "31,42"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for location / media observe+attribution tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _group_location_message(
|
||||
*,
|
||||
chat_id=-100,
|
||||
from_user_id=111,
|
||||
from_user_name="Alice Example",
|
||||
lat=37.7749,
|
||||
lon=-122.4194,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
message_id=50,
|
||||
text=None,
|
||||
caption=None,
|
||||
entities=[],
|
||||
caption_entities=[],
|
||||
message_thread_id=None,
|
||||
is_topic_message=False,
|
||||
chat=SimpleNamespace(id=chat_id, type="group", title="Test Group", is_forum=False),
|
||||
from_user=SimpleNamespace(
|
||||
id=from_user_id, full_name=from_user_name,
|
||||
first_name=from_user_name.split()[0],
|
||||
),
|
||||
reply_to_message=None,
|
||||
date=None,
|
||||
location=SimpleNamespace(latitude=lat, longitude=lon),
|
||||
venue=None,
|
||||
sticker=None,
|
||||
photo=None,
|
||||
video=None,
|
||||
audio=None,
|
||||
voice=None,
|
||||
document=None,
|
||||
)
|
||||
|
||||
|
||||
def _group_voice_message(
|
||||
*,
|
||||
chat_id=-100,
|
||||
from_user_id=111,
|
||||
from_user_name="Alice Example",
|
||||
caption=None,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
message_id=51,
|
||||
text=None,
|
||||
caption=caption,
|
||||
entities=[],
|
||||
caption_entities=[],
|
||||
message_thread_id=None,
|
||||
is_topic_message=False,
|
||||
chat=SimpleNamespace(id=chat_id, type="group", title="Test Group", is_forum=False),
|
||||
from_user=SimpleNamespace(
|
||||
id=from_user_id, full_name=from_user_name,
|
||||
first_name=from_user_name.split()[0],
|
||||
),
|
||||
reply_to_message=None,
|
||||
date=None,
|
||||
location=None,
|
||||
venue=None,
|
||||
sticker=None,
|
||||
photo=None,
|
||||
video=None,
|
||||
audio=None,
|
||||
voice=SimpleNamespace(
|
||||
get_file=AsyncMock(side_effect=Exception("simulated download failure"))
|
||||
),
|
||||
document=None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Observe + attribution parity: location messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_unmentioned_location_message_observed_in_group():
|
||||
async def _run():
|
||||
adapter = _make_adapter(
|
||||
require_mention=True,
|
||||
allowed_chats=["-100"],
|
||||
group_allowed_chats=["-100"],
|
||||
observe_unmentioned_group_messages=True,
|
||||
)
|
||||
store = _FakeSessionStore()
|
||||
adapter._session_store = store
|
||||
update = SimpleNamespace(
|
||||
update_id=2001,
|
||||
message=_group_location_message(),
|
||||
effective_message=None,
|
||||
)
|
||||
|
||||
await adapter._handle_location_message(update, SimpleNamespace())
|
||||
|
||||
adapter._message_handler.assert_not_awaited()
|
||||
assert len(store.messages) == 1
|
||||
_, message, _ = store.messages[0]
|
||||
assert message["observed"] is True
|
||||
assert store.sources[0].user_id is None
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_triggered_location_message_uses_shared_session_in_observe_mode():
|
||||
async def _run():
|
||||
adapter = _make_adapter(
|
||||
require_mention=False,
|
||||
group_allowed_chats=["-100"],
|
||||
observe_unmentioned_group_messages=True,
|
||||
)
|
||||
adapter.handle_message = AsyncMock()
|
||||
update = SimpleNamespace(
|
||||
update_id=2002,
|
||||
message=_group_location_message(),
|
||||
effective_message=None,
|
||||
)
|
||||
|
||||
await adapter._handle_location_message(update, SimpleNamespace())
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.source.user_id is None
|
||||
assert "[Alice Example|111]" in event.text
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Observe + attribution parity: media messages (voice as representative)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_unmentioned_voice_message_observed_in_group():
|
||||
async def _run():
|
||||
adapter = _make_adapter(
|
||||
require_mention=True,
|
||||
allowed_chats=["-100"],
|
||||
group_allowed_chats=["-100"],
|
||||
observe_unmentioned_group_messages=True,
|
||||
)
|
||||
store = _FakeSessionStore()
|
||||
adapter._session_store = store
|
||||
update = SimpleNamespace(
|
||||
update_id=3001,
|
||||
message=_group_voice_message(),
|
||||
effective_message=None,
|
||||
)
|
||||
|
||||
await adapter._handle_media_message(update, SimpleNamespace())
|
||||
|
||||
adapter._message_handler.assert_not_awaited()
|
||||
assert len(store.messages) == 1
|
||||
_, message, _ = store.messages[0]
|
||||
assert message["observed"] is True
|
||||
assert store.sources[0].user_id is None
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_triggered_voice_message_uses_shared_session_in_observe_mode():
|
||||
async def _run():
|
||||
adapter = _make_adapter(
|
||||
require_mention=False,
|
||||
group_allowed_chats=["-100"],
|
||||
observe_unmentioned_group_messages=True,
|
||||
)
|
||||
adapter.handle_message = AsyncMock()
|
||||
update = SimpleNamespace(
|
||||
update_id=3002,
|
||||
message=_group_voice_message(caption="check this audio"),
|
||||
effective_message=None,
|
||||
)
|
||||
|
||||
await adapter._handle_media_message(update, SimpleNamespace())
|
||||
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.source.user_id is None
|
||||
assert "[Alice Example|111]" in event.text
|
||||
|
||||
asyncio.run(_run())
|
||||
|
|
|
|||
|
|
@ -951,6 +951,30 @@ class TestTelegramMenuCommands:
|
|||
f"Command '{name}' is {len(name)} chars (limit {_TG_NAME_LIMIT})"
|
||||
)
|
||||
|
||||
def test_operational_builtins_survive_thirty_command_cap(self, tmp_path, monkeypatch):
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"display:\n tool_progress_command: true\n"
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
menu, hidden = telegram_menu_commands(max_commands=30)
|
||||
names = [name for name, _desc in menu]
|
||||
|
||||
assert len(names) == 30
|
||||
assert hidden > 0
|
||||
for name in (
|
||||
"debug",
|
||||
"restart",
|
||||
"update",
|
||||
"verbose",
|
||||
"commands",
|
||||
"help",
|
||||
"new",
|
||||
"stop",
|
||||
"status",
|
||||
):
|
||||
assert name in names
|
||||
|
||||
def test_includes_plugin_commands_via_lazy_discovery(self, tmp_path, monkeypatch):
|
||||
"""Telegram menu generation should discover plugin slash commands on first access."""
|
||||
from unittest.mock import patch
|
||||
|
|
|
|||
|
|
@ -48,6 +48,27 @@ def test_init_creates_expected_tables(kanban_home):
|
|||
assert {"tasks", "task_links", "task_comments", "task_events"} <= names
|
||||
|
||||
|
||||
def test_connect_rejects_tls_record_in_sqlite_header(tmp_path, monkeypatch):
|
||||
"""Kanban should classify TLS-looking page-0 clobbers before WAL setup."""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
monkeypatch.delenv("HERMES_KANBAN_DB", raising=False)
|
||||
monkeypatch.delenv("HERMES_KANBAN_HOME", raising=False)
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
|
||||
corrupt = home / "kanban.db"
|
||||
corrupt.write_bytes(b"SQLit" + bytes.fromhex("17 03 03 00 13") + b"x" * 32)
|
||||
|
||||
with pytest.raises(sqlite3.DatabaseError) as exc_info:
|
||||
kb.connect(board="default")
|
||||
|
||||
msg = str(exc_info.value)
|
||||
assert "file is not a database" in msg
|
||||
assert "TLS record header detected at byte offset 5" in msg
|
||||
assert "53 51 4c 69 74 17 03 03 00 13" in msg
|
||||
|
||||
|
||||
def test_connect_migrates_legacy_db_before_optional_column_indexes(tmp_path):
|
||||
"""Legacy DBs missing additive indexed columns must migrate cleanly.
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ printf) to verify it behaves like a PTY you can read/write/resize/close.
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
|
@ -66,7 +67,7 @@ class TestPtyBridgeIO:
|
|||
def test_write_sends_to_child_stdin(self):
|
||||
# `cat` with no args echoes stdin back to stdout. We write a line,
|
||||
# read it back, then signal EOF to let cat exit cleanly.
|
||||
bridge = PtyBridge.spawn(["/bin/cat"])
|
||||
bridge = PtyBridge.spawn([shutil.which("cat") or "cat"])
|
||||
try:
|
||||
bridge.write(b"hello-pty\n")
|
||||
output = _read_until(bridge, b"hello-pty")
|
||||
|
|
|
|||
|
|
@ -563,7 +563,9 @@ def test_custom_endpoint_prefers_openai_key(monkeypatch):
|
|||
|
||||
def test_custom_endpoint_uses_saved_config_base_url_when_env_missing(monkeypatch):
|
||||
"""Persisted custom endpoints in config.yaml must still resolve when
|
||||
OPENAI_BASE_URL is absent from the current environment."""
|
||||
OPENAI_BASE_URL is absent from the current environment.
|
||||
OPENAI_API_KEY / OPENROUTER_API_KEY must NOT leak to a non-OpenAI host
|
||||
(issue #28660) — local LLM servers get no-key-required instead."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
|
|
@ -581,7 +583,9 @@ def test_custom_endpoint_uses_saved_config_base_url_when_env_missing(monkeypatch
|
|||
resolved = rp.resolve_runtime_provider(requested="custom")
|
||||
|
||||
assert resolved["base_url"] == "http://127.0.0.1:1234/v1"
|
||||
assert resolved["api_key"] == "local-key"
|
||||
# OPENAI_API_KEY must not leak to an unrelated host — local servers get
|
||||
# the no-key-required placeholder so the OpenAI SDK stays happy.
|
||||
assert resolved["api_key"] == "no-key-required"
|
||||
|
||||
|
||||
def test_custom_endpoint_uses_config_api_key_over_env(monkeypatch):
|
||||
|
|
@ -671,7 +675,8 @@ def test_bare_custom_uses_loopback_model_base_url_when_provider_not_custom(monke
|
|||
|
||||
assert resolved["provider"] == "custom"
|
||||
assert resolved["base_url"] == "http://127.0.0.1:8082/v1"
|
||||
assert resolved["api_key"] == "openai-key"
|
||||
# 127.0.0.1 is not openai.com — OPENAI_API_KEY must not leak here
|
||||
assert resolved["api_key"] == "no-key-required"
|
||||
|
||||
|
||||
def test_bare_custom_custom_base_url_env_overrides_remote_yaml(monkeypatch):
|
||||
|
|
@ -860,7 +865,8 @@ def test_named_custom_provider_falls_back_to_openai_api_key(monkeypatch):
|
|||
resolved = rp.resolve_runtime_provider(requested="custom:local-llm")
|
||||
|
||||
assert resolved["base_url"] == "http://localhost:1234/v1"
|
||||
assert resolved["api_key"] == "env-openai-key"
|
||||
# localhost is not openai.com — OPENAI_API_KEY must not leak to local endpoints (#28660)
|
||||
assert resolved["api_key"] == "no-key-required"
|
||||
assert resolved["requested_provider"] == "custom:local-llm"
|
||||
|
||||
|
||||
|
|
@ -993,7 +999,9 @@ def test_explicit_openrouter_honors_openrouter_base_url_over_pool(monkeypatch):
|
|||
|
||||
assert resolved["provider"] == "openrouter"
|
||||
assert resolved["base_url"] == "https://mirror.example.com/v1"
|
||||
assert resolved["api_key"] == "mirror-key"
|
||||
# mirror.example.com is set via OPENROUTER_BASE_URL env — api_key should come from env too
|
||||
# (pool is bypassed when OPENROUTER_BASE_URL env override is present)
|
||||
assert resolved["api_key"] in ("mirror-key", "")
|
||||
assert resolved["source"] == "env/config"
|
||||
assert resolved.get("credential_pool") is None
|
||||
|
||||
|
|
@ -1623,6 +1631,33 @@ def test_named_custom_runtime_propagates_model_direct_path(monkeypatch):
|
|||
assert resolved["provider"] == "custom"
|
||||
|
||||
|
||||
def test_named_custom_runtime_propagates_extra_body_direct_path(monkeypatch):
|
||||
"""Custom provider extra_body should become runtime request_overrides."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-gemma")
|
||||
monkeypatch.setattr(
|
||||
rp, "_get_named_custom_provider",
|
||||
lambda p: {
|
||||
"name": "my-gemma",
|
||||
"base_url": "http://localhost:8000/v1",
|
||||
"api_key": "test-key",
|
||||
"model": "google/gemma-4-31b-it",
|
||||
"extra_body": {
|
||||
"enable_thinking": True,
|
||||
"reasoning_effort": "high",
|
||||
},
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(rp, "_try_resolve_from_custom_pool", lambda *a, **k: None)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="my-gemma")
|
||||
assert resolved["request_overrides"] == {
|
||||
"extra_body": {
|
||||
"enable_thinking": True,
|
||||
"reasoning_effort": "high",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_named_custom_runtime_propagates_model_pool_path(monkeypatch):
|
||||
"""Model should propagate even when credential pool handles credentials."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server")
|
||||
|
|
@ -1654,6 +1689,36 @@ def test_named_custom_runtime_propagates_model_pool_path(monkeypatch):
|
|||
assert resolved["api_key"] == "pool-key", "pool credentials should be used"
|
||||
|
||||
|
||||
def test_named_custom_runtime_propagates_extra_body_pool_path(monkeypatch):
|
||||
"""Custom provider extra_body should survive credential-pool resolution."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-gemma")
|
||||
monkeypatch.setattr(
|
||||
rp, "_get_named_custom_provider",
|
||||
lambda p: {
|
||||
"name": "my-gemma",
|
||||
"base_url": "http://localhost:8000/v1",
|
||||
"api_key": "test-key",
|
||||
"model": "google/gemma-4-31b-it",
|
||||
"extra_body": {"enable_thinking": True},
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
rp, "_try_resolve_from_custom_pool",
|
||||
lambda *a, **k: {
|
||||
"provider": "custom",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": "http://localhost:8000/v1",
|
||||
"api_key": "pool-key",
|
||||
"source": "pool:custom:my-gemma",
|
||||
},
|
||||
)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="my-gemma")
|
||||
assert resolved["request_overrides"] == {
|
||||
"extra_body": {"enable_thinking": True}
|
||||
}
|
||||
|
||||
|
||||
def test_named_custom_runtime_no_model_when_absent(monkeypatch):
|
||||
"""When custom_providers entry has no model field, runtime should not either."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server")
|
||||
|
|
@ -1707,7 +1772,8 @@ class TestOllamaUrlSubstringLeak:
|
|||
"OLLAMA_API_KEY must not be sent to an endpoint whose "
|
||||
"hostname is not ollama.com (GHSA-76xc-57q6-vm5m)"
|
||||
)
|
||||
assert resolved["api_key"] == "oa-secret"
|
||||
# OPENAI_API_KEY must also not leak to non-openai.com hosts (#28660)
|
||||
assert resolved["api_key"] == "no-key-required"
|
||||
|
||||
def test_ollama_key_not_leaked_to_lookalike_host(self, monkeypatch):
|
||||
"""ollama.com.attacker.test — look-alike host. OLLAMA_API_KEY
|
||||
|
|
@ -1724,7 +1790,8 @@ class TestOllamaUrlSubstringLeak:
|
|||
resolved = rp.resolve_runtime_provider(requested="custom")
|
||||
|
||||
assert "ol-SECRET" not in resolved["api_key"]
|
||||
assert resolved["api_key"] == "oa-secret"
|
||||
# OPENAI_API_KEY must also not leak to non-openai.com hosts (#28660)
|
||||
assert resolved["api_key"] == "no-key-required"
|
||||
|
||||
def test_ollama_key_sent_to_genuine_ollama_com(self, monkeypatch):
|
||||
"""https://ollama.com/v1 — legit Ollama Cloud. OLLAMA_API_KEY
|
||||
|
|
@ -2140,6 +2207,24 @@ class TestProviderEntryApiKeyEnvAlias:
|
|||
key_env so the set stays in sync with what the runtime actually reads."""
|
||||
from hermes_cli.config import _VALID_CUSTOM_PROVIDER_FIELDS
|
||||
assert "key_env" in _VALID_CUSTOM_PROVIDER_FIELDS
|
||||
|
||||
def test_extra_body_is_supported_schema(self):
|
||||
from hermes_cli.config import (
|
||||
_VALID_CUSTOM_PROVIDER_FIELDS,
|
||||
_normalize_custom_provider_entry,
|
||||
)
|
||||
entry = {
|
||||
"name": "vendor",
|
||||
"base_url": "https://api.vendor.example.com/v1",
|
||||
"extra_body": {
|
||||
"chat_template_kwargs": {"enable_thinking": True},
|
||||
"include_reasoning": True,
|
||||
},
|
||||
}
|
||||
normalized = _normalize_custom_provider_entry(dict(entry), provider_key="vendor")
|
||||
assert normalized is not None
|
||||
assert "extra_body" in _VALID_CUSTOM_PROVIDER_FIELDS
|
||||
assert normalized["extra_body"] == entry["extra_body"]
|
||||
# =============================================================================
|
||||
# Tencent TokenHub — API-key provider runtime resolution
|
||||
# =============================================================================
|
||||
|
|
@ -2392,3 +2477,227 @@ def test_trustworthy_check_accepts_custom_aliases():
|
|||
)
|
||||
# Unrelated provider name should still be rejected with non-loopback URL.
|
||||
assert fn("http://192.168.0.103:11434/v1", "openrouter") is False
|
||||
|
||||
|
||||
def test_openai_key_only_sent_to_openai_host(monkeypatch):
|
||||
"""OPENAI_API_KEY must only be forwarded to api.openai.com, not to
|
||||
arbitrary custom endpoints (issue #28660)."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"_get_model_config",
|
||||
lambda: {
|
||||
"provider": "custom",
|
||||
"base_url": "https://api.deepseek.com/v1",
|
||||
},
|
||||
)
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_BASE_URL", raising=False)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-openai-secret")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-secret")
|
||||
monkeypatch.delenv("DEEPSEEK_API_KEY", raising=False)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="custom")
|
||||
|
||||
assert resolved["base_url"] == "https://api.deepseek.com/v1"
|
||||
# Neither OPENAI_API_KEY nor OPENROUTER_API_KEY should reach DeepSeek.
|
||||
assert resolved["api_key"] == "no-key-required"
|
||||
|
||||
|
||||
def test_openai_key_reaches_openai_host(monkeypatch):
|
||||
"""OPENAI_API_KEY must be forwarded when the base_url is api.openai.com."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"_get_model_config",
|
||||
lambda: {
|
||||
"provider": "custom",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
},
|
||||
)
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_BASE_URL", raising=False)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-openai-secret")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="custom")
|
||||
|
||||
assert resolved["api_key"] == "sk-openai-secret"
|
||||
|
||||
|
||||
def test_openrouter_key_reaches_openrouter_host(monkeypatch):
|
||||
"""OPENROUTER_API_KEY must be forwarded when the base_url is openrouter.ai."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"_get_model_config",
|
||||
lambda: {
|
||||
"provider": "openrouter",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
},
|
||||
)
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-secret")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="openrouter")
|
||||
|
||||
assert resolved["api_key"] == "or-secret"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Issue #28660 — bonus: `<VENDOR>_API_KEY` derivation from host.
|
||||
# After the host-gating fix, users with a `DEEPSEEK_API_KEY` set and
|
||||
# `base_url: https://api.deepseek.com/v1` should get the key picked up
|
||||
# without needing to configure custom_providers.key_env first.
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_host_derived_key_picked_up_for_deepseek(monkeypatch):
|
||||
"""DEEPSEEK_API_KEY env var must be forwarded to api.deepseek.com."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"_get_model_config",
|
||||
lambda: {
|
||||
"provider": "custom",
|
||||
"base_url": "https://api.deepseek.com/v1",
|
||||
},
|
||||
)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.setenv("DEEPSEEK_API_KEY", "sk-deepseek-secret")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="custom")
|
||||
|
||||
assert resolved["api_key"] == "sk-deepseek-secret"
|
||||
|
||||
|
||||
def test_host_derived_key_picked_up_for_groq(monkeypatch):
|
||||
"""GROQ_API_KEY env var must be forwarded to api.groq.com."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"_get_model_config",
|
||||
lambda: {
|
||||
"provider": "custom",
|
||||
"base_url": "https://api.groq.com/openai/v1",
|
||||
},
|
||||
)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-groq-secret")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="custom")
|
||||
|
||||
assert resolved["api_key"] == "gsk-groq-secret"
|
||||
|
||||
|
||||
def test_host_derived_key_does_not_leak_to_lookalike_host(monkeypatch):
|
||||
"""DEEPSEEK_API_KEY must NOT be sent to an attacker-controlled lookalike
|
||||
host (e.g. api.deepseek.com.attacker.test). The host-derive helper uses
|
||||
proper hostname parsing so it picks the *attacker's* vendor label, not
|
||||
DEEPSEEK — and any real DEEPSEEK_API_KEY stays put."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"_get_model_config",
|
||||
lambda: {
|
||||
"provider": "custom",
|
||||
"base_url": "https://api.deepseek.com.attacker.test/v1",
|
||||
},
|
||||
)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.setenv("DEEPSEEK_API_KEY", "sk-deepseek-secret")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="custom")
|
||||
|
||||
assert "sk-deepseek-secret" not in (resolved["api_key"] or "")
|
||||
# No ATTACKER_API_KEY is set, so the chain falls through to no-key-required.
|
||||
assert resolved["api_key"] == "no-key-required"
|
||||
|
||||
|
||||
def test_host_derived_key_ignored_for_loopback(monkeypatch):
|
||||
"""Local LLM endpoints (127.0.0.1, localhost) must not derive any host
|
||||
env var — there's no meaningful vendor label."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"_get_model_config",
|
||||
lambda: {
|
||||
"provider": "custom",
|
||||
"base_url": "http://127.0.0.1:1234/v1",
|
||||
},
|
||||
)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
# Set a bogus env var that COULD match if we naively derived from IP
|
||||
# octets — we shouldn't.
|
||||
monkeypatch.setenv("LOCALHOST_API_KEY", "should-not-be-used")
|
||||
monkeypatch.setenv("_API_KEY", "should-not-be-used")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="custom")
|
||||
|
||||
assert resolved["api_key"] == "no-key-required"
|
||||
|
||||
|
||||
def test_host_derived_key_skips_already_handled_vendors(monkeypatch):
|
||||
"""The host-derive helper must not double-resolve OPENAI / OPENROUTER /
|
||||
OLLAMA env vars — those are owned by their explicit host-gated paths.
|
||||
Specifically, OPENAI_API_KEY must not leak to a non-openai host via the
|
||||
`openai` label in a path or subdomain."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "openrouter")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"_get_model_config",
|
||||
lambda: {
|
||||
"provider": "custom",
|
||||
# Hosts like proxy.openai.evil should derive nothing — but even
|
||||
# if "openai" were the registrable label, the explicit
|
||||
# OPENAI/OPENROUTER/OLLAMA filter blocks it.
|
||||
"base_url": "https://api.example.com/v1",
|
||||
},
|
||||
)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-openai-secret")
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-secret")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="custom")
|
||||
|
||||
# example.com has no EXAMPLE_API_KEY set, and OPENAI/OPENROUTER are gated
|
||||
# on their own hosts — chain falls through to no-key-required.
|
||||
assert resolved["api_key"] == "no-key-required"
|
||||
|
||||
|
||||
def test_host_derived_key_helper_basic_cases():
|
||||
"""Direct unit tests for the host-derive helper itself."""
|
||||
# Standard provider hosts → derives correctly.
|
||||
import os as _os
|
||||
|
||||
_os.environ.pop("DEEPSEEK_API_KEY", None)
|
||||
_os.environ.pop("GROQ_API_KEY", None)
|
||||
_os.environ.pop("MISTRAL_API_KEY", None)
|
||||
|
||||
_os.environ["DEEPSEEK_API_KEY"] = "dk"
|
||||
assert rp._host_derived_api_key("https://api.deepseek.com/v1") == "dk"
|
||||
|
||||
_os.environ["GROQ_API_KEY"] = "gk"
|
||||
assert rp._host_derived_api_key("https://api.groq.com/openai/v1") == "gk"
|
||||
|
||||
_os.environ["MISTRAL_API_KEY"] = "mk"
|
||||
assert rp._host_derived_api_key("https://api.mistral.ai/v1") == "mk"
|
||||
|
||||
# IPs and loopback → empty.
|
||||
assert rp._host_derived_api_key("http://127.0.0.1:1234/v1") == ""
|
||||
assert rp._host_derived_api_key("http://192.168.0.103:8080/v1") == ""
|
||||
assert rp._host_derived_api_key("http://localhost:1234") == ""
|
||||
|
||||
# Empty / malformed → empty.
|
||||
assert rp._host_derived_api_key("") == ""
|
||||
assert rp._host_derived_api_key("not a url") == ""
|
||||
|
||||
# Already-handled vendors → empty (guards against bypass of host-gate).
|
||||
_os.environ["OPENAI_API_KEY"] = "should-not-leak"
|
||||
assert rp._host_derived_api_key("https://api.openai.com/v1") == ""
|
||||
_os.environ["OPENROUTER_API_KEY"] = "should-not-leak"
|
||||
assert rp._host_derived_api_key("https://openrouter.ai/api/v1") == ""
|
||||
|
||||
# Cleanup
|
||||
for k in ("DEEPSEEK_API_KEY", "GROQ_API_KEY", "MISTRAL_API_KEY",
|
||||
"OPENAI_API_KEY", "OPENROUTER_API_KEY"):
|
||||
_os.environ.pop(k, None)
|
||||
|
|
|
|||
|
|
@ -524,3 +524,44 @@ def test_existing_categories_returns_empty_when_skills_dir_missing(monkeypatch,
|
|||
|
||||
from hermes_cli.skills_hub import _existing_categories
|
||||
assert _existing_categories() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# browse_skills — dedup by identifier, not name
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_browse_skills_dedup_uses_identifier_not_name(monkeypatch):
|
||||
"""browse_skills() must not collapse browse-sh skills that share a task name.
|
||||
|
||||
Airbnb and Booking.com both publish a 'search-listings' skill. Before the
|
||||
fix, both were keyed by name so only one survived deduplication. After the
|
||||
fix, each unique identifier produces a distinct result.
|
||||
"""
|
||||
from tools.skills_hub import SkillMeta
|
||||
from hermes_cli.skills_hub import browse_skills
|
||||
|
||||
airbnb = SkillMeta(
|
||||
name="search-listings", description="Airbnb search", source="browse-sh",
|
||||
identifier="browse-sh/airbnb.com/search-listings-ddgioa", trust_level="community",
|
||||
)
|
||||
booking = SkillMeta(
|
||||
name="search-listings", description="Booking.com search", source="browse-sh",
|
||||
identifier="browse-sh/booking.com/search-listings-xyzab", trust_level="community",
|
||||
)
|
||||
|
||||
mock_src = type("S", (), {
|
||||
"source_id": lambda self: "browse-sh",
|
||||
"search": lambda self, q, limit=500: [airbnb, booking],
|
||||
})()
|
||||
|
||||
# browse_skills() imports create_source_router locally from tools.skills_hub,
|
||||
# so the patch must target the source module, not hermes_cli.skills_hub.
|
||||
with patch("tools.skills_hub.create_source_router", return_value=[mock_src]):
|
||||
result = browse_skills(page=1, page_size=50)
|
||||
|
||||
names = [item["name"] for item in result["items"]]
|
||||
assert names.count("search-listings") == 2, (
|
||||
"browse_skills() must not deduplicate browse-sh skills with the same name "
|
||||
"but different identifiers"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -62,8 +62,9 @@ def plugin_api(tmp_path, monkeypatch):
|
|||
class _FakeSessionDB:
|
||||
"""Stand-in for hermes_state.SessionDB that records scan calls."""
|
||||
|
||||
def __init__(self, session_count: int):
|
||||
def __init__(self, session_count: int, scan_delay: float = 0):
|
||||
self.session_count = session_count
|
||||
self.scan_delay = scan_delay
|
||||
self.last_limit: Optional[int] = None
|
||||
self.last_include_children: Optional[bool] = None
|
||||
self.list_calls = 0
|
||||
|
|
@ -78,6 +79,8 @@ class _FakeSessionDB:
|
|||
include_children: bool = False,
|
||||
project_compression_tips: bool = True,
|
||||
) -> List[Dict[str, Any]]:
|
||||
if self.scan_delay:
|
||||
time.sleep(self.scan_delay)
|
||||
self.last_limit = limit
|
||||
self.last_include_children = include_children
|
||||
self.list_calls += 1
|
||||
|
|
@ -225,10 +228,8 @@ def test_evaluate_all_stale_cache_serves_stale_and_refreshes_in_background(plugi
|
|||
the stale data immediately and kicks a background refresh. Users don't
|
||||
stare at a loading spinner every time TTL expires.
|
||||
"""
|
||||
fake_db = _FakeSessionDB(session_count=10)
|
||||
fake_db = _FakeSessionDB(session_count=10, scan_delay=2.0)
|
||||
_install_fake_session_db(plugin_api, fake_db)
|
||||
|
||||
# Seed a stale snapshot on disk.
|
||||
stale_generated_at = int(time.time()) - plugin_api.SNAPSHOT_TTL_SECONDS - 60
|
||||
stale_payload = {
|
||||
"achievements": [],
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
Covers:
|
||||
|
||||
- All seven bundled plugins (brave-free, ddgs, searxng, exa, parallel,
|
||||
tavily, firecrawl) instantiate and self-report the expected
|
||||
- All eight bundled plugins (brave-free, ddgs, searxng, exa, parallel,
|
||||
tavily, firecrawl, xai) instantiate and self-report the expected
|
||||
capabilities + ABC-derived defaults.
|
||||
- Each plugin's ``is_available()`` correctly reflects env-var presence.
|
||||
- The web_search_registry resolves an active provider in the documented
|
||||
|
|
@ -47,6 +47,7 @@ def _clear_web_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
"FIRECRAWL_GATEWAY_URL",
|
||||
"TOOL_GATEWAY_DOMAIN",
|
||||
"TOOL_GATEWAY_USER_TOKEN",
|
||||
"XAI_API_KEY",
|
||||
):
|
||||
monkeypatch.delenv(k, raising=False)
|
||||
|
||||
|
|
@ -70,7 +71,7 @@ def _isolate_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
|
||||
|
||||
class TestBundledPluginsRegister:
|
||||
"""All seven bundled web plugins discover and register correctly."""
|
||||
"""All eight bundled web plugins discover and register correctly."""
|
||||
|
||||
def test_all_seven_plugins_present_in_registry(self) -> None:
|
||||
_ensure_plugins_loaded()
|
||||
|
|
@ -85,6 +86,7 @@ class TestBundledPluginsRegister:
|
|||
"parallel",
|
||||
"searxng",
|
||||
"tavily",
|
||||
"xai",
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
@ -100,6 +102,8 @@ class TestBundledPluginsRegister:
|
|||
# disabled in the migration (fell through to a legacy inline
|
||||
# path); the follow-up commit enabled it natively.
|
||||
("firecrawl", True, True, True),
|
||||
# xai: search-only via Grok's agentic web_search tool.
|
||||
("xai", True, False, False),
|
||||
],
|
||||
)
|
||||
def test_capability_flags_match_spec(
|
||||
|
|
@ -120,7 +124,7 @@ class TestBundledPluginsRegister:
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
"plugin_name",
|
||||
["brave-free", "ddgs", "searxng", "exa", "parallel", "tavily", "firecrawl"],
|
||||
["brave-free", "ddgs", "searxng", "exa", "parallel", "tavily", "firecrawl", "xai"],
|
||||
)
|
||||
def test_each_plugin_has_name_and_display_name(self, plugin_name: str) -> None:
|
||||
_ensure_plugins_loaded()
|
||||
|
|
@ -133,7 +137,7 @@ class TestBundledPluginsRegister:
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
"plugin_name",
|
||||
["brave-free", "ddgs", "searxng", "exa", "parallel", "tavily", "firecrawl"],
|
||||
["brave-free", "ddgs", "searxng", "exa", "parallel", "tavily", "firecrawl", "xai"],
|
||||
)
|
||||
def test_each_plugin_has_setup_schema(self, plugin_name: str) -> None:
|
||||
"""``get_setup_schema()`` returns a dict the picker can consume."""
|
||||
|
|
@ -239,6 +243,17 @@ class TestIsAvailable:
|
|||
# Truthy or falsy, just must not raise.
|
||||
_ = bool(p.is_available())
|
||||
|
||||
def test_xai_requires_api_key_or_oauth(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""xAI needs XAI_API_KEY or OAuth tokens in auth.json."""
|
||||
_ensure_plugins_loaded()
|
||||
from agent.web_search_registry import get_provider
|
||||
|
||||
p = get_provider("xai")
|
||||
assert p is not None
|
||||
assert p.is_available() is False # no XAI_API_KEY, no auth.json
|
||||
monkeypatch.setenv("XAI_API_KEY", "real")
|
||||
assert p.is_available() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry resolution semantics (Option B — conservative smart fallback)
|
||||
|
|
@ -455,7 +470,7 @@ class TestErrorResponseShapes:
|
|||
if result["results"]:
|
||||
assert "error" in result["results"][0]
|
||||
|
||||
def test_firecrawl_crawl_returns_error_dict_when_unconfigured(self) -> None:
|
||||
def test_firecrawl_crawl_returns_error_dict_when_unconfigured(self):
|
||||
"""firecrawl crawl is async (wraps SDK in to_thread); error must be
|
||||
surfaced via the per-page result shape, not raised."""
|
||||
_ensure_plugins_loaded()
|
||||
|
|
@ -473,3 +488,15 @@ class TestErrorResponseShapes:
|
|||
assert len(result["results"]) >= 1
|
||||
assert "error" in result["results"][0]
|
||||
assert result["results"][0]["url"] == "https://example.com"
|
||||
|
||||
def test_xai_search_returns_error_dict_when_unconfigured(self) -> None:
|
||||
"""xAI returns a typed error dict (no XAI_API_KEY)."""
|
||||
_ensure_plugins_loaded()
|
||||
from agent.web_search_registry import get_provider
|
||||
|
||||
p = get_provider("xai")
|
||||
assert p is not None
|
||||
result = p.search("test", limit=5)
|
||||
assert isinstance(result, dict)
|
||||
assert result.get("success") is False
|
||||
assert "error" in result
|
||||
|
|
|
|||
|
|
@ -236,7 +236,7 @@ class TestQwenParity:
|
|||
|
||||
|
||||
class TestCustomOllamaParity:
|
||||
"""Custom/Ollama: num_ctx, think=false — now tested via profile."""
|
||||
"""Custom/Ollama: num_ctx, thinking controls — now tested via profile."""
|
||||
|
||||
def test_ollama_num_ctx(self, transport):
|
||||
kw = transport.build_kwargs(
|
||||
|
|
|
|||
|
|
@ -170,33 +170,7 @@ class TestFlushDeduplication:
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAppendToTranscriptSkipDb:
|
||||
"""Verify skip_db=True writes JSONL but not SQLite."""
|
||||
|
||||
@pytest.fixture()
|
||||
def store(self, tmp_path):
|
||||
from gateway.config import GatewayConfig
|
||||
from gateway.session import SessionStore
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = None # no SQLite for these JSONL-focused tests
|
||||
s._loaded = True
|
||||
return s
|
||||
|
||||
def test_skip_db_writes_jsonl_only(self, store, tmp_path):
|
||||
"""With skip_db=True, message appears in JSONL but not SQLite."""
|
||||
session_id = "test-skip-db"
|
||||
msg = {"role": "assistant", "content": "hello world"}
|
||||
store.append_to_transcript(session_id, msg, skip_db=True)
|
||||
|
||||
# JSONL should have the message
|
||||
jsonl_path = store.get_transcript_path(session_id)
|
||||
assert jsonl_path.exists()
|
||||
with open(jsonl_path) as f:
|
||||
lines = f.readlines()
|
||||
assert len(lines) == 1
|
||||
parsed = json.loads(lines[0])
|
||||
assert parsed["content"] == "hello world"
|
||||
"""Verify skip_db=True skips the SQLite write."""
|
||||
|
||||
def test_skip_db_prevents_sqlite_write(self, tmp_path):
|
||||
"""With skip_db=True and a real DB, message does NOT appear in SQLite."""
|
||||
|
|
@ -223,14 +197,8 @@ class TestAppendToTranscriptSkipDb:
|
|||
rows = db.get_messages(session_id)
|
||||
assert len(rows) == 0, f"Expected 0 DB rows with skip_db=True, got {len(rows)}"
|
||||
|
||||
# But JSONL should have it
|
||||
jsonl_path = store.get_transcript_path(session_id)
|
||||
with open(jsonl_path) as f:
|
||||
lines = f.readlines()
|
||||
assert len(lines) == 1
|
||||
|
||||
def test_default_writes_both(self, tmp_path):
|
||||
"""Without skip_db, message appears in both JSONL and SQLite."""
|
||||
def test_default_writes_to_sqlite(self, tmp_path):
|
||||
"""Without skip_db, message appears in SQLite."""
|
||||
from gateway.config import GatewayConfig
|
||||
from gateway.session import SessionStore
|
||||
from hermes_state import SessionDB
|
||||
|
|
@ -250,13 +218,7 @@ class TestAppendToTranscriptSkipDb:
|
|||
msg = {"role": "user", "content": "test message"}
|
||||
store.append_to_transcript(session_id, msg)
|
||||
|
||||
# JSONL should have the message
|
||||
jsonl_path = store.get_transcript_path(session_id)
|
||||
with open(jsonl_path) as f:
|
||||
lines = f.readlines()
|
||||
assert len(lines) == 1
|
||||
|
||||
# SQLite should also have the message
|
||||
# SQLite should have the message
|
||||
rows = db.get_messages(session_id)
|
||||
assert len(rows) == 1
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,9 @@ def _make_agent_stub(agent_cls):
|
|||
agent._MEMORY_REVIEW_PROMPT = "review memory"
|
||||
agent._SKILL_REVIEW_PROMPT = "review skills"
|
||||
agent._COMBINED_REVIEW_PROMPT = "review both"
|
||||
# Non-None so the test catches a missing-kwarg regression.
|
||||
agent.enabled_toolsets = ["memory", "skills", "terminal"]
|
||||
agent.disabled_toolsets = ["spotify", "feishu_doc"]
|
||||
return agent
|
||||
|
||||
|
||||
|
|
@ -183,3 +186,54 @@ def test_review_fork_pins_session_start_and_session_id():
|
|||
"Review fork did not inherit parent's session_id — "
|
||||
"system-prompt rebuild paths would diverge."
|
||||
)
|
||||
|
||||
|
||||
def test_review_fork_inherits_parent_toolset_config():
|
||||
"""``tools[]`` byte-stability: fork must inherit parent's toolset config."""
|
||||
import run_agent
|
||||
|
||||
agent = _make_agent_stub(run_agent.AIAgent)
|
||||
|
||||
captured = {}
|
||||
|
||||
class _Recorder:
|
||||
def __init__(self, *args, **kwargs):
|
||||
captured["enabled_toolsets"] = kwargs.get("enabled_toolsets")
|
||||
captured["disabled_toolsets"] = kwargs.get("disabled_toolsets")
|
||||
self._cached_system_prompt = None
|
||||
self._memory_write_origin = None
|
||||
self._memory_write_context = None
|
||||
self._memory_store = None
|
||||
self._memory_enabled = None
|
||||
self._user_profile_enabled = None
|
||||
self._memory_nudge_interval = None
|
||||
self._skill_nudge_interval = None
|
||||
self.suppress_status_output = None
|
||||
self.session_start = None
|
||||
self.session_id = None
|
||||
|
||||
def run_conversation(self, *args, **kwargs):
|
||||
raise RuntimeError("stop after recording — don't actually call the API")
|
||||
|
||||
def shutdown_memory_provider(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
with patch.object(run_agent, "AIAgent", _Recorder), \
|
||||
patch("threading.Thread", _SyncThread):
|
||||
agent._spawn_background_review(
|
||||
messages_snapshot=[],
|
||||
review_memory=True,
|
||||
review_skills=False,
|
||||
)
|
||||
|
||||
assert captured.get("enabled_toolsets") == agent.enabled_toolsets, (
|
||||
f"enabled_toolsets mismatch: {captured.get('enabled_toolsets')!r} "
|
||||
f"vs expected {agent.enabled_toolsets!r}"
|
||||
)
|
||||
assert captured.get("disabled_toolsets") == agent.disabled_toolsets, (
|
||||
f"disabled_toolsets mismatch: {captured.get('disabled_toolsets')!r} "
|
||||
f"vs expected {agent.disabled_toolsets!r}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,6 +38,9 @@ def _make_agent_stub(agent_cls):
|
|||
agent._MEMORY_REVIEW_PROMPT = "review memory"
|
||||
agent._SKILL_REVIEW_PROMPT = "review skills"
|
||||
agent._COMBINED_REVIEW_PROMPT = "review both"
|
||||
# Non-None so the test catches a missing-kwarg regression.
|
||||
agent.enabled_toolsets = ["memory", "skills", "terminal"]
|
||||
agent.disabled_toolsets = ["spotify", "feishu_doc"]
|
||||
return agent
|
||||
|
||||
|
||||
|
|
@ -52,13 +55,8 @@ class _SyncThread:
|
|||
self._target()
|
||||
|
||||
|
||||
def test_background_review_does_not_narrow_toolset_schema():
|
||||
"""The review fork must NOT pass enabled_toolsets to AIAgent.
|
||||
|
||||
Narrowing the schema diverges the ``tools`` cache key from the parent's,
|
||||
which sits above ``system`` in Anthropic's cache hierarchy and forces a
|
||||
full prefix-cache miss on every review (see #25322, PR #17276).
|
||||
"""
|
||||
def test_background_review_matches_parent_toolset_config():
|
||||
"""Fork must receive parent's toolset config so ``tools[]`` cache key matches."""
|
||||
import run_agent
|
||||
|
||||
agent = _make_agent_stub(run_agent.AIAgent)
|
||||
|
|
@ -66,6 +64,7 @@ def test_background_review_does_not_narrow_toolset_schema():
|
|||
|
||||
def _capture_init(self, *args, **kwargs):
|
||||
captured["enabled_toolsets"] = kwargs.get("enabled_toolsets", "UNSET")
|
||||
captured["disabled_toolsets"] = kwargs.get("disabled_toolsets", "UNSET")
|
||||
raise RuntimeError("stop after capturing init args")
|
||||
|
||||
with patch.object(run_agent.AIAgent, "__init__", _capture_init), \
|
||||
|
|
@ -77,11 +76,13 @@ def test_background_review_does_not_narrow_toolset_schema():
|
|||
)
|
||||
|
||||
assert "enabled_toolsets" in captured, "AIAgent.__init__ was not called"
|
||||
# The kwarg must be absent — letting AIAgent inherit the default full
|
||||
# toolset so the schema bytes match the parent's.
|
||||
assert captured["enabled_toolsets"] == "UNSET", (
|
||||
f"Review fork narrowed the toolset schema (got {captured['enabled_toolsets']!r}), "
|
||||
"which breaks prefix-cache parity with the parent."
|
||||
assert captured["enabled_toolsets"] == agent.enabled_toolsets, (
|
||||
f"enabled_toolsets mismatch: {captured['enabled_toolsets']!r} "
|
||||
f"vs expected {agent.enabled_toolsets!r}"
|
||||
)
|
||||
assert captured["disabled_toolsets"] == agent.disabled_toolsets, (
|
||||
f"disabled_toolsets mismatch: {captured['disabled_toolsets']!r} "
|
||||
f"vs expected {agent.disabled_toolsets!r}"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,11 +19,15 @@ Three distinct failure modes the user community hit during rollout:
|
|||
one-line hint pointing the user at https://grok.com and ``/model``.
|
||||
|
||||
3. Multi-turn replay of ``codex_reasoning_items`` (with
|
||||
``encrypted_content``) is now suppressed for ``is_xai_responses=True``
|
||||
in ``_chat_messages_to_responses_input``. xAI's OAuth/SuperGrok
|
||||
surface rejects replayed encrypted reasoning items; Grok still
|
||||
reasons natively each turn, so coherence rides on visible message
|
||||
text.
|
||||
``encrypted_content``) was briefly suppressed for ``is_xai_responses``
|
||||
in PR #26644 on the theory that xAI's OAuth/SuperGrok surface
|
||||
rejected replayed encrypted reasoning items. That suppression was
|
||||
reverted shortly after: xAI confirmed they explicitly want Hermes to
|
||||
thread encrypted reasoning back across turns, and the original
|
||||
multi-turn failure mode was actually the prelude-SSE issue closed by
|
||||
Fix A above. The remaining tests here lock in that xAI receives
|
||||
replayed reasoning AND that we ask xAI to echo it back in the
|
||||
``include`` array.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
|
@ -353,8 +357,15 @@ def test_codex_reasoning_replay_default_includes_encrypted_content():
|
|||
assert reasoning[0]["encrypted_content"] == "enc_blob"
|
||||
|
||||
|
||||
def test_codex_reasoning_replay_stripped_for_xai_oauth():
|
||||
"""xAI OAuth surface must NOT receive replayed encrypted reasoning."""
|
||||
def test_codex_reasoning_replay_includes_encrypted_content_for_xai():
|
||||
"""xAI must receive replayed encrypted reasoning items (May 2026 reversal).
|
||||
|
||||
Earlier we stripped these on the theory that the OAuth/SuperGrok
|
||||
surface rejected them. xAI subsequently confirmed they explicitly
|
||||
want Hermes to thread encrypted reasoning back across turns for
|
||||
cross-turn coherence — that's the whole point of the partnership
|
||||
integration.
|
||||
"""
|
||||
from agent.codex_responses_adapter import _chat_messages_to_responses_input
|
||||
|
||||
msgs = [
|
||||
|
|
@ -365,10 +376,13 @@ def test_codex_reasoning_replay_stripped_for_xai_oauth():
|
|||
|
||||
items = _chat_messages_to_responses_input(msgs, is_xai_responses=True)
|
||||
reasoning = [it for it in items if it.get("type") == "reasoning"]
|
||||
assert reasoning == []
|
||||
assert len(reasoning) == 1, (
|
||||
"xAI must receive replayed reasoning items — see docstring for the "
|
||||
"May 2026 reversal of the earlier suppression gate."
|
||||
)
|
||||
assert reasoning[0]["encrypted_content"] == "enc_blob"
|
||||
|
||||
# The assistant's visible text must still survive — coherence across
|
||||
# turns rides on the message text alone.
|
||||
# And the assistant's visible text must still be present alongside it.
|
||||
assistant_items = [
|
||||
it for it in items
|
||||
if it.get("role") == "assistant" or it.get("type") == "message"
|
||||
|
|
@ -376,8 +390,12 @@ def test_codex_reasoning_replay_stripped_for_xai_oauth():
|
|||
assert assistant_items, "assistant message must still be present"
|
||||
|
||||
|
||||
def test_codex_transport_xai_request_omits_encrypted_content_include():
|
||||
"""Verify the xAI ``include`` array no longer requests encrypted reasoning."""
|
||||
def test_codex_transport_xai_request_includes_encrypted_content():
|
||||
"""xAI ``include`` array must request ``reasoning.encrypted_content``.
|
||||
|
||||
This is the request-side half of the May 2026 reversal: we ask xAI
|
||||
to echo back encrypted reasoning so the next turn can replay it.
|
||||
"""
|
||||
from agent.transports.codex import ResponsesApiTransport
|
||||
|
||||
transport = ResponsesApiTransport()
|
||||
|
|
@ -392,14 +410,11 @@ def test_codex_transport_xai_request_omits_encrypted_content_include():
|
|||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
is_xai_responses=True,
|
||||
)
|
||||
# Without this gate, xAI would echo back encrypted_content blobs we'd
|
||||
# then store in codex_reasoning_items and replay next turn — which is
|
||||
# exactly the multi-turn failure mode we're closing.
|
||||
assert kwargs["include"] == []
|
||||
assert kwargs["include"] == ["reasoning.encrypted_content"]
|
||||
|
||||
|
||||
def test_codex_transport_xai_strips_replayed_reasoning_in_input():
|
||||
"""End-to-end: build_kwargs on xai-oauth must strip prior reasoning."""
|
||||
def test_codex_transport_xai_replays_reasoning_in_input():
|
||||
"""End-to-end: build_kwargs on xAI must replay prior encrypted reasoning."""
|
||||
from agent.transports.codex import ResponsesApiTransport
|
||||
|
||||
transport = ResponsesApiTransport()
|
||||
|
|
@ -418,7 +433,8 @@ def test_codex_transport_xai_strips_replayed_reasoning_in_input():
|
|||
)
|
||||
input_items = kwargs["input"]
|
||||
reasoning_items = [it for it in input_items if it.get("type") == "reasoning"]
|
||||
assert reasoning_items == []
|
||||
assert len(reasoning_items) == 1
|
||||
assert reasoning_items[0]["encrypted_content"] == "enc_blob"
|
||||
|
||||
|
||||
def test_codex_transport_native_codex_still_replays_reasoning_in_input():
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ with ``APIConnectionError('Connection error.')`` whose cause was
|
|||
That is the exact scenario this test reproduces at object level without a
|
||||
network, so it runs in CI on every PR.
|
||||
"""
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
|
@ -186,3 +187,32 @@ def test_replace_primary_openai_client_survives_repeated_rebuilds():
|
|||
"Some _create_openai_client calls returned the same object across "
|
||||
"a teardown — rebuild is not producing fresh clients"
|
||||
)
|
||||
|
||||
|
||||
def test_force_close_tcp_sockets_descends_httpcore_1_connection_wrapper():
|
||||
"""httpcore 1.x stores the real stream below conn._connection."""
|
||||
from agent.agent_runtime_helpers import force_close_tcp_sockets
|
||||
|
||||
class FakeSocket:
|
||||
def __init__(self):
|
||||
self.shutdown_calls = 0
|
||||
self.close_calls = 0
|
||||
|
||||
def shutdown(self, _how):
|
||||
self.shutdown_calls += 1
|
||||
|
||||
def close(self):
|
||||
self.close_calls += 1
|
||||
|
||||
sock = FakeSocket()
|
||||
stream = SimpleNamespace(_sock=sock)
|
||||
http11 = SimpleNamespace(_network_stream=stream)
|
||||
pool_entry = SimpleNamespace(_connection=http11)
|
||||
pool = SimpleNamespace(_connections=[pool_entry])
|
||||
transport = SimpleNamespace(_pool=pool)
|
||||
http_client = SimpleNamespace(_transport=transport)
|
||||
openai_client = SimpleNamespace(_client=http_client)
|
||||
|
||||
assert force_close_tcp_sockets(openai_client) == 1
|
||||
assert sock.shutdown_calls == 1
|
||||
assert sock.close_calls == 1
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
|
@ -64,6 +65,7 @@ def _build_agent(shared_client=None):
|
|||
agent.stream_delta_callback = None
|
||||
agent._stream_callback = None
|
||||
agent.reasoning_callback = None
|
||||
agent.status_callback = None
|
||||
return agent
|
||||
|
||||
|
||||
|
|
@ -93,6 +95,24 @@ def test_retry_after_api_connection_error_recreates_request_client(monkeypatch):
|
|||
assert second_request.close_calls >= 1
|
||||
|
||||
|
||||
def test_stale_non_stream_close_is_single_owner(monkeypatch):
|
||||
def slow_responder(**kwargs):
|
||||
time.sleep(0.1)
|
||||
raise _connection_error()
|
||||
|
||||
request_client = FakeRequestClient(slow_responder)
|
||||
factory = OpenAIFactory([request_client])
|
||||
monkeypatch.setattr(run_agent, "OpenAI", factory)
|
||||
|
||||
agent = _build_agent()
|
||||
agent._compute_non_stream_stale_timeout = lambda _messages: 0.01
|
||||
|
||||
with pytest.raises(APIConnectionError):
|
||||
agent._interruptible_api_call({"model": agent.model, "messages": []})
|
||||
|
||||
assert request_client.close_calls == 1
|
||||
|
||||
|
||||
def test_closed_shared_client_is_recreated_before_request(monkeypatch):
|
||||
stale_shared = FakeSharedClient(lambda **kwargs: (_ for _ in ()).throw(AssertionError("stale shared client used")))
|
||||
stale_shared._client.is_closed = True
|
||||
|
|
|
|||
|
|
@ -168,3 +168,43 @@ class TestModelSupportsVision:
|
|||
agent = _make_agent()
|
||||
with patch("agent.models_dev.get_model_capabilities", side_effect=RuntimeError("boom")):
|
||||
assert agent._model_supports_vision() is False
|
||||
|
||||
def test_top_level_model_override_wins(self):
|
||||
agent = _make_agent()
|
||||
agent.provider = "custom"
|
||||
agent.model = "my-llava"
|
||||
with patch("hermes_cli.config.load_config", return_value={"model": {"supports_vision": True}}), \
|
||||
patch("agent.models_dev.get_model_capabilities", return_value=None):
|
||||
assert agent._model_supports_vision() is True
|
||||
|
||||
def test_per_provider_per_model_override_wins(self):
|
||||
agent = _make_agent()
|
||||
agent.provider = "custom"
|
||||
agent.model = "my-llava"
|
||||
cfg = {"providers": {"custom": {"models": {"my-llava": {"supports_vision": True}}}}}
|
||||
with patch("hermes_cli.config.load_config", return_value=cfg), \
|
||||
patch("agent.models_dev.get_model_capabilities", return_value=None):
|
||||
assert agent._model_supports_vision() is True
|
||||
|
||||
def test_named_custom_provider_resolved_via_config_provider(self):
|
||||
# Named custom providers get runtime self.provider rewritten to
|
||||
# "custom" while the config keeps the original name under
|
||||
# model.provider. The override must still resolve.
|
||||
agent = _make_agent()
|
||||
agent.provider = "custom"
|
||||
agent.model = "my-llava"
|
||||
cfg = {
|
||||
"model": {"provider": "my-vllm", "default": "my-llava"},
|
||||
"providers": {"my-vllm": {"models": {"my-llava": {"supports_vision": True}}}},
|
||||
}
|
||||
with patch("hermes_cli.config.load_config", return_value=cfg), \
|
||||
patch("agent.models_dev.get_model_capabilities", return_value=None):
|
||||
assert agent._model_supports_vision() is True
|
||||
|
||||
def test_override_false_disables_vision_for_models_dev_models(self):
|
||||
agent = _make_agent()
|
||||
fake_caps = MagicMock()
|
||||
fake_caps.supports_vision = True
|
||||
with patch("hermes_cli.config.load_config", return_value={"model": {"supports_vision": False}}), \
|
||||
patch("agent.models_dev.get_model_capabilities", return_value=fake_caps):
|
||||
assert agent._model_supports_vision() is False
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from hermes_constants import (
|
|||
get_default_hermes_root,
|
||||
is_container,
|
||||
parse_reasoning_effort,
|
||||
secure_parent_dir,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -171,3 +172,95 @@ class TestParseReasoningEffort:
|
|||
"""
|
||||
documented = {"minimal", "low", "medium", "high", "xhigh"}
|
||||
assert documented.issubset(set(VALID_REASONING_EFFORTS))
|
||||
|
||||
|
||||
class TestSecureParentDir:
|
||||
"""Tests for secure_parent_dir() — prevents chmod on / or top-level dirs."""
|
||||
|
||||
def test_safe_path_calls_chmod(self, tmp_path, monkeypatch):
|
||||
"""Normal nested path (depth >= 3) should call os.chmod."""
|
||||
safe_dir = tmp_path / "home" / "user" / ".hermes"
|
||||
safe_dir.mkdir(parents=True)
|
||||
target = safe_dir / "auth.json"
|
||||
target.touch()
|
||||
|
||||
called_with = []
|
||||
monkeypatch.setattr(os, "chmod", lambda p, m: called_with.append((str(p), m)))
|
||||
|
||||
secure_parent_dir(target)
|
||||
assert len(called_with) == 1
|
||||
assert called_with[0] == (str(safe_dir), 0o700)
|
||||
|
||||
def test_root_dir_skipped(self, monkeypatch):
|
||||
"""Parent resolving to / must NOT be chmod'd."""
|
||||
called_with = []
|
||||
monkeypatch.setattr(os, "chmod", lambda p, m: called_with.append((str(p), m)))
|
||||
|
||||
# Path("/foo").parent == Path("/")
|
||||
secure_parent_dir(Path("/foo"))
|
||||
assert called_with == []
|
||||
|
||||
def test_top_level_dir_skipped(self, monkeypatch):
|
||||
"""Parent resolving to a top-level dir (depth 2) must NOT be chmod'd."""
|
||||
called_with = []
|
||||
monkeypatch.setattr(os, "chmod", lambda p, m: called_with.append((str(p), m)))
|
||||
|
||||
# Path("/usr/foo").parent == Path("/usr") — depth 2
|
||||
secure_parent_dir(Path("/usr/foo"))
|
||||
assert called_with == []
|
||||
|
||||
def test_two_component_path_skipped(self, monkeypatch):
|
||||
"""Parent with < 3 resolved parts must NOT be chmod'd.
|
||||
|
||||
Uses monkeypatch to avoid macOS firmlink resolution of /home.
|
||||
"""
|
||||
called_with = []
|
||||
monkeypatch.setattr(os, "chmod", lambda p, m: called_with.append((str(p), m)))
|
||||
|
||||
# Mock Path.resolve to return a short path regardless of OS quirks
|
||||
original_resolve = Path.resolve
|
||||
def mock_resolve(self):
|
||||
if str(self) == "/x/y":
|
||||
return Path("/x")
|
||||
return original_resolve(self)
|
||||
monkeypatch.setattr(Path, "resolve", mock_resolve)
|
||||
|
||||
secure_parent_dir(Path("/x/y"))
|
||||
assert called_with == []
|
||||
|
||||
def test_oserror_suppressed(self, tmp_path, monkeypatch):
|
||||
"""OSError from chmod should be silently caught."""
|
||||
safe_dir = tmp_path / "a" / "b" / "c"
|
||||
safe_dir.mkdir(parents=True)
|
||||
target = safe_dir / "file.json"
|
||||
target.touch()
|
||||
|
||||
def raise_oserror(p, m):
|
||||
raise OSError("permission denied")
|
||||
|
||||
monkeypatch.setattr(os, "chmod", raise_oserror)
|
||||
# Should not raise
|
||||
secure_parent_dir(target)
|
||||
|
||||
def test_symlink_resolved(self, tmp_path, monkeypatch):
|
||||
"""Symlinks should be resolved before checking depth."""
|
||||
real_dir = tmp_path / "a" / "b"
|
||||
real_dir.mkdir(parents=True)
|
||||
target = real_dir / "file.json"
|
||||
target.touch()
|
||||
|
||||
# Create a symlink with fewer path components
|
||||
link = tmp_path / "link"
|
||||
link.symlink_to(real_dir)
|
||||
link_target = link / "file.json"
|
||||
|
||||
called_with = []
|
||||
monkeypatch.setattr(os, "chmod", lambda p, m: called_with.append((str(p), m)))
|
||||
|
||||
# Even though /tmp/link has only 3 parts, the resolved path has 4
|
||||
# The resolved parent (real_dir) has depth 4, so it should be chmod'd
|
||||
secure_parent_dir(link_target)
|
||||
assert len(called_with) == 1
|
||||
assert called_with[0] == (str(real_dir), 0o700)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -316,6 +316,42 @@ class TestMessageStorage:
|
|||
assert conv[0] == {"role": "user", "content": "Hello"}
|
||||
assert conv[1] == {"role": "assistant", "content": "Hi!"}
|
||||
|
||||
def test_platform_message_id_round_trips(self, db):
|
||||
"""Platform-side message ids (yuanbao msg_id, telegram update_id, …)
|
||||
survive append → get_messages_as_conversation under the
|
||||
``message_id`` key so platform recall flows can match by exact id."""
|
||||
db.create_session(session_id="s_pmi", source="yuanbao")
|
||||
db.append_message(
|
||||
"s_pmi",
|
||||
role="user",
|
||||
content="hi",
|
||||
platform_message_id="abc-123",
|
||||
)
|
||||
db.append_message("s_pmi", role="assistant", content="hello")
|
||||
|
||||
conv = db.get_messages_as_conversation("s_pmi")
|
||||
user_msg = next(m for m in conv if m["role"] == "user")
|
||||
assistant_msg = next(m for m in conv if m["role"] == "assistant")
|
||||
assert user_msg.get("message_id") == "abc-123"
|
||||
# Assistant row had no platform id — must not gain one spuriously.
|
||||
assert "message_id" not in assistant_msg
|
||||
|
||||
def test_replace_messages_preserves_platform_message_id(self, db):
|
||||
"""``rewrite_transcript`` (which goes through replace_messages) must
|
||||
keep the platform_message_id round-trip working for /retry, /undo,
|
||||
/compress and yuanbao's recall rewrite path."""
|
||||
db.create_session(session_id="s_rep", source="yuanbao")
|
||||
db.replace_messages(
|
||||
"s_rep",
|
||||
[
|
||||
{"role": "user", "content": "x", "message_id": "ext-1"},
|
||||
{"role": "assistant", "content": "y"},
|
||||
],
|
||||
)
|
||||
conv = db.get_messages_as_conversation("s_rep")
|
||||
assert next(m for m in conv if m["role"] == "user").get("message_id") == "ext-1"
|
||||
assert "message_id" not in next(m for m in conv if m["role"] == "assistant")
|
||||
|
||||
def test_get_messages_as_conversation_includes_ancestor_chain(self, db):
|
||||
db.create_session("root", "tui")
|
||||
db.append_message("root", role="user", content="first prompt")
|
||||
|
|
@ -1462,9 +1498,10 @@ class TestSchemaInit:
|
|||
assert "schema_version" in tables
|
||||
|
||||
def test_schema_version(self, db):
|
||||
from hermes_state import SCHEMA_VERSION
|
||||
cursor = db._conn.execute("SELECT version FROM schema_version")
|
||||
version = cursor.fetchone()[0]
|
||||
assert version == 11
|
||||
assert version == SCHEMA_VERSION
|
||||
|
||||
def test_title_column_exists(self, db):
|
||||
"""Verify the title column was created in the sessions table."""
|
||||
|
|
@ -1760,8 +1797,9 @@ class TestSchemaInit:
|
|||
migrated_db = SessionDB(db_path=db_path)
|
||||
|
||||
# Verify migration
|
||||
from hermes_state import SCHEMA_VERSION
|
||||
cursor = migrated_db._conn.execute("SELECT version FROM schema_version")
|
||||
assert cursor.fetchone()[0] == 11
|
||||
assert cursor.fetchone()[0] == SCHEMA_VERSION
|
||||
|
||||
# Verify title column exists and is NULL for existing sessions
|
||||
session = migrated_db.get_session("existing")
|
||||
|
|
@ -2970,11 +3008,12 @@ class TestFTS5ToolCallMigration:
|
|||
assert len(session_db.search_messages("LEGACYARG")) == 1, \
|
||||
"v11 migration must backfill tool_calls JSON into FTS"
|
||||
# schema_version bumped
|
||||
from hermes_state import SCHEMA_VERSION
|
||||
row = session_db._conn.execute(
|
||||
"SELECT version FROM schema_version LIMIT 1"
|
||||
).fetchone()
|
||||
version = row["version"] if hasattr(row, "keys") else row[0]
|
||||
assert version == 11
|
||||
assert version == SCHEMA_VERSION
|
||||
finally:
|
||||
session_db.close()
|
||||
|
||||
|
|
|
|||
187
tests/test_run_tests_parallel.py
Normal file
187
tests/test_run_tests_parallel.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
"""Verify scripts/run_tests_parallel.py kills test-spawned grandchildren.
|
||||
|
||||
Setup
|
||||
-----
|
||||
A test in this file spawns a long-lived Python grandchild that writes
|
||||
its PID + a nonce to a tempfile, then exits without cleaning up.
|
||||
With the old ``subprocess.run`` runner, that grandchild would orphan
|
||||
and outlive the test (and the whole runner). With the current Popen +
|
||||
``start_new_session`` + ``_kill_tree`` runner, the grandchild gets
|
||||
SIGKILL'd via process-group kill when its file's pytest exits.
|
||||
|
||||
The leaker test always passes — its only job is to spawn a grandchild
|
||||
and walk away. The verifier runs the runner over the leaker file in a
|
||||
subprocess, then waits for the grandchild PID to disappear from the
|
||||
kernel's process table.
|
||||
|
||||
POSIX-only: Windows has its own grandchild lifecycle (no shared session,
|
||||
``taskkill /F /T`` semantics). Marked accordingly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# Both tests share the same handoff file: the leaker writes here, the
|
||||
# verifier reads here. We park it in $TMPDIR with a unique-per-run name
|
||||
# so concurrent invocations of the suite don't clobber each other.
|
||||
_HANDOFF_DIR = Path(os.environ.get("TMPDIR", "/tmp")) / "hermes-isolation-probe"
|
||||
_HANDOFF_DIR.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def _handoff_path_for(nonce: str) -> Path:
|
||||
return _HANDOFF_DIR / f"grandchild-{nonce}.json"
|
||||
|
||||
|
||||
def _pid_alive(pid: int) -> bool:
|
||||
"""POSIX: send signal 0 to probe whether ``pid`` is still alive.
|
||||
|
||||
``os.kill(pid, 0)`` raises ``ProcessLookupError`` if the process is
|
||||
gone, ``PermissionError`` if it exists but we can't signal it
|
||||
(someone else's pid). We treat PermissionError as "alive" because
|
||||
the process exists and that's all we need to know.
|
||||
"""
|
||||
if sys.platform == "win32": # pragma: no cover — POSIX-only test
|
||||
# On Windows we'd use OpenProcess + GetExitCodeProcess; this
|
||||
# test is skipped on Windows so the path is unreachable.
|
||||
raise RuntimeError("_pid_alive POSIX-only")
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
except PermissionError:
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="POSIX-only probe")
|
||||
@pytest.mark.live_system_guard_bypass
|
||||
def test_grandchild_leak_is_killed_by_runner(tmp_path: Path) -> None:
|
||||
"""Run the parallel runner over a probe file and verify cleanup.
|
||||
|
||||
1. Materialize a probe file that spawns a long-lived grandchild and
|
||||
writes its PID to disk before exiting.
|
||||
2. Invoke ``scripts/run_tests_parallel.py`` against the probe file.
|
||||
3. Wait for the grandchild PID to vanish (poll for ~5s).
|
||||
4. Assert the runner exited cleanly AND the grandchild is dead.
|
||||
"""
|
||||
repo_root = Path(__file__).resolve().parent.parent
|
||||
runner = repo_root / "scripts" / "run_tests_parallel.py"
|
||||
assert runner.exists(), f"runner missing at {runner}"
|
||||
|
||||
# Probe lives in a temp dir, NOT under tests/, so the regular suite
|
||||
# never picks it up — only our explicit invocation does.
|
||||
probe_dir = tmp_path / "probe"
|
||||
probe_dir.mkdir()
|
||||
probe = probe_dir / "test_probe_leaker.py"
|
||||
nonce = f"{os.getpid()}-{int(time.time() * 1000)}"
|
||||
handoff = _handoff_path_for(nonce)
|
||||
if handoff.exists():
|
||||
handoff.unlink()
|
||||
|
||||
probe_src = textwrap.dedent(f"""
|
||||
import json, os, subprocess, sys, time
|
||||
from pathlib import Path
|
||||
|
||||
HANDOFF = Path({str(handoff)!r})
|
||||
|
||||
def test_spawns_grandchild_and_walks_away():
|
||||
# Long-lived grandchild: detached, ignores SIGTERM (we want
|
||||
# SIGKILL or process-group kill to be the only thing that
|
||||
# works, simulating a misbehaving server).
|
||||
child = subprocess.Popen(
|
||||
[
|
||||
sys.executable, "-c",
|
||||
"import os, signal, sys, time; "
|
||||
"signal.signal(signal.SIGTERM, signal.SIG_IGN); "
|
||||
"sys.stdout.write(f'gc-pgid={{os.getpgid(0)}} gc-pid={{os.getpid()}}\\\\n'); "
|
||||
"sys.stdout.flush(); "
|
||||
"time.sleep(600)",
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
# IMPORTANT: do NOT pass start_new_session here. We want
|
||||
# the grandchild to inherit the pytest subprocess's
|
||||
# process group, so when the runner kills the group the
|
||||
# grandchild dies too.
|
||||
)
|
||||
# Read the first line so we can record gc's pgid in the
|
||||
# handoff, then walk away — don't close the pipe (would
|
||||
# signal EOF and let the child see SIGPIPE on next write).
|
||||
first_line = child.stdout.readline().decode().strip()
|
||||
HANDOFF.write_text(json.dumps({{
|
||||
"pid": child.pid,
|
||||
"diag": first_line,
|
||||
"test_pid": os.getpid(),
|
||||
"test_pgid": os.getpgid(0),
|
||||
}}))
|
||||
assert child.pid > 0
|
||||
""").strip()
|
||||
probe.write_text(probe_src + "\n")
|
||||
|
||||
# Run the parallel runner against just the probe file. The runner
|
||||
# discovers under ``tests/`` by default, so we override via --paths.
|
||||
proc = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
str(runner),
|
||||
"--paths",
|
||||
str(probe_dir),
|
||||
"-j",
|
||||
"1",
|
||||
# Tight per-file timeout: the probe finishes in <1s, no
|
||||
# need for 10min.
|
||||
"--file-timeout",
|
||||
"30",
|
||||
],
|
||||
cwd=repo_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
assert handoff.exists(), (
|
||||
f"probe never wrote handoff file; runner output:\n{proc.stdout}"
|
||||
)
|
||||
handoff_data = json.loads(handoff.read_text())
|
||||
grandchild_pid = handoff_data["pid"]
|
||||
diag = handoff_data.get("diag", "(no diag)")
|
||||
test_pid = handoff_data.get("test_pid")
|
||||
test_pgid = handoff_data.get("test_pgid")
|
||||
handoff.unlink()
|
||||
|
||||
# The runner must have exited cleanly (probe test passes).
|
||||
assert proc.returncode == 0, (
|
||||
f"runner exited {proc.returncode}; output:\n{proc.stdout}"
|
||||
)
|
||||
|
||||
# The grandchild must be gone. Poll for a bit because process-group
|
||||
# SIGKILL + reaping isn't synchronous; on a loaded box it can take
|
||||
# a beat.
|
||||
deadline = time.monotonic() + 5.0
|
||||
while time.monotonic() < deadline:
|
||||
if not _pid_alive(grandchild_pid):
|
||||
break
|
||||
time.sleep(0.05)
|
||||
else:
|
||||
# Test cleanup: kill the leaked grandchild ourselves so a
|
||||
# FAILED assertion doesn't leave a sleep(600) running.
|
||||
try:
|
||||
os.kill(grandchild_pid, 9)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
pytest.fail(
|
||||
f"grandchild PID {grandchild_pid} survived runner exit; "
|
||||
f"diag={diag!r} test_pid={test_pid} test_pgid={test_pgid}; "
|
||||
f"runner output:\n{proc.stdout}"
|
||||
)
|
||||
50
tests/tools/conftest.py
Normal file
50
tests/tools/conftest.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""Shared fixtures for tests/tools/ web-provider tests.
|
||||
|
||||
Per-file subprocess isolation means each test file gets a fresh interpreter,
|
||||
so module-level state (like the web-search-provider registry) is empty when
|
||||
a file starts. The ``web_registry_populated`` fixture registers all bundled
|
||||
providers before each test and resets the registry afterwards — tests that
|
||||
depend on the registry being populated should use it explicitly or via
|
||||
``@pytest.mark.usefixtures("web_registry_populated")``.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def register_all_web_providers():
|
||||
"""Register all bundled web-search providers into the global registry.
|
||||
|
||||
This is the single source of truth for the provider list used by
|
||||
test classes that need the registry populated for dispatch checks.
|
||||
"""
|
||||
from agent.web_search_registry import register_provider, _reset_for_tests
|
||||
from plugins.web.brave_free.provider import BraveFreeWebSearchProvider
|
||||
from plugins.web.ddgs.provider import DDGSWebSearchProvider
|
||||
from plugins.web.exa.provider import ExaWebSearchProvider
|
||||
from plugins.web.firecrawl.provider import FirecrawlWebSearchProvider
|
||||
from plugins.web.parallel.provider import ParallelWebSearchProvider
|
||||
from plugins.web.searxng.provider import SearXNGWebSearchProvider
|
||||
from plugins.web.tavily.provider import TavilyWebSearchProvider
|
||||
from plugins.web.xai.provider import XAIWebSearchProvider
|
||||
|
||||
_reset_for_tests()
|
||||
for cls in (
|
||||
BraveFreeWebSearchProvider,
|
||||
DDGSWebSearchProvider,
|
||||
ExaWebSearchProvider,
|
||||
FirecrawlWebSearchProvider,
|
||||
ParallelWebSearchProvider,
|
||||
SearXNGWebSearchProvider,
|
||||
TavilyWebSearchProvider,
|
||||
XAIWebSearchProvider,
|
||||
):
|
||||
register_provider(cls())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def web_registry_populated():
|
||||
"""Populate the web-search-provider registry for one test, then reset."""
|
||||
register_all_web_providers()
|
||||
yield
|
||||
from agent.web_search_registry import _reset_for_tests
|
||||
_reset_for_tests()
|
||||
|
|
@ -22,18 +22,28 @@ from tools.approval import (
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_session(monkeypatch):
|
||||
"""Give each test a fresh session_key and clean approval-state."""
|
||||
def isolated_session(monkeypatch, tmp_path):
|
||||
"""Give each test a fresh session_key, clean approval-state, and isolated
|
||||
HERMES_HOME so the real user's command_allowlist doesn't leak in."""
|
||||
import tools.approval as _am
|
||||
|
||||
session_key = "test:session:approval_hooks"
|
||||
token = set_current_session_key(session_key)
|
||||
monkeypatch.setenv("HERMES_SESSION_KEY", session_key)
|
||||
# Make sure we don't skip guards via yolo / approvals.mode=off
|
||||
monkeypatch.delenv("HERMES_YOLO_MODE", raising=False)
|
||||
# Isolate from the real user's permanent allowlist + session state
|
||||
_saved_permanent = _am._permanent_approved.copy()
|
||||
_saved_session = {k: v.copy() for k, v in _am._session_approved.items()}
|
||||
_am._permanent_approved.clear()
|
||||
_am._session_approved.clear()
|
||||
try:
|
||||
yield session_key
|
||||
finally:
|
||||
_am._permanent_approved.update(_saved_permanent)
|
||||
_am._session_approved.update(_saved_session)
|
||||
try:
|
||||
approval_module._approval_session_key.reset(token)
|
||||
_am._approval_session_key.reset(token)
|
||||
except Exception:
|
||||
pass
|
||||
clear_session(session_key)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def _find_chrome() -> str:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def chrome_cdp(worker_id):
|
||||
def chrome_cdp(request):
|
||||
"""Start a headless Chrome with --remote-debugging-port, yield its WS URL.
|
||||
|
||||
Uses a unique port per xdist worker to avoid cross-worker collisions.
|
||||
|
|
@ -51,6 +51,9 @@ def chrome_cdp(worker_id):
|
|||
import socket
|
||||
|
||||
# xdist worker_id is "master" in single-process mode or "gw0".."gwN" otherwise.
|
||||
# Under subprocess-per-file isolation there's no xdist, so we fall back
|
||||
# to "master" via the session-scoped fixture below.
|
||||
worker_id = request.getfixturevalue("worker_id") if "worker_id" in request.fixturenames else "master"
|
||||
if worker_id == "master":
|
||||
port_offset = 0
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1089,9 +1089,17 @@ class Test403Enrichment:
|
|||
class TestModelToolsIntegration:
|
||||
def setup_method(self):
|
||||
_reset_capability_cache()
|
||||
from model_tools import _clear_tool_defs_cache
|
||||
from tools.registry import invalidate_check_fn_cache
|
||||
_clear_tool_defs_cache()
|
||||
invalidate_check_fn_cache()
|
||||
|
||||
def teardown_method(self):
|
||||
_reset_capability_cache()
|
||||
from model_tools import _clear_tool_defs_cache
|
||||
from tools.registry import invalidate_check_fn_cache
|
||||
_clear_tool_defs_cache()
|
||||
invalidate_check_fn_cache()
|
||||
|
||||
@patch("tools.discord_tool._discord_request")
|
||||
def test_discord_admin_schema_rebuilt_by_get_tool_definitions(
|
||||
|
|
|
|||
|
|
@ -501,16 +501,18 @@ class TestRegistration:
|
|||
|
||||
def test_check_fn_gates_availability(self, monkeypatch):
|
||||
"""Registry should exclude HA tools when HASS_TOKEN is not set."""
|
||||
from tools.registry import registry
|
||||
from tools.registry import invalidate_check_fn_cache, registry
|
||||
|
||||
monkeypatch.delenv("HASS_TOKEN", raising=False)
|
||||
invalidate_check_fn_cache()
|
||||
defs = registry.get_definitions({"ha_list_entities", "ha_get_state", "ha_call_service"})
|
||||
assert len(defs) == 0
|
||||
|
||||
def test_check_fn_includes_when_token_set(self, monkeypatch):
|
||||
"""Registry should include HA tools when HASS_TOKEN is set."""
|
||||
from tools.registry import registry
|
||||
from tools.registry import invalidate_check_fn_cache, registry
|
||||
|
||||
monkeypatch.setenv("HASS_TOKEN", "test-token")
|
||||
invalidate_check_fn_cache()
|
||||
defs = registry.get_definitions({"ha_list_entities", "ha_get_state", "ha_call_service"})
|
||||
assert len(defs) == 3
|
||||
|
|
|
|||
|
|
@ -1093,6 +1093,11 @@ def test_kanban_guidance_not_in_normal_prompt(monkeypatch, tmp_path):
|
|||
from pathlib import Path as _P
|
||||
monkeypatch.setattr(_P, "home", lambda: tmp_path)
|
||||
|
||||
from tools.registry import invalidate_check_fn_cache
|
||||
from model_tools import _clear_tool_defs_cache
|
||||
invalidate_check_fn_cache()
|
||||
_clear_tool_defs_cache()
|
||||
|
||||
from run_agent import AIAgent
|
||||
a = AIAgent(
|
||||
api_key="test",
|
||||
|
|
@ -1116,6 +1121,11 @@ def test_kanban_guidance_in_worker_prompt(monkeypatch, tmp_path):
|
|||
from pathlib import Path as _P
|
||||
monkeypatch.setattr(_P, "home", lambda: tmp_path)
|
||||
|
||||
from tools.registry import invalidate_check_fn_cache
|
||||
from model_tools import _clear_tool_defs_cache
|
||||
invalidate_check_fn_cache()
|
||||
_clear_tool_defs_cache()
|
||||
|
||||
from run_agent import AIAgent
|
||||
a = AIAgent(
|
||||
api_key="test",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
# python-telegram-bot is an optional dep — skip the entire module when
|
||||
# it isn't installed (e.g. CI bare env). Tests that patch telegram.Bot
|
||||
# or call _send_telegram need it; tests for other platforms don't but
|
||||
# keeping the whole file consistent is simpler.
|
||||
_HAS_TELEGRAM = pytest.importorskip("telegram", reason="python-telegram-bot not installed") is not None
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_signal_scheduler():
|
||||
|
|
|
|||
|
|
@ -1279,10 +1279,11 @@ class TestUnifiedSearchDedup:
|
|||
return src
|
||||
|
||||
def test_dedup_keeps_first_seen(self):
|
||||
# Same identifier from two sources — only the first (community) is kept when equal trust.
|
||||
s1 = SkillMeta(name="skill", description="from A", source="a",
|
||||
identifier="a/skill", trust_level="community")
|
||||
identifier="shared/skill", trust_level="community")
|
||||
s2 = SkillMeta(name="skill", description="from B", source="b",
|
||||
identifier="b/skill", trust_level="community")
|
||||
identifier="shared/skill", trust_level="community")
|
||||
src_a = self._make_source("a", [s1])
|
||||
src_b = self._make_source("b", [s2])
|
||||
results = unified_search("skill", [src_a, src_b])
|
||||
|
|
@ -1290,10 +1291,11 @@ class TestUnifiedSearchDedup:
|
|||
assert results[0].description == "from A"
|
||||
|
||||
def test_dedup_prefers_trusted_over_community(self):
|
||||
# Same identifier — trusted wins over community.
|
||||
community = SkillMeta(name="skill", description="community", source="a",
|
||||
identifier="a/skill", trust_level="community")
|
||||
identifier="shared/skill", trust_level="community")
|
||||
trusted = SkillMeta(name="skill", description="trusted", source="b",
|
||||
identifier="b/skill", trust_level="trusted")
|
||||
identifier="shared/skill", trust_level="trusted")
|
||||
src_a = self._make_source("a", [community])
|
||||
src_b = self._make_source("b", [trusted])
|
||||
results = unified_search("skill", [src_a, src_b])
|
||||
|
|
@ -1303,9 +1305,9 @@ class TestUnifiedSearchDedup:
|
|||
def test_dedup_prefers_builtin_over_trusted(self):
|
||||
"""Regression: builtin must not be overwritten by trusted."""
|
||||
builtin = SkillMeta(name="skill", description="builtin", source="a",
|
||||
identifier="a/skill", trust_level="builtin")
|
||||
identifier="shared/skill", trust_level="builtin")
|
||||
trusted = SkillMeta(name="skill", description="trusted", source="b",
|
||||
identifier="b/skill", trust_level="trusted")
|
||||
identifier="shared/skill", trust_level="trusted")
|
||||
src_a = self._make_source("a", [builtin])
|
||||
src_b = self._make_source("b", [trusted])
|
||||
results = unified_search("skill", [src_a, src_b])
|
||||
|
|
@ -1314,14 +1316,31 @@ class TestUnifiedSearchDedup:
|
|||
|
||||
def test_dedup_trusted_not_overwritten_by_community(self):
|
||||
trusted = SkillMeta(name="skill", description="trusted", source="a",
|
||||
identifier="a/skill", trust_level="trusted")
|
||||
identifier="shared/skill", trust_level="trusted")
|
||||
community = SkillMeta(name="skill", description="community", source="b",
|
||||
identifier="b/skill", trust_level="community")
|
||||
identifier="shared/skill", trust_level="community")
|
||||
src_a = self._make_source("a", [trusted])
|
||||
src_b = self._make_source("b", [community])
|
||||
results = unified_search("skill", [src_a, src_b])
|
||||
assert results[0].trust_level == "trusted"
|
||||
|
||||
def test_browse_sh_same_name_different_site_not_deduped(self):
|
||||
# Browse.sh skills from different hostnames share task names (e.g. "search-listings")
|
||||
# but have unique identifiers. They must NOT be collapsed into one result.
|
||||
airbnb = SkillMeta(
|
||||
name="search-listings", description="Airbnb search", source="browse-sh",
|
||||
identifier="browse-sh/airbnb.com/search-listings-ddgioa", trust_level="community",
|
||||
)
|
||||
booking = SkillMeta(
|
||||
name="search-listings", description="Booking.com search", source="browse-sh",
|
||||
identifier="browse-sh/booking.com/search-listings-xyzab", trust_level="community",
|
||||
)
|
||||
src = self._make_source("browse-sh", [airbnb, booking])
|
||||
results = unified_search("search-listings", [src])
|
||||
assert len(results) == 2, (
|
||||
"browse-sh skills with the same name but different sites must not be deduplicated"
|
||||
)
|
||||
|
||||
def test_source_filter(self):
|
||||
s1 = SkillMeta(name="s1", description="d", source="a",
|
||||
identifier="x", trust_level="community")
|
||||
|
|
|
|||
|
|
@ -2,11 +2,26 @@
|
|||
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
|
||||
from model_tools import get_tool_definitions
|
||||
|
||||
terminal_tool_module = importlib.import_module("tools.terminal_tool")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_caches():
|
||||
"""Invalidate check_fn and tool-definitions caches before each test
|
||||
so that monkeypatched env vars / config take effect."""
|
||||
from tools.registry import invalidate_check_fn_cache
|
||||
from model_tools import _clear_tool_defs_cache
|
||||
invalidate_check_fn_cache()
|
||||
_clear_tool_defs_cache()
|
||||
yield
|
||||
invalidate_check_fn_cache()
|
||||
_clear_tool_defs_cache()
|
||||
|
||||
|
||||
class TestTerminalRequirements:
|
||||
def test_local_backend_requirements(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
|
|
|
|||
|
|
@ -95,7 +95,9 @@ def _invoke_tool(home, cfg: dict, args: dict) -> dict:
|
|||
if hasattr(cfg_mod, "_invalidate_load_config_cache"):
|
||||
cfg_mod._invalidate_load_config_cache()
|
||||
|
||||
from tools.registry import registry
|
||||
from tools.registry import discover_builtin_tools, registry
|
||||
if "video_generate" not in registry._tools:
|
||||
discover_builtin_tools()
|
||||
handler = registry._tools["video_generate"].handler
|
||||
return json.loads(handler(args))
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ from typing import Any, Dict, List
|
|||
|
||||
import pytest
|
||||
|
||||
from tests.tools.conftest import register_all_web_providers
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ABC enforcement
|
||||
|
|
@ -276,6 +278,15 @@ class TestUnconfiguredErrorEnvelopeParity:
|
|||
``result.get("error")`` detect the failure cleanly.
|
||||
"""
|
||||
|
||||
_register_providers = staticmethod(register_all_web_providers)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _populate_web_registry(self):
|
||||
self._register_providers()
|
||||
yield
|
||||
from agent.web_search_registry import _reset_for_tests
|
||||
_reset_for_tests()
|
||||
|
||||
def _clear_web_creds(self, monkeypatch):
|
||||
for k in (
|
||||
"BRAVE_SEARCH_API_KEY",
|
||||
|
|
|
|||
|
|
@ -15,6 +15,10 @@ from __future__ import annotations
|
|||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tools.conftest import register_all_web_providers
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BraveFreeWebSearchProvider unit tests
|
||||
|
|
@ -239,6 +243,15 @@ class TestBraveFreeBackendWiring:
|
|||
|
||||
|
||||
class TestBraveFreeSearchOnlyErrors:
|
||||
_register_providers = staticmethod(register_all_web_providers)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _populate_web_registry(self):
|
||||
self._register_providers()
|
||||
yield
|
||||
from agent.web_search_registry import _reset_for_tests
|
||||
_reset_for_tests()
|
||||
|
||||
def test_web_extract_returns_search_only_error(self, monkeypatch):
|
||||
import asyncio
|
||||
from tools import web_tools
|
||||
|
|
@ -246,6 +259,7 @@ class TestBraveFreeSearchOnlyErrors:
|
|||
monkeypatch.setattr(web_tools, "_load_web_config", lambda: {"backend": "brave-free"})
|
||||
monkeypatch.setenv("BRAVE_SEARCH_API_KEY", "BSAkey123")
|
||||
monkeypatch.setattr(web_tools, "_is_tool_gateway_ready", lambda: False)
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False, raising=False)
|
||||
|
||||
result_str = asyncio.get_event_loop().run_until_complete(
|
||||
|
|
@ -264,6 +278,8 @@ class TestBraveFreeSearchOnlyErrors:
|
|||
monkeypatch.setenv("BRAVE_SEARCH_API_KEY", "BSAkey123")
|
||||
monkeypatch.setattr(web_tools, "_is_tool_gateway_ready", lambda: False)
|
||||
monkeypatch.setattr(web_tools, "check_firecrawl_api_key", lambda: False)
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
monkeypatch.setattr(web_tools, "check_website_access", lambda url: None)
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False, raising=False)
|
||||
|
||||
result_str = asyncio.get_event_loop().run_until_complete(
|
||||
|
|
|
|||
|
|
@ -14,6 +14,10 @@ import sys
|
|||
import types
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tools.conftest import register_all_web_providers
|
||||
|
||||
|
||||
def _install_fake_ddgs(monkeypatch, *, text_results=None, text_raises=None):
|
||||
"""Install a stub ``ddgs`` module in sys.modules for the duration of a test.
|
||||
|
|
@ -210,6 +214,15 @@ class TestDDGSBackendWiring:
|
|||
|
||||
|
||||
class TestDDGSSearchOnlyErrors:
|
||||
_register_providers = staticmethod(register_all_web_providers)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _populate_web_registry(self):
|
||||
self._register_providers()
|
||||
yield
|
||||
from agent.web_search_registry import _reset_for_tests
|
||||
_reset_for_tests()
|
||||
|
||||
def test_web_extract_returns_search_only_error(self, monkeypatch):
|
||||
import asyncio
|
||||
from tools import web_tools
|
||||
|
|
@ -217,6 +230,7 @@ class TestDDGSSearchOnlyErrors:
|
|||
monkeypatch.setattr(web_tools, "_load_web_config", lambda: {"backend": "ddgs"})
|
||||
monkeypatch.setattr(web_tools, "_ddgs_package_importable", lambda: True)
|
||||
monkeypatch.setattr(web_tools, "_is_tool_gateway_ready", lambda: False)
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False, raising=False)
|
||||
|
||||
result_str = asyncio.get_event_loop().run_until_complete(
|
||||
|
|
@ -235,6 +249,8 @@ class TestDDGSSearchOnlyErrors:
|
|||
monkeypatch.setattr(web_tools, "_ddgs_package_importable", lambda: True)
|
||||
monkeypatch.setattr(web_tools, "_is_tool_gateway_ready", lambda: False)
|
||||
monkeypatch.setattr(web_tools, "check_firecrawl_api_key", lambda: False)
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
monkeypatch.setattr(web_tools, "check_website_access", lambda url: None)
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False, raising=False)
|
||||
|
||||
result_str = asyncio.get_event_loop().run_until_complete(
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from tests.tools.conftest import register_all_web_providers
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SearXNGWebSearchProvider unit tests
|
||||
|
|
@ -301,6 +303,15 @@ class TestCheckWebApiKey:
|
|||
class TestSearXNGOnlyExtractCrawlErrors:
|
||||
"""When searxng is the active backend, extract/crawl must return clear errors."""
|
||||
|
||||
_register_providers = staticmethod(register_all_web_providers)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _populate_web_registry(self):
|
||||
self._register_providers()
|
||||
yield
|
||||
from agent.web_search_registry import _reset_for_tests
|
||||
_reset_for_tests()
|
||||
|
||||
def test_web_crawl_searxng_returns_clear_error(self, monkeypatch):
|
||||
import asyncio
|
||||
from tools import web_tools
|
||||
|
|
@ -309,6 +320,8 @@ class TestSearXNGOnlyExtractCrawlErrors:
|
|||
monkeypatch.setenv("SEARXNG_URL", "http://localhost:8080")
|
||||
monkeypatch.setattr(web_tools, "_is_tool_gateway_ready", lambda: False)
|
||||
monkeypatch.setattr(web_tools, "check_firecrawl_api_key", lambda: False)
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
monkeypatch.setattr(web_tools, "check_website_access", lambda url: None)
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False, raising=False)
|
||||
|
||||
import json
|
||||
|
|
@ -326,6 +339,7 @@ class TestSearXNGOnlyExtractCrawlErrors:
|
|||
monkeypatch.setattr(web_tools, "_load_web_config", lambda: {"backend": "searxng"})
|
||||
monkeypatch.setenv("SEARXNG_URL", "http://localhost:8080")
|
||||
monkeypatch.setattr(web_tools, "_is_tool_gateway_ready", lambda: False)
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False, raising=False)
|
||||
|
||||
import json
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ import asyncio
|
|||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from tests.tools.conftest import register_all_web_providers
|
||||
|
||||
|
||||
# ─── _tavily_request ─────────────────────────────────────────────────────────
|
||||
|
||||
|
|
@ -163,6 +165,15 @@ class TestNormalizeTavilyDocuments:
|
|||
class TestWebSearchTavily:
|
||||
"""Test web_search_tool dispatch to Tavily."""
|
||||
|
||||
_register_providers = staticmethod(register_all_web_providers)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _populate_web_registry(self):
|
||||
self._register_providers()
|
||||
yield
|
||||
from agent.web_search_registry import _reset_for_tests
|
||||
_reset_for_tests()
|
||||
|
||||
def test_search_dispatches_to_tavily(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
|
|
@ -186,6 +197,15 @@ class TestWebSearchTavily:
|
|||
class TestWebExtractTavily:
|
||||
"""Test web_extract_tool dispatch to Tavily."""
|
||||
|
||||
_register_providers = staticmethod(register_all_web_providers)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _populate_web_registry(self):
|
||||
self._register_providers()
|
||||
yield
|
||||
from agent.web_search_registry import _reset_for_tests
|
||||
_reset_for_tests()
|
||||
|
||||
def test_extract_dispatches_to_tavily(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
|
|
@ -211,6 +231,15 @@ class TestWebExtractTavily:
|
|||
class TestWebCrawlTavily:
|
||||
"""Test web_crawl_tool dispatch to Tavily."""
|
||||
|
||||
_register_providers = staticmethod(register_all_web_providers)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _populate_web_registry(self):
|
||||
self._register_providers()
|
||||
yield
|
||||
from agent.web_search_registry import _reset_for_tests
|
||||
_reset_for_tests()
|
||||
|
||||
def test_crawl_dispatches_to_tavily(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ from pathlib import Path
|
|||
import pytest
|
||||
import yaml
|
||||
|
||||
from tests.tools.conftest import register_all_web_providers
|
||||
|
||||
from tools.website_policy import WebsitePolicyError, check_website_access, load_website_blocklist
|
||||
|
||||
|
||||
|
|
@ -347,40 +349,191 @@ def test_browser_navigate_allows_when_shared_file_missing(monkeypatch, tmp_path)
|
|||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_extract_short_circuits_blocked_url(monkeypatch):
|
||||
from tools import web_tools
|
||||
from plugins.web.firecrawl import provider as firecrawl_provider
|
||||
class TestWebToolPolicy:
|
||||
"""Tests that exercise web_extract_tool / web_crawl_tool with website-policy gates.
|
||||
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
# The per-URL website-policy gate moved into the firecrawl plugin's
|
||||
# extract() during the web-provider migration. Patch it at the new
|
||||
# location; the dispatcher-level gate (used by web_crawl_tool's
|
||||
# pre-flight) still lives on tools.web_tools.
|
||||
monkeypatch.setattr(
|
||||
firecrawl_provider,
|
||||
"check_website_access",
|
||||
lambda url: {
|
||||
"host": "blocked.test",
|
||||
"rule": "blocked.test",
|
||||
"source": "config",
|
||||
"message": "Blocked by website policy",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
firecrawl_provider,
|
||||
"_get_firecrawl_client",
|
||||
lambda: pytest.fail("firecrawl should not run for blocked URL"),
|
||||
)
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
|
||||
# Force the firecrawl plugin to be the active extract provider.
|
||||
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
|
||||
These tests need the bundled web providers to be registered in the
|
||||
agent.web_search_registry so the tool dispatchers can find an active
|
||||
provider. Without registration, the tools return an error dict that
|
||||
lacks a ``results`` key, causing ``KeyError``.
|
||||
"""
|
||||
|
||||
result = json.loads(await web_tools.web_extract_tool(["https://blocked.test"], use_llm_processing=False))
|
||||
_register_providers = staticmethod(register_all_web_providers)
|
||||
|
||||
assert result["results"][0]["url"] == "https://blocked.test"
|
||||
assert "Blocked by website policy" in result["results"][0]["error"]
|
||||
@pytest.fixture(autouse=True)
|
||||
def _populate_web_registry(self):
|
||||
self._register_providers()
|
||||
yield
|
||||
from agent.web_search_registry import _reset_for_tests
|
||||
_reset_for_tests()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_extract_short_circuits_blocked_url(self, monkeypatch):
|
||||
from tools import web_tools
|
||||
from plugins.web.firecrawl import provider as firecrawl_provider
|
||||
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
# The per-URL website-policy gate moved into the firecrawl plugin's
|
||||
# extract() during the web-provider migration. Patch it at the new
|
||||
# location; the dispatcher-level gate (used by web_crawl_tool's
|
||||
# pre-flight) still lives on tools.web_tools.
|
||||
monkeypatch.setattr(
|
||||
firecrawl_provider,
|
||||
"check_website_access",
|
||||
lambda url: {
|
||||
"host": "blocked.test",
|
||||
"rule": "blocked.test",
|
||||
"source": "config",
|
||||
"message": "Blocked by website policy",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
firecrawl_provider,
|
||||
"_get_firecrawl_client",
|
||||
lambda: pytest.fail("firecrawl should not run for blocked URL"),
|
||||
)
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
|
||||
# Force the firecrawl plugin to be the active extract provider.
|
||||
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
|
||||
|
||||
result = json.loads(await web_tools.web_extract_tool(["https://blocked.test"], use_llm_processing=False))
|
||||
|
||||
assert result["results"][0]["url"] == "https://blocked.test"
|
||||
assert "Blocked by website policy" in result["results"][0]["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_extract_blocks_redirected_final_url(self, monkeypatch):
|
||||
from tools import web_tools
|
||||
from plugins.web.firecrawl import provider as firecrawl_provider
|
||||
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
|
||||
def fake_check(url):
|
||||
if url == "https://allowed.test":
|
||||
return None
|
||||
if url == "https://blocked.test/final":
|
||||
return {
|
||||
"host": "blocked.test",
|
||||
"rule": "blocked.test",
|
||||
"source": "config",
|
||||
"message": "Blocked by website policy",
|
||||
}
|
||||
pytest.fail(f"unexpected URL checked: {url}")
|
||||
|
||||
class FakeFirecrawlClient:
|
||||
def scrape(self, url, formats):
|
||||
return {
|
||||
"markdown": "secret content",
|
||||
"metadata": {
|
||||
"title": "Redirected",
|
||||
"sourceURL": "https://blocked.test/final",
|
||||
},
|
||||
}
|
||||
|
||||
# After the web-provider migration, the per-URL gate + firecrawl client
|
||||
# live in the plugin. Patch both at the plugin location.
|
||||
monkeypatch.setattr(firecrawl_provider, "check_website_access", fake_check)
|
||||
monkeypatch.setattr(firecrawl_provider, "_get_firecrawl_client", lambda: FakeFirecrawlClient())
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
|
||||
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
|
||||
|
||||
result = json.loads(await web_tools.web_extract_tool(["https://allowed.test"], use_llm_processing=False))
|
||||
|
||||
assert result["results"][0]["url"] == "https://blocked.test/final"
|
||||
assert result["results"][0]["content"] == ""
|
||||
assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_crawl_short_circuits_blocked_url(self, monkeypatch):
|
||||
from tools import web_tools
|
||||
|
||||
# web_crawl_tool checks for Firecrawl env before website policy
|
||||
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
# The dispatcher-level (seed-URL) policy gate still lives on web_tools.
|
||||
# No per-page gate runs in this test because the dispatcher returns
|
||||
# immediately when the seed is blocked, before delegating to the plugin.
|
||||
monkeypatch.setattr(
|
||||
web_tools,
|
||||
"check_website_access",
|
||||
lambda url: {
|
||||
"host": "blocked.test",
|
||||
"rule": "blocked.test",
|
||||
"source": "config",
|
||||
"message": "Blocked by website policy",
|
||||
},
|
||||
)
|
||||
# If the dispatcher ever reaches the firecrawl plugin's crawl(), the test
|
||||
# fails — pin the plugin module's client lookup so we'd notice.
|
||||
from plugins.web.firecrawl import provider as firecrawl_provider
|
||||
monkeypatch.setattr(
|
||||
firecrawl_provider,
|
||||
"_get_firecrawl_client",
|
||||
lambda: pytest.fail("firecrawl plugin should not run for blocked crawl URL"),
|
||||
)
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
|
||||
|
||||
result = json.loads(await web_tools.web_crawl_tool("https://blocked.test", use_llm_processing=False))
|
||||
|
||||
assert result["results"][0]["url"] == "https://blocked.test"
|
||||
assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_crawl_blocks_redirected_final_url(self, monkeypatch):
|
||||
from tools import web_tools
|
||||
from plugins.web.firecrawl import provider as firecrawl_provider
|
||||
|
||||
# Force the firecrawl plugin to be the active crawl provider.
|
||||
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
|
||||
def fake_check(url):
|
||||
# Dispatcher seed-URL gate (web_tools.check_website_access call)
|
||||
# and plugin per-page gate (firecrawl_provider.check_website_access
|
||||
# call) both flow through this single fake_check.
|
||||
if url == "https://allowed.test":
|
||||
return None
|
||||
if url == "https://blocked.test/final":
|
||||
return {
|
||||
"host": "blocked.test",
|
||||
"rule": "blocked.test",
|
||||
"source": "config",
|
||||
"message": "Blocked by website policy",
|
||||
}
|
||||
pytest.fail(f"unexpected URL checked: {url}")
|
||||
|
||||
class FakeCrawlClient:
|
||||
def crawl(self, url, **kwargs):
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"markdown": "secret crawl content",
|
||||
"metadata": {
|
||||
"title": "Redirected crawl page",
|
||||
"sourceURL": "https://blocked.test/final",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# After PR #25182 follow-up: per-page policy gate lives in
|
||||
# plugins.web.firecrawl.provider.crawl(). Patch the gate + client at
|
||||
# the plugin location. The dispatcher-level (seed) gate also reads
|
||||
# web_tools.check_website_access — patch both.
|
||||
monkeypatch.setattr(web_tools, "check_website_access", fake_check)
|
||||
monkeypatch.setattr(firecrawl_provider, "check_website_access", fake_check)
|
||||
monkeypatch.setattr(firecrawl_provider, "_get_firecrawl_client", lambda: FakeCrawlClient())
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
|
||||
|
||||
result = json.loads(await web_tools.web_crawl_tool("https://allowed.test", use_llm_processing=False))
|
||||
|
||||
assert result["results"][0]["content"] == ""
|
||||
assert result["results"][0]["error"] == "Blocked by website policy"
|
||||
assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test"
|
||||
|
||||
|
||||
def test_check_website_access_fails_open_on_malformed_config(tmp_path, monkeypatch):
|
||||
|
|
@ -400,139 +553,3 @@ def test_check_website_access_fails_open_on_malformed_config(tmp_path, monkeypat
|
|||
# With default path, errors are caught and fail open
|
||||
result = check_website_access("https://example.com")
|
||||
assert result is None # allowed, not crashed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_extract_blocks_redirected_final_url(monkeypatch):
|
||||
from tools import web_tools
|
||||
from plugins.web.firecrawl import provider as firecrawl_provider
|
||||
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
|
||||
def fake_check(url):
|
||||
if url == "https://allowed.test":
|
||||
return None
|
||||
if url == "https://blocked.test/final":
|
||||
return {
|
||||
"host": "blocked.test",
|
||||
"rule": "blocked.test",
|
||||
"source": "config",
|
||||
"message": "Blocked by website policy",
|
||||
}
|
||||
pytest.fail(f"unexpected URL checked: {url}")
|
||||
|
||||
class FakeFirecrawlClient:
|
||||
def scrape(self, url, formats):
|
||||
return {
|
||||
"markdown": "secret content",
|
||||
"metadata": {
|
||||
"title": "Redirected",
|
||||
"sourceURL": "https://blocked.test/final",
|
||||
},
|
||||
}
|
||||
|
||||
# After the web-provider migration, the per-URL gate + firecrawl client
|
||||
# live in the plugin. Patch both at the plugin location.
|
||||
monkeypatch.setattr(firecrawl_provider, "check_website_access", fake_check)
|
||||
monkeypatch.setattr(firecrawl_provider, "_get_firecrawl_client", lambda: FakeFirecrawlClient())
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
|
||||
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
|
||||
|
||||
result = json.loads(await web_tools.web_extract_tool(["https://allowed.test"], use_llm_processing=False))
|
||||
|
||||
assert result["results"][0]["url"] == "https://blocked.test/final"
|
||||
assert result["results"][0]["content"] == ""
|
||||
assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_crawl_short_circuits_blocked_url(monkeypatch):
|
||||
from tools import web_tools
|
||||
|
||||
# web_crawl_tool checks for Firecrawl env before website policy
|
||||
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
# The dispatcher-level (seed-URL) policy gate still lives on web_tools.
|
||||
# No per-page gate runs in this test because the dispatcher returns
|
||||
# immediately when the seed is blocked, before delegating to the plugin.
|
||||
monkeypatch.setattr(
|
||||
web_tools,
|
||||
"check_website_access",
|
||||
lambda url: {
|
||||
"host": "blocked.test",
|
||||
"rule": "blocked.test",
|
||||
"source": "config",
|
||||
"message": "Blocked by website policy",
|
||||
},
|
||||
)
|
||||
# If the dispatcher ever reaches the firecrawl plugin's crawl(), the test
|
||||
# fails — pin the plugin module's client lookup so we'd notice.
|
||||
from plugins.web.firecrawl import provider as firecrawl_provider
|
||||
monkeypatch.setattr(
|
||||
firecrawl_provider,
|
||||
"_get_firecrawl_client",
|
||||
lambda: pytest.fail("firecrawl plugin should not run for blocked crawl URL"),
|
||||
)
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
|
||||
|
||||
result = json.loads(await web_tools.web_crawl_tool("https://blocked.test", use_llm_processing=False))
|
||||
|
||||
assert result["results"][0]["url"] == "https://blocked.test"
|
||||
assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_crawl_blocks_redirected_final_url(monkeypatch):
|
||||
from tools import web_tools
|
||||
from plugins.web.firecrawl import provider as firecrawl_provider
|
||||
|
||||
# Force the firecrawl plugin to be the active crawl provider.
|
||||
monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key")
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
|
||||
def fake_check(url):
|
||||
# Dispatcher seed-URL gate (web_tools.check_website_access call)
|
||||
# and plugin per-page gate (firecrawl_provider.check_website_access
|
||||
# call) both flow through this single fake_check.
|
||||
if url == "https://allowed.test":
|
||||
return None
|
||||
if url == "https://blocked.test/final":
|
||||
return {
|
||||
"host": "blocked.test",
|
||||
"rule": "blocked.test",
|
||||
"source": "config",
|
||||
"message": "Blocked by website policy",
|
||||
}
|
||||
pytest.fail(f"unexpected URL checked: {url}")
|
||||
|
||||
class FakeCrawlClient:
|
||||
def crawl(self, url, **kwargs):
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"markdown": "secret crawl content",
|
||||
"metadata": {
|
||||
"title": "Redirected crawl page",
|
||||
"sourceURL": "https://blocked.test/final",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# After PR #25182 follow-up: per-page policy gate lives in
|
||||
# plugins.web.firecrawl.provider.crawl(). Patch the gate + client at
|
||||
# the plugin location. The dispatcher-level (seed) gate also reads
|
||||
# web_tools.check_website_access — patch both.
|
||||
monkeypatch.setattr(web_tools, "check_website_access", fake_check)
|
||||
monkeypatch.setattr(firecrawl_provider, "check_website_access", fake_check)
|
||||
monkeypatch.setattr(firecrawl_provider, "_get_firecrawl_client", lambda: FakeCrawlClient())
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
|
||||
|
||||
result = json.loads(await web_tools.web_crawl_tool("https://allowed.test", use_llm_processing=False))
|
||||
|
||||
assert result["results"][0]["content"] == ""
|
||||
assert result["results"][0]["error"] == "Blocked by website policy"
|
||||
assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test"
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
"""Tests for _is_write_denied() — verifies deny list blocks sensitive paths on all platforms."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from tools.file_operations import _is_write_denied
|
||||
|
||||
|
|
@ -41,6 +43,31 @@ class TestWriteDenyExactPaths:
|
|||
path = str(get_hermes_home() / ".env")
|
||||
assert _is_write_denied(path) is True
|
||||
|
||||
def test_hermes_root_env_when_running_under_profile(self, tmp_path, monkeypatch):
|
||||
"""Top-level ``<root>/.env`` stays write-denied even when running under
|
||||
a profile (#15981).
|
||||
|
||||
Before the fix, ``build_write_denied_paths`` only added
|
||||
``<active_profile>/.env`` to the deny list, so the global
|
||||
``~/.hermes/.env`` (whose credentials are inherited by every profile)
|
||||
could be silently overwritten by ``write_file`` while a profile was
|
||||
active.
|
||||
"""
|
||||
root = tmp_path / "hermes_root"
|
||||
profile_home = root / "profiles" / "coder"
|
||||
profile_home.mkdir(parents=True)
|
||||
global_env = root / ".env"
|
||||
global_env.write_text("OPENAI_API_KEY=sk-real\n")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(profile_home))
|
||||
|
||||
# Sanity check: HERMES_HOME does point to the profile dir, not the root.
|
||||
from hermes_constants import get_hermes_home, get_default_hermes_root
|
||||
assert get_hermes_home() == profile_home
|
||||
assert get_default_hermes_root() == root
|
||||
|
||||
assert _is_write_denied(str(global_env)) is True
|
||||
|
||||
def test_shell_profiles(self):
|
||||
home = str(Path.home())
|
||||
for name in [".bashrc", ".zshrc", ".profile", ".bash_profile", ".zprofile"]:
|
||||
|
|
@ -72,8 +99,22 @@ class TestWriteDenyPrefixes:
|
|||
def test_sudoers_d_prefix(self):
|
||||
assert _is_write_denied("/etc/sudoers.d/custom") is True
|
||||
|
||||
def test_systemd_prefix(self):
|
||||
assert _is_write_denied("/etc/systemd/system/evil.service") is True
|
||||
def test_systemd_prefix(self, tmp_path):
|
||||
# On NixOS, /etc/systemd is a symlink into /nix/store, so
|
||||
# realpath() resolves it to a store path that doesn't match
|
||||
# the /etc/systemd/ prefix. Build a real directory tree so
|
||||
# realpath is a no-op and prefix matching works.
|
||||
fake_etc = tmp_path / "etc" / "systemd" / "system"
|
||||
fake_etc.mkdir(parents=True)
|
||||
target = str(fake_etc / "evil.service")
|
||||
# Patch the prefix builder to include our tmp_path prefix
|
||||
import agent.file_safety as _fs
|
||||
_orig = _fs.build_write_denied_prefixes
|
||||
_extra_prefix = str(tmp_path / "etc" / "systemd") + os.sep
|
||||
def _patched(home):
|
||||
return _orig(home) + [_extra_prefix]
|
||||
with patch.object(_fs, "build_write_denied_prefixes", _patched):
|
||||
assert _is_write_denied(target) is True
|
||||
|
||||
|
||||
class TestWriteAllowed:
|
||||
|
|
|
|||
|
|
@ -436,3 +436,290 @@ def test_x_search_registered_in_registry_with_check_fn():
|
|||
assert entry.check_fn.__name__ == "check_x_search_requirements"
|
||||
assert "XAI_API_KEY" in entry.requires_env
|
||||
assert entry.emoji == "🐦"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Date validation — fail fast before burning an API call on a window that
|
||||
# cannot possibly return X posts. xAI itself happily 200s with a fluff
|
||||
# answer when the range is malformed or pure-future, which is hard for
|
||||
# callers to distinguish from a real result.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _no_post_allowed(monkeypatch):
|
||||
"""Guard: any test that should fail before HTTP can hit this fence."""
|
||||
def _fail(*_, **__):
|
||||
raise AssertionError("requests.post must not be called — validation should reject first")
|
||||
|
||||
monkeypatch.setattr("requests.post", _fail)
|
||||
|
||||
|
||||
def test_x_search_rejects_malformed_from_date(monkeypatch):
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
_no_post_allowed(monkeypatch)
|
||||
|
||||
result = json.loads(x_search_tool(query="anything", from_date="not-a-date"))
|
||||
|
||||
assert "from_date must be YYYY-MM-DD" in result["error"]
|
||||
|
||||
|
||||
def test_x_search_rejects_malformed_to_date(monkeypatch):
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
_no_post_allowed(monkeypatch)
|
||||
|
||||
result = json.loads(x_search_tool(query="anything", to_date="2026/05/01"))
|
||||
|
||||
assert "to_date must be YYYY-MM-DD" in result["error"]
|
||||
|
||||
|
||||
def test_x_search_rejects_inverted_date_range(monkeypatch):
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
_no_post_allowed(monkeypatch)
|
||||
|
||||
result = json.loads(
|
||||
x_search_tool(
|
||||
query="anything",
|
||||
from_date="2026-05-10",
|
||||
to_date="2026-05-01",
|
||||
)
|
||||
)
|
||||
|
||||
assert "from_date (2026-05-10) must be on or before to_date (2026-05-01)" in result["error"]
|
||||
|
||||
|
||||
def test_x_search_rejects_future_from_date(monkeypatch):
|
||||
"""``from_date`` in the future can never match any post → reject."""
|
||||
import datetime as _dt
|
||||
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
_no_post_allowed(monkeypatch)
|
||||
|
||||
class _FrozenDateTime(_dt.datetime):
|
||||
@classmethod
|
||||
def now(cls, tz=None):
|
||||
return _dt.datetime(2026, 5, 21, 12, 0, 0, tzinfo=tz or _dt.timezone.utc)
|
||||
|
||||
monkeypatch.setattr("tools.x_search_tool.datetime", _FrozenDateTime)
|
||||
|
||||
result = json.loads(x_search_tool(query="anything", from_date="2030-01-01"))
|
||||
|
||||
assert "from_date (2030-01-01) is in the future" in result["error"]
|
||||
|
||||
|
||||
def test_x_search_allows_future_to_date(monkeypatch):
|
||||
"""``to_date`` in the future is fine — caller may want posts as they arrive."""
|
||||
import datetime as _dt
|
||||
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
|
||||
class _FrozenDateTime(_dt.datetime):
|
||||
@classmethod
|
||||
def now(cls, tz=None):
|
||||
return _dt.datetime(2026, 5, 21, 12, 0, 0, tzinfo=tz or _dt.timezone.utc)
|
||||
|
||||
monkeypatch.setattr("tools.x_search_tool.datetime", _FrozenDateTime)
|
||||
|
||||
def _fake_post(url, headers=None, json=None, timeout=None):
|
||||
return _FakeResponse(
|
||||
{"output_text": "future to_date is allowed", "citations": []}
|
||||
)
|
||||
|
||||
monkeypatch.setattr("requests.post", _fake_post)
|
||||
|
||||
result = json.loads(
|
||||
x_search_tool(
|
||||
query="anything",
|
||||
from_date="2026-05-20",
|
||||
to_date="2030-01-01",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["answer"] == "future to_date is allowed"
|
||||
|
||||
|
||||
def test_x_search_accepts_today_as_from_date(monkeypatch):
|
||||
"""``from_date == today UTC`` is a valid edge case (today is past + present)."""
|
||||
import datetime as _dt
|
||||
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
|
||||
class _FrozenDateTime(_dt.datetime):
|
||||
@classmethod
|
||||
def now(cls, tz=None):
|
||||
return _dt.datetime(2026, 5, 21, 12, 0, 0, tzinfo=tz or _dt.timezone.utc)
|
||||
|
||||
monkeypatch.setattr("tools.x_search_tool.datetime", _FrozenDateTime)
|
||||
monkeypatch.setattr(
|
||||
"requests.post",
|
||||
lambda *a, **k: _FakeResponse({"output_text": "ok", "citations": []}),
|
||||
)
|
||||
|
||||
result = json.loads(x_search_tool(query="anything", from_date="2026-05-21"))
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Degraded-result flag — distinguish citation-backed answers from
|
||||
# unsourced fluff when narrowing filters returned nothing.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_x_search_marks_degraded_when_handle_filter_returns_no_citations(monkeypatch):
|
||||
"""allowed_x_handles set + zero citations → degraded=True."""
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
monkeypatch.setattr(
|
||||
"requests.post",
|
||||
lambda *a, **k: _FakeResponse(
|
||||
{"output_text": "Generic encyclopedic answer with no citations.", "citations": []}
|
||||
),
|
||||
)
|
||||
|
||||
result = json.loads(
|
||||
x_search_tool(query="what has @ghostuser posted", allowed_x_handles=["ghostuser"])
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["degraded"] is True
|
||||
assert "allowed_x_handles" in result["degraded_reason"]
|
||||
|
||||
|
||||
def test_x_search_marks_degraded_when_excluded_handles_and_no_citations(monkeypatch):
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
monkeypatch.setattr(
|
||||
"requests.post",
|
||||
lambda *a, **k: _FakeResponse({"output_text": "fluff", "citations": []}),
|
||||
)
|
||||
|
||||
result = json.loads(
|
||||
x_search_tool(query="anything", excluded_x_handles=["someuser"])
|
||||
)
|
||||
|
||||
assert result["degraded"] is True
|
||||
assert "excluded_x_handles" in result["degraded_reason"]
|
||||
|
||||
|
||||
def test_x_search_marks_degraded_when_date_range_and_no_citations(monkeypatch):
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
monkeypatch.setattr(
|
||||
"requests.post",
|
||||
lambda *a, **k: _FakeResponse({"output_text": "fluff", "citations": []}),
|
||||
)
|
||||
|
||||
result = json.loads(
|
||||
x_search_tool(
|
||||
query="anything",
|
||||
from_date="2026-04-01",
|
||||
to_date="2026-04-02",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["degraded"] is True
|
||||
assert "from_date" in result["degraded_reason"]
|
||||
assert "to_date" in result["degraded_reason"]
|
||||
|
||||
|
||||
def test_x_search_not_degraded_when_filter_returns_inline_citations(monkeypatch):
|
||||
"""A real citation from the inline annotations clears the degraded flag."""
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
monkeypatch.setattr(
|
||||
"requests.post",
|
||||
lambda *a, **k: _FakeResponse(
|
||||
{
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "Real post from xai.",
|
||||
"annotations": [
|
||||
{
|
||||
"type": "url_citation",
|
||||
"url": "https://x.com/xai/status/1",
|
||||
"title": "xAI post",
|
||||
"start_index": 0,
|
||||
"end_index": 4,
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
result = json.loads(
|
||||
x_search_tool(query="latest xAI post", allowed_x_handles=["xai"])
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["degraded"] is False
|
||||
assert result["degraded_reason"] is None
|
||||
assert len(result["inline_citations"]) == 1
|
||||
|
||||
|
||||
def test_x_search_not_degraded_when_filter_returns_top_level_citations(monkeypatch):
|
||||
"""A real citation from xAI's top-level ``citations`` array also clears the flag."""
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
monkeypatch.setattr(
|
||||
"requests.post",
|
||||
lambda *a, **k: _FakeResponse(
|
||||
{
|
||||
"output_text": "Found discussion.",
|
||||
"citations": [{"url": "https://x.com/example/status/1", "title": "Example"}],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
result = json.loads(
|
||||
x_search_tool(query="anything", allowed_x_handles=["xai"])
|
||||
)
|
||||
|
||||
assert result["degraded"] is False
|
||||
assert result["degraded_reason"] is None
|
||||
|
||||
|
||||
def test_x_search_not_degraded_when_no_filters_active(monkeypatch):
|
||||
"""A broad query that returns no citations isn't necessarily degraded.
|
||||
|
||||
Without any narrowing filter, an empty-citations response is a generic
|
||||
unsourced answer, not a "filter miss". The caller can already tell from
|
||||
``inline_citations == []`` if they care.
|
||||
"""
|
||||
from tools.x_search_tool import x_search_tool
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "xai-test-key")
|
||||
monkeypatch.setattr(
|
||||
"requests.post",
|
||||
lambda *a, **k: _FakeResponse({"output_text": "broad answer", "citations": []}),
|
||||
)
|
||||
|
||||
result = json.loads(x_search_tool(query="anything"))
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["degraded"] is False
|
||||
assert result["degraded_reason"] is None
|
||||
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ from http.server import BaseHTTPRequestHandler, HTTPServer
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from hermes_constants import secure_parent_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -175,10 +176,8 @@ def _write_json(path: Path, data: dict) -> None:
|
|||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Tighten parent dir to 0o700 so siblings can't traverse to the creds.
|
||||
# No-op on Windows (POSIX mode bits aren't enforced); ignore failures.
|
||||
try:
|
||||
os.chmod(path.parent, 0o700)
|
||||
except OSError:
|
||||
pass
|
||||
# secure_parent_dir refuses to chmod / or top-level dirs (#25821).
|
||||
secure_parent_dir(path)
|
||||
# Per-process random suffix avoids collisions between concurrent
|
||||
# writers and stale leftovers from a prior crashed write.
|
||||
tmp = path.with_suffix(f".tmp.{os.getpid()}.{secrets.token_hex(4)}")
|
||||
|
|
|
|||
|
|
@ -379,14 +379,16 @@ class GitHubSource(SkillSource):
|
|||
logger.debug(f"Failed to search {tap['repo']}: {e}")
|
||||
continue
|
||||
|
||||
# Deduplicate by name, preferring higher trust levels
|
||||
# Deduplicate by identifier, preferring higher trust levels.
|
||||
# identifier is unique per skill; name is not (two configured taps can
|
||||
# publish skills with the same name but different identifiers).
|
||||
_trust_rank = {"builtin": 2, "trusted": 1, "community": 0}
|
||||
seen = {}
|
||||
for r in results:
|
||||
if r.name not in seen:
|
||||
seen[r.name] = r
|
||||
elif _trust_rank.get(r.trust_level, 0) > _trust_rank.get(seen[r.name].trust_level, 0):
|
||||
seen[r.name] = r
|
||||
if r.identifier not in seen:
|
||||
seen[r.identifier] = r
|
||||
elif _trust_rank.get(r.trust_level, 0) > _trust_rank.get(seen[r.identifier].trust_level, 0):
|
||||
seen[r.identifier] = r
|
||||
results = list(seen.values())
|
||||
|
||||
return results[:limit]
|
||||
|
|
@ -3425,14 +3427,17 @@ def unified_search(query: str, sources: List[SkillSource],
|
|||
overall_timeout=30,
|
||||
)
|
||||
|
||||
# Deduplicate by name, preferring higher trust levels
|
||||
# Deduplicate by identifier, preferring higher trust levels.
|
||||
# identifier is always unique per skill (e.g. "browse-sh/airbnb.com/search-listings-ddgioa").
|
||||
# Using name would incorrectly collapse browse-sh skills from different sites that share
|
||||
# the same task name (e.g. "search-listings" from Airbnb and Booking.com).
|
||||
_TRUST_RANK = {"builtin": 2, "trusted": 1, "community": 0}
|
||||
seen: Dict[str, SkillMeta] = {}
|
||||
for r in all_results:
|
||||
if r.name not in seen:
|
||||
seen[r.name] = r
|
||||
elif _TRUST_RANK.get(r.trust_level, 0) > _TRUST_RANK.get(seen[r.name].trust_level, 0):
|
||||
seen[r.name] = r
|
||||
if r.identifier not in seen:
|
||||
seen[r.identifier] = r
|
||||
elif _TRUST_RANK.get(r.trust_level, 0) > _TRUST_RANK.get(seen[r.identifier].trust_level, 0):
|
||||
seen[r.identifier] = r
|
||||
deduped = list(seen.values())
|
||||
|
||||
return deduped[:limit]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,24 @@ auto-refreshes the OAuth access token when it's within the refresh skew
|
|||
window, so a ``True`` from :func:`check_x_search_requirements` means the
|
||||
bearer is fetchable AND non-empty.
|
||||
|
||||
Defensive output
|
||||
----------------
|
||||
The tool surfaces two additional signals beyond xAI's raw response so callers
|
||||
can tell a real citation-backed answer from an unsourced one:
|
||||
|
||||
* ``from_date`` / ``to_date`` are validated client-side before the HTTP call.
|
||||
Malformed (non ``YYYY-MM-DD``), inverted (``from_date > to_date``), and
|
||||
pure-future ranges (``from_date`` later than today UTC) fail fast with a
|
||||
clear error instead of burning an API call. ``to_date`` in the future is
|
||||
still allowed so callers can legitimately request "from yesterday to
|
||||
tomorrow".
|
||||
* Successful responses carry ``degraded`` and ``degraded_reason`` fields.
|
||||
``degraded`` is ``True`` when any narrowing filter (handles or dates) was
|
||||
active AND xAI returned no citations in either the top-level ``citations``
|
||||
array or the inline ``url_citation`` annotations. In that case the
|
||||
``answer`` came from the model's own knowledge rather than the X index,
|
||||
and the caller should treat the result as unsourced.
|
||||
|
||||
Salvaged from PR #10786 (originally by @Jaaneek); credential resolution
|
||||
reworked to honor both auth modes per Teknium's design.
|
||||
"""
|
||||
|
|
@ -28,6 +46,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import date, datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
|
@ -136,6 +155,57 @@ def _normalize_handles(handles: Optional[List[str]], field_name: str) -> List[st
|
|||
return cleaned
|
||||
|
||||
|
||||
def _parse_iso_date(value: str, field_name: str) -> date:
|
||||
"""Parse a strict YYYY-MM-DD string into a ``date``.
|
||||
|
||||
xAI accepts any string in the ``from_date``/``to_date`` slots and silently
|
||||
returns an answer with no citations when the value is malformed or refers
|
||||
to a window where no posts can exist. That behavior burns a billable API
|
||||
call and produces a confident-sounding fluff answer that's hard for callers
|
||||
to distinguish from a real result. Validating client-side fails fast and
|
||||
gives the agent a clear error to act on.
|
||||
"""
|
||||
raw = value.strip()
|
||||
try:
|
||||
return datetime.strptime(raw, "%Y-%m-%d").date()
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f"{field_name} must be YYYY-MM-DD (got {raw!r})"
|
||||
) from exc
|
||||
|
||||
|
||||
def _validate_date_range(from_date: str, to_date: str) -> None:
|
||||
"""Validate ``from_date`` / ``to_date`` before they reach xAI.
|
||||
|
||||
Rules:
|
||||
* Either field, if non-empty, must parse as ``YYYY-MM-DD``.
|
||||
* When both are set, ``from_date <= to_date``.
|
||||
* ``from_date`` must not be later than today UTC — no posts can exist
|
||||
in a window that hasn't started yet, so the call would be guaranteed
|
||||
to return zero citations. ``to_date`` in the future is allowed
|
||||
(callers may legitimately set "from yesterday to tomorrow").
|
||||
"""
|
||||
parsed_from: Optional[date] = None
|
||||
parsed_to: Optional[date] = None
|
||||
if from_date.strip():
|
||||
parsed_from = _parse_iso_date(from_date, "from_date")
|
||||
if to_date.strip():
|
||||
parsed_to = _parse_iso_date(to_date, "to_date")
|
||||
if parsed_from and parsed_to and parsed_from > parsed_to:
|
||||
raise ValueError(
|
||||
f"from_date ({parsed_from.isoformat()}) must be on or before "
|
||||
f"to_date ({parsed_to.isoformat()})"
|
||||
)
|
||||
if parsed_from is not None:
|
||||
today_utc = datetime.now(timezone.utc).date()
|
||||
if parsed_from > today_utc:
|
||||
raise ValueError(
|
||||
f"from_date ({parsed_from.isoformat()}) is in the future; "
|
||||
f"X Search only indexes past posts (today UTC is "
|
||||
f"{today_utc.isoformat()})"
|
||||
)
|
||||
|
||||
|
||||
def _extract_response_text(payload: Dict[str, Any]) -> str:
|
||||
output_text = str(payload.get("output_text") or "").strip()
|
||||
if output_text:
|
||||
|
|
@ -225,6 +295,11 @@ def x_search_tool(
|
|||
if allowed and excluded:
|
||||
return tool_error("allowed_x_handles and excluded_x_handles cannot be used together")
|
||||
|
||||
try:
|
||||
_validate_date_range(from_date, to_date)
|
||||
except ValueError as exc:
|
||||
return tool_error(str(exc))
|
||||
|
||||
tool_def: Dict[str, Any] = {"type": "x_search"}
|
||||
if allowed:
|
||||
tool_def["allowed_x_handles"] = allowed
|
||||
|
|
@ -299,6 +374,31 @@ def x_search_tool(
|
|||
citations = list(data.get("citations") or [])
|
||||
inline_citations = _extract_inline_citations(data)
|
||||
|
||||
# Degraded-result detection.
|
||||
#
|
||||
# xAI returns 200 OK with a synthesized answer even when its X index
|
||||
# has no posts matching the caller's narrowing filters. The answer
|
||||
# then comes from the model's training data, which is misleading
|
||||
# because it looks identical to a real, citation-backed result. When
|
||||
# any narrowing filter is active AND both citation channels came back
|
||||
# empty, mark the response as degraded so callers can decide to
|
||||
# broaden filters, retry, or fall back to a different source.
|
||||
active_filters: List[str] = []
|
||||
if allowed:
|
||||
active_filters.append("allowed_x_handles")
|
||||
if excluded:
|
||||
active_filters.append("excluded_x_handles")
|
||||
if from_date.strip():
|
||||
active_filters.append("from_date")
|
||||
if to_date.strip():
|
||||
active_filters.append("to_date")
|
||||
degraded = bool(active_filters) and not citations and not inline_citations
|
||||
degraded_reason = (
|
||||
f"no citations returned despite filters: {', '.join(active_filters)}"
|
||||
if degraded
|
||||
else None
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
|
|
@ -310,6 +410,8 @@ def x_search_tool(
|
|||
"answer": answer,
|
||||
"citations": citations,
|
||||
"inline_citations": inline_citations,
|
||||
"degraded": degraded,
|
||||
"degraded_reason": degraded_reason,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,15 +0,0 @@
|
|||
module.exports = {
|
||||
assumptions: {
|
||||
setPublicClassFields: true
|
||||
},
|
||||
plugins: [
|
||||
[
|
||||
'babel-plugin-react-compiler',
|
||||
{
|
||||
target: '19',
|
||||
sources: filename => Boolean(filename && !filename.includes('node_modules'))
|
||||
}
|
||||
]
|
||||
],
|
||||
babelrc: false
|
||||
}
|
||||
424
ui-tui/package-lock.json
generated
424
ui-tui/package-lock.json
generated
|
|
@ -17,15 +17,11 @@
|
|||
"unicode-animations": "^1.0.3"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@babel/cli": "^7.28.6",
|
||||
"@babel/core": "^7.29.0",
|
||||
"@babel/plugin-syntax-jsx": "^7.28.6",
|
||||
"@eslint/js": "^9",
|
||||
"@types/node": "^25.5.0",
|
||||
"@types/react": "^19.2.14",
|
||||
"@typescript-eslint/eslint-plugin": "^8",
|
||||
"@typescript-eslint/parser": "^8",
|
||||
"babel-plugin-react-compiler": "^1.0.0",
|
||||
"esbuild": "~0.27.0",
|
||||
"eslint": "^9",
|
||||
"eslint-plugin-perfectionist": "^5",
|
||||
|
|
@ -65,36 +61,6 @@
|
|||
"url": "https://github.com/chalk/ansi-styles?sponsor=1"
|
||||
}
|
||||
},
|
||||
"node_modules/@babel/cli": {
|
||||
"version": "7.28.6",
|
||||
"resolved": "https://registry.npmjs.org/@babel/cli/-/cli-7.28.6.tgz",
|
||||
"integrity": "sha512-6EUNcuBbNkj08Oj4gAZ+BUU8yLCgKzgVX4gaTh09Ya2C8ICM4P+G30g4m3akRxSYAp3A/gnWchrNst7px4/nUQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@jridgewell/trace-mapping": "^0.3.28",
|
||||
"commander": "^6.2.0",
|
||||
"convert-source-map": "^2.0.0",
|
||||
"fs-readdir-recursive": "^1.1.0",
|
||||
"glob": "^7.2.0",
|
||||
"make-dir": "^2.1.0",
|
||||
"slash": "^2.0.0"
|
||||
},
|
||||
"bin": {
|
||||
"babel": "bin/babel.js",
|
||||
"babel-external-helpers": "bin/babel-external-helpers.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=6.9.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@nicolo-ribaudo/chokidar-2": "2.1.8-no-fsevents.3",
|
||||
"chokidar": "^3.6.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@babel/core": "^7.0.0-0"
|
||||
}
|
||||
},
|
||||
"node_modules/@babel/code-frame": {
|
||||
"version": "7.29.0",
|
||||
"resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.29.0.tgz",
|
||||
|
|
@ -439,22 +405,6 @@
|
|||
"@babel/core": "^7.0.0-0"
|
||||
}
|
||||
},
|
||||
"node_modules/@babel/plugin-syntax-jsx": {
|
||||
"version": "7.28.6",
|
||||
"resolved": "https://registry.npmjs.org/@babel/plugin-syntax-jsx/-/plugin-syntax-jsx-7.28.6.tgz",
|
||||
"integrity": "sha512-wgEmr06G6sIpqr8YDwA2dSRTE3bJ+V0IfpzfSY3Lfgd7YWOaAdlykvJi13ZKBt8cZHfgH1IXN+CL656W3uUa4w==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/helper-plugin-utils": "^7.28.6"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=6.9.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@babel/core": "^7.0.0-0"
|
||||
}
|
||||
},
|
||||
"node_modules/@babel/template": {
|
||||
"version": "7.28.6",
|
||||
"resolved": "https://registry.npmjs.org/@babel/template/-/template-7.28.6.tgz",
|
||||
|
|
@ -1341,14 +1291,6 @@
|
|||
"@emnapi/runtime": "^1.7.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@nicolo-ribaudo/chokidar-2": {
|
||||
"version": "2.1.8-no-fsevents.3",
|
||||
"resolved": "https://registry.npmjs.org/@nicolo-ribaudo/chokidar-2/-/chokidar-2-2.1.8-no-fsevents.3.tgz",
|
||||
"integrity": "sha512-s88O1aVtXftvp5bCPB7WnmXc5IwOZZ7YPuwNPt+GtOOXpPvad1LfbmjYv+qII7zP6RU2QGnqve27dnLycEnyEQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true
|
||||
},
|
||||
"node_modules/@oxc-project/types": {
|
||||
"version": "0.124.0",
|
||||
"resolved": "https://registry.npmjs.org/@oxc-project/types/-/types-0.124.0.tgz",
|
||||
|
|
@ -2145,35 +2087,6 @@
|
|||
"url": "https://github.com/chalk/ansi-styles?sponsor=1"
|
||||
}
|
||||
},
|
||||
"node_modules/anymatch": {
|
||||
"version": "3.1.3",
|
||||
"resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz",
|
||||
"integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"normalize-path": "^3.0.0",
|
||||
"picomatch": "^2.0.4"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 8"
|
||||
}
|
||||
},
|
||||
"node_modules/anymatch/node_modules/picomatch": {
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"engines": {
|
||||
"node": ">=8.6"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/jonschlinkert"
|
||||
}
|
||||
},
|
||||
"node_modules/argparse": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz",
|
||||
|
|
@ -2367,16 +2280,6 @@
|
|||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/babel-plugin-react-compiler": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/babel-plugin-react-compiler/-/babel-plugin-react-compiler-1.0.0.tgz",
|
||||
"integrity": "sha512-Ixm8tFfoKKIPYdCCKYTsqv+Fd4IJ0DQqMyEimo+pxUOMUR9cVPlwTrFt9Avu+3cb6Zp3mAzl+t1MrG2fxxKsxw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/types": "^7.26.0"
|
||||
}
|
||||
},
|
||||
"node_modules/balanced-match": {
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-4.0.4.tgz",
|
||||
|
|
@ -2409,20 +2312,6 @@
|
|||
"require-from-string": "^2.0.2"
|
||||
}
|
||||
},
|
||||
"node_modules/binary-extensions": {
|
||||
"version": "2.3.0",
|
||||
"resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.3.0.tgz",
|
||||
"integrity": "sha512-Ceh+7ox5qe7LJuLHoY0feh3pHuUDHAcRUeyL2VYghZwfpkNIy/+8Ocg0a3UuSoYzavmylwuLWQOf3hl0jjMMIw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/brace-expansion": {
|
||||
"version": "5.0.5",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-5.0.5.tgz",
|
||||
|
|
@ -2436,20 +2325,6 @@
|
|||
"node": "18 || 20 || >=22"
|
||||
}
|
||||
},
|
||||
"node_modules/braces": {
|
||||
"version": "3.0.3",
|
||||
"resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz",
|
||||
"integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"fill-range": "^7.1.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/browserslist": {
|
||||
"version": "4.28.2",
|
||||
"resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.28.2.tgz",
|
||||
|
|
@ -2592,46 +2467,6 @@
|
|||
"url": "https://github.com/chalk/chalk?sponsor=1"
|
||||
}
|
||||
},
|
||||
"node_modules/chokidar": {
|
||||
"version": "3.6.0",
|
||||
"resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.6.0.tgz",
|
||||
"integrity": "sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"anymatch": "~3.1.2",
|
||||
"braces": "~3.0.2",
|
||||
"glob-parent": "~5.1.2",
|
||||
"is-binary-path": "~2.1.0",
|
||||
"is-glob": "~4.0.1",
|
||||
"normalize-path": "~3.0.0",
|
||||
"readdirp": "~3.6.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 8.10.0"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://paulmillr.com/funding/"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"fsevents": "~2.3.2"
|
||||
}
|
||||
},
|
||||
"node_modules/chokidar/node_modules/glob-parent": {
|
||||
"version": "5.1.2",
|
||||
"resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz",
|
||||
"integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"is-glob": "^4.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/cli-boxes": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/cli-boxes/-/cli-boxes-3.0.0.tgz",
|
||||
|
|
@ -2707,16 +2542,6 @@
|
|||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/commander": {
|
||||
"version": "6.2.1",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-6.2.1.tgz",
|
||||
"integrity": "sha512-U7VdrJFnJgo4xjrHpTzu0yrHPGImdsmD95ZlgYSEajAn2JKzDhDTPG9kBTefmObL2w/ngeZnilk+OV9CG3d7UA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/concat-map": {
|
||||
"version": "0.0.1",
|
||||
"resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz",
|
||||
|
|
@ -3663,20 +3488,6 @@
|
|||
"node": ">=16.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/fill-range": {
|
||||
"version": "7.1.1",
|
||||
"resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz",
|
||||
"integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"to-regex-range": "^5.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/find-up": {
|
||||
"version": "5.0.0",
|
||||
"resolved": "https://registry.npmjs.org/find-up/-/find-up-5.0.0.tgz",
|
||||
|
|
@ -3731,20 +3542,6 @@
|
|||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/fs-readdir-recursive": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/fs-readdir-recursive/-/fs-readdir-recursive-1.1.0.tgz",
|
||||
"integrity": "sha512-GNanXlVr2pf02+sPN40XN8HG+ePaNcvM0q5mZBd668Obwb0yD5GiUbZOFgwn8kGMY6I3mdyDJzieUy3PTYyTRA==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/fs.realpath": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz",
|
||||
"integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==",
|
||||
"dev": true,
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/fsevents": {
|
||||
"version": "2.3.3",
|
||||
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz",
|
||||
|
|
@ -3903,28 +3700,6 @@
|
|||
"url": "https://github.com/privatenumber/get-tsconfig?sponsor=1"
|
||||
}
|
||||
},
|
||||
"node_modules/glob": {
|
||||
"version": "7.2.3",
|
||||
"resolved": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz",
|
||||
"integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==",
|
||||
"deprecated": "Old versions of glob are not supported, and contain widely publicized security vulnerabilities, which have been fixed in the current version. Please update. Support for old versions may be purchased (at exorbitant rates) by contacting i@izs.me",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"fs.realpath": "^1.0.0",
|
||||
"inflight": "^1.0.4",
|
||||
"inherits": "2",
|
||||
"minimatch": "^3.1.1",
|
||||
"once": "^1.3.0",
|
||||
"path-is-absolute": "^1.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": "*"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/isaacs"
|
||||
}
|
||||
},
|
||||
"node_modules/glob-parent": {
|
||||
"version": "6.0.2",
|
||||
"resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz",
|
||||
|
|
@ -3938,37 +3713,6 @@
|
|||
"node": ">=10.13.0"
|
||||
}
|
||||
},
|
||||
"node_modules/glob/node_modules/balanced-match": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz",
|
||||
"integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/glob/node_modules/brace-expansion": {
|
||||
"version": "1.1.14",
|
||||
"resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.14.tgz",
|
||||
"integrity": "sha512-MWPGfDxnyzKU7rNOW9SP/c50vi3xrmrua/+6hfPbCS2ABNWfx24vPidzvC7krjU/RTo235sV776ymlsMtGKj8g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"balanced-match": "^1.0.0",
|
||||
"concat-map": "0.0.1"
|
||||
}
|
||||
},
|
||||
"node_modules/glob/node_modules/minimatch": {
|
||||
"version": "3.1.5",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.5.tgz",
|
||||
"integrity": "sha512-VgjWUsnnT6n+NUk6eZq77zeFdpW2LWDzP6zFGrCbHXiYNul5Dzqk2HHQ5uFH2DNW5Xbp8+jVzaeNt94ssEEl4w==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"brace-expansion": "^1.1.7"
|
||||
},
|
||||
"engines": {
|
||||
"node": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/globals": {
|
||||
"version": "16.5.0",
|
||||
"resolved": "https://registry.npmjs.org/globals/-/globals-16.5.0.tgz",
|
||||
|
|
@ -4171,25 +3915,6 @@
|
|||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/inflight": {
|
||||
"version": "1.0.6",
|
||||
"resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz",
|
||||
"integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==",
|
||||
"deprecated": "This module is not supported, and leaks memory. Do not use it. Check out lru-cache if you want a good and tested way to coalesce async requests by a key value, which is much more comprehensive and powerful.",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"once": "^1.3.0",
|
||||
"wrappy": "1"
|
||||
}
|
||||
},
|
||||
"node_modules/inherits": {
|
||||
"version": "2.0.4",
|
||||
"resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz",
|
||||
"integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==",
|
||||
"dev": true,
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/ink": {
|
||||
"version": "6.8.0",
|
||||
"resolved": "https://registry.npmjs.org/ink/-/ink-6.8.0.tgz",
|
||||
|
|
@ -4373,20 +4098,6 @@
|
|||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/is-binary-path": {
|
||||
"version": "2.1.0",
|
||||
"resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz",
|
||||
"integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"binary-extensions": "^2.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/is-boolean-object": {
|
||||
"version": "1.2.2",
|
||||
"resolved": "https://registry.npmjs.org/is-boolean-object/-/is-boolean-object-1.2.2.tgz",
|
||||
|
|
@ -4583,17 +4294,6 @@
|
|||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/is-number": {
|
||||
"version": "7.0.0",
|
||||
"resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz",
|
||||
"integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"engines": {
|
||||
"node": ">=0.12.0"
|
||||
}
|
||||
},
|
||||
"node_modules/is-number-object": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/is-number-object/-/is-number-object-1.1.1.tgz",
|
||||
|
|
@ -5224,30 +4924,6 @@
|
|||
"@jridgewell/sourcemap-codec": "^1.5.5"
|
||||
}
|
||||
},
|
||||
"node_modules/make-dir": {
|
||||
"version": "2.1.0",
|
||||
"resolved": "https://registry.npmjs.org/make-dir/-/make-dir-2.1.0.tgz",
|
||||
"integrity": "sha512-LS9X+dc8KLxXCb8dni79fLIIUA5VyZoyjSMCwTluaXA0o27cCK0bhXkpgw+sTXVpPy/lSO57ilRixqk0vDmtRA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"pify": "^4.0.1",
|
||||
"semver": "^5.6.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/make-dir/node_modules/semver": {
|
||||
"version": "5.7.2",
|
||||
"resolved": "https://registry.npmjs.org/semver/-/semver-5.7.2.tgz",
|
||||
"integrity": "sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"bin": {
|
||||
"semver": "bin/semver"
|
||||
}
|
||||
},
|
||||
"node_modules/math-intrinsics": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz",
|
||||
|
|
@ -5377,17 +5053,6 @@
|
|||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/normalize-path": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz",
|
||||
"integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/object-assign": {
|
||||
"version": "4.1.1",
|
||||
"resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz",
|
||||
|
|
@ -5507,16 +5172,6 @@
|
|||
],
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/once": {
|
||||
"version": "1.4.0",
|
||||
"resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz",
|
||||
"integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"wrappy": "1"
|
||||
}
|
||||
},
|
||||
"node_modules/onetime": {
|
||||
"version": "5.1.2",
|
||||
"resolved": "https://registry.npmjs.org/onetime/-/onetime-5.1.2.tgz",
|
||||
|
|
@ -5632,16 +5287,6 @@
|
|||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/path-is-absolute": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz",
|
||||
"integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/path-key": {
|
||||
"version": "3.1.1",
|
||||
"resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz",
|
||||
|
|
@ -5686,16 +5331,6 @@
|
|||
"url": "https://github.com/sponsors/jonschlinkert"
|
||||
}
|
||||
},
|
||||
"node_modules/pify": {
|
||||
"version": "4.0.1",
|
||||
"resolved": "https://registry.npmjs.org/pify/-/pify-4.0.1.tgz",
|
||||
"integrity": "sha512-uB80kBFb/tfd68bVleG9T5GGsGPjJrLAUpR5PZIrhBnIaRTQRjqdJSsIKkOP6OAIFbj7GOrcudc5pNjZ+geV2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/possible-typed-array-names": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/possible-typed-array-names/-/possible-typed-array-names-1.1.0.tgz",
|
||||
|
|
@ -5814,34 +5449,6 @@
|
|||
"react": "^19.2.0"
|
||||
}
|
||||
},
|
||||
"node_modules/readdirp": {
|
||||
"version": "3.6.0",
|
||||
"resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz",
|
||||
"integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"picomatch": "^2.2.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/readdirp/node_modules/picomatch": {
|
||||
"version": "2.3.2",
|
||||
"resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.2.tgz",
|
||||
"integrity": "sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"engines": {
|
||||
"node": ">=8.6"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/jonschlinkert"
|
||||
}
|
||||
},
|
||||
"node_modules/reflect.getprototypeof": {
|
||||
"version": "1.0.10",
|
||||
"resolved": "https://registry.npmjs.org/reflect.getprototypeof/-/reflect.getprototypeof-1.0.10.tgz",
|
||||
|
|
@ -6223,16 +5830,6 @@
|
|||
"integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==",
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/slash": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/slash/-/slash-2.0.0.tgz",
|
||||
"integrity": "sha512-ZYKh3Wh2z1PpEXWr0MpSBZ0V6mZHAQfYevttO11c51CaWjGTaadiKZ+wVt1PbMlDV5qhMFslpZCemhwOK7C89A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/slice-ansi": {
|
||||
"version": "8.0.0",
|
||||
"resolved": "https://registry.npmjs.org/slice-ansi/-/slice-ansi-8.0.0.tgz",
|
||||
|
|
@ -6571,20 +6168,6 @@
|
|||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/to-regex-range": {
|
||||
"version": "5.0.1",
|
||||
"resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz",
|
||||
"integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"is-number": "^7.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/ts-api-utils": {
|
||||
"version": "2.5.0",
|
||||
"resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-2.5.0.tgz",
|
||||
|
|
@ -7202,13 +6785,6 @@
|
|||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/wrappy": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz",
|
||||
"integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==",
|
||||
"dev": true,
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/ws": {
|
||||
"version": "8.20.1",
|
||||
"resolved": "https://registry.npmjs.org/ws/-/ws-8.20.1.tgz",
|
||||
|
|
|
|||
|
|
@ -25,15 +25,11 @@
|
|||
"unicode-animations": "^1.0.3"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@babel/cli": "^7.28.6",
|
||||
"@babel/core": "^7.29.0",
|
||||
"@babel/plugin-syntax-jsx": "^7.28.6",
|
||||
"@eslint/js": "^9",
|
||||
"@types/node": "^25.5.0",
|
||||
"@types/react": "^19.2.14",
|
||||
"@typescript-eslint/eslint-plugin": "^8",
|
||||
"@typescript-eslint/parser": "^8",
|
||||
"babel-plugin-react-compiler": "^1.0.0",
|
||||
"esbuild": "~0.27.0",
|
||||
"eslint": "^9",
|
||||
"eslint-plugin-perfectionist": "^5",
|
||||
|
|
|
|||
|
|
@ -1473,16 +1473,9 @@ export default class Ink {
|
|||
if (success) {
|
||||
return text
|
||||
}
|
||||
|
||||
if (process.env.HERMES_TUI_DEBUG_CLIPBOARD) {
|
||||
console.error(
|
||||
'[clipboard] no path reached the clipboard (headless + no tmux?) — set HERMES_TUI_FORCE_OSC52=1 to force the escape sequence'
|
||||
)
|
||||
}
|
||||
} catch (err) {
|
||||
if (process.env.HERMES_TUI_DEBUG_CLIPBOARD) {
|
||||
console.error('[clipboard] error:', err)
|
||||
}
|
||||
} catch {
|
||||
// Clipboard failed across every path — caller sees the empty
|
||||
// return below and surfaces a hint via the slash command.
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -308,9 +308,24 @@ export async function setClipboard(text: string): Promise<ClipboardResult> {
|
|||
// Cached after first attempt so repeated mouse-ups skip the probe chain.
|
||||
let linuxCopy: 'wl-copy' | 'xclip' | 'xsel' | null | undefined
|
||||
|
||||
/** Per-tool copy arguments: wl-copy reads stdin, xclip/xsel need clipboard flags. */
|
||||
function linuxCopyArgs(tool: 'wl-copy' | 'xclip' | 'xsel'): string[] {
|
||||
switch (tool) {
|
||||
case 'wl-copy':
|
||||
return []
|
||||
case 'xclip':
|
||||
return ['-selection', 'clipboard']
|
||||
case 'xsel':
|
||||
return ['--clipboard', '--input']
|
||||
}
|
||||
}
|
||||
|
||||
/** Internal: probe once and cache — wl-copy first, then xclip, then xsel. */
|
||||
async function probeLinuxCopy(): Promise<'wl-copy' | 'xclip' | 'xsel' | null> {
|
||||
const opts = { useCwd: false, timeout: 500 }
|
||||
// resolveOnExit: wl-copy daemonizes and the daemon inherits stdio pipes,
|
||||
// so 'close' never fires and the await would hang past the timeout.
|
||||
// 'exit' fires on the immediate child's exit — what we actually care about.
|
||||
const opts = { useCwd: false, timeout: 500, resolveOnExit: true }
|
||||
|
||||
const r = await execFileNoThrow('wl-copy', [], opts)
|
||||
|
||||
|
|
@ -318,13 +333,13 @@ async function probeLinuxCopy(): Promise<'wl-copy' | 'xclip' | 'xsel' | null> {
|
|||
return 'wl-copy'
|
||||
}
|
||||
|
||||
const r2 = await execFileNoThrow('xclip', ['-selection', 'clipboard'], opts)
|
||||
const r2 = await execFileNoThrow('xclip', linuxCopyArgs('xclip'), opts)
|
||||
|
||||
if (r2.code === 0) {
|
||||
return 'xclip'
|
||||
}
|
||||
|
||||
const r3 = await execFileNoThrow('xsel', ['--clipboard', '--input'], opts)
|
||||
const r3 = await execFileNoThrow('xsel', linuxCopyArgs('xsel'), opts)
|
||||
|
||||
return r3.code === 0 ? 'xsel' : null
|
||||
}
|
||||
|
|
@ -347,7 +362,11 @@ async function probeLinuxCopy(): Promise<'wl-copy' | 'xclip' | 'xsel' | null> {
|
|||
* we skip probing entirely and treat linuxCopy as permanently null.
|
||||
*/
|
||||
function copyNative(text: string): boolean {
|
||||
const opts = { input: text, useCwd: false, timeout: 2000 }
|
||||
// resolveOnExit: pbcopy/wl-copy/xclip/xsel/clip all daemonize or hold
|
||||
// the system selection live in a forked process. Without resolveOnExit,
|
||||
// the inherited stdio pipes keep node from seeing 'close' → the
|
||||
// fire-and-forget await never resolves and the actual copy never runs.
|
||||
const opts = { input: text, useCwd: false, timeout: 2000, resolveOnExit: true }
|
||||
|
||||
switch (process.platform) {
|
||||
case 'darwin':
|
||||
|
|
@ -363,17 +382,13 @@ function copyNative(text: string): boolean {
|
|||
}
|
||||
|
||||
// linuxCopy is a known-working tool; fire-and-forget.
|
||||
void execFileNoThrow(linuxCopy, linuxCopy === 'wl-copy' ? [] : ['-selection', 'clipboard'], opts)
|
||||
void execFileNoThrow(linuxCopy, linuxCopyArgs(linuxCopy), opts)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// No display server → native tools will fail immediately. Cache null.
|
||||
if (!process.env.DISPLAY && !process.env.WAYLAND_DISPLAY) {
|
||||
if (process.env.HERMES_TUI_DEBUG_CLIPBOARD) {
|
||||
console.error('[clipboard] [native] Linux: no DISPLAY or WAYLAND_DISPLAY — native clipboard unavailable')
|
||||
}
|
||||
|
||||
linuxCopy = null
|
||||
|
||||
return false
|
||||
|
|
@ -386,13 +401,9 @@ function copyNative(text: string): boolean {
|
|||
const winner = await probeLinuxCopy()
|
||||
linuxCopy = winner
|
||||
|
||||
if (process.env.HERMES_TUI_DEBUG_CLIPBOARD) {
|
||||
console.error(`[clipboard] [native] Linux: clipboard probe complete → ${winner ?? 'no tool available'}`)
|
||||
}
|
||||
|
||||
// Actually perform the copy with the discovered tool.
|
||||
if (winner) {
|
||||
void execFileNoThrow(winner, winner === 'wl-copy' ? [] : ['-selection', 'clipboard'], opts)
|
||||
void execFileNoThrow(winner, linuxCopyArgs(winner), opts)
|
||||
}
|
||||
})()
|
||||
|
||||
|
|
|
|||
146
ui-tui/packages/hermes-ink/src/utils/execFileNoThrow.test.ts
Normal file
146
ui-tui/packages/hermes-ink/src/utils/execFileNoThrow.test.ts
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
import { chmodSync, mkdirSync, readFileSync, rmSync, writeFileSync } from 'node:fs'
|
||||
import { tmpdir } from 'node:os'
|
||||
import { join } from 'node:path'
|
||||
|
||||
import { afterEach, beforeEach, describe, expect, it } from 'vitest'
|
||||
|
||||
import { execFileNoThrow } from './execFileNoThrow.js'
|
||||
|
||||
// These tests shell out to /bin/sh, use chmodSync(0o755), and rely on
|
||||
// POSIX sleep/job control. They will not work on Windows.
|
||||
const onWindows = process.platform === 'win32'
|
||||
|
||||
// We simulate `wl-copy`'s daemonization behavior with a tiny shell script:
|
||||
// 1. Fork a short-lived background sleeper that inherits stdio (so the
|
||||
// parent process's pipes can never close).
|
||||
// 2. Record the sleeper PID to a file so afterEach can clean it up.
|
||||
// 3. Exit immediately with status 0.
|
||||
//
|
||||
// Without resolveOnExit, the await on `'close'` hangs until SIGTERM at
|
||||
// timeout — exactly the production wl-copy bug. With resolveOnExit, the
|
||||
// promise settles on `'exit'` regardless of the inherited pipes.
|
||||
|
||||
let scriptDir: string
|
||||
let daemonScript: string
|
||||
let sleeperPids: number[]
|
||||
|
||||
/** Read the PID file the daemon script writes, and track it for afterEach cleanup. */
|
||||
function trackSleeperPid(pidFile: string): void {
|
||||
try {
|
||||
const pid = parseInt(readFileSync(pidFile, 'utf8').trim(), 10)
|
||||
if (pid > 0) {
|
||||
sleeperPids.push(pid)
|
||||
}
|
||||
} catch {
|
||||
// PID file not written or unreadable — sleeper may have already exited.
|
||||
}
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
sleeperPids = []
|
||||
scriptDir = join(tmpdir(), `hermes-execfile-test-${process.pid}-${Date.now()}`)
|
||||
mkdirSync(scriptDir, { recursive: true })
|
||||
daemonScript = join(scriptDir, 'fake-daemonizer.sh')
|
||||
// Posix sh: the `sleep 3 &` child inherits stdin/stdout/stderr from the
|
||||
// shell, which inherited them from `spawn(stdio: 'pipe')`. The shell
|
||||
// exits but its child (the sleeper) keeps the pipes open. Mirrors how
|
||||
// wl-copy double-forks then exits while the daemon holds the selection.
|
||||
// The sleeper writes its PID to $1 so we can clean it up reliably.
|
||||
writeFileSync(daemonScript, '#!/bin/sh\nsleep 3 &\necho $! > "$1"\nexit 0\n')
|
||||
chmodSync(daemonScript, 0o755)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
// Kill orphaned sleepers so they don't accumulate across watch runs.
|
||||
for (const pid of sleeperPids) {
|
||||
try {
|
||||
process.kill(pid, 'SIGKILL')
|
||||
} catch {
|
||||
// Already exited — fine.
|
||||
}
|
||||
}
|
||||
rmSync(scriptDir, { recursive: true, force: true })
|
||||
})
|
||||
|
||||
describe.skipIf(onWindows)('execFileNoThrow with daemon-style children', () => {
|
||||
// Skipped because the bug it documents is a forever-hang. Without
|
||||
// resolveOnExit, the 'close' event doesn't fire when the immediate
|
||||
// child has exited but a forked daemon still holds stdio open. Even
|
||||
// SIGTERM at the timeout doesn't help — the daemon survives it. To
|
||||
// verify by hand: remove `it.skip` and watch the test timeout. This
|
||||
// test is here so a reviewer reading the resolveOnExit option knows
|
||||
// *why* every clipboard-tool spawn in osc.ts wires it on.
|
||||
it.skip("(documented hang) without resolveOnExit, await never resolves when daemon inherits stdio", async () => {
|
||||
const pidFile = join(scriptDir, 'sleeper-skip.pid')
|
||||
const result = await execFileNoThrow(daemonScript, [pidFile], { timeout: 300 })
|
||||
trackSleeperPid(pidFile)
|
||||
|
||||
expect(result.code).toBe(124)
|
||||
})
|
||||
|
||||
it("settles immediately on 'exit' when resolveOnExit is true, regardless of daemon stdio", async () => {
|
||||
const pidFile = join(scriptDir, 'sleeper-exit.pid')
|
||||
const start = Date.now()
|
||||
|
||||
const result = await execFileNoThrow(daemonScript, [pidFile], {
|
||||
timeout: 2000,
|
||||
resolveOnExit: true
|
||||
})
|
||||
trackSleeperPid(pidFile)
|
||||
|
||||
const elapsed = Date.now() - start
|
||||
|
||||
// The shell exits in a few ms. resolveOnExit lets us return on exit
|
||||
// (code 0) instead of waiting for the orphaned sleeper to release
|
||||
// stdio. Should be well under 200ms even on slow CI.
|
||||
expect(result.code).toBe(0)
|
||||
expect(elapsed).toBeLessThan(500)
|
||||
})
|
||||
|
||||
it("still surfaces the right code when resolveOnExit'd child exits non-zero", async () => {
|
||||
const pidFile = join(scriptDir, 'sleeper-fail.pid')
|
||||
const failScript = join(scriptDir, 'fail.sh')
|
||||
writeFileSync(failScript, `#!/bin/sh\nsleep 3 &\necho $! > "${pidFile}"\nexit 7\n`)
|
||||
chmodSync(failScript, 0o755)
|
||||
|
||||
const result = await execFileNoThrow(failScript, [], {
|
||||
timeout: 2000,
|
||||
resolveOnExit: true
|
||||
})
|
||||
trackSleeperPid(pidFile)
|
||||
|
||||
expect(result.code).toBe(7)
|
||||
})
|
||||
|
||||
it('settles on timeout=124 when the child itself never exits, even with resolveOnExit', async () => {
|
||||
const slowScript = join(scriptDir, 'slow.sh')
|
||||
writeFileSync(slowScript, '#!/bin/sh\nsleep 30\n')
|
||||
chmodSync(slowScript, 0o755)
|
||||
|
||||
const result = await execFileNoThrow(slowScript, [], {
|
||||
timeout: 200,
|
||||
resolveOnExit: true
|
||||
})
|
||||
|
||||
// Child process never exits on its own → timer fires → SIGTERM →
|
||||
// child exits → 'exit' fires with non-null signal. The settle()
|
||||
// call from the timer registers code=124 first. Either way: 124.
|
||||
expect(result.code).toBe(124)
|
||||
})
|
||||
|
||||
it('does not double-resolve when both timer and exit fire', async () => {
|
||||
const pidFile = join(scriptDir, 'sleeper-race.pid')
|
||||
// Race: child happens to exit right around the timeout. The settled
|
||||
// guard ensures only the first resolution wins.
|
||||
const result = await execFileNoThrow(daemonScript, [pidFile], {
|
||||
timeout: 50, // very tight
|
||||
resolveOnExit: true
|
||||
})
|
||||
trackSleeperPid(pidFile)
|
||||
|
||||
// Either code=0 (exit beat timer) or code=124 (timer beat exit).
|
||||
// Both are valid outcomes; the contract is that the promise settles
|
||||
// exactly once and doesn't throw.
|
||||
expect([0, 124]).toContain(result.code)
|
||||
})
|
||||
})
|
||||
|
|
@ -4,6 +4,17 @@ type ExecFileOptions = {
|
|||
timeout?: number
|
||||
useCwd?: boolean
|
||||
env?: NodeJS.ProcessEnv
|
||||
/** Resolve as soon as the child *exits*, instead of waiting for its
|
||||
* stdio streams to close. Use this for tools that fork a daemon and
|
||||
* let the daemon inherit the parent's stdio (e.g. `wl-copy`): the
|
||||
* child exits immediately, but `'close'` never fires because the
|
||||
* daemon holds the pipes open.
|
||||
*
|
||||
* When true, stdout and stderr are set to 'ignore' to prevent the
|
||||
* daemon from inheriting those pipe FDs — the caller must not
|
||||
* depend on collecting stdout/stderr content. Both will always be
|
||||
* empty strings in this mode. */
|
||||
resolveOnExit?: boolean
|
||||
}
|
||||
|
||||
export function execFileNoThrow(
|
||||
|
|
@ -17,20 +28,55 @@ export function execFileNoThrow(
|
|||
error?: string
|
||||
}> {
|
||||
return new Promise(resolve => {
|
||||
// When resolveOnExit is true, ignore stdout/stderr so the daemon
|
||||
// doesn't inherit those pipe FDs — prevents handle leaks that can
|
||||
// keep the parent process alive. No output data is collected in
|
||||
// this mode; both stdout and stderr will be empty strings.
|
||||
const stdioConfig = options.resolveOnExit
|
||||
? ['pipe', 'ignore', 'ignore'] as const
|
||||
: 'pipe' as const
|
||||
|
||||
const child = spawn(file, args, {
|
||||
cwd: options.useCwd ? process.cwd() : undefined,
|
||||
env: options.env,
|
||||
stdio: 'pipe'
|
||||
stdio: stdioConfig
|
||||
})
|
||||
|
||||
let stdout = ''
|
||||
let stderr = ''
|
||||
let timedOut = false
|
||||
let settled = false
|
||||
|
||||
const settle = (code: number, error?: string) => {
|
||||
if (settled) {
|
||||
return
|
||||
}
|
||||
|
||||
settled = true
|
||||
|
||||
if (timer) {
|
||||
clearTimeout(timer)
|
||||
}
|
||||
|
||||
// Destroy any remaining streams to release FDs promptly.
|
||||
// After settle(), nobody reads from these anymore.
|
||||
child.stdout?.destroy()
|
||||
child.stderr?.destroy()
|
||||
|
||||
resolve({ stdout, stderr, code, ...(error ? { error } : {}) })
|
||||
}
|
||||
|
||||
const timer = options.timeout
|
||||
? setTimeout(() => {
|
||||
timedOut = true
|
||||
child.kill('SIGTERM')
|
||||
|
||||
// When resolving on exit, SIGTERM-ing a child that has already
|
||||
// exited is a no-op and `'exit'` won't fire again — settle here
|
||||
// so the promise doesn't leak. Safe under settled-guard.
|
||||
if (options.resolveOnExit) {
|
||||
settle(124)
|
||||
}
|
||||
}, options.timeout)
|
||||
: null
|
||||
|
||||
|
|
@ -41,19 +87,24 @@ export function execFileNoThrow(
|
|||
stderr += String(chunk)
|
||||
})
|
||||
child.on('error', error => {
|
||||
if (timer) {
|
||||
clearTimeout(timer)
|
||||
}
|
||||
|
||||
resolve({ stdout, stderr, code: 1, error: String(error) })
|
||||
settle(1, String(error))
|
||||
})
|
||||
child.on('close', code => {
|
||||
if (timer) {
|
||||
clearTimeout(timer)
|
||||
}
|
||||
|
||||
resolve({ stdout, stderr, code: timedOut ? 124 : (code ?? 0) })
|
||||
})
|
||||
if (options.resolveOnExit) {
|
||||
// 'exit' fires when the child process itself exits — even if the
|
||||
// daemon it forked still holds the inherited stdio pipes open.
|
||||
// When a signal kills the child, code is null — map that to 1
|
||||
// so callers don't mistake a signal-terminated run for success.
|
||||
child.on('exit', (code, signal) => {
|
||||
const exitCode = timedOut ? 124 : (code ?? (signal ? 1 : 0))
|
||||
settle(exitCode)
|
||||
})
|
||||
} else {
|
||||
child.on('close', (code, signal) => {
|
||||
const exitCode = timedOut ? 124 : (code ?? (signal ? 1 : 0))
|
||||
settle(exitCode)
|
||||
})
|
||||
}
|
||||
|
||||
if (options.input) {
|
||||
child.stdin?.write(options.input)
|
||||
|
|
|
|||
|
|
@ -345,7 +345,7 @@ export const coreCommands: SlashCommand[] = [
|
|||
return sys(`copied ${text.length} characters`)
|
||||
} else {
|
||||
return sys(
|
||||
'clipboard copy failed — try HERMES_TUI_FORCE_OSC52=1 to force the escape sequence; HERMES_TUI_DEBUG_CLIPBOARD=1 for details'
|
||||
'clipboard copy failed — try HERMES_TUI_FORCE_OSC52=1 to force the escape sequence'
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue