mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-29 06:31:32 +00:00
fix(gateway): wire clarify callback for messaging sessions
This commit is contained in:
parent
a521005fe5
commit
78633e58de
2 changed files with 301 additions and 2 deletions
129
gateway/run.py
129
gateway/run.py
|
|
@ -28,7 +28,7 @@ from collections import OrderedDict
|
|||
from contextvars import copy_context
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, Any, List
|
||||
from typing import Dict, Optional, Any, List, Tuple
|
||||
|
||||
# --- Agent cache tuning ---------------------------------------------------
|
||||
# Bounds the per-session AIAgent cache to prevent unbounded growth in
|
||||
|
|
@ -688,6 +688,10 @@ class GatewayRunner:
|
|||
# Key: session_key, Value: True when a prompt is waiting for user input.
|
||||
self._update_prompt_pending: Dict[str, bool] = {}
|
||||
|
||||
# Track pending clarify-tool prompts per session.
|
||||
# Key: session_key, Value: {question, choices, response, event, user_id}
|
||||
self._pending_clarify: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Persistent Honcho managers keyed by gateway session key.
|
||||
# This preserves write_frequency="session" semantics across short-lived
|
||||
# per-message AIAgent instances.
|
||||
|
|
@ -1007,6 +1011,113 @@ class GatewayRunner:
|
|||
thread_sessions_per_user=getattr(config, "thread_sessions_per_user", False),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_clarify_prompt(question: str, choices: Optional[List[str]]) -> str:
|
||||
"""Render a gateway-friendly clarify prompt for messaging platforms."""
|
||||
lines = [f"❓ {question.strip()}"]
|
||||
if choices:
|
||||
for idx, choice in enumerate(choices, start=1):
|
||||
lines.append(f"{idx}. {choice}")
|
||||
lines.append(f"{len(choices) + 1}. Other (type your answer)")
|
||||
lines.append("")
|
||||
lines.append("Reply with the number or the text of your choice.")
|
||||
else:
|
||||
lines.append("")
|
||||
lines.append("Reply with your answer.")
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_clarify_response(raw: str, choices: Optional[List[str]]) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Normalize a user clarify reply into the chosen response text."""
|
||||
text = str(raw or "").strip()
|
||||
if not text:
|
||||
return None, "Please reply with your answer."
|
||||
if not choices:
|
||||
return text, None
|
||||
|
||||
if text.isdigit():
|
||||
idx = int(text)
|
||||
if 1 <= idx <= len(choices):
|
||||
return choices[idx - 1], None
|
||||
if idx == len(choices) + 1:
|
||||
return None, "Type your custom answer in a new message."
|
||||
return None, f"Please reply with 1-{len(choices) + 1}, or type your answer."
|
||||
|
||||
lowered = text.casefold()
|
||||
for choice in choices:
|
||||
if lowered == str(choice).strip().casefold():
|
||||
return choice, None
|
||||
|
||||
return text, None
|
||||
|
||||
def _build_clarify_callback(
|
||||
self,
|
||||
*,
|
||||
source: SessionSource,
|
||||
session_key: str,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Build a blocking clarify callback for gateway-created agents."""
|
||||
import threading
|
||||
|
||||
adapter = self.adapters.get(source.platform)
|
||||
|
||||
def _callback(question: str, choices: Optional[List[str]]) -> str:
|
||||
if not adapter:
|
||||
raise RuntimeError(f"No adapter available for platform {source.platform}")
|
||||
|
||||
entry = {
|
||||
"question": question,
|
||||
"choices": list(choices or []) or None,
|
||||
"response": None,
|
||||
"event": threading.Event(),
|
||||
"user_id": source.user_id,
|
||||
}
|
||||
self._pending_clarify[session_key] = entry
|
||||
|
||||
async def _send_prompt() -> None:
|
||||
await adapter.send(
|
||||
source.chat_id,
|
||||
self._format_clarify_prompt(question, entry["choices"]),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
fut = asyncio.run_coroutine_threadsafe(_send_prompt(), loop)
|
||||
fut.result(timeout=15)
|
||||
timeout = int(os.getenv("HERMES_CLARIFY_TIMEOUT", "600"))
|
||||
if not entry["event"].wait(timeout):
|
||||
raise TimeoutError(f"clarify timed out after {timeout}s")
|
||||
return str(entry.get("response") or "").strip()
|
||||
finally:
|
||||
self._pending_clarify.pop(session_key, None)
|
||||
|
||||
return _callback
|
||||
|
||||
async def _handle_pending_clarify(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
session_key: str,
|
||||
) -> Optional[str]:
|
||||
"""Resolve a pending clarify prompt from the user's next message."""
|
||||
pending = getattr(self, "_pending_clarify", None) or {}
|
||||
entry = pending.get(session_key)
|
||||
if not entry:
|
||||
return None
|
||||
|
||||
cmd = event.get_command()
|
||||
if cmd in {"status", "stop", "restart", "new", "reset", "help", "model", "background", "queue", "approve", "deny"}:
|
||||
return "__skip__"
|
||||
|
||||
response, error_message = self._coerce_clarify_response(event.text or "", entry.get("choices"))
|
||||
if response is None:
|
||||
return error_message or "Please reply with your answer."
|
||||
|
||||
entry["response"] = response
|
||||
entry["event"].set()
|
||||
return ""
|
||||
|
||||
def _resolve_session_agent_runtime(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -3009,11 +3120,19 @@ class GatewayRunner:
|
|||
self.pairing_store._record_rate_limit(platform_name, source.user_id)
|
||||
return None
|
||||
|
||||
_quick_key = self._session_key_for_source(source)
|
||||
|
||||
# Intercept messages that are responses to a pending clarify prompt.
|
||||
_clarify_result = await self._handle_pending_clarify(event, _quick_key)
|
||||
if _clarify_result == "":
|
||||
return None
|
||||
if _clarify_result and _clarify_result != "__skip__":
|
||||
return _clarify_result
|
||||
|
||||
# Intercept messages that are responses to a pending /update prompt.
|
||||
# The update process (detached) wrote .update_prompt.json; the watcher
|
||||
# forwarded it to the user; now the user's reply goes back via
|
||||
# .update_response so the update process can continue.
|
||||
_quick_key = self._session_key_for_source(source)
|
||||
_update_prompts = getattr(self, "_update_prompt_pending", {})
|
||||
if _update_prompts.get(_quick_key):
|
||||
raw = (event.text or "").strip()
|
||||
|
|
@ -9564,6 +9683,12 @@ class GatewayRunner:
|
|||
agent.stream_delta_callback = _stream_delta_cb
|
||||
agent.interim_assistant_callback = _interim_assistant_cb if _want_interim_messages else None
|
||||
agent.status_callback = _status_callback_sync
|
||||
agent.clarify_callback = self._build_clarify_callback(
|
||||
source=source,
|
||||
session_key=session_key,
|
||||
loop=_loop_for_step,
|
||||
metadata=_status_thread_metadata,
|
||||
)
|
||||
agent.reasoning_config = reasoning_config
|
||||
agent.service_tier = self._service_tier
|
||||
agent.request_overrides = turn_route.get("request_overrides")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue