mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(acp): improve zed integration
This commit is contained in:
parent
d0e1388ca9
commit
cb883f9e97
10 changed files with 769 additions and 88 deletions
|
|
@ -49,6 +49,7 @@ def make_tool_progress_cb(
|
|||
session_id: str,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
tool_call_ids: Dict[str, Deque[str]],
|
||||
tool_call_meta: Dict[str, Dict[str, Any]],
|
||||
) -> Callable:
|
||||
"""Create a ``tool_progress_callback`` for AIAgent.
|
||||
|
||||
|
|
@ -84,6 +85,16 @@ def make_tool_progress_cb(
|
|||
tool_call_ids[name] = queue
|
||||
queue.append(tc_id)
|
||||
|
||||
snapshot = None
|
||||
if name in {"write_file", "patch", "skill_manage"}:
|
||||
try:
|
||||
from agent.display import capture_local_edit_snapshot
|
||||
|
||||
snapshot = capture_local_edit_snapshot(name, args)
|
||||
except Exception:
|
||||
logger.debug("Failed to capture ACP edit snapshot for %s", name, exc_info=True)
|
||||
tool_call_meta[tc_id] = {"args": args, "snapshot": snapshot}
|
||||
|
||||
update = build_tool_start(tc_id, name, args)
|
||||
_send_update(conn, session_id, loop, update)
|
||||
|
||||
|
|
@ -119,6 +130,7 @@ def make_step_cb(
|
|||
session_id: str,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
tool_call_ids: Dict[str, Deque[str]],
|
||||
tool_call_meta: Dict[str, Dict[str, Any]],
|
||||
) -> Callable:
|
||||
"""Create a ``step_callback`` for AIAgent.
|
||||
|
||||
|
|
@ -132,10 +144,12 @@ def make_step_cb(
|
|||
for tool_info in prev_tools:
|
||||
tool_name = None
|
||||
result = None
|
||||
function_args = None
|
||||
|
||||
if isinstance(tool_info, dict):
|
||||
tool_name = tool_info.get("name") or tool_info.get("function_name")
|
||||
result = tool_info.get("result") or tool_info.get("output")
|
||||
function_args = tool_info.get("arguments") or tool_info.get("args")
|
||||
elif isinstance(tool_info, str):
|
||||
tool_name = tool_info
|
||||
|
||||
|
|
@ -145,8 +159,13 @@ def make_step_cb(
|
|||
tool_call_ids[tool_name] = queue
|
||||
if tool_name and queue:
|
||||
tc_id = queue.popleft()
|
||||
meta = tool_call_meta.pop(tc_id, {})
|
||||
update = build_tool_complete(
|
||||
tc_id, tool_name, result=str(result) if result is not None else None
|
||||
tc_id,
|
||||
tool_name,
|
||||
result=str(result) if result is not None else None,
|
||||
function_args=function_args or meta.get("args"),
|
||||
snapshot=meta.get("snapshot"),
|
||||
)
|
||||
_send_update(conn, session_id, loop, update)
|
||||
if not queue:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from acp.schema import (
|
|||
McpServerHttp,
|
||||
McpServerSse,
|
||||
McpServerStdio,
|
||||
ModelInfo,
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
|
|
@ -36,6 +37,7 @@ from acp.schema import (
|
|||
SessionCapabilities,
|
||||
SessionForkCapabilities,
|
||||
SessionListCapabilities,
|
||||
SessionModelState,
|
||||
SessionResumeCapabilities,
|
||||
SessionInfo,
|
||||
TextContentBlock,
|
||||
|
|
@ -147,6 +149,98 @@ class HermesACPAgent(acp.Agent):
|
|||
self._conn = conn
|
||||
logger.info("ACP client connected")
|
||||
|
||||
@staticmethod
|
||||
def _encode_model_choice(provider: str | None, model: str | None) -> str:
|
||||
"""Encode a model selection so ACP clients can keep provider context."""
|
||||
raw_model = str(model or "").strip()
|
||||
if not raw_model:
|
||||
return ""
|
||||
raw_provider = str(provider or "").strip().lower()
|
||||
if not raw_provider:
|
||||
return raw_model
|
||||
return f"{raw_provider}:{raw_model}"
|
||||
|
||||
def _build_model_state(self, state: SessionState) -> SessionModelState | None:
|
||||
"""Return the ACP model selector payload for editors like Zed."""
|
||||
model = str(state.model or getattr(state.agent, "model", "") or "").strip()
|
||||
provider = getattr(state.agent, "provider", None) or detect_provider() or "openrouter"
|
||||
|
||||
try:
|
||||
from hermes_cli.models import curated_models_for_provider, normalize_provider, provider_label
|
||||
|
||||
normalized_provider = normalize_provider(provider)
|
||||
provider_name = provider_label(normalized_provider)
|
||||
available_models: list[ModelInfo] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for model_id, description in curated_models_for_provider(normalized_provider):
|
||||
rendered_model = str(model_id or "").strip()
|
||||
if not rendered_model:
|
||||
continue
|
||||
choice_id = self._encode_model_choice(normalized_provider, rendered_model)
|
||||
if choice_id in seen_ids:
|
||||
continue
|
||||
desc_parts = [f"Provider: {provider_name}"]
|
||||
if description:
|
||||
desc_parts.append(str(description).strip())
|
||||
if rendered_model == model:
|
||||
desc_parts.append("current")
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
model_id=choice_id,
|
||||
name=rendered_model,
|
||||
description=" • ".join(part for part in desc_parts if part),
|
||||
)
|
||||
)
|
||||
seen_ids.add(choice_id)
|
||||
|
||||
current_model_id = self._encode_model_choice(normalized_provider, model)
|
||||
if current_model_id and current_model_id not in seen_ids:
|
||||
available_models.insert(
|
||||
0,
|
||||
ModelInfo(
|
||||
model_id=current_model_id,
|
||||
name=model,
|
||||
description=f"Provider: {provider_name} • current",
|
||||
),
|
||||
)
|
||||
|
||||
if available_models:
|
||||
return SessionModelState(
|
||||
available_models=available_models,
|
||||
current_model_id=current_model_id or available_models[0].model_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Could not build ACP model state", exc_info=True)
|
||||
|
||||
if not model:
|
||||
return None
|
||||
|
||||
fallback_choice = self._encode_model_choice(provider, model)
|
||||
return SessionModelState(
|
||||
available_models=[ModelInfo(model_id=fallback_choice, name=model)],
|
||||
current_model_id=fallback_choice,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_model_selection(raw_model: str, current_provider: str) -> tuple[str, str]:
|
||||
"""Resolve ``provider:model`` input into the provider and normalized model id."""
|
||||
target_provider = current_provider
|
||||
new_model = raw_model.strip()
|
||||
|
||||
try:
|
||||
from hermes_cli.models import detect_provider_for_model, parse_model_input
|
||||
|
||||
target_provider, new_model = parse_model_input(new_model, current_provider)
|
||||
if target_provider == current_provider:
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
except Exception:
|
||||
logger.debug("Provider detection failed, using model as-is", exc_info=True)
|
||||
|
||||
return target_provider, new_model
|
||||
|
||||
async def _register_session_mcp_servers(
|
||||
self,
|
||||
state: SessionState,
|
||||
|
|
@ -273,7 +367,10 @@ class HermesACPAgent(acp.Agent):
|
|||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("New session %s (cwd=%s)", state.session_id, cwd)
|
||||
self._schedule_available_commands_update(state.session_id)
|
||||
return NewSessionResponse(session_id=state.session_id)
|
||||
return NewSessionResponse(
|
||||
session_id=state.session_id,
|
||||
models=self._build_model_state(state),
|
||||
)
|
||||
|
||||
async def load_session(
|
||||
self,
|
||||
|
|
@ -289,7 +386,7 @@ class HermesACPAgent(acp.Agent):
|
|||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Loaded session %s", session_id)
|
||||
self._schedule_available_commands_update(session_id)
|
||||
return LoadSessionResponse()
|
||||
return LoadSessionResponse(models=self._build_model_state(state))
|
||||
|
||||
async def resume_session(
|
||||
self,
|
||||
|
|
@ -305,7 +402,7 @@ class HermesACPAgent(acp.Agent):
|
|||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Resumed session %s", state.session_id)
|
||||
self._schedule_available_commands_update(state.session_id)
|
||||
return ResumeSessionResponse()
|
||||
return ResumeSessionResponse(models=self._build_model_state(state))
|
||||
|
||||
async def cancel(self, session_id: str, **kwargs: Any) -> None:
|
||||
state = self.session_manager.get_session(session_id)
|
||||
|
|
@ -340,11 +437,20 @@ class HermesACPAgent(acp.Agent):
|
|||
cwd: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ListSessionsResponse:
|
||||
infos = self.session_manager.list_sessions()
|
||||
sessions = [
|
||||
SessionInfo(session_id=s["session_id"], cwd=s["cwd"])
|
||||
for s in infos
|
||||
]
|
||||
infos = self.session_manager.list_sessions(cwd=cwd)
|
||||
sessions = []
|
||||
for s in infos:
|
||||
updated_at = s.get("updated_at")
|
||||
if updated_at is not None and not isinstance(updated_at, str):
|
||||
updated_at = str(updated_at)
|
||||
sessions.append(
|
||||
SessionInfo(
|
||||
session_id=s["session_id"],
|
||||
cwd=s["cwd"],
|
||||
title=s.get("title"),
|
||||
updated_at=updated_at,
|
||||
)
|
||||
)
|
||||
return ListSessionsResponse(sessions=sessions)
|
||||
|
||||
# ---- Prompt (core) ------------------------------------------------------
|
||||
|
|
@ -389,12 +495,13 @@ class HermesACPAgent(acp.Agent):
|
|||
state.cancel_event.clear()
|
||||
|
||||
tool_call_ids: dict[str, Deque[str]] = defaultdict(deque)
|
||||
tool_call_meta: dict[str, dict[str, Any]] = {}
|
||||
previous_approval_cb = None
|
||||
|
||||
if conn:
|
||||
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids)
|
||||
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
|
||||
thinking_cb = make_thinking_cb(conn, session_id, loop)
|
||||
step_cb = make_step_cb(conn, session_id, loop, tool_call_ids)
|
||||
step_cb = make_step_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
|
||||
message_cb = make_message_cb(conn, session_id, loop)
|
||||
approval_cb = make_approval_callback(conn.request_permission, loop, session_id)
|
||||
else:
|
||||
|
|
@ -449,6 +556,19 @@ class HermesACPAgent(acp.Agent):
|
|||
self.session_manager.save_session(session_id)
|
||||
|
||||
final_response = result.get("final_response", "")
|
||||
if final_response:
|
||||
try:
|
||||
from agent.title_generator import maybe_auto_title
|
||||
|
||||
maybe_auto_title(
|
||||
self.session_manager._get_db(),
|
||||
session_id,
|
||||
user_text,
|
||||
final_response,
|
||||
state.history,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to auto-title ACP session %s", session_id, exc_info=True)
|
||||
if final_response and conn:
|
||||
update = acp.update_agent_message_text(final_response)
|
||||
await conn.session_update(session_id, update)
|
||||
|
|
@ -556,27 +676,15 @@ class HermesACPAgent(acp.Agent):
|
|||
provider = getattr(state.agent, "provider", None) or "auto"
|
||||
return f"Current model: {model}\nProvider: {provider}"
|
||||
|
||||
new_model = args.strip()
|
||||
target_provider = None
|
||||
current_provider = getattr(state.agent, "provider", None) or "openrouter"
|
||||
|
||||
# Auto-detect provider for the requested model
|
||||
try:
|
||||
from hermes_cli.models import parse_model_input, detect_provider_for_model
|
||||
target_provider, new_model = parse_model_input(new_model, current_provider)
|
||||
if target_provider == current_provider:
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
except Exception:
|
||||
logger.debug("Provider detection failed, using model as-is", exc_info=True)
|
||||
target_provider, new_model = self._resolve_model_selection(args, current_provider)
|
||||
|
||||
state.model = new_model
|
||||
state.agent = self.session_manager._make_agent(
|
||||
session_id=state.session_id,
|
||||
cwd=state.cwd,
|
||||
model=new_model,
|
||||
requested_provider=target_provider or current_provider,
|
||||
requested_provider=target_provider,
|
||||
)
|
||||
self.session_manager.save_session(state.session_id)
|
||||
provider_label = getattr(state.agent, "provider", None) or target_provider or current_provider
|
||||
|
|
@ -678,20 +786,30 @@ class HermesACPAgent(acp.Agent):
|
|||
"""Switch the model for a session (called by ACP protocol)."""
|
||||
state = self.session_manager.get_session(session_id)
|
||||
if state:
|
||||
state.model = model_id
|
||||
current_provider = getattr(state.agent, "provider", None)
|
||||
current_base_url = getattr(state.agent, "base_url", None)
|
||||
current_api_mode = getattr(state.agent, "api_mode", None)
|
||||
requested_provider, resolved_model = self._resolve_model_selection(
|
||||
model_id,
|
||||
current_provider or "openrouter",
|
||||
)
|
||||
state.model = resolved_model
|
||||
provider_changed = bool(current_provider and requested_provider != current_provider)
|
||||
current_base_url = None if provider_changed else getattr(state.agent, "base_url", None)
|
||||
current_api_mode = None if provider_changed else getattr(state.agent, "api_mode", None)
|
||||
state.agent = self.session_manager._make_agent(
|
||||
session_id=session_id,
|
||||
cwd=state.cwd,
|
||||
model=model_id,
|
||||
requested_provider=current_provider,
|
||||
model=resolved_model,
|
||||
requested_provider=requested_provider,
|
||||
base_url=current_base_url,
|
||||
api_mode=current_api_mode,
|
||||
)
|
||||
self.session_manager.save_session(session_id)
|
||||
logger.info("Session %s: model switched to %s", session_id, model_id)
|
||||
logger.info(
|
||||
"Session %s: model switched to %s via provider %s",
|
||||
session_id,
|
||||
resolved_model,
|
||||
requested_provider,
|
||||
)
|
||||
return SetSessionModelResponse()
|
||||
logger.warning("Session %s: model switch requested for missing session", session_id)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -13,8 +13,12 @@ from hermes_constants import get_hermes_home
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
|
@ -22,6 +26,64 @@ from typing import Any, Dict, List, Optional
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_cwd_for_compare(cwd: str | None) -> str:
|
||||
raw = str(cwd or ".").strip()
|
||||
if not raw:
|
||||
raw = "."
|
||||
expanded = os.path.expanduser(raw)
|
||||
|
||||
# Normalize Windows drive paths into the equivalent WSL mount form so
|
||||
# ACP history filters match the same workspace across Windows and WSL.
|
||||
match = re.match(r"^([A-Za-z]):[\\/](.*)$", expanded)
|
||||
if match:
|
||||
drive = match.group(1).lower()
|
||||
tail = match.group(2).replace("\\", "/")
|
||||
expanded = f"/mnt/{drive}/{tail}"
|
||||
elif re.match(r"^/mnt/[A-Za-z]/", expanded):
|
||||
expanded = f"/mnt/{expanded[5].lower()}/{expanded[7:]}"
|
||||
|
||||
return os.path.normpath(expanded)
|
||||
|
||||
|
||||
def _build_session_title(title: Any, preview: Any, cwd: str | None) -> str:
|
||||
explicit = str(title or "").strip()
|
||||
if explicit:
|
||||
return explicit
|
||||
preview_text = str(preview or "").strip()
|
||||
if preview_text:
|
||||
return preview_text
|
||||
leaf = os.path.basename(str(cwd or "").rstrip("/\\"))
|
||||
return leaf or "New thread"
|
||||
|
||||
|
||||
def _format_updated_at(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value
|
||||
try:
|
||||
return datetime.fromtimestamp(float(value), tz=timezone.utc).isoformat()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _updated_at_sort_key(value: Any) -> float:
|
||||
if value is None:
|
||||
return float("-inf")
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
raw = str(value).strip()
|
||||
if not raw:
|
||||
return float("-inf")
|
||||
try:
|
||||
return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp()
|
||||
except Exception:
|
||||
try:
|
||||
return float(raw)
|
||||
except Exception:
|
||||
return float("-inf")
|
||||
|
||||
|
||||
def _acp_stderr_print(*args, **kwargs) -> None:
|
||||
"""Best-effort human-readable output sink for ACP stdio sessions.
|
||||
|
||||
|
|
@ -162,47 +224,78 @@ class SessionManager:
|
|||
logger.info("Forked ACP session %s -> %s", session_id, new_id)
|
||||
return state
|
||||
|
||||
def list_sessions(self) -> List[Dict[str, Any]]:
|
||||
def list_sessions(self, cwd: str | None = None) -> List[Dict[str, Any]]:
|
||||
"""Return lightweight info dicts for all sessions (memory + database)."""
|
||||
normalized_cwd = _normalize_cwd_for_compare(cwd) if cwd else None
|
||||
db = self._get_db()
|
||||
persisted_rows: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if db is not None:
|
||||
try:
|
||||
for row in db.list_sessions_rich(source="acp", limit=1000):
|
||||
persisted_rows[str(row["id"])] = dict(row)
|
||||
except Exception:
|
||||
logger.debug("Failed to load ACP sessions from DB", exc_info=True)
|
||||
|
||||
# Collect in-memory sessions first.
|
||||
with self._lock:
|
||||
seen_ids = set(self._sessions.keys())
|
||||
results = [
|
||||
results = []
|
||||
for s in self._sessions.values():
|
||||
history_len = len(s.history)
|
||||
if history_len <= 0:
|
||||
continue
|
||||
if normalized_cwd and _normalize_cwd_for_compare(s.cwd) != normalized_cwd:
|
||||
continue
|
||||
persisted = persisted_rows.get(s.session_id, {})
|
||||
preview = next(
|
||||
(
|
||||
str(msg.get("content") or "").strip()
|
||||
for msg in s.history
|
||||
if msg.get("role") == "user" and str(msg.get("content") or "").strip()
|
||||
),
|
||||
persisted.get("preview") or "",
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"session_id": s.session_id,
|
||||
"cwd": s.cwd,
|
||||
"model": s.model,
|
||||
"history_len": len(s.history),
|
||||
"history_len": history_len,
|
||||
"title": _build_session_title(persisted.get("title"), preview, s.cwd),
|
||||
"updated_at": _format_updated_at(
|
||||
persisted.get("last_active") or persisted.get("started_at") or time.time()
|
||||
),
|
||||
}
|
||||
for s in self._sessions.values()
|
||||
]
|
||||
)
|
||||
|
||||
# Merge any persisted sessions not currently in memory.
|
||||
db = self._get_db()
|
||||
if db is not None:
|
||||
try:
|
||||
rows = db.search_sessions(source="acp", limit=1000)
|
||||
for row in rows:
|
||||
sid = row["id"]
|
||||
for sid, row in persisted_rows.items():
|
||||
if sid in seen_ids:
|
||||
continue
|
||||
message_count = int(row.get("message_count") or 0)
|
||||
if message_count <= 0:
|
||||
continue
|
||||
# Extract cwd from model_config JSON.
|
||||
cwd = "."
|
||||
session_cwd = "."
|
||||
mc = row.get("model_config")
|
||||
if mc:
|
||||
try:
|
||||
cwd = json.loads(mc).get("cwd", ".")
|
||||
session_cwd = json.loads(mc).get("cwd", ".")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
if normalized_cwd and _normalize_cwd_for_compare(session_cwd) != normalized_cwd:
|
||||
continue
|
||||
results.append({
|
||||
"session_id": sid,
|
||||
"cwd": cwd,
|
||||
"cwd": session_cwd,
|
||||
"model": row.get("model") or "",
|
||||
"history_len": row.get("message_count") or 0,
|
||||
"history_len": message_count,
|
||||
"title": _build_session_title(row.get("title"), row.get("preview"), session_cwd),
|
||||
"updated_at": _format_updated_at(row.get("last_active") or row.get("started_at")),
|
||||
})
|
||||
except Exception:
|
||||
logger.debug("Failed to list ACP sessions from DB", exc_info=True)
|
||||
|
||||
results.sort(key=lambda item: _updated_at_sort_key(item.get("updated_at")), reverse=True)
|
||||
return results
|
||||
|
||||
def update_cwd(self, session_id: str, cwd: str) -> Optional[SessionState]:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
|
@ -96,6 +97,170 @@ def build_tool_title(tool_name: str, args: Dict[str, Any]) -> str:
|
|||
return tool_name
|
||||
|
||||
|
||||
def _build_patch_mode_content(patch_text: str) -> List[Any]:
|
||||
"""Parse V4A patch mode input into ACP diff blocks when possible."""
|
||||
if not patch_text:
|
||||
return [acp.tool_content(acp.text_block(""))]
|
||||
|
||||
try:
|
||||
from tools.patch_parser import OperationType, parse_v4a_patch
|
||||
|
||||
operations, error = parse_v4a_patch(patch_text)
|
||||
if error or not operations:
|
||||
return [acp.tool_content(acp.text_block(patch_text))]
|
||||
|
||||
content: List[Any] = []
|
||||
for op in operations:
|
||||
if op.operation == OperationType.UPDATE:
|
||||
old_chunks: list[str] = []
|
||||
new_chunks: list[str] = []
|
||||
for hunk in op.hunks:
|
||||
old_lines = [line.content for line in hunk.lines if line.prefix in (" ", "-")]
|
||||
new_lines = [line.content for line in hunk.lines if line.prefix in (" ", "+")]
|
||||
if old_lines or new_lines:
|
||||
old_chunks.append("\n".join(old_lines))
|
||||
new_chunks.append("\n".join(new_lines))
|
||||
|
||||
old_text = "\n...\n".join(chunk for chunk in old_chunks if chunk)
|
||||
new_text = "\n...\n".join(chunk for chunk in new_chunks if chunk)
|
||||
if old_text or new_text:
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=op.file_path,
|
||||
old_text=old_text or None,
|
||||
new_text=new_text or "",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if op.operation == OperationType.ADD:
|
||||
added_lines = [line.content for hunk in op.hunks for line in hunk.lines if line.prefix == "+"]
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=op.file_path,
|
||||
new_text="\n".join(added_lines),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if op.operation == OperationType.DELETE:
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=op.file_path,
|
||||
old_text=f"Delete file: {op.file_path}",
|
||||
new_text="",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if op.operation == OperationType.MOVE:
|
||||
content.append(
|
||||
acp.tool_content(acp.text_block(f"Move file: {op.file_path} -> {op.new_path}"))
|
||||
)
|
||||
|
||||
return content or [acp.tool_content(acp.text_block(patch_text))]
|
||||
except Exception:
|
||||
return [acp.tool_content(acp.text_block(patch_text))]
|
||||
|
||||
|
||||
def _strip_diff_prefix(path: str) -> str:
|
||||
raw = str(path or "").strip()
|
||||
if raw.startswith(("a/", "b/")):
|
||||
return raw[2:]
|
||||
return raw
|
||||
|
||||
|
||||
def _parse_unified_diff_content(diff_text: str) -> List[Any]:
|
||||
"""Convert unified diff text into ACP diff content blocks."""
|
||||
if not diff_text:
|
||||
return []
|
||||
|
||||
content: List[Any] = []
|
||||
current_old_path: Optional[str] = None
|
||||
current_new_path: Optional[str] = None
|
||||
old_lines: list[str] = []
|
||||
new_lines: list[str] = []
|
||||
|
||||
def _flush() -> None:
|
||||
nonlocal current_old_path, current_new_path, old_lines, new_lines
|
||||
if current_old_path is None and current_new_path is None:
|
||||
return
|
||||
path = current_new_path if current_new_path and current_new_path != "/dev/null" else current_old_path
|
||||
if not path or path == "/dev/null":
|
||||
current_old_path = None
|
||||
current_new_path = None
|
||||
old_lines = []
|
||||
new_lines = []
|
||||
return
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=_strip_diff_prefix(path),
|
||||
old_text="\n".join(old_lines) if old_lines else None,
|
||||
new_text="\n".join(new_lines),
|
||||
)
|
||||
)
|
||||
current_old_path = None
|
||||
current_new_path = None
|
||||
old_lines = []
|
||||
new_lines = []
|
||||
|
||||
for line in diff_text.splitlines():
|
||||
if line.startswith("--- "):
|
||||
_flush()
|
||||
current_old_path = line[4:].strip()
|
||||
continue
|
||||
if line.startswith("+++ "):
|
||||
current_new_path = line[4:].strip()
|
||||
continue
|
||||
if line.startswith("@@"):
|
||||
continue
|
||||
if current_old_path is None and current_new_path is None:
|
||||
continue
|
||||
if line.startswith("+"):
|
||||
new_lines.append(line[1:])
|
||||
elif line.startswith("-"):
|
||||
old_lines.append(line[1:])
|
||||
elif line.startswith(" "):
|
||||
shared = line[1:]
|
||||
old_lines.append(shared)
|
||||
new_lines.append(shared)
|
||||
|
||||
_flush()
|
||||
return content
|
||||
|
||||
|
||||
def _build_tool_complete_content(
|
||||
tool_name: str,
|
||||
result: Optional[str],
|
||||
*,
|
||||
function_args: Optional[Dict[str, Any]] = None,
|
||||
snapshot: Any = None,
|
||||
) -> List[Any]:
|
||||
"""Build structured ACP completion content, falling back to plain text."""
|
||||
display_result = result or ""
|
||||
if len(display_result) > 5000:
|
||||
display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)"
|
||||
|
||||
if tool_name in {"write_file", "patch", "skill_manage"}:
|
||||
try:
|
||||
from agent.display import extract_edit_diff
|
||||
|
||||
diff_text = extract_edit_diff(
|
||||
tool_name,
|
||||
result,
|
||||
function_args=function_args,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
if isinstance(diff_text, str) and diff_text.strip():
|
||||
diff_content = _parse_unified_diff_content(diff_text)
|
||||
if diff_content:
|
||||
return diff_content
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return [acp.tool_content(acp.text_block(display_result))]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build ACP content objects for tool-call events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -119,9 +284,8 @@ def build_tool_start(
|
|||
new = arguments.get("new_string", "")
|
||||
content = [acp.tool_diff_content(path=path, new_text=new, old_text=old)]
|
||||
else:
|
||||
# Patch mode — show the patch content as text
|
||||
patch_text = arguments.get("patch", "")
|
||||
content = [acp.tool_content(acp.text_block(patch_text))]
|
||||
content = _build_patch_mode_content(patch_text)
|
||||
return acp.start_tool_call(
|
||||
tool_call_id, title, kind=kind, content=content, locations=locations,
|
||||
raw_input=arguments,
|
||||
|
|
@ -178,16 +342,17 @@ def build_tool_complete(
|
|||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: Optional[str] = None,
|
||||
function_args: Optional[Dict[str, Any]] = None,
|
||||
snapshot: Any = None,
|
||||
) -> ToolCallProgress:
|
||||
"""Create a ToolCallUpdate (progress) event for a completed tool call."""
|
||||
kind = get_tool_kind(tool_name)
|
||||
|
||||
# Truncate very large results for the UI
|
||||
display_result = result or ""
|
||||
if len(display_result) > 5000:
|
||||
display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)"
|
||||
|
||||
content = [acp.tool_content(acp.text_block(display_result))]
|
||||
content = _build_tool_complete_content(
|
||||
tool_name,
|
||||
result,
|
||||
function_args=function_args,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
return acp.update_tool_call(
|
||||
tool_call_id,
|
||||
kind=kind,
|
||||
|
|
|
|||
|
|
@ -8688,6 +8688,7 @@ class AIAgent:
|
|||
{
|
||||
"name": tc["function"]["name"],
|
||||
"result": _results_by_id.get(tc.get("id")),
|
||||
"arguments": tc["function"].get("arguments"),
|
||||
}
|
||||
for tc in _m["tool_calls"]
|
||||
if isinstance(tc, dict)
|
||||
|
|
|
|||
|
|
@ -42,9 +42,10 @@ class TestToolProgressCallback:
|
|||
def test_emits_tool_call_start(self, mock_conn, event_loop_fixture):
|
||||
"""Tool progress should emit a ToolCallStart update."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
# Run callback in the event loop context
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
|
|
@ -66,9 +67,10 @@ class TestToolProgressCallback:
|
|||
def test_handles_string_args(self, mock_conn, event_loop_fixture):
|
||||
"""If args is a JSON string, it should be parsed."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -82,9 +84,10 @@ class TestToolProgressCallback:
|
|||
def test_handles_non_dict_args(self, mock_conn, event_loop_fixture):
|
||||
"""If args is not a dict, it should be wrapped."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -98,10 +101,11 @@ class TestToolProgressCallback:
|
|||
def test_duplicate_same_name_tool_calls_use_fifo_ids(self, mock_conn, event_loop_fixture):
|
||||
"""Multiple same-name tool calls should be tracked independently in order."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -163,7 +167,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {"terminal": "tc-abc123"}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -181,7 +185,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
cb(1, [{"name": "unknown_tool", "result": "ok"}])
|
||||
|
|
@ -193,7 +197,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {"read_file": "tc-def456"}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -212,7 +216,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {"terminal": deque(["tc-xyz789"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
|
|
@ -224,7 +228,7 @@ class TestStepCallback:
|
|||
cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}])
|
||||
|
||||
mock_btc.assert_called_once_with(
|
||||
"tc-xyz789", "terminal", result='{"output": "hello"}'
|
||||
"tc-xyz789", "terminal", result='{"output": "hello"}', function_args=None, snapshot=None
|
||||
)
|
||||
|
||||
def test_none_result_passed_through(self, mock_conn, event_loop_fixture):
|
||||
|
|
@ -234,7 +238,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {"web_search": deque(["tc-aaa"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
|
|
@ -244,7 +248,50 @@ class TestStepCallback:
|
|||
|
||||
cb(1, [{"name": "web_search", "result": None}])
|
||||
|
||||
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None)
|
||||
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None, function_args=None, snapshot=None)
|
||||
|
||||
def test_step_callback_passes_arguments_and_snapshot(self, mock_conn, event_loop_fixture):
|
||||
from collections import deque
|
||||
|
||||
tool_call_ids = {"write_file": deque(["tc-write"])}
|
||||
tool_call_meta = {"tc-write": {"args": {"path": "fallback.txt"}, "snapshot": "snap"}}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb(1, [{"name": "write_file", "result": '{"bytes_written": 23}', "arguments": {"path": "diff-test.txt"}}])
|
||||
|
||||
mock_btc.assert_called_once_with(
|
||||
"tc-write",
|
||||
"write_file",
|
||||
result='{"bytes_written": 23}',
|
||||
function_args={"path": "diff-test.txt"},
|
||||
snapshot="snap",
|
||||
)
|
||||
|
||||
def test_tool_progress_captures_snapshot_metadata(self, mock_conn, event_loop_fixture):
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
with patch("acp_adapter.events.make_tool_call_id", return_value="tc-meta"), \
|
||||
patch("acp_adapter.events._send_update") as mock_send, \
|
||||
patch("agent.display.capture_local_edit_snapshot", return_value="snapshot"):
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
cb("tool.started", "write_file", None, {"path": "diff-test.txt", "content": "hello"})
|
||||
|
||||
assert list(tool_call_ids["write_file"]) == ["tc-meta"]
|
||||
assert tool_call_meta["tc-meta"] == {
|
||||
"args": {"path": "diff-test.txt", "content": "hello"},
|
||||
"snapshot": "snapshot",
|
||||
}
|
||||
mock_send.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from acp.schema import (
|
|||
|
||||
from acp_adapter.server import HermesACPAgent
|
||||
from acp_adapter.session import SessionManager
|
||||
from acp_adapter.tools import build_tool_start
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -181,6 +182,25 @@ class TestMcpRegistrationE2E:
|
|||
assert complete_event.raw_output is not None
|
||||
assert "hello" in str(complete_event.raw_output)
|
||||
|
||||
def test_patch_mode_tool_start_emits_diff_blocks_for_v4a_patch(self):
|
||||
update = build_tool_start(
|
||||
"tc-1",
|
||||
"patch",
|
||||
{
|
||||
"mode": "patch",
|
||||
"patch": "*** Begin Patch\n*** Update File: src/app.py\n@@\n-old line\n+new line\n*** Add File: src/new.py\n+hello\n*** End Patch",
|
||||
},
|
||||
)
|
||||
|
||||
assert len(update.content) == 2
|
||||
assert update.content[0].type == "diff"
|
||||
assert update.content[0].path == "src/app.py"
|
||||
assert update.content[0].old_text == "old line"
|
||||
assert update.content[0].new_text == "new line"
|
||||
assert update.content[1].type == "diff"
|
||||
assert update.content[1].path == "src/new.py"
|
||||
assert update.content[1].new_text == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_tool_results_paired_by_call_id(self, acp_agent, mock_manager):
|
||||
"""The ToolCallUpdate's toolCallId must match the ToolCallStart's."""
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ from acp.schema import (
|
|||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
SessionModelState,
|
||||
SetSessionConfigOptionResponse,
|
||||
SetSessionModelResponse,
|
||||
SetSessionModeResponse,
|
||||
SessionInfo,
|
||||
TextContentBlock,
|
||||
|
|
@ -127,6 +129,25 @@ class TestSessionOps:
|
|||
assert state is not None
|
||||
assert state.cwd == "/home/user/project"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_session_returns_model_state(self):
|
||||
manager = SessionManager(
|
||||
agent_factory=lambda: SimpleNamespace(model="gpt-5.4", provider="openai-codex")
|
||||
)
|
||||
acp_agent = HermesACPAgent(session_manager=manager)
|
||||
|
||||
with patch(
|
||||
"hermes_cli.models.curated_models_for_provider",
|
||||
return_value=[("gpt-5.4", "recommended"), ("gpt-5.4-mini", "")],
|
||||
):
|
||||
resp = await acp_agent.new_session(cwd="/tmp")
|
||||
|
||||
assert isinstance(resp.models, SessionModelState)
|
||||
assert resp.models.current_model_id == "openai-codex:gpt-5.4"
|
||||
assert resp.models.available_models[0].model_id == "openai-codex:gpt-5.4"
|
||||
assert resp.models.available_models[0].description is not None
|
||||
assert "Provider:" in resp.models.available_models[0].description
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_available_commands_include_help(self, agent):
|
||||
help_cmd = next(
|
||||
|
|
@ -204,6 +225,33 @@ class TestListAndFork:
|
|||
assert fork_resp.session_id
|
||||
assert fork_resp.session_id != new_resp.session_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_includes_title_and_updated_at(self, agent):
|
||||
with patch.object(
|
||||
agent.session_manager,
|
||||
"list_sessions",
|
||||
return_value=[
|
||||
{
|
||||
"session_id": "session-1",
|
||||
"cwd": "/tmp/project",
|
||||
"title": "Fix Zed session history",
|
||||
"updated_at": 123.0,
|
||||
}
|
||||
],
|
||||
):
|
||||
resp = await agent.list_sessions(cwd="/tmp/project")
|
||||
|
||||
assert isinstance(resp.sessions[0], SessionInfo)
|
||||
assert resp.sessions[0].title == "Fix Zed session history"
|
||||
assert resp.sessions[0].updated_at == "123.0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_passes_cwd_filter(self, agent):
|
||||
with patch.object(agent.session_manager, "list_sessions", return_value=[]) as mock_list:
|
||||
await agent.list_sessions(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||
|
||||
mock_list.assert_called_once_with(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session configuration / model routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -257,6 +305,53 @@ class TestSessionConfiguration:
|
|||
assert result == {}
|
||||
assert state.model == "gpt-5.4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_session_model_accepts_provider_prefixed_choice(self, tmp_path, monkeypatch):
|
||||
runtime_calls = []
|
||||
|
||||
def fake_resolve_runtime_provider(requested=None, **kwargs):
|
||||
runtime_calls.append(requested)
|
||||
provider = requested or "openrouter"
|
||||
return {
|
||||
"provider": provider,
|
||||
"api_mode": "anthropic_messages" if provider == "anthropic" else "chat_completions",
|
||||
"base_url": f"https://{provider}.example/v1",
|
||||
"api_key": f"{provider}-key",
|
||||
"command": None,
|
||||
"args": [],
|
||||
}
|
||||
|
||||
def fake_agent(**kwargs):
|
||||
return SimpleNamespace(
|
||||
model=kwargs.get("model"),
|
||||
provider=kwargs.get("provider"),
|
||||
base_url=kwargs.get("base_url"),
|
||||
api_mode=kwargs.get("api_mode"),
|
||||
)
|
||||
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: {
|
||||
"model": {"provider": "openrouter", "default": "openrouter/gpt-5"}
|
||||
})
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
fake_resolve_runtime_provider,
|
||||
)
|
||||
manager = SessionManager(db=SessionDB(tmp_path / "state.db"))
|
||||
|
||||
with patch("run_agent.AIAgent", side_effect=fake_agent):
|
||||
acp_agent = HermesACPAgent(session_manager=manager)
|
||||
state = manager.create_session(cwd="/tmp")
|
||||
result = await acp_agent.set_session_model(
|
||||
model_id="anthropic:claude-sonnet-4-6",
|
||||
session_id=state.session_id,
|
||||
)
|
||||
|
||||
assert isinstance(result, SetSessionModelResponse)
|
||||
assert state.model == "claude-sonnet-4-6"
|
||||
assert state.agent.provider == "anthropic"
|
||||
assert state.agent.base_url == "https://anthropic.example/v1"
|
||||
assert runtime_calls[-1] == "anthropic"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prompt
|
||||
|
|
@ -354,6 +449,31 @@ class TestPrompt:
|
|||
update = last_call[1].get("update") or last_call[0][1]
|
||||
assert update.session_update == "agent_message_chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_auto_titles_session(self, agent):
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "Here is the fix.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "fix the broken ACP history"},
|
||||
{"role": "assistant", "content": "Here is the fix."},
|
||||
],
|
||||
})
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
with patch("agent.title_generator.maybe_auto_title") as mock_title:
|
||||
prompt = [TextContentBlock(type="text", text="fix the broken ACP history")]
|
||||
await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
mock_title.assert_called_once()
|
||||
assert mock_title.call_args.args[1] == new_resp.session_id
|
||||
assert mock_title.call_args.args[2] == "fix the broken ACP history"
|
||||
assert mock_title.call_args.args[3] == "Here is the fix."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_populates_usage_from_top_level_run_conversation_fields(self, agent):
|
||||
"""ACP should map top-level token fields into PromptResponse.usage."""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import contextlib
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
|
@ -100,15 +101,23 @@ class TestListAndCleanup:
|
|||
def test_list_sessions_returns_created(self, manager):
|
||||
s1 = manager.create_session(cwd="/a")
|
||||
s2 = manager.create_session(cwd="/b")
|
||||
s1.history.append({"role": "user", "content": "hello from a"})
|
||||
s2.history.append({"role": "user", "content": "hello from b"})
|
||||
listing = manager.list_sessions()
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert s1.session_id in ids
|
||||
assert s2.session_id in ids
|
||||
assert len(listing) == 2
|
||||
|
||||
def test_list_sessions_hides_empty_threads(self, manager):
|
||||
manager.create_session(cwd="/empty")
|
||||
assert manager.list_sessions() == []
|
||||
|
||||
def test_cleanup_clears_all(self, manager):
|
||||
manager.create_session()
|
||||
manager.create_session()
|
||||
s1 = manager.create_session()
|
||||
s2 = manager.create_session()
|
||||
s1.history.append({"role": "user", "content": "one"})
|
||||
s2.history.append({"role": "user", "content": "two"})
|
||||
assert len(manager.list_sessions()) == 2
|
||||
manager.cleanup()
|
||||
assert manager.list_sessions() == []
|
||||
|
|
@ -194,6 +203,8 @@ class TestPersistence:
|
|||
def test_list_sessions_includes_db_only(self, manager):
|
||||
"""Sessions only in DB (not in memory) appear in list_sessions."""
|
||||
state = manager.create_session(cwd="/db-only")
|
||||
state.history.append({"role": "user", "content": "database only thread"})
|
||||
manager.save_session(state.session_id)
|
||||
sid = state.session_id
|
||||
|
||||
# Drop from memory.
|
||||
|
|
@ -204,6 +215,53 @@ class TestPersistence:
|
|||
ids = {s["session_id"] for s in listing}
|
||||
assert sid in ids
|
||||
|
||||
def test_list_sessions_filters_by_cwd(self, manager):
|
||||
keep = manager.create_session(cwd="/keep")
|
||||
drop = manager.create_session(cwd="/drop")
|
||||
keep.history.append({"role": "user", "content": "keep me"})
|
||||
drop.history.append({"role": "user", "content": "drop me"})
|
||||
|
||||
listing = manager.list_sessions(cwd="/keep")
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert keep.session_id in ids
|
||||
assert drop.session_id not in ids
|
||||
|
||||
def test_list_sessions_matches_windows_and_wsl_paths(self, manager):
|
||||
state = manager.create_session(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||
state.history.append({"role": "user", "content": "same project from WSL"})
|
||||
|
||||
listing = manager.list_sessions(cwd=r"E:\Projects\AI\browser-link-3")
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert state.session_id in ids
|
||||
|
||||
def test_list_sessions_prefers_title_then_preview(self, manager):
|
||||
state = manager.create_session(cwd="/named")
|
||||
state.history.append({"role": "user", "content": "Investigate broken ACP history in Zed"})
|
||||
manager.save_session(state.session_id)
|
||||
db = manager._get_db()
|
||||
db.set_session_title(state.session_id, "Fix Zed ACP history")
|
||||
|
||||
listing = manager.list_sessions(cwd="/named")
|
||||
assert listing[0]["title"] == "Fix Zed ACP history"
|
||||
|
||||
db.set_session_title(state.session_id, "")
|
||||
listing = manager.list_sessions(cwd="/named")
|
||||
assert listing[0]["title"].startswith("Investigate broken ACP history")
|
||||
|
||||
def test_list_sessions_sorted_by_most_recent_activity(self, manager):
|
||||
older = manager.create_session(cwd="/ordered")
|
||||
older.history.append({"role": "user", "content": "older"})
|
||||
manager.save_session(older.session_id)
|
||||
time.sleep(0.02)
|
||||
newer = manager.create_session(cwd="/ordered")
|
||||
newer.history.append({"role": "user", "content": "newer"})
|
||||
manager.save_session(newer.session_id)
|
||||
|
||||
listing = manager.list_sessions(cwd="/ordered")
|
||||
assert [item["session_id"] for item in listing[:2]] == [newer.session_id, older.session_id]
|
||||
assert listing[0]["updated_at"]
|
||||
assert listing[1]["updated_at"]
|
||||
|
||||
def test_fork_restores_source_from_db(self, manager):
|
||||
"""Forking a session that is only in DB should work."""
|
||||
original = manager.create_session()
|
||||
|
|
|
|||
|
|
@ -215,6 +215,46 @@ class TestBuildToolComplete:
|
|||
assert len(display_text) < 6000
|
||||
assert "truncated" in display_text
|
||||
|
||||
def test_build_tool_complete_for_patch_uses_diff_blocks(self):
|
||||
"""Completed patch calls should keep structured diff content for Zed."""
|
||||
patch_result = (
|
||||
'{"success": true, "diff": "--- a/README.md\\n+++ b/README.md\\n@@ -1 +1,2 @@\\n old line\\n+new line\\n", '
|
||||
'"files_modified": ["README.md"]}'
|
||||
)
|
||||
result = build_tool_complete("tc-p1", "patch", patch_result)
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
assert len(result.content) == 1
|
||||
diff_item = result.content[0]
|
||||
assert isinstance(diff_item, FileEditToolCallContent)
|
||||
assert diff_item.path == "README.md"
|
||||
assert diff_item.old_text == "old line"
|
||||
assert diff_item.new_text == "old line\nnew line"
|
||||
|
||||
def test_build_tool_complete_for_patch_falls_back_to_text_when_no_diff(self):
|
||||
result = build_tool_complete("tc-p2", "patch", '{"success": true}')
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
assert isinstance(result.content[0], ContentToolCallContent)
|
||||
|
||||
def test_build_tool_complete_for_write_file_uses_snapshot_diff(self, tmp_path):
|
||||
target = tmp_path / "diff-test.txt"
|
||||
snapshot = type("Snapshot", (), {"paths": [target], "before": {str(target): None}})()
|
||||
target.write_text("hello from hermes\n", encoding="utf-8")
|
||||
|
||||
result = build_tool_complete(
|
||||
"tc-wf1",
|
||||
"write_file",
|
||||
'{"bytes_written": 18, "dirs_created": false}',
|
||||
function_args={"path": str(target), "content": "hello from hermes\n"},
|
||||
snapshot=snapshot,
|
||||
)
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
assert len(result.content) == 1
|
||||
diff_item = result.content[0]
|
||||
assert isinstance(diff_item, FileEditToolCallContent)
|
||||
assert diff_item.path.endswith("diff-test.txt")
|
||||
assert diff_item.old_text is None
|
||||
assert diff_item.new_text == "hello from hermes"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_locations
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue