fix: tighten gateway interrupt salvage follow-ups

Follow-up on top of the helix4u #12388 cherry-picks:
- make deferred post-delivery callbacks generation-aware end-to-end so
  stale runs cannot clear callbacks registered by a fresher run for the
  same session
- bind callback ownership to the active session event at run start and
  snapshot that generation inside base adapter processing so later event
  mutation cannot retarget cleanup
- pass run_generation through proxy mode and drop stale proxy streams /
  final results the same way local runs are dropped
- centralize stop/new interrupt cleanup into one helper and replace the
  open-coded branches with shared logic
- unify internal control interrupt reason strings via shared constants
- remove the return from base.py's finally block so cleanup no longer
  swallows cancellation/exception flow
- add focused regressions for generation forwarding, proxy stale
  suppression, and newer-callback preservation

This addresses all review findings from the initial #12388 review while
keeping the fix scoped to stale-output/typing-loop interrupt handling.
This commit is contained in:
kshitijk4poor 2026-04-19 15:05:14 +05:30 committed by kshitij
parent 8466268ca5
commit 4b6ff0eb7f
4 changed files with 315 additions and 58 deletions

View file

@ -881,10 +881,11 @@ class BasePlatformAdapter(ABC):
# working on a task after --replace or manual restarts.
self._background_tasks: set[asyncio.Task] = set()
# One-shot callbacks to fire after the main response is delivered.
# Keyed by session_key. GatewayRunner uses this to defer
# background-review notifications ("💾 Skill created") until the
# primary reply has been sent.
self._post_delivery_callbacks: Dict[str, Callable] = {}
# Keyed by session_key. Values are either a bare callback (legacy) or
# a ``(generation, callback)`` tuple so GatewayRunner can make deferred
# deliveries generation-aware and avoid stale runs clearing callbacks
# registered by a fresher run for the same session.
self._post_delivery_callbacks: Dict[str, Any] = {}
self._expected_cancelled_tasks: set[asyncio.Task] = set()
self._busy_session_handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]] = None
# Chats where auto-TTS on voice input is disabled (set by /voice off)
@ -1471,6 +1472,48 @@ class BasePlatformAdapter(ABC):
except Exception:
pass
def register_post_delivery_callback(
self,
session_key: str,
callback: Callable,
*,
generation: int | None = None,
) -> None:
"""Register a deferred callback to fire after the main response.
``generation`` lets callers tie the callback to a specific gateway run
generation so stale runs cannot clear callbacks owned by a fresher run.
"""
if not session_key or not callable(callback):
return
if generation is None:
self._post_delivery_callbacks[session_key] = callback
else:
self._post_delivery_callbacks[session_key] = (int(generation), callback)
def pop_post_delivery_callback(
self,
session_key: str,
*,
generation: int | None = None,
) -> Callable | None:
"""Pop a deferred callback, optionally requiring generation ownership."""
if not session_key:
return None
entry = self._post_delivery_callbacks.get(session_key)
if entry is None:
return None
if isinstance(entry, tuple) and len(entry) == 2:
entry_generation, callback = entry
if generation is not None and int(entry_generation) != int(generation):
return None
self._post_delivery_callbacks.pop(session_key, None)
return callback if callable(callback) else None
if generation is not None:
return None
self._post_delivery_callbacks.pop(session_key, None)
return entry if callable(entry) else None
# ── Processing lifecycle hooks ──────────────────────────────────────────
# Subclasses override these to react to message processing events
# (e.g. Discord adds 👀/✅/❌ reactions).
@ -1741,6 +1784,7 @@ class BasePlatformAdapter(ABC):
# Fall back to a new Event only if the entry was removed externally.
interrupt_event = self._active_sessions.get(session_key) or asyncio.Event()
self._active_sessions[session_key] = interrupt_event
callback_generation = getattr(interrupt_event, "_hermes_run_generation", None)
# Start continuous typing indicator (refreshes every 2 seconds)
_thread_metadata = {"thread_id": event.source.thread_id} if event.source.thread_id else None
@ -2015,7 +2059,14 @@ class BasePlatformAdapter(ABC):
finally:
# Fire any one-shot post-delivery callback registered for this
# session (e.g. deferred background-review notifications).
_post_cb = getattr(self, "_post_delivery_callbacks", {}).pop(session_key, None)
_callback_generation = callback_generation
if hasattr(self, "pop_post_delivery_callback"):
_post_cb = self.pop_post_delivery_callback(
session_key,
generation=_callback_generation,
)
else:
_post_cb = getattr(self, "_post_delivery_callbacks", {}).pop(session_key, None)
if callable(_post_cb):
try:
_post_cb()
@ -2061,10 +2112,10 @@ class BasePlatformAdapter(ABC):
pass
# Leave _active_sessions[session_key] populated — the drain
# task's own lifecycle will clean it up.
return
# Clean up session tracking
if session_key in self._active_sessions:
del self._active_sessions[session_key]
else:
# Clean up session tracking
if session_key in self._active_sessions:
del self._active_sessions[session_key]
async def cancel_background_tasks(self) -> None:
"""Cancel any in-flight background message-processing tasks.

View file

@ -402,14 +402,21 @@ def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None:
return adapter.get_pending_message(session_key)
_INTERRUPT_REASON_STOP = "Stop requested"
_INTERRUPT_REASON_RESET = "Session reset requested"
_INTERRUPT_REASON_TIMEOUT = "Execution timed out (inactivity)"
_INTERRUPT_REASON_SSE_DISCONNECT = "SSE client disconnected"
_INTERRUPT_REASON_GATEWAY_SHUTDOWN = "Gateway shutting down"
_INTERRUPT_REASON_GATEWAY_RESTART = "Gateway restarting"
_CONTROL_INTERRUPT_MESSAGES = frozenset(
{
"stop requested",
"session reset requested",
"execution timed out (inactivity)",
"sse client disconnected",
"gateway shutting down",
"gateway restarting",
_INTERRUPT_REASON_STOP.lower(),
_INTERRUPT_REASON_RESET.lower(),
_INTERRUPT_REASON_TIMEOUT.lower(),
_INTERRUPT_REASON_SSE_DISCONNECT.lower(),
_INTERRUPT_REASON_GATEWAY_SHUTDOWN.lower(),
_INTERRUPT_REASON_GATEWAY_RESTART.lower(),
}
)
@ -2514,7 +2521,7 @@ class GatewayRunner:
_sk[:20], _e,
)
self._interrupt_running_agents(
"Gateway restarting" if self._restart_requested else "Gateway shutting down"
_INTERRUPT_REASON_GATEWAY_RESTART if self._restart_requested else _INTERRUPT_REASON_GATEWAY_SHUTDOWN
)
interrupt_deadline = asyncio.get_running_loop().time() + 5.0
while self._running_agents and asyncio.get_running_loop().time() < interrupt_deadline:
@ -3112,21 +3119,12 @@ class GatewayRunner:
# _interrupt_requested. Force-clean _running_agents so the session
# is unlocked and subsequent messages are processed normally.
if _cmd_def_inner and _cmd_def_inner.name == "stop":
running_agent = self._running_agents.get(_quick_key)
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
running_agent.interrupt("Stop requested")
# Force-clean: remove the session lock regardless of agent state
self._invalidate_session_run_generation(
await self._interrupt_and_clear_session(
_quick_key,
reason="stop_command",
source,
interrupt_reason=_INTERRUPT_REASON_STOP,
invalidation_reason="stop_command",
)
adapter = self.adapters.get(source.platform)
if adapter and hasattr(adapter, "interrupt_session_activity"):
await adapter.interrupt_session_activity(_quick_key, source.chat_id)
if adapter and hasattr(adapter, 'get_pending_message'):
adapter.get_pending_message(_quick_key) # consume and discard
self._pending_messages.pop(_quick_key, None)
self._release_running_agent_state(_quick_key)
logger.info("STOP for session %s — agent interrupted, session lock released", _quick_key[:20])
return "⚡ Stopped. You can continue this session."
@ -3138,23 +3136,15 @@ class GatewayRunner:
# doesn't get re-processed as a user message after the
# interrupt completes.
if _cmd_def_inner and _cmd_def_inner.name == "new":
running_agent = self._running_agents.get(_quick_key)
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
running_agent.interrupt("Session reset requested")
# Clear any pending messages so the old text doesn't replay
self._invalidate_session_run_generation(
await self._interrupt_and_clear_session(
_quick_key,
reason="new_command",
source,
interrupt_reason=_INTERRUPT_REASON_RESET,
invalidation_reason="new_command",
)
adapter = self.adapters.get(source.platform)
if adapter and hasattr(adapter, "interrupt_session_activity"):
await adapter.interrupt_session_activity(_quick_key, source.chat_id)
if adapter and hasattr(adapter, 'get_pending_message'):
adapter.get_pending_message(_quick_key) # consume and discard
self._pending_messages.pop(_quick_key, None)
# Clean up the running agent entry so the reset handler
# doesn't think an agent is still active.
self._release_running_agent_state(_quick_key)
return await self._handle_reset_command(event)
# /queue <prompt> — queue without interrupting
@ -4266,6 +4256,15 @@ class GatewayRunner:
if message_text is None:
return
# Bind this gateway run generation to the adapter's active-session
# event so deferred post-delivery callbacks can be released by the
# same run that registered them.
self._bind_adapter_run_generation(
self.adapters.get(source.platform),
session_key,
run_generation,
)
try:
# Emit agent:start hook
hook_ctx = {
@ -4304,7 +4303,12 @@ class GatewayRunner:
run_generation,
)
_stale_adapter = self.adapters.get(source.platform)
if _stale_adapter and hasattr(_stale_adapter, "_post_delivery_callbacks"):
if getattr(type(_stale_adapter), "pop_post_delivery_callback", None) is not None:
_stale_adapter.pop_post_delivery_callback(
_quick_key,
generation=run_generation,
)
elif _stale_adapter and hasattr(_stale_adapter, "_post_delivery_callbacks"):
_stale_adapter._post_delivery_callbacks.pop(_quick_key, None)
return None
@ -4982,22 +4986,23 @@ class GatewayRunner:
agent = self._running_agents.get(session_key)
if agent is _AGENT_PENDING_SENTINEL:
# Force-clean the sentinel so the session is unlocked.
self._invalidate_session_run_generation(session_key, reason="stop_command_pending")
adapter = self.adapters.get(source.platform)
if adapter and hasattr(adapter, "interrupt_session_activity"):
await adapter.interrupt_session_activity(session_key, source.chat_id)
self._release_running_agent_state(session_key)
await self._interrupt_and_clear_session(
session_key,
source,
interrupt_reason=_INTERRUPT_REASON_STOP,
invalidation_reason="stop_command_pending",
)
logger.info("STOP (pending) for session %s — sentinel cleared", session_key[:20])
return "⚡ Stopped. The agent hadn't started yet — you can continue this session."
if agent:
agent.interrupt("Stop requested")
# Force-clean the session lock so a truly hung agent doesn't
# keep it locked forever.
self._invalidate_session_run_generation(session_key, reason="stop_command_handler")
adapter = self.adapters.get(source.platform)
if adapter and hasattr(adapter, "interrupt_session_activity"):
await adapter.interrupt_session_activity(session_key, source.chat_id)
self._release_running_agent_state(session_key)
await self._interrupt_and_clear_session(
session_key,
source,
interrupt_reason=_INTERRUPT_REASON_STOP,
invalidation_reason="stop_command_handler",
)
return "⚡ Stopped. You can continue this session."
else:
return "No active task to stop."
@ -8481,6 +8486,47 @@ class GatewayRunner:
generations = self.__dict__.get("_session_run_generation") or {}
return int(generations.get(session_key, 0)) == int(generation)
def _bind_adapter_run_generation(
self,
adapter: Any,
session_key: str,
generation: int | None,
) -> None:
"""Bind a gateway run generation to the adapter's active-session event."""
if not adapter or not session_key or generation is None:
return
try:
interrupt_event = getattr(adapter, "_active_sessions", {}).get(session_key)
if interrupt_event is not None:
setattr(interrupt_event, "_hermes_run_generation", int(generation))
except Exception:
pass
async def _interrupt_and_clear_session(
self,
session_key: str,
source: SessionSource,
*,
interrupt_reason: str,
invalidation_reason: str,
release_running_state: bool = True,
) -> None:
"""Interrupt the current run and clear queued session state consistently."""
if not session_key:
return
running_agent = self._running_agents.get(session_key)
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
running_agent.interrupt(interrupt_reason)
self._invalidate_session_run_generation(session_key, reason=invalidation_reason)
adapter = self.adapters.get(source.platform)
if adapter and hasattr(adapter, "interrupt_session_activity"):
await adapter.interrupt_session_activity(session_key, source.chat_id)
if adapter and hasattr(adapter, "get_pending_message"):
adapter.get_pending_message(session_key) # consume and discard
self._pending_messages.pop(session_key, None)
if release_running_state:
self._release_running_agent_state(session_key)
def _evict_cached_agent(self, session_key: str) -> None:
"""Remove a cached agent for a session (called on /new, /model, etc)."""
_lock = getattr(self, "_agent_cache_lock", None)
@ -8662,6 +8708,7 @@ class GatewayRunner:
source: "SessionSource",
session_id: str,
session_key: str = None,
run_generation: Optional[int] = None,
event_message_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Forward the message to a remote Hermes API server instead of
@ -8697,6 +8744,11 @@ class GatewayRunner:
proxy_key = os.getenv("GATEWAY_PROXY_KEY", "").strip()
def _run_still_current() -> bool:
if run_generation is None or not session_key:
return True
return self._is_session_run_current(session_key, run_generation)
# Build messages in OpenAI chat format --------------------------
#
# The remote api_server can maintain session continuity via
@ -8826,6 +8878,21 @@ class GatewayRunner:
# Parse SSE stream
buffer = ""
async for chunk in resp.content.iter_any():
if not _run_still_current():
logger.info(
"Discarding stale proxy stream for %s — generation %d is no longer current",
session_key[:20] if session_key else "?",
run_generation or 0,
)
return {
"final_response": "",
"messages": [],
"api_calls": 0,
"tools": [],
"history_offset": len(history),
"session_id": session_id,
"response_previewed": False,
}
text = chunk.decode("utf-8", errors="replace")
buffer += text
@ -8875,6 +8942,21 @@ class GatewayRunner:
stream_task.cancel()
_elapsed = time.time() - _start
if not _run_still_current():
logger.info(
"Discarding stale proxy result for %s — generation %d is no longer current",
session_key[:20] if session_key else "?",
run_generation or 0,
)
return {
"final_response": "",
"messages": [],
"api_calls": 0,
"tools": [],
"history_offset": len(history),
"session_id": session_id,
"response_previewed": False,
}
logger.info(
"proxy response: url=%s session=%s time=%.1fs response=%d chars",
proxy_url, (session_id or "")[:20], _elapsed, len(full_response),
@ -8929,6 +9011,7 @@ class GatewayRunner:
source=source,
session_id=session_id,
session_key=session_key,
run_generation=run_generation,
event_message_id=event_message_id,
)
@ -9527,9 +9610,16 @@ class GatewayRunner:
# Register the release hook on the adapter so base.py's finally
# block can fire it after delivering the main response.
if _status_adapter and session_key:
_pdc = getattr(_status_adapter, "_post_delivery_callbacks", None)
if _pdc is not None:
_pdc[session_key] = _release_bg_review_messages
if getattr(type(_status_adapter), "register_post_delivery_callback", None) is not None:
_status_adapter.register_post_delivery_callback(
session_key,
_release_bg_review_messages,
generation=run_generation,
)
else:
_pdc = getattr(_status_adapter, "_post_delivery_callbacks", None)
if _pdc is not None:
_pdc[session_key] = _release_bg_review_messages
# Store agent reference for interrupt support
agent_holder[0] = agent
@ -10131,7 +10221,7 @@ class GatewayRunner:
# Interrupt the agent if it's still running so the thread
# pool worker is freed.
if _timed_out_agent and hasattr(_timed_out_agent, "interrupt"):
_timed_out_agent.interrupt("Execution timed out (inactivity)")
_timed_out_agent.interrupt(_INTERRUPT_REASON_TIMEOUT)
_timeout_mins = int(_agent_timeout // 60) or 1
@ -10309,7 +10399,17 @@ class GatewayRunner:
# first response has been delivered. Pop from the
# adapter's callback dict (prevents double-fire in
# base.py's finally block) and call it.
if adapter and hasattr(adapter, "_post_delivery_callbacks"):
if getattr(type(adapter), "pop_post_delivery_callback", None) is not None:
_bg_cb = adapter.pop_post_delivery_callback(
session_key,
generation=run_generation,
)
if callable(_bg_cb):
try:
_bg_cb()
except Exception:
pass
elif adapter and hasattr(adapter, "_post_delivery_callbacks"):
_bg_cb = adapter._post_delivery_callbacks.pop(session_key, None)
if callable(_bg_cb):
try:

View file

@ -19,6 +19,7 @@ def _make_runner(proxy_url=None):
runner.config = MagicMock()
runner.config.streaming = StreamingConfig()
runner._running_agents = {}
runner._session_run_generation = {}
runner._session_model_overrides = {}
runner._agent_cache = {}
runner._agent_cache_lock = None
@ -160,10 +161,12 @@ class TestRunAgentProxyDispatch:
source=source,
session_id="test-session-123",
session_key="test-key",
run_generation=7,
)
assert result["final_response"] == "Hello from remote!"
runner._run_agent_via_proxy.assert_called_once()
assert runner._run_agent_via_proxy.call_args.kwargs["run_generation"] == 7
@pytest.mark.asyncio
async def test_run_agent_skips_proxy_when_not_configured(self, monkeypatch):
@ -370,6 +373,40 @@ class TestRunAgentViaProxy:
assert "session_id" in result
assert result["session_id"] == "sess-123"
@pytest.mark.asyncio
async def test_proxy_stale_generation_returns_empty_result(self, monkeypatch):
monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
monkeypatch.delenv("GATEWAY_PROXY_KEY", raising=False)
runner = _make_runner()
source = _make_source()
runner._session_run_generation["test-key"] = 2
resp = _FakeSSEResponse(
status=200,
sse_chunks=[
'data: {"choices":[{"delta":{"content":"stale"}}]}\n\n',
"data: [DONE]\n\n",
],
)
session = _FakeSession(resp)
with patch("gateway.run._load_gateway_config", return_value={}):
with _patch_aiohttp(session):
with patch("aiohttp.ClientTimeout"):
result = await runner._run_agent_via_proxy(
message="hi",
context_prompt="",
history=[],
source=source,
session_id="sess-123",
session_key="test-key",
run_generation=1,
)
assert result["final_response"] == ""
assert result["messages"] == []
assert result["api_calls"] == 0
@pytest.mark.asyncio
async def test_no_auth_header_without_key(self, monkeypatch):
monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")

