mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix: make gateway approval block agent thread like CLI does (#4557)
The gateway's dangerous command approval system was fundamentally broken: the agent loop continued running after a command was flagged, and the approval request only reached the user after the agent finished its entire conversation loop. By then the context was lost. This change makes the gateway approval mirror the CLI's synchronous behavior. When a dangerous command is detected: 1. The agent thread blocks on a threading.Event 2. The approval request is sent to the user immediately 3. The user responds with /approve or /deny 4. The event is signaled and the agent resumes with the real result The agent never sees 'approval_required' as a tool result. It either gets the command output (approved) or a definitive BLOCKED message (denied/timed out) — same as CLI mode. Queue-based design supports multiple concurrent approvals (parallel subagents via delegate_task, execute_code RPC handlers). Each approval gets its own _ApprovalEntry with its own threading.Event. /approve resolves the oldest (FIFO); /approve all resolves all at once. Changes: - tools/approval.py: Queue-based per-session blocking gateway approval (register/unregister callbacks, resolve with FIFO or all-at-once) - gateway/run.py: Register approval callback in run_sync(), remove post-loop pop_pending hack, /approve and /deny support 'all' flag - tests: 21 tests including parallel subagent E2E scenarios
This commit is contained in:
parent
64584a931f
commit
624ad582a5
3 changed files with 742 additions and 277 deletions
222
gateway/run.py
222
gateway/run.py
|
|
@ -2721,27 +2721,12 @@ class GatewayRunner:
|
|||
except Exception as e:
|
||||
logger.error("Process watcher setup error: %s", e)
|
||||
|
||||
# Check if the agent encountered a dangerous command needing approval
|
||||
try:
|
||||
from tools.approval import pop_pending
|
||||
import time as _time
|
||||
pending = pop_pending(session_key)
|
||||
if pending:
|
||||
pending["timestamp"] = _time.time()
|
||||
self._pending_approvals[session_key] = pending
|
||||
# Append structured instructions so the user knows how to respond
|
||||
cmd_preview = pending.get("command", "")
|
||||
if len(cmd_preview) > 200:
|
||||
cmd_preview = cmd_preview[:200] + "..."
|
||||
approval_hint = (
|
||||
f"\n\n⚠️ **Dangerous command requires approval:**\n"
|
||||
f"```\n{cmd_preview}\n```\n"
|
||||
f"Reply `/approve` to execute, `/approve session` to approve this pattern "
|
||||
f"for the session, or `/deny` to cancel."
|
||||
)
|
||||
response = (response or "") + approval_hint
|
||||
except Exception as e:
|
||||
logger.debug("Failed to check pending approvals: %s", e)
|
||||
# NOTE: Dangerous command approvals are now handled inline by the
|
||||
# blocking gateway approval mechanism in tools/approval.py. The agent
|
||||
# thread blocks until the user responds with /approve or /deny, so by
|
||||
# the time we reach here the approval has already been resolved. The
|
||||
# old post-loop pop_pending + approval_hint code was removed in favour
|
||||
# of the blocking approach that mirrors CLI's synchronous input().
|
||||
|
||||
# Save the full conversation to the transcript, including tool calls.
|
||||
# This preserves the complete agent loop (tool_calls, tool results,
|
||||
|
|
@ -4730,123 +4715,93 @@ class GatewayRunner:
|
|||
_APPROVAL_TIMEOUT_SECONDS = 300 # 5 minutes
|
||||
|
||||
async def _handle_approve_command(self, event: MessageEvent) -> Optional[str]:
|
||||
"""Handle /approve command — execute a pending dangerous command.
|
||||
"""Handle /approve command — unblock waiting agent thread(s).
|
||||
|
||||
After execution, re-invokes the agent with the command result so it
|
||||
can continue its multi-step task (fixes the "dead agent" bug where
|
||||
the agent loop exited on approval_required and never resumed).
|
||||
The agent thread(s) are blocked inside tools/approval.py waiting for
|
||||
the user to respond. This handler signals the event so the agent
|
||||
resumes and the terminal_tool executes the command inline — the same
|
||||
flow as the CLI's synchronous input() approval.
|
||||
|
||||
Supports multiple concurrent approvals (parallel subagents,
|
||||
execute_code). ``/approve`` resolves the oldest pending command;
|
||||
``/approve all`` resolves every pending command at once.
|
||||
|
||||
Usage:
|
||||
/approve — approve and execute the pending command
|
||||
/approve session — approve and remember for this session
|
||||
/approve always — approve this pattern permanently
|
||||
/approve — approve oldest pending command once
|
||||
/approve all — approve ALL pending commands at once
|
||||
/approve session — approve oldest + remember for session
|
||||
/approve all session — approve all + remember for session
|
||||
/approve always — approve oldest + remember permanently
|
||||
/approve all always — approve all + remember permanently
|
||||
"""
|
||||
source = event.source
|
||||
session_key = self._session_key_for_source(source)
|
||||
|
||||
if session_key not in self._pending_approvals:
|
||||
from tools.approval import (
|
||||
resolve_gateway_approval, has_blocking_approval,
|
||||
pending_approval_count,
|
||||
)
|
||||
|
||||
if not has_blocking_approval(session_key):
|
||||
if session_key in self._pending_approvals:
|
||||
self._pending_approvals.pop(session_key)
|
||||
return "⚠️ Approval expired (agent is no longer waiting). Ask the agent to try again."
|
||||
return "No pending command to approve."
|
||||
|
||||
import time as _time
|
||||
approval = self._pending_approvals[session_key]
|
||||
# Parse args: support "all", "all session", "all always", "session", "always"
|
||||
args = event.get_command_args().strip().lower().split()
|
||||
resolve_all = "all" in args
|
||||
remaining = [a for a in args if a != "all"]
|
||||
|
||||
# Check for timeout
|
||||
ts = approval.get("timestamp", 0)
|
||||
if _time.time() - ts > self._APPROVAL_TIMEOUT_SECONDS:
|
||||
self._pending_approvals.pop(session_key, None)
|
||||
return "⚠️ Approval expired (timed out after 5 minutes). Ask the agent to try again."
|
||||
|
||||
self._pending_approvals.pop(session_key)
|
||||
cmd = approval["command"]
|
||||
pattern_keys = approval.get("pattern_keys", [])
|
||||
if not pattern_keys:
|
||||
pk = approval.get("pattern_key", "")
|
||||
pattern_keys = [pk] if pk else []
|
||||
|
||||
# Determine approval scope from args
|
||||
args = event.get_command_args().strip().lower()
|
||||
from tools.approval import approve_session, approve_permanent
|
||||
|
||||
if args in ("always", "permanent", "permanently"):
|
||||
for pk in pattern_keys:
|
||||
approve_permanent(pk)
|
||||
if any(a in ("always", "permanent", "permanently") for a in remaining):
|
||||
choice = "always"
|
||||
scope_msg = " (pattern approved permanently)"
|
||||
elif args in ("session", "ses"):
|
||||
for pk in pattern_keys:
|
||||
approve_session(session_key, pk)
|
||||
elif any(a in ("session", "ses") for a in remaining):
|
||||
choice = "session"
|
||||
scope_msg = " (pattern approved for this session)"
|
||||
else:
|
||||
# One-time approval — just approve for session so the immediate
|
||||
# replay works, but don't advertise it as session-wide
|
||||
for pk in pattern_keys:
|
||||
approve_session(session_key, pk)
|
||||
choice = "once"
|
||||
scope_msg = ""
|
||||
|
||||
logger.info("User approved dangerous command via /approve: %s...%s", cmd[:60], scope_msg)
|
||||
from tools.terminal_tool import terminal_tool
|
||||
result = await asyncio.to_thread(terminal_tool, command=cmd, force=True)
|
||||
count = resolve_gateway_approval(session_key, choice, resolve_all=resolve_all)
|
||||
if not count:
|
||||
return "No pending command to approve."
|
||||
|
||||
# Send immediate feedback so the user sees the command output right away
|
||||
immediate_msg = f"✅ Command approved and executed{scope_msg}.\n\n```\n{result[:3500]}\n```"
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
try:
|
||||
await adapter.send(source.chat_id, immediate_msg)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send approval feedback: %s", e)
|
||||
|
||||
# Re-invoke the agent with the command result so it can continue its task.
|
||||
# The agent's conversation history (persisted in SQLite) already contains
|
||||
# the tool call that returned approval_required — the continuation message
|
||||
# provides the actual execution output so the agent can pick up where it
|
||||
# left off.
|
||||
continuation_text = (
|
||||
f"[System: The user approved the previously blocked command and it has been executed.\n"
|
||||
f"Command: {cmd}\n"
|
||||
f"<command_output>\n{result[:3500]}\n</command_output>\n\n"
|
||||
f"Continue with the task you were working on.]"
|
||||
)
|
||||
|
||||
synthetic_event = MessageEvent(
|
||||
text=continuation_text,
|
||||
source=source,
|
||||
message_id=f"approve-continuation-{uuid.uuid4().hex}",
|
||||
)
|
||||
|
||||
async def _continue_agent():
|
||||
try:
|
||||
response = await self._handle_message(synthetic_event)
|
||||
if response and adapter:
|
||||
await adapter.send(source.chat_id, response)
|
||||
except Exception as e:
|
||||
logger.error("Failed to continue agent after /approve: %s", e)
|
||||
if adapter:
|
||||
try:
|
||||
await adapter.send(
|
||||
source.chat_id,
|
||||
f"⚠️ Failed to resume agent after approval: {e}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_task = asyncio.create_task(_continue_agent())
|
||||
self._background_tasks.add(_task)
|
||||
_task.add_done_callback(self._background_tasks.discard)
|
||||
# Return None — we already sent the immediate feedback and the agent
|
||||
# continuation is running in the background.
|
||||
return None
|
||||
count_msg = f" ({count} commands)" if count > 1 else ""
|
||||
logger.info("User approved %d dangerous command(s) via /approve%s", count, scope_msg)
|
||||
return f"✅ Command{'s' if count > 1 else ''} approved{scope_msg}{count_msg}. The agent is resuming..."
|
||||
|
||||
async def _handle_deny_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /deny command — reject a pending dangerous command."""
|
||||
"""Handle /deny command — reject pending dangerous command(s).
|
||||
|
||||
Signals blocked agent thread(s) with a 'deny' result so they receive
|
||||
a definitive BLOCKED message, same as the CLI deny flow.
|
||||
|
||||
``/deny`` denies the oldest; ``/deny all`` denies everything.
|
||||
"""
|
||||
source = event.source
|
||||
session_key = self._session_key_for_source(source)
|
||||
|
||||
if session_key not in self._pending_approvals:
|
||||
from tools.approval import (
|
||||
resolve_gateway_approval, has_blocking_approval,
|
||||
)
|
||||
|
||||
if not has_blocking_approval(session_key):
|
||||
if session_key in self._pending_approvals:
|
||||
self._pending_approvals.pop(session_key)
|
||||
return "❌ Command denied (approval was stale)."
|
||||
return "No pending command to deny."
|
||||
|
||||
self._pending_approvals.pop(session_key)
|
||||
logger.info("User denied dangerous command via /deny")
|
||||
return "❌ Command denied."
|
||||
args = event.get_command_args().strip().lower()
|
||||
resolve_all = "all" in args
|
||||
|
||||
count = resolve_gateway_approval(session_key, "deny", resolve_all=resolve_all)
|
||||
if not count:
|
||||
return "No pending command to deny."
|
||||
|
||||
count_msg = f" ({count} commands)" if count > 1 else ""
|
||||
logger.info("User denied %d dangerous command(s) via /deny", count)
|
||||
return f"❌ Command{'s' if count > 1 else ''} denied{count_msg}."
|
||||
|
||||
async def _handle_update_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /update command — update Hermes Agent to the latest version.
|
||||
|
|
@ -5829,7 +5784,42 @@ class GatewayRunner:
|
|||
if _p:
|
||||
_history_media_paths.add(_p)
|
||||
|
||||
result = agent.run_conversation(message, conversation_history=agent_history, task_id=session_id)
|
||||
# Register per-session gateway approval callback so dangerous
|
||||
# command approval blocks the agent thread (mirrors CLI input()).
|
||||
# The callback bridges sync→async to send the approval request
|
||||
# to the user immediately.
|
||||
from tools.approval import register_gateway_notify, unregister_gateway_notify
|
||||
|
||||
def _approval_notify_sync(approval_data: dict) -> None:
|
||||
"""Send the approval request to the user from the agent thread."""
|
||||
cmd = approval_data.get("command", "")
|
||||
cmd_preview = cmd[:200] + "..." if len(cmd) > 200 else cmd
|
||||
desc = approval_data.get("description", "dangerous command")
|
||||
msg = (
|
||||
f"⚠️ **Dangerous command requires approval:**\n"
|
||||
f"```\n{cmd_preview}\n```\n"
|
||||
f"Reason: {desc}\n\n"
|
||||
f"Reply `/approve` to execute, `/approve session` to approve this pattern "
|
||||
f"for the session, `/approve always` to approve permanently, or `/deny` to cancel."
|
||||
)
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
_status_adapter.send(
|
||||
_status_chat_id,
|
||||
msg,
|
||||
metadata=_status_thread_metadata,
|
||||
),
|
||||
_loop_for_step,
|
||||
).result(timeout=15)
|
||||
except Exception as _e:
|
||||
logger.error("Failed to send approval request: %s", _e)
|
||||
|
||||
_approval_session_key = session_key or ""
|
||||
register_gateway_notify(_approval_session_key, _approval_notify_sync)
|
||||
try:
|
||||
result = agent.run_conversation(message, conversation_history=agent_history, task_id=session_id)
|
||||
finally:
|
||||
unregister_gateway_notify(_approval_session_key)
|
||||
result_holder[0] = result
|
||||
|
||||
# Signal the stream consumer that the agent is done
|
||||
|
|
|
|||
|
|
@ -1,10 +1,16 @@
|
|||
"""Tests for /approve and /deny gateway commands.
|
||||
|
||||
Verifies that dangerous command approvals require explicit /approve or /deny
|
||||
slash commands, not bare "yes"/"no" text matching.
|
||||
Verifies that dangerous command approvals use the blocking gateway approval
|
||||
mechanism — the agent thread blocks until the user responds with /approve
|
||||
or /deny, mirroring the CLI's synchronous input() flow.
|
||||
|
||||
Supports multiple concurrent approvals (parallel subagents, execute_code)
|
||||
via a per-session queue.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
|
@ -61,14 +67,140 @@ def _make_runner():
|
|||
return runner
|
||||
|
||||
|
||||
def _make_pending_approval(command="sudo rm -rf /tmp/test", pattern_key="sudo"):
|
||||
return {
|
||||
"command": command,
|
||||
"pattern_key": pattern_key,
|
||||
"pattern_keys": [pattern_key],
|
||||
"description": "sudo command",
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
def _clear_approval_state():
|
||||
"""Reset all module-level approval state between tests."""
|
||||
from tools import approval as mod
|
||||
mod._gateway_queues.clear()
|
||||
mod._gateway_notify_cbs.clear()
|
||||
mod._session_approved.clear()
|
||||
mod._permanent_approved.clear()
|
||||
mod._pending.clear()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Blocking gateway approval infrastructure (tools/approval.py)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBlockingGatewayApproval:
|
||||
"""Tests for the blocking approval mechanism in tools/approval.py."""
|
||||
|
||||
def setup_method(self):
|
||||
_clear_approval_state()
|
||||
|
||||
def test_register_and_resolve_unblocks_entry(self):
|
||||
"""resolve_gateway_approval signals the entry's event."""
|
||||
from tools.approval import (
|
||||
register_gateway_notify, unregister_gateway_notify,
|
||||
resolve_gateway_approval, has_blocking_approval,
|
||||
_ApprovalEntry, _gateway_queues,
|
||||
)
|
||||
session_key = "test-session"
|
||||
register_gateway_notify(session_key, lambda d: None)
|
||||
|
||||
# Simulate what check_all_command_guards does
|
||||
entry = _ApprovalEntry({"command": "rm -rf /"})
|
||||
_gateway_queues.setdefault(session_key, []).append(entry)
|
||||
|
||||
assert has_blocking_approval(session_key) is True
|
||||
|
||||
# Resolve from another thread
|
||||
def resolve():
|
||||
time.sleep(0.1)
|
||||
resolve_gateway_approval(session_key, "once")
|
||||
|
||||
t = threading.Thread(target=resolve)
|
||||
t.start()
|
||||
resolved = entry.event.wait(timeout=5)
|
||||
t.join()
|
||||
|
||||
assert resolved is True
|
||||
assert entry.result == "once"
|
||||
unregister_gateway_notify(session_key)
|
||||
|
||||
def test_resolve_returns_zero_when_no_pending(self):
|
||||
from tools.approval import resolve_gateway_approval
|
||||
assert resolve_gateway_approval("nonexistent", "once") == 0
|
||||
|
||||
def test_resolve_all_unblocks_multiple_entries(self):
|
||||
"""resolve_gateway_approval with resolve_all=True signals all entries."""
|
||||
from tools.approval import (
|
||||
resolve_gateway_approval, _ApprovalEntry, _gateway_queues,
|
||||
)
|
||||
session_key = "test-all"
|
||||
e1 = _ApprovalEntry({"command": "cmd1"})
|
||||
e2 = _ApprovalEntry({"command": "cmd2"})
|
||||
e3 = _ApprovalEntry({"command": "cmd3"})
|
||||
_gateway_queues[session_key] = [e1, e2, e3]
|
||||
|
||||
count = resolve_gateway_approval(session_key, "session", resolve_all=True)
|
||||
assert count == 3
|
||||
assert all(e.event.is_set() for e in [e1, e2, e3])
|
||||
assert all(e.result == "session" for e in [e1, e2, e3])
|
||||
|
||||
def test_resolve_single_pops_oldest_fifo(self):
|
||||
"""resolve_gateway_approval without resolve_all resolves oldest first."""
|
||||
from tools.approval import (
|
||||
resolve_gateway_approval, pending_approval_count,
|
||||
_ApprovalEntry, _gateway_queues,
|
||||
)
|
||||
session_key = "test-fifo"
|
||||
e1 = _ApprovalEntry({"command": "first"})
|
||||
e2 = _ApprovalEntry({"command": "second"})
|
||||
_gateway_queues[session_key] = [e1, e2]
|
||||
|
||||
count = resolve_gateway_approval(session_key, "once")
|
||||
assert count == 1
|
||||
assert e1.event.is_set()
|
||||
assert e1.result == "once"
|
||||
assert not e2.event.is_set()
|
||||
assert pending_approval_count(session_key) == 1
|
||||
|
||||
def test_unregister_signals_all_entries(self):
|
||||
"""unregister_gateway_notify signals all waiting entries to prevent hangs."""
|
||||
from tools.approval import (
|
||||
register_gateway_notify, unregister_gateway_notify,
|
||||
_ApprovalEntry, _gateway_queues,
|
||||
)
|
||||
session_key = "test-cleanup"
|
||||
register_gateway_notify(session_key, lambda d: None)
|
||||
|
||||
e1 = _ApprovalEntry({"command": "cmd1"})
|
||||
e2 = _ApprovalEntry({"command": "cmd2"})
|
||||
_gateway_queues[session_key] = [e1, e2]
|
||||
|
||||
unregister_gateway_notify(session_key)
|
||||
assert e1.event.is_set()
|
||||
assert e2.event.is_set()
|
||||
|
||||
def test_clear_session_signals_all_entries(self):
|
||||
"""clear_session should unblock all waiting approval threads."""
|
||||
from tools.approval import (
|
||||
register_gateway_notify, clear_session,
|
||||
_ApprovalEntry, _gateway_queues,
|
||||
)
|
||||
session_key = "test-clear"
|
||||
register_gateway_notify(session_key, lambda d: None)
|
||||
|
||||
e1 = _ApprovalEntry({"command": "cmd1"})
|
||||
e2 = _ApprovalEntry({"command": "cmd2"})
|
||||
_gateway_queues[session_key] = [e1, e2]
|
||||
|
||||
clear_session(session_key)
|
||||
assert e1.event.is_set()
|
||||
assert e2.event.is_set()
|
||||
|
||||
def test_pending_approval_count(self):
|
||||
from tools.approval import (
|
||||
pending_approval_count, _ApprovalEntry, _gateway_queues,
|
||||
)
|
||||
session_key = "test-count"
|
||||
assert pending_approval_count(session_key) == 0
|
||||
_gateway_queues[session_key] = [
|
||||
_ApprovalEntry({"command": "a"}),
|
||||
_ApprovalEntry({"command": "b"}),
|
||||
]
|
||||
assert pending_approval_count(session_key) == 2
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -78,146 +210,81 @@ def _make_pending_approval(command="sudo rm -rf /tmp/test", pattern_key="sudo"):
|
|||
|
||||
class TestApproveCommand:
|
||||
|
||||
def setup_method(self):
|
||||
_clear_approval_state()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_executes_pending_command(self):
|
||||
"""Basic /approve executes the pending command and sends feedback."""
|
||||
async def test_approve_resolves_blocking_approval(self):
|
||||
"""Basic /approve signals the oldest blocked agent thread."""
|
||||
from tools.approval import _ApprovalEntry, _gateway_queues
|
||||
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/approve")
|
||||
with (
|
||||
patch("tools.terminal_tool.terminal_tool", return_value="done") as mock_term,
|
||||
patch.object(runner, "_handle_message", new_callable=AsyncMock, return_value="agent continued"),
|
||||
):
|
||||
result = await runner._handle_approve_command(event)
|
||||
# Yield to let the background continuation task run.
|
||||
# This works because mocks return immediately (no real await points).
|
||||
await asyncio.sleep(0)
|
||||
entry = _ApprovalEntry({"command": "test"})
|
||||
_gateway_queues[session_key] = [entry]
|
||||
|
||||
# Returns None because feedback is sent directly via adapter
|
||||
assert result is None
|
||||
mock_term.assert_called_once_with(command="sudo rm -rf /tmp/test", force=True)
|
||||
assert session_key not in runner._pending_approvals
|
||||
|
||||
# Immediate feedback sent via adapter
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
sent_text = adapter.send.call_args_list[0][0][1]
|
||||
assert "Command approved and executed" in sent_text
|
||||
result = await runner._handle_approve_command(_make_event("/approve"))
|
||||
assert "approved" in result.lower()
|
||||
assert "resuming" in result.lower()
|
||||
assert entry.event.is_set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_session_remembers_pattern(self):
|
||||
"""/approve session approves the pattern for the session."""
|
||||
async def test_approve_all_resolves_multiple(self):
|
||||
"""/approve all resolves all pending approvals."""
|
||||
from tools.approval import _ApprovalEntry, _gateway_queues
|
||||
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/approve session")
|
||||
with (
|
||||
patch("tools.terminal_tool.terminal_tool", return_value="done"),
|
||||
patch("tools.approval.approve_session") as mock_session,
|
||||
patch.object(runner, "_handle_message", new_callable=AsyncMock, return_value=None),
|
||||
):
|
||||
result = await runner._handle_approve_command(event)
|
||||
# Yield to let the background continuation task run.
|
||||
# This works because mocks return immediately (no real await points).
|
||||
await asyncio.sleep(0)
|
||||
e1 = _ApprovalEntry({"command": "cmd1"})
|
||||
e2 = _ApprovalEntry({"command": "cmd2"})
|
||||
_gateway_queues[session_key] = [e1, e2]
|
||||
|
||||
assert result is None
|
||||
mock_session.assert_called_once_with(session_key, "sudo")
|
||||
|
||||
# Verify scope message in adapter feedback
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
sent_text = adapter.send.call_args_list[0][0][1]
|
||||
assert "pattern approved for this session" in sent_text
|
||||
result = await runner._handle_approve_command(_make_event("/approve all"))
|
||||
assert "2 commands" in result
|
||||
assert e1.event.is_set()
|
||||
assert e2.event.is_set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_always_approves_permanently(self):
|
||||
"""/approve always approves the pattern permanently."""
|
||||
async def test_approve_all_session(self):
|
||||
"""/approve all session resolves all with session scope."""
|
||||
from tools.approval import _ApprovalEntry, _gateway_queues
|
||||
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/approve always")
|
||||
with (
|
||||
patch("tools.terminal_tool.terminal_tool", return_value="done"),
|
||||
patch("tools.approval.approve_permanent") as mock_perm,
|
||||
patch.object(runner, "_handle_message", new_callable=AsyncMock, return_value=None),
|
||||
):
|
||||
result = await runner._handle_approve_command(event)
|
||||
# Yield to let the background continuation task run.
|
||||
# This works because mocks return immediately (no real await points).
|
||||
await asyncio.sleep(0)
|
||||
e1 = _ApprovalEntry({"command": "cmd1"})
|
||||
e2 = _ApprovalEntry({"command": "cmd2"})
|
||||
_gateway_queues[session_key] = [e1, e2]
|
||||
|
||||
assert result is None
|
||||
mock_perm.assert_called_once_with("sudo")
|
||||
|
||||
# Verify scope message in adapter feedback
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
sent_text = adapter.send.call_args_list[0][0][1]
|
||||
assert "pattern approved permanently" in sent_text
|
||||
result = await runner._handle_approve_command(_make_event("/approve all session"))
|
||||
assert "session" in result.lower()
|
||||
assert e1.result == "session"
|
||||
assert e2.result == "session"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_no_pending(self):
|
||||
"""/approve with no pending approval returns helpful message."""
|
||||
runner = _make_runner()
|
||||
event = _make_event("/approve")
|
||||
result = await runner._handle_approve_command(event)
|
||||
result = await runner._handle_approve_command(_make_event("/approve"))
|
||||
assert "No pending command" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_expired(self):
|
||||
"""/approve on a timed-out approval rejects it."""
|
||||
async def test_approve_stale_old_style_pending(self):
|
||||
"""Old-style _pending_approvals without blocking event reports expired."""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
approval = _make_pending_approval()
|
||||
approval["timestamp"] = time.time() - 600 # 10 minutes ago
|
||||
runner._pending_approvals[session_key] = approval
|
||||
runner._pending_approvals[session_key] = {"command": "test"}
|
||||
|
||||
event = _make_event("/approve")
|
||||
result = await runner._handle_approve_command(event)
|
||||
|
||||
assert "expired" in result
|
||||
result = await runner._handle_approve_command(_make_event("/approve"))
|
||||
assert "expired" in result.lower() or "no longer waiting" in result.lower()
|
||||
assert session_key not in runner._pending_approvals
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_reinvokes_agent_with_result(self):
|
||||
"""After executing, /approve re-invokes the agent with command output."""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/approve")
|
||||
mock_handle = AsyncMock(return_value="I continued the task.")
|
||||
|
||||
with (
|
||||
patch("tools.terminal_tool.terminal_tool", return_value="file deleted"),
|
||||
patch.object(runner, "_handle_message", mock_handle),
|
||||
):
|
||||
await runner._handle_approve_command(event)
|
||||
# Yield to let the background continuation task run.
|
||||
# This works because mocks return immediately (no real await points).
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Agent was re-invoked via _handle_message with a synthetic event
|
||||
mock_handle.assert_called_once()
|
||||
synthetic_event = mock_handle.call_args[0][0]
|
||||
assert "approved" in synthetic_event.text.lower()
|
||||
assert "file deleted" in synthetic_event.text
|
||||
assert "sudo rm -rf /tmp/test" in synthetic_event.text
|
||||
|
||||
# The continuation response was sent to the user
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
# First call: immediate feedback, second call: agent continuation
|
||||
assert adapter.send.call_count == 2
|
||||
continuation_response = adapter.send.call_args_list[1][0][1]
|
||||
assert continuation_response == "I continued the task."
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# /deny command
|
||||
|
|
@ -226,26 +293,48 @@ class TestApproveCommand:
|
|||
|
||||
class TestDenyCommand:
|
||||
|
||||
def setup_method(self):
|
||||
_clear_approval_state()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_clears_pending(self):
|
||||
"""/deny clears the pending approval."""
|
||||
async def test_deny_resolves_blocking_approval(self):
|
||||
"""/deny signals the oldest blocked agent thread with 'deny'."""
|
||||
from tools.approval import _ApprovalEntry, _gateway_queues
|
||||
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
event = _make_event("/deny")
|
||||
result = await runner._handle_deny_command(event)
|
||||
entry = _ApprovalEntry({"command": "test"})
|
||||
_gateway_queues[session_key] = [entry]
|
||||
|
||||
assert "❌ Command denied" in result
|
||||
assert session_key not in runner._pending_approvals
|
||||
result = await runner._handle_deny_command(_make_event("/deny"))
|
||||
assert "denied" in result.lower()
|
||||
assert entry.event.is_set()
|
||||
assert entry.result == "deny"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_all_resolves_all(self):
|
||||
"""/deny all denies all pending approvals."""
|
||||
from tools.approval import _ApprovalEntry, _gateway_queues
|
||||
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
|
||||
e1 = _ApprovalEntry({"command": "cmd1"})
|
||||
e2 = _ApprovalEntry({"command": "cmd2"})
|
||||
_gateway_queues[session_key] = [e1, e2]
|
||||
|
||||
result = await runner._handle_deny_command(_make_event("/deny all"))
|
||||
assert "2 commands" in result
|
||||
assert all(e.result == "deny" for e in [e1, e2])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_no_pending(self):
|
||||
"""/deny with no pending approval returns helpful message."""
|
||||
runner = _make_runner()
|
||||
event = _make_event("/deny")
|
||||
result = await runner._handle_deny_command(event)
|
||||
result = await runner._handle_deny_command(_make_event("/deny"))
|
||||
assert "No pending command" in result
|
||||
|
||||
|
||||
|
|
@ -256,51 +345,267 @@ class TestDenyCommand:
|
|||
|
||||
class TestBareTextNoLongerApproves:
|
||||
|
||||
def setup_method(self):
|
||||
_clear_approval_state()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_yes_does_not_execute_pending_command(self):
|
||||
"""Saying 'yes' in normal conversation must not execute a pending command.
|
||||
"""Saying 'yes' must not trigger approval. Only /approve works."""
|
||||
from tools.approval import _ApprovalEntry, _gateway_queues
|
||||
|
||||
This is the core bug from issue #1888: bare text matching against
|
||||
'yes'/'no' could intercept unrelated user messages.
|
||||
"""
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._pending_approvals[session_key] = _make_pending_approval()
|
||||
|
||||
# Simulate the user saying "yes" as a normal message.
|
||||
# The old code would have executed the pending command.
|
||||
# Now it should fall through to normal processing (agent handles it).
|
||||
event = _make_event("yes")
|
||||
entry = _ApprovalEntry({"command": "test"})
|
||||
_gateway_queues[session_key] = [entry]
|
||||
|
||||
# The approval should still be pending — "yes" is not /approve
|
||||
# We can't easily run _handle_message end-to-end, but we CAN verify
|
||||
# the old text-matching block no longer exists by confirming the
|
||||
# approval is untouched after the command dispatch section.
|
||||
# The key assertion is that _pending_approvals is NOT consumed.
|
||||
assert session_key in runner._pending_approvals
|
||||
# "yes" is not /approve — entry should still be pending
|
||||
assert not entry.event.is_set()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Approval hint appended to response
|
||||
# End-to-end blocking flow
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApprovalHint:
|
||||
class TestBlockingApprovalE2E:
|
||||
"""Test the full blocking flow: agent thread blocks → user approves → agent resumes."""
|
||||
|
||||
def test_approval_hint_appended_to_response(self):
|
||||
"""When a pending approval is collected, structured instructions
|
||||
should be appended to the agent response."""
|
||||
# This tests the approval collection logic at the end of _handle_message.
|
||||
# We verify the hint format directly.
|
||||
cmd = "sudo rm -rf /tmp/dangerous"
|
||||
cmd_preview = cmd
|
||||
hint = (
|
||||
f"\n\n⚠️ **Dangerous command requires approval:**\n"
|
||||
f"```\n{cmd_preview}\n```\n"
|
||||
f"Reply `/approve` to execute, `/approve session` to approve this pattern "
|
||||
f"for the session, or `/deny` to cancel."
|
||||
def setup_method(self):
|
||||
_clear_approval_state()
|
||||
|
||||
def test_blocking_approval_approve_once(self):
|
||||
"""check_all_command_guards blocks until resolve_gateway_approval is called."""
|
||||
from tools.approval import (
|
||||
register_gateway_notify, unregister_gateway_notify,
|
||||
resolve_gateway_approval, check_all_command_guards,
|
||||
)
|
||||
assert "/approve" in hint
|
||||
assert "/deny" in hint
|
||||
assert cmd in hint
|
||||
|
||||
session_key = "e2e-test"
|
||||
notified = []
|
||||
|
||||
register_gateway_notify(session_key, lambda d: notified.append(d))
|
||||
|
||||
result_holder = [None]
|
||||
|
||||
def agent_thread():
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
result_holder[0] = check_all_command_guards(
|
||||
"rm -rf /important", "local"
|
||||
)
|
||||
finally:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
|
||||
t = threading.Thread(target=agent_thread)
|
||||
t.start()
|
||||
|
||||
for _ in range(50):
|
||||
if notified:
|
||||
break
|
||||
time.sleep(0.05)
|
||||
|
||||
assert len(notified) == 1
|
||||
assert "rm -rf /important" in notified[0]["command"]
|
||||
|
||||
resolve_gateway_approval(session_key, "once")
|
||||
t.join(timeout=5)
|
||||
|
||||
assert result_holder[0] is not None
|
||||
assert result_holder[0]["approved"] is True
|
||||
unregister_gateway_notify(session_key)
|
||||
|
||||
def test_blocking_approval_deny(self):
|
||||
"""check_all_command_guards returns BLOCKED when denied."""
|
||||
from tools.approval import (
|
||||
register_gateway_notify, unregister_gateway_notify,
|
||||
resolve_gateway_approval, check_all_command_guards,
|
||||
)
|
||||
|
||||
session_key = "e2e-deny"
|
||||
notified = []
|
||||
register_gateway_notify(session_key, lambda d: notified.append(d))
|
||||
|
||||
result_holder = [None]
|
||||
|
||||
def agent_thread():
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
result_holder[0] = check_all_command_guards(
|
||||
"rm -rf /important", "local"
|
||||
)
|
||||
finally:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
|
||||
t = threading.Thread(target=agent_thread)
|
||||
t.start()
|
||||
for _ in range(50):
|
||||
if notified:
|
||||
break
|
||||
time.sleep(0.05)
|
||||
|
||||
resolve_gateway_approval(session_key, "deny")
|
||||
t.join(timeout=5)
|
||||
|
||||
assert result_holder[0]["approved"] is False
|
||||
assert "BLOCKED" in result_holder[0]["message"]
|
||||
unregister_gateway_notify(session_key)
|
||||
|
||||
def test_blocking_approval_timeout(self):
|
||||
"""check_all_command_guards returns BLOCKED on timeout."""
|
||||
from tools.approval import (
|
||||
register_gateway_notify, unregister_gateway_notify,
|
||||
check_all_command_guards,
|
||||
)
|
||||
|
||||
session_key = "e2e-timeout"
|
||||
register_gateway_notify(session_key, lambda d: None)
|
||||
|
||||
result_holder = [None]
|
||||
|
||||
def agent_thread():
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
with patch("tools.approval._get_approval_config",
|
||||
return_value={"gateway_timeout": 1}):
|
||||
result_holder[0] = check_all_command_guards(
|
||||
"rm -rf /important", "local"
|
||||
)
|
||||
finally:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
|
||||
t = threading.Thread(target=agent_thread)
|
||||
t.start()
|
||||
t.join(timeout=10)
|
||||
|
||||
assert result_holder[0]["approved"] is False
|
||||
assert "timed out" in result_holder[0]["message"]
|
||||
unregister_gateway_notify(session_key)
|
||||
|
||||
def test_parallel_subagent_approvals(self):
|
||||
"""Multiple threads can block concurrently and be resolved independently."""
|
||||
from tools.approval import (
|
||||
register_gateway_notify, unregister_gateway_notify,
|
||||
resolve_gateway_approval, check_all_command_guards,
|
||||
pending_approval_count,
|
||||
)
|
||||
|
||||
session_key = "e2e-parallel"
|
||||
notified = []
|
||||
register_gateway_notify(session_key, lambda d: notified.append(d))
|
||||
|
||||
results = [None, None, None]
|
||||
|
||||
def make_agent(idx, cmd):
|
||||
def run():
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
results[idx] = check_all_command_guards(cmd, "local")
|
||||
finally:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
return run
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=make_agent(0, "rm -rf /a")),
|
||||
threading.Thread(target=make_agent(1, "rm -rf /b")),
|
||||
threading.Thread(target=make_agent(2, "rm -rf /c")),
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
# Wait for all 3 to block
|
||||
for _ in range(100):
|
||||
if len(notified) >= 3:
|
||||
break
|
||||
time.sleep(0.05)
|
||||
|
||||
assert len(notified) == 3
|
||||
assert pending_approval_count(session_key) == 3
|
||||
|
||||
# Approve all at once
|
||||
count = resolve_gateway_approval(session_key, "session", resolve_all=True)
|
||||
assert count == 3
|
||||
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
|
||||
assert all(r is not None for r in results)
|
||||
assert all(r["approved"] is True for r in results)
|
||||
unregister_gateway_notify(session_key)
|
||||
|
||||
def test_parallel_mixed_approve_deny(self):
|
||||
"""Approve some, deny others in a parallel batch."""
|
||||
from tools.approval import (
|
||||
register_gateway_notify, unregister_gateway_notify,
|
||||
resolve_gateway_approval, check_all_command_guards,
|
||||
)
|
||||
|
||||
session_key = "e2e-mixed"
|
||||
register_gateway_notify(session_key, lambda d: None)
|
||||
|
||||
results = [None, None]
|
||||
|
||||
def make_agent(idx, cmd):
|
||||
def run():
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key
|
||||
try:
|
||||
results[idx] = check_all_command_guards(cmd, "local")
|
||||
finally:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
return run
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=make_agent(0, "rm -rf /x")),
|
||||
threading.Thread(target=make_agent(1, "rm -rf /y")),
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
time.sleep(0.3)
|
||||
|
||||
# Approve first, deny second
|
||||
resolve_gateway_approval(session_key, "once") # oldest
|
||||
resolve_gateway_approval(session_key, "deny") # next
|
||||
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
|
||||
assert results[0]["approved"] is True
|
||||
assert results[1]["approved"] is False
|
||||
unregister_gateway_notify(session_key)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Fallback: no gateway callback (cron/batch mode)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFallbackNoCallback:
|
||||
|
||||
def setup_method(self):
|
||||
_clear_approval_state()
|
||||
|
||||
def test_no_callback_returns_approval_required(self):
|
||||
"""Without a registered callback, the old approval_required path is used."""
|
||||
from tools.approval import check_all_command_guards, _pending
|
||||
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
os.environ["HERMES_SESSION_KEY"] = "no-callback-test"
|
||||
try:
|
||||
result = check_all_command_guards("rm -rf /important", "local")
|
||||
finally:
|
||||
os.environ.pop("HERMES_EXEC_ASK", None)
|
||||
os.environ.pop("HERMES_SESSION_KEY", None)
|
||||
|
||||
assert result["approved"] is False
|
||||
assert result.get("status") == "approval_required"
|
||||
|
|
|
|||
|
|
@ -146,6 +146,94 @@ _pending: dict[str, dict] = {}
|
|||
_session_approved: dict[str, set] = {}
|
||||
_permanent_approved: set = set()
|
||||
|
||||
# =========================================================================
|
||||
# Blocking gateway approval (mirrors CLI's synchronous input() flow)
|
||||
# =========================================================================
|
||||
# Per-session QUEUE of pending approvals. Multiple threads (parallel
|
||||
# subagents, execute_code RPC handlers) can block concurrently — each gets
|
||||
# its own threading.Event. /approve resolves the oldest, /approve all
|
||||
# resolves every pending approval in the session.
|
||||
|
||||
|
||||
class _ApprovalEntry:
|
||||
"""One pending dangerous-command approval inside a gateway session."""
|
||||
__slots__ = ("event", "data", "result")
|
||||
|
||||
def __init__(self, data: dict):
|
||||
self.event = threading.Event()
|
||||
self.data = data # command, description, pattern_keys, …
|
||||
self.result: Optional[str] = None # "once"|"session"|"always"|"deny"
|
||||
|
||||
|
||||
_gateway_queues: dict[str, list] = {} # session_key → [_ApprovalEntry, …]
|
||||
_gateway_notify_cbs: dict[str, object] = {} # session_key → callable(approval_data)
|
||||
|
||||
|
||||
def register_gateway_notify(session_key: str, cb) -> None:
|
||||
"""Register a per-session callback for sending approval requests to the user.
|
||||
|
||||
The callback signature is ``cb(approval_data: dict) -> None`` where
|
||||
*approval_data* contains ``command``, ``description``, and
|
||||
``pattern_keys``. The callback bridges sync→async (runs in the agent
|
||||
thread, must schedule the actual send on the event loop).
|
||||
"""
|
||||
with _lock:
|
||||
_gateway_notify_cbs[session_key] = cb
|
||||
|
||||
|
||||
def unregister_gateway_notify(session_key: str) -> None:
|
||||
"""Unregister the per-session gateway approval callback.
|
||||
|
||||
Signals ALL blocked threads for this session so they don't hang forever
|
||||
(e.g. when the agent run finishes or is interrupted).
|
||||
"""
|
||||
with _lock:
|
||||
_gateway_notify_cbs.pop(session_key, None)
|
||||
entries = _gateway_queues.pop(session_key, [])
|
||||
for entry in entries:
|
||||
entry.event.set()
|
||||
|
||||
|
||||
def resolve_gateway_approval(session_key: str, choice: str,
|
||||
resolve_all: bool = False) -> int:
|
||||
"""Called by the gateway's /approve or /deny handler to unblock
|
||||
waiting agent thread(s).
|
||||
|
||||
When *resolve_all* is True every pending approval in the session is
|
||||
resolved at once (``/approve all``). Otherwise only the oldest one
|
||||
is resolved (FIFO).
|
||||
|
||||
Returns the number of approvals resolved (0 means nothing was pending).
|
||||
"""
|
||||
with _lock:
|
||||
queue = _gateway_queues.get(session_key)
|
||||
if not queue:
|
||||
return 0
|
||||
if resolve_all:
|
||||
targets = list(queue)
|
||||
queue.clear()
|
||||
else:
|
||||
targets = [queue.pop(0)]
|
||||
if not queue:
|
||||
_gateway_queues.pop(session_key, None)
|
||||
|
||||
for entry in targets:
|
||||
entry.result = choice
|
||||
entry.event.set()
|
||||
return len(targets)
|
||||
|
||||
|
||||
def has_blocking_approval(session_key: str) -> bool:
|
||||
"""Check if a session has one or more blocking gateway approvals waiting."""
|
||||
with _lock:
|
||||
return bool(_gateway_queues.get(session_key))
|
||||
|
||||
|
||||
def pending_approval_count(session_key: str) -> int:
|
||||
"""Return the number of pending blocking approvals for a session."""
|
||||
with _lock:
|
||||
return len(_gateway_queues.get(session_key, []))
|
||||
|
||||
|
||||
def submit_pending(session_key: str, approval: dict):
|
||||
"""Store a pending approval request for a session."""
|
||||
|
|
@ -202,6 +290,11 @@ def clear_session(session_key: str):
|
|||
with _lock:
|
||||
_session_approved.pop(session_key, None)
|
||||
_pending.pop(session_key, None)
|
||||
_gateway_notify_cbs.pop(session_key, None)
|
||||
# Signal ALL blocked threads so they don't hang forever
|
||||
entries = _gateway_queues.pop(session_key, [])
|
||||
for entry in entries:
|
||||
entry.event.set()
|
||||
|
||||
|
||||
# =========================================================================
|
||||
|
|
@ -622,13 +715,90 @@ def check_all_command_guards(command: str, env_type: str,
|
|||
all_keys = [key for key, _, _ in warnings]
|
||||
has_tirith = any(is_t for _, _, is_t in warnings)
|
||||
|
||||
# Gateway/async: single approval_required with combined description
|
||||
# Store all pattern keys so gateway replay approves all of them
|
||||
# Gateway/async approval — block the agent thread until the user
|
||||
# responds with /approve or /deny, mirroring the CLI's synchronous
|
||||
# input() flow. The agent never sees "approval_required"; it either
|
||||
# gets the command output (approved) or a definitive "BLOCKED" message.
|
||||
if is_gateway or is_ask:
|
||||
notify_cb = None
|
||||
with _lock:
|
||||
notify_cb = _gateway_notify_cbs.get(session_key)
|
||||
|
||||
if notify_cb is not None:
|
||||
# --- Blocking gateway approval (queue-based) ---
|
||||
# Each call gets its own _ApprovalEntry so parallel subagents
|
||||
# and execute_code threads can block concurrently.
|
||||
approval_data = {
|
||||
"command": command,
|
||||
"pattern_key": primary_key,
|
||||
"pattern_keys": all_keys,
|
||||
"description": combined_desc,
|
||||
}
|
||||
entry = _ApprovalEntry(approval_data)
|
||||
with _lock:
|
||||
_gateway_queues.setdefault(session_key, []).append(entry)
|
||||
|
||||
# Notify the user (bridges sync agent thread → async gateway)
|
||||
try:
|
||||
notify_cb(approval_data)
|
||||
except Exception as exc:
|
||||
logger.warning("Gateway approval notify failed: %s", exc)
|
||||
with _lock:
|
||||
queue = _gateway_queues.get(session_key, [])
|
||||
if entry in queue:
|
||||
queue.remove(entry)
|
||||
if not queue:
|
||||
_gateway_queues.pop(session_key, None)
|
||||
return {
|
||||
"approved": False,
|
||||
"message": "BLOCKED: Failed to send approval request to user. Do NOT retry.",
|
||||
"pattern_key": primary_key,
|
||||
"description": combined_desc,
|
||||
}
|
||||
|
||||
# Block until the user responds or timeout (default 5 min)
|
||||
timeout = _get_approval_config().get("gateway_timeout", 300)
|
||||
try:
|
||||
timeout = int(timeout)
|
||||
except (ValueError, TypeError):
|
||||
timeout = 300
|
||||
resolved = entry.event.wait(timeout=timeout)
|
||||
|
||||
# Clean up this entry from the queue
|
||||
with _lock:
|
||||
queue = _gateway_queues.get(session_key, [])
|
||||
if entry in queue:
|
||||
queue.remove(entry)
|
||||
if not queue:
|
||||
_gateway_queues.pop(session_key, None)
|
||||
|
||||
choice = entry.result
|
||||
if not resolved or choice is None or choice == "deny":
|
||||
reason = "timed out" if not resolved else "denied by user"
|
||||
return {
|
||||
"approved": False,
|
||||
"message": f"BLOCKED: Command {reason}. Do NOT retry this command.",
|
||||
"pattern_key": primary_key,
|
||||
"description": combined_desc,
|
||||
}
|
||||
|
||||
# User approved — persist based on scope (same logic as CLI)
|
||||
for key, _, is_tirith in warnings:
|
||||
if choice in ("once", "session") or (choice == "always" and is_tirith):
|
||||
approve_session(session_key, key)
|
||||
elif choice == "always":
|
||||
approve_session(session_key, key)
|
||||
approve_permanent(key)
|
||||
save_permanent_allowlist(_permanent_approved)
|
||||
|
||||
return {"approved": True, "message": None}
|
||||
|
||||
# Fallback: no gateway callback registered (e.g. cron, batch).
|
||||
# Return approval_required for backward compat.
|
||||
submit_pending(session_key, {
|
||||
"command": command,
|
||||
"pattern_key": primary_key, # backward compat
|
||||
"pattern_keys": all_keys, # all keys for replay
|
||||
"pattern_key": primary_key,
|
||||
"pattern_keys": all_keys,
|
||||
"description": combined_desc,
|
||||
})
|
||||
return {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue