From 16d9f58445e7960715585ea96d07908a1c7b5bdc Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Wed, 1 Apr 2026 12:05:02 -0700 Subject: [PATCH] fix(gateway): persist memory flush state to prevent redundant re-flushes on restart (#4481) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: force-close TCP sockets on client cleanup, detect and recover dead connections When a provider drops connections mid-stream (e.g. OpenRouter outage), httpx's graceful close leaves sockets in CLOSE-WAIT indefinitely. These zombie connections accumulate and can prevent recovery without restarting. Changes: - _force_close_tcp_sockets: walks the httpx connection pool and issues socket.shutdown(SHUT_RDWR) + close() to force TCP RST on every socket when a client is closed, preventing CLOSE-WAIT accumulation - _cleanup_dead_connections: probes the primary client's pool for dead sockets (recv MSG_PEEK), rebuilds the client if any are found - Pre-turn health check at the start of each run_conversation call that auto-recovers with a user-facing status message - Primary client rebuild after stale stream detection to purge pool - User-facing messages on streaming connection failures: "Connection to provider dropped — Reconnecting (attempt 2/3)" "Connection failed after 3 attempts — try again in a moment" Made-with: Cursor * fix: pool entry missing base_url for openrouter, clean error messages - _resolve_runtime_from_pool_entry: add OPENROUTER_BASE_URL fallback when pool entry has no runtime_base_url (pool entries from auth.json credential_pool often omit base_url) - Replace Rich console.print for auth errors with plain print() to prevent ANSI escape code mangling through prompt_toolkit's stdout patch - Force-close TCP sockets on client cleanup to prevent CLOSE-WAIT accumulation after provider outages - Pre-turn dead connection detection with auto-recovery and user message - Primary client rebuild after stale stream detection - User-facing status messages on streaming connection failures/retries Made-with: Cursor * fix(gateway): persist memory flush state to prevent redundant re-flushes on restart The _session_expiry_watcher tracked flushed sessions in an in-memory set (_pre_flushed_sessions) that was lost on gateway restart. Expired sessions remained in sessions.json and were re-discovered every restart, causing redundant AIAgent runs that burned API credits and blocked the event loop. Fix: Add a memory_flushed boolean field to SessionEntry, persisted in sessions.json. The watcher sets it after a successful flush. On restart, the flag survives and the watcher skips already-flushed sessions. - Add memory_flushed field to SessionEntry with to_dict/from_dict support - Old sessions.json entries without the field default to False (backward compat) - Remove the ephemeral _pre_flushed_sessions set from SessionStore - Update tests: save/load roundtrip, legacy entry compat, auto-reset behavior --- cli.py | 6 +- gateway/run.py | 14 +- gateway/session.py | 16 ++- hermes_cli/runtime_provider.py | 2 + run_agent.py | 174 ++++++++++++++++++++++- tests/gateway/test_async_memory_flush.py | 113 ++++++++++++--- 6 files changed, 290 insertions(+), 35 deletions(-) diff --git a/cli.py b/cli.py index b13317fe9..f7e45eded 100644 --- a/cli.py +++ b/cli.py @@ -1979,10 +1979,12 @@ class HermesCLI: base_url, _source, ) else: - self.console.print("[bold red]Provider resolver returned an empty API key.[/]") + print("\n⚠️ Provider resolver returned an empty API key. " + "Set OPENROUTER_API_KEY or run: hermes setup") return False if not isinstance(base_url, str) or not base_url: - self.console.print("[bold red]Provider resolver returned an empty base URL.[/]") + print("\n⚠️ Provider resolver returned an empty base URL. " + "Check your provider config or run: hermes setup") return False credentials_changed = api_key != self.api_key or base_url != self.base_url diff --git a/gateway/run.py b/gateway/run.py index 0a2dba7a1..b440ee71c 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1280,8 +1280,8 @@ class GatewayRunner: try: self.session_store._ensure_loaded() for key, entry in list(self.session_store._entries.items()): - if entry.session_id in self.session_store._pre_flushed_sessions: - continue # already flushed this session + if entry.memory_flushed: + continue # already flushed this session (persisted to disk) if not self.session_store._is_session_expired(entry): continue # session still active # Session has expired — flush memories in the background @@ -1292,7 +1292,15 @@ class GatewayRunner: try: await self._async_flush_memories(entry.session_id, key) self._shutdown_gateway_honcho(key) - self.session_store._pre_flushed_sessions.add(entry.session_id) + # Mark as flushed and persist to disk so the flag + # survives gateway restarts. + with self.session_store._lock: + entry.memory_flushed = True + self.session_store._save() + logger.info( + "Pre-reset memory flush completed for session %s", + entry.session_id, + ) except Exception as e: logger.debug("Proactive memory flush failed for %s: %s", entry.session_id, e) except Exception as e: diff --git a/gateway/session.py b/gateway/session.py index 5aefb6c01..fdf5cb6bb 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -364,6 +364,12 @@ class SessionEntry: auto_reset_reason: Optional[str] = None # "idle" or "daily" reset_had_activity: bool = False # whether the expired session had any messages + # Set by the background expiry watcher after it successfully flushes + # memories for this session. Persisted to sessions.json so the flag + # survives gateway restarts (the old in-memory _pre_flushed_sessions + # set was lost on restart, causing redundant re-flushes). + memory_flushed: bool = False + def to_dict(self) -> Dict[str, Any]: result = { "session_key": self.session_key, @@ -381,6 +387,7 @@ class SessionEntry: "last_prompt_tokens": self.last_prompt_tokens, "estimated_cost_usd": self.estimated_cost_usd, "cost_status": self.cost_status, + "memory_flushed": self.memory_flushed, } if self.origin: result["origin"] = self.origin.to_dict() @@ -416,6 +423,7 @@ class SessionEntry: last_prompt_tokens=data.get("last_prompt_tokens", 0), estimated_cost_usd=data.get("estimated_cost_usd", 0.0), cost_status=data.get("cost_status", "unknown"), + memory_flushed=data.get("memory_flushed", False), ) @@ -479,9 +487,6 @@ class SessionStore: self._loaded = False self._lock = threading.Lock() self._has_active_processes_fn = has_active_processes_fn - # on_auto_reset is deprecated — memory flush now runs proactively - # via the background session expiry watcher in GatewayRunner. - self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher # Initialize SQLite session database self._db = None @@ -684,15 +689,12 @@ class SessionStore: self._save() return entry else: - # Session is being auto-reset. The background expiry watcher - # should have already flushed memories proactively; discard - # the marker so it doesn't accumulate. + # Session is being auto-reset. was_auto_reset = True auto_reset_reason = reset_reason # Track whether the expired session had any real conversation reset_had_activity = entry.total_tokens > 0 db_end_session_id = entry.session_id - self._pre_flushed_sessions.discard(entry.session_id) else: was_auto_reset = False auto_reset_reason = None diff --git a/hermes_cli/runtime_provider.py b/hermes_cli/runtime_provider.py index bb5f4758a..aba5bb0cc 100644 --- a/hermes_cli/runtime_provider.py +++ b/hermes_cli/runtime_provider.py @@ -133,6 +133,8 @@ def _resolve_runtime_from_pool_entry( if cfg_provider == "anthropic": cfg_base_url = str(model_cfg.get("base_url") or "").strip().rstrip("/") base_url = cfg_base_url or base_url or "https://api.anthropic.com" + elif provider == "openrouter": + base_url = base_url or OPENROUTER_BASE_URL elif provider == "nous": api_mode = "chat_completions" elif provider == "copilot": diff --git a/run_agent.py b/run_agent.py index 5f77a2619..92ab62fde 100644 --- a/run_agent.py +++ b/run_agent.py @@ -3543,15 +3543,78 @@ class AIAgent: ) return client + @staticmethod + def _force_close_tcp_sockets(client: Any) -> int: + """Force-close underlying TCP sockets to prevent CLOSE-WAIT accumulation. + + When a provider drops a connection mid-stream, httpx's ``client.close()`` + performs a graceful shutdown which leaves sockets in CLOSE-WAIT until the + OS times them out (often minutes). This method walks the httpx transport + pool and issues ``socket.shutdown(SHUT_RDWR)`` + ``socket.close()`` to + force an immediate TCP RST, freeing the file descriptors. + + Returns the number of sockets force-closed. + """ + import socket as _socket + + closed = 0 + try: + http_client = getattr(client, "_client", None) + if http_client is None: + return 0 + transport = getattr(http_client, "_transport", None) + if transport is None: + return 0 + pool = getattr(transport, "_pool", None) + if pool is None: + return 0 + # httpx uses httpcore connection pools; connections live in + # _connections (list) or _pool (list) depending on version. + connections = ( + getattr(pool, "_connections", None) + or getattr(pool, "_pool", None) + or [] + ) + for conn in list(connections): + stream = ( + getattr(conn, "_network_stream", None) + or getattr(conn, "_stream", None) + ) + if stream is None: + continue + sock = getattr(stream, "_sock", None) + if sock is None: + sock = getattr(stream, "stream", None) + if sock is not None: + sock = getattr(sock, "_sock", None) + if sock is None: + continue + try: + sock.shutdown(_socket.SHUT_RDWR) + except OSError: + pass + try: + sock.close() + except OSError: + pass + closed += 1 + except Exception as exc: + logger.debug("Force-close TCP sockets sweep error: %s", exc) + return closed + def _close_openai_client(self, client: Any, *, reason: str, shared: bool) -> None: if client is None: return + # Force-close TCP sockets first to prevent CLOSE-WAIT accumulation, + # then do the graceful SDK-level close. + force_closed = self._force_close_tcp_sockets(client) try: client.close() logger.info( - "OpenAI client closed (%s, shared=%s) %s", + "OpenAI client closed (%s, shared=%s, tcp_force_closed=%d) %s", reason, shared, + force_closed, self._client_log_context(), ) except Exception as exc: @@ -3596,6 +3659,76 @@ class AIAgent: with self._openai_client_lock(): return self.client + def _cleanup_dead_connections(self) -> bool: + """Detect and clean up dead TCP connections on the primary client. + + Inspects the httpx connection pool for sockets in unhealthy states + (CLOSE-WAIT, errors). If any are found, force-closes all sockets + and rebuilds the primary client from scratch. + + Returns True if dead connections were found and cleaned up. + """ + client = getattr(self, "client", None) + if client is None: + return False + try: + http_client = getattr(client, "_client", None) + if http_client is None: + return False + transport = getattr(http_client, "_transport", None) + if transport is None: + return False + pool = getattr(transport, "_pool", None) + if pool is None: + return False + connections = ( + getattr(pool, "_connections", None) + or getattr(pool, "_pool", None) + or [] + ) + dead_count = 0 + for conn in list(connections): + # Check for connections that are idle but have closed sockets + stream = ( + getattr(conn, "_network_stream", None) + or getattr(conn, "_stream", None) + ) + if stream is None: + continue + sock = getattr(stream, "_sock", None) + if sock is None: + sock = getattr(stream, "stream", None) + if sock is not None: + sock = getattr(sock, "_sock", None) + if sock is None: + continue + # Probe socket health with a non-blocking recv peek + import socket as _socket + try: + sock.setblocking(False) + data = sock.recv(1, _socket.MSG_PEEK | _socket.MSG_DONTWAIT) + if data == b"": + dead_count += 1 + except BlockingIOError: + pass # No data available — socket is healthy + except OSError: + dead_count += 1 + finally: + try: + sock.setblocking(True) + except OSError: + pass + if dead_count > 0: + logger.warning( + "Found %d dead connection(s) in client pool — rebuilding client", + dead_count, + ) + self._replace_primary_openai_client(reason="dead_connection_cleanup") + return True + except Exception as exc: + logger.debug("Dead connection check error: %s", exc) + return False + def _create_request_openai_client(self, *, reason: str) -> Any: from unittest.mock import Mock @@ -4387,6 +4520,11 @@ class AIAgent: type(e).__name__, e, ) + self._emit_status( + f"⚠️ Connection to provider dropped " + f"({type(e).__name__}). Reconnecting… " + f"(attempt {_stream_attempt + 2}/{_max_stream_retries + 1})" + ) # Close the stale request client before retry stale = request_client_holder.get("client") if stale is not None: @@ -4394,7 +4532,21 @@ class AIAgent: stale, reason="stream_retry_cleanup" ) request_client_holder["client"] = None + # Also rebuild the primary client to purge + # any dead connections from the pool. + try: + self._replace_primary_openai_client( + reason="stream_retry_pool_cleanup" + ) + except Exception: + pass continue + self._emit_status( + "❌ Connection to provider failed after " + f"{_max_stream_retries + 1} attempts. " + "The provider may be experiencing issues — " + "try again in a moment." + ) logger.warning( "Streaming exhausted %s retries on transient error, " "falling back to non-streaming: %s", @@ -4466,6 +4618,12 @@ class AIAgent: self._close_request_openai_client(rc, reason="stale_stream_kill") except Exception: pass + # Rebuild the primary client too — its connection pool + # may hold dead sockets from the same provider outage. + try: + self._replace_primary_openai_client(reason="stale_stream_pool_cleanup") + except Exception: + pass # Reset the timer so we don't kill repeatedly while # the inner thread processes the closure. last_chunk_time["t"] = time.time() @@ -6254,6 +6412,20 @@ class AIAgent: self._last_content_with_tools = None self._mute_post_response = False self._surrogate_sanitized = False + + # Pre-turn connection health check: detect and clean up dead TCP + # connections left over from provider outages or dropped streams. + # This prevents the next API call from hanging on a zombie socket. + if self.api_mode != "anthropic_messages": + try: + if self._cleanup_dead_connections(): + self._emit_status( + "🔌 Detected stale connections from a previous provider " + "issue — cleaned up automatically. Proceeding with fresh " + "connection." + ) + except Exception: + pass # NOTE: _turns_since_memory and _iters_since_skill are NOT reset here. # They are initialized in __init__ and must persist across run_conversation # calls so that nudge logic accumulates correctly in CLI mode. diff --git a/tests/gateway/test_async_memory_flush.py b/tests/gateway/test_async_memory_flush.py index 675746920..0d7319490 100644 --- a/tests/gateway/test_async_memory_flush.py +++ b/tests/gateway/test_async_memory_flush.py @@ -3,7 +3,7 @@ Verifies that: 1. _is_session_expired() works from a SessionEntry alone (no source needed) 2. The sync callback is no longer called in get_or_create_session -3. _pre_flushed_sessions tracking works correctly +3. memory_flushed flag persists across save/load cycles (prevents restart re-flush) 4. The background watcher can detect expired sessions """ @@ -115,8 +115,8 @@ class TestIsSessionExpired: class TestGetOrCreateSessionNoCallback: """get_or_create_session should NOT call a sync flush callback.""" - def test_auto_reset_cleans_pre_flushed_marker(self, idle_store): - """When a session auto-resets, the pre_flushed marker should be discarded.""" + def test_auto_reset_creates_new_session_after_flush(self, idle_store): + """When a flushed session auto-resets, a new session_id is created.""" source = SessionSource( platform=Platform.TELEGRAM, chat_id="123", @@ -127,7 +127,7 @@ class TestGetOrCreateSessionNoCallback: old_sid = entry1.session_id # Simulate the watcher having flushed it - idle_store._pre_flushed_sessions.add(old_sid) + entry1.memory_flushed = True # Simulate the session going idle entry1.updated_at = datetime.now() - timedelta(minutes=120) @@ -137,9 +137,8 @@ class TestGetOrCreateSessionNoCallback: entry2 = idle_store.get_or_create_session(source) assert entry2.session_id != old_sid assert entry2.was_auto_reset is True - - # The old session_id should be removed from pre_flushed - assert old_sid not in idle_store._pre_flushed_sessions + # New session starts with memory_flushed=False + assert entry2.memory_flushed is False def test_no_sync_callback_invoked(self, idle_store): """No synchronous callback should block during auto-reset.""" @@ -160,21 +159,91 @@ class TestGetOrCreateSessionNoCallback: assert entry2.was_auto_reset is True -class TestPreFlushedSessionsTracking: - """The _pre_flushed_sessions set should prevent double-flushing.""" +class TestMemoryFlushedFlag: + """The memory_flushed flag on SessionEntry prevents double-flushing.""" - def test_starts_empty(self, idle_store): - assert len(idle_store._pre_flushed_sessions) == 0 + def test_defaults_to_false(self): + entry = SessionEntry( + session_key="agent:main:telegram:dm:123", + session_id="sid_new", + created_at=datetime.now(), + updated_at=datetime.now(), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + assert entry.memory_flushed is False - def test_add_and_check(self, idle_store): - idle_store._pre_flushed_sessions.add("sid_old") - assert "sid_old" in idle_store._pre_flushed_sessions - assert "sid_other" not in idle_store._pre_flushed_sessions + def test_persists_through_save_load(self, idle_store): + """memory_flushed=True must survive a save/load cycle (simulates restart).""" + key = "agent:main:discord:thread:789" + entry = SessionEntry( + session_key=key, + session_id="sid_flushed", + created_at=datetime.now() - timedelta(hours=5), + updated_at=datetime.now() - timedelta(hours=5), + platform=Platform.DISCORD, + chat_type="thread", + memory_flushed=True, + ) + idle_store._entries[key] = entry + idle_store._save() - def test_discard_on_reset(self, idle_store): - """discard should remove without raising if not present.""" - idle_store._pre_flushed_sessions.add("sid_a") - idle_store._pre_flushed_sessions.discard("sid_a") - assert "sid_a" not in idle_store._pre_flushed_sessions - # discard on non-existent should not raise - idle_store._pre_flushed_sessions.discard("sid_nonexistent") + # Simulate restart: clear in-memory state, reload from disk + idle_store._entries.clear() + idle_store._loaded = False + idle_store._ensure_loaded() + + reloaded = idle_store._entries[key] + assert reloaded.memory_flushed is True + + def test_unflushed_entry_survives_restart_as_unflushed(self, idle_store): + """An entry without memory_flushed stays False after reload.""" + key = "agent:main:telegram:dm:456" + entry = SessionEntry( + session_key=key, + session_id="sid_not_flushed", + created_at=datetime.now() - timedelta(hours=2), + updated_at=datetime.now() - timedelta(hours=2), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + idle_store._entries[key] = entry + idle_store._save() + + idle_store._entries.clear() + idle_store._loaded = False + idle_store._ensure_loaded() + + reloaded = idle_store._entries[key] + assert reloaded.memory_flushed is False + + def test_roundtrip_to_dict_from_dict(self): + """to_dict/from_dict must preserve memory_flushed.""" + entry = SessionEntry( + session_key="agent:main:telegram:dm:999", + session_id="sid_rt", + created_at=datetime.now(), + updated_at=datetime.now(), + platform=Platform.TELEGRAM, + chat_type="dm", + memory_flushed=True, + ) + d = entry.to_dict() + assert d["memory_flushed"] is True + + restored = SessionEntry.from_dict(d) + assert restored.memory_flushed is True + + def test_legacy_entry_without_field_defaults_false(self): + """Old sessions.json entries missing memory_flushed should default to False.""" + data = { + "session_key": "agent:main:telegram:dm:legacy", + "session_id": "sid_legacy", + "created_at": datetime.now().isoformat(), + "updated_at": datetime.now().isoformat(), + "platform": "telegram", + "chat_type": "dm", + # no memory_flushed key + } + entry = SessionEntry.from_dict(data) + assert entry.memory_flushed is False