fix(gateway): close temporary agents after one-off tasks

Add shared _cleanup_agent_resources() for temporary gateway AIAgent
instances. Apply cleanup to memory flush, background tasks, /btw,
manual /compress, and session-hygiene auto-compression. Prevents
unclosed aiohttp client session leaks.

Cherry-picked from #10899 by @LeonSGP43. Consolidates #10945 by @Lubrsy706.
Fixes #10865.

Co-authored-by: Lubrsy706 <Lubrsy706@users.noreply.github.com>
This commit is contained in:
LeonSGP43 2026-04-16 18:41:31 +05:30 committed by kshitij
parent dc7d47a6b8
commit 465193b7eb
5 changed files with 241 additions and 174 deletions

View file

@ -762,69 +762,72 @@ class GatewayRunner:
enabled_toolsets=["memory", "skills"],
session_id=old_session_id,
)
# Fully silence the flush agent — quiet_mode only suppresses init
# messages; tool call output still leaks to the terminal through
# _safe_print → _print_fn. Set a no-op to prevent that.
tmp_agent._print_fn = lambda *a, **kw: None
# Build conversation history from transcript
msgs = [
{"role": m.get("role"), "content": m.get("content")}
for m in history
if m.get("role") in ("user", "assistant") and m.get("content")
]
# Read live memory state from disk so the flush agent can see
# what's already saved and avoid overwriting newer entries.
_current_memory = ""
try:
from tools.memory_tool import get_memory_dir
_mem_dir = get_memory_dir()
for fname, label in [
("MEMORY.md", "MEMORY (your personal notes)"),
("USER.md", "USER PROFILE (who the user is)"),
]:
fpath = _mem_dir / fname
if fpath.exists():
content = fpath.read_text(encoding="utf-8").strip()
if content:
_current_memory += f"\n\n## Current {label}:\n{content}"
except Exception:
pass # Non-fatal — flush still works, just without the guard
# Fully silence the flush agent — quiet_mode only suppresses init
# messages; tool call output still leaks to the terminal through
# _safe_print → _print_fn. Set a no-op to prevent that.
tmp_agent._print_fn = lambda *a, **kw: None
# Give the agent a real turn to think about what to save
flush_prompt = (
"[System: This session is about to be automatically reset due to "
"inactivity or a scheduled daily reset. The conversation context "
"will be cleared after this turn.\n\n"
"Review the conversation above and:\n"
"1. Save any important facts, preferences, or decisions to memory "
"(user profile or your notes) that would be useful in future sessions.\n"
"2. If you discovered a reusable workflow or solved a non-trivial "
"problem, consider saving it as a skill.\n"
"3. If nothing is worth saving, that's fine — just skip.\n\n"
)
# Build conversation history from transcript
msgs = [
{"role": m.get("role"), "content": m.get("content")}
for m in history
if m.get("role") in ("user", "assistant") and m.get("content")
]
if _current_memory:
flush_prompt += (
"IMPORTANT — here is the current live state of memory. Other "
"sessions, cron jobs, or the user may have updated it since this "
"conversation ended. Do NOT overwrite or remove entries unless "
"the conversation above reveals something that genuinely "
"supersedes them. Only add new information that is not already "
"captured below."
f"{_current_memory}\n\n"
# Read live memory state from disk so the flush agent can see
# what's already saved and avoid overwriting newer entries.
_current_memory = ""
try:
from tools.memory_tool import get_memory_dir
_mem_dir = get_memory_dir()
for fname, label in [
("MEMORY.md", "MEMORY (your personal notes)"),
("USER.md", "USER PROFILE (who the user is)"),
]:
fpath = _mem_dir / fname
if fpath.exists():
content = fpath.read_text(encoding="utf-8").strip()
if content:
_current_memory += f"\n\n## Current {label}:\n{content}"
except Exception:
pass # Non-fatal — flush still works, just without the guard
# Give the agent a real turn to think about what to save
flush_prompt = (
"[System: This session is about to be automatically reset due to "
"inactivity or a scheduled daily reset. The conversation context "
"will be cleared after this turn.\n\n"
"Review the conversation above and:\n"
"1. Save any important facts, preferences, or decisions to memory "
"(user profile or your notes) that would be useful in future sessions.\n"
"2. If you discovered a reusable workflow or solved a non-trivial "
"problem, consider saving it as a skill.\n"
"3. If nothing is worth saving, that's fine — just skip.\n\n"
)
flush_prompt += (
"Do NOT respond to the user. Just use the memory and skill_manage "
"tools if needed, then stop.]"
)
if _current_memory:
flush_prompt += (
"IMPORTANT — here is the current live state of memory. Other "
"sessions, cron jobs, or the user may have updated it since this "
"conversation ended. Do NOT overwrite or remove entries unless "
"the conversation above reveals something that genuinely "
"supersedes them. Only add new information that is not already "
"captured below."
f"{_current_memory}\n\n"
)
tmp_agent.run_conversation(
user_message=flush_prompt,
conversation_history=msgs,
)
flush_prompt += (
"Do NOT respond to the user. Just use the memory and skill_manage "
"tools if needed, then stop.]"
)
tmp_agent.run_conversation(
user_message=flush_prompt,
conversation_history=msgs,
)
finally:
self._cleanup_agent_resources(tmp_agent)
logger.info("Pre-reset memory flush completed for session %s", old_session_id)
except Exception as e:
logger.debug("Pre-reset memory flush failed for session %s: %s", old_session_id, e)
@ -1562,19 +1565,25 @@ class GatewayRunner:
)
except Exception:
pass
try:
if hasattr(agent, "shutdown_memory_provider"):
agent.shutdown_memory_provider()
except Exception:
pass
# Close tool resources (terminal sandboxes, browser daemons,
# background processes, httpx clients) to prevent zombie
# process accumulation.
try:
if hasattr(agent, 'close'):
agent.close()
except Exception:
pass
self._cleanup_agent_resources(agent)
def _cleanup_agent_resources(self, agent: Any) -> None:
"""Best-effort cleanup for temporary or cached agent instances."""
if agent is None:
return
try:
if hasattr(agent, "shutdown_memory_provider"):
agent.shutdown_memory_provider()
except Exception:
pass
# Close tool resources (terminal sandboxes, browser daemons,
# background processes, httpx clients) to prevent zombie
# process accumulation.
try:
if hasattr(agent, "close"):
agent.close()
except Exception:
pass
_STUCK_LOOP_THRESHOLD = 3 # restarts while active before auto-suspend
_STUCK_LOOP_FILE = ".restart_failure_counts"
@ -2077,16 +2086,7 @@ class GatewayRunner:
if _cached_agent is None:
_cached_agent = self._running_agents.get(key)
if _cached_agent and _cached_agent is not _AGENT_PENDING_SENTINEL:
try:
if hasattr(_cached_agent, 'shutdown_memory_provider'):
_cached_agent.shutdown_memory_provider()
except Exception:
pass
try:
if hasattr(_cached_agent, 'close'):
_cached_agent.close()
except Exception:
pass
self._cleanup_agent_resources(_cached_agent)
# Mark as flushed and persist to disk so the flag
# survives gateway restarts.
with self.session_store._lock:
@ -3775,51 +3775,54 @@ class GatewayRunner:
enabled_toolsets=["memory"],
session_id=session_entry.session_id,
)
_hyg_agent._print_fn = lambda *a, **kw: None
try:
_hyg_agent._print_fn = lambda *a, **kw: None
loop = asyncio.get_running_loop()
_compressed, _ = await loop.run_in_executor(
None,
lambda: _hyg_agent._compress_context(
_hyg_msgs, "",
approx_tokens=_approx_tokens,
),
)
# _compress_context ends the old session and creates
# a new session_id. Write compressed messages into
# the NEW session so the old transcript stays intact
# and searchable via session_search.
_hyg_new_sid = _hyg_agent.session_id
if _hyg_new_sid != session_entry.session_id:
session_entry.session_id = _hyg_new_sid
self.session_store._save()
self.session_store.rewrite_transcript(
session_entry.session_id, _compressed
)
# Reset stored token count — transcript was rewritten
session_entry.last_prompt_tokens = 0
history = _compressed
_new_count = len(_compressed)
_new_tokens = estimate_messages_tokens_rough(
_compressed
)
logger.info(
"Session hygiene: compressed %s%s msgs, "
"~%s → ~%s tokens",
_msg_count, _new_count,
f"{_approx_tokens:,}", f"{_new_tokens:,}",
)
if _new_tokens >= _warn_token_threshold:
logger.warning(
"Session hygiene: still ~%s tokens after "
"compression",
f"{_new_tokens:,}",
loop = asyncio.get_running_loop()
_compressed, _ = await loop.run_in_executor(
None,
lambda: _hyg_agent._compress_context(
_hyg_msgs, "",
approx_tokens=_approx_tokens,
),
)
# _compress_context ends the old session and creates
# a new session_id. Write compressed messages into
# the NEW session so the old transcript stays intact
# and searchable via session_search.
_hyg_new_sid = _hyg_agent.session_id
if _hyg_new_sid != session_entry.session_id:
session_entry.session_id = _hyg_new_sid
self.session_store._save()
self.session_store.rewrite_transcript(
session_entry.session_id, _compressed
)
# Reset stored token count — transcript was rewritten
session_entry.last_prompt_tokens = 0
history = _compressed
_new_count = len(_compressed)
_new_tokens = estimate_messages_tokens_rough(
_compressed
)
logger.info(
"Session hygiene: compressed %s%s msgs, "
"~%s → ~%s tokens",
_msg_count, _new_count,
f"{_approx_tokens:,}", f"{_new_tokens:,}",
)
if _new_tokens >= _warn_token_threshold:
logger.warning(
"Session hygiene: still ~%s tokens after "
"compression",
f"{_new_tokens:,}",
)
finally:
self._cleanup_agent_resources(_hyg_agent)
except Exception as e:
logger.warning(
"Session hygiene auto-compress failed: %s", e
@ -4337,16 +4340,7 @@ class GatewayRunner:
_cached = self._agent_cache.get(session_key)
_old_agent = _cached[0] if isinstance(_cached, tuple) else _cached if _cached else None
if _old_agent is not None:
try:
if hasattr(_old_agent, "shutdown_memory_provider"):
_old_agent.shutdown_memory_provider()
except Exception:
pass
try:
if hasattr(_old_agent, "close"):
_old_agent.close()
except Exception:
pass
self._cleanup_agent_resources(_old_agent)
self._evict_cached_agent(session_key)
try:
@ -5741,11 +5735,13 @@ class GatewayRunner:
session_db=self._session_db,
fallback_model=self._fallback_model,
)
return agent.run_conversation(
user_message=prompt,
task_id=task_id,
)
try:
return agent.run_conversation(
user_message=prompt,
task_id=task_id,
)
finally:
self._cleanup_agent_resources(agent)
result = await self._run_in_executor_with_context(run_sync)
@ -5923,11 +5919,14 @@ class GatewayRunner:
skip_context_files=True,
persist_session=False,
)
return agent.run_conversation(
user_message=btw_prompt,
conversation_history=history_snapshot,
task_id=task_id,
)
try:
return agent.run_conversation(
user_message=btw_prompt,
conversation_history=history_snapshot,
task_id=task_id,
)
finally:
self._cleanup_agent_resources(agent)
result = await self._run_in_executor_with_context(run_sync)
@ -6256,42 +6255,45 @@ class GatewayRunner:
enabled_toolsets=["memory"],
session_id=session_entry.session_id,
)
tmp_agent._print_fn = lambda *a, **kw: None
try:
tmp_agent._print_fn = lambda *a, **kw: None
compressor = tmp_agent.context_compressor
compress_start = compressor.protect_first_n
compress_start = compressor._align_boundary_forward(msgs, compress_start)
compress_end = compressor._find_tail_cut_by_tokens(msgs, compress_start)
if compress_start >= compress_end:
return "Nothing to compress yet (the transcript is still all protected context)."
compressor = tmp_agent.context_compressor
compress_start = compressor.protect_first_n
compress_start = compressor._align_boundary_forward(msgs, compress_start)
compress_end = compressor._find_tail_cut_by_tokens(msgs, compress_start)
if compress_start >= compress_end:
return "Nothing to compress yet (the transcript is still all protected context)."
loop = asyncio.get_running_loop()
compressed, _ = await loop.run_in_executor(
None,
lambda: tmp_agent._compress_context(msgs, "", approx_tokens=approx_tokens, focus_topic=focus_topic)
)
loop = asyncio.get_running_loop()
compressed, _ = await loop.run_in_executor(
None,
lambda: tmp_agent._compress_context(msgs, "", approx_tokens=approx_tokens, focus_topic=focus_topic)
)
# _compress_context already calls end_session() on the old session
# (preserving its full transcript in SQLite) and creates a new
# session_id for the continuation. Write the compressed messages
# into the NEW session so the original history stays searchable.
new_session_id = tmp_agent.session_id
if new_session_id != session_entry.session_id:
session_entry.session_id = new_session_id
self.session_store._save()
# _compress_context already calls end_session() on the old session
# (preserving its full transcript in SQLite) and creates a new
# session_id for the continuation. Write the compressed messages
# into the NEW session so the original history stays searchable.
new_session_id = tmp_agent.session_id
if new_session_id != session_entry.session_id:
session_entry.session_id = new_session_id
self.session_store._save()
self.session_store.rewrite_transcript(new_session_id, compressed)
# Reset stored token count — transcript changed, old value is stale
self.session_store.update_session(
session_entry.session_key, last_prompt_tokens=0
)
new_tokens = estimate_messages_tokens_rough(compressed)
summary = summarize_manual_compression(
msgs,
compressed,
approx_tokens,
new_tokens,
)
self.session_store.rewrite_transcript(new_session_id, compressed)
# Reset stored token count — transcript changed, old value is stale
self.session_store.update_session(
session_entry.session_key, last_prompt_tokens=0
)
new_tokens = estimate_messages_tokens_rough(compressed)
summary = summarize_manual_compression(
msgs,
compressed,
approx_tokens,
new_tokens,
)
finally:
self._cleanup_agent_resources(tmp_agent)
lines = [f"🗜️ {summary['headline']}"]
if focus_topic:
lines.append(f"Focus: \"{focus_topic}\"")

