mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
fix: harden gateway startup and turn persistence
Persist the inbound user turn before provider/tool execution so a crash before run_conversation() (e.g. provider/httpx client init failure) keeps the inbound message in the transcript. Repair stale/missing SSL_CERT_FILE state on gateway startup, and avoid duplicate gateway fallback writes.
This commit is contained in:
parent
591e6fb8f4
commit
330ca4585b
5 changed files with 227 additions and 89 deletions
|
|
@ -600,6 +600,19 @@ def run_conversation(
|
|||
|
||||
active_system_prompt = agent._cached_system_prompt
|
||||
|
||||
# Crash-resilience: persist the inbound user turn as soon as the session row
|
||||
# has a valid system prompt, before any provider call or tool execution can
|
||||
# hang/kill the process. The normal end-of-turn persist still runs later;
|
||||
# _last_flushed_db_idx makes this idempotent and prevents duplicate rows.
|
||||
try:
|
||||
agent._persist_session(messages, conversation_history)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Early turn-start session persistence failed for session=%s",
|
||||
agent.session_id or "none",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ── Preflight context compression ──
|
||||
# Before entering the main loop, check if the loaded conversation
|
||||
# history already exceeds the model's context threshold. This handles
|
||||
|
|
|
|||
|
|
@ -771,9 +771,23 @@ def _collect_auto_append_media_tags(
|
|||
# Must run BEFORE any HTTP library (discord, aiohttp, etc.) is imported.
|
||||
# ---------------------------------------------------------------------------
|
||||
def _ensure_ssl_certs() -> None:
|
||||
"""Set SSL_CERT_FILE if the system doesn't expose CA certs to Python."""
|
||||
if "SSL_CERT_FILE" in os.environ:
|
||||
return # user already configured it
|
||||
"""Set SSL_CERT_FILE if the system doesn't expose CA certs to Python.
|
||||
|
||||
Windows startup paths (Desktop, Scheduled Tasks, installer children) can
|
||||
occasionally inherit a stale SSL_CERT_FILE. Returning just because the
|
||||
variable is present makes every later httpx/OpenAI client construction fail
|
||||
with FileNotFoundError from ssl.load_verify_locations(). Treat a missing
|
||||
path as unset and fall back to certifi instead.
|
||||
"""
|
||||
configured_cert = os.environ.get("SSL_CERT_FILE")
|
||||
if configured_cert:
|
||||
if os.path.exists(configured_cert):
|
||||
return # user already configured it to a real file
|
||||
logging.getLogger(__name__).warning(
|
||||
"Ignoring stale SSL_CERT_FILE=%r because the path does not exist",
|
||||
configured_cert,
|
||||
)
|
||||
os.environ.pop("SSL_CERT_FILE", None)
|
||||
|
||||
import ssl
|
||||
|
||||
|
|
@ -9898,6 +9912,37 @@ class GatewayRunner:
|
|||
except Exception:
|
||||
pass
|
||||
logger.exception("Agent error in session %s", session_key)
|
||||
# Crash-resilience for failures that happen before AIAgent enters
|
||||
# run_conversation() (for example: provider/httpx client init
|
||||
# failures). In that path the agent cannot persist the current
|
||||
# inbound turn itself, so append the user message here once. If the
|
||||
# agent already reached its early turn-start persistence, the latest
|
||||
# transcript user row will match and we skip the duplicate.
|
||||
try:
|
||||
if 'message_text' in locals() and message_text is not None and session_entry is not None:
|
||||
_already_persisted = False
|
||||
try:
|
||||
_recent_transcript = self.session_store.load_transcript(session_entry.session_id)
|
||||
except Exception:
|
||||
_recent_transcript = []
|
||||
for _msg in reversed(_recent_transcript[-10:]):
|
||||
if _msg.get("role") == "user":
|
||||
_already_persisted = (_msg.get("content") == message_text)
|
||||
break
|
||||
if not _already_persisted:
|
||||
_user_entry = {
|
||||
"role": "user",
|
||||
"content": message_text,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
if getattr(event, "message_id", None):
|
||||
_user_entry["message_id"] = str(event.message_id)
|
||||
self.session_store.append_to_transcript(
|
||||
session_entry.session_id,
|
||||
_user_entry,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to persist inbound user message after agent exception", exc_info=True)
|
||||
error_type = type(e).__name__
|
||||
error_detail = str(e)[:300] if str(e) else "no details available"
|
||||
status_hint = ""
|
||||
|
|
|
|||
45
tests/gateway/test_ssl_cert_detection.py
Normal file
45
tests/gateway/test_ssl_cert_detection.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
"""Regression tests for gateway SSL certificate environment repair."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
||||
def test_ensure_ssl_certs_ignores_stale_ssl_cert_file(monkeypatch, tmp_path):
|
||||
"""A missing SSL_CERT_FILE should be treated as unset, not trusted."""
|
||||
import ssl
|
||||
import sys
|
||||
|
||||
from gateway.run import _ensure_ssl_certs
|
||||
|
||||
cert_file = tmp_path / "cacert.pem"
|
||||
cert_file.write_text("dummy cert bundle", encoding="utf-8")
|
||||
stale_file = tmp_path / "missing.pem"
|
||||
|
||||
monkeypatch.setenv("SSL_CERT_FILE", str(stale_file))
|
||||
monkeypatch.setattr(
|
||||
ssl,
|
||||
"get_default_verify_paths",
|
||||
lambda: SimpleNamespace(cafile=None, openssl_cafile=None),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"certifi",
|
||||
SimpleNamespace(where=lambda: str(cert_file)),
|
||||
)
|
||||
|
||||
_ensure_ssl_certs()
|
||||
|
||||
assert stale_file.exists() is False
|
||||
assert __import__("os").environ["SSL_CERT_FILE"] == str(cert_file)
|
||||
|
||||
|
||||
def test_ensure_ssl_certs_keeps_existing_ssl_cert_file(monkeypatch, tmp_path):
|
||||
"""A valid user-provided SSL_CERT_FILE must not be overwritten."""
|
||||
from gateway.run import _ensure_ssl_certs
|
||||
|
||||
cert_file = tmp_path / "existing.pem"
|
||||
cert_file.write_text("dummy cert bundle", encoding="utf-8")
|
||||
monkeypatch.setenv("SSL_CERT_FILE", str(cert_file))
|
||||
|
||||
_ensure_ssl_certs()
|
||||
|
||||
assert __import__("os").environ["SSL_CERT_FILE"] == str(cert_file)
|
||||
|
|
@ -107,6 +107,40 @@ def agent():
|
|||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_current_user_turn_is_persisted_before_provider_call(agent):
|
||||
"""The inbound user turn is flushed before provider/tool work can crash."""
|
||||
observed = []
|
||||
|
||||
def _record_persist(messages, conversation_history):
|
||||
observed.append(("persist", list(messages), list(conversation_history or [])))
|
||||
|
||||
def _provider_crash(*_args, **_kwargs):
|
||||
observed.append(("provider", [], []))
|
||||
raise RuntimeError("provider died after turn-start persistence")
|
||||
|
||||
agent.client.chat.completions.create.side_effect = _provider_crash
|
||||
|
||||
with (
|
||||
patch.object(agent, "_persist_session", side_effect=_record_persist),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
result = agent.run_conversation(
|
||||
"new message that must survive a crash",
|
||||
conversation_history=[{"role": "user", "content": "old message"}],
|
||||
)
|
||||
|
||||
assert result.get("failed") is True
|
||||
assert observed[0][0] == "persist"
|
||||
assert observed[1][0] == "provider"
|
||||
persisted_messages = observed[0][1]
|
||||
assert persisted_messages[-1] == {
|
||||
"role": "user",
|
||||
"content": "new message that must survive a crash",
|
||||
}
|
||||
|
||||
|
||||
class TestHTTP413Compression:
|
||||
"""413 errors should trigger compression, not abort as generic 4xx."""
|
||||
|
||||
|
|
@ -217,7 +251,7 @@ class TestHTTP413Compression:
|
|||
patch.object(agent, "_compress_context") as mock_compress,
|
||||
patch.object(
|
||||
agent, "_persist_session",
|
||||
side_effect=lambda msgs, hist: persist_calls.append(hist),
|
||||
side_effect=lambda msgs, hist: persist_calls.append((list(msgs), hist)),
|
||||
),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
|
|
@ -228,12 +262,10 @@ class TestHTTP413Compression:
|
|||
)
|
||||
agent.run_conversation("hello", conversation_history=big_history)
|
||||
|
||||
assert len(persist_calls) >= 1, "Expected at least one _persist_session call"
|
||||
for hist in persist_calls:
|
||||
assert hist is None, (
|
||||
f"conversation_history should be None after mid-loop compression, "
|
||||
f"got list with {len(hist)} items"
|
||||
)
|
||||
assert any(hist is None for _msgs, hist in persist_calls), (
|
||||
"Expected at least one post-compression _persist_session call "
|
||||
"with conversation_history=None"
|
||||
)
|
||||
|
||||
def test_context_overflow_clears_conversation_history_on_persist(self, agent):
|
||||
"""After context-overflow compression, _persist_session must receive None history."""
|
||||
|
|
@ -256,7 +288,7 @@ class TestHTTP413Compression:
|
|||
patch.object(agent, "_compress_context") as mock_compress,
|
||||
patch.object(
|
||||
agent, "_persist_session",
|
||||
side_effect=lambda msgs, hist: persist_calls.append(hist),
|
||||
side_effect=lambda msgs, hist: persist_calls.append((list(msgs), hist)),
|
||||
),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
|
|
@ -267,12 +299,7 @@ class TestHTTP413Compression:
|
|||
)
|
||||
agent.run_conversation("hello", conversation_history=big_history)
|
||||
|
||||
assert len(persist_calls) >= 1
|
||||
for hist in persist_calls:
|
||||
assert hist is None, (
|
||||
f"conversation_history should be None after context-overflow compression, "
|
||||
f"got list with {len(hist)} items"
|
||||
)
|
||||
assert any(hist is None for _msgs, hist in persist_calls)
|
||||
|
||||
def test_400_context_length_triggers_compression(self, agent):
|
||||
"""A 400 with 'maximum context length' should trigger compression, not abort as generic 4xx.
|
||||
|
|
|
|||
|
|
@ -46,28 +46,30 @@ class TestFlushDeduplication:
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test.db"
|
||||
db = SessionDB(db_path=db_path)
|
||||
try:
|
||||
agent = self._make_agent(db)
|
||||
|
||||
agent = self._make_agent(db)
|
||||
conversation_history = [
|
||||
{"role": "user", "content": "old message"},
|
||||
]
|
||||
messages = list(conversation_history) + [
|
||||
{"role": "user", "content": "new question"},
|
||||
{"role": "assistant", "content": "new answer"},
|
||||
]
|
||||
|
||||
conversation_history = [
|
||||
{"role": "user", "content": "old message"},
|
||||
]
|
||||
messages = list(conversation_history) + [
|
||||
{"role": "user", "content": "new question"},
|
||||
{"role": "assistant", "content": "new answer"},
|
||||
]
|
||||
# First flush — should write 2 new messages
|
||||
agent._flush_messages_to_session_db(messages, conversation_history)
|
||||
|
||||
# First flush — should write 2 new messages
|
||||
agent._flush_messages_to_session_db(messages, conversation_history)
|
||||
rows = db.get_messages(agent.session_id)
|
||||
assert len(rows) == 2, f"Expected 2 messages, got {len(rows)}"
|
||||
|
||||
rows = db.get_messages(agent.session_id)
|
||||
assert len(rows) == 2, f"Expected 2 messages, got {len(rows)}"
|
||||
# Second flush with SAME messages — should write 0 new messages
|
||||
agent._flush_messages_to_session_db(messages, conversation_history)
|
||||
|
||||
# Second flush with SAME messages — should write 0 new messages
|
||||
agent._flush_messages_to_session_db(messages, conversation_history)
|
||||
|
||||
rows = db.get_messages(agent.session_id)
|
||||
assert len(rows) == 2, f"Expected still 2 messages after second flush, got {len(rows)}"
|
||||
rows = db.get_messages(agent.session_id)
|
||||
assert len(rows) == 2, f"Expected still 2 messages after second flush, got {len(rows)}"
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def test_flush_writes_incrementally(self):
|
||||
"""Messages added between flushes are written exactly once."""
|
||||
|
|
@ -76,27 +78,29 @@ class TestFlushDeduplication:
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test.db"
|
||||
db = SessionDB(db_path=db_path)
|
||||
try:
|
||||
agent = self._make_agent(db)
|
||||
|
||||
agent = self._make_agent(db)
|
||||
conversation_history = []
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
|
||||
conversation_history = []
|
||||
messages = [
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
# First flush — 1 message
|
||||
agent._flush_messages_to_session_db(messages, conversation_history)
|
||||
rows = db.get_messages(agent.session_id)
|
||||
assert len(rows) == 1
|
||||
|
||||
# First flush — 1 message
|
||||
agent._flush_messages_to_session_db(messages, conversation_history)
|
||||
rows = db.get_messages(agent.session_id)
|
||||
assert len(rows) == 1
|
||||
# Add more messages
|
||||
messages.append({"role": "assistant", "content": "hi there"})
|
||||
messages.append({"role": "user", "content": "follow up"})
|
||||
|
||||
# Add more messages
|
||||
messages.append({"role": "assistant", "content": "hi there"})
|
||||
messages.append({"role": "user", "content": "follow up"})
|
||||
|
||||
# Second flush — should write only 2 new messages
|
||||
agent._flush_messages_to_session_db(messages, conversation_history)
|
||||
rows = db.get_messages(agent.session_id)
|
||||
assert len(rows) == 3, f"Expected 3 total messages, got {len(rows)}"
|
||||
# Second flush — should write only 2 new messages
|
||||
agent._flush_messages_to_session_db(messages, conversation_history)
|
||||
rows = db.get_messages(agent.session_id)
|
||||
assert len(rows) == 3, f"Expected 3 total messages, got {len(rows)}"
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def test_persist_session_multiple_calls_no_duplication(self):
|
||||
"""Multiple _persist_session calls don't duplicate DB entries."""
|
||||
|
|
@ -105,23 +109,25 @@ class TestFlushDeduplication:
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test.db"
|
||||
db = SessionDB(db_path=db_path)
|
||||
try:
|
||||
agent = self._make_agent(db)
|
||||
|
||||
agent = self._make_agent(db)
|
||||
conversation_history = [{"role": "user", "content": "old"}]
|
||||
messages = list(conversation_history) + [
|
||||
{"role": "user", "content": "q1"},
|
||||
{"role": "assistant", "content": "a1"},
|
||||
{"role": "user", "content": "q2"},
|
||||
{"role": "assistant", "content": "a2"},
|
||||
]
|
||||
|
||||
conversation_history = [{"role": "user", "content": "old"}]
|
||||
messages = list(conversation_history) + [
|
||||
{"role": "user", "content": "q1"},
|
||||
{"role": "assistant", "content": "a1"},
|
||||
{"role": "user", "content": "q2"},
|
||||
{"role": "assistant", "content": "a2"},
|
||||
]
|
||||
# Simulate multiple persist calls (like the agent's many exit paths)
|
||||
for _ in range(5):
|
||||
agent._persist_session(messages, conversation_history)
|
||||
|
||||
# Simulate multiple persist calls (like the agent's many exit paths)
|
||||
for _ in range(5):
|
||||
agent._persist_session(messages, conversation_history)
|
||||
|
||||
rows = db.get_messages(agent.session_id)
|
||||
assert len(rows) == 4, f"Expected 4 messages, got {len(rows)} (duplication bug!)"
|
||||
rows = db.get_messages(agent.session_id)
|
||||
assert len(rows) == 4, f"Expected 4 messages, got {len(rows)} (duplication bug!)"
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def test_flush_reset_after_compression(self):
|
||||
"""After compression creates a new session, flush index resets."""
|
||||
|
|
@ -130,36 +136,38 @@ class TestFlushDeduplication:
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test.db"
|
||||
db = SessionDB(db_path=db_path)
|
||||
try:
|
||||
agent = self._make_agent(db)
|
||||
|
||||
agent = self._make_agent(db)
|
||||
# Write some messages
|
||||
messages = [
|
||||
{"role": "user", "content": "msg1"},
|
||||
{"role": "assistant", "content": "reply1"},
|
||||
]
|
||||
agent._flush_messages_to_session_db(messages, [])
|
||||
|
||||
# Write some messages
|
||||
messages = [
|
||||
{"role": "user", "content": "msg1"},
|
||||
{"role": "assistant", "content": "reply1"},
|
||||
]
|
||||
agent._flush_messages_to_session_db(messages, [])
|
||||
old_session = agent.session_id
|
||||
assert agent._last_flushed_db_idx == 2
|
||||
|
||||
old_session = agent.session_id
|
||||
assert agent._last_flushed_db_idx == 2
|
||||
# Simulate what _compress_context does: new session, reset idx
|
||||
agent.session_id = "compressed-session-new"
|
||||
db.create_session(session_id=agent.session_id, source="test")
|
||||
agent._last_flushed_db_idx = 0
|
||||
|
||||
# Simulate what _compress_context does: new session, reset idx
|
||||
agent.session_id = "compressed-session-new"
|
||||
db.create_session(session_id=agent.session_id, source="test")
|
||||
agent._last_flushed_db_idx = 0
|
||||
# Now flush compressed messages to new session
|
||||
compressed_messages = [
|
||||
{"role": "user", "content": "summary of conversation"},
|
||||
]
|
||||
agent._flush_messages_to_session_db(compressed_messages, [])
|
||||
|
||||
# Now flush compressed messages to new session
|
||||
compressed_messages = [
|
||||
{"role": "user", "content": "summary of conversation"},
|
||||
]
|
||||
agent._flush_messages_to_session_db(compressed_messages, [])
|
||||
new_rows = db.get_messages(agent.session_id)
|
||||
assert len(new_rows) == 1
|
||||
|
||||
new_rows = db.get_messages(agent.session_id)
|
||||
assert len(new_rows) == 1
|
||||
|
||||
# Old session should still have its 2 messages
|
||||
old_rows = db.get_messages(old_session)
|
||||
assert len(old_rows) == 2
|
||||
# Old session should still have its 2 messages
|
||||
old_rows = db.get_messages(old_session)
|
||||
assert len(old_rows) == 2
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue