diff --git a/tests/test_tui_mcp_late_refresh.py b/tests/test_tui_mcp_late_refresh.py new file mode 100644 index 00000000000..e3f423fba6f --- /dev/null +++ b/tests/test_tui_mcp_late_refresh.py @@ -0,0 +1,167 @@ +"""Tests for the TUI gateway's late MCP tool-snapshot refresh. + +When an MCP server connects slower than the bounded wait in ``_make_agent``, +the agent is built without its tools and the banner/tool count is stale for the +session. ``_schedule_mcp_late_refresh`` waits for discovery to land, then +rebuilds the snapshot and re-emits ``session.info`` — but only while the +session is still pre-first-turn, so it never invalidates a cached prompt. +""" + +import threading +import time +import types + +import model_tools +from tui_gateway import server +from tui_gateway import entry + + +def _make_fake_agent(initial_tools, *, user_turns=0, api_calls=0): + agent = types.SimpleNamespace() + agent.tools = list(initial_tools) + agent.valid_tool_names = {t["function"]["name"] for t in initial_tools} + agent._user_turn_count = user_turns + agent._api_call_count = api_calls + return agent + + +def _tool(name): + return {"type": "function", "function": {"name": name, "description": "", "parameters": {}}} + + +def _drain_refresh_threads(timeout=5.0): + deadline = time.time() + timeout + for th in list(threading.enumerate()): + if th.name.startswith("tui-mcp-late-refresh-"): + th.join(timeout=max(0.0, deadline - time.time())) + + +def _install(monkeypatch, *, in_flight, join_result, new_defs): + """Wire entry discovery accessors + get_tool_definitions, capture emits.""" + monkeypatch.setattr(entry, "mcp_discovery_in_flight", lambda: in_flight) + monkeypatch.setattr(entry, "join_mcp_discovery", lambda timeout=None: join_result) + monkeypatch.setattr(model_tools, "get_tool_definitions", lambda **kw: list(new_defs)) + monkeypatch.setattr(server, "_load_enabled_toolsets", lambda: None) + monkeypatch.setattr(server, "_session_info", lambda agent, session: {"tools_len": len(agent.tools)}) + + emitted = [] + monkeypatch.setattr(server, "_emit", lambda event, sid, payload=None: emitted.append((event, sid, payload))) + return emitted + + +def test_late_refresh_adds_tools_and_reemits_when_pre_first_turn(monkeypatch): + base = [_tool("read_file"), _tool("write_file")] + full = base + [_tool("mcp__nous_support__a")] # discovery added one tool + agent = _make_fake_agent(base) + sid = "sess-late-1" + server._sessions[sid] = {"agent": agent} + try: + emitted = _install(monkeypatch, in_flight=True, join_result=True, new_defs=full) + server._schedule_mcp_late_refresh(sid, agent) + _drain_refresh_threads() + + assert len(agent.tools) == 3 + assert "mcp__nous_support__a" in agent.valid_tool_names + assert ("session.info", sid, {"tools_len": 3}) in emitted + finally: + server._sessions.pop(sid, None) + + +def test_no_refresh_when_discovery_not_in_flight(monkeypatch): + base = [_tool("read_file")] + agent = _make_fake_agent(base) + sid = "sess-late-2" + server._sessions[sid] = {"agent": agent} + try: + # in_flight=False → helper returns immediately, no thread, no rebuild. + emitted = _install(monkeypatch, in_flight=False, join_result=True, new_defs=base + [_tool("x")]) + server._schedule_mcp_late_refresh(sid, agent) + _drain_refresh_threads() + + assert len(agent.tools) == 1 + assert emitted == [] + finally: + server._sessions.pop(sid, None) + + +def test_no_refresh_once_conversation_started(monkeypatch): + """Cache safety: never rebuild the tool list after the first turn.""" + base = [_tool("read_file")] + full = base + [_tool("mcp__late__b")] + agent = _make_fake_agent(base, user_turns=1) # a turn already happened + sid = "sess-late-3" + server._sessions[sid] = {"agent": agent} + try: + emitted = _install(monkeypatch, in_flight=True, join_result=True, new_defs=full) + server._schedule_mcp_late_refresh(sid, agent) + _drain_refresh_threads() + + # Snapshot frozen; no re-emit that would invalidate the prompt cache. + assert len(agent.tools) == 1 + assert emitted == [] + finally: + server._sessions.pop(sid, None) + + +def test_no_reemit_when_discovery_added_nothing(monkeypatch): + base = [_tool("read_file"), _tool("write_file")] + agent = _make_fake_agent(base) + sid = "sess-late-4" + server._sessions[sid] = {"agent": agent} + try: + # Discovery finished but the registry is unchanged (same count) → + # don't churn the client with a redundant session.info. + emitted = _install(monkeypatch, in_flight=True, join_result=True, new_defs=list(base)) + server._schedule_mcp_late_refresh(sid, agent) + _drain_refresh_threads() + + assert len(agent.tools) == 2 + assert emitted == [] + finally: + server._sessions.pop(sid, None) + + +def test_no_refresh_when_join_times_out(monkeypatch): + base = [_tool("read_file")] + full = base + [_tool("mcp__slow__c")] + agent = _make_fake_agent(base) + sid = "sess-late-5" + server._sessions[sid] = {"agent": agent} + try: + # Server never connected within the bound → join returns False, no rebuild. + emitted = _install(monkeypatch, in_flight=True, join_result=False, new_defs=full) + server._schedule_mcp_late_refresh(sid, agent) + _drain_refresh_threads() + + assert len(agent.tools) == 1 + assert emitted == [] + finally: + server._sessions.pop(sid, None) + + +def test_no_refresh_when_session_replaced(monkeypatch): + """If the session's agent was swapped (e.g. /new) while we waited, bail.""" + base = [_tool("read_file")] + full = base + [_tool("mcp__late__d")] + agent = _make_fake_agent(base) + other_agent = _make_fake_agent(base) + sid = "sess-late-6" + server._sessions[sid] = {"agent": agent} + try: + emitted = _install(monkeypatch, in_flight=True, join_result=True, new_defs=full) + + # Swap the stored agent out the moment join is awaited. + def _swap_join(timeout=None): + server._sessions[sid]["agent"] = other_agent + return True + + monkeypatch.setattr(entry, "join_mcp_discovery", _swap_join) + server._schedule_mcp_late_refresh(sid, agent) + _drain_refresh_threads() + + # Neither agent's snapshot was rebuilt; no emit. + assert len(agent.tools) == 1 + assert len(other_agent.tools) == 1 + assert emitted == [] + finally: + server._sessions.pop(sid, None) diff --git a/tui_gateway/entry.py b/tui_gateway/entry.py index 7069ec97605..28c055d57b2 100644 --- a/tui_gateway/entry.py +++ b/tui_gateway/entry.py @@ -210,6 +210,33 @@ def wait_for_mcp_discovery(timeout: float = 0.75) -> None: thread.join(timeout=timeout) +def mcp_discovery_in_flight() -> bool: + """Return True if the background MCP discovery thread is still running. + + Used by the agent-build path to decide whether to schedule a late tool + snapshot refresh: if discovery didn't land within the bounded + ``wait_for_mcp_discovery`` join, the agent was built without those tools + and the banner/tool count will be stale until they arrive. + """ + thread = _mcp_discovery_thread + return thread is not None and thread.is_alive() + + +def join_mcp_discovery(timeout: float | None = None) -> bool: + """Block until background MCP discovery finishes, up to ``timeout`` seconds. + + Returns True if discovery has completed (thread absent or no longer alive), + False if it is still running after the timeout. Unlike + ``wait_for_mcp_discovery`` this accepts an unbounded/long wait and reports + the outcome, for the off-critical-path late-refresh waiter. + """ + thread = _mcp_discovery_thread + if thread is None: + return True + thread.join(timeout=timeout) + return not thread.is_alive() + + def main(): _install_sidecar_publisher() diff --git a/tui_gateway/server.py b/tui_gateway/server.py index f3ceaa95637..cc9399a7c2e 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -1015,6 +1015,11 @@ def _start_agent_build(sid: str, session: dict) -> None: info["config_warning"] = cfg_warn logger.warning(cfg_warn) _emit("session.info", sid, info) + # If MCP discovery is still in flight (a server slower than the + # bounded wait_for_mcp_discovery join in _make_agent), the agent + # was built without those tools. Catch up once they land — see + # _schedule_mcp_late_refresh. Cache-safe (pre-first-turn only). + _schedule_mcp_late_refresh(sid, agent) except Exception as e: current["agent_error"] = str(e) _emit("error", sid, {"message": f"agent init failed: {e}"}) @@ -3405,6 +3410,87 @@ def _reset_session_agent(sid: str, session: dict) -> dict: return info +def _schedule_mcp_late_refresh(sid: str, agent) -> None: + """Refresh a session's tool snapshot when MCP discovery lands late. + + The agent snapshots ``agent.tools`` once at build time and never re-reads + the registry (run_agent/agent_init). ``_make_agent`` briefly joins the + background MCP discovery thread (``wait_for_mcp_discovery``, ~0.75s) so + already-spawning servers land in that snapshot — but a server that takes + longer than the bound to connect (common for an HTTP MCP server on first + connect) lands *after* the agent is built. Its tools are then absent from + both the agent and the banner for the whole session, even though the + classic CLI shows them (the CLI re-derives ``get_tool_definitions`` at + banner render time, which re-waits, so it picks them up). + + This schedules an off-critical-path daemon that waits for discovery to + finish, then rebuilds the snapshot and re-emits ``session.info`` so both + the agent's callable tools and the banner count catch up — the same + rebuild ``/reload-mcp`` performs, but automatic. + + Cache safety: the rebuild only runs while the session is still pre-first- + turn (no API call made yet → nothing cached to invalidate). If the user + has already sent a message, we leave the snapshot frozen rather than + invalidate the prompt cache mid-conversation — those late tools then + require an explicit ``/reload-mcp`` (which gates on user consent), exactly + as today. No-op when discovery already finished before the agent build. + """ + try: + from tui_gateway.entry import mcp_discovery_in_flight, join_mcp_discovery + except Exception: + return + if not mcp_discovery_in_flight(): + return + + def _wait_then_refresh() -> None: + # Bounded but generous — a server still not connected after this is + # genuinely slow/dead; the user can /reload-mcp once it recovers. + if not join_mcp_discovery(timeout=30.0): + return + with _sessions_lock: + session = _sessions.get(sid) + # Session may have been closed/reset while we waited. + if session is None or session.get("agent") is not agent: + return + # Cache safety: never rebuild the tool list once the conversation + # has started — that would invalidate the cached prompt prefix. + if ( + int(getattr(agent, "_user_turn_count", 0) or 0) > 0 + or int(getattr(agent, "_api_call_count", 0) or 0) > 0 + ): + return + try: + from model_tools import get_tool_definitions + + new_defs = get_tool_definitions( + enabled_toolsets=_load_enabled_toolsets(), + quiet_mode=True, + ) + except Exception as exc: + logger.warning( + "Late MCP refresh: get_tool_definitions failed for %s: %s", + sid, + exc, + ) + return + # No change (discovery added nothing new) → don't churn the client. + if len(new_defs or []) == len(getattr(agent, "tools", []) or []): + return + agent.tools = new_defs + agent.valid_tool_names = ( + {t["function"]["name"] for t in new_defs} if new_defs else set() + ) + info = _session_info(agent, session) + # Emit outside the lock — write_json must not block under _sessions_lock. + _emit("session.info", sid, info) + + threading.Thread( + target=_wait_then_refresh, + name=f"tui-mcp-late-refresh-{sid}", + daemon=True, + ).start() + + def _make_agent( sid: str, key: str, @@ -3643,6 +3729,7 @@ def _init_session( _sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid]) _notify_session_boundary("on_session_reset", key) _emit("session.info", sid, _session_info(agent, _sessions.get(sid, {}))) + _schedule_mcp_late_refresh(sid, agent) def _new_session_key() -> str: