fix(agent): persist tool calls before turn-end flush

Co-authored-by: kshitijk4poor <82637225+kshitijk4poor@users.noreply.github.com>
This commit is contained in:
konsisumer 2026-06-20 10:10:30 +02:00
parent aaa2e2cb88
commit 190b01c553
3 changed files with 314 additions and 7 deletions

View file

@ -4050,6 +4050,19 @@ def run_conversation(
messages.append(assistant_msg)
agent._emit_interim_assistant_message(assistant_msg)
try:
# Persist the assistant tool-call turn before any tool
# side effects run. If a destructive tool restarts or
# terminates Hermes mid-turn, resume logic still sees the
# exact tool-call block that already executed.
agent._flush_messages_to_session_db(messages, conversation_history)
except Exception as exc:
logger.warning(
"Incremental tool-call persistence failed before execution "
"(session=%s): %s",
agent.session_id or "none",
exc,
)
# Close any open streaming display (response box, reasoning
# box) before tool execution begins. Intermediate turns may

View file

@ -69,6 +69,25 @@ def _budget_for_agent(agent) -> BudgetConfig:
_MAX_TOOL_WORKERS = 8
def _flush_session_db_after_tool_progress(
agent,
messages: list,
*,
stage: str,
) -> None:
"""Best-effort incremental SessionDB flush for tool-call progress.
Tool execution can perform side effects that terminate or restart the
current Hermes process before the normal turn-end persistence path runs.
Flush the already-appended assistant/tool messages immediately so the
transcript survives destructive-but-valid tool calls.
"""
try:
agent._flush_messages_to_session_db(messages)
except Exception as exc:
logger.warning("Incremental tool-call persistence failed after %s: %s", stage, exc)
def _ra():
"""Lazy reference to ``run_agent`` so patches like ``run_agent._set_interrupt`` work."""
import run_agent
@ -279,6 +298,11 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
f"[Tool execution cancelled — {tc.function.name} was skipped due to user interrupt]",
tc.id,
))
_flush_session_db_after_tool_progress(
agent,
messages,
stage=f"cancelled tool result {tc.function.name}",
)
return
# ── Parse args + pre-execution bookkeeping ───────────────────────
@ -768,6 +792,11 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
# String results pass through unchanged.
_tool_content = agent._tool_result_content_for_active_model(name, function_result)
messages.append(make_tool_result_message(name, _tool_content, tc.id))
_flush_session_db_after_tool_progress(
agent,
messages,
stage=f"tool result {name}",
)
# ── Per-tool /steer drain ───────────────────────────────────
# Same as the sequential path: drain between each collected
@ -803,13 +832,16 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
agent._vprint(f"{agent.log_prefix}⚡ Interrupt: skipping {len(remaining_calls)} tool call(s)", force=True)
for skipped_tc in remaining_calls:
skipped_name = skipped_tc.function.name
skip_msg = {
"role": "tool",
"name": skipped_name,
"content": f"[Tool execution cancelled — {skipped_name} was skipped due to user interrupt]",
"tool_call_id": skipped_tc.id,
}
messages.append(skip_msg)
messages.append(make_tool_result_message(
skipped_name,
f"[Tool execution cancelled — {skipped_name} was skipped due to user interrupt]",
skipped_tc.id,
))
_flush_session_db_after_tool_progress(
agent,
messages,
stage=f"cancelled tool result {skipped_name}",
)
break
function_name = tool_call.function.name
@ -1402,6 +1434,11 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
# (see parallel path for rationale). String results pass through.
_tool_content = agent._tool_result_content_for_active_model(function_name, function_result)
messages.append(make_tool_result_message(function_name, _tool_content, tool_call.id))
_flush_session_db_after_tool_progress(
agent,
messages,
stage=f"tool result {function_name}",
)
# ── Per-tool /steer drain ───────────────────────────────────
# Drain pending steer BETWEEN individual tool calls so the
@ -1428,6 +1465,11 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
f"[Tool execution skipped — {skipped_name} was not started. User sent a new message]",
skipped_tc.id,
))
_flush_session_db_after_tool_progress(
agent,
messages,
stage=f"skipped tool result {skipped_name}",
)
break
if agent.tool_delay > 0 and i < len(assistant_message.tool_calls):

View file

@ -0,0 +1,252 @@
"""Behavior contracts for incremental tool-call persistence (#49045).
A destructive or process-terminating tool that runs during tool execution
must not lose the just-executed assistant(tool_calls) block or the tool
results that were produced before it fired. These tests pin the contract:
1. run_conversation flushes the assistant tool-call turn to the session
DB BEFORE handing control to _execute_tool_calls (so a tool that
restarts/kills the process never orphans the tool-call block).
2. The SEQUENTIAL tool path flushes each tool result to the session DB
immediately after appending it BEFORE the next tool dispatches.
3. The CONCURRENT tool path flushes each tool result in append order.
These exercise the REAL production dispatch surfaces:
* sequential -> ``run_agent.handle_function_call`` (tool_executor ~1256/1298)
* concurrent -> ``agent._invoke_tool`` (tool_executor ~539)
Mocking the genuine dispatch surface keeps the tests deterministic (no real
``web_search`` / network) AND mutation-survivable: the ordering assertions
read snapshots captured at flush time, so removing any production flush call
makes the corresponding assertion fail.
"""
import copy
from types import SimpleNamespace
from pathlib import Path
import tempfile
from unittest.mock import MagicMock, patch
from agent.tool_dispatch_helpers import make_tool_result_message
from run_agent import AIAgent
def _make_tool_defs(*names: str) -> list:
return [
{
"type": "function",
"function": {
"name": name,
"description": f"{name} tool",
"parameters": {"type": "object", "properties": {}},
},
}
for name in names
]
def _make_agent():
hermes_home = Path(tempfile.mkdtemp(prefix="hermes-test-home-"))
(hermes_home / "logs").mkdir(parents=True, exist_ok=True)
with (
patch(
"run_agent.get_tool_definitions",
return_value=_make_tool_defs("web_search"),
),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
patch("run_agent._hermes_home", hermes_home),
patch("agent.model_metadata.fetch_model_metadata", return_value={}),
):
agent = AIAgent(
api_key="test-key",
base_url="https://openrouter.ai/api/v1",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
agent.client = MagicMock()
agent._cached_system_prompt = "You are helpful."
agent._use_prompt_caching = False
agent.tool_delay = 0
agent.compression_enabled = False
agent.save_trajectories = False
return agent
def _mock_tool_call(name="web_search", arguments="{}", call_id="call_1"):
return SimpleNamespace(
id=call_id,
type="function",
function=SimpleNamespace(name=name, arguments=arguments),
)
def _mock_response(content="Hello", finish_reason="stop", tool_calls=None):
msg = SimpleNamespace(content=content, tool_calls=tool_calls)
choice = SimpleNamespace(message=msg, finish_reason=finish_reason)
return SimpleNamespace(choices=[choice], model="test/model", usage=None)
# ---------------------------------------------------------------------------
# Contract 1: run_conversation persists the assistant tool-call block BEFORE
# tool execution begins.
# ---------------------------------------------------------------------------
def test_run_conversation_flushes_assistant_tool_call_before_execution():
agent = _make_agent()
tool_call = _mock_tool_call(call_id="c1")
agent.client.chat.completions.create.side_effect = [
_mock_response(content="", finish_reason="tool_calls", tool_calls=[tool_call]),
_mock_response(content="done", finish_reason="stop"),
]
# Record a deep snapshot of the message list at every flush so the
# assertion does not depend on later mutations.
flush_snapshots: list[list] = []
def _record_flush(messages, conversation_history=None):
flush_snapshots.append(copy.deepcopy(messages))
agent._flush_messages_to_session_db = MagicMock(side_effect=_record_flush)
# Capture observations at execute time into module-level lists rather than
# asserting inside _execute_tool_calls — run_conversation's outer loop
# swallows exceptions, so an in-callback assertion would never surface.
executed = {"count": 0}
snapshot_at_execute: list = []
def _fake_execute(assistant_message, messages, effective_task_id, api_call_count=0):
executed["count"] += 1
# Record the DB state observed at the moment tool execution begins.
snapshot_at_execute.append(
copy.deepcopy(flush_snapshots[-1]) if flush_snapshots else None
)
# Simulate the tool producing a result (as the real path would).
messages.append(make_tool_result_message("web_search", "search result", "c1"))
with (
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
patch.object(agent, "_execute_tool_calls", side_effect=_fake_execute),
):
result = agent.run_conversation("search something")
assert executed["count"] == 1, "_execute_tool_calls was never reached"
# The assistant tool-call block MUST have been flushed before execution.
last = snapshot_at_execute[0]
assert last is not None, "no flush occurred before tool execution"
assert last[-1]["role"] == "assistant"
assert last[-1]["tool_calls"][0]["id"] == "c1"
assert result["final_response"] == "done"
# ---------------------------------------------------------------------------
# Contract 2: the SEQUENTIAL path flushes each tool result immediately, BEFORE
# the next tool dispatches. Dispatch goes through run_agent.handle_function_call
# (the real production surface), which we mock for determinism.
# ---------------------------------------------------------------------------
def test_execute_tool_calls_sequential_flushes_each_tool_result_before_next_dispatch():
agent = _make_agent()
tool_calls = [
_mock_tool_call(name="web_search", call_id="c1"),
_mock_tool_call(name="web_search", call_id="c2"),
]
messages: list = []
assistant_message = SimpleNamespace(content="", tool_calls=tool_calls)
# Ordered event log interleaving real dispatches and DB flushes.
events: list = []
def _fake_dispatch(function_name, function_args, effective_task_id, **kwargs):
# The result for call N must have been flushed before call N+1 fires.
events.append(("dispatch", kwargs.get("tool_call_id")))
return f"result-{kwargs.get('tool_call_id')}"
def _record_flush(flush_messages, conversation_history=None):
# Snapshot the tail tool result that triggered this flush.
tail = flush_messages[-1]
events.append(("flush", tail.get("role"), tail.get("tool_call_id")))
agent._flush_messages_to_session_db = MagicMock(side_effect=_record_flush)
with (
patch("run_agent.handle_function_call", side_effect=_fake_dispatch) as disp,
patch(
"agent.tool_executor.maybe_persist_tool_result",
side_effect=lambda **kwargs: kwargs["content"],
),
):
agent._execute_tool_calls_sequential(assistant_message, messages, "task-1")
# The mock proves we exercised the REAL sequential dispatch surface.
assert disp.call_count == 2, "sequential path did not dispatch via handle_function_call"
# Both tool results landed, in order.
assert [m["role"] for m in messages] == ["tool", "tool"]
assert [m["tool_call_id"] for m in messages] == ["c1", "c2"]
# Ordering contract: each tool result is flushed AFTER its own dispatch
# and BEFORE the next dispatch. Expected interleaving:
# dispatch c1 -> flush c1 -> dispatch c2 -> flush c2
assert events == [
("dispatch", "c1"),
("flush", "tool", "c1"),
("dispatch", "c2"),
("flush", "tool", "c2"),
]
# ---------------------------------------------------------------------------
# Contract 3: the CONCURRENT path flushes each collected tool result in append
# order. Dispatch goes through agent._invoke_tool (the real concurrent
# surface), which we mock for determinism.
# ---------------------------------------------------------------------------
def test_execute_tool_calls_concurrent_flushes_each_tool_result_in_order():
agent = _make_agent()
tool_calls = [
_mock_tool_call(name="web_search", call_id="c1"),
_mock_tool_call(name="web_search", call_id="c2"),
]
messages: list = []
assistant_message = SimpleNamespace(content="", tool_calls=tool_calls)
invoked_ids: list = []
def _fake_invoke(function_name, function_args, effective_task_id, tool_call_id, **kwargs):
invoked_ids.append(tool_call_id)
return f"result-{tool_call_id}"
# Each flush must observe exactly one more tool result than the previous
# flush, in append order — i.e. the tail tool_call_id sequence is c1, c2.
flushed_tool_ids: list = []
flush_lengths: list = []
def _record_flush(flush_messages, conversation_history=None):
flushed_tool_ids.append(flush_messages[-1]["tool_call_id"])
flush_lengths.append(len([m for m in flush_messages if m.get("role") == "tool"]))
agent._flush_messages_to_session_db = MagicMock(side_effect=_record_flush)
with (
patch.object(agent, "_invoke_tool", side_effect=_fake_invoke) as inv,
patch(
"agent.tool_executor.maybe_persist_tool_result",
side_effect=lambda **kwargs: kwargs["content"],
),
):
agent._execute_tool_calls_concurrent(assistant_message, messages, "task-1")
# Proves the real concurrent dispatch surface was exercised.
assert inv.call_count == 2, "concurrent path did not dispatch via _invoke_tool"
assert sorted(invoked_ids) == ["c1", "c2"]
# Results appended in deterministic order.
assert [m["tool_call_id"] for m in messages] == ["c1", "c2"]
# Each tool result was flushed exactly once, in append order, with the
# running tool count growing by one each time (1 then 2). Removing either
# production flush call breaks one of these assertions.
assert flushed_tool_ids == ["c1", "c2"]
assert flush_lengths == [1, 2]