fix: suppress streamed copilot tool-call markup

This commit is contained in:
David Zhang 2026-04-24 18:09:04 +07:00
parent 0524a40790
commit 4a672ab6d9
2 changed files with 99 additions and 1 deletions

View file

@ -31,6 +31,8 @@ logger = logging.getLogger(__name__)
_TOOL_CALL_BLOCK_RE = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)
_TOOL_CALL_JSON_RE = re.compile(r"\{\s*\"id\"\s*:\s*\"[^\"]+\"\s*,\s*\"type\"\s*:\s*\"function\"\s*,\s*\"function\"\s*:\s*\{.*?\}\s*\}", re.DOTALL)
_TOOL_CALL_START_TAG = "<tool_call>"
_TOOL_CALL_END_TAG = "</tool_call>"
def _resolve_timeout_seconds() -> float:
@ -313,6 +315,8 @@ class CopilotACPClient:
self._activity_callback = activity_callback
self._stream_delta_callback = stream_delta_callback
self._reasoning_callback = reasoning_callback
self._stream_filter_buffer = ""
self._stream_filter_in_tool_call = False
self._acp_command = acp_command or command or _resolve_command()
self._acp_args = list(acp_args or args or _resolve_args())
self._acp_cwd = str(Path(acp_cwd or os.getcwd()).resolve())
@ -505,6 +509,7 @@ class CopilotACPClient:
raise TimeoutError(f"Timed out waiting for Copilot ACP response to {method}.")
try:
self._reset_stream_filters()
_request(
"initialize",
{
@ -655,8 +660,11 @@ class CopilotACPClient:
cb = self._stream_delta_callback
if cb is None or not text:
return
safe_text = self._filter_stream_text(text)
if not safe_text:
return
try:
cb(text)
cb(safe_text)
except Exception:
logger.debug("Copilot ACP stream delta callback failed", exc_info=True)
@ -668,3 +676,53 @@ class CopilotACPClient:
cb(text)
except Exception:
logger.debug("Copilot ACP reasoning callback failed", exc_info=True)
def _reset_stream_filters(self) -> None:
self._stream_filter_buffer = ""
self._stream_filter_in_tool_call = False
def _filter_stream_text(self, text: str) -> str:
if not text:
return ""
buffer = self._stream_filter_buffer + text
self._stream_filter_buffer = ""
visible_parts: list[str] = []
while buffer:
if self._stream_filter_in_tool_call:
end_idx = buffer.find(_TOOL_CALL_END_TAG)
if end_idx < 0:
self._stream_filter_buffer = buffer
return "".join(visible_parts)
buffer = buffer[end_idx + len(_TOOL_CALL_END_TAG):]
self._stream_filter_in_tool_call = False
continue
start_idx = buffer.find(_TOOL_CALL_START_TAG)
if start_idx >= 0:
if start_idx > 0:
visible_parts.append(buffer[:start_idx])
buffer = buffer[start_idx + len(_TOOL_CALL_START_TAG):]
self._stream_filter_in_tool_call = True
continue
holdback = self._longest_tool_tag_prefix_suffix(buffer)
if holdback:
visible_parts.append(buffer[:-holdback])
self._stream_filter_buffer = buffer[-holdback:]
else:
visible_parts.append(buffer)
self._stream_filter_buffer = ""
break
return "".join(visible_parts)
@staticmethod
def _longest_tool_tag_prefix_suffix(text: str) -> int:
for tag in (_TOOL_CALL_START_TAG, _TOOL_CALL_END_TAG):
max_prefix = min(len(text), len(tag) - 1)
for size in range(max_prefix, 0, -1):
if text.endswith(tag[:size]):
return size
return 0

View file

@ -196,6 +196,46 @@ class CopilotACPClientSafetyTests(unittest.TestCase):
self.assertEqual(streamed, ["hello"])
self.assertEqual(reasoned, ["thinking"])
def test_session_update_filters_tool_call_markup_from_stream_callbacks(self) -> None:
streamed: list[str] = []
self.client = CopilotACPClient(
acp_cwd="/tmp",
stream_delta_callback=streamed.append,
)
text_parts: list[str] = []
process = _FakeProcess()
chunks = [
"Before ",
"<tool",
"_call>{\"id\":\"1\",\"type\":\"function\",\"function\":{\"name\":\"terminal\",\"arguments\":\"{}\"}}",
"</tool_call>",
" after",
]
for chunk in chunks:
handled = self.client._handle_server_message(
{
"jsonrpc": "2.0",
"method": "session/update",
"params": {
"update": {
"sessionUpdate": "agent_message_chunk",
"content": {"text": chunk},
}
},
},
process=process,
cwd="/tmp",
text_parts=text_parts,
reasoning_parts=[],
)
self.assertTrue(handled)
self.assertEqual("".join(text_parts), "".join(chunks))
self.assertEqual("".join(streamed), "Before after")
if __name__ == "__main__":
unittest.main()