diff --git a/acp_adapter/events.py b/acp_adapter/events.py index 08da40a68..1257f902e 100644 --- a/acp_adapter/events.py +++ b/acp_adapter/events.py @@ -49,6 +49,7 @@ def make_tool_progress_cb( session_id: str, loop: asyncio.AbstractEventLoop, tool_call_ids: Dict[str, Deque[str]], + tool_call_meta: Dict[str, Dict[str, Any]], ) -> Callable: """Create a ``tool_progress_callback`` for AIAgent. @@ -84,6 +85,16 @@ def make_tool_progress_cb( tool_call_ids[name] = queue queue.append(tc_id) + snapshot = None + if name in {"write_file", "patch", "skill_manage"}: + try: + from agent.display import capture_local_edit_snapshot + + snapshot = capture_local_edit_snapshot(name, args) + except Exception: + logger.debug("Failed to capture ACP edit snapshot for %s", name, exc_info=True) + tool_call_meta[tc_id] = {"args": args, "snapshot": snapshot} + update = build_tool_start(tc_id, name, args) _send_update(conn, session_id, loop, update) @@ -119,6 +130,7 @@ def make_step_cb( session_id: str, loop: asyncio.AbstractEventLoop, tool_call_ids: Dict[str, Deque[str]], + tool_call_meta: Dict[str, Dict[str, Any]], ) -> Callable: """Create a ``step_callback`` for AIAgent. @@ -132,10 +144,12 @@ def make_step_cb( for tool_info in prev_tools: tool_name = None result = None + function_args = None if isinstance(tool_info, dict): tool_name = tool_info.get("name") or tool_info.get("function_name") result = tool_info.get("result") or tool_info.get("output") + function_args = tool_info.get("arguments") or tool_info.get("args") elif isinstance(tool_info, str): tool_name = tool_info @@ -145,8 +159,13 @@ def make_step_cb( tool_call_ids[tool_name] = queue if tool_name and queue: tc_id = queue.popleft() + meta = tool_call_meta.pop(tc_id, {}) update = build_tool_complete( - tc_id, tool_name, result=str(result) if result is not None else None + tc_id, + tool_name, + result=str(result) if result is not None else None, + function_args=function_args or meta.get("args"), + snapshot=meta.get("snapshot"), ) _send_update(conn, session_id, loop, update) if not queue: diff --git a/acp_adapter/server.py b/acp_adapter/server.py index 29f9a10e8..4685a68a8 100644 --- a/acp_adapter/server.py +++ b/acp_adapter/server.py @@ -26,6 +26,7 @@ from acp.schema import ( McpServerHttp, McpServerSse, McpServerStdio, + ModelInfo, NewSessionResponse, PromptResponse, ResumeSessionResponse, @@ -36,6 +37,7 @@ from acp.schema import ( SessionCapabilities, SessionForkCapabilities, SessionListCapabilities, + SessionModelState, SessionResumeCapabilities, SessionInfo, TextContentBlock, @@ -147,6 +149,98 @@ class HermesACPAgent(acp.Agent): self._conn = conn logger.info("ACP client connected") + @staticmethod + def _encode_model_choice(provider: str | None, model: str | None) -> str: + """Encode a model selection so ACP clients can keep provider context.""" + raw_model = str(model or "").strip() + if not raw_model: + return "" + raw_provider = str(provider or "").strip().lower() + if not raw_provider: + return raw_model + return f"{raw_provider}:{raw_model}" + + def _build_model_state(self, state: SessionState) -> SessionModelState | None: + """Return the ACP model selector payload for editors like Zed.""" + model = str(state.model or getattr(state.agent, "model", "") or "").strip() + provider = getattr(state.agent, "provider", None) or detect_provider() or "openrouter" + + try: + from hermes_cli.models import curated_models_for_provider, normalize_provider, provider_label + + normalized_provider = normalize_provider(provider) + provider_name = provider_label(normalized_provider) + available_models: list[ModelInfo] = [] + seen_ids: set[str] = set() + + for model_id, description in curated_models_for_provider(normalized_provider): + rendered_model = str(model_id or "").strip() + if not rendered_model: + continue + choice_id = self._encode_model_choice(normalized_provider, rendered_model) + if choice_id in seen_ids: + continue + desc_parts = [f"Provider: {provider_name}"] + if description: + desc_parts.append(str(description).strip()) + if rendered_model == model: + desc_parts.append("current") + available_models.append( + ModelInfo( + model_id=choice_id, + name=rendered_model, + description=" • ".join(part for part in desc_parts if part), + ) + ) + seen_ids.add(choice_id) + + current_model_id = self._encode_model_choice(normalized_provider, model) + if current_model_id and current_model_id not in seen_ids: + available_models.insert( + 0, + ModelInfo( + model_id=current_model_id, + name=model, + description=f"Provider: {provider_name} • current", + ), + ) + + if available_models: + return SessionModelState( + available_models=available_models, + current_model_id=current_model_id or available_models[0].model_id, + ) + except Exception: + logger.debug("Could not build ACP model state", exc_info=True) + + if not model: + return None + + fallback_choice = self._encode_model_choice(provider, model) + return SessionModelState( + available_models=[ModelInfo(model_id=fallback_choice, name=model)], + current_model_id=fallback_choice, + ) + + @staticmethod + def _resolve_model_selection(raw_model: str, current_provider: str) -> tuple[str, str]: + """Resolve ``provider:model`` input into the provider and normalized model id.""" + target_provider = current_provider + new_model = raw_model.strip() + + try: + from hermes_cli.models import detect_provider_for_model, parse_model_input + + target_provider, new_model = parse_model_input(new_model, current_provider) + if target_provider == current_provider: + detected = detect_provider_for_model(new_model, current_provider) + if detected: + target_provider, new_model = detected + except Exception: + logger.debug("Provider detection failed, using model as-is", exc_info=True) + + return target_provider, new_model + async def _register_session_mcp_servers( self, state: SessionState, @@ -273,7 +367,10 @@ class HermesACPAgent(acp.Agent): await self._register_session_mcp_servers(state, mcp_servers) logger.info("New session %s (cwd=%s)", state.session_id, cwd) self._schedule_available_commands_update(state.session_id) - return NewSessionResponse(session_id=state.session_id) + return NewSessionResponse( + session_id=state.session_id, + models=self._build_model_state(state), + ) async def load_session( self, @@ -289,7 +386,7 @@ class HermesACPAgent(acp.Agent): await self._register_session_mcp_servers(state, mcp_servers) logger.info("Loaded session %s", session_id) self._schedule_available_commands_update(session_id) - return LoadSessionResponse() + return LoadSessionResponse(models=self._build_model_state(state)) async def resume_session( self, @@ -305,7 +402,7 @@ class HermesACPAgent(acp.Agent): await self._register_session_mcp_servers(state, mcp_servers) logger.info("Resumed session %s", state.session_id) self._schedule_available_commands_update(state.session_id) - return ResumeSessionResponse() + return ResumeSessionResponse(models=self._build_model_state(state)) async def cancel(self, session_id: str, **kwargs: Any) -> None: state = self.session_manager.get_session(session_id) @@ -340,11 +437,20 @@ class HermesACPAgent(acp.Agent): cwd: str | None = None, **kwargs: Any, ) -> ListSessionsResponse: - infos = self.session_manager.list_sessions() - sessions = [ - SessionInfo(session_id=s["session_id"], cwd=s["cwd"]) - for s in infos - ] + infos = self.session_manager.list_sessions(cwd=cwd) + sessions = [] + for s in infos: + updated_at = s.get("updated_at") + if updated_at is not None and not isinstance(updated_at, str): + updated_at = str(updated_at) + sessions.append( + SessionInfo( + session_id=s["session_id"], + cwd=s["cwd"], + title=s.get("title"), + updated_at=updated_at, + ) + ) return ListSessionsResponse(sessions=sessions) # ---- Prompt (core) ------------------------------------------------------ @@ -389,12 +495,13 @@ class HermesACPAgent(acp.Agent): state.cancel_event.clear() tool_call_ids: dict[str, Deque[str]] = defaultdict(deque) + tool_call_meta: dict[str, dict[str, Any]] = {} previous_approval_cb = None if conn: - tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids) + tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids, tool_call_meta) thinking_cb = make_thinking_cb(conn, session_id, loop) - step_cb = make_step_cb(conn, session_id, loop, tool_call_ids) + step_cb = make_step_cb(conn, session_id, loop, tool_call_ids, tool_call_meta) message_cb = make_message_cb(conn, session_id, loop) approval_cb = make_approval_callback(conn.request_permission, loop, session_id) else: @@ -449,6 +556,19 @@ class HermesACPAgent(acp.Agent): self.session_manager.save_session(session_id) final_response = result.get("final_response", "") + if final_response: + try: + from agent.title_generator import maybe_auto_title + + maybe_auto_title( + self.session_manager._get_db(), + session_id, + user_text, + final_response, + state.history, + ) + except Exception: + logger.debug("Failed to auto-title ACP session %s", session_id, exc_info=True) if final_response and conn: update = acp.update_agent_message_text(final_response) await conn.session_update(session_id, update) @@ -556,27 +676,15 @@ class HermesACPAgent(acp.Agent): provider = getattr(state.agent, "provider", None) or "auto" return f"Current model: {model}\nProvider: {provider}" - new_model = args.strip() - target_provider = None current_provider = getattr(state.agent, "provider", None) or "openrouter" - - # Auto-detect provider for the requested model - try: - from hermes_cli.models import parse_model_input, detect_provider_for_model - target_provider, new_model = parse_model_input(new_model, current_provider) - if target_provider == current_provider: - detected = detect_provider_for_model(new_model, current_provider) - if detected: - target_provider, new_model = detected - except Exception: - logger.debug("Provider detection failed, using model as-is", exc_info=True) + target_provider, new_model = self._resolve_model_selection(args, current_provider) state.model = new_model state.agent = self.session_manager._make_agent( session_id=state.session_id, cwd=state.cwd, model=new_model, - requested_provider=target_provider or current_provider, + requested_provider=target_provider, ) self.session_manager.save_session(state.session_id) provider_label = getattr(state.agent, "provider", None) or target_provider or current_provider @@ -678,20 +786,30 @@ class HermesACPAgent(acp.Agent): """Switch the model for a session (called by ACP protocol).""" state = self.session_manager.get_session(session_id) if state: - state.model = model_id current_provider = getattr(state.agent, "provider", None) - current_base_url = getattr(state.agent, "base_url", None) - current_api_mode = getattr(state.agent, "api_mode", None) + requested_provider, resolved_model = self._resolve_model_selection( + model_id, + current_provider or "openrouter", + ) + state.model = resolved_model + provider_changed = bool(current_provider and requested_provider != current_provider) + current_base_url = None if provider_changed else getattr(state.agent, "base_url", None) + current_api_mode = None if provider_changed else getattr(state.agent, "api_mode", None) state.agent = self.session_manager._make_agent( session_id=session_id, cwd=state.cwd, - model=model_id, - requested_provider=current_provider, + model=resolved_model, + requested_provider=requested_provider, base_url=current_base_url, api_mode=current_api_mode, ) self.session_manager.save_session(session_id) - logger.info("Session %s: model switched to %s", session_id, model_id) + logger.info( + "Session %s: model switched to %s via provider %s", + session_id, + resolved_model, + requested_provider, + ) return SetSessionModelResponse() logger.warning("Session %s: model switch requested for missing session", session_id) return None diff --git a/acp_adapter/session.py b/acp_adapter/session.py index 4bb823987..3f5f78f9a 100644 --- a/acp_adapter/session.py +++ b/acp_adapter/session.py @@ -13,8 +13,12 @@ from hermes_constants import get_hermes_home import copy import json import logging +import os +import re import sys +import time import uuid +from datetime import datetime, timezone from dataclasses import dataclass, field from threading import Lock from typing import Any, Dict, List, Optional @@ -22,6 +26,64 @@ from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) +def _normalize_cwd_for_compare(cwd: str | None) -> str: + raw = str(cwd or ".").strip() + if not raw: + raw = "." + expanded = os.path.expanduser(raw) + + # Normalize Windows drive paths into the equivalent WSL mount form so + # ACP history filters match the same workspace across Windows and WSL. + match = re.match(r"^([A-Za-z]):[\\/](.*)$", expanded) + if match: + drive = match.group(1).lower() + tail = match.group(2).replace("\\", "/") + expanded = f"/mnt/{drive}/{tail}" + elif re.match(r"^/mnt/[A-Za-z]/", expanded): + expanded = f"/mnt/{expanded[5].lower()}/{expanded[7:]}" + + return os.path.normpath(expanded) + + +def _build_session_title(title: Any, preview: Any, cwd: str | None) -> str: + explicit = str(title or "").strip() + if explicit: + return explicit + preview_text = str(preview or "").strip() + if preview_text: + return preview_text + leaf = os.path.basename(str(cwd or "").rstrip("/\\")) + return leaf or "New thread" + + +def _format_updated_at(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str) and value.strip(): + return value + try: + return datetime.fromtimestamp(float(value), tz=timezone.utc).isoformat() + except Exception: + return None + + +def _updated_at_sort_key(value: Any) -> float: + if value is None: + return float("-inf") + if isinstance(value, (int, float)): + return float(value) + raw = str(value).strip() + if not raw: + return float("-inf") + try: + return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp() + except Exception: + try: + return float(raw) + except Exception: + return float("-inf") + + def _acp_stderr_print(*args, **kwargs) -> None: """Best-effort human-readable output sink for ACP stdio sessions. @@ -162,47 +224,78 @@ class SessionManager: logger.info("Forked ACP session %s -> %s", session_id, new_id) return state - def list_sessions(self) -> List[Dict[str, Any]]: + def list_sessions(self, cwd: str | None = None) -> List[Dict[str, Any]]: """Return lightweight info dicts for all sessions (memory + database).""" + normalized_cwd = _normalize_cwd_for_compare(cwd) if cwd else None + db = self._get_db() + persisted_rows: dict[str, dict[str, Any]] = {} + + if db is not None: + try: + for row in db.list_sessions_rich(source="acp", limit=1000): + persisted_rows[str(row["id"])] = dict(row) + except Exception: + logger.debug("Failed to load ACP sessions from DB", exc_info=True) + # Collect in-memory sessions first. with self._lock: seen_ids = set(self._sessions.keys()) - results = [ - { - "session_id": s.session_id, - "cwd": s.cwd, - "model": s.model, - "history_len": len(s.history), - } - for s in self._sessions.values() - ] + results = [] + for s in self._sessions.values(): + history_len = len(s.history) + if history_len <= 0: + continue + if normalized_cwd and _normalize_cwd_for_compare(s.cwd) != normalized_cwd: + continue + persisted = persisted_rows.get(s.session_id, {}) + preview = next( + ( + str(msg.get("content") or "").strip() + for msg in s.history + if msg.get("role") == "user" and str(msg.get("content") or "").strip() + ), + persisted.get("preview") or "", + ) + results.append( + { + "session_id": s.session_id, + "cwd": s.cwd, + "model": s.model, + "history_len": history_len, + "title": _build_session_title(persisted.get("title"), preview, s.cwd), + "updated_at": _format_updated_at( + persisted.get("last_active") or persisted.get("started_at") or time.time() + ), + } + ) # Merge any persisted sessions not currently in memory. - db = self._get_db() - if db is not None: - try: - rows = db.search_sessions(source="acp", limit=1000) - for row in rows: - sid = row["id"] - if sid in seen_ids: - continue - # Extract cwd from model_config JSON. - cwd = "." - mc = row.get("model_config") - if mc: - try: - cwd = json.loads(mc).get("cwd", ".") - except (json.JSONDecodeError, TypeError): - pass - results.append({ - "session_id": sid, - "cwd": cwd, - "model": row.get("model") or "", - "history_len": row.get("message_count") or 0, - }) - except Exception: - logger.debug("Failed to list ACP sessions from DB", exc_info=True) + for sid, row in persisted_rows.items(): + if sid in seen_ids: + continue + message_count = int(row.get("message_count") or 0) + if message_count <= 0: + continue + # Extract cwd from model_config JSON. + session_cwd = "." + mc = row.get("model_config") + if mc: + try: + session_cwd = json.loads(mc).get("cwd", ".") + except (json.JSONDecodeError, TypeError): + pass + if normalized_cwd and _normalize_cwd_for_compare(session_cwd) != normalized_cwd: + continue + results.append({ + "session_id": sid, + "cwd": session_cwd, + "model": row.get("model") or "", + "history_len": message_count, + "title": _build_session_title(row.get("title"), row.get("preview"), session_cwd), + "updated_at": _format_updated_at(row.get("last_active") or row.get("started_at")), + }) + results.sort(key=lambda item: _updated_at_sort_key(item.get("updated_at")), reverse=True) return results def update_cwd(self, session_id: str, cwd: str) -> Optional[SessionState]: diff --git a/acp_adapter/tools.py b/acp_adapter/tools.py index 52313220b..067652106 100644 --- a/acp_adapter/tools.py +++ b/acp_adapter/tools.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import uuid from typing import Any, Dict, List, Optional @@ -96,6 +97,170 @@ def build_tool_title(tool_name: str, args: Dict[str, Any]) -> str: return tool_name +def _build_patch_mode_content(patch_text: str) -> List[Any]: + """Parse V4A patch mode input into ACP diff blocks when possible.""" + if not patch_text: + return [acp.tool_content(acp.text_block(""))] + + try: + from tools.patch_parser import OperationType, parse_v4a_patch + + operations, error = parse_v4a_patch(patch_text) + if error or not operations: + return [acp.tool_content(acp.text_block(patch_text))] + + content: List[Any] = [] + for op in operations: + if op.operation == OperationType.UPDATE: + old_chunks: list[str] = [] + new_chunks: list[str] = [] + for hunk in op.hunks: + old_lines = [line.content for line in hunk.lines if line.prefix in (" ", "-")] + new_lines = [line.content for line in hunk.lines if line.prefix in (" ", "+")] + if old_lines or new_lines: + old_chunks.append("\n".join(old_lines)) + new_chunks.append("\n".join(new_lines)) + + old_text = "\n...\n".join(chunk for chunk in old_chunks if chunk) + new_text = "\n...\n".join(chunk for chunk in new_chunks if chunk) + if old_text or new_text: + content.append( + acp.tool_diff_content( + path=op.file_path, + old_text=old_text or None, + new_text=new_text or "", + ) + ) + continue + + if op.operation == OperationType.ADD: + added_lines = [line.content for hunk in op.hunks for line in hunk.lines if line.prefix == "+"] + content.append( + acp.tool_diff_content( + path=op.file_path, + new_text="\n".join(added_lines), + ) + ) + continue + + if op.operation == OperationType.DELETE: + content.append( + acp.tool_diff_content( + path=op.file_path, + old_text=f"Delete file: {op.file_path}", + new_text="", + ) + ) + continue + + if op.operation == OperationType.MOVE: + content.append( + acp.tool_content(acp.text_block(f"Move file: {op.file_path} -> {op.new_path}")) + ) + + return content or [acp.tool_content(acp.text_block(patch_text))] + except Exception: + return [acp.tool_content(acp.text_block(patch_text))] + + +def _strip_diff_prefix(path: str) -> str: + raw = str(path or "").strip() + if raw.startswith(("a/", "b/")): + return raw[2:] + return raw + + +def _parse_unified_diff_content(diff_text: str) -> List[Any]: + """Convert unified diff text into ACP diff content blocks.""" + if not diff_text: + return [] + + content: List[Any] = [] + current_old_path: Optional[str] = None + current_new_path: Optional[str] = None + old_lines: list[str] = [] + new_lines: list[str] = [] + + def _flush() -> None: + nonlocal current_old_path, current_new_path, old_lines, new_lines + if current_old_path is None and current_new_path is None: + return + path = current_new_path if current_new_path and current_new_path != "/dev/null" else current_old_path + if not path or path == "/dev/null": + current_old_path = None + current_new_path = None + old_lines = [] + new_lines = [] + return + content.append( + acp.tool_diff_content( + path=_strip_diff_prefix(path), + old_text="\n".join(old_lines) if old_lines else None, + new_text="\n".join(new_lines), + ) + ) + current_old_path = None + current_new_path = None + old_lines = [] + new_lines = [] + + for line in diff_text.splitlines(): + if line.startswith("--- "): + _flush() + current_old_path = line[4:].strip() + continue + if line.startswith("+++ "): + current_new_path = line[4:].strip() + continue + if line.startswith("@@"): + continue + if current_old_path is None and current_new_path is None: + continue + if line.startswith("+"): + new_lines.append(line[1:]) + elif line.startswith("-"): + old_lines.append(line[1:]) + elif line.startswith(" "): + shared = line[1:] + old_lines.append(shared) + new_lines.append(shared) + + _flush() + return content + + +def _build_tool_complete_content( + tool_name: str, + result: Optional[str], + *, + function_args: Optional[Dict[str, Any]] = None, + snapshot: Any = None, +) -> List[Any]: + """Build structured ACP completion content, falling back to plain text.""" + display_result = result or "" + if len(display_result) > 5000: + display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)" + + if tool_name in {"write_file", "patch", "skill_manage"}: + try: + from agent.display import extract_edit_diff + + diff_text = extract_edit_diff( + tool_name, + result, + function_args=function_args, + snapshot=snapshot, + ) + if isinstance(diff_text, str) and diff_text.strip(): + diff_content = _parse_unified_diff_content(diff_text) + if diff_content: + return diff_content + except Exception: + pass + + return [acp.tool_content(acp.text_block(display_result))] + + # --------------------------------------------------------------------------- # Build ACP content objects for tool-call events # --------------------------------------------------------------------------- @@ -119,9 +284,8 @@ def build_tool_start( new = arguments.get("new_string", "") content = [acp.tool_diff_content(path=path, new_text=new, old_text=old)] else: - # Patch mode — show the patch content as text patch_text = arguments.get("patch", "") - content = [acp.tool_content(acp.text_block(patch_text))] + content = _build_patch_mode_content(patch_text) return acp.start_tool_call( tool_call_id, title, kind=kind, content=content, locations=locations, raw_input=arguments, @@ -178,16 +342,17 @@ def build_tool_complete( tool_call_id: str, tool_name: str, result: Optional[str] = None, + function_args: Optional[Dict[str, Any]] = None, + snapshot: Any = None, ) -> ToolCallProgress: """Create a ToolCallUpdate (progress) event for a completed tool call.""" kind = get_tool_kind(tool_name) - - # Truncate very large results for the UI - display_result = result or "" - if len(display_result) > 5000: - display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)" - - content = [acp.tool_content(acp.text_block(display_result))] + content = _build_tool_complete_content( + tool_name, + result, + function_args=function_args, + snapshot=snapshot, + ) return acp.update_tool_call( tool_call_id, kind=kind, diff --git a/pyproject.toml b/pyproject.toml index 7571e51d5..b73bef937 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,8 +76,8 @@ termux = [ "hermes-agent[honcho]", "hermes-agent[acp]", ] -dingtalk = ["dingtalk-stream>=0.1.0,<1"] -feishu = ["lark-oapi>=1.5.3,<2"] +dingtalk = ["dingtalk-stream>=0.1.0,<1", "qrcode>=7.0,<8"] +feishu = ["lark-oapi>=1.5.3,<2", "qrcode>=7.0,<8"] web = ["fastapi>=0.104.0,<1", "uvicorn[standard]>=0.24.0,<1"] rl = [ "atroposlib @ git+https://github.com/NousResearch/atropos.git@c20c85256e5a45ad31edf8b7276e9c5ee1995a30", diff --git a/run_agent.py b/run_agent.py index 49525b8bd..64572001b 100644 --- a/run_agent.py +++ b/run_agent.py @@ -353,12 +353,50 @@ def _sanitize_surrogates(text: str) -> str: return text +def _sanitize_structure_surrogates(payload: Any) -> bool: + """Replace surrogate code points in nested dict/list payloads in-place. + + Mirror of ``_sanitize_structure_non_ascii`` but for surrogate recovery. + Used to scrub nested structured fields (e.g. ``reasoning_details`` — an + array of dicts with ``summary``/``text`` strings) that flat per-field + checks don't reach. Returns True if any surrogates were replaced. + """ + found = False + + def _walk(node): + nonlocal found + if isinstance(node, dict): + for key, value in node.items(): + if isinstance(value, str): + if _SURROGATE_RE.search(value): + node[key] = _SURROGATE_RE.sub('\ufffd', value) + found = True + elif isinstance(value, (dict, list)): + _walk(value) + elif isinstance(node, list): + for idx, value in enumerate(node): + if isinstance(value, str): + if _SURROGATE_RE.search(value): + node[idx] = _SURROGATE_RE.sub('\ufffd', value) + found = True + elif isinstance(value, (dict, list)): + _walk(value) + + _walk(payload) + return found + + def _sanitize_messages_surrogates(messages: list) -> bool: """Sanitize surrogate characters from all string content in a messages list. Walks message dicts in-place. Returns True if any surrogates were found - and replaced, False otherwise. Covers content/text, name, and tool call - metadata/arguments so retries don't fail on a non-content field. + and replaced, False otherwise. Covers content/text, name, tool call + metadata/arguments, AND any additional string or nested structured fields + (``reasoning``, ``reasoning_content``, ``reasoning_details``, etc.) so + retries don't fail on a non-content field. Byte-level reasoning models + (xiaomi/mimo, kimi, glm) can emit lone surrogates in reasoning output + that flow through to ``api_messages["reasoning_content"]`` on the next + turn and crash json.dumps inside the OpenAI SDK. """ found = False for msg in messages: @@ -398,6 +436,21 @@ def _sanitize_messages_surrogates(messages: list) -> bool: if isinstance(fn_args, str) and _SURROGATE_RE.search(fn_args): fn["arguments"] = _SURROGATE_RE.sub('\ufffd', fn_args) found = True + # Walk any additional string / nested fields (reasoning, + # reasoning_content, reasoning_details, etc.) — surrogates from + # byte-level reasoning models (xiaomi/mimo, kimi, glm) can lurk + # in these fields and aren't covered by the per-field checks above. + # Matches _sanitize_messages_non_ascii's coverage (PR #10537). + for key, value in msg.items(): + if key in {"content", "name", "tool_calls", "role"}: + continue + if isinstance(value, str): + if _SURROGATE_RE.search(value): + msg[key] = _SURROGATE_RE.sub('\ufffd', value) + found = True + elif isinstance(value, (dict, list)): + if _sanitize_structure_surrogates(value): + found = True return found @@ -8689,6 +8742,7 @@ class AIAgent: { "name": tc["function"]["name"], "result": _results_by_id.get(tc.get("id")), + "arguments": tc["function"].get("arguments"), } for tc in _m["tool_calls"] if isinstance(tc, dict) @@ -9303,8 +9357,7 @@ class AIAgent: "and had none left for the actual response.\n\n" "To fix this:\n" "→ Lower reasoning effort: `/thinkon low` or `/thinkon minimal`\n" - "→ Increase the output token limit: " - "set `model.max_tokens` in config.yaml" + "→ Or switch to a larger/non-reasoning model with `/model`" ) self._cleanup_task_resources(effective_task_id) self._persist_session(messages, conversation_history) @@ -9571,13 +9624,51 @@ class AIAgent: if isinstance(api_error, UnicodeEncodeError) and getattr(self, '_unicode_sanitization_passes', 0) < 2: _err_str = str(api_error).lower() _is_ascii_codec = "'ascii'" in _err_str or "ascii" in _err_str + # Detect surrogate errors — utf-8 codec refusing to + # encode U+D800..U+DFFF. The error text is: + # "'utf-8' codec can't encode characters in position + # N-M: surrogates not allowed" + _is_surrogate_error = ( + "surrogate" in _err_str + or ("'utf-8'" in _err_str and not _is_ascii_codec) + ) + # Sanitize surrogates from both the canonical `messages` + # list AND `api_messages` (the API-copy, which may carry + # `reasoning_content`/`reasoning_details` transformed + # from `reasoning` — fields the canonical list doesn't + # have directly). Also clean `api_kwargs` if built and + # `prefill_messages` if present. Mirrors the ASCII + # codec recovery below. _surrogates_found = _sanitize_messages_surrogates(messages) - if _surrogates_found: + if isinstance(api_messages, list): + if _sanitize_messages_surrogates(api_messages): + _surrogates_found = True + if isinstance(api_kwargs, dict): + if _sanitize_structure_surrogates(api_kwargs): + _surrogates_found = True + if isinstance(getattr(self, "prefill_messages", None), list): + if _sanitize_messages_surrogates(self.prefill_messages): + _surrogates_found = True + # Gate the retry on the error type, not on whether we + # found anything — _force_ascii_payload / the extended + # surrogate walker above cover all known paths, but a + # new transformed field could still slip through. If + # the error was a surrogate encode failure, always let + # the retry run; the proactive sanitizer at line ~8781 + # runs again on the next iteration. Bounded by + # _unicode_sanitization_passes < 2 (outer guard). + if _surrogates_found or _is_surrogate_error: self._unicode_sanitization_passes += 1 - self._vprint( - f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...", - force=True, - ) + if _surrogates_found: + self._vprint( + f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...", + force=True, + ) + else: + self._vprint( + f"{self.log_prefix}⚠️ Surrogate encoding error — retrying after full-payload sanitization...", + force=True, + ) continue if _is_ascii_codec: self._force_ascii_payload = True diff --git a/scripts/release.py b/scripts/release.py index 55d9f8d1e..028f75ba6 100755 --- a/scripts/release.py +++ b/scripts/release.py @@ -103,6 +103,7 @@ AUTHOR_MAP = { "dangtc94@gmail.com": "dieutx", "jaisehgal11299@gmail.com": "jaisup", "percydikec@gmail.com": "PercyDikec", + "noonou7@gmail.com": "HenkDz", "dean.kerr@gmail.com": "deankerr", "socrates1024@gmail.com": "socrates1024", "satelerd@gmail.com": "satelerd", diff --git a/tests/acp/test_events.py b/tests/acp/test_events.py index bfb82ba0d..c9f91a181 100644 --- a/tests/acp/test_events.py +++ b/tests/acp/test_events.py @@ -42,9 +42,10 @@ class TestToolProgressCallback: def test_emits_tool_call_start(self, mock_conn, event_loop_fixture): """Tool progress should emit a ToolCallStart update.""" tool_call_ids = {} + tool_call_meta = {} loop = event_loop_fixture - cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids) + cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta) # Run callback in the event loop context with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts: @@ -66,9 +67,10 @@ class TestToolProgressCallback: def test_handles_string_args(self, mock_conn, event_loop_fixture): """If args is a JSON string, it should be parsed.""" tool_call_ids = {} + tool_call_meta = {} loop = event_loop_fixture - cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids) + cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta) with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts: future = MagicMock(spec=Future) @@ -82,9 +84,10 @@ class TestToolProgressCallback: def test_handles_non_dict_args(self, mock_conn, event_loop_fixture): """If args is not a dict, it should be wrapped.""" tool_call_ids = {} + tool_call_meta = {} loop = event_loop_fixture - cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids) + cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta) with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts: future = MagicMock(spec=Future) @@ -98,10 +101,11 @@ class TestToolProgressCallback: def test_duplicate_same_name_tool_calls_use_fifo_ids(self, mock_conn, event_loop_fixture): """Multiple same-name tool calls should be tracked independently in order.""" tool_call_ids = {} + tool_call_meta = {} loop = event_loop_fixture - progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids) - step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids) + progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta) + step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta) with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts: future = MagicMock(spec=Future) @@ -163,7 +167,7 @@ class TestStepCallback: tool_call_ids = {"terminal": "tc-abc123"} loop = event_loop_fixture - cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids) + cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {}) with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts: future = MagicMock(spec=Future) @@ -181,7 +185,7 @@ class TestStepCallback: tool_call_ids = {} loop = event_loop_fixture - cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids) + cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {}) with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts: cb(1, [{"name": "unknown_tool", "result": "ok"}]) @@ -193,7 +197,7 @@ class TestStepCallback: tool_call_ids = {"read_file": "tc-def456"} loop = event_loop_fixture - cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids) + cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {}) with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts: future = MagicMock(spec=Future) @@ -212,7 +216,7 @@ class TestStepCallback: tool_call_ids = {"terminal": deque(["tc-xyz789"])} loop = event_loop_fixture - cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids) + cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {}) with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \ patch("acp_adapter.events.build_tool_complete") as mock_btc: @@ -224,7 +228,7 @@ class TestStepCallback: cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}]) mock_btc.assert_called_once_with( - "tc-xyz789", "terminal", result='{"output": "hello"}' + "tc-xyz789", "terminal", result='{"output": "hello"}', function_args=None, snapshot=None ) def test_none_result_passed_through(self, mock_conn, event_loop_fixture): @@ -234,7 +238,7 @@ class TestStepCallback: tool_call_ids = {"web_search": deque(["tc-aaa"])} loop = event_loop_fixture - cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids) + cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {}) with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \ patch("acp_adapter.events.build_tool_complete") as mock_btc: @@ -244,7 +248,50 @@ class TestStepCallback: cb(1, [{"name": "web_search", "result": None}]) - mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None) + mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None, function_args=None, snapshot=None) + + def test_step_callback_passes_arguments_and_snapshot(self, mock_conn, event_loop_fixture): + from collections import deque + + tool_call_ids = {"write_file": deque(["tc-write"])} + tool_call_meta = {"tc-write": {"args": {"path": "fallback.txt"}, "snapshot": "snap"}} + loop = event_loop_fixture + + cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta) + + with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \ + patch("acp_adapter.events.build_tool_complete") as mock_btc: + future = MagicMock(spec=Future) + future.result.return_value = None + mock_rcts.return_value = future + + cb(1, [{"name": "write_file", "result": '{"bytes_written": 23}', "arguments": {"path": "diff-test.txt"}}]) + + mock_btc.assert_called_once_with( + "tc-write", + "write_file", + result='{"bytes_written": 23}', + function_args={"path": "diff-test.txt"}, + snapshot="snap", + ) + + def test_tool_progress_captures_snapshot_metadata(self, mock_conn, event_loop_fixture): + tool_call_ids = {} + tool_call_meta = {} + loop = event_loop_fixture + + with patch("acp_adapter.events.make_tool_call_id", return_value="tc-meta"), \ + patch("acp_adapter.events._send_update") as mock_send, \ + patch("agent.display.capture_local_edit_snapshot", return_value="snapshot"): + cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta) + cb("tool.started", "write_file", None, {"path": "diff-test.txt", "content": "hello"}) + + assert list(tool_call_ids["write_file"]) == ["tc-meta"] + assert tool_call_meta["tc-meta"] == { + "args": {"path": "diff-test.txt", "content": "hello"}, + "snapshot": "snapshot", + } + mock_send.assert_called_once() # --------------------------------------------------------------------------- diff --git a/tests/acp/test_mcp_e2e.py b/tests/acp/test_mcp_e2e.py index 186f1b86f..88e89acf2 100644 --- a/tests/acp/test_mcp_e2e.py +++ b/tests/acp/test_mcp_e2e.py @@ -29,6 +29,7 @@ from acp.schema import ( from acp_adapter.server import HermesACPAgent from acp_adapter.session import SessionManager +from acp_adapter.tools import build_tool_start # --------------------------------------------------------------------------- @@ -181,6 +182,25 @@ class TestMcpRegistrationE2E: assert complete_event.raw_output is not None assert "hello" in str(complete_event.raw_output) + def test_patch_mode_tool_start_emits_diff_blocks_for_v4a_patch(self): + update = build_tool_start( + "tc-1", + "patch", + { + "mode": "patch", + "patch": "*** Begin Patch\n*** Update File: src/app.py\n@@\n-old line\n+new line\n*** Add File: src/new.py\n+hello\n*** End Patch", + }, + ) + + assert len(update.content) == 2 + assert update.content[0].type == "diff" + assert update.content[0].path == "src/app.py" + assert update.content[0].old_text == "old line" + assert update.content[0].new_text == "new line" + assert update.content[1].type == "diff" + assert update.content[1].path == "src/new.py" + assert update.content[1].new_text == "hello" + @pytest.mark.asyncio async def test_prompt_tool_results_paired_by_call_id(self, acp_agent, mock_manager): """The ToolCallUpdate's toolCallId must match the ToolCallStart's.""" diff --git a/tests/acp/test_server.py b/tests/acp/test_server.py index 240392887..5893d7907 100644 --- a/tests/acp/test_server.py +++ b/tests/acp/test_server.py @@ -20,7 +20,9 @@ from acp.schema import ( NewSessionResponse, PromptResponse, ResumeSessionResponse, + SessionModelState, SetSessionConfigOptionResponse, + SetSessionModelResponse, SetSessionModeResponse, SessionInfo, TextContentBlock, @@ -127,6 +129,25 @@ class TestSessionOps: assert state is not None assert state.cwd == "/home/user/project" + @pytest.mark.asyncio + async def test_new_session_returns_model_state(self): + manager = SessionManager( + agent_factory=lambda: SimpleNamespace(model="gpt-5.4", provider="openai-codex") + ) + acp_agent = HermesACPAgent(session_manager=manager) + + with patch( + "hermes_cli.models.curated_models_for_provider", + return_value=[("gpt-5.4", "recommended"), ("gpt-5.4-mini", "")], + ): + resp = await acp_agent.new_session(cwd="/tmp") + + assert isinstance(resp.models, SessionModelState) + assert resp.models.current_model_id == "openai-codex:gpt-5.4" + assert resp.models.available_models[0].model_id == "openai-codex:gpt-5.4" + assert resp.models.available_models[0].description is not None + assert "Provider:" in resp.models.available_models[0].description + @pytest.mark.asyncio async def test_available_commands_include_help(self, agent): help_cmd = next( @@ -204,6 +225,33 @@ class TestListAndFork: assert fork_resp.session_id assert fork_resp.session_id != new_resp.session_id + @pytest.mark.asyncio + async def test_list_sessions_includes_title_and_updated_at(self, agent): + with patch.object( + agent.session_manager, + "list_sessions", + return_value=[ + { + "session_id": "session-1", + "cwd": "/tmp/project", + "title": "Fix Zed session history", + "updated_at": 123.0, + } + ], + ): + resp = await agent.list_sessions(cwd="/tmp/project") + + assert isinstance(resp.sessions[0], SessionInfo) + assert resp.sessions[0].title == "Fix Zed session history" + assert resp.sessions[0].updated_at == "123.0" + + @pytest.mark.asyncio + async def test_list_sessions_passes_cwd_filter(self, agent): + with patch.object(agent.session_manager, "list_sessions", return_value=[]) as mock_list: + await agent.list_sessions(cwd="/mnt/e/Projects/AI/browser-link-3") + + mock_list.assert_called_once_with(cwd="/mnt/e/Projects/AI/browser-link-3") + # --------------------------------------------------------------------------- # session configuration / model routing # --------------------------------------------------------------------------- @@ -257,6 +305,53 @@ class TestSessionConfiguration: assert result == {} assert state.model == "gpt-5.4" + @pytest.mark.asyncio + async def test_set_session_model_accepts_provider_prefixed_choice(self, tmp_path, monkeypatch): + runtime_calls = [] + + def fake_resolve_runtime_provider(requested=None, **kwargs): + runtime_calls.append(requested) + provider = requested or "openrouter" + return { + "provider": provider, + "api_mode": "anthropic_messages" if provider == "anthropic" else "chat_completions", + "base_url": f"https://{provider}.example/v1", + "api_key": f"{provider}-key", + "command": None, + "args": [], + } + + def fake_agent(**kwargs): + return SimpleNamespace( + model=kwargs.get("model"), + provider=kwargs.get("provider"), + base_url=kwargs.get("base_url"), + api_mode=kwargs.get("api_mode"), + ) + + monkeypatch.setattr("hermes_cli.config.load_config", lambda: { + "model": {"provider": "openrouter", "default": "openrouter/gpt-5"} + }) + monkeypatch.setattr( + "hermes_cli.runtime_provider.resolve_runtime_provider", + fake_resolve_runtime_provider, + ) + manager = SessionManager(db=SessionDB(tmp_path / "state.db")) + + with patch("run_agent.AIAgent", side_effect=fake_agent): + acp_agent = HermesACPAgent(session_manager=manager) + state = manager.create_session(cwd="/tmp") + result = await acp_agent.set_session_model( + model_id="anthropic:claude-sonnet-4-6", + session_id=state.session_id, + ) + + assert isinstance(result, SetSessionModelResponse) + assert state.model == "claude-sonnet-4-6" + assert state.agent.provider == "anthropic" + assert state.agent.base_url == "https://anthropic.example/v1" + assert runtime_calls[-1] == "anthropic" + # --------------------------------------------------------------------------- # prompt @@ -354,6 +449,31 @@ class TestPrompt: update = last_call[1].get("update") or last_call[0][1] assert update.session_update == "agent_message_chunk" + @pytest.mark.asyncio + async def test_prompt_auto_titles_session(self, agent): + new_resp = await agent.new_session(cwd=".") + state = agent.session_manager.get_session(new_resp.session_id) + state.agent.run_conversation = MagicMock(return_value={ + "final_response": "Here is the fix.", + "messages": [ + {"role": "user", "content": "fix the broken ACP history"}, + {"role": "assistant", "content": "Here is the fix."}, + ], + }) + + mock_conn = MagicMock(spec=acp.Client) + mock_conn.session_update = AsyncMock() + agent._conn = mock_conn + + with patch("agent.title_generator.maybe_auto_title") as mock_title: + prompt = [TextContentBlock(type="text", text="fix the broken ACP history")] + await agent.prompt(prompt=prompt, session_id=new_resp.session_id) + + mock_title.assert_called_once() + assert mock_title.call_args.args[1] == new_resp.session_id + assert mock_title.call_args.args[2] == "fix the broken ACP history" + assert mock_title.call_args.args[3] == "Here is the fix." + @pytest.mark.asyncio async def test_prompt_populates_usage_from_top_level_run_conversation_fields(self, agent): """ACP should map top-level token fields into PromptResponse.usage.""" diff --git a/tests/acp/test_session.py b/tests/acp/test_session.py index 2d7cc5db2..50d04b1a9 100644 --- a/tests/acp/test_session.py +++ b/tests/acp/test_session.py @@ -3,6 +3,7 @@ import contextlib import io import json +import time from types import SimpleNamespace import pytest from unittest.mock import MagicMock, patch @@ -100,15 +101,23 @@ class TestListAndCleanup: def test_list_sessions_returns_created(self, manager): s1 = manager.create_session(cwd="/a") s2 = manager.create_session(cwd="/b") + s1.history.append({"role": "user", "content": "hello from a"}) + s2.history.append({"role": "user", "content": "hello from b"}) listing = manager.list_sessions() ids = {s["session_id"] for s in listing} assert s1.session_id in ids assert s2.session_id in ids assert len(listing) == 2 + def test_list_sessions_hides_empty_threads(self, manager): + manager.create_session(cwd="/empty") + assert manager.list_sessions() == [] + def test_cleanup_clears_all(self, manager): - manager.create_session() - manager.create_session() + s1 = manager.create_session() + s2 = manager.create_session() + s1.history.append({"role": "user", "content": "one"}) + s2.history.append({"role": "user", "content": "two"}) assert len(manager.list_sessions()) == 2 manager.cleanup() assert manager.list_sessions() == [] @@ -194,6 +203,8 @@ class TestPersistence: def test_list_sessions_includes_db_only(self, manager): """Sessions only in DB (not in memory) appear in list_sessions.""" state = manager.create_session(cwd="/db-only") + state.history.append({"role": "user", "content": "database only thread"}) + manager.save_session(state.session_id) sid = state.session_id # Drop from memory. @@ -204,6 +215,53 @@ class TestPersistence: ids = {s["session_id"] for s in listing} assert sid in ids + def test_list_sessions_filters_by_cwd(self, manager): + keep = manager.create_session(cwd="/keep") + drop = manager.create_session(cwd="/drop") + keep.history.append({"role": "user", "content": "keep me"}) + drop.history.append({"role": "user", "content": "drop me"}) + + listing = manager.list_sessions(cwd="/keep") + ids = {s["session_id"] for s in listing} + assert keep.session_id in ids + assert drop.session_id not in ids + + def test_list_sessions_matches_windows_and_wsl_paths(self, manager): + state = manager.create_session(cwd="/mnt/e/Projects/AI/browser-link-3") + state.history.append({"role": "user", "content": "same project from WSL"}) + + listing = manager.list_sessions(cwd=r"E:\Projects\AI\browser-link-3") + ids = {s["session_id"] for s in listing} + assert state.session_id in ids + + def test_list_sessions_prefers_title_then_preview(self, manager): + state = manager.create_session(cwd="/named") + state.history.append({"role": "user", "content": "Investigate broken ACP history in Zed"}) + manager.save_session(state.session_id) + db = manager._get_db() + db.set_session_title(state.session_id, "Fix Zed ACP history") + + listing = manager.list_sessions(cwd="/named") + assert listing[0]["title"] == "Fix Zed ACP history" + + db.set_session_title(state.session_id, "") + listing = manager.list_sessions(cwd="/named") + assert listing[0]["title"].startswith("Investigate broken ACP history") + + def test_list_sessions_sorted_by_most_recent_activity(self, manager): + older = manager.create_session(cwd="/ordered") + older.history.append({"role": "user", "content": "older"}) + manager.save_session(older.session_id) + time.sleep(0.02) + newer = manager.create_session(cwd="/ordered") + newer.history.append({"role": "user", "content": "newer"}) + manager.save_session(newer.session_id) + + listing = manager.list_sessions(cwd="/ordered") + assert [item["session_id"] for item in listing[:2]] == [newer.session_id, older.session_id] + assert listing[0]["updated_at"] + assert listing[1]["updated_at"] + def test_fork_restores_source_from_db(self, manager): """Forking a session that is only in DB should work.""" original = manager.create_session() diff --git a/tests/acp/test_tools.py b/tests/acp/test_tools.py index 59401501f..603fe7459 100644 --- a/tests/acp/test_tools.py +++ b/tests/acp/test_tools.py @@ -215,6 +215,46 @@ class TestBuildToolComplete: assert len(display_text) < 6000 assert "truncated" in display_text + def test_build_tool_complete_for_patch_uses_diff_blocks(self): + """Completed patch calls should keep structured diff content for Zed.""" + patch_result = ( + '{"success": true, "diff": "--- a/README.md\\n+++ b/README.md\\n@@ -1 +1,2 @@\\n old line\\n+new line\\n", ' + '"files_modified": ["README.md"]}' + ) + result = build_tool_complete("tc-p1", "patch", patch_result) + assert isinstance(result, ToolCallProgress) + assert len(result.content) == 1 + diff_item = result.content[0] + assert isinstance(diff_item, FileEditToolCallContent) + assert diff_item.path == "README.md" + assert diff_item.old_text == "old line" + assert diff_item.new_text == "old line\nnew line" + + def test_build_tool_complete_for_patch_falls_back_to_text_when_no_diff(self): + result = build_tool_complete("tc-p2", "patch", '{"success": true}') + assert isinstance(result, ToolCallProgress) + assert isinstance(result.content[0], ContentToolCallContent) + + def test_build_tool_complete_for_write_file_uses_snapshot_diff(self, tmp_path): + target = tmp_path / "diff-test.txt" + snapshot = type("Snapshot", (), {"paths": [target], "before": {str(target): None}})() + target.write_text("hello from hermes\n", encoding="utf-8") + + result = build_tool_complete( + "tc-wf1", + "write_file", + '{"bytes_written": 18, "dirs_created": false}', + function_args={"path": str(target), "content": "hello from hermes\n"}, + snapshot=snapshot, + ) + assert isinstance(result, ToolCallProgress) + assert len(result.content) == 1 + diff_item = result.content[0] + assert isinstance(diff_item, FileEditToolCallContent) + assert diff_item.path.endswith("diff-test.txt") + assert diff_item.old_text is None + assert diff_item.new_text == "hello from hermes" + # --------------------------------------------------------------------------- # extract_locations diff --git a/tests/cli/test_surrogate_sanitization.py b/tests/cli/test_surrogate_sanitization.py index 43af7fe16..9d677352c 100644 --- a/tests/cli/test_surrogate_sanitization.py +++ b/tests/cli/test_surrogate_sanitization.py @@ -2,7 +2,8 @@ Surrogates (U+D800..U+DFFF) are invalid in UTF-8 and crash json.dumps() inside the OpenAI SDK. They can appear via clipboard paste from rich-text -editors like Google Docs. +editors like Google Docs, OR from byte-level reasoning models (xiaomi/mimo, +kimi, glm) emitting lone halves in reasoning output. """ import json import pytest @@ -11,6 +12,7 @@ from unittest.mock import MagicMock, patch from run_agent import ( _sanitize_surrogates, _sanitize_messages_surrogates, + _sanitize_structure_surrogates, _SURROGATE_RE, ) @@ -109,6 +111,186 @@ class TestSanitizeMessagesSurrogates: assert "\ufffd" in msgs[0]["content"] +class TestReasoningFieldSurrogates: + """Surrogates in reasoning fields (byte-level reasoning models). + + xiaomi/mimo, kimi, glm and similar byte-level tokenizers can emit lone + surrogates in reasoning output. These fields are carried through to the + API as `reasoning_content` on assistant messages, and must be sanitized + or json.dumps() crashes with 'utf-8' codec can't encode surrogates. + """ + + def test_reasoning_field_sanitized(self): + msgs = [ + {"role": "assistant", "content": "ok", "reasoning": "thought \udce2 here"}, + ] + assert _sanitize_messages_surrogates(msgs) is True + assert "\udce2" not in msgs[0]["reasoning"] + assert "\ufffd" in msgs[0]["reasoning"] + + def test_reasoning_content_field_sanitized(self): + """api_messages carry `reasoning_content` built from `reasoning`.""" + msgs = [ + {"role": "assistant", "content": "ok", "reasoning_content": "thought \udce2 here"}, + ] + assert _sanitize_messages_surrogates(msgs) is True + assert "\udce2" not in msgs[0]["reasoning_content"] + assert "\ufffd" in msgs[0]["reasoning_content"] + + def test_reasoning_details_nested_sanitized(self): + """reasoning_details is a list of dicts with nested string fields.""" + msgs = [ + { + "role": "assistant", + "content": "ok", + "reasoning_details": [ + {"type": "reasoning.summary", "summary": "summary \udce2 text"}, + {"type": "reasoning.text", "text": "chain \udc00 of thought"}, + ], + }, + ] + assert _sanitize_messages_surrogates(msgs) is True + assert "\udce2" not in msgs[0]["reasoning_details"][0]["summary"] + assert "\ufffd" in msgs[0]["reasoning_details"][0]["summary"] + assert "\udc00" not in msgs[0]["reasoning_details"][1]["text"] + assert "\ufffd" in msgs[0]["reasoning_details"][1]["text"] + + def test_deeply_nested_reasoning_sanitized(self): + """Nested dicts / lists inside extra fields are recursed into.""" + msgs = [ + { + "role": "assistant", + "content": "ok", + "reasoning_details": [ + { + "type": "reasoning.encrypted", + "content": { + "encrypted_content": "opaque", + "text_parts": ["part1", "part2 \udce2 part"], + }, + }, + ], + }, + ] + assert _sanitize_messages_surrogates(msgs) is True + assert ( + msgs[0]["reasoning_details"][0]["content"]["text_parts"][1] + == "part2 \ufffd part" + ) + + def test_reasoning_end_to_end_json_serialization(self): + """After sanitization, the full message dict must serialize clean.""" + msgs = [ + { + "role": "assistant", + "content": "answer", + "reasoning_content": "reasoning with \udce2 surrogate", + "reasoning_details": [ + {"summary": "nested \udcb0 surrogate"}, + ], + }, + ] + _sanitize_messages_surrogates(msgs) + # Must round-trip through json + utf-8 encoding without error + payload = json.dumps(msgs, ensure_ascii=False).encode("utf-8") + assert b"\\" not in payload[:0] # sanity — just ensure we got bytes + assert len(payload) > 0 + + def test_no_surrogates_returns_false(self): + """Clean reasoning fields don't trigger a modification.""" + msgs = [ + { + "role": "assistant", + "content": "ok", + "reasoning": "clean thought", + "reasoning_content": "also clean", + "reasoning_details": [{"summary": "clean summary"}], + }, + ] + assert _sanitize_messages_surrogates(msgs) is False + + +class TestSanitizeStructureSurrogates: + """Test the _sanitize_structure_surrogates() helper for nested payloads.""" + + def test_empty_payload(self): + assert _sanitize_structure_surrogates({}) is False + assert _sanitize_structure_surrogates([]) is False + + def test_flat_dict(self): + payload = {"a": "clean", "b": "dirty \udce2 text"} + assert _sanitize_structure_surrogates(payload) is True + assert payload["a"] == "clean" + assert "\ufffd" in payload["b"] + + def test_flat_list(self): + payload = ["clean", "dirty \udce2"] + assert _sanitize_structure_surrogates(payload) is True + assert payload[0] == "clean" + assert "\ufffd" in payload[1] + + def test_nested_dict_in_list(self): + payload = [{"x": "dirty \udce2"}, {"x": "clean"}] + assert _sanitize_structure_surrogates(payload) is True + assert "\ufffd" in payload[0]["x"] + assert payload[1]["x"] == "clean" + + def test_deeply_nested(self): + payload = { + "level1": { + "level2": [ + {"level3": "deep \udce2 surrogate"}, + ], + }, + } + assert _sanitize_structure_surrogates(payload) is True + assert "\ufffd" in payload["level1"]["level2"][0]["level3"] + + def test_clean_payload_returns_false(self): + payload = {"a": "clean", "b": [{"c": "also clean"}]} + assert _sanitize_structure_surrogates(payload) is False + + def test_non_string_values_ignored(self): + payload = {"int": 42, "list": [1, 2, 3], "dict": {"none": None}, "bool": True} + assert _sanitize_structure_surrogates(payload) is False + # Non-string values survive unchanged + assert payload["int"] == 42 + assert payload["list"] == [1, 2, 3] + + +class TestApiMessagesSurrogateRecovery: + """Integration: verify the recovery block sanitizes api_messages. + + The bug this guards against: a surrogate in `reasoning_content` on + api_messages (transformed from `reasoning` during build) crashes the + OpenAI SDK's json.dumps(), and the recovery block previously only + sanitized the canonical `messages` list — not `api_messages` — so the + next retry would send the same broken payload and fail 3 times. + """ + + def test_api_messages_reasoning_content_sanitized(self): + """The extended sanitizer catches reasoning_content in api_messages.""" + api_messages = [ + {"role": "system", "content": "sys"}, + { + "role": "assistant", + "content": "response", + "reasoning_content": "thought \udce2 trail", + "tool_calls": [ + { + "id": "call_1", + "function": {"name": "tool", "arguments": "{}"}, + } + ], + }, + {"role": "tool", "content": "result", "tool_call_id": "call_1"}, + ] + assert _sanitize_messages_surrogates(api_messages) is True + assert "\udce2" not in api_messages[1]["reasoning_content"] + # Full payload must now serialize clean + json.dumps(api_messages, ensure_ascii=False).encode("utf-8") + + class TestRunConversationSurrogateSanitization: """Integration: verify run_conversation sanitizes user_message.""" diff --git a/tests/test_project_metadata.py b/tests/test_project_metadata.py index e45b15725..27a1002b5 100644 --- a/tests/test_project_metadata.py +++ b/tests/test_project_metadata.py @@ -34,3 +34,21 @@ def test_messaging_extra_includes_qrcode_for_weixin_setup(): messaging_extra = optional_dependencies["messaging"] assert any(dep.startswith("qrcode") for dep in messaging_extra) + + +def test_dingtalk_extra_includes_qrcode_for_qr_auth(): + """DingTalk's QR-code device-flow auth (hermes_cli/dingtalk_auth.py) + needs the qrcode package.""" + optional_dependencies = _load_optional_dependencies() + + dingtalk_extra = optional_dependencies["dingtalk"] + assert any(dep.startswith("qrcode") for dep in dingtalk_extra) + + +def test_feishu_extra_includes_qrcode_for_qr_login(): + """Feishu's QR login flow (gateway/platforms/feishu.py) needs the + qrcode package.""" + optional_dependencies = _load_optional_dependencies() + + feishu_extra = optional_dependencies["feishu"] + assert any(dep.startswith("qrcode") for dep in feishu_extra)