mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix: tighten gateway interrupt salvage follow-ups
Follow-up on top of the helix4u #12388 cherry-picks: - make deferred post-delivery callbacks generation-aware end-to-end so stale runs cannot clear callbacks registered by a fresher run for the same session - bind callback ownership to the active session event at run start and snapshot that generation inside base adapter processing so later event mutation cannot retarget cleanup - pass run_generation through proxy mode and drop stale proxy streams / final results the same way local runs are dropped - centralize stop/new interrupt cleanup into one helper and replace the open-coded branches with shared logic - unify internal control interrupt reason strings via shared constants - remove the return from base.py's finally block so cleanup no longer swallows cancellation/exception flow - add focused regressions for generation forwarding, proxy stale suppression, and newer-callback preservation This addresses all review findings from the initial #12388 review while keeping the fix scoped to stale-output/typing-loop interrupt handling.
This commit is contained in:
parent
8466268ca5
commit
4b6ff0eb7f
4 changed files with 315 additions and 58 deletions
|
|
@ -881,10 +881,11 @@ class BasePlatformAdapter(ABC):
|
|||
# working on a task after --replace or manual restarts.
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
# One-shot callbacks to fire after the main response is delivered.
|
||||
# Keyed by session_key. GatewayRunner uses this to defer
|
||||
# background-review notifications ("💾 Skill created") until the
|
||||
# primary reply has been sent.
|
||||
self._post_delivery_callbacks: Dict[str, Callable] = {}
|
||||
# Keyed by session_key. Values are either a bare callback (legacy) or
|
||||
# a ``(generation, callback)`` tuple so GatewayRunner can make deferred
|
||||
# deliveries generation-aware and avoid stale runs clearing callbacks
|
||||
# registered by a fresher run for the same session.
|
||||
self._post_delivery_callbacks: Dict[str, Any] = {}
|
||||
self._expected_cancelled_tasks: set[asyncio.Task] = set()
|
||||
self._busy_session_handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]] = None
|
||||
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
||||
|
|
@ -1471,6 +1472,48 @@ class BasePlatformAdapter(ABC):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
def register_post_delivery_callback(
|
||||
self,
|
||||
session_key: str,
|
||||
callback: Callable,
|
||||
*,
|
||||
generation: int | None = None,
|
||||
) -> None:
|
||||
"""Register a deferred callback to fire after the main response.
|
||||
|
||||
``generation`` lets callers tie the callback to a specific gateway run
|
||||
generation so stale runs cannot clear callbacks owned by a fresher run.
|
||||
"""
|
||||
if not session_key or not callable(callback):
|
||||
return
|
||||
if generation is None:
|
||||
self._post_delivery_callbacks[session_key] = callback
|
||||
else:
|
||||
self._post_delivery_callbacks[session_key] = (int(generation), callback)
|
||||
|
||||
def pop_post_delivery_callback(
|
||||
self,
|
||||
session_key: str,
|
||||
*,
|
||||
generation: int | None = None,
|
||||
) -> Callable | None:
|
||||
"""Pop a deferred callback, optionally requiring generation ownership."""
|
||||
if not session_key:
|
||||
return None
|
||||
entry = self._post_delivery_callbacks.get(session_key)
|
||||
if entry is None:
|
||||
return None
|
||||
if isinstance(entry, tuple) and len(entry) == 2:
|
||||
entry_generation, callback = entry
|
||||
if generation is not None and int(entry_generation) != int(generation):
|
||||
return None
|
||||
self._post_delivery_callbacks.pop(session_key, None)
|
||||
return callback if callable(callback) else None
|
||||
if generation is not None:
|
||||
return None
|
||||
self._post_delivery_callbacks.pop(session_key, None)
|
||||
return entry if callable(entry) else None
|
||||
|
||||
# ── Processing lifecycle hooks ──────────────────────────────────────────
|
||||
# Subclasses override these to react to message processing events
|
||||
# (e.g. Discord adds 👀/✅/❌ reactions).
|
||||
|
|
@ -1741,6 +1784,7 @@ class BasePlatformAdapter(ABC):
|
|||
# Fall back to a new Event only if the entry was removed externally.
|
||||
interrupt_event = self._active_sessions.get(session_key) or asyncio.Event()
|
||||
self._active_sessions[session_key] = interrupt_event
|
||||
callback_generation = getattr(interrupt_event, "_hermes_run_generation", None)
|
||||
|
||||
# Start continuous typing indicator (refreshes every 2 seconds)
|
||||
_thread_metadata = {"thread_id": event.source.thread_id} if event.source.thread_id else None
|
||||
|
|
@ -2015,7 +2059,14 @@ class BasePlatformAdapter(ABC):
|
|||
finally:
|
||||
# Fire any one-shot post-delivery callback registered for this
|
||||
# session (e.g. deferred background-review notifications).
|
||||
_post_cb = getattr(self, "_post_delivery_callbacks", {}).pop(session_key, None)
|
||||
_callback_generation = callback_generation
|
||||
if hasattr(self, "pop_post_delivery_callback"):
|
||||
_post_cb = self.pop_post_delivery_callback(
|
||||
session_key,
|
||||
generation=_callback_generation,
|
||||
)
|
||||
else:
|
||||
_post_cb = getattr(self, "_post_delivery_callbacks", {}).pop(session_key, None)
|
||||
if callable(_post_cb):
|
||||
try:
|
||||
_post_cb()
|
||||
|
|
@ -2061,10 +2112,10 @@ class BasePlatformAdapter(ABC):
|
|||
pass
|
||||
# Leave _active_sessions[session_key] populated — the drain
|
||||
# task's own lifecycle will clean it up.
|
||||
return
|
||||
# Clean up session tracking
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
else:
|
||||
# Clean up session tracking
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
|
||||
async def cancel_background_tasks(self) -> None:
|
||||
"""Cancel any in-flight background message-processing tasks.
|
||||
|
|
|
|||
198
gateway/run.py
198
gateway/run.py
|
|
@ -402,14 +402,21 @@ def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None:
|
|||
return adapter.get_pending_message(session_key)
|
||||
|
||||
|
||||
_INTERRUPT_REASON_STOP = "Stop requested"
|
||||
_INTERRUPT_REASON_RESET = "Session reset requested"
|
||||
_INTERRUPT_REASON_TIMEOUT = "Execution timed out (inactivity)"
|
||||
_INTERRUPT_REASON_SSE_DISCONNECT = "SSE client disconnected"
|
||||
_INTERRUPT_REASON_GATEWAY_SHUTDOWN = "Gateway shutting down"
|
||||
_INTERRUPT_REASON_GATEWAY_RESTART = "Gateway restarting"
|
||||
|
||||
_CONTROL_INTERRUPT_MESSAGES = frozenset(
|
||||
{
|
||||
"stop requested",
|
||||
"session reset requested",
|
||||
"execution timed out (inactivity)",
|
||||
"sse client disconnected",
|
||||
"gateway shutting down",
|
||||
"gateway restarting",
|
||||
_INTERRUPT_REASON_STOP.lower(),
|
||||
_INTERRUPT_REASON_RESET.lower(),
|
||||
_INTERRUPT_REASON_TIMEOUT.lower(),
|
||||
_INTERRUPT_REASON_SSE_DISCONNECT.lower(),
|
||||
_INTERRUPT_REASON_GATEWAY_SHUTDOWN.lower(),
|
||||
_INTERRUPT_REASON_GATEWAY_RESTART.lower(),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -2514,7 +2521,7 @@ class GatewayRunner:
|
|||
_sk[:20], _e,
|
||||
)
|
||||
self._interrupt_running_agents(
|
||||
"Gateway restarting" if self._restart_requested else "Gateway shutting down"
|
||||
_INTERRUPT_REASON_GATEWAY_RESTART if self._restart_requested else _INTERRUPT_REASON_GATEWAY_SHUTDOWN
|
||||
)
|
||||
interrupt_deadline = asyncio.get_running_loop().time() + 5.0
|
||||
while self._running_agents and asyncio.get_running_loop().time() < interrupt_deadline:
|
||||
|
|
@ -3112,21 +3119,12 @@ class GatewayRunner:
|
|||
# _interrupt_requested. Force-clean _running_agents so the session
|
||||
# is unlocked and subsequent messages are processed normally.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "stop":
|
||||
running_agent = self._running_agents.get(_quick_key)
|
||||
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
|
||||
running_agent.interrupt("Stop requested")
|
||||
# Force-clean: remove the session lock regardless of agent state
|
||||
self._invalidate_session_run_generation(
|
||||
await self._interrupt_and_clear_session(
|
||||
_quick_key,
|
||||
reason="stop_command",
|
||||
source,
|
||||
interrupt_reason=_INTERRUPT_REASON_STOP,
|
||||
invalidation_reason="stop_command",
|
||||
)
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter and hasattr(adapter, "interrupt_session_activity"):
|
||||
await adapter.interrupt_session_activity(_quick_key, source.chat_id)
|
||||
if adapter and hasattr(adapter, 'get_pending_message'):
|
||||
adapter.get_pending_message(_quick_key) # consume and discard
|
||||
self._pending_messages.pop(_quick_key, None)
|
||||
self._release_running_agent_state(_quick_key)
|
||||
logger.info("STOP for session %s — agent interrupted, session lock released", _quick_key[:20])
|
||||
return "⚡ Stopped. You can continue this session."
|
||||
|
||||
|
|
@ -3138,23 +3136,15 @@ class GatewayRunner:
|
|||
# doesn't get re-processed as a user message after the
|
||||
# interrupt completes.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "new":
|
||||
running_agent = self._running_agents.get(_quick_key)
|
||||
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
|
||||
running_agent.interrupt("Session reset requested")
|
||||
# Clear any pending messages so the old text doesn't replay
|
||||
self._invalidate_session_run_generation(
|
||||
await self._interrupt_and_clear_session(
|
||||
_quick_key,
|
||||
reason="new_command",
|
||||
source,
|
||||
interrupt_reason=_INTERRUPT_REASON_RESET,
|
||||
invalidation_reason="new_command",
|
||||
)
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter and hasattr(adapter, "interrupt_session_activity"):
|
||||
await adapter.interrupt_session_activity(_quick_key, source.chat_id)
|
||||
if adapter and hasattr(adapter, 'get_pending_message'):
|
||||
adapter.get_pending_message(_quick_key) # consume and discard
|
||||
self._pending_messages.pop(_quick_key, None)
|
||||
# Clean up the running agent entry so the reset handler
|
||||
# doesn't think an agent is still active.
|
||||
self._release_running_agent_state(_quick_key)
|
||||
return await self._handle_reset_command(event)
|
||||
|
||||
# /queue <prompt> — queue without interrupting
|
||||
|
|
@ -4266,6 +4256,15 @@ class GatewayRunner:
|
|||
if message_text is None:
|
||||
return
|
||||
|
||||
# Bind this gateway run generation to the adapter's active-session
|
||||
# event so deferred post-delivery callbacks can be released by the
|
||||
# same run that registered them.
|
||||
self._bind_adapter_run_generation(
|
||||
self.adapters.get(source.platform),
|
||||
session_key,
|
||||
run_generation,
|
||||
)
|
||||
|
||||
try:
|
||||
# Emit agent:start hook
|
||||
hook_ctx = {
|
||||
|
|
@ -4304,7 +4303,12 @@ class GatewayRunner:
|
|||
run_generation,
|
||||
)
|
||||
_stale_adapter = self.adapters.get(source.platform)
|
||||
if _stale_adapter and hasattr(_stale_adapter, "_post_delivery_callbacks"):
|
||||
if getattr(type(_stale_adapter), "pop_post_delivery_callback", None) is not None:
|
||||
_stale_adapter.pop_post_delivery_callback(
|
||||
_quick_key,
|
||||
generation=run_generation,
|
||||
)
|
||||
elif _stale_adapter and hasattr(_stale_adapter, "_post_delivery_callbacks"):
|
||||
_stale_adapter._post_delivery_callbacks.pop(_quick_key, None)
|
||||
return None
|
||||
|
||||
|
|
@ -4982,22 +4986,23 @@ class GatewayRunner:
|
|||
agent = self._running_agents.get(session_key)
|
||||
if agent is _AGENT_PENDING_SENTINEL:
|
||||
# Force-clean the sentinel so the session is unlocked.
|
||||
self._invalidate_session_run_generation(session_key, reason="stop_command_pending")
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter and hasattr(adapter, "interrupt_session_activity"):
|
||||
await adapter.interrupt_session_activity(session_key, source.chat_id)
|
||||
self._release_running_agent_state(session_key)
|
||||
await self._interrupt_and_clear_session(
|
||||
session_key,
|
||||
source,
|
||||
interrupt_reason=_INTERRUPT_REASON_STOP,
|
||||
invalidation_reason="stop_command_pending",
|
||||
)
|
||||
logger.info("STOP (pending) for session %s — sentinel cleared", session_key[:20])
|
||||
return "⚡ Stopped. The agent hadn't started yet — you can continue this session."
|
||||
if agent:
|
||||
agent.interrupt("Stop requested")
|
||||
# Force-clean the session lock so a truly hung agent doesn't
|
||||
# keep it locked forever.
|
||||
self._invalidate_session_run_generation(session_key, reason="stop_command_handler")
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter and hasattr(adapter, "interrupt_session_activity"):
|
||||
await adapter.interrupt_session_activity(session_key, source.chat_id)
|
||||
self._release_running_agent_state(session_key)
|
||||
await self._interrupt_and_clear_session(
|
||||
session_key,
|
||||
source,
|
||||
interrupt_reason=_INTERRUPT_REASON_STOP,
|
||||
invalidation_reason="stop_command_handler",
|
||||
)
|
||||
return "⚡ Stopped. You can continue this session."
|
||||
else:
|
||||
return "No active task to stop."
|
||||
|
|
@ -8481,6 +8486,47 @@ class GatewayRunner:
|
|||
generations = self.__dict__.get("_session_run_generation") or {}
|
||||
return int(generations.get(session_key, 0)) == int(generation)
|
||||
|
||||
def _bind_adapter_run_generation(
|
||||
self,
|
||||
adapter: Any,
|
||||
session_key: str,
|
||||
generation: int | None,
|
||||
) -> None:
|
||||
"""Bind a gateway run generation to the adapter's active-session event."""
|
||||
if not adapter or not session_key or generation is None:
|
||||
return
|
||||
try:
|
||||
interrupt_event = getattr(adapter, "_active_sessions", {}).get(session_key)
|
||||
if interrupt_event is not None:
|
||||
setattr(interrupt_event, "_hermes_run_generation", int(generation))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _interrupt_and_clear_session(
|
||||
self,
|
||||
session_key: str,
|
||||
source: SessionSource,
|
||||
*,
|
||||
interrupt_reason: str,
|
||||
invalidation_reason: str,
|
||||
release_running_state: bool = True,
|
||||
) -> None:
|
||||
"""Interrupt the current run and clear queued session state consistently."""
|
||||
if not session_key:
|
||||
return
|
||||
running_agent = self._running_agents.get(session_key)
|
||||
if running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
|
||||
running_agent.interrupt(interrupt_reason)
|
||||
self._invalidate_session_run_generation(session_key, reason=invalidation_reason)
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter and hasattr(adapter, "interrupt_session_activity"):
|
||||
await adapter.interrupt_session_activity(session_key, source.chat_id)
|
||||
if adapter and hasattr(adapter, "get_pending_message"):
|
||||
adapter.get_pending_message(session_key) # consume and discard
|
||||
self._pending_messages.pop(session_key, None)
|
||||
if release_running_state:
|
||||
self._release_running_agent_state(session_key)
|
||||
|
||||
def _evict_cached_agent(self, session_key: str) -> None:
|
||||
"""Remove a cached agent for a session (called on /new, /model, etc)."""
|
||||
_lock = getattr(self, "_agent_cache_lock", None)
|
||||
|
|
@ -8662,6 +8708,7 @@ class GatewayRunner:
|
|||
source: "SessionSource",
|
||||
session_id: str,
|
||||
session_key: str = None,
|
||||
run_generation: Optional[int] = None,
|
||||
event_message_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Forward the message to a remote Hermes API server instead of
|
||||
|
|
@ -8697,6 +8744,11 @@ class GatewayRunner:
|
|||
|
||||
proxy_key = os.getenv("GATEWAY_PROXY_KEY", "").strip()
|
||||
|
||||
def _run_still_current() -> bool:
|
||||
if run_generation is None or not session_key:
|
||||
return True
|
||||
return self._is_session_run_current(session_key, run_generation)
|
||||
|
||||
# Build messages in OpenAI chat format --------------------------
|
||||
#
|
||||
# The remote api_server can maintain session continuity via
|
||||
|
|
@ -8826,6 +8878,21 @@ class GatewayRunner:
|
|||
# Parse SSE stream
|
||||
buffer = ""
|
||||
async for chunk in resp.content.iter_any():
|
||||
if not _run_still_current():
|
||||
logger.info(
|
||||
"Discarding stale proxy stream for %s — generation %d is no longer current",
|
||||
session_key[:20] if session_key else "?",
|
||||
run_generation or 0,
|
||||
)
|
||||
return {
|
||||
"final_response": "",
|
||||
"messages": [],
|
||||
"api_calls": 0,
|
||||
"tools": [],
|
||||
"history_offset": len(history),
|
||||
"session_id": session_id,
|
||||
"response_previewed": False,
|
||||
}
|
||||
text = chunk.decode("utf-8", errors="replace")
|
||||
buffer += text
|
||||
|
||||
|
|
@ -8875,6 +8942,21 @@ class GatewayRunner:
|
|||
stream_task.cancel()
|
||||
|
||||
_elapsed = time.time() - _start
|
||||
if not _run_still_current():
|
||||
logger.info(
|
||||
"Discarding stale proxy result for %s — generation %d is no longer current",
|
||||
session_key[:20] if session_key else "?",
|
||||
run_generation or 0,
|
||||
)
|
||||
return {
|
||||
"final_response": "",
|
||||
"messages": [],
|
||||
"api_calls": 0,
|
||||
"tools": [],
|
||||
"history_offset": len(history),
|
||||
"session_id": session_id,
|
||||
"response_previewed": False,
|
||||
}
|
||||
logger.info(
|
||||
"proxy response: url=%s session=%s time=%.1fs response=%d chars",
|
||||
proxy_url, (session_id or "")[:20], _elapsed, len(full_response),
|
||||
|
|
@ -8929,6 +9011,7 @@ class GatewayRunner:
|
|||
source=source,
|
||||
session_id=session_id,
|
||||
session_key=session_key,
|
||||
run_generation=run_generation,
|
||||
event_message_id=event_message_id,
|
||||
)
|
||||
|
||||
|
|
@ -9527,9 +9610,16 @@ class GatewayRunner:
|
|||
# Register the release hook on the adapter so base.py's finally
|
||||
# block can fire it after delivering the main response.
|
||||
if _status_adapter and session_key:
|
||||
_pdc = getattr(_status_adapter, "_post_delivery_callbacks", None)
|
||||
if _pdc is not None:
|
||||
_pdc[session_key] = _release_bg_review_messages
|
||||
if getattr(type(_status_adapter), "register_post_delivery_callback", None) is not None:
|
||||
_status_adapter.register_post_delivery_callback(
|
||||
session_key,
|
||||
_release_bg_review_messages,
|
||||
generation=run_generation,
|
||||
)
|
||||
else:
|
||||
_pdc = getattr(_status_adapter, "_post_delivery_callbacks", None)
|
||||
if _pdc is not None:
|
||||
_pdc[session_key] = _release_bg_review_messages
|
||||
|
||||
# Store agent reference for interrupt support
|
||||
agent_holder[0] = agent
|
||||
|
|
@ -10131,7 +10221,7 @@ class GatewayRunner:
|
|||
# Interrupt the agent if it's still running so the thread
|
||||
# pool worker is freed.
|
||||
if _timed_out_agent and hasattr(_timed_out_agent, "interrupt"):
|
||||
_timed_out_agent.interrupt("Execution timed out (inactivity)")
|
||||
_timed_out_agent.interrupt(_INTERRUPT_REASON_TIMEOUT)
|
||||
|
||||
_timeout_mins = int(_agent_timeout // 60) or 1
|
||||
|
||||
|
|
@ -10309,7 +10399,17 @@ class GatewayRunner:
|
|||
# first response has been delivered. Pop from the
|
||||
# adapter's callback dict (prevents double-fire in
|
||||
# base.py's finally block) and call it.
|
||||
if adapter and hasattr(adapter, "_post_delivery_callbacks"):
|
||||
if getattr(type(adapter), "pop_post_delivery_callback", None) is not None:
|
||||
_bg_cb = adapter.pop_post_delivery_callback(
|
||||
session_key,
|
||||
generation=run_generation,
|
||||
)
|
||||
if callable(_bg_cb):
|
||||
try:
|
||||
_bg_cb()
|
||||
except Exception:
|
||||
pass
|
||||
elif adapter and hasattr(adapter, "_post_delivery_callbacks"):
|
||||
_bg_cb = adapter._post_delivery_callbacks.pop(session_key, None)
|
||||
if callable(_bg_cb):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ def _make_runner(proxy_url=None):
|
|||
runner.config = MagicMock()
|
||||
runner.config.streaming = StreamingConfig()
|
||||
runner._running_agents = {}
|
||||
runner._session_run_generation = {}
|
||||
runner._session_model_overrides = {}
|
||||
runner._agent_cache = {}
|
||||
runner._agent_cache_lock = None
|
||||
|
|
@ -160,10 +161,12 @@ class TestRunAgentProxyDispatch:
|
|||
source=source,
|
||||
session_id="test-session-123",
|
||||
session_key="test-key",
|
||||
run_generation=7,
|
||||
)
|
||||
|
||||
assert result["final_response"] == "Hello from remote!"
|
||||
runner._run_agent_via_proxy.assert_called_once()
|
||||
assert runner._run_agent_via_proxy.call_args.kwargs["run_generation"] == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_skips_proxy_when_not_configured(self, monkeypatch):
|
||||
|
|
@ -370,6 +373,40 @@ class TestRunAgentViaProxy:
|
|||
assert "session_id" in result
|
||||
assert result["session_id"] == "sess-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_stale_generation_returns_empty_result(self, monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
|
||||
monkeypatch.delenv("GATEWAY_PROXY_KEY", raising=False)
|
||||
runner = _make_runner()
|
||||
source = _make_source()
|
||||
runner._session_run_generation["test-key"] = 2
|
||||
|
||||
resp = _FakeSSEResponse(
|
||||
status=200,
|
||||
sse_chunks=[
|
||||
'data: {"choices":[{"delta":{"content":"stale"}}]}\n\n',
|
||||
"data: [DONE]\n\n",
|
||||
],
|
||||
)
|
||||
session = _FakeSession(resp)
|
||||
|
||||
with patch("gateway.run._load_gateway_config", return_value={}):
|
||||
with _patch_aiohttp(session):
|
||||
with patch("aiohttp.ClientTimeout"):
|
||||
result = await runner._run_agent_via_proxy(
|
||||
message="hi",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-123",
|
||||
session_key="test-key",
|
||||
run_generation=1,
|
||||
)
|
||||
|
||||
assert result["final_response"] == ""
|
||||
assert result["messages"] == []
|
||||
assert result["api_calls"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_auth_header_without_key(self, monkeypatch):
|
||||
monkeypatch.setenv("GATEWAY_PROXY_URL", "http://host:8642")
|
||||
|
|
|
|||
|
|
@ -270,6 +270,75 @@ async def test_handle_message_discards_stale_result_after_session_invalidation(m
|
|||
assert session_key not in runner.adapters[Platform.TELEGRAM]._post_delivery_callbacks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_stale_result_keeps_newer_generation_callback(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
class _Adapter:
|
||||
def __init__(self):
|
||||
self._post_delivery_callbacks = {}
|
||||
|
||||
async def send(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
def pop_post_delivery_callback(self, session_key, *, generation=None):
|
||||
entry = self._post_delivery_callbacks.get(session_key)
|
||||
if entry is None:
|
||||
return None
|
||||
if isinstance(entry, tuple):
|
||||
entry_generation, callback = entry
|
||||
if generation is not None and entry_generation != generation:
|
||||
return None
|
||||
self._post_delivery_callbacks.pop(session_key, None)
|
||||
return callback
|
||||
if generation is not None:
|
||||
return None
|
||||
return self._post_delivery_callbacks.pop(session_key, None)
|
||||
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
|
||||
session_key = session_entry.session_key
|
||||
adapter = _Adapter()
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
async def _stale_result(**kwargs):
|
||||
# Simulate a newer run claiming the callback slot before the stale run unwinds.
|
||||
runner._session_run_generation[session_key] = 2
|
||||
adapter._post_delivery_callbacks[session_key] = (2, lambda: None)
|
||||
return {
|
||||
"final_response": "late reply",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"history_offset": 0,
|
||||
"last_prompt_tokens": 80,
|
||||
"input_tokens": 120,
|
||||
"output_tokens": 45,
|
||||
"model": "openai/test-model",
|
||||
}
|
||||
|
||||
runner._run_agent = AsyncMock(side_effect=_stale_result)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
monkeypatch.setattr(
|
||||
"agent.model_metadata.get_model_context_length",
|
||||
lambda *_args, **_kwargs: 100000,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(_make_event("hello"))
|
||||
|
||||
assert result is None
|
||||
assert session_key in adapter._post_delivery_callbacks
|
||||
assert adapter._post_delivery_callbacks[session_key][0] == 2
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_command_bypasses_active_session_guard():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue