fix(tui): degrade gracefully when state.db init fails

This commit is contained in:
helix4u 2026-04-22 13:49:33 -06:00
parent de849c410d
commit 5dead0f2a0
2 changed files with 109 additions and 7 deletions

View file

@ -1183,6 +1183,75 @@ def test_session_create_no_race_keeps_worker_alive(monkeypatch):
server._sessions.pop(sid, None)
def test_get_db_degrades_cleanly_when_sessiondb_init_fails(monkeypatch):
fake_mod = types.ModuleType("hermes_state")
class _BrokenSessionDB:
def __init__(self):
raise RuntimeError("locking protocol")
fake_mod.SessionDB = _BrokenSessionDB
monkeypatch.setitem(sys.modules, "hermes_state", fake_mod)
monkeypatch.setattr(server, "_db", None)
monkeypatch.setattr(server, "_db_error", None)
assert server._get_db() is None
assert server._db_error == "locking protocol"
def test_session_create_continues_when_state_db_is_unavailable(monkeypatch):
class _FakeWorker:
def __init__(self, key, model):
self.key = key
def close(self):
return None
class _FakeAgent:
def __init__(self):
self.model = "x"
self.provider = "openrouter"
self.base_url = ""
self.api_key = ""
emits = []
monkeypatch.setattr(server, "_make_agent", lambda sid, key: _FakeAgent())
monkeypatch.setattr(server, "_SlashWorker", _FakeWorker)
monkeypatch.setattr(server, "_get_db", lambda: None)
monkeypatch.setattr(server, "_session_info", lambda _a: {"model": "x"})
monkeypatch.setattr(server, "_probe_credentials", lambda _a: None)
monkeypatch.setattr(server, "_wire_callbacks", lambda _sid: None)
monkeypatch.setattr(server, "_emit", lambda *a, **kw: emits.append(a))
import tools.approval as _approval
monkeypatch.setattr(_approval, "register_gateway_notify", lambda key, cb: None)
monkeypatch.setattr(_approval, "load_permanent_allowlist", lambda: None)
resp = server.handle_request(
{"id": "1", "method": "session.create", "params": {"cols": 80}}
)
sid = resp["result"]["session_id"]
session = server._sessions[sid]
session["agent_ready"].wait(timeout=2.0)
assert session["agent_error"] is None
assert session["agent"] is not None
assert not any(args and args[0] == "error" for args in emits)
server._sessions.pop(sid, None)
def test_session_list_returns_clean_error_when_state_db_is_unavailable(monkeypatch):
monkeypatch.setattr(server, "_get_db", lambda: None)
monkeypatch.setattr(server, "_db_error", "locking protocol")
resp = server.handle_request({"id": "1", "method": "session.list", "params": {}})
assert "error" in resp
assert "state.db unavailable: locking protocol" in resp["error"]["message"]
# --------------------------------------------------------------------------
# model.options — curated-list parity with `hermes model` and classic /model
# --------------------------------------------------------------------------

View file

@ -2,6 +2,7 @@ import atexit
import concurrent.futures
import copy
import json
import logging
import os
import queue
import subprocess
@ -15,6 +16,8 @@ from pathlib import Path
from hermes_constants import get_hermes_home
from hermes_cli.env_loader import load_hermes_dotenv
logger = logging.getLogger(__name__)
_hermes_home = get_hermes_home()
load_hermes_dotenv(
hermes_home=_hermes_home, project_env=Path(__file__).parent.parent / ".env"
@ -34,6 +37,7 @@ _methods: dict[str, callable] = {}
_pending: dict[str, tuple[str, threading.Event]] = {}
_answers: dict[str, str] = {}
_db = None
_db_error: str | None = None
_stdout_lock = threading.Lock()
_cfg_lock = threading.Lock()
_cfg_cache: dict | None = None
@ -170,14 +174,28 @@ atexit.register(
def _get_db():
global _db
global _db, _db_error
if _db is None:
from hermes_state import SessionDB
_db = SessionDB()
try:
_db = SessionDB()
_db_error = None
except Exception as exc:
_db_error = str(exc)
logger.warning(
"TUI session store unavailable — continuing without state.db features: %s",
exc,
)
return None
return _db
def _db_unavailable_error(rid, *, code: int):
detail = _db_error or "state.db unavailable"
return _err(rid, code, f"state.db unavailable: {detail}")
def write_json(obj: dict) -> bool:
line = json.dumps(obj, ensure_ascii=False) + "\n"
try:
@ -1318,7 +1336,9 @@ def _(rid, params: dict) -> dict:
finally:
_clear_session_context(tokens)
_get_db().create_session(key, source="tui", model=_resolve_model())
db = _get_db()
if db is not None:
db.create_session(key, source="tui", model=_resolve_model())
session["agent"] = agent
try:
@ -1390,6 +1410,9 @@ def _(rid, params: dict) -> dict:
@method("session.list")
def _(rid, params: dict) -> dict:
db = _get_db()
if db is None:
return _db_unavailable_error(rid, code=5006)
try:
# Resume picker should include human conversation surfaces beyond
# tui/cli (notably telegram from blitz row #7), but avoid internal
@ -1416,7 +1439,7 @@ def _(rid, params: dict) -> dict:
fetch_limit = max(limit * 5, 100)
rows = [
s
for s in _get_db().list_sessions_rich(source=None, limit=fetch_limit)
for s in db.list_sessions_rich(source=None, limit=fetch_limit)
if (s.get("source") or "").strip().lower() in allow
][:limit]
return _ok(
@ -1445,6 +1468,8 @@ def _(rid, params: dict) -> dict:
if not target:
return _err(rid, 4006, "session_id required")
db = _get_db()
if db is None:
return _db_unavailable_error(rid, code=5000)
found = db.get_session(target)
if not found:
found = db.get_session_by_title(target)
@ -1483,13 +1508,16 @@ def _(rid, params: dict) -> dict:
session, err = _sess(params, rid)
if err:
return err
db = _get_db()
if db is None:
return _db_unavailable_error(rid, code=5007)
title, key = params.get("title", ""), session["session_key"]
if not title:
return _ok(
rid, {"title": _get_db().get_session_title(key) or "", "session_key": key}
rid, {"title": db.get_session_title(key) or "", "session_key": key}
)
try:
_get_db().set_session_title(key, title)
db.set_session_title(key, title)
return _ok(rid, {"title": title})
except Exception as e:
return _err(rid, 5007, str(e))
@ -1625,6 +1653,8 @@ def _(rid, params: dict) -> dict:
if err:
return err
db = _get_db()
if db is None:
return _db_unavailable_error(rid, code=5008)
old_key = session["session_key"]
with session["history_lock"]:
history = [dict(msg) for msg in session.get("history", [])]
@ -3453,11 +3483,14 @@ def _(rid, params: dict) -> dict:
@method("insights.get")
def _(rid, params: dict) -> dict:
days = params.get("days", 30)
db = _get_db()
if db is None:
return _db_unavailable_error(rid, code=5017)
try:
cutoff = time.time() - days * 86400
rows = [
s
for s in _get_db().list_sessions_rich(limit=500)
for s in db.list_sessions_rich(limit=500)
if (s.get("started_at") or 0) >= cutoff
]
return _ok(