Merge pull request #52871 from NousResearch/bb/fix-tui-interrupt-queued

fix(tui-gateway): make stop interrupt queued turns
This commit is contained in:
brooklyn! 2026-06-26 00:37:13 -05:00 committed by GitHub
commit 6ba551e942
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 150 additions and 5 deletions

View file

@ -4901,6 +4901,140 @@ def test_interrupt_clears_multiple_own_pending():
server._answers.pop(key, None)
def test_run_prompt_submit_registers_turn_thread_for_interrupt(monkeypatch):
"""_run_prompt_submit must expose the actual turn thread to session.interrupt.
prompt.submit's outer wrapper only waits for agent initialization, then
_run_prompt_submit starts the real conversation thread. If the session keeps
the wrapper thread handle, stop/esc sees a dead thread and never calls
agent.interrupt() on the live turn.
"""
calls = {"interrupted": False, "started": False}
class _FakeThread:
def __init__(self, target=None, daemon=None):
self.target = target
def start(self):
calls["started"] = True
def is_alive(self):
return True
agent = types.SimpleNamespace(
interrupt=lambda: calls.__setitem__("interrupted", True),
run_conversation=lambda *args, **kwargs: {},
)
session = _session(agent=agent, running=True)
server._sessions["sid"] = session
try:
monkeypatch.setattr(server.threading, "Thread", _FakeThread)
monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None)
server._run_prompt_submit("1", "sid", session, "hello")
assert session.get("_run_thread") is not None
resp = server.handle_request(
{"id": "2", "method": "session.interrupt", "params": {"session_id": "sid"}}
)
assert resp.get("result"), f"got error: {resp.get('error')}"
assert calls["interrupted"] is True
finally:
server._sessions.pop("sid", None)
def test_interrupt_drops_queued_prompt_for_session():
"""Explicit stop cancels a queued next turn instead of auto-draining it."""
calls = {"interrupted": False}
class _LiveThread:
def is_alive(self):
return True
session = _session(
agent=types.SimpleNamespace(
interrupt=lambda: calls.__setitem__("interrupted", True)
),
running=True,
queued_prompt={"text": "next prompt", "transport": None},
_run_thread=_LiveThread(),
)
server._sessions["sid"] = session
try:
resp = server.handle_request(
{"id": "1", "method": "session.interrupt", "params": {"session_id": "sid"}}
)
assert resp.get("result"), f"got error: {resp.get('error')}"
assert calls["interrupted"] is True
assert session.get("queued_prompt") is None
finally:
server._sessions.pop("sid", None)
def test_interrupt_before_agent_ready_prevents_late_turn_start(monkeypatch):
"""Stop during lazy agent startup must not start the turn after init finishes."""
threads = []
calls = {"run_prompt": 0}
class _FakeThread:
def __init__(self, target=None, daemon=None):
self.target = target
threads.append(self)
def start(self):
return None
def is_alive(self):
return True
session = _session()
session["agent"] = None
server._sessions["sid"] = session
try:
monkeypatch.setattr(server.threading, "Thread", _FakeThread)
monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None)
monkeypatch.setattr(server, "_ensure_session_db_row", lambda session: None)
monkeypatch.setattr(server, "_persist_branch_seed", lambda session: None)
monkeypatch.setattr(server, "_start_agent_build", lambda sid, session: None)
monkeypatch.setattr(server, "_wait_agent", lambda session, rid: None)
monkeypatch.setattr(
server,
"_run_prompt_submit",
lambda *args, **kwargs: calls.__setitem__(
"run_prompt", calls["run_prompt"] + 1
),
)
submit = server.handle_request(
{
"id": "1",
"method": "prompt.submit",
"params": {"session_id": "sid", "text": "hello"},
}
)
assert submit.get("result"), f"got error: {submit.get('error')}"
assert session["running"] is True
assert len(threads) == 1
stop = server.handle_request(
{"id": "2", "method": "session.interrupt", "params": {"session_id": "sid"}}
)
assert stop.get("result"), f"got error: {stop.get('error')}"
threads[0].target()
assert calls["run_prompt"] == 0
assert session["running"] is False
assert session.get("inflight_turn") is None
finally:
server._sessions.pop("sid", None)
def test_clear_pending_without_sid_clears_all():
"""_clear_pending(None) is the shutdown path — must still release
every pending prompt regardless of owning session."""

View file

@ -7693,14 +7693,17 @@ def _(rid, params: dict) -> dict:
# stuck (a crash/desync that skipped the run loop's `finally`), force-clear it
# so the session can't be permanently bricked at 4009 "session busy" — every
# send/restore/resume would otherwise reject until a full backend restart.
# A genuinely live turn is left alone: its cooperative interrupt + `finally`
# release `running` the normal way; clearing it here would let a second turn
# race the first on the same session.
# Always tell the agent to interrupt when the session claims a run is active:
# stale flags are cleared below, and fresh turns clear the interrupt flag at
# entry. This keeps a stale/missing thread handle from making Stop a no-op.
run_thread = session.get("_run_thread")
run_thread_alive = run_thread is not None and run_thread.is_alive()
should_interrupt = bool(session.get("running")) and run_thread_alive
should_interrupt = bool(session.get("running"))
if should_interrupt and hasattr(session["agent"], "interrupt"):
session["agent"].interrupt()
with session["history_lock"]:
session["_turn_cancel_requested"] = True
session["queued_prompt"] = None
if not run_thread_alive:
with session["history_lock"]:
if session.get("running"):
@ -8031,6 +8034,7 @@ def _(rid, params: dict) -> dict:
except Exception as exc:
print(f"[tui_gateway] prompt.submit: replace_messages failed: {exc}", file=sys.stderr)
session["running"] = True
session["_turn_cancel_requested"] = False
session["last_active"] = time.time()
_start_inflight_turn(session, text)
@ -8057,6 +8061,11 @@ def _(rid, params: dict) -> dict:
session["running"] = False
_clear_inflight_turn(session)
return
with session["history_lock"]:
if session.get("_turn_cancel_requested") or not session.get("running"):
session["running"] = False
_clear_inflight_turn(session)
return
_run_prompt_submit(rid, sid, session, text)
run_thread = threading.Thread(target=run_after_agent_ready, daemon=True)
@ -8713,7 +8722,9 @@ def _run_prompt_submit(rid, sid: str, session: dict, text: Any) -> None:
file=sys.stderr,
)
threading.Thread(target=run, daemon=True).start()
run_thread = threading.Thread(target=run, daemon=True)
session["_run_thread"] = run_thread
run_thread.start()
@method("clipboard.paste")