mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gateway): reset approval and yolo state on session boundary
This commit is contained in:
parent
64c38cc4d0
commit
050aabe2d4
2 changed files with 188 additions and 0 deletions
|
|
@ -7217,6 +7217,7 @@ class GatewayRunner:
|
|||
new_entry = self.session_store.switch_session(session_key, target_id)
|
||||
if not new_entry:
|
||||
return "Failed to switch session."
|
||||
self._clear_session_boundary_security_state(session_key)
|
||||
|
||||
# Get the title for confirmation
|
||||
title = self._session_db.get_session_title(target_id) or name
|
||||
|
|
@ -7306,6 +7307,7 @@ class GatewayRunner:
|
|||
new_entry = self.session_store.switch_session(session_key, new_session_id)
|
||||
if not new_entry:
|
||||
return "Branch created but failed to switch to it."
|
||||
self._clear_session_boundary_security_state(session_key)
|
||||
|
||||
# Evict any cached agent for this session
|
||||
self._evict_cached_agent(session_key)
|
||||
|
|
@ -8680,6 +8682,29 @@ class GatewayRunner:
|
|||
if hasattr(self, "_busy_ack_ts"):
|
||||
self._busy_ack_ts.pop(session_key, None)
|
||||
|
||||
def _clear_session_boundary_security_state(self, session_key: str) -> None:
|
||||
"""Clear approval state that must not survive a real conversation switch."""
|
||||
if not session_key:
|
||||
return
|
||||
|
||||
pending_approvals = getattr(self, "_pending_approvals", None)
|
||||
if isinstance(pending_approvals, dict):
|
||||
pending_approvals.pop(session_key, None)
|
||||
|
||||
try:
|
||||
from tools.approval import clear_session as _clear_approval_session
|
||||
except Exception:
|
||||
return
|
||||
|
||||
try:
|
||||
_clear_approval_session(session_key)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Failed to clear approval state for session boundary %s: %s",
|
||||
session_key,
|
||||
e,
|
||||
)
|
||||
|
||||
def _begin_session_run_generation(self, session_key: str) -> int:
|
||||
"""Claim a fresh run generation token for ``session_key``.
|
||||
|
||||
|
|
|
|||
163
tests/gateway/test_session_boundary_security_state.py
Normal file
163
tests/gateway/test_session_boundary_security_state.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
"""Regression tests for approval-state cleanup on session boundaries."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
from tools import approval as approval_mod
|
||||
from tools.approval import (
|
||||
approve_session,
|
||||
enable_session_yolo,
|
||||
is_approved,
|
||||
is_session_yolo_enabled,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_approval_state():
|
||||
approval_mod._gateway_queues.clear()
|
||||
approval_mod._gateway_notify_cbs.clear()
|
||||
approval_mod._session_approved.clear()
|
||||
approval_mod._session_yolo.clear()
|
||||
approval_mod._permanent_approved.clear()
|
||||
approval_mod._pending.clear()
|
||||
yield
|
||||
approval_mod._gateway_queues.clear()
|
||||
approval_mod._gateway_notify_cbs.clear()
|
||||
approval_mod._session_approved.clear()
|
||||
approval_mod._session_yolo.clear()
|
||||
approval_mod._permanent_approved.clear()
|
||||
approval_mod._pending.clear()
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str) -> MessageEvent:
|
||||
return MessageEvent(text=text, source=_make_source(), message_id="m1")
|
||||
|
||||
|
||||
def _make_entry(session_id: str, source: SessionSource | None = None) -> SessionEntry:
|
||||
source = source or _make_source()
|
||||
return SessionEntry(
|
||||
session_key=build_session_key(source),
|
||||
session_id=session_id,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
origin=source,
|
||||
platform=source.platform,
|
||||
chat_type=source.chat_type,
|
||||
)
|
||||
|
||||
|
||||
def _make_resume_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
source = _make_source()
|
||||
session_key = build_session_key(source)
|
||||
current_entry = _make_entry("current-session", source)
|
||||
resumed_entry = _make_entry("resumed-session", source)
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._background_tasks = set()
|
||||
runner._async_flush_memories = AsyncMock()
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._busy_ack_ts = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._agent_cache_lock = None
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = current_entry
|
||||
runner.session_store.switch_session.return_value = resumed_entry
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.resolve_session_by_title.return_value = "resumed-session"
|
||||
runner._session_db.get_session_title.return_value = "Resumed Work"
|
||||
return runner, session_key
|
||||
|
||||
|
||||
def _make_branch_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
source = _make_source()
|
||||
session_key = build_session_key(source)
|
||||
current_entry = _make_entry("current-session", source)
|
||||
branched_entry = _make_entry("branched-session", source)
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner.config = {}
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._busy_ack_ts = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._agent_cache_lock = None
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = current_entry
|
||||
runner.session_store.load_transcript.return_value = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "world"},
|
||||
]
|
||||
runner.session_store.switch_session.return_value = branched_entry
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.get_session_title.return_value = "Current Work"
|
||||
runner._session_db.get_next_title_in_lineage.return_value = "Current Work #2"
|
||||
return runner, session_key
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_clears_session_scoped_approval_and_yolo_state():
|
||||
runner, session_key = _make_resume_runner()
|
||||
other_key = "agent:main:telegram:dm:other-chat"
|
||||
|
||||
approve_session(session_key, "recursive delete")
|
||||
approve_session(other_key, "recursive delete")
|
||||
enable_session_yolo(session_key)
|
||||
enable_session_yolo(other_key)
|
||||
runner._pending_approvals[session_key] = {"command": "rm -rf /tmp/demo"}
|
||||
runner._pending_approvals[other_key] = {"command": "rm -rf /tmp/other"}
|
||||
|
||||
result = await runner._handle_resume_command(_make_event("/resume Resumed Work"))
|
||||
|
||||
assert "Resumed session" in result
|
||||
assert is_approved(session_key, "recursive delete") is False
|
||||
assert is_session_yolo_enabled(session_key) is False
|
||||
assert session_key not in runner._pending_approvals
|
||||
assert is_approved(other_key, "recursive delete") is True
|
||||
assert is_session_yolo_enabled(other_key) is True
|
||||
assert other_key in runner._pending_approvals
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_branch_clears_session_scoped_approval_and_yolo_state():
|
||||
runner, session_key = _make_branch_runner()
|
||||
other_key = "agent:main:telegram:dm:other-chat"
|
||||
|
||||
approve_session(session_key, "recursive delete")
|
||||
approve_session(other_key, "recursive delete")
|
||||
enable_session_yolo(session_key)
|
||||
enable_session_yolo(other_key)
|
||||
runner._pending_approvals[session_key] = {"command": "rm -rf /tmp/demo"}
|
||||
runner._pending_approvals[other_key] = {"command": "rm -rf /tmp/other"}
|
||||
|
||||
result = await runner._handle_branch_command(_make_event("/branch"))
|
||||
|
||||
assert "Branched to" in result
|
||||
assert is_approved(session_key, "recursive delete") is False
|
||||
assert is_session_yolo_enabled(session_key) is False
|
||||
assert session_key not in runner._pending_approvals
|
||||
assert is_approved(other_key, "recursive delete") is True
|
||||
assert is_session_yolo_enabled(other_key) is True
|
||||
assert other_key in runner._pending_approvals
|
||||
Loading…
Add table
Add a link
Reference in a new issue