From bd6d0987629ecf6d3602d739997564d18de5225f Mon Sep 17 00:00:00 2001 From: rexdotsh <65942753+rexdotsh@users.noreply.github.com> Date: Wed, 3 Jun 2026 20:59:23 +0530 Subject: [PATCH] fix(tui): keep resumed live history current --- tests/tui_gateway/test_protocol.py | 99 ++++++++++++++++++++++++++++++ tui_gateway/server.py | 9 ++- 2 files changed, 106 insertions(+), 2 deletions(-) diff --git a/tests/tui_gateway/test_protocol.py b/tests/tui_gateway/test_protocol.py index 96c49dac0c9..1d31dd3c976 100644 --- a/tests/tui_gateway/test_protocol.py +++ b/tests/tui_gateway/test_protocol.py @@ -496,6 +496,105 @@ def test_session_resume_reuses_existing_live_session(server, monkeypatch): assert created_sids == [first["result"]["session_id"]] +def test_session_resume_live_payload_uses_current_history_with_ancestors(server, monkeypatch): + """Live resume should not reuse a stale ancestor-inclusive snapshot.""" + + target = "20260409_010101_child" + ancestor_history = [{"role": "user", "content": "ancestor"}] + current_history = [ + {"role": "user", "content": "current"}, + {"role": "assistant", "content": "current reply"}, + ] + + class _DB: + def get_session(self, _sid): + return {"id": target} + + def get_session_by_title(self, _title): + return None + + def reopen_session(self, _sid): + return None + + def get_messages_as_conversation(self, _sid, include_ancestors=False): + if include_ancestors: + return ancestor_history + current_history + return list(current_history) + + class _Worker: + def close(self): + pass + + monkeypatch.setattr(server, "_get_db", lambda: _DB()) + monkeypatch.setattr( + server, + "_make_agent", + lambda _sid, key, session_id=None: types.SimpleNamespace( + model="test/model", session_id=session_id or key + ), + ) + monkeypatch.setattr(server, "_SlashWorker", lambda _key, _model: _Worker()) + monkeypatch.setattr( + server, + "_start_notification_poller", + lambda _sid, _session: threading.Event(), + ) + monkeypatch.setattr(server, "_notify_session_boundary", lambda *_args, **_kwargs: None) + monkeypatch.setattr(server, "_wire_callbacks", lambda _sid: None) + monkeypatch.setattr(server, "_emit", lambda *_args, **_kwargs: None) + monkeypatch.setattr( + server, + "_session_info", + lambda _agent, _session=None: {"model": "test/model"}, + ) + + fake_approval = types.SimpleNamespace( + load_permanent_allowlist=lambda: None, + register_gateway_notify=lambda *_args, **_kwargs: None, + ) + + with patch.dict(sys.modules, {"tools.approval": fake_approval}): + first = server.handle_request( + { + "id": "first", + "method": "session.resume", + "params": {"session_id": target, "cols": 100}, + } + ) + + assert "error" not in first + sid = first["result"]["session_id"] + assert first["result"]["messages"] == [ + {"role": "user", "text": "ancestor"}, + {"role": "user", "text": "current"}, + {"role": "assistant", "text": "current reply"}, + ] + + with server._sessions[sid]["history_lock"]: + server._sessions[sid]["history"] = current_history + [ + {"role": "user", "content": "new live turn"}, + {"role": "assistant", "content": "new live reply"}, + ] + + second = server.handle_request( + { + "id": "second", + "method": "session.resume", + "params": {"session_id": target, "cols": 120}, + } + ) + + assert "error" not in second + assert second["result"]["session_id"] == sid + assert second["result"]["messages"] == [ + {"role": "user", "text": "ancestor"}, + {"role": "user", "text": "current"}, + {"role": "assistant", "text": "current reply"}, + {"role": "user", "text": "new live turn"}, + {"role": "assistant", "text": "new live reply"}, + ] + + def test_make_agent_accepts_list_system_prompt(server, monkeypatch): captured = {} diff --git a/tui_gateway/server.py b/tui_gateway/server.py index 40695facbd4..113e29a1ae9 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -3016,6 +3016,9 @@ def _(rid, params: dict) -> dict: display_history = db.get_messages_as_conversation( target, include_ancestors=True ) + display_history_prefix = display_history[ + : max(0, len(display_history) - len(history)) + ] messages = _history_to_messages(display_history) tokens = _set_session_context(target) try: @@ -3024,7 +3027,7 @@ def _(rid, params: dict) -> dict: _clear_session_context(tokens) _init_session(sid, target, agent, history, cols=cols) if sid in _sessions: - _sessions[sid]["display_history"] = display_history + _sessions[sid]["display_history_prefix"] = display_history_prefix except Exception as e: return _err(rid, 5000, f"resume failed: {e}") session = _sessions.get(sid) or {} @@ -3170,7 +3173,9 @@ def _live_session_payload( session["transport"] = transport if touch: session["last_active"] = time.time() - history = list(session.get("display_history") or session.get("history") or []) + history = list(session.get("display_history_prefix") or []) + list( + session.get("history") or [] + ) inflight = _inflight_snapshot(session) running = bool(session.get("running")) payload = {