fix(gateway): wire clarify callback for messaging sessions

This commit is contained in:
tymrtn 2026-04-19 15:51:00 +02:00
parent a521005fe5
commit 78633e58de
2 changed files with 301 additions and 2 deletions

View file

@ -28,7 +28,7 @@ from collections import OrderedDict
from contextvars import copy_context from contextvars import copy_context
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
from typing import Dict, Optional, Any, List from typing import Dict, Optional, Any, List, Tuple
# --- Agent cache tuning --------------------------------------------------- # --- Agent cache tuning ---------------------------------------------------
# Bounds the per-session AIAgent cache to prevent unbounded growth in # Bounds the per-session AIAgent cache to prevent unbounded growth in
@ -688,6 +688,10 @@ class GatewayRunner:
# Key: session_key, Value: True when a prompt is waiting for user input. # Key: session_key, Value: True when a prompt is waiting for user input.
self._update_prompt_pending: Dict[str, bool] = {} self._update_prompt_pending: Dict[str, bool] = {}
# Track pending clarify-tool prompts per session.
# Key: session_key, Value: {question, choices, response, event, user_id}
self._pending_clarify: Dict[str, Dict[str, Any]] = {}
# Persistent Honcho managers keyed by gateway session key. # Persistent Honcho managers keyed by gateway session key.
# This preserves write_frequency="session" semantics across short-lived # This preserves write_frequency="session" semantics across short-lived
# per-message AIAgent instances. # per-message AIAgent instances.
@ -1007,6 +1011,113 @@ class GatewayRunner:
thread_sessions_per_user=getattr(config, "thread_sessions_per_user", False), thread_sessions_per_user=getattr(config, "thread_sessions_per_user", False),
) )
@staticmethod
def _format_clarify_prompt(question: str, choices: Optional[List[str]]) -> str:
"""Render a gateway-friendly clarify prompt for messaging platforms."""
lines = [f"{question.strip()}"]
if choices:
for idx, choice in enumerate(choices, start=1):
lines.append(f"{idx}. {choice}")
lines.append(f"{len(choices) + 1}. Other (type your answer)")
lines.append("")
lines.append("Reply with the number or the text of your choice.")
else:
lines.append("")
lines.append("Reply with your answer.")
return "\n".join(lines)
@staticmethod
def _coerce_clarify_response(raw: str, choices: Optional[List[str]]) -> Tuple[Optional[str], Optional[str]]:
"""Normalize a user clarify reply into the chosen response text."""
text = str(raw or "").strip()
if not text:
return None, "Please reply with your answer."
if not choices:
return text, None
if text.isdigit():
idx = int(text)
if 1 <= idx <= len(choices):
return choices[idx - 1], None
if idx == len(choices) + 1:
return None, "Type your custom answer in a new message."
return None, f"Please reply with 1-{len(choices) + 1}, or type your answer."
lowered = text.casefold()
for choice in choices:
if lowered == str(choice).strip().casefold():
return choice, None
return text, None
def _build_clarify_callback(
self,
*,
source: SessionSource,
session_key: str,
loop: asyncio.AbstractEventLoop,
metadata: Optional[Dict[str, Any]] = None,
):
"""Build a blocking clarify callback for gateway-created agents."""
import threading
adapter = self.adapters.get(source.platform)
def _callback(question: str, choices: Optional[List[str]]) -> str:
if not adapter:
raise RuntimeError(f"No adapter available for platform {source.platform}")
entry = {
"question": question,
"choices": list(choices or []) or None,
"response": None,
"event": threading.Event(),
"user_id": source.user_id,
}
self._pending_clarify[session_key] = entry
async def _send_prompt() -> None:
await adapter.send(
source.chat_id,
self._format_clarify_prompt(question, entry["choices"]),
metadata=metadata,
)
try:
fut = asyncio.run_coroutine_threadsafe(_send_prompt(), loop)
fut.result(timeout=15)
timeout = int(os.getenv("HERMES_CLARIFY_TIMEOUT", "600"))
if not entry["event"].wait(timeout):
raise TimeoutError(f"clarify timed out after {timeout}s")
return str(entry.get("response") or "").strip()
finally:
self._pending_clarify.pop(session_key, None)
return _callback
async def _handle_pending_clarify(
self,
event: MessageEvent,
session_key: str,
) -> Optional[str]:
"""Resolve a pending clarify prompt from the user's next message."""
pending = getattr(self, "_pending_clarify", None) or {}
entry = pending.get(session_key)
if not entry:
return None
cmd = event.get_command()
if cmd in {"status", "stop", "restart", "new", "reset", "help", "model", "background", "queue", "approve", "deny"}:
return "__skip__"
response, error_message = self._coerce_clarify_response(event.text or "", entry.get("choices"))
if response is None:
return error_message or "Please reply with your answer."
entry["response"] = response
entry["event"].set()
return ""
def _resolve_session_agent_runtime( def _resolve_session_agent_runtime(
self, self,
*, *,
@ -3009,11 +3120,19 @@ class GatewayRunner:
self.pairing_store._record_rate_limit(platform_name, source.user_id) self.pairing_store._record_rate_limit(platform_name, source.user_id)
return None return None
_quick_key = self._session_key_for_source(source)
# Intercept messages that are responses to a pending clarify prompt.
_clarify_result = await self._handle_pending_clarify(event, _quick_key)
if _clarify_result == "":
return None
if _clarify_result and _clarify_result != "__skip__":
return _clarify_result
# Intercept messages that are responses to a pending /update prompt. # Intercept messages that are responses to a pending /update prompt.
# The update process (detached) wrote .update_prompt.json; the watcher # The update process (detached) wrote .update_prompt.json; the watcher
# forwarded it to the user; now the user's reply goes back via # forwarded it to the user; now the user's reply goes back via
# .update_response so the update process can continue. # .update_response so the update process can continue.
_quick_key = self._session_key_for_source(source)
_update_prompts = getattr(self, "_update_prompt_pending", {}) _update_prompts = getattr(self, "_update_prompt_pending", {})
if _update_prompts.get(_quick_key): if _update_prompts.get(_quick_key):
raw = (event.text or "").strip() raw = (event.text or "").strip()
@ -9564,6 +9683,12 @@ class GatewayRunner:
agent.stream_delta_callback = _stream_delta_cb agent.stream_delta_callback = _stream_delta_cb
agent.interim_assistant_callback = _interim_assistant_cb if _want_interim_messages else None agent.interim_assistant_callback = _interim_assistant_cb if _want_interim_messages else None
agent.status_callback = _status_callback_sync agent.status_callback = _status_callback_sync
agent.clarify_callback = self._build_clarify_callback(
source=source,
session_key=session_key,
loop=_loop_for_step,
metadata=_status_thread_metadata,
)
agent.reasoning_config = reasoning_config agent.reasoning_config = reasoning_config
agent.service_tier = self._service_tier agent.service_tier = self._service_tier
agent.request_overrides = turn_route.get("request_overrides") agent.request_overrides = turn_route.get("request_overrides")

View file

@ -0,0 +1,174 @@
import asyncio
import threading
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import Platform, StreamingConfig
from gateway.platforms.base import MessageEvent, MessageType
from gateway.run import GatewayRunner
from gateway.session import SessionSource
def _make_source(platform=Platform.TELEGRAM):
return SessionSource(
platform=platform,
chat_id="6493121275",
chat_name="Test Chat",
chat_type="dm",
user_id="6493121275",
user_name="Tyler",
thread_id=None,
)
@pytest.mark.asyncio
async def test_gateway_clarify_callback_round_trip():
runner = GatewayRunner.__new__(GatewayRunner)
adapter = MagicMock()
adapter.send = AsyncMock()
runner.adapters = {Platform.TELEGRAM: adapter}
runner._pending_clarify = {}
source = _make_source()
callback = runner._build_clarify_callback(
source=source,
session_key="telegram:6493121275",
loop=asyncio.get_running_loop(),
metadata=None,
)
result_box = {}
def worker():
result_box["result"] = callback("Pick a color", ["red", "blue"])
thread = threading.Thread(target=worker)
thread.start()
for _ in range(20):
if runner._pending_clarify:
break
await asyncio.sleep(0.05)
for _ in range(20):
if adapter.send.await_count:
break
await asyncio.sleep(0.05)
assert "telegram:6493121275" in runner._pending_clarify
adapter.send.assert_awaited_once()
sent_text = adapter.send.await_args.args[1]
assert "Pick a color" in sent_text
assert "1. red" in sent_text
assert "2. blue" in sent_text
entry = runner._pending_clarify["telegram:6493121275"]
entry["response"] = "blue"
entry["event"].set()
thread.join(timeout=2)
assert result_box["result"] == "blue"
assert "telegram:6493121275" not in runner._pending_clarify
@pytest.mark.asyncio
async def test_handle_pending_clarify_consumes_numeric_reply():
runner = GatewayRunner.__new__(GatewayRunner)
runner._pending_clarify = {
"telegram:6493121275": {
"question": "Pick a color",
"choices": ["red", "blue"],
"response": None,
"event": threading.Event(),
"user_id": "6493121275",
}
}
event = MessageEvent(
text="2",
message_type=MessageType.TEXT,
source=_make_source(),
)
result = await runner._handle_pending_clarify(event, "telegram:6493121275")
assert result == ""
entry = runner._pending_clarify["telegram:6493121275"]
assert entry["response"] == "blue"
assert entry["event"].is_set()
@pytest.mark.asyncio
async def test_run_agent_wires_clarify_callback_to_agent(monkeypatch):
runner = GatewayRunner.__new__(GatewayRunner)
runner.adapters = {}
runner.config = MagicMock()
runner.config.streaming = StreamingConfig()
runner._running_agents = {}
runner._running_agents_ts = {}
runner._session_model_overrides = {}
runner._agent_cache = {}
runner._agent_cache_lock = None
runner._provider_routing = {
"only": None,
"ignore": None,
"order": None,
"sort": None,
"require_parameters": False,
"data_collection": None,
}
runner._fallback_model = None
runner._prefill_messages = None
runner._ephemeral_system_prompt = ""
runner._session_db = None
runner._pending_clarify = {}
runner.hooks = MagicMock()
runner.hooks.loaded_hooks = False
runner._load_reasoning_config = lambda: None
runner._load_service_tier = lambda: None
runner._resolve_session_agent_runtime = lambda **kw: (
"anthropic/claude-sonnet-4",
{
"api_key": "test-key",
"base_url": "https://openrouter.ai/api/v1",
"provider": "openrouter",
"api_mode": "chat_completions",
},
)
runner._resolve_turn_agent_config = lambda message, model, runtime: {
"model": model,
"runtime": runtime,
"request_overrides": None,
}
runner._build_clarify_callback = lambda **kw: (lambda question, choices: "blue")
runner._get_proxy_url = lambda: None
class FakeAgent:
def __init__(self, *args, **kwargs):
self.clarify_callback = None
self.tools = []
def run_conversation(self, user_message=None, **kwargs):
return {
"final_response": self.clarify_callback("Pick a color", ["red", "blue"]),
"messages": [],
"api_calls": 1,
"completed": True,
}
monkeypatch.setattr("gateway.run._load_gateway_config", lambda: {"display": {}})
source = _make_source()
with patch("run_agent.AIAgent", FakeAgent), patch(
"hermes_cli.tools_config._get_platform_tools", return_value=[]
):
result = await runner._run_agent(
message="hello",
context_prompt="",
history=[],
source=source,
session_id="session-1",
session_key="telegram:6493121275",
)
assert result["final_response"] == "blue"