fix: make gateway approval block agent thread like CLI does (#4557)

The gateway's dangerous command approval system was fundamentally broken:
the agent loop continued running after a command was flagged, and the
approval request only reached the user after the agent finished its
entire conversation loop. By then the context was lost.

This change makes the gateway approval mirror the CLI's synchronous
behavior. When a dangerous command is detected:

1. The agent thread blocks on a threading.Event
2. The approval request is sent to the user immediately
3. The user responds with /approve or /deny
4. The event is signaled and the agent resumes with the real result

The agent never sees 'approval_required' as a tool result. It either
gets the command output (approved) or a definitive BLOCKED message
(denied/timed out) — same as CLI mode.

Queue-based design supports multiple concurrent approvals (parallel
subagents via delegate_task, execute_code RPC handlers). Each approval
gets its own _ApprovalEntry with its own threading.Event. /approve
resolves the oldest (FIFO); /approve all resolves all at once.

Changes:
- tools/approval.py: Queue-based per-session blocking gateway approval
  (register/unregister callbacks, resolve with FIFO or all-at-once)
- gateway/run.py: Register approval callback in run_sync(), remove
  post-loop pop_pending hack, /approve and /deny support 'all' flag
- tests: 21 tests including parallel subagent E2E scenarios
This commit is contained in:
Teknium 2026-04-02 01:47:19 -07:00 committed by GitHub
parent 64584a931f
commit 624ad582a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 742 additions and 277 deletions

View file

@ -146,6 +146,94 @@ _pending: dict[str, dict] = {}
_session_approved: dict[str, set] = {}
_permanent_approved: set = set()
# =========================================================================
# Blocking gateway approval (mirrors CLI's synchronous input() flow)
# =========================================================================
# Per-session QUEUE of pending approvals. Multiple threads (parallel
# subagents, execute_code RPC handlers) can block concurrently — each gets
# its own threading.Event. /approve resolves the oldest, /approve all
# resolves every pending approval in the session.
class _ApprovalEntry:
"""One pending dangerous-command approval inside a gateway session."""
__slots__ = ("event", "data", "result")
def __init__(self, data: dict):
self.event = threading.Event()
self.data = data # command, description, pattern_keys, …
self.result: Optional[str] = None # "once"|"session"|"always"|"deny"
_gateway_queues: dict[str, list] = {} # session_key → [_ApprovalEntry, …]
_gateway_notify_cbs: dict[str, object] = {} # session_key → callable(approval_data)
def register_gateway_notify(session_key: str, cb) -> None:
"""Register a per-session callback for sending approval requests to the user.
The callback signature is ``cb(approval_data: dict) -> None`` where
*approval_data* contains ``command``, ``description``, and
``pattern_keys``. The callback bridges syncasync (runs in the agent
thread, must schedule the actual send on the event loop).
"""
with _lock:
_gateway_notify_cbs[session_key] = cb
def unregister_gateway_notify(session_key: str) -> None:
"""Unregister the per-session gateway approval callback.
Signals ALL blocked threads for this session so they don't hang forever
(e.g. when the agent run finishes or is interrupted).
"""
with _lock:
_gateway_notify_cbs.pop(session_key, None)
entries = _gateway_queues.pop(session_key, [])
for entry in entries:
entry.event.set()
def resolve_gateway_approval(session_key: str, choice: str,
resolve_all: bool = False) -> int:
"""Called by the gateway's /approve or /deny handler to unblock
waiting agent thread(s).
When *resolve_all* is True every pending approval in the session is
resolved at once (``/approve all``). Otherwise only the oldest one
is resolved (FIFO).
Returns the number of approvals resolved (0 means nothing was pending).
"""
with _lock:
queue = _gateway_queues.get(session_key)
if not queue:
return 0
if resolve_all:
targets = list(queue)
queue.clear()
else:
targets = [queue.pop(0)]
if not queue:
_gateway_queues.pop(session_key, None)
for entry in targets:
entry.result = choice
entry.event.set()
return len(targets)
def has_blocking_approval(session_key: str) -> bool:
"""Check if a session has one or more blocking gateway approvals waiting."""
with _lock:
return bool(_gateway_queues.get(session_key))
def pending_approval_count(session_key: str) -> int:
"""Return the number of pending blocking approvals for a session."""
with _lock:
return len(_gateway_queues.get(session_key, []))
def submit_pending(session_key: str, approval: dict):
"""Store a pending approval request for a session."""
@ -202,6 +290,11 @@ def clear_session(session_key: str):
with _lock:
_session_approved.pop(session_key, None)
_pending.pop(session_key, None)
_gateway_notify_cbs.pop(session_key, None)
# Signal ALL blocked threads so they don't hang forever
entries = _gateway_queues.pop(session_key, [])
for entry in entries:
entry.event.set()
# =========================================================================
@ -622,13 +715,90 @@ def check_all_command_guards(command: str, env_type: str,
all_keys = [key for key, _, _ in warnings]
has_tirith = any(is_t for _, _, is_t in warnings)
# Gateway/async: single approval_required with combined description
# Store all pattern keys so gateway replay approves all of them
# Gateway/async approval — block the agent thread until the user
# responds with /approve or /deny, mirroring the CLI's synchronous
# input() flow. The agent never sees "approval_required"; it either
# gets the command output (approved) or a definitive "BLOCKED" message.
if is_gateway or is_ask:
notify_cb = None
with _lock:
notify_cb = _gateway_notify_cbs.get(session_key)
if notify_cb is not None:
# --- Blocking gateway approval (queue-based) ---
# Each call gets its own _ApprovalEntry so parallel subagents
# and execute_code threads can block concurrently.
approval_data = {
"command": command,
"pattern_key": primary_key,
"pattern_keys": all_keys,
"description": combined_desc,
}
entry = _ApprovalEntry(approval_data)
with _lock:
_gateway_queues.setdefault(session_key, []).append(entry)
# Notify the user (bridges sync agent thread → async gateway)
try:
notify_cb(approval_data)
except Exception as exc:
logger.warning("Gateway approval notify failed: %s", exc)
with _lock:
queue = _gateway_queues.get(session_key, [])
if entry in queue:
queue.remove(entry)
if not queue:
_gateway_queues.pop(session_key, None)
return {
"approved": False,
"message": "BLOCKED: Failed to send approval request to user. Do NOT retry.",
"pattern_key": primary_key,
"description": combined_desc,
}
# Block until the user responds or timeout (default 5 min)
timeout = _get_approval_config().get("gateway_timeout", 300)
try:
timeout = int(timeout)
except (ValueError, TypeError):
timeout = 300
resolved = entry.event.wait(timeout=timeout)
# Clean up this entry from the queue
with _lock:
queue = _gateway_queues.get(session_key, [])
if entry in queue:
queue.remove(entry)
if not queue:
_gateway_queues.pop(session_key, None)
choice = entry.result
if not resolved or choice is None or choice == "deny":
reason = "timed out" if not resolved else "denied by user"
return {
"approved": False,
"message": f"BLOCKED: Command {reason}. Do NOT retry this command.",
"pattern_key": primary_key,
"description": combined_desc,
}
# User approved — persist based on scope (same logic as CLI)
for key, _, is_tirith in warnings:
if choice in ("once", "session") or (choice == "always" and is_tirith):
approve_session(session_key, key)
elif choice == "always":
approve_session(session_key, key)
approve_permanent(key)
save_permanent_allowlist(_permanent_approved)
return {"approved": True, "message": None}
# Fallback: no gateway callback registered (e.g. cron, batch).
# Return approval_required for backward compat.
submit_pending(session_key, {
"command": command,
"pattern_key": primary_key, # backward compat
"pattern_keys": all_keys, # all keys for replay
"pattern_key": primary_key,
"pattern_keys": all_keys,
"description": combined_desc,
})
return {