Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor

This commit is contained in:
Brooklyn Nicholson 2026-04-17 15:44:57 -05:00
commit bd09e42eac
14 changed files with 1072 additions and 100 deletions

View file

@ -49,6 +49,7 @@ def make_tool_progress_cb(
session_id: str, session_id: str,
loop: asyncio.AbstractEventLoop, loop: asyncio.AbstractEventLoop,
tool_call_ids: Dict[str, Deque[str]], tool_call_ids: Dict[str, Deque[str]],
tool_call_meta: Dict[str, Dict[str, Any]],
) -> Callable: ) -> Callable:
"""Create a ``tool_progress_callback`` for AIAgent. """Create a ``tool_progress_callback`` for AIAgent.
@ -84,6 +85,16 @@ def make_tool_progress_cb(
tool_call_ids[name] = queue tool_call_ids[name] = queue
queue.append(tc_id) queue.append(tc_id)
snapshot = None
if name in {"write_file", "patch", "skill_manage"}:
try:
from agent.display import capture_local_edit_snapshot
snapshot = capture_local_edit_snapshot(name, args)
except Exception:
logger.debug("Failed to capture ACP edit snapshot for %s", name, exc_info=True)
tool_call_meta[tc_id] = {"args": args, "snapshot": snapshot}
update = build_tool_start(tc_id, name, args) update = build_tool_start(tc_id, name, args)
_send_update(conn, session_id, loop, update) _send_update(conn, session_id, loop, update)
@ -119,6 +130,7 @@ def make_step_cb(
session_id: str, session_id: str,
loop: asyncio.AbstractEventLoop, loop: asyncio.AbstractEventLoop,
tool_call_ids: Dict[str, Deque[str]], tool_call_ids: Dict[str, Deque[str]],
tool_call_meta: Dict[str, Dict[str, Any]],
) -> Callable: ) -> Callable:
"""Create a ``step_callback`` for AIAgent. """Create a ``step_callback`` for AIAgent.
@ -132,10 +144,12 @@ def make_step_cb(
for tool_info in prev_tools: for tool_info in prev_tools:
tool_name = None tool_name = None
result = None result = None
function_args = None
if isinstance(tool_info, dict): if isinstance(tool_info, dict):
tool_name = tool_info.get("name") or tool_info.get("function_name") tool_name = tool_info.get("name") or tool_info.get("function_name")
result = tool_info.get("result") or tool_info.get("output") result = tool_info.get("result") or tool_info.get("output")
function_args = tool_info.get("arguments") or tool_info.get("args")
elif isinstance(tool_info, str): elif isinstance(tool_info, str):
tool_name = tool_info tool_name = tool_info
@ -145,8 +159,13 @@ def make_step_cb(
tool_call_ids[tool_name] = queue tool_call_ids[tool_name] = queue
if tool_name and queue: if tool_name and queue:
tc_id = queue.popleft() tc_id = queue.popleft()
meta = tool_call_meta.pop(tc_id, {})
update = build_tool_complete( update = build_tool_complete(
tc_id, tool_name, result=str(result) if result is not None else None tc_id,
tool_name,
result=str(result) if result is not None else None,
function_args=function_args or meta.get("args"),
snapshot=meta.get("snapshot"),
) )
_send_update(conn, session_id, loop, update) _send_update(conn, session_id, loop, update)
if not queue: if not queue:

View file

@ -26,6 +26,7 @@ from acp.schema import (
McpServerHttp, McpServerHttp,
McpServerSse, McpServerSse,
McpServerStdio, McpServerStdio,
ModelInfo,
NewSessionResponse, NewSessionResponse,
PromptResponse, PromptResponse,
ResumeSessionResponse, ResumeSessionResponse,
@ -36,6 +37,7 @@ from acp.schema import (
SessionCapabilities, SessionCapabilities,
SessionForkCapabilities, SessionForkCapabilities,
SessionListCapabilities, SessionListCapabilities,
SessionModelState,
SessionResumeCapabilities, SessionResumeCapabilities,
SessionInfo, SessionInfo,
TextContentBlock, TextContentBlock,
@ -147,6 +149,98 @@ class HermesACPAgent(acp.Agent):
self._conn = conn self._conn = conn
logger.info("ACP client connected") logger.info("ACP client connected")
@staticmethod
def _encode_model_choice(provider: str | None, model: str | None) -> str:
"""Encode a model selection so ACP clients can keep provider context."""
raw_model = str(model or "").strip()
if not raw_model:
return ""
raw_provider = str(provider or "").strip().lower()
if not raw_provider:
return raw_model
return f"{raw_provider}:{raw_model}"
def _build_model_state(self, state: SessionState) -> SessionModelState | None:
"""Return the ACP model selector payload for editors like Zed."""
model = str(state.model or getattr(state.agent, "model", "") or "").strip()
provider = getattr(state.agent, "provider", None) or detect_provider() or "openrouter"
try:
from hermes_cli.models import curated_models_for_provider, normalize_provider, provider_label
normalized_provider = normalize_provider(provider)
provider_name = provider_label(normalized_provider)
available_models: list[ModelInfo] = []
seen_ids: set[str] = set()
for model_id, description in curated_models_for_provider(normalized_provider):
rendered_model = str(model_id or "").strip()
if not rendered_model:
continue
choice_id = self._encode_model_choice(normalized_provider, rendered_model)
if choice_id in seen_ids:
continue
desc_parts = [f"Provider: {provider_name}"]
if description:
desc_parts.append(str(description).strip())
if rendered_model == model:
desc_parts.append("current")
available_models.append(
ModelInfo(
model_id=choice_id,
name=rendered_model,
description="".join(part for part in desc_parts if part),
)
)
seen_ids.add(choice_id)
current_model_id = self._encode_model_choice(normalized_provider, model)
if current_model_id and current_model_id not in seen_ids:
available_models.insert(
0,
ModelInfo(
model_id=current_model_id,
name=model,
description=f"Provider: {provider_name} • current",
),
)
if available_models:
return SessionModelState(
available_models=available_models,
current_model_id=current_model_id or available_models[0].model_id,
)
except Exception:
logger.debug("Could not build ACP model state", exc_info=True)
if not model:
return None
fallback_choice = self._encode_model_choice(provider, model)
return SessionModelState(
available_models=[ModelInfo(model_id=fallback_choice, name=model)],
current_model_id=fallback_choice,
)
@staticmethod
def _resolve_model_selection(raw_model: str, current_provider: str) -> tuple[str, str]:
"""Resolve ``provider:model`` input into the provider and normalized model id."""
target_provider = current_provider
new_model = raw_model.strip()
try:
from hermes_cli.models import detect_provider_for_model, parse_model_input
target_provider, new_model = parse_model_input(new_model, current_provider)
if target_provider == current_provider:
detected = detect_provider_for_model(new_model, current_provider)
if detected:
target_provider, new_model = detected
except Exception:
logger.debug("Provider detection failed, using model as-is", exc_info=True)
return target_provider, new_model
async def _register_session_mcp_servers( async def _register_session_mcp_servers(
self, self,
state: SessionState, state: SessionState,
@ -273,7 +367,10 @@ class HermesACPAgent(acp.Agent):
await self._register_session_mcp_servers(state, mcp_servers) await self._register_session_mcp_servers(state, mcp_servers)
logger.info("New session %s (cwd=%s)", state.session_id, cwd) logger.info("New session %s (cwd=%s)", state.session_id, cwd)
self._schedule_available_commands_update(state.session_id) self._schedule_available_commands_update(state.session_id)
return NewSessionResponse(session_id=state.session_id) return NewSessionResponse(
session_id=state.session_id,
models=self._build_model_state(state),
)
async def load_session( async def load_session(
self, self,
@ -289,7 +386,7 @@ class HermesACPAgent(acp.Agent):
await self._register_session_mcp_servers(state, mcp_servers) await self._register_session_mcp_servers(state, mcp_servers)
logger.info("Loaded session %s", session_id) logger.info("Loaded session %s", session_id)
self._schedule_available_commands_update(session_id) self._schedule_available_commands_update(session_id)
return LoadSessionResponse() return LoadSessionResponse(models=self._build_model_state(state))
async def resume_session( async def resume_session(
self, self,
@ -305,7 +402,7 @@ class HermesACPAgent(acp.Agent):
await self._register_session_mcp_servers(state, mcp_servers) await self._register_session_mcp_servers(state, mcp_servers)
logger.info("Resumed session %s", state.session_id) logger.info("Resumed session %s", state.session_id)
self._schedule_available_commands_update(state.session_id) self._schedule_available_commands_update(state.session_id)
return ResumeSessionResponse() return ResumeSessionResponse(models=self._build_model_state(state))
async def cancel(self, session_id: str, **kwargs: Any) -> None: async def cancel(self, session_id: str, **kwargs: Any) -> None:
state = self.session_manager.get_session(session_id) state = self.session_manager.get_session(session_id)
@ -340,11 +437,20 @@ class HermesACPAgent(acp.Agent):
cwd: str | None = None, cwd: str | None = None,
**kwargs: Any, **kwargs: Any,
) -> ListSessionsResponse: ) -> ListSessionsResponse:
infos = self.session_manager.list_sessions() infos = self.session_manager.list_sessions(cwd=cwd)
sessions = [ sessions = []
SessionInfo(session_id=s["session_id"], cwd=s["cwd"]) for s in infos:
for s in infos updated_at = s.get("updated_at")
] if updated_at is not None and not isinstance(updated_at, str):
updated_at = str(updated_at)
sessions.append(
SessionInfo(
session_id=s["session_id"],
cwd=s["cwd"],
title=s.get("title"),
updated_at=updated_at,
)
)
return ListSessionsResponse(sessions=sessions) return ListSessionsResponse(sessions=sessions)
# ---- Prompt (core) ------------------------------------------------------ # ---- Prompt (core) ------------------------------------------------------
@ -389,12 +495,13 @@ class HermesACPAgent(acp.Agent):
state.cancel_event.clear() state.cancel_event.clear()
tool_call_ids: dict[str, Deque[str]] = defaultdict(deque) tool_call_ids: dict[str, Deque[str]] = defaultdict(deque)
tool_call_meta: dict[str, dict[str, Any]] = {}
previous_approval_cb = None previous_approval_cb = None
if conn: if conn:
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids) tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
thinking_cb = make_thinking_cb(conn, session_id, loop) thinking_cb = make_thinking_cb(conn, session_id, loop)
step_cb = make_step_cb(conn, session_id, loop, tool_call_ids) step_cb = make_step_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
message_cb = make_message_cb(conn, session_id, loop) message_cb = make_message_cb(conn, session_id, loop)
approval_cb = make_approval_callback(conn.request_permission, loop, session_id) approval_cb = make_approval_callback(conn.request_permission, loop, session_id)
else: else:
@ -449,6 +556,19 @@ class HermesACPAgent(acp.Agent):
self.session_manager.save_session(session_id) self.session_manager.save_session(session_id)
final_response = result.get("final_response", "") final_response = result.get("final_response", "")
if final_response:
try:
from agent.title_generator import maybe_auto_title
maybe_auto_title(
self.session_manager._get_db(),
session_id,
user_text,
final_response,
state.history,
)
except Exception:
logger.debug("Failed to auto-title ACP session %s", session_id, exc_info=True)
if final_response and conn: if final_response and conn:
update = acp.update_agent_message_text(final_response) update = acp.update_agent_message_text(final_response)
await conn.session_update(session_id, update) await conn.session_update(session_id, update)
@ -556,27 +676,15 @@ class HermesACPAgent(acp.Agent):
provider = getattr(state.agent, "provider", None) or "auto" provider = getattr(state.agent, "provider", None) or "auto"
return f"Current model: {model}\nProvider: {provider}" return f"Current model: {model}\nProvider: {provider}"
new_model = args.strip()
target_provider = None
current_provider = getattr(state.agent, "provider", None) or "openrouter" current_provider = getattr(state.agent, "provider", None) or "openrouter"
target_provider, new_model = self._resolve_model_selection(args, current_provider)
# Auto-detect provider for the requested model
try:
from hermes_cli.models import parse_model_input, detect_provider_for_model
target_provider, new_model = parse_model_input(new_model, current_provider)
if target_provider == current_provider:
detected = detect_provider_for_model(new_model, current_provider)
if detected:
target_provider, new_model = detected
except Exception:
logger.debug("Provider detection failed, using model as-is", exc_info=True)
state.model = new_model state.model = new_model
state.agent = self.session_manager._make_agent( state.agent = self.session_manager._make_agent(
session_id=state.session_id, session_id=state.session_id,
cwd=state.cwd, cwd=state.cwd,
model=new_model, model=new_model,
requested_provider=target_provider or current_provider, requested_provider=target_provider,
) )
self.session_manager.save_session(state.session_id) self.session_manager.save_session(state.session_id)
provider_label = getattr(state.agent, "provider", None) or target_provider or current_provider provider_label = getattr(state.agent, "provider", None) or target_provider or current_provider
@ -678,20 +786,30 @@ class HermesACPAgent(acp.Agent):
"""Switch the model for a session (called by ACP protocol).""" """Switch the model for a session (called by ACP protocol)."""
state = self.session_manager.get_session(session_id) state = self.session_manager.get_session(session_id)
if state: if state:
state.model = model_id
current_provider = getattr(state.agent, "provider", None) current_provider = getattr(state.agent, "provider", None)
current_base_url = getattr(state.agent, "base_url", None) requested_provider, resolved_model = self._resolve_model_selection(
current_api_mode = getattr(state.agent, "api_mode", None) model_id,
current_provider or "openrouter",
)
state.model = resolved_model
provider_changed = bool(current_provider and requested_provider != current_provider)
current_base_url = None if provider_changed else getattr(state.agent, "base_url", None)
current_api_mode = None if provider_changed else getattr(state.agent, "api_mode", None)
state.agent = self.session_manager._make_agent( state.agent = self.session_manager._make_agent(
session_id=session_id, session_id=session_id,
cwd=state.cwd, cwd=state.cwd,
model=model_id, model=resolved_model,
requested_provider=current_provider, requested_provider=requested_provider,
base_url=current_base_url, base_url=current_base_url,
api_mode=current_api_mode, api_mode=current_api_mode,
) )
self.session_manager.save_session(session_id) self.session_manager.save_session(session_id)
logger.info("Session %s: model switched to %s", session_id, model_id) logger.info(
"Session %s: model switched to %s via provider %s",
session_id,
resolved_model,
requested_provider,
)
return SetSessionModelResponse() return SetSessionModelResponse()
logger.warning("Session %s: model switch requested for missing session", session_id) logger.warning("Session %s: model switch requested for missing session", session_id)
return None return None

View file

@ -13,8 +13,12 @@ from hermes_constants import get_hermes_home
import copy import copy
import json import json
import logging import logging
import os
import re
import sys import sys
import time
import uuid import uuid
from datetime import datetime, timezone
from dataclasses import dataclass, field from dataclasses import dataclass, field
from threading import Lock from threading import Lock
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -22,6 +26,64 @@ from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _normalize_cwd_for_compare(cwd: str | None) -> str:
raw = str(cwd or ".").strip()
if not raw:
raw = "."
expanded = os.path.expanduser(raw)
# Normalize Windows drive paths into the equivalent WSL mount form so
# ACP history filters match the same workspace across Windows and WSL.
match = re.match(r"^([A-Za-z]):[\\/](.*)$", expanded)
if match:
drive = match.group(1).lower()
tail = match.group(2).replace("\\", "/")
expanded = f"/mnt/{drive}/{tail}"
elif re.match(r"^/mnt/[A-Za-z]/", expanded):
expanded = f"/mnt/{expanded[5].lower()}/{expanded[7:]}"
return os.path.normpath(expanded)
def _build_session_title(title: Any, preview: Any, cwd: str | None) -> str:
explicit = str(title or "").strip()
if explicit:
return explicit
preview_text = str(preview or "").strip()
if preview_text:
return preview_text
leaf = os.path.basename(str(cwd or "").rstrip("/\\"))
return leaf or "New thread"
def _format_updated_at(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str) and value.strip():
return value
try:
return datetime.fromtimestamp(float(value), tz=timezone.utc).isoformat()
except Exception:
return None
def _updated_at_sort_key(value: Any) -> float:
if value is None:
return float("-inf")
if isinstance(value, (int, float)):
return float(value)
raw = str(value).strip()
if not raw:
return float("-inf")
try:
return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp()
except Exception:
try:
return float(raw)
except Exception:
return float("-inf")
def _acp_stderr_print(*args, **kwargs) -> None: def _acp_stderr_print(*args, **kwargs) -> None:
"""Best-effort human-readable output sink for ACP stdio sessions. """Best-effort human-readable output sink for ACP stdio sessions.
@ -162,47 +224,78 @@ class SessionManager:
logger.info("Forked ACP session %s -> %s", session_id, new_id) logger.info("Forked ACP session %s -> %s", session_id, new_id)
return state return state
def list_sessions(self) -> List[Dict[str, Any]]: def list_sessions(self, cwd: str | None = None) -> List[Dict[str, Any]]:
"""Return lightweight info dicts for all sessions (memory + database).""" """Return lightweight info dicts for all sessions (memory + database)."""
normalized_cwd = _normalize_cwd_for_compare(cwd) if cwd else None
db = self._get_db()
persisted_rows: dict[str, dict[str, Any]] = {}
if db is not None:
try:
for row in db.list_sessions_rich(source="acp", limit=1000):
persisted_rows[str(row["id"])] = dict(row)
except Exception:
logger.debug("Failed to load ACP sessions from DB", exc_info=True)
# Collect in-memory sessions first. # Collect in-memory sessions first.
with self._lock: with self._lock:
seen_ids = set(self._sessions.keys()) seen_ids = set(self._sessions.keys())
results = [ results = []
{ for s in self._sessions.values():
"session_id": s.session_id, history_len = len(s.history)
"cwd": s.cwd, if history_len <= 0:
"model": s.model, continue
"history_len": len(s.history), if normalized_cwd and _normalize_cwd_for_compare(s.cwd) != normalized_cwd:
} continue
for s in self._sessions.values() persisted = persisted_rows.get(s.session_id, {})
] preview = next(
(
str(msg.get("content") or "").strip()
for msg in s.history
if msg.get("role") == "user" and str(msg.get("content") or "").strip()
),
persisted.get("preview") or "",
)
results.append(
{
"session_id": s.session_id,
"cwd": s.cwd,
"model": s.model,
"history_len": history_len,
"title": _build_session_title(persisted.get("title"), preview, s.cwd),
"updated_at": _format_updated_at(
persisted.get("last_active") or persisted.get("started_at") or time.time()
),
}
)
# Merge any persisted sessions not currently in memory. # Merge any persisted sessions not currently in memory.
db = self._get_db() for sid, row in persisted_rows.items():
if db is not None: if sid in seen_ids:
try: continue
rows = db.search_sessions(source="acp", limit=1000) message_count = int(row.get("message_count") or 0)
for row in rows: if message_count <= 0:
sid = row["id"] continue
if sid in seen_ids: # Extract cwd from model_config JSON.
continue session_cwd = "."
# Extract cwd from model_config JSON. mc = row.get("model_config")
cwd = "." if mc:
mc = row.get("model_config") try:
if mc: session_cwd = json.loads(mc).get("cwd", ".")
try: except (json.JSONDecodeError, TypeError):
cwd = json.loads(mc).get("cwd", ".") pass
except (json.JSONDecodeError, TypeError): if normalized_cwd and _normalize_cwd_for_compare(session_cwd) != normalized_cwd:
pass continue
results.append({ results.append({
"session_id": sid, "session_id": sid,
"cwd": cwd, "cwd": session_cwd,
"model": row.get("model") or "", "model": row.get("model") or "",
"history_len": row.get("message_count") or 0, "history_len": message_count,
}) "title": _build_session_title(row.get("title"), row.get("preview"), session_cwd),
except Exception: "updated_at": _format_updated_at(row.get("last_active") or row.get("started_at")),
logger.debug("Failed to list ACP sessions from DB", exc_info=True) })
results.sort(key=lambda item: _updated_at_sort_key(item.get("updated_at")), reverse=True)
return results return results
def update_cwd(self, session_id: str, cwd: str) -> Optional[SessionState]: def update_cwd(self, session_id: str, cwd: str) -> Optional[SessionState]:

View file

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import json
import uuid import uuid
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -96,6 +97,170 @@ def build_tool_title(tool_name: str, args: Dict[str, Any]) -> str:
return tool_name return tool_name
def _build_patch_mode_content(patch_text: str) -> List[Any]:
"""Parse V4A patch mode input into ACP diff blocks when possible."""
if not patch_text:
return [acp.tool_content(acp.text_block(""))]
try:
from tools.patch_parser import OperationType, parse_v4a_patch
operations, error = parse_v4a_patch(patch_text)
if error or not operations:
return [acp.tool_content(acp.text_block(patch_text))]
content: List[Any] = []
for op in operations:
if op.operation == OperationType.UPDATE:
old_chunks: list[str] = []
new_chunks: list[str] = []
for hunk in op.hunks:
old_lines = [line.content for line in hunk.lines if line.prefix in (" ", "-")]
new_lines = [line.content for line in hunk.lines if line.prefix in (" ", "+")]
if old_lines or new_lines:
old_chunks.append("\n".join(old_lines))
new_chunks.append("\n".join(new_lines))
old_text = "\n...\n".join(chunk for chunk in old_chunks if chunk)
new_text = "\n...\n".join(chunk for chunk in new_chunks if chunk)
if old_text or new_text:
content.append(
acp.tool_diff_content(
path=op.file_path,
old_text=old_text or None,
new_text=new_text or "",
)
)
continue
if op.operation == OperationType.ADD:
added_lines = [line.content for hunk in op.hunks for line in hunk.lines if line.prefix == "+"]
content.append(
acp.tool_diff_content(
path=op.file_path,
new_text="\n".join(added_lines),
)
)
continue
if op.operation == OperationType.DELETE:
content.append(
acp.tool_diff_content(
path=op.file_path,
old_text=f"Delete file: {op.file_path}",
new_text="",
)
)
continue
if op.operation == OperationType.MOVE:
content.append(
acp.tool_content(acp.text_block(f"Move file: {op.file_path} -> {op.new_path}"))
)
return content or [acp.tool_content(acp.text_block(patch_text))]
except Exception:
return [acp.tool_content(acp.text_block(patch_text))]
def _strip_diff_prefix(path: str) -> str:
raw = str(path or "").strip()
if raw.startswith(("a/", "b/")):
return raw[2:]
return raw
def _parse_unified_diff_content(diff_text: str) -> List[Any]:
"""Convert unified diff text into ACP diff content blocks."""
if not diff_text:
return []
content: List[Any] = []
current_old_path: Optional[str] = None
current_new_path: Optional[str] = None
old_lines: list[str] = []
new_lines: list[str] = []
def _flush() -> None:
nonlocal current_old_path, current_new_path, old_lines, new_lines
if current_old_path is None and current_new_path is None:
return
path = current_new_path if current_new_path and current_new_path != "/dev/null" else current_old_path
if not path or path == "/dev/null":
current_old_path = None
current_new_path = None
old_lines = []
new_lines = []
return
content.append(
acp.tool_diff_content(
path=_strip_diff_prefix(path),
old_text="\n".join(old_lines) if old_lines else None,
new_text="\n".join(new_lines),
)
)
current_old_path = None
current_new_path = None
old_lines = []
new_lines = []
for line in diff_text.splitlines():
if line.startswith("--- "):
_flush()
current_old_path = line[4:].strip()
continue
if line.startswith("+++ "):
current_new_path = line[4:].strip()
continue
if line.startswith("@@"):
continue
if current_old_path is None and current_new_path is None:
continue
if line.startswith("+"):
new_lines.append(line[1:])
elif line.startswith("-"):
old_lines.append(line[1:])
elif line.startswith(" "):
shared = line[1:]
old_lines.append(shared)
new_lines.append(shared)
_flush()
return content
def _build_tool_complete_content(
tool_name: str,
result: Optional[str],
*,
function_args: Optional[Dict[str, Any]] = None,
snapshot: Any = None,
) -> List[Any]:
"""Build structured ACP completion content, falling back to plain text."""
display_result = result or ""
if len(display_result) > 5000:
display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)"
if tool_name in {"write_file", "patch", "skill_manage"}:
try:
from agent.display import extract_edit_diff
diff_text = extract_edit_diff(
tool_name,
result,
function_args=function_args,
snapshot=snapshot,
)
if isinstance(diff_text, str) and diff_text.strip():
diff_content = _parse_unified_diff_content(diff_text)
if diff_content:
return diff_content
except Exception:
pass
return [acp.tool_content(acp.text_block(display_result))]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Build ACP content objects for tool-call events # Build ACP content objects for tool-call events
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -119,9 +284,8 @@ def build_tool_start(
new = arguments.get("new_string", "") new = arguments.get("new_string", "")
content = [acp.tool_diff_content(path=path, new_text=new, old_text=old)] content = [acp.tool_diff_content(path=path, new_text=new, old_text=old)]
else: else:
# Patch mode — show the patch content as text
patch_text = arguments.get("patch", "") patch_text = arguments.get("patch", "")
content = [acp.tool_content(acp.text_block(patch_text))] content = _build_patch_mode_content(patch_text)
return acp.start_tool_call( return acp.start_tool_call(
tool_call_id, title, kind=kind, content=content, locations=locations, tool_call_id, title, kind=kind, content=content, locations=locations,
raw_input=arguments, raw_input=arguments,
@ -178,16 +342,17 @@ def build_tool_complete(
tool_call_id: str, tool_call_id: str,
tool_name: str, tool_name: str,
result: Optional[str] = None, result: Optional[str] = None,
function_args: Optional[Dict[str, Any]] = None,
snapshot: Any = None,
) -> ToolCallProgress: ) -> ToolCallProgress:
"""Create a ToolCallUpdate (progress) event for a completed tool call.""" """Create a ToolCallUpdate (progress) event for a completed tool call."""
kind = get_tool_kind(tool_name) kind = get_tool_kind(tool_name)
content = _build_tool_complete_content(
# Truncate very large results for the UI tool_name,
display_result = result or "" result,
if len(display_result) > 5000: function_args=function_args,
display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)" snapshot=snapshot,
)
content = [acp.tool_content(acp.text_block(display_result))]
return acp.update_tool_call( return acp.update_tool_call(
tool_call_id, tool_call_id,
kind=kind, kind=kind,

View file

@ -76,8 +76,8 @@ termux = [
"hermes-agent[honcho]", "hermes-agent[honcho]",
"hermes-agent[acp]", "hermes-agent[acp]",
] ]
dingtalk = ["dingtalk-stream>=0.1.0,<1"] dingtalk = ["dingtalk-stream>=0.1.0,<1", "qrcode>=7.0,<8"]
feishu = ["lark-oapi>=1.5.3,<2"] feishu = ["lark-oapi>=1.5.3,<2", "qrcode>=7.0,<8"]
web = ["fastapi>=0.104.0,<1", "uvicorn[standard]>=0.24.0,<1"] web = ["fastapi>=0.104.0,<1", "uvicorn[standard]>=0.24.0,<1"]
rl = [ rl = [
"atroposlib @ git+https://github.com/NousResearch/atropos.git@c20c85256e5a45ad31edf8b7276e9c5ee1995a30", "atroposlib @ git+https://github.com/NousResearch/atropos.git@c20c85256e5a45ad31edf8b7276e9c5ee1995a30",

View file

@ -353,12 +353,50 @@ def _sanitize_surrogates(text: str) -> str:
return text return text
def _sanitize_structure_surrogates(payload: Any) -> bool:
"""Replace surrogate code points in nested dict/list payloads in-place.
Mirror of ``_sanitize_structure_non_ascii`` but for surrogate recovery.
Used to scrub nested structured fields (e.g. ``reasoning_details`` an
array of dicts with ``summary``/``text`` strings) that flat per-field
checks don't reach. Returns True if any surrogates were replaced.
"""
found = False
def _walk(node):
nonlocal found
if isinstance(node, dict):
for key, value in node.items():
if isinstance(value, str):
if _SURROGATE_RE.search(value):
node[key] = _SURROGATE_RE.sub('\ufffd', value)
found = True
elif isinstance(value, (dict, list)):
_walk(value)
elif isinstance(node, list):
for idx, value in enumerate(node):
if isinstance(value, str):
if _SURROGATE_RE.search(value):
node[idx] = _SURROGATE_RE.sub('\ufffd', value)
found = True
elif isinstance(value, (dict, list)):
_walk(value)
_walk(payload)
return found
def _sanitize_messages_surrogates(messages: list) -> bool: def _sanitize_messages_surrogates(messages: list) -> bool:
"""Sanitize surrogate characters from all string content in a messages list. """Sanitize surrogate characters from all string content in a messages list.
Walks message dicts in-place. Returns True if any surrogates were found Walks message dicts in-place. Returns True if any surrogates were found
and replaced, False otherwise. Covers content/text, name, and tool call and replaced, False otherwise. Covers content/text, name, tool call
metadata/arguments so retries don't fail on a non-content field. metadata/arguments, AND any additional string or nested structured fields
(``reasoning``, ``reasoning_content``, ``reasoning_details``, etc.) so
retries don't fail on a non-content field. Byte-level reasoning models
(xiaomi/mimo, kimi, glm) can emit lone surrogates in reasoning output
that flow through to ``api_messages["reasoning_content"]`` on the next
turn and crash json.dumps inside the OpenAI SDK.
""" """
found = False found = False
for msg in messages: for msg in messages:
@ -398,6 +436,21 @@ def _sanitize_messages_surrogates(messages: list) -> bool:
if isinstance(fn_args, str) and _SURROGATE_RE.search(fn_args): if isinstance(fn_args, str) and _SURROGATE_RE.search(fn_args):
fn["arguments"] = _SURROGATE_RE.sub('\ufffd', fn_args) fn["arguments"] = _SURROGATE_RE.sub('\ufffd', fn_args)
found = True found = True
# Walk any additional string / nested fields (reasoning,
# reasoning_content, reasoning_details, etc.) — surrogates from
# byte-level reasoning models (xiaomi/mimo, kimi, glm) can lurk
# in these fields and aren't covered by the per-field checks above.
# Matches _sanitize_messages_non_ascii's coverage (PR #10537).
for key, value in msg.items():
if key in {"content", "name", "tool_calls", "role"}:
continue
if isinstance(value, str):
if _SURROGATE_RE.search(value):
msg[key] = _SURROGATE_RE.sub('\ufffd', value)
found = True
elif isinstance(value, (dict, list)):
if _sanitize_structure_surrogates(value):
found = True
return found return found
@ -8689,6 +8742,7 @@ class AIAgent:
{ {
"name": tc["function"]["name"], "name": tc["function"]["name"],
"result": _results_by_id.get(tc.get("id")), "result": _results_by_id.get(tc.get("id")),
"arguments": tc["function"].get("arguments"),
} }
for tc in _m["tool_calls"] for tc in _m["tool_calls"]
if isinstance(tc, dict) if isinstance(tc, dict)
@ -9303,8 +9357,7 @@ class AIAgent:
"and had none left for the actual response.\n\n" "and had none left for the actual response.\n\n"
"To fix this:\n" "To fix this:\n"
"→ Lower reasoning effort: `/thinkon low` or `/thinkon minimal`\n" "→ Lower reasoning effort: `/thinkon low` or `/thinkon minimal`\n"
"→ Increase the output token limit: " "→ Or switch to a larger/non-reasoning model with `/model`"
"set `model.max_tokens` in config.yaml"
) )
self._cleanup_task_resources(effective_task_id) self._cleanup_task_resources(effective_task_id)
self._persist_session(messages, conversation_history) self._persist_session(messages, conversation_history)
@ -9571,13 +9624,51 @@ class AIAgent:
if isinstance(api_error, UnicodeEncodeError) and getattr(self, '_unicode_sanitization_passes', 0) < 2: if isinstance(api_error, UnicodeEncodeError) and getattr(self, '_unicode_sanitization_passes', 0) < 2:
_err_str = str(api_error).lower() _err_str = str(api_error).lower()
_is_ascii_codec = "'ascii'" in _err_str or "ascii" in _err_str _is_ascii_codec = "'ascii'" in _err_str or "ascii" in _err_str
# Detect surrogate errors — utf-8 codec refusing to
# encode U+D800..U+DFFF. The error text is:
# "'utf-8' codec can't encode characters in position
# N-M: surrogates not allowed"
_is_surrogate_error = (
"surrogate" in _err_str
or ("'utf-8'" in _err_str and not _is_ascii_codec)
)
# Sanitize surrogates from both the canonical `messages`
# list AND `api_messages` (the API-copy, which may carry
# `reasoning_content`/`reasoning_details` transformed
# from `reasoning` — fields the canonical list doesn't
# have directly). Also clean `api_kwargs` if built and
# `prefill_messages` if present. Mirrors the ASCII
# codec recovery below.
_surrogates_found = _sanitize_messages_surrogates(messages) _surrogates_found = _sanitize_messages_surrogates(messages)
if _surrogates_found: if isinstance(api_messages, list):
if _sanitize_messages_surrogates(api_messages):
_surrogates_found = True
if isinstance(api_kwargs, dict):
if _sanitize_structure_surrogates(api_kwargs):
_surrogates_found = True
if isinstance(getattr(self, "prefill_messages", None), list):
if _sanitize_messages_surrogates(self.prefill_messages):
_surrogates_found = True
# Gate the retry on the error type, not on whether we
# found anything — _force_ascii_payload / the extended
# surrogate walker above cover all known paths, but a
# new transformed field could still slip through. If
# the error was a surrogate encode failure, always let
# the retry run; the proactive sanitizer at line ~8781
# runs again on the next iteration. Bounded by
# _unicode_sanitization_passes < 2 (outer guard).
if _surrogates_found or _is_surrogate_error:
self._unicode_sanitization_passes += 1 self._unicode_sanitization_passes += 1
self._vprint( if _surrogates_found:
f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...", self._vprint(
force=True, f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...",
) force=True,
)
else:
self._vprint(
f"{self.log_prefix}⚠️ Surrogate encoding error — retrying after full-payload sanitization...",
force=True,
)
continue continue
if _is_ascii_codec: if _is_ascii_codec:
self._force_ascii_payload = True self._force_ascii_payload = True

View file

@ -103,6 +103,7 @@ AUTHOR_MAP = {
"dangtc94@gmail.com": "dieutx", "dangtc94@gmail.com": "dieutx",
"jaisehgal11299@gmail.com": "jaisup", "jaisehgal11299@gmail.com": "jaisup",
"percydikec@gmail.com": "PercyDikec", "percydikec@gmail.com": "PercyDikec",
"noonou7@gmail.com": "HenkDz",
"dean.kerr@gmail.com": "deankerr", "dean.kerr@gmail.com": "deankerr",
"socrates1024@gmail.com": "socrates1024", "socrates1024@gmail.com": "socrates1024",
"satelerd@gmail.com": "satelerd", "satelerd@gmail.com": "satelerd",

View file

@ -42,9 +42,10 @@ class TestToolProgressCallback:
def test_emits_tool_call_start(self, mock_conn, event_loop_fixture): def test_emits_tool_call_start(self, mock_conn, event_loop_fixture):
"""Tool progress should emit a ToolCallStart update.""" """Tool progress should emit a ToolCallStart update."""
tool_call_ids = {} tool_call_ids = {}
tool_call_meta = {}
loop = event_loop_fixture loop = event_loop_fixture
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids) cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
# Run callback in the event loop context # Run callback in the event loop context
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts: with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
@ -66,9 +67,10 @@ class TestToolProgressCallback:
def test_handles_string_args(self, mock_conn, event_loop_fixture): def test_handles_string_args(self, mock_conn, event_loop_fixture):
"""If args is a JSON string, it should be parsed.""" """If args is a JSON string, it should be parsed."""
tool_call_ids = {} tool_call_ids = {}
tool_call_meta = {}
loop = event_loop_fixture loop = event_loop_fixture
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids) cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts: with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
future = MagicMock(spec=Future) future = MagicMock(spec=Future)
@ -82,9 +84,10 @@ class TestToolProgressCallback:
def test_handles_non_dict_args(self, mock_conn, event_loop_fixture): def test_handles_non_dict_args(self, mock_conn, event_loop_fixture):
"""If args is not a dict, it should be wrapped.""" """If args is not a dict, it should be wrapped."""
tool_call_ids = {} tool_call_ids = {}
tool_call_meta = {}
loop = event_loop_fixture loop = event_loop_fixture
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids) cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts: with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
future = MagicMock(spec=Future) future = MagicMock(spec=Future)
@ -98,10 +101,11 @@ class TestToolProgressCallback:
def test_duplicate_same_name_tool_calls_use_fifo_ids(self, mock_conn, event_loop_fixture): def test_duplicate_same_name_tool_calls_use_fifo_ids(self, mock_conn, event_loop_fixture):
"""Multiple same-name tool calls should be tracked independently in order.""" """Multiple same-name tool calls should be tracked independently in order."""
tool_call_ids = {} tool_call_ids = {}
tool_call_meta = {}
loop = event_loop_fixture loop = event_loop_fixture
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids) progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids) step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts: with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
future = MagicMock(spec=Future) future = MagicMock(spec=Future)
@ -163,7 +167,7 @@ class TestStepCallback:
tool_call_ids = {"terminal": "tc-abc123"} tool_call_ids = {"terminal": "tc-abc123"}
loop = event_loop_fixture loop = event_loop_fixture
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids) cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts: with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
future = MagicMock(spec=Future) future = MagicMock(spec=Future)
@ -181,7 +185,7 @@ class TestStepCallback:
tool_call_ids = {} tool_call_ids = {}
loop = event_loop_fixture loop = event_loop_fixture
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids) cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts: with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
cb(1, [{"name": "unknown_tool", "result": "ok"}]) cb(1, [{"name": "unknown_tool", "result": "ok"}])
@ -193,7 +197,7 @@ class TestStepCallback:
tool_call_ids = {"read_file": "tc-def456"} tool_call_ids = {"read_file": "tc-def456"}
loop = event_loop_fixture loop = event_loop_fixture
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids) cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts: with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
future = MagicMock(spec=Future) future = MagicMock(spec=Future)
@ -212,7 +216,7 @@ class TestStepCallback:
tool_call_ids = {"terminal": deque(["tc-xyz789"])} tool_call_ids = {"terminal": deque(["tc-xyz789"])}
loop = event_loop_fixture loop = event_loop_fixture
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids) cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \ with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
patch("acp_adapter.events.build_tool_complete") as mock_btc: patch("acp_adapter.events.build_tool_complete") as mock_btc:
@ -224,7 +228,7 @@ class TestStepCallback:
cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}]) cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}])
mock_btc.assert_called_once_with( mock_btc.assert_called_once_with(
"tc-xyz789", "terminal", result='{"output": "hello"}' "tc-xyz789", "terminal", result='{"output": "hello"}', function_args=None, snapshot=None
) )
def test_none_result_passed_through(self, mock_conn, event_loop_fixture): def test_none_result_passed_through(self, mock_conn, event_loop_fixture):
@ -234,7 +238,7 @@ class TestStepCallback:
tool_call_ids = {"web_search": deque(["tc-aaa"])} tool_call_ids = {"web_search": deque(["tc-aaa"])}
loop = event_loop_fixture loop = event_loop_fixture
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids) cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \ with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
patch("acp_adapter.events.build_tool_complete") as mock_btc: patch("acp_adapter.events.build_tool_complete") as mock_btc:
@ -244,7 +248,50 @@ class TestStepCallback:
cb(1, [{"name": "web_search", "result": None}]) cb(1, [{"name": "web_search", "result": None}])
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None) mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None, function_args=None, snapshot=None)
def test_step_callback_passes_arguments_and_snapshot(self, mock_conn, event_loop_fixture):
from collections import deque
tool_call_ids = {"write_file": deque(["tc-write"])}
tool_call_meta = {"tc-write": {"args": {"path": "fallback.txt"}, "snapshot": "snap"}}
loop = event_loop_fixture
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
patch("acp_adapter.events.build_tool_complete") as mock_btc:
future = MagicMock(spec=Future)
future.result.return_value = None
mock_rcts.return_value = future
cb(1, [{"name": "write_file", "result": '{"bytes_written": 23}', "arguments": {"path": "diff-test.txt"}}])
mock_btc.assert_called_once_with(
"tc-write",
"write_file",
result='{"bytes_written": 23}',
function_args={"path": "diff-test.txt"},
snapshot="snap",
)
def test_tool_progress_captures_snapshot_metadata(self, mock_conn, event_loop_fixture):
tool_call_ids = {}
tool_call_meta = {}
loop = event_loop_fixture
with patch("acp_adapter.events.make_tool_call_id", return_value="tc-meta"), \
patch("acp_adapter.events._send_update") as mock_send, \
patch("agent.display.capture_local_edit_snapshot", return_value="snapshot"):
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
cb("tool.started", "write_file", None, {"path": "diff-test.txt", "content": "hello"})
assert list(tool_call_ids["write_file"]) == ["tc-meta"]
assert tool_call_meta["tc-meta"] == {
"args": {"path": "diff-test.txt", "content": "hello"},
"snapshot": "snapshot",
}
mock_send.assert_called_once()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View file

