fix(gateway): preserve queued voice events for STT

This commit is contained in:
etcircle 2026-04-11 21:36:05 +01:00 committed by Teknium
parent 4bede272cf
commit dd6b5ffa74
No known key found for this signature in database
3 changed files with 269 additions and 196 deletions

View file

@ -352,19 +352,14 @@ def _build_media_placeholder(event) -> str:
return "\n".join(parts)
def _dequeue_pending_text(adapter, session_key: str) -> str | None:
"""Consume and return the text of a pending queued message.
def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None:
"""Consume and return the full pending event for a session.
Preserves media context for captionless photo/document events by
building a placeholder so the message isn't silently dropped.
Queued follow-ups must preserve their media metadata so they can re-enter
the normal image/STT/document preprocessing path instead of being reduced
to a placeholder string.
"""
event = adapter.get_pending_message(session_key)
if not event:
return None
text = event.text
if not text and getattr(event, "media_urls", None):
text = _build_media_placeholder(event)
return text
return adapter.get_pending_message(session_key)
def _check_unavailable_skill(command_name: str) -> str | None:
@ -2775,6 +2770,162 @@ class GatewayRunner:
del self._running_agents[_quick_key]
self._running_agents_ts.pop(_quick_key, None)
async def _prepare_inbound_message_text(
self,
*,
event: MessageEvent,
source: SessionSource,
history: List[Dict[str, Any]],
) -> Optional[str]:
"""Prepare inbound event text for the agent.
Keep the normal inbound path and the queued follow-up path on the same
preprocessing pipeline so sender attribution, image enrichment, STT,
document notes, reply context, and @ references all behave the same.
"""
history = history or []
message_text = event.text or ""
_is_shared_thread = (
source.chat_type != "dm"
and source.thread_id
and not getattr(self.config, "thread_sessions_per_user", False)
)
if _is_shared_thread and source.user_name:
message_text = f"[{source.user_name}] {message_text}"
if event.media_urls:
image_paths = []
audio_paths = []
for i, path in enumerate(event.media_urls):
mtype = event.media_types[i] if i < len(event.media_types) else ""
if mtype.startswith("image/") or event.message_type == MessageType.PHOTO:
image_paths.append(path)
if mtype.startswith("audio/") or event.message_type in (MessageType.VOICE, MessageType.AUDIO):
audio_paths.append(path)
if image_paths:
message_text = await self._enrich_message_with_vision(
message_text,
image_paths,
)
if audio_paths:
message_text = await self._enrich_message_with_transcription(
message_text,
audio_paths,
)
_stt_fail_markers = (
"No STT provider",
"STT is disabled",
"can't listen",
"VOICE_TOOLS_OPENAI_KEY",
)
if any(marker in message_text for marker in _stt_fail_markers):
_stt_adapter = self.adapters.get(source.platform)
_stt_meta = {"thread_id": source.thread_id} if source.thread_id else None
if _stt_adapter:
try:
_stt_msg = (
"🎤 I received your voice message but can't transcribe it — "
"no speech-to-text provider is configured.\n\n"
"To enable voice: install faster-whisper "
"(`pip install faster-whisper` in the Hermes venv) "
"and set `stt.enabled: true` in config.yaml, "
"then /restart the gateway."
)
if self._has_setup_skill():
_stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`"
await _stt_adapter.send(
source.chat_id,
_stt_msg,
metadata=_stt_meta,
)
except Exception:
pass
if event.media_urls and event.message_type == MessageType.DOCUMENT:
import mimetypes as _mimetypes
_TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"}
for i, path in enumerate(event.media_urls):
mtype = event.media_types[i] if i < len(event.media_types) else ""
if mtype in ("", "application/octet-stream"):
import os as _os2
_ext = _os2.path.splitext(path)[1].lower()
if _ext in _TEXT_EXTENSIONS:
mtype = "text/plain"
else:
guessed, _ = _mimetypes.guess_type(path)
if guessed:
mtype = guessed
if not mtype.startswith(("application/", "text/")):
continue
import os as _os
import re as _re
basename = _os.path.basename(path)
parts = basename.split("_", 2)
display_name = parts[2] if len(parts) >= 3 else basename
display_name = _re.sub(r'[^\w.\- ]', '_', display_name)
if mtype.startswith("text/"):
context_note = (
f"[The user sent a text document: '{display_name}'. "
f"Its content has been included below. "
f"The file is also saved at: {path}]"
)
else:
context_note = (
f"[The user sent a document: '{display_name}'. "
f"The file is saved at: {path}. "
f"Ask the user what they'd like you to do with it.]"
)
message_text = f"{context_note}\n\n{message_text}"
if getattr(event, "reply_to_text", None) and event.reply_to_message_id:
reply_snippet = event.reply_to_text[:500]
found_in_history = any(
reply_snippet[:200] in (msg.get("content") or "")
for msg in history
if msg.get("role") in ("assistant", "user", "tool")
)
if not found_in_history:
message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}'
if "@" in message_text:
try:
from agent.context_references import preprocess_context_references_async
from agent.model_metadata import get_model_context_length
_msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~"))
_msg_ctx_len = get_model_context_length(
self._model,
base_url=self._base_url or "",
)
_ctx_result = await preprocess_context_references_async(
message_text,
cwd=_msg_cwd,
context_length=_msg_ctx_len,
allowed_root=_msg_cwd,
)
if _ctx_result.blocked:
_adapter = self.adapters.get(source.platform)
if _adapter:
await _adapter.send(
source.chat_id,
"\n".join(_ctx_result.warnings) or "Context injection refused.",
)
return None
if _ctx_result.expanded:
message_text = _ctx_result.message
except Exception as exc:
logger.debug("@ context reference expansion failed: %s", exc)
return message_text
async def _handle_message_with_agent(self, event, source, _quick_key: str):
"""Inner handler that runs under the _running_agents sentinel guard."""
_msg_start_time = time.time()
@ -3215,149 +3366,13 @@ class GatewayRunner:
# attachments (documents, audio, etc.) are not sent to the vision
# tool even when they appear in the same message.
# -----------------------------------------------------------------
message_text = event.text or ""
# -----------------------------------------------------------------
# Sender attribution for shared thread sessions.
#
# When multiple users share a single thread session (the default for
# threads), prefix each message with [sender name] so the agent can
# tell participants apart. Skip for DMs (single-user by nature) and
# when per-user thread isolation is explicitly enabled.
# -----------------------------------------------------------------
_is_shared_thread = (
source.chat_type != "dm"
and source.thread_id
and not getattr(self.config, "thread_sessions_per_user", False)
message_text = await self._prepare_inbound_message_text(
event=event,
source=source,
history=history,
)
if _is_shared_thread and source.user_name:
message_text = f"[{source.user_name}] {message_text}"
if event.media_urls:
image_paths = []
for i, path in enumerate(event.media_urls):
# Check media_types if available; otherwise infer from message type
mtype = event.media_types[i] if i < len(event.media_types) else ""
is_image = (
mtype.startswith("image/")
or event.message_type == MessageType.PHOTO
)
if is_image:
image_paths.append(path)
if image_paths:
message_text = await self._enrich_message_with_vision(
message_text, image_paths
)
# -----------------------------------------------------------------
# Auto-transcribe voice/audio messages sent by the user
# -----------------------------------------------------------------
if event.media_urls:
audio_paths = []
for i, path in enumerate(event.media_urls):
mtype = event.media_types[i] if i < len(event.media_types) else ""
is_audio = (
mtype.startswith("audio/")
or event.message_type in (MessageType.VOICE, MessageType.AUDIO)
)
if is_audio:
audio_paths.append(path)
if audio_paths:
message_text = await self._enrich_message_with_transcription(
message_text, audio_paths
)
# If STT failed, send a direct message to the user so they
# know voice isn't configured — don't rely on the agent to
# relay the error clearly.
_stt_fail_markers = (
"No STT provider",
"STT is disabled",
"can't listen",
"VOICE_TOOLS_OPENAI_KEY",
)
if any(m in message_text for m in _stt_fail_markers):
_stt_adapter = self.adapters.get(source.platform)
_stt_meta = {"thread_id": source.thread_id} if source.thread_id else None
if _stt_adapter:
try:
_stt_msg = (
"🎤 I received your voice message but can't transcribe it — "
"no speech-to-text provider is configured.\n\n"
"To enable voice: install faster-whisper "
"(`pip install faster-whisper` in the Hermes venv) "
"and set `stt.enabled: true` in config.yaml, "
"then /restart the gateway."
)
# Point to setup skill if it's installed
if self._has_setup_skill():
_stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`"
await _stt_adapter.send(
source.chat_id, _stt_msg,
metadata=_stt_meta,
)
except Exception:
pass
# -----------------------------------------------------------------
# Enrich document messages with context notes for the agent
# -----------------------------------------------------------------
if event.media_urls and event.message_type == MessageType.DOCUMENT:
import mimetypes as _mimetypes
_TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"}
for i, path in enumerate(event.media_urls):
mtype = event.media_types[i] if i < len(event.media_types) else ""
# Fall back to extension-based detection when MIME type is unreliable.
if mtype in ("", "application/octet-stream"):
import os as _os2
_ext = _os2.path.splitext(path)[1].lower()
if _ext in _TEXT_EXTENSIONS:
mtype = "text/plain"
else:
guessed, _ = _mimetypes.guess_type(path)
if guessed:
mtype = guessed
if not mtype.startswith(("application/", "text/")):
continue
# Extract display filename by stripping the doc_{uuid12}_ prefix
import os as _os
basename = _os.path.basename(path)
# Format: doc_<12hex>_<original_filename>
parts = basename.split("_", 2)
display_name = parts[2] if len(parts) >= 3 else basename
# Sanitize to prevent prompt injection via filenames
import re as _re
display_name = _re.sub(r'[^\w.\- ]', '_', display_name)
if mtype.startswith("text/"):
context_note = (
f"[The user sent a text document: '{display_name}'. "
f"Its content has been included below. "
f"The file is also saved at: {path}]"
)
else:
context_note = (
f"[The user sent a document: '{display_name}'. "
f"The file is saved at: {path}. "
f"Ask the user what they'd like you to do with it.]"
)
message_text = f"{context_note}\n\n{message_text}"
# -----------------------------------------------------------------
# Inject reply context when user replies to a message not in history.
# Telegram (and other platforms) let users reply to specific messages,
# but if the quoted message is from a previous session, cron delivery,
# or background task, the agent has no context about what's being
# referenced. Prepend the quoted text so the agent understands. (#1594)
# -----------------------------------------------------------------
if getattr(event, 'reply_to_text', None) and event.reply_to_message_id:
reply_snippet = event.reply_to_text[:500]
found_in_history = any(
reply_snippet[:200] in (msg.get("content") or "")
for msg in history
if msg.get("role") in ("assistant", "user", "tool")
)
if not found_in_history:
message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}'
if message_text is None:
return
try:
# Emit agent:start hook
@ -3369,30 +3384,6 @@ class GatewayRunner:
}
await self.hooks.emit("agent:start", hook_ctx)
# Expand @ context references (@file:, @folder:, @diff, etc.)
if "@" in message_text:
try:
from agent.context_references import preprocess_context_references_async
from agent.model_metadata import get_model_context_length
_msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~"))
_msg_ctx_len = get_model_context_length(
self._model, base_url=self._base_url or "")
_ctx_result = await preprocess_context_references_async(
message_text, cwd=_msg_cwd,
context_length=_msg_ctx_len, allowed_root=_msg_cwd)
if _ctx_result.blocked:
_adapter = self.adapters.get(source.platform)
if _adapter:
await _adapter.send(
source.chat_id,
"\n".join(_ctx_result.warnings) or "Context injection refused.",
)
return
if _ctx_result.expanded:
message_text = _ctx_result.message
except Exception as exc:
logger.debug("@ context reference expansion failed: %s", exc)
# Run the agent
agent_result = await self._run_agent(
message=message_text,
@ -8057,17 +8048,16 @@ class GatewayRunner:
# Get pending message from adapter.
# Use session_key (not source.chat_id) to match adapter's storage keys.
pending_event = None
pending = None
if result and adapter and session_key:
if result.get("interrupted"):
pending = _dequeue_pending_text(adapter, session_key)
if not pending and result.get("interrupt_message"):
pending = result.get("interrupt_message")
else:
pending = _dequeue_pending_text(adapter, session_key)
if pending:
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
pending_event = _dequeue_pending_event(adapter, session_key)
if result.get("interrupted") and not pending_event and result.get("interrupt_message"):
pending = result.get("interrupt_message")
elif pending_event:
pending = pending_event.text or _build_media_placeholder(pending_event)
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
# Safety net: if the pending text is a slash command (e.g. "/stop",
# "/new"), discard it — commands should never be passed to the agent
# as user input. The primary fix is in base.py (commands bypass the
@ -8085,27 +8075,29 @@ class GatewayRunner:
"commands must not be passed as agent input",
_pending_cmd_word,
)
pending_event = None
pending = None
except Exception:
pass
if self._draining and pending:
if self._draining and (pending_event or pending):
logger.info(
"Discarding pending follow-up for session %s during gateway %s",
session_key[:20] if session_key else "?",
self._status_action_label(),
)
pending_event = None
pending = None
if pending:
if pending_event or pending:
logger.debug("Processing pending message: '%s...'", pending[:40])
# Clear the adapter's interrupt event so the next _run_agent call
# doesn't immediately re-trigger the interrupt before the new agent
# even makes its first API call (this was causing an infinite loop).
if adapter and hasattr(adapter, '_active_sessions') and session_key and session_key in adapter._active_sessions:
adapter._active_sessions[session_key].clear()
# Cap recursion depth to prevent resource exhaustion when the
# user sends multiple messages while the agent keeps failing. (#816)
if _interrupt_depth >= self._MAX_INTERRUPT_DEPTH:
@ -8114,9 +8106,10 @@ class GatewayRunner:
"queueing message instead of recursing.",
_interrupt_depth, session_key,
)
# Queue the pending message for normal processing on next turn
adapter = self.adapters.get(source.platform)
if adapter and hasattr(adapter, 'queue_message'):
if adapter and pending_event:
merge_pending_message_event(adapter._pending_messages, session_key, pending_event)
elif adapter and hasattr(adapter, 'queue_message'):
adapter.queue_message(session_key, pending)
return result_holder[0] or {"final_response": response, "messages": history}
@ -8138,16 +8131,30 @@ class GatewayRunner:
# interrupted." is just noise; the user already knows they sent a
# new message).
# Process the pending message with updated history
updated_history = result.get("messages", history)
next_source = source
next_message = pending
next_message_id = None
if pending_event is not None:
next_source = getattr(pending_event, "source", None) or source
next_message = await self._prepare_inbound_message_text(
event=pending_event,
source=next_source,
history=updated_history,
)
if next_message is None:
return result
next_message_id = getattr(pending_event, "message_id", None)
return await self._run_agent(
message=pending,
message=next_message,
context_prompt=context_prompt,
history=updated_history,
source=source,
source=next_source,
session_id=session_id,
session_key=session_key,
_interrupt_depth=_interrupt_depth + 1,
event_message_id=next_message_id,
)
finally:
# Stop progress sender, interrupt monitor, and notification task

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.run import _dequeue_pending_event
from gateway.platforms.base import (
BasePlatformAdapter,
MessageEvent,
@ -79,6 +80,26 @@ class TestQueueMessageStorage:
# Should be consumed (cleared)
assert adapter.get_pending_message(session_key) is None
def test_dequeue_pending_event_preserves_voice_media_metadata(self):
adapter = _StubAdapter()
session_key = "telegram:user:voice"
event = MessageEvent(
text="",
message_type=MessageType.VOICE,
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
message_id="voice-q1",
media_urls=["/tmp/voice.ogg"],
media_types=["audio/ogg"],
)
adapter._pending_messages[session_key] = event
retrieved = _dequeue_pending_event(adapter, session_key)
assert retrieved is event
assert retrieved.media_urls == ["/tmp/voice.ogg"]
assert retrieved.media_types == ["audio/ogg"]
assert adapter.get_pending_message(session_key) is None
def test_queue_does_not_set_interrupt_event(self):
"""The whole point of /queue — no interrupt signal."""
adapter = _StubAdapter()

View file

@ -6,7 +6,9 @@ from unittest.mock import AsyncMock, patch
import pytest
import yaml
from gateway.config import GatewayConfig, load_gateway_config
from gateway.config import GatewayConfig, Platform, load_gateway_config
from gateway.platforms.base import MessageEvent, MessageType
from gateway.session import SessionSource
def test_gateway_config_stt_disabled_from_dict_nested():
@ -69,3 +71,46 @@ async def test_enrich_message_with_transcription_avoids_bogus_no_provider_messag
assert "No STT provider is configured" not in result
assert "trouble transcribing" in result
assert "caption" in result
@pytest.mark.asyncio
async def test_prepare_inbound_message_text_transcribes_queued_voice_event():
from gateway.run import GatewayRunner
runner = GatewayRunner.__new__(GatewayRunner)
runner.config = GatewayConfig(stt_enabled=True)
runner.adapters = {}
runner._model = "test-model"
runner._base_url = ""
runner._has_setup_skill = lambda: False
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="123",
chat_type="dm",
)
event = MessageEvent(
text="",
message_type=MessageType.VOICE,
source=source,
media_urls=["/tmp/queued-voice.ogg"],
media_types=["audio/ogg"],
)
with patch(
"tools.transcription_tools.transcribe_audio",
return_value={
"success": True,
"transcript": "queued voice transcript",
"provider": "local_command",
},
):
result = await runner._prepare_inbound_message_text(
event=event,
source=source,
history=[],
)
assert result is not None
assert "queued voice transcript" in result
assert "voice message" in result.lower()