Clean cached agents on gateway shutdown

This commit is contained in:
SBA 2026-04-23 21:59:59 +01:00
parent 8f5fee3e3e
commit 7e96310ce1
2 changed files with 50 additions and 0 deletions

View file

@ -1756,6 +1756,31 @@ class GatewayRunner:
pass pass
self._cleanup_agent_resources(agent) self._cleanup_agent_resources(agent)
def _finalize_shutdown_cached_agents(self, active_agents: Dict[str, Any]) -> None:
"""Hard-clean idle cached agents during full gateway shutdown.
Normal cache eviction uses release_clients() so a session can resume
with preserved tool state. Full gateway shutdown is a real session
boundary, so cached agents must be torn down like active ones.
"""
_cache = getattr(self, "_agent_cache", None)
_lock = getattr(self, "_agent_cache_lock", None)
if _cache is None or _lock is None:
return
active_ids = {id(agent) for agent in active_agents.values()}
cached_agents: List[Any] = []
with _lock:
for entry in list(_cache.values()):
agent = entry[0] if isinstance(entry, tuple) and entry else None
if agent is None or id(agent) in active_ids:
continue
cached_agents.append(agent)
_cache.clear()
for agent in cached_agents:
self._cleanup_agent_resources(agent)
def _cleanup_agent_resources(self, agent: Any) -> None: def _cleanup_agent_resources(self, agent: Any) -> None:
"""Best-effort cleanup for temporary or cached agent instances.""" """Best-effort cleanup for temporary or cached agent instances."""
if agent is None: if agent is None:
@ -2628,6 +2653,7 @@ class GatewayRunner:
logger.error("Failed to launch detached gateway restart: %s", e) logger.error("Failed to launch detached gateway restart: %s", e)
self._finalize_shutdown_agents(active_agents) self._finalize_shutdown_agents(active_agents)
self._finalize_shutdown_cached_agents(active_agents)
for platform, adapter in list(self.adapters.items()): for platform, adapter in list(self.adapters.items()):
try: try:

View file

@ -1,4 +1,6 @@
import asyncio import asyncio
import threading
from collections import OrderedDict
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@ -125,6 +127,28 @@ async def test_gateway_stop_service_restart_sets_named_exit_code():
assert runner._exit_code == GATEWAY_SERVICE_RESTART_EXIT_CODE assert runner._exit_code == GATEWAY_SERVICE_RESTART_EXIT_CODE
@pytest.mark.asyncio
async def test_gateway_stop_hard_cleans_cached_agents():
runner, adapter = make_restart_runner()
adapter.disconnect = AsyncMock()
runner._agent_cache = OrderedDict()
runner._agent_cache_lock = threading.Lock()
runner._cleanup_agent_resources = MagicMock()
cached_a = MagicMock()
cached_b = MagicMock()
runner._agent_cache["session-a"] = (cached_a, "sig-a")
runner._agent_cache["session-b"] = (cached_b, "sig-b")
with patch("gateway.status.remove_pid_file"), patch("gateway.status.write_runtime_status"):
await runner.stop()
assert list(runner._agent_cache.keys()) == []
runner._cleanup_agent_resources.assert_any_call(cached_a)
runner._cleanup_agent_resources.assert_any_call(cached_b)
assert runner._cleanup_agent_resources.call_count == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_drain_active_agents_throttles_status_updates(): async def test_drain_active_agents_throttles_status_updates():
runner, _adapter = make_restart_runner() runner, _adapter = make_restart_runner()