mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor
This commit is contained in:
commit
bd09e42eac
14 changed files with 1072 additions and 100 deletions
|
|
@ -49,6 +49,7 @@ def make_tool_progress_cb(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
loop: asyncio.AbstractEventLoop,
|
loop: asyncio.AbstractEventLoop,
|
||||||
tool_call_ids: Dict[str, Deque[str]],
|
tool_call_ids: Dict[str, Deque[str]],
|
||||||
|
tool_call_meta: Dict[str, Dict[str, Any]],
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
"""Create a ``tool_progress_callback`` for AIAgent.
|
"""Create a ``tool_progress_callback`` for AIAgent.
|
||||||
|
|
||||||
|
|
@ -84,6 +85,16 @@ def make_tool_progress_cb(
|
||||||
tool_call_ids[name] = queue
|
tool_call_ids[name] = queue
|
||||||
queue.append(tc_id)
|
queue.append(tc_id)
|
||||||
|
|
||||||
|
snapshot = None
|
||||||
|
if name in {"write_file", "patch", "skill_manage"}:
|
||||||
|
try:
|
||||||
|
from agent.display import capture_local_edit_snapshot
|
||||||
|
|
||||||
|
snapshot = capture_local_edit_snapshot(name, args)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to capture ACP edit snapshot for %s", name, exc_info=True)
|
||||||
|
tool_call_meta[tc_id] = {"args": args, "snapshot": snapshot}
|
||||||
|
|
||||||
update = build_tool_start(tc_id, name, args)
|
update = build_tool_start(tc_id, name, args)
|
||||||
_send_update(conn, session_id, loop, update)
|
_send_update(conn, session_id, loop, update)
|
||||||
|
|
||||||
|
|
@ -119,6 +130,7 @@ def make_step_cb(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
loop: asyncio.AbstractEventLoop,
|
loop: asyncio.AbstractEventLoop,
|
||||||
tool_call_ids: Dict[str, Deque[str]],
|
tool_call_ids: Dict[str, Deque[str]],
|
||||||
|
tool_call_meta: Dict[str, Dict[str, Any]],
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
"""Create a ``step_callback`` for AIAgent.
|
"""Create a ``step_callback`` for AIAgent.
|
||||||
|
|
||||||
|
|
@ -132,10 +144,12 @@ def make_step_cb(
|
||||||
for tool_info in prev_tools:
|
for tool_info in prev_tools:
|
||||||
tool_name = None
|
tool_name = None
|
||||||
result = None
|
result = None
|
||||||
|
function_args = None
|
||||||
|
|
||||||
if isinstance(tool_info, dict):
|
if isinstance(tool_info, dict):
|
||||||
tool_name = tool_info.get("name") or tool_info.get("function_name")
|
tool_name = tool_info.get("name") or tool_info.get("function_name")
|
||||||
result = tool_info.get("result") or tool_info.get("output")
|
result = tool_info.get("result") or tool_info.get("output")
|
||||||
|
function_args = tool_info.get("arguments") or tool_info.get("args")
|
||||||
elif isinstance(tool_info, str):
|
elif isinstance(tool_info, str):
|
||||||
tool_name = tool_info
|
tool_name = tool_info
|
||||||
|
|
||||||
|
|
@ -145,8 +159,13 @@ def make_step_cb(
|
||||||
tool_call_ids[tool_name] = queue
|
tool_call_ids[tool_name] = queue
|
||||||
if tool_name and queue:
|
if tool_name and queue:
|
||||||
tc_id = queue.popleft()
|
tc_id = queue.popleft()
|
||||||
|
meta = tool_call_meta.pop(tc_id, {})
|
||||||
update = build_tool_complete(
|
update = build_tool_complete(
|
||||||
tc_id, tool_name, result=str(result) if result is not None else None
|
tc_id,
|
||||||
|
tool_name,
|
||||||
|
result=str(result) if result is not None else None,
|
||||||
|
function_args=function_args or meta.get("args"),
|
||||||
|
snapshot=meta.get("snapshot"),
|
||||||
)
|
)
|
||||||
_send_update(conn, session_id, loop, update)
|
_send_update(conn, session_id, loop, update)
|
||||||
if not queue:
|
if not queue:
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from acp.schema import (
|
||||||
McpServerHttp,
|
McpServerHttp,
|
||||||
McpServerSse,
|
McpServerSse,
|
||||||
McpServerStdio,
|
McpServerStdio,
|
||||||
|
ModelInfo,
|
||||||
NewSessionResponse,
|
NewSessionResponse,
|
||||||
PromptResponse,
|
PromptResponse,
|
||||||
ResumeSessionResponse,
|
ResumeSessionResponse,
|
||||||
|
|
@ -36,6 +37,7 @@ from acp.schema import (
|
||||||
SessionCapabilities,
|
SessionCapabilities,
|
||||||
SessionForkCapabilities,
|
SessionForkCapabilities,
|
||||||
SessionListCapabilities,
|
SessionListCapabilities,
|
||||||
|
SessionModelState,
|
||||||
SessionResumeCapabilities,
|
SessionResumeCapabilities,
|
||||||
SessionInfo,
|
SessionInfo,
|
||||||
TextContentBlock,
|
TextContentBlock,
|
||||||
|
|
@ -147,6 +149,98 @@ class HermesACPAgent(acp.Agent):
|
||||||
self._conn = conn
|
self._conn = conn
|
||||||
logger.info("ACP client connected")
|
logger.info("ACP client connected")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _encode_model_choice(provider: str | None, model: str | None) -> str:
|
||||||
|
"""Encode a model selection so ACP clients can keep provider context."""
|
||||||
|
raw_model = str(model or "").strip()
|
||||||
|
if not raw_model:
|
||||||
|
return ""
|
||||||
|
raw_provider = str(provider or "").strip().lower()
|
||||||
|
if not raw_provider:
|
||||||
|
return raw_model
|
||||||
|
return f"{raw_provider}:{raw_model}"
|
||||||
|
|
||||||
|
def _build_model_state(self, state: SessionState) -> SessionModelState | None:
|
||||||
|
"""Return the ACP model selector payload for editors like Zed."""
|
||||||
|
model = str(state.model or getattr(state.agent, "model", "") or "").strip()
|
||||||
|
provider = getattr(state.agent, "provider", None) or detect_provider() or "openrouter"
|
||||||
|
|
||||||
|
try:
|
||||||
|
from hermes_cli.models import curated_models_for_provider, normalize_provider, provider_label
|
||||||
|
|
||||||
|
normalized_provider = normalize_provider(provider)
|
||||||
|
provider_name = provider_label(normalized_provider)
|
||||||
|
available_models: list[ModelInfo] = []
|
||||||
|
seen_ids: set[str] = set()
|
||||||
|
|
||||||
|
for model_id, description in curated_models_for_provider(normalized_provider):
|
||||||
|
rendered_model = str(model_id or "").strip()
|
||||||
|
if not rendered_model:
|
||||||
|
continue
|
||||||
|
choice_id = self._encode_model_choice(normalized_provider, rendered_model)
|
||||||
|
if choice_id in seen_ids:
|
||||||
|
continue
|
||||||
|
desc_parts = [f"Provider: {provider_name}"]
|
||||||
|
if description:
|
||||||
|
desc_parts.append(str(description).strip())
|
||||||
|
if rendered_model == model:
|
||||||
|
desc_parts.append("current")
|
||||||
|
available_models.append(
|
||||||
|
ModelInfo(
|
||||||
|
model_id=choice_id,
|
||||||
|
name=rendered_model,
|
||||||
|
description=" • ".join(part for part in desc_parts if part),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_ids.add(choice_id)
|
||||||
|
|
||||||
|
current_model_id = self._encode_model_choice(normalized_provider, model)
|
||||||
|
if current_model_id and current_model_id not in seen_ids:
|
||||||
|
available_models.insert(
|
||||||
|
0,
|
||||||
|
ModelInfo(
|
||||||
|
model_id=current_model_id,
|
||||||
|
name=model,
|
||||||
|
description=f"Provider: {provider_name} • current",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if available_models:
|
||||||
|
return SessionModelState(
|
||||||
|
available_models=available_models,
|
||||||
|
current_model_id=current_model_id or available_models[0].model_id,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Could not build ACP model state", exc_info=True)
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
return None
|
||||||
|
|
||||||
|
fallback_choice = self._encode_model_choice(provider, model)
|
||||||
|
return SessionModelState(
|
||||||
|
available_models=[ModelInfo(model_id=fallback_choice, name=model)],
|
||||||
|
current_model_id=fallback_choice,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_model_selection(raw_model: str, current_provider: str) -> tuple[str, str]:
|
||||||
|
"""Resolve ``provider:model`` input into the provider and normalized model id."""
|
||||||
|
target_provider = current_provider
|
||||||
|
new_model = raw_model.strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from hermes_cli.models import detect_provider_for_model, parse_model_input
|
||||||
|
|
||||||
|
target_provider, new_model = parse_model_input(new_model, current_provider)
|
||||||
|
if target_provider == current_provider:
|
||||||
|
detected = detect_provider_for_model(new_model, current_provider)
|
||||||
|
if detected:
|
||||||
|
target_provider, new_model = detected
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Provider detection failed, using model as-is", exc_info=True)
|
||||||
|
|
||||||
|
return target_provider, new_model
|
||||||
|
|
||||||
async def _register_session_mcp_servers(
|
async def _register_session_mcp_servers(
|
||||||
self,
|
self,
|
||||||
state: SessionState,
|
state: SessionState,
|
||||||
|
|
@ -273,7 +367,10 @@ class HermesACPAgent(acp.Agent):
|
||||||
await self._register_session_mcp_servers(state, mcp_servers)
|
await self._register_session_mcp_servers(state, mcp_servers)
|
||||||
logger.info("New session %s (cwd=%s)", state.session_id, cwd)
|
logger.info("New session %s (cwd=%s)", state.session_id, cwd)
|
||||||
self._schedule_available_commands_update(state.session_id)
|
self._schedule_available_commands_update(state.session_id)
|
||||||
return NewSessionResponse(session_id=state.session_id)
|
return NewSessionResponse(
|
||||||
|
session_id=state.session_id,
|
||||||
|
models=self._build_model_state(state),
|
||||||
|
)
|
||||||
|
|
||||||
async def load_session(
|
async def load_session(
|
||||||
self,
|
self,
|
||||||
|
|
@ -289,7 +386,7 @@ class HermesACPAgent(acp.Agent):
|
||||||
await self._register_session_mcp_servers(state, mcp_servers)
|
await self._register_session_mcp_servers(state, mcp_servers)
|
||||||
logger.info("Loaded session %s", session_id)
|
logger.info("Loaded session %s", session_id)
|
||||||
self._schedule_available_commands_update(session_id)
|
self._schedule_available_commands_update(session_id)
|
||||||
return LoadSessionResponse()
|
return LoadSessionResponse(models=self._build_model_state(state))
|
||||||
|
|
||||||
async def resume_session(
|
async def resume_session(
|
||||||
self,
|
self,
|
||||||
|
|
@ -305,7 +402,7 @@ class HermesACPAgent(acp.Agent):
|
||||||
await self._register_session_mcp_servers(state, mcp_servers)
|
await self._register_session_mcp_servers(state, mcp_servers)
|
||||||
logger.info("Resumed session %s", state.session_id)
|
logger.info("Resumed session %s", state.session_id)
|
||||||
self._schedule_available_commands_update(state.session_id)
|
self._schedule_available_commands_update(state.session_id)
|
||||||
return ResumeSessionResponse()
|
return ResumeSessionResponse(models=self._build_model_state(state))
|
||||||
|
|
||||||
async def cancel(self, session_id: str, **kwargs: Any) -> None:
|
async def cancel(self, session_id: str, **kwargs: Any) -> None:
|
||||||
state = self.session_manager.get_session(session_id)
|
state = self.session_manager.get_session(session_id)
|
||||||
|
|
@ -340,11 +437,20 @@ class HermesACPAgent(acp.Agent):
|
||||||
cwd: str | None = None,
|
cwd: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ListSessionsResponse:
|
) -> ListSessionsResponse:
|
||||||
infos = self.session_manager.list_sessions()
|
infos = self.session_manager.list_sessions(cwd=cwd)
|
||||||
sessions = [
|
sessions = []
|
||||||
SessionInfo(session_id=s["session_id"], cwd=s["cwd"])
|
for s in infos:
|
||||||
for s in infos
|
updated_at = s.get("updated_at")
|
||||||
]
|
if updated_at is not None and not isinstance(updated_at, str):
|
||||||
|
updated_at = str(updated_at)
|
||||||
|
sessions.append(
|
||||||
|
SessionInfo(
|
||||||
|
session_id=s["session_id"],
|
||||||
|
cwd=s["cwd"],
|
||||||
|
title=s.get("title"),
|
||||||
|
updated_at=updated_at,
|
||||||
|
)
|
||||||
|
)
|
||||||
return ListSessionsResponse(sessions=sessions)
|
return ListSessionsResponse(sessions=sessions)
|
||||||
|
|
||||||
# ---- Prompt (core) ------------------------------------------------------
|
# ---- Prompt (core) ------------------------------------------------------
|
||||||
|
|
@ -389,12 +495,13 @@ class HermesACPAgent(acp.Agent):
|
||||||
state.cancel_event.clear()
|
state.cancel_event.clear()
|
||||||
|
|
||||||
tool_call_ids: dict[str, Deque[str]] = defaultdict(deque)
|
tool_call_ids: dict[str, Deque[str]] = defaultdict(deque)
|
||||||
|
tool_call_meta: dict[str, dict[str, Any]] = {}
|
||||||
previous_approval_cb = None
|
previous_approval_cb = None
|
||||||
|
|
||||||
if conn:
|
if conn:
|
||||||
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids)
|
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
|
||||||
thinking_cb = make_thinking_cb(conn, session_id, loop)
|
thinking_cb = make_thinking_cb(conn, session_id, loop)
|
||||||
step_cb = make_step_cb(conn, session_id, loop, tool_call_ids)
|
step_cb = make_step_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
|
||||||
message_cb = make_message_cb(conn, session_id, loop)
|
message_cb = make_message_cb(conn, session_id, loop)
|
||||||
approval_cb = make_approval_callback(conn.request_permission, loop, session_id)
|
approval_cb = make_approval_callback(conn.request_permission, loop, session_id)
|
||||||
else:
|
else:
|
||||||
|
|
@ -449,6 +556,19 @@ class HermesACPAgent(acp.Agent):
|
||||||
self.session_manager.save_session(session_id)
|
self.session_manager.save_session(session_id)
|
||||||
|
|
||||||
final_response = result.get("final_response", "")
|
final_response = result.get("final_response", "")
|
||||||
|
if final_response:
|
||||||
|
try:
|
||||||
|
from agent.title_generator import maybe_auto_title
|
||||||
|
|
||||||
|
maybe_auto_title(
|
||||||
|
self.session_manager._get_db(),
|
||||||
|
session_id,
|
||||||
|
user_text,
|
||||||
|
final_response,
|
||||||
|
state.history,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to auto-title ACP session %s", session_id, exc_info=True)
|
||||||
if final_response and conn:
|
if final_response and conn:
|
||||||
update = acp.update_agent_message_text(final_response)
|
update = acp.update_agent_message_text(final_response)
|
||||||
await conn.session_update(session_id, update)
|
await conn.session_update(session_id, update)
|
||||||
|
|
@ -556,27 +676,15 @@ class HermesACPAgent(acp.Agent):
|
||||||
provider = getattr(state.agent, "provider", None) or "auto"
|
provider = getattr(state.agent, "provider", None) or "auto"
|
||||||
return f"Current model: {model}\nProvider: {provider}"
|
return f"Current model: {model}\nProvider: {provider}"
|
||||||
|
|
||||||
new_model = args.strip()
|
|
||||||
target_provider = None
|
|
||||||
current_provider = getattr(state.agent, "provider", None) or "openrouter"
|
current_provider = getattr(state.agent, "provider", None) or "openrouter"
|
||||||
|
target_provider, new_model = self._resolve_model_selection(args, current_provider)
|
||||||
# Auto-detect provider for the requested model
|
|
||||||
try:
|
|
||||||
from hermes_cli.models import parse_model_input, detect_provider_for_model
|
|
||||||
target_provider, new_model = parse_model_input(new_model, current_provider)
|
|
||||||
if target_provider == current_provider:
|
|
||||||
detected = detect_provider_for_model(new_model, current_provider)
|
|
||||||
if detected:
|
|
||||||
target_provider, new_model = detected
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Provider detection failed, using model as-is", exc_info=True)
|
|
||||||
|
|
||||||
state.model = new_model
|
state.model = new_model
|
||||||
state.agent = self.session_manager._make_agent(
|
state.agent = self.session_manager._make_agent(
|
||||||
session_id=state.session_id,
|
session_id=state.session_id,
|
||||||
cwd=state.cwd,
|
cwd=state.cwd,
|
||||||
model=new_model,
|
model=new_model,
|
||||||
requested_provider=target_provider or current_provider,
|
requested_provider=target_provider,
|
||||||
)
|
)
|
||||||
self.session_manager.save_session(state.session_id)
|
self.session_manager.save_session(state.session_id)
|
||||||
provider_label = getattr(state.agent, "provider", None) or target_provider or current_provider
|
provider_label = getattr(state.agent, "provider", None) or target_provider or current_provider
|
||||||
|
|
@ -678,20 +786,30 @@ class HermesACPAgent(acp.Agent):
|
||||||
"""Switch the model for a session (called by ACP protocol)."""
|
"""Switch the model for a session (called by ACP protocol)."""
|
||||||
state = self.session_manager.get_session(session_id)
|
state = self.session_manager.get_session(session_id)
|
||||||
if state:
|
if state:
|
||||||
state.model = model_id
|
|
||||||
current_provider = getattr(state.agent, "provider", None)
|
current_provider = getattr(state.agent, "provider", None)
|
||||||
current_base_url = getattr(state.agent, "base_url", None)
|
requested_provider, resolved_model = self._resolve_model_selection(
|
||||||
current_api_mode = getattr(state.agent, "api_mode", None)
|
model_id,
|
||||||
|
current_provider or "openrouter",
|
||||||
|
)
|
||||||
|
state.model = resolved_model
|
||||||
|
provider_changed = bool(current_provider and requested_provider != current_provider)
|
||||||
|
current_base_url = None if provider_changed else getattr(state.agent, "base_url", None)
|
||||||
|
current_api_mode = None if provider_changed else getattr(state.agent, "api_mode", None)
|
||||||
state.agent = self.session_manager._make_agent(
|
state.agent = self.session_manager._make_agent(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
cwd=state.cwd,
|
cwd=state.cwd,
|
||||||
model=model_id,
|
model=resolved_model,
|
||||||
requested_provider=current_provider,
|
requested_provider=requested_provider,
|
||||||
base_url=current_base_url,
|
base_url=current_base_url,
|
||||||
api_mode=current_api_mode,
|
api_mode=current_api_mode,
|
||||||
)
|
)
|
||||||
self.session_manager.save_session(session_id)
|
self.session_manager.save_session(session_id)
|
||||||
logger.info("Session %s: model switched to %s", session_id, model_id)
|
logger.info(
|
||||||
|
"Session %s: model switched to %s via provider %s",
|
||||||
|
session_id,
|
||||||
|
resolved_model,
|
||||||
|
requested_provider,
|
||||||
|
)
|
||||||
return SetSessionModelResponse()
|
return SetSessionModelResponse()
|
||||||
logger.warning("Session %s: model switch requested for missing session", session_id)
|
logger.warning("Session %s: model switch requested for missing session", session_id)
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,12 @@ from hermes_constants import get_hermes_home
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
@ -22,6 +26,64 @@ from typing import Any, Dict, List, Optional
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_cwd_for_compare(cwd: str | None) -> str:
|
||||||
|
raw = str(cwd or ".").strip()
|
||||||
|
if not raw:
|
||||||
|
raw = "."
|
||||||
|
expanded = os.path.expanduser(raw)
|
||||||
|
|
||||||
|
# Normalize Windows drive paths into the equivalent WSL mount form so
|
||||||
|
# ACP history filters match the same workspace across Windows and WSL.
|
||||||
|
match = re.match(r"^([A-Za-z]):[\\/](.*)$", expanded)
|
||||||
|
if match:
|
||||||
|
drive = match.group(1).lower()
|
||||||
|
tail = match.group(2).replace("\\", "/")
|
||||||
|
expanded = f"/mnt/{drive}/{tail}"
|
||||||
|
elif re.match(r"^/mnt/[A-Za-z]/", expanded):
|
||||||
|
expanded = f"/mnt/{expanded[5].lower()}/{expanded[7:]}"
|
||||||
|
|
||||||
|
return os.path.normpath(expanded)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_session_title(title: Any, preview: Any, cwd: str | None) -> str:
|
||||||
|
explicit = str(title or "").strip()
|
||||||
|
if explicit:
|
||||||
|
return explicit
|
||||||
|
preview_text = str(preview or "").strip()
|
||||||
|
if preview_text:
|
||||||
|
return preview_text
|
||||||
|
leaf = os.path.basename(str(cwd or "").rstrip("/\\"))
|
||||||
|
return leaf or "New thread"
|
||||||
|
|
||||||
|
|
||||||
|
def _format_updated_at(value: Any) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, str) and value.strip():
|
||||||
|
return value
|
||||||
|
try:
|
||||||
|
return datetime.fromtimestamp(float(value), tz=timezone.utc).isoformat()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _updated_at_sort_key(value: Any) -> float:
|
||||||
|
if value is None:
|
||||||
|
return float("-inf")
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return float(value)
|
||||||
|
raw = str(value).strip()
|
||||||
|
if not raw:
|
||||||
|
return float("-inf")
|
||||||
|
try:
|
||||||
|
return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp()
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
return float(raw)
|
||||||
|
except Exception:
|
||||||
|
return float("-inf")
|
||||||
|
|
||||||
|
|
||||||
def _acp_stderr_print(*args, **kwargs) -> None:
|
def _acp_stderr_print(*args, **kwargs) -> None:
|
||||||
"""Best-effort human-readable output sink for ACP stdio sessions.
|
"""Best-effort human-readable output sink for ACP stdio sessions.
|
||||||
|
|
||||||
|
|
@ -162,47 +224,78 @@ class SessionManager:
|
||||||
logger.info("Forked ACP session %s -> %s", session_id, new_id)
|
logger.info("Forked ACP session %s -> %s", session_id, new_id)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def list_sessions(self) -> List[Dict[str, Any]]:
|
def list_sessions(self, cwd: str | None = None) -> List[Dict[str, Any]]:
|
||||||
"""Return lightweight info dicts for all sessions (memory + database)."""
|
"""Return lightweight info dicts for all sessions (memory + database)."""
|
||||||
|
normalized_cwd = _normalize_cwd_for_compare(cwd) if cwd else None
|
||||||
|
db = self._get_db()
|
||||||
|
persisted_rows: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
if db is not None:
|
||||||
|
try:
|
||||||
|
for row in db.list_sessions_rich(source="acp", limit=1000):
|
||||||
|
persisted_rows[str(row["id"])] = dict(row)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to load ACP sessions from DB", exc_info=True)
|
||||||
|
|
||||||
# Collect in-memory sessions first.
|
# Collect in-memory sessions first.
|
||||||
with self._lock:
|
with self._lock:
|
||||||
seen_ids = set(self._sessions.keys())
|
seen_ids = set(self._sessions.keys())
|
||||||
results = [
|
results = []
|
||||||
{
|
for s in self._sessions.values():
|
||||||
"session_id": s.session_id,
|
history_len = len(s.history)
|
||||||
"cwd": s.cwd,
|
if history_len <= 0:
|
||||||
"model": s.model,
|
continue
|
||||||
"history_len": len(s.history),
|
if normalized_cwd and _normalize_cwd_for_compare(s.cwd) != normalized_cwd:
|
||||||
}
|
continue
|
||||||
for s in self._sessions.values()
|
persisted = persisted_rows.get(s.session_id, {})
|
||||||
]
|
preview = next(
|
||||||
|
(
|
||||||
|
str(msg.get("content") or "").strip()
|
||||||
|
for msg in s.history
|
||||||
|
if msg.get("role") == "user" and str(msg.get("content") or "").strip()
|
||||||
|
),
|
||||||
|
persisted.get("preview") or "",
|
||||||
|
)
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"session_id": s.session_id,
|
||||||
|
"cwd": s.cwd,
|
||||||
|
"model": s.model,
|
||||||
|
"history_len": history_len,
|
||||||
|
"title": _build_session_title(persisted.get("title"), preview, s.cwd),
|
||||||
|
"updated_at": _format_updated_at(
|
||||||
|
persisted.get("last_active") or persisted.get("started_at") or time.time()
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Merge any persisted sessions not currently in memory.
|
# Merge any persisted sessions not currently in memory.
|
||||||
db = self._get_db()
|
for sid, row in persisted_rows.items():
|
||||||
if db is not None:
|
if sid in seen_ids:
|
||||||
try:
|
continue
|
||||||
rows = db.search_sessions(source="acp", limit=1000)
|
message_count = int(row.get("message_count") or 0)
|
||||||
for row in rows:
|
if message_count <= 0:
|
||||||
sid = row["id"]
|
continue
|
||||||
if sid in seen_ids:
|
# Extract cwd from model_config JSON.
|
||||||
continue
|
session_cwd = "."
|
||||||
# Extract cwd from model_config JSON.
|
mc = row.get("model_config")
|
||||||
cwd = "."
|
if mc:
|
||||||
mc = row.get("model_config")
|
try:
|
||||||
if mc:
|
session_cwd = json.loads(mc).get("cwd", ".")
|
||||||
try:
|
except (json.JSONDecodeError, TypeError):
|
||||||
cwd = json.loads(mc).get("cwd", ".")
|
pass
|
||||||
except (json.JSONDecodeError, TypeError):
|
if normalized_cwd and _normalize_cwd_for_compare(session_cwd) != normalized_cwd:
|
||||||
pass
|
continue
|
||||||
results.append({
|
results.append({
|
||||||
"session_id": sid,
|
"session_id": sid,
|
||||||
"cwd": cwd,
|
"cwd": session_cwd,
|
||||||
"model": row.get("model") or "",
|
"model": row.get("model") or "",
|
||||||
"history_len": row.get("message_count") or 0,
|
"history_len": message_count,
|
||||||
})
|
"title": _build_session_title(row.get("title"), row.get("preview"), session_cwd),
|
||||||
except Exception:
|
"updated_at": _format_updated_at(row.get("last_active") or row.get("started_at")),
|
||||||
logger.debug("Failed to list ACP sessions from DB", exc_info=True)
|
})
|
||||||
|
|
||||||
|
results.sort(key=lambda item: _updated_at_sort_key(item.get("updated_at")), reverse=True)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def update_cwd(self, session_id: str, cwd: str) -> Optional[SessionState]:
|
def update_cwd(self, session_id: str, cwd: str) -> Optional[SessionState]:
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
@ -96,6 +97,170 @@ def build_tool_title(tool_name: str, args: Dict[str, Any]) -> str:
|
||||||
return tool_name
|
return tool_name
|
||||||
|
|
||||||
|
|
||||||
|
def _build_patch_mode_content(patch_text: str) -> List[Any]:
|
||||||
|
"""Parse V4A patch mode input into ACP diff blocks when possible."""
|
||||||
|
if not patch_text:
|
||||||
|
return [acp.tool_content(acp.text_block(""))]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tools.patch_parser import OperationType, parse_v4a_patch
|
||||||
|
|
||||||
|
operations, error = parse_v4a_patch(patch_text)
|
||||||
|
if error or not operations:
|
||||||
|
return [acp.tool_content(acp.text_block(patch_text))]
|
||||||
|
|
||||||
|
content: List[Any] = []
|
||||||
|
for op in operations:
|
||||||
|
if op.operation == OperationType.UPDATE:
|
||||||
|
old_chunks: list[str] = []
|
||||||
|
new_chunks: list[str] = []
|
||||||
|
for hunk in op.hunks:
|
||||||
|
old_lines = [line.content for line in hunk.lines if line.prefix in (" ", "-")]
|
||||||
|
new_lines = [line.content for line in hunk.lines if line.prefix in (" ", "+")]
|
||||||
|
if old_lines or new_lines:
|
||||||
|
old_chunks.append("\n".join(old_lines))
|
||||||
|
new_chunks.append("\n".join(new_lines))
|
||||||
|
|
||||||
|
old_text = "\n...\n".join(chunk for chunk in old_chunks if chunk)
|
||||||
|
new_text = "\n...\n".join(chunk for chunk in new_chunks if chunk)
|
||||||
|
if old_text or new_text:
|
||||||
|
content.append(
|
||||||
|
acp.tool_diff_content(
|
||||||
|
path=op.file_path,
|
||||||
|
old_text=old_text or None,
|
||||||
|
new_text=new_text or "",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if op.operation == OperationType.ADD:
|
||||||
|
added_lines = [line.content for hunk in op.hunks for line in hunk.lines if line.prefix == "+"]
|
||||||
|
content.append(
|
||||||
|
acp.tool_diff_content(
|
||||||
|
path=op.file_path,
|
||||||
|
new_text="\n".join(added_lines),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if op.operation == OperationType.DELETE:
|
||||||
|
content.append(
|
||||||
|
acp.tool_diff_content(
|
||||||
|
path=op.file_path,
|
||||||
|
old_text=f"Delete file: {op.file_path}",
|
||||||
|
new_text="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if op.operation == OperationType.MOVE:
|
||||||
|
content.append(
|
||||||
|
acp.tool_content(acp.text_block(f"Move file: {op.file_path} -> {op.new_path}"))
|
||||||
|
)
|
||||||
|
|
||||||
|
return content or [acp.tool_content(acp.text_block(patch_text))]
|
||||||
|
except Exception:
|
||||||
|
return [acp.tool_content(acp.text_block(patch_text))]
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_diff_prefix(path: str) -> str:
|
||||||
|
raw = str(path or "").strip()
|
||||||
|
if raw.startswith(("a/", "b/")):
|
||||||
|
return raw[2:]
|
||||||
|
return raw
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_unified_diff_content(diff_text: str) -> List[Any]:
|
||||||
|
"""Convert unified diff text into ACP diff content blocks."""
|
||||||
|
if not diff_text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
content: List[Any] = []
|
||||||
|
current_old_path: Optional[str] = None
|
||||||
|
current_new_path: Optional[str] = None
|
||||||
|
old_lines: list[str] = []
|
||||||
|
new_lines: list[str] = []
|
||||||
|
|
||||||
|
def _flush() -> None:
|
||||||
|
nonlocal current_old_path, current_new_path, old_lines, new_lines
|
||||||
|
if current_old_path is None and current_new_path is None:
|
||||||
|
return
|
||||||
|
path = current_new_path if current_new_path and current_new_path != "/dev/null" else current_old_path
|
||||||
|
if not path or path == "/dev/null":
|
||||||
|
current_old_path = None
|
||||||
|
current_new_path = None
|
||||||
|
old_lines = []
|
||||||
|
new_lines = []
|
||||||
|
return
|
||||||
|
content.append(
|
||||||
|
acp.tool_diff_content(
|
||||||
|
path=_strip_diff_prefix(path),
|
||||||
|
old_text="\n".join(old_lines) if old_lines else None,
|
||||||
|
new_text="\n".join(new_lines),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_old_path = None
|
||||||
|
current_new_path = None
|
||||||
|
old_lines = []
|
||||||
|
new_lines = []
|
||||||
|
|
||||||
|
for line in diff_text.splitlines():
|
||||||
|
if line.startswith("--- "):
|
||||||
|
_flush()
|
||||||
|
current_old_path = line[4:].strip()
|
||||||
|
continue
|
||||||
|
if line.startswith("+++ "):
|
||||||
|
current_new_path = line[4:].strip()
|
||||||
|
continue
|
||||||
|
if line.startswith("@@"):
|
||||||
|
continue
|
||||||
|
if current_old_path is None and current_new_path is None:
|
||||||
|
continue
|
||||||
|
if line.startswith("+"):
|
||||||
|
new_lines.append(line[1:])
|
||||||
|
elif line.startswith("-"):
|
||||||
|
old_lines.append(line[1:])
|
||||||
|
elif line.startswith(" "):
|
||||||
|
shared = line[1:]
|
||||||
|
old_lines.append(shared)
|
||||||
|
new_lines.append(shared)
|
||||||
|
|
||||||
|
_flush()
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def _build_tool_complete_content(
|
||||||
|
tool_name: str,
|
||||||
|
result: Optional[str],
|
||||||
|
*,
|
||||||
|
function_args: Optional[Dict[str, Any]] = None,
|
||||||
|
snapshot: Any = None,
|
||||||
|
) -> List[Any]:
|
||||||
|
"""Build structured ACP completion content, falling back to plain text."""
|
||||||
|
display_result = result or ""
|
||||||
|
if len(display_result) > 5000:
|
||||||
|
display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)"
|
||||||
|
|
||||||
|
if tool_name in {"write_file", "patch", "skill_manage"}:
|
||||||
|
try:
|
||||||
|
from agent.display import extract_edit_diff
|
||||||
|
|
||||||
|
diff_text = extract_edit_diff(
|
||||||
|
tool_name,
|
||||||
|
result,
|
||||||
|
function_args=function_args,
|
||||||
|
snapshot=snapshot,
|
||||||
|
)
|
||||||
|
if isinstance(diff_text, str) and diff_text.strip():
|
||||||
|
diff_content = _parse_unified_diff_content(diff_text)
|
||||||
|
if diff_content:
|
||||||
|
return diff_content
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return [acp.tool_content(acp.text_block(display_result))]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Build ACP content objects for tool-call events
|
# Build ACP content objects for tool-call events
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -119,9 +284,8 @@ def build_tool_start(
|
||||||
new = arguments.get("new_string", "")
|
new = arguments.get("new_string", "")
|
||||||
content = [acp.tool_diff_content(path=path, new_text=new, old_text=old)]
|
content = [acp.tool_diff_content(path=path, new_text=new, old_text=old)]
|
||||||
else:
|
else:
|
||||||
# Patch mode — show the patch content as text
|
|
||||||
patch_text = arguments.get("patch", "")
|
patch_text = arguments.get("patch", "")
|
||||||
content = [acp.tool_content(acp.text_block(patch_text))]
|
content = _build_patch_mode_content(patch_text)
|
||||||
return acp.start_tool_call(
|
return acp.start_tool_call(
|
||||||
tool_call_id, title, kind=kind, content=content, locations=locations,
|
tool_call_id, title, kind=kind, content=content, locations=locations,
|
||||||
raw_input=arguments,
|
raw_input=arguments,
|
||||||
|
|
@ -178,16 +342,17 @@ def build_tool_complete(
|
||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
tool_name: str,
|
tool_name: str,
|
||||||
result: Optional[str] = None,
|
result: Optional[str] = None,
|
||||||
|
function_args: Optional[Dict[str, Any]] = None,
|
||||||
|
snapshot: Any = None,
|
||||||
) -> ToolCallProgress:
|
) -> ToolCallProgress:
|
||||||
"""Create a ToolCallUpdate (progress) event for a completed tool call."""
|
"""Create a ToolCallUpdate (progress) event for a completed tool call."""
|
||||||
kind = get_tool_kind(tool_name)
|
kind = get_tool_kind(tool_name)
|
||||||
|
content = _build_tool_complete_content(
|
||||||
# Truncate very large results for the UI
|
tool_name,
|
||||||
display_result = result or ""
|
result,
|
||||||
if len(display_result) > 5000:
|
function_args=function_args,
|
||||||
display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)"
|
snapshot=snapshot,
|
||||||
|
)
|
||||||
content = [acp.tool_content(acp.text_block(display_result))]
|
|
||||||
return acp.update_tool_call(
|
return acp.update_tool_call(
|
||||||
tool_call_id,
|
tool_call_id,
|
||||||
kind=kind,
|
kind=kind,
|
||||||
|
|
|
||||||
|
|
@ -76,8 +76,8 @@ termux = [
|
||||||
"hermes-agent[honcho]",
|
"hermes-agent[honcho]",
|
||||||
"hermes-agent[acp]",
|
"hermes-agent[acp]",
|
||||||
]
|
]
|
||||||
dingtalk = ["dingtalk-stream>=0.1.0,<1"]
|
dingtalk = ["dingtalk-stream>=0.1.0,<1", "qrcode>=7.0,<8"]
|
||||||
feishu = ["lark-oapi>=1.5.3,<2"]
|
feishu = ["lark-oapi>=1.5.3,<2", "qrcode>=7.0,<8"]
|
||||||
web = ["fastapi>=0.104.0,<1", "uvicorn[standard]>=0.24.0,<1"]
|
web = ["fastapi>=0.104.0,<1", "uvicorn[standard]>=0.24.0,<1"]
|
||||||
rl = [
|
rl = [
|
||||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git@c20c85256e5a45ad31edf8b7276e9c5ee1995a30",
|
"atroposlib @ git+https://github.com/NousResearch/atropos.git@c20c85256e5a45ad31edf8b7276e9c5ee1995a30",
|
||||||
|
|
|
||||||
109
run_agent.py
109
run_agent.py
|
|
@ -353,12 +353,50 @@ def _sanitize_surrogates(text: str) -> str:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_structure_surrogates(payload: Any) -> bool:
|
||||||
|
"""Replace surrogate code points in nested dict/list payloads in-place.
|
||||||
|
|
||||||
|
Mirror of ``_sanitize_structure_non_ascii`` but for surrogate recovery.
|
||||||
|
Used to scrub nested structured fields (e.g. ``reasoning_details`` — an
|
||||||
|
array of dicts with ``summary``/``text`` strings) that flat per-field
|
||||||
|
checks don't reach. Returns True if any surrogates were replaced.
|
||||||
|
"""
|
||||||
|
found = False
|
||||||
|
|
||||||
|
def _walk(node):
|
||||||
|
nonlocal found
|
||||||
|
if isinstance(node, dict):
|
||||||
|
for key, value in node.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
if _SURROGATE_RE.search(value):
|
||||||
|
node[key] = _SURROGATE_RE.sub('\ufffd', value)
|
||||||
|
found = True
|
||||||
|
elif isinstance(value, (dict, list)):
|
||||||
|
_walk(value)
|
||||||
|
elif isinstance(node, list):
|
||||||
|
for idx, value in enumerate(node):
|
||||||
|
if isinstance(value, str):
|
||||||
|
if _SURROGATE_RE.search(value):
|
||||||
|
node[idx] = _SURROGATE_RE.sub('\ufffd', value)
|
||||||
|
found = True
|
||||||
|
elif isinstance(value, (dict, list)):
|
||||||
|
_walk(value)
|
||||||
|
|
||||||
|
_walk(payload)
|
||||||
|
return found
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_messages_surrogates(messages: list) -> bool:
|
def _sanitize_messages_surrogates(messages: list) -> bool:
|
||||||
"""Sanitize surrogate characters from all string content in a messages list.
|
"""Sanitize surrogate characters from all string content in a messages list.
|
||||||
|
|
||||||
Walks message dicts in-place. Returns True if any surrogates were found
|
Walks message dicts in-place. Returns True if any surrogates were found
|
||||||
and replaced, False otherwise. Covers content/text, name, and tool call
|
and replaced, False otherwise. Covers content/text, name, tool call
|
||||||
metadata/arguments so retries don't fail on a non-content field.
|
metadata/arguments, AND any additional string or nested structured fields
|
||||||
|
(``reasoning``, ``reasoning_content``, ``reasoning_details``, etc.) so
|
||||||
|
retries don't fail on a non-content field. Byte-level reasoning models
|
||||||
|
(xiaomi/mimo, kimi, glm) can emit lone surrogates in reasoning output
|
||||||
|
that flow through to ``api_messages["reasoning_content"]`` on the next
|
||||||
|
turn and crash json.dumps inside the OpenAI SDK.
|
||||||
"""
|
"""
|
||||||
found = False
|
found = False
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
|
|
@ -398,6 +436,21 @@ def _sanitize_messages_surrogates(messages: list) -> bool:
|
||||||
if isinstance(fn_args, str) and _SURROGATE_RE.search(fn_args):
|
if isinstance(fn_args, str) and _SURROGATE_RE.search(fn_args):
|
||||||
fn["arguments"] = _SURROGATE_RE.sub('\ufffd', fn_args)
|
fn["arguments"] = _SURROGATE_RE.sub('\ufffd', fn_args)
|
||||||
found = True
|
found = True
|
||||||
|
# Walk any additional string / nested fields (reasoning,
|
||||||
|
# reasoning_content, reasoning_details, etc.) — surrogates from
|
||||||
|
# byte-level reasoning models (xiaomi/mimo, kimi, glm) can lurk
|
||||||
|
# in these fields and aren't covered by the per-field checks above.
|
||||||
|
# Matches _sanitize_messages_non_ascii's coverage (PR #10537).
|
||||||
|
for key, value in msg.items():
|
||||||
|
if key in {"content", "name", "tool_calls", "role"}:
|
||||||
|
continue
|
||||||
|
if isinstance(value, str):
|
||||||
|
if _SURROGATE_RE.search(value):
|
||||||
|
msg[key] = _SURROGATE_RE.sub('\ufffd', value)
|
||||||
|
found = True
|
||||||
|
elif isinstance(value, (dict, list)):
|
||||||
|
if _sanitize_structure_surrogates(value):
|
||||||
|
found = True
|
||||||
return found
|
return found
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -8689,6 +8742,7 @@ class AIAgent:
|
||||||
{
|
{
|
||||||
"name": tc["function"]["name"],
|
"name": tc["function"]["name"],
|
||||||
"result": _results_by_id.get(tc.get("id")),
|
"result": _results_by_id.get(tc.get("id")),
|
||||||
|
"arguments": tc["function"].get("arguments"),
|
||||||
}
|
}
|
||||||
for tc in _m["tool_calls"]
|
for tc in _m["tool_calls"]
|
||||||
if isinstance(tc, dict)
|
if isinstance(tc, dict)
|
||||||
|
|
@ -9303,8 +9357,7 @@ class AIAgent:
|
||||||
"and had none left for the actual response.\n\n"
|
"and had none left for the actual response.\n\n"
|
||||||
"To fix this:\n"
|
"To fix this:\n"
|
||||||
"→ Lower reasoning effort: `/thinkon low` or `/thinkon minimal`\n"
|
"→ Lower reasoning effort: `/thinkon low` or `/thinkon minimal`\n"
|
||||||
"→ Increase the output token limit: "
|
"→ Or switch to a larger/non-reasoning model with `/model`"
|
||||||
"set `model.max_tokens` in config.yaml"
|
|
||||||
)
|
)
|
||||||
self._cleanup_task_resources(effective_task_id)
|
self._cleanup_task_resources(effective_task_id)
|
||||||
self._persist_session(messages, conversation_history)
|
self._persist_session(messages, conversation_history)
|
||||||
|
|
@ -9571,13 +9624,51 @@ class AIAgent:
|
||||||
if isinstance(api_error, UnicodeEncodeError) and getattr(self, '_unicode_sanitization_passes', 0) < 2:
|
if isinstance(api_error, UnicodeEncodeError) and getattr(self, '_unicode_sanitization_passes', 0) < 2:
|
||||||
_err_str = str(api_error).lower()
|
_err_str = str(api_error).lower()
|
||||||
_is_ascii_codec = "'ascii'" in _err_str or "ascii" in _err_str
|
_is_ascii_codec = "'ascii'" in _err_str or "ascii" in _err_str
|
||||||
|
# Detect surrogate errors — utf-8 codec refusing to
|
||||||
|
# encode U+D800..U+DFFF. The error text is:
|
||||||
|
# "'utf-8' codec can't encode characters in position
|
||||||
|
# N-M: surrogates not allowed"
|
||||||
|
_is_surrogate_error = (
|
||||||
|
"surrogate" in _err_str
|
||||||
|
or ("'utf-8'" in _err_str and not _is_ascii_codec)
|
||||||
|
)
|
||||||
|
# Sanitize surrogates from both the canonical `messages`
|
||||||
|
# list AND `api_messages` (the API-copy, which may carry
|
||||||
|
# `reasoning_content`/`reasoning_details` transformed
|
||||||
|
# from `reasoning` — fields the canonical list doesn't
|
||||||
|
# have directly). Also clean `api_kwargs` if built and
|
||||||
|
# `prefill_messages` if present. Mirrors the ASCII
|
||||||
|
# codec recovery below.
|
||||||
_surrogates_found = _sanitize_messages_surrogates(messages)
|
_surrogates_found = _sanitize_messages_surrogates(messages)
|
||||||
if _surrogates_found:
|
if isinstance(api_messages, list):
|
||||||
|
if _sanitize_messages_surrogates(api_messages):
|
||||||
|
_surrogates_found = True
|
||||||
|
if isinstance(api_kwargs, dict):
|
||||||
|
if _sanitize_structure_surrogates(api_kwargs):
|
||||||
|
_surrogates_found = True
|
||||||
|
if isinstance(getattr(self, "prefill_messages", None), list):
|
||||||
|
if _sanitize_messages_surrogates(self.prefill_messages):
|
||||||
|
_surrogates_found = True
|
||||||
|
# Gate the retry on the error type, not on whether we
|
||||||
|
# found anything — _force_ascii_payload / the extended
|
||||||
|
# surrogate walker above cover all known paths, but a
|
||||||
|
# new transformed field could still slip through. If
|
||||||
|
# the error was a surrogate encode failure, always let
|
||||||
|
# the retry run; the proactive sanitizer at line ~8781
|
||||||
|
# runs again on the next iteration. Bounded by
|
||||||
|
# _unicode_sanitization_passes < 2 (outer guard).
|
||||||
|
if _surrogates_found or _is_surrogate_error:
|
||||||
self._unicode_sanitization_passes += 1
|
self._unicode_sanitization_passes += 1
|
||||||
self._vprint(
|
if _surrogates_found:
|
||||||
f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...",
|
self._vprint(
|
||||||
force=True,
|
f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...",
|
||||||
)
|
force=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._vprint(
|
||||||
|
f"{self.log_prefix}⚠️ Surrogate encoding error — retrying after full-payload sanitization...",
|
||||||
|
force=True,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
if _is_ascii_codec:
|
if _is_ascii_codec:
|
||||||
self._force_ascii_payload = True
|
self._force_ascii_payload = True
|
||||||
|
|
|
||||||
|
|
@ -103,6 +103,7 @@ AUTHOR_MAP = {
|
||||||
"dangtc94@gmail.com": "dieutx",
|
"dangtc94@gmail.com": "dieutx",
|
||||||
"jaisehgal11299@gmail.com": "jaisup",
|
"jaisehgal11299@gmail.com": "jaisup",
|
||||||
"percydikec@gmail.com": "PercyDikec",
|
"percydikec@gmail.com": "PercyDikec",
|
||||||
|
"noonou7@gmail.com": "HenkDz",
|
||||||
"dean.kerr@gmail.com": "deankerr",
|
"dean.kerr@gmail.com": "deankerr",
|
||||||
"socrates1024@gmail.com": "socrates1024",
|
"socrates1024@gmail.com": "socrates1024",
|
||||||
"satelerd@gmail.com": "satelerd",
|
"satelerd@gmail.com": "satelerd",
|
||||||
|
|
|
||||||
|
|
@ -42,9 +42,10 @@ class TestToolProgressCallback:
|
||||||
def test_emits_tool_call_start(self, mock_conn, event_loop_fixture):
|
def test_emits_tool_call_start(self, mock_conn, event_loop_fixture):
|
||||||
"""Tool progress should emit a ToolCallStart update."""
|
"""Tool progress should emit a ToolCallStart update."""
|
||||||
tool_call_ids = {}
|
tool_call_ids = {}
|
||||||
|
tool_call_meta = {}
|
||||||
loop = event_loop_fixture
|
loop = event_loop_fixture
|
||||||
|
|
||||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||||
|
|
||||||
# Run callback in the event loop context
|
# Run callback in the event loop context
|
||||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||||
|
|
@ -66,9 +67,10 @@ class TestToolProgressCallback:
|
||||||
def test_handles_string_args(self, mock_conn, event_loop_fixture):
|
def test_handles_string_args(self, mock_conn, event_loop_fixture):
|
||||||
"""If args is a JSON string, it should be parsed."""
|
"""If args is a JSON string, it should be parsed."""
|
||||||
tool_call_ids = {}
|
tool_call_ids = {}
|
||||||
|
tool_call_meta = {}
|
||||||
loop = event_loop_fixture
|
loop = event_loop_fixture
|
||||||
|
|
||||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||||
|
|
||||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||||
future = MagicMock(spec=Future)
|
future = MagicMock(spec=Future)
|
||||||
|
|
@ -82,9 +84,10 @@ class TestToolProgressCallback:
|
||||||
def test_handles_non_dict_args(self, mock_conn, event_loop_fixture):
|
def test_handles_non_dict_args(self, mock_conn, event_loop_fixture):
|
||||||
"""If args is not a dict, it should be wrapped."""
|
"""If args is not a dict, it should be wrapped."""
|
||||||
tool_call_ids = {}
|
tool_call_ids = {}
|
||||||
|
tool_call_meta = {}
|
||||||
loop = event_loop_fixture
|
loop = event_loop_fixture
|
||||||
|
|
||||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||||
|
|
||||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||||
future = MagicMock(spec=Future)
|
future = MagicMock(spec=Future)
|
||||||
|
|
@ -98,10 +101,11 @@ class TestToolProgressCallback:
|
||||||
def test_duplicate_same_name_tool_calls_use_fifo_ids(self, mock_conn, event_loop_fixture):
|
def test_duplicate_same_name_tool_calls_use_fifo_ids(self, mock_conn, event_loop_fixture):
|
||||||
"""Multiple same-name tool calls should be tracked independently in order."""
|
"""Multiple same-name tool calls should be tracked independently in order."""
|
||||||
tool_call_ids = {}
|
tool_call_ids = {}
|
||||||
|
tool_call_meta = {}
|
||||||
loop = event_loop_fixture
|
loop = event_loop_fixture
|
||||||
|
|
||||||
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||||
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||||
|
|
||||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||||
future = MagicMock(spec=Future)
|
future = MagicMock(spec=Future)
|
||||||
|
|
@ -163,7 +167,7 @@ class TestStepCallback:
|
||||||
tool_call_ids = {"terminal": "tc-abc123"}
|
tool_call_ids = {"terminal": "tc-abc123"}
|
||||||
loop = event_loop_fixture
|
loop = event_loop_fixture
|
||||||
|
|
||||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||||
|
|
||||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||||
future = MagicMock(spec=Future)
|
future = MagicMock(spec=Future)
|
||||||
|
|
@ -181,7 +185,7 @@ class TestStepCallback:
|
||||||
tool_call_ids = {}
|
tool_call_ids = {}
|
||||||
loop = event_loop_fixture
|
loop = event_loop_fixture
|
||||||
|
|
||||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||||
|
|
||||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||||
cb(1, [{"name": "unknown_tool", "result": "ok"}])
|
cb(1, [{"name": "unknown_tool", "result": "ok"}])
|
||||||
|
|
@ -193,7 +197,7 @@ class TestStepCallback:
|
||||||
tool_call_ids = {"read_file": "tc-def456"}
|
tool_call_ids = {"read_file": "tc-def456"}
|
||||||
loop = event_loop_fixture
|
loop = event_loop_fixture
|
||||||
|
|
||||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||||
|
|
||||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||||
future = MagicMock(spec=Future)
|
future = MagicMock(spec=Future)
|
||||||
|
|
@ -212,7 +216,7 @@ class TestStepCallback:
|
||||||
tool_call_ids = {"terminal": deque(["tc-xyz789"])}
|
tool_call_ids = {"terminal": deque(["tc-xyz789"])}
|
||||||
loop = event_loop_fixture
|
loop = event_loop_fixture
|
||||||
|
|
||||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||||
|
|
||||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||||
|
|
@ -224,7 +228,7 @@ class TestStepCallback:
|
||||||
cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}])
|
cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}])
|
||||||
|
|
||||||
mock_btc.assert_called_once_with(
|
mock_btc.assert_called_once_with(
|
||||||
"tc-xyz789", "terminal", result='{"output": "hello"}'
|
"tc-xyz789", "terminal", result='{"output": "hello"}', function_args=None, snapshot=None
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_none_result_passed_through(self, mock_conn, event_loop_fixture):
|
def test_none_result_passed_through(self, mock_conn, event_loop_fixture):
|
||||||
|
|
@ -234,7 +238,7 @@ class TestStepCallback:
|
||||||
tool_call_ids = {"web_search": deque(["tc-aaa"])}
|
tool_call_ids = {"web_search": deque(["tc-aaa"])}
|
||||||
loop = event_loop_fixture
|
loop = event_loop_fixture
|
||||||
|
|
||||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||||
|
|
||||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||||
|
|
@ -244,7 +248,50 @@ class TestStepCallback:
|
||||||
|
|
||||||
cb(1, [{"name": "web_search", "result": None}])
|
cb(1, [{"name": "web_search", "result": None}])
|
||||||
|
|
||||||
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None)
|
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None, function_args=None, snapshot=None)
|
||||||
|
|
||||||
|
def test_step_callback_passes_arguments_and_snapshot(self, mock_conn, event_loop_fixture):
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
tool_call_ids = {"write_file": deque(["tc-write"])}
|
||||||
|
tool_call_meta = {"tc-write": {"args": {"path": "fallback.txt"}, "snapshot": "snap"}}
|
||||||
|
loop = event_loop_fixture
|
||||||
|
|
||||||
|
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||||
|
|
||||||
|
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||||
|
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||||
|
future = MagicMock(spec=Future)
|
||||||
|
future.result.return_value = None
|
||||||
|
mock_rcts.return_value = future
|
||||||
|
|
||||||
|
cb(1, [{"name": "write_file", "result": '{"bytes_written": 23}', "arguments": {"path": "diff-test.txt"}}])
|
||||||
|
|
||||||
|
mock_btc.assert_called_once_with(
|
||||||
|
"tc-write",
|
||||||
|
"write_file",
|
||||||
|
result='{"bytes_written": 23}',
|
||||||
|
function_args={"path": "diff-test.txt"},
|
||||||
|
snapshot="snap",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_tool_progress_captures_snapshot_metadata(self, mock_conn, event_loop_fixture):
|
||||||
|
tool_call_ids = {}
|
||||||
|
tool_call_meta = {}
|
||||||
|
loop = event_loop_fixture
|
||||||
|
|
||||||
|
with patch("acp_adapter.events.make_tool_call_id", return_value="tc-meta"), \
|
||||||
|
patch("acp_adapter.events._send_update") as mock_send, \
|
||||||
|
patch("agent.display.capture_local_edit_snapshot", return_value="snapshot"):
|
||||||
|
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||||
|
cb("tool.started", "write_file", None, {"path": "diff-test.txt", "content": "hello"})
|
||||||
|
|
||||||
|
assert list(tool_call_ids["write_file"]) == ["tc-meta"]
|
||||||
|
assert tool_call_meta["tc-meta"] == {
|
||||||
|
"args": {"path": "diff-test.txt", "content": "hello"},
|
||||||
|
"snapshot": "snapshot",
|
||||||
|
}
|
||||||
|
mock_send.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ from acp.schema import (
|
||||||
|
|
||||||
from acp_adapter.server import HermesACPAgent
|
from acp_adapter.server import HermesACPAgent
|
||||||
from acp_adapter.session import SessionManager
|
from acp_adapter.session import SessionManager
|
||||||
|
from acp_adapter.tools import build_tool_start
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -181,6 +182,25 @@ class TestMcpRegistrationE2E:
|
||||||
assert complete_event.raw_output is not None
|
assert complete_event.raw_output is not None
|
||||||
assert "hello" in str(complete_event.raw_output)
|
assert "hello" in str(complete_event.raw_output)
|
||||||
|
|
||||||
|
def test_patch_mode_tool_start_emits_diff_blocks_for_v4a_patch(self):
|
||||||
|
update = build_tool_start(
|
||||||
|
"tc-1",
|
||||||
|
"patch",
|
||||||
|
{
|
||||||
|
"mode": "patch",
|
||||||
|
"patch": "*** Begin Patch\n*** Update File: src/app.py\n@@\n-old line\n+new line\n*** Add File: src/new.py\n+hello\n*** End Patch",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(update.content) == 2
|
||||||
|
assert update.content[0].type == "diff"
|
||||||
|
assert update.content[0].path == "src/app.py"
|
||||||
|
assert update.content[0].old_text == "old line"
|
||||||
|
assert update.content[0].new_text == "new line"
|
||||||
|
assert update.content[1].type == "diff"
|
||||||
|
assert update.content[1].path == "src/new.py"
|
||||||
|
assert update.content[1].new_text == "hello"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_prompt_tool_results_paired_by_call_id(self, acp_agent, mock_manager):
|
async def test_prompt_tool_results_paired_by_call_id(self, acp_agent, mock_manager):
|
||||||
"""The ToolCallUpdate's toolCallId must match the ToolCallStart's."""
|
"""The ToolCallUpdate's toolCallId must match the ToolCallStart's."""
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,9 @@ from acp.schema import (
|
||||||
NewSessionResponse,
|
NewSessionResponse,
|
||||||
PromptResponse,
|
PromptResponse,
|
||||||
ResumeSessionResponse,
|
ResumeSessionResponse,
|
||||||
|
SessionModelState,
|
||||||
SetSessionConfigOptionResponse,
|
SetSessionConfigOptionResponse,
|
||||||
|
SetSessionModelResponse,
|
||||||
SetSessionModeResponse,
|
SetSessionModeResponse,
|
||||||
SessionInfo,
|
SessionInfo,
|
||||||
TextContentBlock,
|
TextContentBlock,
|
||||||
|
|
@ -127,6 +129,25 @@ class TestSessionOps:
|
||||||
assert state is not None
|
assert state is not None
|
||||||
assert state.cwd == "/home/user/project"
|
assert state.cwd == "/home/user/project"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_session_returns_model_state(self):
|
||||||
|
manager = SessionManager(
|
||||||
|
agent_factory=lambda: SimpleNamespace(model="gpt-5.4", provider="openai-codex")
|
||||||
|
)
|
||||||
|
acp_agent = HermesACPAgent(session_manager=manager)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"hermes_cli.models.curated_models_for_provider",
|
||||||
|
return_value=[("gpt-5.4", "recommended"), ("gpt-5.4-mini", "")],
|
||||||
|
):
|
||||||
|
resp = await acp_agent.new_session(cwd="/tmp")
|
||||||
|
|
||||||
|
assert isinstance(resp.models, SessionModelState)
|
||||||
|
assert resp.models.current_model_id == "openai-codex:gpt-5.4"
|
||||||
|
assert resp.models.available_models[0].model_id == "openai-codex:gpt-5.4"
|
||||||
|
assert resp.models.available_models[0].description is not None
|
||||||
|
assert "Provider:" in resp.models.available_models[0].description
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_available_commands_include_help(self, agent):
|
async def test_available_commands_include_help(self, agent):
|
||||||
help_cmd = next(
|
help_cmd = next(
|
||||||
|
|
@ -204,6 +225,33 @@ class TestListAndFork:
|
||||||
assert fork_resp.session_id
|
assert fork_resp.session_id
|
||||||
assert fork_resp.session_id != new_resp.session_id
|
assert fork_resp.session_id != new_resp.session_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_sessions_includes_title_and_updated_at(self, agent):
|
||||||
|
with patch.object(
|
||||||
|
agent.session_manager,
|
||||||
|
"list_sessions",
|
||||||
|
return_value=[
|
||||||
|
{
|
||||||
|
"session_id": "session-1",
|
||||||
|
"cwd": "/tmp/project",
|
||||||
|
"title": "Fix Zed session history",
|
||||||
|
"updated_at": 123.0,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
):
|
||||||
|
resp = await agent.list_sessions(cwd="/tmp/project")
|
||||||
|
|
||||||
|
assert isinstance(resp.sessions[0], SessionInfo)
|
||||||
|
assert resp.sessions[0].title == "Fix Zed session history"
|
||||||
|
assert resp.sessions[0].updated_at == "123.0"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_sessions_passes_cwd_filter(self, agent):
|
||||||
|
with patch.object(agent.session_manager, "list_sessions", return_value=[]) as mock_list:
|
||||||
|
await agent.list_sessions(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||||
|
|
||||||
|
mock_list.assert_called_once_with(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# session configuration / model routing
|
# session configuration / model routing
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -257,6 +305,53 @@ class TestSessionConfiguration:
|
||||||
assert result == {}
|
assert result == {}
|
||||||
assert state.model == "gpt-5.4"
|
assert state.model == "gpt-5.4"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_session_model_accepts_provider_prefixed_choice(self, tmp_path, monkeypatch):
|
||||||
|
runtime_calls = []
|
||||||
|
|
||||||
|
def fake_resolve_runtime_provider(requested=None, **kwargs):
|
||||||
|
runtime_calls.append(requested)
|
||||||
|
provider = requested or "openrouter"
|
||||||
|
return {
|
||||||
|
"provider": provider,
|
||||||
|
"api_mode": "anthropic_messages" if provider == "anthropic" else "chat_completions",
|
||||||
|
"base_url": f"https://{provider}.example/v1",
|
||||||
|
"api_key": f"{provider}-key",
|
||||||
|
"command": None,
|
||||||
|
"args": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
def fake_agent(**kwargs):
|
||||||
|
return SimpleNamespace(
|
||||||
|
model=kwargs.get("model"),
|
||||||
|
provider=kwargs.get("provider"),
|
||||||
|
base_url=kwargs.get("base_url"),
|
||||||
|
api_mode=kwargs.get("api_mode"),
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("hermes_cli.config.load_config", lambda: {
|
||||||
|
"model": {"provider": "openrouter", "default": "openrouter/gpt-5"}
|
||||||
|
})
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||||
|
fake_resolve_runtime_provider,
|
||||||
|
)
|
||||||
|
manager = SessionManager(db=SessionDB(tmp_path / "state.db"))
|
||||||
|
|
||||||
|
with patch("run_agent.AIAgent", side_effect=fake_agent):
|
||||||
|
acp_agent = HermesACPAgent(session_manager=manager)
|
||||||
|
state = manager.create_session(cwd="/tmp")
|
||||||
|
result = await acp_agent.set_session_model(
|
||||||
|
model_id="anthropic:claude-sonnet-4-6",
|
||||||
|
session_id=state.session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, SetSessionModelResponse)
|
||||||
|
assert state.model == "claude-sonnet-4-6"
|
||||||
|
assert state.agent.provider == "anthropic"
|
||||||
|
assert state.agent.base_url == "https://anthropic.example/v1"
|
||||||
|
assert runtime_calls[-1] == "anthropic"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# prompt
|
# prompt
|
||||||
|
|
@ -354,6 +449,31 @@ class TestPrompt:
|
||||||
update = last_call[1].get("update") or last_call[0][1]
|
update = last_call[1].get("update") or last_call[0][1]
|
||||||
assert update.session_update == "agent_message_chunk"
|
assert update.session_update == "agent_message_chunk"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_auto_titles_session(self, agent):
|
||||||
|
new_resp = await agent.new_session(cwd=".")
|
||||||
|
state = agent.session_manager.get_session(new_resp.session_id)
|
||||||
|
state.agent.run_conversation = MagicMock(return_value={
|
||||||
|
"final_response": "Here is the fix.",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "fix the broken ACP history"},
|
||||||
|
{"role": "assistant", "content": "Here is the fix."},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_conn = MagicMock(spec=acp.Client)
|
||||||
|
mock_conn.session_update = AsyncMock()
|
||||||
|
agent._conn = mock_conn
|
||||||
|
|
||||||
|
with patch("agent.title_generator.maybe_auto_title") as mock_title:
|
||||||
|
prompt = [TextContentBlock(type="text", text="fix the broken ACP history")]
|
||||||
|
await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||||
|
|
||||||
|
mock_title.assert_called_once()
|
||||||
|
assert mock_title.call_args.args[1] == new_resp.session_id
|
||||||
|
assert mock_title.call_args.args[2] == "fix the broken ACP history"
|
||||||
|
assert mock_title.call_args.args[3] == "Here is the fix."
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_prompt_populates_usage_from_top_level_run_conversation_fields(self, agent):
|
async def test_prompt_populates_usage_from_top_level_run_conversation_fields(self, agent):
|
||||||
"""ACP should map top-level token fields into PromptResponse.usage."""
|
"""ACP should map top-level token fields into PromptResponse.usage."""
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
@ -100,15 +101,23 @@ class TestListAndCleanup:
|
||||||
def test_list_sessions_returns_created(self, manager):
|
def test_list_sessions_returns_created(self, manager):
|
||||||
s1 = manager.create_session(cwd="/a")
|
s1 = manager.create_session(cwd="/a")
|
||||||
s2 = manager.create_session(cwd="/b")
|
s2 = manager.create_session(cwd="/b")
|
||||||
|
s1.history.append({"role": "user", "content": "hello from a"})
|
||||||
|
s2.history.append({"role": "user", "content": "hello from b"})
|
||||||
listing = manager.list_sessions()
|
listing = manager.list_sessions()
|
||||||
ids = {s["session_id"] for s in listing}
|
ids = {s["session_id"] for s in listing}
|
||||||
assert s1.session_id in ids
|
assert s1.session_id in ids
|
||||||
assert s2.session_id in ids
|
assert s2.session_id in ids
|
||||||
assert len(listing) == 2
|
assert len(listing) == 2
|
||||||
|
|
||||||
|
def test_list_sessions_hides_empty_threads(self, manager):
|
||||||
|
manager.create_session(cwd="/empty")
|
||||||
|
assert manager.list_sessions() == []
|
||||||
|
|
||||||
def test_cleanup_clears_all(self, manager):
|
def test_cleanup_clears_all(self, manager):
|
||||||
manager.create_session()
|
s1 = manager.create_session()
|
||||||
manager.create_session()
|
s2 = manager.create_session()
|
||||||
|
s1.history.append({"role": "user", "content": "one"})
|
||||||
|
s2.history.append({"role": "user", "content": "two"})
|
||||||
assert len(manager.list_sessions()) == 2
|
assert len(manager.list_sessions()) == 2
|
||||||
manager.cleanup()
|
manager.cleanup()
|
||||||
assert manager.list_sessions() == []
|
assert manager.list_sessions() == []
|
||||||
|
|
@ -194,6 +203,8 @@ class TestPersistence:
|
||||||
def test_list_sessions_includes_db_only(self, manager):
|
def test_list_sessions_includes_db_only(self, manager):
|
||||||
"""Sessions only in DB (not in memory) appear in list_sessions."""
|
"""Sessions only in DB (not in memory) appear in list_sessions."""
|
||||||
state = manager.create_session(cwd="/db-only")
|
state = manager.create_session(cwd="/db-only")
|
||||||
|
state.history.append({"role": "user", "content": "database only thread"})
|
||||||
|
manager.save_session(state.session_id)
|
||||||
sid = state.session_id
|
sid = state.session_id
|
||||||
|
|
||||||
# Drop from memory.
|
# Drop from memory.
|
||||||
|
|
@ -204,6 +215,53 @@ class TestPersistence:
|
||||||
ids = {s["session_id"] for s in listing}
|
ids = {s["session_id"] for s in listing}
|
||||||
assert sid in ids
|
assert sid in ids
|
||||||
|
|
||||||
|
def test_list_sessions_filters_by_cwd(self, manager):
|
||||||
|
keep = manager.create_session(cwd="/keep")
|
||||||
|
drop = manager.create_session(cwd="/drop")
|
||||||
|
keep.history.append({"role": "user", "content": "keep me"})
|
||||||
|
drop.history.append({"role": "user", "content": "drop me"})
|
||||||
|
|
||||||
|
listing = manager.list_sessions(cwd="/keep")
|
||||||
|
ids = {s["session_id"] for s in listing}
|
||||||
|
assert keep.session_id in ids
|
||||||
|
assert drop.session_id not in ids
|
||||||
|
|
||||||
|
def test_list_sessions_matches_windows_and_wsl_paths(self, manager):
|
||||||
|
state = manager.create_session(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||||
|
state.history.append({"role": "user", "content": "same project from WSL"})
|
||||||
|
|
||||||
|
listing = manager.list_sessions(cwd=r"E:\Projects\AI\browser-link-3")
|
||||||
|
ids = {s["session_id"] for s in listing}
|
||||||
|
assert state.session_id in ids
|
||||||
|
|
||||||
|
def test_list_sessions_prefers_title_then_preview(self, manager):
|
||||||
|
state = manager.create_session(cwd="/named")
|
||||||
|
state.history.append({"role": "user", "content": "Investigate broken ACP history in Zed"})
|
||||||
|
manager.save_session(state.session_id)
|
||||||
|
db = manager._get_db()
|
||||||
|
db.set_session_title(state.session_id, "Fix Zed ACP history")
|
||||||
|
|
||||||
|
listing = manager.list_sessions(cwd="/named")
|
||||||
|
assert listing[0]["title"] == "Fix Zed ACP history"
|
||||||
|
|
||||||
|
db.set_session_title(state.session_id, "")
|
||||||
|
listing = manager.list_sessions(cwd="/named")
|
||||||
|
assert listing[0]["title"].startswith("Investigate broken ACP history")
|
||||||
|
|
||||||
|
def test_list_sessions_sorted_by_most_recent_activity(self, manager):
|
||||||
|
older = manager.create_session(cwd="/ordered")
|
||||||
|
older.history.append({"role": "user", "content": "older"})
|
||||||
|
manager.save_session(older.session_id)
|
||||||
|
time.sleep(0.02)
|
||||||
|
newer = manager.create_session(cwd="/ordered")
|
||||||
|
newer.history.append({"role": "user", "content": "newer"})
|
||||||
|
manager.save_session(newer.session_id)
|
||||||
|
|
||||||
|
listing = manager.list_sessions(cwd="/ordered")
|
||||||
|
assert [item["session_id"] for item in listing[:2]] == [newer.session_id, older.session_id]
|
||||||
|
assert listing[0]["updated_at"]
|
||||||
|
assert listing[1]["updated_at"]
|
||||||
|
|
||||||
def test_fork_restores_source_from_db(self, manager):
|
def test_fork_restores_source_from_db(self, manager):
|
||||||
"""Forking a session that is only in DB should work."""
|
"""Forking a session that is only in DB should work."""
|
||||||
original = manager.create_session()
|
original = manager.create_session()
|
||||||
|
|
|
||||||
|
|
@ -215,6 +215,46 @@ class TestBuildToolComplete:
|
||||||
assert len(display_text) < 6000
|
assert len(display_text) < 6000
|
||||||
assert "truncated" in display_text
|
assert "truncated" in display_text
|
||||||
|
|
||||||
|
def test_build_tool_complete_for_patch_uses_diff_blocks(self):
|
||||||
|
"""Completed patch calls should keep structured diff content for Zed."""
|
||||||
|
patch_result = (
|
||||||
|
'{"success": true, "diff": "--- a/README.md\\n+++ b/README.md\\n@@ -1 +1,2 @@\\n old line\\n+new line\\n", '
|
||||||
|
'"files_modified": ["README.md"]}'
|
||||||
|
)
|
||||||
|
result = build_tool_complete("tc-p1", "patch", patch_result)
|
||||||
|
assert isinstance(result, ToolCallProgress)
|
||||||
|
assert len(result.content) == 1
|
||||||
|
diff_item = result.content[0]
|
||||||
|
assert isinstance(diff_item, FileEditToolCallContent)
|
||||||
|
assert diff_item.path == "README.md"
|
||||||
|
assert diff_item.old_text == "old line"
|
||||||
|
assert diff_item.new_text == "old line\nnew line"
|
||||||
|
|
||||||
|
def test_build_tool_complete_for_patch_falls_back_to_text_when_no_diff(self):
|
||||||
|
result = build_tool_complete("tc-p2", "patch", '{"success": true}')
|
||||||
|
assert isinstance(result, ToolCallProgress)
|
||||||
|
assert isinstance(result.content[0], ContentToolCallContent)
|
||||||
|
|
||||||
|
def test_build_tool_complete_for_write_file_uses_snapshot_diff(self, tmp_path):
|
||||||
|
target = tmp_path / "diff-test.txt"
|
||||||
|
snapshot = type("Snapshot", (), {"paths": [target], "before": {str(target): None}})()
|
||||||
|
target.write_text("hello from hermes\n", encoding="utf-8")
|
||||||
|
|
||||||
|
result = build_tool_complete(
|
||||||
|
"tc-wf1",
|
||||||
|
"write_file",
|
||||||
|
'{"bytes_written": 18, "dirs_created": false}',
|
||||||
|
function_args={"path": str(target), "content": "hello from hermes\n"},
|
||||||
|
snapshot=snapshot,
|
||||||
|
)
|
||||||
|
assert isinstance(result, ToolCallProgress)
|
||||||
|
assert len(result.content) == 1
|
||||||
|
diff_item = result.content[0]
|
||||||
|
assert isinstance(diff_item, FileEditToolCallContent)
|
||||||
|
assert diff_item.path.endswith("diff-test.txt")
|
||||||
|
assert diff_item.old_text is None
|
||||||
|
assert diff_item.new_text == "hello from hermes"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# extract_locations
|
# extract_locations
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,8 @@
|
||||||
|
|
||||||
Surrogates (U+D800..U+DFFF) are invalid in UTF-8 and crash json.dumps()
|
Surrogates (U+D800..U+DFFF) are invalid in UTF-8 and crash json.dumps()
|
||||||
inside the OpenAI SDK. They can appear via clipboard paste from rich-text
|
inside the OpenAI SDK. They can appear via clipboard paste from rich-text
|
||||||
editors like Google Docs.
|
editors like Google Docs, OR from byte-level reasoning models (xiaomi/mimo,
|
||||||
|
kimi, glm) emitting lone halves in reasoning output.
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -11,6 +12,7 @@ from unittest.mock import MagicMock, patch
|
||||||
from run_agent import (
|
from run_agent import (
|
||||||
_sanitize_surrogates,
|
_sanitize_surrogates,
|
||||||
_sanitize_messages_surrogates,
|
_sanitize_messages_surrogates,
|
||||||
|
_sanitize_structure_surrogates,
|
||||||
_SURROGATE_RE,
|
_SURROGATE_RE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -109,6 +111,186 @@ class TestSanitizeMessagesSurrogates:
|
||||||
assert "\ufffd" in msgs[0]["content"]
|
assert "\ufffd" in msgs[0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestReasoningFieldSurrogates:
|
||||||
|
"""Surrogates in reasoning fields (byte-level reasoning models).
|
||||||
|
|
||||||
|
xiaomi/mimo, kimi, glm and similar byte-level tokenizers can emit lone
|
||||||
|
surrogates in reasoning output. These fields are carried through to the
|
||||||
|
API as `reasoning_content` on assistant messages, and must be sanitized
|
||||||
|
or json.dumps() crashes with 'utf-8' codec can't encode surrogates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_reasoning_field_sanitized(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "assistant", "content": "ok", "reasoning": "thought \udce2 here"},
|
||||||
|
]
|
||||||
|
assert _sanitize_messages_surrogates(msgs) is True
|
||||||
|
assert "\udce2" not in msgs[0]["reasoning"]
|
||||||
|
assert "\ufffd" in msgs[0]["reasoning"]
|
||||||
|
|
||||||
|
def test_reasoning_content_field_sanitized(self):
|
||||||
|
"""api_messages carry `reasoning_content` built from `reasoning`."""
|
||||||
|
msgs = [
|
||||||
|
{"role": "assistant", "content": "ok", "reasoning_content": "thought \udce2 here"},
|
||||||
|
]
|
||||||
|
assert _sanitize_messages_surrogates(msgs) is True
|
||||||
|
assert "\udce2" not in msgs[0]["reasoning_content"]
|
||||||
|
assert "\ufffd" in msgs[0]["reasoning_content"]
|
||||||
|
|
||||||
|
def test_reasoning_details_nested_sanitized(self):
|
||||||
|
"""reasoning_details is a list of dicts with nested string fields."""
|
||||||
|
msgs = [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "ok",
|
||||||
|
"reasoning_details": [
|
||||||
|
{"type": "reasoning.summary", "summary": "summary \udce2 text"},
|
||||||
|
{"type": "reasoning.text", "text": "chain \udc00 of thought"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
assert _sanitize_messages_surrogates(msgs) is True
|
||||||
|
assert "\udce2" not in msgs[0]["reasoning_details"][0]["summary"]
|
||||||
|
assert "\ufffd" in msgs[0]["reasoning_details"][0]["summary"]
|
||||||
|
assert "\udc00" not in msgs[0]["reasoning_details"][1]["text"]
|
||||||
|
assert "\ufffd" in msgs[0]["reasoning_details"][1]["text"]
|
||||||
|
|
||||||
|
def test_deeply_nested_reasoning_sanitized(self):
|
||||||
|
"""Nested dicts / lists inside extra fields are recursed into."""
|
||||||
|
msgs = [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "ok",
|
||||||
|
"reasoning_details": [
|
||||||
|
{
|
||||||
|
"type": "reasoning.encrypted",
|
||||||
|
"content": {
|
||||||
|
"encrypted_content": "opaque",
|
||||||
|
"text_parts": ["part1", "part2 \udce2 part"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
assert _sanitize_messages_surrogates(msgs) is True
|
||||||
|
assert (
|
||||||
|
msgs[0]["reasoning_details"][0]["content"]["text_parts"][1]
|
||||||
|
== "part2 \ufffd part"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_reasoning_end_to_end_json_serialization(self):
|
||||||
|
"""After sanitization, the full message dict must serialize clean."""
|
||||||
|
msgs = [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "answer",
|
||||||
|
"reasoning_content": "reasoning with \udce2 surrogate",
|
||||||
|
"reasoning_details": [
|
||||||
|
{"summary": "nested \udcb0 surrogate"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
_sanitize_messages_surrogates(msgs)
|
||||||
|
# Must round-trip through json + utf-8 encoding without error
|
||||||
|
payload = json.dumps(msgs, ensure_ascii=False).encode("utf-8")
|
||||||
|
assert b"\\" not in payload[:0] # sanity — just ensure we got bytes
|
||||||
|
assert len(payload) > 0
|
||||||
|
|
||||||
|
def test_no_surrogates_returns_false(self):
|
||||||
|
"""Clean reasoning fields don't trigger a modification."""
|
||||||
|
msgs = [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "ok",
|
||||||
|
"reasoning": "clean thought",
|
||||||
|
"reasoning_content": "also clean",
|
||||||
|
"reasoning_details": [{"summary": "clean summary"}],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
assert _sanitize_messages_surrogates(msgs) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestSanitizeStructureSurrogates:
|
||||||
|
"""Test the _sanitize_structure_surrogates() helper for nested payloads."""
|
||||||
|
|
||||||
|
def test_empty_payload(self):
|
||||||
|
assert _sanitize_structure_surrogates({}) is False
|
||||||
|
assert _sanitize_structure_surrogates([]) is False
|
||||||
|
|
||||||
|
def test_flat_dict(self):
|
||||||
|
payload = {"a": "clean", "b": "dirty \udce2 text"}
|
||||||
|
assert _sanitize_structure_surrogates(payload) is True
|
||||||
|
assert payload["a"] == "clean"
|
||||||
|
assert "\ufffd" in payload["b"]
|
||||||
|
|
||||||
|
def test_flat_list(self):
|
||||||
|
payload = ["clean", "dirty \udce2"]
|
||||||
|
assert _sanitize_structure_surrogates(payload) is True
|
||||||
|
assert payload[0] == "clean"
|
||||||
|
assert "\ufffd" in payload[1]
|
||||||
|
|
||||||
|
def test_nested_dict_in_list(self):
|
||||||
|
payload = [{"x": "dirty \udce2"}, {"x": "clean"}]
|
||||||
|
assert _sanitize_structure_surrogates(payload) is True
|
||||||
|
assert "\ufffd" in payload[0]["x"]
|
||||||
|
assert payload[1]["x"] == "clean"
|
||||||
|
|
||||||
|
def test_deeply_nested(self):
|
||||||
|
payload = {
|
||||||
|
"level1": {
|
||||||
|
"level2": [
|
||||||
|
{"level3": "deep \udce2 surrogate"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert _sanitize_structure_surrogates(payload) is True
|
||||||
|
assert "\ufffd" in payload["level1"]["level2"][0]["level3"]
|
||||||
|
|
||||||
|
def test_clean_payload_returns_false(self):
|
||||||
|
payload = {"a": "clean", "b": [{"c": "also clean"}]}
|
||||||
|
assert _sanitize_structure_surrogates(payload) is False
|
||||||
|
|
||||||
|
def test_non_string_values_ignored(self):
|
||||||
|
payload = {"int": 42, "list": [1, 2, 3], "dict": {"none": None}, "bool": True}
|
||||||
|
assert _sanitize_structure_surrogates(payload) is False
|
||||||
|
# Non-string values survive unchanged
|
||||||
|
assert payload["int"] == 42
|
||||||
|
assert payload["list"] == [1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
|
class TestApiMessagesSurrogateRecovery:
|
||||||
|
"""Integration: verify the recovery block sanitizes api_messages.
|
||||||
|
|
||||||
|
The bug this guards against: a surrogate in `reasoning_content` on
|
||||||
|
api_messages (transformed from `reasoning` during build) crashes the
|
||||||
|
OpenAI SDK's json.dumps(), and the recovery block previously only
|
||||||
|
sanitized the canonical `messages` list — not `api_messages` — so the
|
||||||
|
next retry would send the same broken payload and fail 3 times.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_api_messages_reasoning_content_sanitized(self):
|
||||||
|
"""The extended sanitizer catches reasoning_content in api_messages."""
|
||||||
|
api_messages = [
|
||||||
|
{"role": "system", "content": "sys"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "response",
|
||||||
|
"reasoning_content": "thought \udce2 trail",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"function": {"name": "tool", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "content": "result", "tool_call_id": "call_1"},
|
||||||
|
]
|
||||||
|
assert _sanitize_messages_surrogates(api_messages) is True
|
||||||
|
assert "\udce2" not in api_messages[1]["reasoning_content"]
|
||||||
|
# Full payload must now serialize clean
|
||||||
|
json.dumps(api_messages, ensure_ascii=False).encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
class TestRunConversationSurrogateSanitization:
|
class TestRunConversationSurrogateSanitization:
|
||||||
"""Integration: verify run_conversation sanitizes user_message."""
|
"""Integration: verify run_conversation sanitizes user_message."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,3 +34,21 @@ def test_messaging_extra_includes_qrcode_for_weixin_setup():
|
||||||
|
|
||||||
messaging_extra = optional_dependencies["messaging"]
|
messaging_extra = optional_dependencies["messaging"]
|
||||||
assert any(dep.startswith("qrcode") for dep in messaging_extra)
|
assert any(dep.startswith("qrcode") for dep in messaging_extra)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dingtalk_extra_includes_qrcode_for_qr_auth():
|
||||||
|
"""DingTalk's QR-code device-flow auth (hermes_cli/dingtalk_auth.py)
|
||||||
|
needs the qrcode package."""
|
||||||
|
optional_dependencies = _load_optional_dependencies()
|
||||||
|
|
||||||
|
dingtalk_extra = optional_dependencies["dingtalk"]
|
||||||
|
assert any(dep.startswith("qrcode") for dep in dingtalk_extra)
|
||||||
|
|
||||||
|
|
||||||
|
def test_feishu_extra_includes_qrcode_for_qr_login():
|
||||||
|
"""Feishu's QR login flow (gateway/platforms/feishu.py) needs the
|
||||||
|
qrcode package."""
|
||||||
|
optional_dependencies = _load_optional_dependencies()
|
||||||
|
|
||||||
|
feishu_extra = optional_dependencies["feishu"]
|
||||||
|
assert any(dep.startswith("qrcode") for dep in feishu_extra)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue