From b481348fbc2d1ac65ef86fe3bdee2af1e661acfb Mon Sep 17 00:00:00 2001 From: sgaofen <135070653+sgaofen@users.noreply.github.com> Date: Thu, 23 Apr 2026 00:30:22 -0700 Subject: [PATCH] fix(agent): stream copilot ACP chat completions --- agent/copilot_acp_client.py | 52 ++++++++++++++++++- tests/agent/test_copilot_acp_client.py | 71 ++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 1 deletion(-) diff --git a/agent/copilot_acp_client.py b/agent/copilot_acp_client.py index 79030146f36..ce3ec2c5c40 100644 --- a/agent/copilot_acp_client.py +++ b/agent/copilot_acp_client.py @@ -249,6 +249,52 @@ def _build_openai_tool_call( ) +def _completion_to_stream_chunks(completion: SimpleNamespace) -> list[SimpleNamespace]: + """Convert a one-shot ACP response into OpenAI-style stream chunks.""" + choice = completion.choices[0] + message = choice.message + tool_call_deltas = None + if message.tool_calls: + tool_call_deltas = [] + for index, tool_call in enumerate(message.tool_calls): + tool_call_deltas.append( + SimpleNamespace( + index=index, + id=getattr(tool_call, "id", None), + type=getattr(tool_call, "type", "function"), + function=SimpleNamespace( + name=getattr(tool_call.function, "name", None), + arguments=getattr(tool_call.function, "arguments", None), + ), + ) + ) + + delta = SimpleNamespace( + role="assistant", + content=message.content or None, + tool_calls=tool_call_deltas, + reasoning_content=message.reasoning_content, + reasoning=message.reasoning, + ) + data_chunk = SimpleNamespace( + choices=[ + SimpleNamespace( + index=0, + delta=delta, + finish_reason=choice.finish_reason, + ) + ], + model=completion.model, + usage=None, + ) + usage_chunk = SimpleNamespace( + choices=[], + model=completion.model, + usage=completion.usage, + ) + return [data_chunk, usage_chunk] + + def _extract_tool_calls_from_text(text: str) -> tuple[list[ChatCompletionMessageToolCall], str]: if not isinstance(text, str) or not text.strip(): return [], "" @@ -399,6 +445,7 @@ class CopilotACPClient: timeout: float | None = None, tools: list[dict[str, Any]] | None = None, tool_choice: Any = None, + stream: bool = False, **_: Any, ) -> Any: prompt_text = _format_messages_as_prompt( @@ -445,11 +492,14 @@ class CopilotACPClient: ) finish_reason = "tool_calls" if tool_calls else "stop" choice = SimpleNamespace(message=assistant_message, finish_reason=finish_reason) - return SimpleNamespace( + completion = SimpleNamespace( choices=[choice], usage=usage, model=model or "copilot-acp", ) + if stream: + return _completion_to_stream_chunks(completion) + return completion def _run_prompt(self, prompt_text: str, *, timeout_seconds: float) -> tuple[str, str]: try: diff --git a/tests/agent/test_copilot_acp_client.py b/tests/agent/test_copilot_acp_client.py index dc209f27029..5f2d3c234fe 100644 --- a/tests/agent/test_copilot_acp_client.py +++ b/tests/agent/test_copilot_acp_client.py @@ -56,6 +56,77 @@ class CopilotACPClientSafetyTests(unittest.TestCase): self.assertEqual(dict(tool_call.function)["name"], "read_file") self.assertEqual(choice.message.content, "I'll inspect that.") + def test_stream_true_returns_iterable_text_chunks(self) -> None: + with patch.object(self.client, "_run_prompt", return_value=("Hello from ACP", "")): + stream = self.client._create_chat_completion( + model="copilot-acp", + messages=[{"role": "user", "content": "hello"}], + stream=True, + ) + + chunks = list(stream) + self.assertEqual(len(chunks), 2) + self.assertEqual(chunks[0].choices[0].delta.content, "Hello from ACP") + self.assertIsNone(chunks[0].choices[0].delta.tool_calls) + self.assertEqual(chunks[0].choices[0].finish_reason, "stop") + self.assertEqual(chunks[1].choices, []) + self.assertEqual(chunks[1].usage.total_tokens, 0) + + def test_stream_true_preserves_tool_call_deltas(self) -> None: + tool_response = ( + "" + '{"id":"call_read","type":"function",' + '"function":{"name":"read_file","arguments":"{\\"path\\":\\"README.md\\"}"}}' + "" + ) + + with patch.object(self.client, "_run_prompt", return_value=(tool_response, "")): + stream = self.client._create_chat_completion( + model="copilot-acp", + messages=[{"role": "user", "content": "read README.md"}], + stream=True, + ) + + chunks = list(stream) + delta = chunks[0].choices[0].delta + self.assertIsNone(delta.content) + self.assertEqual(chunks[0].choices[0].finish_reason, "tool_calls") + self.assertEqual(len(delta.tool_calls), 1) + tool_delta = delta.tool_calls[0] + self.assertEqual(tool_delta.index, 0) + self.assertEqual(tool_delta.id, "call_read") + self.assertEqual(tool_delta.function.name, "read_file") + self.assertEqual( + json.loads(tool_delta.function.arguments), + {"path": "README.md"}, + ) + self.assertEqual(chunks[1].choices, []) + + def test_timeout_object_is_coerced_for_streaming_requests(self) -> None: + captured: dict[str, float] = {} + + def fake_run_prompt(prompt_text: str, *, timeout_seconds: float) -> tuple[str, str]: + captured["timeout"] = timeout_seconds + return "ok", "" + + timeout = type( + "TimeoutLike", + (), + {"read": 12.0, "write": 5.0, "connect": 3.0, "pool": 1.0}, + )() + + with patch.object(self.client, "_run_prompt", side_effect=fake_run_prompt): + list( + self.client._create_chat_completion( + model="copilot-acp", + messages=[{"role": "user", "content": "hello"}], + timeout=timeout, + stream=True, + ) + ) + + self.assertEqual(captured["timeout"], 12.0) + def _dispatch(self, message: dict, *, cwd: str) -> dict: process = _FakeProcess() handled = self.client._handle_server_message(