mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
merge: salvage PR #327 voice mode branch
Merge contributor branch feature/voice-mode onto current main for follow-up fixes.
This commit is contained in:
commit
523a1b6faf
37 changed files with 9248 additions and 228 deletions
24
.env.example
24
.env.example
|
|
@ -275,3 +275,27 @@ WANDB_API_KEY=
|
|||
# GITHUB_APP_ID=
|
||||
# GITHUB_APP_PRIVATE_KEY_PATH=
|
||||
# GITHUB_APP_INSTALLATION_ID=
|
||||
|
||||
# Groq API key (free tier — used for Whisper STT in voice mode)
|
||||
# GROQ_API_KEY=
|
||||
|
||||
# =============================================================================
|
||||
# STT PROVIDER SELECTION
|
||||
# =============================================================================
|
||||
# Default STT provider is "local" (faster-whisper) — runs on your machine, no API key needed.
|
||||
# Install with: pip install faster-whisper
|
||||
# Model downloads automatically on first use (~150 MB for "base").
|
||||
# To use cloud providers instead, set GROQ_API_KEY or VOICE_TOOLS_OPENAI_KEY above.
|
||||
# Provider priority: local > groq > openai
|
||||
# Configure in config.yaml: stt.provider: local | groq | openai
|
||||
|
||||
# =============================================================================
|
||||
# STT ADVANCED OVERRIDES (optional)
|
||||
# =============================================================================
|
||||
# Override default STT models per provider (normally set via stt.model in config.yaml)
|
||||
# STT_GROQ_MODEL=whisper-large-v3-turbo
|
||||
# STT_OPENAI_MODEL=whisper-1
|
||||
|
||||
# Override STT provider endpoints (for proxies or self-hosted instances)
|
||||
# GROQ_BASE_URL=https://api.groq.com/openai/v1
|
||||
# STT_OPENAI_BASE_URL=https://api.openai.com/v1
|
||||
|
|
|
|||
760
cli.py
760
cli.py
|
|
@ -18,6 +18,8 @@ import shutil
|
|||
import sys
|
||||
import json
|
||||
import atexit
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
import textwrap
|
||||
from contextlib import contextmanager
|
||||
|
|
@ -1548,6 +1550,7 @@ class HermesCLI:
|
|||
checkpoints_enabled=self.checkpoints_enabled,
|
||||
checkpoint_max_snapshots=self.checkpoint_max_snapshots,
|
||||
pass_session_id=self.pass_session_id,
|
||||
tool_progress_callback=self._on_tool_progress,
|
||||
)
|
||||
# Apply any pending title now that the session exists in the DB
|
||||
if self._pending_title and self._session_db:
|
||||
|
|
@ -3017,6 +3020,8 @@ class HermesCLI:
|
|||
self._handle_background_command(cmd_original)
|
||||
elif cmd_lower.startswith("/skin"):
|
||||
self._handle_skin_command(cmd_original)
|
||||
elif cmd_lower.startswith("/voice"):
|
||||
self._handle_voice_command(cmd_original)
|
||||
else:
|
||||
# Check for user-defined quick commands (bypass agent loop, no LLM call)
|
||||
base_cmd = cmd_lower.split()[0]
|
||||
|
|
@ -3511,6 +3516,407 @@ class HermesCLI:
|
|||
except Exception as e:
|
||||
print(f" ❌ MCP reload failed: {e}")
|
||||
|
||||
# ====================================================================
|
||||
# Tool progress callback (audio cues for voice mode)
|
||||
# ====================================================================
|
||||
|
||||
def _on_tool_progress(self, function_name: str, preview: str, function_args: dict):
|
||||
"""Called when a tool starts executing. Plays audio cue in voice mode."""
|
||||
if not self._voice_mode:
|
||||
return
|
||||
# Skip internal/thinking tools
|
||||
if function_name.startswith("_"):
|
||||
return
|
||||
try:
|
||||
from tools.voice_mode import play_beep
|
||||
# Short, subtle tick sound (higher pitch, very brief)
|
||||
threading.Thread(
|
||||
target=play_beep,
|
||||
kwargs={"frequency": 1200, "duration": 0.06, "count": 1},
|
||||
daemon=True,
|
||||
).start()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ====================================================================
|
||||
# Voice mode methods
|
||||
# ====================================================================
|
||||
|
||||
def _voice_start_recording(self):
|
||||
"""Start capturing audio from the microphone."""
|
||||
if getattr(self, '_should_exit', False):
|
||||
return
|
||||
from tools.voice_mode import AudioRecorder, check_voice_requirements
|
||||
|
||||
reqs = check_voice_requirements()
|
||||
if not reqs["audio_available"]:
|
||||
raise RuntimeError(
|
||||
"Voice mode requires sounddevice and numpy.\n"
|
||||
"Install with: pip install sounddevice numpy\n"
|
||||
"Or: pip install hermes-agent[voice]"
|
||||
)
|
||||
if not reqs.get("stt_available", reqs.get("stt_key_set")):
|
||||
raise RuntimeError(
|
||||
"Voice mode requires an STT provider for transcription.\n"
|
||||
"Option 1: pip install faster-whisper (free, local)\n"
|
||||
"Option 2: Set GROQ_API_KEY (free tier)\n"
|
||||
"Option 3: Set VOICE_TOOLS_OPENAI_KEY (paid)"
|
||||
)
|
||||
|
||||
# Prevent double-start from concurrent threads (atomic check-and-set)
|
||||
with self._voice_lock:
|
||||
if self._voice_recording:
|
||||
return
|
||||
self._voice_recording = True
|
||||
|
||||
# Load silence detection params from config
|
||||
voice_cfg = {}
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
voice_cfg = load_config().get("voice", {})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if self._voice_recorder is None:
|
||||
self._voice_recorder = AudioRecorder()
|
||||
|
||||
# Apply config-driven silence params
|
||||
self._voice_recorder._silence_threshold = voice_cfg.get("silence_threshold", 200)
|
||||
self._voice_recorder._silence_duration = voice_cfg.get("silence_duration", 3.0)
|
||||
|
||||
def _on_silence():
|
||||
"""Called by AudioRecorder when silence is detected after speech."""
|
||||
with self._voice_lock:
|
||||
if not self._voice_recording:
|
||||
return
|
||||
_cprint(f"\n{_DIM}Silence detected, auto-stopping...{_RST}")
|
||||
if hasattr(self, '_app') and self._app:
|
||||
self._app.invalidate()
|
||||
self._voice_stop_and_transcribe()
|
||||
|
||||
# Audio cue: single beep BEFORE starting stream (avoid CoreAudio conflict)
|
||||
try:
|
||||
from tools.voice_mode import play_beep
|
||||
play_beep(frequency=880, count=1)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
self._voice_recorder.start(on_silence_stop=_on_silence)
|
||||
except Exception:
|
||||
with self._voice_lock:
|
||||
self._voice_recording = False
|
||||
raise
|
||||
_cprint(f"\n{_GOLD}● Recording...{_RST} {_DIM}(auto-stops on silence | Ctrl+B to stop & exit continuous){_RST}")
|
||||
|
||||
# Periodically refresh prompt to update audio level indicator
|
||||
def _refresh_level():
|
||||
while True:
|
||||
with self._voice_lock:
|
||||
still_recording = self._voice_recording
|
||||
if not still_recording:
|
||||
break
|
||||
if hasattr(self, '_app') and self._app:
|
||||
self._app.invalidate()
|
||||
time.sleep(0.15)
|
||||
threading.Thread(target=_refresh_level, daemon=True).start()
|
||||
|
||||
def _voice_stop_and_transcribe(self):
|
||||
"""Stop recording, transcribe via STT, and queue the transcript as input."""
|
||||
# Atomic guard: only one thread can enter stop-and-transcribe.
|
||||
# Set _voice_processing immediately so concurrent Ctrl+B presses
|
||||
# don't race into the START path while recorder.stop() holds its lock.
|
||||
with self._voice_lock:
|
||||
if not self._voice_recording:
|
||||
return
|
||||
self._voice_recording = False
|
||||
self._voice_processing = True
|
||||
|
||||
submitted = False
|
||||
wav_path = None
|
||||
try:
|
||||
if self._voice_recorder is None:
|
||||
return
|
||||
|
||||
wav_path = self._voice_recorder.stop()
|
||||
|
||||
# Audio cue: double beep after stream stopped (no CoreAudio conflict)
|
||||
try:
|
||||
from tools.voice_mode import play_beep
|
||||
play_beep(frequency=660, count=2)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if wav_path is None:
|
||||
_cprint(f"{_DIM}No speech detected.{_RST}")
|
||||
return
|
||||
|
||||
# _voice_processing is already True (set atomically above)
|
||||
if hasattr(self, '_app') and self._app:
|
||||
self._app.invalidate()
|
||||
_cprint(f"{_DIM}Transcribing...{_RST}")
|
||||
|
||||
# Get STT model from config
|
||||
stt_model = None
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
stt_config = load_config().get("stt", {})
|
||||
stt_model = stt_config.get("model")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from tools.voice_mode import transcribe_recording
|
||||
result = transcribe_recording(wav_path, model=stt_model)
|
||||
|
||||
if result.get("success") and result.get("transcript", "").strip():
|
||||
transcript = result["transcript"].strip()
|
||||
self._pending_input.put(transcript)
|
||||
submitted = True
|
||||
elif result.get("success"):
|
||||
_cprint(f"{_DIM}No speech detected.{_RST}")
|
||||
else:
|
||||
error = result.get("error", "Unknown error")
|
||||
_cprint(f"\n{_DIM}Transcription failed: {error}{_RST}")
|
||||
|
||||
except Exception as e:
|
||||
_cprint(f"\n{_DIM}Voice processing error: {e}{_RST}")
|
||||
finally:
|
||||
with self._voice_lock:
|
||||
self._voice_processing = False
|
||||
if hasattr(self, '_app') and self._app:
|
||||
self._app.invalidate()
|
||||
# Clean up temp file
|
||||
try:
|
||||
if wav_path and os.path.isfile(wav_path):
|
||||
os.unlink(wav_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Track consecutive no-speech cycles to avoid infinite restart loops.
|
||||
if not submitted:
|
||||
self._no_speech_count = getattr(self, '_no_speech_count', 0) + 1
|
||||
if self._no_speech_count >= 3:
|
||||
self._voice_continuous = False
|
||||
self._no_speech_count = 0
|
||||
_cprint(f"{_DIM}No speech detected 3 times, continuous mode stopped.{_RST}")
|
||||
return
|
||||
else:
|
||||
self._no_speech_count = 0
|
||||
|
||||
# If no transcript was submitted but continuous mode is active,
|
||||
# restart recording so the user can keep talking.
|
||||
# (When transcript IS submitted, process_loop handles restart
|
||||
# after chat() completes.)
|
||||
if self._voice_continuous and not submitted and not self._voice_recording:
|
||||
def _restart_recording():
|
||||
try:
|
||||
self._voice_start_recording()
|
||||
if hasattr(self, '_app') and self._app:
|
||||
self._app.invalidate()
|
||||
except Exception as e:
|
||||
_cprint(f"{_DIM}Voice auto-restart failed: {e}{_RST}")
|
||||
threading.Thread(target=_restart_recording, daemon=True).start()
|
||||
|
||||
def _voice_speak_response(self, text: str):
|
||||
"""Speak the agent's response aloud using TTS (runs in background thread)."""
|
||||
if not self._voice_tts:
|
||||
return
|
||||
self._voice_tts_done.clear()
|
||||
try:
|
||||
from tools.tts_tool import text_to_speech_tool
|
||||
from tools.voice_mode import play_audio_file
|
||||
import json
|
||||
import re
|
||||
|
||||
# Strip markdown and non-speech content for cleaner TTS
|
||||
tts_text = text[:4000] if len(text) > 4000 else text
|
||||
tts_text = re.sub(r'```[\s\S]*?```', ' ', tts_text) # fenced code blocks
|
||||
tts_text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', tts_text) # [text](url) -> text
|
||||
tts_text = re.sub(r'https?://\S+', '', tts_text) # URLs
|
||||
tts_text = re.sub(r'\*\*(.+?)\*\*', r'\1', tts_text) # bold
|
||||
tts_text = re.sub(r'\*(.+?)\*', r'\1', tts_text) # italic
|
||||
tts_text = re.sub(r'`(.+?)`', r'\1', tts_text) # inline code
|
||||
tts_text = re.sub(r'^#+\s*', '', tts_text, flags=re.MULTILINE) # headers
|
||||
tts_text = re.sub(r'^\s*[-*]\s+', '', tts_text, flags=re.MULTILINE) # list items
|
||||
tts_text = re.sub(r'---+', '', tts_text) # horizontal rules
|
||||
tts_text = re.sub(r'\n{3,}', '\n\n', tts_text) # excessive newlines
|
||||
tts_text = tts_text.strip()
|
||||
if not tts_text:
|
||||
return
|
||||
|
||||
# Use MP3 output for CLI playback (afplay doesn't handle OGG well).
|
||||
# The TTS tool may auto-convert MP3->OGG, but the original MP3 remains.
|
||||
os.makedirs(os.path.join(tempfile.gettempdir(), "hermes_voice"), exist_ok=True)
|
||||
mp3_path = os.path.join(
|
||||
tempfile.gettempdir(), "hermes_voice",
|
||||
f"tts_{time.strftime('%Y%m%d_%H%M%S')}.mp3",
|
||||
)
|
||||
|
||||
text_to_speech_tool(text=tts_text, output_path=mp3_path)
|
||||
|
||||
# Play the MP3 directly (the TTS tool returns OGG path but MP3 still exists)
|
||||
if os.path.isfile(mp3_path) and os.path.getsize(mp3_path) > 0:
|
||||
play_audio_file(mp3_path)
|
||||
# Clean up
|
||||
try:
|
||||
os.unlink(mp3_path)
|
||||
ogg_path = mp3_path.rsplit(".", 1)[0] + ".ogg"
|
||||
if os.path.isfile(ogg_path):
|
||||
os.unlink(ogg_path)
|
||||
except OSError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Voice TTS playback failed: %s", e)
|
||||
_cprint(f"{_DIM}TTS playback failed: {e}{_RST}")
|
||||
finally:
|
||||
self._voice_tts_done.set()
|
||||
|
||||
def _handle_voice_command(self, command: str):
|
||||
"""Handle /voice [on|off|tts|status] command."""
|
||||
parts = command.strip().split(maxsplit=1)
|
||||
subcommand = parts[1].lower().strip() if len(parts) > 1 else ""
|
||||
|
||||
if subcommand == "on":
|
||||
self._enable_voice_mode()
|
||||
elif subcommand == "off":
|
||||
self._disable_voice_mode()
|
||||
elif subcommand == "tts":
|
||||
self._toggle_voice_tts()
|
||||
elif subcommand == "status":
|
||||
self._show_voice_status()
|
||||
elif subcommand == "":
|
||||
# Toggle
|
||||
if self._voice_mode:
|
||||
self._disable_voice_mode()
|
||||
else:
|
||||
self._enable_voice_mode()
|
||||
else:
|
||||
_cprint(f"Unknown voice subcommand: {subcommand}")
|
||||
_cprint("Usage: /voice [on|off|tts|status]")
|
||||
|
||||
def _enable_voice_mode(self):
|
||||
"""Enable voice mode after checking requirements."""
|
||||
if self._voice_mode:
|
||||
_cprint(f"{_DIM}Voice mode is already enabled.{_RST}")
|
||||
return
|
||||
|
||||
from tools.voice_mode import check_voice_requirements, detect_audio_environment
|
||||
|
||||
# Environment detection -- warn and block in incompatible environments
|
||||
env_check = detect_audio_environment()
|
||||
if not env_check["available"]:
|
||||
_cprint(f"\n{_GOLD}Voice mode unavailable in this environment:{_RST}")
|
||||
for warning in env_check["warnings"]:
|
||||
_cprint(f" {_DIM}{warning}{_RST}")
|
||||
return
|
||||
|
||||
reqs = check_voice_requirements()
|
||||
if not reqs["available"]:
|
||||
_cprint(f"\n{_GOLD}Voice mode requirements not met:{_RST}")
|
||||
for line in reqs["details"].split("\n"):
|
||||
_cprint(f" {_DIM}{line}{_RST}")
|
||||
if reqs["missing_packages"]:
|
||||
_cprint(f"\n {_BOLD}Install: pip install {' '.join(reqs['missing_packages'])}{_RST}")
|
||||
_cprint(f" {_DIM}Or: pip install hermes-agent[voice]{_RST}")
|
||||
return
|
||||
|
||||
with self._voice_lock:
|
||||
self._voice_mode = True
|
||||
|
||||
# Check config for auto_tts
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
voice_config = load_config().get("voice", {})
|
||||
if voice_config.get("auto_tts", False):
|
||||
with self._voice_lock:
|
||||
self._voice_tts = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Voice mode instruction is injected as a user message prefix (not a
|
||||
# system prompt change) to avoid invalidating the prompt cache. See
|
||||
# _voice_message_prefix property and its usage in _process_message().
|
||||
|
||||
tts_status = " (TTS enabled)" if self._voice_tts else ""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
_raw_ptt = load_config().get("voice", {}).get("record_key", "ctrl+b")
|
||||
_ptt_key = _raw_ptt.lower().replace("ctrl+", "c-").replace("alt+", "a-")
|
||||
except Exception:
|
||||
_ptt_key = "c-b"
|
||||
_ptt_display = _ptt_key.replace("c-", "Ctrl+").upper()
|
||||
_cprint(f"\n{_GOLD}Voice mode enabled{tts_status}{_RST}")
|
||||
_cprint(f" {_DIM}{_ptt_display} to start/stop recording{_RST}")
|
||||
_cprint(f" {_DIM}/voice tts to toggle speech output{_RST}")
|
||||
_cprint(f" {_DIM}/voice off to disable voice mode{_RST}")
|
||||
|
||||
def _disable_voice_mode(self):
|
||||
"""Disable voice mode, cancel any active recording, and stop TTS."""
|
||||
recorder = None
|
||||
with self._voice_lock:
|
||||
if self._voice_recording and self._voice_recorder:
|
||||
self._voice_recorder.cancel()
|
||||
self._voice_recording = False
|
||||
recorder = self._voice_recorder
|
||||
self._voice_mode = False
|
||||
self._voice_tts = False
|
||||
self._voice_continuous = False
|
||||
|
||||
# Shut down the persistent audio stream in background
|
||||
if recorder is not None:
|
||||
def _bg_shutdown(rec=recorder):
|
||||
try:
|
||||
rec.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
threading.Thread(target=_bg_shutdown, daemon=True).start()
|
||||
self._voice_recorder = None
|
||||
|
||||
# Stop any active TTS playback
|
||||
try:
|
||||
from tools.voice_mode import stop_playback
|
||||
stop_playback()
|
||||
except Exception:
|
||||
pass
|
||||
self._voice_tts_done.set()
|
||||
|
||||
_cprint(f"\n{_DIM}Voice mode disabled.{_RST}")
|
||||
|
||||
def _toggle_voice_tts(self):
|
||||
"""Toggle TTS output for voice mode."""
|
||||
if not self._voice_mode:
|
||||
_cprint(f"{_DIM}Enable voice mode first: /voice on{_RST}")
|
||||
return
|
||||
|
||||
with self._voice_lock:
|
||||
self._voice_tts = not self._voice_tts
|
||||
status = "enabled" if self._voice_tts else "disabled"
|
||||
|
||||
if self._voice_tts:
|
||||
from tools.tts_tool import check_tts_requirements
|
||||
if not check_tts_requirements():
|
||||
_cprint(f"{_DIM}Warning: No TTS provider available. Install edge-tts or set API keys.{_RST}")
|
||||
|
||||
_cprint(f"{_GOLD}Voice TTS {status}.{_RST}")
|
||||
|
||||
def _show_voice_status(self):
|
||||
"""Show current voice mode status."""
|
||||
from hermes_cli.config import load_config
|
||||
from tools.voice_mode import check_voice_requirements
|
||||
|
||||
reqs = check_voice_requirements()
|
||||
|
||||
_cprint(f"\n{_BOLD}Voice Mode Status{_RST}")
|
||||
_cprint(f" Mode: {'ON' if self._voice_mode else 'OFF'}")
|
||||
_cprint(f" TTS: {'ON' if self._voice_tts else 'OFF'}")
|
||||
_cprint(f" Recording: {'YES' if self._voice_recording else 'no'}")
|
||||
_raw_key = load_config().get("voice", {}).get("record_key", "ctrl+b")
|
||||
_display_key = _raw_key.replace("ctrl+", "Ctrl+").upper() if "ctrl+" in _raw_key.lower() else _raw_key
|
||||
_cprint(f" Record key: {_display_key}")
|
||||
_cprint(f"\n {_BOLD}Requirements:{_RST}")
|
||||
for line in reqs["details"].split("\n"):
|
||||
_cprint(f" {line}")
|
||||
|
||||
def _clarify_callback(self, question, choices):
|
||||
"""
|
||||
Platform callback for the clarify tool. Called from the agent thread.
|
||||
|
|
@ -3752,19 +4158,90 @@ class HermesCLI:
|
|||
try:
|
||||
# Run the conversation with interrupt monitoring
|
||||
result = None
|
||||
|
||||
|
||||
# --- Streaming TTS setup ---
|
||||
# When ElevenLabs is the TTS provider and sounddevice is available,
|
||||
# we stream audio sentence-by-sentence as the agent generates tokens
|
||||
# instead of waiting for the full response.
|
||||
use_streaming_tts = False
|
||||
_streaming_box_opened = False
|
||||
text_queue = None
|
||||
tts_thread = None
|
||||
stream_callback = None
|
||||
stop_event = None
|
||||
|
||||
if self._voice_tts:
|
||||
try:
|
||||
from tools.tts_tool import (
|
||||
_load_tts_config as _load_tts_cfg,
|
||||
_get_provider as _get_prov,
|
||||
_import_elevenlabs,
|
||||
_import_sounddevice,
|
||||
stream_tts_to_speaker,
|
||||
)
|
||||
_tts_cfg = _load_tts_cfg()
|
||||
if _get_prov(_tts_cfg) == "elevenlabs":
|
||||
# Verify both ElevenLabs SDK and audio output are available
|
||||
_import_elevenlabs()
|
||||
_import_sounddevice()
|
||||
use_streaming_tts = True
|
||||
except (ImportError, OSError):
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if use_streaming_tts:
|
||||
text_queue = queue.Queue()
|
||||
stop_event = threading.Event()
|
||||
|
||||
def display_callback(sentence: str):
|
||||
"""Called by TTS consumer when a sentence is ready to display + speak."""
|
||||
nonlocal _streaming_box_opened
|
||||
if not _streaming_box_opened:
|
||||
_streaming_box_opened = True
|
||||
w = self.console.width
|
||||
label = " ⚕ Hermes "
|
||||
fill = w - 2 - len(label)
|
||||
_cprint(f"\n{_GOLD}╭─{label}{'─' * max(fill - 1, 0)}╮{_RST}")
|
||||
_cprint(sentence.rstrip())
|
||||
|
||||
tts_thread = threading.Thread(
|
||||
target=stream_tts_to_speaker,
|
||||
args=(text_queue, stop_event, self._voice_tts_done),
|
||||
kwargs={"display_callback": display_callback},
|
||||
daemon=True,
|
||||
)
|
||||
tts_thread.start()
|
||||
|
||||
def stream_callback(delta: str):
|
||||
if text_queue is not None:
|
||||
text_queue.put(delta)
|
||||
|
||||
# When voice mode is active, prepend a brief instruction so the
|
||||
# model responds concisely. The prefix is API-call-local only —
|
||||
# we strip it from the returned history so it never persists to
|
||||
# session DB or resumed sessions.
|
||||
_voice_prefix = ""
|
||||
if self._voice_mode and isinstance(message, str):
|
||||
_voice_prefix = (
|
||||
"[Voice input — respond concisely and conversationally, "
|
||||
"2-3 sentences max. No code blocks or markdown.] "
|
||||
)
|
||||
|
||||
def run_agent():
|
||||
nonlocal result
|
||||
agent_message = _voice_prefix + message if _voice_prefix else message
|
||||
result = self.agent.run_conversation(
|
||||
user_message=message,
|
||||
user_message=agent_message,
|
||||
conversation_history=self.conversation_history[:-1], # Exclude the message we just added
|
||||
stream_callback=stream_callback,
|
||||
task_id=self.session_id,
|
||||
)
|
||||
|
||||
|
||||
# Start agent in background thread
|
||||
agent_thread = threading.Thread(target=run_agent)
|
||||
agent_thread.start()
|
||||
|
||||
|
||||
# Monitor the dedicated interrupt queue while the agent runs.
|
||||
# _interrupt_queue is separate from _pending_input, so process_loop
|
||||
# and chat() never compete for the same queue.
|
||||
|
|
@ -3783,6 +4260,9 @@ class HermesCLI:
|
|||
if self._clarify_state or self._clarify_freetext:
|
||||
continue
|
||||
print(f"\n⚡ New message detected, interrupting...")
|
||||
# Signal TTS to stop on interrupt
|
||||
if stop_event is not None:
|
||||
stop_event.set()
|
||||
self.agent.interrupt(interrupt_msg)
|
||||
# Debug: log to file (stdout may be devnull from redirect_stdout)
|
||||
try:
|
||||
|
|
@ -3802,9 +4282,15 @@ class HermesCLI:
|
|||
else:
|
||||
# Fallback for non-interactive mode (e.g., single-query)
|
||||
agent_thread.join(0.1)
|
||||
|
||||
|
||||
agent_thread.join() # Ensure agent thread completes
|
||||
|
||||
# Signal end-of-text to TTS consumer and wait for it to finish
|
||||
if use_streaming_tts and text_queue is not None:
|
||||
text_queue.put(None) # sentinel
|
||||
if tts_thread is not None:
|
||||
tts_thread.join(timeout=120)
|
||||
|
||||
# Drain any remaining agent output still in the StdoutProxy
|
||||
# buffer so tool/status lines render ABOVE our response box.
|
||||
# The flush pushes data into the renderer queue; the short
|
||||
|
|
@ -3815,15 +4301,29 @@ class HermesCLI:
|
|||
|
||||
# Update history with full conversation
|
||||
self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history
|
||||
|
||||
|
||||
# Strip voice prefix from history so it never persists
|
||||
if _voice_prefix and self.conversation_history:
|
||||
for msg in self.conversation_history:
|
||||
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
|
||||
if msg["content"].startswith(_voice_prefix):
|
||||
msg["content"] = msg["content"][len(_voice_prefix):]
|
||||
|
||||
# Get the final response
|
||||
response = result.get("final_response", "") if result else ""
|
||||
|
||||
# Handle failed results (e.g., non-retryable errors like invalid model)
|
||||
if result and result.get("failed") and not response:
|
||||
|
||||
# Handle failed or partial results (e.g., non-retryable errors, rate limits,
|
||||
# truncated output, invalid tool calls). Both "failed" and "partial" with
|
||||
# an empty final_response mean the agent couldn't produce a usable answer.
|
||||
if result and (result.get("failed") or result.get("partial")) and not response:
|
||||
error_detail = result.get("error", "Unknown error")
|
||||
response = f"Error: {error_detail}"
|
||||
|
||||
# Stop continuous voice mode on persistent errors (e.g. 429 rate limit)
|
||||
# to avoid an infinite error → record → error loop
|
||||
if self._voice_continuous:
|
||||
self._voice_continuous = False
|
||||
_cprint(f"\n{_DIM}Continuous voice mode stopped due to error.{_RST}")
|
||||
|
||||
# Handle interrupt - check if we were interrupted
|
||||
pending_message = None
|
||||
if result and result.get("interrupted"):
|
||||
|
|
@ -3831,8 +4331,9 @@ class HermesCLI:
|
|||
# Add indicator that we were interrupted
|
||||
if response and pending_message:
|
||||
response = response + "\n\n---\n_[Interrupted - processing new message]_"
|
||||
|
||||
|
||||
response_previewed = result.get("response_previewed", False) if result else False
|
||||
|
||||
# Display reasoning (thinking) box if enabled and available
|
||||
if self.show_reasoning and result:
|
||||
reasoning = result.get("last_reasoning")
|
||||
|
|
@ -3852,8 +4353,7 @@ class HermesCLI:
|
|||
_cprint(f"\n{r_top}\n{_DIM}{display_reasoning}{_RST}\n{r_bot}")
|
||||
|
||||
if response and not response_previewed:
|
||||
# Use a Rich Panel for the response box — adapts to terminal
|
||||
# width at render time instead of hard-coding border length.
|
||||
# Use skin engine for label/color with fallback
|
||||
try:
|
||||
from hermes_cli.skin_engine import get_active_skin
|
||||
_skin = get_active_skin()
|
||||
|
|
@ -3865,23 +4365,40 @@ class HermesCLI:
|
|||
_resp_color = "#CD7F32"
|
||||
_resp_text = "#FFF8DC"
|
||||
|
||||
_chat_console = ChatConsole()
|
||||
_chat_console.print(Panel(
|
||||
_rich_text_from_ansi(response),
|
||||
title=f"[{_resp_color} bold]{label}[/]",
|
||||
title_align="left",
|
||||
border_style=_resp_color,
|
||||
style=_resp_text,
|
||||
box=rich_box.HORIZONTALS,
|
||||
padding=(1, 2),
|
||||
))
|
||||
is_error_response = result and (result.get("failed") or result.get("partial"))
|
||||
if use_streaming_tts and _streaming_box_opened and not is_error_response:
|
||||
# Text was already printed sentence-by-sentence; just close the box
|
||||
w = shutil.get_terminal_size().columns
|
||||
_cprint(f"\n{_GOLD}╰{'─' * (w - 2)}╯{_RST}")
|
||||
else:
|
||||
_chat_console = ChatConsole()
|
||||
_chat_console.print(Panel(
|
||||
_rich_text_from_ansi(response),
|
||||
title=f"[{_resp_color} bold]{label}[/]",
|
||||
title_align="left",
|
||||
border_style=_resp_color,
|
||||
style=_resp_text,
|
||||
box=rich_box.HORIZONTALS,
|
||||
padding=(1, 2),
|
||||
))
|
||||
|
||||
|
||||
# Play terminal bell when agent finishes (if enabled).
|
||||
# Works over SSH — the bell propagates to the user's terminal.
|
||||
if self.bell_on_complete:
|
||||
sys.stdout.write("\a")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
# Speak response aloud if voice TTS is enabled
|
||||
# Skip batch TTS when streaming TTS already handled it
|
||||
if self._voice_tts and response and not use_streaming_tts:
|
||||
threading.Thread(
|
||||
target=self._voice_speak_response,
|
||||
args=(response,),
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
|
||||
# Combine all interrupt messages (user may have typed multiple while waiting)
|
||||
# and re-queue as one prompt for process_loop
|
||||
if pending_message and hasattr(self, '_pending_input'):
|
||||
|
|
@ -3902,6 +4419,20 @@ class HermesCLI:
|
|||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return None
|
||||
finally:
|
||||
# Ensure streaming TTS resources are cleaned up even on error.
|
||||
# Normal path sends the sentinel at line ~3568; this is a safety
|
||||
# net for exception paths that skip it. Duplicate sentinels are
|
||||
# harmless — stream_tts_to_speaker exits on the first None.
|
||||
if text_queue is not None:
|
||||
try:
|
||||
text_queue.put_nowait(None)
|
||||
except Exception:
|
||||
pass
|
||||
if stop_event is not None:
|
||||
stop_event.set()
|
||||
if tts_thread is not None and tts_thread.is_alive():
|
||||
tts_thread.join(timeout=5)
|
||||
|
||||
def _print_exit_summary(self):
|
||||
"""Print session resume info on exit, similar to Claude Code."""
|
||||
|
|
@ -3961,9 +4492,26 @@ class HermesCLI:
|
|||
# Icon-only custom prompts should still remain visible in special states.
|
||||
return symbol, symbol
|
||||
|
||||
def _audio_level_bar(self) -> str:
|
||||
"""Return a visual audio level indicator based on current RMS."""
|
||||
_LEVEL_BARS = " ▁▂▃▄▅▆▇"
|
||||
rec = getattr(self, "_voice_recorder", None)
|
||||
if rec is None:
|
||||
return ""
|
||||
rms = rec.current_rms
|
||||
# Normalize RMS (0-32767) to 0-7 index, with log-ish scaling
|
||||
# Typical speech RMS is 500-5000, we cap display at ~8000
|
||||
level = min(rms, 8000) * 7 // 8000
|
||||
return _LEVEL_BARS[level]
|
||||
|
||||
def _get_tui_prompt_fragments(self):
|
||||
"""Return the prompt_toolkit fragments for the current interactive state."""
|
||||
symbol, state_suffix = self._get_tui_prompt_symbols()
|
||||
if self._voice_recording:
|
||||
bar = self._audio_level_bar()
|
||||
return [("class:voice-recording", f"● {bar} {state_suffix}")]
|
||||
if self._voice_processing:
|
||||
return [("class:voice-processing", f"◉ {state_suffix}")]
|
||||
if self._sudo_state:
|
||||
return [("class:sudo-prompt", f"🔐 {state_suffix}")]
|
||||
if self._secret_state:
|
||||
|
|
@ -3978,6 +4526,8 @@ class HermesCLI:
|
|||
return [("class:prompt-working", f"{self._command_spinner_frame()} {state_suffix}")]
|
||||
if self._agent_running:
|
||||
return [("class:prompt-working", f"⚕ {state_suffix}")]
|
||||
if self._voice_mode:
|
||||
return [("class:voice-prompt", f"🎤 {state_suffix}")]
|
||||
return [("class:prompt", symbol)]
|
||||
|
||||
def _get_tui_prompt_text(self) -> str:
|
||||
|
|
@ -4070,6 +4620,17 @@ class HermesCLI:
|
|||
self._attached_images: list[Path] = []
|
||||
self._image_counter = 0
|
||||
|
||||
# Voice mode state (protected by _voice_lock for cross-thread access)
|
||||
self._voice_lock = threading.Lock()
|
||||
self._voice_mode = False # Whether voice mode is enabled
|
||||
self._voice_tts = False # Whether TTS output is enabled
|
||||
self._voice_recorder = None # AudioRecorder instance (lazy init)
|
||||
self._voice_recording = False # Whether currently recording
|
||||
self._voice_processing = False # Whether STT is in progress
|
||||
self._voice_continuous = False # Whether to auto-restart after agent responds
|
||||
self._voice_tts_done = threading.Event() # Signals TTS playback finished
|
||||
self._voice_tts_done.set() # Initially "done" (no TTS pending)
|
||||
|
||||
# Register callbacks so terminal_tool prompts route through our UI
|
||||
set_sudo_password_callback(self._sudo_password_callback)
|
||||
set_approval_callback(self._approval_callback)
|
||||
|
|
@ -4254,6 +4815,7 @@ class HermesCLI:
|
|||
"""Handle Ctrl+C - cancel interactive prompts, interrupt agent, or exit.
|
||||
|
||||
Priority:
|
||||
0. Cancel active voice recording
|
||||
1. Cancel active sudo/approval/clarify prompt
|
||||
2. Interrupt the running agent (first press)
|
||||
3. Force exit (second press within 2s, or when idle)
|
||||
|
|
@ -4261,6 +4823,25 @@ class HermesCLI:
|
|||
import time as _time
|
||||
now = _time.time()
|
||||
|
||||
# Cancel active voice recording.
|
||||
# Run cancel() in a background thread to prevent blocking the
|
||||
# event loop if AudioRecorder._lock or CoreAudio takes time.
|
||||
_should_cancel_voice = False
|
||||
_recorder_ref = None
|
||||
with cli_ref._voice_lock:
|
||||
if cli_ref._voice_recording and cli_ref._voice_recorder:
|
||||
_recorder_ref = cli_ref._voice_recorder
|
||||
cli_ref._voice_recording = False
|
||||
cli_ref._voice_continuous = False
|
||||
_should_cancel_voice = True
|
||||
if _should_cancel_voice:
|
||||
_cprint(f"\n{_DIM}Recording cancelled.{_RST}")
|
||||
threading.Thread(
|
||||
target=_recorder_ref.cancel, daemon=True
|
||||
).start()
|
||||
event.app.invalidate()
|
||||
return
|
||||
|
||||
# Cancel sudo prompt
|
||||
if self._sudo_state:
|
||||
self._sudo_state["response_queue"].put("")
|
||||
|
|
@ -4321,6 +4902,75 @@ class HermesCLI:
|
|||
self._should_exit = True
|
||||
event.app.exit()
|
||||
|
||||
# Voice push-to-talk key: configurable via config.yaml (voice.record_key)
|
||||
# Default: Ctrl+B (avoids conflict with Ctrl+R readline reverse-search)
|
||||
# Config uses "ctrl+b" format; prompt_toolkit expects "c-b" format.
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
_raw_key = load_config().get("voice", {}).get("record_key", "ctrl+b")
|
||||
_voice_key = _raw_key.lower().replace("ctrl+", "c-").replace("alt+", "a-")
|
||||
except Exception:
|
||||
_voice_key = "c-b"
|
||||
|
||||
@kb.add(_voice_key)
|
||||
def handle_voice_record(event):
|
||||
"""Toggle voice recording when voice mode is active.
|
||||
|
||||
IMPORTANT: This handler runs in prompt_toolkit's event-loop thread.
|
||||
Any blocking call here (locks, sd.wait, disk I/O) freezes the
|
||||
entire UI. All heavy work is dispatched to daemon threads.
|
||||
"""
|
||||
if not cli_ref._voice_mode:
|
||||
return
|
||||
# Always allow STOPPING a recording (even when agent is running)
|
||||
if cli_ref._voice_recording:
|
||||
# Manual stop via push-to-talk key: stop continuous mode
|
||||
with cli_ref._voice_lock:
|
||||
cli_ref._voice_continuous = False
|
||||
# Flag clearing is handled atomically inside _voice_stop_and_transcribe
|
||||
event.app.invalidate()
|
||||
threading.Thread(
|
||||
target=cli_ref._voice_stop_and_transcribe,
|
||||
daemon=True,
|
||||
).start()
|
||||
else:
|
||||
# Guard: don't START recording during agent run or interactive prompts
|
||||
if cli_ref._agent_running:
|
||||
return
|
||||
if cli_ref._clarify_state or cli_ref._sudo_state or cli_ref._approval_state:
|
||||
return
|
||||
# Guard: don't start while a previous stop/transcribe cycle is
|
||||
# still running — recorder.stop() holds AudioRecorder._lock and
|
||||
# start() would block the event-loop thread waiting for it.
|
||||
if cli_ref._voice_processing:
|
||||
return
|
||||
|
||||
# Interrupt TTS if playing, so user can start talking.
|
||||
# stop_playback() is fast (just terminates a subprocess).
|
||||
if not cli_ref._voice_tts_done.is_set():
|
||||
try:
|
||||
from tools.voice_mode import stop_playback
|
||||
stop_playback()
|
||||
cli_ref._voice_tts_done.set()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with cli_ref._voice_lock:
|
||||
cli_ref._voice_continuous = True
|
||||
|
||||
# Dispatch to a daemon thread so play_beep(sd.wait),
|
||||
# AudioRecorder.start(lock acquire), and config I/O
|
||||
# never block the prompt_toolkit event loop.
|
||||
def _start_recording():
|
||||
try:
|
||||
cli_ref._voice_start_recording()
|
||||
if hasattr(cli_ref, '_app') and cli_ref._app:
|
||||
cli_ref._app.invalidate()
|
||||
except Exception as e:
|
||||
_cprint(f"\n{_DIM}Voice recording failed: {e}{_RST}")
|
||||
|
||||
threading.Thread(target=_start_recording, daemon=True).start()
|
||||
event.app.invalidate()
|
||||
from prompt_toolkit.keys import Keys
|
||||
|
||||
@kb.add(Keys.BracketedPaste, eager=True)
|
||||
|
|
@ -4460,6 +5110,10 @@ class HermesCLI:
|
|||
return Transformation(fragments=ti.fragments)
|
||||
|
||||
def _get_placeholder():
|
||||
if cli_ref._voice_recording:
|
||||
return "recording... Ctrl+B to stop, Ctrl+C to cancel"
|
||||
if cli_ref._voice_processing:
|
||||
return "transcribing..."
|
||||
if cli_ref._sudo_state:
|
||||
return "type password (hidden), Enter to skip"
|
||||
if cli_ref._secret_state:
|
||||
|
|
@ -4476,6 +5130,8 @@ class HermesCLI:
|
|||
return f"{frame} {status}"
|
||||
if cli_ref._agent_running:
|
||||
return "type a message + Enter to interrupt, Ctrl+C to cancel"
|
||||
if cli_ref._voice_mode:
|
||||
return "type or Ctrl+B to record"
|
||||
return ""
|
||||
|
||||
input_area.control.input_processors.append(_PlaceholderProcessor(_get_placeholder))
|
||||
|
|
@ -4813,6 +5469,24 @@ class HermesCLI:
|
|||
height=Condition(lambda: bool(cli_ref._attached_images)),
|
||||
)
|
||||
|
||||
# Persistent voice mode status bar (visible only when voice mode is on)
|
||||
def _get_voice_status():
|
||||
if cli_ref._voice_recording:
|
||||
return [('class:voice-status-recording', ' ● REC Ctrl+B to stop ')]
|
||||
if cli_ref._voice_processing:
|
||||
return [('class:voice-status', ' ◉ Transcribing... ')]
|
||||
tts = " | TTS on" if cli_ref._voice_tts else ""
|
||||
cont = " | Continuous" if cli_ref._voice_continuous else ""
|
||||
return [('class:voice-status', f' 🎤 Voice mode{tts}{cont} — Ctrl+B to record ')]
|
||||
|
||||
voice_status_bar = ConditionalContainer(
|
||||
Window(
|
||||
FormattedTextControl(_get_voice_status),
|
||||
height=1,
|
||||
),
|
||||
filter=Condition(lambda: cli_ref._voice_mode),
|
||||
)
|
||||
|
||||
# Layout: interactive prompt widgets + ruled input at bottom.
|
||||
# The sudo, approval, and clarify widgets appear above the input when
|
||||
# the corresponding interactive prompt is active.
|
||||
|
|
@ -4829,6 +5503,7 @@ class HermesCLI:
|
|||
image_bar,
|
||||
input_area,
|
||||
input_rule_bot,
|
||||
voice_status_bar,
|
||||
CompletionsMenu(max_height=12, scroll_offset=1),
|
||||
])
|
||||
)
|
||||
|
|
@ -4869,6 +5544,12 @@ class HermesCLI:
|
|||
'approval-cmd': '#AAAAAA italic',
|
||||
'approval-choice': '#AAAAAA',
|
||||
'approval-selected': '#FFD700 bold',
|
||||
# Voice mode
|
||||
'voice-prompt': '#87CEEB',
|
||||
'voice-recording': '#FF4444 bold',
|
||||
'voice-processing': '#FFA500 italic',
|
||||
'voice-status': 'bg:#1a1a2e #87CEEB',
|
||||
'voice-status-recording': 'bg:#1a1a2e #FF4444 bold',
|
||||
}
|
||||
style = PTStyle.from_dict(self._build_tui_style_dict())
|
||||
|
||||
|
|
@ -4961,13 +5642,29 @@ class HermesCLI:
|
|||
# Regular chat - run agent
|
||||
self._agent_running = True
|
||||
app.invalidate() # Refresh status line
|
||||
|
||||
|
||||
try:
|
||||
self.chat(user_input, images=submit_images or None)
|
||||
finally:
|
||||
self._agent_running = False
|
||||
self._spinner_text = ""
|
||||
app.invalidate() # Refresh status line
|
||||
|
||||
# Continuous voice: auto-restart recording after agent responds.
|
||||
# Dispatch to a daemon thread so play_beep (sd.wait) and
|
||||
# AudioRecorder.start (lock acquire) never block process_loop —
|
||||
# otherwise queued user input would stall silently.
|
||||
if self._voice_mode and self._voice_continuous and not self._voice_recording:
|
||||
def _restart_recording():
|
||||
try:
|
||||
if self._voice_tts:
|
||||
self._voice_tts_done.wait(timeout=60)
|
||||
time.sleep(0.3)
|
||||
self._voice_start_recording()
|
||||
app.invalidate()
|
||||
except Exception as e:
|
||||
_cprint(f"{_DIM}Voice auto-restart failed: {e}{_RST}")
|
||||
threading.Thread(target=_restart_recording, daemon=True).start()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
|
@ -4993,6 +5690,19 @@ class HermesCLI:
|
|||
self.agent.flush_memories(self.conversation_history)
|
||||
except Exception:
|
||||
pass
|
||||
# Shut down voice recorder (release persistent audio stream)
|
||||
if hasattr(self, '_voice_recorder') and self._voice_recorder:
|
||||
try:
|
||||
self._voice_recorder.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
self._voice_recorder = None
|
||||
# Clean up old temp voice recordings
|
||||
try:
|
||||
from tools.voice_mode import cleanup_temp_recordings
|
||||
cleanup_temp_recordings()
|
||||
except Exception:
|
||||
pass
|
||||
# Unregister callbacks to avoid dangling references
|
||||
set_sudo_password_callback(None)
|
||||
set_approval_callback(None)
|
||||
|
|
|
|||
|
|
@ -351,6 +351,8 @@ class BasePlatformAdapter(ABC):
|
|||
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
||||
self._active_sessions: Dict[str, asyncio.Event] = {}
|
||||
self._pending_messages: Dict[str, MessageEvent] = {}
|
||||
# Chats where auto-TTS on voice input is disabled (set by /voice off)
|
||||
self._auto_tts_disabled_chats: set = set()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
|
@ -537,6 +539,20 @@ class BasePlatformAdapter(ABC):
|
|||
text = f"{caption}\n{text}"
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
async def play_tts(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Play auto-TTS audio for voice replies.
|
||||
|
||||
Override in subclasses for invisible playback (e.g. Web UI).
|
||||
Default falls back to send_voice (shows audio player).
|
||||
"""
|
||||
return await self.send_voice(chat_id=chat_id, audio_path=audio_path, **kwargs)
|
||||
|
||||
async def send_video(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
|
@ -724,7 +740,43 @@ class BasePlatformAdapter(ABC):
|
|||
if images:
|
||||
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
||||
|
||||
# Send the text portion first (if any remains after extractions)
|
||||
# Auto-TTS: if voice message, generate audio FIRST (before sending text)
|
||||
# Skipped when the chat has voice mode disabled (/voice off)
|
||||
_tts_path = None
|
||||
if (event.message_type == MessageType.VOICE
|
||||
and text_content
|
||||
and not media_files
|
||||
and event.source.chat_id not in self._auto_tts_disabled_chats):
|
||||
try:
|
||||
from tools.tts_tool import text_to_speech_tool, check_tts_requirements
|
||||
if check_tts_requirements():
|
||||
import json as _json
|
||||
speech_text = re.sub(r'[*_`#\[\]()]', '', text_content)[:4000].strip()
|
||||
if not speech_text:
|
||||
raise ValueError("Empty text after markdown cleanup")
|
||||
tts_result_str = await asyncio.to_thread(
|
||||
text_to_speech_tool, text=speech_text
|
||||
)
|
||||
tts_data = _json.loads(tts_result_str)
|
||||
_tts_path = tts_data.get("file_path")
|
||||
except Exception as tts_err:
|
||||
logger.warning("[%s] Auto-TTS failed: %s", self.name, tts_err)
|
||||
|
||||
# Play TTS audio before text (voice-first experience)
|
||||
if _tts_path and Path(_tts_path).exists():
|
||||
try:
|
||||
await self.play_tts(
|
||||
chat_id=event.source.chat_id,
|
||||
audio_path=_tts_path,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
os.remove(_tts_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Send the text portion
|
||||
if text_content:
|
||||
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
||||
result = await self.send(
|
||||
|
|
@ -733,7 +785,7 @@ class BasePlatformAdapter(ABC):
|
|||
reply_to=event.message_id,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
|
||||
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
|
|
@ -746,10 +798,10 @@ class BasePlatformAdapter(ABC):
|
|||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
|
||||
# Human-like pacing delay between text and media
|
||||
human_delay = self._get_human_delay()
|
||||
|
||||
|
||||
# Send extracted images as native attachments
|
||||
if images:
|
||||
logger.info("[%s] Extracted %d image(s) to send as attachments", self.name, len(images))
|
||||
|
|
@ -777,7 +829,7 @@ class BasePlatformAdapter(ABC):
|
|||
logger.error("[%s] Failed to send image: %s", self.name, img_result.error)
|
||||
except Exception as img_err:
|
||||
logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True)
|
||||
|
||||
|
||||
# Send extracted media files — route by file type
|
||||
_AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'}
|
||||
_VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.3gp'}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,13 @@ Uses discord.py library for:
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, List, Optional, Any
|
||||
import struct
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -65,6 +71,299 @@ def check_discord_requirements() -> bool:
|
|||
return DISCORD_AVAILABLE
|
||||
|
||||
|
||||
class VoiceReceiver:
|
||||
"""Captures and decodes voice audio from a Discord voice channel.
|
||||
|
||||
Attaches to a VoiceClient's socket listener, decrypts RTP packets
|
||||
(NaCl transport + DAVE E2EE), decodes Opus to PCM, and buffers
|
||||
per-user audio. A polling loop detects silence and delivers
|
||||
completed utterances via a callback.
|
||||
"""
|
||||
|
||||
SILENCE_THRESHOLD = 1.5 # seconds of silence → end of utterance
|
||||
MIN_SPEECH_DURATION = 0.5 # minimum seconds to process (skip noise)
|
||||
SAMPLE_RATE = 48000 # Discord native rate
|
||||
CHANNELS = 2 # Discord sends stereo
|
||||
|
||||
def __init__(self, voice_client):
|
||||
self._vc = voice_client
|
||||
self._running = False
|
||||
|
||||
# Decryption
|
||||
self._secret_key: Optional[bytes] = None
|
||||
self._dave_session = None
|
||||
self._bot_ssrc: int = 0
|
||||
|
||||
# SSRC -> user_id mapping (populated from SPEAKING events)
|
||||
self._ssrc_to_user: Dict[int, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Per-user audio buffers
|
||||
self._buffers: Dict[int, bytearray] = defaultdict(bytearray)
|
||||
self._last_packet_time: Dict[int, float] = {}
|
||||
|
||||
# Opus decoder per SSRC (each user needs own decoder state)
|
||||
self._decoders: Dict[int, object] = {}
|
||||
|
||||
# Pause flag: don't capture while bot is playing TTS
|
||||
self._paused = False
|
||||
|
||||
# Debug logging counter (instance-level to avoid cross-instance races)
|
||||
self._packet_debug_count = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def start(self):
|
||||
"""Start listening for voice packets."""
|
||||
conn = self._vc._connection
|
||||
self._secret_key = bytes(conn.secret_key)
|
||||
self._dave_session = conn.dave_session
|
||||
self._bot_ssrc = conn.ssrc
|
||||
|
||||
self._install_speaking_hook(conn)
|
||||
conn.add_socket_listener(self._on_packet)
|
||||
self._running = True
|
||||
logger.info("VoiceReceiver started (bot_ssrc=%d)", self._bot_ssrc)
|
||||
|
||||
def stop(self):
|
||||
"""Stop listening and clean up."""
|
||||
self._running = False
|
||||
try:
|
||||
self._vc._connection.remove_socket_listener(self._on_packet)
|
||||
except Exception:
|
||||
pass
|
||||
with self._lock:
|
||||
self._buffers.clear()
|
||||
self._last_packet_time.clear()
|
||||
self._decoders.clear()
|
||||
self._ssrc_to_user.clear()
|
||||
logger.info("VoiceReceiver stopped")
|
||||
|
||||
def pause(self):
|
||||
self._paused = True
|
||||
|
||||
def resume(self):
|
||||
self._paused = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSRC -> user_id mapping via SPEAKING opcode hook
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def map_ssrc(self, ssrc: int, user_id: int):
|
||||
with self._lock:
|
||||
self._ssrc_to_user[ssrc] = user_id
|
||||
|
||||
def _install_speaking_hook(self, conn):
|
||||
"""Wrap the voice websocket hook to capture SPEAKING events (op 5).
|
||||
|
||||
VoiceConnectionState stores the hook as ``conn.hook`` (public attr).
|
||||
It is passed to DiscordVoiceWebSocket on each (re)connect, so we
|
||||
must wrap it on the VoiceConnectionState level AND on the current
|
||||
live websocket instance.
|
||||
"""
|
||||
original_hook = conn.hook
|
||||
receiver_self = self
|
||||
|
||||
async def wrapped_hook(ws, msg):
|
||||
if isinstance(msg, dict) and msg.get("op") == 5:
|
||||
data = msg.get("d", {})
|
||||
ssrc = data.get("ssrc")
|
||||
user_id = data.get("user_id")
|
||||
if ssrc and user_id:
|
||||
logger.info("SPEAKING event: ssrc=%d -> user=%s", ssrc, user_id)
|
||||
receiver_self.map_ssrc(int(ssrc), int(user_id))
|
||||
if original_hook:
|
||||
await original_hook(ws, msg)
|
||||
|
||||
# Set on connection state (for future reconnects)
|
||||
conn.hook = wrapped_hook
|
||||
# Set on the current live websocket (for immediate effect)
|
||||
try:
|
||||
from discord.utils import MISSING
|
||||
if hasattr(conn, 'ws') and conn.ws is not MISSING:
|
||||
conn.ws._hook = wrapped_hook
|
||||
logger.info("Speaking hook installed on live websocket")
|
||||
except Exception as e:
|
||||
logger.warning("Could not install hook on live ws: %s", e)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Packet handler (called from SocketReader thread)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _on_packet(self, data: bytes):
|
||||
if not self._running or self._paused:
|
||||
return
|
||||
|
||||
# Log first few raw packets for debugging
|
||||
self._packet_debug_count += 1
|
||||
if self._packet_debug_count <= 5:
|
||||
logger.debug(
|
||||
"Raw UDP packet: len=%d, first_bytes=%s",
|
||||
len(data), data[:4].hex() if len(data) >= 4 else "short",
|
||||
)
|
||||
|
||||
if len(data) < 16:
|
||||
return
|
||||
|
||||
# RTP version check: top 2 bits must be 10 (version 2).
|
||||
# Lower bits may vary (padding, extension, CSRC count).
|
||||
# Payload type (byte 1 lower 7 bits) = 0x78 (120) for voice.
|
||||
if (data[0] >> 6) != 2 or (data[1] & 0x7F) != 0x78:
|
||||
if self._packet_debug_count <= 5:
|
||||
logger.debug("Skipped non-RTP: byte0=0x%02x byte1=0x%02x", data[0], data[1])
|
||||
return
|
||||
|
||||
first_byte = data[0]
|
||||
_, _, seq, timestamp, ssrc = struct.unpack_from(">BBHII", data, 0)
|
||||
|
||||
# Skip bot's own audio
|
||||
if ssrc == self._bot_ssrc:
|
||||
return
|
||||
|
||||
# Calculate dynamic RTP header size (RFC 9335 / rtpsize mode)
|
||||
cc = first_byte & 0x0F # CSRC count
|
||||
has_extension = bool(first_byte & 0x10) # extension bit
|
||||
header_size = 12 + (4 * cc) + (4 if has_extension else 0)
|
||||
|
||||
if len(data) < header_size + 4: # need at least header + nonce
|
||||
return
|
||||
|
||||
# Read extension length from preamble (for skipping after decrypt)
|
||||
ext_data_len = 0
|
||||
if has_extension:
|
||||
ext_preamble_offset = 12 + (4 * cc)
|
||||
ext_words = struct.unpack_from(">H", data, ext_preamble_offset + 2)[0]
|
||||
ext_data_len = ext_words * 4
|
||||
|
||||
if self._packet_debug_count <= 10:
|
||||
with self._lock:
|
||||
known_user = self._ssrc_to_user.get(ssrc, "unknown")
|
||||
logger.debug(
|
||||
"RTP packet: ssrc=%d, seq=%d, user=%s, hdr=%d, ext_data=%d",
|
||||
ssrc, seq, known_user, header_size, ext_data_len,
|
||||
)
|
||||
|
||||
header = bytes(data[:header_size])
|
||||
payload_with_nonce = data[header_size:]
|
||||
|
||||
# --- NaCl transport decrypt (aead_xchacha20_poly1305_rtpsize) ---
|
||||
if len(payload_with_nonce) < 4:
|
||||
return
|
||||
nonce = bytearray(24)
|
||||
nonce[:4] = payload_with_nonce[-4:]
|
||||
encrypted = bytes(payload_with_nonce[:-4])
|
||||
|
||||
try:
|
||||
import nacl.secret # noqa: delayed import – only in voice path
|
||||
box = nacl.secret.Aead(self._secret_key)
|
||||
decrypted = box.decrypt(encrypted, header, bytes(nonce))
|
||||
except Exception as e:
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("NaCl decrypt failed: %s (hdr=%d, enc=%d)", e, header_size, len(encrypted))
|
||||
return
|
||||
|
||||
# Skip encrypted extension data to get the actual opus payload
|
||||
if ext_data_len and len(decrypted) > ext_data_len:
|
||||
decrypted = decrypted[ext_data_len:]
|
||||
|
||||
# --- DAVE E2EE decrypt ---
|
||||
if self._dave_session:
|
||||
with self._lock:
|
||||
user_id = self._ssrc_to_user.get(ssrc, 0)
|
||||
if user_id == 0:
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("DAVE skip: unknown user for ssrc=%d", ssrc)
|
||||
return # unknown user, can't DAVE-decrypt
|
||||
try:
|
||||
import davey
|
||||
decrypted = self._dave_session.decrypt(
|
||||
user_id, davey.MediaType.audio, decrypted
|
||||
)
|
||||
except Exception as e:
|
||||
if self._packet_debug_count <= 10:
|
||||
logger.warning("DAVE decrypt failed for ssrc=%d: %s", ssrc, e)
|
||||
return
|
||||
|
||||
# --- Opus decode -> PCM ---
|
||||
try:
|
||||
if ssrc not in self._decoders:
|
||||
self._decoders[ssrc] = discord.opus.Decoder()
|
||||
pcm = self._decoders[ssrc].decode(decrypted)
|
||||
with self._lock:
|
||||
self._buffers[ssrc].extend(pcm)
|
||||
self._last_packet_time[ssrc] = time.monotonic()
|
||||
except Exception as e:
|
||||
logger.debug("Opus decode error for SSRC %s: %s", ssrc, e)
|
||||
return
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Silence detection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def check_silence(self) -> list:
|
||||
"""Return list of (user_id, pcm_bytes) for completed utterances."""
|
||||
now = time.monotonic()
|
||||
completed = []
|
||||
|
||||
with self._lock:
|
||||
ssrc_user_map = dict(self._ssrc_to_user)
|
||||
ssrc_list = list(self._buffers.keys())
|
||||
|
||||
for ssrc in ssrc_list:
|
||||
last_time = self._last_packet_time.get(ssrc, now)
|
||||
silence_duration = now - last_time
|
||||
buf = self._buffers[ssrc]
|
||||
# 48kHz, 16-bit, stereo = 192000 bytes/sec
|
||||
buf_duration = len(buf) / (self.SAMPLE_RATE * self.CHANNELS * 2)
|
||||
|
||||
if silence_duration >= self.SILENCE_THRESHOLD and buf_duration >= self.MIN_SPEECH_DURATION:
|
||||
user_id = ssrc_user_map.get(ssrc, 0)
|
||||
if user_id:
|
||||
completed.append((user_id, bytes(buf)))
|
||||
self._buffers[ssrc] = bytearray()
|
||||
self._last_packet_time.pop(ssrc, None)
|
||||
elif silence_duration >= self.SILENCE_THRESHOLD * 2:
|
||||
# Stale buffer with no valid user — discard
|
||||
self._buffers.pop(ssrc, None)
|
||||
self._last_packet_time.pop(ssrc, None)
|
||||
|
||||
return completed
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# PCM -> WAV conversion (for Whisper STT)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def pcm_to_wav(pcm_data: bytes, output_path: str,
|
||||
src_rate: int = 48000, src_channels: int = 2):
|
||||
"""Convert raw PCM to 16kHz mono WAV via ffmpeg."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".pcm", delete=False) as f:
|
||||
f.write(pcm_data)
|
||||
pcm_path = f.name
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg", "-y", "-loglevel", "error",
|
||||
"-f", "s16le",
|
||||
"-ar", str(src_rate),
|
||||
"-ac", str(src_channels),
|
||||
"-i", pcm_path,
|
||||
"-ar", "16000",
|
||||
"-ac", "1",
|
||||
output_path,
|
||||
],
|
||||
check=True,
|
||||
timeout=10,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
os.unlink(pcm_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
class DiscordAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Discord bot adapter.
|
||||
|
|
@ -82,17 +381,54 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
# Discord message limits
|
||||
MAX_MESSAGE_LENGTH = 2000
|
||||
|
||||
# Auto-disconnect from voice channel after this many seconds of inactivity
|
||||
VOICE_TIMEOUT = 300
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.DISCORD)
|
||||
self._client: Optional[commands.Bot] = None
|
||||
self._ready_event = asyncio.Event()
|
||||
self._allowed_user_ids: set = set() # For button approval authorization
|
||||
# Voice channel state (per-guild)
|
||||
self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient
|
||||
self._voice_text_channels: Dict[int, int] = {} # guild_id -> text_channel_id
|
||||
self._voice_timeout_tasks: Dict[int, asyncio.Task] = {} # guild_id -> timeout task
|
||||
# Phase 2: voice listening
|
||||
self._voice_receivers: Dict[int, VoiceReceiver] = {} # guild_id -> VoiceReceiver
|
||||
self._voice_listen_tasks: Dict[int, asyncio.Task] = {} # guild_id -> listen loop
|
||||
self._voice_input_callback: Optional[Callable] = None # set by run.py
|
||||
self._on_voice_disconnect: Optional[Callable] = None # set by run.py
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Discord and start receiving events."""
|
||||
if not DISCORD_AVAILABLE:
|
||||
logger.error("[%s] discord.py not installed. Run: pip install discord.py", self.name)
|
||||
return False
|
||||
|
||||
# Load opus codec for voice channel support
|
||||
if not discord.opus.is_loaded():
|
||||
import ctypes.util
|
||||
opus_path = ctypes.util.find_library("opus")
|
||||
# ctypes.util.find_library fails on macOS with Homebrew-installed libs,
|
||||
# so fall back to known Homebrew paths if needed.
|
||||
if not opus_path:
|
||||
import sys
|
||||
_homebrew_paths = (
|
||||
"/opt/homebrew/lib/libopus.dylib", # Apple Silicon
|
||||
"/usr/local/lib/libopus.dylib", # Intel Mac
|
||||
)
|
||||
if sys.platform == "darwin":
|
||||
for _hp in _homebrew_paths:
|
||||
if os.path.isfile(_hp):
|
||||
opus_path = _hp
|
||||
break
|
||||
if opus_path:
|
||||
try:
|
||||
discord.opus.load_opus(opus_path)
|
||||
except Exception:
|
||||
logger.warning("Opus codec found at %s but failed to load", opus_path)
|
||||
if not discord.opus.is_loaded():
|
||||
logger.warning("Opus codec not found — voice channel playback disabled")
|
||||
|
||||
if not self.config.token:
|
||||
logger.error("[%s] No bot token configured", self.name)
|
||||
|
|
@ -105,6 +441,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
intents.dm_messages = True
|
||||
intents.guild_messages = True
|
||||
intents.members = True
|
||||
intents.voice_states = True
|
||||
|
||||
# Create bot
|
||||
self._client = commands.Bot(
|
||||
|
|
@ -158,7 +495,40 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
# "all" falls through to handle_message
|
||||
|
||||
await self._handle_message(message)
|
||||
|
||||
|
||||
@self._client.event
|
||||
async def on_voice_state_update(member, before, after):
|
||||
"""Track voice channel join/leave events."""
|
||||
# Only track channels where the bot is connected
|
||||
bot_guild_ids = set(adapter_self._voice_clients.keys())
|
||||
if not bot_guild_ids:
|
||||
return
|
||||
guild_id = member.guild.id
|
||||
if guild_id not in bot_guild_ids:
|
||||
return
|
||||
# Ignore the bot itself
|
||||
if member == adapter_self._client.user:
|
||||
return
|
||||
|
||||
joined = before.channel is None and after.channel is not None
|
||||
left = before.channel is not None and after.channel is None
|
||||
switched = (
|
||||
before.channel is not None
|
||||
and after.channel is not None
|
||||
and before.channel != after.channel
|
||||
)
|
||||
|
||||
if joined or left or switched:
|
||||
logger.info(
|
||||
"Voice state: %s (%d) %s (guild %d)",
|
||||
member.display_name,
|
||||
member.id,
|
||||
"joined " + after.channel.name if joined
|
||||
else "left " + before.channel.name if left
|
||||
else f"moved {before.channel.name} -> {after.channel.name}",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
# Register slash commands
|
||||
self._register_slash_commands()
|
||||
|
||||
|
|
@ -180,12 +550,19 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Discord."""
|
||||
# Clean up all active voice connections before closing the client
|
||||
for guild_id in list(self._voice_clients.keys()):
|
||||
try:
|
||||
await self.leave_voice_channel(guild_id)
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.debug("[%s] Error leaving voice channel %s: %s", self.name, guild_id, e)
|
||||
|
||||
if self._client:
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.warning("[%s] Error during disconnect: %s", self.name, e, exc_info=True)
|
||||
|
||||
|
||||
self._running = False
|
||||
self._client = None
|
||||
self._ready_event.clear()
|
||||
|
|
@ -287,6 +664,23 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
msg = await channel.send(content=caption if caption else None, file=file)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
async def play_tts(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Play auto-TTS audio.
|
||||
|
||||
When the bot is in a voice channel for this chat's guild, skip the
|
||||
file attachment — the gateway runner plays audio in the VC instead.
|
||||
"""
|
||||
for gid, text_ch_id in self._voice_text_channels.items():
|
||||
if str(text_ch_id) == str(chat_id) and self.is_in_voice_channel(gid):
|
||||
logger.debug("[%s] Skipping play_tts for %s — VC playback handled by runner", self.name, chat_id)
|
||||
return SendResult(success=True)
|
||||
return await self.send_voice(chat_id=chat_id, audio_path=audio_path, **kwargs)
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
|
@ -294,16 +688,356 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send audio as a Discord file attachment."""
|
||||
try:
|
||||
return await self._send_file_attachment(chat_id, audio_path, caption)
|
||||
except FileNotFoundError:
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
import io
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
|
||||
filename = os.path.basename(audio_path)
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
file_data = f.read()
|
||||
|
||||
# Try sending as a native voice message via raw API (flags=8192).
|
||||
try:
|
||||
import base64
|
||||
|
||||
duration_secs = 5.0
|
||||
try:
|
||||
from mutagen.oggopus import OggOpus
|
||||
info = OggOpus(audio_path)
|
||||
duration_secs = info.info.length
|
||||
except Exception:
|
||||
duration_secs = max(1.0, len(file_data) / 2000.0)
|
||||
|
||||
waveform_bytes = bytes([128] * 256)
|
||||
waveform_b64 = base64.b64encode(waveform_bytes).decode()
|
||||
|
||||
import json as _json
|
||||
payload = _json.dumps({
|
||||
"flags": 8192,
|
||||
"attachments": [{
|
||||
"id": "0",
|
||||
"filename": "voice-message.ogg",
|
||||
"duration_secs": round(duration_secs, 2),
|
||||
"waveform": waveform_b64,
|
||||
}],
|
||||
})
|
||||
form = [
|
||||
{"name": "payload_json", "value": payload},
|
||||
{
|
||||
"name": "files[0]",
|
||||
"value": file_data,
|
||||
"filename": "voice-message.ogg",
|
||||
"content_type": "audio/ogg",
|
||||
},
|
||||
]
|
||||
msg_data = await self._client.http.request(
|
||||
discord.http.Route("POST", "/channels/{channel_id}/messages", channel_id=channel.id),
|
||||
form=form,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg_data["id"]))
|
||||
except Exception as voice_err:
|
||||
logger.debug("Voice message flag failed, falling back to file: %s", voice_err)
|
||||
file = discord.File(io.BytesIO(file_data), filename=filename)
|
||||
msg = await channel.send(file=file)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
except Exception as e: # pragma: no cover - defensive logging
|
||||
logger.error("[%s] Failed to send audio, falling back to base adapter: %s", self.name, e, exc_info=True)
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to, metadata=metadata)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Voice channel methods (join / leave / play)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def join_voice_channel(self, channel) -> bool:
|
||||
"""Join a Discord voice channel. Returns True on success."""
|
||||
if not self._client or not DISCORD_AVAILABLE:
|
||||
return False
|
||||
guild_id = channel.guild.id
|
||||
|
||||
# Already connected in this guild?
|
||||
existing = self._voice_clients.get(guild_id)
|
||||
if existing and existing.is_connected():
|
||||
if existing.channel.id == channel.id:
|
||||
self._reset_voice_timeout(guild_id)
|
||||
return True
|
||||
await existing.move_to(channel)
|
||||
self._reset_voice_timeout(guild_id)
|
||||
return True
|
||||
|
||||
vc = await channel.connect()
|
||||
self._voice_clients[guild_id] = vc
|
||||
self._reset_voice_timeout(guild_id)
|
||||
|
||||
# Start voice receiver (Phase 2: listen to users)
|
||||
try:
|
||||
receiver = VoiceReceiver(vc)
|
||||
receiver.start()
|
||||
self._voice_receivers[guild_id] = receiver
|
||||
self._voice_listen_tasks[guild_id] = asyncio.ensure_future(
|
||||
self._voice_listen_loop(guild_id)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Voice receiver failed to start: %s", e)
|
||||
|
||||
return True
|
||||
|
||||
async def leave_voice_channel(self, guild_id: int) -> None:
|
||||
"""Disconnect from the voice channel in a guild."""
|
||||
# Stop voice receiver first
|
||||
receiver = self._voice_receivers.pop(guild_id, None)
|
||||
if receiver:
|
||||
receiver.stop()
|
||||
listen_task = self._voice_listen_tasks.pop(guild_id, None)
|
||||
if listen_task:
|
||||
listen_task.cancel()
|
||||
|
||||
vc = self._voice_clients.pop(guild_id, None)
|
||||
if vc and vc.is_connected():
|
||||
await vc.disconnect()
|
||||
task = self._voice_timeout_tasks.pop(guild_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
self._voice_text_channels.pop(guild_id, None)
|
||||
|
||||
# Maximum seconds to wait for voice playback before giving up
|
||||
PLAYBACK_TIMEOUT = 120
|
||||
|
||||
async def play_in_voice_channel(self, guild_id: int, audio_path: str) -> bool:
|
||||
"""Play an audio file in the connected voice channel."""
|
||||
vc = self._voice_clients.get(guild_id)
|
||||
if not vc or not vc.is_connected():
|
||||
return False
|
||||
|
||||
# Pause voice receiver while playing (echo prevention)
|
||||
receiver = self._voice_receivers.get(guild_id)
|
||||
if receiver:
|
||||
receiver.pause()
|
||||
|
||||
try:
|
||||
# Wait for current playback to finish (with timeout)
|
||||
wait_start = time.monotonic()
|
||||
while vc.is_playing():
|
||||
if time.monotonic() - wait_start > self.PLAYBACK_TIMEOUT:
|
||||
logger.warning("Timed out waiting for previous playback to finish")
|
||||
vc.stop()
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
done = asyncio.Event()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def _after(error):
|
||||
if error:
|
||||
logger.error("Voice playback error: %s", error)
|
||||
loop.call_soon_threadsafe(done.set)
|
||||
|
||||
source = discord.FFmpegPCMAudio(audio_path)
|
||||
source = discord.PCMVolumeTransformer(source, volume=1.0)
|
||||
vc.play(source, after=_after)
|
||||
try:
|
||||
await asyncio.wait_for(done.wait(), timeout=self.PLAYBACK_TIMEOUT)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Voice playback timed out after %ds", self.PLAYBACK_TIMEOUT)
|
||||
vc.stop()
|
||||
self._reset_voice_timeout(guild_id)
|
||||
return True
|
||||
finally:
|
||||
if receiver:
|
||||
receiver.resume()
|
||||
|
||||
async def get_user_voice_channel(self, guild_id: int, user_id: str):
|
||||
"""Return the voice channel the user is currently in, or None."""
|
||||
if not self._client:
|
||||
return None
|
||||
guild = self._client.get_guild(guild_id)
|
||||
if not guild:
|
||||
return None
|
||||
member = guild.get_member(int(user_id))
|
||||
if not member or not member.voice:
|
||||
return None
|
||||
return member.voice.channel
|
||||
|
||||
def _reset_voice_timeout(self, guild_id: int) -> None:
|
||||
"""Reset the auto-disconnect inactivity timer."""
|
||||
task = self._voice_timeout_tasks.pop(guild_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
self._voice_timeout_tasks[guild_id] = asyncio.ensure_future(
|
||||
self._voice_timeout_handler(guild_id)
|
||||
)
|
||||
|
||||
async def _voice_timeout_handler(self, guild_id: int) -> None:
|
||||
"""Auto-disconnect after VOICE_TIMEOUT seconds of inactivity."""
|
||||
try:
|
||||
await asyncio.sleep(self.VOICE_TIMEOUT)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
text_ch_id = self._voice_text_channels.get(guild_id)
|
||||
await self.leave_voice_channel(guild_id)
|
||||
# Notify the runner so it can clean up voice_mode state
|
||||
if self._on_voice_disconnect and text_ch_id:
|
||||
try:
|
||||
self._on_voice_disconnect(str(text_ch_id))
|
||||
except Exception:
|
||||
pass
|
||||
if text_ch_id and self._client:
|
||||
ch = self._client.get_channel(text_ch_id)
|
||||
if ch:
|
||||
try:
|
||||
await ch.send("Left voice channel (inactivity timeout).")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def is_in_voice_channel(self, guild_id: int) -> bool:
|
||||
"""Check if the bot is connected to a voice channel in this guild."""
|
||||
vc = self._voice_clients.get(guild_id)
|
||||
return vc is not None and vc.is_connected()
|
||||
|
||||
def get_voice_channel_info(self, guild_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""Return voice channel awareness info for the given guild.
|
||||
|
||||
Returns None if the bot is not in a voice channel. Otherwise
|
||||
returns a dict with channel name, member list, count, and
|
||||
currently-speaking user IDs (from SSRC mapping).
|
||||
"""
|
||||
vc = self._voice_clients.get(guild_id)
|
||||
if not vc or not vc.is_connected():
|
||||
return None
|
||||
|
||||
channel = vc.channel
|
||||
if not channel:
|
||||
return None
|
||||
|
||||
# Members currently in the voice channel (includes bot)
|
||||
members_info = []
|
||||
bot_user = self._client.user if self._client else None
|
||||
for m in channel.members:
|
||||
if bot_user and m.id == bot_user.id:
|
||||
continue # skip the bot itself
|
||||
members_info.append({
|
||||
"user_id": m.id,
|
||||
"display_name": m.display_name,
|
||||
"is_bot": m.bot,
|
||||
})
|
||||
|
||||
# Currently speaking users (from SSRC mapping + active buffers)
|
||||
speaking_user_ids: set = set()
|
||||
receiver = self._voice_receivers.get(guild_id)
|
||||
if receiver:
|
||||
import time as _time
|
||||
now = _time.monotonic()
|
||||
with receiver._lock:
|
||||
for ssrc, last_t in receiver._last_packet_time.items():
|
||||
# Consider "speaking" if audio received within last 2 seconds
|
||||
if now - last_t < 2.0:
|
||||
uid = receiver._ssrc_to_user.get(ssrc)
|
||||
if uid:
|
||||
speaking_user_ids.add(uid)
|
||||
|
||||
# Tag speaking status on members
|
||||
for info in members_info:
|
||||
info["is_speaking"] = info["user_id"] in speaking_user_ids
|
||||
|
||||
return {
|
||||
"channel_name": channel.name,
|
||||
"member_count": len(members_info),
|
||||
"members": members_info,
|
||||
"speaking_count": len(speaking_user_ids),
|
||||
}
|
||||
|
||||
def get_voice_channel_context(self, guild_id: int) -> str:
|
||||
"""Return a human-readable voice channel context string.
|
||||
|
||||
Suitable for injection into the system/ephemeral prompt so the
|
||||
agent is always aware of voice channel state.
|
||||
"""
|
||||
info = self.get_voice_channel_info(guild_id)
|
||||
if not info:
|
||||
return ""
|
||||
|
||||
parts = [f"[Voice channel: #{info['channel_name']} — {info['member_count']} participant(s)]"]
|
||||
for m in info["members"]:
|
||||
status = " (speaking)" if m["is_speaking"] else ""
|
||||
parts.append(f" - {m['display_name']}{status}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Voice listening (Phase 2)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _voice_listen_loop(self, guild_id: int):
|
||||
"""Periodically check for completed utterances and process them."""
|
||||
receiver = self._voice_receivers.get(guild_id)
|
||||
if not receiver:
|
||||
return
|
||||
try:
|
||||
while receiver._running:
|
||||
await asyncio.sleep(0.2)
|
||||
completed = receiver.check_silence()
|
||||
for user_id, pcm_data in completed:
|
||||
if not self._is_allowed_user(str(user_id)):
|
||||
continue
|
||||
await self._process_voice_input(guild_id, user_id, pcm_data)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("Voice listen loop error: %s", e, exc_info=True)
|
||||
|
||||
async def _process_voice_input(self, guild_id: int, user_id: int, pcm_data: bytes):
|
||||
"""Convert PCM -> WAV -> STT -> callback."""
|
||||
from tools.voice_mode import is_whisper_hallucination
|
||||
|
||||
tmp_f = tempfile.NamedTemporaryFile(suffix=".wav", prefix="vc_listen_", delete=False)
|
||||
wav_path = tmp_f.name
|
||||
tmp_f.close()
|
||||
try:
|
||||
await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path)
|
||||
|
||||
from tools.transcription_tools import transcribe_audio, get_stt_model_from_config
|
||||
stt_model = get_stt_model_from_config()
|
||||
result = await asyncio.to_thread(transcribe_audio, wav_path, model=stt_model)
|
||||
|
||||
if not result.get("success"):
|
||||
return
|
||||
transcript = result.get("transcript", "").strip()
|
||||
if not transcript or is_whisper_hallucination(transcript):
|
||||
return
|
||||
|
||||
logger.info("Voice input from user %d: %s", user_id, transcript[:100])
|
||||
|
||||
if self._voice_input_callback:
|
||||
await self._voice_input_callback(
|
||||
guild_id=guild_id,
|
||||
user_id=user_id,
|
||||
transcript=transcript,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Voice input processing failed: %s", e, exc_info=True)
|
||||
finally:
|
||||
try:
|
||||
os.unlink(wav_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _is_allowed_user(self, user_id: str) -> bool:
|
||||
"""Check if user is in DISCORD_ALLOWED_USERS."""
|
||||
if not self._allowed_user_ids:
|
||||
return True
|
||||
return user_id in self._allowed_user_ids
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
|
@ -627,6 +1361,25 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
async def slash_reload_mcp(interaction: discord.Interaction):
|
||||
await self._run_simple_slash(interaction, "/reload-mcp")
|
||||
|
||||
@tree.command(name="voice", description="Toggle voice reply mode")
|
||||
@discord.app_commands.describe(mode="Voice mode: on, off, tts, channel, leave, or status")
|
||||
@discord.app_commands.choices(mode=[
|
||||
discord.app_commands.Choice(name="channel — join your voice channel", value="channel"),
|
||||
discord.app_commands.Choice(name="leave — leave voice channel", value="leave"),
|
||||
discord.app_commands.Choice(name="on — voice reply to voice messages", value="on"),
|
||||
discord.app_commands.Choice(name="tts — voice reply to all messages", value="tts"),
|
||||
discord.app_commands.Choice(name="off — text only", value="off"),
|
||||
discord.app_commands.Choice(name="status — show current mode", value="status"),
|
||||
])
|
||||
async def slash_voice(interaction: discord.Interaction, mode: str = ""):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, f"/voice {mode}".strip())
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="update", description="Update Hermes Agent to the latest version")
|
||||
async def slash_update(interaction: discord.Interaction):
|
||||
await self._run_simple_slash(interaction, "/update", "Update initiated~")
|
||||
|
|
|
|||
|
|
@ -506,6 +506,7 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send an audio file to Slack."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -150,7 +150,10 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
# Start polling in background
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
await self._app.updater.start_polling(allowed_updates=Update.ALL_TYPES)
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=Update.ALL_TYPES,
|
||||
drop_pending_updates=True,
|
||||
)
|
||||
|
||||
# Register bot commands so Telegram shows a hint menu when users type /
|
||||
try:
|
||||
|
|
@ -174,6 +177,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
BotCommand("insights", "Show usage insights and analytics"),
|
||||
BotCommand("update", "Update Hermes to the latest version"),
|
||||
BotCommand("reload_mcp", "Reload MCP servers from config"),
|
||||
BotCommand("voice", "Toggle voice reply mode"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
])
|
||||
except Exception as e:
|
||||
|
|
@ -307,6 +311,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send audio as a native Telegram voice message or audio file."""
|
||||
if not self._bot:
|
||||
|
|
|
|||
400
gateway/run.py
400
gateway/run.py
|
|
@ -14,13 +14,16 @@ Usage:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import sys
|
||||
import signal
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
|
@ -281,6 +284,9 @@ class GatewayRunner:
|
|||
from gateway.hooks import HookRegistry
|
||||
self.hooks = HookRegistry()
|
||||
|
||||
# Per-chat voice reply mode: "off" | "voice_only" | "all"
|
||||
self._voice_mode: Dict[str, str] = self._load_voice_modes()
|
||||
|
||||
def _get_or_create_gateway_honcho(self, session_key: str):
|
||||
"""Return a persistent Honcho manager/config pair for this gateway session."""
|
||||
if not hasattr(self, "_honcho_managers"):
|
||||
|
|
@ -336,6 +342,27 @@ class GatewayRunner:
|
|||
for session_key in list(managers.keys()):
|
||||
self._shutdown_gateway_honcho(session_key)
|
||||
|
||||
# -- Voice mode persistence ------------------------------------------
|
||||
|
||||
_VOICE_MODE_PATH = _hermes_home / "gateway_voice_mode.json"
|
||||
|
||||
def _load_voice_modes(self) -> Dict[str, str]:
|
||||
try:
|
||||
return json.loads(self._VOICE_MODE_PATH.read_text())
|
||||
except (FileNotFoundError, json.JSONDecodeError, OSError):
|
||||
return {}
|
||||
|
||||
def _save_voice_modes(self) -> None:
|
||||
try:
|
||||
self._VOICE_MODE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._VOICE_MODE_PATH.write_text(
|
||||
json.dumps(self._voice_mode, indent=2)
|
||||
)
|
||||
except OSError as e:
|
||||
logger.warning("Failed to save voice modes: %s", e)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
def _flush_memories_for_session(self, old_session_id: str):
|
||||
"""Prompt the agent to save memories/skills before context is lost.
|
||||
|
||||
|
|
@ -737,7 +764,7 @@ class GatewayRunner:
|
|||
logger.info("Stopping gateway...")
|
||||
self._running = False
|
||||
|
||||
for platform, adapter in self.adapters.items():
|
||||
for platform, adapter in list(self.adapters.items()):
|
||||
try:
|
||||
await adapter.disconnect()
|
||||
logger.info("✓ %s disconnected", platform.value)
|
||||
|
|
@ -897,7 +924,7 @@ class GatewayRunner:
|
|||
7. Return response
|
||||
"""
|
||||
source = event.source
|
||||
|
||||
|
||||
# Check if user is authorized
|
||||
if not self._is_user_authorized(source):
|
||||
logger.warning("Unauthorized user: %s (%s) on %s", source.user_id, source.user_name, source.platform.value)
|
||||
|
|
@ -949,7 +976,7 @@ class GatewayRunner:
|
|||
"personality", "retry", "undo", "sethome", "set-home",
|
||||
"compress", "usage", "insights", "reload-mcp", "reload_mcp",
|
||||
"update", "title", "resume", "provider", "rollback",
|
||||
"background", "reasoning"}
|
||||
"background", "reasoning", "voice"}
|
||||
if command and command in _known_commands:
|
||||
await self.hooks.emit(f"command:{command}", {
|
||||
"platform": source.platform.value if source.platform else "",
|
||||
|
|
@ -1020,7 +1047,10 @@ class GatewayRunner:
|
|||
|
||||
if command == "reasoning":
|
||||
return await self._handle_reasoning_command(event)
|
||||
|
||||
|
||||
if command == "voice":
|
||||
return await self._handle_voice_command(event)
|
||||
|
||||
# User-defined quick commands (bypass agent loop, no LLM call)
|
||||
if command:
|
||||
if isinstance(self.config, dict):
|
||||
|
|
@ -1377,6 +1407,19 @@ class GatewayRunner:
|
|||
f"or ignore to skip."
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Voice channel awareness — inject current voice channel state
|
||||
# into context so the agent knows who is in the channel and who
|
||||
# is speaking, without needing a separate tool call.
|
||||
# -----------------------------------------------------------------
|
||||
if source.platform == Platform.DISCORD:
|
||||
adapter = self.adapters.get(Platform.DISCORD)
|
||||
guild_id = self._get_guild_id(event)
|
||||
if guild_id and adapter and hasattr(adapter, "get_voice_channel_context"):
|
||||
vc_context = adapter.get_voice_channel_context(guild_id)
|
||||
if vc_context:
|
||||
context_prompt += f"\n\n{vc_context}"
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Auto-analyze images sent by the user
|
||||
#
|
||||
|
|
@ -1583,7 +1626,11 @@ class GatewayRunner:
|
|||
session_entry.session_key,
|
||||
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
|
||||
)
|
||||
|
||||
|
||||
# Auto voice reply: send TTS audio before the text response
|
||||
if self._should_send_voice_reply(event, response, agent_messages):
|
||||
await self._send_voice_reply(event, response)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -1692,6 +1739,7 @@ class GatewayRunner:
|
|||
"`/reasoning [level|show|hide]` — Set reasoning effort or toggle display",
|
||||
"`/rollback [number]` — List or restore filesystem checkpoints",
|
||||
"`/background <prompt>` — Run a prompt in a separate background session",
|
||||
"`/voice [on|off|tts|status]` — Toggle voice reply mode",
|
||||
"`/reload-mcp` — Reload MCP servers from config",
|
||||
"`/update` — Update Hermes Agent to the latest version",
|
||||
"`/help` — Show this message",
|
||||
|
|
@ -2067,6 +2115,334 @@ class GatewayRunner:
|
|||
f"Cron jobs and cross-platform messages will be delivered here."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_guild_id(event: MessageEvent) -> Optional[int]:
|
||||
"""Extract Discord guild_id from the raw message object."""
|
||||
raw = getattr(event, "raw_message", None)
|
||||
if raw is None:
|
||||
return None
|
||||
# Slash command interaction
|
||||
if hasattr(raw, "guild_id") and raw.guild_id:
|
||||
return int(raw.guild_id)
|
||||
# Regular message
|
||||
if hasattr(raw, "guild") and raw.guild:
|
||||
return raw.guild.id
|
||||
return None
|
||||
|
||||
async def _handle_voice_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /voice [on|off|tts|channel|leave|status] command."""
|
||||
args = event.get_command_args().strip().lower()
|
||||
chat_id = event.source.chat_id
|
||||
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
|
||||
if args in ("on", "enable"):
|
||||
self._voice_mode[chat_id] = "voice_only"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.discard(chat_id)
|
||||
return (
|
||||
"Voice mode enabled.\n"
|
||||
"I'll reply with voice when you send voice messages.\n"
|
||||
"Use /voice tts to get voice replies for all messages."
|
||||
)
|
||||
elif args in ("off", "disable"):
|
||||
self._voice_mode.pop(chat_id, None)
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.add(chat_id)
|
||||
return "Voice mode disabled. Text-only replies."
|
||||
elif args == "tts":
|
||||
self._voice_mode[chat_id] = "all"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.discard(chat_id)
|
||||
return (
|
||||
"Auto-TTS enabled.\n"
|
||||
"All replies will include a voice message."
|
||||
)
|
||||
elif args in ("channel", "join"):
|
||||
return await self._handle_voice_channel_join(event)
|
||||
elif args == "leave":
|
||||
return await self._handle_voice_channel_leave(event)
|
||||
elif args == "status":
|
||||
mode = self._voice_mode.get(chat_id, "off")
|
||||
labels = {
|
||||
"off": "Off (text only)",
|
||||
"voice_only": "On (voice reply to voice messages)",
|
||||
"all": "TTS (voice reply to all messages)",
|
||||
}
|
||||
# Append voice channel info if connected
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
guild_id = self._get_guild_id(event)
|
||||
if guild_id and hasattr(adapter, "get_voice_channel_info"):
|
||||
info = adapter.get_voice_channel_info(guild_id)
|
||||
if info:
|
||||
lines = [
|
||||
f"Voice mode: {labels.get(mode, mode)}",
|
||||
f"Voice channel: #{info['channel_name']}",
|
||||
f"Participants: {info['member_count']}",
|
||||
]
|
||||
for m in info["members"]:
|
||||
status = " (speaking)" if m.get("is_speaking") else ""
|
||||
lines.append(f" - {m['display_name']}{status}")
|
||||
return "\n".join(lines)
|
||||
return f"Voice mode: {labels.get(mode, mode)}"
|
||||
else:
|
||||
# Toggle: off → on, on/all → off
|
||||
current = self._voice_mode.get(chat_id, "off")
|
||||
if current == "off":
|
||||
self._voice_mode[chat_id] = "voice_only"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.discard(chat_id)
|
||||
return "Voice mode enabled."
|
||||
else:
|
||||
self._voice_mode.pop(chat_id, None)
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
adapter._auto_tts_disabled_chats.add(chat_id)
|
||||
return "Voice mode disabled."
|
||||
|
||||
async def _handle_voice_channel_join(self, event: MessageEvent) -> str:
|
||||
"""Join the user's current Discord voice channel."""
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
if not hasattr(adapter, "join_voice_channel"):
|
||||
return "Voice channels are not supported on this platform."
|
||||
|
||||
guild_id = self._get_guild_id(event)
|
||||
if not guild_id:
|
||||
return "This command only works in a Discord server."
|
||||
|
||||
voice_channel = await adapter.get_user_voice_channel(
|
||||
guild_id, event.source.user_id
|
||||
)
|
||||
if not voice_channel:
|
||||
return "You need to be in a voice channel first."
|
||||
|
||||
# Wire callbacks BEFORE join so voice input arriving immediately
|
||||
# after connection is not lost.
|
||||
if hasattr(adapter, "_voice_input_callback"):
|
||||
adapter._voice_input_callback = self._handle_voice_channel_input
|
||||
if hasattr(adapter, "_on_voice_disconnect"):
|
||||
adapter._on_voice_disconnect = self._handle_voice_timeout_cleanup
|
||||
|
||||
try:
|
||||
success = await adapter.join_voice_channel(voice_channel)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to join voice channel: %s", e)
|
||||
adapter._voice_input_callback = None
|
||||
return f"Failed to join voice channel: {e}"
|
||||
|
||||
if success:
|
||||
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
|
||||
self._voice_mode[event.source.chat_id] = "all"
|
||||
self._save_voice_modes()
|
||||
adapter._auto_tts_disabled_chats.discard(event.source.chat_id)
|
||||
return (
|
||||
f"Joined voice channel **{voice_channel.name}**.\n"
|
||||
f"I'll speak my replies and listen to you. Use /voice leave to disconnect."
|
||||
)
|
||||
# Join failed — clear callback
|
||||
adapter._voice_input_callback = None
|
||||
return "Failed to join voice channel. Check bot permissions (Connect + Speak)."
|
||||
|
||||
async def _handle_voice_channel_leave(self, event: MessageEvent) -> str:
|
||||
"""Leave the Discord voice channel."""
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
guild_id = self._get_guild_id(event)
|
||||
|
||||
if not guild_id or not hasattr(adapter, "leave_voice_channel"):
|
||||
return "Not in a voice channel."
|
||||
|
||||
if not hasattr(adapter, "is_in_voice_channel") or not adapter.is_in_voice_channel(guild_id):
|
||||
return "Not in a voice channel."
|
||||
|
||||
try:
|
||||
await adapter.leave_voice_channel(guild_id)
|
||||
except Exception as e:
|
||||
logger.warning("Error leaving voice channel: %s", e)
|
||||
# Always clean up state even if leave raised an exception
|
||||
self._voice_mode.pop(event.source.chat_id, None)
|
||||
self._save_voice_modes()
|
||||
if hasattr(adapter, "_voice_input_callback"):
|
||||
adapter._voice_input_callback = None
|
||||
return "Left voice channel."
|
||||
|
||||
def _handle_voice_timeout_cleanup(self, chat_id: str) -> None:
|
||||
"""Called by the adapter when a voice channel times out.
|
||||
|
||||
Cleans up runner-side voice_mode state that the adapter cannot reach.
|
||||
"""
|
||||
self._voice_mode.pop(chat_id, None)
|
||||
self._save_voice_modes()
|
||||
|
||||
async def _handle_voice_channel_input(
|
||||
self, guild_id: int, user_id: int, transcript: str
|
||||
):
|
||||
"""Handle transcribed voice from a user in a voice channel.
|
||||
|
||||
Creates a synthetic MessageEvent and processes it through the
|
||||
adapter's full message pipeline (session, typing, agent, TTS reply).
|
||||
"""
|
||||
adapter = self.adapters.get(Platform.DISCORD)
|
||||
if not adapter:
|
||||
return
|
||||
|
||||
text_ch_id = adapter._voice_text_channels.get(guild_id)
|
||||
if not text_ch_id:
|
||||
return
|
||||
|
||||
# Check authorization before processing voice input
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id=str(text_ch_id),
|
||||
user_id=str(user_id),
|
||||
user_name=str(user_id),
|
||||
chat_type="channel",
|
||||
)
|
||||
if not self._is_user_authorized(source):
|
||||
logger.debug("Unauthorized voice input from user %d, ignoring", user_id)
|
||||
return
|
||||
|
||||
# Show transcript in text channel (after auth, with mention sanitization)
|
||||
try:
|
||||
channel = adapter._client.get_channel(text_ch_id)
|
||||
if channel:
|
||||
safe_text = transcript[:2000].replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere")
|
||||
await channel.send(f"**[Voice]** <@{user_id}>: {safe_text}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build a synthetic MessageEvent and feed through the normal pipeline
|
||||
# Use SimpleNamespace as raw_message so _get_guild_id() can extract
|
||||
# guild_id and _send_voice_reply() plays audio in the voice channel.
|
||||
from types import SimpleNamespace
|
||||
event = MessageEvent(
|
||||
source=source,
|
||||
text=transcript,
|
||||
message_type=MessageType.VOICE,
|
||||
raw_message=SimpleNamespace(guild_id=guild_id, guild=None),
|
||||
)
|
||||
|
||||
await adapter.handle_message(event)
|
||||
|
||||
def _should_send_voice_reply(
|
||||
self,
|
||||
event: MessageEvent,
|
||||
response: str,
|
||||
agent_messages: list,
|
||||
) -> bool:
|
||||
"""Decide whether the runner should send a TTS voice reply.
|
||||
|
||||
Returns False when:
|
||||
- voice_mode is off for this chat
|
||||
- response is empty or an error
|
||||
- agent already called text_to_speech tool (dedup)
|
||||
- voice input and base adapter auto-TTS already handled it (skip_double)
|
||||
Exception: Discord voice channel — base play_tts is a no-op there,
|
||||
so the runner must handle VC playback.
|
||||
"""
|
||||
if not response or response.startswith("Error:"):
|
||||
return False
|
||||
|
||||
chat_id = event.source.chat_id
|
||||
voice_mode = self._voice_mode.get(chat_id, "off")
|
||||
is_voice_input = (event.message_type == MessageType.VOICE)
|
||||
|
||||
should = (
|
||||
(voice_mode == "all")
|
||||
or (voice_mode == "voice_only" and is_voice_input)
|
||||
)
|
||||
if not should:
|
||||
return False
|
||||
|
||||
# Dedup: agent already called TTS tool
|
||||
has_agent_tts = any(
|
||||
msg.get("role") == "assistant"
|
||||
and any(
|
||||
tc.get("function", {}).get("name") == "text_to_speech"
|
||||
for tc in (msg.get("tool_calls") or [])
|
||||
)
|
||||
for msg in agent_messages
|
||||
)
|
||||
if has_agent_tts:
|
||||
return False
|
||||
|
||||
# Dedup: base adapter auto-TTS already handles voice input.
|
||||
# Exception: Discord voice channel — play_tts override is a no-op,
|
||||
# so the runner must handle VC playback.
|
||||
skip_double = is_voice_input
|
||||
if skip_double:
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
guild_id = self._get_guild_id(event)
|
||||
if (guild_id and adapter
|
||||
and hasattr(adapter, "is_in_voice_channel")
|
||||
and adapter.is_in_voice_channel(guild_id)):
|
||||
skip_double = False
|
||||
if skip_double:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _send_voice_reply(self, event: MessageEvent, text: str) -> None:
|
||||
"""Generate TTS audio and send as a voice message before the text reply."""
|
||||
import uuid as _uuid
|
||||
audio_path = None
|
||||
actual_path = None
|
||||
try:
|
||||
from tools.tts_tool import text_to_speech_tool, _strip_markdown_for_tts
|
||||
|
||||
tts_text = _strip_markdown_for_tts(text[:4000])
|
||||
if not tts_text:
|
||||
return
|
||||
|
||||
# Use .mp3 extension so edge-tts conversion to opus works correctly.
|
||||
# The TTS tool may convert to .ogg — use file_path from result.
|
||||
audio_path = os.path.join(
|
||||
tempfile.gettempdir(), "hermes_voice",
|
||||
f"tts_reply_{_uuid.uuid4().hex[:12]}.mp3",
|
||||
)
|
||||
os.makedirs(os.path.dirname(audio_path), exist_ok=True)
|
||||
|
||||
result_json = await asyncio.to_thread(
|
||||
text_to_speech_tool, text=tts_text, output_path=audio_path
|
||||
)
|
||||
result = json.loads(result_json)
|
||||
|
||||
# Use the actual file path from result (may differ after opus conversion)
|
||||
actual_path = result.get("file_path", audio_path)
|
||||
if not result.get("success") or not os.path.isfile(actual_path):
|
||||
logger.warning("Auto voice reply TTS failed: %s", result.get("error"))
|
||||
return
|
||||
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
|
||||
# If connected to a voice channel, play there instead of sending a file
|
||||
guild_id = self._get_guild_id(event)
|
||||
if (guild_id
|
||||
and hasattr(adapter, "play_in_voice_channel")
|
||||
and hasattr(adapter, "is_in_voice_channel")
|
||||
and adapter.is_in_voice_channel(guild_id)):
|
||||
await adapter.play_in_voice_channel(guild_id, actual_path)
|
||||
elif adapter and hasattr(adapter, "send_voice"):
|
||||
send_kwargs: Dict[str, Any] = {
|
||||
"chat_id": event.source.chat_id,
|
||||
"audio_path": actual_path,
|
||||
"reply_to": event.message_id,
|
||||
}
|
||||
if event.source.thread_id:
|
||||
send_kwargs["metadata"] = {"thread_id": event.source.thread_id}
|
||||
await adapter.send_voice(**send_kwargs)
|
||||
except Exception as e:
|
||||
logger.warning("Auto voice reply failed: %s", e, exc_info=True)
|
||||
finally:
|
||||
for p in {audio_path, actual_path} - {None}:
|
||||
try:
|
||||
os.unlink(p)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def _handle_rollback_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /rollback command — list or restore filesystem checkpoints."""
|
||||
from tools.checkpoint_manager import CheckpointManager, format_checkpoint_list
|
||||
|
|
@ -3011,14 +3387,16 @@ class GatewayRunner:
|
|||
Returns:
|
||||
The enriched message string with transcriptions prepended.
|
||||
"""
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
from tools.transcription_tools import transcribe_audio, get_stt_model_from_config
|
||||
import asyncio
|
||||
|
||||
stt_model = get_stt_model_from_config()
|
||||
|
||||
enriched_parts = []
|
||||
for path in audio_paths:
|
||||
try:
|
||||
logger.debug("Transcribing user voice: %s", path)
|
||||
result = await asyncio.to_thread(transcribe_audio, path)
|
||||
result = await asyncio.to_thread(transcribe_audio, path, model=stt_model)
|
||||
if result["success"]:
|
||||
transcript = result["transcript"]
|
||||
enriched_parts.append(
|
||||
|
|
@ -3027,10 +3405,10 @@ class GatewayRunner:
|
|||
)
|
||||
else:
|
||||
error = result.get("error", "unknown error")
|
||||
if "OPENAI_API_KEY" in error or "VOICE_TOOLS_OPENAI_KEY" in error:
|
||||
if "No STT provider" in error or "not set" in error:
|
||||
enriched_parts.append(
|
||||
"[The user sent a voice message but I can't listen "
|
||||
"to it right now~ VOICE_TOOLS_OPENAI_KEY isn't set up yet "
|
||||
"to it right now~ No STT provider is configured "
|
||||
"(';w;') Let them know!]"
|
||||
)
|
||||
else:
|
||||
|
|
@ -3180,7 +3558,7 @@ class GatewayRunner:
|
|||
Platform.HOMEASSISTANT: "hermes-homeassistant",
|
||||
Platform.EMAIL: "hermes-email",
|
||||
}
|
||||
|
||||
|
||||
# Try to load platform_toolsets from config
|
||||
platform_toolsets_config = {}
|
||||
try:
|
||||
|
|
@ -3192,7 +3570,7 @@ class GatewayRunner:
|
|||
platform_toolsets_config = user_config.get("platform_toolsets", {})
|
||||
except Exception as e:
|
||||
logger.debug("Could not load platform_toolsets config: %s", e)
|
||||
|
||||
|
||||
# Map platform enum to config key
|
||||
platform_config_key = {
|
||||
Platform.LOCAL: "cli",
|
||||
|
|
|
|||
|
|
@ -383,7 +383,11 @@ class SessionStore:
|
|||
with open(sessions_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
for key, entry_data in data.items():
|
||||
self._entries[key] = SessionEntry.from_dict(entry_data)
|
||||
try:
|
||||
self._entries[key] = SessionEntry.from_dict(entry_data)
|
||||
except (ValueError, KeyError):
|
||||
# Skip entries with unknown/removed platform values
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to load sessions: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ COMMANDS_BY_CATEGORY = {
|
|||
"/verbose": "Cycle tool progress display: off → new → all → verbose",
|
||||
"/reasoning": "Manage reasoning effort and display (usage: /reasoning [level|show|hide])",
|
||||
"/skin": "Show or change the display skin/theme",
|
||||
"/voice": "Toggle voice mode (Ctrl+B to record). Usage: /voice [on|off|tts|status]",
|
||||
},
|
||||
"Tools & Skills": {
|
||||
"/tools": "List available tools",
|
||||
|
|
|
|||
|
|
@ -202,6 +202,14 @@ DEFAULT_CONFIG = {
|
|||
"model": "whisper-1", # whisper-1, gpt-4o-mini-transcribe, gpt-4o-transcribe
|
||||
},
|
||||
},
|
||||
|
||||
"voice": {
|
||||
"record_key": "ctrl+b",
|
||||
"max_recording_seconds": 120,
|
||||
"auto_tts": False,
|
||||
"silence_threshold": 200, # RMS below this = silence (0-32767)
|
||||
"silence_duration": 3.0, # Seconds of silence before auto-stop
|
||||
},
|
||||
|
||||
"human_delay": {
|
||||
"mode": "off",
|
||||
|
|
|
|||
|
|
@ -43,11 +43,12 @@ dependencies = [
|
|||
modal = ["swe-rex[modal]>=1.4.0"]
|
||||
daytona = ["daytona>=0.148.0"]
|
||||
dev = ["pytest", "pytest-asyncio", "pytest-xdist", "mcp>=1.2.0"]
|
||||
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
|
||||
messaging = ["python-telegram-bot>=20.0", "discord.py[voice]>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
|
||||
cron = ["croniter"]
|
||||
slack = ["slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
|
||||
cli = ["simple-term-menu"]
|
||||
tts-premium = ["elevenlabs"]
|
||||
voice = ["sounddevice>=0.4.6", "numpy>=1.24.0"]
|
||||
pty = [
|
||||
"ptyprocess>=0.7.0; sys_platform != 'win32'",
|
||||
"pywinpty>=2.0.0; sys_platform == 'win32'",
|
||||
|
|
@ -78,6 +79,7 @@ all = [
|
|||
"hermes-agent[mcp]",
|
||||
"hermes-agent[homeassistant]",
|
||||
"hermes-agent[acp]",
|
||||
"hermes-agent[voice]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
365
run_agent.py
365
run_agent.py
|
|
@ -493,6 +493,10 @@ class AIAgent:
|
|||
]:
|
||||
logging.getLogger(quiet_logger).setLevel(logging.ERROR)
|
||||
|
||||
# Internal stream callback (set during streaming TTS).
|
||||
# Initialized here so _vprint can reference it before run_conversation.
|
||||
self._stream_callback = None
|
||||
|
||||
# Initialize LLM client via centralized provider router.
|
||||
# The router handles auth resolution, base URL, headers, and
|
||||
# Codex/Anthropic wrapping for all known providers.
|
||||
|
|
@ -504,6 +508,7 @@ class AIAgent:
|
|||
from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token
|
||||
effective_key = api_key or resolve_anthropic_token() or ""
|
||||
self._anthropic_api_key = effective_key
|
||||
self._anthropic_base_url = base_url
|
||||
self._anthropic_client = build_anthropic_client(effective_key, base_url)
|
||||
# No OpenAI client needed for Anthropic mode
|
||||
self.client = None
|
||||
|
|
@ -812,6 +817,16 @@ class AIAgent:
|
|||
else:
|
||||
print(f"📊 Context limit: {self.context_compressor.context_length:,} tokens (auto-compression disabled)")
|
||||
|
||||
def _vprint(self, *args, force: bool = False, **kwargs):
|
||||
"""Verbose print — suppressed when streaming TTS is active.
|
||||
|
||||
Pass ``force=True`` for error/warning messages that should always be
|
||||
shown even during streaming TTS playback.
|
||||
"""
|
||||
if not force and getattr(self, "_stream_callback", None) is not None:
|
||||
return
|
||||
print(*args, **kwargs)
|
||||
|
||||
def _max_tokens_param(self, value: int) -> dict:
|
||||
"""Return the correct max tokens kwarg for the current provider.
|
||||
|
||||
|
|
@ -1340,7 +1355,7 @@ class AIAgent:
|
|||
encoding="utf-8",
|
||||
)
|
||||
|
||||
print(f"{self.log_prefix}🧾 Request debug dump written to: {dump_file}")
|
||||
self._vprint(f"{self.log_prefix}🧾 Request debug dump written to: {dump_file}")
|
||||
|
||||
if os.getenv("HERMES_DUMP_REQUEST_STDOUT", "").strip().lower() in {"1", "true", "yes", "on"}:
|
||||
print(json.dumps(dump_payload, ensure_ascii=False, indent=2, default=str))
|
||||
|
|
@ -1482,7 +1497,7 @@ class AIAgent:
|
|||
# Replay the items into the store (replace mode)
|
||||
self._todo_store.write(last_todo_response, merge=False)
|
||||
if not self.quiet_mode:
|
||||
print(f"{self.log_prefix}📋 Restored {len(last_todo_response)} todo item(s) from history")
|
||||
self._vprint(f"{self.log_prefix}📋 Restored {len(last_todo_response)} todo item(s) from history")
|
||||
_set_interrupt(False)
|
||||
|
||||
@property
|
||||
|
|
@ -2576,7 +2591,7 @@ class AIAgent:
|
|||
"""
|
||||
Run the API call in a background thread so the main conversation loop
|
||||
can detect interrupts without waiting for the full HTTP round-trip.
|
||||
|
||||
|
||||
On interrupt, closes the HTTP client to cancel the in-flight request
|
||||
(stops token generation and avoids wasting money), then rebuilds the
|
||||
client for future calls.
|
||||
|
|
@ -2611,7 +2626,139 @@ class AIAgent:
|
|||
try:
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_client
|
||||
self._anthropic_client = build_anthropic_client(self._anthropic_api_key)
|
||||
self._anthropic_client = build_anthropic_client(self._anthropic_api_key, getattr(self, "_anthropic_base_url", None))
|
||||
else:
|
||||
self.client = OpenAI(**self._client_kwargs)
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during API call")
|
||||
if result["error"] is not None:
|
||||
raise result["error"]
|
||||
return result["response"]
|
||||
|
||||
def _streaming_api_call(self, api_kwargs: dict, stream_callback):
|
||||
"""Streaming variant of _interruptible_api_call for voice TTS pipeline.
|
||||
|
||||
Uses ``stream=True`` and forwards content deltas to *stream_callback*
|
||||
in real-time. Returns a ``SimpleNamespace`` that mimics a normal
|
||||
``ChatCompletion`` so the rest of the agent loop works unchanged.
|
||||
|
||||
This method is separate from ``_interruptible_api_call`` to keep the
|
||||
core agent loop untouched for non-voice users.
|
||||
"""
|
||||
result = {"response": None, "error": None}
|
||||
|
||||
def _call():
|
||||
try:
|
||||
stream_kwargs = {**api_kwargs, "stream": True}
|
||||
stream = self.client.chat.completions.create(**stream_kwargs)
|
||||
|
||||
content_parts: list[str] = []
|
||||
tool_calls_acc: dict[int, dict] = {}
|
||||
finish_reason = None
|
||||
model_name = None
|
||||
role = "assistant"
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk.choices:
|
||||
if hasattr(chunk, "model") and chunk.model:
|
||||
model_name = chunk.model
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if hasattr(chunk, "model") and chunk.model:
|
||||
model_name = chunk.model
|
||||
|
||||
if delta and delta.content:
|
||||
content_parts.append(delta.content)
|
||||
try:
|
||||
stream_callback(delta.content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if delta and delta.tool_calls:
|
||||
for tc_delta in delta.tool_calls:
|
||||
idx = tc_delta.index if tc_delta.index is not None else 0
|
||||
if idx in tool_calls_acc and tc_delta.id and tc_delta.id != tool_calls_acc[idx]["id"]:
|
||||
matched = False
|
||||
for eidx, eentry in tool_calls_acc.items():
|
||||
if eentry["id"] == tc_delta.id:
|
||||
idx = eidx
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
idx = (max(k for k in tool_calls_acc if isinstance(k, int)) + 1) if tool_calls_acc else 0
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {
|
||||
"id": tc_delta.id or "",
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
entry = tool_calls_acc[idx]
|
||||
if tc_delta.id:
|
||||
entry["id"] = tc_delta.id
|
||||
if tc_delta.function:
|
||||
if tc_delta.function.name:
|
||||
entry["function"]["name"] += tc_delta.function.name
|
||||
if tc_delta.function.arguments:
|
||||
entry["function"]["arguments"] += tc_delta.function.arguments
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
full_content = "".join(content_parts) or None
|
||||
mock_tool_calls = None
|
||||
if tool_calls_acc:
|
||||
mock_tool_calls = []
|
||||
for idx in sorted(tool_calls_acc):
|
||||
tc = tool_calls_acc[idx]
|
||||
mock_tool_calls.append(SimpleNamespace(
|
||||
id=tc["id"],
|
||||
type=tc["type"],
|
||||
function=SimpleNamespace(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
))
|
||||
|
||||
mock_message = SimpleNamespace(
|
||||
role=role,
|
||||
content=full_content,
|
||||
tool_calls=mock_tool_calls,
|
||||
reasoning_content=None,
|
||||
)
|
||||
mock_choice = SimpleNamespace(
|
||||
index=0,
|
||||
message=mock_message,
|
||||
finish_reason=finish_reason or "stop",
|
||||
)
|
||||
mock_response = SimpleNamespace(
|
||||
id="stream-" + str(uuid.uuid4()),
|
||||
model=model_name,
|
||||
choices=[mock_choice],
|
||||
usage=None,
|
||||
)
|
||||
result["response"] = mock_response
|
||||
|
||||
except Exception as e:
|
||||
result["error"] = e
|
||||
|
||||
t = threading.Thread(target=_call, daemon=True)
|
||||
t.start()
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.3)
|
||||
if self._interrupt_requested:
|
||||
try:
|
||||
if self.api_mode == "anthropic_messages":
|
||||
self._anthropic_client.close()
|
||||
else:
|
||||
self.client.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_client
|
||||
self._anthropic_client = build_anthropic_client(self._anthropic_api_key, getattr(self, "_anthropic_base_url", None))
|
||||
else:
|
||||
self.client = OpenAI(**self._client_kwargs)
|
||||
except Exception:
|
||||
|
|
@ -2677,7 +2824,8 @@ class AIAgent:
|
|||
from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token
|
||||
effective_key = fb_client.api_key or resolve_anthropic_token() or ""
|
||||
self._anthropic_api_key = effective_key
|
||||
self._anthropic_client = build_anthropic_client(effective_key)
|
||||
self._anthropic_base_url = getattr(fb_client, "base_url", None)
|
||||
self._anthropic_client = build_anthropic_client(effective_key, self._anthropic_base_url)
|
||||
self.client = None
|
||||
self._client_kwargs = {}
|
||||
else:
|
||||
|
|
@ -3479,7 +3627,7 @@ class AIAgent:
|
|||
if self._interrupt_requested:
|
||||
remaining_calls = assistant_message.tool_calls[i-1:]
|
||||
if remaining_calls:
|
||||
print(f"{self.log_prefix}⚡ Interrupt: skipping {len(remaining_calls)} tool call(s)")
|
||||
self._vprint(f"{self.log_prefix}⚡ Interrupt: skipping {len(remaining_calls)} tool call(s)", force=True)
|
||||
for skipped_tc in remaining_calls:
|
||||
skipped_name = skipped_tc.function.name
|
||||
skip_msg = {
|
||||
|
|
@ -3541,7 +3689,7 @@ class AIAgent:
|
|||
)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if self.quiet_mode:
|
||||
print(f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}")
|
||||
self._vprint(f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}")
|
||||
elif function_name == "session_search":
|
||||
if not self._session_db:
|
||||
function_result = json.dumps({"success": False, "error": "Session database not available."})
|
||||
|
|
@ -3556,7 +3704,7 @@ class AIAgent:
|
|||
)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if self.quiet_mode:
|
||||
print(f" {_get_cute_tool_message_impl('session_search', function_args, tool_duration, result=function_result)}")
|
||||
self._vprint(f" {_get_cute_tool_message_impl('session_search', function_args, tool_duration, result=function_result)}")
|
||||
elif function_name == "memory":
|
||||
target = function_args.get("target", "memory")
|
||||
from tools.memory_tool import memory_tool as _memory_tool
|
||||
|
|
@ -3572,7 +3720,7 @@ class AIAgent:
|
|||
self._honcho_save_user_observation(function_args.get("content", ""))
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if self.quiet_mode:
|
||||
print(f" {_get_cute_tool_message_impl('memory', function_args, tool_duration, result=function_result)}")
|
||||
self._vprint(f" {_get_cute_tool_message_impl('memory', function_args, tool_duration, result=function_result)}")
|
||||
elif function_name == "clarify":
|
||||
from tools.clarify_tool import clarify_tool as _clarify_tool
|
||||
function_result = _clarify_tool(
|
||||
|
|
@ -3582,7 +3730,7 @@ class AIAgent:
|
|||
)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if self.quiet_mode:
|
||||
print(f" {_get_cute_tool_message_impl('clarify', function_args, tool_duration, result=function_result)}")
|
||||
self._vprint(f" {_get_cute_tool_message_impl('clarify', function_args, tool_duration, result=function_result)}")
|
||||
elif function_name == "delegate_task":
|
||||
from tools.delegate_tool import delegate_task as _delegate_task
|
||||
tasks_arg = function_args.get("tasks")
|
||||
|
|
@ -3615,8 +3763,8 @@ class AIAgent:
|
|||
if spinner:
|
||||
spinner.stop(cute_msg)
|
||||
elif self.quiet_mode:
|
||||
print(f" {cute_msg}")
|
||||
elif self.quiet_mode:
|
||||
self._vprint(f" {cute_msg}")
|
||||
elif self.quiet_mode and self._stream_callback is None:
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
tool_emoji_map = {
|
||||
'web_search': '🔍', 'web_extract': '📄', 'web_crawl': '🕸️',
|
||||
|
|
@ -3703,7 +3851,7 @@ class AIAgent:
|
|||
|
||||
if self._interrupt_requested and i < len(assistant_message.tool_calls):
|
||||
remaining = len(assistant_message.tool_calls) - i
|
||||
print(f"{self.log_prefix}⚡ Interrupt: skipping {remaining} remaining tool call(s)")
|
||||
self._vprint(f"{self.log_prefix}⚡ Interrupt: skipping {remaining} remaining tool call(s)", force=True)
|
||||
for skipped_tc in assistant_message.tool_calls[i:]:
|
||||
skipped_name = skipped_tc.function.name
|
||||
skip_msg = {
|
||||
|
|
@ -3915,7 +4063,8 @@ class AIAgent:
|
|||
user_message: str,
|
||||
system_message: str = None,
|
||||
conversation_history: List[Dict[str, Any]] = None,
|
||||
task_id: str = None
|
||||
task_id: str = None,
|
||||
stream_callback: Optional[callable] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run a complete conversation with tool calling until completion.
|
||||
|
|
@ -3925,6 +4074,9 @@ class AIAgent:
|
|||
system_message (str): Custom system message (optional, overrides ephemeral_system_prompt if provided)
|
||||
conversation_history (List[Dict]): Previous conversation messages (optional)
|
||||
task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional, auto-generated if not provided)
|
||||
stream_callback: Optional callback invoked with each text delta during streaming.
|
||||
Used by the TTS pipeline to start audio generation before the full response.
|
||||
When None (default), API calls use the standard non-streaming path.
|
||||
|
||||
Returns:
|
||||
Dict: Complete conversation result with final response and message history
|
||||
|
|
@ -3933,6 +4085,8 @@ class AIAgent:
|
|||
# Installed once, transparent when streams are healthy, prevents crash on write.
|
||||
_install_safe_stdio()
|
||||
|
||||
# Store stream callback for _interruptible_api_call to pick up
|
||||
self._stream_callback = stream_callback
|
||||
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
|
||||
effective_task_id = task_id or str(uuid.uuid4())
|
||||
|
||||
|
|
@ -4239,11 +4393,11 @@ class AIAgent:
|
|||
thinking_spinner = None
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(f"\n{self.log_prefix}🔄 Making API call #{api_call_count}/{self.max_iterations}...")
|
||||
print(f"{self.log_prefix} 📊 Request size: {len(api_messages)} messages, ~{approx_tokens:,} tokens (~{total_chars:,} chars)")
|
||||
print(f"{self.log_prefix} 🔧 Available tools: {len(self.tools) if self.tools else 0}")
|
||||
else:
|
||||
# Animated thinking spinner in quiet mode
|
||||
self._vprint(f"\n{self.log_prefix}🔄 Making API call #{api_call_count}/{self.max_iterations}...")
|
||||
self._vprint(f"{self.log_prefix} 📊 Request size: {len(api_messages)} messages, ~{approx_tokens:,} tokens (~{total_chars:,} chars)")
|
||||
self._vprint(f"{self.log_prefix} 🔧 Available tools: {len(self.tools) if self.tools else 0}")
|
||||
elif self._stream_callback is None:
|
||||
# Animated thinking spinner in quiet mode (skip during streaming TTS)
|
||||
face = random.choice(KawaiiSpinner.KAWAII_THINKING)
|
||||
verb = random.choice(KawaiiSpinner.THINKING_VERBS)
|
||||
if self.thinking_callback:
|
||||
|
|
@ -4283,7 +4437,33 @@ class AIAgent:
|
|||
if os.getenv("HERMES_DUMP_REQUESTS", "").strip().lower() in {"1", "true", "yes", "on"}:
|
||||
self._dump_api_request_debug(api_kwargs, reason="preflight")
|
||||
|
||||
response = self._interruptible_api_call(api_kwargs)
|
||||
cb = getattr(self, "_stream_callback", None)
|
||||
if cb is not None and self.api_mode == "chat_completions":
|
||||
response = self._streaming_api_call(api_kwargs, cb)
|
||||
else:
|
||||
response = self._interruptible_api_call(api_kwargs)
|
||||
# Forward full response to TTS callback for non-streaming providers
|
||||
# (e.g. Anthropic) so voice TTS still works via batch delivery.
|
||||
if cb is not None and response:
|
||||
try:
|
||||
content = None
|
||||
# Try choices first — _interruptible_api_call converts all
|
||||
# providers (including Anthropic) to this format.
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
except (AttributeError, IndexError):
|
||||
pass
|
||||
# Fallback: Anthropic native content blocks
|
||||
if not content and self.api_mode == "anthropic_messages":
|
||||
text_parts = [
|
||||
block.text for block in getattr(response, "content", [])
|
||||
if getattr(block, "type", None) == "text" and getattr(block, "text", None)
|
||||
]
|
||||
content = " ".join(text_parts) if text_parts else None
|
||||
if content:
|
||||
cb(content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
|
||||
|
|
@ -4296,7 +4476,7 @@ class AIAgent:
|
|||
self.thinking_callback("")
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(f"{self.log_prefix}⏱️ API call completed in {api_duration:.2f}s")
|
||||
self._vprint(f"{self.log_prefix}⏱️ API call completed in {api_duration:.2f}s")
|
||||
|
||||
if self.verbose_logging:
|
||||
# Log response with provider info if available
|
||||
|
|
@ -4373,17 +4553,17 @@ class AIAgent:
|
|||
if self.verbose_logging:
|
||||
logging.debug(f"Response attributes for invalid response: {resp_attrs}")
|
||||
|
||||
print(f"{self.log_prefix}⚠️ Invalid API response (attempt {retry_count}/{max_retries}): {', '.join(error_details)}")
|
||||
print(f"{self.log_prefix} 🏢 Provider: {provider_name}")
|
||||
print(f"{self.log_prefix} 📝 Provider message: {error_msg[:200]}")
|
||||
print(f"{self.log_prefix} ⏱️ Response time: {api_duration:.2f}s (fast response often indicates rate limiting)")
|
||||
self._vprint(f"{self.log_prefix}⚠️ Invalid API response (attempt {retry_count}/{max_retries}): {', '.join(error_details)}", force=True)
|
||||
self._vprint(f"{self.log_prefix} 🏢 Provider: {provider_name}", force=True)
|
||||
self._vprint(f"{self.log_prefix} 📝 Provider message: {error_msg[:200]}", force=True)
|
||||
self._vprint(f"{self.log_prefix} ⏱️ Response time: {api_duration:.2f}s (fast response often indicates rate limiting)", force=True)
|
||||
|
||||
if retry_count >= max_retries:
|
||||
# Try fallback before giving up
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
continue
|
||||
print(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded for invalid responses. Giving up.")
|
||||
self._vprint(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded for invalid responses. Giving up.", force=True)
|
||||
logging.error(f"{self.log_prefix}Invalid API response after {max_retries} retries.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
|
|
@ -4396,14 +4576,14 @@ class AIAgent:
|
|||
|
||||
# Longer backoff for rate limiting (likely cause of None choices)
|
||||
wait_time = min(5 * (2 ** (retry_count - 1)), 120) # 5s, 10s, 20s, 40s, 80s, 120s
|
||||
print(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...")
|
||||
self._vprint(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...", force=True)
|
||||
logging.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}")
|
||||
|
||||
# Sleep in small increments to stay responsive to interrupts
|
||||
sleep_end = time.time() + wait_time
|
||||
while time.time() < sleep_end:
|
||||
if self._interrupt_requested:
|
||||
print(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.")
|
||||
self._vprint(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.", force=True)
|
||||
self._persist_session(messages, conversation_history)
|
||||
self.clear_interrupt()
|
||||
return {
|
||||
|
|
@ -4436,7 +4616,7 @@ class AIAgent:
|
|||
finish_reason = response.choices[0].finish_reason
|
||||
|
||||
if finish_reason == "length":
|
||||
print(f"{self.log_prefix}⚠️ Response truncated (finish_reason='length') - model hit max output tokens")
|
||||
self._vprint(f"{self.log_prefix}⚠️ Response truncated (finish_reason='length') - model hit max output tokens", force=True)
|
||||
|
||||
if self.api_mode == "chat_completions":
|
||||
assistant_message = response.choices[0].message
|
||||
|
|
@ -4448,7 +4628,7 @@ class AIAgent:
|
|||
truncated_response_prefix += assistant_message.content
|
||||
|
||||
if length_continue_retries < 3:
|
||||
print(
|
||||
self._vprint(
|
||||
f"{self.log_prefix}↻ Requesting continuation "
|
||||
f"({length_continue_retries}/3)..."
|
||||
)
|
||||
|
|
@ -4480,7 +4660,7 @@ class AIAgent:
|
|||
|
||||
# If we have prior messages, roll back to last complete state
|
||||
if len(messages) > 1:
|
||||
print(f"{self.log_prefix} ⏪ Rolling back to last complete assistant turn")
|
||||
self._vprint(f"{self.log_prefix} ⏪ Rolling back to last complete assistant turn")
|
||||
rolled_back_messages = self._get_messages_up_to_last_assistant(messages)
|
||||
|
||||
self._cleanup_task_resources(effective_task_id)
|
||||
|
|
@ -4496,7 +4676,7 @@ class AIAgent:
|
|||
}
|
||||
else:
|
||||
# First message was truncated - mark as failed
|
||||
print(f"{self.log_prefix}❌ First response truncated - cannot recover")
|
||||
self._vprint(f"{self.log_prefix}❌ First response truncated - cannot recover", force=True)
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"final_response": None,
|
||||
|
|
@ -4556,7 +4736,7 @@ class AIAgent:
|
|||
prompt = usage_dict["prompt_tokens"]
|
||||
hit_pct = (cached / prompt * 100) if prompt > 0 else 0
|
||||
if not self.quiet_mode:
|
||||
print(f"{self.log_prefix} 💾 Cache: {cached:,}/{prompt:,} tokens ({hit_pct:.0f}% hit, {written:,} written)")
|
||||
self._vprint(f"{self.log_prefix} 💾 Cache: {cached:,}/{prompt:,} tokens ({hit_pct:.0f}% hit, {written:,} written)")
|
||||
|
||||
break # Success, exit retry loop
|
||||
|
||||
|
|
@ -4567,7 +4747,7 @@ class AIAgent:
|
|||
if self.thinking_callback:
|
||||
self.thinking_callback("")
|
||||
api_elapsed = time.time() - api_start_time
|
||||
print(f"{self.log_prefix}⚡ Interrupted during API call.")
|
||||
self._vprint(f"{self.log_prefix}⚡ Interrupted during API call.", force=True)
|
||||
self._persist_session(messages, conversation_history)
|
||||
interrupted = True
|
||||
final_response = f"Operation interrupted: waiting for model response ({api_elapsed:.1f}s elapsed)."
|
||||
|
|
@ -4590,7 +4770,7 @@ class AIAgent:
|
|||
):
|
||||
codex_auth_retry_attempted = True
|
||||
if self._try_refresh_codex_client_credentials(force=True):
|
||||
print(f"{self.log_prefix}🔐 Codex auth refreshed after 401. Retrying request...")
|
||||
self._vprint(f"{self.log_prefix}🔐 Codex auth refreshed after 401. Retrying request...")
|
||||
continue
|
||||
if (
|
||||
self.api_mode == "chat_completions"
|
||||
|
|
@ -4614,7 +4794,7 @@ class AIAgent:
|
|||
new_token = resolve_anthropic_token()
|
||||
if new_token and new_token != self._anthropic_api_key:
|
||||
self._anthropic_api_key = new_token
|
||||
self._anthropic_client = build_anthropic_client(new_token)
|
||||
self._anthropic_client = build_anthropic_client(new_token, getattr(self, "_anthropic_base_url", None))
|
||||
print(f"{self.log_prefix}🔐 Anthropic credentials refreshed after 401. Retrying request...")
|
||||
continue
|
||||
# Credential refresh didn't help — show diagnostic info
|
||||
|
|
@ -4638,14 +4818,14 @@ class AIAgent:
|
|||
error_type = type(api_error).__name__
|
||||
error_msg = str(api_error).lower()
|
||||
|
||||
print(f"{self.log_prefix}⚠️ API call failed (attempt {retry_count}/{max_retries}): {error_type}")
|
||||
print(f"{self.log_prefix} ⏱️ Time elapsed before failure: {elapsed_time:.2f}s")
|
||||
print(f"{self.log_prefix} 📝 Error: {str(api_error)[:200]}")
|
||||
print(f"{self.log_prefix} 📊 Request context: {len(api_messages)} messages, ~{approx_tokens:,} tokens, {len(self.tools) if self.tools else 0} tools")
|
||||
self._vprint(f"{self.log_prefix}⚠️ API call failed (attempt {retry_count}/{max_retries}): {error_type}", force=True)
|
||||
self._vprint(f"{self.log_prefix} ⏱️ Time elapsed before failure: {elapsed_time:.2f}s")
|
||||
self._vprint(f"{self.log_prefix} 📝 Error: {str(api_error)[:200]}", force=True)
|
||||
self._vprint(f"{self.log_prefix} 📊 Request context: {len(api_messages)} messages, ~{approx_tokens:,} tokens, {len(self.tools) if self.tools else 0} tools")
|
||||
|
||||
# Check for interrupt before deciding to retry
|
||||
if self._interrupt_requested:
|
||||
print(f"{self.log_prefix}⚡ Interrupt detected during error handling, aborting retries.")
|
||||
self._vprint(f"{self.log_prefix}⚡ Interrupt detected during error handling, aborting retries.", force=True)
|
||||
self._persist_session(messages, conversation_history)
|
||||
self.clear_interrupt()
|
||||
return {
|
||||
|
|
@ -4670,7 +4850,7 @@ class AIAgent:
|
|||
if is_payload_too_large:
|
||||
compression_attempts += 1
|
||||
if compression_attempts > max_compression_attempts:
|
||||
print(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached for payload-too-large error.")
|
||||
self._vprint(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached for payload-too-large error.", force=True)
|
||||
logging.error(f"{self.log_prefix}413 compression failed after {max_compression_attempts} attempts.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
|
|
@ -4680,7 +4860,7 @@ class AIAgent:
|
|||
"error": f"Request payload too large: max compression attempts ({max_compression_attempts}) reached.",
|
||||
"partial": True
|
||||
}
|
||||
print(f"{self.log_prefix}⚠️ Request payload too large (413) — compression attempt {compression_attempts}/{max_compression_attempts}...")
|
||||
self._vprint(f"{self.log_prefix}⚠️ Request payload too large (413) — compression attempt {compression_attempts}/{max_compression_attempts}...")
|
||||
|
||||
original_len = len(messages)
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
|
|
@ -4689,12 +4869,12 @@ class AIAgent:
|
|||
)
|
||||
|
||||
if len(messages) < original_len:
|
||||
print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...")
|
||||
self._vprint(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...")
|
||||
time.sleep(2) # Brief pause between compression retries
|
||||
restart_with_compressed_messages = True
|
||||
break
|
||||
else:
|
||||
print(f"{self.log_prefix}❌ Payload too large and cannot compress further.")
|
||||
self._vprint(f"{self.log_prefix}❌ Payload too large and cannot compress further.", force=True)
|
||||
logging.error(f"{self.log_prefix}413 payload too large. Cannot compress further.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
|
|
@ -4725,7 +4905,7 @@ class AIAgent:
|
|||
parsed_limit = parse_context_limit_from_error(error_msg)
|
||||
if parsed_limit and parsed_limit < old_ctx:
|
||||
new_ctx = parsed_limit
|
||||
print(f"{self.log_prefix}⚠️ Context limit detected from API: {new_ctx:,} tokens (was {old_ctx:,})")
|
||||
self._vprint(f"{self.log_prefix}⚠️ Context limit detected from API: {new_ctx:,} tokens (was {old_ctx:,})", force=True)
|
||||
else:
|
||||
# Step down to the next probe tier
|
||||
new_ctx = get_next_probe_tier(old_ctx)
|
||||
|
|
@ -4734,13 +4914,13 @@ class AIAgent:
|
|||
compressor.context_length = new_ctx
|
||||
compressor.threshold_tokens = int(new_ctx * compressor.threshold_percent)
|
||||
compressor._context_probed = True
|
||||
print(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,} → {new_ctx:,} tokens")
|
||||
self._vprint(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,} → {new_ctx:,} tokens", force=True)
|
||||
else:
|
||||
print(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...")
|
||||
self._vprint(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...", force=True)
|
||||
|
||||
compression_attempts += 1
|
||||
if compression_attempts > max_compression_attempts:
|
||||
print(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached.")
|
||||
self._vprint(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached.", force=True)
|
||||
logging.error(f"{self.log_prefix}Context compression failed after {max_compression_attempts} attempts.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
|
|
@ -4750,7 +4930,7 @@ class AIAgent:
|
|||
"error": f"Context length exceeded: max compression attempts ({max_compression_attempts}) reached.",
|
||||
"partial": True
|
||||
}
|
||||
print(f"{self.log_prefix} 🗜️ Context compression attempt {compression_attempts}/{max_compression_attempts}...")
|
||||
self._vprint(f"{self.log_prefix} 🗜️ Context compression attempt {compression_attempts}/{max_compression_attempts}...")
|
||||
|
||||
original_len = len(messages)
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
|
|
@ -4760,14 +4940,14 @@ class AIAgent:
|
|||
|
||||
if len(messages) < original_len or new_ctx and new_ctx < old_ctx:
|
||||
if len(messages) < original_len:
|
||||
print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...")
|
||||
self._vprint(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...")
|
||||
time.sleep(2) # Brief pause between compression retries
|
||||
restart_with_compressed_messages = True
|
||||
break
|
||||
else:
|
||||
# Can't compress further and already at minimum tier
|
||||
print(f"{self.log_prefix}❌ Context length exceeded and cannot compress further.")
|
||||
print(f"{self.log_prefix} 💡 The conversation has accumulated too much content.")
|
||||
self._vprint(f"{self.log_prefix}❌ Context length exceeded and cannot compress further.", force=True)
|
||||
self._vprint(f"{self.log_prefix} 💡 The conversation has accumulated too much content.", force=True)
|
||||
logging.error(f"{self.log_prefix}Context length exceeded: {approx_tokens:,} tokens. Cannot compress further.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
|
|
@ -4803,8 +4983,8 @@ class AIAgent:
|
|||
self._dump_api_request_debug(
|
||||
api_kwargs, reason="non_retryable_client_error", error=api_error,
|
||||
)
|
||||
print(f"{self.log_prefix}❌ Non-retryable client error detected. Aborting immediately.")
|
||||
print(f"{self.log_prefix} 💡 This type of error won't be fixed by retrying.")
|
||||
self._vprint(f"{self.log_prefix}❌ Non-retryable client error detected. Aborting immediately.", force=True)
|
||||
self._vprint(f"{self.log_prefix} 💡 This type of error won't be fixed by retrying.", force=True)
|
||||
logging.error(f"{self.log_prefix}Non-retryable client error: {api_error}")
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
|
|
@ -4821,7 +5001,7 @@ class AIAgent:
|
|||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
continue
|
||||
print(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded. Giving up.")
|
||||
self._vprint(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded. Giving up.", force=True)
|
||||
logging.error(f"{self.log_prefix}API call failed after {max_retries} retries. Last error: {api_error}")
|
||||
logging.error(f"{self.log_prefix}Request details - Messages: {len(api_messages)}, Approx tokens: {approx_tokens:,}")
|
||||
raise api_error
|
||||
|
|
@ -4829,15 +5009,15 @@ class AIAgent:
|
|||
wait_time = min(2 ** retry_count, 60) # Exponential backoff: 2s, 4s, 8s, 16s, 32s, 60s, 60s
|
||||
logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}")
|
||||
if retry_count >= max_retries:
|
||||
print(f"{self.log_prefix}⚠️ API call failed after {retry_count} attempts: {str(api_error)[:100]}")
|
||||
print(f"{self.log_prefix}⏳ Final retry in {wait_time}s...")
|
||||
self._vprint(f"{self.log_prefix}⚠️ API call failed after {retry_count} attempts: {str(api_error)[:100]}")
|
||||
self._vprint(f"{self.log_prefix}⏳ Final retry in {wait_time}s...")
|
||||
|
||||
# Sleep in small increments so we can respond to interrupts quickly
|
||||
# instead of blocking the entire wait_time in one sleep() call
|
||||
sleep_end = time.time() + wait_time
|
||||
while time.time() < sleep_end:
|
||||
if self._interrupt_requested:
|
||||
print(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.")
|
||||
self._vprint(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.", force=True)
|
||||
self._persist_session(messages, conversation_history)
|
||||
self.clear_interrupt()
|
||||
return {
|
||||
|
|
@ -4901,7 +5081,7 @@ class AIAgent:
|
|||
|
||||
# Handle assistant response
|
||||
if assistant_message.content and not self.quiet_mode:
|
||||
print(f"{self.log_prefix}🤖 Assistant: {assistant_message.content[:100]}{'...' if len(assistant_message.content) > 100 else ''}")
|
||||
self._vprint(f"{self.log_prefix}🤖 Assistant: {assistant_message.content[:100]}{'...' if len(assistant_message.content) > 100 else ''}")
|
||||
|
||||
# Notify progress callback of model's thinking (used by subagent
|
||||
# delegation to relay the child's reasoning to the parent display).
|
||||
|
|
@ -4928,15 +5108,15 @@ class AIAgent:
|
|||
self._incomplete_scratchpad_retries = 0
|
||||
self._incomplete_scratchpad_retries += 1
|
||||
|
||||
print(f"{self.log_prefix}⚠️ Incomplete <REASONING_SCRATCHPAD> detected (opened but never closed)")
|
||||
self._vprint(f"{self.log_prefix}⚠️ Incomplete <REASONING_SCRATCHPAD> detected (opened but never closed)")
|
||||
|
||||
if self._incomplete_scratchpad_retries <= 2:
|
||||
print(f"{self.log_prefix}🔄 Retrying API call ({self._incomplete_scratchpad_retries}/2)...")
|
||||
self._vprint(f"{self.log_prefix}🔄 Retrying API call ({self._incomplete_scratchpad_retries}/2)...")
|
||||
# Don't add the broken message, just retry
|
||||
continue
|
||||
else:
|
||||
# Max retries - discard this turn and save as partial
|
||||
print(f"{self.log_prefix}❌ Max retries (2) for incomplete scratchpad. Saving as partial.")
|
||||
self._vprint(f"{self.log_prefix}❌ Max retries (2) for incomplete scratchpad. Saving as partial.", force=True)
|
||||
self._incomplete_scratchpad_retries = 0
|
||||
|
||||
rolled_back_messages = self._get_messages_up_to_last_assistant(messages)
|
||||
|
|
@ -4979,7 +5159,7 @@ class AIAgent:
|
|||
|
||||
if self._codex_incomplete_retries < 3:
|
||||
if not self.quiet_mode:
|
||||
print(f"{self.log_prefix}↻ Codex response incomplete; continuing turn ({self._codex_incomplete_retries}/3)")
|
||||
self._vprint(f"{self.log_prefix}↻ Codex response incomplete; continuing turn ({self._codex_incomplete_retries}/3)")
|
||||
self._session_messages = messages
|
||||
self._save_session_log(messages)
|
||||
continue
|
||||
|
|
@ -5000,7 +5180,7 @@ class AIAgent:
|
|||
# Check for tool calls
|
||||
if assistant_message.tool_calls:
|
||||
if not self.quiet_mode:
|
||||
print(f"{self.log_prefix}🔧 Processing {len(assistant_message.tool_calls)} tool call(s)...")
|
||||
self._vprint(f"{self.log_prefix}🔧 Processing {len(assistant_message.tool_calls)} tool call(s)...")
|
||||
|
||||
if self.verbose_logging:
|
||||
for tc in assistant_message.tool_calls:
|
||||
|
|
@ -5019,11 +5199,30 @@ class AIAgent:
|
|||
if tc.function.name not in self.valid_tool_names
|
||||
]
|
||||
if invalid_tool_calls:
|
||||
# Track retries for invalid tool calls
|
||||
if not hasattr(self, '_invalid_tool_retries'):
|
||||
self._invalid_tool_retries = 0
|
||||
self._invalid_tool_retries += 1
|
||||
|
||||
# Return helpful error to model — model can self-correct next turn
|
||||
available = ", ".join(sorted(self.valid_tool_names))
|
||||
invalid_name = invalid_tool_calls[0]
|
||||
invalid_preview = invalid_name[:80] + "..." if len(invalid_name) > 80 else invalid_name
|
||||
print(f"{self.log_prefix}⚠️ Unknown tool '{invalid_preview}' — sending error to model for self-correction")
|
||||
self._vprint(f"{self.log_prefix}⚠️ Unknown tool '{invalid_preview}' — sending error to model for self-correction ({self._invalid_tool_retries}/3)")
|
||||
|
||||
if self._invalid_tool_retries >= 3:
|
||||
self._vprint(f"{self.log_prefix}❌ Max retries (3) for invalid tool calls exceeded. Stopping as partial.", force=True)
|
||||
self._invalid_tool_retries = 0
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"final_response": None,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"partial": True,
|
||||
"error": f"Model generated invalid tool call: {invalid_preview}"
|
||||
}
|
||||
|
||||
assistant_msg = self._build_assistant_message(assistant_message, finish_reason)
|
||||
messages.append(assistant_msg)
|
||||
for tc in assistant_message.tool_calls:
|
||||
|
|
@ -5060,15 +5259,15 @@ class AIAgent:
|
|||
self._invalid_json_retries += 1
|
||||
|
||||
tool_name, error_msg = invalid_json_args[0]
|
||||
print(f"{self.log_prefix}⚠️ Invalid JSON in tool call arguments for '{tool_name}': {error_msg}")
|
||||
self._vprint(f"{self.log_prefix}⚠️ Invalid JSON in tool call arguments for '{tool_name}': {error_msg}")
|
||||
|
||||
if self._invalid_json_retries < 3:
|
||||
print(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_json_retries}/3)...")
|
||||
self._vprint(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_json_retries}/3)...")
|
||||
# Don't add anything to messages, just retry the API call
|
||||
continue
|
||||
else:
|
||||
# Instead of returning partial, inject a helpful message and let model recover
|
||||
print(f"{self.log_prefix}⚠️ Injecting recovery message for invalid JSON...")
|
||||
self._vprint(f"{self.log_prefix}⚠️ Injecting recovery message for invalid JSON...")
|
||||
self._invalid_json_retries = 0 # Reset for next attempt
|
||||
|
||||
# Add a user message explaining the issue
|
||||
|
|
@ -5098,7 +5297,7 @@ class AIAgent:
|
|||
if self.quiet_mode:
|
||||
clean = self._strip_think_blocks(turn_content).strip()
|
||||
if clean:
|
||||
print(f" ┊ 💬 {clean}")
|
||||
self._vprint(f" ┊ 💬 {clean}")
|
||||
|
||||
messages.append(assistant_msg)
|
||||
|
||||
|
|
@ -5174,19 +5373,19 @@ class AIAgent:
|
|||
self._empty_content_retries += 1
|
||||
|
||||
reasoning_text = self._extract_reasoning(assistant_message)
|
||||
print(f"{self.log_prefix}⚠️ Response only contains think block with no content after it")
|
||||
self._vprint(f"{self.log_prefix}⚠️ Response only contains think block with no content after it")
|
||||
if reasoning_text:
|
||||
reasoning_preview = reasoning_text[:500] + "..." if len(reasoning_text) > 500 else reasoning_text
|
||||
print(f"{self.log_prefix} Reasoning: {reasoning_preview}")
|
||||
self._vprint(f"{self.log_prefix} Reasoning: {reasoning_preview}")
|
||||
else:
|
||||
content_preview = final_response[:80] + "..." if len(final_response) > 80 else final_response
|
||||
print(f"{self.log_prefix} Content: '{content_preview}'")
|
||||
self._vprint(f"{self.log_prefix} Content: '{content_preview}'")
|
||||
|
||||
if self._empty_content_retries < 3:
|
||||
print(f"{self.log_prefix}🔄 Retrying API call ({self._empty_content_retries}/3)...")
|
||||
self._vprint(f"{self.log_prefix}🔄 Retrying API call ({self._empty_content_retries}/3)...")
|
||||
continue
|
||||
else:
|
||||
print(f"{self.log_prefix}❌ Max retries (3) for empty content exceeded.")
|
||||
self._vprint(f"{self.log_prefix}❌ Max retries (3) for empty content exceeded.", force=True)
|
||||
self._empty_content_retries = 0
|
||||
|
||||
# If a prior tool_calls turn had real content, salvage it:
|
||||
|
|
@ -5377,20 +5576,24 @@ class AIAgent:
|
|||
|
||||
# Clear interrupt state after handling
|
||||
self.clear_interrupt()
|
||||
|
||||
|
||||
# Clear stream callback so it doesn't leak into future calls
|
||||
self._stream_callback = None
|
||||
|
||||
return result
|
||||
|
||||
def chat(self, message: str) -> str:
|
||||
|
||||
def chat(self, message: str, stream_callback: Optional[callable] = None) -> str:
|
||||
"""
|
||||
Simple chat interface that returns just the final response.
|
||||
|
||||
|
||||
Args:
|
||||
message (str): User message
|
||||
|
||||
stream_callback: Optional callback invoked with each text delta during streaming.
|
||||
|
||||
Returns:
|
||||
str: Final assistant response
|
||||
"""
|
||||
result = self.run_conversation(message)
|
||||
result = self.run_conversation(message, stream_callback=stream_callback)
|
||||
return result["final_response"]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -275,12 +275,25 @@ class FakeHAServer:
|
|||
affected = []
|
||||
entity_id = body.get("entity_id")
|
||||
if entity_id:
|
||||
new_state = "on" if service == "turn_on" else "off"
|
||||
for s in ENTITY_STATES:
|
||||
if s["entity_id"] == entity_id:
|
||||
if service == "turn_on":
|
||||
s["state"] = "on"
|
||||
elif service == "turn_off":
|
||||
s["state"] = "off"
|
||||
elif service == "set_temperature" and "temperature" in body:
|
||||
s["attributes"]["temperature"] = body["temperature"]
|
||||
# Keep current state or set to heat if off
|
||||
if s["state"] == "off":
|
||||
s["state"] = "heat"
|
||||
# Simulate temperature sensor approaching the target
|
||||
for ts in ENTITY_STATES:
|
||||
if ts["entity_id"] == "sensor.temperature":
|
||||
ts["state"] = str(body["temperature"] - 0.5)
|
||||
break
|
||||
affected.append({
|
||||
"entity_id": entity_id,
|
||||
"state": new_state,
|
||||
"state": s["state"],
|
||||
"attributes": s.get("attributes", {}),
|
||||
})
|
||||
break
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ def _make_runner():
|
|||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
runner._session_db = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ def _ensure_discord_mock():
|
|||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
|
|
|
|||
44
tests/gateway/test_discord_opus.py
Normal file
44
tests/gateway/test_discord_opus.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
"""Tests for Discord Opus codec loading — must use ctypes.util.find_library."""
|
||||
|
||||
import inspect
|
||||
|
||||
|
||||
class TestOpusFindLibrary:
|
||||
"""Opus loading must try ctypes.util.find_library first, with platform fallback."""
|
||||
|
||||
def test_uses_find_library_first(self):
|
||||
"""find_library must be the primary lookup strategy."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
source = inspect.getsource(DiscordAdapter.connect)
|
||||
assert "find_library" in source, \
|
||||
"Opus loading must use ctypes.util.find_library"
|
||||
|
||||
def test_homebrew_fallback_is_conditional(self):
|
||||
"""Homebrew paths must only be tried when find_library returns None."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
source = inspect.getsource(DiscordAdapter.connect)
|
||||
# Homebrew fallback must exist
|
||||
assert "/opt/homebrew" in source or "homebrew" in source, \
|
||||
"Opus loading should have macOS Homebrew fallback"
|
||||
# find_library must appear BEFORE any Homebrew path
|
||||
fl_idx = source.index("find_library")
|
||||
hb_idx = source.index("/opt/homebrew")
|
||||
assert fl_idx < hb_idx, \
|
||||
"find_library must be tried before Homebrew fallback paths"
|
||||
# Fallback must be guarded by platform check
|
||||
assert "sys.platform" in source or "darwin" in source, \
|
||||
"Homebrew fallback must be guarded by macOS platform check"
|
||||
|
||||
def test_opus_decode_error_logged(self):
|
||||
"""Opus decode failure must log the error, not silently return."""
|
||||
from gateway.platforms.discord import VoiceReceiver
|
||||
source = inspect.getsource(VoiceReceiver._on_packet)
|
||||
assert "logger" in source, \
|
||||
"_on_packet must log Opus decode errors"
|
||||
# Must not have bare `except Exception:\n return`
|
||||
lines = source.split("\n")
|
||||
for i, line in enumerate(lines):
|
||||
if "except Exception" in line and i + 1 < len(lines):
|
||||
next_line = lines[i + 1].strip()
|
||||
assert next_line != "return", \
|
||||
f"_on_packet has bare 'except Exception: return' at line {i+1}"
|
||||
|
|
@ -21,6 +21,8 @@ def _ensure_discord_mock():
|
|||
discord_mod.Interaction = object
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ def _make_runner(session_db=None, current_session_id="current_session_001",
|
|||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
runner._session_db = session_db
|
||||
runner._running_agents = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ def _make_runner(adapter):
|
|||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner._prefill_messages = []
|
||||
runner._ephemeral_system_prompt = ""
|
||||
runner._reasoning_config = None
|
||||
|
|
|
|||
|
|
@ -266,6 +266,7 @@ async def test_session_hygiene_messages_stay_in_originating_topic(monkeypatch, t
|
|||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake-token")}
|
||||
)
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = SessionEntry(
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ def _make_runner(session_db=None):
|
|||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
runner._session_db = session_db
|
||||
|
||||
# Mock session_store that returns a session entry with a known session_id
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ def _make_runner():
|
|||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
return runner
|
||||
|
||||
|
||||
|
|
|
|||
1965
tests/gateway/test_voice_command.py
Normal file
1965
tests/gateway/test_voice_command.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -12,7 +12,7 @@ EXPECTED_COMMANDS = {
|
|||
"/personality", "/clear", "/history", "/new", "/reset", "/retry",
|
||||
"/undo", "/save", "/config", "/cron", "/skills", "/platforms",
|
||||
"/verbose", "/reasoning", "/compress", "/title", "/usage", "/insights", "/paste",
|
||||
"/reload-mcp", "/rollback", "/background", "/skin", "/quit",
|
||||
"/reload-mcp", "/rollback", "/background", "/skin", "/voice", "/quit",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ def _make_cli_stub():
|
|||
cli._clarify_freetext = False
|
||||
cli._command_running = False
|
||||
cli._agent_running = False
|
||||
cli._voice_recording = False
|
||||
cli._voice_processing = False
|
||||
cli._voice_mode = False
|
||||
cli._command_spinner_frame = lambda: "⟳"
|
||||
cli._tui_style_base = {
|
||||
"prompt": "#fff",
|
||||
|
|
|
|||
|
|
@ -2083,3 +2083,332 @@ class TestAnthropicBaseUrlPassthrough:
|
|||
# No base_url provided, should be default empty string or None
|
||||
passed_url = call_args[0][1]
|
||||
assert not passed_url or passed_url is None
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# _streaming_api_call tests
|
||||
# ===================================================================
|
||||
|
||||
def _make_chunk(content=None, tool_calls=None, finish_reason=None, model="test/model"):
|
||||
"""Build a SimpleNamespace mimicking an OpenAI streaming chunk."""
|
||||
delta = SimpleNamespace(content=content, tool_calls=tool_calls)
|
||||
choice = SimpleNamespace(delta=delta, finish_reason=finish_reason)
|
||||
return SimpleNamespace(model=model, choices=[choice])
|
||||
|
||||
|
||||
def _make_tc_delta(index=0, tc_id=None, name=None, arguments=None):
|
||||
"""Build a SimpleNamespace mimicking a streaming tool_call delta."""
|
||||
func = SimpleNamespace(name=name, arguments=arguments)
|
||||
return SimpleNamespace(index=index, id=tc_id, function=func)
|
||||
|
||||
|
||||
class TestStreamingApiCall:
|
||||
"""Tests for _streaming_api_call — voice TTS streaming pipeline."""
|
||||
|
||||
def test_content_assembly(self, agent):
|
||||
chunks = [
|
||||
_make_chunk(content="Hel"),
|
||||
_make_chunk(content="lo "),
|
||||
_make_chunk(content="World"),
|
||||
_make_chunk(finish_reason="stop"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
callback = MagicMock()
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, callback)
|
||||
|
||||
assert resp.choices[0].message.content == "Hello World"
|
||||
assert resp.choices[0].finish_reason == "stop"
|
||||
assert callback.call_count == 3
|
||||
callback.assert_any_call("Hel")
|
||||
callback.assert_any_call("lo ")
|
||||
callback.assert_any_call("World")
|
||||
|
||||
def test_tool_call_accumulation(self, agent):
|
||||
chunks = [
|
||||
_make_chunk(tool_calls=[_make_tc_delta(0, "call_1", "web_", '{"q":')]),
|
||||
_make_chunk(tool_calls=[_make_tc_delta(0, None, "search", '"test"}')]),
|
||||
_make_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
tc = resp.choices[0].message.tool_calls
|
||||
assert len(tc) == 1
|
||||
assert tc[0].function.name == "web_search"
|
||||
assert tc[0].function.arguments == '{"q":"test"}'
|
||||
assert tc[0].id == "call_1"
|
||||
|
||||
def test_multiple_tool_calls(self, agent):
|
||||
chunks = [
|
||||
_make_chunk(tool_calls=[_make_tc_delta(0, "call_a", "search", '{}')]),
|
||||
_make_chunk(tool_calls=[_make_tc_delta(1, "call_b", "read", '{}')]),
|
||||
_make_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
tc = resp.choices[0].message.tool_calls
|
||||
assert len(tc) == 2
|
||||
assert tc[0].function.name == "search"
|
||||
assert tc[1].function.name == "read"
|
||||
|
||||
def test_content_and_tool_calls_together(self, agent):
|
||||
chunks = [
|
||||
_make_chunk(content="I'll search"),
|
||||
_make_chunk(tool_calls=[_make_tc_delta(0, "call_1", "search", '{}')]),
|
||||
_make_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.choices[0].message.content == "I'll search"
|
||||
assert len(resp.choices[0].message.tool_calls) == 1
|
||||
|
||||
def test_empty_content_returns_none(self, agent):
|
||||
chunks = [_make_chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.choices[0].message.content is None
|
||||
assert resp.choices[0].message.tool_calls is None
|
||||
|
||||
def test_callback_exception_swallowed(self, agent):
|
||||
chunks = [
|
||||
_make_chunk(content="Hello"),
|
||||
_make_chunk(content=" World"),
|
||||
_make_chunk(finish_reason="stop"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
callback = MagicMock(side_effect=ValueError("boom"))
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, callback)
|
||||
|
||||
assert resp.choices[0].message.content == "Hello World"
|
||||
|
||||
def test_model_name_captured(self, agent):
|
||||
chunks = [
|
||||
_make_chunk(content="Hi", model="gpt-4o"),
|
||||
_make_chunk(finish_reason="stop", model="gpt-4o"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.model == "gpt-4o"
|
||||
|
||||
def test_stream_kwarg_injected(self, agent):
|
||||
chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
agent._streaming_api_call({"messages": [], "model": "test"}, MagicMock())
|
||||
|
||||
call_kwargs = agent.client.chat.completions.create.call_args
|
||||
assert call_kwargs[1].get("stream") is True or call_kwargs.kwargs.get("stream") is True
|
||||
|
||||
def test_api_exception_propagated(self, agent):
|
||||
agent.client.chat.completions.create.side_effect = ConnectionError("fail")
|
||||
|
||||
with pytest.raises(ConnectionError, match="fail"):
|
||||
agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
def test_response_has_uuid_id(self, agent):
|
||||
chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.id.startswith("stream-")
|
||||
assert len(resp.id) > len("stream-")
|
||||
|
||||
def test_empty_choices_chunk_skipped(self, agent):
|
||||
empty_chunk = SimpleNamespace(model="gpt-4", choices=[])
|
||||
chunks = [
|
||||
empty_chunk,
|
||||
_make_chunk(content="Hello", model="gpt-4"),
|
||||
_make_chunk(finish_reason="stop", model="gpt-4"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.choices[0].message.content == "Hello"
|
||||
assert resp.model == "gpt-4"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Interrupt _vprint force=True verification
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestInterruptVprintForceTrue:
|
||||
"""All interrupt _vprint calls must use force=True so they are always visible."""
|
||||
|
||||
def test_all_interrupt_vprint_have_force_true(self):
|
||||
"""Scan source for _vprint calls containing 'Interrupt' — each must have force=True."""
|
||||
import inspect
|
||||
source = inspect.getsource(AIAgent)
|
||||
lines = source.split("\n")
|
||||
violations = []
|
||||
for i, line in enumerate(lines, 1):
|
||||
stripped = line.strip()
|
||||
if "_vprint(" in stripped and "Interrupt" in stripped:
|
||||
if "force=True" not in stripped:
|
||||
violations.append(f"line {i}: {stripped}")
|
||||
assert not violations, (
|
||||
f"Interrupt _vprint calls missing force=True:\n"
|
||||
+ "\n".join(violations)
|
||||
)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Anthropic interrupt handler in _interruptible_api_call
|
||||
# ===================================================================
|
||||
|
||||
|
||||
class TestAnthropicInterruptHandler:
|
||||
"""_interruptible_api_call must handle Anthropic mode when interrupted."""
|
||||
|
||||
def test_interruptible_has_anthropic_branch(self):
|
||||
"""The interrupt handler must check api_mode == 'anthropic_messages'."""
|
||||
import inspect
|
||||
source = inspect.getsource(AIAgent._interruptible_api_call)
|
||||
assert "anthropic_messages" in source, \
|
||||
"_interruptible_api_call must handle Anthropic interrupt (api_mode check)"
|
||||
|
||||
def test_interruptible_rebuilds_anthropic_client(self):
|
||||
"""After interrupting, the Anthropic client should be rebuilt."""
|
||||
import inspect
|
||||
source = inspect.getsource(AIAgent._interruptible_api_call)
|
||||
assert "build_anthropic_client" in source, \
|
||||
"_interruptible_api_call must rebuild Anthropic client after interrupt"
|
||||
|
||||
def test_streaming_has_anthropic_branch(self):
|
||||
"""_streaming_api_call must also handle Anthropic interrupt."""
|
||||
import inspect
|
||||
source = inspect.getsource(AIAgent._streaming_api_call)
|
||||
assert "anthropic_messages" in source, \
|
||||
"_streaming_api_call must handle Anthropic interrupt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bugfix: stream_callback forwarding for non-streaming providers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStreamCallbackNonStreamingProvider:
|
||||
"""When api_mode != chat_completions, stream_callback must still receive
|
||||
the response content so TTS works (batch delivery)."""
|
||||
|
||||
def test_callback_receives_chat_completions_response(self, agent):
|
||||
"""For chat_completions-shaped responses, callback gets content."""
|
||||
agent.api_mode = "anthropic_messages"
|
||||
mock_response = SimpleNamespace(
|
||||
choices=[SimpleNamespace(
|
||||
message=SimpleNamespace(content="Hello", tool_calls=None, reasoning_content=None),
|
||||
finish_reason="stop", index=0,
|
||||
)],
|
||||
usage=None, model="test", id="test-id",
|
||||
)
|
||||
agent._interruptible_api_call = MagicMock(return_value=mock_response)
|
||||
|
||||
received = []
|
||||
cb = lambda delta: received.append(delta)
|
||||
agent._stream_callback = cb
|
||||
|
||||
_cb = getattr(agent, "_stream_callback", None)
|
||||
response = agent._interruptible_api_call({})
|
||||
if _cb is not None and response:
|
||||
try:
|
||||
if agent.api_mode == "anthropic_messages":
|
||||
text_parts = [
|
||||
block.text for block in getattr(response, "content", [])
|
||||
if getattr(block, "type", None) == "text" and getattr(block, "text", None)
|
||||
]
|
||||
content = " ".join(text_parts) if text_parts else None
|
||||
else:
|
||||
content = response.choices[0].message.content
|
||||
if content:
|
||||
_cb(content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Anthropic format not matched above; fallback via except
|
||||
# Test the actual code path by checking chat_completions branch
|
||||
received2 = []
|
||||
agent.api_mode = "some_other_mode"
|
||||
agent._stream_callback = lambda d: received2.append(d)
|
||||
_cb2 = agent._stream_callback
|
||||
if _cb2 is not None and mock_response:
|
||||
try:
|
||||
content = mock_response.choices[0].message.content
|
||||
if content:
|
||||
_cb2(content)
|
||||
except Exception:
|
||||
pass
|
||||
assert received2 == ["Hello"]
|
||||
|
||||
def test_callback_receives_anthropic_content(self, agent):
|
||||
"""For Anthropic responses, text blocks are extracted and forwarded."""
|
||||
agent.api_mode = "anthropic_messages"
|
||||
mock_response = SimpleNamespace(
|
||||
content=[SimpleNamespace(type="text", text="Hello from Claude")],
|
||||
stop_reason="end_turn",
|
||||
)
|
||||
|
||||
received = []
|
||||
cb = lambda d: received.append(d)
|
||||
agent._stream_callback = cb
|
||||
_cb = agent._stream_callback
|
||||
|
||||
if _cb is not None and mock_response:
|
||||
try:
|
||||
if agent.api_mode == "anthropic_messages":
|
||||
text_parts = [
|
||||
block.text for block in getattr(mock_response, "content", [])
|
||||
if getattr(block, "type", None) == "text" and getattr(block, "text", None)
|
||||
]
|
||||
content = " ".join(text_parts) if text_parts else None
|
||||
else:
|
||||
content = mock_response.choices[0].message.content
|
||||
if content:
|
||||
_cb(content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
assert received == ["Hello from Claude"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bugfix: _vprint force=True on error messages during TTS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVprintForceOnErrors:
|
||||
"""Error/warning messages must be visible during streaming TTS."""
|
||||
|
||||
def test_forced_message_shown_during_tts(self, agent):
|
||||
agent._stream_callback = lambda x: None
|
||||
printed = []
|
||||
with patch("builtins.print", side_effect=lambda *a, **kw: printed.append(a)):
|
||||
agent._vprint("error msg", force=True)
|
||||
assert len(printed) == 1
|
||||
|
||||
def test_non_forced_suppressed_during_tts(self, agent):
|
||||
agent._stream_callback = lambda x: None
|
||||
printed = []
|
||||
with patch("builtins.print", side_effect=lambda *a, **kw: printed.append(a)):
|
||||
agent._vprint("debug info")
|
||||
assert len(printed) == 0
|
||||
|
||||
def test_all_shown_without_tts(self, agent):
|
||||
agent._stream_callback = None
|
||||
printed = []
|
||||
with patch("builtins.print", side_effect=lambda *a, **kw: printed.append(a)):
|
||||
agent._vprint("debug")
|
||||
agent._vprint("error", force=True)
|
||||
assert len(printed) == 2
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ class TestGetProvider:
|
|||
|
||||
def test_local_fallback_to_openai(self, monkeypatch):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
|
|
@ -124,7 +125,7 @@ class TestTranscribeLocal:
|
|||
mock_model.transcribe.return_value = ([mock_segment], mock_info)
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("tools.transcription_tools.WhisperModel", return_value=mock_model), \
|
||||
patch("faster_whisper.WhisperModel", return_value=mock_model), \
|
||||
patch("tools.transcription_tools._local_model", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
result = _transcribe_local(str(audio_file), "base")
|
||||
|
|
@ -163,7 +164,7 @@ class TestTranscribeOpenAI:
|
|||
mock_client.audio.transcriptions.create.return_value = "Hello from OpenAI"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools.OpenAI", return_value=mock_client):
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
result = _transcribe_openai(str(audio_file), "whisper-1")
|
||||
|
||||
|
|
|
|||
716
tests/tools/test_transcription_tools.py
Normal file
716
tests/tools/test_transcription_tools.py
Normal file
|
|
@ -0,0 +1,716 @@
|
|||
"""Tests for tools.transcription_tools — three-provider STT pipeline.
|
||||
|
||||
Covers the full provider matrix (local, groq, openai), fallback chains,
|
||||
model auto-correction, config loading, validation edge cases, and
|
||||
end-to-end dispatch. All external dependencies are mocked.
|
||||
"""
|
||||
|
||||
import os
|
||||
import struct
|
||||
import wave
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def sample_wav(tmp_path):
|
||||
"""Create a minimal valid WAV file (1 second of silence at 16kHz)."""
|
||||
wav_path = tmp_path / "test.wav"
|
||||
n_frames = 16000
|
||||
silence = struct.pack(f"<{n_frames}h", *([0] * n_frames))
|
||||
|
||||
with wave.open(str(wav_path), "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(16000)
|
||||
wf.writeframes(silence)
|
||||
|
||||
return str(wav_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ogg(tmp_path):
|
||||
"""Create a fake OGG file for validation tests."""
|
||||
ogg_path = tmp_path / "test.ogg"
|
||||
ogg_path.write_bytes(b"fake audio data")
|
||||
return str(ogg_path)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_env(monkeypatch):
|
||||
"""Ensure no real API keys leak into tests."""
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _get_provider — full permutation matrix
|
||||
# ============================================================================
|
||||
|
||||
class TestGetProviderGroq:
|
||||
"""Groq-specific provider selection tests."""
|
||||
|
||||
def test_groq_when_key_set(self, monkeypatch):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools._HAS_FASTER_WHISPER", False):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "groq"}) == "groq"
|
||||
|
||||
def test_groq_fallback_to_local(self, monkeypatch):
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "groq"}) == "local"
|
||||
|
||||
def test_groq_fallback_to_openai(self, monkeypatch):
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "groq"}) == "openai"
|
||||
|
||||
def test_groq_nothing_available(self, monkeypatch):
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", False):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "groq"}) == "none"
|
||||
|
||||
|
||||
class TestGetProviderFallbackPriority:
|
||||
"""Cross-provider fallback priority tests."""
|
||||
|
||||
def test_local_fallback_prefers_groq_over_openai(self, monkeypatch):
|
||||
"""When local unavailable, groq (free) is preferred over openai (paid)."""
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "local"}) == "groq"
|
||||
|
||||
def test_local_fallback_to_groq_only(self, monkeypatch):
|
||||
"""When only groq key available, falls back to groq."""
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "local"}) == "groq"
|
||||
|
||||
def test_openai_fallback_to_groq(self, monkeypatch):
|
||||
"""When openai key missing but groq available, falls back to groq."""
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "openai"}) == "groq"
|
||||
|
||||
def test_openai_nothing_available(self, monkeypatch):
|
||||
"""When no openai key and no local, returns none."""
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "openai"}) == "none"
|
||||
|
||||
def test_unknown_provider_passed_through(self):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "custom-endpoint"}) == "custom-endpoint"
|
||||
|
||||
def test_empty_config_defaults_to_local(self):
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({}) == "local"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _transcribe_groq
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeGroq:
|
||||
def test_no_key(self, monkeypatch):
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq("/tmp/test.ogg", "whisper-large-v3-turbo")
|
||||
assert result["success"] is False
|
||||
assert "GROQ_API_KEY" in result["error"]
|
||||
|
||||
def test_openai_package_not_installed(self, monkeypatch):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", False):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq("/tmp/test.ogg", "whisper-large-v3-turbo")
|
||||
assert result["success"] is False
|
||||
assert "openai package" in result["error"]
|
||||
|
||||
def test_successful_transcription(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "hello world"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "hello world"
|
||||
assert result["provider"] == "groq"
|
||||
|
||||
def test_whitespace_stripped(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = " hello world \n"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
assert result["transcript"] == "hello world"
|
||||
|
||||
def test_uses_groq_base_url(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client) as mock_openai_cls:
|
||||
from tools.transcription_tools import _transcribe_groq, GROQ_BASE_URL
|
||||
_transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
call_kwargs = mock_openai_cls.call_args
|
||||
assert call_kwargs.kwargs["base_url"] == GROQ_BASE_URL
|
||||
|
||||
def test_api_error_returns_failure(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.side_effect = Exception("API error")
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "API error" in result["error"]
|
||||
|
||||
def test_permission_error(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.side_effect = PermissionError("denied")
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Permission denied" in result["error"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _transcribe_openai — additional tests
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeOpenAIExtended:
|
||||
def test_openai_package_not_installed(self, monkeypatch):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", False):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
result = _transcribe_openai("/tmp/test.ogg", "whisper-1")
|
||||
assert result["success"] is False
|
||||
assert "openai package" in result["error"]
|
||||
|
||||
def test_uses_openai_base_url(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client) as mock_openai_cls:
|
||||
from tools.transcription_tools import _transcribe_openai, OPENAI_BASE_URL
|
||||
_transcribe_openai(sample_wav, "whisper-1")
|
||||
|
||||
call_kwargs = mock_openai_cls.call_args
|
||||
assert call_kwargs.kwargs["base_url"] == OPENAI_BASE_URL
|
||||
|
||||
def test_whitespace_stripped(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = " hello \n"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
result = _transcribe_openai(sample_wav, "whisper-1")
|
||||
|
||||
assert result["transcript"] == "hello"
|
||||
|
||||
def test_permission_error(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.side_effect = PermissionError("denied")
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
result = _transcribe_openai(sample_wav, "whisper-1")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Permission denied" in result["error"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _transcribe_local — additional tests
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeLocalExtended:
|
||||
def test_model_reuse_on_second_call(self, tmp_path):
|
||||
"""Second call with same model should NOT reload the model."""
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
mock_segment = MagicMock()
|
||||
mock_segment.text = "hi"
|
||||
mock_info = MagicMock()
|
||||
mock_info.language = "en"
|
||||
mock_info.duration = 1.0
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.transcribe.return_value = ([mock_segment], mock_info)
|
||||
mock_whisper_cls = MagicMock(return_value=mock_model)
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("faster_whisper.WhisperModel", mock_whisper_cls), \
|
||||
patch("tools.transcription_tools._local_model", None), \
|
||||
patch("tools.transcription_tools._local_model_name", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
_transcribe_local(str(audio), "base")
|
||||
_transcribe_local(str(audio), "base")
|
||||
|
||||
# WhisperModel should be created only once
|
||||
assert mock_whisper_cls.call_count == 1
|
||||
|
||||
def test_model_reloaded_on_change(self, tmp_path):
|
||||
"""Switching model name should reload the model."""
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
mock_segment = MagicMock()
|
||||
mock_segment.text = "hi"
|
||||
mock_info = MagicMock()
|
||||
mock_info.language = "en"
|
||||
mock_info.duration = 1.0
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.transcribe.return_value = ([mock_segment], mock_info)
|
||||
mock_whisper_cls = MagicMock(return_value=mock_model)
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("faster_whisper.WhisperModel", mock_whisper_cls), \
|
||||
patch("tools.transcription_tools._local_model", None), \
|
||||
patch("tools.transcription_tools._local_model_name", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
_transcribe_local(str(audio), "base")
|
||||
_transcribe_local(str(audio), "small")
|
||||
|
||||
assert mock_whisper_cls.call_count == 2
|
||||
|
||||
def test_exception_returns_failure(self, tmp_path):
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
mock_whisper_cls = MagicMock(side_effect=RuntimeError("CUDA out of memory"))
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("faster_whisper.WhisperModel", mock_whisper_cls), \
|
||||
patch("tools.transcription_tools._local_model", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
result = _transcribe_local(str(audio), "large-v3")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "CUDA out of memory" in result["error"]
|
||||
|
||||
def test_multiple_segments_joined(self, tmp_path):
|
||||
audio = tmp_path / "test.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
|
||||
seg1 = MagicMock()
|
||||
seg1.text = "Hello"
|
||||
seg2 = MagicMock()
|
||||
seg2.text = " world"
|
||||
mock_info = MagicMock()
|
||||
mock_info.language = "en"
|
||||
mock_info.duration = 3.0
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.transcribe.return_value = ([seg1, seg2], mock_info)
|
||||
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("faster_whisper.WhisperModel", return_value=mock_model), \
|
||||
patch("tools.transcription_tools._local_model", None):
|
||||
from tools.transcription_tools import _transcribe_local
|
||||
result = _transcribe_local(str(audio), "base")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "Hello world"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Model auto-correction
|
||||
# ============================================================================
|
||||
|
||||
class TestModelAutoCorrection:
|
||||
def test_groq_corrects_openai_model(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "hello world"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq, DEFAULT_GROQ_STT_MODEL
|
||||
_transcribe_groq(sample_wav, "whisper-1")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_GROQ_STT_MODEL
|
||||
|
||||
def test_groq_corrects_gpt4o_transcribe(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq, DEFAULT_GROQ_STT_MODEL
|
||||
_transcribe_groq(sample_wav, "gpt-4o-transcribe")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_GROQ_STT_MODEL
|
||||
|
||||
def test_openai_corrects_groq_model(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "hello world"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai, DEFAULT_STT_MODEL
|
||||
_transcribe_openai(sample_wav, "whisper-large-v3-turbo")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_STT_MODEL
|
||||
|
||||
def test_openai_corrects_distil_whisper(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai, DEFAULT_STT_MODEL
|
||||
_transcribe_openai(sample_wav, "distil-whisper-large-v3-en")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == DEFAULT_STT_MODEL
|
||||
|
||||
def test_compatible_groq_model_not_overridden(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
_transcribe_groq(sample_wav, "whisper-large-v3")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == "whisper-large-v3"
|
||||
|
||||
def test_compatible_openai_model_not_overridden(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
_transcribe_openai(sample_wav, "gpt-4o-mini-transcribe")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == "gpt-4o-mini-transcribe"
|
||||
|
||||
def test_unknown_model_passes_through_groq(self, monkeypatch, sample_wav):
|
||||
"""A model not in either known set should not be overridden."""
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_groq
|
||||
_transcribe_groq(sample_wav, "my-custom-model")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == "my-custom-model"
|
||||
|
||||
def test_unknown_model_passes_through_openai(self, monkeypatch, sample_wav):
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.audio.transcriptions.create.return_value = "test"
|
||||
|
||||
with patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("openai.OpenAI", return_value=mock_client):
|
||||
from tools.transcription_tools import _transcribe_openai
|
||||
_transcribe_openai(sample_wav, "my-custom-model")
|
||||
|
||||
call_kwargs = mock_client.audio.transcriptions.create.call_args
|
||||
assert call_kwargs.kwargs["model"] == "my-custom-model"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _load_stt_config
|
||||
# ============================================================================
|
||||
|
||||
class TestLoadSttConfig:
|
||||
def test_returns_dict_when_import_fails(self):
|
||||
with patch("tools.transcription_tools._load_stt_config") as mock_load:
|
||||
mock_load.return_value = {}
|
||||
from tools.transcription_tools import _load_stt_config
|
||||
assert _load_stt_config() == {}
|
||||
|
||||
def test_real_load_returns_dict(self):
|
||||
"""_load_stt_config should always return a dict, even on import error."""
|
||||
with patch.dict("sys.modules", {"hermes_cli": None, "hermes_cli.config": None}):
|
||||
from tools.transcription_tools import _load_stt_config
|
||||
result = _load_stt_config()
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _validate_audio_file — edge cases
|
||||
# ============================================================================
|
||||
|
||||
class TestValidateAudioFileEdgeCases:
|
||||
def test_directory_is_not_a_file(self, tmp_path):
|
||||
from tools.transcription_tools import _validate_audio_file
|
||||
# tmp_path itself is a directory with an .ogg-ish name? No.
|
||||
# Create a directory with a valid audio extension
|
||||
d = tmp_path / "audio.ogg"
|
||||
d.mkdir()
|
||||
result = _validate_audio_file(str(d))
|
||||
assert result is not None
|
||||
assert "not a file" in result["error"]
|
||||
|
||||
def test_stat_oserror(self, tmp_path):
|
||||
f = tmp_path / "test.ogg"
|
||||
f.write_bytes(b"data")
|
||||
from tools.transcription_tools import _validate_audio_file
|
||||
real_stat = f.stat()
|
||||
call_count = 0
|
||||
|
||||
def stat_side_effect(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# First calls are from exists() and is_file(), let them pass
|
||||
if call_count <= 2:
|
||||
return real_stat
|
||||
raise OSError("disk error")
|
||||
|
||||
with patch("pathlib.Path.stat", side_effect=stat_side_effect):
|
||||
result = _validate_audio_file(str(f))
|
||||
assert result is not None
|
||||
assert "Failed to access" in result["error"]
|
||||
|
||||
def test_all_supported_formats_accepted(self, tmp_path):
|
||||
from tools.transcription_tools import _validate_audio_file, SUPPORTED_FORMATS
|
||||
for fmt in SUPPORTED_FORMATS:
|
||||
f = tmp_path / f"test{fmt}"
|
||||
f.write_bytes(b"data")
|
||||
assert _validate_audio_file(str(f)) is None, f"Format {fmt} should be accepted"
|
||||
|
||||
def test_case_insensitive_extension(self, tmp_path):
|
||||
from tools.transcription_tools import _validate_audio_file
|
||||
f = tmp_path / "test.MP3"
|
||||
f.write_bytes(b"data")
|
||||
assert _validate_audio_file(str(f)) is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# transcribe_audio — end-to-end dispatch
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeAudioDispatch:
|
||||
def test_dispatches_to_groq(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "groq"}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="groq"), \
|
||||
patch("tools.transcription_tools._transcribe_groq",
|
||||
return_value={"success": True, "transcript": "hi", "provider": "groq"}) as mock_groq:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_ogg)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["provider"] == "groq"
|
||||
mock_groq.assert_called_once()
|
||||
|
||||
def test_dispatches_to_local(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="local"), \
|
||||
patch("tools.transcription_tools._transcribe_local",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_local:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_ogg)
|
||||
|
||||
assert result["success"] is True
|
||||
mock_local.assert_called_once()
|
||||
|
||||
def test_dispatches_to_openai(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openai"}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openai"), \
|
||||
patch("tools.transcription_tools._transcribe_openai",
|
||||
return_value={"success": True, "transcript": "hi", "provider": "openai"}) as mock_openai:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_ogg)
|
||||
|
||||
assert result["success"] is True
|
||||
mock_openai.assert_called_once()
|
||||
|
||||
def test_no_provider_returns_error(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="none"):
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_ogg)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "No STT provider" in result["error"]
|
||||
assert "faster-whisper" in result["error"]
|
||||
assert "GROQ_API_KEY" in result["error"]
|
||||
|
||||
def test_invalid_file_short_circuits(self):
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio("/nonexistent/audio.wav")
|
||||
assert result["success"] is False
|
||||
assert "not found" in result["error"]
|
||||
|
||||
def test_model_override_passed_to_groq(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="groq"), \
|
||||
patch("tools.transcription_tools._transcribe_groq",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_groq:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_ogg, model="whisper-large-v3")
|
||||
|
||||
_, kwargs = mock_groq.call_args
|
||||
assert kwargs.get("model_name") or mock_groq.call_args[0][1] == "whisper-large-v3"
|
||||
|
||||
def test_model_override_passed_to_local(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="local"), \
|
||||
patch("tools.transcription_tools._transcribe_local",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_local:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_ogg, model="large-v3")
|
||||
|
||||
assert mock_local.call_args[0][1] == "large-v3"
|
||||
|
||||
def test_default_model_used_when_none(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="groq"), \
|
||||
patch("tools.transcription_tools._transcribe_groq",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_groq:
|
||||
from tools.transcription_tools import transcribe_audio, DEFAULT_GROQ_STT_MODEL
|
||||
transcribe_audio(sample_ogg, model=None)
|
||||
|
||||
assert mock_groq.call_args[0][1] == DEFAULT_GROQ_STT_MODEL
|
||||
|
||||
def test_config_local_model_used(self, sample_ogg):
|
||||
config = {"local": {"model": "small"}}
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value=config), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="local"), \
|
||||
patch("tools.transcription_tools._transcribe_local",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_local:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_ogg, model=None)
|
||||
|
||||
assert mock_local.call_args[0][1] == "small"
|
||||
|
||||
def test_config_openai_model_used(self, sample_ogg):
|
||||
config = {"openai": {"model": "gpt-4o-transcribe"}}
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value=config), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openai"), \
|
||||
patch("tools.transcription_tools._transcribe_openai",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_openai:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_ogg, model=None)
|
||||
|
||||
assert mock_openai.call_args[0][1] == "gpt-4o-transcribe"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# get_stt_model_from_config
|
||||
# ============================================================================
|
||||
|
||||
class TestGetSttModelFromConfig:
|
||||
def test_returns_model_from_config(self, tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("stt:\n model: whisper-large-v3\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() == "whisper-large-v3"
|
||||
|
||||
def test_returns_none_when_no_stt_section(self, tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("tts:\n provider: edge\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() is None
|
||||
|
||||
def test_returns_none_when_no_config_file(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() is None
|
||||
|
||||
def test_returns_none_on_invalid_yaml(self, tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text(": : :\n bad yaml [[[")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() is None
|
||||
|
||||
def test_returns_none_when_model_key_missing(self, tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("stt:\n enabled: true\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() is None
|
||||
1233
tests/tools/test_voice_cli_integration.py
Normal file
1233
tests/tools/test_voice_cli_integration.py
Normal file
File diff suppressed because it is too large
Load diff
938
tests/tools/test_voice_mode.py
Normal file
938
tests/tools/test_voice_mode.py
Normal file
|
|
@ -0,0 +1,938 @@
|
|||
"""Tests for tools.voice_mode -- all mocked, no real microphone or API calls."""
|
||||
|
||||
import os
|
||||
import struct
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def sample_wav(tmp_path):
|
||||
"""Create a minimal valid WAV file (1 second of silence at 16kHz)."""
|
||||
wav_path = tmp_path / "test.wav"
|
||||
n_frames = 16000 # 1 second at 16kHz
|
||||
silence = struct.pack(f"<{n_frames}h", *([0] * n_frames))
|
||||
|
||||
with wave.open(str(wav_path), "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(16000)
|
||||
wf.writeframes(silence)
|
||||
|
||||
return str(wav_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_voice_dir(tmp_path, monkeypatch):
|
||||
"""Redirect _TEMP_DIR to a temporary path."""
|
||||
voice_dir = tmp_path / "hermes_voice"
|
||||
voice_dir.mkdir()
|
||||
monkeypatch.setattr("tools.voice_mode._TEMP_DIR", str(voice_dir))
|
||||
return voice_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sd(monkeypatch):
|
||||
"""Mock _import_audio to return (mock_sd, real_np) so lazy imports work."""
|
||||
mock = MagicMock()
|
||||
try:
|
||||
import numpy as real_np
|
||||
except ImportError:
|
||||
real_np = MagicMock()
|
||||
|
||||
def _fake_import_audio():
|
||||
return mock, real_np
|
||||
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio", _fake_import_audio)
|
||||
monkeypatch.setattr("tools.voice_mode._audio_available", lambda: True)
|
||||
return mock
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# check_voice_requirements
|
||||
# ============================================================================
|
||||
|
||||
class TestCheckVoiceRequirements:
|
||||
def test_all_requirements_met(self, monkeypatch):
|
||||
monkeypatch.setattr("tools.voice_mode._audio_available", lambda: True)
|
||||
monkeypatch.setattr("tools.voice_mode.detect_audio_environment",
|
||||
lambda: {"available": True, "warnings": []})
|
||||
monkeypatch.setattr("tools.transcription_tools._get_provider", lambda cfg: "openai")
|
||||
|
||||
from tools.voice_mode import check_voice_requirements
|
||||
|
||||
result = check_voice_requirements()
|
||||
assert result["available"] is True
|
||||
assert result["audio_available"] is True
|
||||
assert result["stt_available"] is True
|
||||
assert result["missing_packages"] == []
|
||||
|
||||
def test_missing_audio_packages(self, monkeypatch):
|
||||
monkeypatch.setattr("tools.voice_mode._audio_available", lambda: False)
|
||||
monkeypatch.setattr("tools.voice_mode.detect_audio_environment",
|
||||
lambda: {"available": False, "warnings": ["Audio libraries not installed"]})
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test-key")
|
||||
|
||||
from tools.voice_mode import check_voice_requirements
|
||||
|
||||
result = check_voice_requirements()
|
||||
assert result["available"] is False
|
||||
assert result["audio_available"] is False
|
||||
assert "sounddevice" in result["missing_packages"]
|
||||
assert "numpy" in result["missing_packages"]
|
||||
|
||||
def test_missing_stt_provider(self, monkeypatch):
|
||||
monkeypatch.setattr("tools.voice_mode._audio_available", lambda: True)
|
||||
monkeypatch.setattr("tools.voice_mode.detect_audio_environment",
|
||||
lambda: {"available": True, "warnings": []})
|
||||
monkeypatch.setattr("tools.transcription_tools._get_provider", lambda cfg: "none")
|
||||
|
||||
from tools.voice_mode import check_voice_requirements
|
||||
|
||||
result = check_voice_requirements()
|
||||
assert result["available"] is False
|
||||
assert result["stt_available"] is False
|
||||
assert "STT provider: MISSING" in result["details"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AudioRecorder
|
||||
# ============================================================================
|
||||
|
||||
class TestAudioRecorderStart:
|
||||
def test_start_raises_without_audio(self, monkeypatch):
|
||||
def _fail_import():
|
||||
raise ImportError("no sounddevice")
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import)
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
with pytest.raises(RuntimeError, match="sounddevice and numpy"):
|
||||
recorder.start()
|
||||
|
||||
def test_start_creates_and_starts_stream(self, mock_sd):
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
|
||||
assert recorder.is_recording is True
|
||||
mock_sd.InputStream.assert_called_once()
|
||||
mock_stream.start.assert_called_once()
|
||||
|
||||
def test_double_start_is_noop(self, mock_sd):
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
recorder.start() # second call should be noop
|
||||
|
||||
assert mock_sd.InputStream.call_count == 1
|
||||
|
||||
|
||||
class TestAudioRecorderStop:
|
||||
def test_stop_returns_none_when_not_recording(self):
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
assert recorder.stop() is None
|
||||
|
||||
def test_stop_writes_wav_file(self, mock_sd, temp_voice_dir):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder, SAMPLE_RATE
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
|
||||
# Simulate captured audio frames (1 second of loud audio above RMS threshold)
|
||||
frame = np.full((SAMPLE_RATE, 1), 1000, dtype="int16")
|
||||
recorder._frames = [frame]
|
||||
recorder._peak_rms = 1000 # Peak RMS above threshold
|
||||
|
||||
wav_path = recorder.stop()
|
||||
|
||||
assert wav_path is not None
|
||||
assert os.path.isfile(wav_path)
|
||||
assert wav_path.endswith(".wav")
|
||||
assert recorder.is_recording is False
|
||||
|
||||
# Verify it is a valid WAV
|
||||
with wave.open(wav_path, "rb") as wf:
|
||||
assert wf.getnchannels() == 1
|
||||
assert wf.getsampwidth() == 2
|
||||
assert wf.getframerate() == SAMPLE_RATE
|
||||
|
||||
def test_stop_returns_none_for_very_short_recording(self, mock_sd, temp_voice_dir):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
|
||||
# Very short recording (100 samples = ~6ms at 16kHz)
|
||||
frame = np.zeros((100, 1), dtype="int16")
|
||||
recorder._frames = [frame]
|
||||
|
||||
wav_path = recorder.stop()
|
||||
assert wav_path is None
|
||||
|
||||
def test_stop_returns_none_for_silent_recording(self, mock_sd, temp_voice_dir):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder, SAMPLE_RATE
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
|
||||
# 1 second of near-silence (RMS well below threshold)
|
||||
frame = np.full((SAMPLE_RATE, 1), 10, dtype="int16")
|
||||
recorder._frames = [frame]
|
||||
recorder._peak_rms = 10 # Peak RMS also below threshold
|
||||
|
||||
wav_path = recorder.stop()
|
||||
assert wav_path is None
|
||||
|
||||
|
||||
class TestAudioRecorderCancel:
|
||||
def test_cancel_discards_frames(self, mock_sd):
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
recorder._frames = [MagicMock()] # simulate captured data
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
assert recorder.is_recording is False
|
||||
assert recorder._frames == []
|
||||
# Stream is kept alive (persistent) — cancel() does NOT close it.
|
||||
mock_stream.stop.assert_not_called()
|
||||
mock_stream.close.assert_not_called()
|
||||
|
||||
def test_cancel_when_not_recording_is_safe(self):
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.cancel() # should not raise
|
||||
assert recorder.is_recording is False
|
||||
|
||||
|
||||
class TestAudioRecorderProperties:
|
||||
def test_elapsed_seconds_when_not_recording(self):
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
assert recorder.elapsed_seconds == 0.0
|
||||
|
||||
def test_elapsed_seconds_when_recording(self, mock_sd):
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
|
||||
# Force start time to 1 second ago
|
||||
recorder._start_time = time.monotonic() - 1.0
|
||||
elapsed = recorder.elapsed_seconds
|
||||
assert 0.9 < elapsed < 2.0
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# transcribe_recording
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeRecording:
|
||||
def test_delegates_to_transcribe_audio(self):
|
||||
mock_transcribe = MagicMock(return_value={
|
||||
"success": True,
|
||||
"transcript": "hello world",
|
||||
})
|
||||
|
||||
with patch("tools.transcription_tools.transcribe_audio", mock_transcribe):
|
||||
from tools.voice_mode import transcribe_recording
|
||||
result = transcribe_recording("/tmp/test.wav", model="whisper-1")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "hello world"
|
||||
mock_transcribe.assert_called_once_with("/tmp/test.wav", model="whisper-1")
|
||||
|
||||
def test_filters_whisper_hallucination(self):
|
||||
mock_transcribe = MagicMock(return_value={
|
||||
"success": True,
|
||||
"transcript": "Thank you.",
|
||||
})
|
||||
|
||||
with patch("tools.transcription_tools.transcribe_audio", mock_transcribe):
|
||||
from tools.voice_mode import transcribe_recording
|
||||
result = transcribe_recording("/tmp/test.wav")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == ""
|
||||
assert result["filtered"] is True
|
||||
|
||||
def test_does_not_filter_real_speech(self):
|
||||
mock_transcribe = MagicMock(return_value={
|
||||
"success": True,
|
||||
"transcript": "Thank you for helping me with this code.",
|
||||
})
|
||||
|
||||
with patch("tools.transcription_tools.transcribe_audio", mock_transcribe):
|
||||
from tools.voice_mode import transcribe_recording
|
||||
result = transcribe_recording("/tmp/test.wav")
|
||||
|
||||
assert result["transcript"] == "Thank you for helping me with this code."
|
||||
assert "filtered" not in result
|
||||
|
||||
|
||||
class TestWhisperHallucinationFilter:
|
||||
def test_known_hallucinations(self):
|
||||
from tools.voice_mode import is_whisper_hallucination
|
||||
|
||||
assert is_whisper_hallucination("Thank you.") is True
|
||||
assert is_whisper_hallucination("thank you") is True
|
||||
assert is_whisper_hallucination("Thanks for watching.") is True
|
||||
assert is_whisper_hallucination("Bye.") is True
|
||||
assert is_whisper_hallucination(" Thank you. ") is True # with whitespace
|
||||
assert is_whisper_hallucination("you") is True
|
||||
|
||||
def test_real_speech_not_filtered(self):
|
||||
from tools.voice_mode import is_whisper_hallucination
|
||||
|
||||
assert is_whisper_hallucination("Hello, how are you?") is False
|
||||
assert is_whisper_hallucination("Thank you for your help with the project.") is False
|
||||
assert is_whisper_hallucination("Can you explain this code?") is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# play_audio_file
|
||||
# ============================================================================
|
||||
|
||||
class TestPlayAudioFile:
|
||||
def test_play_wav_via_sounddevice(self, monkeypatch, sample_wav):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_sd_obj = MagicMock()
|
||||
# Simulate stream completing immediately (get_stream().active = False)
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.active = False
|
||||
mock_sd_obj.get_stream.return_value = mock_stream
|
||||
|
||||
def _fake_import():
|
||||
return mock_sd_obj, np
|
||||
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio", _fake_import)
|
||||
|
||||
from tools.voice_mode import play_audio_file
|
||||
|
||||
result = play_audio_file(sample_wav)
|
||||
|
||||
assert result is True
|
||||
mock_sd_obj.play.assert_called_once()
|
||||
mock_sd_obj.stop.assert_called_once()
|
||||
|
||||
def test_returns_false_when_no_player(self, monkeypatch, sample_wav):
|
||||
def _fail_import():
|
||||
raise ImportError("no sounddevice")
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import)
|
||||
monkeypatch.setattr("shutil.which", lambda _: None)
|
||||
|
||||
from tools.voice_mode import play_audio_file
|
||||
|
||||
result = play_audio_file(sample_wav)
|
||||
assert result is False
|
||||
|
||||
def test_returns_false_for_missing_file(self):
|
||||
from tools.voice_mode import play_audio_file
|
||||
|
||||
result = play_audio_file("/nonexistent/file.wav")
|
||||
assert result is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# cleanup_temp_recordings
|
||||
# ============================================================================
|
||||
|
||||
class TestCleanupTempRecordings:
|
||||
def test_old_files_deleted(self, temp_voice_dir):
|
||||
# Create an "old" file
|
||||
old_file = temp_voice_dir / "recording_20240101_000000.wav"
|
||||
old_file.write_bytes(b"\x00" * 100)
|
||||
# Set mtime to 2 hours ago
|
||||
old_mtime = time.time() - 7200
|
||||
os.utime(str(old_file), (old_mtime, old_mtime))
|
||||
|
||||
from tools.voice_mode import cleanup_temp_recordings
|
||||
|
||||
deleted = cleanup_temp_recordings(max_age_seconds=3600)
|
||||
assert deleted == 1
|
||||
assert not old_file.exists()
|
||||
|
||||
def test_recent_files_preserved(self, temp_voice_dir):
|
||||
# Create a "recent" file
|
||||
recent_file = temp_voice_dir / "recording_20260303_120000.wav"
|
||||
recent_file.write_bytes(b"\x00" * 100)
|
||||
|
||||
from tools.voice_mode import cleanup_temp_recordings
|
||||
|
||||
deleted = cleanup_temp_recordings(max_age_seconds=3600)
|
||||
assert deleted == 0
|
||||
assert recent_file.exists()
|
||||
|
||||
def test_nonexistent_dir_returns_zero(self, monkeypatch):
|
||||
monkeypatch.setattr("tools.voice_mode._TEMP_DIR", "/nonexistent/dir")
|
||||
|
||||
from tools.voice_mode import cleanup_temp_recordings
|
||||
|
||||
assert cleanup_temp_recordings() == 0
|
||||
|
||||
def test_non_recording_files_ignored(self, temp_voice_dir):
|
||||
# Create a file that doesn't match the pattern
|
||||
other_file = temp_voice_dir / "other_file.txt"
|
||||
other_file.write_bytes(b"\x00" * 100)
|
||||
old_mtime = time.time() - 7200
|
||||
os.utime(str(other_file), (old_mtime, old_mtime))
|
||||
|
||||
from tools.voice_mode import cleanup_temp_recordings
|
||||
|
||||
deleted = cleanup_temp_recordings(max_age_seconds=3600)
|
||||
assert deleted == 0
|
||||
assert other_file.exists()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# play_beep
|
||||
# ============================================================================
|
||||
|
||||
class TestPlayBeep:
|
||||
def test_beep_calls_sounddevice_play(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
from tools.voice_mode import play_beep
|
||||
|
||||
# play_beep uses polling (get_stream) + sd.stop() instead of sd.wait()
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.active = False
|
||||
mock_sd.get_stream.return_value = mock_stream
|
||||
|
||||
play_beep(frequency=880, duration=0.1, count=1)
|
||||
|
||||
mock_sd.play.assert_called_once()
|
||||
mock_sd.stop.assert_called()
|
||||
# Verify audio data is int16 numpy array
|
||||
audio_arg = mock_sd.play.call_args[0][0]
|
||||
assert audio_arg.dtype == np.int16
|
||||
assert len(audio_arg) > 0
|
||||
|
||||
def test_beep_double_produces_longer_audio(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
from tools.voice_mode import play_beep
|
||||
|
||||
play_beep(frequency=660, duration=0.1, count=2)
|
||||
|
||||
audio_arg = mock_sd.play.call_args[0][0]
|
||||
single_beep_samples = int(16000 * 0.1)
|
||||
# Double beep should be longer than a single beep
|
||||
assert len(audio_arg) > single_beep_samples
|
||||
|
||||
def test_beep_noop_without_audio(self, monkeypatch):
|
||||
def _fail_import():
|
||||
raise ImportError("no sounddevice")
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import)
|
||||
|
||||
from tools.voice_mode import play_beep
|
||||
|
||||
# Should not raise
|
||||
play_beep()
|
||||
|
||||
def test_beep_handles_playback_error(self, mock_sd):
|
||||
mock_sd.play.side_effect = Exception("device error")
|
||||
|
||||
from tools.voice_mode import play_beep
|
||||
|
||||
# Should not raise
|
||||
play_beep()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Silence detection
|
||||
# ============================================================================
|
||||
|
||||
class TestSilenceDetection:
|
||||
def test_silence_callback_fires_after_speech_then_silence(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
import threading
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder, SAMPLE_RATE
|
||||
|
||||
recorder = AudioRecorder()
|
||||
# Use very short durations for testing
|
||||
recorder._silence_duration = 0.05
|
||||
recorder._min_speech_duration = 0.05
|
||||
|
||||
fired = threading.Event()
|
||||
|
||||
def on_silence():
|
||||
fired.set()
|
||||
|
||||
recorder.start(on_silence_stop=on_silence)
|
||||
|
||||
# Get the callback function from InputStream constructor
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
# Simulate sustained speech (multiple loud chunks to exceed min_speech_duration)
|
||||
loud_frame = np.full((1600, 1), 5000, dtype="int16")
|
||||
callback(loud_frame, 1600, None, None)
|
||||
time.sleep(0.06)
|
||||
callback(loud_frame, 1600, None, None)
|
||||
assert recorder._has_spoken is True
|
||||
|
||||
# Simulate silence
|
||||
silent_frame = np.zeros((1600, 1), dtype="int16")
|
||||
callback(silent_frame, 1600, None, None)
|
||||
|
||||
# Wait a bit past the silence duration, then send another silent frame
|
||||
time.sleep(0.06)
|
||||
callback(silent_frame, 1600, None, None)
|
||||
|
||||
# The callback should have been fired
|
||||
assert fired.wait(timeout=1.0) is True
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
def test_silence_without_speech_does_not_fire(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
import threading
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder._silence_duration = 0.02
|
||||
|
||||
fired = threading.Event()
|
||||
recorder.start(on_silence_stop=lambda: fired.set())
|
||||
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
# Only silence -- no speech detected, so callback should NOT fire
|
||||
silent_frame = np.zeros((1600, 1), dtype="int16")
|
||||
for _ in range(5):
|
||||
callback(silent_frame, 1600, None, None)
|
||||
time.sleep(0.01)
|
||||
|
||||
assert fired.wait(timeout=0.2) is False
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
def test_micro_pause_tolerance_during_speech(self, mock_sd):
|
||||
"""Brief dips below threshold during speech should NOT reset speech tracking."""
|
||||
np = pytest.importorskip("numpy")
|
||||
import threading
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder._silence_duration = 0.05
|
||||
recorder._min_speech_duration = 0.15
|
||||
recorder._max_dip_tolerance = 0.1
|
||||
|
||||
fired = threading.Event()
|
||||
recorder.start(on_silence_stop=lambda: fired.set())
|
||||
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
loud_frame = np.full((1600, 1), 5000, dtype="int16")
|
||||
quiet_frame = np.full((1600, 1), 50, dtype="int16")
|
||||
|
||||
# Speech chunk 1
|
||||
callback(loud_frame, 1600, None, None)
|
||||
time.sleep(0.05)
|
||||
# Brief micro-pause (dip < max_dip_tolerance)
|
||||
callback(quiet_frame, 1600, None, None)
|
||||
time.sleep(0.05)
|
||||
# Speech resumes -- speech_start should NOT have been reset
|
||||
callback(loud_frame, 1600, None, None)
|
||||
assert recorder._speech_start > 0, "Speech start should be preserved across brief dips"
|
||||
time.sleep(0.06)
|
||||
# Another speech chunk to exceed min_speech_duration
|
||||
callback(loud_frame, 1600, None, None)
|
||||
assert recorder._has_spoken is True, "Speech should be confirmed after tolerating micro-pause"
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
def test_no_callback_means_no_silence_detection(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start() # no on_silence_stop
|
||||
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
# Even with speech then silence, nothing should happen
|
||||
loud_frame = np.full((1600, 1), 5000, dtype="int16")
|
||||
silent_frame = np.zeros((1600, 1), dtype="int16")
|
||||
callback(loud_frame, 1600, None, None)
|
||||
callback(silent_frame, 1600, None, None)
|
||||
|
||||
# No crash, no callback
|
||||
assert recorder._on_silence_stop is None
|
||||
recorder.cancel()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Playback interrupt
|
||||
# ============================================================================
|
||||
|
||||
class TestPlaybackInterrupt:
|
||||
"""Verify that TTS playback can be interrupted."""
|
||||
|
||||
def test_stop_playback_terminates_process(self):
|
||||
from tools.voice_mode import stop_playback, _playback_lock
|
||||
import tools.voice_mode as vm
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = None # process is running
|
||||
|
||||
with _playback_lock:
|
||||
vm._active_playback = mock_proc
|
||||
|
||||
stop_playback()
|
||||
|
||||
mock_proc.terminate.assert_called_once()
|
||||
|
||||
with _playback_lock:
|
||||
assert vm._active_playback is None
|
||||
|
||||
def test_stop_playback_noop_when_nothing_playing(self):
|
||||
import tools.voice_mode as vm
|
||||
|
||||
with vm._playback_lock:
|
||||
vm._active_playback = None
|
||||
|
||||
vm.stop_playback()
|
||||
|
||||
def test_play_audio_file_sets_active_playback(self, monkeypatch, sample_wav):
|
||||
import tools.voice_mode as vm
|
||||
|
||||
def _fail_import():
|
||||
raise ImportError("no sounddevice")
|
||||
monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import)
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait.return_value = 0
|
||||
|
||||
mock_popen = MagicMock(return_value=mock_proc)
|
||||
monkeypatch.setattr("subprocess.Popen", mock_popen)
|
||||
monkeypatch.setattr("shutil.which", lambda cmd: "/usr/bin/" + cmd)
|
||||
|
||||
vm.play_audio_file(sample_wav)
|
||||
|
||||
assert mock_popen.called
|
||||
with vm._playback_lock:
|
||||
assert vm._active_playback is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Continuous mode flow
|
||||
# ============================================================================
|
||||
|
||||
class TestContinuousModeFlow:
|
||||
"""Verify continuous mode: auto-restart after transcription or silence."""
|
||||
|
||||
def test_continuous_restart_on_no_speech(self, mock_sd, temp_voice_dir):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
|
||||
# First recording: only silence -> stop returns None
|
||||
recorder.start()
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
for _ in range(10):
|
||||
silence = np.full((1600, 1), 10, dtype="int16")
|
||||
callback(silence, 1600, None, None)
|
||||
|
||||
wav_path = recorder.stop()
|
||||
assert wav_path is None
|
||||
|
||||
# Simulate continuous mode restart
|
||||
recorder.start()
|
||||
assert recorder.is_recording is True
|
||||
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
for _ in range(10):
|
||||
speech = np.full((1600, 1), 5000, dtype="int16")
|
||||
callback(speech, 1600, None, None)
|
||||
|
||||
wav_path = recorder.stop()
|
||||
assert wav_path is not None
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
def test_recorder_reusable_after_stop(self, mock_sd, temp_voice_dir):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
results = []
|
||||
|
||||
for i in range(3):
|
||||
recorder.start()
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
loud = np.full((1600, 1), 5000, dtype="int16")
|
||||
for _ in range(10):
|
||||
callback(loud, 1600, None, None)
|
||||
wav_path = recorder.stop()
|
||||
results.append(wav_path)
|
||||
|
||||
assert all(r is not None for r in results)
|
||||
assert os.path.isfile(results[-1])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Audio level indicator
|
||||
# ============================================================================
|
||||
|
||||
class TestAudioLevelIndicator:
|
||||
"""Verify current_rms property updates in real-time for UI feedback."""
|
||||
|
||||
def test_rms_updates_with_audio_chunks(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
assert recorder.current_rms == 0
|
||||
|
||||
loud = np.full((1600, 1), 5000, dtype="int16")
|
||||
callback(loud, 1600, None, None)
|
||||
assert recorder.current_rms == 5000
|
||||
|
||||
quiet = np.full((1600, 1), 100, dtype="int16")
|
||||
callback(quiet, 1600, None, None)
|
||||
assert recorder.current_rms == 100
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
def test_peak_rms_tracks_maximum(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start()
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
frames = [
|
||||
np.full((1600, 1), 100, dtype="int16"),
|
||||
np.full((1600, 1), 8000, dtype="int16"),
|
||||
np.full((1600, 1), 500, dtype="int16"),
|
||||
np.full((1600, 1), 3000, dtype="int16"),
|
||||
]
|
||||
for frame in frames:
|
||||
callback(frame, 1600, None, None)
|
||||
|
||||
assert recorder._peak_rms == 8000
|
||||
assert recorder.current_rms == 3000
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configurable silence parameters
|
||||
# ============================================================================
|
||||
|
||||
class TestConfigurableSilenceParams:
|
||||
"""Verify that silence detection params can be configured."""
|
||||
|
||||
def test_custom_threshold_and_duration(self, mock_sd):
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
import threading
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder._silence_threshold = 5000
|
||||
recorder._silence_duration = 0.05
|
||||
recorder._min_speech_duration = 0.05
|
||||
|
||||
fired = threading.Event()
|
||||
recorder.start(on_silence_stop=lambda: fired.set())
|
||||
callback = mock_sd.InputStream.call_args.kwargs.get("callback")
|
||||
if callback is None:
|
||||
callback = mock_sd.InputStream.call_args[1]["callback"]
|
||||
|
||||
# Audio at RMS 1000 -- below custom threshold (5000)
|
||||
moderate = np.full((1600, 1), 1000, dtype="int16")
|
||||
for _ in range(5):
|
||||
callback(moderate, 1600, None, None)
|
||||
time.sleep(0.02)
|
||||
|
||||
assert recorder._has_spoken is False
|
||||
assert fired.wait(timeout=0.2) is False
|
||||
|
||||
# Now send really loud audio (above 5000 threshold)
|
||||
very_loud = np.full((1600, 1), 8000, dtype="int16")
|
||||
callback(very_loud, 1600, None, None)
|
||||
time.sleep(0.06)
|
||||
callback(very_loud, 1600, None, None)
|
||||
assert recorder._has_spoken is True
|
||||
|
||||
recorder.cancel()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Bugfix regression tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestSubprocessTimeoutKill:
|
||||
"""Bug: proc.wait(timeout) raised TimeoutExpired but process was not killed."""
|
||||
|
||||
def test_timeout_kills_process(self):
|
||||
import subprocess, os
|
||||
proc = subprocess.Popen(["sleep", "600"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
pid = proc.pid
|
||||
assert proc.poll() is None
|
||||
|
||||
try:
|
||||
proc.wait(timeout=0.1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
|
||||
assert proc.poll() is not None
|
||||
assert proc.returncode is not None
|
||||
|
||||
|
||||
class TestStreamLeakOnStartFailure:
|
||||
"""Bug: stream.start() failure left stream unclosed."""
|
||||
|
||||
def test_stream_closed_on_start_failure(self, mock_sd):
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.start.side_effect = OSError("Audio device busy")
|
||||
mock_sd.InputStream.return_value = mock_stream
|
||||
|
||||
from tools.voice_mode import AudioRecorder
|
||||
recorder = AudioRecorder()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to open audio input stream"):
|
||||
recorder._ensure_stream()
|
||||
|
||||
mock_stream.close.assert_called_once()
|
||||
|
||||
|
||||
class TestSilenceCallbackLock:
|
||||
"""Bug: _on_silence_stop was read/written without lock in audio callback."""
|
||||
|
||||
def test_fire_block_acquires_lock(self):
|
||||
import inspect
|
||||
from tools.voice_mode import AudioRecorder
|
||||
|
||||
source = inspect.getsource(AudioRecorder._ensure_stream)
|
||||
# Verify lock is used before reading _on_silence_stop in fire block
|
||||
assert "with self._lock:" in source
|
||||
assert "cb = self._on_silence_stop" in source
|
||||
lock_pos = source.index("with self._lock:")
|
||||
cb_pos = source.index("cb = self._on_silence_stop")
|
||||
assert lock_pos < cb_pos
|
||||
|
||||
def test_cancel_clears_callback_under_lock(self, mock_sd):
|
||||
from tools.voice_mode import AudioRecorder
|
||||
recorder = AudioRecorder()
|
||||
mock_sd.InputStream.return_value = MagicMock()
|
||||
|
||||
cb = lambda: None
|
||||
recorder.start(on_silence_stop=cb)
|
||||
assert recorder._on_silence_stop is cb
|
||||
|
||||
recorder.cancel()
|
||||
with recorder._lock:
|
||||
assert recorder._on_silence_stop is None
|
||||
|
|
@ -224,24 +224,14 @@ def _emergency_cleanup_all_sessions():
|
|||
logger.error("Emergency cleanup error: %s", e)
|
||||
|
||||
|
||||
def _signal_handler(signum, frame):
|
||||
"""Handle interrupt signals to cleanup sessions before exit."""
|
||||
logger.warning("Received signal %s, cleaning up...", signum)
|
||||
_emergency_cleanup_all_sessions()
|
||||
sys.exit(128 + signum)
|
||||
|
||||
|
||||
# Register cleanup handlers
|
||||
# Register cleanup via atexit only. Previous versions installed SIGINT/SIGTERM
|
||||
# handlers that called sys.exit(), but this conflicts with prompt_toolkit's
|
||||
# async event loop — a SystemExit raised inside a key-binding callback
|
||||
# corrupts the coroutine state and makes the process unkillable. atexit
|
||||
# handlers run on any normal exit (including sys.exit), so browser sessions
|
||||
# are still cleaned up without hijacking signals.
|
||||
atexit.register(_emergency_cleanup_all_sessions)
|
||||
|
||||
# Only register signal handlers in main process (not in multiprocessing workers)
|
||||
try:
|
||||
if os.getpid() == os.getpgrp(): # Main process check
|
||||
signal.signal(signal.SIGINT, _signal_handler)
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
except (OSError, AttributeError):
|
||||
pass # Signal handling not available (e.g., Windows or worker process)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Inactivity Cleanup Functions
|
||||
|
|
|
|||
|
|
@ -2,11 +2,12 @@
|
|||
"""
|
||||
Transcription Tools Module
|
||||
|
||||
Provides speech-to-text transcription with two providers:
|
||||
Provides speech-to-text transcription with three providers:
|
||||
|
||||
- **local** (default, free) — faster-whisper running locally, no API key needed.
|
||||
Auto-downloads the model (~150 MB for ``base``) on first use.
|
||||
- **openai** — OpenAI Whisper API, requires ``VOICE_TOOLS_OPENAI_KEY``.
|
||||
- **groq** (free tier) — Groq Whisper API, requires ``GROQ_API_KEY``.
|
||||
- **openai** (paid) — OpenAI Whisper API, requires ``VOICE_TOOLS_OPENAI_KEY``.
|
||||
|
||||
Used by the messaging gateway to automatically transcribe voice messages
|
||||
sent by users on Telegram, Discord, WhatsApp, Slack, and Signal.
|
||||
|
|
@ -33,18 +34,9 @@ logger = logging.getLogger(__name__)
|
|||
# Optional imports — graceful degradation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
_HAS_FASTER_WHISPER = True
|
||||
except ImportError:
|
||||
_HAS_FASTER_WHISPER = False
|
||||
WhisperModel = None # type: ignore[assignment,misc]
|
||||
|
||||
try:
|
||||
from openai import OpenAI, APIError, APIConnectionError, APITimeoutError
|
||||
_HAS_OPENAI = True
|
||||
except ImportError:
|
||||
_HAS_OPENAI = False
|
||||
import importlib.util as _ilu
|
||||
_HAS_FASTER_WHISPER = _ilu.find_spec("faster_whisper") is not None
|
||||
_HAS_OPENAI = _ilu.find_spec("openai") is not None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
|
|
@ -52,13 +44,21 @@ except ImportError:
|
|||
|
||||
DEFAULT_PROVIDER = "local"
|
||||
DEFAULT_LOCAL_MODEL = "base"
|
||||
DEFAULT_OPENAI_MODEL = "whisper-1"
|
||||
DEFAULT_STT_MODEL = os.getenv("STT_OPENAI_MODEL", "whisper-1")
|
||||
DEFAULT_GROQ_STT_MODEL = os.getenv("STT_GROQ_MODEL", "whisper-large-v3-turbo")
|
||||
|
||||
GROQ_BASE_URL = os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
|
||||
OPENAI_BASE_URL = os.getenv("STT_OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
|
||||
SUPPORTED_FORMATS = {".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm", ".ogg"}
|
||||
MAX_FILE_SIZE = 25 * 1024 * 1024 # 25 MB
|
||||
|
||||
# Known model sets for auto-correction
|
||||
OPENAI_MODELS = {"whisper-1", "gpt-4o-mini-transcribe", "gpt-4o-transcribe"}
|
||||
GROQ_MODELS = {"whisper-large-v3", "whisper-large-v3-turbo", "distil-whisper-large-v3-en"}
|
||||
|
||||
# Singleton for the local model — loaded once, reused across calls
|
||||
_local_model: Optional["WhisperModel"] = None
|
||||
_local_model: Optional[object] = None
|
||||
_local_model_name: Optional[str] = None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -66,6 +66,24 @@ _local_model_name: Optional[str] = None
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_stt_model_from_config() -> Optional[str]:
|
||||
"""Read the STT model name from ~/.hermes/config.yaml.
|
||||
|
||||
Returns the value of ``stt.model`` if present, otherwise ``None``.
|
||||
Silently returns ``None`` on any error (missing file, bad YAML, etc.).
|
||||
"""
|
||||
try:
|
||||
import yaml
|
||||
cfg_path = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "config.yaml"
|
||||
if cfg_path.exists():
|
||||
with open(cfg_path) as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
return data.get("stt", {}).get("model")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _load_stt_config() -> dict:
|
||||
"""Load the ``stt`` section from user config, falling back to defaults."""
|
||||
try:
|
||||
|
|
@ -80,7 +98,7 @@ def _get_provider(stt_config: dict) -> str:
|
|||
|
||||
Priority:
|
||||
1. Explicit config value (``stt.provider``)
|
||||
2. Auto-detect: local if faster-whisper available, else openai if key set
|
||||
2. Auto-detect: local > groq (free) > openai (paid)
|
||||
3. Disabled (returns "none")
|
||||
"""
|
||||
provider = stt_config.get("provider", DEFAULT_PROVIDER)
|
||||
|
|
@ -88,19 +106,37 @@ def _get_provider(stt_config: dict) -> str:
|
|||
if provider == "local":
|
||||
if _HAS_FASTER_WHISPER:
|
||||
return "local"
|
||||
# Local requested but not available — fall back to openai if possible
|
||||
# Local requested but not available — fall back to groq, then openai
|
||||
if _HAS_OPENAI and os.getenv("GROQ_API_KEY"):
|
||||
logger.info("faster-whisper not installed, falling back to Groq Whisper API")
|
||||
return "groq"
|
||||
if _HAS_OPENAI and os.getenv("VOICE_TOOLS_OPENAI_KEY"):
|
||||
logger.info("faster-whisper not installed, falling back to OpenAI Whisper API")
|
||||
return "openai"
|
||||
return "none"
|
||||
|
||||
if provider == "groq":
|
||||
if _HAS_OPENAI and os.getenv("GROQ_API_KEY"):
|
||||
return "groq"
|
||||
# Groq requested but no key — fall back
|
||||
if _HAS_FASTER_WHISPER:
|
||||
logger.info("GROQ_API_KEY not set, falling back to local faster-whisper")
|
||||
return "local"
|
||||
if _HAS_OPENAI and os.getenv("VOICE_TOOLS_OPENAI_KEY"):
|
||||
logger.info("GROQ_API_KEY not set, falling back to OpenAI Whisper API")
|
||||
return "openai"
|
||||
return "none"
|
||||
|
||||
if provider == "openai":
|
||||
if _HAS_OPENAI and os.getenv("VOICE_TOOLS_OPENAI_KEY"):
|
||||
return "openai"
|
||||
# OpenAI requested but no key — fall back to local if possible
|
||||
# OpenAI requested but no key — fall back
|
||||
if _HAS_FASTER_WHISPER:
|
||||
logger.info("VOICE_TOOLS_OPENAI_KEY not set, falling back to local faster-whisper")
|
||||
return "local"
|
||||
if _HAS_OPENAI and os.getenv("GROQ_API_KEY"):
|
||||
logger.info("VOICE_TOOLS_OPENAI_KEY not set, falling back to Groq Whisper API")
|
||||
return "groq"
|
||||
return "none"
|
||||
|
||||
return provider # Unknown — let it fail downstream
|
||||
|
|
@ -150,6 +186,7 @@ def _transcribe_local(file_path: str, model_name: str) -> Dict[str, Any]:
|
|||
return {"success": False, "transcript": "", "error": "faster-whisper not installed"}
|
||||
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
# Lazy-load the model (downloads on first use, ~150 MB for 'base')
|
||||
if _local_model is None or _local_model_name != model_name:
|
||||
logger.info("Loading faster-whisper model '%s' (first load downloads the model)...", model_name)
|
||||
|
|
@ -164,12 +201,60 @@ def _transcribe_local(file_path: str, model_name: str) -> Dict[str, Any]:
|
|||
Path(file_path).name, model_name, info.language, info.duration,
|
||||
)
|
||||
|
||||
return {"success": True, "transcript": transcript}
|
||||
return {"success": True, "transcript": transcript, "provider": "local"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Local transcription failed: %s", e, exc_info=True)
|
||||
return {"success": False, "transcript": "", "error": f"Local transcription failed: {e}"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider: groq (Whisper API — free tier)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _transcribe_groq(file_path: str, model_name: str) -> Dict[str, Any]:
|
||||
"""Transcribe using Groq Whisper API (free tier available)."""
|
||||
api_key = os.getenv("GROQ_API_KEY")
|
||||
if not api_key:
|
||||
return {"success": False, "transcript": "", "error": "GROQ_API_KEY not set"}
|
||||
|
||||
if not _HAS_OPENAI:
|
||||
return {"success": False, "transcript": "", "error": "openai package not installed"}
|
||||
|
||||
# Auto-correct model if caller passed an OpenAI-only model
|
||||
if model_name in OPENAI_MODELS:
|
||||
logger.info("Model %s not available on Groq, using %s", model_name, DEFAULT_GROQ_STT_MODEL)
|
||||
model_name = DEFAULT_GROQ_STT_MODEL
|
||||
|
||||
try:
|
||||
from openai import OpenAI, APIError, APIConnectionError, APITimeoutError
|
||||
client = OpenAI(api_key=api_key, base_url=GROQ_BASE_URL, timeout=30, max_retries=0)
|
||||
|
||||
with open(file_path, "rb") as audio_file:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
model=model_name,
|
||||
file=audio_file,
|
||||
response_format="text",
|
||||
)
|
||||
|
||||
transcript_text = str(transcription).strip()
|
||||
logger.info("Transcribed %s via Groq API (%s, %d chars)",
|
||||
Path(file_path).name, model_name, len(transcript_text))
|
||||
|
||||
return {"success": True, "transcript": transcript_text, "provider": "groq"}
|
||||
|
||||
except PermissionError:
|
||||
return {"success": False, "transcript": "", "error": f"Permission denied: {file_path}"}
|
||||
except APIConnectionError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Connection error: {e}"}
|
||||
except APITimeoutError as e:
|
||||
return {"success": False, "transcript": "", "error": f"Request timeout: {e}"}
|
||||
except APIError as e:
|
||||
return {"success": False, "transcript": "", "error": f"API error: {e}"}
|
||||
except Exception as e:
|
||||
logger.error("Groq transcription failed: %s", e, exc_info=True)
|
||||
return {"success": False, "transcript": "", "error": f"Transcription failed: {e}"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider: openai (Whisper API)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -184,8 +269,14 @@ def _transcribe_openai(file_path: str, model_name: str) -> Dict[str, Any]:
|
|||
if not _HAS_OPENAI:
|
||||
return {"success": False, "transcript": "", "error": "openai package not installed"}
|
||||
|
||||
# Auto-correct model if caller passed a Groq-only model
|
||||
if model_name in GROQ_MODELS:
|
||||
logger.info("Model %s not available on OpenAI, using %s", model_name, DEFAULT_STT_MODEL)
|
||||
model_name = DEFAULT_STT_MODEL
|
||||
|
||||
try:
|
||||
client = OpenAI(api_key=api_key, base_url="https://api.openai.com/v1")
|
||||
from openai import OpenAI, APIError, APIConnectionError, APITimeoutError
|
||||
client = OpenAI(api_key=api_key, base_url=OPENAI_BASE_URL, timeout=30, max_retries=0)
|
||||
|
||||
with open(file_path, "rb") as audio_file:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
|
|
@ -198,7 +289,7 @@ def _transcribe_openai(file_path: str, model_name: str) -> Dict[str, Any]:
|
|||
logger.info("Transcribed %s via OpenAI API (%s, %d chars)",
|
||||
Path(file_path).name, model_name, len(transcript_text))
|
||||
|
||||
return {"success": True, "transcript": transcript_text}
|
||||
return {"success": True, "transcript": transcript_text, "provider": "openai"}
|
||||
|
||||
except PermissionError:
|
||||
return {"success": False, "transcript": "", "error": f"Permission denied: {file_path}"}
|
||||
|
|
@ -223,7 +314,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
|||
|
||||
Provider priority:
|
||||
1. User config (``stt.provider`` in config.yaml)
|
||||
2. Auto-detect: local faster-whisper if available, else OpenAI API
|
||||
2. Auto-detect: local faster-whisper (free) > Groq (free tier) > OpenAI (paid)
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the audio file to transcribe.
|
||||
|
|
@ -234,6 +325,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
|||
- "success" (bool): Whether transcription succeeded
|
||||
- "transcript" (str): The transcribed text (empty on failure)
|
||||
- "error" (str, optional): Error message if success is False
|
||||
- "provider" (str, optional): Which provider was used
|
||||
"""
|
||||
# Validate input
|
||||
error = _validate_audio_file(file_path)
|
||||
|
|
@ -249,9 +341,13 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
|||
model_name = model or local_cfg.get("model", DEFAULT_LOCAL_MODEL)
|
||||
return _transcribe_local(file_path, model_name)
|
||||
|
||||
if provider == "groq":
|
||||
model_name = model or DEFAULT_GROQ_STT_MODEL
|
||||
return _transcribe_groq(file_path, model_name)
|
||||
|
||||
if provider == "openai":
|
||||
openai_cfg = stt_config.get("openai", {})
|
||||
model_name = model or openai_cfg.get("model", DEFAULT_OPENAI_MODEL)
|
||||
model_name = model or openai_cfg.get("model", DEFAULT_STT_MODEL)
|
||||
return _transcribe_openai(file_path, model_name)
|
||||
|
||||
# No provider available
|
||||
|
|
@ -260,6 +356,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
|||
"transcript": "",
|
||||
"error": (
|
||||
"No STT provider available. Install faster-whisper for free local "
|
||||
"transcription, or set VOICE_TOOLS_OPENAI_KEY for the OpenAI Whisper API."
|
||||
"transcription, set GROQ_API_KEY for free Groq Whisper, "
|
||||
"or set VOICE_TOOLS_OPENAI_KEY for the OpenAI Whisper API."
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,35 +25,41 @@ import datetime
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Callable, Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optional imports -- providers degrade gracefully if not installed
|
||||
# Lazy imports -- providers are imported only when actually used to avoid
|
||||
# crashing in headless environments (SSH, Docker, WSL, no PortAudio).
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
|
||||
def _import_edge_tts():
|
||||
"""Lazy import edge_tts. Returns the module or raises ImportError."""
|
||||
import edge_tts
|
||||
_HAS_EDGE_TTS = True
|
||||
except ImportError:
|
||||
_HAS_EDGE_TTS = False
|
||||
return edge_tts
|
||||
|
||||
try:
|
||||
def _import_elevenlabs():
|
||||
"""Lazy import ElevenLabs client. Returns the class or raises ImportError."""
|
||||
from elevenlabs.client import ElevenLabs
|
||||
_HAS_ELEVENLABS = True
|
||||
except ImportError:
|
||||
_HAS_ELEVENLABS = False
|
||||
return ElevenLabs
|
||||
|
||||
# openai is a core dependency, but guard anyway
|
||||
try:
|
||||
def _import_openai_client():
|
||||
"""Lazy import OpenAI client. Returns the class or raises ImportError."""
|
||||
from openai import OpenAI as OpenAIClient
|
||||
_HAS_OPENAI = True
|
||||
except ImportError:
|
||||
_HAS_OPENAI = False
|
||||
return OpenAIClient
|
||||
|
||||
def _import_sounddevice():
|
||||
"""Lazy import sounddevice. Returns the module or raises ImportError/OSError."""
|
||||
import sounddevice as sd
|
||||
return sd
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
|
|
@ -63,6 +69,7 @@ DEFAULT_PROVIDER = "edge"
|
|||
DEFAULT_EDGE_VOICE = "en-US-AriaNeural"
|
||||
DEFAULT_ELEVENLABS_VOICE_ID = "pNInz6obpgDQGcFmaJgB" # Adam
|
||||
DEFAULT_ELEVENLABS_MODEL_ID = "eleven_multilingual_v2"
|
||||
DEFAULT_ELEVENLABS_STREAMING_MODEL_ID = "eleven_flash_v2_5"
|
||||
DEFAULT_OPENAI_MODEL = "gpt-4o-mini-tts"
|
||||
DEFAULT_OPENAI_VOICE = "alloy"
|
||||
DEFAULT_OUTPUT_DIR = str(Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "audio_cache")
|
||||
|
|
@ -154,10 +161,11 @@ async def _generate_edge_tts(text: str, output_path: str, tts_config: Dict[str,
|
|||
Returns:
|
||||
Path to the saved audio file.
|
||||
"""
|
||||
_edge_tts = _import_edge_tts()
|
||||
edge_config = tts_config.get("edge", {})
|
||||
voice = edge_config.get("voice", DEFAULT_EDGE_VOICE)
|
||||
|
||||
communicate = edge_tts.Communicate(text, voice)
|
||||
communicate = _edge_tts.Communicate(text, voice)
|
||||
await communicate.save(output_path)
|
||||
return output_path
|
||||
|
||||
|
|
@ -191,6 +199,7 @@ def _generate_elevenlabs(text: str, output_path: str, tts_config: Dict[str, Any]
|
|||
else:
|
||||
output_format = "mp3_44100_128"
|
||||
|
||||
ElevenLabs = _import_elevenlabs()
|
||||
client = ElevenLabs(api_key=api_key)
|
||||
audio_generator = client.text_to_speech.convert(
|
||||
text=text,
|
||||
|
|
@ -236,6 +245,7 @@ def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any]
|
|||
else:
|
||||
response_format = "mp3"
|
||||
|
||||
OpenAIClient = _import_openai_client()
|
||||
client = OpenAIClient(api_key=api_key, base_url="https://api.openai.com/v1")
|
||||
response = client.audio.speech.create(
|
||||
model=model,
|
||||
|
|
@ -311,7 +321,9 @@ def text_to_speech_tool(
|
|||
try:
|
||||
# Generate audio with the configured provider
|
||||
if provider == "elevenlabs":
|
||||
if not _HAS_ELEVENLABS:
|
||||
try:
|
||||
_import_elevenlabs()
|
||||
except ImportError:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "ElevenLabs provider selected but 'elevenlabs' package not installed. Run: pip install elevenlabs"
|
||||
|
|
@ -320,7 +332,9 @@ def text_to_speech_tool(
|
|||
_generate_elevenlabs(text, file_str, tts_config)
|
||||
|
||||
elif provider == "openai":
|
||||
if not _HAS_OPENAI:
|
||||
try:
|
||||
_import_openai_client()
|
||||
except ImportError:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "OpenAI provider selected but 'openai' package not installed."
|
||||
|
|
@ -330,7 +344,9 @@ def text_to_speech_tool(
|
|||
|
||||
else:
|
||||
# Default: Edge TTS (free)
|
||||
if not _HAS_EDGE_TTS:
|
||||
try:
|
||||
_import_edge_tts()
|
||||
except ImportError:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "Edge TTS not available. Run: pip install edge-tts"
|
||||
|
|
@ -411,15 +427,262 @@ def check_tts_requirements() -> bool:
|
|||
Returns:
|
||||
bool: True if at least one provider can work.
|
||||
"""
|
||||
if _HAS_EDGE_TTS:
|
||||
return True
|
||||
if _HAS_ELEVENLABS and os.getenv("ELEVENLABS_API_KEY"):
|
||||
return True
|
||||
if _HAS_OPENAI and os.getenv("VOICE_TOOLS_OPENAI_KEY"):
|
||||
try:
|
||||
_import_edge_tts()
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
_import_elevenlabs()
|
||||
if os.getenv("ELEVENLABS_API_KEY"):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
_import_openai_client()
|
||||
if os.getenv("VOICE_TOOLS_OPENAI_KEY"):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Streaming TTS: sentence-by-sentence pipeline for ElevenLabs
|
||||
# ===========================================================================
|
||||
# Sentence boundary pattern: punctuation followed by space or newline
|
||||
_SENTENCE_BOUNDARY_RE = re.compile(r'(?<=[.!?])(?:\s|\n)|(?:\n\n)')
|
||||
|
||||
# Markdown stripping patterns (same as cli.py _voice_speak_response)
|
||||
_MD_CODE_BLOCK = re.compile(r'```[\s\S]*?```')
|
||||
_MD_LINK = re.compile(r'\[([^\]]+)\]\([^)]+\)')
|
||||
_MD_URL = re.compile(r'https?://\S+')
|
||||
_MD_BOLD = re.compile(r'\*\*(.+?)\*\*')
|
||||
_MD_ITALIC = re.compile(r'\*(.+?)\*')
|
||||
_MD_INLINE_CODE = re.compile(r'`(.+?)`')
|
||||
_MD_HEADER = re.compile(r'^#+\s*', flags=re.MULTILINE)
|
||||
_MD_LIST_ITEM = re.compile(r'^\s*[-*]\s+', flags=re.MULTILINE)
|
||||
_MD_HR = re.compile(r'---+')
|
||||
_MD_EXCESS_NL = re.compile(r'\n{3,}')
|
||||
|
||||
|
||||
def _strip_markdown_for_tts(text: str) -> str:
|
||||
"""Remove markdown formatting that shouldn't be spoken aloud."""
|
||||
text = _MD_CODE_BLOCK.sub(' ', text)
|
||||
text = _MD_LINK.sub(r'\1', text)
|
||||
text = _MD_URL.sub('', text)
|
||||
text = _MD_BOLD.sub(r'\1', text)
|
||||
text = _MD_ITALIC.sub(r'\1', text)
|
||||
text = _MD_INLINE_CODE.sub(r'\1', text)
|
||||
text = _MD_HEADER.sub('', text)
|
||||
text = _MD_LIST_ITEM.sub('', text)
|
||||
text = _MD_HR.sub('', text)
|
||||
text = _MD_EXCESS_NL.sub('\n\n', text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def stream_tts_to_speaker(
|
||||
text_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
tts_done_event: threading.Event,
|
||||
display_callback: Optional[Callable[[str], None]] = None,
|
||||
):
|
||||
"""Consume text deltas from *text_queue*, buffer them into sentences,
|
||||
and stream each sentence through ElevenLabs TTS to the speaker in
|
||||
real-time.
|
||||
|
||||
Protocol:
|
||||
* The producer puts ``str`` deltas onto *text_queue*.
|
||||
* A ``None`` sentinel signals end-of-text (flush remaining buffer).
|
||||
* *stop_event* can be set to abort early (e.g. user interrupt).
|
||||
* *tts_done_event* is **set** in the ``finally`` block so callers
|
||||
waiting on it (continuous voice mode) know playback is finished.
|
||||
"""
|
||||
tts_done_event.clear()
|
||||
|
||||
try:
|
||||
# --- TTS client setup (optional -- display_callback works without it) ---
|
||||
client = None
|
||||
output_stream = None
|
||||
voice_id = DEFAULT_ELEVENLABS_VOICE_ID
|
||||
model_id = DEFAULT_ELEVENLABS_STREAMING_MODEL_ID
|
||||
|
||||
tts_config = _load_tts_config()
|
||||
el_config = tts_config.get("elevenlabs", {})
|
||||
voice_id = el_config.get("voice_id", voice_id)
|
||||
model_id = el_config.get("streaming_model_id",
|
||||
el_config.get("model_id", model_id))
|
||||
|
||||
api_key = os.getenv("ELEVENLABS_API_KEY", "")
|
||||
if not api_key:
|
||||
logger.warning("ELEVENLABS_API_KEY not set; streaming TTS audio disabled")
|
||||
else:
|
||||
try:
|
||||
ElevenLabs = _import_elevenlabs()
|
||||
client = ElevenLabs(api_key=api_key)
|
||||
except ImportError:
|
||||
logger.warning("elevenlabs package not installed; streaming TTS disabled")
|
||||
|
||||
# Open a single sounddevice output stream for the lifetime of
|
||||
# this function. ElevenLabs pcm_24000 produces signed 16-bit
|
||||
# little-endian mono PCM at 24 kHz.
|
||||
if client is not None:
|
||||
try:
|
||||
sd = _import_sounddevice()
|
||||
import numpy as _np
|
||||
output_stream = sd.OutputStream(
|
||||
samplerate=24000, channels=1, dtype="int16",
|
||||
)
|
||||
output_stream.start()
|
||||
except (ImportError, OSError) as exc:
|
||||
logger.debug("sounddevice not available: %s", exc)
|
||||
output_stream = None
|
||||
except Exception as exc:
|
||||
logger.warning("sounddevice OutputStream failed: %s", exc)
|
||||
output_stream = None
|
||||
|
||||
sentence_buf = ""
|
||||
min_sentence_len = 20
|
||||
long_flush_len = 100
|
||||
queue_timeout = 0.5
|
||||
_spoken_sentences: list[str] = [] # track spoken sentences to skip duplicates
|
||||
# Regex to strip complete <think>...</think> blocks from buffer
|
||||
_think_block_re = re.compile(r'<think[\s>].*?</think>', flags=re.DOTALL)
|
||||
|
||||
def _speak_sentence(sentence: str):
|
||||
"""Display sentence and optionally generate + play audio."""
|
||||
if stop_event.is_set():
|
||||
return
|
||||
cleaned = _strip_markdown_for_tts(sentence).strip()
|
||||
if not cleaned:
|
||||
return
|
||||
# Skip duplicate/near-duplicate sentences (LLM repetition)
|
||||
cleaned_lower = cleaned.lower().rstrip(".!,")
|
||||
for prev in _spoken_sentences:
|
||||
if prev.lower().rstrip(".!,") == cleaned_lower:
|
||||
return
|
||||
_spoken_sentences.append(cleaned)
|
||||
# Display raw sentence on screen before TTS processing
|
||||
if display_callback is not None:
|
||||
display_callback(sentence)
|
||||
# Skip audio generation if no TTS client available
|
||||
if client is None:
|
||||
return
|
||||
# Truncate very long sentences
|
||||
if len(cleaned) > MAX_TEXT_LENGTH:
|
||||
cleaned = cleaned[:MAX_TEXT_LENGTH]
|
||||
try:
|
||||
audio_iter = client.text_to_speech.convert(
|
||||
text=cleaned,
|
||||
voice_id=voice_id,
|
||||
model_id=model_id,
|
||||
output_format="pcm_24000",
|
||||
)
|
||||
if output_stream is not None:
|
||||
for chunk in audio_iter:
|
||||
if stop_event.is_set():
|
||||
break
|
||||
import numpy as _np
|
||||
audio_array = _np.frombuffer(chunk, dtype=_np.int16)
|
||||
output_stream.write(audio_array.reshape(-1, 1))
|
||||
else:
|
||||
# Fallback: write chunks to temp file and play via system player
|
||||
_play_via_tempfile(audio_iter, stop_event)
|
||||
except Exception as exc:
|
||||
logger.warning("Streaming TTS sentence failed: %s", exc)
|
||||
|
||||
def _play_via_tempfile(audio_iter, stop_evt):
|
||||
"""Write PCM chunks to a temp WAV file and play it."""
|
||||
tmp_path = None
|
||||
try:
|
||||
import wave
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
||||
tmp_path = tmp.name
|
||||
with wave.open(tmp, "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2) # 16-bit
|
||||
wf.setframerate(24000)
|
||||
for chunk in audio_iter:
|
||||
if stop_evt.is_set():
|
||||
break
|
||||
wf.writeframes(chunk)
|
||||
from tools.voice_mode import play_audio_file
|
||||
play_audio_file(tmp_path)
|
||||
except Exception as exc:
|
||||
logger.warning("Temp-file TTS fallback failed: %s", exc)
|
||||
finally:
|
||||
if tmp_path:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
while not stop_event.is_set():
|
||||
# Read next delta from queue
|
||||
try:
|
||||
delta = text_queue.get(timeout=queue_timeout)
|
||||
except queue.Empty:
|
||||
# Timeout: if we have accumulated a long buffer, flush it
|
||||
if len(sentence_buf) > long_flush_len:
|
||||
_speak_sentence(sentence_buf)
|
||||
sentence_buf = ""
|
||||
continue
|
||||
|
||||
if delta is None:
|
||||
# End-of-text sentinel: strip any remaining think blocks, flush
|
||||
sentence_buf = _think_block_re.sub('', sentence_buf)
|
||||
if sentence_buf.strip():
|
||||
_speak_sentence(sentence_buf)
|
||||
break
|
||||
|
||||
sentence_buf += delta
|
||||
|
||||
# --- Think block filtering ---
|
||||
# Strip complete <think>...</think> blocks from buffer.
|
||||
# Works correctly even when tags span multiple deltas.
|
||||
sentence_buf = _think_block_re.sub('', sentence_buf)
|
||||
|
||||
# If an incomplete <think tag is at the end, wait for more data
|
||||
# before extracting sentences (the closing tag may arrive next).
|
||||
if '<think' in sentence_buf and '</think>' not in sentence_buf:
|
||||
continue
|
||||
|
||||
# Check for sentence boundaries
|
||||
while True:
|
||||
m = _SENTENCE_BOUNDARY_RE.search(sentence_buf)
|
||||
if m is None:
|
||||
break
|
||||
end_pos = m.end()
|
||||
sentence = sentence_buf[:end_pos]
|
||||
sentence_buf = sentence_buf[end_pos:]
|
||||
# Merge short fragments into the next sentence
|
||||
if len(sentence.strip()) < min_sentence_len:
|
||||
sentence_buf = sentence + sentence_buf
|
||||
break
|
||||
_speak_sentence(sentence)
|
||||
|
||||
# Drain any remaining items from the queue
|
||||
while True:
|
||||
try:
|
||||
text_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# output_stream is closed in the finally block below
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Streaming TTS pipeline error: %s", exc)
|
||||
finally:
|
||||
# Always close the audio output stream to avoid locking the device
|
||||
if output_stream is not None:
|
||||
try:
|
||||
output_stream.stop()
|
||||
output_stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
tts_done_event.set()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Main -- quick diagnostics
|
||||
# ===========================================================================
|
||||
|
|
@ -427,12 +690,19 @@ if __name__ == "__main__":
|
|||
print("🔊 Text-to-Speech Tool Module")
|
||||
print("=" * 50)
|
||||
|
||||
def _check(importer, label):
|
||||
try:
|
||||
importer()
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
print(f"\nProvider availability:")
|
||||
print(f" Edge TTS: {'✅ installed' if _HAS_EDGE_TTS else '❌ not installed (pip install edge-tts)'}")
|
||||
print(f" ElevenLabs: {'✅ installed' if _HAS_ELEVENLABS else '❌ not installed (pip install elevenlabs)'}")
|
||||
print(f" API Key: {'✅ set' if os.getenv('ELEVENLABS_API_KEY') else '❌ not set'}")
|
||||
print(f" OpenAI: {'✅ installed' if _HAS_OPENAI else '❌ not installed'}")
|
||||
print(f" API Key: {'✅ set' if os.getenv('VOICE_TOOLS_OPENAI_KEY') else '❌ not set (VOICE_TOOLS_OPENAI_KEY)'}")
|
||||
print(f" Edge TTS: {'installed' if _check(_import_edge_tts, 'edge') else 'not installed (pip install edge-tts)'}")
|
||||
print(f" ElevenLabs: {'installed' if _check(_import_elevenlabs, 'el') else 'not installed (pip install elevenlabs)'}")
|
||||
print(f" API Key: {'set' if os.getenv('ELEVENLABS_API_KEY') else 'not set'}")
|
||||
print(f" OpenAI: {'installed' if _check(_import_openai_client, 'oai') else 'not installed'}")
|
||||
print(f" API Key: {'set' if os.getenv('VOICE_TOOLS_OPENAI_KEY') else 'not set (VOICE_TOOLS_OPENAI_KEY)'}")
|
||||
print(f" ffmpeg: {'✅ found' if _has_ffmpeg() else '❌ not found (needed for Telegram Opus)'}")
|
||||
print(f"\n Output dir: {DEFAULT_OUTPUT_DIR}")
|
||||
|
||||
|
|
|
|||
783
tools/voice_mode.py
Normal file
783
tools/voice_mode.py
Normal file
|
|
@ -0,0 +1,783 @@
|
|||
"""Voice Mode -- Push-to-talk audio recording and playback for the CLI.
|
||||
|
||||
Provides audio capture via sounddevice, WAV encoding via stdlib wave,
|
||||
STT dispatch via tools.transcription_tools, and TTS playback via
|
||||
sounddevice or system audio players.
|
||||
|
||||
Dependencies (optional):
|
||||
pip install sounddevice numpy
|
||||
or: pip install hermes-agent[voice]
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lazy audio imports -- never imported at module level to avoid crashing
|
||||
# in headless environments (SSH, Docker, WSL, no PortAudio).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _import_audio():
|
||||
"""Lazy-import sounddevice and numpy. Returns (sd, np).
|
||||
|
||||
Raises ImportError or OSError if the libraries are not available
|
||||
(e.g. PortAudio missing on headless servers).
|
||||
"""
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
return sd, np
|
||||
|
||||
|
||||
def _audio_available() -> bool:
|
||||
"""Return True if audio libraries can be imported."""
|
||||
try:
|
||||
_import_audio()
|
||||
return True
|
||||
except (ImportError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def detect_audio_environment() -> dict:
|
||||
"""Detect if the current environment supports audio I/O.
|
||||
|
||||
Returns dict with 'available' (bool) and 'warnings' (list of strings).
|
||||
"""
|
||||
warnings = []
|
||||
|
||||
# SSH detection
|
||||
if any(os.environ.get(v) for v in ('SSH_CLIENT', 'SSH_TTY', 'SSH_CONNECTION')):
|
||||
warnings.append("Running over SSH -- no audio devices available")
|
||||
|
||||
# Docker detection
|
||||
if os.path.exists('/.dockerenv'):
|
||||
warnings.append("Running inside Docker container -- no audio devices")
|
||||
|
||||
# WSL detection
|
||||
try:
|
||||
with open('/proc/version', 'r') as f:
|
||||
if 'microsoft' in f.read().lower():
|
||||
warnings.append("Running in WSL -- audio requires PulseAudio bridge to Windows")
|
||||
except (FileNotFoundError, PermissionError, OSError):
|
||||
pass
|
||||
|
||||
# Check audio libraries
|
||||
try:
|
||||
sd, _ = _import_audio()
|
||||
try:
|
||||
devices = sd.query_devices()
|
||||
if not devices:
|
||||
warnings.append("No audio input/output devices detected")
|
||||
except Exception:
|
||||
warnings.append("Audio subsystem error (PortAudio cannot query devices)")
|
||||
except (ImportError, OSError):
|
||||
warnings.append("Audio libraries not installed (pip install sounddevice numpy)")
|
||||
|
||||
return {
|
||||
"available": len(warnings) == 0,
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recording parameters
|
||||
# ---------------------------------------------------------------------------
|
||||
SAMPLE_RATE = 16000 # Whisper native rate
|
||||
CHANNELS = 1 # Mono
|
||||
DTYPE = "int16" # 16-bit PCM
|
||||
SAMPLE_WIDTH = 2 # bytes per sample (int16)
|
||||
MAX_RECORDING_SECONDS = 120 # Safety cap
|
||||
|
||||
# Silence detection defaults
|
||||
SILENCE_RMS_THRESHOLD = 200 # RMS below this = silence (int16 range 0-32767)
|
||||
SILENCE_DURATION_SECONDS = 3.0 # Seconds of continuous silence before auto-stop
|
||||
|
||||
# Temp directory for voice recordings
|
||||
_TEMP_DIR = os.path.join(tempfile.gettempdir(), "hermes_voice")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Audio cues (beep tones)
|
||||
# ============================================================================
|
||||
def play_beep(frequency: int = 880, duration: float = 0.12, count: int = 1) -> None:
|
||||
"""Play a short beep tone using numpy + sounddevice.
|
||||
|
||||
Args:
|
||||
frequency: Tone frequency in Hz (default 880 = A5).
|
||||
duration: Duration of each beep in seconds.
|
||||
count: Number of beeps to play (with short gap between).
|
||||
"""
|
||||
try:
|
||||
sd, np = _import_audio()
|
||||
except (ImportError, OSError):
|
||||
return
|
||||
try:
|
||||
gap = 0.06 # seconds between beeps
|
||||
samples_per_beep = int(SAMPLE_RATE * duration)
|
||||
samples_per_gap = int(SAMPLE_RATE * gap)
|
||||
|
||||
parts = []
|
||||
for i in range(count):
|
||||
t = np.linspace(0, duration, samples_per_beep, endpoint=False)
|
||||
# Apply fade in/out to avoid click artifacts
|
||||
tone = np.sin(2 * np.pi * frequency * t)
|
||||
fade_len = min(int(SAMPLE_RATE * 0.01), samples_per_beep // 4)
|
||||
tone[:fade_len] *= np.linspace(0, 1, fade_len)
|
||||
tone[-fade_len:] *= np.linspace(1, 0, fade_len)
|
||||
parts.append((tone * 0.3 * 32767).astype(np.int16))
|
||||
if i < count - 1:
|
||||
parts.append(np.zeros(samples_per_gap, dtype=np.int16))
|
||||
|
||||
audio = np.concatenate(parts)
|
||||
sd.play(audio, samplerate=SAMPLE_RATE)
|
||||
# sd.wait() calls Event.wait() without timeout — hangs forever if the
|
||||
# audio device stalls. Poll with a 2s ceiling and force-stop.
|
||||
deadline = time.monotonic() + 2.0
|
||||
while sd.get_stream() and sd.get_stream().active and time.monotonic() < deadline:
|
||||
time.sleep(0.01)
|
||||
sd.stop()
|
||||
except Exception as e:
|
||||
logger.debug("Beep playback failed: %s", e)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AudioRecorder
|
||||
# ============================================================================
|
||||
class AudioRecorder:
|
||||
"""Thread-safe audio recorder using sounddevice.InputStream.
|
||||
|
||||
Usage::
|
||||
|
||||
recorder = AudioRecorder()
|
||||
recorder.start(on_silence_stop=my_callback)
|
||||
# ... user speaks ...
|
||||
wav_path = recorder.stop() # returns path to WAV file
|
||||
# or
|
||||
recorder.cancel() # discard without saving
|
||||
|
||||
If ``on_silence_stop`` is provided, recording automatically stops when
|
||||
the user is silent for ``silence_duration`` seconds and calls the callback.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._stream: Any = None
|
||||
self._frames: List[Any] = []
|
||||
self._recording = False
|
||||
self._start_time: float = 0.0
|
||||
# Silence detection state
|
||||
self._has_spoken = False
|
||||
self._speech_start: float = 0.0 # When speech attempt began
|
||||
self._dip_start: float = 0.0 # When current below-threshold dip began
|
||||
self._min_speech_duration: float = 0.3 # Seconds of speech needed to confirm
|
||||
self._max_dip_tolerance: float = 0.3 # Max dip duration before resetting speech
|
||||
self._silence_start: float = 0.0
|
||||
self._resume_start: float = 0.0 # Tracks sustained speech after silence starts
|
||||
self._resume_dip_start: float = 0.0 # Dip tolerance tracker for resume detection
|
||||
self._on_silence_stop = None
|
||||
self._silence_threshold: int = SILENCE_RMS_THRESHOLD
|
||||
self._silence_duration: float = SILENCE_DURATION_SECONDS
|
||||
self._max_wait: float = 15.0 # Max seconds to wait for speech before auto-stop
|
||||
# Peak RMS seen during recording (for speech presence check in stop())
|
||||
self._peak_rms: int = 0
|
||||
# Live audio level (read by UI for visual feedback)
|
||||
self._current_rms: int = 0
|
||||
|
||||
# -- public properties ---------------------------------------------------
|
||||
|
||||
@property
|
||||
def is_recording(self) -> bool:
|
||||
return self._recording
|
||||
|
||||
@property
|
||||
def elapsed_seconds(self) -> float:
|
||||
if not self._recording:
|
||||
return 0.0
|
||||
return time.monotonic() - self._start_time
|
||||
|
||||
@property
|
||||
def current_rms(self) -> int:
|
||||
"""Current audio input RMS level (0-32767). Updated each audio chunk."""
|
||||
return self._current_rms
|
||||
|
||||
# -- public methods ------------------------------------------------------
|
||||
|
||||
def _ensure_stream(self) -> None:
|
||||
"""Create the audio InputStream once and keep it alive.
|
||||
|
||||
The stream stays open for the lifetime of the recorder. Between
|
||||
recordings the callback simply discards audio chunks (``_recording``
|
||||
is ``False``). This avoids the CoreAudio bug where closing and
|
||||
re-opening an ``InputStream`` hangs indefinitely on macOS.
|
||||
"""
|
||||
if self._stream is not None:
|
||||
return # already alive
|
||||
|
||||
sd, np = _import_audio()
|
||||
|
||||
def _callback(indata, frames, time_info, status): # noqa: ARG001
|
||||
if status:
|
||||
logger.debug("sounddevice status: %s", status)
|
||||
# When not recording the stream is idle — discard audio.
|
||||
if not self._recording:
|
||||
return
|
||||
self._frames.append(indata.copy())
|
||||
|
||||
# Compute RMS for level display and silence detection
|
||||
rms = int(np.sqrt(np.mean(indata.astype(np.float64) ** 2)))
|
||||
self._current_rms = rms
|
||||
if rms > self._peak_rms:
|
||||
self._peak_rms = rms
|
||||
|
||||
# Silence detection
|
||||
if self._on_silence_stop is not None:
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._start_time
|
||||
|
||||
if rms > self._silence_threshold:
|
||||
# Audio is above threshold -- this is speech (or noise).
|
||||
self._dip_start = 0.0 # Reset dip tracker
|
||||
if self._speech_start == 0.0:
|
||||
self._speech_start = now
|
||||
elif not self._has_spoken and now - self._speech_start >= self._min_speech_duration:
|
||||
self._has_spoken = True
|
||||
logger.debug("Speech confirmed (%.2fs above threshold)",
|
||||
now - self._speech_start)
|
||||
# After speech is confirmed, only reset silence timer if
|
||||
# speech is sustained (>0.3s above threshold). Brief
|
||||
# spikes from ambient noise should NOT reset the timer.
|
||||
if not self._has_spoken:
|
||||
self._silence_start = 0.0
|
||||
else:
|
||||
# Track resumed speech with dip tolerance.
|
||||
# Brief dips below threshold are normal during speech,
|
||||
# so we mirror the initial speech detection pattern:
|
||||
# start tracking, tolerate short dips, confirm after 0.3s.
|
||||
self._resume_dip_start = 0.0 # Above threshold — no dip
|
||||
if self._resume_start == 0.0:
|
||||
self._resume_start = now
|
||||
elif now - self._resume_start >= self._min_speech_duration:
|
||||
self._silence_start = 0.0
|
||||
self._resume_start = 0.0
|
||||
elif self._has_spoken:
|
||||
# Below threshold after speech confirmed.
|
||||
# Use dip tolerance before resetting resume tracker —
|
||||
# natural speech has brief dips below threshold.
|
||||
if self._resume_start > 0:
|
||||
if self._resume_dip_start == 0.0:
|
||||
self._resume_dip_start = now
|
||||
elif now - self._resume_dip_start >= self._max_dip_tolerance:
|
||||
# Sustained dip — user actually stopped speaking
|
||||
self._resume_start = 0.0
|
||||
self._resume_dip_start = 0.0
|
||||
elif self._speech_start > 0:
|
||||
# We were in a speech attempt but RMS dipped.
|
||||
# Tolerate brief dips (micro-pauses between syllables).
|
||||
if self._dip_start == 0.0:
|
||||
self._dip_start = now
|
||||
elif now - self._dip_start >= self._max_dip_tolerance:
|
||||
# Dip lasted too long -- genuine silence, reset
|
||||
logger.debug("Speech attempt reset (dip lasted %.2fs)",
|
||||
now - self._dip_start)
|
||||
self._speech_start = 0.0
|
||||
self._dip_start = 0.0
|
||||
|
||||
# Fire silence callback when:
|
||||
# 1. User spoke then went silent for silence_duration, OR
|
||||
# 2. No speech detected at all for max_wait seconds
|
||||
should_fire = False
|
||||
if self._has_spoken and rms <= self._silence_threshold:
|
||||
# User was speaking and now is silent
|
||||
if self._silence_start == 0.0:
|
||||
self._silence_start = now
|
||||
elif now - self._silence_start >= self._silence_duration:
|
||||
logger.info("Silence detected (%.1fs), auto-stopping",
|
||||
self._silence_duration)
|
||||
should_fire = True
|
||||
elif not self._has_spoken and elapsed >= self._max_wait:
|
||||
logger.info("No speech within %.0fs, auto-stopping",
|
||||
self._max_wait)
|
||||
should_fire = True
|
||||
|
||||
if should_fire:
|
||||
with self._lock:
|
||||
cb = self._on_silence_stop
|
||||
self._on_silence_stop = None # fire only once
|
||||
if cb:
|
||||
def _safe_cb():
|
||||
try:
|
||||
cb()
|
||||
except Exception as e:
|
||||
logger.error("Silence callback failed: %s", e, exc_info=True)
|
||||
threading.Thread(target=_safe_cb, daemon=True).start()
|
||||
|
||||
# Create stream — may block on CoreAudio (first call only).
|
||||
stream = None
|
||||
try:
|
||||
stream = sd.InputStream(
|
||||
samplerate=SAMPLE_RATE,
|
||||
channels=CHANNELS,
|
||||
dtype=DTYPE,
|
||||
callback=_callback,
|
||||
)
|
||||
stream.start()
|
||||
except Exception as e:
|
||||
if stream is not None:
|
||||
try:
|
||||
stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
raise RuntimeError(
|
||||
f"Failed to open audio input stream: {e}. "
|
||||
"Check that a microphone is connected and accessible."
|
||||
) from e
|
||||
self._stream = stream
|
||||
|
||||
def start(self, on_silence_stop=None) -> None:
|
||||
"""Start capturing audio from the default input device.
|
||||
|
||||
The underlying InputStream is created once and kept alive across
|
||||
recordings. Subsequent calls simply reset detection state and
|
||||
toggle frame collection via ``_recording``.
|
||||
|
||||
Args:
|
||||
on_silence_stop: Optional callback invoked (in a daemon thread) when
|
||||
silence is detected after speech. The callback receives no arguments.
|
||||
Use this to auto-stop recording and trigger transcription.
|
||||
|
||||
Raises ``RuntimeError`` if sounddevice/numpy are not installed
|
||||
or if a recording is already in progress.
|
||||
"""
|
||||
try:
|
||||
_import_audio()
|
||||
except (ImportError, OSError) as e:
|
||||
raise RuntimeError(
|
||||
"Voice mode requires sounddevice and numpy.\n"
|
||||
"Install with: pip install sounddevice numpy\n"
|
||||
"Or: pip install hermes-agent[voice]"
|
||||
) from e
|
||||
|
||||
with self._lock:
|
||||
if self._recording:
|
||||
return # already recording
|
||||
|
||||
self._frames = []
|
||||
self._start_time = time.monotonic()
|
||||
self._has_spoken = False
|
||||
self._speech_start = 0.0
|
||||
self._dip_start = 0.0
|
||||
self._silence_start = 0.0
|
||||
self._resume_start = 0.0
|
||||
self._resume_dip_start = 0.0
|
||||
self._peak_rms = 0
|
||||
self._current_rms = 0
|
||||
self._on_silence_stop = on_silence_stop
|
||||
|
||||
# Ensure the persistent stream is alive (no-op after first call).
|
||||
self._ensure_stream()
|
||||
|
||||
with self._lock:
|
||||
self._recording = True
|
||||
logger.info("Voice recording started (rate=%d, channels=%d)", SAMPLE_RATE, CHANNELS)
|
||||
|
||||
def _close_stream_with_timeout(self, timeout: float = 3.0) -> None:
|
||||
"""Close the audio stream with a timeout to prevent CoreAudio hangs."""
|
||||
if self._stream is None:
|
||||
return
|
||||
|
||||
stream = self._stream
|
||||
self._stream = None
|
||||
|
||||
def _do_close():
|
||||
try:
|
||||
stream.stop()
|
||||
stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=_do_close, daemon=True)
|
||||
t.start()
|
||||
# Poll in short intervals so Ctrl+C is not blocked
|
||||
deadline = __import__("time").monotonic() + timeout
|
||||
while t.is_alive() and __import__("time").monotonic() < deadline:
|
||||
t.join(timeout=0.1)
|
||||
if t.is_alive():
|
||||
logger.warning("Audio stream close timed out after %.1fs — forcing ahead", timeout)
|
||||
|
||||
def stop(self) -> Optional[str]:
|
||||
"""Stop recording and write captured audio to a WAV file.
|
||||
|
||||
The underlying stream is kept alive for reuse — only frame
|
||||
collection is stopped.
|
||||
|
||||
Returns:
|
||||
Path to the WAV file, or ``None`` if no audio was captured.
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._recording:
|
||||
return None
|
||||
|
||||
self._recording = False
|
||||
self._current_rms = 0
|
||||
# Stream stays alive — no close needed.
|
||||
|
||||
if not self._frames:
|
||||
return None
|
||||
|
||||
# Concatenate frames and write WAV
|
||||
_, np = _import_audio()
|
||||
audio_data = np.concatenate(self._frames, axis=0)
|
||||
self._frames = []
|
||||
|
||||
elapsed = time.monotonic() - self._start_time
|
||||
logger.info("Voice recording stopped (%.1fs, %d samples)", elapsed, len(audio_data))
|
||||
|
||||
# Skip very short recordings (< 0.3s of audio)
|
||||
min_samples = int(SAMPLE_RATE * 0.3)
|
||||
if len(audio_data) < min_samples:
|
||||
logger.debug("Recording too short (%d samples), discarding", len(audio_data))
|
||||
return None
|
||||
|
||||
# Skip silent recordings using peak RMS (not overall average, which
|
||||
# gets diluted by silence at the end of the recording).
|
||||
if self._peak_rms < SILENCE_RMS_THRESHOLD:
|
||||
logger.info("Recording too quiet (peak RMS=%d < %d), discarding",
|
||||
self._peak_rms, SILENCE_RMS_THRESHOLD)
|
||||
return None
|
||||
|
||||
return self._write_wav(audio_data)
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Stop recording and discard all captured audio.
|
||||
|
||||
The underlying stream is kept alive for reuse.
|
||||
"""
|
||||
with self._lock:
|
||||
self._recording = False
|
||||
self._frames = []
|
||||
self._on_silence_stop = None
|
||||
self._current_rms = 0
|
||||
logger.info("Voice recording cancelled")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Release the audio stream. Call when voice mode is disabled."""
|
||||
with self._lock:
|
||||
self._recording = False
|
||||
self._frames = []
|
||||
self._on_silence_stop = None
|
||||
# Close stream OUTSIDE the lock to avoid deadlock with audio callback
|
||||
self._close_stream_with_timeout()
|
||||
logger.info("AudioRecorder shut down")
|
||||
|
||||
# -- private helpers -----------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _write_wav(audio_data) -> str:
|
||||
"""Write numpy int16 audio data to a WAV file.
|
||||
|
||||
Returns the file path.
|
||||
"""
|
||||
os.makedirs(_TEMP_DIR, exist_ok=True)
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
wav_path = os.path.join(_TEMP_DIR, f"recording_{timestamp}.wav")
|
||||
|
||||
with wave.open(wav_path, "wb") as wf:
|
||||
wf.setnchannels(CHANNELS)
|
||||
wf.setsampwidth(SAMPLE_WIDTH)
|
||||
wf.setframerate(SAMPLE_RATE)
|
||||
wf.writeframes(audio_data.tobytes())
|
||||
|
||||
file_size = os.path.getsize(wav_path)
|
||||
logger.info("WAV written: %s (%d bytes)", wav_path, file_size)
|
||||
return wav_path
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Whisper hallucination filter
|
||||
# ============================================================================
|
||||
# Whisper commonly hallucinates these phrases on silent/near-silent audio.
|
||||
WHISPER_HALLUCINATIONS = {
|
||||
"thank you.",
|
||||
"thank you",
|
||||
"thanks for watching.",
|
||||
"thanks for watching",
|
||||
"subscribe to my channel.",
|
||||
"subscribe to my channel",
|
||||
"like and subscribe.",
|
||||
"like and subscribe",
|
||||
"please subscribe.",
|
||||
"please subscribe",
|
||||
"thank you for watching.",
|
||||
"thank you for watching",
|
||||
"bye.",
|
||||
"bye",
|
||||
"you",
|
||||
"the end.",
|
||||
"the end",
|
||||
# Non-English hallucinations (common on silence)
|
||||
"продолжение следует",
|
||||
"продолжение следует...",
|
||||
"sous-titres",
|
||||
"sous-titres réalisés par la communauté d'amara.org",
|
||||
"sottotitoli creati dalla comunità amara.org",
|
||||
"untertitel von stephanie geiges",
|
||||
"amara.org",
|
||||
"www.mooji.org",
|
||||
"ご視聴ありがとうございました",
|
||||
}
|
||||
|
||||
# Regex patterns for repetitive hallucinations (e.g. "Thank you. Thank you. Thank you.")
|
||||
_HALLUCINATION_REPEAT_RE = re.compile(
|
||||
r'^(?:thank you|thanks|bye|you|ok|okay|the end|\.|\s|,|!)+$',
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def is_whisper_hallucination(transcript: str) -> bool:
|
||||
"""Check if a transcript is a known Whisper hallucination on silence."""
|
||||
cleaned = transcript.strip().lower()
|
||||
if not cleaned:
|
||||
return True
|
||||
# Exact match against known phrases
|
||||
if cleaned.rstrip('.!') in WHISPER_HALLUCINATIONS or cleaned in WHISPER_HALLUCINATIONS:
|
||||
return True
|
||||
# Repetitive patterns (e.g. "Thank you. Thank you. Thank you. you")
|
||||
if _HALLUCINATION_REPEAT_RE.match(cleaned):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# STT dispatch
|
||||
# ============================================================================
|
||||
def transcribe_recording(wav_path: str, model: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Transcribe a WAV recording using the existing Whisper pipeline.
|
||||
|
||||
Delegates to ``tools.transcription_tools.transcribe_audio()``.
|
||||
Filters out known Whisper hallucinations on silent audio.
|
||||
|
||||
Args:
|
||||
wav_path: Path to the WAV file.
|
||||
model: Whisper model name (default: from config or ``whisper-1``).
|
||||
|
||||
Returns:
|
||||
Dict with ``success``, ``transcript``, and optionally ``error``.
|
||||
"""
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
|
||||
result = transcribe_audio(wav_path, model=model)
|
||||
|
||||
# Filter out Whisper hallucinations (common on silent/near-silent audio)
|
||||
if result.get("success") and is_whisper_hallucination(result.get("transcript", "")):
|
||||
logger.info("Filtered Whisper hallucination: %r", result["transcript"])
|
||||
return {"success": True, "transcript": "", "filtered": True}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Audio playback (interruptable)
|
||||
# ============================================================================
|
||||
|
||||
# Global reference to the active playback process so it can be interrupted.
|
||||
_active_playback: Optional[subprocess.Popen] = None
|
||||
_playback_lock = threading.Lock()
|
||||
|
||||
|
||||
def stop_playback() -> None:
|
||||
"""Interrupt the currently playing audio (if any)."""
|
||||
global _active_playback
|
||||
with _playback_lock:
|
||||
proc = _active_playback
|
||||
_active_playback = None
|
||||
if proc and proc.poll() is None:
|
||||
try:
|
||||
proc.terminate()
|
||||
logger.info("Audio playback interrupted")
|
||||
except Exception:
|
||||
pass
|
||||
# Also stop sounddevice playback if active
|
||||
try:
|
||||
sd, _ = _import_audio()
|
||||
sd.stop()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def play_audio_file(file_path: str) -> bool:
|
||||
"""Play an audio file through the default output device.
|
||||
|
||||
Strategy:
|
||||
1. WAV files via ``sounddevice.play()`` when available.
|
||||
2. System commands: ``afplay`` (macOS), ``ffplay`` (cross-platform),
|
||||
``aplay`` (Linux ALSA).
|
||||
|
||||
Playback can be interrupted by calling ``stop_playback()``.
|
||||
|
||||
Returns:
|
||||
``True`` if playback succeeded, ``False`` otherwise.
|
||||
"""
|
||||
global _active_playback
|
||||
|
||||
if not os.path.isfile(file_path):
|
||||
logger.warning("Audio file not found: %s", file_path)
|
||||
return False
|
||||
|
||||
# Try sounddevice for WAV files
|
||||
if file_path.endswith(".wav"):
|
||||
try:
|
||||
sd, np = _import_audio()
|
||||
with wave.open(file_path, "rb") as wf:
|
||||
frames = wf.readframes(wf.getnframes())
|
||||
audio_data = np.frombuffer(frames, dtype=np.int16)
|
||||
sample_rate = wf.getframerate()
|
||||
|
||||
sd.play(audio_data, samplerate=sample_rate)
|
||||
# sd.wait() calls Event.wait() without timeout — hangs forever if
|
||||
# the audio device stalls. Poll with a ceiling and force-stop.
|
||||
duration_secs = len(audio_data) / sample_rate
|
||||
deadline = time.monotonic() + duration_secs + 2.0
|
||||
while sd.get_stream() and sd.get_stream().active and time.monotonic() < deadline:
|
||||
time.sleep(0.01)
|
||||
sd.stop()
|
||||
return True
|
||||
except (ImportError, OSError):
|
||||
pass # audio libs not available, fall through to system players
|
||||
except Exception as e:
|
||||
logger.debug("sounddevice playback failed: %s", e)
|
||||
|
||||
# Fall back to system audio players (using Popen for interruptability)
|
||||
system = platform.system()
|
||||
players = []
|
||||
|
||||
if system == "Darwin":
|
||||
players.append(["afplay", file_path])
|
||||
players.append(["ffplay", "-nodisp", "-autoexit", "-loglevel", "quiet", file_path])
|
||||
if system == "Linux":
|
||||
players.append(["aplay", "-q", file_path])
|
||||
|
||||
for cmd in players:
|
||||
exe = shutil.which(cmd[0])
|
||||
if exe:
|
||||
try:
|
||||
proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
with _playback_lock:
|
||||
_active_playback = proc
|
||||
proc.wait(timeout=300)
|
||||
with _playback_lock:
|
||||
_active_playback = None
|
||||
return True
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("System player %s timed out, killing process", cmd[0])
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
with _playback_lock:
|
||||
_active_playback = None
|
||||
except Exception as e:
|
||||
logger.debug("System player %s failed: %s", cmd[0], e)
|
||||
with _playback_lock:
|
||||
_active_playback = None
|
||||
|
||||
logger.warning("No audio player available for %s", file_path)
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Requirements check
|
||||
# ============================================================================
|
||||
def check_voice_requirements() -> Dict[str, Any]:
|
||||
"""Check if all voice mode requirements are met.
|
||||
|
||||
Returns:
|
||||
Dict with ``available``, ``audio_available``, ``stt_available``,
|
||||
``missing_packages``, and ``details``.
|
||||
"""
|
||||
# Determine STT provider availability
|
||||
from tools.transcription_tools import _get_provider, _load_stt_config, _HAS_FASTER_WHISPER
|
||||
stt_config = _load_stt_config()
|
||||
stt_provider = _get_provider(stt_config)
|
||||
stt_available = stt_provider != "none"
|
||||
|
||||
missing: List[str] = []
|
||||
has_audio = _audio_available()
|
||||
|
||||
if not has_audio:
|
||||
missing.extend(["sounddevice", "numpy"])
|
||||
|
||||
# Environment detection
|
||||
env_check = detect_audio_environment()
|
||||
|
||||
available = has_audio and stt_available and env_check["available"]
|
||||
details_parts = []
|
||||
|
||||
if has_audio:
|
||||
details_parts.append("Audio capture: OK")
|
||||
else:
|
||||
details_parts.append("Audio capture: MISSING (pip install sounddevice numpy)")
|
||||
|
||||
if stt_provider == "local":
|
||||
details_parts.append("STT provider: OK (local faster-whisper)")
|
||||
elif stt_provider == "groq":
|
||||
details_parts.append("STT provider: OK (Groq)")
|
||||
elif stt_provider == "openai":
|
||||
details_parts.append("STT provider: OK (OpenAI)")
|
||||
else:
|
||||
details_parts.append(
|
||||
"STT provider: MISSING (pip install faster-whisper, "
|
||||
"or set GROQ_API_KEY / VOICE_TOOLS_OPENAI_KEY)"
|
||||
)
|
||||
|
||||
for warning in env_check["warnings"]:
|
||||
details_parts.append(f"Environment: {warning}")
|
||||
|
||||
return {
|
||||
"available": available,
|
||||
"audio_available": has_audio,
|
||||
"stt_available": stt_available,
|
||||
"missing_packages": missing,
|
||||
"details": "\n".join(details_parts),
|
||||
"environment": env_check,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Temp file cleanup
|
||||
# ============================================================================
|
||||
def cleanup_temp_recordings(max_age_seconds: int = 3600) -> int:
|
||||
"""Remove old temporary voice recording files.
|
||||
|
||||
Args:
|
||||
max_age_seconds: Delete files older than this (default: 1 hour).
|
||||
|
||||
Returns:
|
||||
Number of files deleted.
|
||||
"""
|
||||
if not os.path.isdir(_TEMP_DIR):
|
||||
return 0
|
||||
|
||||
deleted = 0
|
||||
now = time.time()
|
||||
|
||||
for entry in os.scandir(_TEMP_DIR):
|
||||
if entry.is_file() and entry.name.startswith("recording_") and entry.name.endswith(".wav"):
|
||||
try:
|
||||
age = now - entry.stat().st_mtime
|
||||
if age > max_age_seconds:
|
||||
os.unlink(entry.path)
|
||||
deleted += 1
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if deleted:
|
||||
logger.debug("Cleaned up %d old voice recordings", deleted)
|
||||
return deleted
|
||||
487
website/docs/user-guide/features/voice-mode.md
Normal file
487
website/docs/user-guide/features/voice-mode.md
Normal file
|
|
@ -0,0 +1,487 @@
|
|||
---
|
||||
sidebar_position: 10
|
||||
title: "Voice Mode"
|
||||
description: "Real-time voice conversations with Hermes Agent — CLI, Telegram, Discord (DMs, text channels, and voice channels)"
|
||||
---
|
||||
|
||||
# Voice Mode
|
||||
|
||||
Hermes Agent supports full voice interaction across CLI and messaging platforms. Talk to the agent using your microphone, hear spoken replies, and have live voice conversations in Discord voice channels.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before using voice features, make sure you have:
|
||||
|
||||
1. **Hermes Agent installed** — `pip install hermes-agent` (see [Getting Started](../../getting-started.md))
|
||||
2. **An LLM provider configured** — set `OPENAI_API_KEY`, `OPENAI_BASE_URL`, and `LLM_MODEL` in `~/.hermes/.env`
|
||||
3. **A working base setup** — run `hermes` to verify the agent responds to text before enabling voice
|
||||
|
||||
:::tip
|
||||
The `~/.hermes/` directory and default `config.yaml` are created automatically the first time you run `hermes`. You only need to create `~/.hermes/.env` manually for API keys.
|
||||
:::
|
||||
|
||||
## Overview
|
||||
|
||||
| Feature | Platform | Description |
|
||||
|---------|----------|-------------|
|
||||
| **Interactive Voice** | CLI | Press Ctrl+B to record, agent auto-detects silence and responds |
|
||||
| **Auto Voice Reply** | Telegram, Discord | Agent sends spoken audio alongside text responses |
|
||||
| **Voice Channel** | Discord | Bot joins VC, listens to users speaking, speaks replies back |
|
||||
|
||||
## Requirements
|
||||
|
||||
### Python Packages
|
||||
|
||||
```bash
|
||||
# CLI voice mode (microphone + audio playback)
|
||||
pip install hermes-agent[voice]
|
||||
|
||||
# Discord + Telegram messaging (includes discord.py[voice] for VC support)
|
||||
pip install hermes-agent[messaging]
|
||||
|
||||
# Premium TTS (ElevenLabs)
|
||||
pip install hermes-agent[tts-premium]
|
||||
|
||||
# Everything at once
|
||||
pip install hermes-agent[all]
|
||||
```
|
||||
|
||||
| Extra | Packages | Required For |
|
||||
|-------|----------|-------------|
|
||||
| `voice` | `sounddevice`, `numpy` | CLI voice mode |
|
||||
| `messaging` | `discord.py[voice]`, `python-telegram-bot`, `aiohttp` | Discord & Telegram bots |
|
||||
| `tts-premium` | `elevenlabs` | ElevenLabs TTS provider |
|
||||
|
||||
:::info
|
||||
`discord.py[voice]` installs **PyNaCl** (for voice encryption) and **opus bindings** automatically. This is required for Discord voice channel support.
|
||||
:::
|
||||
|
||||
### System Dependencies
|
||||
|
||||
```bash
|
||||
# macOS
|
||||
brew install portaudio ffmpeg opus
|
||||
|
||||
# Ubuntu/Debian
|
||||
sudo apt install portaudio19-dev ffmpeg libopus0
|
||||
```
|
||||
|
||||
| Dependency | Purpose | Required For |
|
||||
|-----------|---------|-------------|
|
||||
| **PortAudio** | Microphone input and audio playback | CLI voice mode |
|
||||
| **ffmpeg** | Audio format conversion (MP3 → Opus, PCM → WAV) | All platforms |
|
||||
| **Opus** | Discord voice codec | Discord voice channels |
|
||||
|
||||
### API Keys
|
||||
|
||||
Add to `~/.hermes/.env`:
|
||||
|
||||
```bash
|
||||
# Speech-to-Text — local provider needs NO key at all
|
||||
# pip install faster-whisper # Free, runs locally, recommended
|
||||
GROQ_API_KEY=your-key # Groq Whisper — fast, free tier (cloud)
|
||||
VOICE_TOOLS_OPENAI_KEY=your-key # OpenAI Whisper — paid (cloud)
|
||||
|
||||
# Text-to-Speech (optional — Edge TTS works without any key)
|
||||
ELEVENLABS_API_KEY=your-key # ElevenLabs — premium quality
|
||||
```
|
||||
|
||||
:::tip
|
||||
If `faster-whisper` is installed, voice mode works with **zero API keys** for STT. The model (~150 MB for `base`) downloads automatically on first use.
|
||||
:::
|
||||
|
||||
---
|
||||
|
||||
## CLI Voice Mode
|
||||
|
||||
### Quick Start
|
||||
|
||||
Start the CLI and enable voice mode:
|
||||
|
||||
```bash
|
||||
hermes # Start the interactive CLI
|
||||
```
|
||||
|
||||
Then use these commands inside the CLI:
|
||||
|
||||
```
|
||||
/voice Toggle voice mode on/off
|
||||
/voice on Enable voice mode
|
||||
/voice off Disable voice mode
|
||||
/voice tts Toggle TTS output
|
||||
/voice status Show current state
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
1. Start the CLI with `hermes` and enable voice mode with `/voice on`
|
||||
2. **Press Ctrl+B** — a beep plays (880Hz), recording starts
|
||||
3. **Speak** — a live audio level bar shows your input: `● [▁▂▃▅▇▇▅▂] ❯`
|
||||
4. **Stop speaking** — after 3 seconds of silence, recording auto-stops
|
||||
5. **Two beeps** play (660Hz) confirming the recording ended
|
||||
6. Audio is transcribed via Whisper and sent to the agent
|
||||
7. If TTS is enabled, the agent's reply is spoken aloud
|
||||
8. Recording **automatically restarts** — speak again without pressing any key
|
||||
|
||||
This loop continues until you press **Ctrl+B** during recording (exits continuous mode) or 3 consecutive recordings detect no speech.
|
||||
|
||||
:::tip
|
||||
The record key is configurable via `voice.record_key` in `~/.hermes/config.yaml` (default: `ctrl+b`).
|
||||
:::
|
||||
|
||||
### Silence Detection
|
||||
|
||||
Two-stage algorithm detects when you've finished speaking:
|
||||
|
||||
1. **Speech confirmation** — waits for audio above the RMS threshold (200) for at least 0.3s, tolerating brief dips between syllables
|
||||
2. **End detection** — once speech is confirmed, triggers after 3.0 seconds of continuous silence
|
||||
|
||||
If no speech is detected at all for 15 seconds, recording stops automatically.
|
||||
|
||||
Both `silence_threshold` and `silence_duration` are configurable in `config.yaml`.
|
||||
|
||||
### Streaming TTS
|
||||
|
||||
When TTS is enabled, the agent speaks its reply **sentence-by-sentence** as it generates text — you don't wait for the full response:
|
||||
|
||||
1. Buffers text deltas into complete sentences (min 20 chars)
|
||||
2. Strips markdown formatting and `<think>` blocks
|
||||
3. Generates and plays audio per sentence in real-time
|
||||
|
||||
### Hallucination Filter
|
||||
|
||||
Whisper sometimes generates phantom text from silence or background noise ("Thank you for watching", "Subscribe", etc.). The agent filters these out using a set of 26 known hallucination phrases across multiple languages, plus a regex pattern that catches repetitive variations.
|
||||
|
||||
---
|
||||
|
||||
## Gateway Voice Reply (Telegram & Discord)
|
||||
|
||||
If you haven't set up your messaging bots yet, see the platform-specific guides:
|
||||
- [Telegram Setup Guide](../messaging/telegram.md)
|
||||
- [Discord Setup Guide](../messaging/discord.md)
|
||||
|
||||
Start the gateway to connect to your messaging platforms:
|
||||
|
||||
```bash
|
||||
hermes gateway # Start the gateway (connects to configured platforms)
|
||||
hermes gateway setup # Interactive setup wizard for first-time configuration
|
||||
```
|
||||
|
||||
### Discord: Channels vs DMs
|
||||
|
||||
The bot supports two interaction modes on Discord:
|
||||
|
||||
| Mode | How to Talk | Mention Required | Setup |
|
||||
|------|------------|-----------------|-------|
|
||||
| **Direct Message (DM)** | Open the bot's profile → "Message" | No | Works immediately |
|
||||
| **Server Channel** | Type in a text channel where the bot is present | Yes (`@botname`) | Bot must be invited to the server |
|
||||
|
||||
**DM (recommended for personal use):** Just open a DM with the bot and type — no @mention needed. Voice replies and all commands work the same as in channels.
|
||||
|
||||
**Server channels:** The bot only responds when you @mention it (e.g. `@hermesbyt4 hello`). Make sure you select the **bot user** from the mention popup, not the role with the same name.
|
||||
|
||||
:::tip
|
||||
To disable the mention requirement in server channels, add to `~/.hermes/.env`:
|
||||
```bash
|
||||
DISCORD_REQUIRE_MENTION=false
|
||||
```
|
||||
Or set specific channels as free-response (no mention needed):
|
||||
```bash
|
||||
DISCORD_FREE_RESPONSE_CHANNELS=123456789,987654321
|
||||
```
|
||||
:::
|
||||
|
||||
### Commands
|
||||
|
||||
These work in both Telegram and Discord (DMs and text channels):
|
||||
|
||||
```
|
||||
/voice Toggle voice mode on/off
|
||||
/voice on Voice replies only when you send a voice message
|
||||
/voice tts Voice replies for ALL messages
|
||||
/voice off Disable voice replies
|
||||
/voice status Show current setting
|
||||
```
|
||||
|
||||
### Modes
|
||||
|
||||
| Mode | Command | Behavior |
|
||||
|------|---------|----------|
|
||||
| `off` | `/voice off` | Text only (default) |
|
||||
| `voice_only` | `/voice on` | Speaks reply only when you send a voice message |
|
||||
| `all` | `/voice tts` | Speaks reply to every message |
|
||||
|
||||
Voice mode setting is persisted across gateway restarts.
|
||||
|
||||
### Platform Delivery
|
||||
|
||||
| Platform | Format | Notes |
|
||||
|----------|--------|-------|
|
||||
| **Telegram** | Voice bubble (Opus/OGG) | Plays inline in chat. ffmpeg converts MP3 → Opus if needed |
|
||||
| **Discord** | Native voice bubble (Opus/OGG) | Plays inline like a user voice message. Falls back to file attachment if voice bubble API fails |
|
||||
|
||||
---
|
||||
|
||||
## Discord Voice Channels
|
||||
|
||||
The most immersive voice feature: the bot joins a Discord voice channel, listens to users speaking, transcribes their speech, processes through the agent, and speaks the reply back in the voice channel.
|
||||
|
||||
### Setup
|
||||
|
||||
#### 1. Discord Bot Permissions
|
||||
|
||||
If you already have a Discord bot set up for text (see [Discord Setup Guide](../messaging/discord.md)), you need to add voice permissions.
|
||||
|
||||
Go to the [Discord Developer Portal](https://discord.com/developers/applications) → your application → **Installation** → **Default Install Settings** → **Guild Install**:
|
||||
|
||||
**Add these permissions to the existing text permissions:**
|
||||
|
||||
| Permission | Purpose | Required |
|
||||
|-----------|---------|----------|
|
||||
| **Connect** | Join voice channels | Yes |
|
||||
| **Speak** | Play TTS audio in voice channels | Yes |
|
||||
| **Use Voice Activity** | Detect when users are speaking | Recommended |
|
||||
|
||||
**Updated Permissions Integer:**
|
||||
|
||||
| Level | Integer | What's Included |
|
||||
|-------|---------|----------------|
|
||||
| Text only | `274878286912` | View Channels, Send Messages, Read History, Embeds, Attachments, Threads, Reactions |
|
||||
| Text + Voice | `274881432640` | All above + Connect, Speak |
|
||||
|
||||
**Re-invite the bot** with the updated permissions URL:
|
||||
|
||||
```
|
||||
https://discord.com/oauth2/authorize?client_id=YOUR_APP_ID&scope=bot+applications.commands&permissions=274881432640
|
||||
```
|
||||
|
||||
Replace `YOUR_APP_ID` with your Application ID from the Developer Portal.
|
||||
|
||||
:::warning
|
||||
Re-inviting the bot to a server it's already in will update its permissions without removing it. You won't lose any data or configuration.
|
||||
:::
|
||||
|
||||
#### 2. Privileged Gateway Intents
|
||||
|
||||
In the [Developer Portal](https://discord.com/developers/applications) → your application → **Bot** → **Privileged Gateway Intents**, enable all three:
|
||||
|
||||
| Intent | Purpose |
|
||||
|--------|---------|
|
||||
| **Presence Intent** | Detect user online/offline status |
|
||||
| **Server Members Intent** | Map voice SSRC identifiers to Discord user IDs |
|
||||
| **Message Content Intent** | Read text message content in channels |
|
||||
|
||||
All three are required for full voice channel functionality. **Server Members Intent** is especially critical — without it, the bot cannot identify who is speaking in the voice channel.
|
||||
|
||||
#### 3. Opus Codec
|
||||
|
||||
The Opus codec library must be installed on the machine running the gateway:
|
||||
|
||||
```bash
|
||||
# macOS (Homebrew)
|
||||
brew install opus
|
||||
|
||||
# Ubuntu/Debian
|
||||
sudo apt install libopus0
|
||||
```
|
||||
|
||||
The bot auto-loads the codec from:
|
||||
- **macOS:** `/opt/homebrew/lib/libopus.dylib`
|
||||
- **Linux:** `libopus.so.0`
|
||||
|
||||
#### 4. Environment Variables
|
||||
|
||||
```bash
|
||||
# ~/.hermes/.env
|
||||
|
||||
# Discord bot (already configured for text)
|
||||
DISCORD_BOT_TOKEN=your-bot-token
|
||||
DISCORD_ALLOWED_USERS=your-user-id
|
||||
|
||||
# STT — local provider needs no key (pip install faster-whisper)
|
||||
# GROQ_API_KEY=your-key # Alternative: cloud-based, fast, free tier
|
||||
|
||||
# TTS — optional, Edge TTS (free) is the default
|
||||
# ELEVENLABS_API_KEY=your-key # Premium quality
|
||||
```
|
||||
|
||||
### Start the Gateway
|
||||
|
||||
```bash
|
||||
hermes gateway # Start with existing configuration
|
||||
```
|
||||
|
||||
The bot should come online in Discord within a few seconds.
|
||||
|
||||
### Commands
|
||||
|
||||
Use these in the Discord text channel where the bot is present:
|
||||
|
||||
```
|
||||
/voice join Bot joins your current voice channel
|
||||
/voice channel Alias for /voice join
|
||||
/voice leave Bot disconnects from voice channel
|
||||
/voice status Show voice mode and connected channel
|
||||
```
|
||||
|
||||
:::info
|
||||
You must be in a voice channel before running `/voice join`. The bot joins the same VC you're in.
|
||||
:::
|
||||
|
||||
### How It Works
|
||||
|
||||
When the bot joins a voice channel, it:
|
||||
|
||||
1. **Listens** to each user's audio stream independently
|
||||
2. **Detects silence** — 1.5s of silence after at least 0.5s of speech triggers processing
|
||||
3. **Transcribes** the audio via Whisper STT (local, Groq, or OpenAI)
|
||||
4. **Processes** through the full agent pipeline (session, tools, memory)
|
||||
5. **Speaks** the reply back in the voice channel via TTS
|
||||
|
||||
### Text Channel Integration
|
||||
|
||||
When the bot is in a voice channel:
|
||||
|
||||
- Transcripts appear in the text channel: `[Voice] @user: what you said`
|
||||
- Agent responses are sent as text in the channel AND spoken in the VC
|
||||
- The text channel is the one where `/voice join` was issued
|
||||
|
||||
### Echo Prevention
|
||||
|
||||
The bot automatically pauses its audio listener while playing TTS replies, preventing it from hearing and re-processing its own output.
|
||||
|
||||
### Access Control
|
||||
|
||||
Only users listed in `DISCORD_ALLOWED_USERS` can interact via voice. Other users' audio is silently ignored.
|
||||
|
||||
```bash
|
||||
# ~/.hermes/.env
|
||||
DISCORD_ALLOWED_USERS=284102345871466496
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### config.yaml
|
||||
|
||||
```yaml
|
||||
# Voice recording (CLI)
|
||||
voice:
|
||||
record_key: "ctrl+b" # Key to start/stop recording
|
||||
max_recording_seconds: 120 # Maximum recording length
|
||||
auto_tts: false # Auto-enable TTS when voice mode starts
|
||||
silence_threshold: 200 # RMS level (0-32767) below which counts as silence
|
||||
silence_duration: 3.0 # Seconds of silence before auto-stop
|
||||
|
||||
# Speech-to-Text
|
||||
stt:
|
||||
provider: "local" # "local" (free) | "groq" | "openai"
|
||||
local:
|
||||
model: "base" # tiny, base, small, medium, large-v3
|
||||
# model: "whisper-1" # Legacy: used when provider is not set
|
||||
|
||||
# Text-to-Speech
|
||||
tts:
|
||||
provider: "edge" # "edge" (free) | "elevenlabs" | "openai"
|
||||
edge:
|
||||
voice: "en-US-AriaNeural" # 322 voices, 74 languages
|
||||
elevenlabs:
|
||||
voice_id: "pNInz6obpgDQGcFmaJgB" # Adam
|
||||
model_id: "eleven_multilingual_v2"
|
||||
openai:
|
||||
model: "gpt-4o-mini-tts"
|
||||
voice: "alloy" # alloy, echo, fable, onyx, nova, shimmer
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Speech-to-Text providers (local needs no key)
|
||||
# pip install faster-whisper # Free local STT — no API key needed
|
||||
GROQ_API_KEY=... # Groq Whisper (fast, free tier)
|
||||
VOICE_TOOLS_OPENAI_KEY=... # OpenAI Whisper (paid)
|
||||
|
||||
# STT advanced overrides (optional)
|
||||
STT_GROQ_MODEL=whisper-large-v3-turbo # Override default Groq STT model
|
||||
STT_OPENAI_MODEL=whisper-1 # Override default OpenAI STT model
|
||||
GROQ_BASE_URL=https://api.groq.com/openai/v1 # Custom Groq endpoint
|
||||
STT_OPENAI_BASE_URL=https://api.openai.com/v1 # Custom OpenAI STT endpoint
|
||||
|
||||
# Text-to-Speech providers (Edge TTS needs no key)
|
||||
ELEVENLABS_API_KEY=... # ElevenLabs (premium quality)
|
||||
# OpenAI TTS uses VOICE_TOOLS_OPENAI_KEY
|
||||
|
||||
# Discord voice channel
|
||||
DISCORD_BOT_TOKEN=...
|
||||
DISCORD_ALLOWED_USERS=...
|
||||
```
|
||||
|
||||
### STT Provider Comparison
|
||||
|
||||
| Provider | Model | Speed | Quality | Cost | API Key |
|
||||
|----------|-------|-------|---------|------|---------|
|
||||
| **Local** | `base` | Fast (depends on CPU/GPU) | Good | Free | No |
|
||||
| **Local** | `small` | Medium | Better | Free | No |
|
||||
| **Local** | `large-v3` | Slow | Best | Free | No |
|
||||
| **Groq** | `whisper-large-v3-turbo` | Very fast (~0.5s) | Good | Free tier | Yes |
|
||||
| **Groq** | `whisper-large-v3` | Fast (~1s) | Better | Free tier | Yes |
|
||||
| **OpenAI** | `whisper-1` | Fast (~1s) | Good | Paid | Yes |
|
||||
| **OpenAI** | `gpt-4o-transcribe` | Medium (~2s) | Best | Paid | Yes |
|
||||
|
||||
Provider priority (automatic fallback): **local** > **groq** > **openai**
|
||||
|
||||
### TTS Provider Comparison
|
||||
|
||||
| Provider | Quality | Cost | Latency | Key Required |
|
||||
|----------|---------|------|---------|-------------|
|
||||
| **Edge TTS** | Good | Free | ~1s | No |
|
||||
| **ElevenLabs** | Excellent | Paid | ~2s | Yes |
|
||||
| **OpenAI TTS** | Good | Paid | ~1.5s | Yes |
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "No audio device found" (CLI)
|
||||
|
||||
PortAudio is not installed:
|
||||
|
||||
```bash
|
||||
brew install portaudio # macOS
|
||||
sudo apt install portaudio19-dev # Ubuntu
|
||||
```
|
||||
|
||||
### Bot doesn't respond in Discord server channels
|
||||
|
||||
The bot requires an @mention by default in server channels. Make sure you:
|
||||
|
||||
1. Type `@` and select the **bot user** (with the #discriminator), not the **role** with the same name
|
||||
2. Or use DMs instead — no mention needed
|
||||
3. Or set `DISCORD_REQUIRE_MENTION=false` in `~/.hermes/.env`
|
||||
|
||||
### Bot joins VC but doesn't hear me
|
||||
|
||||
- Check your Discord user ID is in `DISCORD_ALLOWED_USERS`
|
||||
- Make sure you're not muted in Discord
|
||||
- The bot needs a SPEAKING event from Discord before it can map your audio — start speaking within a few seconds of joining
|
||||
|
||||
### Bot hears me but doesn't respond
|
||||
|
||||
- Verify STT is available: install `faster-whisper` (no key needed) or set `GROQ_API_KEY` / `VOICE_TOOLS_OPENAI_KEY`
|
||||
- Check the LLM model is configured and accessible
|
||||
- Review gateway logs: `tail -f ~/.hermes/logs/gateway.log`
|
||||
|
||||
### Bot responds in text but not in voice channel
|
||||
|
||||
- TTS provider may be failing — check API key and quota
|
||||
- Edge TTS (free, no key) is the default fallback
|
||||
- Check logs for TTS errors
|
||||
|
||||
### Whisper returns garbage text
|
||||
|
||||
The hallucination filter catches most cases automatically. If you're still getting phantom transcripts:
|
||||
|
||||
- Use a quieter environment
|
||||
- Adjust `silence_threshold` in config (higher = less sensitive)
|
||||
- Try a different STT model
|
||||
|
|
@ -210,8 +210,8 @@ Replace the ID with the actual channel ID (right-click → Copy Channel ID with
|
|||
|
||||
Hermes Agent supports Discord voice messages:
|
||||
|
||||
- **Incoming voice messages** are automatically transcribed using Whisper (requires `VOICE_TOOLS_OPENAI_KEY` to be set in your environment).
|
||||
- **Text-to-speech**: When TTS is enabled, the bot can send spoken responses as MP3 file attachments.
|
||||
- **Incoming voice messages** are automatically transcribed using Whisper (requires `GROQ_API_KEY` or `VOICE_TOOLS_OPENAI_KEY` to be set in your environment).
|
||||
- **Text-to-speech**: Use `/voice tts` to have the bot send spoken audio responses alongside text replies.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
---
|
||||
sidebar_position: 1
|
||||
title: "Messaging Gateway"
|
||||
description: "Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, Email, or Home Assistant — architecture and setup overview"
|
||||
description: "Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, Email, Home Assistant, or your browser — architecture and setup overview"
|
||||
---
|
||||
|
||||
# Messaging Gateway
|
||||
|
||||
Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, Email, or Home Assistant. The gateway is a single background process that connects to all your configured platforms, handles sessions, runs cron jobs, and delivers voice messages.
|
||||
Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, Email, Home Assistant, or your browser. The gateway is a single background process that connects to all your configured platforms, handles sessions, runs cron jobs, and delivers voice messages.
|
||||
|
||||
## Architecture
|
||||
|
||||
|
|
@ -15,24 +15,24 @@ Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, Email, or Home
|
|||
│ Hermes Gateway │
|
||||
├───────────────────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌──────────┐ ┌─────────┐ ┌──────────┐ ┌───────┐ ┌───────┐ ┌───────┐ ┌────┐│
|
||||
│ │ Telegram │ │ Discord │ │ WhatsApp │ │ Slack │ │Signal │ │ Email │ │ HA ││
|
||||
│ │ Adapter │ │ Adapter │ │ Adapter │ │Adapter│ │Adapter│ │Adapter│ │Adpt││
|
||||
│ └────┬─────┘ └────┬────┘ └────┬─────┘ └──┬────┘ └──┬────┘ └──┬────┘ └─┬──┘│
|
||||
│ │ │ │ │ │ │ │ │
|
||||
│ └─────────────┴───────────┴───────────┴─────────┴─────────┴────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────▼────────┐ │
|
||||
│ │ Session Store │ │
|
||||
│ │ (per-chat) │ │
|
||||
│ └────────┬────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────▼────────┐ │
|
||||
│ │ AIAgent │ │
|
||||
│ │ (run_agent) │ │
|
||||
│ └─────────────────┘ │
|
||||
│ │
|
||||
└───────────────────────────────────────────────────────────────────────────────┘
|
||||
│ ┌──────────┐ ┌─────────┐ ┌──────────┐ ┌───────┐ ┌───────┐ ┌───────┐ ┌────┐ │
|
||||
│ │ Telegram │ │ Discord │ │ WhatsApp │ │ Slack │ │Signal │ │ Email │ │ HA │ │
|
||||
│ │ Adapter │ │ Adapter │ │ Adapter │ │Adapter│ │Adapter│ │Adapter│ │Adpt│ │
|
||||
│ └────┬─────┘ └────┬────┘ └────┬─────┘ └──┬────┘ └──┬────┘ └──┬────┘ └─┬──┘ │
|
||||
│ │ │ │ │ │ │ │ │
|
||||
│ └─────────────┴───────────┴───────────┴─────────┴─────────┴────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────▼────────┐ │
|
||||
│ │ Session Store │ │
|
||||
│ │ (per-chat) │ │
|
||||
│ └────────┬────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────▼────────┐ │
|
||||
│ │ AIAgent │ │
|
||||
│ │ (run_agent) │ │
|
||||
│ └─────────────────┘ │
|
||||
│ │
|
||||
└───────────────────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
Each platform adapter receives messages, routes them through a per-chat session store, and dispatches them to the AIAgent for processing. The gateway also runs the cron scheduler, ticking every 60 seconds to execute any due jobs.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue