diff --git a/tests/test_tui_gateway_server.py b/tests/test_tui_gateway_server.py index 35bc3f449b..8831efb896 100644 --- a/tests/test_tui_gateway_server.py +++ b/tests/test_tui_gateway_server.py @@ -546,3 +546,169 @@ def test_session_info_includes_mcp_servers(monkeypatch): assert info["mcp_servers"] == fake_status + +# --------------------------------------------------------------------------- +# History-mutating commands must reject while session.running is True. +# Without these guards, prompt.submit's post-run history write either +# clobbers the mutation (version matches) or silently drops the agent's +# output (version mismatch) — both produce UI<->backend state desync. +# --------------------------------------------------------------------------- + + +def test_session_undo_rejects_while_running(): + """Fix for TUI silent-drop #1: /undo must not mutate history + while the agent is mid-turn — would either clobber the undo or + cause prompt.submit to silently drop the agent's response.""" + server._sessions["sid"] = _session(running=True, history=[ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ]) + try: + resp = server.handle_request( + {"id": "1", "method": "session.undo", "params": {"session_id": "sid"}} + ) + assert resp.get("error"), "session.undo should reject while running" + assert resp["error"]["code"] == 4009 + assert "session busy" in resp["error"]["message"] + # History must be unchanged + assert len(server._sessions["sid"]["history"]) == 2 + finally: + server._sessions.pop("sid", None) + + +def test_session_undo_allowed_when_idle(): + """Regression guard: when not running, /undo still works.""" + server._sessions["sid"] = _session(running=False, history=[ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ]) + try: + resp = server.handle_request( + {"id": "1", "method": "session.undo", "params": {"session_id": "sid"}} + ) + assert resp.get("result"), f"got error: {resp.get('error')}" + assert resp["result"]["removed"] == 2 + assert server._sessions["sid"]["history"] == [] + finally: + server._sessions.pop("sid", None) + + +def test_session_compress_rejects_while_running(monkeypatch): + server._sessions["sid"] = _session(running=True) + try: + resp = server.handle_request( + {"id": "1", "method": "session.compress", "params": {"session_id": "sid"}} + ) + assert resp.get("error") + assert resp["error"]["code"] == 4009 + finally: + server._sessions.pop("sid", None) + + +def test_rollback_restore_rejects_full_history_while_running(monkeypatch): + """Full-history rollback must reject; file-scoped rollback still allowed.""" + server._sessions["sid"] = _session(running=True) + try: + resp = server.handle_request( + {"id": "1", "method": "rollback.restore", "params": {"session_id": "sid", "hash": "abc"}} + ) + assert resp.get("error"), "full-history rollback should reject while running" + assert resp["error"]["code"] == 4009 + finally: + server._sessions.pop("sid", None) + + +def test_prompt_submit_history_version_mismatch_surfaces_warning(monkeypatch): + """Fix for TUI silent-drop #2: the defensive backstop at prompt.submit + must attach a 'warning' to message.complete when history was + mutated externally during the turn (instead of silently dropping + the agent's output).""" + # Agent bumps history_version itself mid-run to simulate an external + # mutation slipping past the guards. + session_ref = {"s": None} + + class _RacyAgent: + def run_conversation(self, prompt, conversation_history=None, stream_callback=None): + # Simulate: something external bumped history_version + # while we were running. + with session_ref["s"]["history_lock"]: + session_ref["s"]["history_version"] += 1 + return {"final_response": "agent reply", "messages": [{"role": "assistant", "content": "agent reply"}]} + + class _ImmediateThread: + def __init__(self, target=None, daemon=None): + self._target = target + + def start(self): + self._target() + + server._sessions["sid"] = _session(agent=_RacyAgent()) + session_ref["s"] = server._sessions["sid"] + emits: list[tuple] = [] + try: + monkeypatch.setattr(server.threading, "Thread", _ImmediateThread) + monkeypatch.setattr(server, "_get_usage", lambda _a: {}) + monkeypatch.setattr(server, "render_message", lambda _t, _c: "") + monkeypatch.setattr(server, "_emit", lambda *a: emits.append(a)) + + resp = server.handle_request( + {"id": "1", "method": "prompt.submit", "params": {"session_id": "sid", "text": "hi"}} + ) + assert resp.get("result"), f"got error: {resp.get('error')}" + + # History should NOT contain the agent's output (version mismatch) + assert server._sessions["sid"]["history"] == [] + + # message.complete must carry a 'warning' so the UI / operator + # knows the output was not persisted. + complete_calls = [a for a in emits if a[0] == "message.complete"] + assert len(complete_calls) == 1 + _, _, payload = complete_calls[0] + assert "warning" in payload, ( + "message.complete must include a 'warning' field on " + "history_version mismatch — otherwise the UI silently " + "shows output that was never persisted" + ) + assert "not saved" in payload["warning"].lower() or "changed" in payload["warning"].lower() + finally: + server._sessions.pop("sid", None) + + +def test_prompt_submit_history_version_match_persists_normally(monkeypatch): + """Regression guard: the backstop does not affect the happy path.""" + class _Agent: + def run_conversation(self, prompt, conversation_history=None, stream_callback=None): + return {"final_response": "reply", "messages": [{"role": "assistant", "content": "reply"}]} + + class _ImmediateThread: + def __init__(self, target=None, daemon=None): + self._target = target + + def start(self): + self._target() + + server._sessions["sid"] = _session(agent=_Agent()) + emits: list[tuple] = [] + try: + monkeypatch.setattr(server.threading, "Thread", _ImmediateThread) + monkeypatch.setattr(server, "_get_usage", lambda _a: {}) + monkeypatch.setattr(server, "render_message", lambda _t, _c: "") + monkeypatch.setattr(server, "_emit", lambda *a: emits.append(a)) + + resp = server.handle_request( + {"id": "1", "method": "prompt.submit", "params": {"session_id": "sid", "text": "hi"}} + ) + assert resp.get("result") + + # History was written + assert server._sessions["sid"]["history"] == [{"role": "assistant", "content": "reply"}] + assert server._sessions["sid"]["history_version"] == 1 + + # No warning should be attached + complete_calls = [a for a in emits if a[0] == "message.complete"] + assert len(complete_calls) == 1 + _, _, payload = complete_calls[0] + assert "warning" not in payload + finally: + server._sessions.pop("sid", None) + diff --git a/tui_gateway/server.py b/tui_gateway/server.py index d86db00066..c58c65763e 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -1224,6 +1224,13 @@ def _(rid, params: dict) -> dict: session, err = _sess(params, rid) if err: return err + # Reject during an in-flight turn. If we mutated history while + # the agent thread is running, prompt.submit's post-run history + # write would either clobber the undo (version matches) or + # silently drop the agent's output (version mismatch, see below). + # Neither is what the user wants — make them /interrupt first. + if session.get("running"): + return _err(rid, 4009, "session busy — /interrupt the current turn before /undo") removed = 0 with session["history_lock"]: history = session.get("history", []) @@ -1243,6 +1250,8 @@ def _(rid, params: dict) -> dict: session, err = _sess(params, rid) if err: return err + if session.get("running"): + return _err(rid, 4009, "session busy — /interrupt the current turn before /compress") try: with session["history_lock"]: removed, usage = _compress_session_history(session, str(params.get("focus_topic", "") or "").strip()) @@ -1443,12 +1452,33 @@ def _(rid, params: dict) -> dict: ) last_reasoning = None + status_note = None if isinstance(result, dict): if isinstance(result.get("messages"), list): with session["history_lock"]: - if int(session.get("history_version", 0)) == history_version: + current_version = int(session.get("history_version", 0)) + if current_version == history_version: session["history"] = result["messages"] session["history_version"] = history_version + 1 + else: + # History mutated externally during the turn + # (undo/compress/retry/rollback now guard on + # session.running, but this is the defensive + # backstop for any path that slips past). + # Surface the desync rather than silently + # dropping the agent's output — the UI can + # show the response and warn that it was + # not persisted. + print( + f"[tui_gateway] prompt.submit: history_version mismatch " + f"(expected={history_version} current={current_version}) — " + f"agent output NOT written to session history", + file=sys.stderr, + ) + status_note = ( + "History changed during this turn — the response above is visible " + "but was not saved to session history." + ) raw = result.get("final_response", "") status = "interrupted" if result.get("interrupted") else "error" if result.get("error") else "complete" lr = result.get("last_reasoning") @@ -1461,6 +1491,8 @@ def _(rid, params: dict) -> dict: payload = {"text": raw, "usage": _get_usage(agent), "status": status} if last_reasoning: payload["reasoning"] = last_reasoning + if status_note: + payload["warning"] = status_note rendered = render_message(raw, cols) if rendered: payload["rendered"] = rendered @@ -2168,6 +2200,8 @@ def _(rid, params: dict) -> dict: if name == "retry": if not session: return _err(rid, 4001, "no active session to retry") + if session.get("running"): + return _err(rid, 4009, "session busy — /interrupt the current turn before /retry") history = session.get("history", []) if not history: return _err(rid, 4018, "no previous user message to retry") @@ -2578,6 +2612,13 @@ def _(rid, params: dict) -> dict: file_path = params.get("file_path", "") if not target: return _err(rid, 4014, "hash required") + # Full-history rollback mutates session history. Rejecting during + # an in-flight turn prevents prompt.submit from silently dropping + # the agent's output (version mismatch path) or clobbering the + # rollback (version-matches path). A file-scoped rollback only + # touches disk, so we allow it. + if not file_path and session.get("running"): + return _err(rid, 4009, "session busy — /interrupt the current turn before full rollback.restore") try: def go(mgr, cwd): resolved = _resolve_checkpoint_hash(mgr, cwd, target)