fix: surface copilot acp progress and stale detection

This commit is contained in:
David Zhang 2026-04-24 17:54:46 +07:00
parent 18f3fc8a6f
commit 0524a40790
4 changed files with 196 additions and 13 deletions

View file

@ -9,6 +9,7 @@ back into the minimal shape Hermes expects from an OpenAI client.
from __future__ import annotations from __future__ import annotations
import json import json
import logging
import os import os
import queue import queue
import re import re
@ -25,12 +26,31 @@ from agent.file_safety import get_read_block_error, is_write_denied
from agent.redact import redact_sensitive_text from agent.redact import redact_sensitive_text
ACP_MARKER_BASE_URL = "acp://copilot" ACP_MARKER_BASE_URL = "acp://copilot"
_DEFAULT_TIMEOUT_SECONDS = 900.0
logger = logging.getLogger(__name__)
_TOOL_CALL_BLOCK_RE = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL) _TOOL_CALL_BLOCK_RE = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)
_TOOL_CALL_JSON_RE = re.compile(r"\{\s*\"id\"\s*:\s*\"[^\"]+\"\s*,\s*\"type\"\s*:\s*\"function\"\s*,\s*\"function\"\s*:\s*\{.*?\}\s*\}", re.DOTALL) _TOOL_CALL_JSON_RE = re.compile(r"\{\s*\"id\"\s*:\s*\"[^\"]+\"\s*,\s*\"type\"\s*:\s*\"function\"\s*,\s*\"function\"\s*:\s*\{.*?\}\s*\}", re.DOTALL)
def _resolve_timeout_seconds() -> float:
raw = os.getenv("HERMES_COPILOT_ACP_TIMEOUT_SECONDS", "").strip()
try:
timeout = float(raw) if raw else 300.0
except ValueError:
return 300.0
return timeout if timeout > 0 else 300.0
def _resolve_inactivity_timeout_seconds() -> float:
raw = os.getenv("HERMES_COPILOT_ACP_INACTIVITY_TIMEOUT_SECONDS", "").strip()
try:
timeout = float(raw) if raw else 300.0
except ValueError:
return 300.0
return timeout if timeout > 0 else 300.0
def _resolve_command() -> str: def _resolve_command() -> str:
return ( return (
os.getenv("HERMES_COPILOT_ACP_COMMAND", "").strip() os.getenv("HERMES_COPILOT_ACP_COMMAND", "").strip()
@ -277,6 +297,9 @@ class CopilotACPClient:
api_key: str | None = None, api_key: str | None = None,
base_url: str | None = None, base_url: str | None = None,
default_headers: dict[str, str] | None = None, default_headers: dict[str, str] | None = None,
activity_callback: Any | None = None,
stream_delta_callback: Any | None = None,
reasoning_callback: Any | None = None,
acp_command: str | None = None, acp_command: str | None = None,
acp_args: list[str] | None = None, acp_args: list[str] | None = None,
acp_cwd: str | None = None, acp_cwd: str | None = None,
@ -287,6 +310,9 @@ class CopilotACPClient:
self.api_key = api_key or "copilot-acp" self.api_key = api_key or "copilot-acp"
self.base_url = base_url or ACP_MARKER_BASE_URL self.base_url = base_url or ACP_MARKER_BASE_URL
self._default_headers = dict(default_headers or {}) self._default_headers = dict(default_headers or {})
self._activity_callback = activity_callback
self._stream_delta_callback = stream_delta_callback
self._reasoning_callback = reasoning_callback
self._acp_command = acp_command or command or _resolve_command() self._acp_command = acp_command or command or _resolve_command()
self._acp_args = list(acp_args or args or _resolve_args()) self._acp_args = list(acp_args or args or _resolve_args())
self._acp_cwd = str(Path(acp_cwd or os.getcwd()).resolve()) self._acp_cwd = str(Path(acp_cwd or os.getcwd()).resolve())
@ -331,7 +357,7 @@ class CopilotACPClient:
# Normalise timeout: run_agent.py may pass an httpx.Timeout object # Normalise timeout: run_agent.py may pass an httpx.Timeout object
# (used natively by the OpenAI SDK) rather than a plain float. # (used natively by the OpenAI SDK) rather than a plain float.
if timeout is None: if timeout is None:
_effective_timeout = _DEFAULT_TIMEOUT_SECONDS _effective_timeout = _resolve_timeout_seconds()
elif isinstance(timeout, (int, float)): elif isinstance(timeout, (int, float)):
_effective_timeout = float(timeout) _effective_timeout = float(timeout)
else: else:
@ -342,7 +368,7 @@ class CopilotACPClient:
for attr in ("read", "write", "connect", "pool", "timeout") for attr in ("read", "write", "connect", "pool", "timeout")
] ]
_numeric = [float(v) for v in _candidates if isinstance(v, (int, float))] _numeric = [float(v) for v in _candidates if isinstance(v, (int, float))]
_effective_timeout = max(_numeric) if _numeric else _DEFAULT_TIMEOUT_SECONDS _effective_timeout = max(_numeric) if _numeric else _resolve_timeout_seconds()
response_text, reasoning_text = self._run_prompt( response_text, reasoning_text = self._run_prompt(
prompt_text, prompt_text,
@ -436,14 +462,25 @@ class CopilotACPClient:
proc.stdin.flush() proc.stdin.flush()
deadline = time.time() + timeout_seconds deadline = time.time() + timeout_seconds
last_message_time = time.time()
inactivity_timeout = min(timeout_seconds, _resolve_inactivity_timeout_seconds())
while time.time() < deadline: while time.time() < deadline:
if proc.poll() is not None: if proc.poll() is not None:
break break
try: try:
msg = inbox.get(timeout=0.1) msg = inbox.get(timeout=0.1)
except queue.Empty: except queue.Empty:
silence = time.time() - last_message_time
if silence > inactivity_timeout:
raise TimeoutError(
f"Copilot ACP {method} went silent for {int(silence)}s "
f"(threshold: {int(inactivity_timeout)}s)."
)
continue continue
last_message_time = time.time()
self._notify_activity(f"copilot-acp:{method}")
if self._handle_server_message( if self._handle_server_message(
msg, msg,
process=proc, process=proc,
@ -539,8 +576,10 @@ class CopilotACPClient:
chunk_text = str(content.get("text") or "") chunk_text = str(content.get("text") or "")
if kind == "agent_message_chunk" and chunk_text and text_parts is not None: if kind == "agent_message_chunk" and chunk_text and text_parts is not None:
text_parts.append(chunk_text) text_parts.append(chunk_text)
self._emit_stream_delta(chunk_text)
elif kind == "agent_thought_chunk" and chunk_text and reasoning_parts is not None: elif kind == "agent_thought_chunk" and chunk_text and reasoning_parts is not None:
reasoning_parts.append(chunk_text) reasoning_parts.append(chunk_text)
self._emit_reasoning_delta(chunk_text)
return True return True
if process.stdin is None: if process.stdin is None:
@ -602,3 +641,30 @@ class CopilotACPClient:
process.stdin.write(json.dumps(response) + "\n") process.stdin.write(json.dumps(response) + "\n")
process.stdin.flush() process.stdin.flush()
return True return True
def _notify_activity(self, detail: str | None = None) -> None:
cb = self._activity_callback
if cb is None:
return
try:
cb(detail)
except Exception:
logger.debug("Copilot ACP activity callback failed", exc_info=True)
def _emit_stream_delta(self, text: str) -> None:
cb = self._stream_delta_callback
if cb is None or not text:
return
try:
cb(text)
except Exception:
logger.debug("Copilot ACP stream delta callback failed", exc_info=True)
def _emit_reasoning_delta(self, text: str) -> None:
cb = self._reasoning_callback
if cb is None or not text:
return
try:
cb(text)
except Exception:
logger.debug("Copilot ACP reasoning callback failed", exc_info=True)

View file

@ -5340,7 +5340,12 @@ class AIAgent:
provider fallback. provider fallback.
""" """
result = {"response": None, "error": None} result = {"response": None, "error": None}
request_client_holder = {"client": None} request_client_holder = {"client": None, "last_provider_activity": None}
def _note_provider_activity(detail: str | None = None) -> None:
request_client_holder["last_provider_activity"] = time.time()
if detail:
self._touch_activity(detail)
def _call(): def _call():
try: try:
@ -5368,7 +5373,21 @@ class AIAgent:
raw_response = client.converse(**api_kwargs) raw_response = client.converse(**api_kwargs)
result["response"] = normalize_converse_response(raw_response) result["response"] = normalize_converse_response(raw_response)
else: else:
request_client_holder["client"] = self._create_request_openai_client(reason="chat_completion_request") if self.provider == "copilot-acp":
with self._openai_client_lock():
request_kwargs = dict(self._client_kwargs)
request_kwargs.update({
"activity_callback": _note_provider_activity,
"stream_delta_callback": self._fire_stream_delta,
"reasoning_callback": self._fire_reasoning_delta,
})
request_client_holder["client"] = self._create_openai_client(
request_kwargs,
reason="chat_completion_request",
shared=False,
)
else:
request_client_holder["client"] = self._create_request_openai_client(reason="chat_completion_request")
result["response"] = request_client_holder["client"].chat.completions.create(**api_kwargs) result["response"] = request_client_holder["client"].chat.completions.create(**api_kwargs)
except Exception as e: except Exception as e:
result["error"] = e result["error"] = e
@ -5408,16 +5427,18 @@ class AIAgent:
# Stale-call detector: kill the connection if no response # Stale-call detector: kill the connection if no response
# arrives within the configured timeout. # arrives within the configured timeout.
_elapsed = time.time() - _call_start _elapsed = time.time() - _call_start
if _elapsed > _stale_timeout: _last_provider_activity = request_client_holder.get("last_provider_activity") or _call_start
_stale_elapsed = time.time() - _last_provider_activity
if _stale_elapsed > _stale_timeout:
_est_ctx = sum(len(str(v)) for v in api_kwargs.get("messages", [])) // 4 _est_ctx = sum(len(str(v)) for v in api_kwargs.get("messages", [])) // 4
logger.warning( logger.warning(
"Non-streaming API call stale for %.0fs (threshold %.0fs). " "Non-streaming API call stale for %.0fs since last provider activity "
"model=%s context=~%s tokens. Killing connection.", "(wall %.0fs, threshold %.0fs). model=%s context=~%s tokens. Killing connection.",
_elapsed, _stale_timeout, _stale_elapsed, _elapsed, _stale_timeout,
api_kwargs.get("model", "unknown"), f"{_est_ctx:,}", api_kwargs.get("model", "unknown"), f"{_est_ctx:,}",
) )
self._emit_status( self._emit_status(
f"⚠️ No response from provider for {int(_elapsed)}s " f"⚠️ No provider activity for {int(_stale_elapsed)}s "
f"(non-streaming, model: {api_kwargs.get('model', 'unknown')}). " f"(non-streaming, model: {api_kwargs.get('model', 'unknown')}). "
f"Aborting call." f"Aborting call."
) )
@ -5438,7 +5459,7 @@ class AIAgent:
except Exception: except Exception:
pass pass
self._touch_activity( self._touch_activity(
f"stale non-streaming call killed after {int(_elapsed)}s" f"stale non-streaming call killed after {int(_stale_elapsed)}s of silence"
) )
# Wait briefly for the thread to notice the closed connection. # Wait briefly for the thread to notice the closed connection.
t.join(timeout=2.0) t.join(timeout=2.0)
@ -5535,7 +5556,7 @@ class AIAgent:
cb(text) cb(text)
delivered = True delivered = True
except Exception: except Exception:
pass logger.debug("stream_delta_callback error", exc_info=True)
if delivered: if delivered:
self._record_streamed_assistant_text(text) self._record_streamed_assistant_text(text)
@ -5546,7 +5567,7 @@ class AIAgent:
try: try:
cb(text) cb(text)
except Exception: except Exception:
pass logger.debug("reasoning_callback error", exc_info=True)
def _fire_tool_gen_started(self, tool_name: str) -> None: def _fire_tool_gen_started(self, tool_name: str) -> None:
"""Notify display layer that the model is generating tool call arguments. """Notify display layer that the model is generating tool call arguments.

View file

@ -141,6 +141,61 @@ class CopilotACPClientSafetyTests(unittest.TestCase):
self.assertIn("error", response) self.assertIn("error", response)
self.assertFalse(outside.exists()) self.assertFalse(outside.exists())
def test_session_update_streams_live_text_and_reasoning_callbacks(self) -> None:
streamed: list[str] = []
reasoned: list[str] = []
self.client = CopilotACPClient(
acp_cwd="/tmp",
stream_delta_callback=streamed.append,
reasoning_callback=reasoned.append,
)
text_parts: list[str] = []
reasoning_parts: list[str] = []
process = _FakeProcess()
handled = self.client._handle_server_message(
{
"jsonrpc": "2.0",
"method": "session/update",
"params": {
"update": {
"sessionUpdate": "agent_message_chunk",
"content": {"text": "hello"},
}
},
},
process=process,
cwd="/tmp",
text_parts=text_parts,
reasoning_parts=reasoning_parts,
)
self.assertTrue(handled)
handled = self.client._handle_server_message(
{
"jsonrpc": "2.0",
"method": "session/update",
"params": {
"update": {
"sessionUpdate": "agent_thought_chunk",
"content": {"text": "thinking"},
}
},
},
process=process,
cwd="/tmp",
text_parts=text_parts,
reasoning_parts=reasoning_parts,
)
self.assertTrue(handled)
self.assertEqual(text_parts, ["hello"])
self.assertEqual(reasoning_parts, ["thinking"])
self.assertEqual(streamed, ["hello"])
self.assertEqual(reasoned, ["thinking"])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -1,5 +1,6 @@
import sys import sys
import threading import threading
import time
import types import types
from types import SimpleNamespace from types import SimpleNamespace
@ -154,6 +155,46 @@ def test_concurrent_requests_do_not_break_each_other_when_one_client_closes(monk
assert len(factory.calls) == 2 assert len(factory.calls) == 2
def test_copilot_acp_provider_activity_prevents_false_stale_timeout(monkeypatch):
class FakeCopilotRequestClient(FakeRequestClient):
def __init__(self, kwargs):
self._kwargs = dict(kwargs)
super().__init__(self._responder)
def _responder(self, **kwargs):
activity_cb = self._kwargs.get("activity_callback")
stream_cb = self._kwargs.get("stream_delta_callback")
reasoning_cb = self._kwargs.get("reasoning_callback")
for i in range(8):
if activity_cb is not None:
activity_cb(f"copilot-acp heartbeat {i}")
time.sleep(0.05)
if reasoning_cb is not None:
reasoning_cb("thinking")
if stream_cb is not None:
stream_cb("done")
return {"ok": True}
def _fake_acp_factory(**kwargs):
return FakeCopilotRequestClient(kwargs)
monkeypatch.setattr("agent.copilot_acp_client.CopilotACPClient", _fake_acp_factory)
agent = _build_agent()
agent.provider = "copilot-acp"
agent.base_url = "acp://copilot"
agent.model = "gpt-5.4"
agent._client_kwargs = {"api_key": "copilot-acp", "base_url": agent.base_url}
agent.client = FakeSharedClient(lambda **kwargs: {"shared": True})
agent.stream_delta_callback = lambda _delta: None
agent.reasoning_callback = lambda _delta: None
agent.status_callback = None
agent._compute_non_stream_stale_timeout = lambda _messages: 0.08
result = agent._interruptible_api_call({"model": agent.model, "messages": []})
assert result == {"ok": True}
def test_streaming_call_recreates_closed_shared_client_before_request(monkeypatch): def test_streaming_call_recreates_closed_shared_client_before_request(monkeypatch):
chunks = iter([ chunks = iter([