@ -29,6 +29,7 @@ from acp.schema import (
from acp_adapter.server import HermesACPAgent from acp_adapter.server import HermesACPAgent
from acp_adapter.session import SessionManager from acp_adapter.session import SessionManager
from acp_adapter.tools import build_tool_start
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -181,6 +182,25 @@ class TestMcpRegistrationE2E:
assert complete_event.raw_output is not None assert complete_event.raw_output is not None
assert "hello" in str(complete_event.raw_output) assert "hello" in str(complete_event.raw_output)
def test_patch_mode_tool_start_emits_diff_blocks_for_v4a_patch(self):
update = build_tool_start(
"tc-1",
"patch",
{
"mode": "patch",
"patch": "*** Begin Patch\n*** Update File: src/app.py\n@@\n-old line\n+new line\n*** Add File: src/new.py\n+hello\n*** End Patch",
},
)
assert len(update.content) == 2
assert update.content[0].type == "diff"
assert update.content[0].path == "src/app.py"
assert update.content[0].old_text == "old line"
assert update.content[0].new_text == "new line"
assert update.content[1].type == "diff"
assert update.content[1].path == "src/new.py"
assert update.content[1].new_text == "hello"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_prompt_tool_results_paired_by_call_id(self, acp_agent, mock_manager): async def test_prompt_tool_results_paired_by_call_id(self, acp_agent, mock_manager):
"""The ToolCallUpdate's toolCallId must match the ToolCallStart's.""" """The ToolCallUpdate's toolCallId must match the ToolCallStart's."""

View file

@ -20,7 +20,9 @@ from acp.schema import (
NewSessionResponse, NewSessionResponse,
PromptResponse, PromptResponse,
ResumeSessionResponse, ResumeSessionResponse,
SessionModelState,
SetSessionConfigOptionResponse, SetSessionConfigOptionResponse,
SetSessionModelResponse,
SetSessionModeResponse, SetSessionModeResponse,
SessionInfo, SessionInfo,
TextContentBlock, TextContentBlock,
@ -127,6 +129,25 @@ class TestSessionOps:
assert state is not None assert state is not None
assert state.cwd == "/home/user/project" assert state.cwd == "/home/user/project"
@pytest.mark.asyncio
async def test_new_session_returns_model_state(self):
manager = SessionManager(
agent_factory=lambda: SimpleNamespace(model="gpt-5.4", provider="openai-codex")
)
acp_agent = HermesACPAgent(session_manager=manager)
with patch(
"hermes_cli.models.curated_models_for_provider",
return_value=[("gpt-5.4", "recommended"), ("gpt-5.4-mini", "")],
):
resp = await acp_agent.new_session(cwd="/tmp")
assert isinstance(resp.models, SessionModelState)
assert resp.models.current_model_id == "openai-codex:gpt-5.4"
assert resp.models.available_models[0].model_id == "openai-codex:gpt-5.4"
assert resp.models.available_models[0].description is not None
assert "Provider:" in resp.models.available_models[0].description
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_available_commands_include_help(self, agent): async def test_available_commands_include_help(self, agent):
help_cmd = next( help_cmd = next(
@ -204,6 +225,33 @@ class TestListAndFork:
assert fork_resp.session_id assert fork_resp.session_id
assert fork_resp.session_id != new_resp.session_id assert fork_resp.session_id != new_resp.session_id
@pytest.mark.asyncio
async def test_list_sessions_includes_title_and_updated_at(self, agent):
with patch.object(
agent.session_manager,
"list_sessions",
return_value=[
{
"session_id": "session-1",
"cwd": "/tmp/project",
"title": "Fix Zed session history",
"updated_at": 123.0,
}
],
):
resp = await agent.list_sessions(cwd="/tmp/project")
assert isinstance(resp.sessions[0], SessionInfo)
assert resp.sessions[0].title == "Fix Zed session history"
assert resp.sessions[0].updated_at == "123.0"
@pytest.mark.asyncio
async def test_list_sessions_passes_cwd_filter(self, agent):
with patch.object(agent.session_manager, "list_sessions", return_value=[]) as mock_list:
await agent.list_sessions(cwd="/mnt/e/Projects/AI/browser-link-3")
mock_list.assert_called_once_with(cwd="/mnt/e/Projects/AI/browser-link-3")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# session configuration / model routing # session configuration / model routing
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -257,6 +305,53 @@ class TestSessionConfiguration:
assert result == {} assert result == {}
assert state.model == "gpt-5.4" assert state.model == "gpt-5.4"
@pytest.mark.asyncio
async def test_set_session_model_accepts_provider_prefixed_choice(self, tmp_path, monkeypatch):
runtime_calls = []
def fake_resolve_runtime_provider(requested=None, **kwargs):
runtime_calls.append(requested)
provider = requested or "openrouter"
return {
"provider": provider,
"api_mode": "anthropic_messages" if provider == "anthropic" else "chat_completions",
"base_url": f"https://{provider}.example/v1",
"api_key": f"{provider}-key",
"command": None,
"args": [],
}
def fake_agent(**kwargs):
return SimpleNamespace(
model=kwargs.get("model"),
provider=kwargs.get("provider"),
base_url=kwargs.get("base_url"),
api_mode=kwargs.get("api_mode"),
)
monkeypatch.setattr("hermes_cli.config.load_config", lambda: {
"model": {"provider": "openrouter", "default": "openrouter/gpt-5"}
})
monkeypatch.setattr(
"hermes_cli.runtime_provider.resolve_runtime_provider",
fake_resolve_runtime_provider,
)
manager = SessionManager(db=SessionDB(tmp_path / "state.db"))
with patch("run_agent.AIAgent", side_effect=fake_agent):
acp_agent = HermesACPAgent(session_manager=manager)
state = manager.create_session(cwd="/tmp")
result = await acp_agent.set_session_model(
model_id="anthropic:claude-sonnet-4-6",
session_id=state.session_id,
)
assert isinstance(result, SetSessionModelResponse)
assert state.model == "claude-sonnet-4-6"
assert state.agent.provider == "anthropic"
assert state.agent.base_url == "https://anthropic.example/v1"
assert runtime_calls[-1] == "anthropic"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# prompt # prompt
@ -354,6 +449,31 @@ class TestPrompt:
update = last_call[1].get("update") or last_call[0][1] update = last_call[1].get("update") or last_call[0][1]
assert update.session_update == "agent_message_chunk" assert update.session_update == "agent_message_chunk"
@pytest.mark.asyncio
async def test_prompt_auto_titles_session(self, agent):
new_resp = await agent.new_session(cwd=".")
state = agent.session_manager.get_session(new_resp.session_id)
state.agent.run_conversation = MagicMock(return_value={
"final_response": "Here is the fix.",
"messages": [
{"role": "user", "content": "fix the broken ACP history"},
{"role": "assistant", "content": "Here is the fix."},
],
})
mock_conn = MagicMock(spec=acp.Client)
mock_conn.session_update = AsyncMock()
agent._conn = mock_conn
with patch("agent.title_generator.maybe_auto_title") as mock_title:
prompt = [TextContentBlock(type="text", text="fix the broken ACP history")]
await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
mock_title.assert_called_once()
assert mock_title.call_args.args[1] == new_resp.session_id
assert mock_title.call_args.args[2] == "fix the broken ACP history"
assert mock_title.call_args.args[3] == "Here is the fix."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_prompt_populates_usage_from_top_level_run_conversation_fields(self, agent): async def test_prompt_populates_usage_from_top_level_run_conversation_fields(self, agent):
"""ACP should map top-level token fields into PromptResponse.usage.""" """ACP should map top-level token fields into PromptResponse.usage."""

View file

@ -3,6 +3,7 @@
import contextlib import contextlib
import io import io
import json import json
import time
from types import SimpleNamespace from types import SimpleNamespace
import pytest import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -100,15 +101,23 @@ class TestListAndCleanup:
def test_list_sessions_returns_created(self, manager): def test_list_sessions_returns_created(self, manager):
s1 = manager.create_session(cwd="/a") s1 = manager.create_session(cwd="/a")
s2 = manager.create_session(cwd="/b") s2 = manager.create_session(cwd="/b")
s1.history.append({"role": "user", "content": "hello from a"})
s2.history.append({"role": "user", "content": "hello from b"})
listing = manager.list_sessions() listing = manager.list_sessions()
ids = {s["session_id"] for s in listing} ids = {s["session_id"] for s in listing}
assert s1.session_id in ids assert s1.session_id in ids
assert s2.session_id in ids assert s2.session_id in ids
assert len(listing) == 2 assert len(listing) == 2
def test_list_sessions_hides_empty_threads(self, manager):
manager.create_session(cwd="/empty")
assert manager.list_sessions() == []
def test_cleanup_clears_all(self, manager): def test_cleanup_clears_all(self, manager):
manager.create_session() s1 = manager.create_session()
manager.create_session() s2 = manager.create_session()
s1.history.append({"role": "user", "content": "one"})
s2.history.append({"role": "user", "content": "two"})
assert len(manager.list_sessions()) == 2 assert len(manager.list_sessions()) == 2
manager.cleanup() manager.cleanup()
assert manager.list_sessions() == [] assert manager.list_sessions() == []
@ -194,6 +203,8 @@ class TestPersistence:
def test_list_sessions_includes_db_only(self, manager): def test_list_sessions_includes_db_only(self, manager):
"""Sessions only in DB (not in memory) appear in list_sessions.""" """Sessions only in DB (not in memory) appear in list_sessions."""
state = manager.create_session(cwd="/db-only") state = manager.create_session(cwd="/db-only")
state.history.append({"role": "user", "content": "database only thread"})
manager.save_session(state.session_id)
sid = state.session_id sid = state.session_id
# Drop from memory. # Drop from memory.
@ -204,6 +215,53 @@ class TestPersistence:
ids = {s["session_id"] for s in listing} ids = {s["session_id"] for s in listing}
assert sid in ids assert sid in ids
def test_list_sessions_filters_by_cwd(self, manager):
keep = manager.create_session(cwd="/keep")
drop = manager.create_session(cwd="/drop")
keep.history.append({"role": "user", "content": "keep me"})
drop.history.append({"role": "user", "content": "drop me"})
listing = manager.list_sessions(cwd="/keep")
ids = {s["session_id"] for s in listing}
assert keep.session_id in ids
assert drop.session_id not in ids
def test_list_sessions_matches_windows_and_wsl_paths(self, manager):
state = manager.create_session(cwd="/mnt/e/Projects/AI/browser-link-3")
state.history.append({"role": "user", "content": "same project from WSL"})
listing = manager.list_sessions(cwd=r"E:\Projects\AI\browser-link-3")
ids = {s["session_id"] for s in listing}
assert state.session_id in ids
def test_list_sessions_prefers_title_then_preview(self, manager):
state = manager.create_session(cwd="/named")
state.history.append({"role": "user", "content": "Investigate broken ACP history in Zed"})
manager.save_session(state.session_id)
db = manager._get_db()
db.set_session_title(state.session_id, "Fix Zed ACP history")
listing = manager.list_sessions(cwd="/named")
assert listing[0]["title"] == "Fix Zed ACP history"
db.set_session_title(state.session_id, "")
listing = manager.list_sessions(cwd="/named")
assert listing[0]["title"].startswith("Investigate broken ACP history")
def test_list_sessions_sorted_by_most_recent_activity(self, manager):
older = manager.create_session(cwd="/ordered")
older.history.append({"role": "user", "content": "older"})
manager.save_session(older.session_id)
time.sleep(0.02)
newer = manager.create_session(cwd="/ordered")
newer.history.append({"role": "user", "content": "newer"})
manager.save_session(newer.session_id)
listing = manager.list_sessions(cwd="/ordered")
assert [item["session_id"] for item in listing[:2]] == [newer.session_id, older.session_id]
assert listing[0]["updated_at"]
assert listing[1]["updated_at"]
def test_fork_restores_source_from_db(self, manager): def test_fork_restores_source_from_db(self, manager):
"""Forking a session that is only in DB should work.""" """Forking a session that is only in DB should work."""
original = manager.create_session() original = manager.create_session()

View file

@ -215,6 +215,46 @@ class TestBuildToolComplete:
assert len(display_text) < 6000 assert len(display_text) < 6000
assert "truncated" in display_text assert "truncated" in display_text
def test_build_tool_complete_for_patch_uses_diff_blocks(self):
"""Completed patch calls should keep structured diff content for Zed."""
patch_result = (
'{"success": true, "diff": "--- a/README.md\\n+++ b/README.md\\n@@ -1 +1,2 @@\\n old line\\n+new line\\n", '
'"files_modified": ["README.md"]}'
)
result = build_tool_complete("tc-p1", "patch", patch_result)
assert isinstance(result, ToolCallProgress)
assert len(result.content) == 1
diff_item = result.content[0]
assert isinstance(diff_item, FileEditToolCallContent)
assert diff_item.path == "README.md"
assert diff_item.old_text == "old line"
assert diff_item.new_text == "old line\nnew line"
def test_build_tool_complete_for_patch_falls_back_to_text_when_no_diff(self):
result = build_tool_complete("tc-p2", "patch", '{"success": true}')
assert isinstance(result, ToolCallProgress)
assert isinstance(result.content[0], ContentToolCallContent)
def test_build_tool_complete_for_write_file_uses_snapshot_diff(self, tmp_path):
target = tmp_path / "diff-test.txt"
snapshot = type("Snapshot", (), {"paths": [target], "before": {str(target): None}})()
target.write_text("hello from hermes\n", encoding="utf-8")
result = build_tool_complete(
"tc-wf1",
"write_file",
'{"bytes_written": 18, "dirs_created": false}',
function_args={"path": str(target), "content": "hello from hermes\n"},
snapshot=snapshot,
)
assert isinstance(result, ToolCallProgress)
assert len(result.content) == 1
diff_item = result.content[0]
assert isinstance(diff_item, FileEditToolCallContent)
assert diff_item.path.endswith("diff-test.txt")
assert diff_item.old_text is None
assert diff_item.new_text == "hello from hermes"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# extract_locations # extract_locations

View file

@ -2,7 +2,8 @@
Surrogates (U+D800..U+DFFF) are invalid in UTF-8 and crash json.dumps() Surrogates (U+D800..U+DFFF) are invalid in UTF-8 and crash json.dumps()
inside the OpenAI SDK. They can appear via clipboard paste from rich-text inside the OpenAI SDK. They can appear via clipboard paste from rich-text
editors like Google Docs. editors like Google Docs, OR from byte-level reasoning models (xiaomi/mimo,
kimi, glm) emitting lone halves in reasoning output.
""" """
import json import json
import pytest import pytest
@ -11,6 +12,7 @@ from unittest.mock import MagicMock, patch
from run_agent import ( from run_agent import (
_sanitize_surrogates, _sanitize_surrogates,
_sanitize_messages_surrogates, _sanitize_messages_surrogates,
_sanitize_structure_surrogates,
_SURROGATE_RE, _SURROGATE_RE,
) )
@ -109,6 +111,186 @@ class TestSanitizeMessagesSurrogates:
assert "\ufffd" in msgs[0]["content"] assert "\ufffd" in msgs[0]["content"]
class TestReasoningFieldSurrogates:
"""Surrogates in reasoning fields (byte-level reasoning models).
xiaomi/mimo, kimi, glm and similar byte-level tokenizers can emit lone
surrogates in reasoning output. These fields are carried through to the
API as `reasoning_content` on assistant messages, and must be sanitized
or json.dumps() crashes with 'utf-8' codec can't encode surrogates.
"""
def test_reasoning_field_sanitized(self):
msgs = [
{"role": "assistant", "content": "ok", "reasoning": "thought \udce2 here"},
]
assert _sanitize_messages_surrogates(msgs) is True
assert "\udce2" not in msgs[0]["reasoning"]
assert "\ufffd" in msgs[0]["reasoning"]
def test_reasoning_content_field_sanitized(self):
"""api_messages carry `reasoning_content` built from `reasoning`."""
msgs = [
{"role": "assistant", "content": "ok", "reasoning_content": "thought \udce2 here"},
]
assert _sanitize_messages_surrogates(msgs) is True
assert "\udce2" not in msgs[0]["reasoning_content"]
assert "\ufffd" in msgs[0]["reasoning_content"]
def test_reasoning_details_nested_sanitized(self):
"""reasoning_details is a list of dicts with nested string fields."""
msgs = [
{
"role": "assistant",
"content": "ok",
"reasoning_details": [
{"type": "reasoning.summary", "summary": "summary \udce2 text"},
{"type": "reasoning.text", "text": "chain \udc00 of thought"},
],
},
]
assert _sanitize_messages_surrogates(msgs) is True
assert "\udce2" not in msgs[0]["reasoning_details"][0]["summary"]
assert "\ufffd" in msgs[0]["reasoning_details"][0]["summary"]
assert "\udc00" not in msgs[0]["reasoning_details"][1]["text"]
assert "\ufffd" in msgs[0]["reasoning_details"][1]["text"]
def test_deeply_nested_reasoning_sanitized(self):
"""Nested dicts / lists inside extra fields are recursed into."""
msgs = [
{
"role": "assistant",
"content": "ok",
"reasoning_details": [
{
"type": "reasoning.encrypted",
"content": {
"encrypted_content": "opaque",
"text_parts": ["part1", "part2 \udce2 part"],
},
},
],
},
]
assert _sanitize_messages_surrogates(msgs) is True
assert (
msgs[0]["reasoning_details"][0]["content"]["text_parts"][1]
== "part2 \ufffd part"
)
def test_reasoning_end_to_end_json_serialization(self):
"""After sanitization, the full message dict must serialize clean."""
msgs = [
{
"role": "assistant",
"content": "answer",
"reasoning_content": "reasoning with \udce2 surrogate",
"reasoning_details": [
{"summary": "nested \udcb0 surrogate"},
],
},
]
_sanitize_messages_surrogates(msgs)
# Must round-trip through json + utf-8 encoding without error
payload = json.dumps(msgs, ensure_ascii=False).encode("utf-8")
assert b"\\" not in payload[:0] # sanity — just ensure we got bytes
assert len(payload) > 0
def test_no_surrogates_returns_false(self):
"""Clean reasoning fields don't trigger a modification."""
msgs = [
{
"role": "assistant",
"content": "ok",
"reasoning": "clean thought",
"reasoning_content": "also clean",
"reasoning_details": [{"summary": "clean summary"}],
},
]
assert _sanitize_messages_surrogates(msgs) is False
class TestSanitizeStructureSurrogates:
"""Test the _sanitize_structure_surrogates() helper for nested payloads."""
def test_empty_payload(self):
assert _sanitize_structure_surrogates({}) is False
assert _sanitize_structure_surrogates([]) is False
def test_flat_dict(self):
payload = {"a": "clean", "b": "dirty \udce2 text"}
assert _sanitize_structure_surrogates(payload) is True
assert payload["a"] == "clean"
assert "\ufffd" in payload["b"]
def test_flat_list(self):
payload = ["clean", "dirty \udce2"]
assert _sanitize_structure_surrogates(payload) is True
assert payload[0] == "clean"
assert "\ufffd" in payload[1]
def test_nested_dict_in_list(self):
payload = [{"x": "dirty \udce2"}, {"x": "clean"}]
assert _sanitize_structure_surrogates(payload) is True
assert "\ufffd" in payload[0]["x"]
assert payload[1]["x"] == "clean"
def test_deeply_nested(self):
payload = {
"level1": {
"level2": [
{"level3": "deep \udce2 surrogate"},
],
},
}
assert _sanitize_structure_surrogates(payload) is True
assert "\ufffd" in payload["level1"]["level2"][0]["level3"]
def test_clean_payload_returns_false(self):
payload = {"a": "clean", "b": [{"c": "also clean"}]}
assert _sanitize_structure_surrogates(payload) is False
def test_non_string_values_ignored(self):
payload = {"int": 42, "list": [1, 2, 3], "dict": {"none": None}, "bool": True}
assert _sanitize_structure_surrogates(payload) is False
# Non-string values survive unchanged
assert payload["int"] == 42
assert payload["list"] == [1, 2, 3]
class TestApiMessagesSurrogateRecovery:
"""Integration: verify the recovery block sanitizes api_messages.
The bug this guards against: a surrogate in `reasoning_content` on
api_messages (transformed from `reasoning` during build) crashes the
OpenAI SDK's json.dumps(), and the recovery block previously only
sanitized the canonical `messages` list not `api_messages` so the
next retry would send the same broken payload and fail 3 times.
"""
def test_api_messages_reasoning_content_sanitized(self):
"""The extended sanitizer catches reasoning_content in api_messages."""
api_messages = [
{"role": "system", "content": "sys"},
{
"role": "assistant",
"content": "response",
"reasoning_content": "thought \udce2 trail",
"tool_calls": [
{
"id": "call_1",
"function": {"name": "tool", "arguments": "{}"},
}
],
},
{"role": "tool", "content": "result", "tool_call_id": "call_1"},
]
assert _sanitize_messages_surrogates(api_messages) is True
assert "\udce2" not in api_messages[1]["reasoning_content"]
# Full payload must now serialize clean
json.dumps(api_messages, ensure_ascii=False).encode("utf-8")
class TestRunConversationSurrogateSanitization: class TestRunConversationSurrogateSanitization:
"""Integration: verify run_conversation sanitizes user_message.""" """Integration: verify run_conversation sanitizes user_message."""

View file

@ -34,3 +34,21 @@ def test_messaging_extra_includes_qrcode_for_weixin_setup():
messaging_extra = optional_dependencies["messaging"] messaging_extra = optional_dependencies["messaging"]
assert any(dep.startswith("qrcode") for dep in messaging_extra) assert any(dep.startswith("qrcode") for dep in messaging_extra)
def test_dingtalk_extra_includes_qrcode_for_qr_auth():
"""DingTalk's QR-code device-flow auth (hermes_cli/dingtalk_auth.py)
needs the qrcode package."""
optional_dependencies = _load_optional_dependencies()
dingtalk_extra = optional_dependencies["dingtalk"]
assert any(dep.startswith("qrcode") for dep in dingtalk_extra)
def test_feishu_extra_includes_qrcode_for_qr_login():
"""Feishu's QR login flow (gateway/platforms/feishu.py) needs the
qrcode package."""
optional_dependencies = _load_optional_dependencies()
feishu_extra = optional_dependencies["feishu"]
assert any(dep.startswith("qrcode") for dep in feishu_extra)