mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-29 06:31:32 +00:00
fix(gateway): wire clarify callback for messaging sessions
This commit is contained in:
parent
a521005fe5
commit
78633e58de
2 changed files with 301 additions and 2 deletions
129
gateway/run.py
129
gateway/run.py
|
|
@ -28,7 +28,7 @@ from collections import OrderedDict
|
||||||
from contextvars import copy_context
|
from contextvars import copy_context
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, Optional, Any, List
|
from typing import Dict, Optional, Any, List, Tuple
|
||||||
|
|
||||||
# --- Agent cache tuning ---------------------------------------------------
|
# --- Agent cache tuning ---------------------------------------------------
|
||||||
# Bounds the per-session AIAgent cache to prevent unbounded growth in
|
# Bounds the per-session AIAgent cache to prevent unbounded growth in
|
||||||
|
|
@ -688,6 +688,10 @@ class GatewayRunner:
|
||||||
# Key: session_key, Value: True when a prompt is waiting for user input.
|
# Key: session_key, Value: True when a prompt is waiting for user input.
|
||||||
self._update_prompt_pending: Dict[str, bool] = {}
|
self._update_prompt_pending: Dict[str, bool] = {}
|
||||||
|
|
||||||
|
# Track pending clarify-tool prompts per session.
|
||||||
|
# Key: session_key, Value: {question, choices, response, event, user_id}
|
||||||
|
self._pending_clarify: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
# Persistent Honcho managers keyed by gateway session key.
|
# Persistent Honcho managers keyed by gateway session key.
|
||||||
# This preserves write_frequency="session" semantics across short-lived
|
# This preserves write_frequency="session" semantics across short-lived
|
||||||
# per-message AIAgent instances.
|
# per-message AIAgent instances.
|
||||||
|
|
@ -1007,6 +1011,113 @@ class GatewayRunner:
|
||||||
thread_sessions_per_user=getattr(config, "thread_sessions_per_user", False),
|
thread_sessions_per_user=getattr(config, "thread_sessions_per_user", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_clarify_prompt(question: str, choices: Optional[List[str]]) -> str:
|
||||||
|
"""Render a gateway-friendly clarify prompt for messaging platforms."""
|
||||||
|
lines = [f"❓ {question.strip()}"]
|
||||||
|
if choices:
|
||||||
|
for idx, choice in enumerate(choices, start=1):
|
||||||
|
lines.append(f"{idx}. {choice}")
|
||||||
|
lines.append(f"{len(choices) + 1}. Other (type your answer)")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("Reply with the number or the text of your choice.")
|
||||||
|
else:
|
||||||
|
lines.append("")
|
||||||
|
lines.append("Reply with your answer.")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce_clarify_response(raw: str, choices: Optional[List[str]]) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
"""Normalize a user clarify reply into the chosen response text."""
|
||||||
|
text = str(raw or "").strip()
|
||||||
|
if not text:
|
||||||
|
return None, "Please reply with your answer."
|
||||||
|
if not choices:
|
||||||
|
return text, None
|
||||||
|
|
||||||
|
if text.isdigit():
|
||||||
|
idx = int(text)
|
||||||
|
if 1 <= idx <= len(choices):
|
||||||
|
return choices[idx - 1], None
|
||||||
|
if idx == len(choices) + 1:
|
||||||
|
return None, "Type your custom answer in a new message."
|
||||||
|
return None, f"Please reply with 1-{len(choices) + 1}, or type your answer."
|
||||||
|
|
||||||
|
lowered = text.casefold()
|
||||||
|
for choice in choices:
|
||||||
|
if lowered == str(choice).strip().casefold():
|
||||||
|
return choice, None
|
||||||
|
|
||||||
|
return text, None
|
||||||
|
|
||||||
|
def _build_clarify_callback(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
source: SessionSource,
|
||||||
|
session_key: str,
|
||||||
|
loop: asyncio.AbstractEventLoop,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
"""Build a blocking clarify callback for gateway-created agents."""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
adapter = self.adapters.get(source.platform)
|
||||||
|
|
||||||
|
def _callback(question: str, choices: Optional[List[str]]) -> str:
|
||||||
|
if not adapter:
|
||||||
|
raise RuntimeError(f"No adapter available for platform {source.platform}")
|
||||||
|
|
||||||
|
entry = {
|
||||||
|
"question": question,
|
||||||
|
"choices": list(choices or []) or None,
|
||||||
|
"response": None,
|
||||||
|
"event": threading.Event(),
|
||||||
|
"user_id": source.user_id,
|
||||||
|
}
|
||||||
|
self._pending_clarify[session_key] = entry
|
||||||
|
|
||||||
|
async def _send_prompt() -> None:
|
||||||
|
await adapter.send(
|
||||||
|
source.chat_id,
|
||||||
|
self._format_clarify_prompt(question, entry["choices"]),
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
fut = asyncio.run_coroutine_threadsafe(_send_prompt(), loop)
|
||||||
|
fut.result(timeout=15)
|
||||||
|
timeout = int(os.getenv("HERMES_CLARIFY_TIMEOUT", "600"))
|
||||||
|
if not entry["event"].wait(timeout):
|
||||||
|
raise TimeoutError(f"clarify timed out after {timeout}s")
|
||||||
|
return str(entry.get("response") or "").strip()
|
||||||
|
finally:
|
||||||
|
self._pending_clarify.pop(session_key, None)
|
||||||
|
|
||||||
|
return _callback
|
||||||
|
|
||||||
|
async def _handle_pending_clarify(
|
||||||
|
self,
|
||||||
|
event: MessageEvent,
|
||||||
|
session_key: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Resolve a pending clarify prompt from the user's next message."""
|
||||||
|
pending = getattr(self, "_pending_clarify", None) or {}
|
||||||
|
entry = pending.get(session_key)
|
||||||
|
if not entry:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cmd = event.get_command()
|
||||||
|
if cmd in {"status", "stop", "restart", "new", "reset", "help", "model", "background", "queue", "approve", "deny"}:
|
||||||
|
return "__skip__"
|
||||||
|
|
||||||
|
response, error_message = self._coerce_clarify_response(event.text or "", entry.get("choices"))
|
||||||
|
if response is None:
|
||||||
|
return error_message or "Please reply with your answer."
|
||||||
|
|
||||||
|
entry["response"] = response
|
||||||
|
entry["event"].set()
|
||||||
|
return ""
|
||||||
|
|
||||||
def _resolve_session_agent_runtime(
|
def _resolve_session_agent_runtime(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|
@ -3009,11 +3120,19 @@ class GatewayRunner:
|
||||||
self.pairing_store._record_rate_limit(platform_name, source.user_id)
|
self.pairing_store._record_rate_limit(platform_name, source.user_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
_quick_key = self._session_key_for_source(source)
|
||||||
|
|
||||||
|
# Intercept messages that are responses to a pending clarify prompt.
|
||||||
|
_clarify_result = await self._handle_pending_clarify(event, _quick_key)
|
||||||
|
if _clarify_result == "":
|
||||||
|
return None
|
||||||
|
if _clarify_result and _clarify_result != "__skip__":
|
||||||
|
return _clarify_result
|
||||||
|
|
||||||
# Intercept messages that are responses to a pending /update prompt.
|
# Intercept messages that are responses to a pending /update prompt.
|
||||||
# The update process (detached) wrote .update_prompt.json; the watcher
|
# The update process (detached) wrote .update_prompt.json; the watcher
|
||||||
# forwarded it to the user; now the user's reply goes back via
|
# forwarded it to the user; now the user's reply goes back via
|
||||||
# .update_response so the update process can continue.
|
# .update_response so the update process can continue.
|
||||||
_quick_key = self._session_key_for_source(source)
|
|
||||||
_update_prompts = getattr(self, "_update_prompt_pending", {})
|
_update_prompts = getattr(self, "_update_prompt_pending", {})
|
||||||
if _update_prompts.get(_quick_key):
|
if _update_prompts.get(_quick_key):
|
||||||
raw = (event.text or "").strip()
|
raw = (event.text or "").strip()
|
||||||
|
|
@ -9564,6 +9683,12 @@ class GatewayRunner:
|
||||||
agent.stream_delta_callback = _stream_delta_cb
|
agent.stream_delta_callback = _stream_delta_cb
|
||||||
agent.interim_assistant_callback = _interim_assistant_cb if _want_interim_messages else None
|
agent.interim_assistant_callback = _interim_assistant_cb if _want_interim_messages else None
|
||||||
agent.status_callback = _status_callback_sync
|
agent.status_callback = _status_callback_sync
|
||||||
|
agent.clarify_callback = self._build_clarify_callback(
|
||||||
|
source=source,
|
||||||
|
session_key=session_key,
|
||||||
|
loop=_loop_for_step,
|
||||||
|
metadata=_status_thread_metadata,
|
||||||
|
)
|
||||||
agent.reasoning_config = reasoning_config
|
agent.reasoning_config = reasoning_config
|
||||||
agent.service_tier = self._service_tier
|
agent.service_tier = self._service_tier
|
||||||
agent.request_overrides = turn_route.get("request_overrides")
|
agent.request_overrides = turn_route.get("request_overrides")
|
||||||
|
|
|
||||||
174
tests/gateway/test_clarify_callback.py
Normal file
174
tests/gateway/test_clarify_callback.py
Normal file
|
|
@ -0,0 +1,174 @@
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gateway.config import Platform, StreamingConfig
|
||||||
|
from gateway.platforms.base import MessageEvent, MessageType
|
||||||
|
from gateway.run import GatewayRunner
|
||||||
|
from gateway.session import SessionSource
|
||||||
|
|
||||||
|
|
||||||
|
def _make_source(platform=Platform.TELEGRAM):
|
||||||
|
return SessionSource(
|
||||||
|
platform=platform,
|
||||||
|
chat_id="6493121275",
|
||||||
|
chat_name="Test Chat",
|
||||||
|
chat_type="dm",
|
||||||
|
user_id="6493121275",
|
||||||
|
user_name="Tyler",
|
||||||
|
thread_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gateway_clarify_callback_round_trip():
|
||||||
|
runner = GatewayRunner.__new__(GatewayRunner)
|
||||||
|
adapter = MagicMock()
|
||||||
|
adapter.send = AsyncMock()
|
||||||
|
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||||
|
runner._pending_clarify = {}
|
||||||
|
source = _make_source()
|
||||||
|
|
||||||
|
callback = runner._build_clarify_callback(
|
||||||
|
source=source,
|
||||||
|
session_key="telegram:6493121275",
|
||||||
|
loop=asyncio.get_running_loop(),
|
||||||
|
metadata=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_box = {}
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
result_box["result"] = callback("Pick a color", ["red", "blue"])
|
||||||
|
|
||||||
|
thread = threading.Thread(target=worker)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
for _ in range(20):
|
||||||
|
if runner._pending_clarify:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
for _ in range(20):
|
||||||
|
if adapter.send.await_count:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
assert "telegram:6493121275" in runner._pending_clarify
|
||||||
|
adapter.send.assert_awaited_once()
|
||||||
|
sent_text = adapter.send.await_args.args[1]
|
||||||
|
assert "Pick a color" in sent_text
|
||||||
|
assert "1. red" in sent_text
|
||||||
|
assert "2. blue" in sent_text
|
||||||
|
|
||||||
|
entry = runner._pending_clarify["telegram:6493121275"]
|
||||||
|
entry["response"] = "blue"
|
||||||
|
entry["event"].set()
|
||||||
|
|
||||||
|
thread.join(timeout=2)
|
||||||
|
assert result_box["result"] == "blue"
|
||||||
|
assert "telegram:6493121275" not in runner._pending_clarify
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_pending_clarify_consumes_numeric_reply():
|
||||||
|
runner = GatewayRunner.__new__(GatewayRunner)
|
||||||
|
runner._pending_clarify = {
|
||||||
|
"telegram:6493121275": {
|
||||||
|
"question": "Pick a color",
|
||||||
|
"choices": ["red", "blue"],
|
||||||
|
"response": None,
|
||||||
|
"event": threading.Event(),
|
||||||
|
"user_id": "6493121275",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
event = MessageEvent(
|
||||||
|
text="2",
|
||||||
|
message_type=MessageType.TEXT,
|
||||||
|
source=_make_source(),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await runner._handle_pending_clarify(event, "telegram:6493121275")
|
||||||
|
|
||||||
|
assert result == ""
|
||||||
|
entry = runner._pending_clarify["telegram:6493121275"]
|
||||||
|
assert entry["response"] == "blue"
|
||||||
|
assert entry["event"].is_set()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_agent_wires_clarify_callback_to_agent(monkeypatch):
|
||||||
|
runner = GatewayRunner.__new__(GatewayRunner)
|
||||||
|
runner.adapters = {}
|
||||||
|
runner.config = MagicMock()
|
||||||
|
runner.config.streaming = StreamingConfig()
|
||||||
|
runner._running_agents = {}
|
||||||
|
runner._running_agents_ts = {}
|
||||||
|
runner._session_model_overrides = {}
|
||||||
|
runner._agent_cache = {}
|
||||||
|
runner._agent_cache_lock = None
|
||||||
|
runner._provider_routing = {
|
||||||
|
"only": None,
|
||||||
|
"ignore": None,
|
||||||
|
"order": None,
|
||||||
|
"sort": None,
|
||||||
|
"require_parameters": False,
|
||||||
|
"data_collection": None,
|
||||||
|
}
|
||||||
|
runner._fallback_model = None
|
||||||
|
runner._prefill_messages = None
|
||||||
|
runner._ephemeral_system_prompt = ""
|
||||||
|
runner._session_db = None
|
||||||
|
runner._pending_clarify = {}
|
||||||
|
runner.hooks = MagicMock()
|
||||||
|
runner.hooks.loaded_hooks = False
|
||||||
|
runner._load_reasoning_config = lambda: None
|
||||||
|
runner._load_service_tier = lambda: None
|
||||||
|
runner._resolve_session_agent_runtime = lambda **kw: (
|
||||||
|
"anthropic/claude-sonnet-4",
|
||||||
|
{
|
||||||
|
"api_key": "test-key",
|
||||||
|
"base_url": "https://openrouter.ai/api/v1",
|
||||||
|
"provider": "openrouter",
|
||||||
|
"api_mode": "chat_completions",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
runner._resolve_turn_agent_config = lambda message, model, runtime: {
|
||||||
|
"model": model,
|
||||||
|
"runtime": runtime,
|
||||||
|
"request_overrides": None,
|
||||||
|
}
|
||||||
|
runner._build_clarify_callback = lambda **kw: (lambda question, choices: "blue")
|
||||||
|
runner._get_proxy_url = lambda: None
|
||||||
|
|
||||||
|
class FakeAgent:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.clarify_callback = None
|
||||||
|
self.tools = []
|
||||||
|
|
||||||
|
def run_conversation(self, user_message=None, **kwargs):
|
||||||
|
return {
|
||||||
|
"final_response": self.clarify_callback("Pick a color", ["red", "blue"]),
|
||||||
|
"messages": [],
|
||||||
|
"api_calls": 1,
|
||||||
|
"completed": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr("gateway.run._load_gateway_config", lambda: {"display": {}})
|
||||||
|
|
||||||
|
source = _make_source()
|
||||||
|
with patch("run_agent.AIAgent", FakeAgent), patch(
|
||||||
|
"hermes_cli.tools_config._get_platform_tools", return_value=[]
|
||||||
|
):
|
||||||
|
result = await runner._run_agent(
|
||||||
|
message="hello",
|
||||||
|
context_prompt="",
|
||||||
|
history=[],
|
||||||
|
source=source,
|
||||||
|
session_id="session-1",
|
||||||
|
session_key="telegram:6493121275",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["final_response"] == "blue"
|
||||||
Loading…
Add table
Add a link
Reference in a new issue