fix(acp): improve zed integration

This commit is contained in:
Henkey 2026-04-17 17:28:59 +01:00 committed by Teknium
parent d0e1388ca9
commit cb883f9e97
10 changed files with 769 additions and 88 deletions

View file

@ -49,6 +49,7 @@ def make_tool_progress_cb(
session_id: str,
loop: asyncio.AbstractEventLoop,
tool_call_ids: Dict[str, Deque[str]],
tool_call_meta: Dict[str, Dict[str, Any]],
) -> Callable:
"""Create a ``tool_progress_callback`` for AIAgent.
@ -84,6 +85,16 @@ def make_tool_progress_cb(
tool_call_ids[name] = queue
queue.append(tc_id)
snapshot = None
if name in {"write_file", "patch", "skill_manage"}:
try:
from agent.display import capture_local_edit_snapshot
snapshot = capture_local_edit_snapshot(name, args)
except Exception:
logger.debug("Failed to capture ACP edit snapshot for %s", name, exc_info=True)
tool_call_meta[tc_id] = {"args": args, "snapshot": snapshot}
update = build_tool_start(tc_id, name, args)
_send_update(conn, session_id, loop, update)
@ -119,6 +130,7 @@ def make_step_cb(
session_id: str,
loop: asyncio.AbstractEventLoop,
tool_call_ids: Dict[str, Deque[str]],
tool_call_meta: Dict[str, Dict[str, Any]],
) -> Callable:
"""Create a ``step_callback`` for AIAgent.
@ -132,10 +144,12 @@ def make_step_cb(
for tool_info in prev_tools:
tool_name = None
result = None
function_args = None
if isinstance(tool_info, dict):
tool_name = tool_info.get("name") or tool_info.get("function_name")
result = tool_info.get("result") or tool_info.get("output")
function_args = tool_info.get("arguments") or tool_info.get("args")
elif isinstance(tool_info, str):
tool_name = tool_info
@ -145,8 +159,13 @@ def make_step_cb(
tool_call_ids[tool_name] = queue
if tool_name and queue:
tc_id = queue.popleft()
meta = tool_call_meta.pop(tc_id, {})
update = build_tool_complete(
tc_id, tool_name, result=str(result) if result is not None else None
tc_id,
tool_name,
result=str(result) if result is not None else None,
function_args=function_args or meta.get("args"),
snapshot=meta.get("snapshot"),
)
_send_update(conn, session_id, loop, update)
if not queue:

View file

@ -26,6 +26,7 @@ from acp.schema import (
McpServerHttp,
McpServerSse,
McpServerStdio,
ModelInfo,
NewSessionResponse,
PromptResponse,
ResumeSessionResponse,
@ -36,6 +37,7 @@ from acp.schema import (
SessionCapabilities,
SessionForkCapabilities,
SessionListCapabilities,
SessionModelState,
SessionResumeCapabilities,
SessionInfo,
TextContentBlock,
@ -147,6 +149,98 @@ class HermesACPAgent(acp.Agent):
self._conn = conn
logger.info("ACP client connected")
@staticmethod
def _encode_model_choice(provider: str | None, model: str | None) -> str:
"""Encode a model selection so ACP clients can keep provider context."""
raw_model = str(model or "").strip()
if not raw_model:
return ""
raw_provider = str(provider or "").strip().lower()
if not raw_provider:
return raw_model
return f"{raw_provider}:{raw_model}"
def _build_model_state(self, state: SessionState) -> SessionModelState | None:
"""Return the ACP model selector payload for editors like Zed."""
model = str(state.model or getattr(state.agent, "model", "") or "").strip()
provider = getattr(state.agent, "provider", None) or detect_provider() or "openrouter"
try:
from hermes_cli.models import curated_models_for_provider, normalize_provider, provider_label
normalized_provider = normalize_provider(provider)
provider_name = provider_label(normalized_provider)
available_models: list[ModelInfo] = []
seen_ids: set[str] = set()
for model_id, description in curated_models_for_provider(normalized_provider):
rendered_model = str(model_id or "").strip()
if not rendered_model:
continue
choice_id = self._encode_model_choice(normalized_provider, rendered_model)
if choice_id in seen_ids:
continue
desc_parts = [f"Provider: {provider_name}"]
if description:
desc_parts.append(str(description).strip())
if rendered_model == model:
desc_parts.append("current")
available_models.append(
ModelInfo(
model_id=choice_id,
name=rendered_model,
description="".join(part for part in desc_parts if part),
)
)
seen_ids.add(choice_id)
current_model_id = self._encode_model_choice(normalized_provider, model)
if current_model_id and current_model_id not in seen_ids:
available_models.insert(
0,
ModelInfo(
model_id=current_model_id,
name=model,
description=f"Provider: {provider_name} • current",
),
)
if available_models:
return SessionModelState(
available_models=available_models,
current_model_id=current_model_id or available_models[0].model_id,
)
except Exception:
logger.debug("Could not build ACP model state", exc_info=True)
if not model:
return None
fallback_choice = self._encode_model_choice(provider, model)
return SessionModelState(
available_models=[ModelInfo(model_id=fallback_choice, name=model)],
current_model_id=fallback_choice,
)
@staticmethod
def _resolve_model_selection(raw_model: str, current_provider: str) -> tuple[str, str]:
"""Resolve ``provider:model`` input into the provider and normalized model id."""
target_provider = current_provider
new_model = raw_model.strip()
try:
from hermes_cli.models import detect_provider_for_model, parse_model_input
target_provider, new_model = parse_model_input(new_model, current_provider)
if target_provider == current_provider:
detected = detect_provider_for_model(new_model, current_provider)
if detected:
target_provider, new_model = detected
except Exception:
logger.debug("Provider detection failed, using model as-is", exc_info=True)
return target_provider, new_model
async def _register_session_mcp_servers(
self,
state: SessionState,
@ -273,7 +367,10 @@ class HermesACPAgent(acp.Agent):
await self._register_session_mcp_servers(state, mcp_servers)
logger.info("New session %s (cwd=%s)", state.session_id, cwd)
self._schedule_available_commands_update(state.session_id)
return NewSessionResponse(session_id=state.session_id)
return NewSessionResponse(
session_id=state.session_id,
models=self._build_model_state(state),
)
async def load_session(
self,
@ -289,7 +386,7 @@ class HermesACPAgent(acp.Agent):
await self._register_session_mcp_servers(state, mcp_servers)
logger.info("Loaded session %s", session_id)
self._schedule_available_commands_update(session_id)
return LoadSessionResponse()
return LoadSessionResponse(models=self._build_model_state(state))
async def resume_session(
self,
@ -305,7 +402,7 @@ class HermesACPAgent(acp.Agent):
await self._register_session_mcp_servers(state, mcp_servers)
logger.info("Resumed session %s", state.session_id)
self._schedule_available_commands_update(state.session_id)
return ResumeSessionResponse()
return ResumeSessionResponse(models=self._build_model_state(state))
async def cancel(self, session_id: str, **kwargs: Any) -> None:
state = self.session_manager.get_session(session_id)
@ -340,11 +437,20 @@ class HermesACPAgent(acp.Agent):
cwd: str | None = None,
**kwargs: Any,
) -> ListSessionsResponse:
infos = self.session_manager.list_sessions()
sessions = [
SessionInfo(session_id=s["session_id"], cwd=s["cwd"])
for s in infos
]
infos = self.session_manager.list_sessions(cwd=cwd)
sessions = []
for s in infos:
updated_at = s.get("updated_at")
if updated_at is not None and not isinstance(updated_at, str):
updated_at = str(updated_at)
sessions.append(
SessionInfo(
session_id=s["session_id"],
cwd=s["cwd"],
title=s.get("title"),
updated_at=updated_at,
)
)
return ListSessionsResponse(sessions=sessions)
# ---- Prompt (core) ------------------------------------------------------
@ -389,12 +495,13 @@ class HermesACPAgent(acp.Agent):
state.cancel_event.clear()
tool_call_ids: dict[str, Deque[str]] = defaultdict(deque)
tool_call_meta: dict[str, dict[str, Any]] = {}
previous_approval_cb = None
if conn:
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids)
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
thinking_cb = make_thinking_cb(conn, session_id, loop)
step_cb = make_step_cb(conn, session_id, loop, tool_call_ids)
step_cb = make_step_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
message_cb = make_message_cb(conn, session_id, loop)
approval_cb = make_approval_callback(conn.request_permission, loop, session_id)
else:
@ -449,6 +556,19 @@ class HermesACPAgent(acp.Agent):
self.session_manager.save_session(session_id)
final_response = result.get("final_response", "")
if final_response:
try:
from agent.title_generator import maybe_auto_title
maybe_auto_title(
self.session_manager._get_db(),
session_id,
user_text,
final_response,
state.history,
)
except Exception:
logger.debug("Failed to auto-title ACP session %s", session_id, exc_info=True)
if final_response and conn:
update = acp.update_agent_message_text(final_response)
await conn.session_update(session_id, update)
@ -556,27 +676,15 @@ class HermesACPAgent(acp.Agent):
provider = getattr(state.agent, "provider", None) or "auto"
return f"Current model: {model}\nProvider: {provider}"
new_model = args.strip()
target_provider = None
current_provider = getattr(state.agent, "provider", None) or "openrouter"
# Auto-detect provider for the requested model
try:
from hermes_cli.models import parse_model_input, detect_provider_for_model
target_provider, new_model = parse_model_input(new_model, current_provider)
if target_provider == current_provider:
detected = detect_provider_for_model(new_model, current_provider)
if detected:
target_provider, new_model = detected
except Exception:
logger.debug("Provider detection failed, using model as-is", exc_info=True)
target_provider, new_model = self._resolve_model_selection(args, current_provider)
state.model = new_model
state.agent = self.session_manager._make_agent(
session_id=state.session_id,
cwd=state.cwd,
model=new_model,
requested_provider=target_provider or current_provider,
requested_provider=target_provider,
)
self.session_manager.save_session(state.session_id)
provider_label = getattr(state.agent, "provider", None) or target_provider or current_provider
@ -678,20 +786,30 @@ class HermesACPAgent(acp.Agent):
"""Switch the model for a session (called by ACP protocol)."""
state = self.session_manager.get_session(session_id)
if state:
state.model = model_id
current_provider = getattr(state.agent, "provider", None)
current_base_url = getattr(state.agent, "base_url", None)
current_api_mode = getattr(state.agent, "api_mode", None)
requested_provider, resolved_model = self._resolve_model_selection(
model_id,
current_provider or "openrouter",
)
state.model = resolved_model
provider_changed = bool(current_provider and requested_provider != current_provider)
current_base_url = None if provider_changed else getattr(state.agent, "base_url", None)
current_api_mode = None if provider_changed else getattr(state.agent, "api_mode", None)
state.agent = self.session_manager._make_agent(
session_id=session_id,
cwd=state.cwd,
model=model_id,
requested_provider=current_provider,
model=resolved_model,
requested_provider=requested_provider,
base_url=current_base_url,
api_mode=current_api_mode,
)
self.session_manager.save_session(session_id)
logger.info("Session %s: model switched to %s", session_id, model_id)
logger.info(
"Session %s: model switched to %s via provider %s",
session_id,
resolved_model,
requested_provider,
)
return SetSessionModelResponse()
logger.warning("Session %s: model switch requested for missing session", session_id)
return None

View file

@ -13,8 +13,12 @@ from hermes_constants import get_hermes_home
import copy
import json
import logging
import os
import re
import sys
import time
import uuid
from datetime import datetime, timezone
from dataclasses import dataclass, field
from threading import Lock
from typing import Any, Dict, List, Optional
@ -22,6 +26,64 @@ from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
def _normalize_cwd_for_compare(cwd: str | None) -> str:
raw = str(cwd or ".").strip()
if not raw:
raw = "."
expanded = os.path.expanduser(raw)
# Normalize Windows drive paths into the equivalent WSL mount form so
# ACP history filters match the same workspace across Windows and WSL.
match = re.match(r"^([A-Za-z]):[\\/](.*)$", expanded)
if match:
drive = match.group(1).lower()
tail = match.group(2).replace("\\", "/")
expanded = f"/mnt/{drive}/{tail}"
elif re.match(r"^/mnt/[A-Za-z]/", expanded):
expanded = f"/mnt/{expanded[5].lower()}/{expanded[7:]}"
return os.path.normpath(expanded)
def _build_session_title(title: Any, preview: Any, cwd: str | None) -> str:
explicit = str(title or "").strip()
if explicit:
return explicit
preview_text = str(preview or "").strip()
if preview_text:
return preview_text
leaf = os.path.basename(str(cwd or "").rstrip("/\\"))
return leaf or "New thread"
def _format_updated_at(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str) and value.strip():
return value
try:
return datetime.fromtimestamp(float(value), tz=timezone.utc).isoformat()
except Exception:
return None
def _updated_at_sort_key(value: Any) -> float:
if value is None:
return float("-inf")
if isinstance(value, (int, float)):
return float(value)
raw = str(value).strip()
if not raw:
return float("-inf")
try:
return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp()
except Exception:
try:
return float(raw)
except Exception:
return float("-inf")
def _acp_stderr_print(*args, **kwargs) -> None:
"""Best-effort human-readable output sink for ACP stdio sessions.
@ -162,47 +224,78 @@ class SessionManager:
logger.info("Forked ACP session %s -> %s", session_id, new_id)
return state
def list_sessions(self) -> List[Dict[str, Any]]:
def list_sessions(self, cwd: str | None = None) -> List[Dict[str, Any]]:
"""Return lightweight info dicts for all sessions (memory + database)."""
normalized_cwd = _normalize_cwd_for_compare(cwd) if cwd else None
db = self._get_db()
persisted_rows: dict[str, dict[str, Any]] = {}
if db is not None:
try:
for row in db.list_sessions_rich(source="acp", limit=1000):
persisted_rows[str(row["id"])] = dict(row)
except Exception:
logger.debug("Failed to load ACP sessions from DB", exc_info=True)
# Collect in-memory sessions first.
with self._lock:
seen_ids = set(self._sessions.keys())
results = [
{
"session_id": s.session_id,
"cwd": s.cwd,
"model": s.model,
"history_len": len(s.history),
}
for s in self._sessions.values()
]
results = []
for s in self._sessions.values():
history_len = len(s.history)
if history_len <= 0:
continue
if normalized_cwd and _normalize_cwd_for_compare(s.cwd) != normalized_cwd:
continue
persisted = persisted_rows.get(s.session_id, {})
preview = next(
(
str(msg.get("content") or "").strip()
for msg in s.history
if msg.get("role") == "user" and str(msg.get("content") or "").strip()
),
persisted.get("preview") or "",
)
results.append(
{
"session_id": s.session_id,
"cwd": s.cwd,
"model": s.model,
"history_len": history_len,
"title": _build_session_title(persisted.get("title"), preview, s.cwd),
"updated_at": _format_updated_at(
persisted.get("last_active") or persisted.get("started_at") or time.time()
),
}
)
# Merge any persisted sessions not currently in memory.
db = self._get_db()
if db is not None:
try:
rows = db.search_sessions(source="acp", limit=1000)
for row in rows:
sid = row["id"]
if sid in seen_ids:
continue
# Extract cwd from model_config JSON.
cwd = "."
mc = row.get("model_config")
if mc:
try:
cwd = json.loads(mc).get("cwd", ".")
except (json.JSONDecodeError, TypeError):
pass
results.append({
"session_id": sid,
"cwd": cwd,
"model": row.get("model") or "",
"history_len": row.get("message_count") or 0,
})
except Exception:
logger.debug("Failed to list ACP sessions from DB", exc_info=True)
for sid, row in persisted_rows.items():
if sid in seen_ids:
continue
message_count = int(row.get("message_count") or 0)
if message_count <= 0:
continue
# Extract cwd from model_config JSON.
session_cwd = "."
mc = row.get("model_config")
if mc:
try:
session_cwd = json.loads(mc).get("cwd", ".")
except (json.JSONDecodeError, TypeError):
pass
if normalized_cwd and _normalize_cwd_for_compare(session_cwd) != normalized_cwd:
continue
results.append({
"session_id": sid,
"cwd": session_cwd,
"model": row.get("model") or "",
"history_len": message_count,
"title": _build_session_title(row.get("title"), row.get("preview"), session_cwd),
"updated_at": _format_updated_at(row.get("last_active") or row.get("started_at")),
})
results.sort(key=lambda item: _updated_at_sort_key(item.get("updated_at")), reverse=True)
return results
def update_cwd(self, session_id: str, cwd: str) -> Optional[SessionState]:

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import json
import uuid
from typing import Any, Dict, List, Optional
@ -96,6 +97,170 @@ def build_tool_title(tool_name: str, args: Dict[str, Any]) -> str:
return tool_name
def _build_patch_mode_content(patch_text: str) -> List[Any]:
"""Parse V4A patch mode input into ACP diff blocks when possible."""
if not patch_text:
return [acp.tool_content(acp.text_block(""))]
try:
from tools.patch_parser import OperationType, parse_v4a_patch
operations, error = parse_v4a_patch(patch_text)
if error or not operations:
return [acp.tool_content(acp.text_block(patch_text))]
content: List[Any] = []
for op in operations:
if op.operation == OperationType.UPDATE:
old_chunks: list[str] = []
new_chunks: list[str] = []
for hunk in op.hunks:
old_lines = [line.content for line in hunk.lines if line.prefix in (" ", "-")]
new_lines = [line.content for line in hunk.lines if line.prefix in (" ", "+")]
if old_lines or new_lines:
old_chunks.append("\n".join(old_lines))
new_chunks.append("\n".join(new_lines))
old_text = "\n...\n".join(chunk for chunk in old_chunks if chunk)
new_text = "\n...\n".join(chunk for chunk in new_chunks if chunk)
if old_text or new_text:
content.append(
acp.tool_diff_content(
path=op.file_path,
old_text=old_text or None,
new_text=new_text or "",
)
)
continue
if op.operation == OperationType.ADD:
added_lines = [line.content for hunk in op.hunks for line in hunk.lines if line.prefix == "+"]
content.append(
acp.tool_diff_content(
path=op.file_path,
new_text="\n".join(added_lines),
)
)
continue
if op.operation == OperationType.DELETE:
content.append(
acp.tool_diff_content(
path=op.file_path,
old_text=f"Delete file: {op.file_path}",
new_text="",
)
)
continue
if op.operation == OperationType.MOVE:
content.append(
acp.tool_content(acp.text_block(f"Move file: {op.file_path} -> {op.new_path}"))
)
return content or [acp.tool_content(acp.text_block(patch_text))]
except Exception:
return [acp.tool_content(acp.text_block(patch_text))]
def _strip_diff_prefix(path: str) -> str:
raw = str(path or "").strip()
if raw.startswith(("a/", "b/")):
return raw[2:]
return raw
def _parse_unified_diff_content(diff_text: str) -> List[Any]:
"""Convert unified diff text into ACP diff content blocks."""
if not diff_text:
return []
content: List[Any] = []
current_old_path: Optional[str] = None
current_new_path: Optional[str] = None
old_lines: list[str] = []
new_lines: list[str] = []
def _flush() -> None:
nonlocal current_old_path, current_new_path, old_lines, new_lines
if current_old_path is None and current_new_path is None:
return
path = current_new_path if current_new_path and current_new_path != "/dev/null" else current_old_path
if not path or path == "/dev/null":
current_old_path = None
current_new_path = None
old_lines = []
new_lines = []
return
content.append(
acp.tool_diff_content(
path=_strip_diff_prefix(path),
old_text="\n".join(old_lines) if old_lines else None,
new_text="\n".join(new_lines),
)
)
current_old_path = None
current_new_path = None
old_lines = []
new_lines = []
for line in diff_text.splitlines():
if line.startswith("--- "):
_flush()
current_old_path = line[4:].strip()
continue
if line.startswith("+++ "):
current_new_path = line[4:].strip()
continue
if line.startswith("@@"):
continue
if current_old_path is None and current_new_path is None:
continue
if line.startswith("+"):
new_lines.append(line[1:])
elif line.startswith("-"):
old_lines.append(line[1:])
elif line.startswith(" "):
shared = line[1:]
old_lines.append(shared)
new_lines.append(shared)
_flush()
return content
def _build_tool_complete_content(
tool_name: str,
result: Optional[str],
*,
function_args: Optional[Dict[str, Any]] = None,
snapshot: Any = None,
) -> List[Any]:
"""Build structured ACP completion content, falling back to plain text."""
display_result = result or ""
if len(display_result) > 5000:
display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)"
if tool_name in {"write_file", "patch", "skill_manage"}:
try:
from agent.display import extract_edit_diff
diff_text = extract_edit_diff(
tool_name,
result,
function_args=function_args,
snapshot=snapshot,
)
if isinstance(diff_text, str) and diff_text.strip():
diff_content = _parse_unified_diff_content(diff_text)
if diff_content:
return diff_content
except Exception:
pass
return [acp.tool_content(acp.text_block(display_result))]
# ---------------------------------------------------------------------------
# Build ACP content objects for tool-call events
# ---------------------------------------------------------------------------
@ -119,9 +284,8 @@ def build_tool_start(
new = arguments.get("new_string", "")
content = [acp.tool_diff_content(path=path, new_text=new, old_text=old)]
else:
# Patch mode — show the patch content as text
patch_text = arguments.get("patch", "")
content = [acp.tool_content(acp.text_block(patch_text))]
content = _build_patch_mode_content(patch_text)
return acp.start_tool_call(
tool_call_id, title, kind=kind, content=content, locations=locations,
raw_input=arguments,
@ -178,16 +342,17 @@ def build_tool_complete(
tool_call_id: str,
tool_name: str,
result: Optional[str] = None,
function_args: Optional[Dict[str, Any]] = None,
snapshot: Any = None,
) -> ToolCallProgress:
"""Create a ToolCallUpdate (progress) event for a completed tool call."""
kind = get_tool_kind(tool_name)
# Truncate very large results for the UI
display_result = result or ""
if len(display_result) > 5000:
display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)"
content = [acp.tool_content(acp.text_block(display_result))]
content = _build_tool_complete_content(
tool_name,
result,
function_args=function_args,
snapshot=snapshot,
)
return acp.update_tool_call(
tool_call_id,
kind=kind,