View file

@ -270,6 +270,75 @@ async def test_handle_message_discards_stale_result_after_session_invalidation(m
assert session_key not in runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks
@pytest.mark.asyncio
async def test_handle_message_stale_result_keeps_newer_generation_callback(monkeypatch):
import gateway.run as gateway_run
class _Adapter:
def __init__(self):
self._post_delivery_callbacks = {}
async def send(self, *args, **kwargs):
return None
def pop_post_delivery_callback(self, session_key, *, generation=None):
entry = self._post_delivery_callbacks.get(session_key)
if entry is None:
return None
if isinstance(entry, tuple):
entry_generation, callback = entry
if generation is not None and entry_generation != generation:
return None
self._post_delivery_callbacks.pop(session_key, None)
return callback
if generation is not None:
return None
return self._post_delivery_callbacks.pop(session_key, None)
session_entry = SessionEntry(
session_key=build_session_key(_make_source()),
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner = _make_runner(session_entry)
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
session_key = session_entry.session_key
adapter = _Adapter()
runner.adapters[Platform.TELEGRAM] = adapter
async def _stale_result(**kwargs):
# Simulate a newer run claiming the callback slot before the stale run unwinds.
runner._session_run_generation[session_key] = 2
adapter._post_delivery_callbacks[session_key] = (2, lambda: None)
return {
"final_response": "late reply",
"messages": [],
"tools": [],
"history_offset": 0,
"last_prompt_tokens": 80,
"input_tokens": 120,
"output_tokens": 45,
"model": "openai/test-model",
}
runner._run_agent = AsyncMock(side_effect=_stale_result)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
monkeypatch.setattr(
"agent.model_metadata.get_model_context_length",
lambda *_args, **_kwargs: 100000,
)
result = await runner._handle_message(_make_event("hello"))
assert result is None
assert session_key in adapter._post_delivery_callbacks
assert adapter._post_delivery_callbacks[session_key][0] == 2
@pytest.mark.asyncio
async def test_status_command_bypasses_active_session_guard():