diff --git a/tests/test_tui_gateway_server.py b/tests/test_tui_gateway_server.py index 07d9765eea6..97f584d0732 100644 --- a/tests/test_tui_gateway_server.py +++ b/tests/test_tui_gateway_server.py @@ -4901,6 +4901,140 @@ def test_interrupt_clears_multiple_own_pending(): server._answers.pop(key, None) +def test_run_prompt_submit_registers_turn_thread_for_interrupt(monkeypatch): + """_run_prompt_submit must expose the actual turn thread to session.interrupt. + + prompt.submit's outer wrapper only waits for agent initialization, then + _run_prompt_submit starts the real conversation thread. If the session keeps + the wrapper thread handle, stop/esc sees a dead thread and never calls + agent.interrupt() on the live turn. + """ + calls = {"interrupted": False, "started": False} + + class _FakeThread: + def __init__(self, target=None, daemon=None): + self.target = target + + def start(self): + calls["started"] = True + + def is_alive(self): + return True + + agent = types.SimpleNamespace( + interrupt=lambda: calls.__setitem__("interrupted", True), + run_conversation=lambda *args, **kwargs: {}, + ) + session = _session(agent=agent, running=True) + server._sessions["sid"] = session + + try: + monkeypatch.setattr(server.threading, "Thread", _FakeThread) + monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None) + + server._run_prompt_submit("1", "sid", session, "hello") + + assert session.get("_run_thread") is not None + resp = server.handle_request( + {"id": "2", "method": "session.interrupt", "params": {"session_id": "sid"}} + ) + + assert resp.get("result"), f"got error: {resp.get('error')}" + assert calls["interrupted"] is True + finally: + server._sessions.pop("sid", None) + + +def test_interrupt_drops_queued_prompt_for_session(): + """Explicit stop cancels a queued next turn instead of auto-draining it.""" + calls = {"interrupted": False} + + class _LiveThread: + def is_alive(self): + return True + + session = _session( + agent=types.SimpleNamespace( + interrupt=lambda: calls.__setitem__("interrupted", True) + ), + running=True, + queued_prompt={"text": "next prompt", "transport": None}, + _run_thread=_LiveThread(), + ) + server._sessions["sid"] = session + + try: + resp = server.handle_request( + {"id": "1", "method": "session.interrupt", "params": {"session_id": "sid"}} + ) + + assert resp.get("result"), f"got error: {resp.get('error')}" + assert calls["interrupted"] is True + assert session.get("queued_prompt") is None + finally: + server._sessions.pop("sid", None) + + +def test_interrupt_before_agent_ready_prevents_late_turn_start(monkeypatch): + """Stop during lazy agent startup must not start the turn after init finishes.""" + threads = [] + calls = {"run_prompt": 0} + + class _FakeThread: + def __init__(self, target=None, daemon=None): + self.target = target + threads.append(self) + + def start(self): + return None + + def is_alive(self): + return True + + session = _session() + session["agent"] = None + server._sessions["sid"] = session + + try: + monkeypatch.setattr(server.threading, "Thread", _FakeThread) + monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None) + monkeypatch.setattr(server, "_ensure_session_db_row", lambda session: None) + monkeypatch.setattr(server, "_persist_branch_seed", lambda session: None) + monkeypatch.setattr(server, "_start_agent_build", lambda sid, session: None) + monkeypatch.setattr(server, "_wait_agent", lambda session, rid: None) + monkeypatch.setattr( + server, + "_run_prompt_submit", + lambda *args, **kwargs: calls.__setitem__( + "run_prompt", calls["run_prompt"] + 1 + ), + ) + + submit = server.handle_request( + { + "id": "1", + "method": "prompt.submit", + "params": {"session_id": "sid", "text": "hello"}, + } + ) + assert submit.get("result"), f"got error: {submit.get('error')}" + assert session["running"] is True + assert len(threads) == 1 + + stop = server.handle_request( + {"id": "2", "method": "session.interrupt", "params": {"session_id": "sid"}} + ) + assert stop.get("result"), f"got error: {stop.get('error')}" + + threads[0].target() + + assert calls["run_prompt"] == 0 + assert session["running"] is False + assert session.get("inflight_turn") is None + finally: + server._sessions.pop("sid", None) + + def test_clear_pending_without_sid_clears_all(): """_clear_pending(None) is the shutdown path — must still release every pending prompt regardless of owning session.""" diff --git a/tui_gateway/server.py b/tui_gateway/server.py index 8171caa9cde..826efc9faa2 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -7693,14 +7693,17 @@ def _(rid, params: dict) -> dict: # stuck (a crash/desync that skipped the run loop's `finally`), force-clear it # so the session can't be permanently bricked at 4009 "session busy" — every # send/restore/resume would otherwise reject until a full backend restart. - # A genuinely live turn is left alone: its cooperative interrupt + `finally` - # release `running` the normal way; clearing it here would let a second turn - # race the first on the same session. + # Always tell the agent to interrupt when the session claims a run is active: + # stale flags are cleared below, and fresh turns clear the interrupt flag at + # entry. This keeps a stale/missing thread handle from making Stop a no-op. run_thread = session.get("_run_thread") run_thread_alive = run_thread is not None and run_thread.is_alive() - should_interrupt = bool(session.get("running")) and run_thread_alive + should_interrupt = bool(session.get("running")) if should_interrupt and hasattr(session["agent"], "interrupt"): session["agent"].interrupt() + with session["history_lock"]: + session["_turn_cancel_requested"] = True + session["queued_prompt"] = None if not run_thread_alive: with session["history_lock"]: if session.get("running"): @@ -8031,6 +8034,7 @@ def _(rid, params: dict) -> dict: except Exception as exc: print(f"[tui_gateway] prompt.submit: replace_messages failed: {exc}", file=sys.stderr) session["running"] = True + session["_turn_cancel_requested"] = False session["last_active"] = time.time() _start_inflight_turn(session, text) @@ -8057,6 +8061,11 @@ def _(rid, params: dict) -> dict: session["running"] = False _clear_inflight_turn(session) return + with session["history_lock"]: + if session.get("_turn_cancel_requested") or not session.get("running"): + session["running"] = False + _clear_inflight_turn(session) + return _run_prompt_submit(rid, sid, session, text) run_thread = threading.Thread(target=run_after_agent_ready, daemon=True) @@ -8713,7 +8722,9 @@ def _run_prompt_submit(rid, sid: str, session: dict, text: Any) -> None: file=sys.stderr, ) - threading.Thread(target=run, daemon=True).start() + run_thread = threading.Thread(target=run, daemon=True) + session["_run_thread"] = run_thread + run_thread.start() @method("clipboard.paste")