diff --git a/gateway/run.py b/gateway/run.py index 8683c5a75..af3946d4a 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -2383,10 +2383,23 @@ class GatewayRunner: # existing ``.restart_failure_counts`` stuck-loop counter # (incremented below, threshold 3), which sets # ``suspended=True`` and overrides resume_pending. + # + # Iterate self._running_agents (current) rather than the + # drain-start ``active_agents`` snapshot — the snapshot + # may include sessions that finished gracefully during + # the drain window, and marking those falsely would give + # them a stray restart-interruption system note on their + # next turn even though their previous turn completed + # cleanly. Skip pending sentinels for the same reason + # _interrupt_running_agents() does: their agent hasn't + # started yet, there's nothing to interrupt, and the + # session shouldn't carry a misleading resume flag. _resume_reason = ( "restart_timeout" if self._restart_requested else "shutdown_timeout" ) - for _sk in list(active_agents.keys()): + for _sk, _agent in list(self._running_agents.items()): + if _agent is _AGENT_PENDING_SENTINEL: + continue try: self.session_store.mark_resume_pending(_sk, _resume_reason) except Exception as _e: diff --git a/tests/gateway/test_restart_resume_pending.py b/tests/gateway/test_restart_resume_pending.py index a18d85cc4..c11b2740d 100644 --- a/tests/gateway/test_restart_resume_pending.py +++ b/tests/gateway/test_restart_resume_pending.py @@ -516,6 +516,84 @@ async def test_clean_drain_does_not_mark_resume_pending(): running_agent.interrupt.assert_not_called() +@pytest.mark.asyncio +async def test_drain_timeout_only_marks_still_running_sessions(): + """A session that finished gracefully during the drain window must + NOT be marked ``resume_pending`` — it completed cleanly and its + next turn should be a normal fresh turn, not one prefixed with the + restart-interruption system note. + + Regression guard for using ``self._running_agents`` at timeout + rather than the ``active_agents`` drain-start snapshot. + """ + runner, adapter = make_restart_runner() + adapter.disconnect = AsyncMock() + # Long enough for the finisher to exit, short enough to still time out + # with the stuck session still present. + runner._restart_drain_timeout = 0.3 + + session_key_finisher = "agent:main:telegram:dm:A" + session_key_stuck = "agent:main:telegram:dm:B" + runner._running_agents = { + session_key_finisher: MagicMock(), + session_key_stuck: MagicMock(), + } + + async def finish_one(): + await asyncio.sleep(0.05) + runner._running_agents.pop(session_key_finisher, None) + + asyncio.create_task(finish_one()) + + session_store = MagicMock() + session_store.mark_resume_pending = MagicMock(return_value=True) + runner.session_store = session_store + + with patch("gateway.status.remove_pid_file"), patch( + "gateway.status.write_runtime_status" + ): + await runner.stop() + + calls = session_store.mark_resume_pending.call_args_list + marked = {args[0][0] for args in calls} + # Only the session still running at timeout is marked; the finisher is not. + assert marked == {session_key_stuck} + + +@pytest.mark.asyncio +async def test_drain_timeout_skips_pending_sentinel_sessions(): + """Pending sentinels — sessions whose AIAgent construction hasn't + produced a real agent yet — are skipped by + ``_interrupt_running_agents()``. The resume_pending marking must + mirror that: no agent started means no turn was interrupted. + """ + from gateway.run import _AGENT_PENDING_SENTINEL + + runner, adapter = make_restart_runner() + adapter.disconnect = AsyncMock() + runner._restart_drain_timeout = 0.05 + + session_key_real = "agent:main:telegram:dm:A" + session_key_sentinel = "agent:main:telegram:dm:B" + runner._running_agents = { + session_key_real: MagicMock(), + session_key_sentinel: _AGENT_PENDING_SENTINEL, + } + + session_store = MagicMock() + session_store.mark_resume_pending = MagicMock(return_value=True) + runner.session_store = session_store + + with patch("gateway.status.remove_pid_file"), patch( + "gateway.status.write_runtime_status" + ): + await runner.stop() + + calls = session_store.mark_resume_pending.call_args_list + marked = {args[0][0] for args in calls} + assert marked == {session_key_real} + + # --------------------------------------------------------------------------- # Shutdown banner wording # ---------------------------------------------------------------------------