View file

@ -220,6 +220,8 @@ class TestRunBackgroundTask:
with patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "test-key"}), \
patch("run_agent.AIAgent") as MockAgent:
mock_agent_instance = MagicMock()
mock_agent_instance.shutdown_memory_provider = MagicMock()
mock_agent_instance.close = MagicMock()
mock_agent_instance.run_conversation.return_value = mock_result
MockAgent.return_value = mock_agent_instance
@ -231,6 +233,37 @@ class TestRunBackgroundTask:
content = call_args[1].get("content", call_args[0][1] if len(call_args[0]) > 1 else "")
assert "Background task complete" in content
assert "Hello from background!" in content
mock_agent_instance.shutdown_memory_provider.assert_called_once()
mock_agent_instance.close.assert_called_once()
@pytest.mark.asyncio
async def test_agent_cleanup_runs_when_background_agent_raises(self):
"""Temporary background agents must be cleaned up on error paths too."""
runner = _make_runner()
mock_adapter = AsyncMock()
mock_adapter.send = AsyncMock()
runner.adapters[Platform.TELEGRAM] = mock_adapter
source = SessionSource(
platform=Platform.TELEGRAM,
user_id="12345",
chat_id="67890",
user_name="testuser",
)
with patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "test-key"}), \
patch("run_agent.AIAgent") as MockAgent:
mock_agent_instance = MagicMock()
mock_agent_instance.shutdown_memory_provider = MagicMock()
mock_agent_instance.close = MagicMock()
mock_agent_instance.run_conversation.side_effect = RuntimeError("boom")
MockAgent.return_value = mock_agent_instance
await runner._run_background_task("say hello", source, "bg_test")
mock_adapter.send.assert_called_once()
mock_agent_instance.shutdown_memory_provider.assert_called_once()
mock_agent_instance.close.assert_called_once()
@pytest.mark.asyncio
async def test_exception_sends_error_message(self):

