fix: harden gateway startup and turn persistence

Persist the inbound user turn before provider/tool execution so a crash
before run_conversation() (e.g. provider/httpx client init failure) keeps
the inbound message in the transcript. Repair stale/missing SSL_CERT_FILE
state on gateway startup, and avoid duplicate gateway fallback writes.
This commit is contained in:
bmoore210 2026-06-07 01:44:48 -07:00 committed by Teknium
parent 591e6fb8f4
commit 330ca4585b
5 changed files with 227 additions and 89 deletions

View file

@ -600,6 +600,19 @@ def run_conversation(
active_system_prompt = agent._cached_system_prompt
# Crash-resilience: persist the inbound user turn as soon as the session row
# has a valid system prompt, before any provider call or tool execution can
# hang/kill the process. The normal end-of-turn persist still runs later;
# _last_flushed_db_idx makes this idempotent and prevents duplicate rows.
try:
agent._persist_session(messages, conversation_history)
except Exception:
logger.warning(
"Early turn-start session persistence failed for session=%s",
agent.session_id or "none",
exc_info=True,
)
# ── Preflight context compression ──
# Before entering the main loop, check if the loaded conversation
# history already exceeds the model's context threshold. This handles

View file

@ -771,9 +771,23 @@ def _collect_auto_append_media_tags(
# Must run BEFORE any HTTP library (discord, aiohttp, etc.) is imported.
# ---------------------------------------------------------------------------
def _ensure_ssl_certs() -> None:
"""Set SSL_CERT_FILE if the system doesn't expose CA certs to Python."""
if "SSL_CERT_FILE" in os.environ:
return # user already configured it
"""Set SSL_CERT_FILE if the system doesn't expose CA certs to Python.
Windows startup paths (Desktop, Scheduled Tasks, installer children) can
occasionally inherit a stale SSL_CERT_FILE. Returning just because the
variable is present makes every later httpx/OpenAI client construction fail
with FileNotFoundError from ssl.load_verify_locations(). Treat a missing
path as unset and fall back to certifi instead.
"""
configured_cert = os.environ.get("SSL_CERT_FILE")
if configured_cert:
if os.path.exists(configured_cert):
return # user already configured it to a real file
logging.getLogger(__name__).warning(
"Ignoring stale SSL_CERT_FILE=%r because the path does not exist",
configured_cert,
)
os.environ.pop("SSL_CERT_FILE", None)
import ssl
@ -9898,6 +9912,37 @@ class GatewayRunner:
except Exception:
pass
logger.exception("Agent error in session %s", session_key)
# Crash-resilience for failures that happen before AIAgent enters
# run_conversation() (for example: provider/httpx client init
# failures). In that path the agent cannot persist the current
# inbound turn itself, so append the user message here once. If the
# agent already reached its early turn-start persistence, the latest
# transcript user row will match and we skip the duplicate.
try:
if 'message_text' in locals() and message_text is not None and session_entry is not None:
_already_persisted = False
try:
_recent_transcript = self.session_store.load_transcript(session_entry.session_id)
except Exception:
_recent_transcript = []
for _msg in reversed(_recent_transcript[-10:]):
if _msg.get("role") == "user":
_already_persisted = (_msg.get("content") == message_text)
break
if not _already_persisted:
_user_entry = {
"role": "user",
"content": message_text,
"timestamp": datetime.now().isoformat(),
}
if getattr(event, "message_id", None):
_user_entry["message_id"] = str(event.message_id)
self.session_store.append_to_transcript(
session_entry.session_id,
_user_entry,
)
except Exception:
logger.debug("Failed to persist inbound user message after agent exception", exc_info=True)
error_type = type(e).__name__
error_detail = str(e)[:300] if str(e) else "no details available"
status_hint = ""

View file

@ -0,0 +1,45 @@
"""Regression tests for gateway SSL certificate environment repair."""
from types import SimpleNamespace
def test_ensure_ssl_certs_ignores_stale_ssl_cert_file(monkeypatch, tmp_path):
"""A missing SSL_CERT_FILE should be treated as unset, not trusted."""
import ssl
import sys
from gateway.run import _ensure_ssl_certs
cert_file = tmp_path / "cacert.pem"
cert_file.write_text("dummy cert bundle", encoding="utf-8")
stale_file = tmp_path / "missing.pem"
monkeypatch.setenv("SSL_CERT_FILE", str(stale_file))
monkeypatch.setattr(
ssl,
"get_default_verify_paths",
lambda: SimpleNamespace(cafile=None, openssl_cafile=None),
)
monkeypatch.setitem(
sys.modules,
"certifi",
SimpleNamespace(where=lambda: str(cert_file)),
)
_ensure_ssl_certs()
assert stale_file.exists() is False
assert __import__("os").environ["SSL_CERT_FILE"] == str(cert_file)
def test_ensure_ssl_certs_keeps_existing_ssl_cert_file(monkeypatch, tmp_path):
"""A valid user-provided SSL_CERT_FILE must not be overwritten."""
from gateway.run import _ensure_ssl_certs
cert_file = tmp_path / "existing.pem"
cert_file.write_text("dummy cert bundle", encoding="utf-8")
monkeypatch.setenv("SSL_CERT_FILE", str(cert_file))
_ensure_ssl_certs()
assert __import__("os").environ["SSL_CERT_FILE"] == str(cert_file)

View file

@ -107,6 +107,40 @@ def agent():
# Tests
# ---------------------------------------------------------------------------
def test_current_user_turn_is_persisted_before_provider_call(agent):
"""The inbound user turn is flushed before provider/tool work can crash."""
observed = []
def _record_persist(messages, conversation_history):
observed.append(("persist", list(messages), list(conversation_history or [])))
def _provider_crash(*_args, **_kwargs):
observed.append(("provider", [], []))
raise RuntimeError("provider died after turn-start persistence")
agent.client.chat.completions.create.side_effect = _provider_crash
with (
patch.object(agent, "_persist_session", side_effect=_record_persist),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
result = agent.run_conversation(
"new message that must survive a crash",
conversation_history=[{"role": "user", "content": "old message"}],
)
assert result.get("failed") is True
assert observed[0][0] == "persist"
assert observed[1][0] == "provider"
persisted_messages = observed[0][1]
assert persisted_messages[-1] == {
"role": "user",
"content": "new message that must survive a crash",
}
class TestHTTP413Compression:
"""413 errors should trigger compression, not abort as generic 4xx."""
@ -217,7 +251,7 @@ class TestHTTP413Compression:
patch.object(agent, "_compress_context") as mock_compress,
patch.object(
agent, "_persist_session",
side_effect=lambda msgs, hist: persist_calls.append(hist),
side_effect=lambda msgs, hist: persist_calls.append((list(msgs), hist)),
),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
@ -228,12 +262,10 @@ class TestHTTP413Compression:
)
agent.run_conversation("hello", conversation_history=big_history)
assert len(persist_calls) >= 1, "Expected at least one _persist_session call"
for hist in persist_calls:
assert hist is None, (
f"conversation_history should be None after mid-loop compression, "
f"got list with {len(hist)} items"
)
assert any(hist is None for _msgs, hist in persist_calls), (
"Expected at least one post-compression _persist_session call "
"with conversation_history=None"
)
def test_context_overflow_clears_conversation_history_on_persist(self, agent):
"""After context-overflow compression, _persist_session must receive None history."""
@ -256,7 +288,7 @@ class TestHTTP413Compression:
patch.object(agent, "_compress_context") as mock_compress,
patch.object(
agent, "_persist_session",
side_effect=lambda msgs, hist: persist_calls.append(hist),
side_effect=lambda msgs, hist: persist_calls.append((list(msgs), hist)),
),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
@ -267,12 +299,7 @@ class TestHTTP413Compression:
)
agent.run_conversation("hello", conversation_history=big_history)
assert len(persist_calls) >= 1
for hist in persist_calls:
assert hist is None, (
f"conversation_history should be None after context-overflow compression, "
f"got list with {len(hist)} items"
)
assert any(hist is None for _msgs, hist in persist_calls)
def test_400_context_length_triggers_compression(self, agent):
"""A 400 with 'maximum context length' should trigger compression, not abort as generic 4xx.

View file

@ -46,28 +46,30 @@ class TestFlushDeduplication:
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.db"
db = SessionDB(db_path=db_path)
try:
agent = self._make_agent(db)
agent = self._make_agent(db)
conversation_history = [
{"role": "user", "content": "old message"},
]
messages = list(conversation_history) + [
{"role": "user", "content": "new question"},
{"role": "assistant", "content": "new answer"},
]
conversation_history = [
{"role": "user", "content": "old message"},
]
messages = list(conversation_history) + [
{"role": "user", "content": "new question"},
{"role": "assistant", "content": "new answer"},
]
# First flush — should write 2 new messages
agent._flush_messages_to_session_db(messages, conversation_history)
# First flush — should write 2 new messages
agent._flush_messages_to_session_db(messages, conversation_history)
rows = db.get_messages(agent.session_id)
assert len(rows) == 2, f"Expected 2 messages, got {len(rows)}"
rows = db.get_messages(agent.session_id)
assert len(rows) == 2, f"Expected 2 messages, got {len(rows)}"
# Second flush with SAME messages — should write 0 new messages
agent._flush_messages_to_session_db(messages, conversation_history)
# Second flush with SAME messages — should write 0 new messages
agent._flush_messages_to_session_db(messages, conversation_history)
rows = db.get_messages(agent.session_id)
assert len(rows) == 2, f"Expected still 2 messages after second flush, got {len(rows)}"
rows = db.get_messages(agent.session_id)
assert len(rows) == 2, f"Expected still 2 messages after second flush, got {len(rows)}"
finally:
db.close()
def test_flush_writes_incrementally(self):
"""Messages added between flushes are written exactly once."""
@ -76,27 +78,29 @@ class TestFlushDeduplication:
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.db"
db = SessionDB(db_path=db_path)
try:
agent = self._make_agent(db)
agent = self._make_agent(db)
conversation_history = []
messages = [
{"role": "user", "content": "hello"},
]
conversation_history = []
messages = [
{"role": "user", "content": "hello"},
]
# First flush — 1 message
agent._flush_messages_to_session_db(messages, conversation_history)
rows = db.get_messages(agent.session_id)
assert len(rows) == 1
# First flush — 1 message
agent._flush_messages_to_session_db(messages, conversation_history)
rows = db.get_messages(agent.session_id)
assert len(rows) == 1
# Add more messages
messages.append({"role": "assistant", "content": "hi there"})
messages.append({"role": "user", "content": "follow up"})
# Add more messages
messages.append({"role": "assistant", "content": "hi there"})
messages.append({"role": "user", "content": "follow up"})
# Second flush — should write only 2 new messages
agent._flush_messages_to_session_db(messages, conversation_history)
rows = db.get_messages(agent.session_id)
assert len(rows) == 3, f"Expected 3 total messages, got {len(rows)}"
# Second flush — should write only 2 new messages
agent._flush_messages_to_session_db(messages, conversation_history)
rows = db.get_messages(agent.session_id)
assert len(rows) == 3, f"Expected 3 total messages, got {len(rows)}"
finally:
db.close()
def test_persist_session_multiple_calls_no_duplication(self):
"""Multiple _persist_session calls don't duplicate DB entries."""
@ -105,23 +109,25 @@ class TestFlushDeduplication:
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.db"
db = SessionDB(db_path=db_path)
try:
agent = self._make_agent(db)
agent = self._make_agent(db)
conversation_history = [{"role": "user", "content": "old"}]
messages = list(conversation_history) + [
{"role": "user", "content": "q1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "q2"},
{"role": "assistant", "content": "a2"},
]
conversation_history = [{"role": "user", "content": "old"}]
messages = list(conversation_history) + [
{"role": "user", "content": "q1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "q2"},
{"role": "assistant", "content": "a2"},
]
# Simulate multiple persist calls (like the agent's many exit paths)
for _ in range(5):
agent._persist_session(messages, conversation_history)
# Simulate multiple persist calls (like the agent's many exit paths)
for _ in range(5):
agent._persist_session(messages, conversation_history)
rows = db.get_messages(agent.session_id)
assert len(rows) == 4, f"Expected 4 messages, got {len(rows)} (duplication bug!)"
rows = db.get_messages(agent.session_id)
assert len(rows) == 4, f"Expected 4 messages, got {len(rows)} (duplication bug!)"
finally:
db.close()
def test_flush_reset_after_compression(self):
"""After compression creates a new session, flush index resets."""
@ -130,36 +136,38 @@ class TestFlushDeduplication:
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.db"
db = SessionDB(db_path=db_path)
try:
agent = self._make_agent(db)
agent = self._make_agent(db)
# Write some messages
messages = [
{"role": "user", "content": "msg1"},
{"role": "assistant", "content": "reply1"},
]
agent._flush_messages_to_session_db(messages, [])
# Write some messages
messages = [
{"role": "user", "content": "msg1"},
{"role": "assistant", "content": "reply1"},
]
agent._flush_messages_to_session_db(messages, [])
old_session = agent.session_id
assert agent._last_flushed_db_idx == 2
old_session = agent.session_id
assert agent._last_flushed_db_idx == 2
# Simulate what _compress_context does: new session, reset idx
agent.session_id = "compressed-session-new"
db.create_session(session_id=agent.session_id, source="test")
agent._last_flushed_db_idx = 0
# Simulate what _compress_context does: new session, reset idx
agent.session_id = "compressed-session-new"
db.create_session(session_id=agent.session_id, source="test")
agent._last_flushed_db_idx = 0
# Now flush compressed messages to new session
compressed_messages = [
{"role": "user", "content": "summary of conversation"},
]
agent._flush_messages_to_session_db(compressed_messages, [])
# Now flush compressed messages to new session
compressed_messages = [
{"role": "user", "content": "summary of conversation"},
]
agent._flush_messages_to_session_db(compressed_messages, [])
new_rows = db.get_messages(agent.session_id)
assert len(new_rows) == 1
new_rows = db.get_messages(agent.session_id)
assert len(new_rows) == 1
# Old session should still have its 2 messages
old_rows = db.get_messages(old_session)
assert len(old_rows) == 2
# Old session should still have its 2 messages
old_rows = db.get_messages(old_session)
assert len(old_rows) == 2
finally:
db.close()
# ---------------------------------------------------------------------------