fix(gateway): isolate approval session key per turn

This commit is contained in:
Tranquil-Flow 2026-04-03 01:09:45 +00:00 committed by Teknium
parent 5359921199
commit 3bfb39a25f
4 changed files with 136 additions and 5 deletions

View file

@ -390,6 +390,9 @@ class TestBlockingApprovalE2E:
result_holder = [None]
def agent_thread():
from tools.approval import reset_current_session_key, set_current_session_key
token = set_current_session_key(session_key)
os.environ["HERMES_EXEC_ASK"] = "1"
os.environ["HERMES_SESSION_KEY"] = session_key
try:
@ -399,6 +402,7 @@ class TestBlockingApprovalE2E:
finally:
os.environ.pop("HERMES_EXEC_ASK", None)
os.environ.pop("HERMES_SESSION_KEY", None)
reset_current_session_key(token)
t = threading.Thread(target=agent_thread)
t.start()
@ -432,6 +436,9 @@ class TestBlockingApprovalE2E:
result_holder = [None]
def agent_thread():
from tools.approval import reset_current_session_key, set_current_session_key
token = set_current_session_key(session_key)
os.environ["HERMES_EXEC_ASK"] = "1"
os.environ["HERMES_SESSION_KEY"] = session_key
try:
@ -441,6 +448,7 @@ class TestBlockingApprovalE2E:
finally:
os.environ.pop("HERMES_EXEC_ASK", None)
os.environ.pop("HERMES_SESSION_KEY", None)
reset_current_session_key(token)
t = threading.Thread(target=agent_thread)
t.start()
@ -469,6 +477,9 @@ class TestBlockingApprovalE2E:
result_holder = [None]
def agent_thread():
from tools.approval import reset_current_session_key, set_current_session_key
token = set_current_session_key(session_key)
os.environ["HERMES_EXEC_ASK"] = "1"
os.environ["HERMES_SESSION_KEY"] = session_key
try:
@ -480,6 +491,7 @@ class TestBlockingApprovalE2E:
finally:
os.environ.pop("HERMES_EXEC_ASK", None)
os.environ.pop("HERMES_SESSION_KEY", None)
reset_current_session_key(token)
t = threading.Thread(target=agent_thread)
t.start()
@ -505,6 +517,9 @@ class TestBlockingApprovalE2E:
def make_agent(idx, cmd):
def run():
from tools.approval import reset_current_session_key, set_current_session_key
token = set_current_session_key(session_key)
os.environ["HERMES_EXEC_ASK"] = "1"
os.environ["HERMES_SESSION_KEY"] = session_key
try:
@ -512,6 +527,7 @@ class TestBlockingApprovalE2E:
finally:
os.environ.pop("HERMES_EXEC_ASK", None)
os.environ.pop("HERMES_SESSION_KEY", None)
reset_current_session_key(token)
return run
threads = [
@ -556,6 +572,9 @@ class TestBlockingApprovalE2E:
def make_agent(idx, cmd):
def run():
from tools.approval import reset_current_session_key, set_current_session_key
token = set_current_session_key(session_key)
os.environ["HERMES_EXEC_ASK"] = "1"
os.environ["HERMES_SESSION_KEY"] = session_key
try:
@ -563,6 +582,7 @@ class TestBlockingApprovalE2E:
finally:
os.environ.pop("HERMES_EXEC_ASK", None)
os.environ.pop("HERMES_SESSION_KEY", None)
reset_current_session_key(token)
return run
threads = [
@ -580,8 +600,9 @@ class TestBlockingApprovalE2E:
for t in threads:
t.join(timeout=5)
assert results[0]["approved"] is True
assert results[1]["approved"] is False
assert all(r is not None for r in results)
assert sorted(r["approved"] for r in results) == [False, True]
assert sum("BLOCKED" in (r.get("message") or "") for r in results) == 1
unregister_gateway_notify(session_key)

View file

@ -1,5 +1,7 @@
"""Tests for the dangerous command approval module."""
import ast
from pathlib import Path
from unittest.mock import patch as mock_patch
import tools.approval as approval_module
@ -148,6 +150,79 @@ class TestApproveAndCheckSession:
assert has_pending(key) is False
class TestSessionKeyContext:
def test_context_session_key_overrides_process_env(self):
token = approval_module.set_current_session_key("alice")
try:
with mock_patch.dict("os.environ", {"HERMES_SESSION_KEY": "bob"}, clear=False):
assert approval_module.get_current_session_key() == "alice"
finally:
approval_module.reset_current_session_key(token)
def test_gateway_runner_binds_session_key_to_context_before_agent_run(self):
run_py = Path(__file__).resolve().parents[2] / "gateway" / "run.py"
module = ast.parse(run_py.read_text(encoding="utf-8"))
run_sync = None
for node in ast.walk(module):
if isinstance(node, ast.FunctionDef) and node.name == "run_sync":
run_sync = node
break
assert run_sync is not None, "gateway.run.run_sync not found"
called_names = set()
for node in ast.walk(run_sync):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
called_names.add(node.func.id)
assert "set_current_session_key" in called_names
assert "reset_current_session_key" in called_names
def test_context_keeps_pending_approval_attached_to_originating_session(self):
import os
import threading
clear_session("alice")
clear_session("bob")
pop_pending("alice")
pop_pending("bob")
approval_module._permanent_approved.clear()
alice_ready = threading.Event()
bob_ready = threading.Event()
def worker_alice():
token = approval_module.set_current_session_key("alice")
try:
os.environ["HERMES_EXEC_ASK"] = "1"
os.environ["HERMES_SESSION_KEY"] = "alice"
alice_ready.set()
bob_ready.wait()
approval_module.check_all_command_guards("rm -rf /tmp/alice-secret", "local")
finally:
approval_module.reset_current_session_key(token)
def worker_bob():
alice_ready.wait()
token = approval_module.set_current_session_key("bob")
try:
os.environ["HERMES_SESSION_KEY"] = "bob"
bob_ready.set()
finally:
approval_module.reset_current_session_key(token)
t1 = threading.Thread(target=worker_alice)
t2 = threading.Thread(target=worker_bob)
t1.start()
t2.start()
t1.join()
t2.join()
assert pop_pending("alice") is not None
assert pop_pending("bob") is None
class TestRmFalsePositiveFix:
"""Regression tests: filenames starting with 'r' must NOT trigger recursive delete."""