diff --git a/agent/copilot_acp_client.py b/agent/copilot_acp_client.py index 80a1f842bb..c2b3f99f87 100644 --- a/agent/copilot_acp_client.py +++ b/agent/copilot_acp_client.py @@ -31,6 +31,8 @@ logger = logging.getLogger(__name__) _TOOL_CALL_BLOCK_RE = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) _TOOL_CALL_JSON_RE = re.compile(r"\{\s*\"id\"\s*:\s*\"[^\"]+\"\s*,\s*\"type\"\s*:\s*\"function\"\s*,\s*\"function\"\s*:\s*\{.*?\}\s*\}", re.DOTALL) +_TOOL_CALL_START_TAG = "" +_TOOL_CALL_END_TAG = "" def _resolve_timeout_seconds() -> float: @@ -313,6 +315,8 @@ class CopilotACPClient: self._activity_callback = activity_callback self._stream_delta_callback = stream_delta_callback self._reasoning_callback = reasoning_callback + self._stream_filter_buffer = "" + self._stream_filter_in_tool_call = False self._acp_command = acp_command or command or _resolve_command() self._acp_args = list(acp_args or args or _resolve_args()) self._acp_cwd = str(Path(acp_cwd or os.getcwd()).resolve()) @@ -505,6 +509,7 @@ class CopilotACPClient: raise TimeoutError(f"Timed out waiting for Copilot ACP response to {method}.") try: + self._reset_stream_filters() _request( "initialize", { @@ -655,8 +660,11 @@ class CopilotACPClient: cb = self._stream_delta_callback if cb is None or not text: return + safe_text = self._filter_stream_text(text) + if not safe_text: + return try: - cb(text) + cb(safe_text) except Exception: logger.debug("Copilot ACP stream delta callback failed", exc_info=True) @@ -668,3 +676,53 @@ class CopilotACPClient: cb(text) except Exception: logger.debug("Copilot ACP reasoning callback failed", exc_info=True) + + def _reset_stream_filters(self) -> None: + self._stream_filter_buffer = "" + self._stream_filter_in_tool_call = False + + def _filter_stream_text(self, text: str) -> str: + if not text: + return "" + + buffer = self._stream_filter_buffer + text + self._stream_filter_buffer = "" + visible_parts: list[str] = [] + + while buffer: + if self._stream_filter_in_tool_call: + end_idx = buffer.find(_TOOL_CALL_END_TAG) + if end_idx < 0: + self._stream_filter_buffer = buffer + return "".join(visible_parts) + buffer = buffer[end_idx + len(_TOOL_CALL_END_TAG):] + self._stream_filter_in_tool_call = False + continue + + start_idx = buffer.find(_TOOL_CALL_START_TAG) + if start_idx >= 0: + if start_idx > 0: + visible_parts.append(buffer[:start_idx]) + buffer = buffer[start_idx + len(_TOOL_CALL_START_TAG):] + self._stream_filter_in_tool_call = True + continue + + holdback = self._longest_tool_tag_prefix_suffix(buffer) + if holdback: + visible_parts.append(buffer[:-holdback]) + self._stream_filter_buffer = buffer[-holdback:] + else: + visible_parts.append(buffer) + self._stream_filter_buffer = "" + break + + return "".join(visible_parts) + + @staticmethod + def _longest_tool_tag_prefix_suffix(text: str) -> int: + for tag in (_TOOL_CALL_START_TAG, _TOOL_CALL_END_TAG): + max_prefix = min(len(text), len(tag) - 1) + for size in range(max_prefix, 0, -1): + if text.endswith(tag[:size]): + return size + return 0 diff --git a/tests/agent/test_copilot_acp_client.py b/tests/agent/test_copilot_acp_client.py index 3e5fc77758..51f4d27aea 100644 --- a/tests/agent/test_copilot_acp_client.py +++ b/tests/agent/test_copilot_acp_client.py @@ -196,6 +196,46 @@ class CopilotACPClientSafetyTests(unittest.TestCase): self.assertEqual(streamed, ["hello"]) self.assertEqual(reasoned, ["thinking"]) + def test_session_update_filters_tool_call_markup_from_stream_callbacks(self) -> None: + streamed: list[str] = [] + + self.client = CopilotACPClient( + acp_cwd="/tmp", + stream_delta_callback=streamed.append, + ) + + text_parts: list[str] = [] + process = _FakeProcess() + chunks = [ + "Before ", + "{\"id\":\"1\",\"type\":\"function\",\"function\":{\"name\":\"terminal\",\"arguments\":\"{}\"}}", + "", + " after", + ] + + for chunk in chunks: + handled = self.client._handle_server_message( + { + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "update": { + "sessionUpdate": "agent_message_chunk", + "content": {"text": chunk}, + } + }, + }, + process=process, + cwd="/tmp", + text_parts=text_parts, + reasoning_parts=[], + ) + self.assertTrue(handled) + + self.assertEqual("".join(text_parts), "".join(chunks)) + self.assertEqual("".join(streamed), "Before after") + if __name__ == "__main__": unittest.main()