From 78633e58decaa3b207adaeb318f87fbfbb580f70 Mon Sep 17 00:00:00 2001 From: tymrtn Date: Sun, 19 Apr 2026 15:51:00 +0200 Subject: [PATCH] fix(gateway): wire clarify callback for messaging sessions --- gateway/run.py | 129 +++++++++++++++++- tests/gateway/test_clarify_callback.py | 174 +++++++++++++++++++++++++ 2 files changed, 301 insertions(+), 2 deletions(-) create mode 100644 tests/gateway/test_clarify_callback.py diff --git a/gateway/run.py b/gateway/run.py index 60c57495b44..c5d358b18bc 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -28,7 +28,7 @@ from collections import OrderedDict from contextvars import copy_context from pathlib import Path from datetime import datetime -from typing import Dict, Optional, Any, List +from typing import Dict, Optional, Any, List, Tuple # --- Agent cache tuning --------------------------------------------------- # Bounds the per-session AIAgent cache to prevent unbounded growth in @@ -688,6 +688,10 @@ class GatewayRunner: # Key: session_key, Value: True when a prompt is waiting for user input. self._update_prompt_pending: Dict[str, bool] = {} + # Track pending clarify-tool prompts per session. + # Key: session_key, Value: {question, choices, response, event, user_id} + self._pending_clarify: Dict[str, Dict[str, Any]] = {} + # Persistent Honcho managers keyed by gateway session key. # This preserves write_frequency="session" semantics across short-lived # per-message AIAgent instances. @@ -1007,6 +1011,113 @@ class GatewayRunner: thread_sessions_per_user=getattr(config, "thread_sessions_per_user", False), ) + @staticmethod + def _format_clarify_prompt(question: str, choices: Optional[List[str]]) -> str: + """Render a gateway-friendly clarify prompt for messaging platforms.""" + lines = [f"❓ {question.strip()}"] + if choices: + for idx, choice in enumerate(choices, start=1): + lines.append(f"{idx}. {choice}") + lines.append(f"{len(choices) + 1}. Other (type your answer)") + lines.append("") + lines.append("Reply with the number or the text of your choice.") + else: + lines.append("") + lines.append("Reply with your answer.") + return "\n".join(lines) + + @staticmethod + def _coerce_clarify_response(raw: str, choices: Optional[List[str]]) -> Tuple[Optional[str], Optional[str]]: + """Normalize a user clarify reply into the chosen response text.""" + text = str(raw or "").strip() + if not text: + return None, "Please reply with your answer." + if not choices: + return text, None + + if text.isdigit(): + idx = int(text) + if 1 <= idx <= len(choices): + return choices[idx - 1], None + if idx == len(choices) + 1: + return None, "Type your custom answer in a new message." + return None, f"Please reply with 1-{len(choices) + 1}, or type your answer." + + lowered = text.casefold() + for choice in choices: + if lowered == str(choice).strip().casefold(): + return choice, None + + return text, None + + def _build_clarify_callback( + self, + *, + source: SessionSource, + session_key: str, + loop: asyncio.AbstractEventLoop, + metadata: Optional[Dict[str, Any]] = None, + ): + """Build a blocking clarify callback for gateway-created agents.""" + import threading + + adapter = self.adapters.get(source.platform) + + def _callback(question: str, choices: Optional[List[str]]) -> str: + if not adapter: + raise RuntimeError(f"No adapter available for platform {source.platform}") + + entry = { + "question": question, + "choices": list(choices or []) or None, + "response": None, + "event": threading.Event(), + "user_id": source.user_id, + } + self._pending_clarify[session_key] = entry + + async def _send_prompt() -> None: + await adapter.send( + source.chat_id, + self._format_clarify_prompt(question, entry["choices"]), + metadata=metadata, + ) + + try: + fut = asyncio.run_coroutine_threadsafe(_send_prompt(), loop) + fut.result(timeout=15) + timeout = int(os.getenv("HERMES_CLARIFY_TIMEOUT", "600")) + if not entry["event"].wait(timeout): + raise TimeoutError(f"clarify timed out after {timeout}s") + return str(entry.get("response") or "").strip() + finally: + self._pending_clarify.pop(session_key, None) + + return _callback + + async def _handle_pending_clarify( + self, + event: MessageEvent, + session_key: str, + ) -> Optional[str]: + """Resolve a pending clarify prompt from the user's next message.""" + pending = getattr(self, "_pending_clarify", None) or {} + entry = pending.get(session_key) + if not entry: + return None + + cmd = event.get_command() + if cmd in {"status", "stop", "restart", "new", "reset", "help", "model", "background", "queue", "approve", "deny"}: + return "__skip__" + + response, error_message = self._coerce_clarify_response(event.text or "", entry.get("choices")) + if response is None: + return error_message or "Please reply with your answer." + + entry["response"] = response + entry["event"].set() + return "" + def _resolve_session_agent_runtime( self, *, @@ -3009,11 +3120,19 @@ class GatewayRunner: self.pairing_store._record_rate_limit(platform_name, source.user_id) return None + _quick_key = self._session_key_for_source(source) + + # Intercept messages that are responses to a pending clarify prompt. + _clarify_result = await self._handle_pending_clarify(event, _quick_key) + if _clarify_result == "": + return None + if _clarify_result and _clarify_result != "__skip__": + return _clarify_result + # Intercept messages that are responses to a pending /update prompt. # The update process (detached) wrote .update_prompt.json; the watcher # forwarded it to the user; now the user's reply goes back via # .update_response so the update process can continue. - _quick_key = self._session_key_for_source(source) _update_prompts = getattr(self, "_update_prompt_pending", {}) if _update_prompts.get(_quick_key): raw = (event.text or "").strip() @@ -9564,6 +9683,12 @@ class GatewayRunner: agent.stream_delta_callback = _stream_delta_cb agent.interim_assistant_callback = _interim_assistant_cb if _want_interim_messages else None agent.status_callback = _status_callback_sync + agent.clarify_callback = self._build_clarify_callback( + source=source, + session_key=session_key, + loop=_loop_for_step, + metadata=_status_thread_metadata, + ) agent.reasoning_config = reasoning_config agent.service_tier = self._service_tier agent.request_overrides = turn_route.get("request_overrides") diff --git a/tests/gateway/test_clarify_callback.py b/tests/gateway/test_clarify_callback.py new file mode 100644 index 00000000000..cc0b585caa5 --- /dev/null +++ b/tests/gateway/test_clarify_callback.py @@ -0,0 +1,174 @@ +import asyncio +import threading +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from gateway.config import Platform, StreamingConfig +from gateway.platforms.base import MessageEvent, MessageType +from gateway.run import GatewayRunner +from gateway.session import SessionSource + + +def _make_source(platform=Platform.TELEGRAM): + return SessionSource( + platform=platform, + chat_id="6493121275", + chat_name="Test Chat", + chat_type="dm", + user_id="6493121275", + user_name="Tyler", + thread_id=None, + ) + + +@pytest.mark.asyncio +async def test_gateway_clarify_callback_round_trip(): + runner = GatewayRunner.__new__(GatewayRunner) + adapter = MagicMock() + adapter.send = AsyncMock() + runner.adapters = {Platform.TELEGRAM: adapter} + runner._pending_clarify = {} + source = _make_source() + + callback = runner._build_clarify_callback( + source=source, + session_key="telegram:6493121275", + loop=asyncio.get_running_loop(), + metadata=None, + ) + + result_box = {} + + def worker(): + result_box["result"] = callback("Pick a color", ["red", "blue"]) + + thread = threading.Thread(target=worker) + thread.start() + + for _ in range(20): + if runner._pending_clarify: + break + await asyncio.sleep(0.05) + + for _ in range(20): + if adapter.send.await_count: + break + await asyncio.sleep(0.05) + + assert "telegram:6493121275" in runner._pending_clarify + adapter.send.assert_awaited_once() + sent_text = adapter.send.await_args.args[1] + assert "Pick a color" in sent_text + assert "1. red" in sent_text + assert "2. blue" in sent_text + + entry = runner._pending_clarify["telegram:6493121275"] + entry["response"] = "blue" + entry["event"].set() + + thread.join(timeout=2) + assert result_box["result"] == "blue" + assert "telegram:6493121275" not in runner._pending_clarify + + +@pytest.mark.asyncio +async def test_handle_pending_clarify_consumes_numeric_reply(): + runner = GatewayRunner.__new__(GatewayRunner) + runner._pending_clarify = { + "telegram:6493121275": { + "question": "Pick a color", + "choices": ["red", "blue"], + "response": None, + "event": threading.Event(), + "user_id": "6493121275", + } + } + event = MessageEvent( + text="2", + message_type=MessageType.TEXT, + source=_make_source(), + ) + + result = await runner._handle_pending_clarify(event, "telegram:6493121275") + + assert result == "" + entry = runner._pending_clarify["telegram:6493121275"] + assert entry["response"] == "blue" + assert entry["event"].is_set() + + +@pytest.mark.asyncio +async def test_run_agent_wires_clarify_callback_to_agent(monkeypatch): + runner = GatewayRunner.__new__(GatewayRunner) + runner.adapters = {} + runner.config = MagicMock() + runner.config.streaming = StreamingConfig() + runner._running_agents = {} + runner._running_agents_ts = {} + runner._session_model_overrides = {} + runner._agent_cache = {} + runner._agent_cache_lock = None + runner._provider_routing = { + "only": None, + "ignore": None, + "order": None, + "sort": None, + "require_parameters": False, + "data_collection": None, + } + runner._fallback_model = None + runner._prefill_messages = None + runner._ephemeral_system_prompt = "" + runner._session_db = None + runner._pending_clarify = {} + runner.hooks = MagicMock() + runner.hooks.loaded_hooks = False + runner._load_reasoning_config = lambda: None + runner._load_service_tier = lambda: None + runner._resolve_session_agent_runtime = lambda **kw: ( + "anthropic/claude-sonnet-4", + { + "api_key": "test-key", + "base_url": "https://openrouter.ai/api/v1", + "provider": "openrouter", + "api_mode": "chat_completions", + }, + ) + runner._resolve_turn_agent_config = lambda message, model, runtime: { + "model": model, + "runtime": runtime, + "request_overrides": None, + } + runner._build_clarify_callback = lambda **kw: (lambda question, choices: "blue") + runner._get_proxy_url = lambda: None + + class FakeAgent: + def __init__(self, *args, **kwargs): + self.clarify_callback = None + self.tools = [] + + def run_conversation(self, user_message=None, **kwargs): + return { + "final_response": self.clarify_callback("Pick a color", ["red", "blue"]), + "messages": [], + "api_calls": 1, + "completed": True, + } + + monkeypatch.setattr("gateway.run._load_gateway_config", lambda: {"display": {}}) + + source = _make_source() + with patch("run_agent.AIAgent", FakeAgent), patch( + "hermes_cli.tools_config._get_platform_tools", return_value=[] + ): + result = await runner._run_agent( + message="hello", + context_prompt="", + history=[], + source=source, + session_id="session-1", + session_key="telegram:6493121275", + ) + + assert result["final_response"] == "blue"