feat: pre-call sanitization and post-call tool guardrails (#1732)

Salvage of PR #1321 by @alireza78a (cherry-picked concept, reimplemented
against current main).

Phase 1 — Pre-call message sanitization:
  _sanitize_api_messages() now runs unconditionally before every LLM call.
  Previously gated on context_compressor being present, so sessions loaded
  from disk or running without compression could accumulate dangling
  tool_call/tool_result pairs causing API errors.

Phase 2a — Delegate task cap:
  _cap_delegate_task_calls() truncates excess delegate_task calls per turn
  to MAX_CONCURRENT_CHILDREN. The existing cap in delegate_tool.py only
  limits the task array within a single call; this catches multiple
  separate delegate_task tool_calls in one turn.

Phase 2b — Tool call deduplication:
  _deduplicate_tool_calls() drops duplicate (tool_name, arguments) pairs
  within a single turn when models stutter.

All three are static methods on AIAgent, independently testable.
29 tests covering happy paths and edge cases.
This commit is contained in:
Teknium 2026-03-17 04:24:27 -07:00 committed by GitHub
parent fb20a9e120
commit 85993fbb5a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 394 additions and 7 deletions

View file

@ -1957,7 +1957,124 @@ class AIAgent:
prompt_parts.append(PLATFORM_HINTS[platform_key])
return "\n\n".join(prompt_parts)
# =========================================================================
# Pre/post-call guardrails (inspired by PR #1321 — @alireza78a)
# =========================================================================
@staticmethod
def _get_tool_call_id_static(tc) -> str:
"""Extract call ID from a tool_call entry (dict or object)."""
if isinstance(tc, dict):
return tc.get("id", "") or ""
return getattr(tc, "id", "") or ""
@staticmethod
def _sanitize_api_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Fix orphaned tool_call / tool_result pairs before every LLM call.
Runs unconditionally not gated on whether the context compressor
is present so orphans from session loading or manual message
manipulation are always caught.
"""
surviving_call_ids: set = set()
for msg in messages:
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
cid = AIAgent._get_tool_call_id_static(tc)
if cid:
surviving_call_ids.add(cid)
result_call_ids: set = set()
for msg in messages:
if msg.get("role") == "tool":
cid = msg.get("tool_call_id")
if cid:
result_call_ids.add(cid)
# 1. Drop tool results with no matching assistant call
orphaned_results = result_call_ids - surviving_call_ids
if orphaned_results:
messages = [
m for m in messages
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
]
logger.debug(
"Pre-call sanitizer: removed %d orphaned tool result(s)",
len(orphaned_results),
)
# 2. Inject stub results for calls whose result was dropped
missing_results = surviving_call_ids - result_call_ids
if missing_results:
patched: List[Dict[str, Any]] = []
for msg in messages:
patched.append(msg)
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
cid = AIAgent._get_tool_call_id_static(tc)
if cid in missing_results:
patched.append({
"role": "tool",
"content": "[Result unavailable — see context summary above]",
"tool_call_id": cid,
})
messages = patched
logger.debug(
"Pre-call sanitizer: added %d stub tool result(s)",
len(missing_results),
)
return messages
@staticmethod
def _cap_delegate_task_calls(tool_calls: list) -> list:
"""Truncate excess delegate_task calls to MAX_CONCURRENT_CHILDREN.
The delegate_tool caps the task list inside a single call, but the
model can emit multiple separate delegate_task tool_calls in one
turn. This truncates the excess, preserving all non-delegate calls.
Returns the original list if no truncation was needed.
"""
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
delegate_count = sum(1 for tc in tool_calls if tc.function.name == "delegate_task")
if delegate_count <= MAX_CONCURRENT_CHILDREN:
return tool_calls
kept_delegates = 0
truncated = []
for tc in tool_calls:
if tc.function.name == "delegate_task":
if kept_delegates < MAX_CONCURRENT_CHILDREN:
truncated.append(tc)
kept_delegates += 1
else:
truncated.append(tc)
logger.warning(
"Truncated %d excess delegate_task call(s) to enforce "
"MAX_CONCURRENT_CHILDREN=%d limit",
delegate_count - MAX_CONCURRENT_CHILDREN, MAX_CONCURRENT_CHILDREN,
)
return truncated
@staticmethod
def _deduplicate_tool_calls(tool_calls: list) -> list:
"""Remove duplicate (tool_name, arguments) pairs within a single turn.
Only the first occurrence of each unique pair is kept.
Returns the original list if no duplicates were found.
"""
seen: set = set()
unique: list = []
for tc in tool_calls:
key = (tc.function.name, tc.function.arguments)
if key not in seen:
seen.add(key)
unique.append(tc)
else:
logger.warning("Removed duplicate tool call: %s", tc.function.name)
return unique if len(unique) < len(tool_calls) else tool_calls
def _repair_tool_call(self, tool_name: str) -> str | None:
"""Attempt to repair a mismatched tool name before aborting.
@ -4992,11 +5109,10 @@ class AIAgent:
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl)
# Safety net: strip orphaned tool results / add stubs for missing
# results before sending to the API. The compressor handles this
# during compression, but orphans can also sneak in from session
# loading or manual message manipulation.
if hasattr(self, 'context_compressor') and self.context_compressor:
api_messages = self.context_compressor._sanitize_tool_pairs(api_messages)
# results before sending to the API. Runs unconditionally — not
# gated on context_compressor — so orphans from session loading or
# manual message manipulation are always caught.
api_messages = self._sanitize_api_messages(api_messages)
# Calculate approximate request size for logging
total_chars = sum(len(str(msg)) for msg in api_messages)
@ -6026,7 +6142,15 @@ class AIAgent:
# Reset retry counter on successful JSON validation
self._invalid_json_retries = 0
# ── Post-call guardrails ──────────────────────────
assistant_message.tool_calls = self._cap_delegate_task_calls(
assistant_message.tool_calls
)
assistant_message.tool_calls = self._deduplicate_tool_calls(
assistant_message.tool_calls
)
assistant_msg = self._build_assistant_message(assistant_message, finish_reason)
# If this turn has both content AND tool_calls, capture the content

View file

@ -0,0 +1,263 @@
"""Unit tests for AIAgent pre/post-LLM-call guardrails.
Covers three static methods on AIAgent (inspired by PR #1321 — @alireza78a):
- _sanitize_api_messages() Phase 1: orphaned tool pair repair
- _cap_delegate_task_calls() Phase 2a: subagent concurrency limit
- _deduplicate_tool_calls() Phase 2b: identical call deduplication
"""
import types
from run_agent import AIAgent
from tools.delegate_tool import MAX_CONCURRENT_CHILDREN
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_tc(name: str, arguments: str = "{}") -> types.SimpleNamespace:
"""Create a minimal tool_call SimpleNamespace mirroring the OpenAI SDK object."""
tc = types.SimpleNamespace()
tc.function = types.SimpleNamespace(name=name, arguments=arguments)
return tc
def tool_result(call_id: str, content: str = "ok") -> dict:
return {"role": "tool", "tool_call_id": call_id, "content": content}
def assistant_dict_call(call_id: str, name: str = "terminal") -> dict:
"""Dict-style tool_call (as stored in message history)."""
return {"id": call_id, "function": {"name": name, "arguments": "{}"}}
# ---------------------------------------------------------------------------
# Phase 1 — _sanitize_api_messages
# ---------------------------------------------------------------------------
class TestSanitizeApiMessages:
def test_orphaned_result_removed(self):
msgs = [
{"role": "assistant", "tool_calls": [assistant_dict_call("c1")]},
tool_result("c1"),
tool_result("c_ORPHAN"),
]
out = AIAgent._sanitize_api_messages(msgs)
assert len(out) == 2
assert all(m.get("tool_call_id") != "c_ORPHAN" for m in out)
def test_orphaned_call_gets_stub_result(self):
msgs = [
{"role": "assistant", "tool_calls": [assistant_dict_call("c2")]},
]
out = AIAgent._sanitize_api_messages(msgs)
assert len(out) == 2
stub = out[1]
assert stub["role"] == "tool"
assert stub["tool_call_id"] == "c2"
assert stub["content"]
def test_clean_messages_pass_through(self):
msgs = [
{"role": "user", "content": "hello"},
{"role": "assistant", "tool_calls": [assistant_dict_call("c3")]},
tool_result("c3"),
{"role": "assistant", "content": "done"},
]
out = AIAgent._sanitize_api_messages(msgs)
assert out == msgs
def test_mixed_orphaned_result_and_orphaned_call(self):
msgs = [
{"role": "assistant", "tool_calls": [
assistant_dict_call("c4"),
assistant_dict_call("c5"),
]},
tool_result("c4"),
tool_result("c_DANGLING"),
]
out = AIAgent._sanitize_api_messages(msgs)
ids = [m.get("tool_call_id") for m in out if m.get("role") == "tool"]
assert "c_DANGLING" not in ids
assert "c4" in ids
assert "c5" in ids
def test_empty_list_is_safe(self):
assert AIAgent._sanitize_api_messages([]) == []
def test_no_tool_messages(self):
msgs = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
out = AIAgent._sanitize_api_messages(msgs)
assert out == msgs
def test_sdk_object_tool_calls(self):
tc_obj = types.SimpleNamespace(id="c6", function=types.SimpleNamespace(
name="terminal", arguments="{}"
))
msgs = [
{"role": "assistant", "tool_calls": [tc_obj]},
]
out = AIAgent._sanitize_api_messages(msgs)
assert len(out) == 2
assert out[1]["tool_call_id"] == "c6"
# ---------------------------------------------------------------------------
# Phase 2a — _cap_delegate_task_calls
# ---------------------------------------------------------------------------
class TestCapDelegateTaskCalls:
def test_excess_delegates_truncated(self):
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
out = AIAgent._cap_delegate_task_calls(tcs)
delegate_count = sum(1 for tc in out if tc.function.name == "delegate_task")
assert delegate_count == MAX_CONCURRENT_CHILDREN
def test_non_delegate_calls_preserved(self):
tcs = (
[make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 1)]
+ [make_tc("terminal"), make_tc("web_search")]
)
out = AIAgent._cap_delegate_task_calls(tcs)
names = [tc.function.name for tc in out]
assert "terminal" in names
assert "web_search" in names
def test_at_limit_passes_through(self):
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN)]
out = AIAgent._cap_delegate_task_calls(tcs)
assert out is tcs
def test_below_limit_passes_through(self):
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN - 1)]
out = AIAgent._cap_delegate_task_calls(tcs)
assert out is tcs
def test_no_delegate_calls_unchanged(self):
tcs = [make_tc("terminal"), make_tc("web_search")]
out = AIAgent._cap_delegate_task_calls(tcs)
assert out is tcs
def test_empty_list_safe(self):
assert AIAgent._cap_delegate_task_calls([]) == []
def test_original_list_not_mutated(self):
tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)]
original_len = len(tcs)
AIAgent._cap_delegate_task_calls(tcs)
assert len(tcs) == original_len
def test_interleaved_order_preserved(self):
delegates = [make_tc("delegate_task", f'{{"task":"{i}"}}')
for i in range(MAX_CONCURRENT_CHILDREN + 1)]
t1 = make_tc("terminal", '{"cmd":"ls"}')
w1 = make_tc("web_search", '{"q":"x"}')
tcs = [delegates[0], t1, delegates[1], w1] + delegates[2:]
out = AIAgent._cap_delegate_task_calls(tcs)
expected = [delegates[0], t1, delegates[1], w1] + delegates[2:MAX_CONCURRENT_CHILDREN]
assert len(out) == len(expected)
for i, (actual, exp) in enumerate(zip(out, expected)):
assert actual is exp, f"mismatch at index {i}"
# ---------------------------------------------------------------------------
# Phase 2b — _deduplicate_tool_calls
# ---------------------------------------------------------------------------
class TestDeduplicateToolCalls:
def test_duplicate_pair_deduplicated(self):
tcs = [
make_tc("web_search", '{"query":"foo"}'),
make_tc("web_search", '{"query":"foo"}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert len(out) == 1
def test_multiple_duplicates(self):
tcs = [
make_tc("web_search", '{"q":"a"}'),
make_tc("web_search", '{"q":"a"}'),
make_tc("terminal", '{"cmd":"ls"}'),
make_tc("terminal", '{"cmd":"ls"}'),
make_tc("terminal", '{"cmd":"pwd"}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert len(out) == 3
def test_same_tool_different_args_kept(self):
tcs = [
make_tc("terminal", '{"cmd":"ls"}'),
make_tc("terminal", '{"cmd":"pwd"}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert out is tcs
def test_different_tools_same_args_kept(self):
tcs = [
make_tc("tool_a", '{"x":1}'),
make_tc("tool_b", '{"x":1}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert out is tcs
def test_clean_list_unchanged(self):
tcs = [
make_tc("web_search", '{"q":"x"}'),
make_tc("terminal", '{"cmd":"ls"}'),
]
out = AIAgent._deduplicate_tool_calls(tcs)
assert out is tcs
def test_empty_list_safe(self):
assert AIAgent._deduplicate_tool_calls([]) == []
def test_first_occurrence_kept(self):
tc1 = make_tc("terminal", '{"cmd":"ls"}')
tc2 = make_tc("terminal", '{"cmd":"ls"}')
out = AIAgent._deduplicate_tool_calls([tc1, tc2])
assert len(out) == 1
assert out[0] is tc1
def test_original_list_not_mutated(self):
tcs = [
make_tc("web_search", '{"q":"dup"}'),
make_tc("web_search", '{"q":"dup"}'),
]
original_len = len(tcs)
AIAgent._deduplicate_tool_calls(tcs)
assert len(tcs) == original_len
# ---------------------------------------------------------------------------
# _get_tool_call_id_static
# ---------------------------------------------------------------------------
class TestGetToolCallIdStatic:
def test_dict_with_valid_id(self):
assert AIAgent._get_tool_call_id_static({"id": "call_123"}) == "call_123"
def test_dict_with_none_id(self):
assert AIAgent._get_tool_call_id_static({"id": None}) == ""
def test_dict_without_id_key(self):
assert AIAgent._get_tool_call_id_static({"function": {}}) == ""
def test_object_with_valid_id(self):
tc = types.SimpleNamespace(id="call_456")
assert AIAgent._get_tool_call_id_static(tc) == "call_456"
def test_object_with_none_id(self):
tc = types.SimpleNamespace(id=None)
assert AIAgent._get_tool_call_id_static(tc) == ""
def test_object_without_id_attr(self):
tc = types.SimpleNamespace()
assert AIAgent._get_tool_call_id_static(tc) == ""