mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix: suppress streamed copilot tool-call markup
This commit is contained in:
parent
0524a40790
commit
4a672ab6d9
2 changed files with 99 additions and 1 deletions
|
|
@ -31,6 +31,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
_TOOL_CALL_BLOCK_RE = re.compile(r"<tool_call>\s*(\{.*?\})\s*</tool_call>", re.DOTALL)
|
||||
_TOOL_CALL_JSON_RE = re.compile(r"\{\s*\"id\"\s*:\s*\"[^\"]+\"\s*,\s*\"type\"\s*:\s*\"function\"\s*,\s*\"function\"\s*:\s*\{.*?\}\s*\}", re.DOTALL)
|
||||
_TOOL_CALL_START_TAG = "<tool_call>"
|
||||
_TOOL_CALL_END_TAG = "</tool_call>"
|
||||
|
||||
|
||||
def _resolve_timeout_seconds() -> float:
|
||||
|
|
@ -313,6 +315,8 @@ class CopilotACPClient:
|
|||
self._activity_callback = activity_callback
|
||||
self._stream_delta_callback = stream_delta_callback
|
||||
self._reasoning_callback = reasoning_callback
|
||||
self._stream_filter_buffer = ""
|
||||
self._stream_filter_in_tool_call = False
|
||||
self._acp_command = acp_command or command or _resolve_command()
|
||||
self._acp_args = list(acp_args or args or _resolve_args())
|
||||
self._acp_cwd = str(Path(acp_cwd or os.getcwd()).resolve())
|
||||
|
|
@ -505,6 +509,7 @@ class CopilotACPClient:
|
|||
raise TimeoutError(f"Timed out waiting for Copilot ACP response to {method}.")
|
||||
|
||||
try:
|
||||
self._reset_stream_filters()
|
||||
_request(
|
||||
"initialize",
|
||||
{
|
||||
|
|
@ -655,8 +660,11 @@ class CopilotACPClient:
|
|||
cb = self._stream_delta_callback
|
||||
if cb is None or not text:
|
||||
return
|
||||
safe_text = self._filter_stream_text(text)
|
||||
if not safe_text:
|
||||
return
|
||||
try:
|
||||
cb(text)
|
||||
cb(safe_text)
|
||||
except Exception:
|
||||
logger.debug("Copilot ACP stream delta callback failed", exc_info=True)
|
||||
|
||||
|
|
@ -668,3 +676,53 @@ class CopilotACPClient:
|
|||
cb(text)
|
||||
except Exception:
|
||||
logger.debug("Copilot ACP reasoning callback failed", exc_info=True)
|
||||
|
||||
def _reset_stream_filters(self) -> None:
|
||||
self._stream_filter_buffer = ""
|
||||
self._stream_filter_in_tool_call = False
|
||||
|
||||
def _filter_stream_text(self, text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
buffer = self._stream_filter_buffer + text
|
||||
self._stream_filter_buffer = ""
|
||||
visible_parts: list[str] = []
|
||||
|
||||
while buffer:
|
||||
if self._stream_filter_in_tool_call:
|
||||
end_idx = buffer.find(_TOOL_CALL_END_TAG)
|
||||
if end_idx < 0:
|
||||
self._stream_filter_buffer = buffer
|
||||
return "".join(visible_parts)
|
||||
buffer = buffer[end_idx + len(_TOOL_CALL_END_TAG):]
|
||||
self._stream_filter_in_tool_call = False
|
||||
continue
|
||||
|
||||
start_idx = buffer.find(_TOOL_CALL_START_TAG)
|
||||
if start_idx >= 0:
|
||||
if start_idx > 0:
|
||||
visible_parts.append(buffer[:start_idx])
|
||||
buffer = buffer[start_idx + len(_TOOL_CALL_START_TAG):]
|
||||
self._stream_filter_in_tool_call = True
|
||||
continue
|
||||
|
||||
holdback = self._longest_tool_tag_prefix_suffix(buffer)
|
||||
if holdback:
|
||||
visible_parts.append(buffer[:-holdback])
|
||||
self._stream_filter_buffer = buffer[-holdback:]
|
||||
else:
|
||||
visible_parts.append(buffer)
|
||||
self._stream_filter_buffer = ""
|
||||
break
|
||||
|
||||
return "".join(visible_parts)
|
||||
|
||||
@staticmethod
|
||||
def _longest_tool_tag_prefix_suffix(text: str) -> int:
|
||||
for tag in (_TOOL_CALL_START_TAG, _TOOL_CALL_END_TAG):
|
||||
max_prefix = min(len(text), len(tag) - 1)
|
||||
for size in range(max_prefix, 0, -1):
|
||||
if text.endswith(tag[:size]):
|
||||
return size
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -196,6 +196,46 @@ class CopilotACPClientSafetyTests(unittest.TestCase):
|
|||
self.assertEqual(streamed, ["hello"])
|
||||
self.assertEqual(reasoned, ["thinking"])
|
||||
|
||||
def test_session_update_filters_tool_call_markup_from_stream_callbacks(self) -> None:
|
||||
streamed: list[str] = []
|
||||
|
||||
self.client = CopilotACPClient(
|
||||
acp_cwd="/tmp",
|
||||
stream_delta_callback=streamed.append,
|
||||
)
|
||||
|
||||
text_parts: list[str] = []
|
||||
process = _FakeProcess()
|
||||
chunks = [
|
||||
"Before ",
|
||||
"<tool",
|
||||
"_call>{\"id\":\"1\",\"type\":\"function\",\"function\":{\"name\":\"terminal\",\"arguments\":\"{}\"}}",
|
||||
"</tool_call>",
|
||||
" after",
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
handled = self.client._handle_server_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "session/update",
|
||||
"params": {
|
||||
"update": {
|
||||
"sessionUpdate": "agent_message_chunk",
|
||||
"content": {"text": chunk},
|
||||
}
|
||||
},
|
||||
},
|
||||
process=process,
|
||||
cwd="/tmp",
|
||||
text_parts=text_parts,
|
||||
reasoning_parts=[],
|
||||
)
|
||||
self.assertTrue(handled)
|
||||
|
||||
self.assertEqual("".join(text_parts), "".join(chunks))
|
||||
self.assertEqual("".join(streamed), "Before after")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue