fix(tui): keep resumed live history current

This commit is contained in:
rexdotsh 2026-06-03 20:59:23 +05:30 committed by kshitij
parent 98903d0313
commit bd6d098762
2 changed files with 106 additions and 2 deletions

View file

@ -496,6 +496,105 @@ def test_session_resume_reuses_existing_live_session(server, monkeypatch):
assert created_sids == [first["result"]["session_id"]]
def test_session_resume_live_payload_uses_current_history_with_ancestors(server, monkeypatch):
"""Live resume should not reuse a stale ancestor-inclusive snapshot."""
target = "20260409_010101_child"
ancestor_history = [{"role": "user", "content": "ancestor"}]
current_history = [
{"role": "user", "content": "current"},
{"role": "assistant", "content": "current reply"},
]
class _DB:
def get_session(self, _sid):
return {"id": target}
def get_session_by_title(self, _title):
return None
def reopen_session(self, _sid):
return None
def get_messages_as_conversation(self, _sid, include_ancestors=False):
if include_ancestors:
return ancestor_history + current_history
return list(current_history)
class _Worker:
def close(self):
pass
monkeypatch.setattr(server, "_get_db", lambda: _DB())
monkeypatch.setattr(
server,
"_make_agent",
lambda _sid, key, session_id=None: types.SimpleNamespace(
model="test/model", session_id=session_id or key
),
)
monkeypatch.setattr(server, "_SlashWorker", lambda _key, _model: _Worker())
monkeypatch.setattr(
server,
"_start_notification_poller",
lambda _sid, _session: threading.Event(),
)
monkeypatch.setattr(server, "_notify_session_boundary", lambda *_args, **_kwargs: None)
monkeypatch.setattr(server, "_wire_callbacks", lambda _sid: None)
monkeypatch.setattr(server, "_emit", lambda *_args, **_kwargs: None)
monkeypatch.setattr(
server,
"_session_info",
lambda _agent, _session=None: {"model": "test/model"},
)
fake_approval = types.SimpleNamespace(
load_permanent_allowlist=lambda: None,
register_gateway_notify=lambda *_args, **_kwargs: None,
)
with patch.dict(sys.modules, {"tools.approval": fake_approval}):
first = server.handle_request(
{
"id": "first",
"method": "session.resume",
"params": {"session_id": target, "cols": 100},
}
)
assert "error" not in first
sid = first["result"]["session_id"]
assert first["result"]["messages"] == [
{"role": "user", "text": "ancestor"},
{"role": "user", "text": "current"},
{"role": "assistant", "text": "current reply"},
]
with server._sessions[sid]["history_lock"]:
server._sessions[sid]["history"] = current_history + [
{"role": "user", "content": "new live turn"},
{"role": "assistant", "content": "new live reply"},
]
second = server.handle_request(
{
"id": "second",
"method": "session.resume",
"params": {"session_id": target, "cols": 120},
}
)
assert "error" not in second
assert second["result"]["session_id"] == sid
assert second["result"]["messages"] == [
{"role": "user", "text": "ancestor"},
{"role": "user", "text": "current"},
{"role": "assistant", "text": "current reply"},
{"role": "user", "text": "new live turn"},
{"role": "assistant", "text": "new live reply"},
]
def test_make_agent_accepts_list_system_prompt(server, monkeypatch):
captured = {}

View file

@ -3016,6 +3016,9 @@ def _(rid, params: dict) -> dict:
display_history = db.get_messages_as_conversation(
target, include_ancestors=True
)
display_history_prefix = display_history[
: max(0, len(display_history) - len(history))
]
messages = _history_to_messages(display_history)
tokens = _set_session_context(target)
try:
@ -3024,7 +3027,7 @@ def _(rid, params: dict) -> dict:
_clear_session_context(tokens)
_init_session(sid, target, agent, history, cols=cols)
if sid in _sessions:
_sessions[sid]["display_history"] = display_history
_sessions[sid]["display_history_prefix"] = display_history_prefix
except Exception as e:
return _err(rid, 5000, f"resume failed: {e}")
session = _sessions.get(sid) or {}
@ -3170,7 +3173,9 @@ def _live_session_payload(
session["transport"] = transport
if touch:
session["last_active"] = time.time()
history = list(session.get("display_history") or session.get("history") or [])
history = list(session.get("display_history_prefix") or []) + list(
session.get("history") or []
)
inflight = _inflight_snapshot(session)
running = bool(session.get("running"))
payload = {