mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gateway): stop typing loops on session interrupt
This commit is contained in:
parent
b05d30418d
commit
150382e8b7
6 changed files with 456 additions and 18 deletions
|
|
@ -1401,7 +1401,13 @@ class BasePlatformAdapter(ABC):
|
|||
|
||||
return paths, cleaned
|
||||
|
||||
async def _keep_typing(self, chat_id: str, interval: float = 2.0, metadata=None) -> None:
|
||||
async def _keep_typing(
|
||||
self,
|
||||
chat_id: str,
|
||||
interval: float = 2.0,
|
||||
metadata=None,
|
||||
stop_event: asyncio.Event | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Continuously send typing indicator until cancelled.
|
||||
|
||||
|
|
@ -1415,9 +1421,18 @@ class BasePlatformAdapter(ABC):
|
|||
"""
|
||||
try:
|
||||
while True:
|
||||
if stop_event is not None and stop_event.is_set():
|
||||
return
|
||||
if chat_id not in self._typing_paused:
|
||||
await self.send_typing(chat_id, metadata=metadata)
|
||||
await asyncio.sleep(interval)
|
||||
if stop_event is None:
|
||||
await asyncio.sleep(interval)
|
||||
continue
|
||||
try:
|
||||
await asyncio.wait_for(stop_event.wait(), timeout=interval)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
pass # Normal cancellation when handler completes
|
||||
finally:
|
||||
|
|
@ -1444,6 +1459,17 @@ class BasePlatformAdapter(ABC):
|
|||
"""Resume typing indicator for a chat after approval resolves."""
|
||||
self._typing_paused.discard(chat_id)
|
||||
|
||||
async def interrupt_session_activity(self, session_key: str, chat_id: str) -> None:
|
||||
"""Signal the active session loop to stop and clear typing immediately."""
|
||||
if session_key:
|
||||
interrupt_event = self._active_sessions.get(session_key)
|
||||
if interrupt_event is not None:
|
||||
interrupt_event.set()
|
||||
try:
|
||||
await self.stop_typing(chat_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── Processing lifecycle hooks ──────────────────────────────────────────
|
||||
# Subclasses override these to react to message processing events
|
||||
# (e.g. Discord adds 👀/✅/❌ reactions).
|
||||
|
|
@ -1717,7 +1743,13 @@ class BasePlatformAdapter(ABC):
|
|||
|
||||
# Start continuous typing indicator (refreshes every 2 seconds)
|
||||
_thread_metadata = {"thread_id": event.source.thread_id} if event.source.thread_id else None
|
||||
typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id, metadata=_thread_metadata))
|
||||
typing_task = asyncio.create_task(
|
||||
self._keep_typing(
|
||||
event.source.chat_id,
|
||||
metadata=_thread_metadata,
|
||||
stop_event=interrupt_event,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await self._run_processing_hook("on_processing_start", event)
|
||||
|
|
|
|||
147
gateway/run.py
147
gateway/run.py
|
|
@ -402,6 +402,26 @@ def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None:
|
|||
return adapter.get_pending_message(session_key)
|
||||
|
||||
|
||||
_CONTROL_INTERRUPT_MESSAGES = frozenset(
|
||||
{
|
||||
"stop requested",
|
||||
"session reset requested",
|
||||
"execution timed out (inactivity)",
|
||||
"sse client disconnected",
|
||||
"gateway shutting down",
|
||||
"gateway restarting",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_control_interrupt_message(message: Optional[str]) -> bool:
|
||||
"""Return True when an interrupt message is internal control flow."""
|
||||
if not message:
|
||||
return False
|
||||
normalized = " ".join(str(message).strip().split()).lower()
|
||||
return normalized in _CONTROL_INTERRUPT_MESSAGES
|
||||
|
||||
|
||||
def _check_unavailable_skill(command_name: str) -> str | None:
|
||||
"""Check if a command matches a known-but-inactive skill.
|
||||
|
||||
|
|
@ -630,6 +650,7 @@ class GatewayRunner:
|
|||
self._running_agents_ts: Dict[str, float] = {} # start timestamp per session
|
||||
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
|
||||
self._busy_ack_ts: Dict[str, float] = {} # last busy-ack timestamp per session (debounce)
|
||||
self._session_run_generation: Dict[str, int] = {}
|
||||
|
||||
# Cache AIAgent instances per session to preserve prompt caching.
|
||||
# Without this, a new AIAgent is created per message, rebuilding the
|
||||
|
|
@ -3064,6 +3085,10 @@ class GatewayRunner:
|
|||
_quick_key[:30], _stale_age, _stale_idle,
|
||||
_raw_stale_timeout, _stale_detail,
|
||||
)
|
||||
self._invalidate_session_run_generation(
|
||||
_quick_key,
|
||||
reason="stale_running_agent_eviction",
|
||||
)
|
||||
self._release_running_agent_state(_quick_key)
|
||||
|
||||
if _quick_key in self._running_agents:
|
||||
|
|
@ -3091,7 +3116,13 @@ class GatewayRunner:
|
|||
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
|
||||
running_agent.interrupt("Stop requested")
|
||||
# Force-clean: remove the session lock regardless of agent state
|
||||
self._invalidate_session_run_generation(
|
||||
_quick_key,
|
||||
reason="stop_command",
|
||||
)
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter and hasattr(adapter, "interrupt_session_activity"):
|
||||
await adapter.interrupt_session_activity(_quick_key, source.chat_id)
|
||||
if adapter and hasattr(adapter, 'get_pending_message'):
|
||||
adapter.get_pending_message(_quick_key) # consume and discard
|
||||
self._pending_messages.pop(_quick_key, None)
|
||||
|
|
@ -3111,7 +3142,13 @@ class GatewayRunner:
|
|||
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
|
||||
running_agent.interrupt("Session reset requested")
|
||||
# Clear any pending messages so the old text doesn't replay
|
||||
self._invalidate_session_run_generation(
|
||||
_quick_key,
|
||||
reason="new_command",
|
||||
)
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter and hasattr(adapter, "interrupt_session_activity"):
|
||||
await adapter.interrupt_session_activity(_quick_key, source.chat_id)
|
||||
if adapter and hasattr(adapter, 'get_pending_message'):
|
||||
adapter.get_pending_message(_quick_key) # consume and discard
|
||||
self._pending_messages.pop(_quick_key, None)
|
||||
|
|
@ -3598,9 +3635,10 @@ class GatewayRunner:
|
|||
# same session — corrupting the transcript.
|
||||
self._running_agents[_quick_key] = _AGENT_PENDING_SENTINEL
|
||||
self._running_agents_ts[_quick_key] = time.time()
|
||||
_run_generation = self._begin_session_run_generation(_quick_key)
|
||||
|
||||
try:
|
||||
return await self._handle_message_with_agent(event, source, _quick_key)
|
||||
return await self._handle_message_with_agent(event, source, _quick_key, _run_generation)
|
||||
finally:
|
||||
# If _run_agent replaced the sentinel with a real agent and
|
||||
# then cleaned it up, this is a no-op. If we exited early
|
||||
|
|
@ -3771,7 +3809,7 @@ class GatewayRunner:
|
|||
|
||||
return message_text
|
||||
|
||||
async def _handle_message_with_agent(self, event, source, _quick_key: str):
|
||||
async def _handle_message_with_agent(self, event, source, _quick_key: str, run_generation: int):
|
||||
"""Inner handler that runs under the _running_agents sentinel guard."""
|
||||
_msg_start_time = time.time()
|
||||
_platform_name = source.platform.value if hasattr(source.platform, "value") else str(source.platform)
|
||||
|
|
@ -4246,6 +4284,7 @@ class GatewayRunner:
|
|||
source=source,
|
||||
session_id=session_entry.session_id,
|
||||
session_key=session_key,
|
||||
run_generation=run_generation,
|
||||
event_message_id=event.message_id,
|
||||
channel_prompt=event.channel_prompt,
|
||||
)
|
||||
|
|
@ -4258,6 +4297,17 @@ class GatewayRunner:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
if not self._is_session_run_current(_quick_key, run_generation):
|
||||
logger.info(
|
||||
"Discarding stale agent result for %s — generation %d is no longer current",
|
||||
_quick_key[:20] if _quick_key else "?",
|
||||
run_generation,
|
||||
)
|
||||
_stale_adapter = self.adapters.get(source.platform)
|
||||
if _stale_adapter and hasattr(_stale_adapter, "_post_delivery_callbacks"):
|
||||
_stale_adapter._post_delivery_callbacks.pop(_quick_key, None)
|
||||
return None
|
||||
|
||||
response = agent_result.get("final_response") or ""
|
||||
|
||||
# Convert the agent's internal "(empty)" sentinel into a
|
||||
|
|
@ -4672,6 +4722,7 @@ class GatewayRunner:
|
|||
|
||||
# Get existing session key
|
||||
session_key = self._session_key_for_source(source)
|
||||
self._invalidate_session_run_generation(session_key, reason="session_reset")
|
||||
|
||||
# Flush memories in the background (fire-and-forget) so the user
|
||||
# gets the "Session reset!" response immediately.
|
||||
|
|
@ -4931,6 +4982,10 @@ class GatewayRunner:
|
|||
agent = self._running_agents.get(session_key)
|
||||
if agent is _AGENT_PENDING_SENTINEL:
|
||||
# Force-clean the sentinel so the session is unlocked.
|
||||
self._invalidate_session_run_generation(session_key, reason="stop_command_pending")
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter and hasattr(adapter, "interrupt_session_activity"):
|
||||
await adapter.interrupt_session_activity(session_key, source.chat_id)
|
||||
self._release_running_agent_state(session_key)
|
||||
logger.info("STOP (pending) for session %s — sentinel cleared", session_key[:20])
|
||||
return "⚡ Stopped. The agent hadn't started yet — you can continue this session."
|
||||
|
|
@ -4938,6 +4993,10 @@ class GatewayRunner:
|
|||
agent.interrupt("Stop requested")
|
||||
# Force-clean the session lock so a truly hung agent doesn't
|
||||
# keep it locked forever.
|
||||
self._invalidate_session_run_generation(session_key, reason="stop_command_handler")
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter and hasattr(adapter, "interrupt_session_activity"):
|
||||
await adapter.interrupt_session_activity(session_key, source.chat_id)
|
||||
self._release_running_agent_state(session_key)
|
||||
return "⚡ Stopped. You can continue this session."
|
||||
else:
|
||||
|
|
@ -8385,6 +8444,43 @@ class GatewayRunner:
|
|||
if hasattr(self, "_busy_ack_ts"):
|
||||
self._busy_ack_ts.pop(session_key, None)
|
||||
|
||||
def _begin_session_run_generation(self, session_key: str) -> int:
|
||||
"""Claim a fresh run generation token for ``session_key``.
|
||||
|
||||
Every top-level gateway turn gets a monotonically increasing token.
|
||||
If a later command like /stop or /new invalidates that token while the
|
||||
old worker is still unwinding, the late result can be recognized and
|
||||
dropped instead of bleeding into the fresh session.
|
||||
"""
|
||||
if not session_key:
|
||||
return 0
|
||||
generations = self.__dict__.get("_session_run_generation")
|
||||
if generations is None:
|
||||
generations = {}
|
||||
self._session_run_generation = generations
|
||||
next_generation = int(generations.get(session_key, 0)) + 1
|
||||
generations[session_key] = next_generation
|
||||
return next_generation
|
||||
|
||||
def _invalidate_session_run_generation(self, session_key: str, *, reason: str = "") -> int:
|
||||
"""Invalidate any in-flight run token for ``session_key``."""
|
||||
generation = self._begin_session_run_generation(session_key)
|
||||
if reason:
|
||||
logger.info(
|
||||
"Invalidated run generation for %s → %d (%s)",
|
||||
session_key[:20],
|
||||
generation,
|
||||
reason,
|
||||
)
|
||||
return generation
|
||||
|
||||
def _is_session_run_current(self, session_key: str, generation: int) -> bool:
|
||||
"""Return True when ``generation`` is still current for ``session_key``."""
|
||||
if not session_key:
|
||||
return True
|
||||
generations = self.__dict__.get("_session_run_generation") or {}
|
||||
return int(generations.get(session_key, 0)) == int(generation)
|
||||
|
||||
def _evict_cached_agent(self, session_key: str) -> None:
|
||||
"""Remove a cached agent for a session (called on /new, /model, etc)."""
|
||||
_lock = getattr(self, "_agent_cache_lock", None)
|
||||
|
|
@ -8807,6 +8903,7 @@ class GatewayRunner:
|
|||
source: SessionSource,
|
||||
session_id: str,
|
||||
session_key: str = None,
|
||||
run_generation: Optional[int] = None,
|
||||
_interrupt_depth: int = 0,
|
||||
event_message_id: Optional[str] = None,
|
||||
channel_prompt: Optional[str] = None,
|
||||
|
|
@ -8837,6 +8934,11 @@ class GatewayRunner:
|
|||
|
||||
from run_agent import AIAgent
|
||||
import queue
|
||||
|
||||
def _run_still_current() -> bool:
|
||||
if run_generation is None or not session_key:
|
||||
return True
|
||||
return self._is_session_run_current(session_key, run_generation)
|
||||
|
||||
user_config = _load_gateway_config()
|
||||
platform_key = _platform_config_key(source.platform)
|
||||
|
|
@ -8891,7 +8993,7 @@ class GatewayRunner:
|
|||
|
||||
def progress_callback(event_type: str, tool_name: str = None, preview: str = None, args: dict = None, **kwargs):
|
||||
"""Callback invoked by agent on tool lifecycle events."""
|
||||
if not progress_queue:
|
||||
if not progress_queue or not _run_still_current():
|
||||
return
|
||||
|
||||
# Only act on tool.started events (ignore tool.completed, reasoning.available, etc.)
|
||||
|
|
@ -8996,6 +9098,14 @@ class GatewayRunner:
|
|||
|
||||
while True:
|
||||
try:
|
||||
if not _run_still_current():
|
||||
while not progress_queue.empty():
|
||||
try:
|
||||
progress_queue.get_nowait()
|
||||
except Exception:
|
||||
break
|
||||
return
|
||||
|
||||
raw = progress_queue.get_nowait()
|
||||
|
||||
# Handle dedup messages: update last line with repeat counter
|
||||
|
|
@ -9021,6 +9131,9 @@ class GatewayRunner:
|
|||
await asyncio.sleep(_remaining)
|
||||
continue
|
||||
|
||||
if not _run_still_current():
|
||||
return
|
||||
|
||||
if can_edit and progress_msg_id is not None:
|
||||
# Try to edit the existing progress message
|
||||
full_text = "\n".join(progress_lines)
|
||||
|
|
@ -9056,7 +9169,8 @@ class GatewayRunner:
|
|||
|
||||
# Restore typing indicator
|
||||
await asyncio.sleep(0.3)
|
||||
await adapter.send_typing(source.chat_id, metadata=_progress_metadata)
|
||||
if _run_still_current():
|
||||
await adapter.send_typing(source.chat_id, metadata=_progress_metadata)
|
||||
|
||||
except queue.Empty:
|
||||
await asyncio.sleep(0.3)
|
||||
|
|
@ -9100,6 +9214,8 @@ class GatewayRunner:
|
|||
_hooks_ref = self.hooks
|
||||
|
||||
def _step_callback_sync(iteration: int, prev_tools: list) -> None:
|
||||
if not _run_still_current():
|
||||
return
|
||||
try:
|
||||
# prev_tools may be list[str] or list[dict] with "name"/"result"
|
||||
# keys. Normalise to keep "tool_names" backward-compatible for
|
||||
|
|
@ -9130,7 +9246,7 @@ class GatewayRunner:
|
|||
_status_thread_metadata = {"thread_id": _progress_thread_id} if _progress_thread_id else None
|
||||
|
||||
def _status_callback_sync(event_type: str, message: str) -> None:
|
||||
if not _status_adapter:
|
||||
if not _status_adapter or not _run_still_current():
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
|
|
@ -9261,12 +9377,16 @@ class GatewayRunner:
|
|||
metadata={"thread_id": _progress_thread_id} if _progress_thread_id else None,
|
||||
)
|
||||
if _want_stream_deltas:
|
||||
_stream_delta_cb = _stream_consumer.on_delta
|
||||
def _stream_delta_cb(text: str) -> None:
|
||||
if _run_still_current():
|
||||
_stream_consumer.on_delta(text)
|
||||
stream_consumer_holder[0] = _stream_consumer
|
||||
except Exception as _sc_err:
|
||||
logger.debug("Could not set up stream consumer: %s", _sc_err)
|
||||
|
||||
def _interim_assistant_cb(text: str, *, already_streamed: bool = False) -> None:
|
||||
if not _run_still_current():
|
||||
return
|
||||
if _stream_consumer is not None:
|
||||
if already_streamed:
|
||||
_stream_consumer.on_segment_break()
|
||||
|
|
@ -9370,7 +9490,7 @@ class GatewayRunner:
|
|||
_bg_review_pending_lock = threading.Lock()
|
||||
|
||||
def _deliver_bg_review_message(message: str) -> None:
|
||||
if not _status_adapter:
|
||||
if not _status_adapter or not _run_still_current():
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
|
|
@ -9394,7 +9514,7 @@ class GatewayRunner:
|
|||
|
||||
# Background review delivery — send "💾 Memory updated" etc. to user
|
||||
def _bg_review_send(message: str) -> None:
|
||||
if not _status_adapter:
|
||||
if not _status_adapter or not _run_still_current():
|
||||
return
|
||||
if not _bg_review_release.is_set():
|
||||
with _bg_review_pending_lock:
|
||||
|
|
@ -10076,7 +10196,15 @@ class GatewayRunner:
|
|||
if result and adapter and session_key:
|
||||
pending_event = _dequeue_pending_event(adapter, session_key)
|
||||
if result.get("interrupted") and not pending_event and result.get("interrupt_message"):
|
||||
pending = result.get("interrupt_message")
|
||||
interrupt_message = result.get("interrupt_message")
|
||||
if _is_control_interrupt_message(interrupt_message):
|
||||
logger.info(
|
||||
"Ignoring control interrupt message for session %s: %s",
|
||||
session_key[:20] if session_key else "?",
|
||||
interrupt_message,
|
||||
)
|
||||
else:
|
||||
pending = interrupt_message
|
||||
elif pending_event:
|
||||
pending = pending_event.text or _build_media_placeholder(pending_event)
|
||||
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
|
||||
|
|
@ -10229,6 +10357,7 @@ class GatewayRunner:
|
|||
source=next_source,
|
||||
session_id=session_id,
|
||||
session_key=session_key,
|
||||
run_generation=run_generation,
|
||||
_interrupt_depth=_interrupt_depth + 1,
|
||||
event_message_id=next_message_id,
|
||||
channel_prompt=next_channel_prompt,
|
||||
|
|
|
|||
|
|
@ -1,13 +1,18 @@
|
|||
"""Tests for the pending_event None guard in recursive _run_agent calls.
|
||||
"""Tests for pending follow-up extraction in recursive _run_agent calls.
|
||||
|
||||
When pending_event is None (Path B: pending comes from interrupt_message),
|
||||
accessing pending_event.channel_prompt previously raised AttributeError.
|
||||
This verifies the fix: channel_prompt is captured inside the
|
||||
`if pending_event is not None:` block and falls back to None otherwise.
|
||||
|
||||
Also verifies that internal control interrupt reasons like "Stop requested"
|
||||
do not get recycled into the pending-user-message follow-up path.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from gateway.run import _is_control_interrupt_message
|
||||
|
||||
|
||||
def _extract_channel_prompt(pending_event):
|
||||
"""Reproduce the fixed logic from gateway/run.py.
|
||||
|
|
@ -21,6 +26,15 @@ def _extract_channel_prompt(pending_event):
|
|||
return next_channel_prompt
|
||||
|
||||
|
||||
def _extract_pending_text(interrupted, pending_event, interrupt_message):
|
||||
"""Reproduce the fixed pending-text selection from gateway/run.py."""
|
||||
if interrupted and pending_event is None and interrupt_message:
|
||||
if _is_control_interrupt_message(interrupt_message):
|
||||
return None
|
||||
return interrupt_message
|
||||
return None
|
||||
|
||||
|
||||
class TestPendingEventNoneChannelPrompt:
|
||||
"""Guard against AttributeError when pending_event is None."""
|
||||
|
||||
|
|
@ -40,3 +54,19 @@ class TestPendingEventNoneChannelPrompt:
|
|||
event = SimpleNamespace()
|
||||
result = _extract_channel_prompt(event)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestControlInterruptMessages:
|
||||
"""Control interrupt reasons must not become follow-up user input."""
|
||||
|
||||
def test_stop_requested_is_not_treated_as_pending_user_message(self):
|
||||
result = _extract_pending_text(True, None, "Stop requested")
|
||||
assert result is None
|
||||
|
||||
def test_session_reset_requested_is_not_treated_as_pending_user_message(self):
|
||||
result = _extract_pending_text(True, None, "Session reset requested")
|
||||
assert result is None
|
||||
|
||||
def test_real_user_interrupt_message_still_requeues(self):
|
||||
result = _extract_pending_text(True, None, "actually use postgres instead")
|
||||
assert result == "actually use postgres instead"
|
||||
|
|
|
|||
|
|
@ -51,6 +51,9 @@ class ProgressCaptureAdapter(BasePlatformAdapter):
|
|||
async def send_typing(self, chat_id, metadata=None) -> None:
|
||||
self.typing.append({"chat_id": chat_id, "metadata": metadata})
|
||||
|
||||
async def stop_typing(self, chat_id) -> None:
|
||||
self.typing.append({"chat_id": chat_id, "metadata": {"stopped": True}})
|
||||
|
||||
async def get_chat_info(self, chat_id: str):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
|
@ -90,6 +93,40 @@ class LongPreviewAgent:
|
|||
}
|
||||
|
||||
|
||||
class DelayedProgressAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
self.tool_progress_callback("tool.started", "terminal", "first command", {})
|
||||
time.sleep(0.45)
|
||||
self.tool_progress_callback("tool.started", "terminal", "second command", {})
|
||||
time.sleep(0.1)
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
class DelayedInterimAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
self.interim_assistant_callback("first interim")
|
||||
time.sleep(0.45)
|
||||
self.interim_assistant_callback("second interim")
|
||||
time.sleep(0.1)
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
def _make_runner(adapter):
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
GatewayRunner = gateway_run.GatewayRunner
|
||||
|
|
@ -104,6 +141,7 @@ def _make_runner(adapter):
|
|||
runner._fallback_model = None
|
||||
runner._session_db = None
|
||||
runner._running_agents = {}
|
||||
runner._session_run_generation = {}
|
||||
runner.hooks = SimpleNamespace(loaded_hooks=False)
|
||||
runner.config = SimpleNamespace(
|
||||
thread_sessions_per_user=False,
|
||||
|
|
@ -744,6 +782,154 @@ async def test_base_processing_releases_post_delivery_callback_after_main_send()
|
|||
assert released == [True]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_drops_tool_progress_after_generation_invalidation(monkeypatch, tmp_path):
|
||||
import yaml
|
||||
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
yaml.dump({"display": {"tool_progress": "all"}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = DelayedProgressAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
import tools.terminal_tool # noqa: F401 - register terminal tool metadata
|
||||
|
||||
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="dm-1",
|
||||
chat_type="dm",
|
||||
thread_id=None,
|
||||
)
|
||||
session_key = "agent:main:discord:dm:dm-1"
|
||||
runner._session_run_generation[session_key] = 1
|
||||
|
||||
original_send = adapter.send
|
||||
invalidated = {"done": False}
|
||||
|
||||
async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None):
|
||||
result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata)
|
||||
if "first command" in content and not invalidated["done"]:
|
||||
invalidated["done"] = True
|
||||
runner._invalidate_session_run_generation(session_key, reason="test_stop")
|
||||
return result
|
||||
|
||||
adapter.send = send_and_invalidate
|
||||
|
||||
result = await runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-progress-stop",
|
||||
session_key=session_key,
|
||||
run_generation=1,
|
||||
)
|
||||
|
||||
all_progress_text = " ".join(call["content"] for call in adapter.sent)
|
||||
all_progress_text += " ".join(call["content"] for call in adapter.edits)
|
||||
assert result["final_response"] == "done"
|
||||
assert 'first command' in all_progress_text
|
||||
assert 'second command' not in all_progress_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_drops_interim_commentary_after_generation_invalidation(monkeypatch, tmp_path):
|
||||
import yaml
|
||||
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
yaml.dump({"display": {"tool_progress": "off", "interim_assistant_messages": True}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = DelayedInterimAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="dm-2",
|
||||
chat_type="dm",
|
||||
thread_id=None,
|
||||
)
|
||||
session_key = "agent:main:discord:dm:dm-2"
|
||||
runner._session_run_generation[session_key] = 1
|
||||
|
||||
original_send = adapter.send
|
||||
invalidated = {"done": False}
|
||||
|
||||
async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None):
|
||||
result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata)
|
||||
if content == "first interim" and not invalidated["done"]:
|
||||
invalidated["done"] = True
|
||||
runner._invalidate_session_run_generation(session_key, reason="test_stop")
|
||||
return result
|
||||
|
||||
adapter.send = send_and_invalidate
|
||||
|
||||
result = await runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-commentary-stop",
|
||||
session_key=session_key,
|
||||
run_generation=1,
|
||||
)
|
||||
|
||||
sent_texts = [call["content"] for call in adapter.sent]
|
||||
assert result["final_response"] == "done"
|
||||
assert "first interim" in sent_texts
|
||||
assert "second interim" not in sent_texts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keep_typing_stops_immediately_when_interrupt_event_is_set():
|
||||
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
task = asyncio.create_task(
|
||||
adapter._keep_typing(
|
||||
"dm-typing-stop",
|
||||
interval=30.0,
|
||||
stop_event=stop_event,
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
stop_event.set()
|
||||
await asyncio.wait_for(task, timeout=0.5)
|
||||
|
||||
normal_typing_calls = [
|
||||
call for call in adapter.typing if call.get("metadata") != {"stopped": True}
|
||||
]
|
||||
stopped_calls = [
|
||||
call for call in adapter.typing if call.get("metadata") == {"stopped": True}
|
||||
]
|
||||
assert len(normal_typing_calls) == 1
|
||||
assert len(stopped_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verbose_mode_does_not_truncate_args_by_default(monkeypatch, tmp_path):
|
||||
"""Verbose mode with default tool_preview_length (0) should NOT truncate args.
|
||||
|
|
|
|||
|
|
@ -24,10 +24,18 @@ class _FakeAdapter:
|
|||
|
||||
def __init__(self):
|
||||
self._pending_messages = {}
|
||||
self._active_sessions = {}
|
||||
self.interrupted_sessions = []
|
||||
|
||||
async def send(self, chat_id, text, **kwargs):
|
||||
pass
|
||||
|
||||
async def interrupt_session_activity(self, session_key, chat_id):
|
||||
self.interrupted_sessions.append((session_key, chat_id))
|
||||
event = self._active_sessions.get(session_key)
|
||||
if event is not None:
|
||||
event.set()
|
||||
|
||||
|
||||
def _make_runner():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
|
|
@ -37,6 +45,7 @@ def _make_runner():
|
|||
runner.adapters = {Platform.TELEGRAM: _FakeAdapter()}
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._session_run_generation = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._voice_mode = {}
|
||||
|
|
@ -81,7 +90,7 @@ async def test_sentinel_placed_before_agent_setup():
|
|||
# Patch _handle_message_with_agent to capture state at entry
|
||||
sentinel_was_set = False
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
async def mock_inner(self_inner, ev, src, qk, generation):
|
||||
nonlocal sentinel_was_set
|
||||
sentinel_was_set = runner._running_agents.get(qk) is _AGENT_PENDING_SENTINEL
|
||||
return "ok"
|
||||
|
|
@ -105,7 +114,7 @@ async def test_sentinel_cleaned_up_after_handler_returns():
|
|||
event = _make_event()
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
async def mock_inner(self_inner, ev, src, qk, generation):
|
||||
return "ok"
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
||||
|
|
@ -127,7 +136,7 @@ async def test_sentinel_cleaned_up_on_exception():
|
|||
event = _make_event()
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
async def mock_inner(self_inner, ev, src, qk):
|
||||
async def mock_inner(self_inner, ev, src, qk, generation):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
||||
|
|
@ -154,7 +163,7 @@ async def test_second_message_during_sentinel_queued_not_duplicate():
|
|||
|
||||
barrier = asyncio.Event()
|
||||
|
||||
async def slow_inner(self_inner, ev, src, qk):
|
||||
async def slow_inner(self_inner, ev, src, qk, generation):
|
||||
# Simulate slow setup — wait until test tells us to proceed
|
||||
await barrier.wait()
|
||||
return "ok"
|
||||
|
|
@ -333,7 +342,7 @@ async def test_stop_during_sentinel_force_cleans_session():
|
|||
|
||||
barrier = asyncio.Event()
|
||||
|
||||
async def slow_inner(self_inner, ev, src, qk):
|
||||
async def slow_inner(self_inner, ev, src, qk, generation):
|
||||
await barrier.wait()
|
||||
return "ok"
|
||||
|
||||
|
|
@ -381,6 +390,7 @@ async def test_stop_hard_kills_running_agent():
|
|||
fake_agent = MagicMock()
|
||||
fake_agent.get_activity_summary.return_value = {"seconds_since_activity": 0}
|
||||
runner._running_agents[session_key] = fake_agent
|
||||
runner.adapters[Platform.TELEGRAM]._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
# Send /stop
|
||||
stop_event = _make_event(text="/stop")
|
||||
|
|
@ -393,6 +403,10 @@ async def test_stop_hard_kills_running_agent():
|
|||
assert session_key not in runner._running_agents, (
|
||||
"/stop must remove the agent from _running_agents so the session is unlocked"
|
||||
)
|
||||
assert runner.adapters[Platform.TELEGRAM].interrupted_sessions == [
|
||||
(session_key, "12345")
|
||||
]
|
||||
assert runner.adapters[Platform.TELEGRAM]._active_sessions[session_key].is_set()
|
||||
|
||||
# Must return a confirmation
|
||||
assert result is not None
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ def _make_runner(session_entry: SessionEntry):
|
|||
runner.session_store.rewrite_transcript = MagicMock()
|
||||
runner.session_store.update_session = MagicMock()
|
||||
runner._running_agents = {}
|
||||
runner._session_run_generation = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = MagicMock()
|
||||
|
|
@ -223,6 +224,52 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_discards_stale_result_after_session_invalidation(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
|
||||
session_key = session_entry.session_key
|
||||
runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks = {session_key: object()}
|
||||
|
||||
async def _stale_result(**kwargs):
|
||||
runner._invalidate_session_run_generation(kwargs["session_key"], reason="test_stale_result")
|
||||
return {
|
||||
"final_response": "late reply",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"history_offset": 0,
|
||||
"last_prompt_tokens": 80,
|
||||
"input_tokens": 120,
|
||||
"output_tokens": 45,
|
||||
"model": "openai/test-model",
|
||||
}
|
||||
|
||||
runner._run_agent = AsyncMock(side_effect=_stale_result)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
monkeypatch.setattr(
|
||||
"agent.model_metadata.get_model_context_length",
|
||||
lambda *_args, **_kwargs: 100000,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(_make_event("hello"))
|
||||
|
||||
assert result is None
|
||||
runner.session_store.append_to_transcript.assert_not_called()
|
||||
runner.session_store.update_session.assert_not_called()
|
||||
assert session_key not in runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_command_bypasses_active_session_guard():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue