fix(agent): detect truncated streaming tool calls before execution

When a streaming response is cut mid-tool-call (connection drop, timeout),
the accumulated function.arguments is invalid JSON. The mock response
builder defaulted finish_reason to 'stop', so the agent loop treated it
as a valid completed turn and tried to execute tools with broken args.

Fix: validate tool call arguments with json.loads() during mock response
reconstruction. If any are invalid JSON, override finish_reason to
'length'. In the main loop's length handler, if tool calls are present,
refuse to execute and return partial=True with a clear error instead of
silently failing or wasting retries.

Also fixes _thinking_exhausted to not short-circuit when tool calls are
present — truncated tool calls are not thinking exhaustion.

Original cherry-picked from PR #6776 by AIandI0x1.
Closes #6638.
This commit is contained in:
AIandI0x1 2026-04-09 16:18:14 -07:00 committed by Teknium
parent 3b554bf839
commit 2d0d05a337
2 changed files with 125 additions and 5 deletions

View file

@ -4584,20 +4584,31 @@ class AIAgent:
# Build mock response matching non-streaming shape
full_content = "".join(content_parts) or None
mock_tool_calls = None
has_truncated_tool_args = False
if tool_calls_acc:
mock_tool_calls = []
for idx in sorted(tool_calls_acc):
tc = tool_calls_acc[idx]
arguments = tc["function"]["arguments"]
if arguments and arguments.strip():
try:
json.loads(arguments)
except json.JSONDecodeError:
has_truncated_tool_args = True
mock_tool_calls.append(SimpleNamespace(
id=tc["id"],
type=tc["type"],
extra_content=tc.get("extra_content"),
function=SimpleNamespace(
name=tc["function"]["name"],
arguments=tc["function"]["arguments"],
arguments=arguments,
),
))
effective_finish_reason = finish_reason or "stop"
if has_truncated_tool_args:
effective_finish_reason = "length"
full_reasoning = "".join(reasoning_parts) or None
mock_message = SimpleNamespace(
role=role,
@ -4608,7 +4619,7 @@ class AIAgent:
mock_choice = SimpleNamespace(
index=0,
message=mock_message,
finish_reason=finish_reason or "stop",
finish_reason=effective_finish_reason,
)
return SimpleNamespace(
id="stream-" + str(uuid.uuid4()),
@ -7319,6 +7330,7 @@ class AIAgent:
interrupted = False
codex_ack_continuations = 0
length_continue_retries = 0
truncated_tool_call_retries = 0
truncated_response_prefix = ""
compression_attempts = 0
_turn_exit_reason = "unknown" # Diagnostic: why the loop ended
@ -7787,9 +7799,11 @@ class AIAgent:
# retries are pointless. Detect this early and give a
# targeted error instead of wasting 3 API calls.
_trunc_content = None
_trunc_has_tool_calls = False
if self.api_mode == "chat_completions":
_trunc_msg = response.choices[0].message if (hasattr(response, "choices") and response.choices) else None
_trunc_content = getattr(_trunc_msg, "content", None) if _trunc_msg else None
_trunc_has_tool_calls = bool(getattr(_trunc_msg, "tool_calls", None)) if _trunc_msg else False
elif self.api_mode == "anthropic_messages":
# Anthropic response.content is a list of blocks
_text_parts = []
@ -7799,9 +7813,11 @@ class AIAgent:
_trunc_content = "\n".join(_text_parts) if _text_parts else None
_thinking_exhausted = (
_trunc_content is not None
and not self._has_content_after_think_block(_trunc_content)
) or _trunc_content is None
not _trunc_has_tool_calls and (
(_trunc_content is not None and not self._has_content_after_think_block(_trunc_content))
or _trunc_content is None
)
)
if _thinking_exhausted:
_exhaust_error = (
@ -7877,6 +7893,34 @@ class AIAgent:
"error": "Response remained truncated after 3 continuation attempts",
}
if self.api_mode == "chat_completions":
assistant_message = response.choices[0].message
if assistant_message.tool_calls:
if truncated_tool_call_retries < 1:
truncated_tool_call_retries += 1
self._vprint(
f"{self.log_prefix}⚠️ Truncated tool call detected — retrying API call...",
force=True,
)
# Don't append the broken response to messages;
# just re-run the same API call from the current
# message state, giving the model another chance.
continue
self._vprint(
f"{self.log_prefix}⚠️ Truncated tool call response detected again — refusing to execute incomplete tool arguments.",
force=True,
)
self._cleanup_task_resources(effective_task_id)
self._persist_session(messages, conversation_history)
return {
"final_response": None,
"messages": messages,
"api_calls": api_call_count,
"completed": False,
"partial": True,
"error": "Response truncated due to output length limit",
}
# If we have prior messages, roll back to last complete state
if len(messages) > 1:
self._vprint(f"{self.log_prefix} ⏪ Rolling back to last complete assistant turn")

View file

@ -1949,6 +1949,68 @@ class TestRunConversation:
assert result["final_response"] is not None
assert "Thinking Budget Exhausted" in result["final_response"]
def test_length_with_tool_calls_returns_partial_without_executing_tools(self, agent):
self._setup_agent(agent)
bad_tc = _mock_tool_call(
name="write_file",
arguments='{"path":"report.md","content":"partial',
call_id="c1",
)
resp = _mock_response(content="", finish_reason="length", tool_calls=[bad_tc])
agent.client.chat.completions.create.return_value = resp
with (
patch("run_agent.handle_function_call") as mock_handle_function_call,
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
result = agent.run_conversation("write the report")
assert result["completed"] is False
assert result["partial"] is True
assert "truncated due to output length limit" in result["error"]
mock_handle_function_call.assert_not_called()
def test_truncated_tool_call_retries_once_before_refusing(self, agent):
"""When tool call args are truncated, the agent retries the API call
once. If the retry succeeds (valid JSON args), tool execution proceeds."""
self._setup_agent(agent)
agent.valid_tool_names.add("write_file")
bad_tc = _mock_tool_call(
name="write_file",
arguments='{"path":"report.md","content":"partial',
call_id="c1",
)
truncated_resp = _mock_response(
content="", finish_reason="length", tool_calls=[bad_tc],
)
good_tc = _mock_tool_call(
name="write_file",
arguments='{"path":"report.md","content":"full content"}',
call_id="c2",
)
good_resp = _mock_response(
content="", finish_reason="stop", tool_calls=[good_tc],
)
with (
patch("run_agent.handle_function_call", return_value='{"success":true}') as mock_hfc,
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
# First call: truncated → retry. Second: valid → execute tool.
# Third: final text response.
final_resp = _mock_response(content="Done!", finish_reason="stop")
agent.client.chat.completions.create.side_effect = [
truncated_resp, good_resp, final_resp,
]
result = agent.run_conversation("write the report")
# Tool was executed on the retry (good_resp)
mock_hfc.assert_called_once()
assert result["final_response"] == "Done!"
class TestRetryExhaustion:
"""Regression: retry_count > max_retries was dead code (off-by-one).
@ -3082,6 +3144,20 @@ class TestStreamingApiCall:
assert tc[0].function.name == "search"
assert tc[1].function.name == "read"
def test_truncated_tool_call_args_upgrade_finish_reason_to_length(self, agent):
chunks = [
_make_chunk(tool_calls=[_make_tc_delta(0, "call_1", "write_file", '{"path":"x.txt","content":"hel')]),
]
agent.client.chat.completions.create.return_value = iter(chunks)
resp = agent._interruptible_streaming_api_call({"messages": []})
tc = resp.choices[0].message.tool_calls
assert len(tc) == 1
assert tc[0].function.name == "write_file"
assert tc[0].function.arguments == '{"path":"x.txt","content":"hel'
assert resp.choices[0].finish_reason == "length"
def test_ollama_reused_index_separate_tool_calls(self, agent):
"""Ollama sends every tool call at index 0 with different ids.