diff --git a/gateway/run.py b/gateway/run.py index 469abe9ec..cca4f8739 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -352,19 +352,14 @@ def _build_media_placeholder(event) -> str: return "\n".join(parts) -def _dequeue_pending_text(adapter, session_key: str) -> str | None: - """Consume and return the text of a pending queued message. +def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None: + """Consume and return the full pending event for a session. - Preserves media context for captionless photo/document events by - building a placeholder so the message isn't silently dropped. + Queued follow-ups must preserve their media metadata so they can re-enter + the normal image/STT/document preprocessing path instead of being reduced + to a placeholder string. """ - event = adapter.get_pending_message(session_key) - if not event: - return None - text = event.text - if not text and getattr(event, "media_urls", None): - text = _build_media_placeholder(event) - return text + return adapter.get_pending_message(session_key) def _check_unavailable_skill(command_name: str) -> str | None: @@ -2775,6 +2770,162 @@ class GatewayRunner: del self._running_agents[_quick_key] self._running_agents_ts.pop(_quick_key, None) + async def _prepare_inbound_message_text( + self, + *, + event: MessageEvent, + source: SessionSource, + history: List[Dict[str, Any]], + ) -> Optional[str]: + """Prepare inbound event text for the agent. + + Keep the normal inbound path and the queued follow-up path on the same + preprocessing pipeline so sender attribution, image enrichment, STT, + document notes, reply context, and @ references all behave the same. + """ + history = history or [] + message_text = event.text or "" + + _is_shared_thread = ( + source.chat_type != "dm" + and source.thread_id + and not getattr(self.config, "thread_sessions_per_user", False) + ) + if _is_shared_thread and source.user_name: + message_text = f"[{source.user_name}] {message_text}" + + if event.media_urls: + image_paths = [] + audio_paths = [] + for i, path in enumerate(event.media_urls): + mtype = event.media_types[i] if i < len(event.media_types) else "" + if mtype.startswith("image/") or event.message_type == MessageType.PHOTO: + image_paths.append(path) + if mtype.startswith("audio/") or event.message_type in (MessageType.VOICE, MessageType.AUDIO): + audio_paths.append(path) + + if image_paths: + message_text = await self._enrich_message_with_vision( + message_text, + image_paths, + ) + + if audio_paths: + message_text = await self._enrich_message_with_transcription( + message_text, + audio_paths, + ) + _stt_fail_markers = ( + "No STT provider", + "STT is disabled", + "can't listen", + "VOICE_TOOLS_OPENAI_KEY", + ) + if any(marker in message_text for marker in _stt_fail_markers): + _stt_adapter = self.adapters.get(source.platform) + _stt_meta = {"thread_id": source.thread_id} if source.thread_id else None + if _stt_adapter: + try: + _stt_msg = ( + "🎤 I received your voice message but can't transcribe it — " + "no speech-to-text provider is configured.\n\n" + "To enable voice: install faster-whisper " + "(`pip install faster-whisper` in the Hermes venv) " + "and set `stt.enabled: true` in config.yaml, " + "then /restart the gateway." + ) + if self._has_setup_skill(): + _stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`" + await _stt_adapter.send( + source.chat_id, + _stt_msg, + metadata=_stt_meta, + ) + except Exception: + pass + + if event.media_urls and event.message_type == MessageType.DOCUMENT: + import mimetypes as _mimetypes + + _TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"} + for i, path in enumerate(event.media_urls): + mtype = event.media_types[i] if i < len(event.media_types) else "" + if mtype in ("", "application/octet-stream"): + import os as _os2 + + _ext = _os2.path.splitext(path)[1].lower() + if _ext in _TEXT_EXTENSIONS: + mtype = "text/plain" + else: + guessed, _ = _mimetypes.guess_type(path) + if guessed: + mtype = guessed + if not mtype.startswith(("application/", "text/")): + continue + + import os as _os + import re as _re + + basename = _os.path.basename(path) + parts = basename.split("_", 2) + display_name = parts[2] if len(parts) >= 3 else basename + display_name = _re.sub(r'[^\w.\- ]', '_', display_name) + + if mtype.startswith("text/"): + context_note = ( + f"[The user sent a text document: '{display_name}'. " + f"Its content has been included below. " + f"The file is also saved at: {path}]" + ) + else: + context_note = ( + f"[The user sent a document: '{display_name}'. " + f"The file is saved at: {path}. " + f"Ask the user what they'd like you to do with it.]" + ) + message_text = f"{context_note}\n\n{message_text}" + + if getattr(event, "reply_to_text", None) and event.reply_to_message_id: + reply_snippet = event.reply_to_text[:500] + found_in_history = any( + reply_snippet[:200] in (msg.get("content") or "") + for msg in history + if msg.get("role") in ("assistant", "user", "tool") + ) + if not found_in_history: + message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}' + + if "@" in message_text: + try: + from agent.context_references import preprocess_context_references_async + from agent.model_metadata import get_model_context_length + + _msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~")) + _msg_ctx_len = get_model_context_length( + self._model, + base_url=self._base_url or "", + ) + _ctx_result = await preprocess_context_references_async( + message_text, + cwd=_msg_cwd, + context_length=_msg_ctx_len, + allowed_root=_msg_cwd, + ) + if _ctx_result.blocked: + _adapter = self.adapters.get(source.platform) + if _adapter: + await _adapter.send( + source.chat_id, + "\n".join(_ctx_result.warnings) or "Context injection refused.", + ) + return None + if _ctx_result.expanded: + message_text = _ctx_result.message + except Exception as exc: + logger.debug("@ context reference expansion failed: %s", exc) + + return message_text + async def _handle_message_with_agent(self, event, source, _quick_key: str): """Inner handler that runs under the _running_agents sentinel guard.""" _msg_start_time = time.time() @@ -3215,149 +3366,13 @@ class GatewayRunner: # attachments (documents, audio, etc.) are not sent to the vision # tool even when they appear in the same message. # ----------------------------------------------------------------- - message_text = event.text or "" - - # ----------------------------------------------------------------- - # Sender attribution for shared thread sessions. - # - # When multiple users share a single thread session (the default for - # threads), prefix each message with [sender name] so the agent can - # tell participants apart. Skip for DMs (single-user by nature) and - # when per-user thread isolation is explicitly enabled. - # ----------------------------------------------------------------- - _is_shared_thread = ( - source.chat_type != "dm" - and source.thread_id - and not getattr(self.config, "thread_sessions_per_user", False) + message_text = await self._prepare_inbound_message_text( + event=event, + source=source, + history=history, ) - if _is_shared_thread and source.user_name: - message_text = f"[{source.user_name}] {message_text}" - - if event.media_urls: - image_paths = [] - for i, path in enumerate(event.media_urls): - # Check media_types if available; otherwise infer from message type - mtype = event.media_types[i] if i < len(event.media_types) else "" - is_image = ( - mtype.startswith("image/") - or event.message_type == MessageType.PHOTO - ) - if is_image: - image_paths.append(path) - if image_paths: - message_text = await self._enrich_message_with_vision( - message_text, image_paths - ) - - # ----------------------------------------------------------------- - # Auto-transcribe voice/audio messages sent by the user - # ----------------------------------------------------------------- - if event.media_urls: - audio_paths = [] - for i, path in enumerate(event.media_urls): - mtype = event.media_types[i] if i < len(event.media_types) else "" - is_audio = ( - mtype.startswith("audio/") - or event.message_type in (MessageType.VOICE, MessageType.AUDIO) - ) - if is_audio: - audio_paths.append(path) - if audio_paths: - message_text = await self._enrich_message_with_transcription( - message_text, audio_paths - ) - # If STT failed, send a direct message to the user so they - # know voice isn't configured — don't rely on the agent to - # relay the error clearly. - _stt_fail_markers = ( - "No STT provider", - "STT is disabled", - "can't listen", - "VOICE_TOOLS_OPENAI_KEY", - ) - if any(m in message_text for m in _stt_fail_markers): - _stt_adapter = self.adapters.get(source.platform) - _stt_meta = {"thread_id": source.thread_id} if source.thread_id else None - if _stt_adapter: - try: - _stt_msg = ( - "🎤 I received your voice message but can't transcribe it — " - "no speech-to-text provider is configured.\n\n" - "To enable voice: install faster-whisper " - "(`pip install faster-whisper` in the Hermes venv) " - "and set `stt.enabled: true` in config.yaml, " - "then /restart the gateway." - ) - # Point to setup skill if it's installed - if self._has_setup_skill(): - _stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`" - await _stt_adapter.send( - source.chat_id, _stt_msg, - metadata=_stt_meta, - ) - except Exception: - pass - - # ----------------------------------------------------------------- - # Enrich document messages with context notes for the agent - # ----------------------------------------------------------------- - if event.media_urls and event.message_type == MessageType.DOCUMENT: - import mimetypes as _mimetypes - _TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"} - for i, path in enumerate(event.media_urls): - mtype = event.media_types[i] if i < len(event.media_types) else "" - # Fall back to extension-based detection when MIME type is unreliable. - if mtype in ("", "application/octet-stream"): - import os as _os2 - _ext = _os2.path.splitext(path)[1].lower() - if _ext in _TEXT_EXTENSIONS: - mtype = "text/plain" - else: - guessed, _ = _mimetypes.guess_type(path) - if guessed: - mtype = guessed - if not mtype.startswith(("application/", "text/")): - continue - # Extract display filename by stripping the doc_{uuid12}_ prefix - import os as _os - basename = _os.path.basename(path) - # Format: doc_<12hex>_ - parts = basename.split("_", 2) - display_name = parts[2] if len(parts) >= 3 else basename - # Sanitize to prevent prompt injection via filenames - import re as _re - display_name = _re.sub(r'[^\w.\- ]', '_', display_name) - - if mtype.startswith("text/"): - context_note = ( - f"[The user sent a text document: '{display_name}'. " - f"Its content has been included below. " - f"The file is also saved at: {path}]" - ) - else: - context_note = ( - f"[The user sent a document: '{display_name}'. " - f"The file is saved at: {path}. " - f"Ask the user what they'd like you to do with it.]" - ) - message_text = f"{context_note}\n\n{message_text}" - - # ----------------------------------------------------------------- - # Inject reply context when user replies to a message not in history. - # Telegram (and other platforms) let users reply to specific messages, - # but if the quoted message is from a previous session, cron delivery, - # or background task, the agent has no context about what's being - # referenced. Prepend the quoted text so the agent understands. (#1594) - # ----------------------------------------------------------------- - if getattr(event, 'reply_to_text', None) and event.reply_to_message_id: - reply_snippet = event.reply_to_text[:500] - found_in_history = any( - reply_snippet[:200] in (msg.get("content") or "") - for msg in history - if msg.get("role") in ("assistant", "user", "tool") - ) - if not found_in_history: - message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}' + if message_text is None: + return try: # Emit agent:start hook @@ -3369,30 +3384,6 @@ class GatewayRunner: } await self.hooks.emit("agent:start", hook_ctx) - # Expand @ context references (@file:, @folder:, @diff, etc.) - if "@" in message_text: - try: - from agent.context_references import preprocess_context_references_async - from agent.model_metadata import get_model_context_length - _msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~")) - _msg_ctx_len = get_model_context_length( - self._model, base_url=self._base_url or "") - _ctx_result = await preprocess_context_references_async( - message_text, cwd=_msg_cwd, - context_length=_msg_ctx_len, allowed_root=_msg_cwd) - if _ctx_result.blocked: - _adapter = self.adapters.get(source.platform) - if _adapter: - await _adapter.send( - source.chat_id, - "\n".join(_ctx_result.warnings) or "Context injection refused.", - ) - return - if _ctx_result.expanded: - message_text = _ctx_result.message - except Exception as exc: - logger.debug("@ context reference expansion failed: %s", exc) - # Run the agent agent_result = await self._run_agent( message=message_text, @@ -8057,17 +8048,16 @@ class GatewayRunner: # Get pending message from adapter. # Use session_key (not source.chat_id) to match adapter's storage keys. + pending_event = None pending = None if result and adapter and session_key: - if result.get("interrupted"): - pending = _dequeue_pending_text(adapter, session_key) - if not pending and result.get("interrupt_message"): - pending = result.get("interrupt_message") - else: - pending = _dequeue_pending_text(adapter, session_key) - if pending: - logger.debug("Processing queued message after agent completion: '%s...'", pending[:40]) - + pending_event = _dequeue_pending_event(adapter, session_key) + if result.get("interrupted") and not pending_event and result.get("interrupt_message"): + pending = result.get("interrupt_message") + elif pending_event: + pending = pending_event.text or _build_media_placeholder(pending_event) + logger.debug("Processing queued message after agent completion: '%s...'", pending[:40]) + # Safety net: if the pending text is a slash command (e.g. "/stop", # "/new"), discard it — commands should never be passed to the agent # as user input. The primary fix is in base.py (commands bypass the @@ -8085,27 +8075,29 @@ class GatewayRunner: "commands must not be passed as agent input", _pending_cmd_word, ) + pending_event = None pending = None except Exception: pass - if self._draining and pending: + if self._draining and (pending_event or pending): logger.info( "Discarding pending follow-up for session %s during gateway %s", session_key[:20] if session_key else "?", self._status_action_label(), ) + pending_event = None pending = None - if pending: + if pending_event or pending: logger.debug("Processing pending message: '%s...'", pending[:40]) - + # Clear the adapter's interrupt event so the next _run_agent call # doesn't immediately re-trigger the interrupt before the new agent # even makes its first API call (this was causing an infinite loop). if adapter and hasattr(adapter, '_active_sessions') and session_key and session_key in adapter._active_sessions: adapter._active_sessions[session_key].clear() - + # Cap recursion depth to prevent resource exhaustion when the # user sends multiple messages while the agent keeps failing. (#816) if _interrupt_depth >= self._MAX_INTERRUPT_DEPTH: @@ -8114,9 +8106,10 @@ class GatewayRunner: "queueing message instead of recursing.", _interrupt_depth, session_key, ) - # Queue the pending message for normal processing on next turn adapter = self.adapters.get(source.platform) - if adapter and hasattr(adapter, 'queue_message'): + if adapter and pending_event: + merge_pending_message_event(adapter._pending_messages, session_key, pending_event) + elif adapter and hasattr(adapter, 'queue_message'): adapter.queue_message(session_key, pending) return result_holder[0] or {"final_response": response, "messages": history} @@ -8138,16 +8131,30 @@ class GatewayRunner: # interrupted." is just noise; the user already knows they sent a # new message). - # Process the pending message with updated history updated_history = result.get("messages", history) + next_source = source + next_message = pending + next_message_id = None + if pending_event is not None: + next_source = getattr(pending_event, "source", None) or source + next_message = await self._prepare_inbound_message_text( + event=pending_event, + source=next_source, + history=updated_history, + ) + if next_message is None: + return result + next_message_id = getattr(pending_event, "message_id", None) + return await self._run_agent( - message=pending, + message=next_message, context_prompt=context_prompt, history=updated_history, - source=source, + source=next_source, session_id=session_id, session_key=session_key, _interrupt_depth=_interrupt_depth + 1, + event_message_id=next_message_id, ) finally: # Stop progress sender, interrupt monitor, and notification task diff --git a/tests/gateway/test_queue_consumption.py b/tests/gateway/test_queue_consumption.py index 2a4dd4ff0..50effc139 100644 --- a/tests/gateway/test_queue_consumption.py +++ b/tests/gateway/test_queue_consumption.py @@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from gateway.run import _dequeue_pending_event from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -79,6 +80,26 @@ class TestQueueMessageStorage: # Should be consumed (cleared) assert adapter.get_pending_message(session_key) is None + def test_dequeue_pending_event_preserves_voice_media_metadata(self): + adapter = _StubAdapter() + session_key = "telegram:user:voice" + event = MessageEvent( + text="", + message_type=MessageType.VOICE, + source=MagicMock(chat_id="123", platform=Platform.TELEGRAM), + message_id="voice-q1", + media_urls=["/tmp/voice.ogg"], + media_types=["audio/ogg"], + ) + adapter._pending_messages[session_key] = event + + retrieved = _dequeue_pending_event(adapter, session_key) + + assert retrieved is event + assert retrieved.media_urls == ["/tmp/voice.ogg"] + assert retrieved.media_types == ["audio/ogg"] + assert adapter.get_pending_message(session_key) is None + def test_queue_does_not_set_interrupt_event(self): """The whole point of /queue — no interrupt signal.""" adapter = _StubAdapter() diff --git a/tests/gateway/test_stt_config.py b/tests/gateway/test_stt_config.py index a49e40215..23ba06af2 100644 --- a/tests/gateway/test_stt_config.py +++ b/tests/gateway/test_stt_config.py @@ -6,7 +6,9 @@ from unittest.mock import AsyncMock, patch import pytest import yaml -from gateway.config import GatewayConfig, load_gateway_config +from gateway.config import GatewayConfig, Platform, load_gateway_config +from gateway.platforms.base import MessageEvent, MessageType +from gateway.session import SessionSource def test_gateway_config_stt_disabled_from_dict_nested(): @@ -69,3 +71,46 @@ async def test_enrich_message_with_transcription_avoids_bogus_no_provider_messag assert "No STT provider is configured" not in result assert "trouble transcribing" in result assert "caption" in result + + +@pytest.mark.asyncio +async def test_prepare_inbound_message_text_transcribes_queued_voice_event(): + from gateway.run import GatewayRunner + + runner = GatewayRunner.__new__(GatewayRunner) + runner.config = GatewayConfig(stt_enabled=True) + runner.adapters = {} + runner._model = "test-model" + runner._base_url = "" + runner._has_setup_skill = lambda: False + + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123", + chat_type="dm", + ) + event = MessageEvent( + text="", + message_type=MessageType.VOICE, + source=source, + media_urls=["/tmp/queued-voice.ogg"], + media_types=["audio/ogg"], + ) + + with patch( + "tools.transcription_tools.transcribe_audio", + return_value={ + "success": True, + "transcript": "queued voice transcript", + "provider": "local_command", + }, + ): + result = await runner._prepare_inbound_message_text( + event=event, + source=source, + history=[], + ) + + assert result is not None + assert "queued voice transcript" in result + assert "voice message" in result.lower()