From 330ca4585ba101e2268288f5e949525b5ba00b43 Mon Sep 17 00:00:00 2001 From: bmoore210 <266365592+bmoore210@users.noreply.github.com> Date: Sun, 7 Jun 2026 01:44:48 -0700 Subject: [PATCH] fix: harden gateway startup and turn persistence Persist the inbound user turn before provider/tool execution so a crash before run_conversation() (e.g. provider/httpx client init failure) keeps the inbound message in the transcript. Repair stale/missing SSL_CERT_FILE state on gateway startup, and avoid duplicate gateway fallback writes. --- agent/conversation_loop.py | 13 ++ gateway/run.py | 51 +++++++- tests/gateway/test_ssl_cert_detection.py | 45 +++++++ tests/run_agent/test_413_compression.py | 55 +++++--- tests/run_agent/test_860_dedup.py | 152 ++++++++++++----------- 5 files changed, 227 insertions(+), 89 deletions(-) create mode 100644 tests/gateway/test_ssl_cert_detection.py diff --git a/agent/conversation_loop.py b/agent/conversation_loop.py index 660792feab6..330d37df270 100644 --- a/agent/conversation_loop.py +++ b/agent/conversation_loop.py @@ -600,6 +600,19 @@ def run_conversation( active_system_prompt = agent._cached_system_prompt + # Crash-resilience: persist the inbound user turn as soon as the session row + # has a valid system prompt, before any provider call or tool execution can + # hang/kill the process. The normal end-of-turn persist still runs later; + # _last_flushed_db_idx makes this idempotent and prevents duplicate rows. + try: + agent._persist_session(messages, conversation_history) + except Exception: + logger.warning( + "Early turn-start session persistence failed for session=%s", + agent.session_id or "none", + exc_info=True, + ) + # ── Preflight context compression ── # Before entering the main loop, check if the loaded conversation # history already exceeds the model's context threshold. This handles diff --git a/gateway/run.py b/gateway/run.py index 3a8fb06d598..cb93dce1c15 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -771,9 +771,23 @@ def _collect_auto_append_media_tags( # Must run BEFORE any HTTP library (discord, aiohttp, etc.) is imported. # --------------------------------------------------------------------------- def _ensure_ssl_certs() -> None: - """Set SSL_CERT_FILE if the system doesn't expose CA certs to Python.""" - if "SSL_CERT_FILE" in os.environ: - return # user already configured it + """Set SSL_CERT_FILE if the system doesn't expose CA certs to Python. + + Windows startup paths (Desktop, Scheduled Tasks, installer children) can + occasionally inherit a stale SSL_CERT_FILE. Returning just because the + variable is present makes every later httpx/OpenAI client construction fail + with FileNotFoundError from ssl.load_verify_locations(). Treat a missing + path as unset and fall back to certifi instead. + """ + configured_cert = os.environ.get("SSL_CERT_FILE") + if configured_cert: + if os.path.exists(configured_cert): + return # user already configured it to a real file + logging.getLogger(__name__).warning( + "Ignoring stale SSL_CERT_FILE=%r because the path does not exist", + configured_cert, + ) + os.environ.pop("SSL_CERT_FILE", None) import ssl @@ -9898,6 +9912,37 @@ class GatewayRunner: except Exception: pass logger.exception("Agent error in session %s", session_key) + # Crash-resilience for failures that happen before AIAgent enters + # run_conversation() (for example: provider/httpx client init + # failures). In that path the agent cannot persist the current + # inbound turn itself, so append the user message here once. If the + # agent already reached its early turn-start persistence, the latest + # transcript user row will match and we skip the duplicate. + try: + if 'message_text' in locals() and message_text is not None and session_entry is not None: + _already_persisted = False + try: + _recent_transcript = self.session_store.load_transcript(session_entry.session_id) + except Exception: + _recent_transcript = [] + for _msg in reversed(_recent_transcript[-10:]): + if _msg.get("role") == "user": + _already_persisted = (_msg.get("content") == message_text) + break + if not _already_persisted: + _user_entry = { + "role": "user", + "content": message_text, + "timestamp": datetime.now().isoformat(), + } + if getattr(event, "message_id", None): + _user_entry["message_id"] = str(event.message_id) + self.session_store.append_to_transcript( + session_entry.session_id, + _user_entry, + ) + except Exception: + logger.debug("Failed to persist inbound user message after agent exception", exc_info=True) error_type = type(e).__name__ error_detail = str(e)[:300] if str(e) else "no details available" status_hint = "" diff --git a/tests/gateway/test_ssl_cert_detection.py b/tests/gateway/test_ssl_cert_detection.py new file mode 100644 index 00000000000..b6704c382ae --- /dev/null +++ b/tests/gateway/test_ssl_cert_detection.py @@ -0,0 +1,45 @@ +"""Regression tests for gateway SSL certificate environment repair.""" + +from types import SimpleNamespace + + +def test_ensure_ssl_certs_ignores_stale_ssl_cert_file(monkeypatch, tmp_path): + """A missing SSL_CERT_FILE should be treated as unset, not trusted.""" + import ssl + import sys + + from gateway.run import _ensure_ssl_certs + + cert_file = tmp_path / "cacert.pem" + cert_file.write_text("dummy cert bundle", encoding="utf-8") + stale_file = tmp_path / "missing.pem" + + monkeypatch.setenv("SSL_CERT_FILE", str(stale_file)) + monkeypatch.setattr( + ssl, + "get_default_verify_paths", + lambda: SimpleNamespace(cafile=None, openssl_cafile=None), + ) + monkeypatch.setitem( + sys.modules, + "certifi", + SimpleNamespace(where=lambda: str(cert_file)), + ) + + _ensure_ssl_certs() + + assert stale_file.exists() is False + assert __import__("os").environ["SSL_CERT_FILE"] == str(cert_file) + + +def test_ensure_ssl_certs_keeps_existing_ssl_cert_file(monkeypatch, tmp_path): + """A valid user-provided SSL_CERT_FILE must not be overwritten.""" + from gateway.run import _ensure_ssl_certs + + cert_file = tmp_path / "existing.pem" + cert_file.write_text("dummy cert bundle", encoding="utf-8") + monkeypatch.setenv("SSL_CERT_FILE", str(cert_file)) + + _ensure_ssl_certs() + + assert __import__("os").environ["SSL_CERT_FILE"] == str(cert_file) \ No newline at end of file diff --git a/tests/run_agent/test_413_compression.py b/tests/run_agent/test_413_compression.py index 2b8c32e297b..939c3682b88 100644 --- a/tests/run_agent/test_413_compression.py +++ b/tests/run_agent/test_413_compression.py @@ -107,6 +107,40 @@ def agent(): # Tests # --------------------------------------------------------------------------- + +def test_current_user_turn_is_persisted_before_provider_call(agent): + """The inbound user turn is flushed before provider/tool work can crash.""" + observed = [] + + def _record_persist(messages, conversation_history): + observed.append(("persist", list(messages), list(conversation_history or []))) + + def _provider_crash(*_args, **_kwargs): + observed.append(("provider", [], [])) + raise RuntimeError("provider died after turn-start persistence") + + agent.client.chat.completions.create.side_effect = _provider_crash + + with ( + patch.object(agent, "_persist_session", side_effect=_record_persist), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + result = agent.run_conversation( + "new message that must survive a crash", + conversation_history=[{"role": "user", "content": "old message"}], + ) + + assert result.get("failed") is True + assert observed[0][0] == "persist" + assert observed[1][0] == "provider" + persisted_messages = observed[0][1] + assert persisted_messages[-1] == { + "role": "user", + "content": "new message that must survive a crash", + } + + class TestHTTP413Compression: """413 errors should trigger compression, not abort as generic 4xx.""" @@ -217,7 +251,7 @@ class TestHTTP413Compression: patch.object(agent, "_compress_context") as mock_compress, patch.object( agent, "_persist_session", - side_effect=lambda msgs, hist: persist_calls.append(hist), + side_effect=lambda msgs, hist: persist_calls.append((list(msgs), hist)), ), patch.object(agent, "_save_trajectory"), patch.object(agent, "_cleanup_task_resources"), @@ -228,12 +262,10 @@ class TestHTTP413Compression: ) agent.run_conversation("hello", conversation_history=big_history) - assert len(persist_calls) >= 1, "Expected at least one _persist_session call" - for hist in persist_calls: - assert hist is None, ( - f"conversation_history should be None after mid-loop compression, " - f"got list with {len(hist)} items" - ) + assert any(hist is None for _msgs, hist in persist_calls), ( + "Expected at least one post-compression _persist_session call " + "with conversation_history=None" + ) def test_context_overflow_clears_conversation_history_on_persist(self, agent): """After context-overflow compression, _persist_session must receive None history.""" @@ -256,7 +288,7 @@ class TestHTTP413Compression: patch.object(agent, "_compress_context") as mock_compress, patch.object( agent, "_persist_session", - side_effect=lambda msgs, hist: persist_calls.append(hist), + side_effect=lambda msgs, hist: persist_calls.append((list(msgs), hist)), ), patch.object(agent, "_save_trajectory"), patch.object(agent, "_cleanup_task_resources"), @@ -267,12 +299,7 @@ class TestHTTP413Compression: ) agent.run_conversation("hello", conversation_history=big_history) - assert len(persist_calls) >= 1 - for hist in persist_calls: - assert hist is None, ( - f"conversation_history should be None after context-overflow compression, " - f"got list with {len(hist)} items" - ) + assert any(hist is None for _msgs, hist in persist_calls) def test_400_context_length_triggers_compression(self, agent): """A 400 with 'maximum context length' should trigger compression, not abort as generic 4xx. diff --git a/tests/run_agent/test_860_dedup.py b/tests/run_agent/test_860_dedup.py index 39a7c0f3154..3a70f95adb7 100644 --- a/tests/run_agent/test_860_dedup.py +++ b/tests/run_agent/test_860_dedup.py @@ -46,28 +46,30 @@ class TestFlushDeduplication: with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "test.db" db = SessionDB(db_path=db_path) + try: + agent = self._make_agent(db) - agent = self._make_agent(db) + conversation_history = [ + {"role": "user", "content": "old message"}, + ] + messages = list(conversation_history) + [ + {"role": "user", "content": "new question"}, + {"role": "assistant", "content": "new answer"}, + ] - conversation_history = [ - {"role": "user", "content": "old message"}, - ] - messages = list(conversation_history) + [ - {"role": "user", "content": "new question"}, - {"role": "assistant", "content": "new answer"}, - ] + # First flush — should write 2 new messages + agent._flush_messages_to_session_db(messages, conversation_history) - # First flush — should write 2 new messages - agent._flush_messages_to_session_db(messages, conversation_history) + rows = db.get_messages(agent.session_id) + assert len(rows) == 2, f"Expected 2 messages, got {len(rows)}" - rows = db.get_messages(agent.session_id) - assert len(rows) == 2, f"Expected 2 messages, got {len(rows)}" + # Second flush with SAME messages — should write 0 new messages + agent._flush_messages_to_session_db(messages, conversation_history) - # Second flush with SAME messages — should write 0 new messages - agent._flush_messages_to_session_db(messages, conversation_history) - - rows = db.get_messages(agent.session_id) - assert len(rows) == 2, f"Expected still 2 messages after second flush, got {len(rows)}" + rows = db.get_messages(agent.session_id) + assert len(rows) == 2, f"Expected still 2 messages after second flush, got {len(rows)}" + finally: + db.close() def test_flush_writes_incrementally(self): """Messages added between flushes are written exactly once.""" @@ -76,27 +78,29 @@ class TestFlushDeduplication: with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "test.db" db = SessionDB(db_path=db_path) + try: + agent = self._make_agent(db) - agent = self._make_agent(db) + conversation_history = [] + messages = [ + {"role": "user", "content": "hello"}, + ] - conversation_history = [] - messages = [ - {"role": "user", "content": "hello"}, - ] + # First flush — 1 message + agent._flush_messages_to_session_db(messages, conversation_history) + rows = db.get_messages(agent.session_id) + assert len(rows) == 1 - # First flush — 1 message - agent._flush_messages_to_session_db(messages, conversation_history) - rows = db.get_messages(agent.session_id) - assert len(rows) == 1 + # Add more messages + messages.append({"role": "assistant", "content": "hi there"}) + messages.append({"role": "user", "content": "follow up"}) - # Add more messages - messages.append({"role": "assistant", "content": "hi there"}) - messages.append({"role": "user", "content": "follow up"}) - - # Second flush — should write only 2 new messages - agent._flush_messages_to_session_db(messages, conversation_history) - rows = db.get_messages(agent.session_id) - assert len(rows) == 3, f"Expected 3 total messages, got {len(rows)}" + # Second flush — should write only 2 new messages + agent._flush_messages_to_session_db(messages, conversation_history) + rows = db.get_messages(agent.session_id) + assert len(rows) == 3, f"Expected 3 total messages, got {len(rows)}" + finally: + db.close() def test_persist_session_multiple_calls_no_duplication(self): """Multiple _persist_session calls don't duplicate DB entries.""" @@ -105,23 +109,25 @@ class TestFlushDeduplication: with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "test.db" db = SessionDB(db_path=db_path) + try: + agent = self._make_agent(db) - agent = self._make_agent(db) + conversation_history = [{"role": "user", "content": "old"}] + messages = list(conversation_history) + [ + {"role": "user", "content": "q1"}, + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "q2"}, + {"role": "assistant", "content": "a2"}, + ] - conversation_history = [{"role": "user", "content": "old"}] - messages = list(conversation_history) + [ - {"role": "user", "content": "q1"}, - {"role": "assistant", "content": "a1"}, - {"role": "user", "content": "q2"}, - {"role": "assistant", "content": "a2"}, - ] + # Simulate multiple persist calls (like the agent's many exit paths) + for _ in range(5): + agent._persist_session(messages, conversation_history) - # Simulate multiple persist calls (like the agent's many exit paths) - for _ in range(5): - agent._persist_session(messages, conversation_history) - - rows = db.get_messages(agent.session_id) - assert len(rows) == 4, f"Expected 4 messages, got {len(rows)} (duplication bug!)" + rows = db.get_messages(agent.session_id) + assert len(rows) == 4, f"Expected 4 messages, got {len(rows)} (duplication bug!)" + finally: + db.close() def test_flush_reset_after_compression(self): """After compression creates a new session, flush index resets.""" @@ -130,36 +136,38 @@ class TestFlushDeduplication: with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "test.db" db = SessionDB(db_path=db_path) + try: + agent = self._make_agent(db) - agent = self._make_agent(db) + # Write some messages + messages = [ + {"role": "user", "content": "msg1"}, + {"role": "assistant", "content": "reply1"}, + ] + agent._flush_messages_to_session_db(messages, []) - # Write some messages - messages = [ - {"role": "user", "content": "msg1"}, - {"role": "assistant", "content": "reply1"}, - ] - agent._flush_messages_to_session_db(messages, []) + old_session = agent.session_id + assert agent._last_flushed_db_idx == 2 - old_session = agent.session_id - assert agent._last_flushed_db_idx == 2 + # Simulate what _compress_context does: new session, reset idx + agent.session_id = "compressed-session-new" + db.create_session(session_id=agent.session_id, source="test") + agent._last_flushed_db_idx = 0 - # Simulate what _compress_context does: new session, reset idx - agent.session_id = "compressed-session-new" - db.create_session(session_id=agent.session_id, source="test") - agent._last_flushed_db_idx = 0 + # Now flush compressed messages to new session + compressed_messages = [ + {"role": "user", "content": "summary of conversation"}, + ] + agent._flush_messages_to_session_db(compressed_messages, []) - # Now flush compressed messages to new session - compressed_messages = [ - {"role": "user", "content": "summary of conversation"}, - ] - agent._flush_messages_to_session_db(compressed_messages, []) + new_rows = db.get_messages(agent.session_id) + assert len(new_rows) == 1 - new_rows = db.get_messages(agent.session_id) - assert len(new_rows) == 1 - - # Old session should still have its 2 messages - old_rows = db.get_messages(old_session) - assert len(old_rows) == 2 + # Old session should still have its 2 messages + old_rows = db.get_messages(old_session) + assert len(old_rows) == 2 + finally: + db.close() # ---------------------------------------------------------------------------