feat: scroll aware sticky prompt

This commit is contained in:
Brooklyn Nicholson 2026-04-14 11:49:32 -05:00
commit 9a3a2925ed
141 changed files with 8867 additions and 829 deletions

View file

@ -313,6 +313,17 @@ def disable_session_yolo(session_key: str) -> None:
_session_yolo.discard(session_key)
def clear_session(session_key: str) -> None:
"""Remove all approval and yolo state for a given session."""
if not session_key:
return
with _lock:
_session_approved.pop(session_key, None)
_session_yolo.discard(session_key)
_pending.pop(session_key, None)
_gateway_queues.pop(session_key, None)
def is_session_yolo_enabled(session_key: str) -> bool:
"""Return True when YOLO bypass is enabled for a specific session."""
if not session_key:

View file

@ -556,27 +556,54 @@ class ShellFileOperations(FileOperations):
def _suggest_similar_files(self, path: str) -> ReadResult:
"""Suggest similar files when the requested file is not found."""
# Get directory and filename
dir_path = os.path.dirname(path) or "."
filename = os.path.basename(path)
# List files in directory
ls_cmd = f"ls -1 {self._escape_shell_arg(dir_path)} 2>/dev/null | head -20"
basename_no_ext = os.path.splitext(filename)[0]
ext = os.path.splitext(filename)[1].lower()
lower_name = filename.lower()
# List files in the target directory
ls_cmd = f"ls -1 {self._escape_shell_arg(dir_path)} 2>/dev/null | head -50"
ls_result = self._exec(ls_cmd)
similar = []
scored: list = [] # (score, filepath) — higher is better
if ls_result.exit_code == 0 and ls_result.stdout.strip():
files = ls_result.stdout.strip().split('\n')
# Simple similarity: files that share some characters with the target
for f in files:
# Check if filenames share significant overlap
common = set(filename.lower()) & set(f.lower())
if len(common) >= len(filename) * 0.5: # 50% character overlap
similar.append(os.path.join(dir_path, f))
for f in ls_result.stdout.strip().split('\n'):
if not f:
continue
lf = f.lower()
score = 0
# Exact match (shouldn't happen, but guard)
if lf == lower_name:
score = 100
# Same base name, different extension (e.g. config.yml vs config.yaml)
elif os.path.splitext(f)[0].lower() == basename_no_ext.lower():
score = 90
# Target is prefix of candidate or vice-versa
elif lf.startswith(lower_name) or lower_name.startswith(lf):
score = 70
# Substring match (candidate contains query)
elif lower_name in lf:
score = 60
# Reverse substring (query contains candidate name)
elif lf in lower_name and len(lf) > 2:
score = 40
# Same extension with some overlap
elif ext and os.path.splitext(f)[1].lower() == ext:
common = set(lower_name) & set(lf)
if len(common) >= max(len(lower_name), len(lf)) * 0.4:
score = 30
if score > 0:
scored.append((score, os.path.join(dir_path, f)))
scored.sort(key=lambda x: -x[0])
similar = [fp for _, fp in scored[:5]]
return ReadResult(
error=f"File not found: {path}",
similar_files=similar[:5] # Limit to 5 suggestions
similar_files=similar
)
def read_file_raw(self, path: str) -> ReadResult:
@ -845,8 +872,33 @@ class ShellFileOperations(FileOperations):
# Validate that the path exists before searching
check = self._exec(f"test -e {self._escape_shell_arg(path)} && echo exists || echo not_found")
if "not_found" in check.stdout:
# Try to suggest nearby paths
parent = os.path.dirname(path) or "."
basename_query = os.path.basename(path)
hint_parts = [f"Path not found: {path}"]
# Check if parent directory exists and list similar entries
parent_check = self._exec(
f"test -d {self._escape_shell_arg(parent)} && echo yes || echo no"
)
if "yes" in parent_check.stdout and basename_query:
ls_result = self._exec(
f"ls -1 {self._escape_shell_arg(parent)} 2>/dev/null | head -20"
)
if ls_result.exit_code == 0 and ls_result.stdout.strip():
lower_q = basename_query.lower()
candidates = []
for entry in ls_result.stdout.strip().split('\n'):
if not entry:
continue
le = entry.lower()
if lower_q in le or le in lower_q or le.startswith(lower_q[:3]):
candidates.append(os.path.join(parent, entry))
if candidates:
hint_parts.append(
"Similar paths: " + ", ".join(candidates[:5])
)
return SearchResult(
error=f"Path not found: {path}. Verify the path exists (use 'terminal' to check).",
error=". ".join(hint_parts),
total_count=0
)
@ -912,7 +964,8 @@ class ShellFileOperations(FileOperations):
rg --files respects .gitignore and excludes hidden directories by
default, and uses parallel directory traversal for ~200x speedup
over find on wide trees.
over find on wide trees. Results are sorted by modification time
(most recently edited first) when rg >= 13.0 supports --sortr.
"""
# rg --files -g uses glob patterns; wrap bare names so they match
# at any depth (equivalent to find -name).
@ -922,14 +975,25 @@ class ShellFileOperations(FileOperations):
glob_pattern = pattern
fetch_limit = limit + offset
cmd = (
f"rg --files -g {self._escape_shell_arg(glob_pattern)} "
# Try mtime-sorted first (rg 13+); fall back to unsorted if not supported.
cmd_sorted = (
f"rg --files --sortr=modified -g {self._escape_shell_arg(glob_pattern)} "
f"{self._escape_shell_arg(path)} 2>/dev/null "
f"| head -n {fetch_limit}"
)
result = self._exec(cmd, timeout=60)
result = self._exec(cmd_sorted, timeout=60)
all_files = [f for f in result.stdout.strip().split('\n') if f]
if not all_files:
# --sortr may have failed on older rg; retry without it.
cmd_plain = (
f"rg --files -g {self._escape_shell_arg(glob_pattern)} "
f"{self._escape_shell_arg(path)} 2>/dev/null "
f"| head -n {fetch_limit}"
)
result = self._exec(cmd_plain, timeout=60)
all_files = [f for f in result.stdout.strip().split('\n') if f]
page = all_files[offset:offset + limit]
return SearchResult(

View file

@ -70,6 +70,7 @@ Thread safety:
"""
import asyncio
import concurrent.futures
import inspect
import json
import logging
@ -1167,13 +1168,43 @@ def _ensure_mcp_loop():
def _run_on_mcp_loop(coro, timeout: float = 30):
"""Schedule a coroutine on the MCP event loop and block until done."""
"""Schedule a coroutine on the MCP event loop and block until done.
Poll in short intervals so the calling agent thread can honor user
interrupts while the MCP work is still running on the background loop.
"""
from tools.interrupt import is_interrupted
with _lock:
loop = _mcp_loop
if loop is None or not loop.is_running():
raise RuntimeError("MCP event loop is not running")
future = asyncio.run_coroutine_threadsafe(coro, loop)
return future.result(timeout=timeout)
deadline = None if timeout is None else time.monotonic() + timeout
while True:
if is_interrupted():
future.cancel()
raise InterruptedError("User sent a new message")
wait_timeout = 0.1
if deadline is not None:
remaining = deadline - time.monotonic()
if remaining <= 0:
return future.result(timeout=0)
wait_timeout = min(wait_timeout, remaining)
try:
return future.result(timeout=wait_timeout)
except concurrent.futures.TimeoutError:
continue
def _interrupted_call_result() -> str:
"""Standardized JSON error for a user-interrupted MCP tool call."""
return json.dumps({
"error": "MCP call interrupted: user sent a new message"
})
# ---------------------------------------------------------------------------
@ -1299,6 +1330,8 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
try:
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except InterruptedError:
return _interrupted_call_result()
except Exception as exc:
logger.error(
"MCP tool %s/%s call failed: %s",
@ -1342,6 +1375,8 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
try:
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except InterruptedError:
return _interrupted_call_result()
except Exception as exc:
logger.error(
"MCP %s/list_resources failed: %s", server_name, exc,
@ -1386,6 +1421,8 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
try:
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except InterruptedError:
return _interrupted_call_result()
except Exception as exc:
logger.error(
"MCP %s/read_resource failed: %s", server_name, exc,
@ -1433,6 +1470,8 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
try:
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except InterruptedError:
return _interrupted_call_result()
except Exception as exc:
logger.error(
"MCP %s/list_prompts failed: %s", server_name, exc,
@ -1488,6 +1527,8 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
try:
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
except InterruptedError:
return _interrupted_call_result()
except Exception as exc:
logger.error(
"MCP %s/get_prompt failed: %s", server_name, exc,

View file

@ -16,6 +16,7 @@ Import chain (circular-import safe):
import json
import logging
import threading
from typing import Callable, Dict, List, Optional, Set
logger = logging.getLogger(__name__)
@ -51,6 +52,49 @@ class ToolRegistry:
def __init__(self):
self._tools: Dict[str, ToolEntry] = {}
self._toolset_checks: Dict[str, Callable] = {}
# MCP dynamic refresh can mutate the registry while other threads are
# reading tool metadata, so keep mutations serialized and readers on
# stable snapshots.
self._lock = threading.RLock()
def _snapshot_state(self) -> tuple[List[ToolEntry], Dict[str, Callable]]:
"""Return a coherent snapshot of registry entries and toolset checks."""
with self._lock:
return list(self._tools.values()), dict(self._toolset_checks)
def _snapshot_entries(self) -> List[ToolEntry]:
"""Return a stable snapshot of registered tool entries."""
return self._snapshot_state()[0]
def _snapshot_toolset_checks(self) -> Dict[str, Callable]:
"""Return a stable snapshot of toolset availability checks."""
return self._snapshot_state()[1]
def _evaluate_toolset_check(self, toolset: str, check: Callable | None) -> bool:
"""Run a toolset check, treating missing or failing checks as unavailable/available."""
if not check:
return True
try:
return bool(check())
except Exception:
logger.debug("Toolset %s check raised; marking unavailable", toolset)
return False
def get_entry(self, name: str) -> Optional[ToolEntry]:
"""Return a registered tool entry by name, or None."""
with self._lock:
return self._tools.get(name)
def get_registered_toolset_names(self) -> List[str]:
"""Return sorted unique toolset names present in the registry."""
return sorted({entry.toolset for entry in self._snapshot_entries()})
def get_tool_names_for_toolset(self, toolset: str) -> List[str]:
"""Return sorted tool names registered under a given toolset."""
return sorted(
entry.name for entry in self._snapshot_entries()
if entry.toolset == toolset
)
# ------------------------------------------------------------------
# Registration
@ -70,27 +114,28 @@ class ToolRegistry:
max_result_size_chars: int | float | None = None,
):
"""Register a tool. Called at module-import time by each tool file."""
existing = self._tools.get(name)
if existing and existing.toolset != toolset:
logger.warning(
"Tool name collision: '%s' (toolset '%s') is being "
"overwritten by toolset '%s'",
name, existing.toolset, toolset,
with self._lock:
existing = self._tools.get(name)
if existing and existing.toolset != toolset:
logger.warning(
"Tool name collision: '%s' (toolset '%s') is being "
"overwritten by toolset '%s'",
name, existing.toolset, toolset,
)
self._tools[name] = ToolEntry(
name=name,
toolset=toolset,
schema=schema,
handler=handler,
check_fn=check_fn,
requires_env=requires_env or [],
is_async=is_async,
description=description or schema.get("description", ""),
emoji=emoji,
max_result_size_chars=max_result_size_chars,
)
self._tools[name] = ToolEntry(
name=name,
toolset=toolset,
schema=schema,
handler=handler,
check_fn=check_fn,
requires_env=requires_env or [],
is_async=is_async,
description=description or schema.get("description", ""),
emoji=emoji,
max_result_size_chars=max_result_size_chars,
)
if check_fn and toolset not in self._toolset_checks:
self._toolset_checks[toolset] = check_fn
if check_fn and toolset not in self._toolset_checks:
self._toolset_checks[toolset] = check_fn
def deregister(self, name: str) -> None:
"""Remove a tool from the registry.
@ -99,14 +144,15 @@ class ToolRegistry:
same toolset. Used by MCP dynamic tool discovery to nuke-and-repave
when a server sends ``notifications/tools/list_changed``.
"""
entry = self._tools.pop(name, None)
if entry is None:
return
# Drop the toolset check if this was the last tool in that toolset
if entry.toolset in self._toolset_checks and not any(
e.toolset == entry.toolset for e in self._tools.values()
):
self._toolset_checks.pop(entry.toolset, None)
with self._lock:
entry = self._tools.pop(name, None)
if entry is None:
return
# Drop the toolset check if this was the last tool in that toolset
if entry.toolset in self._toolset_checks and not any(
e.toolset == entry.toolset for e in self._tools.values()
):
self._toolset_checks.pop(entry.toolset, None)
logger.debug("Deregistered tool: %s", name)
# ------------------------------------------------------------------
@ -121,8 +167,9 @@ class ToolRegistry:
"""
result = []
check_results: Dict[Callable, bool] = {}
entries_by_name = {entry.name: entry for entry in self._snapshot_entries()}
for name in sorted(tool_names):
entry = self._tools.get(name)
entry = entries_by_name.get(name)
if not entry:
continue
if entry.check_fn:
@ -153,7 +200,7 @@ class ToolRegistry:
* All exceptions are caught and returned as ``{"error": "..."}``
for consistent error format.
"""
entry = self._tools.get(name)
entry = self.get_entry(name)
if not entry:
return json.dumps({"error": f"Unknown tool: {name}"})
try:
@ -171,7 +218,7 @@ class ToolRegistry:
def get_max_result_size(self, name: str, default: int | float | None = None) -> int | float:
"""Return per-tool max result size, or *default* (or global default)."""
entry = self._tools.get(name)
entry = self.get_entry(name)
if entry and entry.max_result_size_chars is not None:
return entry.max_result_size_chars
if default is not None:
@ -181,7 +228,7 @@ class ToolRegistry:
def get_all_tool_names(self) -> List[str]:
"""Return sorted list of all registered tool names."""
return sorted(self._tools.keys())
return sorted(entry.name for entry in self._snapshot_entries())
def get_schema(self, name: str) -> Optional[dict]:
"""Return a tool's raw schema dict, bypassing check_fn filtering.
@ -189,22 +236,22 @@ class ToolRegistry:
Useful for token estimation and introspection where availability
doesn't matter — only the schema content does.
"""
entry = self._tools.get(name)
entry = self.get_entry(name)
return entry.schema if entry else None
def get_toolset_for_tool(self, name: str) -> Optional[str]:
"""Return the toolset a tool belongs to, or None."""
entry = self._tools.get(name)
entry = self.get_entry(name)
return entry.toolset if entry else None
def get_emoji(self, name: str, default: str = "") -> str:
"""Return the emoji for a tool, or *default* if unset."""
entry = self._tools.get(name)
entry = self.get_entry(name)
return (entry.emoji if entry and entry.emoji else default)
def get_tool_to_toolset_map(self) -> Dict[str, str]:
"""Return ``{tool_name: toolset_name}`` for every registered tool."""
return {name: e.toolset for name, e in self._tools.items()}
return {entry.name: entry.toolset for entry in self._snapshot_entries()}
def is_toolset_available(self, toolset: str) -> bool:
"""Check if a toolset's requirements are met.
@ -212,28 +259,30 @@ class ToolRegistry:
Returns False (rather than crashing) when the check function raises
an unexpected exception (e.g. network error, missing import, bad config).
"""
check = self._toolset_checks.get(toolset)
if not check:
return True
try:
return bool(check())
except Exception:
logger.debug("Toolset %s check raised; marking unavailable", toolset)
return False
with self._lock:
check = self._toolset_checks.get(toolset)
return self._evaluate_toolset_check(toolset, check)
def check_toolset_requirements(self) -> Dict[str, bool]:
"""Return ``{toolset: available_bool}`` for every toolset."""
toolsets = set(e.toolset for e in self._tools.values())
return {ts: self.is_toolset_available(ts) for ts in sorted(toolsets)}
entries, toolset_checks = self._snapshot_state()
toolsets = sorted({entry.toolset for entry in entries})
return {
toolset: self._evaluate_toolset_check(toolset, toolset_checks.get(toolset))
for toolset in toolsets
}
def get_available_toolsets(self) -> Dict[str, dict]:
"""Return toolset metadata for UI display."""
toolsets: Dict[str, dict] = {}
for entry in self._tools.values():
entries, toolset_checks = self._snapshot_state()
for entry in entries:
ts = entry.toolset
if ts not in toolsets:
toolsets[ts] = {
"available": self.is_toolset_available(ts),
"available": self._evaluate_toolset_check(
ts, toolset_checks.get(ts)
),
"tools": [],
"description": "",
"requirements": [],
@ -248,13 +297,14 @@ class ToolRegistry:
def get_toolset_requirements(self) -> Dict[str, dict]:
"""Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat."""
result: Dict[str, dict] = {}
for entry in self._tools.values():
entries, toolset_checks = self._snapshot_state()
for entry in entries:
ts = entry.toolset
if ts not in result:
result[ts] = {
"name": ts,
"env_vars": [],
"check_fn": self._toolset_checks.get(ts),
"check_fn": toolset_checks.get(ts),
"setup_url": None,
"tools": [],
}
@ -270,18 +320,19 @@ class ToolRegistry:
available = []
unavailable = []
seen = set()
for entry in self._tools.values():
entries, toolset_checks = self._snapshot_state()
for entry in entries:
ts = entry.toolset
if ts in seen:
continue
seen.add(ts)
if self.is_toolset_available(ts):
if self._evaluate_toolset_check(ts, toolset_checks.get(ts)):
available.append(ts)
else:
unavailable.append({
"name": ts,
"env_vars": entry.requires_env,
"tools": [e.name for e in self._tools.values() if e.toolset == ts],
"tools": [e.name for e in entries if e.toolset == ts],
})
return available, unavailable

View file

@ -152,6 +152,7 @@ def _handle_send(args):
"whatsapp": Platform.WHATSAPP,
"signal": Platform.SIGNAL,
"bluebubbles": Platform.BLUEBUBBLES,
"qqbot": Platform.QQBOT,
"matrix": Platform.MATRIX,
"mattermost": Platform.MATTERMOST,
"homeassistant": Platform.HOMEASSISTANT,
@ -426,6 +427,8 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None,
result = await _send_wecom(pconfig.extra, chat_id, chunk)
elif platform == Platform.BLUEBUBBLES:
result = await _send_bluebubbles(pconfig.extra, chat_id, chunk)
elif platform == Platform.QQBOT:
result = await _send_qqbot(pconfig, chat_id, chunk)
else:
result = {"error": f"Direct sending not yet implemented for {platform.value}"}
@ -1038,6 +1041,58 @@ def _check_send_message():
return False
async def _send_qqbot(pconfig, chat_id, message):
"""Send via QQBot using the REST API directly (no WebSocket needed).
Uses the QQ Bot Open Platform REST endpoints to get an access token
and post a message. Works for guild channels without requiring
a running gateway adapter.
"""
try:
import httpx
except ImportError:
return _error("QQBot direct send requires httpx. Run: pip install httpx")
extra = pconfig.extra or {}
appid = extra.get("app_id") or os.getenv("QQ_APP_ID", "")
secret = (pconfig.token or extra.get("client_secret")
or os.getenv("QQ_CLIENT_SECRET", ""))
if not appid or not secret:
return _error("QQBot: QQ_APP_ID / QQ_CLIENT_SECRET not configured.")
try:
async with httpx.AsyncClient(timeout=15) as client:
# Step 1: Get access token
token_resp = await client.post(
"https://bots.qq.com/app/getAppAccessToken",
json={"appId": str(appid), "clientSecret": str(secret)},
)
if token_resp.status_code != 200:
return _error(f"QQBot token request failed: {token_resp.status_code}")
token_data = token_resp.json()
access_token = token_data.get("access_token")
if not access_token:
return _error(f"QQBot: no access_token in response")
# Step 2: Send message via REST
headers = {
"Authorization": f"QQBotAccessToken {access_token}",
"Content-Type": "application/json",
}
url = f"https://api.sgroup.qq.com/channels/{chat_id}/messages"
payload = {"content": message[:4000], "msg_type": 0}
resp = await client.post(url, json=payload, headers=headers)
if resp.status_code in (200, 201):
data = resp.json()
return {"success": True, "platform": "qqbot", "chat_id": chat_id,
"message_id": data.get("id")}
else:
return _error(f"QQBot send failed: {resp.status_code} {resp.text}")
except Exception as e:
return _error(f"QQBot send failed: {e}")
# --- Registry ---
from tools.registry import registry, tool_error

View file

@ -245,6 +245,9 @@ def _get_required_environment_variables(
if isinstance(required_for, str) and required_for.strip():
normalized["required_for"] = required_for.strip()
if entry.get("optional"):
normalized["optional"] = True
seen.add(env_name)
required.append(normalized)
@ -378,6 +381,8 @@ def _remaining_required_environment_names(
remaining = []
for entry in required_env_vars:
name = entry["name"]
if entry.get("optional"):
continue
if name in missing_names or not _is_env_var_persisted(name, env_snapshot):
remaining.append(name)
return remaining
@ -1042,7 +1047,8 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
missing_required_env_vars = [
e
for e in required_env_vars
if not _is_env_var_persisted(e["name"], env_snapshot)
if not e.get("optional")
and not _is_env_var_persisted(e["name"], env_snapshot)
]
capture_result = _capture_required_environment_variables(
skill_name,