mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
refactor(cli): implement approval locking mechanism to serialize concurrent requests
- Introduced _approval_lock to ensure that approval prompts are handled sequentially, preventing state clobbering from parallel delegation subtasks. - Updated approval_callback and HermesCLI methods to utilize the lock for managing approval state and deadlines. - Added tests for the config bridging logic to ensure correct environment variable mapping from config.yaml.
This commit is contained in:
parent
a20d373945
commit
163fa4a9d1
3 changed files with 231 additions and 69 deletions
72
cli.py
72
cli.py
|
|
@ -3571,48 +3571,51 @@ class HermesCLI:
|
|||
|
||||
Called from the agent thread. Shows a selection UI similar to clarify
|
||||
with choices: once / session / always / deny.
|
||||
|
||||
Uses _approval_lock to serialize concurrent requests (e.g. from
|
||||
parallel delegation subtasks) so each prompt gets its own turn
|
||||
and the shared _approval_state / _approval_deadline aren't clobbered.
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
timeout = 60
|
||||
response_queue = queue.Queue()
|
||||
choices = ["once", "session", "always", "deny"]
|
||||
with self._approval_lock:
|
||||
timeout = 60
|
||||
response_queue = queue.Queue()
|
||||
choices = ["once", "session", "always", "deny"]
|
||||
|
||||
self._approval_state = {
|
||||
"command": command,
|
||||
"description": description,
|
||||
"choices": choices,
|
||||
"selected": 0,
|
||||
"response_queue": response_queue,
|
||||
}
|
||||
self._approval_deadline = _time.monotonic() + timeout
|
||||
self._approval_state = {
|
||||
"command": command,
|
||||
"description": description,
|
||||
"choices": choices,
|
||||
"selected": 0,
|
||||
"response_queue": response_queue,
|
||||
}
|
||||
self._approval_deadline = _time.monotonic() + timeout
|
||||
|
||||
self._invalidate()
|
||||
self._invalidate()
|
||||
|
||||
# Same throttled countdown as _clarify_callback — repaint only
|
||||
# every 5 s to avoid flicker in Kitty / ghostty / etc.
|
||||
_last_countdown_refresh = _time.monotonic()
|
||||
while True:
|
||||
try:
|
||||
result = response_queue.get(timeout=1)
|
||||
self._approval_state = None
|
||||
self._approval_deadline = 0
|
||||
self._invalidate()
|
||||
return result
|
||||
except queue.Empty:
|
||||
remaining = self._approval_deadline - _time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
now = _time.monotonic()
|
||||
if now - _last_countdown_refresh >= 5.0:
|
||||
_last_countdown_refresh = now
|
||||
_last_countdown_refresh = _time.monotonic()
|
||||
while True:
|
||||
try:
|
||||
result = response_queue.get(timeout=1)
|
||||
self._approval_state = None
|
||||
self._approval_deadline = 0
|
||||
self._invalidate()
|
||||
return result
|
||||
except queue.Empty:
|
||||
remaining = self._approval_deadline - _time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
now = _time.monotonic()
|
||||
if now - _last_countdown_refresh >= 5.0:
|
||||
_last_countdown_refresh = now
|
||||
self._invalidate()
|
||||
|
||||
self._approval_state = None
|
||||
self._approval_deadline = 0
|
||||
self._invalidate()
|
||||
_cprint(f"\n{_DIM} ⏱ Timeout — denying command{_RST}")
|
||||
return "deny"
|
||||
self._approval_state = None
|
||||
self._approval_deadline = 0
|
||||
self._invalidate()
|
||||
_cprint(f"\n{_DIM} ⏱ Timeout — denying command{_RST}")
|
||||
return "deny"
|
||||
|
||||
def _secret_capture_callback(self, var_name: str, prompt: str, metadata=None) -> dict:
|
||||
return prompt_for_secret(self, var_name, prompt, metadata)
|
||||
|
|
@ -3920,6 +3923,7 @@ class HermesCLI:
|
|||
# Dangerous command approval state (similar mechanism to clarify)
|
||||
self._approval_state = None # dict with command, description, choices, selected, response_queue
|
||||
self._approval_deadline = 0
|
||||
self._approval_lock = threading.Lock() # serialize concurrent approval prompts (delegation race fix)
|
||||
|
||||
# Slash command loading state
|
||||
self._command_running = False
|
||||
|
|
|
|||
|
|
@ -227,43 +227,53 @@ def approval_callback(cli, command: str, description: str) -> str:
|
|||
Shows a selection UI with choices: once / session / always / deny.
|
||||
When the command is longer than 70 characters, a "view" option is
|
||||
included so the user can reveal the full text before deciding.
|
||||
|
||||
Uses cli._approval_lock to serialize concurrent requests (e.g. from
|
||||
parallel delegation subtasks) so each prompt gets its own turn.
|
||||
"""
|
||||
timeout = 60
|
||||
response_queue = queue.Queue()
|
||||
choices = ["once", "session", "always", "deny"]
|
||||
if len(command) > 70:
|
||||
choices.append("view")
|
||||
lock = getattr(cli, "_approval_lock", None)
|
||||
if lock is None:
|
||||
import threading
|
||||
cli._approval_lock = threading.Lock()
|
||||
lock = cli._approval_lock
|
||||
|
||||
cli._approval_state = {
|
||||
"command": command,
|
||||
"description": description,
|
||||
"choices": choices,
|
||||
"selected": 0,
|
||||
"response_queue": response_queue,
|
||||
}
|
||||
cli._approval_deadline = _time.monotonic() + timeout
|
||||
with lock:
|
||||
timeout = 60
|
||||
response_queue = queue.Queue()
|
||||
choices = ["once", "session", "always", "deny"]
|
||||
if len(command) > 70:
|
||||
choices.append("view")
|
||||
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
cli._approval_state = {
|
||||
"command": command,
|
||||
"description": description,
|
||||
"choices": choices,
|
||||
"selected": 0,
|
||||
"response_queue": response_queue,
|
||||
}
|
||||
cli._approval_deadline = _time.monotonic() + timeout
|
||||
|
||||
while True:
|
||||
try:
|
||||
result = response_queue.get(timeout=1)
|
||||
cli._approval_state = None
|
||||
cli._approval_deadline = 0
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
return result
|
||||
except queue.Empty:
|
||||
remaining = cli._approval_deadline - _time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
|
||||
cli._approval_state = None
|
||||
cli._approval_deadline = 0
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
cprint(f"\n{_DIM} ⏱ Timeout — denying command{_RST}")
|
||||
return "deny"
|
||||
while True:
|
||||
try:
|
||||
result = response_queue.get(timeout=1)
|
||||
cli._approval_state = None
|
||||
cli._approval_deadline = 0
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
return result
|
||||
except queue.Empty:
|
||||
remaining = cli._approval_deadline - _time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
|
||||
cli._approval_state = None
|
||||
cli._approval_deadline = 0
|
||||
if hasattr(cli, "_app") and cli._app:
|
||||
cli._app.invalidate()
|
||||
cprint(f"\n{_DIM} ⏱ Timeout — denying command{_RST}")
|
||||
return "deny"
|
||||
|
|
|
|||
148
tests/gateway/test_config_cwd_bridge.py
Normal file
148
tests/gateway/test_config_cwd_bridge.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""Tests for the config.yaml → env var bridge logic in gateway/run.py.
|
||||
|
||||
Specifically tests that top-level `cwd:` and `backend:` in config.yaml
|
||||
are correctly bridged to TERMINAL_CWD / TERMINAL_ENV env vars as
|
||||
convenience aliases for `terminal.cwd` / `terminal.backend`.
|
||||
|
||||
The bridge logic is module-level code in gateway/run.py, so we test
|
||||
the semantics by reimplementing the relevant config bridge snippet and
|
||||
asserting the expected env var outcomes.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
|
||||
|
||||
def _simulate_config_bridge(cfg: dict, initial_env: dict | None = None):
|
||||
"""Simulate the gateway config bridge logic from gateway/run.py.
|
||||
|
||||
Returns the resulting env dict (only TERMINAL_* and MESSAGING_CWD keys).
|
||||
"""
|
||||
env = dict(initial_env or {})
|
||||
|
||||
# --- Replicate lines 54-56: generic top-level bridge (for context) ---
|
||||
for key, val in cfg.items():
|
||||
if isinstance(val, (str, int, float, bool)) and key not in env:
|
||||
env[key] = str(val)
|
||||
|
||||
# --- Replicate lines 59-87: terminal config bridge ---
|
||||
terminal_cfg = cfg.get("terminal", {})
|
||||
if terminal_cfg and isinstance(terminal_cfg, dict):
|
||||
terminal_env_map = {
|
||||
"backend": "TERMINAL_ENV",
|
||||
"cwd": "TERMINAL_CWD",
|
||||
"timeout": "TERMINAL_TIMEOUT",
|
||||
}
|
||||
for cfg_key, env_var in terminal_env_map.items():
|
||||
if cfg_key in terminal_cfg:
|
||||
val = terminal_cfg[cfg_key]
|
||||
if isinstance(val, list):
|
||||
env[env_var] = json.dumps(val)
|
||||
else:
|
||||
env[env_var] = str(val)
|
||||
|
||||
# --- NEW: top-level aliases (the fix being tested) ---
|
||||
top_level_aliases = {
|
||||
"cwd": "TERMINAL_CWD",
|
||||
"backend": "TERMINAL_ENV",
|
||||
}
|
||||
for alias_key, alias_env in top_level_aliases.items():
|
||||
if alias_env not in env:
|
||||
alias_val = cfg.get(alias_key)
|
||||
if isinstance(alias_val, str) and alias_val.strip():
|
||||
env[alias_env] = alias_val.strip()
|
||||
|
||||
# --- Replicate lines 144-147: MESSAGING_CWD fallback ---
|
||||
configured_cwd = env.get("TERMINAL_CWD", "")
|
||||
if not configured_cwd or configured_cwd in (".", "auto", "cwd"):
|
||||
messaging_cwd = env.get("MESSAGING_CWD") or "/root" # Path.home() for root
|
||||
env["TERMINAL_CWD"] = messaging_cwd
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class TestTopLevelCwdAlias:
|
||||
"""Top-level `cwd:` should be treated as `terminal.cwd`."""
|
||||
|
||||
def test_top_level_cwd_sets_terminal_cwd(self):
|
||||
cfg = {"cwd": "/home/hermes/projects"}
|
||||
result = _simulate_config_bridge(cfg)
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes/projects"
|
||||
|
||||
def test_top_level_backend_sets_terminal_env(self):
|
||||
cfg = {"backend": "docker"}
|
||||
result = _simulate_config_bridge(cfg)
|
||||
assert result["TERMINAL_ENV"] == "docker"
|
||||
|
||||
def test_top_level_cwd_and_backend(self):
|
||||
cfg = {"backend": "local", "cwd": "/home/hermes/projects"}
|
||||
result = _simulate_config_bridge(cfg)
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes/projects"
|
||||
assert result["TERMINAL_ENV"] == "local"
|
||||
|
||||
def test_nested_terminal_takes_precedence_over_top_level(self):
|
||||
"""terminal.cwd should win over top-level cwd."""
|
||||
cfg = {
|
||||
"cwd": "/should/not/use",
|
||||
"terminal": {"cwd": "/home/hermes/real"},
|
||||
}
|
||||
result = _simulate_config_bridge(cfg)
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes/real"
|
||||
|
||||
def test_nested_terminal_backend_takes_precedence(self):
|
||||
cfg = {
|
||||
"backend": "should-not-use",
|
||||
"terminal": {"backend": "docker"},
|
||||
}
|
||||
result = _simulate_config_bridge(cfg)
|
||||
assert result["TERMINAL_ENV"] == "docker"
|
||||
|
||||
def test_no_cwd_falls_back_to_messaging_cwd(self):
|
||||
cfg = {}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes/projects"})
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes/projects"
|
||||
|
||||
def test_no_cwd_no_messaging_cwd_falls_back_to_home(self):
|
||||
cfg = {}
|
||||
result = _simulate_config_bridge(cfg)
|
||||
assert result["TERMINAL_CWD"] == "/root" # Path.home() for root user
|
||||
|
||||
def test_dot_cwd_triggers_messaging_fallback(self):
|
||||
"""cwd: '.' should trigger MESSAGING_CWD fallback."""
|
||||
cfg = {"cwd": "."}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes"})
|
||||
# "." is stripped but truthy, so it gets set as TERMINAL_CWD
|
||||
# Then the MESSAGING_CWD fallback does NOT trigger since TERMINAL_CWD
|
||||
# is set and not in (".", "auto", "cwd").
|
||||
# Wait — "." IS in the fallback list! So this should fall through.
|
||||
# Actually the alias sets it to ".", then the messaging fallback
|
||||
# checks if it's in (".", "auto", "cwd") and overrides.
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes"
|
||||
|
||||
def test_auto_cwd_triggers_messaging_fallback(self):
|
||||
cfg = {"cwd": "auto"}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes"})
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes"
|
||||
|
||||
def test_empty_cwd_ignored(self):
|
||||
cfg = {"cwd": ""}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes"})
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes"
|
||||
|
||||
def test_whitespace_only_cwd_ignored(self):
|
||||
cfg = {"cwd": " "}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/fallback"})
|
||||
assert result["TERMINAL_CWD"] == "/fallback"
|
||||
|
||||
def test_messaging_cwd_env_var_works(self):
|
||||
"""MESSAGING_CWD in initial env should be picked up as fallback."""
|
||||
cfg = {}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/home/hermes/projects"})
|
||||
assert result["TERMINAL_CWD"] == "/home/hermes/projects"
|
||||
|
||||
def test_top_level_cwd_beats_messaging_cwd(self):
|
||||
"""Explicit top-level cwd should take precedence over MESSAGING_CWD."""
|
||||
cfg = {"cwd": "/from/config"}
|
||||
result = _simulate_config_bridge(cfg, {"MESSAGING_CWD": "/from/env"})
|
||||
assert result["TERMINAL_CWD"] == "/from/config"
|
||||
Loading…
Add table
Add a link
Reference in a new issue