mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat(api): structured run events via /v1/runs SSE endpoint
Add POST /v1/runs to start async agent runs and GET /v1/runs/{run_id}/events
for SSE streaming of typed lifecycle events (tool.started, tool.completed,
message.delta, reasoning.available, run.completed, run.failed).
Changes the internal tool_progress_callback signature from positional
(tool_name, preview, args) to event-type-first
(event_type, tool_name, preview, args, **kwargs). Existing consumers
filter on event_type and remain backward-compatible.
Adds concurrency limit (_MAX_CONCURRENT_RUNS=10) and orphaned run sweep.
Fixes logic inversion in cli.py _on_tool_progress where the original PR
would have displayed internal tools instead of non-internal ones.
Co-authored-by: Mibayy <mibayy@users.noreply.github.com>
This commit is contained in:
parent
e167ad8f61
commit
cc2b56b26a
11 changed files with 337 additions and 44 deletions
|
|
@ -54,14 +54,18 @@ def make_tool_progress_cb(
|
|||
|
||||
Signature expected by AIAgent::
|
||||
|
||||
tool_progress_callback(name: str, preview: str, args: dict)
|
||||
tool_progress_callback(event_type: str, name: str, preview: str, args: dict, **kwargs)
|
||||
|
||||
Emits ``ToolCallStart`` for each tool invocation and tracks IDs in a FIFO
|
||||
Emits ``ToolCallStart`` for ``tool.started`` events and tracks IDs in a FIFO
|
||||
queue per tool name so duplicate/parallel same-name calls still complete
|
||||
against the correct ACP tool call.
|
||||
against the correct ACP tool call. Other event types (``tool.completed``,
|
||||
``reasoning.available``) are silently ignored.
|
||||
"""
|
||||
|
||||
def _tool_progress(name: str, preview: str, args: Any = None) -> None:
|
||||
def _tool_progress(event_type: str, name: str = None, preview: str = None, args: Any = None, **kwargs) -> None:
|
||||
# Only emit ACP ToolCallStart for tool.started; ignore other event types
|
||||
if event_type != "tool.started":
|
||||
return
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import acp
|
|||
from acp.schema import (
|
||||
AgentCapabilities,
|
||||
AuthenticateResponse,
|
||||
AuthMethodAgent,
|
||||
AvailableCommand,
|
||||
AvailableCommandsUpdate,
|
||||
ClientCapabilities,
|
||||
|
|
@ -43,6 +42,12 @@ from acp.schema import (
|
|||
Usage,
|
||||
)
|
||||
|
||||
# AuthMethodAgent was renamed from AuthMethod in agent-client-protocol 0.9.0
|
||||
try:
|
||||
from acp.schema import AuthMethodAgent
|
||||
except ImportError:
|
||||
from acp.schema import AuthMethod as AuthMethodAgent # type: ignore[attr-defined]
|
||||
|
||||
from acp_adapter.auth import detect_provider, has_provider
|
||||
from acp_adapter.events import (
|
||||
make_message_cb,
|
||||
|
|
|
|||
11
cli.py
11
cli.py
|
|
@ -5457,14 +5457,17 @@ class HermesCLI:
|
|||
# Tool progress callback (audio cues for voice mode)
|
||||
# ====================================================================
|
||||
|
||||
def _on_tool_progress(self, function_name: str, preview: str, function_args: dict):
|
||||
"""Called when a tool starts executing.
|
||||
def _on_tool_progress(self, event_type: str, function_name: str = None, preview: str = None, function_args: dict = None, **kwargs):
|
||||
"""Called on tool lifecycle events (tool.started, tool.completed, reasoning.available, etc.).
|
||||
|
||||
Updates the TUI spinner widget so the user can see what the agent
|
||||
is doing during tool execution (fills the gap between thinking
|
||||
spinner and next response). Also plays audio cue in voice mode.
|
||||
"""
|
||||
if not function_name.startswith("_"):
|
||||
# Only act on tool.started; ignore tool.completed, reasoning.available, etc.
|
||||
if event_type != "tool.started":
|
||||
return
|
||||
if function_name and not function_name.startswith("_"):
|
||||
from agent.display import get_tool_emoji
|
||||
emoji = get_tool_emoji(function_name)
|
||||
label = preview or function_name
|
||||
|
|
@ -5477,7 +5480,7 @@ class HermesCLI:
|
|||
|
||||
if not self._voice_mode:
|
||||
return
|
||||
if function_name.startswith("_"):
|
||||
if not function_name or function_name.startswith("_"):
|
||||
return
|
||||
try:
|
||||
from tools.voice_mode import play_beep
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ Exposes an HTTP server with endpoints:
|
|||
- GET /v1/responses/{response_id} — Retrieve a stored response
|
||||
- DELETE /v1/responses/{response_id} — Delete a stored response
|
||||
- GET /v1/models — lists hermes-agent as an available model
|
||||
- POST /v1/runs — start a run, returns run_id immediately (202)
|
||||
- GET /v1/runs/{run_id}/events — SSE stream of structured lifecycle events
|
||||
- GET /health — health check
|
||||
|
||||
Any OpenAI-compatible frontend (Open WebUI, LobeChat, LibreChat,
|
||||
|
|
@ -300,6 +302,10 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
self._runner: Optional["web.AppRunner"] = None
|
||||
self._site: Optional["web.TCPSite"] = None
|
||||
self._response_store = ResponseStore()
|
||||
# Active run streams: run_id -> asyncio.Queue of SSE event dicts
|
||||
self._run_streams: Dict[str, "asyncio.Queue[Optional[Dict]]"] = {}
|
||||
# Creation timestamps for orphaned-run TTL sweep
|
||||
self._run_streams_created: Dict[str, float] = {}
|
||||
self._session_db: Optional[Any] = None # Lazy-init SessionDB for session continuity
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -1287,6 +1293,236 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
|
||||
return await loop.run_in_executor(None, _run)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# /v1/runs — structured event streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_MAX_CONCURRENT_RUNS = 10 # Prevent unbounded resource allocation
|
||||
_RUN_STREAM_TTL = 300 # seconds before orphaned runs are swept
|
||||
|
||||
def _make_run_event_callback(self, run_id: str, loop: "asyncio.AbstractEventLoop"):
|
||||
"""Return a tool_progress_callback that pushes structured events to the run's SSE queue."""
|
||||
def _push(event: Dict[str, Any]) -> None:
|
||||
q = self._run_streams.get(run_id)
|
||||
if q is None:
|
||||
return
|
||||
try:
|
||||
loop.call_soon_threadsafe(q.put_nowait, event)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _callback(event_type: str, tool_name: str = None, preview: str = None, args=None, **kwargs):
|
||||
ts = time.time()
|
||||
if event_type == "tool.started":
|
||||
_push({
|
||||
"event": "tool.started",
|
||||
"run_id": run_id,
|
||||
"timestamp": ts,
|
||||
"tool": tool_name,
|
||||
"preview": preview,
|
||||
})
|
||||
elif event_type == "tool.completed":
|
||||
_push({
|
||||
"event": "tool.completed",
|
||||
"run_id": run_id,
|
||||
"timestamp": ts,
|
||||
"tool": tool_name,
|
||||
"duration": round(kwargs.get("duration", 0), 3),
|
||||
"error": kwargs.get("is_error", False),
|
||||
})
|
||||
elif event_type == "reasoning.available":
|
||||
_push({
|
||||
"event": "reasoning.available",
|
||||
"run_id": run_id,
|
||||
"timestamp": ts,
|
||||
"text": preview or "",
|
||||
})
|
||||
# _thinking and subagent_progress are intentionally not forwarded
|
||||
|
||||
return _callback
|
||||
|
||||
async def _handle_runs(self, request: "web.Request") -> "web.Response":
|
||||
"""POST /v1/runs — start an agent run, return run_id immediately."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
|
||||
# Enforce concurrency limit
|
||||
if len(self._run_streams) >= self._MAX_CONCURRENT_RUNS:
|
||||
return web.json_response(
|
||||
_openai_error(f"Too many concurrent runs (max {self._MAX_CONCURRENT_RUNS})", code="rate_limit_exceeded"),
|
||||
status=429,
|
||||
)
|
||||
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response(_openai_error("Invalid JSON"), status=400)
|
||||
|
||||
raw_input = body.get("input")
|
||||
if not raw_input:
|
||||
return web.json_response(_openai_error("Missing 'input' field"), status=400)
|
||||
|
||||
user_message = raw_input if isinstance(raw_input, str) else (raw_input[-1].get("content", "") if isinstance(raw_input, list) else "")
|
||||
if not user_message:
|
||||
return web.json_response(_openai_error("No user message found in input"), status=400)
|
||||
|
||||
run_id = f"run_{uuid.uuid4().hex}"
|
||||
loop = asyncio.get_running_loop()
|
||||
q: "asyncio.Queue[Optional[Dict]]" = asyncio.Queue()
|
||||
self._run_streams[run_id] = q
|
||||
self._run_streams_created[run_id] = time.time()
|
||||
|
||||
event_cb = self._make_run_event_callback(run_id, loop)
|
||||
|
||||
# Also wire stream_delta_callback so message.delta events flow through
|
||||
def _text_cb(delta: Optional[str]) -> None:
|
||||
if delta is None:
|
||||
return
|
||||
try:
|
||||
loop.call_soon_threadsafe(q.put_nowait, {
|
||||
"event": "message.delta",
|
||||
"run_id": run_id,
|
||||
"timestamp": time.time(),
|
||||
"delta": delta,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
instructions = body.get("instructions")
|
||||
previous_response_id = body.get("previous_response_id")
|
||||
conversation_history: List[Dict[str, str]] = []
|
||||
if previous_response_id:
|
||||
stored = self._response_store.get(previous_response_id)
|
||||
if stored:
|
||||
conversation_history = list(stored.get("conversation_history", []))
|
||||
if instructions is None:
|
||||
instructions = stored.get("instructions")
|
||||
|
||||
session_id = body.get("session_id") or run_id
|
||||
ephemeral_system_prompt = instructions
|
||||
|
||||
async def _run_and_close():
|
||||
try:
|
||||
agent = self._create_agent(
|
||||
ephemeral_system_prompt=ephemeral_system_prompt,
|
||||
session_id=session_id,
|
||||
stream_delta_callback=_text_cb,
|
||||
tool_progress_callback=event_cb,
|
||||
)
|
||||
def _run_sync():
|
||||
r = agent.run_conversation(
|
||||
user_message=user_message,
|
||||
conversation_history=conversation_history,
|
||||
)
|
||||
u = {
|
||||
"input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0,
|
||||
"output_tokens": getattr(agent, "session_completion_tokens", 0) or 0,
|
||||
"total_tokens": getattr(agent, "session_total_tokens", 0) or 0,
|
||||
}
|
||||
return r, u
|
||||
|
||||
result, usage = await asyncio.get_running_loop().run_in_executor(None, _run_sync)
|
||||
final_response = result.get("final_response", "") if isinstance(result, dict) else ""
|
||||
q.put_nowait({
|
||||
"event": "run.completed",
|
||||
"run_id": run_id,
|
||||
"timestamp": time.time(),
|
||||
"output": final_response,
|
||||
"usage": usage,
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.exception("[api_server] run %s failed", run_id)
|
||||
try:
|
||||
q.put_nowait({
|
||||
"event": "run.failed",
|
||||
"run_id": run_id,
|
||||
"timestamp": time.time(),
|
||||
"error": str(exc),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
# Sentinel: signal SSE stream to close
|
||||
try:
|
||||
q.put_nowait(None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
task = asyncio.create_task(_run_and_close())
|
||||
try:
|
||||
self._background_tasks.add(task)
|
||||
except TypeError:
|
||||
pass
|
||||
if hasattr(task, "add_done_callback"):
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
return web.json_response({"run_id": run_id, "status": "started"}, status=202)
|
||||
|
||||
async def _handle_run_events(self, request: "web.Request") -> "web.StreamResponse":
|
||||
"""GET /v1/runs/{run_id}/events — SSE stream of structured agent lifecycle events."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
|
||||
run_id = request.match_info["run_id"]
|
||||
|
||||
# Allow subscribing slightly before the run is registered (race condition window)
|
||||
for _ in range(20):
|
||||
if run_id in self._run_streams:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
else:
|
||||
return web.json_response(_openai_error(f"Run not found: {run_id}", code="run_not_found"), status=404)
|
||||
|
||||
q = self._run_streams[run_id]
|
||||
|
||||
response = web.StreamResponse(
|
||||
status=200,
|
||||
headers={
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
await response.prepare(request)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(q.get(), timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
await response.write(b": keepalive\n\n")
|
||||
continue
|
||||
if event is None:
|
||||
# Run finished — send final SSE comment and close
|
||||
await response.write(b": stream closed\n\n")
|
||||
break
|
||||
payload = f"data: {json.dumps(event)}\n\n"
|
||||
await response.write(payload.encode())
|
||||
except Exception as exc:
|
||||
logger.debug("[api_server] SSE stream error for run %s: %s", run_id, exc)
|
||||
finally:
|
||||
self._run_streams.pop(run_id, None)
|
||||
self._run_streams_created.pop(run_id, None)
|
||||
|
||||
return response
|
||||
|
||||
async def _sweep_orphaned_runs(self) -> None:
|
||||
"""Periodically clean up run streams that were never consumed."""
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
now = time.time()
|
||||
stale = [
|
||||
run_id
|
||||
for run_id, created_at in list(self._run_streams_created.items())
|
||||
if now - created_at > self._RUN_STREAM_TTL
|
||||
]
|
||||
for run_id in stale:
|
||||
logger.debug("[api_server] sweeping orphaned run %s", run_id)
|
||||
self._run_streams.pop(run_id, None)
|
||||
self._run_streams_created.pop(run_id, None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# BasePlatformAdapter interface
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -1317,6 +1553,17 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
self._app.router.add_post("/api/jobs/{job_id}/pause", self._handle_pause_job)
|
||||
self._app.router.add_post("/api/jobs/{job_id}/resume", self._handle_resume_job)
|
||||
self._app.router.add_post("/api/jobs/{job_id}/run", self._handle_run_job)
|
||||
# Structured event streaming
|
||||
self._app.router.add_post("/v1/runs", self._handle_runs)
|
||||
self._app.router.add_get("/v1/runs/{run_id}/events", self._handle_run_events)
|
||||
# Start background sweep to clean up orphaned (unconsumed) run streams
|
||||
sweep_task = asyncio.create_task(self._sweep_orphaned_runs())
|
||||
try:
|
||||
self._background_tasks.add(sweep_task)
|
||||
except TypeError:
|
||||
pass
|
||||
if hasattr(sweep_task, "add_done_callback"):
|
||||
sweep_task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
# Port conflict detection — fail fast if port is already in use
|
||||
import socket as _socket
|
||||
|
|
|
|||
|
|
@ -6000,11 +6000,15 @@ class GatewayRunner:
|
|||
last_progress_msg = [None] # Track last message for dedup
|
||||
repeat_count = [0] # How many times the same message repeated
|
||||
|
||||
def progress_callback(tool_name: str, preview: str = None, args: dict = None):
|
||||
"""Callback invoked by agent when a tool is called."""
|
||||
def progress_callback(event_type: str, tool_name: str = None, preview: str = None, args: dict = None, **kwargs):
|
||||
"""Callback invoked by agent on tool lifecycle events."""
|
||||
if not progress_queue:
|
||||
return
|
||||
|
||||
|
||||
# Only act on tool.started events (ignore tool.completed, reasoning.available, etc.)
|
||||
if event_type not in ("tool.started",):
|
||||
return
|
||||
|
||||
# "new" mode: only report when tool changes
|
||||
if progress_mode == "new" and tool_name == last_tool[0]:
|
||||
return
|
||||
|
|
|
|||
36
run_agent.py
36
run_agent.py
|
|
@ -6056,7 +6056,7 @@ class AIAgent:
|
|||
if self.tool_progress_callback:
|
||||
try:
|
||||
preview = _build_tool_preview(name, args)
|
||||
self.tool_progress_callback(name, preview, args)
|
||||
self.tool_progress_callback("tool.started", name, preview, args)
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||
|
||||
|
|
@ -6121,6 +6121,15 @@ class AIAgent:
|
|||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview)
|
||||
|
||||
if self.tool_progress_callback:
|
||||
try:
|
||||
self.tool_progress_callback(
|
||||
"tool.completed", function_name, None, None,
|
||||
duration=tool_duration, is_error=is_error,
|
||||
)
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
||||
logging.debug(f"Tool result ({len(function_result)} chars): {function_result}")
|
||||
|
|
@ -6220,7 +6229,7 @@ class AIAgent:
|
|||
if self.tool_progress_callback:
|
||||
try:
|
||||
preview = _build_tool_preview(function_name, function_args)
|
||||
self.tool_progress_callback(function_name, preview, function_args)
|
||||
self.tool_progress_callback("tool.started", function_name, preview, function_args)
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||
|
||||
|
|
@ -6407,6 +6416,15 @@ class AIAgent:
|
|||
if _is_error_result:
|
||||
logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview)
|
||||
|
||||
if self.tool_progress_callback:
|
||||
try:
|
||||
self.tool_progress_callback(
|
||||
"tool.completed", function_name, None, None,
|
||||
duration=tool_duration, is_error=_is_error_result,
|
||||
)
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
||||
logging.debug(f"Tool result ({len(function_result)} chars): {function_result}")
|
||||
|
|
@ -8263,21 +8281,25 @@ class AIAgent:
|
|||
|
||||
# Notify progress callback of model's thinking (used by subagent
|
||||
# delegation to relay the child's reasoning to the parent display).
|
||||
# Guard: only fire for subagents (_delegate_depth >= 1) to avoid
|
||||
# spamming gateway platforms with the main agent's every thought.
|
||||
if (assistant_message.content and self.tool_progress_callback
|
||||
and getattr(self, '_delegate_depth', 0) > 0):
|
||||
if (assistant_message.content and self.tool_progress_callback):
|
||||
_think_text = assistant_message.content.strip()
|
||||
# Strip reasoning XML tags that shouldn't leak to parent display
|
||||
_think_text = re.sub(
|
||||
r'</?(?:REASONING_SCRATCHPAD|think|reasoning)>', '', _think_text
|
||||
).strip()
|
||||
# For subagents: relay first line to parent display (existing behaviour).
|
||||
# For all agents with a structured callback: emit reasoning.available event.
|
||||
first_line = _think_text.split('\n')[0][:80] if _think_text else ""
|
||||
if first_line:
|
||||
if first_line and getattr(self, '_delegate_depth', 0) > 0:
|
||||
try:
|
||||
self.tool_progress_callback("_thinking", first_line)
|
||||
except Exception:
|
||||
pass
|
||||
elif _think_text:
|
||||
try:
|
||||
self.tool_progress_callback("reasoning.available", "_thinking", _think_text[:500], None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check for incomplete <REASONING_SCRATCHPAD> (opened but never closed)
|
||||
# This means the model ran out of output tokens mid-reasoning — retry up to 2 times
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class TestToolProgressCallback:
|
|||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb("terminal", "$ ls -la", {"command": "ls -la"})
|
||||
cb("tool.started", "terminal", "$ ls -la", {"command": "ls -la"})
|
||||
|
||||
# Should have tracked the tool call ID
|
||||
assert "terminal" in tool_call_ids
|
||||
|
|
@ -75,7 +75,7 @@ class TestToolProgressCallback:
|
|||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb("read_file", "Reading /etc/hosts", '{"path": "/etc/hosts"}')
|
||||
cb("tool.started", "read_file", "Reading /etc/hosts", '{"path": "/etc/hosts"}')
|
||||
|
||||
assert "read_file" in tool_call_ids
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ class TestToolProgressCallback:
|
|||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb("terminal", "$ echo hi", None)
|
||||
cb("tool.started", "terminal", "$ echo hi", None)
|
||||
|
||||
assert "terminal" in tool_call_ids
|
||||
|
||||
|
|
@ -108,8 +108,8 @@ class TestToolProgressCallback:
|
|||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
progress_cb("terminal", "$ ls", {"command": "ls"})
|
||||
progress_cb("terminal", "$ pwd", {"command": "pwd"})
|
||||
progress_cb("tool.started", "terminal", "$ ls", {"command": "ls"})
|
||||
progress_cb("tool.started", "terminal", "$ pwd", {"command": "pwd"})
|
||||
assert len(tool_call_ids["terminal"]) == 2
|
||||
|
||||
step_cb(1, [{"name": "terminal", "result": "ok-1"}])
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ class TestMcpRegistrationE2E:
|
|||
# 1) Agent fires tool_progress_callback (ToolCallStart)
|
||||
if agent.tool_progress_callback:
|
||||
agent.tool_progress_callback(
|
||||
"terminal", "$ echo hello", {"command": "echo hello"}
|
||||
"tool.started", "terminal", "$ echo hello", {"command": "echo hello"}
|
||||
)
|
||||
|
||||
# 2) Agent fires step_callback with tool results (ToolCallUpdate)
|
||||
|
|
@ -197,8 +197,8 @@ class TestMcpRegistrationE2E:
|
|||
agent = state.agent
|
||||
# Fire two tool calls
|
||||
if agent.tool_progress_callback:
|
||||
agent.tool_progress_callback("read_file", "read: /etc/hosts", {"path": "/etc/hosts"})
|
||||
agent.tool_progress_callback("web_search", "web search: test", {"query": "test"})
|
||||
agent.tool_progress_callback("tool.started", "read_file", "read: /etc/hosts", {"path": "/etc/hosts"})
|
||||
agent.tool_progress_callback("tool.started", "web_search", "web search: test", {"query": "test"})
|
||||
|
||||
if agent.step_callback:
|
||||
agent.step_callback(1, [
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class TestBuildChildProgressCallback:
|
|||
cb = _build_child_progress_callback(0, parent)
|
||||
assert cb is not None
|
||||
|
||||
cb("web_search", "quantum computing")
|
||||
cb("tool.started", "web_search", "quantum computing", {})
|
||||
output = buf.getvalue()
|
||||
assert "web_search" in output
|
||||
assert "quantum computing" in output
|
||||
|
|
@ -131,11 +131,11 @@ class TestBuildChildProgressCallback:
|
|||
|
||||
# Send 4 tool calls — shouldn't flush yet (BATCH_SIZE = 5)
|
||||
for i in range(4):
|
||||
cb(f"tool_{i}", f"arg_{i}")
|
||||
cb("tool.started", f"tool_{i}", f"arg_{i}", {})
|
||||
parent_cb.assert_not_called()
|
||||
|
||||
# 5th call should trigger flush
|
||||
cb("tool_4", "arg_4")
|
||||
cb("tool.started", "tool_4", "arg_4", {})
|
||||
parent_cb.assert_called_once()
|
||||
call_args = parent_cb.call_args
|
||||
assert "tool_0" in call_args[0][1]
|
||||
|
|
@ -207,7 +207,7 @@ class TestBuildChildProgressCallback:
|
|||
parent.tool_progress_callback = None
|
||||
|
||||
cb = _build_child_progress_callback(0, parent, task_count=1)
|
||||
cb("web_search", "test")
|
||||
cb("tool.started", "web_search", "test", {})
|
||||
|
||||
output = buf.getvalue()
|
||||
assert "[" not in output
|
||||
|
|
@ -330,9 +330,9 @@ class TestBatchFlush:
|
|||
cb = _build_child_progress_callback(0, parent)
|
||||
|
||||
# Send 3 tools (below batch size of 5)
|
||||
cb("web_search", "query1")
|
||||
cb("read_file", "file.txt")
|
||||
cb("write_file", "out.txt")
|
||||
cb("tool.started", "web_search", "query1", {})
|
||||
cb("tool.started", "read_file", "file.txt", {})
|
||||
cb("tool.started", "write_file", "out.txt", {})
|
||||
parent_cb.assert_not_called()
|
||||
|
||||
# Flush should send the remaining 3
|
||||
|
|
@ -365,7 +365,7 @@ class TestBatchFlush:
|
|||
parent.tool_progress_callback = None
|
||||
|
||||
cb = _build_child_progress_callback(0, parent)
|
||||
cb("web_search", "test")
|
||||
cb("tool.started", "web_search", "test", {})
|
||||
cb._flush() # Should not crash
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -60,9 +60,9 @@ class FakeAgent:
|
|||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
self.tool_progress_callback("terminal", "pwd")
|
||||
self.tool_progress_callback("tool.started", "terminal", "pwd", {})
|
||||
time.sleep(0.35)
|
||||
self.tool_progress_callback("browser_navigate", "https://example.com")
|
||||
self.tool_progress_callback("tool.started", "browser_navigate", "https://example.com", {})
|
||||
time.sleep(0.35)
|
||||
return {
|
||||
"final_response": "done",
|
||||
|
|
|
|||
|
|
@ -98,11 +98,15 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
|||
_BATCH_SIZE = 5
|
||||
_batch: List[str] = []
|
||||
|
||||
def _callback(tool_name: str, preview: str = None):
|
||||
# Special "_thinking" event: model produced text content (reasoning)
|
||||
if tool_name == "_thinking":
|
||||
def _callback(event_type: str, tool_name: str = None, preview: str = None, args=None, **kwargs):
|
||||
# event_type is one of: "tool.started", "tool.completed",
|
||||
# "reasoning.available", "_thinking", "subagent_progress"
|
||||
|
||||
# "_thinking" / reasoning events
|
||||
if event_type in ("_thinking", "reasoning.available"):
|
||||
text = preview or tool_name or ""
|
||||
if spinner:
|
||||
short = (preview[:55] + "...") if preview and len(preview) > 55 else (preview or "")
|
||||
short = (text[:55] + "...") if len(text) > 55 else text
|
||||
try:
|
||||
spinner.print_above(f" {prefix}├─ 💭 \"{short}\"")
|
||||
except Exception as e:
|
||||
|
|
@ -110,11 +114,15 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
|||
# Don't relay thinking to gateway (too noisy for chat)
|
||||
return
|
||||
|
||||
# Regular tool call event
|
||||
# tool.completed — no display needed here (spinner shows on started)
|
||||
if event_type == "tool.completed":
|
||||
return
|
||||
|
||||
# tool.started — display and batch for parent relay
|
||||
if spinner:
|
||||
short = (preview[:35] + "...") if preview and len(preview) > 35 else (preview or "")
|
||||
from agent.display import get_tool_emoji
|
||||
emoji = get_tool_emoji(tool_name)
|
||||
emoji = get_tool_emoji(tool_name or "")
|
||||
line = f" {prefix}├─ {emoji} {tool_name}"
|
||||
if short:
|
||||
line += f" \"{short}\""
|
||||
|
|
@ -124,7 +132,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in
|
|||
logger.debug("Spinner print_above failed: %s", e)
|
||||
|
||||
if parent_cb:
|
||||
_batch.append(tool_name)
|
||||
_batch.append(tool_name or "")
|
||||
if len(_batch) >= _BATCH_SIZE:
|
||||
summary = ", ".join(_batch)
|
||||
try:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue