fix(gateway): wire clarify callback for messaging sessions

This commit is contained in:
tymrtn 2026-04-19 15:51:00 +02:00
parent a521005fe5
commit 78633e58de
2 changed files with 301 additions and 2 deletions

View file

@ -28,7 +28,7 @@ from collections import OrderedDict
from contextvars import copy_context
from pathlib import Path
from datetime import datetime
from typing import Dict, Optional, Any, List
from typing import Dict, Optional, Any, List, Tuple
# --- Agent cache tuning ---------------------------------------------------
# Bounds the per-session AIAgent cache to prevent unbounded growth in
@ -688,6 +688,10 @@ class GatewayRunner:
# Key: session_key, Value: True when a prompt is waiting for user input.
self._update_prompt_pending: Dict[str, bool] = {}
# Track pending clarify-tool prompts per session.
# Key: session_key, Value: {question, choices, response, event, user_id}
self._pending_clarify: Dict[str, Dict[str, Any]] = {}
# Persistent Honcho managers keyed by gateway session key.
# This preserves write_frequency="session" semantics across short-lived
# per-message AIAgent instances.
@ -1007,6 +1011,113 @@ class GatewayRunner:
thread_sessions_per_user=getattr(config, "thread_sessions_per_user", False),
)
@staticmethod
def _format_clarify_prompt(question: str, choices: Optional[List[str]]) -> str:
"""Render a gateway-friendly clarify prompt for messaging platforms."""
lines = [f"{question.strip()}"]
if choices:
for idx, choice in enumerate(choices, start=1):
lines.append(f"{idx}. {choice}")
lines.append(f"{len(choices) + 1}. Other (type your answer)")
lines.append("")
lines.append("Reply with the number or the text of your choice.")
else:
lines.append("")
lines.append("Reply with your answer.")
return "\n".join(lines)
@staticmethod
def _coerce_clarify_response(raw: str, choices: Optional[List[str]]) -> Tuple[Optional[str], Optional[str]]:
"""Normalize a user clarify reply into the chosen response text."""
text = str(raw or "").strip()
if not text:
return None, "Please reply with your answer."
if not choices:
return text, None
if text.isdigit():
idx = int(text)
if 1 <= idx <= len(choices):
return choices[idx - 1], None
if idx == len(choices) + 1:
return None, "Type your custom answer in a new message."
return None, f"Please reply with 1-{len(choices) + 1}, or type your answer."
lowered = text.casefold()
for choice in choices:
if lowered == str(choice).strip().casefold():
return choice, None
return text, None
def _build_clarify_callback(
self,
*,
source: SessionSource,
session_key: str,
loop: asyncio.AbstractEventLoop,
metadata: Optional[Dict[str, Any]] = None,
):
"""Build a blocking clarify callback for gateway-created agents."""
import threading
adapter = self.adapters.get(source.platform)
def _callback(question: str, choices: Optional[List[str]]) -> str:
if not adapter:
raise RuntimeError(f"No adapter available for platform {source.platform}")
entry = {
"question": question,
"choices": list(choices or []) or None,
"response": None,
"event": threading.Event(),
"user_id": source.user_id,
}
self._pending_clarify[session_key] = entry
async def _send_prompt() -> None:
await adapter.send(
source.chat_id,
self._format_clarify_prompt(question, entry["choices"]),
metadata=metadata,
)
try:
fut = asyncio.run_coroutine_threadsafe(_send_prompt(), loop)
fut.result(timeout=15)
timeout = int(os.getenv("HERMES_CLARIFY_TIMEOUT", "600"))
if not entry["event"].wait(timeout):
raise TimeoutError(f"clarify timed out after {timeout}s")
return str(entry.get("response") or "").strip()
finally:
self._pending_clarify.pop(session_key, None)
return _callback
async def _handle_pending_clarify(
self,
event: MessageEvent,
session_key: str,
) -> Optional[str]:
"""Resolve a pending clarify prompt from the user's next message."""
pending = getattr(self, "_pending_clarify", None) or {}
entry = pending.get(session_key)
if not entry:
return None
cmd = event.get_command()
if cmd in {"status", "stop", "restart", "new", "reset", "help", "model", "background", "queue", "approve", "deny"}:
return "__skip__"
response, error_message = self._coerce_clarify_response(event.text or "", entry.get("choices"))
if response is None:
return error_message or "Please reply with your answer."
entry["response"] = response
entry["event"].set()
return ""
def _resolve_session_agent_runtime(
self,
*,
@ -3009,11 +3120,19 @@ class GatewayRunner:
self.pairing_store._record_rate_limit(platform_name, source.user_id)
return None
_quick_key = self._session_key_for_source(source)
# Intercept messages that are responses to a pending clarify prompt.
_clarify_result = await self._handle_pending_clarify(event, _quick_key)
if _clarify_result == "":
return None
if _clarify_result and _clarify_result != "__skip__":
return _clarify_result
# Intercept messages that are responses to a pending /update prompt.
# The update process (detached) wrote .update_prompt.json; the watcher
# forwarded it to the user; now the user's reply goes back via
# .update_response so the update process can continue.
_quick_key = self._session_key_for_source(source)
_update_prompts = getattr(self, "_update_prompt_pending", {})
if _update_prompts.get(_quick_key):
raw = (event.text or "").strip()
@ -9564,6 +9683,12 @@ class GatewayRunner:
agent.stream_delta_callback = _stream_delta_cb
agent.interim_assistant_callback = _interim_assistant_cb if _want_interim_messages else None
agent.status_callback = _status_callback_sync
agent.clarify_callback = self._build_clarify_callback(
source=source,
session_key=session_key,
loop=_loop_for_step,
metadata=_status_thread_metadata,
)
agent.reasoning_config = reasoning_config
agent.service_tier = self._service_tier
agent.request_overrides = turn_route.get("request_overrides")