fix(gateway): stop typing loops on session interrupt

This commit is contained in:
helix4u 2026-04-18 21:21:55 -06:00 committed by kshitij
parent b05d30418d
commit 150382e8b7
6 changed files with 456 additions and 18 deletions

View file

@ -1401,7 +1401,13 @@ class BasePlatformAdapter(ABC):
return paths, cleaned return paths, cleaned
async def _keep_typing(self, chat_id: str, interval: float = 2.0, metadata=None) -> None: async def _keep_typing(
self,
chat_id: str,
interval: float = 2.0,
metadata=None,
stop_event: asyncio.Event | None = None,
) -> None:
""" """
Continuously send typing indicator until cancelled. Continuously send typing indicator until cancelled.
@ -1415,9 +1421,18 @@ class BasePlatformAdapter(ABC):
""" """
try: try:
while True: while True:
if stop_event is not None and stop_event.is_set():
return
if chat_id not in self._typing_paused: if chat_id not in self._typing_paused:
await self.send_typing(chat_id, metadata=metadata) await self.send_typing(chat_id, metadata=metadata)
await asyncio.sleep(interval) if stop_event is None:
await asyncio.sleep(interval)
continue
try:
await asyncio.wait_for(stop_event.wait(), timeout=interval)
except asyncio.TimeoutError:
continue
return
except asyncio.CancelledError: except asyncio.CancelledError:
pass # Normal cancellation when handler completes pass # Normal cancellation when handler completes
finally: finally:
@ -1444,6 +1459,17 @@ class BasePlatformAdapter(ABC):
"""Resume typing indicator for a chat after approval resolves.""" """Resume typing indicator for a chat after approval resolves."""
self._typing_paused.discard(chat_id) self._typing_paused.discard(chat_id)
async def interrupt_session_activity(self, session_key: str, chat_id: str) -> None:
"""Signal the active session loop to stop and clear typing immediately."""
if session_key:
interrupt_event = self._active_sessions.get(session_key)
if interrupt_event is not None:
interrupt_event.set()
try:
await self.stop_typing(chat_id)
except Exception:
pass
# ── Processing lifecycle hooks ────────────────────────────────────────── # ── Processing lifecycle hooks ──────────────────────────────────────────
# Subclasses override these to react to message processing events # Subclasses override these to react to message processing events
# (e.g. Discord adds 👀/✅/❌ reactions). # (e.g. Discord adds 👀/✅/❌ reactions).
@ -1717,7 +1743,13 @@ class BasePlatformAdapter(ABC):
# Start continuous typing indicator (refreshes every 2 seconds) # Start continuous typing indicator (refreshes every 2 seconds)
_thread_metadata = {"thread_id": event.source.thread_id} if event.source.thread_id else None _thread_metadata = {"thread_id": event.source.thread_id} if event.source.thread_id else None
typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id, metadata=_thread_metadata)) typing_task = asyncio.create_task(
self._keep_typing(
event.source.chat_id,
metadata=_thread_metadata,
stop_event=interrupt_event,
)
)
try: try:
await self._run_processing_hook("on_processing_start", event) await self._run_processing_hook("on_processing_start", event)

View file

@ -402,6 +402,26 @@ def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None:
return adapter.get_pending_message(session_key) return adapter.get_pending_message(session_key)
_CONTROL_INTERRUPT_MESSAGES = frozenset(
{
"stop requested",
"session reset requested",
"execution timed out (inactivity)",
"sse client disconnected",
"gateway shutting down",
"gateway restarting",
}
)
def _is_control_interrupt_message(message: Optional[str]) -> bool:
"""Return True when an interrupt message is internal control flow."""
if not message:
return False
normalized = " ".join(str(message).strip().split()).lower()
return normalized in _CONTROL_INTERRUPT_MESSAGES
def _check_unavailable_skill(command_name: str) -> str | None: def _check_unavailable_skill(command_name: str) -> str | None:
"""Check if a command matches a known-but-inactive skill. """Check if a command matches a known-but-inactive skill.
@ -630,6 +650,7 @@ class GatewayRunner:
self._running_agents_ts: Dict[str, float] = {} # start timestamp per session self._running_agents_ts: Dict[str, float] = {} # start timestamp per session
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
self._busy_ack_ts: Dict[str, float] = {} # last busy-ack timestamp per session (debounce) self._busy_ack_ts: Dict[str, float] = {} # last busy-ack timestamp per session (debounce)
self._session_run_generation: Dict[str, int] = {}
# Cache AIAgent instances per session to preserve prompt caching. # Cache AIAgent instances per session to preserve prompt caching.
# Without this, a new AIAgent is created per message, rebuilding the # Without this, a new AIAgent is created per message, rebuilding the
@ -3064,6 +3085,10 @@ class GatewayRunner:
_quick_key[:30], _stale_age, _stale_idle, _quick_key[:30], _stale_age, _stale_idle,
_raw_stale_timeout, _stale_detail, _raw_stale_timeout, _stale_detail,
) )
self._invalidate_session_run_generation(
_quick_key,
reason="stale_running_agent_eviction",
)
self._release_running_agent_state(_quick_key) self._release_running_agent_state(_quick_key)
if _quick_key in self._running_agents: if _quick_key in self._running_agents:
@ -3091,7 +3116,13 @@ class GatewayRunner:
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL: if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
running_agent.interrupt("Stop requested") running_agent.interrupt("Stop requested")
# Force-clean: remove the session lock regardless of agent state # Force-clean: remove the session lock regardless of agent state
self._invalidate_session_run_generation(
_quick_key,
reason="stop_command",
)
adapter = self.adapters.get(source.platform) adapter = self.adapters.get(source.platform)
if adapter and hasattr(adapter, "interrupt_session_activity"):
await adapter.interrupt_session_activity(_quick_key, source.chat_id)
if adapter and hasattr(adapter, 'get_pending_message'): if adapter and hasattr(adapter, 'get_pending_message'):
adapter.get_pending_message(_quick_key) # consume and discard adapter.get_pending_message(_quick_key) # consume and discard
self._pending_messages.pop(_quick_key, None) self._pending_messages.pop(_quick_key, None)
@ -3111,7 +3142,13 @@ class GatewayRunner:
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL: if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
running_agent.interrupt("Session reset requested") running_agent.interrupt("Session reset requested")
# Clear any pending messages so the old text doesn't replay # Clear any pending messages so the old text doesn't replay
self._invalidate_session_run_generation(
_quick_key,
reason="new_command",
)
adapter = self.adapters.get(source.platform) adapter = self.adapters.get(source.platform)
if adapter and hasattr(adapter, "interrupt_session_activity"):
await adapter.interrupt_session_activity(_quick_key, source.chat_id)
if adapter and hasattr(adapter, 'get_pending_message'): if adapter and hasattr(adapter, 'get_pending_message'):
adapter.get_pending_message(_quick_key) # consume and discard adapter.get_pending_message(_quick_key) # consume and discard
self._pending_messages.pop(_quick_key, None) self._pending_messages.pop(_quick_key, None)
@ -3598,9 +3635,10 @@ class GatewayRunner:
# same session — corrupting the transcript. # same session — corrupting the transcript.
self._running_agents[_quick_key] = _AGENT_PENDING_SENTINEL self._running_agents[_quick_key] = _AGENT_PENDING_SENTINEL
self._running_agents_ts[_quick_key] = time.time() self._running_agents_ts[_quick_key] = time.time()
_run_generation = self._begin_session_run_generation(_quick_key)
try: try:
return await self._handle_message_with_agent(event, source, _quick_key) return await self._handle_message_with_agent(event, source, _quick_key, _run_generation)
finally: finally:
# If _run_agent replaced the sentinel with a real agent and # If _run_agent replaced the sentinel with a real agent and
# then cleaned it up, this is a no-op. If we exited early # then cleaned it up, this is a no-op. If we exited early
@ -3771,7 +3809,7 @@ class GatewayRunner:
return message_text return message_text
async def _handle_message_with_agent(self, event, source, _quick_key: str): async def _handle_message_with_agent(self, event, source, _quick_key: str, run_generation: int):
"""Inner handler that runs under the _running_agents sentinel guard.""" """Inner handler that runs under the _running_agents sentinel guard."""
_msg_start_time = time.time() _msg_start_time = time.time()
_platform_name = source.platform.value if hasattr(source.platform, "value") else str(source.platform) _platform_name = source.platform.value if hasattr(source.platform, "value") else str(source.platform)
@ -4246,6 +4284,7 @@ class GatewayRunner:
source=source, source=source,
session_id=session_entry.session_id, session_id=session_entry.session_id,
session_key=session_key, session_key=session_key,
run_generation=run_generation,
event_message_id=event.message_id, event_message_id=event.message_id,
channel_prompt=event.channel_prompt, channel_prompt=event.channel_prompt,
) )
@ -4258,6 +4297,17 @@ class GatewayRunner:
except Exception: except Exception:
pass pass
if not self._is_session_run_current(_quick_key, run_generation):
logger.info(
"Discarding stale agent result for %s — generation %d is no longer current",
_quick_key[:20] if _quick_key else "?",
run_generation,
)
_stale_adapter = self.adapters.get(source.platform)
if _stale_adapter and hasattr(_stale_adapter, "_post_delivery_callbacks"):
_stale_adapter._post_delivery_callbacks.pop(_quick_key, None)
return None
response = agent_result.get("final_response") or "" response = agent_result.get("final_response") or ""
# Convert the agent's internal "(empty)" sentinel into a # Convert the agent's internal "(empty)" sentinel into a
@ -4672,6 +4722,7 @@ class GatewayRunner:
# Get existing session key # Get existing session key
session_key = self._session_key_for_source(source) session_key = self._session_key_for_source(source)
self._invalidate_session_run_generation(session_key, reason="session_reset")
# Flush memories in the background (fire-and-forget) so the user # Flush memories in the background (fire-and-forget) so the user
# gets the "Session reset!" response immediately. # gets the "Session reset!" response immediately.
@ -4931,6 +4982,10 @@ class GatewayRunner:
agent = self._running_agents.get(session_key) agent = self._running_agents.get(session_key)
if agent is _AGENT_PENDING_SENTINEL: if agent is _AGENT_PENDING_SENTINEL:
# Force-clean the sentinel so the session is unlocked. # Force-clean the sentinel so the session is unlocked.
self._invalidate_session_run_generation(session_key, reason="stop_command_pending")
adapter = self.adapters.get(source.platform)
if adapter and hasattr(adapter, "interrupt_session_activity"):
await adapter.interrupt_session_activity(session_key, source.chat_id)
self._release_running_agent_state(session_key) self._release_running_agent_state(session_key)
logger.info("STOP (pending) for session %s — sentinel cleared", session_key[:20]) logger.info("STOP (pending) for session %s — sentinel cleared", session_key[:20])
return "⚡ Stopped. The agent hadn't started yet — you can continue this session." return "⚡ Stopped. The agent hadn't started yet — you can continue this session."
@ -4938,6 +4993,10 @@ class GatewayRunner:
agent.interrupt("Stop requested") agent.interrupt("Stop requested")
# Force-clean the session lock so a truly hung agent doesn't # Force-clean the session lock so a truly hung agent doesn't
# keep it locked forever. # keep it locked forever.
self._invalidate_session_run_generation(session_key, reason="stop_command_handler")
adapter = self.adapters.get(source.platform)
if adapter and hasattr(adapter, "interrupt_session_activity"):
await adapter.interrupt_session_activity(session_key, source.chat_id)
self._release_running_agent_state(session_key) self._release_running_agent_state(session_key)
return "⚡ Stopped. You can continue this session." return "⚡ Stopped. You can continue this session."
else: else:
@ -8385,6 +8444,43 @@ class GatewayRunner:
if hasattr(self, "_busy_ack_ts"): if hasattr(self, "_busy_ack_ts"):
self._busy_ack_ts.pop(session_key, None) self._busy_ack_ts.pop(session_key, None)
def _begin_session_run_generation(self, session_key: str) -> int:
"""Claim a fresh run generation token for ``session_key``.
Every top-level gateway turn gets a monotonically increasing token.
If a later command like /stop or /new invalidates that token while the
old worker is still unwinding, the late result can be recognized and
dropped instead of bleeding into the fresh session.
"""
if not session_key:
return 0
generations = self.__dict__.get("_session_run_generation")
if generations is None:
generations = {}
self._session_run_generation = generations
next_generation = int(generations.get(session_key, 0)) + 1
generations[session_key] = next_generation
return next_generation
def _invalidate_session_run_generation(self, session_key: str, *, reason: str = "") -> int:
"""Invalidate any in-flight run token for ``session_key``."""
generation = self._begin_session_run_generation(session_key)
if reason:
logger.info(
"Invalidated run generation for %s%d (%s)",
session_key[:20],
generation,
reason,
)
return generation
def _is_session_run_current(self, session_key: str, generation: int) -> bool:
"""Return True when ``generation`` is still current for ``session_key``."""
if not session_key:
return True
generations = self.__dict__.get("_session_run_generation") or {}
return int(generations.get(session_key, 0)) == int(generation)
def _evict_cached_agent(self, session_key: str) -> None: def _evict_cached_agent(self, session_key: str) -> None:
"""Remove a cached agent for a session (called on /new, /model, etc).""" """Remove a cached agent for a session (called on /new, /model, etc)."""
_lock = getattr(self, "_agent_cache_lock", None) _lock = getattr(self, "_agent_cache_lock", None)
@ -8807,6 +8903,7 @@ class GatewayRunner:
source: SessionSource, source: SessionSource,
session_id: str, session_id: str,
session_key: str = None, session_key: str = None,
run_generation: Optional[int] = None,
_interrupt_depth: int = 0, _interrupt_depth: int = 0,
event_message_id: Optional[str] = None, event_message_id: Optional[str] = None,
channel_prompt: Optional[str] = None, channel_prompt: Optional[str] = None,
@ -8837,6 +8934,11 @@ class GatewayRunner:
from run_agent import AIAgent from run_agent import AIAgent
import queue import queue
def _run_still_current() -> bool:
if run_generation is None or not session_key:
return True
return self._is_session_run_current(session_key, run_generation)
user_config = _load_gateway_config() user_config = _load_gateway_config()
platform_key = _platform_config_key(source.platform) platform_key = _platform_config_key(source.platform)
@ -8891,7 +8993,7 @@ class GatewayRunner:
def progress_callback(event_type: str, tool_name: str = None, preview: str = None, args: dict = None, **kwargs): def progress_callback(event_type: str, tool_name: str = None, preview: str = None, args: dict = None, **kwargs):
"""Callback invoked by agent on tool lifecycle events.""" """Callback invoked by agent on tool lifecycle events."""
if not progress_queue: if not progress_queue or not _run_still_current():
return return
# Only act on tool.started events (ignore tool.completed, reasoning.available, etc.) # Only act on tool.started events (ignore tool.completed, reasoning.available, etc.)
@ -8996,6 +9098,14 @@ class GatewayRunner:
while True: while True:
try: try:
if not _run_still_current():
while not progress_queue.empty():
try:
progress_queue.get_nowait()
except Exception:
break
return
raw = progress_queue.get_nowait() raw = progress_queue.get_nowait()
# Handle dedup messages: update last line with repeat counter # Handle dedup messages: update last line with repeat counter
@ -9021,6 +9131,9 @@ class GatewayRunner:
await asyncio.sleep(_remaining) await asyncio.sleep(_remaining)
continue continue
if not _run_still_current():
return
if can_edit and progress_msg_id is not None: if can_edit and progress_msg_id is not None:
# Try to edit the existing progress message # Try to edit the existing progress message
full_text = "\n".join(progress_lines) full_text = "\n".join(progress_lines)
@ -9056,7 +9169,8 @@ class GatewayRunner:
# Restore typing indicator # Restore typing indicator
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
await adapter.send_typing(source.chat_id, metadata=_progress_metadata) if _run_still_current():
await adapter.send_typing(source.chat_id, metadata=_progress_metadata)
except queue.Empty: except queue.Empty:
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
@ -9100,6 +9214,8 @@ class GatewayRunner:
_hooks_ref = self.hooks _hooks_ref = self.hooks
def _step_callback_sync(iteration: int, prev_tools: list) -> None: def _step_callback_sync(iteration: int, prev_tools: list) -> None:
if not _run_still_current():
return
try: try:
# prev_tools may be list[str] or list[dict] with "name"/"result" # prev_tools may be list[str] or list[dict] with "name"/"result"
# keys. Normalise to keep "tool_names" backward-compatible for # keys. Normalise to keep "tool_names" backward-compatible for
@ -9130,7 +9246,7 @@ class GatewayRunner:
_status_thread_metadata = {"thread_id": _progress_thread_id} if _progress_thread_id else None _status_thread_metadata = {"thread_id": _progress_thread_id} if _progress_thread_id else None
def _status_callback_sync(event_type: str, message: str) -> None: def _status_callback_sync(event_type: str, message: str) -> None:
if not _status_adapter: if not _status_adapter or not _run_still_current():
return return
try: try:
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
@ -9261,12 +9377,16 @@ class GatewayRunner:
metadata={"thread_id": _progress_thread_id} if _progress_thread_id else None, metadata={"thread_id": _progress_thread_id} if _progress_thread_id else None,
) )
if _want_stream_deltas: if _want_stream_deltas:
_stream_delta_cb = _stream_consumer.on_delta def _stream_delta_cb(text: str) -> None:
if _run_still_current():
_stream_consumer.on_delta(text)
stream_consumer_holder[0] = _stream_consumer stream_consumer_holder[0] = _stream_consumer
except Exception as _sc_err: except Exception as _sc_err:
logger.debug("Could not set up stream consumer: %s", _sc_err) logger.debug("Could not set up stream consumer: %s", _sc_err)
def _interim_assistant_cb(text: str, *, already_streamed: bool = False) -> None: def _interim_assistant_cb(text: str, *, already_streamed: bool = False) -> None:
if not _run_still_current():
return
if _stream_consumer is not None: if _stream_consumer is not None:
if already_streamed: if already_streamed:
_stream_consumer.on_segment_break() _stream_consumer.on_segment_break()
@ -9370,7 +9490,7 @@ class GatewayRunner:
_bg_review_pending_lock = threading.Lock() _bg_review_pending_lock = threading.Lock()
def _deliver_bg_review_message(message: str) -> None: def _deliver_bg_review_message(message: str) -> None:
if not _status_adapter: if not _status_adapter or not _run_still_current():
return return
try: try:
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
@ -9394,7 +9514,7 @@ class GatewayRunner:
# Background review delivery — send "💾 Memory updated" etc. to user # Background review delivery — send "💾 Memory updated" etc. to user
def _bg_review_send(message: str) -> None: def _bg_review_send(message: str) -> None:
if not _status_adapter: if not _status_adapter or not _run_still_current():
return return
if not _bg_review_release.is_set(): if not _bg_review_release.is_set():
with _bg_review_pending_lock: with _bg_review_pending_lock:
@ -10076,7 +10196,15 @@ class GatewayRunner:
if result and adapter and session_key: if result and adapter and session_key:
pending_event = _dequeue_pending_event(adapter, session_key) pending_event = _dequeue_pending_event(adapter, session_key)
if result.get("interrupted") and not pending_event and result.get("interrupt_message"): if result.get("interrupted") and not pending_event and result.get("interrupt_message"):
pending = result.get("interrupt_message") interrupt_message = result.get("interrupt_message")
if _is_control_interrupt_message(interrupt_message):
logger.info(
"Ignoring control interrupt message for session %s: %s",
session_key[:20] if session_key else "?",
interrupt_message,
)
else:
pending = interrupt_message
elif pending_event: elif pending_event:
pending = pending_event.text or _build_media_placeholder(pending_event) pending = pending_event.text or _build_media_placeholder(pending_event)
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40]) logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
@ -10229,6 +10357,7 @@ class GatewayRunner:
source=next_source, source=next_source,
session_id=session_id, session_id=session_id,
session_key=session_key, session_key=session_key,
run_generation=run_generation,
_interrupt_depth=_interrupt_depth + 1, _interrupt_depth=_interrupt_depth + 1,
event_message_id=next_message_id, event_message_id=next_message_id,
channel_prompt=next_channel_prompt, channel_prompt=next_channel_prompt,

View file

@ -1,13 +1,18 @@
"""Tests for the pending_event None guard in recursive _run_agent calls. """Tests for pending follow-up extraction in recursive _run_agent calls.
When pending_event is None (Path B: pending comes from interrupt_message), When pending_event is None (Path B: pending comes from interrupt_message),
accessing pending_event.channel_prompt previously raised AttributeError. accessing pending_event.channel_prompt previously raised AttributeError.
This verifies the fix: channel_prompt is captured inside the This verifies the fix: channel_prompt is captured inside the
`if pending_event is not None:` block and falls back to None otherwise. `if pending_event is not None:` block and falls back to None otherwise.
Also verifies that internal control interrupt reasons like "Stop requested"
do not get recycled into the pending-user-message follow-up path.
""" """
from types import SimpleNamespace from types import SimpleNamespace
from gateway.run import _is_control_interrupt_message
def _extract_channel_prompt(pending_event): def _extract_channel_prompt(pending_event):
"""Reproduce the fixed logic from gateway/run.py. """Reproduce the fixed logic from gateway/run.py.
@ -21,6 +26,15 @@ def _extract_channel_prompt(pending_event):
return next_channel_prompt return next_channel_prompt
def _extract_pending_text(interrupted, pending_event, interrupt_message):
"""Reproduce the fixed pending-text selection from gateway/run.py."""
if interrupted and pending_event is None and interrupt_message:
if _is_control_interrupt_message(interrupt_message):
return None
return interrupt_message
return None
class TestPendingEventNoneChannelPrompt: class TestPendingEventNoneChannelPrompt:
"""Guard against AttributeError when pending_event is None.""" """Guard against AttributeError when pending_event is None."""
@ -40,3 +54,19 @@ class TestPendingEventNoneChannelPrompt:
event = SimpleNamespace() event = SimpleNamespace()
result = _extract_channel_prompt(event) result = _extract_channel_prompt(event)
assert result is None assert result is None
class TestControlInterruptMessages:
"""Control interrupt reasons must not become follow-up user input."""
def test_stop_requested_is_not_treated_as_pending_user_message(self):
result = _extract_pending_text(True, None, "Stop requested")
assert result is None
def test_session_reset_requested_is_not_treated_as_pending_user_message(self):
result = _extract_pending_text(True, None, "Session reset requested")
assert result is None
def test_real_user_interrupt_message_still_requeues(self):
result = _extract_pending_text(True, None, "actually use postgres instead")
assert result == "actually use postgres instead"

View file

@ -51,6 +51,9 @@ class ProgressCaptureAdapter(BasePlatformAdapter):
async def send_typing(self, chat_id, metadata=None) -> None: async def send_typing(self, chat_id, metadata=None) -> None:
self.typing.append({"chat_id": chat_id, "metadata": metadata}) self.typing.append({"chat_id": chat_id, "metadata": metadata})
async def stop_typing(self, chat_id) -> None:
self.typing.append({"chat_id": chat_id, "metadata": {"stopped": True}})
async def get_chat_info(self, chat_id: str): async def get_chat_info(self, chat_id: str):
return {"id": chat_id} return {"id": chat_id}
@ -90,6 +93,40 @@ class LongPreviewAgent:
} }
class DelayedProgressAgent:
def __init__(self, **kwargs):
self.tool_progress_callback = kwargs.get("tool_progress_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
self.tool_progress_callback("tool.started", "terminal", "first command", {})
time.sleep(0.45)
self.tool_progress_callback("tool.started", "terminal", "second command", {})
time.sleep(0.1)
return {
"final_response": "done",
"messages": [],
"api_calls": 1,
}
class DelayedInterimAgent:
def __init__(self, **kwargs):
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
self.interim_assistant_callback("first interim")
time.sleep(0.45)
self.interim_assistant_callback("second interim")
time.sleep(0.1)
return {
"final_response": "done",
"messages": [],
"api_calls": 1,
}
def _make_runner(adapter): def _make_runner(adapter):
gateway_run = importlib.import_module("gateway.run") gateway_run = importlib.import_module("gateway.run")
GatewayRunner = gateway_run.GatewayRunner GatewayRunner = gateway_run.GatewayRunner
@ -104,6 +141,7 @@ def _make_runner(adapter):
runner._fallback_model = None runner._fallback_model = None
runner._session_db = None runner._session_db = None
runner._running_agents = {} runner._running_agents = {}
runner._session_run_generation = {}
runner.hooks = SimpleNamespace(loaded_hooks=False) runner.hooks = SimpleNamespace(loaded_hooks=False)
runner.config = SimpleNamespace( runner.config = SimpleNamespace(
thread_sessions_per_user=False, thread_sessions_per_user=False,
@ -744,6 +782,154 @@ async def test_base_processing_releases_post_delivery_callback_after_main_send()
assert released == [True] assert released == [True]
@pytest.mark.asyncio
async def test_run_agent_drops_tool_progress_after_generation_invalidation(monkeypatch, tmp_path):
import yaml
(tmp_path / "config.yaml").write_text(
yaml.dump({"display": {"tool_progress": "all"}}),
encoding="utf-8",
)
fake_dotenv = types.ModuleType("dotenv")
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = DelayedProgressAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
import tools.terminal_tool # noqa: F401 - register terminal tool metadata
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
runner = _make_runner(adapter)
gateway_run = importlib.import_module("gateway.run")
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
source = SessionSource(
platform=Platform.DISCORD,
chat_id="dm-1",
chat_type="dm",
thread_id=None,
)
session_key = "agent:main:discord:dm:dm-1"
runner._session_run_generation[session_key] = 1
original_send = adapter.send
invalidated = {"done": False}
async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None):
result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata)
if "first command" in content and not invalidated["done"]:
invalidated["done"] = True
runner._invalidate_session_run_generation(session_key, reason="test_stop")
return result
adapter.send = send_and_invalidate
result = await runner._run_agent(
message="hello",
context_prompt="",
history=[],
source=source,
session_id="sess-progress-stop",
session_key=session_key,
run_generation=1,
)
all_progress_text = " ".join(call["content"] for call in adapter.sent)
all_progress_text += " ".join(call["content"] for call in adapter.edits)
assert result["final_response"] == "done"
assert 'first command' in all_progress_text
assert 'second command' not in all_progress_text
@pytest.mark.asyncio
async def test_run_agent_drops_interim_commentary_after_generation_invalidation(monkeypatch, tmp_path):
import yaml
(tmp_path / "config.yaml").write_text(
yaml.dump({"display": {"tool_progress": "off", "interim_assistant_messages": True}}),
encoding="utf-8",
)
fake_dotenv = types.ModuleType("dotenv")
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = DelayedInterimAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
runner = _make_runner(adapter)
gateway_run = importlib.import_module("gateway.run")
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
source = SessionSource(
platform=Platform.DISCORD,
chat_id="dm-2",
chat_type="dm",
thread_id=None,
)
session_key = "agent:main:discord:dm:dm-2"
runner._session_run_generation[session_key] = 1
original_send = adapter.send
invalidated = {"done": False}
async def send_and_invalidate(chat_id, content, reply_to=None, metadata=None):
result = await original_send(chat_id, content, reply_to=reply_to, metadata=metadata)
if content == "first interim" and not invalidated["done"]:
invalidated["done"] = True
runner._invalidate_session_run_generation(session_key, reason="test_stop")
return result
adapter.send = send_and_invalidate
result = await runner._run_agent(
message="hello",
context_prompt="",
history=[],
source=source,
session_id="sess-commentary-stop",
session_key=session_key,
run_generation=1,
)
sent_texts = [call["content"] for call in adapter.sent]
assert result["final_response"] == "done"
assert "first interim" in sent_texts
assert "second interim" not in sent_texts
@pytest.mark.asyncio
async def test_keep_typing_stops_immediately_when_interrupt_event_is_set():
adapter = ProgressCaptureAdapter(platform=Platform.DISCORD)
stop_event = asyncio.Event()
task = asyncio.create_task(
adapter._keep_typing(
"dm-typing-stop",
interval=30.0,
stop_event=stop_event,
)
)
await asyncio.sleep(0.05)
stop_event.set()
await asyncio.wait_for(task, timeout=0.5)
normal_typing_calls = [
call for call in adapter.typing if call.get("metadata") != {"stopped": True}
]
stopped_calls = [
call for call in adapter.typing if call.get("metadata") == {"stopped": True}
]
assert len(normal_typing_calls) == 1
assert len(stopped_calls) == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_verbose_mode_does_not_truncate_args_by_default(monkeypatch, tmp_path): async def test_verbose_mode_does_not_truncate_args_by_default(monkeypatch, tmp_path):
"""Verbose mode with default tool_preview_length (0) should NOT truncate args. """Verbose mode with default tool_preview_length (0) should NOT truncate args.

View file

@ -24,10 +24,18 @@ class _FakeAdapter:
def __init__(self): def __init__(self):
self._pending_messages = {} self._pending_messages = {}
self._active_sessions = {}
self.interrupted_sessions = []
async def send(self, chat_id, text, **kwargs): async def send(self, chat_id, text, **kwargs):
pass pass
async def interrupt_session_activity(self, session_key, chat_id):
self.interrupted_sessions.append((session_key, chat_id))
event = self._active_sessions.get(session_key)
if event is not None:
event.set()
def _make_runner(): def _make_runner():
runner = object.__new__(GatewayRunner) runner = object.__new__(GatewayRunner)
@ -37,6 +45,7 @@ def _make_runner():
runner.adapters = {Platform.TELEGRAM: _FakeAdapter()} runner.adapters = {Platform.TELEGRAM: _FakeAdapter()}
runner._running_agents = {} runner._running_agents = {}
runner._running_agents_ts = {} runner._running_agents_ts = {}
runner._session_run_generation = {}
runner._pending_messages = {} runner._pending_messages = {}
runner._pending_approvals = {} runner._pending_approvals = {}
runner._voice_mode = {} runner._voice_mode = {}
@ -81,7 +90,7 @@ async def test_sentinel_placed_before_agent_setup():
# Patch _handle_message_with_agent to capture state at entry # Patch _handle_message_with_agent to capture state at entry
sentinel_was_set = False sentinel_was_set = False
async def mock_inner(self_inner, ev, src, qk): async def mock_inner(self_inner, ev, src, qk, generation):
nonlocal sentinel_was_set nonlocal sentinel_was_set
sentinel_was_set = runner._running_agents.get(qk) is _AGENT_PENDING_SENTINEL sentinel_was_set = runner._running_agents.get(qk) is _AGENT_PENDING_SENTINEL
return "ok" return "ok"
@ -105,7 +114,7 @@ async def test_sentinel_cleaned_up_after_handler_returns():
event = _make_event() event = _make_event()
session_key = build_session_key(event.source) session_key = build_session_key(event.source)
async def mock_inner(self_inner, ev, src, qk): async def mock_inner(self_inner, ev, src, qk, generation):
return "ok" return "ok"
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner): with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
@ -127,7 +136,7 @@ async def test_sentinel_cleaned_up_on_exception():
event = _make_event() event = _make_event()
session_key = build_session_key(event.source) session_key = build_session_key(event.source)
async def mock_inner(self_inner, ev, src, qk): async def mock_inner(self_inner, ev, src, qk, generation):
raise RuntimeError("boom") raise RuntimeError("boom")
with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner): with patch.object(GatewayRunner, "_handle_message_with_agent", mock_inner):
@ -154,7 +163,7 @@ async def test_second_message_during_sentinel_queued_not_duplicate():
barrier = asyncio.Event() barrier = asyncio.Event()
async def slow_inner(self_inner, ev, src, qk): async def slow_inner(self_inner, ev, src, qk, generation):
# Simulate slow setup — wait until test tells us to proceed # Simulate slow setup — wait until test tells us to proceed
await barrier.wait() await barrier.wait()
return "ok" return "ok"
@ -333,7 +342,7 @@ async def test_stop_during_sentinel_force_cleans_session():
barrier = asyncio.Event() barrier = asyncio.Event()
async def slow_inner(self_inner, ev, src, qk): async def slow_inner(self_inner, ev, src, qk, generation):
await barrier.wait() await barrier.wait()
return "ok" return "ok"
@ -381,6 +390,7 @@ async def test_stop_hard_kills_running_agent():
fake_agent = MagicMock() fake_agent = MagicMock()
fake_agent.get_activity_summary.return_value = {"seconds_since_activity": 0} fake_agent.get_activity_summary.return_value = {"seconds_since_activity": 0}
runner._running_agents[session_key] = fake_agent runner._running_agents[session_key] = fake_agent
runner.adapters[Platform.TELEGRAM]._active_sessions[session_key] = asyncio.Event()
# Send /stop # Send /stop
stop_event = _make_event(text="/stop") stop_event = _make_event(text="/stop")
@ -393,6 +403,10 @@ async def test_stop_hard_kills_running_agent():
assert session_key not in runner._running_agents, ( assert session_key not in runner._running_agents, (
"/stop must remove the agent from _running_agents so the session is unlocked" "/stop must remove the agent from _running_agents so the session is unlocked"
) )
assert runner.adapters[Platform.TELEGRAM].interrupted_sessions == [
(session_key, "12345")
]
assert runner.adapters[Platform.TELEGRAM]._active_sessions[session_key].is_set()
# Must return a confirmation # Must return a confirmation
assert result is not None assert result is not None

View file

@ -50,6 +50,7 @@ def _make_runner(session_entry: SessionEntry):
runner.session_store.rewrite_transcript = MagicMock() runner.session_store.rewrite_transcript = MagicMock()
runner.session_store.update_session = MagicMock() runner.session_store.update_session = MagicMock()
runner._running_agents = {} runner._running_agents = {}
runner._session_run_generation = {}
runner._pending_messages = {} runner._pending_messages = {}
runner._pending_approvals = {} runner._pending_approvals = {}
runner._session_db = MagicMock() runner._session_db = MagicMock()
@ -223,6 +224,52 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
) )
@pytest.mark.asyncio
async def test_handle_message_discards_stale_result_after_session_invalidation(monkeypatch):
import gateway.run as gateway_run
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner = _make_runner(session_entry)
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
session_key = session_entry.session_key
runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks = {session_key: object()}
async def _stale_result(**kwargs):
runner._invalidate_session_run_generation(kwargs["session_key"], reason="test_stale_result")
return {
"final_response": "late reply",
"messages": [],
"tools": [],
"history_offset": 0,
"last_prompt_tokens": 80,
"input_tokens": 120,
"output_tokens": 45,
"model": "openai/test-model",
}
runner._run_agent = AsyncMock(side_effect=_stale_result)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
monkeypatch.setattr(
"agent.model_metadata.get_model_context_length",
lambda *_args, **_kwargs: 100000,
)
result = await runner._handle_message(_make_event("hello"))
assert result is None
runner.session_store.append_to_transcript.assert_not_called()
runner.session_store.update_session.assert_not_called()
assert session_key not in runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_status_command_bypasses_active_session_guard(): async def test_status_command_bypasses_active_session_guard():