fix(agent): stream copilot ACP chat completions

This commit is contained in:
sgaofen 2026-04-23 00:30:22 -07:00 committed by Teknium
parent 0106082d1f
commit b481348fbc
2 changed files with 122 additions and 1 deletions

View file

@ -249,6 +249,52 @@ def _build_openai_tool_call(
)
def _completion_to_stream_chunks(completion: SimpleNamespace) -> list[SimpleNamespace]:
"""Convert a one-shot ACP response into OpenAI-style stream chunks."""
choice = completion.choices[0]
message = choice.message
tool_call_deltas = None
if message.tool_calls:
tool_call_deltas = []
for index, tool_call in enumerate(message.tool_calls):
tool_call_deltas.append(
SimpleNamespace(
index=index,
id=getattr(tool_call, "id", None),
type=getattr(tool_call, "type", "function"),
function=SimpleNamespace(
name=getattr(tool_call.function, "name", None),
arguments=getattr(tool_call.function, "arguments", None),
),
)
)
delta = SimpleNamespace(
role="assistant",
content=message.content or None,
tool_calls=tool_call_deltas,
reasoning_content=message.reasoning_content,
reasoning=message.reasoning,
)
data_chunk = SimpleNamespace(
choices=[
SimpleNamespace(
index=0,
delta=delta,
finish_reason=choice.finish_reason,
)
],
model=completion.model,
usage=None,
)
usage_chunk = SimpleNamespace(
choices=[],
model=completion.model,
usage=completion.usage,
)
return [data_chunk, usage_chunk]
def _extract_tool_calls_from_text(text: str) -> tuple[list[ChatCompletionMessageToolCall], str]:
if not isinstance(text, str) or not text.strip():
return [], ""
@ -399,6 +445,7 @@ class CopilotACPClient:
timeout: float | None = None,
tools: list[dict[str, Any]] | None = None,
tool_choice: Any = None,
stream: bool = False,
**_: Any,
) -> Any:
prompt_text = _format_messages_as_prompt(
@ -445,11 +492,14 @@ class CopilotACPClient:
)
finish_reason = "tool_calls" if tool_calls else "stop"
choice = SimpleNamespace(message=assistant_message, finish_reason=finish_reason)
return SimpleNamespace(
completion = SimpleNamespace(
choices=[choice],
usage=usage,
model=model or "copilot-acp",
)
if stream:
return _completion_to_stream_chunks(completion)
return completion
def _run_prompt(self, prompt_text: str, *, timeout_seconds: float) -> tuple[str, str]:
try:

View file

@ -56,6 +56,77 @@ class CopilotACPClientSafetyTests(unittest.TestCase):
self.assertEqual(dict(tool_call.function)["name"], "read_file")
self.assertEqual(choice.message.content, "I'll inspect that.")
def test_stream_true_returns_iterable_text_chunks(self) -> None:
with patch.object(self.client, "_run_prompt", return_value=("Hello from ACP", "")):
stream = self.client._create_chat_completion(
model="copilot-acp",
messages=[{"role": "user", "content": "hello"}],
stream=True,
)
chunks = list(stream)
self.assertEqual(len(chunks), 2)
self.assertEqual(chunks[0].choices[0].delta.content, "Hello from ACP")
self.assertIsNone(chunks[0].choices[0].delta.tool_calls)
self.assertEqual(chunks[0].choices[0].finish_reason, "stop")
self.assertEqual(chunks[1].choices, [])
self.assertEqual(chunks[1].usage.total_tokens, 0)
def test_stream_true_preserves_tool_call_deltas(self) -> None:
tool_response = (
"<tool_call>"
'{"id":"call_read","type":"function",'
'"function":{"name":"read_file","arguments":"{\\"path\\":\\"README.md\\"}"}}'
"</tool_call>"
)
with patch.object(self.client, "_run_prompt", return_value=(tool_response, "")):
stream = self.client._create_chat_completion(
model="copilot-acp",
messages=[{"role": "user", "content": "read README.md"}],
stream=True,
)
chunks = list(stream)
delta = chunks[0].choices[0].delta
self.assertIsNone(delta.content)
self.assertEqual(chunks[0].choices[0].finish_reason, "tool_calls")
self.assertEqual(len(delta.tool_calls), 1)
tool_delta = delta.tool_calls[0]
self.assertEqual(tool_delta.index, 0)
self.assertEqual(tool_delta.id, "call_read")
self.assertEqual(tool_delta.function.name, "read_file")
self.assertEqual(
json.loads(tool_delta.function.arguments),
{"path": "README.md"},
)
self.assertEqual(chunks[1].choices, [])
def test_timeout_object_is_coerced_for_streaming_requests(self) -> None:
captured: dict[str, float] = {}
def fake_run_prompt(prompt_text: str, *, timeout_seconds: float) -> tuple[str, str]:
captured["timeout"] = timeout_seconds
return "ok", ""
timeout = type(
"TimeoutLike",
(),
{"read": 12.0, "write": 5.0, "connect": 3.0, "pool": 1.0},
)()
with patch.object(self.client, "_run_prompt", side_effect=fake_run_prompt):
list(
self.client._create_chat_completion(
model="copilot-acp",
messages=[{"role": "user", "content": "hello"}],
timeout=timeout,
stream=True,
)
)
self.assertEqual(captured["timeout"], 12.0)
def _dispatch(self, message: dict, *, cwd: str) -> dict:
process = _FakeProcess()
handled = self.client._handle_server_message(