diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index dc0f22d2a..2b8536062 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -881,10 +881,11 @@ class BasePlatformAdapter(ABC): # working on a task after --replace or manual restarts. self._background_tasks: set[asyncio.Task] = set() # One-shot callbacks to fire after the main response is delivered. - # Keyed by session_key. GatewayRunner uses this to defer - # background-review notifications ("💾 Skill created") until the - # primary reply has been sent. - self._post_delivery_callbacks: Dict[str, Callable] = {} + # Keyed by session_key. Values are either a bare callback (legacy) or + # a ``(generation, callback)`` tuple so GatewayRunner can make deferred + # deliveries generation-aware and avoid stale runs clearing callbacks + # registered by a fresher run for the same session. + self._post_delivery_callbacks: Dict[str, Any] = {} self._expected_cancelled_tasks: set[asyncio.Task] = set() self._busy_session_handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]] = None # Chats where auto-TTS on voice input is disabled (set by /voice off) @@ -1471,6 +1472,48 @@ class BasePlatformAdapter(ABC): except Exception: pass + def register_post_delivery_callback( + self, + session_key: str, + callback: Callable, + *, + generation: int | None = None, + ) -> None: + """Register a deferred callback to fire after the main response. + + ``generation`` lets callers tie the callback to a specific gateway run + generation so stale runs cannot clear callbacks owned by a fresher run. + """ + if not session_key or not callable(callback): + return + if generation is None: + self._post_delivery_callbacks[session_key] = callback + else: + self._post_delivery_callbacks[session_key] = (int(generation), callback) + + def pop_post_delivery_callback( + self, + session_key: str, + *, + generation: int | None = None, + ) -> Callable | None: + """Pop a deferred callback, optionally requiring generation ownership.""" + if not session_key: + return None + entry = self._post_delivery_callbacks.get(session_key) + if entry is None: + return None + if isinstance(entry, tuple) and len(entry) == 2: + entry_generation, callback = entry + if generation is not None and int(entry_generation) != int(generation): + return None + self._post_delivery_callbacks.pop(session_key, None) + return callback if callable(callback) else None + if generation is not None: + return None + self._post_delivery_callbacks.pop(session_key, None) + return entry if callable(entry) else None + # ── Processing lifecycle hooks ────────────────────────────────────────── # Subclasses override these to react to message processing events # (e.g. Discord adds 👀/✅/❌ reactions). @@ -1741,6 +1784,7 @@ class BasePlatformAdapter(ABC): # Fall back to a new Event only if the entry was removed externally. interrupt_event = self._active_sessions.get(session_key) or asyncio.Event() self._active_sessions[session_key] = interrupt_event + callback_generation = getattr(interrupt_event, "_hermes_run_generation", None) # Start continuous typing indicator (refreshes every 2 seconds) _thread_metadata = {"thread_id": event.source.thread_id} if event.source.thread_id else None @@ -2015,7 +2059,14 @@ class BasePlatformAdapter(ABC): finally: # Fire any one-shot post-delivery callback registered for this # session (e.g. deferred background-review notifications). - _post_cb = getattr(self, "_post_delivery_callbacks", {}).pop(session_key, None) + _callback_generation = callback_generation + if hasattr(self, "pop_post_delivery_callback"): + _post_cb = self.pop_post_delivery_callback( + session_key, + generation=_callback_generation, + ) + else: + _post_cb = getattr(self, "_post_delivery_callbacks", {}).pop(session_key, None) if callable(_post_cb): try: _post_cb() @@ -2061,10 +2112,10 @@ class BasePlatformAdapter(ABC): pass # Leave _active_sessions[session_key] populated — the drain # task's own lifecycle will clean it up. - return - # Clean up session tracking - if session_key in self._active_sessions: - del self._active_sessions[session_key] + else: + # Clean up session tracking + if session_key in self._active_sessions: + del self._active_sessions[session_key] async def cancel_background_tasks(self) -> None: """Cancel any in-flight background message-processing tasks. diff --git a/gateway/run.py b/gateway/run.py index ed3b6b5ee..60c57495b 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -402,14 +402,21 @@ def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None: return adapter.get_pending_message(session_key) +_INTERRUPT_REASON_STOP = "Stop requested" +_INTERRUPT_REASON_RESET = "Session reset requested" +_INTERRUPT_REASON_TIMEOUT = "Execution timed out (inactivity)" +_INTERRUPT_REASON_SSE_DISCONNECT = "SSE client disconnected" +_INTERRUPT_REASON_GATEWAY_SHUTDOWN = "Gateway shutting down" +_INTERRUPT_REASON_GATEWAY_RESTART = "Gateway restarting" + _CONTROL_INTERRUPT_MESSAGES = frozenset( { - "stop requested", - "session reset requested", - "execution timed out (inactivity)", - "sse client disconnected", - "gateway shutting down", - "gateway restarting", + _INTERRUPT_REASON_STOP.lower(), + _INTERRUPT_REASON_RESET.lower(), + _INTERRUPT_REASON_TIMEOUT.lower(), + _INTERRUPT_REASON_SSE_DISCONNECT.lower(), + _INTERRUPT_REASON_GATEWAY_SHUTDOWN.lower(), + _INTERRUPT_REASON_GATEWAY_RESTART.lower(), } ) @@ -2514,7 +2521,7 @@ class GatewayRunner: _sk[:20], _e, ) self._interrupt_running_agents( - "Gateway restarting" if self._restart_requested else "Gateway shutting down" + _INTERRUPT_REASON_GATEWAY_RESTART if self._restart_requested else _INTERRUPT_REASON_GATEWAY_SHUTDOWN ) interrupt_deadline = asyncio.get_running_loop().time() + 5.0 while self._running_agents and asyncio.get_running_loop().time() < interrupt_deadline: @@ -3112,21 +3119,12 @@ class GatewayRunner: # _interrupt_requested. Force-clean _running_agents so the session # is unlocked and subsequent messages are processed normally. if _cmd_def_inner and _cmd_def_inner.name == "stop": - running_agent = self._running_agents.get(_quick_key) - if running_agent and running_agent is not _AGENT_PENDING_SENTINEL: - running_agent.interrupt("Stop requested") - # Force-clean: remove the session lock regardless of agent state - self._invalidate_session_run_generation( + await self._interrupt_and_clear_session( _quick_key, - reason="stop_command", + source, + interrupt_reason=_INTERRUPT_REASON_STOP, + invalidation_reason="stop_command", ) - adapter = self.adapters.get(source.platform) - if adapter and hasattr(adapter, "interrupt_session_activity"): - await adapter.interrupt_session_activity(_quick_key, source.chat_id) - if adapter and hasattr(adapter, 'get_pending_message'): - adapter.get_pending_message(_quick_key) # consume and discard - self._pending_messages.pop(_quick_key, None) - self._release_running_agent_state(_quick_key) logger.info("STOP for session %s — agent interrupted, session lock released", _quick_key[:20]) return "⚡ Stopped. You can continue this session." @@ -3138,23 +3136,15 @@ class GatewayRunner: # doesn't get re-processed as a user message after the # interrupt completes. if _cmd_def_inner and _cmd_def_inner.name == "new": - running_agent = self._running_agents.get(_quick_key) - if running_agent and running_agent is not _AGENT_PENDING_SENTINEL: - running_agent.interrupt("Session reset requested") # Clear any pending messages so the old text doesn't replay - self._invalidate_session_run_generation( + await self._interrupt_and_clear_session( _quick_key, - reason="new_command", + source, + interrupt_reason=_INTERRUPT_REASON_RESET, + invalidation_reason="new_command", ) - adapter = self.adapters.get(source.platform) - if adapter and hasattr(adapter, "interrupt_session_activity"): - await adapter.interrupt_session_activity(_quick_key, source.chat_id) - if adapter and hasattr(adapter, 'get_pending_message'): - adapter.get_pending_message(_quick_key) # consume and discard - self._pending_messages.pop(_quick_key, None) # Clean up the running agent entry so the reset handler # doesn't think an agent is still active. - self._release_running_agent_state(_quick_key) return await self._handle_reset_command(event) # /queue — queue without interrupting @@ -4266,6 +4256,15 @@ class GatewayRunner: if message_text is None: return + # Bind this gateway run generation to the adapter's active-session + # event so deferred post-delivery callbacks can be released by the + # same run that registered them. + self._bind_adapter_run_generation( + self.adapters.get(source.platform), + session_key, + run_generation, + ) + try: # Emit agent:start hook hook_ctx = { @@ -4304,7 +4303,12 @@ class GatewayRunner: run_generation, ) _stale_adapter = self.adapters.get(source.platform) - if _stale_adapter and hasattr(_stale_adapter, "_post_delivery_callbacks"): + if getattr(type(_stale_adapter), "pop_post_delivery_callback", None) is not None: + _stale_adapter.pop_post_delivery_callback( + _quick_key, + generation=run_generation, + ) + elif _stale_adapter and hasattr(_stale_adapter, "_post_delivery_callbacks"): _stale_adapter._post_delivery_callbacks.pop(_quick_key, None) return None @@ -4982,22 +4986,23 @@ class GatewayRunner: agent = self._running_agents.get(session_key) if agent is _AGENT_PENDING_SENTINEL: # Force-clean the sentinel so the session is unlocked. - self._invalidate_session_run_generation(session_key, reason="stop_command_pending") - adapter = self.adapters.get(source.platform) - if adapter and hasattr(adapter, "interrupt_session_activity"): - await adapter.interrupt_session_activity(session_key, source.chat_id) - self._release_running_agent_state(session_key) + await self._interrupt_and_clear_session( + session_key, + source, + interrupt_reason=_INTERRUPT_REASON_STOP, + invalidation_reason="stop_command_pending", + ) logger.info("STOP (pending) for session %s — sentinel cleared", session_key[:20]) return "⚡ Stopped. The agent hadn't started yet — you can continue this session." if agent: - agent.interrupt("Stop requested") # Force-clean the session lock so a truly hung agent doesn't # keep it locked forever. - self._invalidate_session_run_generation(session_key, reason="stop_command_handler") - adapter = self.adapters.get(source.platform) - if adapter and hasattr(adapter, "interrupt_session_activity"): - await adapter.interrupt_session_activity(session_key, source.chat_id) - self._release_running_agent_state(session_key) + await self._interrupt_and_clear_session( + session_key, + source, + interrupt_reason=_INTERRUPT_REASON_STOP, + invalidation_reason="stop_command_handler", + ) return "⚡ Stopped. You can continue this session." else: return "No active task to stop." @@ -8481,6 +8486,47 @@ class GatewayRunner: generations = self.__dict__.get("_session_run_generation") or {} return int(generations.get(session_key, 0)) == int(generation) + def _bind_adapter_run_generation( + self, + adapter: Any, + session_key: str, + generation: int | None, + ) -> None: + """Bind a gateway run generation to the adapter's active-session event.""" + if not adapter or not session_key or generation is None: + return + try: + interrupt_event = getattr(adapter, "_active_sessions", {}).get(session_key) + if interrupt_event is not None: + setattr(interrupt_event, "_hermes_run_generation", int(generation)) + except Exception: + pass + + async def _interrupt_and_clear_session( + self, + session_key: str, + source: SessionSource, + *, + interrupt_reason: str, + invalidation_reason: str, + release_running_state: bool = True, + ) -> None: + """Interrupt the current run and clear queued session state consistently.""" + if not session_key: + return + running_agent = self._running_agents.get(session_key) + if running_agent and running_agent is not _AGENT_PENDING_SENTINEL: + running_agent.interrupt(interrupt_reason) + self._invalidate_session_run_generation(session_key, reason=invalidation_reason) + adapter = self.adapters.get(source.platform) + if adapter and hasattr(adapter, "interrupt_session_activity"): + await adapter.interrupt_session_activity(session_key, source.chat_id) + if adapter and hasattr(adapter, "get_pending_message"): + adapter.get_pending_message(session_key) # consume and discard + self._pending_messages.pop(session_key, None) + if release_running_state: + self._release_running_agent_state(session_key) + def _evict_cached_agent(self, session_key: str) -> None: """Remove a cached agent for a session (called on /new, /model, etc).""" _lock = getattr(self, "_agent_cache_lock", None) @@ -8662,6 +8708,7 @@ class GatewayRunner: source: "SessionSource", session_id: str, session_key: str = None, + run_generation: Optional[int] = None, event_message_id: Optional[str] = None, ) -> Dict[str, Any]: """Forward the message to a remote Hermes API server instead of @@ -8697,6 +8744,11 @@ class GatewayRunner: proxy_key = os.getenv("GATEWAY_PROXY_KEY", "").strip() + def _run_still_current() -> bool: + if run_generation is None or not session_key: + return True + return self._is_session_run_current(session_key, run_generation) + # Build messages in OpenAI chat format -------------------------- # # The remote api_server can maintain session continuity via @@ -8826,6 +8878,21 @@ class GatewayRunner: # Parse SSE stream buffer = "" async for chunk in resp.content.iter_any(): + if not _run_still_current(): + logger.info( + "Discarding stale proxy stream for %s — generation %d is no longer current", + session_key[:20] if session_key else "?", + run_generation or 0, + ) + return { + "final_response": "", + "messages": [], + "api_calls": 0, + "tools": [], + "history_offset": len(history), + "session_id": session_id, + "response_previewed": False, + } text = chunk.decode("utf-8", errors="replace") buffer += text @@ -8875,6 +8942,21 @@ class GatewayRunner: stream_task.cancel() _elapsed = time.time() - _start + if not _run_still_current(): + logger.info( + "Discarding stale proxy result for %s — generation %d is no longer current", + session_key[:20] if session_key else "?", + run_generation or 0, + ) + return { + "final_response": "", + "messages": [], + "api_calls": 0, + "tools": [], + "history_offset": len(history), + "session_id": session_id, + "response_previewed": False, + } logger.info( "proxy response: url=%s session=%s time=%.1fs response=%d chars", proxy_url, (session_id or "")[:20], _elapsed, len(full_response), @@ -8929,6 +9011,7 @@ class GatewayRunner: source=source, session_id=session_id, session_key=session_key, + run_generation=run_generation, event_message_id=event_message_id, ) @@ -9527,9 +9610,16 @@ class GatewayRunner: # Register the release hook on the adapter so base.py's finally # block can fire it after delivering the main response. if _status_adapter and session_key: - _pdc = getattr(_status_adapter, "_post_delivery_callbacks", None) - if _pdc is not None: - _pdc[session_key] = _release_bg_review_messages + if getattr(type(_status_adapter), "register_post_delivery_callback", None) is not None: + _status_adapter.register_post_delivery_callback( + session_key, + _release_bg_review_messages, + generation=run_generation, + ) + else: + _pdc = getattr(_status_adapter, "_post_delivery_callbacks", None) + if _pdc is not None: + _pdc[session_key] = _release_bg_review_messages # Store agent reference for interrupt support agent_holder[0] = agent @@ -10131,7 +10221,7 @@ class GatewayRunner: # Interrupt the agent if it's still running so the thread # pool worker is freed. if _timed_out_agent and hasattr(_timed_out_agent, "interrupt"): - _timed_out_agent.interrupt("Execution timed out (inactivity)") + _timed_out_agent.interrupt(_INTERRUPT_REASON_TIMEOUT) _timeout_mins = int(_agent_timeout // 60) or 1 @@ -10309,7 +10399,17 @@ class GatewayRunner: # first response has been delivered. Pop from the # adapter's callback dict (prevents double-fire in # base.py's finally block) and call it. - if adapter and hasattr(adapter, "_post_delivery_callbacks"): + if getattr(type(adapter), "pop_post_delivery_callback", None) is not None: + _bg_cb = adapter.pop_post_delivery_callback( + session_key, + generation=run_generation, + ) + if callable(_bg_cb): + try: + _bg_cb() + except Exception: + pass + elif adapter and hasattr(adapter, "_post_delivery_callbacks"): _bg_cb = adapter._post_delivery_callbacks.pop(session_key, None) if callable(_bg_cb): try: diff --git a/tests/gateway/test_proxy_mode.py b/tests/gateway/test_proxy_mode.py index f3024cb09..11180639e 100644 --- a/tests/gateway/test_proxy_mode.py +++ b/tests/gateway/test_proxy_mode.py @@ -19,6 +19,7 @@ def _make_runner(proxy_url=None): runner.config = MagicMock() runner.config.streaming = StreamingConfig() runner._running_agents = {} + runner._session_run_generation = {} runner._session_model_overrides = {} runner._agent_cache = {} runner._agent_cache_lock = None @@ -160,10 +161,12 @@ class TestRunAgentProxyDispatch: source=source, session_id="test-session-123", session_key="test-key", + run_generation=7, ) assert result["final_response"] == "Hello from remote!" runner._run_agent_via_proxy.assert_called_once() + assert runner._run_agent_via_proxy.call_args.kwargs["run_generation"] == 7 @pytest.mark.asyncio async def test_run_agent_skips_proxy_when_not_configured(self, monkeypatch): @@ -370,6 +373,40 @@ class TestRunAgentViaProxy: assert "session_id" in result assert result["session_id"] == "sess-123" + @pytest.mark.asyncio + async def test_proxy_stale_generation_returns_empty_result(self, monkeypatch): + monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642") + monkeypatch.delenv("GATEWAY_PROXY_KEY", raising=False) + runner = _make_runner() + source = _make_source() + runner._session_run_generation["test-key"] = 2 + + resp = _FakeSSEResponse( + status=200, + sse_chunks=[ + 'data: {"choices":[{"delta":{"content":"stale"}}]}\n\n', + "data: [DONE]\n\n", + ], + ) + session = _FakeSession(resp) + + with patch("gateway.run._load_gateway_config", return_value={}): + with _patch_aiohttp(session): + with patch("aiohttp.ClientTimeout"): + result = await runner._run_agent_via_proxy( + message="hi", + context_prompt="", + history=[], + source=source, + session_id="sess-123", + session_key="test-key", + run_generation=1, + ) + + assert result["final_response"] == "" + assert result["messages"] == [] + assert result["api_calls"] == 0 + @pytest.mark.asyncio async def test_no_auth_header_without_key(self, monkeypatch): monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642") diff --git a/tests/gateway/test_status_command.py b/tests/gateway/test_status_command.py index 3cdf637dd..50e1c52cc 100644 --- a/tests/gateway/test_status_command.py +++ b/tests/gateway/test_status_command.py @@ -270,6 +270,75 @@ async def test_handle_message_discards_stale_result_after_session_invalidation(m assert session_key not in runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks +@pytest.mark.asyncio +async def test_handle_message_stale_result_keeps_newer_generation_callback(monkeypatch): + import gateway.run as gateway_run + + class _Adapter: + def __init__(self): + self._post_delivery_callbacks = {} + + async def send(self, *args, **kwargs): + return None + + def pop_post_delivery_callback(self, session_key, *, generation=None): + entry = self._post_delivery_callbacks.get(session_key) + if entry is None: + return None + if isinstance(entry, tuple): + entry_generation, callback = entry + if generation is not None and entry_generation != generation: + return None + self._post_delivery_callbacks.pop(session_key, None) + return callback + if generation is not None: + return None + return self._post_delivery_callbacks.pop(session_key, None) + + session_entry = SessionEntry( + session_key=build_session_key(_make_source()), + session_id="sess-1", + created_at=datetime.now(), + updated_at=datetime.now(), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + runner = _make_runner(session_entry) + runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}] + session_key = session_entry.session_key + adapter = _Adapter() + runner.adapters[Platform.TELEGRAM] = adapter + + async def _stale_result(**kwargs): + # Simulate a newer run claiming the callback slot before the stale run unwinds. + runner._session_run_generation[session_key] = 2 + adapter._post_delivery_callbacks[session_key] = (2, lambda: None) + return { + "final_response": "late reply", + "messages": [], + "tools": [], + "history_offset": 0, + "last_prompt_tokens": 80, + "input_tokens": 120, + "output_tokens": 45, + "model": "openai/test-model", + } + + runner._run_agent = AsyncMock(side_effect=_stale_result) + + monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"}) + monkeypatch.setattr( + "agent.model_metadata.get_model_context_length", + lambda *_args, **_kwargs: 100000, + ) + + result = await runner._handle_message(_make_event("hello")) + + assert result is None + assert session_key in adapter._post_delivery_callbacks + assert adapter._post_delivery_callbacks[session_key][0] == 2 + + @pytest.mark.asyncio async def test_status_command_bypasses_active_session_guard():