mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gateway): stop typing loops on session interrupt
This commit is contained in:
parent
b05d30418d
commit
150382e8b7
6 changed files with 456 additions and 18 deletions
|
|
@ -1401,7 +1401,13 @@ class BasePlatformAdapter(ABC):
|
||||||
|
|
||||||
return paths, cleaned
|
return paths, cleaned
|
||||||
|
|
||||||
async def _keep_typing(self, chat_id: str, interval: float = 2.0, metadata=None) -> None:
|
async def _keep_typing(
|
||||||
|
self,
|
||||||
|
chat_id: str,
|
||||||
|
interval: float = 2.0,
|
||||||
|
metadata=None,
|
||||||
|
stop_event: asyncio.Event | None = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Continuously send typing indicator until cancelled.
|
Continuously send typing indicator until cancelled.
|
||||||
|
|
||||||
|
|
@ -1415,9 +1421,18 @@ class BasePlatformAdapter(ABC):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
if stop_event is not None and stop_event.is_set():
|
||||||
|
return
|
||||||
if chat_id not in self._typing_paused:
|
if chat_id not in self._typing_paused:
|
||||||
await self.send_typing(chat_id, metadata=metadata)
|
await self.send_typing(chat_id, metadata=metadata)
|
||||||
await asyncio.sleep(interval)
|
if stop_event is None:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(stop_event.wait(), timeout=interval)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
return
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass # Normal cancellation when handler completes
|
pass # Normal cancellation when handler completes
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -1444,6 +1459,17 @@ class BasePlatformAdapter(ABC):
|
||||||
"""Resume typing indicator for a chat after approval resolves."""
|
"""Resume typing indicator for a chat after approval resolves."""
|
||||||
self._typing_paused.discard(chat_id)
|
self._typing_paused.discard(chat_id)
|
||||||
|
|
||||||
|
async def interrupt_session_activity(self, session_key: str, chat_id: str) -> None:
|
||||||
|
"""Signal the active session loop to stop and clear typing immediately."""
|
||||||
|
if session_key:
|
||||||
|
interrupt_event = self._active_sessions.get(session_key)
|
||||||
|
if interrupt_event is not None:
|
||||||
|
interrupt_event.set()
|
||||||
|
try:
|
||||||
|
await self.stop_typing(chat_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# ── Processing lifecycle hooks ──────────────────────────────────────────
|
# ── Processing lifecycle hooks ──────────────────────────────────────────
|
||||||
# Subclasses override these to react to message processing events
|
# Subclasses override these to react to message processing events
|
||||||
# (e.g. Discord adds 👀/✅/❌ reactions).
|
# (e.g. Discord adds 👀/✅/❌ reactions).
|
||||||
|
|
@ -1717,7 +1743,13 @@ class BasePlatformAdapter(ABC):
|
||||||
|
|
||||||
# Start continuous typing indicator (refreshes every 2 seconds)
|
# Start continuous typing indicator (refreshes every 2 seconds)
|
||||||
_thread_metadata = {"thread_id": event.source.thread_id} if event.source.thread_id else None
|
_thread_metadata = {"thread_id": event.source.thread_id} if event.source.thread_id else None
|
||||||
typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id, metadata=_thread_metadata))
|
typing_task = asyncio.create_task(
|
||||||
|
self._keep_typing(
|
||||||
|
event.source.chat_id,
|
||||||
|
metadata=_thread_metadata,
|
||||||
|
stop_event=interrupt_event,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._run_processing_hook("on_processing_start", event)
|
await self._run_processing_hook("on_processing_start", event)
|
||||||
|
|
|
||||||
147
gateway/run.py
147
gateway/run.py
|
|
@ -402,6 +402,26 @@ def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None:
|
||||||
return adapter.get_pending_message(session_key)
|
return adapter.get_pending_message(session_key)
|
||||||
|
|
||||||
|
|
||||||
|
_CONTROL_INTERRUPT_MESSAGES = frozenset(
|
||||||
|
{
|
||||||
|
"stop requested",
|
||||||
|
"session reset requested",
|
||||||
|
"execution timed out (inactivity)",
|
||||||
|
"sse client disconnected",
|
||||||
|
"gateway shutting down",
|
||||||
|
"gateway restarting",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_control_interrupt_message(message: Optional[str]) -> bool:
|
||||||
|
"""Return True when an interrupt message is internal control flow."""
|
||||||
|
if not message:
|
||||||
|
return False
|
||||||
|
normalized = " ".join(str(message).strip().split()).lower()
|
||||||
|
return normalized in _CONTROL_INTERRUPT_MESSAGES
|
||||||
|
|
||||||
|
|
||||||
def _check_unavailable_skill(command_name: str) -> str | None:
|
def _check_unavailable_skill(command_name: str) -> str | None:
|
||||||
"""Check if a command matches a known-but-inactive skill.
|
"""Check if a command matches a known-but-inactive skill.
|
||||||
|
|
||||||
|
|
@ -630,6 +650,7 @@ class GatewayRunner:
|
||||||
self._running_agents_ts: Dict[str, float] = {} # start timestamp per session
|
self._running_agents_ts: Dict[str, float] = {} # start timestamp per session
|
||||||
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
|
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
|
||||||
self._busy_ack_ts: Dict[str, float] = {} # last busy-ack timestamp per session (debounce)
|
self._busy_ack_ts: Dict[str, float] = {} # last busy-ack timestamp per session (debounce)
|
||||||
|
self._session_run_generation: Dict[str, int] = {}
|
||||||
|
|
||||||
# Cache AIAgent instances per session to preserve prompt caching.
|
# Cache AIAgent instances per session to preserve prompt caching.
|
||||||
# Without this, a new AIAgent is created per message, rebuilding the
|
# Without this, a new AIAgent is created per message, rebuilding the
|
||||||
|
|
@ -3064,6 +3085,10 @@ class GatewayRunner:
|
||||||
_quick_key[:30], _stale_age, _stale_idle,
|
_quick_key[:30], _stale_age, _stale_idle,
|
||||||
_raw_stale_timeout, _stale_detail,
|
_raw_stale_timeout, _stale_detail,
|
||||||
)
|
)
|
||||||
|
self._invalidate_session_run_generation(
|
||||||
|
_quick_key,
|
||||||
|
reason="stale_running_agent_eviction",
|
||||||
|
)
|
||||||
self._release_running_agent_state(_quick_key)
|
self._release_running_agent_state(_quick_key)
|
||||||
|
|
||||||
if _quick_key in self._running_agents:
|
if _quick_key in self._running_agents:
|
||||||
|
|
@ -3091,7 +3116,13 @@ class GatewayRunner:
|
||||||
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
|
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
|
||||||
running_agent.interrupt("Stop requested")
|
running_agent.interrupt("Stop requested")
|
||||||
# Force-clean: remove the session lock regardless of agent state
|
# Force-clean: remove the session lock regardless of agent state
|
||||||
|
self._invalidate_session_run_generation(
|
||||||
|
_quick_key,
|
||||||
|
reason="stop_command",
|
||||||
|
)
|
||||||
adapter = self.adapters.get(source.platform)
|
adapter = self.adapters.get(source.platform)
|
||||||
|
if adapter and hasattr(adapter, "interrupt_session_activity"):
|
||||||
|
await adapter.interrupt_session_activity(_quick_key, source.chat_id)
|
||||||
if adapter and hasattr(adapter, 'get_pending_message'):
|
if adapter and hasattr(adapter, 'get_pending_message'):
|
||||||
adapter.get_pending_message(_quick_key) # consume and discard
|
adapter.get_pending_message(_quick_key) # consume and discard
|
||||||
self._pending_messages.pop(_quick_key, None)
|
self._pending_messages.pop(_quick_key, None)
|
||||||
|
|
@ -3111,7 +3142,13 @@ class GatewayRunner:
|
||||||
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
|
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
|
||||||
running_agent.interrupt("Session reset requested")
|
running_agent.interrupt("Session reset requested")
|
||||||
# Clear any pending messages so the old text doesn't replay
|
# Clear any pending messages so the old text doesn't replay
|
||||||
|
self._invalidate_session_run_generation(
|
||||||
|
_quick_key,
|
||||||
|
reason="new_command",
|
||||||
|
)
|
||||||
adapter = self.adapters.get(source.platform)
|
adapter = self.adapters.get(source.platform)
|
||||||
|
if adapter and hasattr(adapter, "interrupt_session_activity"):
|
||||||
|
await adapter.interrupt_session_activity(_quick_key, source.chat_id)
|
||||||
if adapter and hasattr(adapter, 'get_pending_message'):
|
if adapter and hasattr(adapter, 'get_pending_message'):
|
||||||
adapter.get_pending_message(_quick_key) # consume and discard
|
adapter.get_pending_message(_quick_key) # consume and discard
|
||||||
self._pending_messages.pop(_quick_key, None)
|
self._pending_messages.pop(_quick_key, None)
|
||||||
|
|
@ -3598,9 +3635,10 @@ class GatewayRunner:
|
||||||
# same session — corrupting the transcript.
|
# same session — corrupting the transcript.
|
||||||
self._running_agents[_quick_key] = _AGENT_PENDING_SENTINEL
|
self._running_agents[_quick_key] = _AGENT_PENDING_SENTINEL
|
||||||
self._running_agents_ts[_quick_key] = time.time()
|
self._running_agents_ts[_quick_key] = time.time()
|
||||||
|
_run_generation = self._begin_session_run_generation(_quick_key)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self._handle_message_with_agent(event, source, _quick_key)
|
return await self._handle_message_with_agent(event, source, _quick_key, _run_generation)
|
||||||
finally:
|
finally:
|
||||||
# If _run_agent replaced the sentinel with a real agent and
|
# If _run_agent replaced the sentinel with a real agent and
|
||||||
# then cleaned it up, this is a no-op. If we exited early
|
# then cleaned it up, this is a no-op. If we exited early
|
||||||
|
|
@ -3771,7 +3809,7 @@ class GatewayRunner:
|
||||||
|
|
||||||
return message_text
|
return message_text
|
||||||
|
|
||||||
async def _handle_message_with_agent(self, event, source, _quick_key: str):
|
async def _handle_message_with_agent(self, event, source, _quick_key: str, run_generation: int):
|
||||||
"""Inner handler that runs under the _running_agents sentinel guard."""
|
"""Inner handler that runs under the _running_agents sentinel guard."""
|
||||||
_msg_start_time = time.time()
|
_msg_start_time = time.time()
|
||||||
_platform_name = source.platform.value if hasattr(source.platform, "value") else str(source.platform)
|
_platform_name = source.platform.value if hasattr(source.platform, "value") else str(source.platform)
|
||||||
|
|
@ -4246,6 +4284,7 @@ class GatewayRunner:
|
||||||
source=source,
|
source=source,
|
||||||
session_id=session_entry.session_id,
|
session_id=session_entry.session_id,
|
||||||
session_key=session_key,
|
session_key=session_key,
|
||||||
|
run_generation=run_generation,
|
||||||
event_message_id=event.message_id,
|
event_message_id=event.message_id,
|
||||||
channel_prompt=event.channel_prompt,
|
channel_prompt=event.channel_prompt,
|
||||||
)
|
)
|
||||||
|
|
@ -4258,6 +4297,17 @@ class GatewayRunner:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if not self._is_session_run_current(_quick_key, run_generation):
|
||||||
|
logger.info(
|
||||||
|
"Discarding stale agent result for %s — generation %d is no longer current",
|
||||||
|
_quick_key[:20] if _quick_key else "?",
|
||||||
|
run_generation,
|
||||||
|
)
|
||||||
|
_stale_adapter = self.adapters.get(source.platform)
|
||||||
|
if _stale_adapter and hasattr(_stale_adapter, "_post_delivery_callbacks"):
|
||||||
|
_stale_adapter._post_delivery_callbacks.pop(_quick_key, None)
|
||||||
|
return None
|
||||||
|
|
||||||
response = agent_result.get("final_response") or ""
|
response = agent_result.get("final_response") or ""
|
||||||
|
|
||||||
# Convert the agent's internal "(empty)" sentinel into a
|
# Convert the agent's internal "(empty)" sentinel into a
|
||||||
|
|
@ -4672,6 +4722,7 @@ class GatewayRunner:
|
||||||
|
|
||||||
# Get existing session key
|
# Get existing session key
|
||||||
session_key = self._session_key_for_source(source)
|
session_key = self._session_key_for_source(source)
|
||||||
|
self._invalidate_session_run_generation(session_key, reason="session_reset")
|
||||||
|
|
||||||
# Flush memories in the background (fire-and-forget) so the user
|
# Flush memories in the background (fire-and-forget) so the user
|
||||||
# gets the "Session reset!" response immediately.
|
# gets the "Session reset!" response immediately.
|
||||||
|
|
@ -4931,6 +4982,10 @@ class GatewayRunner:
|
||||||
agent = self._running_agents.get(session_key)
|
agent = self._running_agents.get(session_key)
|
||||||
if agent is _AGENT_PENDING_SENTINEL:
|
if agent is _AGENT_PENDING_SENTINEL:
|
||||||
# Force-clean the sentinel so the session is unlocked.
|
# Force-clean the sentinel so the session is unlocked.
|
||||||
|
self._invalidate_session_run_generation(session_key, reason="stop_command_pending")
|
||||||
|
adapter = self.adapters.get(source.platform)
|
||||||
|
if adapter and hasattr(adapter, "interrupt_session_activity"):
|
||||||
|
await adapter.interrupt_session_activity(session_key, source.chat_id)
|
||||||
self._release_running_agent_state(session_key)
|
self._release_running_agent_state(session_key)
|
||||||
logger.info("STOP (pending) for session %s — sentinel cleared", session_key[:20])
|
logger.info("STOP (pending) for session %s — sentinel cleared", session_key[:20])
|
||||||
return "⚡ Stopped. The agent hadn't started yet — you can continue this session."
|
return "⚡ Stopped. The agent hadn't started yet — you can continue this session."
|
||||||
|
|
@ -4938,6 +4993,10 @@ class GatewayRunner:
|
||||||
agent.interrupt("Stop requested")
|
agent.interrupt("Stop requested")
|
||||||
# Force-clean the session lock so a truly hung agent doesn't
|
# Force-clean the session lock so a truly hung agent doesn't
|
||||||
# keep it locked forever.
|
# keep it locked forever.
|
||||||
|
self._invalidate_session_run_generation(session_key, reason="stop_command_handler")
|
||||||
|
adapter = self.adapters.get(source.platform)
|
||||||
|
if adapter and hasattr(adapter, "interrupt_session_activity"):
|
||||||
|
await adapter.interrupt_session_activity(session_key, source.chat_id)
|
||||||
self._release_running_agent_state(session_key)
|
self._release_running_agent_state(session_key)
|
||||||
return "⚡ Stopped. You can continue this session."
|
return "⚡ Stopped. You can continue this session."
|
||||||
else:
|
else:
|
||||||
|
|
@ -8385,6 +8444,43 @@ class GatewayRunner:
|
||||||
if hasattr(self, "_busy_ack_ts"):
|
if hasattr(self, "_busy_ack_ts"):
|
||||||
self._busy_ack_ts.pop(session_key, None)
|
self._busy_ack_ts.pop(session_key, None)
|
||||||
|
|
||||||
|
def _begin_session_run_generation(self, session_key: str) -> int:
|
||||||
|
"""Claim a fresh run generation token for ``session_key``.
|
||||||
|
|
||||||
|
Every top-level gateway turn gets a monotonically increasing token.
|
||||||
|
If a later command like /stop or /new invalidates that token while the
|
||||||
|
old worker is still unwinding, the late result can be recognized and
|
||||||
|
dropped instead of bleeding into the fresh session.
|
||||||
|
"""
|
||||||
|
if not session_key:
|
||||||
|
return 0
|
||||||
|
generations = self.__dict__.get("_session_run_generation")
|
||||||
|
if generations is None:
|
||||||
|
generations = {}
|
||||||
|
self._session_run_generation = generations
|
||||||
|
next_generation = int(generations.get(session_key, 0)) + 1
|
||||||
|
generations[session_key] = next_generation
|
||||||
|
return next_generation
|
||||||
|
|
||||||
|
def _invalidate_session_run_generation(self, session_key: str, *, reason: str = "") -> int:
|
||||||
|
"""Invalidate any in-flight run token for ``session_key``."""
|
||||||
|
generation = self._begin_session_run_generation(session_key)
|
||||||
|
if reason:
|
||||||
|
logger.info(
|
||||||
|
"Invalidated run generation for %s → %d (%s)",
|
||||||
|
session_key[:20],
|
||||||
|
generation,
|
||||||
|
reason,
|
||||||
|
)
|
||||||
|
return generation
|
||||||
|
|
||||||
|
def _is_session_run_current(self, session_key: str, generation: int) -> bool:
|
||||||
|
"""Return True when ``generation`` is still current for ``session_key``."""
|
||||||
|
if not session_key:
|
||||||
|
return True
|
||||||
|
generations = self.__dict__.get("_session_run_generation") or {}
|
||||||
|
return int(generations.get(session_key, 0)) == int(generation)
|
||||||
|
|
||||||
def _evict_cached_agent(self, session_key: str) -> None:
|
def _evict_cached_agent(self, session_key: str) -> None:
|
||||||
"""Remove a cached agent for a session (called on /new, /model, etc)."""
|
"""Remove a cached agent for a session (called on /new, /model, etc)."""
|
||||||
_lock = getattr(self, "_agent_cache_lock", None)
|
_lock = getattr(self, "_agent_cache_lock", None)
|
||||||
|
|
@ -8807,6 +8903,7 @@ class GatewayRunner:
|
||||||
source: SessionSource,
|
source: SessionSource,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
session_key: str = None,
|
session_key: str = None,
|
||||||
|
run_generation: Optional[int] = None,
|
||||||
_interrupt_depth: int = 0,
|
_interrupt_depth: int = 0,
|
||||||
event_message_id: Optional[str] = None,
|
event_message_id: Optional[str] = None,
|
||||||
channel_prompt: Optional[str] = None,
|
channel_prompt: Optional[str] = None,
|
||||||
|
|
@ -8837,6 +8934,11 @@ class GatewayRunner:
|
||||||
|
|
||||||
from run_agent import AIAgent
|
from run_agent import AIAgent
|
||||||
import queue
|
import queue
|
||||||
|
|
||||||
|
def _run_still_current() -> bool:
|
||||||
|
if run_generation is None or not session_key:
|
||||||
|
return True
|
||||||
|
return self._is_session_run_current(session_key, run_generation)
|
||||||
|
|
||||||
user_config = _load_gateway_config()
|
user_config = _load_gateway_config()
|
||||||
platform_key = _platform_config_key(source.platform)
|
platform_key = _platform_config_key(source.platform)
|
||||||
|
|
@ -8891,7 +8993,7 @@ class GatewayRunner:
|
||||||
|
|
||||||
def progress_callback(event_type: str, tool_name: str = None, preview: str = None, args: dict = None, **kwargs):
|
def progress_callback(event_type: str, tool_name: str = None, preview: str = None, args: dict = None, **kwargs):
|
||||||
"""Callback invoked by agent on tool lifecycle events."""
|
"""Callback invoked by agent on tool lifecycle events."""
|
||||||
if not progress_queue:
|
if not progress_queue or not _run_still_current():
|
||||||
return
|
return
|
||||||
|
|
||||||
# Only act on tool.started events (ignore tool.completed, reasoning.available, etc.)
|
# Only act on tool.started events (ignore tool.completed, reasoning.available, etc.)
|
||||||
|
|
@ -8996,6 +9098,14 @@ class GatewayRunner:
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
if not _run_still_current():
|
||||||
|
while not progress_queue.empty():
|
||||||
|
try:
|
||||||
|
progress_queue.get_nowait()
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
return
|
||||||
|
|
||||||
raw = progress_queue.get_nowait()
|
raw = progress_queue.get_nowait()
|
||||||
|
|
||||||
# Handle dedup messages: update last line with repeat counter
|
# Handle dedup messages: update last line with repeat counter
|
||||||
|
|
@ -9021,6 +9131,9 @@ class GatewayRunner:
|
||||||
await asyncio.sleep(_remaining)
|
await asyncio.sleep(_remaining)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if not _run_still_current():
|
||||||
|
return
|
||||||
|
|
||||||
if can_edit and progress_msg_id is not None:
|
if can_edit and progress_msg_id is not None:
|
||||||
# Try to edit the existing progress message
|
# Try to edit the existing progress message
|
||||||
full_text = "\n".join(progress_lines)
|
full_text = "\n".join(progress_lines)
|
||||||
|
|
@ -9056,7 +9169,8 @@ class GatewayRunner:
|
||||||
|
|
||||||
# Restore typing indicator
|
# Restore typing indicator
|
||||||
await asyncio.sleep(0.3)
|
await asyncio.sleep(0.3)
|
||||||
await adapter.send_typing(source.chat_id, metadata=_progress_metadata)
|
if _run_still_current():
|
||||||
|
await adapter.send_typing(source.chat_id, metadata=_progress_metadata)
|
||||||
|
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
await asyncio.sleep(0.3)
|
await asyncio.sleep(0.3)
|
||||||
|
|
@ -9100,6 +9214,8 @@ class GatewayRunner:
|
||||||
_hooks_ref = self.hooks
|
_hooks_ref = self.hooks
|
||||||
|
|
||||||
def _step_callback_sync(iteration: int, prev_tools: list) -> None:
|
def _step_callback_sync(iteration: int, prev_tools: list) -> None:
|
||||||
|
if not _run_still_current():
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
# prev_tools may be list[str] or list[dict] with "name"/"result"
|
# prev_tools may be list[str] or list[dict] with "name"/"result"
|
||||||
# keys. Normalise to keep "tool_names" backward-compatible for
|
# keys. Normalise to keep "tool_names" backward-compatible for
|
||||||
|
|
@ -9130,7 +9246,7 @@ class GatewayRunner:
|
||||||
_status_thread_metadata = {"thread_id": _progress_thread_id} if _progress_thread_id else None
|
_status_thread_metadata = {"thread_id": _progress_thread_id} if _progress_thread_id else None
|
||||||
|
|
||||||
def _status_callback_sync(event_type: str, message: str) -> None:
|
def _status_callback_sync(event_type: str, message: str) -> None:
|
||||||
if not _status_adapter:
|
if not _status_adapter or not _run_still_current():
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
asyncio.run_coroutine_threadsafe(
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
|
@ -9261,12 +9377,16 @@ class GatewayRunner:
|
||||||
metadata={"thread_id": _progress_thread_id} if _progress_thread_id else None,
|
metadata={"thread_id": _progress_thread_id} if _progress_thread_id else None,
|
||||||
)
|
)
|
||||||
if _want_stream_deltas:
|
if _want_stream_deltas:
|
||||||
_stream_delta_cb = _stream_consumer.on_delta
|
def _stream_delta_cb(text: str) -> None:
|
||||||
|
if _run_still_current():
|
||||||
|
_stream_consumer.on_delta(text)
|
||||||
stream_consumer_holder[0] = _stream_consumer
|
stream_consumer_holder[0] = _stream_consumer
|
||||||
except Exception as _sc_err:
|
except Exception as _sc_err:
|
||||||
logger.debug("Could not set up stream consumer: %s", _sc_err)
|
logger.debug("Could not set up stream consumer: %s", _sc_err)
|
||||||
|
|
||||||
def _interim_assistant_cb(text: str, *, already_streamed: bool = False) -> None:
|
def _interim_assistant_cb(text: str, *, already_streamed: bool = False) -> None:
|
||||||
|
if not _run_still_current():
|
||||||
|
return
|
||||||
if _stream_consumer is not None:
|
if _stream_consumer is not None:
|
||||||
if already_streamed:
|
if already_streamed:
|
||||||
_stream_consumer.on_segment_break()
|
_stream_consumer.on_segment_break()
|
||||||
|
|
@ -9370,7 +9490,7 @@ class GatewayRunner:
|
||||||
_bg_review_pending_lock = threading.Lock()
|
_bg_review_pending_lock = threading.Lock()
|
||||||
|
|
||||||
def _deliver_bg_review_message(message: str) -> None:
|
def _deliver_bg_review_message(message: str) -> None:
|
||||||
if not _status_adapter:
|
if not _status_adapter or not _run_still_current():
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
asyncio.run_coroutine_threadsafe(
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
|
@ -9394,7 +9514,7 @@ class GatewayRunner:
|
||||||
|
|
||||||
# Background review delivery — send "💾 Memory updated" etc. to user
|
# Background review delivery — send "💾 Memory updated" etc. to user
|
||||||
def _bg_review_send(message: str) -> None:
|
def _bg_review_send(message: str) -> None:
|
||||||
if not _status_adapter:
|
if not _status_adapter or not _run_still_current():
|
||||||
return
|
return
|
||||||
if not _bg_review_release.is_set():
|
if not _bg_review_release.is_set():
|
||||||
with _bg_review_pending_lock:
|
with _bg_review_pending_lock:
|
||||||
|
|
@ -10076,7 +10196,15 @@ class GatewayRunner:
|
||||||
if result and adapter and session_key:
|
if result and adapter and session_key:
|
||||||
pending_event = _dequeue_pending_event(adapter, session_key)
|
pending_event = _dequeue_pending_event(adapter, session_key)
|
||||||
if result.get("interrupted") and not pending_event and result.get("interrupt_message"):
|
if result.get("interrupted") and not pending_event and result.get("interrupt_message"):
|
||||||
pending = result.get("interrupt_message")
|
interrupt_message = result.get("interrupt_message")
|
||||||
|
if _is_control_interrupt_message(interrupt_message):
|
||||||
|
logger.info(
|
||||||
|
"Ignoring control interrupt message for session %s: %s",
|
||||||
|
session_key[:20] if session_key else "?",
|
||||||
|
interrupt_message,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pending = interrupt_message
|
||||||
elif pending_event:
|
elif pending_event:
|
||||||
pending = pending_event.text or _build_media_placeholder(pending_event)
|
pending = pending_event.text or _build_media_placeholder(pending_event)
|
||||||
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
|
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
|
||||||
|
|
@ -10229,6 +10357,7 @@ class GatewayRunner:
|
||||||
source=next_source,
|
source=next_source,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
session_key=session_key,
|
session_key=session_key,
|
||||||
|
run_generation=run_generation,
|
||||||
_interrupt_depth=_interrupt_depth + 1,
|
_interrupt_depth=_interrupt_depth + 1,
|
||||||
event_message_id=next_message_id,
|
event_message_id=next_message_id,
|
||||||
channel_prompt=next_channel_prompt,
|
channel_prompt=next_channel_prompt,
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,18 @@
|
||||||
"""Tests for the pending_event None guard in recursive _run_agent calls.
|
"""Tests for pending follow-up extraction in recursive _run_agent calls.
|
||||||
|
|
||||||
When pending_event is None (Path B: pending comes from interrupt_message),
|
When pending_event is None (Path B: pending comes from interrupt_message),
|
||||||
accessing pending_event.channel_prompt previously raised AttributeError.
|
accessing pending_event.channel_prompt previously raised AttributeError.
|
||||||
This verifies the fix: channel_prompt is captured inside the
|
This verifies the fix: channel_prompt is captured inside the
|
||||||
`if pending_event is not None:` block and falls back to None otherwise.
|
`if pending_event is not None:` block and falls back to None otherwise.
|
||||||
|
|
||||||
|
Also verifies that internal control interrupt reasons like "Stop requested"
|
||||||
|
do not get recycled into the pending-user-message follow-up path.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from gateway.run import _is_control_interrupt_message
|
||||||
|
|
||||||
|
|
||||||
def _extract_channel_prompt(pending_event):
|
def _extract_channel_prompt(pending_event):
|
||||||
"""Reproduce the fixed logic from gateway/run.py.
|
"""Reproduce the fixed logic from gateway/run.py.
|
||||||
|
|
@ -21,6 +26,15 @@ def _extract_channel_prompt(pending_event):
|
||||||
return next_channel_prompt
|
return next_channel_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_pending_text(interrupted, pending_event, interrupt_message):
|
||||||
|
"""Reproduce the fixed pending-text selection from gateway/run.py."""
|
||||||
|
if interrupted and pending_event is None and interrupt_message:
|
||||||
|
if _is_control_interrupt_message(interrupt_message):
|
||||||
|
return None
|
||||||
|
return interrupt_message
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class TestPendingEventNoneChannelPrompt:
|
class TestPendingEventNoneChannelPrompt:
|
||||||
"""Guard against AttributeError when pending_event is None."""
|
"""Guard against AttributeError when pending_event is None."""
|
||||||
|
|
||||||
|
|
@ -40,3 +54,19 @@ class TestPendingEventNoneChannelPrompt:
|
||||||
event = SimpleNamespace()
|
event = SimpleNamespace()
|
||||||
result = _extract_channel_prompt(event)
|
result = _extract_channel_prompt(event)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestControlInterruptMessages:
|
||||||
|
"""Control interrupt reasons must not become follow-up user input."""
|
||||||
|
|
||||||
|
def test_stop_requested_is_not_treated_as_pending_user_message(self):
|
||||||
|
result = _extract_pending_text(True, None, "Stop requested")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_session_reset_requested_is_not_treated_as_pending_user_message(self):
|
||||||
|
result = _extract_pending_text(True, None, "Session reset requested")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_real_user_interrupt_message_still_requeues(self):
|
||||||
|
result = _extract_pending_text(True, None, "actually use postgres instead")
|
||||||
|
assert result == "actually use postgres instead"
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,9 @@ class ProgressCaptureAdapter(BasePlatformAdapter):
|
||||||
async def send_typing(self, chat_id, metadata=None) -> None:
|
async def send_typing(self, chat_id, metadata=None) -> None:
|
||||||
self.typing.append({"chat_id": chat_id, "metadata": metadata})
|
self.typing.append({"chat_id": chat_id, "metadata": metadata})
|
||||||
|
|
||||||
|
async def stop_typing(self, chat_id) -> None:
|
||||||
|
self.typing.append({"chat_id": chat_id, "metadata": {"stopped": True}})
|
||||||
|
|
||||||
async def get_chat_info(self, chat_id: str):
|
async def get_chat_info(self, chat_id: str):
|
||||||
return {"id": chat_id}
|
return {"id": chat_id}
|
||||||
|
|
||||||
|
|
@ -90,6 +93,40 @@ class LongPreviewAgent:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DelayedProgressAgent:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||||
|
self.tools = []
|
||||||
|
|
||||||
|
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||||
|
self.tool_progress_callback("tool.started", "terminal", "first command", {})
|
||||||
|
time.sleep(0.45)
|
||||||
|
self.tool_progress_callback("tool.started", "terminal", "second command", {})
|
||||||
|
time.sleep(0.1)
|
||||||
|
return {
|
||||||
|
"final_response": "done",
|
||||||
|
"messages": [],
|
||||||
|
"api_calls": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DelayedInterimAgent:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||||
|
self.tools = []
|
||||||
|
|
||||||
|
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||||
|
self.interim_assistant_callback("first interim")
|
||||||
|
time.sleep(0.45)
|
||||||
|
self.interim_assistant_callback("second interim")
|
||||||
|
time.sleep(0.1)
|
||||||
|
return {
|
||||||
|
"final_response": "done",
|
||||||
|
"messages": [],
|
||||||
|
"api_calls": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _make_runner(adapter):
|
def _make_runner(adapter):
|
||||||
gateway_run = importlib.import_module("gateway.run")
|
gateway_run = importlib.import_module("gateway.run")
|
||||||
GatewayRunner = gateway_run.GatewayRunner
|
GatewayRunner = gateway_run.GatewayRunner
|
||||||
|
|
@ -104,6 +141,7 @@ def _make_runner(adapter):
|
||||||
runner._fallback_model = None
|
runner._fallback_model = None
|
||||||
runner._session_db = None
|
runner._session_db = None
|
||||||
runner._running_agents = {}
|
runner._running_agents = {}
|
||||||
|
runner._session_run_generation = {}
|
||||||
runner.hooks = SimpleNamespace(loaded_hooks=False)
|
runner.hooks = SimpleNamespace(loaded_hooks=False)
|
||||||
runner.config = SimpleNamespace(
|
runner.config = SimpleNamespace(
|
||||||
thread_sessions_per_user=False,
|
thread_sessions_per_user=False,
|
||||||
|
|
@ -744,6 +782,154 @@ async def test_base_processing_releases_post_delivery_callback_after_main_send()
|
||||||
assert released == [True]
|
assert released == [True]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_agent_drops_tool_progress_after_generation_invalidation(monkeypatch, tmp_path):
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
(tmp_path / "config.yaml").write_text(
|
||||||
|
yaml.dump({"display": {"tool_progress": "all"}}),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
fake_dotenv = types.ModuleType("dotenv")
|
||||||
|
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||||
|
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||||
|
|
||||||
|
fake_run_agent = types.ModuleType("run_agent")
|
||||||
|
fake_run_agent.AIAgent = DelayedProgressAgent
|
||||||
|
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||||
|
import tools.terminal_tool # noqa: F401 - register terminal tool metadata
|
||||||
|
|
||||||
|
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
|
||||||
|
runner = _make_runner(adapter)
|
||||||
|
gateway_run = importlib.import_module("gateway.run")
|
||||||
|
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||||
|
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||||
|
|
||||||
|
source = SessionSource(
|
||||||
|
platform=Platform.DISCORD,
|
||||||
|
chat_id="dm-1",
|
||||||
|
chat_type="dm",
|
||||||
|
thread_id=None,
|
||||||
|
)
|
||||||
|
session_key = "agent:main:discord:dm:dm-1"
|
||||||
|
runner._session_run_generation[session_key] = 1
|
||||||
|
|
||||||
|
original_send = adapter.send
|
||||||
|
invalidated = {"done": False}
|
||||||
|
|
||||||
|
async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None):
|
||||||
|
result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata)
|
||||||
|
if "first command" in content and not invalidated["done"]:
|
||||||
|
invalidated["done"] = True
|
||||||
|
runner._invalidate_session_run_generation(session_key, reason="test_stop")
|
||||||
|
return result
|
||||||
|
|
||||||
|
adapter.send = send_and_invalidate
|
||||||
|
|
||||||
|
result = await runner._run_agent(
|
||||||
|
message="hello",
|
||||||
|
context_prompt="",
|
||||||
|
history=[],
|
||||||
|
source=source,
|
||||||
|
session_id="sess-progress-stop",
|
||||||
|
session_key=session_key,
|
||||||
|
run_generation=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_progress_text = " ".join(call["content"] for call in adapter.sent)
|
||||||
|
all_progress_text += " ".join(call["content"] for call in adapter.edits)
|
||||||
|
assert result["final_response"] == "done"
|
||||||
|
assert 'first command' in all_progress_text
|
||||||
|
assert 'second command' not in all_progress_text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_agent_drops_interim_commentary_after_generation_invalidation(monkeypatch, tmp_path):
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
(tmp_path / "config.yaml").write_text(
|
||||||
|
yaml.dump({"display": {"tool_progress": "off", "interim_assistant_messages": True}}),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
fake_dotenv = types.ModuleType("dotenv")
|
||||||
|
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||||
|
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||||
|
|
||||||
|
fake_run_agent = types.ModuleType("run_agent")
|
||||||
|
fake_run_agent.AIAgent = DelayedInterimAgent
|
||||||
|
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||||
|
|
||||||
|
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
|
||||||
|
runner = _make_runner(adapter)
|
||||||
|
gateway_run = importlib.import_module("gateway.run")
|
||||||
|
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||||
|
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||||
|
|
||||||
|
source = SessionSource(
|
||||||
|
platform=Platform.DISCORD,
|
||||||
|
chat_id="dm-2",
|
||||||
|
chat_type="dm",
|
||||||
|
thread_id=None,
|
||||||
|
)
|
||||||
|
session_key = "agent:main:discord:dm:dm-2"
|
||||||
|
runner._session_run_generation[session_key] = 1
|
||||||
|
|
||||||
|
original_send = adapter.send
|
||||||
|
invalidated = {"done": False}
|
||||||
|
|
||||||
|
async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None):
|
||||||
|
result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata)
|
||||||
|
if content == "first interim" and not invalidated["done"]:
|
||||||
|
invalidated["done"] = True
|
||||||
|
runner._invalidate_session_run_generation(session_key, reason="test_stop")
|
||||||
|
return result
|
||||||
|
|
||||||
|
adapter.send = send_and_invalidate
|
||||||
|
|
||||||
|
result = await runner._run_agent(
|
||||||
|
message="hello",
|
||||||
|
context_prompt="",
|
||||||
|
history=[],
|
||||||
|
source=source,
|
||||||
|
session_id="sess-commentary-stop",
|
||||||
|
session_key=session_key,
|
||||||
|
run_generation=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
sent_texts = [call["content"] for call in adapter.sent]
|
||||||
|
assert result["final_response"] == "done"
|
||||||
|
assert "first interim" in sent_texts
|
||||||
|
assert "second interim" not in sent_texts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keep_typing_stops_immediately_when_interrupt_event_is_set():
|
||||||
|
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
|
||||||
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
adapter._keep_typing(
|
||||||
|
"dm-typing-stop",
|
||||||
|
interval=30.0,
|
||||||
|
stop_event=stop_event,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
stop_event.set()
|
||||||
|
await asyncio.wait_for(task, timeout=0.5)
|
||||||
|
|
||||||
|
normal_typing_calls = [
|
||||||
|
call for call in adapter.typing if call.get("metadata") != {"stopped": True}
|
||||||
|
]
|
||||||
|
stopped_calls = [
|
||||||
|
call for call in adapter.typing if call.get("metadata") == {"stopped": True}
|
||||||
|
]
|
||||||
|
assert len(normal_typing_calls) == 1
|
||||||
|
assert len(stopped_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_verbose_mode_does_not_truncate_args_by_default(monkeypatch, tmp_path):
|
async def test_verbose_mode_does_not_truncate_args_by_default(monkeypatch, tmp_path):
|
||||||
"""Verbose mode with default tool_preview_length (0) should NOT truncate args.
|
"""Verbose mode with default tool_preview_length (0) should NOT truncate args.
|
||||||
|
|
|
||||||
|
|
@ -24,10 +24,18 @@ class _FakeAdapter:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._pending_messages = {}
|
self._pending_messages = {}
|
||||||
|
self._active_sessions = {}
|
||||||
|
self.interrupted_sessions = []
|
||||||
|
|
||||||
async def send(self, chat_id, text, **kwargs):
|
async def send(self, chat_id, text, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def interrupt_session_activity(self, session_key, chat_id):
|
||||||
|
self.interrupted_sessions.append((session_key, chat_id))
|
||||||
|
event = self._active_sessions.get(session_key)
|
||||||
|
if event is not None:
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
|
||||||
def _make_runner():
|
def _make_runner():
|
||||||
runner = object.__new__(GatewayRunner)
|
runner = object.__new__(GatewayRunner)
|
||||||
|
|
@ -37,6 +45,7 @@ def _make_runner():
|
||||||
runner.adapters = {Platform.TELEGRAM: _FakeAdapter()}
|
runner.adapters = {Platform.TELEGRAM: _FakeAdapter()}
|
||||||
runner._running_agents = {}
|
runner._running_agents = {}
|
||||||
runner._running_agents_ts = {}
|
runner._running_agents_ts = {}
|
||||||
|
runner._session_run_generation = {}
|
||||||
runner._pending_messages = {}
|
runner._pending_messages = {}
|
||||||
runner._pending_approvals = {}
|
runner._pending_approvals = {}
|
||||||
runner._voice_mode = {}
|
runner._voice_mode = {}
|
||||||
|
|
@ -81,7 +90,7 @@ async def test_sentinel_placed_before_agent_setup():
|
||||||
# Patch _handle_message_with_agent to capture state at entry
|
# Patch _handle_message_with_agent to capture state at entry
|
||||||
sentinel_was_set = False
|
sentinel_was_set = False
|
||||||
|
|
||||||
async def mock_inner(self_inner, ev, src, qk):
|
async def mock_inner(self_inner, ev, src, qk, generation):
|
||||||
nonlocal sentinel_was_set
|
nonlocal sentinel_was_set
|
||||||
sentinel_was_set = runner._running_agents.get(qk) is _AGENT_PENDING_SENTINEL
|
sentinel_was_set = runner._running_agents.get(qk) is _AGENT_PENDING_SENTINEL
|
||||||
return "ok"
|
return "ok"
|
||||||
|
|
@ -105,7 +114,7 @@ async def test_sentinel_cleaned_up_after_handler_returns():
|
||||||
event = _make_event()
|
event = _make_event()
|
||||||
session_key = build_session_key(event.source)
|
session_key = build_session_key(event.source)
|
||||||
|
|
||||||
async def mock_inner(self_inner, ev, src, qk):
|
async def mock_inner(self_inner, ev, src, qk, generation):
|
||||||
return "ok"
|
return "ok"
|
||||||
|
|
||||||
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
||||||
|
|
@ -127,7 +136,7 @@ async def test_sentinel_cleaned_up_on_exception():
|
||||||
event = _make_event()
|
event = _make_event()
|
||||||
session_key = build_session_key(event.source)
|
session_key = build_session_key(event.source)
|
||||||
|
|
||||||
async def mock_inner(self_inner, ev, src, qk):
|
async def mock_inner(self_inner, ev, src, qk, generation):
|
||||||
raise RuntimeError("boom")
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
|
||||||
|
|
@ -154,7 +163,7 @@ async def test_second_message_during_sentinel_queued_not_duplicate():
|
||||||
|
|
||||||
barrier = asyncio.Event()
|
barrier = asyncio.Event()
|
||||||
|
|
||||||
async def slow_inner(self_inner, ev, src, qk):
|
async def slow_inner(self_inner, ev, src, qk, generation):
|
||||||
# Simulate slow setup — wait until test tells us to proceed
|
# Simulate slow setup — wait until test tells us to proceed
|
||||||
await barrier.wait()
|
await barrier.wait()
|
||||||
return "ok"
|
return "ok"
|
||||||
|
|
@ -333,7 +342,7 @@ async def test_stop_during_sentinel_force_cleans_session():
|
||||||
|
|
||||||
barrier = asyncio.Event()
|
barrier = asyncio.Event()
|
||||||
|
|
||||||
async def slow_inner(self_inner, ev, src, qk):
|
async def slow_inner(self_inner, ev, src, qk, generation):
|
||||||
await barrier.wait()
|
await barrier.wait()
|
||||||
return "ok"
|
return "ok"
|
||||||
|
|
||||||
|
|
@ -381,6 +390,7 @@ async def test_stop_hard_kills_running_agent():
|
||||||
fake_agent = MagicMock()
|
fake_agent = MagicMock()
|
||||||
fake_agent.get_activity_summary.return_value = {"seconds_since_activity": 0}
|
fake_agent.get_activity_summary.return_value = {"seconds_since_activity": 0}
|
||||||
runner._running_agents[session_key] = fake_agent
|
runner._running_agents[session_key] = fake_agent
|
||||||
|
runner.adapters[Platform.TELEGRAM]._active_sessions[session_key] = asyncio.Event()
|
||||||
|
|
||||||
# Send /stop
|
# Send /stop
|
||||||
stop_event = _make_event(text="/stop")
|
stop_event = _make_event(text="/stop")
|
||||||
|
|
@ -393,6 +403,10 @@ async def test_stop_hard_kills_running_agent():
|
||||||
assert session_key not in runner._running_agents, (
|
assert session_key not in runner._running_agents, (
|
||||||
"/stop must remove the agent from _running_agents so the session is unlocked"
|
"/stop must remove the agent from _running_agents so the session is unlocked"
|
||||||
)
|
)
|
||||||
|
assert runner.adapters[Platform.TELEGRAM].interrupted_sessions == [
|
||||||
|
(session_key, "12345")
|
||||||
|
]
|
||||||
|
assert runner.adapters[Platform.TELEGRAM]._active_sessions[session_key].is_set()
|
||||||
|
|
||||||
# Must return a confirmation
|
# Must return a confirmation
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ def _make_runner(session_entry: SessionEntry):
|
||||||
runner.session_store.rewrite_transcript = MagicMock()
|
runner.session_store.rewrite_transcript = MagicMock()
|
||||||
runner.session_store.update_session = MagicMock()
|
runner.session_store.update_session = MagicMock()
|
||||||
runner._running_agents = {}
|
runner._running_agents = {}
|
||||||
|
runner._session_run_generation = {}
|
||||||
runner._pending_messages = {}
|
runner._pending_messages = {}
|
||||||
runner._pending_approvals = {}
|
runner._pending_approvals = {}
|
||||||
runner._session_db = MagicMock()
|
runner._session_db = MagicMock()
|
||||||
|
|
@ -223,6 +224,52 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_message_discards_stale_result_after_session_invalidation(monkeypatch):
|
||||||
|
import gateway.run as gateway_run
|
||||||
|
|
||||||
|
session_entry = SessionEntry(
|
||||||
|
session_key=build_session_key(_make_source()),
|
||||||
|
session_id="sess-1",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
chat_type="dm",
|
||||||
|
)
|
||||||
|
runner = _make_runner(session_entry)
|
||||||
|
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
|
||||||
|
session_key = session_entry.session_key
|
||||||
|
runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks = {session_key: object()}
|
||||||
|
|
||||||
|
async def _stale_result(**kwargs):
|
||||||
|
runner._invalidate_session_run_generation(kwargs["session_key"], reason="test_stale_result")
|
||||||
|
return {
|
||||||
|
"final_response": "late reply",
|
||||||
|
"messages": [],
|
||||||
|
"tools": [],
|
||||||
|
"history_offset": 0,
|
||||||
|
"last_prompt_tokens": 80,
|
||||||
|
"input_tokens": 120,
|
||||||
|
"output_tokens": 45,
|
||||||
|
"model": "openai/test-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
runner._run_agent = AsyncMock(side_effect=_stale_result)
|
||||||
|
|
||||||
|
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"agent.model_metadata.get_model_context_length",
|
||||||
|
lambda *_args, **_kwargs: 100000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await runner._handle_message(_make_event("hello"))
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
runner.session_store.append_to_transcript.assert_not_called()
|
||||||
|
runner.session_store.update_session.assert_not_called()
|
||||||
|
assert session_key not in runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_status_command_bypasses_active_session_guard():
|
async def test_status_command_bypasses_active_session_guard():
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue