hermes-agent/tests/tui_gateway/test_protocol.py
Teknium dca439fe92
fix(tui): scope session.interrupt pending-prompt release to the calling session (#12441)
session.interrupt on session A was blast-resolving pending
clarify/sudo/secret prompts on ALL sessions sharing the same
tui_gateway process.  Other sessions' agent threads unblocked with
empty-string answers as if the user had cancelled — silent
cross-session corruption.

Root cause: _pending and _answers were globals keyed by random rid
with no record of the owning session.  _clear_pending() iterated
every entry, so the session.interrupt handler had no way to limit
the release to its own sid.

Fix:
- tui_gateway/server.py: _pending now maps rid to (sid, Event)
  tuples.  _clear_pending takes an optional sid argument and filters
  by owner_sid when provided.  session.interrupt passes the calling
  sid so unrelated sessions are untouched.  _clear_pending(None)
  remains the shutdown path for completeness.
- _block and _respond updated to pack/unpack the new tuple format.

Tests (tests/test_tui_gateway_server.py): 4 new cases.
- test_interrupt_only_clears_own_session_pending: two sessions with
  pending prompts, interrupting one must not release the other.
- test_interrupt_clears_multiple_own_pending: same-sid multi-prompt
  release works.
- test_clear_pending_without_sid_clears_all: shutdown path preserved.
- test_respond_unpacks_sid_tuple_correctly: _respond handles the
  tuple format.

Also updated tests/tui_gateway/test_protocol.py to use the new tuple
format for test_block_and_respond and test_clear_pending.

Live E2E against the live Python environment confirmed cross-session
isolation: interrupting sid_a released its own pending prompt without
touching sid_b's.  All 78 related tests pass.
2026-04-19 00:03:58 -07:00

434 lines
14 KiB
Python

"""Tests for tui_gateway JSON-RPC protocol plumbing."""
import io
import json
import sys
import threading
from unittest.mock import MagicMock, patch
import pytest
_original_stdout = sys.stdout
@pytest.fixture(autouse=True)
def _restore_stdout():
yield
sys.stdout = _original_stdout
@pytest.fixture()
def server():
with patch.dict("sys.modules", {
"hermes_constants": MagicMock(get_hermes_home=MagicMock(return_value="/tmp/hermes_test")),
"hermes_cli.env_loader": MagicMock(),
"hermes_cli.banner": MagicMock(),
"hermes_state": MagicMock(),
}):
import importlib
mod = importlib.import_module("tui_gateway.server")
yield mod
mod._sessions.clear()
mod._pending.clear()
mod._answers.clear()
mod._methods.clear()
importlib.reload(mod)
@pytest.fixture()
def capture(server):
"""Redirect server's real stdout to a StringIO and return (server, buf)."""
buf = io.StringIO()
server._real_stdout = buf
return server, buf
# ── JSON-RPC envelope ────────────────────────────────────────────────
def test_unknown_method(server):
resp = server.handle_request({"id": "1", "method": "bogus"})
assert resp["error"]["code"] == -32601
def test_ok_envelope(server):
assert server._ok("r1", {"x": 1}) == {
"jsonrpc": "2.0", "id": "r1", "result": {"x": 1},
}
def test_err_envelope(server):
assert server._err("r2", 4001, "nope") == {
"jsonrpc": "2.0", "id": "r2", "error": {"code": 4001, "message": "nope"},
}
# ── write_json ───────────────────────────────────────────────────────
def test_write_json(capture):
server, buf = capture
assert server.write_json({"test": True})
assert json.loads(buf.getvalue()) == {"test": True}
def test_write_json_broken_pipe(server):
class _Broken:
def write(self, _): raise BrokenPipeError
def flush(self): raise BrokenPipeError
server._real_stdout = _Broken()
assert server.write_json({"x": 1}) is False
# ── _emit ────────────────────────────────────────────────────────────
def test_emit_with_payload(capture):
server, buf = capture
server._emit("test.event", "s1", {"key": "val"})
msg = json.loads(buf.getvalue())
assert msg["method"] == "event"
assert msg["params"]["type"] == "test.event"
assert msg["params"]["session_id"] == "s1"
assert msg["params"]["payload"]["key"] == "val"
def test_emit_without_payload(capture):
server, buf = capture
server._emit("ping", "s2")
assert "payload" not in json.loads(buf.getvalue())["params"]
# ── Blocking prompt round-trip ───────────────────────────────────────
def test_block_and_respond(capture):
server, _ = capture
result = [None]
threading.Thread(
target=lambda: result.__setitem__(0, server._block("test.prompt", "s1", {"q": "?"}, timeout=5)),
).start()
for _ in range(100):
if server._pending:
break
threading.Event().wait(0.01)
rid = next(iter(server._pending))
server._answers[rid] = "my_answer"
# _pending values are (sid, Event) tuples — unpack to set the Event
_, ev = server._pending[rid]
ev.set()
threading.Event().wait(0.1)
assert result[0] == "my_answer"
def test_clear_pending(server):
ev = threading.Event()
# _pending values are (sid, Event) tuples
server._pending["r1"] = ("sid-x", ev)
server._clear_pending()
assert ev.is_set()
assert server._answers["r1"] == ""
# ── Session lookup ───────────────────────────────────────────────────
def test_sess_missing(server):
_, err = server._sess({"session_id": "nope"}, "r1")
assert err["error"]["code"] == 4001
def test_sess_found(server):
server._sessions["abc"] = {"agent": MagicMock()}
s, err = server._sess({"session_id": "abc"}, "r1")
assert s is not None
assert err is None
# ── session.resume payload ────────────────────────────────────────────
def test_session_resume_returns_hydrated_messages(server, monkeypatch):
class _DB:
def get_session(self, _sid):
return {"id": "20260409_010101_abc123"}
def get_session_by_title(self, _title):
return None
def reopen_session(self, _sid):
return None
def get_messages_as_conversation(self, _sid):
return [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "yo"},
{"role": "tool", "content": "searched"},
{"role": "assistant", "content": " "},
{"role": "assistant", "content": None},
{"role": "narrator", "content": "skip"},
]
monkeypatch.setattr(server, "_get_db", lambda: _DB())
monkeypatch.setattr(server, "_make_agent", lambda sid, key, session_id=None: object())
monkeypatch.setattr(server, "_init_session", lambda sid, key, agent, history, cols=80: None)
monkeypatch.setattr(server, "_session_info", lambda _agent: {"model": "test/model"})
resp = server.handle_request(
{
"id": "r1",
"method": "session.resume",
"params": {"session_id": "20260409_010101_abc123", "cols": 100},
}
)
assert "error" not in resp
assert resp["result"]["message_count"] == 3
assert resp["result"]["messages"] == [
{"role": "user", "text": "hello"},
{"role": "assistant", "text": "yo"},
{"role": "tool", "name": "tool", "context": ""},
]
# ── Config I/O ───────────────────────────────────────────────────────
def test_config_load_missing(server, tmp_path):
server._hermes_home = tmp_path
assert server._load_cfg() == {}
def test_config_roundtrip(server, tmp_path):
server._hermes_home = tmp_path
server._save_cfg({"model": "test/model"})
assert server._load_cfg()["model"] == "test/model"
# ── _cli_exec_blocked ────────────────────────────────────────────────
@pytest.mark.parametrize("argv", [
[],
["setup"],
["gateway"],
["sessions", "browse"],
["config", "edit"],
])
def test_cli_exec_blocked(server, argv):
assert server._cli_exec_blocked(argv) is not None
@pytest.mark.parametrize("argv", [
["version"],
["sessions", "list"],
])
def test_cli_exec_allowed(server, argv):
assert server._cli_exec_blocked(argv) is None
# ── slash.exec skill command interception ────────────────────────────
def test_slash_exec_rejects_skill_commands(server):
"""slash.exec must reject skill commands so the TUI falls through to command.dispatch."""
# Register a mock session
sid = "test-session"
server._sessions[sid] = {"session_key": sid, "agent": None}
# Mock scan_skill_commands to return a known skill
fake_skills = {"/hermes-agent-dev": {"name": "hermes-agent-dev", "description": "Dev workflow"}}
with patch("agent.skill_commands.get_skill_commands", return_value=fake_skills):
resp = server.handle_request({
"id": "r1",
"method": "slash.exec",
"params": {"command": "hermes-agent-dev", "session_id": sid},
})
# Should return an error so the TUI's .catch() fires command.dispatch
assert "error" in resp
assert resp["error"]["code"] == 4018
assert "skill command" in resp["error"]["message"]
@pytest.mark.parametrize("cmd", ["retry", "queue hello", "q hello", "steer fix the test", "plan"])
def test_slash_exec_rejects_pending_input_commands(server, cmd):
"""slash.exec must reject commands that use _pending_input in the CLI."""
sid = "test-session"
server._sessions[sid] = {"session_key": sid, "agent": None}
resp = server.handle_request({
"id": "r1",
"method": "slash.exec",
"params": {"command": cmd, "session_id": sid},
})
assert "error" in resp
assert resp["error"]["code"] == 4018
assert "pending-input command" in resp["error"]["message"]
def test_command_dispatch_queue_sends_message(server):
"""command.dispatch /queue returns {type: 'send', message: ...} for the TUI."""
sid = "test-session"
server._sessions[sid] = {"session_key": sid}
resp = server.handle_request({
"id": "r1",
"method": "command.dispatch",
"params": {"name": "queue", "arg": "tell me about quantum computing", "session_id": sid},
})
assert "error" not in resp
result = resp["result"]
assert result["type"] == "send"
assert result["message"] == "tell me about quantum computing"
def test_command_dispatch_queue_requires_arg(server):
"""command.dispatch /queue without an argument returns an error."""
sid = "test-session"
server._sessions[sid] = {"session_key": sid}
resp = server.handle_request({
"id": "r2",
"method": "command.dispatch",
"params": {"name": "queue", "arg": "", "session_id": sid},
})
assert "error" in resp
assert resp["error"]["code"] == 4004
def test_command_dispatch_steer_fallback_sends_message(server):
"""command.dispatch /steer with no active agent falls back to send."""
sid = "test-session"
server._sessions[sid] = {"session_key": sid, "agent": None}
resp = server.handle_request({
"id": "r3",
"method": "command.dispatch",
"params": {"name": "steer", "arg": "focus on testing", "session_id": sid},
})
assert "error" not in resp
result = resp["result"]
assert result["type"] == "send"
assert result["message"] == "focus on testing"
def test_command_dispatch_retry_finds_last_user_message(server):
"""command.dispatch /retry walks session['history'] to find the last user message."""
sid = "test-session"
history = [
{"role": "user", "content": "first question"},
{"role": "assistant", "content": "first answer"},
{"role": "user", "content": "second question"},
{"role": "assistant", "content": "second answer"},
]
server._sessions[sid] = {
"session_key": sid,
"agent": None,
"history": history,
"history_lock": threading.Lock(),
"history_version": 0,
}
resp = server.handle_request({
"id": "r4",
"method": "command.dispatch",
"params": {"name": "retry", "session_id": sid},
})
assert "error" not in resp
result = resp["result"]
assert result["type"] == "send"
assert result["message"] == "second question"
# Verify history was truncated: everything from last user message onward removed
assert len(server._sessions[sid]["history"]) == 2
assert server._sessions[sid]["history"][-1]["role"] == "assistant"
assert server._sessions[sid]["history_version"] == 1
def test_command_dispatch_retry_empty_history(server):
"""command.dispatch /retry with empty history returns error."""
sid = "test-session"
server._sessions[sid] = {
"session_key": sid,
"agent": None,
"history": [],
"history_lock": threading.Lock(),
"history_version": 0,
}
resp = server.handle_request({
"id": "r5",
"method": "command.dispatch",
"params": {"name": "retry", "session_id": sid},
})
assert "error" in resp
assert resp["error"]["code"] == 4018
def test_command_dispatch_retry_handles_multipart_content(server):
"""command.dispatch /retry extracts text from multipart content lists."""
sid = "test-session"
history = [
{"role": "user", "content": [
{"type": "text", "text": "analyze this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
]},
{"role": "assistant", "content": "I see the image."},
]
server._sessions[sid] = {
"session_key": sid,
"agent": None,
"history": history,
"history_lock": threading.Lock(),
"history_version": 0,
}
resp = server.handle_request({
"id": "r6",
"method": "command.dispatch",
"params": {"name": "retry", "session_id": sid},
})
assert "error" not in resp
result = resp["result"]
assert result["type"] == "send"
assert result["message"] == "analyze this"
def test_command_dispatch_returns_skill_payload(server):
"""command.dispatch returns structured skill payload for the TUI to send()."""
sid = "test-session"
server._sessions[sid] = {"session_key": sid}
fake_skills = {"/hermes-agent-dev": {"name": "hermes-agent-dev", "description": "Dev workflow"}}
fake_msg = "Loaded skill content here"
with patch("agent.skill_commands.scan_skill_commands", return_value=fake_skills), \
patch("agent.skill_commands.build_skill_invocation_message", return_value=fake_msg):
resp = server.handle_request({
"id": "r2",
"method": "command.dispatch",
"params": {"name": "hermes-agent-dev", "session_id": sid},
})
assert "error" not in resp
result = resp["result"]
assert result["type"] == "skill"
assert result["message"] == fake_msg
assert result["name"] == "hermes-agent-dev"