diff --git a/tests/test_tui_gateway_server.py b/tests/test_tui_gateway_server.py index ab7b52df0..2c50065b2 100644 --- a/tests/test_tui_gateway_server.py +++ b/tests/test_tui_gateway_server.py @@ -1446,6 +1446,75 @@ def test_session_create_no_race_keeps_worker_alive(monkeypatch): server._sessions.pop(sid, None) +def test_get_db_degrades_cleanly_when_sessiondb_init_fails(monkeypatch): + fake_mod = types.ModuleType("hermes_state") + + class _BrokenSessionDB: + def __init__(self): + raise RuntimeError("locking protocol") + + fake_mod.SessionDB = _BrokenSessionDB + monkeypatch.setitem(sys.modules, "hermes_state", fake_mod) + monkeypatch.setattr(server, "_db", None) + monkeypatch.setattr(server, "_db_error", None) + + assert server._get_db() is None + assert server._db_error == "locking protocol" + + +def test_session_create_continues_when_state_db_is_unavailable(monkeypatch): + class _FakeWorker: + def __init__(self, key, model): + self.key = key + + def close(self): + return None + + class _FakeAgent: + def __init__(self): + self.model = "x" + self.provider = "openrouter" + self.base_url = "" + self.api_key = "" + + emits = [] + + monkeypatch.setattr(server, "_make_agent", lambda sid, key: _FakeAgent()) + monkeypatch.setattr(server, "_SlashWorker", _FakeWorker) + monkeypatch.setattr(server, "_get_db", lambda: None) + monkeypatch.setattr(server, "_session_info", lambda _a: {"model": "x"}) + monkeypatch.setattr(server, "_probe_credentials", lambda _a: None) + monkeypatch.setattr(server, "_wire_callbacks", lambda _sid: None) + monkeypatch.setattr(server, "_emit", lambda *a, **kw: emits.append(a)) + + import tools.approval as _approval + monkeypatch.setattr(_approval, "register_gateway_notify", lambda key, cb: None) + monkeypatch.setattr(_approval, "load_permanent_allowlist", lambda: None) + + resp = server.handle_request( + {"id": "1", "method": "session.create", "params": {"cols": 80}} + ) + sid = resp["result"]["session_id"] + session = server._sessions[sid] + session["agent_ready"].wait(timeout=2.0) + + assert session["agent_error"] is None + assert session["agent"] is not None + assert not any(args and args[0] == "error" for args in emits) + + server._sessions.pop(sid, None) + + +def test_session_list_returns_clean_error_when_state_db_is_unavailable(monkeypatch): + monkeypatch.setattr(server, "_get_db", lambda: None) + monkeypatch.setattr(server, "_db_error", "locking protocol") + + resp = server.handle_request({"id": "1", "method": "session.list", "params": {}}) + + assert "error" in resp + assert "state.db unavailable: locking protocol" in resp["error"]["message"] + + # -------------------------------------------------------------------------- # model.options — curated-list parity with `hermes model` and classic /model # -------------------------------------------------------------------------- diff --git a/tui_gateway/server.py b/tui_gateway/server.py index 3aac77192..982536d24 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -2,6 +2,7 @@ import atexit import concurrent.futures import copy import json +import logging import os import queue import subprocess @@ -15,6 +16,8 @@ from pathlib import Path from hermes_constants import get_hermes_home from hermes_cli.env_loader import load_hermes_dotenv +logger = logging.getLogger(__name__) + _hermes_home = get_hermes_home() load_hermes_dotenv( hermes_home=_hermes_home, project_env=Path(__file__).parent.parent / ".env" @@ -34,6 +37,7 @@ _methods: dict[str, callable] = {} _pending: dict[str, tuple[str, threading.Event]] = {} _answers: dict[str, str] = {} _db = None +_db_error: str | None = None _stdout_lock = threading.Lock() _cfg_lock = threading.Lock() _cfg_cache: dict | None = None @@ -170,14 +174,28 @@ atexit.register( def _get_db(): - global _db + global _db, _db_error if _db is None: from hermes_state import SessionDB - _db = SessionDB() + try: + _db = SessionDB() + _db_error = None + except Exception as exc: + _db_error = str(exc) + logger.warning( + "TUI session store unavailable — continuing without state.db features: %s", + exc, + ) + return None return _db +def _db_unavailable_error(rid, *, code: int): + detail = _db_error or "state.db unavailable" + return _err(rid, code, f"state.db unavailable: {detail}") + + def write_json(obj: dict) -> bool: line = json.dumps(obj, ensure_ascii=False) + "\n" try: @@ -1329,7 +1347,9 @@ def _(rid, params: dict) -> dict: finally: _clear_session_context(tokens) - _get_db().create_session(key, source="tui", model=_resolve_model()) + db = _get_db() + if db is not None: + db.create_session(key, source="tui", model=_resolve_model()) session["agent"] = agent try: @@ -1401,6 +1421,9 @@ def _(rid, params: dict) -> dict: @method("session.list") def _(rid, params: dict) -> dict: + db = _get_db() + if db is None: + return _db_unavailable_error(rid, code=5006) try: # Resume picker should include human conversation surfaces beyond # tui/cli (notably telegram from blitz row #7), but avoid internal @@ -1427,7 +1450,7 @@ def _(rid, params: dict) -> dict: fetch_limit = max(limit * 5, 100) rows = [ s - for s in _get_db().list_sessions_rich(source=None, limit=fetch_limit) + for s in db.list_sessions_rich(source=None, limit=fetch_limit) if (s.get("source") or "").strip().lower() in allow ][:limit] return _ok( @@ -1456,6 +1479,8 @@ def _(rid, params: dict) -> dict: if not target: return _err(rid, 4006, "session_id required") db = _get_db() + if db is None: + return _db_unavailable_error(rid, code=5000) found = db.get_session(target) if not found: found = db.get_session_by_title(target) @@ -1494,13 +1519,16 @@ def _(rid, params: dict) -> dict: session, err = _sess(params, rid) if err: return err + db = _get_db() + if db is None: + return _db_unavailable_error(rid, code=5007) title, key = params.get("title", ""), session["session_key"] if not title: return _ok( - rid, {"title": _get_db().get_session_title(key) or "", "session_key": key} + rid, {"title": db.get_session_title(key) or "", "session_key": key} ) try: - _get_db().set_session_title(key, title) + db.set_session_title(key, title) return _ok(rid, {"title": title}) except Exception as e: return _err(rid, 5007, str(e)) @@ -1636,6 +1664,8 @@ def _(rid, params: dict) -> dict: if err: return err db = _get_db() + if db is None: + return _db_unavailable_error(rid, code=5008) old_key = session["session_key"] with session["history_lock"]: history = [dict(msg) for msg in session.get("history", [])] @@ -3483,11 +3513,14 @@ def _(rid, params: dict) -> dict: @method("insights.get") def _(rid, params: dict) -> dict: days = params.get("days", 30) + db = _get_db() + if db is None: + return _db_unavailable_error(rid, code=5017) try: cutoff = time.time() - days * 86400 rows = [ s - for s in _get_db().list_sessions_rich(limit=500) + for s in db.list_sessions_rich(limit=500) if (s.get("started_at") or 0) >= cutoff ] return _ok(