View file

@ -62,6 +62,8 @@ async def test_compress_command_reports_noop_without_success_banner():
history = _make_history()
runner = _make_runner(history)
agent_instance = MagicMock()
agent_instance.shutdown_memory_provider = MagicMock()
agent_instance.close = MagicMock()
agent_instance.context_compressor.protect_first_n = 0
agent_instance.context_compressor._align_boundary_forward.return_value = 0
agent_instance.context_compressor._find_tail_cut_by_tokens.return_value = 2
@ -83,6 +85,8 @@ async def test_compress_command_reports_noop_without_success_banner():
assert "No changes from compression" in result
assert "Compressed:" not in result
assert "Rough transcript estimate: ~100 tokens (unchanged)" in result
agent_instance.shutdown_memory_provider.assert_called_once()
agent_instance.close.assert_called_once()
@pytest.mark.asyncio
@ -95,6 +99,8 @@ async def test_compress_command_explains_when_token_estimate_rises():
]
runner = _make_runner(history)
agent_instance = MagicMock()
agent_instance.shutdown_memory_provider = MagicMock()
agent_instance.close = MagicMock()
agent_instance.context_compressor.protect_first_n = 0
agent_instance.context_compressor._align_boundary_forward.return_value = 0
agent_instance.context_compressor._find_tail_cut_by_tokens.return_value = 2
@ -119,3 +125,5 @@ async def test_compress_command_explains_when_token_estimate_rises():
assert "Compressed: 4 → 3 messages" in result
assert "Rough transcript estimate: ~100 → ~120 tokens" in result
assert "denser summaries" in result
agent_instance.shutdown_memory_provider.assert_called_once()
agent_instance.close.assert_called_once()

View file

@ -202,6 +202,22 @@ class TestFlushAgentSilenced:
sys.stdout = old_stdout
assert buf.getvalue() == "", "no-op print_fn spinner must not write to stdout"
def test_flush_agent_closes_resources_after_run(self, monkeypatch):
"""Memory flush should close temporary agent resources after the turn."""
runner, tmp_agent, _ = _make_flush_context(monkeypatch)
tmp_agent.shutdown_memory_provider = MagicMock()
tmp_agent.close = MagicMock()
with (
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(get_memory_dir=lambda: Path("/nonexistent"))}),
):
runner._flush_memories_for_session("session_cleanup")
tmp_agent.shutdown_memory_provider.assert_called_once()
tmp_agent.close.assert_called_once()
class TestFlushPromptStructure:
"""Verify the flush prompt retains its core instructions."""

View file

@ -305,10 +305,15 @@ async def test_session_hygiene_messages_stay_in_originating_topic(monkeypatch, t
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
class FakeCompressAgent:
last_instance = None
def __init__(self, **kwargs):
self.model = kwargs.get("model")
self.session_id = kwargs.get("session_id", "fake-session")
self._print_fn = None
self.shutdown_memory_provider = MagicMock()
self.close = MagicMock()
type(self).last_instance = self
def _compress_context(self, messages, *_args, **_kwargs):
# Simulate real _compress_context: create a new session_id
@ -385,3 +390,6 @@ async def test_session_hygiene_messages_stay_in_originating_topic(monkeypatch, t
# Compression warnings are no longer sent to users — compression
# happens silently with server-side logging only.
assert len(adapter.sent) == 0
assert FakeCompressAgent.last_instance is not None
FakeCompressAgent.last_instance.shutdown_memory_provider.assert_called_once()
FakeCompressAgent.last_instance.close.assert_called_once()