mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat(steer): /steer <prompt> injects a mid-run note after the next tool call (#12116)
* feat(steer): /steer <prompt> injects a mid-run note after the next tool call Adds a new slash command that sits between /queue (turn boundary) and interrupt. /steer <text> stashes the message on the running agent and the agent loop appends it to the LAST tool result's content once the current tool batch finishes. The model sees it as part of the tool output on its next iteration. No interrupt is fired, no new user turn is inserted, and no prompt cache invalidation happens beyond the normal per-turn tool-result churn. Message-role alternation is preserved — we only modify an existing role:"tool" message's content. Wiring ------ - hermes_cli/commands.py: register /steer + add to ACTIVE_SESSION_BYPASS_COMMANDS. - run_agent.py: add _pending_steer state, AIAgent.steer(), _drain_pending_steer(), _apply_pending_steer_to_tool_results(); drain at end of both parallel and sequential tool executors; clear on interrupt; return leftover as result['pending_steer'] if the agent exits before another tool batch. - cli.py: /steer handler — route to agent.steer() when running, fall back to the regular queue otherwise; deliver result['pending_steer'] as next turn. - gateway/run.py: running-agent intercept calls running_agent.steer(); idle-agent path strips the prefix and forwards as a regular user message. - tui_gateway/server.py: new session.steer JSON-RPC method. - ui-tui: SessionSteerResponse type + local /steer slash command that calls session.steer when ui.busy, otherwise enqueues for the next turn. Fallbacks --------- - Agent exits mid-steer → surfaces in run_conversation result as pending_steer so CLI/gateway deliver it as the next user turn instead of silently dropping it. - All tools skipped after interrupt → re-stashes pending_steer for the caller. - No active agent → /steer reduces to sending the text as a normal message. Tests ----- - tests/run_agent/test_steer.py — accept/reject, concatenation, drain, last-tool-result injection, multimodal list content, thread safety, cleared-on-interrupt, registry membership, bypass-set membership. - tests/gateway/test_steer_command.py — running agent, pending sentinel, missing steer() method, rejected payload, empty payload. - tests/gateway/test_command_bypass_active_session.py — /steer bypasses the Level-1 base adapter guard. - tests/test_tui_gateway_server.py — session.steer RPC paths. 72/72 targeted tests pass under scripts/run_tests.sh. * feat(steer): register /steer in Discord's native slash tree Discord's app_commands tree is a curated subset of slash commands (not derived from COMMAND_REGISTRY like Telegram/Slack). /steer already works there as plain text (routes through handle_message → base adapter bypass → runner), but registering it here adds Discord's native autocomplete + argument hint UI so users can discover and type it like any other first-class command.
This commit is contained in:
parent
f9667331e5
commit
2edebedc9e
12 changed files with 826 additions and 2 deletions
32
cli.py
32
cli.py
|
|
@ -5720,6 +5720,30 @@ class HermesCLI:
|
|||
_cprint(f" Queued for the next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
else:
|
||||
_cprint(f" Queued: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
elif canonical == "steer":
|
||||
# Inject a message after the next tool call without interrupting.
|
||||
# If the agent is actively running, push the text into the agent's
|
||||
# pending_steer slot — the drain hook in _execute_tool_calls_*
|
||||
# will append it to the next tool result's content. If no agent
|
||||
# is running, fall back to queue semantics (same as /queue).
|
||||
parts = cmd_original.split(None, 1)
|
||||
payload = parts[1].strip() if len(parts) > 1 else ""
|
||||
if not payload:
|
||||
_cprint(" Usage: /steer <prompt>")
|
||||
elif self._agent_running and self.agent is not None and hasattr(self.agent, "steer"):
|
||||
try:
|
||||
accepted = self.agent.steer(payload)
|
||||
except Exception as exc:
|
||||
_cprint(f" Steer failed: {exc}")
|
||||
else:
|
||||
if accepted:
|
||||
_cprint(f" ⏩ Steer queued — arrives after the next tool call: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
else:
|
||||
_cprint(" Steer rejected (empty payload).")
|
||||
else:
|
||||
# No active run — treat as a normal next-turn message.
|
||||
self._pending_input.put(payload)
|
||||
_cprint(f" No agent running; queued as next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
elif canonical == "skin":
|
||||
self._handle_skin_command(cmd_original)
|
||||
elif canonical == "voice":
|
||||
|
|
@ -8245,6 +8269,14 @@ class HermesCLI:
|
|||
print(f"\n⚡ Sending after interrupt: '{preview}'")
|
||||
self._pending_input.put(combined)
|
||||
|
||||
# If a /steer was left over (agent finished before another tool
|
||||
# batch could absorb it), deliver it as the next user turn.
|
||||
_leftover_steer = result.get("pending_steer") if result else None
|
||||
if _leftover_steer and hasattr(self, '_pending_input'):
|
||||
preview = _leftover_steer[:60] + ("..." if len(_leftover_steer) > 60 else "")
|
||||
print(f"\n⏩ Delivering leftover /steer as next turn: '{preview}'")
|
||||
self._pending_input.put(_leftover_steer)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1994,6 +1994,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
async def slash_stop(interaction: discord.Interaction):
|
||||
await self._run_simple_slash(interaction, "/stop", "Stop requested~")
|
||||
|
||||
@tree.command(name="steer", description="Inject a message after the next tool call (no interrupt)")
|
||||
@discord.app_commands.describe(prompt="Text to inject into the agent's next tool result")
|
||||
async def slash_steer(interaction: discord.Interaction, prompt: str):
|
||||
await self._run_simple_slash(interaction, f"/steer {prompt}".strip())
|
||||
|
||||
@tree.command(name="compress", description="Compress conversation context")
|
||||
async def slash_compress(interaction: discord.Interaction):
|
||||
await self._run_simple_slash(interaction, "/compress")
|
||||
|
|
|
|||
|
|
@ -3019,6 +3019,54 @@ class GatewayRunner:
|
|||
adapter._pending_messages[_quick_key] = queued_event
|
||||
return "Queued for the next turn."
|
||||
|
||||
# /steer <prompt> — inject mid-run after the next tool call.
|
||||
# Unlike /queue (turn boundary), /steer lands BETWEEN tool-call
|
||||
# iterations inside the same agent run, by appending to the
|
||||
# last tool result's content. No interrupt, no new user turn,
|
||||
# no role-alternation violation.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "steer":
|
||||
steer_text = event.get_command_args().strip()
|
||||
if not steer_text:
|
||||
return "Usage: /steer <prompt>"
|
||||
running_agent = self._running_agents.get(_quick_key)
|
||||
if running_agent is _AGENT_PENDING_SENTINEL:
|
||||
# Agent hasn't started yet — queue as turn-boundary fallback.
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
from gateway.platforms.base import MessageEvent as _ME, MessageType as _MT
|
||||
queued_event = _ME(
|
||||
text=steer_text,
|
||||
message_type=_MT.TEXT,
|
||||
source=event.source,
|
||||
message_id=event.message_id,
|
||||
channel_prompt=event.channel_prompt,
|
||||
)
|
||||
adapter._pending_messages[_quick_key] = queued_event
|
||||
return "Agent still starting — /steer queued for the next turn."
|
||||
if running_agent and hasattr(running_agent, "steer"):
|
||||
try:
|
||||
accepted = running_agent.steer(steer_text)
|
||||
except Exception as exc:
|
||||
logger.warning("Steer failed for session %s: %s", _quick_key[:20], exc)
|
||||
return f"⚠️ Steer failed: {exc}"
|
||||
if accepted:
|
||||
preview = steer_text[:60] + ("..." if len(steer_text) > 60 else "")
|
||||
return f"⏩ Steer queued — arrives after the next tool call: '{preview}'"
|
||||
return "Steer rejected (empty payload)."
|
||||
# Running agent is missing or lacks steer() — fall back to queue.
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
from gateway.platforms.base import MessageEvent as _ME, MessageType as _MT
|
||||
queued_event = _ME(
|
||||
text=steer_text,
|
||||
message_type=_MT.TEXT,
|
||||
source=event.source,
|
||||
message_id=event.message_id,
|
||||
channel_prompt=event.channel_prompt,
|
||||
)
|
||||
adapter._pending_messages[_quick_key] = queued_event
|
||||
return "No active agent — /steer queued for the next turn."
|
||||
|
||||
# /model must not be used while the agent is running.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "model":
|
||||
return "Agent is running — wait or /stop first, then switch models."
|
||||
|
|
@ -3260,6 +3308,21 @@ class GatewayRunner:
|
|||
if canonical == "btw":
|
||||
return await self._handle_btw_command(event)
|
||||
|
||||
if canonical == "steer":
|
||||
# No active agent — /steer has no tool call to inject into.
|
||||
# Strip the prefix so downstream treats it as a normal user
|
||||
# message. If the payload is empty, surface the usage hint.
|
||||
steer_payload = event.get_command_args().strip()
|
||||
if not steer_payload:
|
||||
return "Usage: /steer <prompt> (no agent is running; sending as a normal message)"
|
||||
try:
|
||||
event.text = steer_payload
|
||||
except Exception:
|
||||
pass
|
||||
# Do NOT return — fall through to _handle_message_with_agent
|
||||
# at the end of this function so the rewritten text is sent
|
||||
# to the agent as a regular user turn.
|
||||
|
||||
if canonical == "voice":
|
||||
return await self._handle_voice_command(event)
|
||||
|
||||
|
|
|
|||
|
|
@ -91,6 +91,8 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
|||
aliases=("tasks",)),
|
||||
CommandDef("queue", "Queue a prompt for the next turn (doesn't interrupt)", "Session",
|
||||
aliases=("q",), args_hint="<prompt>"),
|
||||
CommandDef("steer", "Inject a message after the next tool call without interrupting", "Session",
|
||||
args_hint="<prompt>"),
|
||||
CommandDef("status", "Show session info", "Session"),
|
||||
CommandDef("profile", "Show active profile name and home directory", "Info"),
|
||||
CommandDef("sethome", "Set this chat as the home channel", "Session",
|
||||
|
|
@ -275,6 +277,7 @@ ACTIVE_SESSION_BYPASS_COMMANDS: frozenset[str] = frozenset(
|
|||
"queue",
|
||||
"restart",
|
||||
"status",
|
||||
"steer",
|
||||
"stop",
|
||||
"update",
|
||||
}
|
||||
|
|
|
|||
152
run_agent.py
152
run_agent.py
|
|
@ -832,6 +832,16 @@ class AIAgent:
|
|||
self._interrupt_thread_signal_pending = False
|
||||
self._client_lock = threading.RLock()
|
||||
|
||||
# /steer mechanism — inject a user note into the next tool result
|
||||
# without interrupting the agent. Unlike interrupt(), steer() does
|
||||
# NOT set _interrupt_requested; it waits for the current tool batch
|
||||
# to finish naturally, then the drain hook appends the text to the
|
||||
# last tool result's content so the model sees it on its next
|
||||
# iteration. Message-role alternation is preserved (we modify an
|
||||
# existing tool message rather than inserting a new user turn).
|
||||
self._pending_steer: Optional[str] = None
|
||||
self._pending_steer_lock = threading.Lock()
|
||||
|
||||
# Concurrent-tool worker thread tracking. `_execute_tool_calls_concurrent`
|
||||
# runs each tool on its own ThreadPoolExecutor worker — those worker
|
||||
# threads have tids distinct from `_execution_thread_id`, so
|
||||
|
|
@ -3265,6 +3275,129 @@ class AIAgent:
|
|||
_set_interrupt(False, _wtid)
|
||||
except Exception:
|
||||
pass
|
||||
# A hard interrupt supersedes any pending /steer — the steer was
|
||||
# meant for the agent's next tool-call iteration, which will no
|
||||
# longer happen. Drop it instead of surprising the user with a
|
||||
# late injection on the post-interrupt turn.
|
||||
_steer_lock = getattr(self, "_pending_steer_lock", None)
|
||||
if _steer_lock is not None:
|
||||
with _steer_lock:
|
||||
self._pending_steer = None
|
||||
|
||||
def steer(self, text: str) -> bool:
|
||||
"""
|
||||
Inject a user message into the next tool result without interrupting.
|
||||
|
||||
Unlike interrupt(), this does NOT stop the current tool call. The
|
||||
text is stashed and the agent loop appends it to the LAST tool
|
||||
result's content once the current tool batch finishes. The model
|
||||
sees the steer as part of the tool output on its next iteration.
|
||||
|
||||
Thread-safe: callable from gateway/CLI/TUI threads. Multiple calls
|
||||
before the drain point concatenate with newlines.
|
||||
|
||||
Args:
|
||||
text: The user text to inject. Empty strings are ignored.
|
||||
|
||||
Returns:
|
||||
True if the steer was accepted, False if the text was empty.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return False
|
||||
cleaned = text.strip()
|
||||
_lock = getattr(self, "_pending_steer_lock", None)
|
||||
if _lock is None:
|
||||
# Test stubs that built AIAgent via object.__new__ skip __init__.
|
||||
# Fall back to direct attribute set; no concurrent callers expected
|
||||
# in those stubs.
|
||||
existing = getattr(self, "_pending_steer", None)
|
||||
self._pending_steer = (existing + "\n" + cleaned) if existing else cleaned
|
||||
return True
|
||||
with _lock:
|
||||
if self._pending_steer:
|
||||
self._pending_steer = self._pending_steer + "\n" + cleaned
|
||||
else:
|
||||
self._pending_steer = cleaned
|
||||
return True
|
||||
|
||||
def _drain_pending_steer(self) -> Optional[str]:
|
||||
"""Return the pending steer text (if any) and clear the slot.
|
||||
|
||||
Safe to call from the agent execution thread after appending tool
|
||||
results. Returns None when no steer is pending.
|
||||
"""
|
||||
_lock = getattr(self, "_pending_steer_lock", None)
|
||||
if _lock is None:
|
||||
text = getattr(self, "_pending_steer", None)
|
||||
self._pending_steer = None
|
||||
return text
|
||||
with _lock:
|
||||
text = self._pending_steer
|
||||
self._pending_steer = None
|
||||
return text
|
||||
|
||||
def _apply_pending_steer_to_tool_results(self, messages: list, num_tool_msgs: int) -> None:
|
||||
"""Append any pending /steer text to the last tool result in this turn.
|
||||
|
||||
Called at the end of a tool-call batch, before the next API call.
|
||||
The steer is appended to the last ``role:"tool"`` message's content
|
||||
with a clear marker so the model understands it came from the user
|
||||
and NOT from the tool itself. Role alternation is preserved —
|
||||
nothing new is inserted, we only modify existing content.
|
||||
|
||||
Args:
|
||||
messages: The running messages list.
|
||||
num_tool_msgs: Number of tool results appended in this batch;
|
||||
used to locate the tail slice safely.
|
||||
"""
|
||||
if num_tool_msgs <= 0 or not messages:
|
||||
return
|
||||
steer_text = self._drain_pending_steer()
|
||||
if not steer_text:
|
||||
return
|
||||
# Find the last tool-role message in the recent tail. Skipping
|
||||
# non-tool messages defends against future code appending
|
||||
# something else at the boundary.
|
||||
target_idx = None
|
||||
for j in range(len(messages) - 1, max(len(messages) - num_tool_msgs - 1, -1), -1):
|
||||
msg = messages[j]
|
||||
if isinstance(msg, dict) and msg.get("role") == "tool":
|
||||
target_idx = j
|
||||
break
|
||||
if target_idx is None:
|
||||
# No tool result in this batch (e.g. all skipped by interrupt);
|
||||
# put the steer back so the caller's fallback path can deliver
|
||||
# it as a normal next-turn user message.
|
||||
_lock = getattr(self, "_pending_steer_lock", None)
|
||||
if _lock is not None:
|
||||
with _lock:
|
||||
if self._pending_steer:
|
||||
self._pending_steer = self._pending_steer + "\n" + steer_text
|
||||
else:
|
||||
self._pending_steer = steer_text
|
||||
else:
|
||||
existing = getattr(self, "_pending_steer", None)
|
||||
self._pending_steer = (existing + "\n" + steer_text) if existing else steer_text
|
||||
return
|
||||
marker = f"\n\n[USER STEER (injected mid-run, not tool output): {steer_text}]"
|
||||
existing_content = messages[target_idx].get("content", "")
|
||||
if not isinstance(existing_content, str):
|
||||
# Anthropic multimodal content blocks — preserve them and append
|
||||
# a text block at the end.
|
||||
try:
|
||||
blocks = list(existing_content) if existing_content else []
|
||||
blocks.append({"type": "text", "text": marker.lstrip()})
|
||||
messages[target_idx]["content"] = blocks
|
||||
except Exception:
|
||||
# Fall back to string replacement if content shape is unexpected.
|
||||
messages[target_idx]["content"] = f"{existing_content}{marker}"
|
||||
else:
|
||||
messages[target_idx]["content"] = existing_content + marker
|
||||
logger.info(
|
||||
"Delivered /steer to agent after tool batch (%d chars): %s",
|
||||
len(steer_text),
|
||||
steer_text[:120] + ("..." if len(steer_text) > 120 else ""),
|
||||
)
|
||||
|
||||
def _touch_activity(self, desc: str) -> None:
|
||||
"""Update the last-activity timestamp and description (thread-safe)."""
|
||||
|
|
@ -7951,6 +8084,13 @@ class AIAgent:
|
|||
turn_tool_msgs = messages[-num_tools:]
|
||||
enforce_turn_budget(turn_tool_msgs, env=get_active_env(effective_task_id))
|
||||
|
||||
# ── /steer injection ──────────────────────────────────────────────
|
||||
# Append any pending user steer text to the last tool result so the
|
||||
# agent sees it on its next iteration. Runs AFTER budget enforcement
|
||||
# so the steer marker is never truncated. See steer() for details.
|
||||
if num_tools > 0:
|
||||
self._apply_pending_steer_to_tool_results(messages, num_tools)
|
||||
|
||||
def _execute_tool_calls_sequential(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
"""Execute tool calls sequentially (original behavior). Used for single calls or interactive tools."""
|
||||
for i, tool_call in enumerate(assistant_message.tool_calls, 1):
|
||||
|
|
@ -8330,6 +8470,12 @@ class AIAgent:
|
|||
if num_tools_seq > 0:
|
||||
enforce_turn_budget(messages[-num_tools_seq:], env=get_active_env(effective_task_id))
|
||||
|
||||
# ── /steer injection ──────────────────────────────────────────────
|
||||
# See _execute_tool_calls_parallel for the rationale. Same hook,
|
||||
# applied to sequential execution as well.
|
||||
if num_tools_seq > 0:
|
||||
self._apply_pending_steer_to_tool_results(messages, num_tools_seq)
|
||||
|
||||
|
||||
|
||||
def _handle_max_iterations(self, messages: list, api_call_count: int) -> str:
|
||||
|
|
@ -11610,6 +11756,12 @@ class AIAgent:
|
|||
"cost_status": self.session_cost_status,
|
||||
"cost_source": self.session_cost_source,
|
||||
}
|
||||
# If a /steer landed after the final assistant turn (no more tool
|
||||
# batches to drain into), hand it back to the caller so it can be
|
||||
# delivered as the next user turn instead of being silently lost.
|
||||
_leftover_steer = self._drain_pending_steer()
|
||||
if _leftover_steer:
|
||||
result["pending_steer"] = _leftover_steer
|
||||
self._response_was_previewed = False
|
||||
|
||||
# Include interrupt message if one triggered the interrupt
|
||||
|
|
|
|||
|
|
@ -200,6 +200,25 @@ class TestCommandBypassActiveSession:
|
|||
"/background response was not sent back to the user"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_bypasses_guard(self):
|
||||
"""/steer must bypass the Level-1 active-session guard so it reaches
|
||||
the gateway runner's /steer handler and injects into the running
|
||||
agent instead of being queued as user text for the next turn.
|
||||
"""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/steer also check auth.log"))
|
||||
|
||||
assert sk not in adapter._pending_messages, (
|
||||
"/steer was queued as a pending message instead of being dispatched"
|
||||
)
|
||||
assert any("handled:steer" in r for r in adapter.sent_responses), (
|
||||
"/steer response was not sent back to the user"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_bypasses_guard(self):
|
||||
"""/help must bypass so it is not silently dropped as pending slash text."""
|
||||
|
|
|
|||
191
tests/gateway/test_steer_command.py
Normal file
191
tests/gateway/test_steer_command.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
"""Tests for the gateway /steer command handler.
|
||||
|
||||
/steer injects a user message into the agent's next tool result without
|
||||
interrupting. The gateway runner must:
|
||||
|
||||
1. When an agent IS running → call ``agent.steer(text)``, do NOT set
|
||||
``_interrupt_requested``, do NOT touch ``_pending_messages``.
|
||||
2. When the agent is the PENDING sentinel → fall back to /queue
|
||||
semantics (store in ``adapter._pending_messages``).
|
||||
3. When no agent is active → strip the slash prefix and let the normal
|
||||
prompt pipeline handle it as a regular user message.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str) -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=_make_source(),
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
|
||||
def _make_runner(session_entry: SessionEntry):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock()
|
||||
adapter._pending_messages = {}
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = session_entry
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner.session_store.has_any_sessions.return_value = True
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.get_session_title.return_value = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._show_reasoning = False
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._set_session_env = lambda _context: None
|
||||
runner._should_send_voice_reply = lambda *_args, **_kwargs: False
|
||||
runner._send_voice_reply = AsyncMock()
|
||||
runner._capture_gateway_honcho_if_configured = lambda *args, **kwargs: None
|
||||
runner._emit_gateway_run_progress = AsyncMock()
|
||||
return runner, adapter
|
||||
|
||||
|
||||
def _session_entry() -> SessionEntry:
|
||||
return SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
total_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_calls_agent_steer_and_does_not_interrupt():
|
||||
"""When an agent is running, /steer must call agent.steer(text) and
|
||||
leave interrupt state untouched."""
|
||||
runner, adapter = _make_runner(_session_entry())
|
||||
sk = build_session_key(_make_source())
|
||||
|
||||
running_agent = MagicMock()
|
||||
running_agent.steer.return_value = True
|
||||
runner._running_agents[sk] = running_agent
|
||||
|
||||
result = await runner._handle_message(_make_event("/steer also check auth.log"))
|
||||
|
||||
# The handler replied with a confirmation
|
||||
assert result is not None
|
||||
assert "steer" in result.lower() or "queued" in result.lower()
|
||||
# The agent's steer() was called with the payload (prefix stripped)
|
||||
running_agent.steer.assert_called_once_with("also check auth.log")
|
||||
# Critically: interrupt was NOT called
|
||||
running_agent.interrupt.assert_not_called()
|
||||
# And no user-text queueing happened — the steer doesn't go into
|
||||
# _pending_messages (that would be turn-boundary /queue semantics).
|
||||
assert runner._pending_messages == {}
|
||||
assert adapter._pending_messages == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_without_payload_returns_usage():
|
||||
runner, _adapter = _make_runner(_session_entry())
|
||||
sk = build_session_key(_make_source())
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents[sk] = running_agent
|
||||
|
||||
result = await runner._handle_message(_make_event("/steer"))
|
||||
|
||||
assert result is not None
|
||||
assert "Usage" in result or "usage" in result
|
||||
running_agent.steer.assert_not_called()
|
||||
running_agent.interrupt.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_with_pending_sentinel_falls_back_to_queue():
|
||||
"""When the agent hasn't finished booting (sentinel), /steer should
|
||||
queue as a turn-boundary follow-up instead of crashing."""
|
||||
from gateway.run import _AGENT_PENDING_SENTINEL
|
||||
|
||||
runner, adapter = _make_runner(_session_entry())
|
||||
sk = build_session_key(_make_source())
|
||||
runner._running_agents[sk] = _AGENT_PENDING_SENTINEL
|
||||
|
||||
result = await runner._handle_message(_make_event("/steer wait up"))
|
||||
|
||||
assert result is not None
|
||||
assert "queued" in result.lower() or "starting" in result.lower()
|
||||
# The fallback put the text into the adapter's pending queue.
|
||||
assert sk in adapter._pending_messages
|
||||
assert adapter._pending_messages[sk].text == "wait up"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_agent_without_steer_method_falls_back():
|
||||
"""If the running agent somehow lacks the steer() method (older build,
|
||||
test stub), the handler must not explode — fall back to /queue."""
|
||||
runner, adapter = _make_runner(_session_entry())
|
||||
sk = build_session_key(_make_source())
|
||||
|
||||
# A bare object that does NOT have steer() — use a spec'd Mock so
|
||||
# hasattr(agent, "steer") returns False.
|
||||
running_agent = MagicMock(spec=[])
|
||||
runner._running_agents[sk] = running_agent
|
||||
|
||||
result = await runner._handle_message(_make_event("/steer fallback"))
|
||||
|
||||
assert result is not None
|
||||
# Must mention queueing since steer wasn't available
|
||||
assert "queued" in result.lower()
|
||||
assert sk in adapter._pending_messages
|
||||
assert adapter._pending_messages[sk].text == "fallback"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_rejected_payload_returns_rejection_message():
|
||||
"""If agent.steer() returns False (e.g. empty after strip — though
|
||||
the gateway already guards this), surface a rejection message."""
|
||||
runner, _adapter = _make_runner(_session_entry())
|
||||
sk = build_session_key(_make_source())
|
||||
|
||||
running_agent = MagicMock()
|
||||
running_agent.steer.return_value = False
|
||||
runner._running_agents[sk] = running_agent
|
||||
|
||||
result = await runner._handle_message(_make_event("/steer hello"))
|
||||
|
||||
assert result is not None
|
||||
assert "rejected" in result.lower() or "empty" in result.lower()
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
pytest.main([__file__, "-v"])
|
||||
228
tests/run_agent/test_steer.py
Normal file
228
tests/run_agent/test_steer.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
"""Tests for AIAgent.steer() — mid-run user message injection.
|
||||
|
||||
/steer lets the user add a note to the agent's next tool result without
|
||||
interrupting the current tool call. The agent sees the note inline with
|
||||
tool output on its next iteration, preserving message-role alternation
|
||||
and prompt-cache integrity.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _bare_agent() -> AIAgent:
|
||||
"""Build an AIAgent without running __init__, then install the steer
|
||||
state manually — matches the existing object.__new__ stub pattern
|
||||
used elsewhere in the test suite.
|
||||
"""
|
||||
agent = object.__new__(AIAgent)
|
||||
agent._pending_steer = None
|
||||
agent._pending_steer_lock = threading.Lock()
|
||||
return agent
|
||||
|
||||
|
||||
class TestSteerAcceptance:
|
||||
def test_accepts_non_empty_text(self):
|
||||
agent = _bare_agent()
|
||||
assert agent.steer("go ahead and check the logs") is True
|
||||
assert agent._pending_steer == "go ahead and check the logs"
|
||||
|
||||
def test_rejects_empty_string(self):
|
||||
agent = _bare_agent()
|
||||
assert agent.steer("") is False
|
||||
assert agent._pending_steer is None
|
||||
|
||||
def test_rejects_whitespace_only(self):
|
||||
agent = _bare_agent()
|
||||
assert agent.steer(" \n\t ") is False
|
||||
assert agent._pending_steer is None
|
||||
|
||||
def test_rejects_none(self):
|
||||
agent = _bare_agent()
|
||||
assert agent.steer(None) is False # type: ignore[arg-type]
|
||||
assert agent._pending_steer is None
|
||||
|
||||
def test_strips_surrounding_whitespace(self):
|
||||
agent = _bare_agent()
|
||||
assert agent.steer(" hello world \n") is True
|
||||
assert agent._pending_steer == "hello world"
|
||||
|
||||
def test_concatenates_multiple_steers_with_newlines(self):
|
||||
agent = _bare_agent()
|
||||
agent.steer("first note")
|
||||
agent.steer("second note")
|
||||
agent.steer("third note")
|
||||
assert agent._pending_steer == "first note\nsecond note\nthird note"
|
||||
|
||||
|
||||
class TestSteerDrain:
|
||||
def test_drain_returns_and_clears(self):
|
||||
agent = _bare_agent()
|
||||
agent.steer("hello")
|
||||
assert agent._drain_pending_steer() == "hello"
|
||||
assert agent._pending_steer is None
|
||||
|
||||
def test_drain_on_empty_returns_none(self):
|
||||
agent = _bare_agent()
|
||||
assert agent._drain_pending_steer() is None
|
||||
|
||||
|
||||
class TestSteerInjection:
|
||||
def test_appends_to_last_tool_result(self):
|
||||
agent = _bare_agent()
|
||||
agent.steer("please also check auth.log")
|
||||
messages = [
|
||||
{"role": "user", "content": "what's in /var/log?"},
|
||||
{"role": "assistant", "tool_calls": [{"id": "a"}, {"id": "b"}]},
|
||||
{"role": "tool", "content": "ls output A", "tool_call_id": "a"},
|
||||
{"role": "tool", "content": "ls output B", "tool_call_id": "b"},
|
||||
]
|
||||
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=2)
|
||||
# The LAST tool result is modified; earlier ones are untouched.
|
||||
assert messages[2]["content"] == "ls output A"
|
||||
assert "ls output B" in messages[3]["content"]
|
||||
assert "[USER STEER" in messages[3]["content"]
|
||||
assert "please also check auth.log" in messages[3]["content"]
|
||||
# And pending_steer is consumed.
|
||||
assert agent._pending_steer is None
|
||||
|
||||
def test_no_op_when_no_steer_pending(self):
|
||||
agent = _bare_agent()
|
||||
messages = [
|
||||
{"role": "assistant", "tool_calls": [{"id": "a"}]},
|
||||
{"role": "tool", "content": "output", "tool_call_id": "a"},
|
||||
]
|
||||
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1)
|
||||
assert messages[-1]["content"] == "output" # unchanged
|
||||
|
||||
def test_no_op_when_num_tool_msgs_zero(self):
|
||||
agent = _bare_agent()
|
||||
agent.steer("steer")
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=0)
|
||||
# Steer should remain pending (nothing to drain into)
|
||||
assert agent._pending_steer == "steer"
|
||||
|
||||
def test_marker_is_unambiguous_about_origin(self):
|
||||
"""The injection marker must make clear the text is from the user
|
||||
and not tool output — this is the cache-safe way to signal
|
||||
provenance without violating message-role alternation.
|
||||
"""
|
||||
agent = _bare_agent()
|
||||
agent.steer("stop after next step")
|
||||
messages = [{"role": "tool", "content": "x", "tool_call_id": "1"}]
|
||||
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1)
|
||||
content = messages[-1]["content"]
|
||||
assert "USER STEER" in content
|
||||
assert "not tool output" in content.lower() or "injected mid-run" in content.lower()
|
||||
|
||||
def test_multimodal_content_list_preserved(self):
|
||||
"""Anthropic-style list content should be preserved, with the steer
|
||||
appended as a text block."""
|
||||
agent = _bare_agent()
|
||||
agent.steer("extra note")
|
||||
original_blocks = [{"type": "text", "text": "existing output"}]
|
||||
messages = [
|
||||
{"role": "tool", "content": list(original_blocks), "tool_call_id": "1"}
|
||||
]
|
||||
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1)
|
||||
new_content = messages[-1]["content"]
|
||||
assert isinstance(new_content, list)
|
||||
assert len(new_content) == 2
|
||||
assert new_content[0] == {"type": "text", "text": "existing output"}
|
||||
assert new_content[1]["type"] == "text"
|
||||
assert "extra note" in new_content[1]["text"]
|
||||
|
||||
def test_restashed_when_no_tool_result_in_batch(self):
|
||||
"""If the 'batch' contains no tool-role messages (e.g. all skipped
|
||||
after an interrupt), the steer should be put back into the pending
|
||||
slot so the caller's fallback path can deliver it."""
|
||||
agent = _bare_agent()
|
||||
agent.steer("ping")
|
||||
messages = [
|
||||
{"role": "user", "content": "x"},
|
||||
{"role": "assistant", "content": "y"},
|
||||
]
|
||||
# Claim there were N tool msgs, but the tail has none — simulates
|
||||
# the interrupt-cancelled case.
|
||||
agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=2)
|
||||
# Messages untouched
|
||||
assert messages[-1]["content"] == "y"
|
||||
# And the steer is back in pending so the fallback can grab it
|
||||
assert agent._pending_steer == "ping"
|
||||
|
||||
|
||||
class TestSteerThreadSafety:
|
||||
def test_concurrent_steer_calls_preserve_all_text(self):
|
||||
agent = _bare_agent()
|
||||
N = 200
|
||||
|
||||
def worker(idx: int) -> None:
|
||||
agent.steer(f"note-{idx}")
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(N)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
text = agent._drain_pending_steer()
|
||||
assert text is not None
|
||||
# Every single note must be preserved — none dropped by the lock.
|
||||
lines = text.split("\n")
|
||||
assert len(lines) == N
|
||||
assert set(lines) == {f"note-{i}" for i in range(N)}
|
||||
|
||||
|
||||
class TestSteerClearedOnInterrupt:
|
||||
def test_clear_interrupt_drops_pending_steer(self):
|
||||
"""A hard interrupt supersedes any pending steer — the agent's
|
||||
next tool iteration won't happen, so delivering the steer later
|
||||
would be surprising."""
|
||||
agent = _bare_agent()
|
||||
# Minimal surface needed by clear_interrupt()
|
||||
agent._interrupt_requested = True
|
||||
agent._interrupt_message = None
|
||||
agent._interrupt_thread_signal_pending = False
|
||||
agent._execution_thread_id = None
|
||||
agent._tool_worker_threads = None
|
||||
agent._tool_worker_threads_lock = None
|
||||
|
||||
agent.steer("will be dropped")
|
||||
assert agent._pending_steer == "will be dropped"
|
||||
|
||||
agent.clear_interrupt()
|
||||
assert agent._pending_steer is None
|
||||
|
||||
|
||||
class TestSteerCommandRegistry:
|
||||
def test_steer_in_command_registry(self):
|
||||
"""The /steer slash command must be registered so it reaches all
|
||||
platforms (CLI, gateway, TUI autocomplete, Telegram/Slack menus).
|
||||
"""
|
||||
from hermes_cli.commands import resolve_command, ACTIVE_SESSION_BYPASS_COMMANDS
|
||||
|
||||
cmd = resolve_command("steer")
|
||||
assert cmd is not None
|
||||
assert cmd.name == "steer"
|
||||
assert cmd.category == "Session"
|
||||
assert cmd.args_hint == "<prompt>"
|
||||
|
||||
def test_steer_in_bypass_set(self):
|
||||
"""When the agent is running, /steer MUST bypass the Level-1
|
||||
base-adapter queue so it reaches the gateway runner's /steer
|
||||
handler. Otherwise it would be queued as user text and only
|
||||
delivered at turn end — defeating the whole point.
|
||||
"""
|
||||
from hermes_cli.commands import ACTIVE_SESSION_BYPASS_COMMANDS, should_bypass_active_session
|
||||
|
||||
assert "steer" in ACTIVE_SESSION_BYPASS_COMMANDS
|
||||
assert should_bypass_active_session("steer") is True
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
pytest.main([__file__, "-v"])
|
||||
|
|
@ -438,3 +438,74 @@ def test_rollback_restore_resolves_number_and_file_path():
|
|||
assert resp["result"]["success"] is True
|
||||
assert calls["args"][1] == "bbb222"
|
||||
assert calls["args"][2] == "src/app.tsx"
|
||||
|
||||
|
||||
# ── session.steer ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_session_steer_calls_agent_steer_when_agent_supports_it():
|
||||
"""The TUI RPC method must call agent.steer(text) and return a
|
||||
queued status without touching interrupt state.
|
||||
"""
|
||||
calls = {}
|
||||
|
||||
class _Agent:
|
||||
def steer(self, text):
|
||||
calls["steer_text"] = text
|
||||
return True
|
||||
|
||||
def interrupt(self, *args, **kwargs):
|
||||
calls["interrupt_called"] = True
|
||||
|
||||
server._sessions["sid"] = _session(agent=_Agent())
|
||||
try:
|
||||
resp = server.handle_request(
|
||||
{
|
||||
"id": "1",
|
||||
"method": "session.steer",
|
||||
"params": {"session_id": "sid", "text": "also check auth.log"},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
assert "result" in resp, resp
|
||||
assert resp["result"]["status"] == "queued"
|
||||
assert resp["result"]["text"] == "also check auth.log"
|
||||
assert calls["steer_text"] == "also check auth.log"
|
||||
assert "interrupt_called" not in calls # must NOT interrupt
|
||||
|
||||
|
||||
def test_session_steer_rejects_empty_text():
|
||||
server._sessions["sid"] = _session(agent=types.SimpleNamespace(steer=lambda t: True))
|
||||
try:
|
||||
resp = server.handle_request(
|
||||
{
|
||||
"id": "1",
|
||||
"method": "session.steer",
|
||||
"params": {"session_id": "sid", "text": " "},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
assert "error" in resp, resp
|
||||
assert resp["error"]["code"] == 4002
|
||||
|
||||
|
||||
def test_session_steer_errors_when_agent_has_no_steer_method():
|
||||
server._sessions["sid"] = _session(agent=types.SimpleNamespace()) # no steer()
|
||||
try:
|
||||
resp = server.handle_request(
|
||||
{
|
||||
"id": "1",
|
||||
"method": "session.steer",
|
||||
"params": {"session_id": "sid", "text": "hi"},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
server._sessions.pop("sid", None)
|
||||
|
||||
assert "error" in resp, resp
|
||||
assert resp["error"]["code"] == 4010
|
||||
|
||||
|
|
|
|||
|
|
@ -1340,6 +1340,31 @@ def _(rid, params: dict) -> dict:
|
|||
return _ok(rid, {"status": "interrupted"})
|
||||
|
||||
|
||||
@method("session.steer")
|
||||
def _(rid, params: dict) -> dict:
|
||||
"""Inject a user message into the next tool result without interrupting.
|
||||
|
||||
Mirrors AIAgent.steer(). Safe to call while a turn is running — the text
|
||||
lands on the last tool result of the next tool batch and the model sees
|
||||
it on its next iteration. No interrupt, no new user turn, no role
|
||||
alternation violation.
|
||||
"""
|
||||
text = (params.get("text") or "").strip()
|
||||
if not text:
|
||||
return _err(rid, 4002, "text is required")
|
||||
session, err = _sess_nowait(params, rid)
|
||||
if err:
|
||||
return err
|
||||
agent = session.get("agent")
|
||||
if agent is None or not hasattr(agent, "steer"):
|
||||
return _err(rid, 4010, "agent does not support steer")
|
||||
try:
|
||||
accepted = agent.steer(text)
|
||||
except Exception as exc:
|
||||
return _err(rid, 5000, f"steer failed: {exc}")
|
||||
return _ok(rid, {"status": "queued" if accepted else "rejected", "text": text})
|
||||
|
||||
|
||||
@method("terminal.resize")
|
||||
def _(rid, params: dict) -> dict:
|
||||
session, err = _sess_nowait(params, rid)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import { dailyFortune, randomFortune } from '../../../content/fortunes.js'
|
||||
import { HOTKEYS } from '../../../content/hotkeys.js'
|
||||
import { nextDetailsMode, parseDetailsMode } from '../../../domain/details.js'
|
||||
import type { ConfigGetValueResponse, ConfigSetResponse, SessionUndoResponse } from '../../../gatewayTypes.js'
|
||||
import type { ConfigGetValueResponse, ConfigSetResponse, SessionSteerResponse, SessionUndoResponse } from '../../../gatewayTypes.js'
|
||||
import { writeOsc52Clipboard } from '../../../lib/osc52.js'
|
||||
import type { DetailsMode, Msg, PanelSection } from '../../../types.js'
|
||||
import { patchOverlayState } from '../../overlayStore.js'
|
||||
|
|
@ -245,6 +245,36 @@ export const coreCommands: SlashCommand[] = [
|
|||
}
|
||||
},
|
||||
|
||||
{
|
||||
help: 'inject a message after the next tool call (no interrupt)',
|
||||
name: 'steer',
|
||||
run: (arg, ctx) => {
|
||||
const payload = arg?.trim() ?? ''
|
||||
|
||||
if (!payload) {
|
||||
return ctx.transcript.sys('usage: /steer <prompt>')
|
||||
}
|
||||
|
||||
// If the agent isn't running, fall back to the queue so the user's
|
||||
// message isn't lost — identical semantics to the gateway handler.
|
||||
if (!ctx.ui.busy || !ctx.sid) {
|
||||
ctx.composer.enqueue(payload)
|
||||
ctx.transcript.sys(`no active turn — queued for next: "${payload.slice(0, 50)}${payload.length > 50 ? '…' : ''}"`)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.gateway.rpc<SessionSteerResponse>('session.steer', { session_id: ctx.sid, text: payload }).then(
|
||||
ctx.guarded<SessionSteerResponse>(r => {
|
||||
if (r?.status === 'queued') {
|
||||
ctx.transcript.sys(`⏩ steer queued — arrives after next tool call: "${payload.slice(0, 50)}${payload.length > 50 ? '…' : ''}"`)
|
||||
} else {
|
||||
ctx.transcript.sys('steer rejected')
|
||||
}
|
||||
})
|
||||
).catch(ctx.guardedErr)
|
||||
}
|
||||
},
|
||||
|
||||
{
|
||||
help: 'undo last exchange',
|
||||
name: 'undo',
|
||||
|
|
|
|||
|
|
@ -152,6 +152,11 @@ export interface SessionInterruptResponse {
|
|||
ok?: boolean
|
||||
}
|
||||
|
||||
export interface SessionSteerResponse {
|
||||
status?: 'queued' | 'rejected'
|
||||
text?: string
|
||||
}
|
||||
|
||||
// ── Prompt / submission ──────────────────────────────────────────────
|
||||
|
||||
export interface PromptSubmitResponse {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue