diff --git a/.env.example b/.env.example index a5153d1d07..fb215afbad 100644 --- a/.env.example +++ b/.env.example @@ -275,3 +275,27 @@ WANDB_API_KEY= # GITHUB_APP_ID= # GITHUB_APP_PRIVATE_KEY_PATH= # GITHUB_APP_INSTALLATION_ID= + +# Groq API key (free tier — used for Whisper STT in voice mode) +# GROQ_API_KEY= + +# ============================================================================= +# STT PROVIDER SELECTION +# ============================================================================= +# Default STT provider is "local" (faster-whisper) — runs on your machine, no API key needed. +# Install with: pip install faster-whisper +# Model downloads automatically on first use (~150 MB for "base"). +# To use cloud providers instead, set GROQ_API_KEY or VOICE_TOOLS_OPENAI_KEY above. +# Provider priority: local > groq > openai +# Configure in config.yaml: stt.provider: local | groq | openai + +# ============================================================================= +# STT ADVANCED OVERRIDES (optional) +# ============================================================================= +# Override default STT models per provider (normally set via stt.model in config.yaml) +# STT_GROQ_MODEL=whisper-large-v3-turbo +# STT_OPENAI_MODEL=whisper-1 + +# Override STT provider endpoints (for proxies or self-hosted instances) +# GROQ_BASE_URL=https://api.groq.com/openai/v1 +# STT_OPENAI_BASE_URL=https://api.openai.com/v1 diff --git a/cli.py b/cli.py index 253cdd085e..094be22e97 100755 --- a/cli.py +++ b/cli.py @@ -18,6 +18,8 @@ import shutil import sys import json import atexit +import tempfile +import time import uuid import textwrap from contextlib import contextmanager @@ -1287,11 +1289,41 @@ class HermesCLI: self._history_file = _hermes_home / ".hermes_history" self._last_invalidate: float = 0.0 # throttle UI repaints self._app = None + + # State shared by interactive run() and single-query chat mode. + # These must exist before any direct chat() call because single-query + # mode does not go through run(). + self._agent_running = False + self._pending_input = queue.Queue() + self._interrupt_queue = queue.Queue() + self._should_exit = False + self._last_ctrl_c_time = 0 + self._clarify_state = None + self._clarify_freetext = False + self._clarify_deadline = 0 + self._sudo_state = None + self._sudo_deadline = 0 + self._approval_state = None + self._approval_deadline = 0 + self._approval_lock = threading.Lock() self._secret_state = None self._secret_deadline = 0 self._spinner_text: str = "" # thinking spinner text for TUI self._command_running = False self._command_status = "" + self._attached_images: list[Path] = [] + self._image_counter = 0 + + # Voice mode state (also reinitialized inside run() for interactive TUI). + self._voice_lock = threading.Lock() + self._voice_mode = False + self._voice_tts = False + self._voice_recorder = None + self._voice_recording = False + self._voice_processing = False + self._voice_continuous = False + self._voice_tts_done = threading.Event() + self._voice_tts_done.set() # Background task tracking: {task_id: threading.Thread} self._background_tasks: Dict[str, threading.Thread] = {} @@ -1548,6 +1580,7 @@ class HermesCLI: checkpoints_enabled=self.checkpoints_enabled, checkpoint_max_snapshots=self.checkpoint_max_snapshots, pass_session_id=self.pass_session_id, + tool_progress_callback=self._on_tool_progress, ) # Apply any pending title now that the session exists in the DB if self._pending_title and self._session_db: @@ -3017,6 +3050,8 @@ class HermesCLI: self._handle_background_command(cmd_original) elif cmd_lower.startswith("/skin"): self._handle_skin_command(cmd_original) + elif cmd_lower.startswith("/voice"): + self._handle_voice_command(cmd_original) else: # Check for user-defined quick commands (bypass agent loop, no LLM call) base_cmd = cmd_lower.split()[0] @@ -3511,6 +3546,407 @@ class HermesCLI: except Exception as e: print(f" ❌ MCP reload failed: {e}") + # ==================================================================== + # Tool progress callback (audio cues for voice mode) + # ==================================================================== + + def _on_tool_progress(self, function_name: str, preview: str, function_args: dict): + """Called when a tool starts executing. Plays audio cue in voice mode.""" + if not self._voice_mode: + return + # Skip internal/thinking tools + if function_name.startswith("_"): + return + try: + from tools.voice_mode import play_beep + # Short, subtle tick sound (higher pitch, very brief) + threading.Thread( + target=play_beep, + kwargs={"frequency": 1200, "duration": 0.06, "count": 1}, + daemon=True, + ).start() + except Exception: + pass + + # ==================================================================== + # Voice mode methods + # ==================================================================== + + def _voice_start_recording(self): + """Start capturing audio from the microphone.""" + if getattr(self, '_should_exit', False): + return + from tools.voice_mode import AudioRecorder, check_voice_requirements + + reqs = check_voice_requirements() + if not reqs["audio_available"]: + raise RuntimeError( + "Voice mode requires sounddevice and numpy.\n" + "Install with: pip install sounddevice numpy\n" + "Or: pip install hermes-agent[voice]" + ) + if not reqs.get("stt_available", reqs.get("stt_key_set")): + raise RuntimeError( + "Voice mode requires an STT provider for transcription.\n" + "Option 1: pip install faster-whisper (free, local)\n" + "Option 2: Set GROQ_API_KEY (free tier)\n" + "Option 3: Set VOICE_TOOLS_OPENAI_KEY (paid)" + ) + + # Prevent double-start from concurrent threads (atomic check-and-set) + with self._voice_lock: + if self._voice_recording: + return + self._voice_recording = True + + # Load silence detection params from config + voice_cfg = {} + try: + from hermes_cli.config import load_config + voice_cfg = load_config().get("voice", {}) + except Exception: + pass + + if self._voice_recorder is None: + self._voice_recorder = AudioRecorder() + + # Apply config-driven silence params + self._voice_recorder._silence_threshold = voice_cfg.get("silence_threshold", 200) + self._voice_recorder._silence_duration = voice_cfg.get("silence_duration", 3.0) + + def _on_silence(): + """Called by AudioRecorder when silence is detected after speech.""" + with self._voice_lock: + if not self._voice_recording: + return + _cprint(f"\n{_DIM}Silence detected, auto-stopping...{_RST}") + if hasattr(self, '_app') and self._app: + self._app.invalidate() + self._voice_stop_and_transcribe() + + # Audio cue: single beep BEFORE starting stream (avoid CoreAudio conflict) + try: + from tools.voice_mode import play_beep + play_beep(frequency=880, count=1) + except Exception: + pass + + try: + self._voice_recorder.start(on_silence_stop=_on_silence) + except Exception: + with self._voice_lock: + self._voice_recording = False + raise + _cprint(f"\n{_GOLD}● Recording...{_RST} {_DIM}(auto-stops on silence | Ctrl+B to stop & exit continuous){_RST}") + + # Periodically refresh prompt to update audio level indicator + def _refresh_level(): + while True: + with self._voice_lock: + still_recording = self._voice_recording + if not still_recording: + break + if hasattr(self, '_app') and self._app: + self._app.invalidate() + time.sleep(0.15) + threading.Thread(target=_refresh_level, daemon=True).start() + + def _voice_stop_and_transcribe(self): + """Stop recording, transcribe via STT, and queue the transcript as input.""" + # Atomic guard: only one thread can enter stop-and-transcribe. + # Set _voice_processing immediately so concurrent Ctrl+B presses + # don't race into the START path while recorder.stop() holds its lock. + with self._voice_lock: + if not self._voice_recording: + return + self._voice_recording = False + self._voice_processing = True + + submitted = False + wav_path = None + try: + if self._voice_recorder is None: + return + + wav_path = self._voice_recorder.stop() + + # Audio cue: double beep after stream stopped (no CoreAudio conflict) + try: + from tools.voice_mode import play_beep + play_beep(frequency=660, count=2) + except Exception: + pass + + if wav_path is None: + _cprint(f"{_DIM}No speech detected.{_RST}") + return + + # _voice_processing is already True (set atomically above) + if hasattr(self, '_app') and self._app: + self._app.invalidate() + _cprint(f"{_DIM}Transcribing...{_RST}") + + # Get STT model from config + stt_model = None + try: + from hermes_cli.config import load_config + stt_config = load_config().get("stt", {}) + stt_model = stt_config.get("model") + except Exception: + pass + + from tools.voice_mode import transcribe_recording + result = transcribe_recording(wav_path, model=stt_model) + + if result.get("success") and result.get("transcript", "").strip(): + transcript = result["transcript"].strip() + self._pending_input.put(transcript) + submitted = True + elif result.get("success"): + _cprint(f"{_DIM}No speech detected.{_RST}") + else: + error = result.get("error", "Unknown error") + _cprint(f"\n{_DIM}Transcription failed: {error}{_RST}") + + except Exception as e: + _cprint(f"\n{_DIM}Voice processing error: {e}{_RST}") + finally: + with self._voice_lock: + self._voice_processing = False + if hasattr(self, '_app') and self._app: + self._app.invalidate() + # Clean up temp file + try: + if wav_path and os.path.isfile(wav_path): + os.unlink(wav_path) + except Exception: + pass + + # Track consecutive no-speech cycles to avoid infinite restart loops. + if not submitted: + self._no_speech_count = getattr(self, '_no_speech_count', 0) + 1 + if self._no_speech_count >= 3: + self._voice_continuous = False + self._no_speech_count = 0 + _cprint(f"{_DIM}No speech detected 3 times, continuous mode stopped.{_RST}") + return + else: + self._no_speech_count = 0 + + # If no transcript was submitted but continuous mode is active, + # restart recording so the user can keep talking. + # (When transcript IS submitted, process_loop handles restart + # after chat() completes.) + if self._voice_continuous and not submitted and not self._voice_recording: + def _restart_recording(): + try: + self._voice_start_recording() + if hasattr(self, '_app') and self._app: + self._app.invalidate() + except Exception as e: + _cprint(f"{_DIM}Voice auto-restart failed: {e}{_RST}") + threading.Thread(target=_restart_recording, daemon=True).start() + + def _voice_speak_response(self, text: str): + """Speak the agent's response aloud using TTS (runs in background thread).""" + if not self._voice_tts: + return + self._voice_tts_done.clear() + try: + from tools.tts_tool import text_to_speech_tool + from tools.voice_mode import play_audio_file + import json + import re + + # Strip markdown and non-speech content for cleaner TTS + tts_text = text[:4000] if len(text) > 4000 else text + tts_text = re.sub(r'```[\s\S]*?```', ' ', tts_text) # fenced code blocks + tts_text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', tts_text) # [text](url) -> text + tts_text = re.sub(r'https?://\S+', '', tts_text) # URLs + tts_text = re.sub(r'\*\*(.+?)\*\*', r'\1', tts_text) # bold + tts_text = re.sub(r'\*(.+?)\*', r'\1', tts_text) # italic + tts_text = re.sub(r'`(.+?)`', r'\1', tts_text) # inline code + tts_text = re.sub(r'^#+\s*', '', tts_text, flags=re.MULTILINE) # headers + tts_text = re.sub(r'^\s*[-*]\s+', '', tts_text, flags=re.MULTILINE) # list items + tts_text = re.sub(r'---+', '', tts_text) # horizontal rules + tts_text = re.sub(r'\n{3,}', '\n\n', tts_text) # excessive newlines + tts_text = tts_text.strip() + if not tts_text: + return + + # Use MP3 output for CLI playback (afplay doesn't handle OGG well). + # The TTS tool may auto-convert MP3->OGG, but the original MP3 remains. + os.makedirs(os.path.join(tempfile.gettempdir(), "hermes_voice"), exist_ok=True) + mp3_path = os.path.join( + tempfile.gettempdir(), "hermes_voice", + f"tts_{time.strftime('%Y%m%d_%H%M%S')}.mp3", + ) + + text_to_speech_tool(text=tts_text, output_path=mp3_path) + + # Play the MP3 directly (the TTS tool returns OGG path but MP3 still exists) + if os.path.isfile(mp3_path) and os.path.getsize(mp3_path) > 0: + play_audio_file(mp3_path) + # Clean up + try: + os.unlink(mp3_path) + ogg_path = mp3_path.rsplit(".", 1)[0] + ".ogg" + if os.path.isfile(ogg_path): + os.unlink(ogg_path) + except OSError: + pass + except Exception as e: + logger.warning("Voice TTS playback failed: %s", e) + _cprint(f"{_DIM}TTS playback failed: {e}{_RST}") + finally: + self._voice_tts_done.set() + + def _handle_voice_command(self, command: str): + """Handle /voice [on|off|tts|status] command.""" + parts = command.strip().split(maxsplit=1) + subcommand = parts[1].lower().strip() if len(parts) > 1 else "" + + if subcommand == "on": + self._enable_voice_mode() + elif subcommand == "off": + self._disable_voice_mode() + elif subcommand == "tts": + self._toggle_voice_tts() + elif subcommand == "status": + self._show_voice_status() + elif subcommand == "": + # Toggle + if self._voice_mode: + self._disable_voice_mode() + else: + self._enable_voice_mode() + else: + _cprint(f"Unknown voice subcommand: {subcommand}") + _cprint("Usage: /voice [on|off|tts|status]") + + def _enable_voice_mode(self): + """Enable voice mode after checking requirements.""" + if self._voice_mode: + _cprint(f"{_DIM}Voice mode is already enabled.{_RST}") + return + + from tools.voice_mode import check_voice_requirements, detect_audio_environment + + # Environment detection -- warn and block in incompatible environments + env_check = detect_audio_environment() + if not env_check["available"]: + _cprint(f"\n{_GOLD}Voice mode unavailable in this environment:{_RST}") + for warning in env_check["warnings"]: + _cprint(f" {_DIM}{warning}{_RST}") + return + + reqs = check_voice_requirements() + if not reqs["available"]: + _cprint(f"\n{_GOLD}Voice mode requirements not met:{_RST}") + for line in reqs["details"].split("\n"): + _cprint(f" {_DIM}{line}{_RST}") + if reqs["missing_packages"]: + _cprint(f"\n {_BOLD}Install: pip install {' '.join(reqs['missing_packages'])}{_RST}") + _cprint(f" {_DIM}Or: pip install hermes-agent[voice]{_RST}") + return + + with self._voice_lock: + self._voice_mode = True + + # Check config for auto_tts + try: + from hermes_cli.config import load_config + voice_config = load_config().get("voice", {}) + if voice_config.get("auto_tts", False): + with self._voice_lock: + self._voice_tts = True + except Exception: + pass + + # Voice mode instruction is injected as a user message prefix (not a + # system prompt change) to avoid invalidating the prompt cache. See + # _voice_message_prefix property and its usage in _process_message(). + + tts_status = " (TTS enabled)" if self._voice_tts else "" + try: + from hermes_cli.config import load_config + _raw_ptt = load_config().get("voice", {}).get("record_key", "ctrl+b") + _ptt_key = _raw_ptt.lower().replace("ctrl+", "c-").replace("alt+", "a-") + except Exception: + _ptt_key = "c-b" + _ptt_display = _ptt_key.replace("c-", "Ctrl+").upper() + _cprint(f"\n{_GOLD}Voice mode enabled{tts_status}{_RST}") + _cprint(f" {_DIM}{_ptt_display} to start/stop recording{_RST}") + _cprint(f" {_DIM}/voice tts to toggle speech output{_RST}") + _cprint(f" {_DIM}/voice off to disable voice mode{_RST}") + + def _disable_voice_mode(self): + """Disable voice mode, cancel any active recording, and stop TTS.""" + recorder = None + with self._voice_lock: + if self._voice_recording and self._voice_recorder: + self._voice_recorder.cancel() + self._voice_recording = False + recorder = self._voice_recorder + self._voice_mode = False + self._voice_tts = False + self._voice_continuous = False + + # Shut down the persistent audio stream in background + if recorder is not None: + def _bg_shutdown(rec=recorder): + try: + rec.shutdown() + except Exception: + pass + threading.Thread(target=_bg_shutdown, daemon=True).start() + self._voice_recorder = None + + # Stop any active TTS playback + try: + from tools.voice_mode import stop_playback + stop_playback() + except Exception: + pass + self._voice_tts_done.set() + + _cprint(f"\n{_DIM}Voice mode disabled.{_RST}") + + def _toggle_voice_tts(self): + """Toggle TTS output for voice mode.""" + if not self._voice_mode: + _cprint(f"{_DIM}Enable voice mode first: /voice on{_RST}") + return + + with self._voice_lock: + self._voice_tts = not self._voice_tts + status = "enabled" if self._voice_tts else "disabled" + + if self._voice_tts: + from tools.tts_tool import check_tts_requirements + if not check_tts_requirements(): + _cprint(f"{_DIM}Warning: No TTS provider available. Install edge-tts or set API keys.{_RST}") + + _cprint(f"{_GOLD}Voice TTS {status}.{_RST}") + + def _show_voice_status(self): + """Show current voice mode status.""" + from hermes_cli.config import load_config + from tools.voice_mode import check_voice_requirements + + reqs = check_voice_requirements() + + _cprint(f"\n{_BOLD}Voice Mode Status{_RST}") + _cprint(f" Mode: {'ON' if self._voice_mode else 'OFF'}") + _cprint(f" TTS: {'ON' if self._voice_tts else 'OFF'}") + _cprint(f" Recording: {'YES' if self._voice_recording else 'no'}") + _raw_key = load_config().get("voice", {}).get("record_key", "ctrl+b") + _display_key = _raw_key.replace("ctrl+", "Ctrl+").upper() if "ctrl+" in _raw_key.lower() else _raw_key + _cprint(f" Record key: {_display_key}") + _cprint(f"\n {_BOLD}Requirements:{_RST}") + for line in reqs["details"].split("\n"): + _cprint(f" {line}") + def _clarify_callback(self, question, choices): """ Platform callback for the clarify tool. Called from the agent thread. @@ -3752,19 +4188,90 @@ class HermesCLI: try: # Run the conversation with interrupt monitoring result = None - + + # --- Streaming TTS setup --- + # When ElevenLabs is the TTS provider and sounddevice is available, + # we stream audio sentence-by-sentence as the agent generates tokens + # instead of waiting for the full response. + use_streaming_tts = False + _streaming_box_opened = False + text_queue = None + tts_thread = None + stream_callback = None + stop_event = None + + if self._voice_tts: + try: + from tools.tts_tool import ( + _load_tts_config as _load_tts_cfg, + _get_provider as _get_prov, + _import_elevenlabs, + _import_sounddevice, + stream_tts_to_speaker, + ) + _tts_cfg = _load_tts_cfg() + if _get_prov(_tts_cfg) == "elevenlabs": + # Verify both ElevenLabs SDK and audio output are available + _import_elevenlabs() + _import_sounddevice() + use_streaming_tts = True + except (ImportError, OSError): + pass + except Exception: + pass + + if use_streaming_tts: + text_queue = queue.Queue() + stop_event = threading.Event() + + def display_callback(sentence: str): + """Called by TTS consumer when a sentence is ready to display + speak.""" + nonlocal _streaming_box_opened + if not _streaming_box_opened: + _streaming_box_opened = True + w = self.console.width + label = " ⚕ Hermes " + fill = w - 2 - len(label) + _cprint(f"\n{_GOLD}╭─{label}{'─' * max(fill - 1, 0)}╮{_RST}") + _cprint(sentence.rstrip()) + + tts_thread = threading.Thread( + target=stream_tts_to_speaker, + args=(text_queue, stop_event, self._voice_tts_done), + kwargs={"display_callback": display_callback}, + daemon=True, + ) + tts_thread.start() + + def stream_callback(delta: str): + if text_queue is not None: + text_queue.put(delta) + + # When voice mode is active, prepend a brief instruction so the + # model responds concisely. The prefix is API-call-local only — + # run_conversation persists the original clean user message. + _voice_prefix = "" + if self._voice_mode and isinstance(message, str): + _voice_prefix = ( + "[Voice input — respond concisely and conversationally, " + "2-3 sentences max. No code blocks or markdown.] " + ) + def run_agent(): nonlocal result + agent_message = _voice_prefix + message if _voice_prefix else message result = self.agent.run_conversation( - user_message=message, + user_message=agent_message, conversation_history=self.conversation_history[:-1], # Exclude the message we just added + stream_callback=stream_callback, task_id=self.session_id, + persist_user_message=message if _voice_prefix else None, ) - + # Start agent in background thread agent_thread = threading.Thread(target=run_agent) agent_thread.start() - + # Monitor the dedicated interrupt queue while the agent runs. # _interrupt_queue is separate from _pending_input, so process_loop # and chat() never compete for the same queue. @@ -3783,6 +4290,9 @@ class HermesCLI: if self._clarify_state or self._clarify_freetext: continue print(f"\n⚡ New message detected, interrupting...") + # Signal TTS to stop on interrupt + if stop_event is not None: + stop_event.set() self.agent.interrupt(interrupt_msg) # Debug: log to file (stdout may be devnull from redirect_stdout) try: @@ -3802,9 +4312,15 @@ class HermesCLI: else: # Fallback for non-interactive mode (e.g., single-query) agent_thread.join(0.1) - + agent_thread.join() # Ensure agent thread completes + # Signal end-of-text to TTS consumer and wait for it to finish + if use_streaming_tts and text_queue is not None: + text_queue.put(None) # sentinel + if tts_thread is not None: + tts_thread.join(timeout=120) + # Drain any remaining agent output still in the StdoutProxy # buffer so tool/status lines render ABOVE our response box. # The flush pushes data into the renderer queue; the short @@ -3815,15 +4331,22 @@ class HermesCLI: # Update history with full conversation self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history - + # Get the final response response = result.get("final_response", "") if result else "" - - # Handle failed results (e.g., non-retryable errors like invalid model) - if result and result.get("failed") and not response: + + # Handle failed or partial results (e.g., non-retryable errors, rate limits, + # truncated output, invalid tool calls). Both "failed" and "partial" with + # an empty final_response mean the agent couldn't produce a usable answer. + if result and (result.get("failed") or result.get("partial")) and not response: error_detail = result.get("error", "Unknown error") response = f"Error: {error_detail}" - + # Stop continuous voice mode on persistent errors (e.g. 429 rate limit) + # to avoid an infinite error → record → error loop + if self._voice_continuous: + self._voice_continuous = False + _cprint(f"\n{_DIM}Continuous voice mode stopped due to error.{_RST}") + # Handle interrupt - check if we were interrupted pending_message = None if result and result.get("interrupted"): @@ -3831,8 +4354,9 @@ class HermesCLI: # Add indicator that we were interrupted if response and pending_message: response = response + "\n\n---\n_[Interrupted - processing new message]_" - + response_previewed = result.get("response_previewed", False) if result else False + # Display reasoning (thinking) box if enabled and available if self.show_reasoning and result: reasoning = result.get("last_reasoning") @@ -3852,8 +4376,7 @@ class HermesCLI: _cprint(f"\n{r_top}\n{_DIM}{display_reasoning}{_RST}\n{r_bot}") if response and not response_previewed: - # Use a Rich Panel for the response box — adapts to terminal - # width at render time instead of hard-coding border length. + # Use skin engine for label/color with fallback try: from hermes_cli.skin_engine import get_active_skin _skin = get_active_skin() @@ -3865,23 +4388,40 @@ class HermesCLI: _resp_color = "#CD7F32" _resp_text = "#FFF8DC" - _chat_console = ChatConsole() - _chat_console.print(Panel( - _rich_text_from_ansi(response), - title=f"[{_resp_color} bold]{label}[/]", - title_align="left", - border_style=_resp_color, - style=_resp_text, - box=rich_box.HORIZONTALS, - padding=(1, 2), - )) + is_error_response = result and (result.get("failed") or result.get("partial")) + if use_streaming_tts and _streaming_box_opened and not is_error_response: + # Text was already printed sentence-by-sentence; just close the box + w = shutil.get_terminal_size().columns + _cprint(f"\n{_GOLD}╰{'─' * (w - 2)}╯{_RST}") + else: + _chat_console = ChatConsole() + _chat_console.print(Panel( + _rich_text_from_ansi(response), + title=f"[{_resp_color} bold]{label}[/]", + title_align="left", + border_style=_resp_color, + style=_resp_text, + box=rich_box.HORIZONTALS, + padding=(1, 2), + )) + # Play terminal bell when agent finishes (if enabled). # Works over SSH — the bell propagates to the user's terminal. if self.bell_on_complete: sys.stdout.write("\a") sys.stdout.flush() - + + # Speak response aloud if voice TTS is enabled + # Skip batch TTS when streaming TTS already handled it + if self._voice_tts and response and not use_streaming_tts: + threading.Thread( + target=self._voice_speak_response, + args=(response,), + daemon=True, + ).start() + + # Combine all interrupt messages (user may have typed multiple while waiting) # and re-queue as one prompt for process_loop if pending_message and hasattr(self, '_pending_input'): @@ -3902,6 +4442,20 @@ class HermesCLI: except Exception as e: print(f"Error: {e}") return None + finally: + # Ensure streaming TTS resources are cleaned up even on error. + # Normal path sends the sentinel at line ~3568; this is a safety + # net for exception paths that skip it. Duplicate sentinels are + # harmless — stream_tts_to_speaker exits on the first None. + if text_queue is not None: + try: + text_queue.put_nowait(None) + except Exception: + pass + if stop_event is not None: + stop_event.set() + if tts_thread is not None and tts_thread.is_alive(): + tts_thread.join(timeout=5) def _print_exit_summary(self): """Print session resume info on exit, similar to Claude Code.""" @@ -3961,9 +4515,26 @@ class HermesCLI: # Icon-only custom prompts should still remain visible in special states. return symbol, symbol + def _audio_level_bar(self) -> str: + """Return a visual audio level indicator based on current RMS.""" + _LEVEL_BARS = " ▁▂▃▄▅▆▇" + rec = getattr(self, "_voice_recorder", None) + if rec is None: + return "" + rms = rec.current_rms + # Normalize RMS (0-32767) to 0-7 index, with log-ish scaling + # Typical speech RMS is 500-5000, we cap display at ~8000 + level = min(rms, 8000) * 7 // 8000 + return _LEVEL_BARS[level] + def _get_tui_prompt_fragments(self): """Return the prompt_toolkit fragments for the current interactive state.""" symbol, state_suffix = self._get_tui_prompt_symbols() + if self._voice_recording: + bar = self._audio_level_bar() + return [("class:voice-recording", f"● {bar} {state_suffix}")] + if self._voice_processing: + return [("class:voice-processing", f"◉ {state_suffix}")] if self._sudo_state: return [("class:sudo-prompt", f"🔐 {state_suffix}")] if self._secret_state: @@ -3978,6 +4549,8 @@ class HermesCLI: return [("class:prompt-working", f"{self._command_spinner_frame()} {state_suffix}")] if self._agent_running: return [("class:prompt-working", f"⚕ {state_suffix}")] + if self._voice_mode: + return [("class:voice-prompt", f"🎤 {state_suffix}")] return [("class:prompt", symbol)] def _get_tui_prompt_text(self) -> str: @@ -4070,6 +4643,17 @@ class HermesCLI: self._attached_images: list[Path] = [] self._image_counter = 0 + # Voice mode state (protected by _voice_lock for cross-thread access) + self._voice_lock = threading.Lock() + self._voice_mode = False # Whether voice mode is enabled + self._voice_tts = False # Whether TTS output is enabled + self._voice_recorder = None # AudioRecorder instance (lazy init) + self._voice_recording = False # Whether currently recording + self._voice_processing = False # Whether STT is in progress + self._voice_continuous = False # Whether to auto-restart after agent responds + self._voice_tts_done = threading.Event() # Signals TTS playback finished + self._voice_tts_done.set() # Initially "done" (no TTS pending) + # Register callbacks so terminal_tool prompts route through our UI set_sudo_password_callback(self._sudo_password_callback) set_approval_callback(self._approval_callback) @@ -4254,6 +4838,7 @@ class HermesCLI: """Handle Ctrl+C - cancel interactive prompts, interrupt agent, or exit. Priority: + 0. Cancel active voice recording 1. Cancel active sudo/approval/clarify prompt 2. Interrupt the running agent (first press) 3. Force exit (second press within 2s, or when idle) @@ -4261,6 +4846,25 @@ class HermesCLI: import time as _time now = _time.time() + # Cancel active voice recording. + # Run cancel() in a background thread to prevent blocking the + # event loop if AudioRecorder._lock or CoreAudio takes time. + _should_cancel_voice = False + _recorder_ref = None + with cli_ref._voice_lock: + if cli_ref._voice_recording and cli_ref._voice_recorder: + _recorder_ref = cli_ref._voice_recorder + cli_ref._voice_recording = False + cli_ref._voice_continuous = False + _should_cancel_voice = True + if _should_cancel_voice: + _cprint(f"\n{_DIM}Recording cancelled.{_RST}") + threading.Thread( + target=_recorder_ref.cancel, daemon=True + ).start() + event.app.invalidate() + return + # Cancel sudo prompt if self._sudo_state: self._sudo_state["response_queue"].put("") @@ -4321,6 +4925,75 @@ class HermesCLI: self._should_exit = True event.app.exit() + # Voice push-to-talk key: configurable via config.yaml (voice.record_key) + # Default: Ctrl+B (avoids conflict with Ctrl+R readline reverse-search) + # Config uses "ctrl+b" format; prompt_toolkit expects "c-b" format. + try: + from hermes_cli.config import load_config + _raw_key = load_config().get("voice", {}).get("record_key", "ctrl+b") + _voice_key = _raw_key.lower().replace("ctrl+", "c-").replace("alt+", "a-") + except Exception: + _voice_key = "c-b" + + @kb.add(_voice_key) + def handle_voice_record(event): + """Toggle voice recording when voice mode is active. + + IMPORTANT: This handler runs in prompt_toolkit's event-loop thread. + Any blocking call here (locks, sd.wait, disk I/O) freezes the + entire UI. All heavy work is dispatched to daemon threads. + """ + if not cli_ref._voice_mode: + return + # Always allow STOPPING a recording (even when agent is running) + if cli_ref._voice_recording: + # Manual stop via push-to-talk key: stop continuous mode + with cli_ref._voice_lock: + cli_ref._voice_continuous = False + # Flag clearing is handled atomically inside _voice_stop_and_transcribe + event.app.invalidate() + threading.Thread( + target=cli_ref._voice_stop_and_transcribe, + daemon=True, + ).start() + else: + # Guard: don't START recording during agent run or interactive prompts + if cli_ref._agent_running: + return + if cli_ref._clarify_state or cli_ref._sudo_state or cli_ref._approval_state: + return + # Guard: don't start while a previous stop/transcribe cycle is + # still running — recorder.stop() holds AudioRecorder._lock and + # start() would block the event-loop thread waiting for it. + if cli_ref._voice_processing: + return + + # Interrupt TTS if playing, so user can start talking. + # stop_playback() is fast (just terminates a subprocess). + if not cli_ref._voice_tts_done.is_set(): + try: + from tools.voice_mode import stop_playback + stop_playback() + cli_ref._voice_tts_done.set() + except Exception: + pass + + with cli_ref._voice_lock: + cli_ref._voice_continuous = True + + # Dispatch to a daemon thread so play_beep(sd.wait), + # AudioRecorder.start(lock acquire), and config I/O + # never block the prompt_toolkit event loop. + def _start_recording(): + try: + cli_ref._voice_start_recording() + if hasattr(cli_ref, '_app') and cli_ref._app: + cli_ref._app.invalidate() + except Exception as e: + _cprint(f"\n{_DIM}Voice recording failed: {e}{_RST}") + + threading.Thread(target=_start_recording, daemon=True).start() + event.app.invalidate() from prompt_toolkit.keys import Keys @kb.add(Keys.BracketedPaste, eager=True) @@ -4460,6 +5133,10 @@ class HermesCLI: return Transformation(fragments=ti.fragments) def _get_placeholder(): + if cli_ref._voice_recording: + return "recording... Ctrl+B to stop, Ctrl+C to cancel" + if cli_ref._voice_processing: + return "transcribing..." if cli_ref._sudo_state: return "type password (hidden), Enter to skip" if cli_ref._secret_state: @@ -4476,6 +5153,8 @@ class HermesCLI: return f"{frame} {status}" if cli_ref._agent_running: return "type a message + Enter to interrupt, Ctrl+C to cancel" + if cli_ref._voice_mode: + return "type or Ctrl+B to record" return "" input_area.control.input_processors.append(_PlaceholderProcessor(_get_placeholder)) @@ -4813,6 +5492,24 @@ class HermesCLI: height=Condition(lambda: bool(cli_ref._attached_images)), ) + # Persistent voice mode status bar (visible only when voice mode is on) + def _get_voice_status(): + if cli_ref._voice_recording: + return [('class:voice-status-recording', ' ● REC Ctrl+B to stop ')] + if cli_ref._voice_processing: + return [('class:voice-status', ' ◉ Transcribing... ')] + tts = " | TTS on" if cli_ref._voice_tts else "" + cont = " | Continuous" if cli_ref._voice_continuous else "" + return [('class:voice-status', f' 🎤 Voice mode{tts}{cont} — Ctrl+B to record ')] + + voice_status_bar = ConditionalContainer( + Window( + FormattedTextControl(_get_voice_status), + height=1, + ), + filter=Condition(lambda: cli_ref._voice_mode), + ) + # Layout: interactive prompt widgets + ruled input at bottom. # The sudo, approval, and clarify widgets appear above the input when # the corresponding interactive prompt is active. @@ -4829,6 +5526,7 @@ class HermesCLI: image_bar, input_area, input_rule_bot, + voice_status_bar, CompletionsMenu(max_height=12, scroll_offset=1), ]) ) @@ -4869,6 +5567,12 @@ class HermesCLI: 'approval-cmd': '#AAAAAA italic', 'approval-choice': '#AAAAAA', 'approval-selected': '#FFD700 bold', + # Voice mode + 'voice-prompt': '#87CEEB', + 'voice-recording': '#FF4444 bold', + 'voice-processing': '#FFA500 italic', + 'voice-status': 'bg:#1a1a2e #87CEEB', + 'voice-status-recording': 'bg:#1a1a2e #FF4444 bold', } style = PTStyle.from_dict(self._build_tui_style_dict()) @@ -4961,13 +5665,29 @@ class HermesCLI: # Regular chat - run agent self._agent_running = True app.invalidate() # Refresh status line - + try: self.chat(user_input, images=submit_images or None) finally: self._agent_running = False self._spinner_text = "" app.invalidate() # Refresh status line + + # Continuous voice: auto-restart recording after agent responds. + # Dispatch to a daemon thread so play_beep (sd.wait) and + # AudioRecorder.start (lock acquire) never block process_loop — + # otherwise queued user input would stall silently. + if self._voice_mode and self._voice_continuous and not self._voice_recording: + def _restart_recording(): + try: + if self._voice_tts: + self._voice_tts_done.wait(timeout=60) + time.sleep(0.3) + self._voice_start_recording() + app.invalidate() + except Exception as e: + _cprint(f"{_DIM}Voice auto-restart failed: {e}{_RST}") + threading.Thread(target=_restart_recording, daemon=True).start() except Exception as e: print(f"Error: {e}") @@ -4993,6 +5713,19 @@ class HermesCLI: self.agent.flush_memories(self.conversation_history) except Exception: pass + # Shut down voice recorder (release persistent audio stream) + if hasattr(self, '_voice_recorder') and self._voice_recorder: + try: + self._voice_recorder.shutdown() + except Exception: + pass + self._voice_recorder = None + # Clean up old temp voice recordings + try: + from tools.voice_mode import cleanup_temp_recordings + cleanup_temp_recordings() + except Exception: + pass # Unregister callbacks to avoid dangling references set_sudo_password_callback(None) set_approval_callback(None) diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index b0d70399b8..67a8323a7b 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -351,6 +351,8 @@ class BasePlatformAdapter(ABC): # Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt) self._active_sessions: Dict[str, asyncio.Event] = {} self._pending_messages: Dict[str, MessageEvent] = {} + # Chats where auto-TTS on voice input is disabled (set by /voice off) + self._auto_tts_disabled_chats: set = set() @property def name(self) -> str: @@ -537,6 +539,20 @@ class BasePlatformAdapter(ABC): text = f"{caption}\n{text}" return await self.send(chat_id=chat_id, content=text, reply_to=reply_to) + async def play_tts( + self, + chat_id: str, + audio_path: str, + **kwargs, + ) -> SendResult: + """ + Play auto-TTS audio for voice replies. + + Override in subclasses for invisible playback (e.g. Web UI). + Default falls back to send_voice (shows audio player). + """ + return await self.send_voice(chat_id=chat_id, audio_path=audio_path, **kwargs) + async def send_video( self, chat_id: str, @@ -724,7 +740,43 @@ class BasePlatformAdapter(ABC): if images: logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response)) - # Send the text portion first (if any remains after extractions) + # Auto-TTS: if voice message, generate audio FIRST (before sending text) + # Skipped when the chat has voice mode disabled (/voice off) + _tts_path = None + if (event.message_type == MessageType.VOICE + and text_content + and not media_files + and event.source.chat_id not in self._auto_tts_disabled_chats): + try: + from tools.tts_tool import text_to_speech_tool, check_tts_requirements + if check_tts_requirements(): + import json as _json + speech_text = re.sub(r'[*_`#\[\]()]', '', text_content)[:4000].strip() + if not speech_text: + raise ValueError("Empty text after markdown cleanup") + tts_result_str = await asyncio.to_thread( + text_to_speech_tool, text=speech_text + ) + tts_data = _json.loads(tts_result_str) + _tts_path = tts_data.get("file_path") + except Exception as tts_err: + logger.warning("[%s] Auto-TTS failed: %s", self.name, tts_err) + + # Play TTS audio before text (voice-first experience) + if _tts_path and Path(_tts_path).exists(): + try: + await self.play_tts( + chat_id=event.source.chat_id, + audio_path=_tts_path, + metadata=_thread_metadata, + ) + finally: + try: + os.remove(_tts_path) + except OSError: + pass + + # Send the text portion if text_content: logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id) result = await self.send( @@ -733,7 +785,7 @@ class BasePlatformAdapter(ABC): reply_to=event.message_id, metadata=_thread_metadata, ) - + # Log send failures (don't raise - user already saw tool progress) if not result.success: print(f"[{self.name}] Failed to send response: {result.error}") @@ -746,10 +798,10 @@ class BasePlatformAdapter(ABC): ) if not fallback_result.success: print(f"[{self.name}] Fallback send also failed: {fallback_result.error}") - + # Human-like pacing delay between text and media human_delay = self._get_human_delay() - + # Send extracted images as native attachments if images: logger.info("[%s] Extracted %d image(s) to send as attachments", self.name, len(images)) @@ -777,7 +829,7 @@ class BasePlatformAdapter(ABC): logger.error("[%s] Failed to send image: %s", self.name, img_result.error) except Exception as img_err: logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True) - + # Send extracted media files — route by file type _AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'} _VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.3gp'} diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 47760d2367..0d23407bf3 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -10,7 +10,13 @@ Uses discord.py library for: import asyncio import logging import os -from typing import Dict, List, Optional, Any +import struct +import subprocess +import tempfile +import threading +import time +from collections import defaultdict +from typing import Callable, Dict, List, Optional, Any logger = logging.getLogger(__name__) @@ -65,6 +71,299 @@ def check_discord_requirements() -> bool: return DISCORD_AVAILABLE +class VoiceReceiver: + """Captures and decodes voice audio from a Discord voice channel. + + Attaches to a VoiceClient's socket listener, decrypts RTP packets + (NaCl transport + DAVE E2EE), decodes Opus to PCM, and buffers + per-user audio. A polling loop detects silence and delivers + completed utterances via a callback. + """ + + SILENCE_THRESHOLD = 1.5 # seconds of silence → end of utterance + MIN_SPEECH_DURATION = 0.5 # minimum seconds to process (skip noise) + SAMPLE_RATE = 48000 # Discord native rate + CHANNELS = 2 # Discord sends stereo + + def __init__(self, voice_client): + self._vc = voice_client + self._running = False + + # Decryption + self._secret_key: Optional[bytes] = None + self._dave_session = None + self._bot_ssrc: int = 0 + + # SSRC -> user_id mapping (populated from SPEAKING events) + self._ssrc_to_user: Dict[int, int] = {} + self._lock = threading.Lock() + + # Per-user audio buffers + self._buffers: Dict[int, bytearray] = defaultdict(bytearray) + self._last_packet_time: Dict[int, float] = {} + + # Opus decoder per SSRC (each user needs own decoder state) + self._decoders: Dict[int, object] = {} + + # Pause flag: don't capture while bot is playing TTS + self._paused = False + + # Debug logging counter (instance-level to avoid cross-instance races) + self._packet_debug_count = 0 + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def start(self): + """Start listening for voice packets.""" + conn = self._vc._connection + self._secret_key = bytes(conn.secret_key) + self._dave_session = conn.dave_session + self._bot_ssrc = conn.ssrc + + self._install_speaking_hook(conn) + conn.add_socket_listener(self._on_packet) + self._running = True + logger.info("VoiceReceiver started (bot_ssrc=%d)", self._bot_ssrc) + + def stop(self): + """Stop listening and clean up.""" + self._running = False + try: + self._vc._connection.remove_socket_listener(self._on_packet) + except Exception: + pass + with self._lock: + self._buffers.clear() + self._last_packet_time.clear() + self._decoders.clear() + self._ssrc_to_user.clear() + logger.info("VoiceReceiver stopped") + + def pause(self): + self._paused = True + + def resume(self): + self._paused = False + + # ------------------------------------------------------------------ + # SSRC -> user_id mapping via SPEAKING opcode hook + # ------------------------------------------------------------------ + + def map_ssrc(self, ssrc: int, user_id: int): + with self._lock: + self._ssrc_to_user[ssrc] = user_id + + def _install_speaking_hook(self, conn): + """Wrap the voice websocket hook to capture SPEAKING events (op 5). + + VoiceConnectionState stores the hook as ``conn.hook`` (public attr). + It is passed to DiscordVoiceWebSocket on each (re)connect, so we + must wrap it on the VoiceConnectionState level AND on the current + live websocket instance. + """ + original_hook = conn.hook + receiver_self = self + + async def wrapped_hook(ws, msg): + if isinstance(msg, dict) and msg.get("op") == 5: + data = msg.get("d", {}) + ssrc = data.get("ssrc") + user_id = data.get("user_id") + if ssrc and user_id: + logger.info("SPEAKING event: ssrc=%d -> user=%s", ssrc, user_id) + receiver_self.map_ssrc(int(ssrc), int(user_id)) + if original_hook: + await original_hook(ws, msg) + + # Set on connection state (for future reconnects) + conn.hook = wrapped_hook + # Set on the current live websocket (for immediate effect) + try: + from discord.utils import MISSING + if hasattr(conn, 'ws') and conn.ws is not MISSING: + conn.ws._hook = wrapped_hook + logger.info("Speaking hook installed on live websocket") + except Exception as e: + logger.warning("Could not install hook on live ws: %s", e) + + # ------------------------------------------------------------------ + # Packet handler (called from SocketReader thread) + # ------------------------------------------------------------------ + + def _on_packet(self, data: bytes): + if not self._running or self._paused: + return + + # Log first few raw packets for debugging + self._packet_debug_count += 1 + if self._packet_debug_count <= 5: + logger.debug( + "Raw UDP packet: len=%d, first_bytes=%s", + len(data), data[:4].hex() if len(data) >= 4 else "short", + ) + + if len(data) < 16: + return + + # RTP version check: top 2 bits must be 10 (version 2). + # Lower bits may vary (padding, extension, CSRC count). + # Payload type (byte 1 lower 7 bits) = 0x78 (120) for voice. + if (data[0] >> 6) != 2 or (data[1] & 0x7F) != 0x78: + if self._packet_debug_count <= 5: + logger.debug("Skipped non-RTP: byte0=0x%02x byte1=0x%02x", data[0], data[1]) + return + + first_byte = data[0] + _, _, seq, timestamp, ssrc = struct.unpack_from(">BBHII", data, 0) + + # Skip bot's own audio + if ssrc == self._bot_ssrc: + return + + # Calculate dynamic RTP header size (RFC 9335 / rtpsize mode) + cc = first_byte & 0x0F # CSRC count + has_extension = bool(first_byte & 0x10) # extension bit + header_size = 12 + (4 * cc) + (4 if has_extension else 0) + + if len(data) < header_size + 4: # need at least header + nonce + return + + # Read extension length from preamble (for skipping after decrypt) + ext_data_len = 0 + if has_extension: + ext_preamble_offset = 12 + (4 * cc) + ext_words = struct.unpack_from(">H", data, ext_preamble_offset + 2)[0] + ext_data_len = ext_words * 4 + + if self._packet_debug_count <= 10: + with self._lock: + known_user = self._ssrc_to_user.get(ssrc, "unknown") + logger.debug( + "RTP packet: ssrc=%d, seq=%d, user=%s, hdr=%d, ext_data=%d", + ssrc, seq, known_user, header_size, ext_data_len, + ) + + header = bytes(data[:header_size]) + payload_with_nonce = data[header_size:] + + # --- NaCl transport decrypt (aead_xchacha20_poly1305_rtpsize) --- + if len(payload_with_nonce) < 4: + return + nonce = bytearray(24) + nonce[:4] = payload_with_nonce[-4:] + encrypted = bytes(payload_with_nonce[:-4]) + + try: + import nacl.secret # noqa: delayed import – only in voice path + box = nacl.secret.Aead(self._secret_key) + decrypted = box.decrypt(encrypted, header, bytes(nonce)) + except Exception as e: + if self._packet_debug_count <= 10: + logger.warning("NaCl decrypt failed: %s (hdr=%d, enc=%d)", e, header_size, len(encrypted)) + return + + # Skip encrypted extension data to get the actual opus payload + if ext_data_len and len(decrypted) > ext_data_len: + decrypted = decrypted[ext_data_len:] + + # --- DAVE E2EE decrypt --- + if self._dave_session: + with self._lock: + user_id = self._ssrc_to_user.get(ssrc, 0) + if user_id == 0: + if self._packet_debug_count <= 10: + logger.warning("DAVE skip: unknown user for ssrc=%d", ssrc) + return # unknown user, can't DAVE-decrypt + try: + import davey + decrypted = self._dave_session.decrypt( + user_id, davey.MediaType.audio, decrypted + ) + except Exception as e: + if self._packet_debug_count <= 10: + logger.warning("DAVE decrypt failed for ssrc=%d: %s", ssrc, e) + return + + # --- Opus decode -> PCM --- + try: + if ssrc not in self._decoders: + self._decoders[ssrc] = discord.opus.Decoder() + pcm = self._decoders[ssrc].decode(decrypted) + with self._lock: + self._buffers[ssrc].extend(pcm) + self._last_packet_time[ssrc] = time.monotonic() + except Exception as e: + logger.debug("Opus decode error for SSRC %s: %s", ssrc, e) + return + + # ------------------------------------------------------------------ + # Silence detection + # ------------------------------------------------------------------ + + def check_silence(self) -> list: + """Return list of (user_id, pcm_bytes) for completed utterances.""" + now = time.monotonic() + completed = [] + + with self._lock: + ssrc_user_map = dict(self._ssrc_to_user) + ssrc_list = list(self._buffers.keys()) + + for ssrc in ssrc_list: + last_time = self._last_packet_time.get(ssrc, now) + silence_duration = now - last_time + buf = self._buffers[ssrc] + # 48kHz, 16-bit, stereo = 192000 bytes/sec + buf_duration = len(buf) / (self.SAMPLE_RATE * self.CHANNELS * 2) + + if silence_duration >= self.SILENCE_THRESHOLD and buf_duration >= self.MIN_SPEECH_DURATION: + user_id = ssrc_user_map.get(ssrc, 0) + if user_id: + completed.append((user_id, bytes(buf))) + self._buffers[ssrc] = bytearray() + self._last_packet_time.pop(ssrc, None) + elif silence_duration >= self.SILENCE_THRESHOLD * 2: + # Stale buffer with no valid user — discard + self._buffers.pop(ssrc, None) + self._last_packet_time.pop(ssrc, None) + + return completed + + # ------------------------------------------------------------------ + # PCM -> WAV conversion (for Whisper STT) + # ------------------------------------------------------------------ + + @staticmethod + def pcm_to_wav(pcm_data: bytes, output_path: str, + src_rate: int = 48000, src_channels: int = 2): + """Convert raw PCM to 16kHz mono WAV via ffmpeg.""" + with tempfile.NamedTemporaryFile(suffix=".pcm", delete=False) as f: + f.write(pcm_data) + pcm_path = f.name + try: + subprocess.run( + [ + "ffmpeg", "-y", "-loglevel", "error", + "-f", "s16le", + "-ar", str(src_rate), + "-ac", str(src_channels), + "-i", pcm_path, + "-ar", "16000", + "-ac", "1", + output_path, + ], + check=True, + timeout=10, + ) + finally: + try: + os.unlink(pcm_path) + except OSError: + pass + + class DiscordAdapter(BasePlatformAdapter): """ Discord bot adapter. @@ -82,17 +381,54 @@ class DiscordAdapter(BasePlatformAdapter): # Discord message limits MAX_MESSAGE_LENGTH = 2000 + # Auto-disconnect from voice channel after this many seconds of inactivity + VOICE_TIMEOUT = 300 + def __init__(self, config: PlatformConfig): super().__init__(config, Platform.DISCORD) self._client: Optional[commands.Bot] = None self._ready_event = asyncio.Event() self._allowed_user_ids: set = set() # For button approval authorization + # Voice channel state (per-guild) + self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient + self._voice_text_channels: Dict[int, int] = {} # guild_id -> text_channel_id + self._voice_timeout_tasks: Dict[int, asyncio.Task] = {} # guild_id -> timeout task + # Phase 2: voice listening + self._voice_receivers: Dict[int, VoiceReceiver] = {} # guild_id -> VoiceReceiver + self._voice_listen_tasks: Dict[int, asyncio.Task] = {} # guild_id -> listen loop + self._voice_input_callback: Optional[Callable] = None # set by run.py + self._on_voice_disconnect: Optional[Callable] = None # set by run.py async def connect(self) -> bool: """Connect to Discord and start receiving events.""" if not DISCORD_AVAILABLE: logger.error("[%s] discord.py not installed. Run: pip install discord.py", self.name) return False + + # Load opus codec for voice channel support + if not discord.opus.is_loaded(): + import ctypes.util + opus_path = ctypes.util.find_library("opus") + # ctypes.util.find_library fails on macOS with Homebrew-installed libs, + # so fall back to known Homebrew paths if needed. + if not opus_path: + import sys + _homebrew_paths = ( + "/opt/homebrew/lib/libopus.dylib", # Apple Silicon + "/usr/local/lib/libopus.dylib", # Intel Mac + ) + if sys.platform == "darwin": + for _hp in _homebrew_paths: + if os.path.isfile(_hp): + opus_path = _hp + break + if opus_path: + try: + discord.opus.load_opus(opus_path) + except Exception: + logger.warning("Opus codec found at %s but failed to load", opus_path) + if not discord.opus.is_loaded(): + logger.warning("Opus codec not found — voice channel playback disabled") if not self.config.token: logger.error("[%s] No bot token configured", self.name) @@ -105,6 +441,7 @@ class DiscordAdapter(BasePlatformAdapter): intents.dm_messages = True intents.guild_messages = True intents.members = True + intents.voice_states = True # Create bot self._client = commands.Bot( @@ -158,7 +495,40 @@ class DiscordAdapter(BasePlatformAdapter): # "all" falls through to handle_message await self._handle_message(message) - + + @self._client.event + async def on_voice_state_update(member, before, after): + """Track voice channel join/leave events.""" + # Only track channels where the bot is connected + bot_guild_ids = set(adapter_self._voice_clients.keys()) + if not bot_guild_ids: + return + guild_id = member.guild.id + if guild_id not in bot_guild_ids: + return + # Ignore the bot itself + if member == adapter_self._client.user: + return + + joined = before.channel is None and after.channel is not None + left = before.channel is not None and after.channel is None + switched = ( + before.channel is not None + and after.channel is not None + and before.channel != after.channel + ) + + if joined or left or switched: + logger.info( + "Voice state: %s (%d) %s (guild %d)", + member.display_name, + member.id, + "joined " + after.channel.name if joined + else "left " + before.channel.name if left + else f"moved {before.channel.name} -> {after.channel.name}", + guild_id, + ) + # Register slash commands self._register_slash_commands() @@ -180,12 +550,19 @@ class DiscordAdapter(BasePlatformAdapter): async def disconnect(self) -> None: """Disconnect from Discord.""" + # Clean up all active voice connections before closing the client + for guild_id in list(self._voice_clients.keys()): + try: + await self.leave_voice_channel(guild_id) + except Exception as e: # pragma: no cover - defensive logging + logger.debug("[%s] Error leaving voice channel %s: %s", self.name, guild_id, e) + if self._client: try: await self._client.close() except Exception as e: # pragma: no cover - defensive logging logger.warning("[%s] Error during disconnect: %s", self.name, e, exc_info=True) - + self._running = False self._client = None self._ready_event.clear() @@ -287,6 +664,23 @@ class DiscordAdapter(BasePlatformAdapter): msg = await channel.send(content=caption if caption else None, file=file) return SendResult(success=True, message_id=str(msg.id)) + async def play_tts( + self, + chat_id: str, + audio_path: str, + **kwargs, + ) -> SendResult: + """Play auto-TTS audio. + + When the bot is in a voice channel for this chat's guild, skip the + file attachment — the gateway runner plays audio in the VC instead. + """ + for gid, text_ch_id in self._voice_text_channels.items(): + if str(text_ch_id) == str(chat_id) and self.is_in_voice_channel(gid): + logger.debug("[%s] Skipping play_tts for %s — VC playback handled by runner", self.name, chat_id) + return SendResult(success=True) + return await self.send_voice(chat_id=chat_id, audio_path=audio_path, **kwargs) + async def send_voice( self, chat_id: str, @@ -294,16 +688,356 @@ class DiscordAdapter(BasePlatformAdapter): caption: Optional[str] = None, reply_to: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + **kwargs, ) -> SendResult: """Send audio as a Discord file attachment.""" try: - return await self._send_file_attachment(chat_id, audio_path, caption) - except FileNotFoundError: - return SendResult(success=False, error=f"Audio file not found: {audio_path}") + import io + + channel = self._client.get_channel(int(chat_id)) + if not channel: + channel = await self._client.fetch_channel(int(chat_id)) + if not channel: + return SendResult(success=False, error=f"Channel {chat_id} not found") + + if not os.path.exists(audio_path): + return SendResult(success=False, error=f"Audio file not found: {audio_path}") + + filename = os.path.basename(audio_path) + + with open(audio_path, "rb") as f: + file_data = f.read() + + # Try sending as a native voice message via raw API (flags=8192). + try: + import base64 + + duration_secs = 5.0 + try: + from mutagen.oggopus import OggOpus + info = OggOpus(audio_path) + duration_secs = info.info.length + except Exception: + duration_secs = max(1.0, len(file_data) / 2000.0) + + waveform_bytes = bytes([128] * 256) + waveform_b64 = base64.b64encode(waveform_bytes).decode() + + import json as _json + payload = _json.dumps({ + "flags": 8192, + "attachments": [{ + "id": "0", + "filename": "voice-message.ogg", + "duration_secs": round(duration_secs, 2), + "waveform": waveform_b64, + }], + }) + form = [ + {"name": "payload_json", "value": payload}, + { + "name": "files[0]", + "value": file_data, + "filename": "voice-message.ogg", + "content_type": "audio/ogg", + }, + ] + msg_data = await self._client.http.request( + discord.http.Route("POST", "/channels/{channel_id}/messages", channel_id=channel.id), + form=form, + ) + return SendResult(success=True, message_id=str(msg_data["id"])) + except Exception as voice_err: + logger.debug("Voice message flag failed, falling back to file: %s", voice_err) + file = discord.File(io.BytesIO(file_data), filename=filename) + msg = await channel.send(file=file) + return SendResult(success=True, message_id=str(msg.id)) except Exception as e: # pragma: no cover - defensive logging logger.error("[%s] Failed to send audio, falling back to base adapter: %s", self.name, e, exc_info=True) return await super().send_voice(chat_id, audio_path, caption, reply_to, metadata=metadata) + # ------------------------------------------------------------------ + # Voice channel methods (join / leave / play) + # ------------------------------------------------------------------ + + async def join_voice_channel(self, channel) -> bool: + """Join a Discord voice channel. Returns True on success.""" + if not self._client or not DISCORD_AVAILABLE: + return False + guild_id = channel.guild.id + + # Already connected in this guild? + existing = self._voice_clients.get(guild_id) + if existing and existing.is_connected(): + if existing.channel.id == channel.id: + self._reset_voice_timeout(guild_id) + return True + await existing.move_to(channel) + self._reset_voice_timeout(guild_id) + return True + + vc = await channel.connect() + self._voice_clients[guild_id] = vc + self._reset_voice_timeout(guild_id) + + # Start voice receiver (Phase 2: listen to users) + try: + receiver = VoiceReceiver(vc) + receiver.start() + self._voice_receivers[guild_id] = receiver + self._voice_listen_tasks[guild_id] = asyncio.ensure_future( + self._voice_listen_loop(guild_id) + ) + except Exception as e: + logger.warning("Voice receiver failed to start: %s", e) + + return True + + async def leave_voice_channel(self, guild_id: int) -> None: + """Disconnect from the voice channel in a guild.""" + # Stop voice receiver first + receiver = self._voice_receivers.pop(guild_id, None) + if receiver: + receiver.stop() + listen_task = self._voice_listen_tasks.pop(guild_id, None) + if listen_task: + listen_task.cancel() + + vc = self._voice_clients.pop(guild_id, None) + if vc and vc.is_connected(): + await vc.disconnect() + task = self._voice_timeout_tasks.pop(guild_id, None) + if task: + task.cancel() + self._voice_text_channels.pop(guild_id, None) + + # Maximum seconds to wait for voice playback before giving up + PLAYBACK_TIMEOUT = 120 + + async def play_in_voice_channel(self, guild_id: int, audio_path: str) -> bool: + """Play an audio file in the connected voice channel.""" + vc = self._voice_clients.get(guild_id) + if not vc or not vc.is_connected(): + return False + + # Pause voice receiver while playing (echo prevention) + receiver = self._voice_receivers.get(guild_id) + if receiver: + receiver.pause() + + try: + # Wait for current playback to finish (with timeout) + wait_start = time.monotonic() + while vc.is_playing(): + if time.monotonic() - wait_start > self.PLAYBACK_TIMEOUT: + logger.warning("Timed out waiting for previous playback to finish") + vc.stop() + break + await asyncio.sleep(0.1) + + done = asyncio.Event() + loop = asyncio.get_running_loop() + + def _after(error): + if error: + logger.error("Voice playback error: %s", error) + loop.call_soon_threadsafe(done.set) + + source = discord.FFmpegPCMAudio(audio_path) + source = discord.PCMVolumeTransformer(source, volume=1.0) + vc.play(source, after=_after) + try: + await asyncio.wait_for(done.wait(), timeout=self.PLAYBACK_TIMEOUT) + except asyncio.TimeoutError: + logger.warning("Voice playback timed out after %ds", self.PLAYBACK_TIMEOUT) + vc.stop() + self._reset_voice_timeout(guild_id) + return True + finally: + if receiver: + receiver.resume() + + async def get_user_voice_channel(self, guild_id: int, user_id: str): + """Return the voice channel the user is currently in, or None.""" + if not self._client: + return None + guild = self._client.get_guild(guild_id) + if not guild: + return None + member = guild.get_member(int(user_id)) + if not member or not member.voice: + return None + return member.voice.channel + + def _reset_voice_timeout(self, guild_id: int) -> None: + """Reset the auto-disconnect inactivity timer.""" + task = self._voice_timeout_tasks.pop(guild_id, None) + if task: + task.cancel() + self._voice_timeout_tasks[guild_id] = asyncio.ensure_future( + self._voice_timeout_handler(guild_id) + ) + + async def _voice_timeout_handler(self, guild_id: int) -> None: + """Auto-disconnect after VOICE_TIMEOUT seconds of inactivity.""" + try: + await asyncio.sleep(self.VOICE_TIMEOUT) + except asyncio.CancelledError: + return + text_ch_id = self._voice_text_channels.get(guild_id) + await self.leave_voice_channel(guild_id) + # Notify the runner so it can clean up voice_mode state + if self._on_voice_disconnect and text_ch_id: + try: + self._on_voice_disconnect(str(text_ch_id)) + except Exception: + pass + if text_ch_id and self._client: + ch = self._client.get_channel(text_ch_id) + if ch: + try: + await ch.send("Left voice channel (inactivity timeout).") + except Exception: + pass + + def is_in_voice_channel(self, guild_id: int) -> bool: + """Check if the bot is connected to a voice channel in this guild.""" + vc = self._voice_clients.get(guild_id) + return vc is not None and vc.is_connected() + + def get_voice_channel_info(self, guild_id: int) -> Optional[Dict[str, Any]]: + """Return voice channel awareness info for the given guild. + + Returns None if the bot is not in a voice channel. Otherwise + returns a dict with channel name, member list, count, and + currently-speaking user IDs (from SSRC mapping). + """ + vc = self._voice_clients.get(guild_id) + if not vc or not vc.is_connected(): + return None + + channel = vc.channel + if not channel: + return None + + # Members currently in the voice channel (includes bot) + members_info = [] + bot_user = self._client.user if self._client else None + for m in channel.members: + if bot_user and m.id == bot_user.id: + continue # skip the bot itself + members_info.append({ + "user_id": m.id, + "display_name": m.display_name, + "is_bot": m.bot, + }) + + # Currently speaking users (from SSRC mapping + active buffers) + speaking_user_ids: set = set() + receiver = self._voice_receivers.get(guild_id) + if receiver: + import time as _time + now = _time.monotonic() + with receiver._lock: + for ssrc, last_t in receiver._last_packet_time.items(): + # Consider "speaking" if audio received within last 2 seconds + if now - last_t < 2.0: + uid = receiver._ssrc_to_user.get(ssrc) + if uid: + speaking_user_ids.add(uid) + + # Tag speaking status on members + for info in members_info: + info["is_speaking"] = info["user_id"] in speaking_user_ids + + return { + "channel_name": channel.name, + "member_count": len(members_info), + "members": members_info, + "speaking_count": len(speaking_user_ids), + } + + def get_voice_channel_context(self, guild_id: int) -> str: + """Return a human-readable voice channel context string. + + Suitable for injection into the system/ephemeral prompt so the + agent is always aware of voice channel state. + """ + info = self.get_voice_channel_info(guild_id) + if not info: + return "" + + parts = [f"[Voice channel: #{info['channel_name']} — {info['member_count']} participant(s)]"] + for m in info["members"]: + status = " (speaking)" if m["is_speaking"] else "" + parts.append(f" - {m['display_name']}{status}") + + return "\n".join(parts) + + # ------------------------------------------------------------------ + # Voice listening (Phase 2) + # ------------------------------------------------------------------ + + async def _voice_listen_loop(self, guild_id: int): + """Periodically check for completed utterances and process them.""" + receiver = self._voice_receivers.get(guild_id) + if not receiver: + return + try: + while receiver._running: + await asyncio.sleep(0.2) + completed = receiver.check_silence() + for user_id, pcm_data in completed: + if not self._is_allowed_user(str(user_id)): + continue + await self._process_voice_input(guild_id, user_id, pcm_data) + except asyncio.CancelledError: + pass + except Exception as e: + logger.error("Voice listen loop error: %s", e, exc_info=True) + + async def _process_voice_input(self, guild_id: int, user_id: int, pcm_data: bytes): + """Convert PCM -> WAV -> STT -> callback.""" + from tools.voice_mode import is_whisper_hallucination + + tmp_f = tempfile.NamedTemporaryFile(suffix=".wav", prefix="vc_listen_", delete=False) + wav_path = tmp_f.name + tmp_f.close() + try: + await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path) + + from tools.transcription_tools import transcribe_audio, get_stt_model_from_config + stt_model = get_stt_model_from_config() + result = await asyncio.to_thread(transcribe_audio, wav_path, model=stt_model) + + if not result.get("success"): + return + transcript = result.get("transcript", "").strip() + if not transcript or is_whisper_hallucination(transcript): + return + + logger.info("Voice input from user %d: %s", user_id, transcript[:100]) + + if self._voice_input_callback: + await self._voice_input_callback( + guild_id=guild_id, + user_id=user_id, + transcript=transcript, + ) + except Exception as e: + logger.warning("Voice input processing failed: %s", e, exc_info=True) + finally: + try: + os.unlink(wav_path) + except OSError: + pass + + def _is_allowed_user(self, user_id: str) -> bool: + """Check if user is in DISCORD_ALLOWED_USERS.""" + if not self._allowed_user_ids: + return True + return user_id in self._allowed_user_ids + async def send_image_file( self, chat_id: str, @@ -627,6 +1361,25 @@ class DiscordAdapter(BasePlatformAdapter): async def slash_reload_mcp(interaction: discord.Interaction): await self._run_simple_slash(interaction, "/reload-mcp") + @tree.command(name="voice", description="Toggle voice reply mode") + @discord.app_commands.describe(mode="Voice mode: on, off, tts, channel, leave, or status") + @discord.app_commands.choices(mode=[ + discord.app_commands.Choice(name="channel — join your voice channel", value="channel"), + discord.app_commands.Choice(name="leave — leave voice channel", value="leave"), + discord.app_commands.Choice(name="on — voice reply to voice messages", value="on"), + discord.app_commands.Choice(name="tts — voice reply to all messages", value="tts"), + discord.app_commands.Choice(name="off — text only", value="off"), + discord.app_commands.Choice(name="status — show current mode", value="status"), + ]) + async def slash_voice(interaction: discord.Interaction, mode: str = ""): + await interaction.response.defer(ephemeral=True) + event = self._build_slash_event(interaction, f"/voice {mode}".strip()) + await self.handle_message(event) + try: + await interaction.followup.send("Done~", ephemeral=True) + except Exception as e: + logger.debug("Discord followup failed: %s", e) + @tree.command(name="update", description="Update Hermes Agent to the latest version") async def slash_update(interaction: discord.Interaction): await self._run_simple_slash(interaction, "/update", "Update initiated~") diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py index d75685bfbc..cd9dd4d2bf 100644 --- a/gateway/platforms/slack.py +++ b/gateway/platforms/slack.py @@ -506,6 +506,7 @@ class SlackAdapter(BasePlatformAdapter): caption: Optional[str] = None, reply_to: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + **kwargs, ) -> SendResult: """Send an audio file to Slack.""" try: diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 06f423c661..df44733e3d 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -150,7 +150,10 @@ class TelegramAdapter(BasePlatformAdapter): # Start polling in background await self._app.initialize() await self._app.start() - await self._app.updater.start_polling(allowed_updates=Update.ALL_TYPES) + await self._app.updater.start_polling( + allowed_updates=Update.ALL_TYPES, + drop_pending_updates=True, + ) # Register bot commands so Telegram shows a hint menu when users type / try: @@ -174,6 +177,7 @@ class TelegramAdapter(BasePlatformAdapter): BotCommand("insights", "Show usage insights and analytics"), BotCommand("update", "Update Hermes to the latest version"), BotCommand("reload_mcp", "Reload MCP servers from config"), + BotCommand("voice", "Toggle voice reply mode"), BotCommand("help", "Show available commands"), ]) except Exception as e: @@ -307,6 +311,7 @@ class TelegramAdapter(BasePlatformAdapter): caption: Optional[str] = None, reply_to: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + **kwargs, ) -> SendResult: """Send audio as a native Telegram voice message or audio file.""" if not self._bot: diff --git a/gateway/run.py b/gateway/run.py index 221f8f9163..6795610a88 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -14,13 +14,16 @@ Usage: """ import asyncio +import json import logging import os import re import shlex import sys import signal +import tempfile import threading +import time from logging.handlers import RotatingFileHandler from pathlib import Path from datetime import datetime @@ -281,6 +284,9 @@ class GatewayRunner: from gateway.hooks import HookRegistry self.hooks = HookRegistry() + # Per-chat voice reply mode: "off" | "voice_only" | "all" + self._voice_mode: Dict[str, str] = self._load_voice_modes() + def _get_or_create_gateway_honcho(self, session_key: str): """Return a persistent Honcho manager/config pair for this gateway session.""" if not hasattr(self, "_honcho_managers"): @@ -336,6 +342,57 @@ class GatewayRunner: for session_key in list(managers.keys()): self._shutdown_gateway_honcho(session_key) + # -- Voice mode persistence ------------------------------------------ + + _VOICE_MODE_PATH = _hermes_home / "gateway_voice_mode.json" + + def _load_voice_modes(self) -> Dict[str, str]: + try: + data = json.loads(self._VOICE_MODE_PATH.read_text()) + except (FileNotFoundError, json.JSONDecodeError, OSError): + return {} + + if not isinstance(data, dict): + return {} + + valid_modes = {"off", "voice_only", "all"} + return { + str(chat_id): mode + for chat_id, mode in data.items() + if mode in valid_modes + } + + def _save_voice_modes(self) -> None: + try: + self._VOICE_MODE_PATH.parent.mkdir(parents=True, exist_ok=True) + self._VOICE_MODE_PATH.write_text( + json.dumps(self._voice_mode, indent=2) + ) + except OSError as e: + logger.warning("Failed to save voice modes: %s", e) + + def _set_adapter_auto_tts_disabled(self, adapter, chat_id: str, disabled: bool) -> None: + """Update an adapter's in-memory auto-TTS suppression set if present.""" + disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None) + if not isinstance(disabled_chats, set): + return + if disabled: + disabled_chats.add(chat_id) + else: + disabled_chats.discard(chat_id) + + def _sync_voice_mode_state_to_adapter(self, adapter) -> None: + """Restore persisted /voice off state into a live platform adapter.""" + disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None) + if not isinstance(disabled_chats, set): + return + disabled_chats.clear() + disabled_chats.update( + chat_id for chat_id, mode in self._voice_mode.items() if mode == "off" + ) + + # ----------------------------------------------------------------- + def _flush_memories_for_session(self, old_session_id: str): """Prompt the agent to save memories/skills before context is lost. @@ -639,6 +696,7 @@ class GatewayRunner: success = await adapter.connect() if success: self.adapters[platform] = adapter + self._sync_voice_mode_state_to_adapter(adapter) connected_count += 1 logger.info("✓ %s connected", platform.value) else: @@ -737,7 +795,7 @@ class GatewayRunner: logger.info("Stopping gateway...") self._running = False - for platform, adapter in self.adapters.items(): + for platform, adapter in list(self.adapters.items()): try: await adapter.disconnect() logger.info("✓ %s disconnected", platform.value) @@ -897,7 +955,7 @@ class GatewayRunner: 7. Return response """ source = event.source - + # Check if user is authorized if not self._is_user_authorized(source): logger.warning("Unauthorized user: %s (%s) on %s", source.user_id, source.user_name, source.platform.value) @@ -949,7 +1007,7 @@ class GatewayRunner: "personality", "retry", "undo", "sethome", "set-home", "compress", "usage", "insights", "reload-mcp", "reload_mcp", "update", "title", "resume", "provider", "rollback", - "background", "reasoning"} + "background", "reasoning", "voice"} if command and command in _known_commands: await self.hooks.emit(f"command:{command}", { "platform": source.platform.value if source.platform else "", @@ -1020,7 +1078,10 @@ class GatewayRunner: if command == "reasoning": return await self._handle_reasoning_command(event) - + + if command == "voice": + return await self._handle_voice_command(event) + # User-defined quick commands (bypass agent loop, no LLM call) if command: if isinstance(self.config, dict): @@ -1377,6 +1438,19 @@ class GatewayRunner: f"or ignore to skip." ) + # ----------------------------------------------------------------- + # Voice channel awareness — inject current voice channel state + # into context so the agent knows who is in the channel and who + # is speaking, without needing a separate tool call. + # ----------------------------------------------------------------- + if source.platform == Platform.DISCORD: + adapter = self.adapters.get(Platform.DISCORD) + guild_id = self._get_guild_id(event) + if guild_id and adapter and hasattr(adapter, "get_voice_channel_context"): + vc_context = adapter.get_voice_channel_context(guild_id) + if vc_context: + context_prompt += f"\n\n{vc_context}" + # ----------------------------------------------------------------- # Auto-analyze images sent by the user # @@ -1583,7 +1657,11 @@ class GatewayRunner: session_entry.session_key, last_prompt_tokens=agent_result.get("last_prompt_tokens", 0), ) - + + # Auto voice reply: send TTS audio before the text response + if self._should_send_voice_reply(event, response, agent_messages): + await self._send_voice_reply(event, response) + return response except Exception as e: @@ -1692,6 +1770,7 @@ class GatewayRunner: "`/reasoning [level|show|hide]` — Set reasoning effort or toggle display", "`/rollback [number]` — List or restore filesystem checkpoints", "`/background ` — Run a prompt in a separate background session", + "`/voice [on|off|tts|status]` — Toggle voice reply mode", "`/reload-mcp` — Reload MCP servers from config", "`/update` — Update Hermes Agent to the latest version", "`/help` — Show this message", @@ -2067,6 +2146,337 @@ class GatewayRunner: f"Cron jobs and cross-platform messages will be delivered here." ) + @staticmethod + def _get_guild_id(event: MessageEvent) -> Optional[int]: + """Extract Discord guild_id from the raw message object.""" + raw = getattr(event, "raw_message", None) + if raw is None: + return None + # Slash command interaction + if hasattr(raw, "guild_id") and raw.guild_id: + return int(raw.guild_id) + # Regular message + if hasattr(raw, "guild") and raw.guild: + return raw.guild.id + return None + + async def _handle_voice_command(self, event: MessageEvent) -> str: + """Handle /voice [on|off|tts|channel|leave|status] command.""" + args = event.get_command_args().strip().lower() + chat_id = event.source.chat_id + + adapter = self.adapters.get(event.source.platform) + + if args in ("on", "enable"): + self._voice_mode[chat_id] = "voice_only" + self._save_voice_modes() + if adapter: + self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False) + return ( + "Voice mode enabled.\n" + "I'll reply with voice when you send voice messages.\n" + "Use /voice tts to get voice replies for all messages." + ) + elif args in ("off", "disable"): + self._voice_mode[chat_id] = "off" + self._save_voice_modes() + if adapter: + self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True) + return "Voice mode disabled. Text-only replies." + elif args == "tts": + self._voice_mode[chat_id] = "all" + self._save_voice_modes() + if adapter: + self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False) + return ( + "Auto-TTS enabled.\n" + "All replies will include a voice message." + ) + elif args in ("channel", "join"): + return await self._handle_voice_channel_join(event) + elif args == "leave": + return await self._handle_voice_channel_leave(event) + elif args == "status": + mode = self._voice_mode.get(chat_id, "off") + labels = { + "off": "Off (text only)", + "voice_only": "On (voice reply to voice messages)", + "all": "TTS (voice reply to all messages)", + } + # Append voice channel info if connected + adapter = self.adapters.get(event.source.platform) + guild_id = self._get_guild_id(event) + if guild_id and hasattr(adapter, "get_voice_channel_info"): + info = adapter.get_voice_channel_info(guild_id) + if info: + lines = [ + f"Voice mode: {labels.get(mode, mode)}", + f"Voice channel: #{info['channel_name']}", + f"Participants: {info['member_count']}", + ] + for m in info["members"]: + status = " (speaking)" if m.get("is_speaking") else "" + lines.append(f" - {m['display_name']}{status}") + return "\n".join(lines) + return f"Voice mode: {labels.get(mode, mode)}" + else: + # Toggle: off → on, on/all → off + current = self._voice_mode.get(chat_id, "off") + if current == "off": + self._voice_mode[chat_id] = "voice_only" + self._save_voice_modes() + if adapter: + self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False) + return "Voice mode enabled." + else: + self._voice_mode[chat_id] = "off" + self._save_voice_modes() + if adapter: + self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True) + return "Voice mode disabled." + + async def _handle_voice_channel_join(self, event: MessageEvent) -> str: + """Join the user's current Discord voice channel.""" + adapter = self.adapters.get(event.source.platform) + if not hasattr(adapter, "join_voice_channel"): + return "Voice channels are not supported on this platform." + + guild_id = self._get_guild_id(event) + if not guild_id: + return "This command only works in a Discord server." + + voice_channel = await adapter.get_user_voice_channel( + guild_id, event.source.user_id + ) + if not voice_channel: + return "You need to be in a voice channel first." + + # Wire callbacks BEFORE join so voice input arriving immediately + # after connection is not lost. + if hasattr(adapter, "_voice_input_callback"): + adapter._voice_input_callback = self._handle_voice_channel_input + if hasattr(adapter, "_on_voice_disconnect"): + adapter._on_voice_disconnect = self._handle_voice_timeout_cleanup + + try: + success = await adapter.join_voice_channel(voice_channel) + except Exception as e: + logger.warning("Failed to join voice channel: %s", e) + adapter._voice_input_callback = None + return f"Failed to join voice channel: {e}" + + if success: + adapter._voice_text_channels[guild_id] = int(event.source.chat_id) + self._voice_mode[event.source.chat_id] = "all" + self._save_voice_modes() + self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=False) + return ( + f"Joined voice channel **{voice_channel.name}**.\n" + f"I'll speak my replies and listen to you. Use /voice leave to disconnect." + ) + # Join failed — clear callback + adapter._voice_input_callback = None + return "Failed to join voice channel. Check bot permissions (Connect + Speak)." + + async def _handle_voice_channel_leave(self, event: MessageEvent) -> str: + """Leave the Discord voice channel.""" + adapter = self.adapters.get(event.source.platform) + guild_id = self._get_guild_id(event) + + if not guild_id or not hasattr(adapter, "leave_voice_channel"): + return "Not in a voice channel." + + if not hasattr(adapter, "is_in_voice_channel") or not adapter.is_in_voice_channel(guild_id): + return "Not in a voice channel." + + try: + await adapter.leave_voice_channel(guild_id) + except Exception as e: + logger.warning("Error leaving voice channel: %s", e) + # Always clean up state even if leave raised an exception + self._voice_mode[event.source.chat_id] = "off" + self._save_voice_modes() + self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=True) + if hasattr(adapter, "_voice_input_callback"): + adapter._voice_input_callback = None + return "Left voice channel." + + def _handle_voice_timeout_cleanup(self, chat_id: str) -> None: + """Called by the adapter when a voice channel times out. + + Cleans up runner-side voice_mode state that the adapter cannot reach. + """ + self._voice_mode[chat_id] = "off" + self._save_voice_modes() + adapter = self.adapters.get(Platform.DISCORD) + self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True) + + async def _handle_voice_channel_input( + self, guild_id: int, user_id: int, transcript: str + ): + """Handle transcribed voice from a user in a voice channel. + + Creates a synthetic MessageEvent and processes it through the + adapter's full message pipeline (session, typing, agent, TTS reply). + """ + adapter = self.adapters.get(Platform.DISCORD) + if not adapter: + return + + text_ch_id = adapter._voice_text_channels.get(guild_id) + if not text_ch_id: + return + + # Check authorization before processing voice input + source = SessionSource( + platform=Platform.DISCORD, + chat_id=str(text_ch_id), + user_id=str(user_id), + user_name=str(user_id), + chat_type="channel", + ) + if not self._is_user_authorized(source): + logger.debug("Unauthorized voice input from user %d, ignoring", user_id) + return + + # Show transcript in text channel (after auth, with mention sanitization) + try: + channel = adapter._client.get_channel(text_ch_id) + if channel: + safe_text = transcript[:2000].replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere") + await channel.send(f"**[Voice]** <@{user_id}>: {safe_text}") + except Exception: + pass + + # Build a synthetic MessageEvent and feed through the normal pipeline + # Use SimpleNamespace as raw_message so _get_guild_id() can extract + # guild_id and _send_voice_reply() plays audio in the voice channel. + from types import SimpleNamespace + event = MessageEvent( + source=source, + text=transcript, + message_type=MessageType.VOICE, + raw_message=SimpleNamespace(guild_id=guild_id, guild=None), + ) + + await adapter.handle_message(event) + + def _should_send_voice_reply( + self, + event: MessageEvent, + response: str, + agent_messages: list, + ) -> bool: + """Decide whether the runner should send a TTS voice reply. + + Returns False when: + - voice_mode is off for this chat + - response is empty or an error + - agent already called text_to_speech tool (dedup) + - voice input and base adapter auto-TTS already handled it (skip_double) + Exception: Discord voice channel — base play_tts is a no-op there, + so the runner must handle VC playback. + """ + if not response or response.startswith("Error:"): + return False + + chat_id = event.source.chat_id + voice_mode = self._voice_mode.get(chat_id, "off") + is_voice_input = (event.message_type == MessageType.VOICE) + + should = ( + (voice_mode == "all") + or (voice_mode == "voice_only" and is_voice_input) + ) + if not should: + return False + + # Dedup: agent already called TTS tool + has_agent_tts = any( + msg.get("role") == "assistant" + and any( + tc.get("function", {}).get("name") == "text_to_speech" + for tc in (msg.get("tool_calls") or []) + ) + for msg in agent_messages + ) + if has_agent_tts: + return False + + # Dedup: base adapter auto-TTS already handles voice input. + # Exception: Discord voice channel — play_tts override is a no-op, + # so the runner must handle VC playback. + skip_double = is_voice_input + if skip_double: + adapter = self.adapters.get(event.source.platform) + guild_id = self._get_guild_id(event) + if (guild_id and adapter + and hasattr(adapter, "is_in_voice_channel") + and adapter.is_in_voice_channel(guild_id)): + skip_double = False + if skip_double: + return False + + return True + + async def _send_voice_reply(self, event: MessageEvent, text: str) -> None: + """Generate TTS audio and send as a voice message before the text reply.""" + import uuid as _uuid + audio_path = None + actual_path = None + try: + from tools.tts_tool import text_to_speech_tool, _strip_markdown_for_tts + + tts_text = _strip_markdown_for_tts(text[:4000]) + if not tts_text: + return + + # Use .mp3 extension so edge-tts conversion to opus works correctly. + # The TTS tool may convert to .ogg — use file_path from result. + audio_path = os.path.join( + tempfile.gettempdir(), "hermes_voice", + f"tts_reply_{_uuid.uuid4().hex[:12]}.mp3", + ) + os.makedirs(os.path.dirname(audio_path), exist_ok=True) + + result_json = await asyncio.to_thread( + text_to_speech_tool, text=tts_text, output_path=audio_path + ) + result = json.loads(result_json) + + # Use the actual file path from result (may differ after opus conversion) + actual_path = result.get("file_path", audio_path) + if not result.get("success") or not os.path.isfile(actual_path): + logger.warning("Auto voice reply TTS failed: %s", result.get("error")) + return + + adapter = self.adapters.get(event.source.platform) + + # If connected to a voice channel, play there instead of sending a file + guild_id = self._get_guild_id(event) + if (guild_id + and hasattr(adapter, "play_in_voice_channel") + and hasattr(adapter, "is_in_voice_channel") + and adapter.is_in_voice_channel(guild_id)): + await adapter.play_in_voice_channel(guild_id, actual_path) + elif adapter and hasattr(adapter, "send_voice"): + send_kwargs: Dict[str, Any] = { + "chat_id": event.source.chat_id, + "audio_path": actual_path, + "reply_to": event.message_id, + } + if event.source.thread_id: + send_kwargs["metadata"] = {"thread_id": event.source.thread_id} + await adapter.send_voice(**send_kwargs) + except Exception as e: + logger.warning("Auto voice reply failed: %s", e, exc_info=True) + finally: + for p in {audio_path, actual_path} - {None}: + try: + os.unlink(p) + except OSError: + pass + async def _handle_rollback_command(self, event: MessageEvent) -> str: """Handle /rollback command — list or restore filesystem checkpoints.""" from tools.checkpoint_manager import CheckpointManager, format_checkpoint_list @@ -3011,14 +3421,16 @@ class GatewayRunner: Returns: The enriched message string with transcriptions prepended. """ - from tools.transcription_tools import transcribe_audio + from tools.transcription_tools import transcribe_audio, get_stt_model_from_config import asyncio + stt_model = get_stt_model_from_config() + enriched_parts = [] for path in audio_paths: try: logger.debug("Transcribing user voice: %s", path) - result = await asyncio.to_thread(transcribe_audio, path) + result = await asyncio.to_thread(transcribe_audio, path, model=stt_model) if result["success"]: transcript = result["transcript"] enriched_parts.append( @@ -3027,10 +3439,10 @@ class GatewayRunner: ) else: error = result.get("error", "unknown error") - if "OPENAI_API_KEY" in error or "VOICE_TOOLS_OPENAI_KEY" in error: + if "No STT provider" in error or "not set" in error: enriched_parts.append( "[The user sent a voice message but I can't listen " - "to it right now~ VOICE_TOOLS_OPENAI_KEY isn't set up yet " + "to it right now~ No STT provider is configured " "(';w;') Let them know!]" ) else: @@ -3180,7 +3592,7 @@ class GatewayRunner: Platform.HOMEASSISTANT: "hermes-homeassistant", Platform.EMAIL: "hermes-email", } - + # Try to load platform_toolsets from config platform_toolsets_config = {} try: @@ -3192,7 +3604,7 @@ class GatewayRunner: platform_toolsets_config = user_config.get("platform_toolsets", {}) except Exception as e: logger.debug("Could not load platform_toolsets config: %s", e) - + # Map platform enum to config key platform_config_key = { Platform.LOCAL: "cli", diff --git a/gateway/session.py b/gateway/session.py index 3e42db4fe3..86e42b5950 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -383,7 +383,11 @@ class SessionStore: with open(sessions_file, "r", encoding="utf-8") as f: data = json.load(f) for key, entry_data in data.items(): - self._entries[key] = SessionEntry.from_dict(entry_data) + try: + self._entries[key] = SessionEntry.from_dict(entry_data) + except (ValueError, KeyError): + # Skip entries with unknown/removed platform values + continue except Exception as e: print(f"[gateway] Warning: Failed to load sessions: {e}") diff --git a/hermes_cli/commands.py b/hermes_cli/commands.py index 57899cf085..a9a1a67ba7 100644 --- a/hermes_cli/commands.py +++ b/hermes_cli/commands.py @@ -37,6 +37,7 @@ COMMANDS_BY_CATEGORY = { "/verbose": "Cycle tool progress display: off → new → all → verbose", "/reasoning": "Manage reasoning effort and display (usage: /reasoning [level|show|hide])", "/skin": "Show or change the display skin/theme", + "/voice": "Toggle voice mode (Ctrl+B to record). Usage: /voice [on|off|tts|status]", }, "Tools & Skills": { "/tools": "List available tools", diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 02edad1fae..b37f30f0cd 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -202,6 +202,14 @@ DEFAULT_CONFIG = { "model": "whisper-1", # whisper-1, gpt-4o-mini-transcribe, gpt-4o-transcribe }, }, + + "voice": { + "record_key": "ctrl+b", + "max_recording_seconds": 120, + "auto_tts": False, + "silence_threshold": 200, # RMS below this = silence (0-32767) + "silence_duration": 3.0, # Seconds of silence before auto-stop + }, "human_delay": { "mode": "off", diff --git a/pyproject.toml b/pyproject.toml index 7e4197724b..fa248cd0e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,11 +43,12 @@ dependencies = [ modal = ["swe-rex[modal]>=1.4.0"] daytona = ["daytona>=0.148.0"] dev = ["pytest", "pytest-asyncio", "pytest-xdist", "mcp>=1.2.0"] -messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"] +messaging = ["python-telegram-bot>=20.0", "discord.py[voice]>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"] cron = ["croniter"] slack = ["slack-bolt>=1.18.0", "slack-sdk>=3.27.0"] cli = ["simple-term-menu"] tts-premium = ["elevenlabs"] +voice = ["sounddevice>=0.4.6", "numpy>=1.24.0"] pty = [ "ptyprocess>=0.7.0; sys_platform != 'win32'", "pywinpty>=2.0.0; sys_platform == 'win32'", @@ -78,6 +79,7 @@ all = [ "hermes-agent[mcp]", "hermes-agent[homeassistant]", "hermes-agent[acp]", + "hermes-agent[voice]", ] [project.scripts] diff --git a/run_agent.py b/run_agent.py index ba214b715f..bdf0496553 100644 --- a/run_agent.py +++ b/run_agent.py @@ -493,6 +493,16 @@ class AIAgent: ]: logging.getLogger(quiet_logger).setLevel(logging.ERROR) + # Internal stream callback (set during streaming TTS). + # Initialized here so _vprint can reference it before run_conversation. + self._stream_callback = None + + # Optional current-turn user-message override used when the API-facing + # user message intentionally differs from the persisted transcript + # (e.g. CLI voice mode adds a temporary prefix for the live call only). + self._persist_user_message_idx = None + self._persist_user_message_override = None + # Initialize LLM client via centralized provider router. # The router handles auth resolution, base URL, headers, and # Codex/Anthropic wrapping for all known providers. @@ -504,6 +514,7 @@ class AIAgent: from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token effective_key = api_key or resolve_anthropic_token() or "" self._anthropic_api_key = effective_key + self._anthropic_base_url = base_url self._anthropic_client = build_anthropic_client(effective_key, base_url) # No OpenAI client needed for Anthropic mode self.client = None @@ -812,6 +823,16 @@ class AIAgent: else: print(f"📊 Context limit: {self.context_compressor.context_length:,} tokens (auto-compression disabled)") + def _vprint(self, *args, force: bool = False, **kwargs): + """Verbose print — suppressed when streaming TTS is active. + + Pass ``force=True`` for error/warning messages that should always be + shown even during streaming TTS playback. + """ + if not force and getattr(self, "_stream_callback", None) is not None: + return + print(*args, **kwargs) + def _max_tokens_param(self, value: int) -> dict: """Return the correct max tokens kwarg for the current provider. @@ -983,11 +1004,30 @@ class AIAgent: if self.verbose_logging: logging.warning(f"Failed to cleanup browser for task {task_id}: {e}") + def _apply_persist_user_message_override(self, messages: List[Dict]) -> None: + """Rewrite the current-turn user message before persistence/return. + + Some call paths need an API-only user-message variant without letting + that synthetic text leak into persisted transcripts or resumed session + history. When an override is configured for the active turn, mutate the + in-memory messages list in place so both persistence and returned + history stay clean. + """ + idx = getattr(self, "_persist_user_message_idx", None) + override = getattr(self, "_persist_user_message_override", None) + if override is None or idx is None: + return + if 0 <= idx < len(messages): + msg = messages[idx] + if isinstance(msg, dict) and msg.get("role") == "user": + msg["content"] = override + def _persist_session(self, messages: List[Dict], conversation_history: List[Dict] = None): """Save session state to both JSON log and SQLite on any exit path. Ensures conversations are never lost, even on errors or early returns. """ + self._apply_persist_user_message_override(messages) self._session_messages = messages self._save_session_log(messages) self._flush_messages_to_session_db(messages, conversation_history) @@ -1001,6 +1041,7 @@ class AIAgent: """ if not self._session_db: return + self._apply_persist_user_message_override(messages) try: start_idx = len(conversation_history) if conversation_history else 0 flush_from = max(start_idx, self._last_flushed_db_idx) @@ -1340,7 +1381,7 @@ class AIAgent: encoding="utf-8", ) - print(f"{self.log_prefix}🧾 Request debug dump written to: {dump_file}") + self._vprint(f"{self.log_prefix}🧾 Request debug dump written to: {dump_file}") if os.getenv("HERMES_DUMP_REQUEST_STDOUT", "").strip().lower() in {"1", "true", "yes", "on"}: print(json.dumps(dump_payload, ensure_ascii=False, indent=2, default=str)) @@ -1482,7 +1523,7 @@ class AIAgent: # Replay the items into the store (replace mode) self._todo_store.write(last_todo_response, merge=False) if not self.quiet_mode: - print(f"{self.log_prefix}📋 Restored {len(last_todo_response)} todo item(s) from history") + self._vprint(f"{self.log_prefix}📋 Restored {len(last_todo_response)} todo item(s) from history") _set_interrupt(False) @property @@ -2576,7 +2617,7 @@ class AIAgent: """ Run the API call in a background thread so the main conversation loop can detect interrupts without waiting for the full HTTP round-trip. - + On interrupt, closes the HTTP client to cancel the in-flight request (stops token generation and avoids wasting money), then rebuilds the client for future calls. @@ -2611,7 +2652,139 @@ class AIAgent: try: if self.api_mode == "anthropic_messages": from agent.anthropic_adapter import build_anthropic_client - self._anthropic_client = build_anthropic_client(self._anthropic_api_key) + self._anthropic_client = build_anthropic_client(self._anthropic_api_key, getattr(self, "_anthropic_base_url", None)) + else: + self.client = OpenAI(**self._client_kwargs) + except Exception: + pass + raise InterruptedError("Agent interrupted during API call") + if result["error"] is not None: + raise result["error"] + return result["response"] + + def _streaming_api_call(self, api_kwargs: dict, stream_callback): + """Streaming variant of _interruptible_api_call for voice TTS pipeline. + + Uses ``stream=True`` and forwards content deltas to *stream_callback* + in real-time. Returns a ``SimpleNamespace`` that mimics a normal + ``ChatCompletion`` so the rest of the agent loop works unchanged. + + This method is separate from ``_interruptible_api_call`` to keep the + core agent loop untouched for non-voice users. + """ + result = {"response": None, "error": None} + + def _call(): + try: + stream_kwargs = {**api_kwargs, "stream": True} + stream = self.client.chat.completions.create(**stream_kwargs) + + content_parts: list[str] = [] + tool_calls_acc: dict[int, dict] = {} + finish_reason = None + model_name = None + role = "assistant" + + for chunk in stream: + if not chunk.choices: + if hasattr(chunk, "model") and chunk.model: + model_name = chunk.model + continue + + delta = chunk.choices[0].delta + if hasattr(chunk, "model") and chunk.model: + model_name = chunk.model + + if delta and delta.content: + content_parts.append(delta.content) + try: + stream_callback(delta.content) + except Exception: + pass + + if delta and delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index if tc_delta.index is not None else 0 + if idx in tool_calls_acc and tc_delta.id and tc_delta.id != tool_calls_acc[idx]["id"]: + matched = False + for eidx, eentry in tool_calls_acc.items(): + if eentry["id"] == tc_delta.id: + idx = eidx + matched = True + break + if not matched: + idx = (max(k for k in tool_calls_acc if isinstance(k, int)) + 1) if tool_calls_acc else 0 + if idx not in tool_calls_acc: + tool_calls_acc[idx] = { + "id": tc_delta.id or "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + entry = tool_calls_acc[idx] + if tc_delta.id: + entry["id"] = tc_delta.id + if tc_delta.function: + if tc_delta.function.name: + entry["function"]["name"] += tc_delta.function.name + if tc_delta.function.arguments: + entry["function"]["arguments"] += tc_delta.function.arguments + + if chunk.choices[0].finish_reason: + finish_reason = chunk.choices[0].finish_reason + + full_content = "".join(content_parts) or None + mock_tool_calls = None + if tool_calls_acc: + mock_tool_calls = [] + for idx in sorted(tool_calls_acc): + tc = tool_calls_acc[idx] + mock_tool_calls.append(SimpleNamespace( + id=tc["id"], + type=tc["type"], + function=SimpleNamespace( + name=tc["function"]["name"], + arguments=tc["function"]["arguments"], + ), + )) + + mock_message = SimpleNamespace( + role=role, + content=full_content, + tool_calls=mock_tool_calls, + reasoning_content=None, + ) + mock_choice = SimpleNamespace( + index=0, + message=mock_message, + finish_reason=finish_reason or "stop", + ) + mock_response = SimpleNamespace( + id="stream-" + str(uuid.uuid4()), + model=model_name, + choices=[mock_choice], + usage=None, + ) + result["response"] = mock_response + + except Exception as e: + result["error"] = e + + t = threading.Thread(target=_call, daemon=True) + t.start() + while t.is_alive(): + t.join(timeout=0.3) + if self._interrupt_requested: + try: + if self.api_mode == "anthropic_messages": + self._anthropic_client.close() + else: + self.client.close() + except Exception: + pass + try: + if self.api_mode == "anthropic_messages": + from agent.anthropic_adapter import build_anthropic_client + self._anthropic_client = build_anthropic_client(self._anthropic_api_key, getattr(self, "_anthropic_base_url", None)) else: self.client = OpenAI(**self._client_kwargs) except Exception: @@ -2677,7 +2850,8 @@ class AIAgent: from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token effective_key = fb_client.api_key or resolve_anthropic_token() or "" self._anthropic_api_key = effective_key - self._anthropic_client = build_anthropic_client(effective_key) + self._anthropic_base_url = getattr(fb_client, "base_url", None) + self._anthropic_client = build_anthropic_client(effective_key, self._anthropic_base_url) self.client = None self._client_kwargs = {} else: @@ -3479,7 +3653,7 @@ class AIAgent: if self._interrupt_requested: remaining_calls = assistant_message.tool_calls[i-1:] if remaining_calls: - print(f"{self.log_prefix}⚡ Interrupt: skipping {len(remaining_calls)} tool call(s)") + self._vprint(f"{self.log_prefix}⚡ Interrupt: skipping {len(remaining_calls)} tool call(s)", force=True) for skipped_tc in remaining_calls: skipped_name = skipped_tc.function.name skip_msg = { @@ -3541,7 +3715,7 @@ class AIAgent: ) tool_duration = time.time() - tool_start_time if self.quiet_mode: - print(f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}") + self._vprint(f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}") elif function_name == "session_search": if not self._session_db: function_result = json.dumps({"success": False, "error": "Session database not available."}) @@ -3556,7 +3730,7 @@ class AIAgent: ) tool_duration = time.time() - tool_start_time if self.quiet_mode: - print(f" {_get_cute_tool_message_impl('session_search', function_args, tool_duration, result=function_result)}") + self._vprint(f" {_get_cute_tool_message_impl('session_search', function_args, tool_duration, result=function_result)}") elif function_name == "memory": target = function_args.get("target", "memory") from tools.memory_tool import memory_tool as _memory_tool @@ -3572,7 +3746,7 @@ class AIAgent: self._honcho_save_user_observation(function_args.get("content", "")) tool_duration = time.time() - tool_start_time if self.quiet_mode: - print(f" {_get_cute_tool_message_impl('memory', function_args, tool_duration, result=function_result)}") + self._vprint(f" {_get_cute_tool_message_impl('memory', function_args, tool_duration, result=function_result)}") elif function_name == "clarify": from tools.clarify_tool import clarify_tool as _clarify_tool function_result = _clarify_tool( @@ -3582,7 +3756,7 @@ class AIAgent: ) tool_duration = time.time() - tool_start_time if self.quiet_mode: - print(f" {_get_cute_tool_message_impl('clarify', function_args, tool_duration, result=function_result)}") + self._vprint(f" {_get_cute_tool_message_impl('clarify', function_args, tool_duration, result=function_result)}") elif function_name == "delegate_task": from tools.delegate_tool import delegate_task as _delegate_task tasks_arg = function_args.get("tasks") @@ -3615,8 +3789,8 @@ class AIAgent: if spinner: spinner.stop(cute_msg) elif self.quiet_mode: - print(f" {cute_msg}") - elif self.quiet_mode: + self._vprint(f" {cute_msg}") + elif self.quiet_mode and self._stream_callback is None: face = random.choice(KawaiiSpinner.KAWAII_WAITING) tool_emoji_map = { 'web_search': '🔍', 'web_extract': '📄', 'web_crawl': '🕸️', @@ -3703,7 +3877,7 @@ class AIAgent: if self._interrupt_requested and i < len(assistant_message.tool_calls): remaining = len(assistant_message.tool_calls) - i - print(f"{self.log_prefix}⚡ Interrupt: skipping {remaining} remaining tool call(s)") + self._vprint(f"{self.log_prefix}⚡ Interrupt: skipping {remaining} remaining tool call(s)", force=True) for skipped_tc in assistant_message.tool_calls[i:]: skipped_name = skipped_tc.function.name skip_msg = { @@ -3915,7 +4089,9 @@ class AIAgent: user_message: str, system_message: str = None, conversation_history: List[Dict[str, Any]] = None, - task_id: str = None + task_id: str = None, + stream_callback: Optional[callable] = None, + persist_user_message: Optional[str] = None, ) -> Dict[str, Any]: """ Run a complete conversation with tool calling until completion. @@ -3925,6 +4101,12 @@ class AIAgent: system_message (str): Custom system message (optional, overrides ephemeral_system_prompt if provided) conversation_history (List[Dict]): Previous conversation messages (optional) task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional, auto-generated if not provided) + stream_callback: Optional callback invoked with each text delta during streaming. + Used by the TTS pipeline to start audio generation before the full response. + When None (default), API calls use the standard non-streaming path. + persist_user_message: Optional clean user message to store in + transcripts/history when user_message contains API-only + synthetic prefixes. Returns: Dict: Complete conversation result with final response and message history @@ -3933,6 +4115,10 @@ class AIAgent: # Installed once, transparent when streams are healthy, prevents crash on write. _install_safe_stdio() + # Store stream callback for _interruptible_api_call to pick up + self._stream_callback = stream_callback + self._persist_user_message_idx = None + self._persist_user_message_override = persist_user_message # Generate unique task_id if not provided to isolate VMs between concurrent tasks effective_task_id = task_id or str(uuid.uuid4()) @@ -3967,7 +4153,7 @@ class AIAgent: # Preserve the original user message before nudge injection. # Honcho should receive the actual user input, not system nudges. - original_user_message = user_message + original_user_message = persist_user_message if persist_user_message is not None else user_message # Periodic memory nudge: remind the model to consider saving memories. # Counter resets whenever the memory tool is actually used. @@ -4005,7 +4191,7 @@ class AIAgent: _recall_mode = (self._honcho_config.recall_mode if self._honcho_config else "hybrid") if self._honcho and self._honcho_session_key and _recall_mode != "tools": try: - prefetched_context = self._honcho_prefetch(user_message) + prefetched_context = self._honcho_prefetch(original_user_message) if prefetched_context: if not conversation_history: self._honcho_context = prefetched_context @@ -4018,6 +4204,7 @@ class AIAgent: user_msg = {"role": "user", "content": user_message} messages.append(user_msg) current_turn_user_idx = len(messages) - 1 + self._persist_user_message_idx = current_turn_user_idx if not self.quiet_mode: print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'") @@ -4239,11 +4426,11 @@ class AIAgent: thinking_spinner = None if not self.quiet_mode: - print(f"\n{self.log_prefix}🔄 Making API call #{api_call_count}/{self.max_iterations}...") - print(f"{self.log_prefix} 📊 Request size: {len(api_messages)} messages, ~{approx_tokens:,} tokens (~{total_chars:,} chars)") - print(f"{self.log_prefix} 🔧 Available tools: {len(self.tools) if self.tools else 0}") - else: - # Animated thinking spinner in quiet mode + self._vprint(f"\n{self.log_prefix}🔄 Making API call #{api_call_count}/{self.max_iterations}...") + self._vprint(f"{self.log_prefix} 📊 Request size: {len(api_messages)} messages, ~{approx_tokens:,} tokens (~{total_chars:,} chars)") + self._vprint(f"{self.log_prefix} 🔧 Available tools: {len(self.tools) if self.tools else 0}") + elif self._stream_callback is None: + # Animated thinking spinner in quiet mode (skip during streaming TTS) face = random.choice(KawaiiSpinner.KAWAII_THINKING) verb = random.choice(KawaiiSpinner.THINKING_VERBS) if self.thinking_callback: @@ -4283,7 +4470,33 @@ class AIAgent: if os.getenv("HERMES_DUMP_REQUESTS", "").strip().lower() in {"1", "true", "yes", "on"}: self._dump_api_request_debug(api_kwargs, reason="preflight") - response = self._interruptible_api_call(api_kwargs) + cb = getattr(self, "_stream_callback", None) + if cb is not None and self.api_mode == "chat_completions": + response = self._streaming_api_call(api_kwargs, cb) + else: + response = self._interruptible_api_call(api_kwargs) + # Forward full response to TTS callback for non-streaming providers + # (e.g. Anthropic) so voice TTS still works via batch delivery. + if cb is not None and response: + try: + content = None + # Try choices first — _interruptible_api_call converts all + # providers (including Anthropic) to this format. + try: + content = response.choices[0].message.content + except (AttributeError, IndexError): + pass + # Fallback: Anthropic native content blocks + if not content and self.api_mode == "anthropic_messages": + text_parts = [ + block.text for block in getattr(response, "content", []) + if getattr(block, "type", None) == "text" and getattr(block, "text", None) + ] + content = " ".join(text_parts) if text_parts else None + if content: + cb(content) + except Exception: + pass api_duration = time.time() - api_start_time @@ -4296,7 +4509,7 @@ class AIAgent: self.thinking_callback("") if not self.quiet_mode: - print(f"{self.log_prefix}⏱️ API call completed in {api_duration:.2f}s") + self._vprint(f"{self.log_prefix}⏱️ API call completed in {api_duration:.2f}s") if self.verbose_logging: # Log response with provider info if available @@ -4373,17 +4586,17 @@ class AIAgent: if self.verbose_logging: logging.debug(f"Response attributes for invalid response: {resp_attrs}") - print(f"{self.log_prefix}⚠️ Invalid API response (attempt {retry_count}/{max_retries}): {', '.join(error_details)}") - print(f"{self.log_prefix} 🏢 Provider: {provider_name}") - print(f"{self.log_prefix} 📝 Provider message: {error_msg[:200]}") - print(f"{self.log_prefix} ⏱️ Response time: {api_duration:.2f}s (fast response often indicates rate limiting)") + self._vprint(f"{self.log_prefix}⚠️ Invalid API response (attempt {retry_count}/{max_retries}): {', '.join(error_details)}", force=True) + self._vprint(f"{self.log_prefix} 🏢 Provider: {provider_name}", force=True) + self._vprint(f"{self.log_prefix} 📝 Provider message: {error_msg[:200]}", force=True) + self._vprint(f"{self.log_prefix} ⏱️ Response time: {api_duration:.2f}s (fast response often indicates rate limiting)", force=True) if retry_count >= max_retries: # Try fallback before giving up if self._try_activate_fallback(): retry_count = 0 continue - print(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded for invalid responses. Giving up.") + self._vprint(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded for invalid responses. Giving up.", force=True) logging.error(f"{self.log_prefix}Invalid API response after {max_retries} retries.") self._persist_session(messages, conversation_history) return { @@ -4396,14 +4609,14 @@ class AIAgent: # Longer backoff for rate limiting (likely cause of None choices) wait_time = min(5 * (2 ** (retry_count - 1)), 120) # 5s, 10s, 20s, 40s, 80s, 120s - print(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...") + self._vprint(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...", force=True) logging.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}") # Sleep in small increments to stay responsive to interrupts sleep_end = time.time() + wait_time while time.time() < sleep_end: if self._interrupt_requested: - print(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.") + self._vprint(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.", force=True) self._persist_session(messages, conversation_history) self.clear_interrupt() return { @@ -4436,7 +4649,7 @@ class AIAgent: finish_reason = response.choices[0].finish_reason if finish_reason == "length": - print(f"{self.log_prefix}⚠️ Response truncated (finish_reason='length') - model hit max output tokens") + self._vprint(f"{self.log_prefix}⚠️ Response truncated (finish_reason='length') - model hit max output tokens", force=True) if self.api_mode == "chat_completions": assistant_message = response.choices[0].message @@ -4448,7 +4661,7 @@ class AIAgent: truncated_response_prefix += assistant_message.content if length_continue_retries < 3: - print( + self._vprint( f"{self.log_prefix}↻ Requesting continuation " f"({length_continue_retries}/3)..." ) @@ -4480,7 +4693,7 @@ class AIAgent: # If we have prior messages, roll back to last complete state if len(messages) > 1: - print(f"{self.log_prefix} ⏪ Rolling back to last complete assistant turn") + self._vprint(f"{self.log_prefix} ⏪ Rolling back to last complete assistant turn") rolled_back_messages = self._get_messages_up_to_last_assistant(messages) self._cleanup_task_resources(effective_task_id) @@ -4496,7 +4709,7 @@ class AIAgent: } else: # First message was truncated - mark as failed - print(f"{self.log_prefix}❌ First response truncated - cannot recover") + self._vprint(f"{self.log_prefix}❌ First response truncated - cannot recover", force=True) self._persist_session(messages, conversation_history) return { "final_response": None, @@ -4556,7 +4769,7 @@ class AIAgent: prompt = usage_dict["prompt_tokens"] hit_pct = (cached / prompt * 100) if prompt > 0 else 0 if not self.quiet_mode: - print(f"{self.log_prefix} 💾 Cache: {cached:,}/{prompt:,} tokens ({hit_pct:.0f}% hit, {written:,} written)") + self._vprint(f"{self.log_prefix} 💾 Cache: {cached:,}/{prompt:,} tokens ({hit_pct:.0f}% hit, {written:,} written)") break # Success, exit retry loop @@ -4567,7 +4780,7 @@ class AIAgent: if self.thinking_callback: self.thinking_callback("") api_elapsed = time.time() - api_start_time - print(f"{self.log_prefix}⚡ Interrupted during API call.") + self._vprint(f"{self.log_prefix}⚡ Interrupted during API call.", force=True) self._persist_session(messages, conversation_history) interrupted = True final_response = f"Operation interrupted: waiting for model response ({api_elapsed:.1f}s elapsed)." @@ -4590,7 +4803,7 @@ class AIAgent: ): codex_auth_retry_attempted = True if self._try_refresh_codex_client_credentials(force=True): - print(f"{self.log_prefix}🔐 Codex auth refreshed after 401. Retrying request...") + self._vprint(f"{self.log_prefix}🔐 Codex auth refreshed after 401. Retrying request...") continue if ( self.api_mode == "chat_completions" @@ -4614,7 +4827,7 @@ class AIAgent: new_token = resolve_anthropic_token() if new_token and new_token != self._anthropic_api_key: self._anthropic_api_key = new_token - self._anthropic_client = build_anthropic_client(new_token) + self._anthropic_client = build_anthropic_client(new_token, getattr(self, "_anthropic_base_url", None)) print(f"{self.log_prefix}🔐 Anthropic credentials refreshed after 401. Retrying request...") continue # Credential refresh didn't help — show diagnostic info @@ -4638,14 +4851,14 @@ class AIAgent: error_type = type(api_error).__name__ error_msg = str(api_error).lower() - print(f"{self.log_prefix}⚠️ API call failed (attempt {retry_count}/{max_retries}): {error_type}") - print(f"{self.log_prefix} ⏱️ Time elapsed before failure: {elapsed_time:.2f}s") - print(f"{self.log_prefix} 📝 Error: {str(api_error)[:200]}") - print(f"{self.log_prefix} 📊 Request context: {len(api_messages)} messages, ~{approx_tokens:,} tokens, {len(self.tools) if self.tools else 0} tools") + self._vprint(f"{self.log_prefix}⚠️ API call failed (attempt {retry_count}/{max_retries}): {error_type}", force=True) + self._vprint(f"{self.log_prefix} ⏱️ Time elapsed before failure: {elapsed_time:.2f}s") + self._vprint(f"{self.log_prefix} 📝 Error: {str(api_error)[:200]}", force=True) + self._vprint(f"{self.log_prefix} 📊 Request context: {len(api_messages)} messages, ~{approx_tokens:,} tokens, {len(self.tools) if self.tools else 0} tools") # Check for interrupt before deciding to retry if self._interrupt_requested: - print(f"{self.log_prefix}⚡ Interrupt detected during error handling, aborting retries.") + self._vprint(f"{self.log_prefix}⚡ Interrupt detected during error handling, aborting retries.", force=True) self._persist_session(messages, conversation_history) self.clear_interrupt() return { @@ -4670,7 +4883,7 @@ class AIAgent: if is_payload_too_large: compression_attempts += 1 if compression_attempts > max_compression_attempts: - print(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached for payload-too-large error.") + self._vprint(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached for payload-too-large error.", force=True) logging.error(f"{self.log_prefix}413 compression failed after {max_compression_attempts} attempts.") self._persist_session(messages, conversation_history) return { @@ -4680,7 +4893,7 @@ class AIAgent: "error": f"Request payload too large: max compression attempts ({max_compression_attempts}) reached.", "partial": True } - print(f"{self.log_prefix}⚠️ Request payload too large (413) — compression attempt {compression_attempts}/{max_compression_attempts}...") + self._vprint(f"{self.log_prefix}⚠️ Request payload too large (413) — compression attempt {compression_attempts}/{max_compression_attempts}...") original_len = len(messages) messages, active_system_prompt = self._compress_context( @@ -4689,12 +4902,12 @@ class AIAgent: ) if len(messages) < original_len: - print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...") + self._vprint(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...") time.sleep(2) # Brief pause between compression retries restart_with_compressed_messages = True break else: - print(f"{self.log_prefix}❌ Payload too large and cannot compress further.") + self._vprint(f"{self.log_prefix}❌ Payload too large and cannot compress further.", force=True) logging.error(f"{self.log_prefix}413 payload too large. Cannot compress further.") self._persist_session(messages, conversation_history) return { @@ -4725,7 +4938,7 @@ class AIAgent: parsed_limit = parse_context_limit_from_error(error_msg) if parsed_limit and parsed_limit < old_ctx: new_ctx = parsed_limit - print(f"{self.log_prefix}⚠️ Context limit detected from API: {new_ctx:,} tokens (was {old_ctx:,})") + self._vprint(f"{self.log_prefix}⚠️ Context limit detected from API: {new_ctx:,} tokens (was {old_ctx:,})", force=True) else: # Step down to the next probe tier new_ctx = get_next_probe_tier(old_ctx) @@ -4734,13 +4947,13 @@ class AIAgent: compressor.context_length = new_ctx compressor.threshold_tokens = int(new_ctx * compressor.threshold_percent) compressor._context_probed = True - print(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,} → {new_ctx:,} tokens") + self._vprint(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,} → {new_ctx:,} tokens", force=True) else: - print(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...") + self._vprint(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...", force=True) compression_attempts += 1 if compression_attempts > max_compression_attempts: - print(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached.") + self._vprint(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached.", force=True) logging.error(f"{self.log_prefix}Context compression failed after {max_compression_attempts} attempts.") self._persist_session(messages, conversation_history) return { @@ -4750,7 +4963,7 @@ class AIAgent: "error": f"Context length exceeded: max compression attempts ({max_compression_attempts}) reached.", "partial": True } - print(f"{self.log_prefix} 🗜️ Context compression attempt {compression_attempts}/{max_compression_attempts}...") + self._vprint(f"{self.log_prefix} 🗜️ Context compression attempt {compression_attempts}/{max_compression_attempts}...") original_len = len(messages) messages, active_system_prompt = self._compress_context( @@ -4760,14 +4973,14 @@ class AIAgent: if len(messages) < original_len or new_ctx and new_ctx < old_ctx: if len(messages) < original_len: - print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...") + self._vprint(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...") time.sleep(2) # Brief pause between compression retries restart_with_compressed_messages = True break else: # Can't compress further and already at minimum tier - print(f"{self.log_prefix}❌ Context length exceeded and cannot compress further.") - print(f"{self.log_prefix} 💡 The conversation has accumulated too much content.") + self._vprint(f"{self.log_prefix}❌ Context length exceeded and cannot compress further.", force=True) + self._vprint(f"{self.log_prefix} 💡 The conversation has accumulated too much content.", force=True) logging.error(f"{self.log_prefix}Context length exceeded: {approx_tokens:,} tokens. Cannot compress further.") self._persist_session(messages, conversation_history) return { @@ -4803,8 +5016,8 @@ class AIAgent: self._dump_api_request_debug( api_kwargs, reason="non_retryable_client_error", error=api_error, ) - print(f"{self.log_prefix}❌ Non-retryable client error detected. Aborting immediately.") - print(f"{self.log_prefix} 💡 This type of error won't be fixed by retrying.") + self._vprint(f"{self.log_prefix}❌ Non-retryable client error detected. Aborting immediately.", force=True) + self._vprint(f"{self.log_prefix} 💡 This type of error won't be fixed by retrying.", force=True) logging.error(f"{self.log_prefix}Non-retryable client error: {api_error}") self._persist_session(messages, conversation_history) return { @@ -4821,7 +5034,7 @@ class AIAgent: if self._try_activate_fallback(): retry_count = 0 continue - print(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded. Giving up.") + self._vprint(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded. Giving up.", force=True) logging.error(f"{self.log_prefix}API call failed after {max_retries} retries. Last error: {api_error}") logging.error(f"{self.log_prefix}Request details - Messages: {len(api_messages)}, Approx tokens: {approx_tokens:,}") raise api_error @@ -4829,15 +5042,15 @@ class AIAgent: wait_time = min(2 ** retry_count, 60) # Exponential backoff: 2s, 4s, 8s, 16s, 32s, 60s, 60s logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}") if retry_count >= max_retries: - print(f"{self.log_prefix}⚠️ API call failed after {retry_count} attempts: {str(api_error)[:100]}") - print(f"{self.log_prefix}⏳ Final retry in {wait_time}s...") + self._vprint(f"{self.log_prefix}⚠️ API call failed after {retry_count} attempts: {str(api_error)[:100]}") + self._vprint(f"{self.log_prefix}⏳ Final retry in {wait_time}s...") # Sleep in small increments so we can respond to interrupts quickly # instead of blocking the entire wait_time in one sleep() call sleep_end = time.time() + wait_time while time.time() < sleep_end: if self._interrupt_requested: - print(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.") + self._vprint(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.", force=True) self._persist_session(messages, conversation_history) self.clear_interrupt() return { @@ -4901,7 +5114,7 @@ class AIAgent: # Handle assistant response if assistant_message.content and not self.quiet_mode: - print(f"{self.log_prefix}🤖 Assistant: {assistant_message.content[:100]}{'...' if len(assistant_message.content) > 100 else ''}") + self._vprint(f"{self.log_prefix}🤖 Assistant: {assistant_message.content[:100]}{'...' if len(assistant_message.content) > 100 else ''}") # Notify progress callback of model's thinking (used by subagent # delegation to relay the child's reasoning to the parent display). @@ -4928,15 +5141,15 @@ class AIAgent: self._incomplete_scratchpad_retries = 0 self._incomplete_scratchpad_retries += 1 - print(f"{self.log_prefix}⚠️ Incomplete detected (opened but never closed)") + self._vprint(f"{self.log_prefix}⚠️ Incomplete detected (opened but never closed)") if self._incomplete_scratchpad_retries <= 2: - print(f"{self.log_prefix}🔄 Retrying API call ({self._incomplete_scratchpad_retries}/2)...") + self._vprint(f"{self.log_prefix}🔄 Retrying API call ({self._incomplete_scratchpad_retries}/2)...") # Don't add the broken message, just retry continue else: # Max retries - discard this turn and save as partial - print(f"{self.log_prefix}❌ Max retries (2) for incomplete scratchpad. Saving as partial.") + self._vprint(f"{self.log_prefix}❌ Max retries (2) for incomplete scratchpad. Saving as partial.", force=True) self._incomplete_scratchpad_retries = 0 rolled_back_messages = self._get_messages_up_to_last_assistant(messages) @@ -4979,7 +5192,7 @@ class AIAgent: if self._codex_incomplete_retries < 3: if not self.quiet_mode: - print(f"{self.log_prefix}↻ Codex response incomplete; continuing turn ({self._codex_incomplete_retries}/3)") + self._vprint(f"{self.log_prefix}↻ Codex response incomplete; continuing turn ({self._codex_incomplete_retries}/3)") self._session_messages = messages self._save_session_log(messages) continue @@ -5000,7 +5213,7 @@ class AIAgent: # Check for tool calls if assistant_message.tool_calls: if not self.quiet_mode: - print(f"{self.log_prefix}🔧 Processing {len(assistant_message.tool_calls)} tool call(s)...") + self._vprint(f"{self.log_prefix}🔧 Processing {len(assistant_message.tool_calls)} tool call(s)...") if self.verbose_logging: for tc in assistant_message.tool_calls: @@ -5019,11 +5232,30 @@ class AIAgent: if tc.function.name not in self.valid_tool_names ] if invalid_tool_calls: + # Track retries for invalid tool calls + if not hasattr(self, '_invalid_tool_retries'): + self._invalid_tool_retries = 0 + self._invalid_tool_retries += 1 + # Return helpful error to model — model can self-correct next turn available = ", ".join(sorted(self.valid_tool_names)) invalid_name = invalid_tool_calls[0] invalid_preview = invalid_name[:80] + "..." if len(invalid_name) > 80 else invalid_name - print(f"{self.log_prefix}⚠️ Unknown tool '{invalid_preview}' — sending error to model for self-correction") + self._vprint(f"{self.log_prefix}⚠️ Unknown tool '{invalid_preview}' — sending error to model for self-correction ({self._invalid_tool_retries}/3)") + + if self._invalid_tool_retries >= 3: + self._vprint(f"{self.log_prefix}❌ Max retries (3) for invalid tool calls exceeded. Stopping as partial.", force=True) + self._invalid_tool_retries = 0 + self._persist_session(messages, conversation_history) + return { + "final_response": None, + "messages": messages, + "api_calls": api_call_count, + "completed": False, + "partial": True, + "error": f"Model generated invalid tool call: {invalid_preview}" + } + assistant_msg = self._build_assistant_message(assistant_message, finish_reason) messages.append(assistant_msg) for tc in assistant_message.tool_calls: @@ -5060,15 +5292,15 @@ class AIAgent: self._invalid_json_retries += 1 tool_name, error_msg = invalid_json_args[0] - print(f"{self.log_prefix}⚠️ Invalid JSON in tool call arguments for '{tool_name}': {error_msg}") + self._vprint(f"{self.log_prefix}⚠️ Invalid JSON in tool call arguments for '{tool_name}': {error_msg}") if self._invalid_json_retries < 3: - print(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_json_retries}/3)...") + self._vprint(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_json_retries}/3)...") # Don't add anything to messages, just retry the API call continue else: # Instead of returning partial, inject a helpful message and let model recover - print(f"{self.log_prefix}⚠️ Injecting recovery message for invalid JSON...") + self._vprint(f"{self.log_prefix}⚠️ Injecting recovery message for invalid JSON...") self._invalid_json_retries = 0 # Reset for next attempt # Add a user message explaining the issue @@ -5098,7 +5330,7 @@ class AIAgent: if self.quiet_mode: clean = self._strip_think_blocks(turn_content).strip() if clean: - print(f" ┊ 💬 {clean}") + self._vprint(f" ┊ 💬 {clean}") messages.append(assistant_msg) @@ -5174,19 +5406,19 @@ class AIAgent: self._empty_content_retries += 1 reasoning_text = self._extract_reasoning(assistant_message) - print(f"{self.log_prefix}⚠️ Response only contains think block with no content after it") + self._vprint(f"{self.log_prefix}⚠️ Response only contains think block with no content after it") if reasoning_text: reasoning_preview = reasoning_text[:500] + "..." if len(reasoning_text) > 500 else reasoning_text - print(f"{self.log_prefix} Reasoning: {reasoning_preview}") + self._vprint(f"{self.log_prefix} Reasoning: {reasoning_preview}") else: content_preview = final_response[:80] + "..." if len(final_response) > 80 else final_response - print(f"{self.log_prefix} Content: '{content_preview}'") + self._vprint(f"{self.log_prefix} Content: '{content_preview}'") if self._empty_content_retries < 3: - print(f"{self.log_prefix}🔄 Retrying API call ({self._empty_content_retries}/3)...") + self._vprint(f"{self.log_prefix}🔄 Retrying API call ({self._empty_content_retries}/3)...") continue else: - print(f"{self.log_prefix}❌ Max retries (3) for empty content exceeded.") + self._vprint(f"{self.log_prefix}❌ Max retries (3) for empty content exceeded.", force=True) self._empty_content_retries = 0 # If a prior tool_calls turn had real content, salvage it: @@ -5377,20 +5609,24 @@ class AIAgent: # Clear interrupt state after handling self.clear_interrupt() - + + # Clear stream callback so it doesn't leak into future calls + self._stream_callback = None + return result - - def chat(self, message: str) -> str: + + def chat(self, message: str, stream_callback: Optional[callable] = None) -> str: """ Simple chat interface that returns just the final response. - + Args: message (str): User message - + stream_callback: Optional callback invoked with each text delta during streaming. + Returns: str: Final assistant response """ - result = self.run_conversation(message) + result = self.run_conversation(message, stream_callback=stream_callback) return result["final_response"] diff --git a/tests/fakes/fake_ha_server.py b/tests/fakes/fake_ha_server.py index 1d51bf51b6..b5119da366 100644 --- a/tests/fakes/fake_ha_server.py +++ b/tests/fakes/fake_ha_server.py @@ -275,12 +275,25 @@ class FakeHAServer: affected = [] entity_id = body.get("entity_id") if entity_id: - new_state = "on" if service == "turn_on" else "off" for s in ENTITY_STATES: if s["entity_id"] == entity_id: + if service == "turn_on": + s["state"] = "on" + elif service == "turn_off": + s["state"] = "off" + elif service == "set_temperature" and "temperature" in body: + s["attributes"]["temperature"] = body["temperature"] + # Keep current state or set to heat if off + if s["state"] == "off": + s["state"] = "heat" + # Simulate temperature sensor approaching the target + for ts in ENTITY_STATES: + if ts["entity_id"] == "sensor.temperature": + ts["state"] = str(body["temperature"] - 0.5) + break affected.append({ "entity_id": entity_id, - "state": new_state, + "state": s["state"], "attributes": s.get("attributes", {}), }) break diff --git a/tests/gateway/test_background_command.py b/tests/gateway/test_background_command.py index 6a780fb13f..027742ea01 100644 --- a/tests/gateway/test_background_command.py +++ b/tests/gateway/test_background_command.py @@ -32,6 +32,7 @@ def _make_runner(): from gateway.run import GatewayRunner runner = object.__new__(GatewayRunner) runner.adapters = {} + runner._voice_mode = {} runner._session_db = None runner._reasoning_config = None runner._provider_routing = {} diff --git a/tests/gateway/test_discord_free_response.py b/tests/gateway/test_discord_free_response.py index ff15326dbb..3d41104c86 100644 --- a/tests/gateway/test_discord_free_response.py +++ b/tests/gateway/test_discord_free_response.py @@ -29,6 +29,8 @@ def _ensure_discord_mock(): discord_mod.Embed = MagicMock discord_mod.app_commands = SimpleNamespace( describe=lambda **kwargs: (lambda fn: fn), + choices=lambda **kwargs: (lambda fn: fn), + Choice=lambda **kwargs: SimpleNamespace(**kwargs), ) ext_mod = MagicMock() diff --git a/tests/gateway/test_discord_opus.py b/tests/gateway/test_discord_opus.py new file mode 100644 index 0000000000..ef66cde004 --- /dev/null +++ b/tests/gateway/test_discord_opus.py @@ -0,0 +1,44 @@ +"""Tests for Discord Opus codec loading — must use ctypes.util.find_library.""" + +import inspect + + +class TestOpusFindLibrary: + """Opus loading must try ctypes.util.find_library first, with platform fallback.""" + + def test_uses_find_library_first(self): + """find_library must be the primary lookup strategy.""" + from gateway.platforms.discord import DiscordAdapter + source = inspect.getsource(DiscordAdapter.connect) + assert "find_library" in source, \ + "Opus loading must use ctypes.util.find_library" + + def test_homebrew_fallback_is_conditional(self): + """Homebrew paths must only be tried when find_library returns None.""" + from gateway.platforms.discord import DiscordAdapter + source = inspect.getsource(DiscordAdapter.connect) + # Homebrew fallback must exist + assert "/opt/homebrew" in source or "homebrew" in source, \ + "Opus loading should have macOS Homebrew fallback" + # find_library must appear BEFORE any Homebrew path + fl_idx = source.index("find_library") + hb_idx = source.index("/opt/homebrew") + assert fl_idx < hb_idx, \ + "find_library must be tried before Homebrew fallback paths" + # Fallback must be guarded by platform check + assert "sys.platform" in source or "darwin" in source, \ + "Homebrew fallback must be guarded by macOS platform check" + + def test_opus_decode_error_logged(self): + """Opus decode failure must log the error, not silently return.""" + from gateway.platforms.discord import VoiceReceiver + source = inspect.getsource(VoiceReceiver._on_packet) + assert "logger" in source, \ + "_on_packet must log Opus decode errors" + # Must not have bare `except Exception:\n return` + lines = source.split("\n") + for i, line in enumerate(lines): + if "except Exception" in line and i + 1 < len(lines): + next_line = lines[i + 1].strip() + assert next_line != "return", \ + f"_on_packet has bare 'except Exception: return' at line {i+1}" diff --git a/tests/gateway/test_discord_slash_commands.py b/tests/gateway/test_discord_slash_commands.py index 78141a6395..3c441258cd 100644 --- a/tests/gateway/test_discord_slash_commands.py +++ b/tests/gateway/test_discord_slash_commands.py @@ -21,6 +21,8 @@ def _ensure_discord_mock(): discord_mod.Interaction = object discord_mod.app_commands = SimpleNamespace( describe=lambda **kwargs: (lambda fn: fn), + choices=lambda **kwargs: (lambda fn: fn), + Choice=lambda **kwargs: SimpleNamespace(**kwargs), ) ext_mod = MagicMock() diff --git a/tests/gateway/test_resume_command.py b/tests/gateway/test_resume_command.py index 17adcd2e74..987afbce32 100644 --- a/tests/gateway/test_resume_command.py +++ b/tests/gateway/test_resume_command.py @@ -36,6 +36,7 @@ def _make_runner(session_db=None, current_session_id="current_session_001", from gateway.run import GatewayRunner runner = object.__new__(GatewayRunner) runner.adapters = {} + runner._voice_mode = {} runner._session_db = session_db runner._running_agents = {} diff --git a/tests/gateway/test_run_progress_topics.py b/tests/gateway/test_run_progress_topics.py index 20ae712a20..66d13e0d01 100644 --- a/tests/gateway/test_run_progress_topics.py +++ b/tests/gateway/test_run_progress_topics.py @@ -77,6 +77,7 @@ def _make_runner(adapter): runner = object.__new__(GatewayRunner) runner.adapters = {Platform.TELEGRAM: adapter} + runner._voice_mode = {} runner._prefill_messages = [] runner._ephemeral_system_prompt = "" runner._reasoning_config = None diff --git a/tests/gateway/test_session_hygiene.py b/tests/gateway/test_session_hygiene.py index d627c20565..7e75b906d5 100644 --- a/tests/gateway/test_session_hygiene.py +++ b/tests/gateway/test_session_hygiene.py @@ -266,6 +266,7 @@ async def test_session_hygiene_messages_stay_in_originating_topic(monkeypatch, t platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="fake-token")} ) runner.adapters = {Platform.TELEGRAM: adapter} + runner._voice_mode = {} runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False) runner.session_store = MagicMock() runner.session_store.get_or_create_session.return_value = SessionEntry( diff --git a/tests/gateway/test_title_command.py b/tests/gateway/test_title_command.py index 7f7c782a71..d5bad6c57a 100644 --- a/tests/gateway/test_title_command.py +++ b/tests/gateway/test_title_command.py @@ -31,6 +31,7 @@ def _make_runner(session_db=None): from gateway.run import GatewayRunner runner = object.__new__(GatewayRunner) runner.adapters = {} + runner._voice_mode = {} runner._session_db = session_db # Mock session_store that returns a session entry with a known session_id diff --git a/tests/gateway/test_update_command.py b/tests/gateway/test_update_command.py index a76ce7c828..124745635f 100644 --- a/tests/gateway/test_update_command.py +++ b/tests/gateway/test_update_command.py @@ -33,6 +33,7 @@ def _make_runner(): from gateway.run import GatewayRunner runner = object.__new__(GatewayRunner) runner.adapters = {} + runner._voice_mode = {} return runner diff --git a/tests/gateway/test_voice_command.py b/tests/gateway/test_voice_command.py new file mode 100644 index 0000000000..545f2b28fb --- /dev/null +++ b/tests/gateway/test_voice_command.py @@ -0,0 +1,2033 @@ +"""Tests for the /voice command and auto voice reply in the gateway.""" + +import json +import os +import queue +import sys +import threading +import time +import pytest +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + + +def _ensure_discord_mock(): + """Install a lightweight discord mock when discord.py isn't available.""" + if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"): + return + + discord_mod = MagicMock() + discord_mod.Intents.default.return_value = MagicMock() + discord_mod.Client = MagicMock + discord_mod.File = MagicMock + discord_mod.DMChannel = type("DMChannel", (), {}) + discord_mod.Thread = type("Thread", (), {}) + discord_mod.ForumChannel = type("ForumChannel", (), {}) + discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object) + discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3) + discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4) + discord_mod.Interaction = object + discord_mod.Embed = MagicMock + discord_mod.app_commands = SimpleNamespace( + describe=lambda **kwargs: (lambda fn: fn), + choices=lambda **kwargs: (lambda fn: fn), + Choice=lambda **kwargs: SimpleNamespace(**kwargs), + ) + discord_mod.opus = SimpleNamespace(is_loaded=lambda: True, load_opus=lambda *_args, **_kwargs: None) + discord_mod.FFmpegPCMAudio = MagicMock + discord_mod.PCMVolumeTransformer = MagicMock + discord_mod.http = SimpleNamespace(Route=MagicMock) + + ext_mod = MagicMock() + commands_mod = MagicMock() + commands_mod.Bot = MagicMock + ext_mod.commands = commands_mod + + sys.modules.setdefault("discord", discord_mod) + sys.modules.setdefault("discord.ext", ext_mod) + sys.modules.setdefault("discord.ext.commands", commands_mod) + + +_ensure_discord_mock() + +from gateway.platforms.base import MessageEvent, MessageType, SessionSource + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_event(text: str = "", message_type=MessageType.TEXT, chat_id="123") -> MessageEvent: + source = SessionSource( + chat_id=chat_id, + user_id="user1", + platform=MagicMock(), + ) + source.platform.value = "telegram" + source.thread_id = None + event = MessageEvent(text=text, message_type=message_type, source=source) + event.message_id = "msg42" + return event + + +def _make_runner(tmp_path): + """Create a bare GatewayRunner without calling __init__.""" + from gateway.run import GatewayRunner + runner = object.__new__(GatewayRunner) + runner.adapters = {} + runner._voice_mode = {} + runner._VOICE_MODE_PATH = tmp_path / "gateway_voice_mode.json" + runner._session_db = None + runner.session_store = MagicMock() + runner._is_user_authorized = lambda source: True + return runner + + +# ===================================================================== +# /voice command handler +# ===================================================================== + +class TestHandleVoiceCommand: + + @pytest.fixture + def runner(self, tmp_path): + return _make_runner(tmp_path) + + @pytest.mark.asyncio + async def test_voice_on(self, runner): + event = _make_event("/voice on") + result = await runner._handle_voice_command(event) + assert "enabled" in result.lower() + assert runner._voice_mode["123"] == "voice_only" + + @pytest.mark.asyncio + async def test_voice_off(self, runner): + runner._voice_mode["123"] = "voice_only" + event = _make_event("/voice off") + result = await runner._handle_voice_command(event) + assert "disabled" in result.lower() + assert runner._voice_mode["123"] == "off" + + @pytest.mark.asyncio + async def test_voice_tts(self, runner): + event = _make_event("/voice tts") + result = await runner._handle_voice_command(event) + assert "tts" in result.lower() + assert runner._voice_mode["123"] == "all" + + @pytest.mark.asyncio + async def test_voice_status_off(self, runner): + event = _make_event("/voice status") + result = await runner._handle_voice_command(event) + assert "off" in result.lower() + + @pytest.mark.asyncio + async def test_voice_status_on(self, runner): + runner._voice_mode["123"] = "voice_only" + event = _make_event("/voice status") + result = await runner._handle_voice_command(event) + assert "voice reply" in result.lower() + + @pytest.mark.asyncio + async def test_toggle_off_to_on(self, runner): + event = _make_event("/voice") + result = await runner._handle_voice_command(event) + assert "enabled" in result.lower() + assert runner._voice_mode["123"] == "voice_only" + + @pytest.mark.asyncio + async def test_toggle_on_to_off(self, runner): + runner._voice_mode["123"] = "voice_only" + event = _make_event("/voice") + result = await runner._handle_voice_command(event) + assert "disabled" in result.lower() + assert runner._voice_mode["123"] == "off" + + @pytest.mark.asyncio + async def test_persistence_saved(self, runner): + event = _make_event("/voice on") + await runner._handle_voice_command(event) + assert runner._VOICE_MODE_PATH.exists() + data = json.loads(runner._VOICE_MODE_PATH.read_text()) + assert data["123"] == "voice_only" + + @pytest.mark.asyncio + async def test_persistence_loaded(self, runner): + runner._VOICE_MODE_PATH.write_text(json.dumps({"456": "all"})) + loaded = runner._load_voice_modes() + assert loaded == {"456": "all"} + + @pytest.mark.asyncio + async def test_persistence_saved_for_off(self, runner): + event = _make_event("/voice off") + await runner._handle_voice_command(event) + data = json.loads(runner._VOICE_MODE_PATH.read_text()) + assert data["123"] == "off" + + def test_sync_voice_mode_state_to_adapter_restores_off_chats(self, runner): + runner._voice_mode = {"123": "off", "456": "all"} + adapter = SimpleNamespace(_auto_tts_disabled_chats=set()) + + runner._sync_voice_mode_state_to_adapter(adapter) + + assert adapter._auto_tts_disabled_chats == {"123"} + + def test_restart_restores_voice_off_state(self, runner, tmp_path): + runner._VOICE_MODE_PATH.write_text(json.dumps({"123": "off"})) + + restored_runner = _make_runner(tmp_path) + restored_runner._voice_mode = restored_runner._load_voice_modes() + adapter = SimpleNamespace(_auto_tts_disabled_chats=set()) + + restored_runner._sync_voice_mode_state_to_adapter(adapter) + + assert restored_runner._voice_mode["123"] == "off" + assert adapter._auto_tts_disabled_chats == {"123"} + + @pytest.mark.asyncio + async def test_per_chat_isolation(self, runner): + e1 = _make_event("/voice on", chat_id="aaa") + e2 = _make_event("/voice tts", chat_id="bbb") + await runner._handle_voice_command(e1) + await runner._handle_voice_command(e2) + assert runner._voice_mode["aaa"] == "voice_only" + assert runner._voice_mode["bbb"] == "all" + + +# ===================================================================== +# Auto voice reply decision logic +# ===================================================================== + +class TestAutoVoiceReply: + """Test the real _should_send_voice_reply method on GatewayRunner. + + The gateway has two TTS paths: + 1. base adapter auto-TTS: fires for voice input in _process_message_background + 2. gateway _send_voice_reply: fires based on voice_mode setting + + To prevent double audio, _send_voice_reply is skipped when voice input + already triggered base adapter auto-TTS (skip_double = is_voice_input). + Exception: Discord voice channel — both auto-TTS and Discord play_tts + override skip, so the runner must handle it via play_in_voice_channel. + """ + + @pytest.fixture + def runner(self, tmp_path): + return _make_runner(tmp_path) + + def _call(self, runner, voice_mode, message_type, agent_messages=None, + response="Hello!", in_voice_channel=False): + """Call real _should_send_voice_reply on a GatewayRunner instance.""" + chat_id = "123" + if voice_mode != "off": + runner._voice_mode[chat_id] = voice_mode + else: + runner._voice_mode.pop(chat_id, None) + + event = _make_event(message_type=message_type) + + if in_voice_channel: + mock_adapter = MagicMock() + mock_adapter.is_in_voice_channel = MagicMock(return_value=True) + event.raw_message = SimpleNamespace(guild_id=111, guild=None) + runner.adapters[event.source.platform] = mock_adapter + + return runner._should_send_voice_reply( + event, response, agent_messages or [] + ) + + # -- Full platform x input x mode matrix -------------------------------- + # + # Legend: + # base = base adapter auto-TTS (play_tts) + # runner = gateway _send_voice_reply + # + # | Platform | Input | Mode | base | runner | Expected | + # |---------------|-------|------------|------|--------|--------------| + # | Telegram | voice | off | yes | skip | 1 audio | + # | Telegram | voice | voice_only | yes | skip* | 1 audio | + # | Telegram | voice | all | yes | skip* | 1 audio | + # | Telegram | text | off | skip | skip | 0 audio | + # | Telegram | text | voice_only | skip | skip | 0 audio | + # | Telegram | text | all | skip | yes | 1 audio | + # | Discord text | voice | all | yes | skip* | 1 audio | + # | Discord text | text | all | skip | yes | 1 audio | + # | Discord VC | voice | all | skip†| yes | 1 audio (VC) | + # | Web UI | voice | off | yes | skip | 1 audio | + # | Web UI | voice | all | yes | skip* | 1 audio | + # | Web UI | text | all | skip | yes | 1 audio | + # | Slack | voice | all | yes | skip* | 1 audio | + # | Slack | text | all | skip | yes | 1 audio | + # + # * skip_double: voice input → base already handles + # † Discord play_tts override skips when in VC + + # -- Telegram/Slack/Web: voice input, base handles --------------------- + + def test_voice_input_voice_only_skipped(self, runner): + """voice_only + voice input: base auto-TTS handles it, runner skips.""" + assert self._call(runner, "voice_only", MessageType.VOICE) is False + + def test_voice_input_all_mode_skipped(self, runner): + """all + voice input: base auto-TTS handles it, runner skips.""" + assert self._call(runner, "all", MessageType.VOICE) is False + + # -- Text input: only runner handles ----------------------------------- + + def test_text_input_all_mode_runner_fires(self, runner): + """all + text input: only runner fires (base auto-TTS only for voice).""" + assert self._call(runner, "all", MessageType.TEXT) is True + + def test_text_input_voice_only_no_reply(self, runner): + """voice_only + text input: neither fires.""" + assert self._call(runner, "voice_only", MessageType.TEXT) is False + + # -- Mode off: nothing fires ------------------------------------------- + + def test_off_mode_voice(self, runner): + assert self._call(runner, "off", MessageType.VOICE) is False + + def test_off_mode_text(self, runner): + assert self._call(runner, "off", MessageType.TEXT) is False + + # -- Discord VC exception: runner must handle -------------------------- + + def test_discord_vc_voice_input_runner_fires(self, runner): + """Discord VC + voice input: base play_tts skips (VC override), + so runner must handle via play_in_voice_channel.""" + assert self._call(runner, "all", MessageType.VOICE, in_voice_channel=True) is True + + def test_discord_vc_voice_only_runner_fires(self, runner): + """Discord VC + voice_only + voice: runner must handle.""" + assert self._call(runner, "voice_only", MessageType.VOICE, in_voice_channel=True) is True + + # -- Edge cases -------------------------------------------------------- + + def test_error_response_skipped(self, runner): + assert self._call(runner, "all", MessageType.TEXT, response="Error: boom") is False + + def test_empty_response_skipped(self, runner): + assert self._call(runner, "all", MessageType.TEXT, response="") is False + + def test_dedup_skips_when_agent_called_tts(self, runner): + messages = [{ + "role": "assistant", + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": {"name": "text_to_speech", "arguments": "{}"}, + }], + }] + assert self._call(runner, "all", MessageType.TEXT, agent_messages=messages) is False + + def test_no_dedup_for_other_tools(self, runner): + messages = [{ + "role": "assistant", + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": {"name": "web_search", "arguments": "{}"}, + }], + }] + assert self._call(runner, "all", MessageType.TEXT, agent_messages=messages) is True + + +# ===================================================================== +# _send_voice_reply +# ===================================================================== + +class TestSendVoiceReply: + + @pytest.fixture + def runner(self, tmp_path): + return _make_runner(tmp_path) + + @pytest.mark.asyncio + async def test_calls_tts_and_send_voice(self, runner): + mock_adapter = AsyncMock() + mock_adapter.send_voice = AsyncMock() + event = _make_event() + runner.adapters[event.source.platform] = mock_adapter + + tts_result = json.dumps({"success": True, "file_path": "/tmp/test.ogg"}) + + with patch("tools.tts_tool.text_to_speech_tool", return_value=tts_result), \ + patch("tools.tts_tool._strip_markdown_for_tts", side_effect=lambda t: t), \ + patch("os.path.isfile", return_value=True), \ + patch("os.unlink"), \ + patch("os.makedirs"): + await runner._send_voice_reply(event, "Hello world") + + mock_adapter.send_voice.assert_called_once() + call_args = mock_adapter.send_voice.call_args + assert call_args.kwargs.get("chat_id") == "123" + + @pytest.mark.asyncio + async def test_empty_text_after_strip_skips(self, runner): + event = _make_event() + + with patch("tools.tts_tool.text_to_speech_tool") as mock_tts, \ + patch("tools.tts_tool._strip_markdown_for_tts", return_value=""): + await runner._send_voice_reply(event, "```code only```") + + mock_tts.assert_not_called() + + @pytest.mark.asyncio + async def test_tts_failure_no_crash(self, runner): + event = _make_event() + mock_adapter = AsyncMock() + runner.adapters[event.source.platform] = mock_adapter + tts_result = json.dumps({"success": False, "error": "API error"}) + + with patch("tools.tts_tool.text_to_speech_tool", return_value=tts_result), \ + patch("tools.tts_tool._strip_markdown_for_tts", side_effect=lambda t: t), \ + patch("os.path.isfile", return_value=False), \ + patch("os.makedirs"): + await runner._send_voice_reply(event, "Hello") + + mock_adapter.send_voice.assert_not_called() + + @pytest.mark.asyncio + async def test_exception_caught(self, runner): + event = _make_event() + with patch("tools.tts_tool.text_to_speech_tool", side_effect=RuntimeError("boom")), \ + patch("tools.tts_tool._strip_markdown_for_tts", side_effect=lambda t: t), \ + patch("os.makedirs"): + # Should not raise + await runner._send_voice_reply(event, "Hello") + + +# ===================================================================== +# Discord play_tts skip when in voice channel +# ===================================================================== + +class TestDiscordPlayTtsSkip: + """Discord adapter skips play_tts when bot is in a voice channel.""" + + def _make_discord_adapter(self): + from gateway.platforms.discord import DiscordAdapter + from gateway.config import Platform, PlatformConfig + config = PlatformConfig(enabled=True, extra={}) + config.token = "fake-token" + adapter = object.__new__(DiscordAdapter) + adapter.platform = Platform.DISCORD + adapter.config = config + adapter._voice_clients = {} + adapter._voice_text_channels = {} + adapter._voice_timeout_tasks = {} + adapter._voice_receivers = {} + adapter._voice_listen_tasks = {} + adapter._client = None + adapter._broadcast = AsyncMock() + return adapter + + @pytest.mark.asyncio + async def test_play_tts_skipped_when_in_vc(self): + adapter = self._make_discord_adapter() + # Simulate bot in voice channel for guild 111, text channel 123 + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + adapter._voice_clients[111] = mock_vc + adapter._voice_text_channels[111] = 123 + + result = await adapter.play_tts(chat_id="123", audio_path="/tmp/test.ogg") + assert result.success is True + # send_voice should NOT have been called (no client, would fail) + + @pytest.mark.asyncio + async def test_play_tts_not_skipped_when_not_in_vc(self): + adapter = self._make_discord_adapter() + # No voice connection — play_tts falls through to send_voice + result = await adapter.play_tts(chat_id="123", audio_path="/tmp/test.ogg") + # send_voice will fail (no client), but play_tts should NOT return early + assert result.success is False + + @pytest.mark.asyncio + async def test_play_tts_not_skipped_for_different_channel(self): + adapter = self._make_discord_adapter() + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + adapter._voice_clients[111] = mock_vc + adapter._voice_text_channels[111] = 999 # different channel + + result = await adapter.play_tts(chat_id="123", audio_path="/tmp/test.ogg") + # Different channel — should NOT skip, falls through to send_voice (fails) + assert result.success is False + + +# ===================================================================== +# Web play_tts sends play_audio (not voice bubble) +# ===================================================================== + +# ===================================================================== +# Help text + known commands +# ===================================================================== + +class TestVoiceInHelp: + + def test_voice_in_help_output(self): + from gateway.run import GatewayRunner + import inspect + source = inspect.getsource(GatewayRunner._handle_help_command) + assert "/voice" in source + + def test_voice_is_known_command(self): + from gateway.run import GatewayRunner + import inspect + source = inspect.getsource(GatewayRunner._handle_message) + assert '"voice"' in source + + +# ===================================================================== +# VoiceReceiver unit tests +# ===================================================================== + +class TestVoiceReceiver: + """Test VoiceReceiver silence detection, SSRC mapping, and lifecycle.""" + + def _make_receiver(self): + from gateway.platforms.discord import VoiceReceiver + mock_vc = MagicMock() + mock_vc._connection.secret_key = [0] * 32 + mock_vc._connection.dave_session = None + mock_vc._connection.ssrc = 9999 + mock_vc._connection.add_socket_listener = MagicMock() + mock_vc._connection.remove_socket_listener = MagicMock() + mock_vc._connection.hook = None + receiver = VoiceReceiver(mock_vc) + return receiver + + def test_initial_state(self): + receiver = self._make_receiver() + assert receiver._running is False + assert receiver._paused is False + assert len(receiver._buffers) == 0 + assert len(receiver._ssrc_to_user) == 0 + + def test_start_sets_running(self): + receiver = self._make_receiver() + receiver.start() + assert receiver._running is True + + def test_stop_clears_state(self): + receiver = self._make_receiver() + receiver.start() + receiver.map_ssrc(100, 42) + receiver._buffers[100] = bytearray(b"\x00" * 1000) + receiver._last_packet_time[100] = time.monotonic() + receiver.stop() + assert receiver._running is False + assert len(receiver._buffers) == 0 + assert len(receiver._ssrc_to_user) == 0 + assert len(receiver._last_packet_time) == 0 + + def test_map_ssrc(self): + receiver = self._make_receiver() + receiver.map_ssrc(100, 42) + assert receiver._ssrc_to_user[100] == 42 + + def test_map_ssrc_overwrites(self): + receiver = self._make_receiver() + receiver.map_ssrc(100, 42) + receiver.map_ssrc(100, 99) + assert receiver._ssrc_to_user[100] == 99 + + def test_pause_resume(self): + receiver = self._make_receiver() + assert receiver._paused is False + receiver.pause() + assert receiver._paused is True + receiver.resume() + assert receiver._paused is False + + def test_check_silence_empty(self): + receiver = self._make_receiver() + assert receiver.check_silence() == [] + + def test_check_silence_returns_completed_utterance(self): + receiver = self._make_receiver() + receiver.map_ssrc(100, 42) + # 48kHz, stereo, 16-bit = 192000 bytes/sec + # MIN_SPEECH_DURATION = 0.5s → need 96000 bytes + pcm_data = bytearray(b"\x00" * 96000) + receiver._buffers[100] = pcm_data + # Set last_packet_time far enough in the past to exceed SILENCE_THRESHOLD + receiver._last_packet_time[100] = time.monotonic() - 3.0 + completed = receiver.check_silence() + assert len(completed) == 1 + user_id, data = completed[0] + assert user_id == 42 + assert len(data) == 96000 + # Buffer should be cleared after extraction + assert len(receiver._buffers[100]) == 0 + + def test_check_silence_ignores_short_buffer(self): + receiver = self._make_receiver() + receiver.map_ssrc(100, 42) + # Too short to meet MIN_SPEECH_DURATION + receiver._buffers[100] = bytearray(b"\x00" * 100) + receiver._last_packet_time[100] = time.monotonic() - 3.0 + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_check_silence_ignores_recent_audio(self): + receiver = self._make_receiver() + receiver.map_ssrc(100, 42) + receiver._buffers[100] = bytearray(b"\x00" * 96000) + receiver._last_packet_time[100] = time.monotonic() # just now + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_check_silence_unknown_user_discarded(self): + receiver = self._make_receiver() + # No SSRC mapping — user_id will be 0 + receiver._buffers[100] = bytearray(b"\x00" * 96000) + receiver._last_packet_time[100] = time.monotonic() - 3.0 + completed = receiver.check_silence() + assert len(completed) == 0 + + def test_stale_buffer_discarded(self): + receiver = self._make_receiver() + # Buffer with no user mapping and very old timestamp + receiver._buffers[200] = bytearray(b"\x00" * 100) + receiver._last_packet_time[200] = time.monotonic() - 10.0 + receiver.check_silence() + # Stale buffer (> 2x threshold) should be discarded + assert 200 not in receiver._buffers + + def test_on_packet_skips_when_not_running(self): + receiver = self._make_receiver() + # Not started — _running is False + receiver._on_packet(b"\x00" * 100) + assert len(receiver._buffers) == 0 + + def test_on_packet_skips_when_paused(self): + receiver = self._make_receiver() + receiver.start() + receiver.pause() + receiver._on_packet(b"\x00" * 100) + # Paused — should not process + assert len(receiver._buffers) == 0 + + def test_on_packet_skips_short_data(self): + receiver = self._make_receiver() + receiver.start() + receiver._on_packet(b"\x00" * 10) + assert len(receiver._buffers) == 0 + + def test_on_packet_skips_non_rtp(self): + receiver = self._make_receiver() + receiver.start() + # Valid length but wrong RTP version + data = bytearray(b"\x00" * 20) + data[0] = 0x00 # version 0, not 2 + receiver._on_packet(bytes(data)) + assert len(receiver._buffers) == 0 + + +# ===================================================================== +# Gateway voice channel commands (join / leave / input) +# ===================================================================== + +class TestVoiceChannelCommands: + """Test _handle_voice_channel_join, _handle_voice_channel_leave, + _handle_voice_channel_input on the GatewayRunner.""" + + @pytest.fixture + def runner(self, tmp_path): + return _make_runner(tmp_path) + + def _make_discord_event(self, text="/voice channel", chat_id="123", + guild_id=111, user_id="user1"): + """Create event with raw_message carrying guild info.""" + source = SessionSource( + chat_id=chat_id, + user_id=user_id, + platform=MagicMock(), + ) + source.platform.value = "discord" + source.thread_id = None + event = MessageEvent(text=text, message_type=MessageType.TEXT, source=source) + event.message_id = "msg42" + event.raw_message = SimpleNamespace(guild_id=guild_id, guild=None) + return event + + # -- _handle_voice_channel_join -- + + @pytest.mark.asyncio + async def test_join_unsupported_platform(self, runner): + """Platform without join_voice_channel returns unsupported message.""" + mock_adapter = AsyncMock(spec=[]) # no join_voice_channel + event = self._make_discord_event() + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_join(event) + assert "not supported" in result.lower() + + @pytest.mark.asyncio + async def test_join_no_guild_id(self, runner): + """DM context (no guild_id) returns error.""" + mock_adapter = AsyncMock() + mock_adapter.join_voice_channel = AsyncMock() + event = self._make_discord_event() + event.raw_message = None # no guild info + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_join(event) + assert "discord server" in result.lower() + + @pytest.mark.asyncio + async def test_join_user_not_in_vc(self, runner): + """User not in any voice channel.""" + mock_adapter = AsyncMock() + mock_adapter.join_voice_channel = AsyncMock() + mock_adapter.get_user_voice_channel = AsyncMock(return_value=None) + event = self._make_discord_event() + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_join(event) + assert "need to be in a voice channel" in result.lower() + + @pytest.mark.asyncio + async def test_join_success(self, runner): + """Successful join sets voice_mode and returns confirmation.""" + mock_channel = MagicMock() + mock_channel.name = "General" + mock_adapter = AsyncMock() + mock_adapter.join_voice_channel = AsyncMock(return_value=True) + mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel) + mock_adapter._voice_text_channels = {} + mock_adapter._voice_input_callback = None + event = self._make_discord_event() + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_join(event) + assert "joined" in result.lower() + assert "General" in result + assert runner._voice_mode["123"] == "all" + + @pytest.mark.asyncio + async def test_join_failure(self, runner): + """Failed join returns permissions error.""" + mock_channel = MagicMock() + mock_channel.name = "General" + mock_adapter = AsyncMock() + mock_adapter.join_voice_channel = AsyncMock(return_value=False) + mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel) + event = self._make_discord_event() + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_join(event) + assert "failed" in result.lower() + + @pytest.mark.asyncio + async def test_join_exception(self, runner): + """Exception during join is caught and reported.""" + mock_channel = MagicMock() + mock_channel.name = "General" + mock_adapter = AsyncMock() + mock_adapter.join_voice_channel = AsyncMock(side_effect=RuntimeError("No permission")) + mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel) + event = self._make_discord_event() + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_join(event) + assert "failed" in result.lower() + + # -- _handle_voice_channel_leave -- + + @pytest.mark.asyncio + async def test_leave_not_in_vc(self, runner): + """Leave when not in VC returns appropriate message.""" + mock_adapter = AsyncMock() + mock_adapter.is_in_voice_channel = MagicMock(return_value=False) + event = self._make_discord_event("/voice leave") + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_leave(event) + assert "not in" in result.lower() + + @pytest.mark.asyncio + async def test_leave_no_guild(self, runner): + """Leave from DM returns not in voice channel.""" + mock_adapter = AsyncMock() + event = self._make_discord_event("/voice leave") + event.raw_message = None + runner.adapters[event.source.platform] = mock_adapter + result = await runner._handle_voice_channel_leave(event) + assert "not in" in result.lower() + + @pytest.mark.asyncio + async def test_leave_success(self, runner): + """Successful leave disconnects and clears voice mode.""" + mock_adapter = AsyncMock() + mock_adapter.is_in_voice_channel = MagicMock(return_value=True) + mock_adapter.leave_voice_channel = AsyncMock() + event = self._make_discord_event("/voice leave") + runner.adapters[event.source.platform] = mock_adapter + runner._voice_mode["123"] = "all" + result = await runner._handle_voice_channel_leave(event) + assert "left" in result.lower() + assert runner._voice_mode["123"] == "off" + mock_adapter.leave_voice_channel.assert_called_once_with(111) + + # -- _handle_voice_channel_input -- + + @pytest.mark.asyncio + async def test_input_no_adapter(self, runner): + """No Discord adapter — early return, no crash.""" + from gateway.config import Platform + # No adapters set + await runner._handle_voice_channel_input(111, 42, "Hello") + + @pytest.mark.asyncio + async def test_input_no_text_channel(self, runner): + """No text channel mapped for guild — early return.""" + from gateway.config import Platform + mock_adapter = AsyncMock() + mock_adapter._voice_text_channels = {} + mock_adapter._client = MagicMock() + runner.adapters[Platform.DISCORD] = mock_adapter + await runner._handle_voice_channel_input(111, 42, "Hello") + + @pytest.mark.asyncio + async def test_input_creates_event_and_dispatches(self, runner): + """Voice input creates synthetic event and calls handle_message.""" + from gateway.config import Platform + mock_adapter = AsyncMock() + mock_adapter._voice_text_channels = {111: 123} + mock_channel = AsyncMock() + mock_adapter._client = MagicMock() + mock_adapter._client.get_channel = MagicMock(return_value=mock_channel) + mock_adapter.handle_message = AsyncMock() + runner.adapters[Platform.DISCORD] = mock_adapter + await runner._handle_voice_channel_input(111, 42, "Hello from VC") + mock_adapter.handle_message.assert_called_once() + event = mock_adapter.handle_message.call_args[0][0] + assert event.text == "Hello from VC" + assert event.message_type == MessageType.VOICE + assert event.source.chat_id == "123" + assert event.source.chat_type == "channel" + + @pytest.mark.asyncio + async def test_input_posts_transcript_in_text_channel(self, runner): + """Voice input sends transcript message to text channel.""" + from gateway.config import Platform + mock_adapter = AsyncMock() + mock_adapter._voice_text_channels = {111: 123} + mock_channel = AsyncMock() + mock_adapter._client = MagicMock() + mock_adapter._client.get_channel = MagicMock(return_value=mock_channel) + mock_adapter.handle_message = AsyncMock() + runner.adapters[Platform.DISCORD] = mock_adapter + await runner._handle_voice_channel_input(111, 42, "Test transcript") + mock_channel.send.assert_called_once() + msg = mock_channel.send.call_args[0][0] + assert "Test transcript" in msg + assert "42" in msg # user_id in mention + + # -- _get_guild_id -- + + def test_get_guild_id_from_guild(self, runner): + event = _make_event() + mock_guild = MagicMock() + mock_guild.id = 555 + event.raw_message = SimpleNamespace(guild_id=None, guild=mock_guild) + result = runner._get_guild_id(event) + assert result == 555 + + def test_get_guild_id_from_interaction(self, runner): + event = _make_event() + event.raw_message = SimpleNamespace(guild_id=777, guild=None) + result = runner._get_guild_id(event) + assert result == 777 + + def test_get_guild_id_none(self, runner): + event = _make_event() + event.raw_message = None + result = runner._get_guild_id(event) + assert result is None + + def test_get_guild_id_dm(self, runner): + event = _make_event() + event.raw_message = SimpleNamespace(guild_id=None, guild=None) + result = runner._get_guild_id(event) + assert result is None + + +# ===================================================================== +# Discord adapter voice channel methods +# ===================================================================== + +class TestDiscordVoiceChannelMethods: + """Test DiscordAdapter voice channel methods (join, leave, play, etc.).""" + + def _make_adapter(self): + from gateway.platforms.discord import DiscordAdapter + from gateway.config import Platform, PlatformConfig + config = PlatformConfig(enabled=True, extra={}) + config.token = "fake-token" + adapter = object.__new__(DiscordAdapter) + adapter.platform = Platform.DISCORD + adapter.config = config + adapter._client = MagicMock() + adapter._voice_clients = {} + adapter._voice_text_channels = {} + adapter._voice_timeout_tasks = {} + adapter._voice_receivers = {} + adapter._voice_listen_tasks = {} + adapter._voice_input_callback = None + adapter._allowed_user_ids = set() + adapter._running = True + return adapter + + def test_is_in_voice_channel_true(self): + adapter = self._make_adapter() + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + adapter._voice_clients[111] = mock_vc + assert adapter.is_in_voice_channel(111) is True + + def test_is_in_voice_channel_false_no_client(self): + adapter = self._make_adapter() + assert adapter.is_in_voice_channel(111) is False + + def test_is_in_voice_channel_false_disconnected(self): + adapter = self._make_adapter() + mock_vc = MagicMock() + mock_vc.is_connected.return_value = False + adapter._voice_clients[111] = mock_vc + assert adapter.is_in_voice_channel(111) is False + + @pytest.mark.asyncio + async def test_leave_voice_channel_cleans_up(self): + adapter = self._make_adapter() + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + mock_vc.disconnect = AsyncMock() + adapter._voice_clients[111] = mock_vc + adapter._voice_text_channels[111] = 123 + + mock_receiver = MagicMock() + adapter._voice_receivers[111] = mock_receiver + + mock_task = MagicMock() + adapter._voice_listen_tasks[111] = mock_task + + mock_timeout = MagicMock() + adapter._voice_timeout_tasks[111] = mock_timeout + + await adapter.leave_voice_channel(111) + + mock_receiver.stop.assert_called_once() + mock_task.cancel.assert_called_once() + mock_vc.disconnect.assert_called_once() + mock_timeout.cancel.assert_called_once() + assert 111 not in adapter._voice_clients + assert 111 not in adapter._voice_text_channels + assert 111 not in adapter._voice_receivers + + @pytest.mark.asyncio + async def test_leave_voice_channel_no_connection(self): + """Leave when not connected — no crash.""" + adapter = self._make_adapter() + await adapter.leave_voice_channel(111) # should not raise + + @pytest.mark.asyncio + async def test_get_user_voice_channel_no_client(self): + adapter = self._make_adapter() + adapter._client = None + result = await adapter.get_user_voice_channel(111, "42") + assert result is None + + @pytest.mark.asyncio + async def test_get_user_voice_channel_no_guild(self): + adapter = self._make_adapter() + adapter._client.get_guild = MagicMock(return_value=None) + result = await adapter.get_user_voice_channel(111, "42") + assert result is None + + @pytest.mark.asyncio + async def test_get_user_voice_channel_user_not_in_vc(self): + adapter = self._make_adapter() + mock_guild = MagicMock() + mock_member = MagicMock() + mock_member.voice = None + mock_guild.get_member = MagicMock(return_value=mock_member) + adapter._client.get_guild = MagicMock(return_value=mock_guild) + result = await adapter.get_user_voice_channel(111, "42") + assert result is None + + @pytest.mark.asyncio + async def test_get_user_voice_channel_success(self): + adapter = self._make_adapter() + mock_vc = MagicMock() + mock_guild = MagicMock() + mock_member = MagicMock() + mock_member.voice = MagicMock() + mock_member.voice.channel = mock_vc + mock_guild.get_member = MagicMock(return_value=mock_member) + adapter._client.get_guild = MagicMock(return_value=mock_guild) + result = await adapter.get_user_voice_channel(111, "42") + assert result is mock_vc + + @pytest.mark.asyncio + async def test_play_in_voice_channel_not_connected(self): + adapter = self._make_adapter() + result = await adapter.play_in_voice_channel(111, "/tmp/test.ogg") + assert result is False + + def test_is_allowed_user_empty_list(self): + adapter = self._make_adapter() + assert adapter._is_allowed_user("42") is True + + def test_is_allowed_user_in_list(self): + adapter = self._make_adapter() + adapter._allowed_user_ids = {"42", "99"} + assert adapter._is_allowed_user("42") is True + + def test_is_allowed_user_not_in_list(self): + adapter = self._make_adapter() + adapter._allowed_user_ids = {"99"} + assert adapter._is_allowed_user("42") is False + + @pytest.mark.asyncio + async def test_process_voice_input_success(self): + """Successful voice input: PCM->WAV->STT->callback.""" + adapter = self._make_adapter() + callback = AsyncMock() + adapter._voice_input_callback = callback + adapter._allowed_user_ids = set() + + pcm_data = b"\x00" * 96000 + + with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \ + patch("tools.transcription_tools.transcribe_audio", + return_value={"success": True, "transcript": "Hello"}), \ + patch("tools.voice_mode.is_whisper_hallucination", return_value=False): + await adapter._process_voice_input(111, 42, pcm_data) + + callback.assert_called_once_with(guild_id=111, user_id=42, transcript="Hello") + + @pytest.mark.asyncio + async def test_process_voice_input_hallucination_filtered(self): + """Whisper hallucination is filtered out.""" + adapter = self._make_adapter() + callback = AsyncMock() + adapter._voice_input_callback = callback + + with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \ + patch("tools.transcription_tools.transcribe_audio", + return_value={"success": True, "transcript": "Thank you."}), \ + patch("tools.voice_mode.is_whisper_hallucination", return_value=True): + await adapter._process_voice_input(111, 42, b"\x00" * 96000) + + callback.assert_not_called() + + @pytest.mark.asyncio + async def test_process_voice_input_stt_failure(self): + """STT failure — callback not called.""" + adapter = self._make_adapter() + callback = AsyncMock() + adapter._voice_input_callback = callback + + with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav"), \ + patch("tools.transcription_tools.transcribe_audio", + return_value={"success": False, "error": "API error"}): + await adapter._process_voice_input(111, 42, b"\x00" * 96000) + + callback.assert_not_called() + + @pytest.mark.asyncio + async def test_process_voice_input_exception_caught(self): + """Exception during processing is caught, no crash.""" + adapter = self._make_adapter() + adapter._voice_input_callback = AsyncMock() + + with patch("gateway.platforms.discord.VoiceReceiver.pcm_to_wav", + side_effect=RuntimeError("ffmpeg not found")): + await adapter._process_voice_input(111, 42, b"\x00" * 96000) + # Should not raise + + +# ===================================================================== +# stream_tts_to_speaker functional tests +# ===================================================================== + +# ===================================================================== +# VoiceReceiver thread-safety (lock coverage) +# ===================================================================== + +class TestVoiceReceiverThreadSafety: + """Verify that VoiceReceiver buffer access is protected by lock.""" + + def _make_receiver(self): + from gateway.platforms.discord import VoiceReceiver + mock_vc = MagicMock() + mock_vc._connection.secret_key = [0] * 32 + mock_vc._connection.dave_session = None + mock_vc._connection.ssrc = 9999 + mock_vc._connection.add_socket_listener = MagicMock() + mock_vc._connection.remove_socket_listener = MagicMock() + mock_vc._connection.hook = None + return VoiceReceiver(mock_vc) + + def test_check_silence_holds_lock(self): + """check_silence must hold lock while iterating buffers.""" + import ast, inspect, textwrap + from gateway.platforms.discord import VoiceReceiver + source = textwrap.dedent(inspect.getsource(VoiceReceiver.check_silence)) + tree = ast.parse(source) + # Find 'with self._lock:' that contains buffer iteration + found_lock_with_for = False + for node in ast.walk(tree): + if isinstance(node, ast.With): + # Check if lock context and contains for loop + has_lock = any( + "lock" in ast.dump(item) for item in node.items + ) + has_for = any(isinstance(n, ast.For) for n in ast.walk(node)) + if has_lock and has_for: + found_lock_with_for = True + assert found_lock_with_for, ( + "check_silence must hold self._lock while iterating buffers" + ) + + def test_on_packet_buffer_write_holds_lock(self): + """_on_packet must hold lock when writing to buffers.""" + import ast, inspect, textwrap + from gateway.platforms.discord import VoiceReceiver + source = textwrap.dedent(inspect.getsource(VoiceReceiver._on_packet)) + tree = ast.parse(source) + # Find 'with self._lock:' that contains buffer extend + found_lock_with_extend = False + for node in ast.walk(tree): + if isinstance(node, ast.With): + src_fragment = ast.dump(node) + if "lock" in src_fragment and "extend" in src_fragment: + found_lock_with_extend = True + assert found_lock_with_extend, ( + "_on_packet must hold self._lock when extending buffers" + ) + + def test_concurrent_buffer_access_safe(self): + """Simulate concurrent buffer writes and reads under lock.""" + import threading + receiver = self._make_receiver() + receiver.start() + errors = [] + + def writer(): + for _ in range(1000): + with receiver._lock: + receiver._buffers[100].extend(b"\x00" * 192) + receiver._last_packet_time[100] = time.monotonic() + + def reader(): + for _ in range(1000): + try: + receiver.check_silence() + except Exception as e: + errors.append(str(e)) + + t1 = threading.Thread(target=writer) + t2 = threading.Thread(target=reader) + t1.start() + t2.start() + t1.join() + t2.join() + assert len(errors) == 0, f"Race detected: {errors[:3]}" + + +# ===================================================================== +# Callback wiring order (join) +# ===================================================================== + +class TestCallbackWiringOrder: + """Verify callback is wired BEFORE join, not after.""" + + def test_callback_set_before_join(self): + """_handle_voice_channel_join wires callback before calling join.""" + import ast, inspect + from gateway.run import GatewayRunner + source = inspect.getsource(GatewayRunner._handle_voice_channel_join) + lines = source.split("\n") + callback_line = None + join_line = None + for i, line in enumerate(lines): + if "_voice_input_callback" in line and "=" in line and "None" not in line: + if callback_line is None: + callback_line = i + if "join_voice_channel" in line and "await" in line: + join_line = i + assert callback_line is not None, "callback wiring not found" + assert join_line is not None, "join_voice_channel call not found" + assert callback_line < join_line, ( + f"callback must be wired (line {callback_line}) BEFORE " + f"join_voice_channel (line {join_line})" + ) + + @pytest.mark.asyncio + async def test_join_failure_clears_callback(self, tmp_path): + """If join fails with exception, callback is cleaned up.""" + runner = _make_runner(tmp_path) + + mock_channel = MagicMock() + mock_channel.name = "General" + mock_adapter = AsyncMock() + mock_adapter.join_voice_channel = AsyncMock( + side_effect=RuntimeError("No permission") + ) + mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel) + mock_adapter._voice_input_callback = None + + event = _make_event("/voice channel") + event.raw_message = SimpleNamespace(guild_id=111, guild=None) + runner.adapters[event.source.platform] = mock_adapter + + result = await runner._handle_voice_channel_join(event) + assert "failed" in result.lower() + assert mock_adapter._voice_input_callback is None + + @pytest.mark.asyncio + async def test_join_returns_false_clears_callback(self, tmp_path): + """If join returns False, callback is cleaned up.""" + runner = _make_runner(tmp_path) + + mock_channel = MagicMock() + mock_channel.name = "General" + mock_adapter = AsyncMock() + mock_adapter.join_voice_channel = AsyncMock(return_value=False) + mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel) + mock_adapter._voice_input_callback = None + + event = _make_event("/voice channel") + event.raw_message = SimpleNamespace(guild_id=111, guild=None) + runner.adapters[event.source.platform] = mock_adapter + + result = await runner._handle_voice_channel_join(event) + assert "failed" in result.lower() + assert mock_adapter._voice_input_callback is None + + +# ===================================================================== +# Leave exception handling +# ===================================================================== + +class TestLeaveExceptionHandling: + """Verify state is cleaned up even when leave_voice_channel raises.""" + + @pytest.fixture + def runner(self, tmp_path): + return _make_runner(tmp_path) + + @pytest.mark.asyncio + async def test_leave_exception_still_cleans_state(self, runner): + """If leave_voice_channel raises, voice_mode is still cleaned up.""" + mock_adapter = AsyncMock() + mock_adapter.is_in_voice_channel = MagicMock(return_value=True) + mock_adapter.leave_voice_channel = AsyncMock( + side_effect=RuntimeError("Connection reset") + ) + mock_adapter._voice_input_callback = MagicMock() + + event = _make_event("/voice leave") + event.raw_message = SimpleNamespace(guild_id=111, guild=None) + runner.adapters[event.source.platform] = mock_adapter + runner._voice_mode["123"] = "all" + + result = await runner._handle_voice_channel_leave(event) + assert "left" in result.lower() + assert runner._voice_mode["123"] == "off" + assert mock_adapter._voice_input_callback is None + + @pytest.mark.asyncio + async def test_leave_clears_callback(self, runner): + """Normal leave also clears the voice input callback.""" + mock_adapter = AsyncMock() + mock_adapter.is_in_voice_channel = MagicMock(return_value=True) + mock_adapter.leave_voice_channel = AsyncMock() + mock_adapter._voice_input_callback = MagicMock() + + event = _make_event("/voice leave") + event.raw_message = SimpleNamespace(guild_id=111, guild=None) + runner.adapters[event.source.platform] = mock_adapter + runner._voice_mode["123"] = "all" + + await runner._handle_voice_channel_leave(event) + assert mock_adapter._voice_input_callback is None + + +# ===================================================================== +# Base adapter empty text guard +# ===================================================================== + +class TestAutoTtsEmptyTextGuard: + """Verify base adapter skips TTS when text is empty after markdown strip.""" + + def test_empty_after_strip_skips_tts(self): + """Markdown-only content should not trigger TTS call.""" + import re + text_content = "****" + speech_text = re.sub(r'[*_`#\[\]()]', '', text_content)[:4000].strip() + assert not speech_text, "Expected empty after stripping markdown chars" + + def test_code_block_response_skips_tts(self): + """Code-only response results in empty speech text.""" + import re + text_content = "```python\nprint(1)\n```" + speech_text = re.sub(r'[*_`#\[\]()]', '', text_content)[:4000].strip() + # Note: base.py regex only strips individual chars, not full code blocks + # So code blocks are partially stripped but may leave content + # The real fix is in base.py — empty check after strip + + def test_base_empty_check_in_source(self): + """base.py must check speech_text is non-empty before calling TTS.""" + import ast, inspect + from gateway.platforms.base import BasePlatformAdapter + source = inspect.getsource(BasePlatformAdapter._process_message_background) + assert "if not speech_text" in source or "not speech_text" in source, ( + "base.py must guard against empty speech_text before TTS call" + ) + + +class TestStreamTtsToSpeaker: + """Functional tests for the streaming TTS pipeline.""" + + def test_none_sentinel_flushes_buffer(self): + """None sentinel causes remaining buffer to be spoken.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + def display(text): + spoken.append(text) + + text_q.put("Hello world.") + text_q.put(None) + + stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=display) + assert done_evt.is_set() + assert any("Hello" in s for s in spoken) + + def test_stop_event_aborts_early(self): + """Setting stop_event causes early exit.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + stop_evt.set() + text_q.put("Should not be spoken.") + text_q.put(None) + + stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t)) + assert done_evt.is_set() + assert len(spoken) == 0 + + def test_done_event_set_on_exception(self): + """tts_done_event is set even when an exception occurs.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + + # Put a non-string that will cause concatenation to fail + text_q.put(12345) + text_q.put(None) + + stream_tts_to_speaker(text_q, stop_evt, done_evt) + assert done_evt.is_set() + + def test_think_blocks_stripped(self): + """... content is not spoken.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + text_q.put("internal reasoning") + text_q.put("Visible response. ") + text_q.put(None) + + stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t)) + assert done_evt.is_set() + joined = " ".join(spoken) + assert "internal reasoning" not in joined + assert "Visible" in joined + + def test_sentence_splitting(self): + """Sentences are split at boundaries and spoken individually.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + # Two sentences long enough to exceed min_sentence_len (20) + text_q.put("This is the first sentence. ") + text_q.put("This is the second sentence. ") + text_q.put(None) + + stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t)) + assert done_evt.is_set() + assert len(spoken) >= 2 + + def test_markdown_stripped_in_speech(self): + """Markdown formatting is removed before display/speech.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + text_q.put("**Bold text** and `code`. ") + text_q.put(None) + + stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t)) + assert done_evt.is_set() + # Display callback gets raw text (before markdown stripping) + # But the actual TTS audio would be stripped — we verify pipeline doesn't crash + + def test_duplicate_sentences_deduped(self): + """Repeated sentences are spoken only once.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + # Same sentence twice, each long enough + text_q.put("This is a repeated sentence. ") + text_q.put("This is a repeated sentence. ") + text_q.put(None) + + stream_tts_to_speaker(text_q, stop_evt, done_evt, display_callback=lambda t: spoken.append(t)) + assert done_evt.is_set() + # First occurrence is spoken, second is deduped + assert len(spoken) == 1 + + def test_no_api_key_display_only(self): + """Without ELEVENLABS_API_KEY, display callback still works.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + text_q.put("Display only text. ") + text_q.put(None) + + with patch.dict(os.environ, {"ELEVENLABS_API_KEY": ""}): + stream_tts_to_speaker(text_q, stop_evt, done_evt, + display_callback=lambda t: spoken.append(t)) + assert done_evt.is_set() + assert len(spoken) >= 1 + + def test_long_buffer_flushed_on_timeout(self): + """Buffer longer than long_flush_len is flushed on queue timeout.""" + from tools.tts_tool import stream_tts_to_speaker + text_q = queue.Queue() + stop_evt = threading.Event() + done_evt = threading.Event() + spoken = [] + + # Put a long text without sentence boundary, then None after a delay + long_text = "a" * 150 # > long_flush_len (100) + text_q.put(long_text) + + def delayed_sentinel(): + time.sleep(1.0) + text_q.put(None) + + t = threading.Thread(target=delayed_sentinel, daemon=True) + t.start() + + stream_tts_to_speaker(text_q, stop_evt, done_evt, + display_callback=lambda t: spoken.append(t)) + t.join(timeout=5) + assert done_evt.is_set() + assert len(spoken) >= 1 + + +# ===================================================================== +# Bug 1: VoiceReceiver.stop() must hold lock while clearing shared state +# ===================================================================== + +class TestStopAcquiresLock: + """stop() must acquire _lock before clearing buffers/state.""" + + @staticmethod + def _make_receiver(): + from gateway.platforms.discord import VoiceReceiver + vc = MagicMock() + vc._connection.secret_key = [0] * 32 + vc._connection.dave_session = None + vc._connection.ssrc = 1 + return VoiceReceiver(vc) + + def test_stop_clears_under_lock(self): + """stop() acquires _lock before clearing buffers. + + Verify by holding the lock from another thread and checking that + stop() blocks until the lock is released. + """ + receiver = self._make_receiver() + receiver.start() + receiver._buffers[100] = bytearray(b"\x00" * 500) + receiver._last_packet_time[100] = time.monotonic() + receiver.map_ssrc(100, 42) + + # Hold the lock from another thread + lock_acquired = threading.Event() + release_lock = threading.Event() + + def hold_lock(): + with receiver._lock: + lock_acquired.set() + release_lock.wait(timeout=5) + + holder = threading.Thread(target=hold_lock, daemon=True) + holder.start() + lock_acquired.wait(timeout=2) + + # stop() in another thread — should block on the lock + stop_done = threading.Event() + + def do_stop(): + receiver.stop() + stop_done.set() + + stopper = threading.Thread(target=do_stop, daemon=True) + stopper.start() + + # stop should NOT complete while lock is held + assert not stop_done.wait(timeout=0.3), \ + "stop() should block while _lock is held by another thread" + + # Release the lock — stop should complete + release_lock.set() + assert stop_done.wait(timeout=2), \ + "stop() should complete after lock is released" + + # State should be cleared + assert len(receiver._buffers) == 0 + assert len(receiver._ssrc_to_user) == 0 + holder.join(timeout=2) + stopper.join(timeout=2) + + def test_stop_does_not_deadlock_with_on_packet(self): + """stop() during _on_packet should not deadlock.""" + receiver = self._make_receiver() + receiver.start() + + blocked = threading.Event() + released = threading.Event() + + def hold_lock(): + with receiver._lock: + blocked.set() + released.wait(timeout=2) + + t = threading.Thread(target=hold_lock, daemon=True) + t.start() + blocked.wait(timeout=2) + + stop_done = threading.Event() + + def do_stop(): + receiver.stop() + stop_done.set() + + t2 = threading.Thread(target=do_stop, daemon=True) + t2.start() + + # stop should be blocked waiting for lock + assert not stop_done.wait(timeout=0.2), \ + "stop() should wait for lock, not clear without it" + + released.set() + assert stop_done.wait(timeout=2), "stop() should complete after lock released" + t.join(timeout=2) + t2.join(timeout=2) + + +# ===================================================================== +# Bug 2: _packet_debug_count must be instance-level, not class-level +# ===================================================================== + +class TestPacketDebugCounterIsInstanceLevel: + """Each VoiceReceiver instance has its own debug counter.""" + + @staticmethod + def _make_receiver(): + from gateway.platforms.discord import VoiceReceiver + vc = MagicMock() + vc._connection.secret_key = [0] * 32 + vc._connection.dave_session = None + vc._connection.ssrc = 1 + return VoiceReceiver(vc) + + def test_counter_is_per_instance(self): + """Two receivers have independent counters.""" + r1 = self._make_receiver() + r2 = self._make_receiver() + + r1._packet_debug_count = 10 + assert r2._packet_debug_count == 0, \ + "_packet_debug_count must be instance-level, not shared across instances" + + def test_counter_initialized_in_init(self): + """Counter is set in __init__, not as a class variable.""" + r = self._make_receiver() + assert "_packet_debug_count" in r.__dict__, \ + "_packet_debug_count should be in instance __dict__, not class" + + +# ===================================================================== +# Bug 3: play_in_voice_channel uses get_running_loop not get_event_loop +# ===================================================================== + +class TestPlayInVoiceChannelUsesRunningLoop: + """play_in_voice_channel must use asyncio.get_running_loop().""" + + def test_source_uses_get_running_loop(self): + """The method source code calls get_running_loop, not get_event_loop.""" + import inspect + from gateway.platforms.discord import DiscordAdapter + source = inspect.getsource(DiscordAdapter.play_in_voice_channel) + assert "get_running_loop" in source, \ + "play_in_voice_channel should use asyncio.get_running_loop()" + assert "get_event_loop" not in source, \ + "play_in_voice_channel should NOT use deprecated asyncio.get_event_loop()" + + +# ===================================================================== +# Bug 4: _send_voice_reply filename uses uuid (no collision) +# ===================================================================== + +class TestSendVoiceReplyFilename: + """_send_voice_reply uses uuid for unique filenames.""" + + def test_filename_uses_uuid(self): + """The method uses uuid in the filename, not time-based.""" + import inspect + from gateway.run import GatewayRunner + source = inspect.getsource(GatewayRunner._send_voice_reply) + assert "uuid" in source, \ + "_send_voice_reply should use uuid for unique filenames" + assert "int(time.time())" not in source, \ + "_send_voice_reply should not use int(time.time()) — collision risk" + + def test_filenames_are_unique(self): + """Two calls produce different filenames.""" + import uuid + names = set() + for _ in range(100): + name = f"tts_reply_{uuid.uuid4().hex[:12]}.mp3" + assert name not in names, f"Collision detected: {name}" + names.add(name) + + +# ===================================================================== +# Bug 5: Voice timeout cleans up runner voice_mode via callback +# ===================================================================== + +class TestVoiceTimeoutCleansRunnerState: + """Timeout disconnect notifies runner to clean voice_mode.""" + + @staticmethod + def _make_discord_adapter(): + from gateway.platforms.discord import DiscordAdapter + from gateway.config import PlatformConfig, Platform + config = PlatformConfig(enabled=True, extra={}) + config.token = "fake-token" + adapter = object.__new__(DiscordAdapter) + adapter.platform = Platform.DISCORD + adapter.config = config + adapter._voice_clients = {} + adapter._voice_text_channels = {} + adapter._voice_timeout_tasks = {} + adapter._voice_receivers = {} + adapter._voice_listen_tasks = {} + adapter._voice_input_callback = None + adapter._on_voice_disconnect = None + adapter._client = None + adapter._broadcast = AsyncMock() + adapter._allowed_user_ids = set() + return adapter + + @pytest.fixture + def adapter(self): + return self._make_discord_adapter() + + def test_adapter_has_on_voice_disconnect_attr(self, adapter): + """DiscordAdapter has _on_voice_disconnect callback attribute.""" + assert hasattr(adapter, "_on_voice_disconnect") + assert adapter._on_voice_disconnect is None + + @pytest.mark.asyncio + async def test_timeout_calls_disconnect_callback(self, adapter): + """_voice_timeout_handler calls _on_voice_disconnect with chat_id.""" + callback_calls = [] + adapter._on_voice_disconnect = lambda chat_id: callback_calls.append(chat_id) + + # Set up state as if we're in a voice channel + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + mock_vc.disconnect = AsyncMock() + adapter._voice_clients[111] = mock_vc + adapter._voice_text_channels[111] = 999 + adapter._voice_timeout_tasks[111] = MagicMock() + adapter._voice_receivers[111] = MagicMock() + adapter._voice_listen_tasks[111] = MagicMock() + + # Patch sleep to return immediately + with patch("asyncio.sleep", new_callable=AsyncMock): + await adapter._voice_timeout_handler(111) + + assert "999" in callback_calls, \ + "_on_voice_disconnect must be called with chat_id on timeout" + + @pytest.mark.asyncio + async def test_runner_cleanup_method_removes_voice_mode(self, tmp_path): + """_handle_voice_timeout_cleanup removes voice_mode for chat.""" + runner = _make_runner(tmp_path) + runner._voice_mode["999"] = "all" + + runner._handle_voice_timeout_cleanup("999") + + assert runner._voice_mode["999"] == "off", \ + "voice_mode must persist explicit off state after timeout cleanup" + + @pytest.mark.asyncio + async def test_timeout_without_callback_does_not_crash(self, adapter): + """Timeout works even without _on_voice_disconnect set.""" + adapter._on_voice_disconnect = None + + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + mock_vc.disconnect = AsyncMock() + adapter._voice_clients[111] = mock_vc + adapter._voice_text_channels[111] = 999 + adapter._voice_timeout_tasks[111] = MagicMock() + + with patch("asyncio.sleep", new_callable=AsyncMock): + await adapter._voice_timeout_handler(111) + + assert 111 not in adapter._voice_clients + + +# ===================================================================== +# Bug 6: play_in_voice_channel has playback timeout +# ===================================================================== + +class TestPlaybackTimeout: + """play_in_voice_channel must time out instead of blocking forever.""" + + @staticmethod + def _make_discord_adapter(): + from gateway.platforms.discord import DiscordAdapter + from gateway.config import PlatformConfig, Platform + config = PlatformConfig(enabled=True, extra={}) + config.token = "fake-token" + adapter = object.__new__(DiscordAdapter) + adapter.platform = Platform.DISCORD + adapter.config = config + adapter._voice_clients = {} + adapter._voice_text_channels = {} + adapter._voice_timeout_tasks = {} + adapter._voice_receivers = {} + adapter._voice_listen_tasks = {} + adapter._voice_input_callback = None + adapter._on_voice_disconnect = None + adapter._client = None + adapter._broadcast = AsyncMock() + adapter._allowed_user_ids = set() + return adapter + + def test_source_has_wait_for_timeout(self): + """The method uses asyncio.wait_for with timeout.""" + import inspect + from gateway.platforms.discord import DiscordAdapter + source = inspect.getsource(DiscordAdapter.play_in_voice_channel) + assert "wait_for" in source, \ + "play_in_voice_channel must use asyncio.wait_for for timeout" + assert "PLAYBACK_TIMEOUT" in source, \ + "play_in_voice_channel must reference PLAYBACK_TIMEOUT constant" + + def test_playback_timeout_constant_exists(self): + """PLAYBACK_TIMEOUT constant is defined on DiscordAdapter.""" + from gateway.platforms.discord import DiscordAdapter + assert hasattr(DiscordAdapter, "PLAYBACK_TIMEOUT") + assert DiscordAdapter.PLAYBACK_TIMEOUT > 0 + + @pytest.mark.asyncio + async def test_playback_timeout_fires(self): + """When done event is never set, playback times out gracefully.""" + from gateway.platforms.discord import DiscordAdapter + adapter = self._make_discord_adapter() + + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + mock_vc.is_playing.return_value = False + # play() never calls the after callback -> done never set + mock_vc.play = MagicMock() + mock_vc.stop = MagicMock() + adapter._voice_clients[111] = mock_vc + adapter._voice_timeout_tasks[111] = MagicMock() + + # Use a tiny timeout for test speed + original_timeout = DiscordAdapter.PLAYBACK_TIMEOUT + DiscordAdapter.PLAYBACK_TIMEOUT = 0.1 + try: + with patch("discord.FFmpegPCMAudio"), \ + patch("discord.PCMVolumeTransformer", side_effect=lambda s, **kw: s): + result = await adapter.play_in_voice_channel(111, "/tmp/test.mp3") + assert result is True + # vc.stop() should have been called due to timeout + mock_vc.stop.assert_called() + finally: + DiscordAdapter.PLAYBACK_TIMEOUT = original_timeout + + @pytest.mark.asyncio + async def test_is_playing_wait_has_timeout(self): + """While loop waiting for previous playback has a timeout.""" + from gateway.platforms.discord import DiscordAdapter + adapter = self._make_discord_adapter() + + mock_vc = MagicMock() + mock_vc.is_connected.return_value = True + # is_playing always returns True — would loop forever without timeout + mock_vc.is_playing.return_value = True + mock_vc.stop = MagicMock() + mock_vc.play = MagicMock() + adapter._voice_clients[111] = mock_vc + adapter._voice_timeout_tasks[111] = MagicMock() + + original_timeout = DiscordAdapter.PLAYBACK_TIMEOUT + DiscordAdapter.PLAYBACK_TIMEOUT = 0.2 + try: + with patch("discord.FFmpegPCMAudio"), \ + patch("discord.PCMVolumeTransformer", side_effect=lambda s, **kw: s): + result = await adapter.play_in_voice_channel(111, "/tmp/test.mp3") + assert result is True + # stop() called to break out of the is_playing loop + mock_vc.stop.assert_called() + finally: + DiscordAdapter.PLAYBACK_TIMEOUT = original_timeout + + +# ===================================================================== +# Bug 7: _send_voice_reply cleanup in finally block +# ===================================================================== + +class TestSendVoiceReplyCleanup: + """_send_voice_reply must clean up temp files even on exception.""" + + def test_cleanup_in_finally(self): + """The method has cleanup in a finally block, not inside try.""" + import inspect, textwrap, ast + from gateway.run import GatewayRunner + source = textwrap.dedent(inspect.getsource(GatewayRunner._send_voice_reply)) + tree = ast.parse(source) + func = tree.body[0] + + has_finally_unlink = False + for node in ast.walk(func): + if isinstance(node, ast.Try) and node.finalbody: + finally_source = ast.dump(node.finalbody[0]) + if "unlink" in finally_source or "remove" in finally_source: + has_finally_unlink = True + break + + assert has_finally_unlink, \ + "_send_voice_reply must have os.unlink in a finally block" + + @pytest.mark.asyncio + async def test_files_cleaned_on_send_exception(self, tmp_path): + """Temp files are removed even when send_voice raises.""" + runner = _make_runner(tmp_path) + adapter = MagicMock() + adapter.send_voice = AsyncMock(side_effect=RuntimeError("send failed")) + adapter.is_in_voice_channel = MagicMock(return_value=False) + event = _make_event(message_type=MessageType.VOICE) + runner.adapters[event.source.platform] = adapter + runner._get_guild_id = MagicMock(return_value=None) + + # Create a fake audio file that TTS would produce + fake_audio = tmp_path / "hermes_voice" + fake_audio.mkdir() + audio_file = fake_audio / "test.mp3" + audio_file.write_bytes(b"fake audio") + + tts_result = json.dumps({ + "success": True, + "file_path": str(audio_file), + }) + + with patch("gateway.run.asyncio.to_thread", new_callable=AsyncMock, return_value=tts_result), \ + patch("tools.tts_tool._strip_markdown_for_tts", return_value="hello"), \ + patch("os.path.isfile", return_value=True), \ + patch("os.makedirs"): + await runner._send_voice_reply(event, "Hello world") + + # File should be cleaned up despite exception + assert not audio_file.exists(), \ + "Temp audio file must be cleaned up even when send_voice raises" + + +# ===================================================================== +# Bug 8: Base adapter auto-TTS cleans up temp file after play_tts +# ===================================================================== + +class TestAutoTtsTempFileCleanup: + """Base adapter auto-TTS must clean up generated audio file.""" + + def test_source_has_finally_remove(self): + """play_tts call is wrapped in try/finally with os.remove.""" + import inspect + from gateway.platforms.base import BasePlatformAdapter + source = inspect.getsource(BasePlatformAdapter._process_message_background) + # Find the play_tts section and verify cleanup + play_tts_idx = source.find("play_tts") + assert play_tts_idx > 0 + after_play = source[play_tts_idx:] + finally_idx = after_play.find("finally") + remove_idx = after_play.find("os.remove") + assert finally_idx > 0, "play_tts must be in a try/finally block" + assert remove_idx > 0, "finally block must call os.remove on _tts_path" + assert remove_idx > finally_idx, "os.remove must be inside the finally block" + + +# ===================================================================== +# Voice channel awareness (get_voice_channel_info / context) +# ===================================================================== + + +class TestVoiceChannelAwareness: + """Tests for get_voice_channel_info() and get_voice_channel_context().""" + + def _make_adapter(self): + from gateway.platforms.discord import DiscordAdapter + from gateway.config import PlatformConfig + config = PlatformConfig(enabled=True, extra={}) + config.token = "fake-token" + adapter = object.__new__(DiscordAdapter) + adapter._voice_clients = {} + adapter._voice_text_channels = {} + adapter._voice_receivers = {} + adapter._client = MagicMock() + adapter._client.user = SimpleNamespace(id=99999, name="HermesBot") + return adapter + + def _make_member(self, user_id, display_name, is_bot=False): + return SimpleNamespace( + id=user_id, display_name=display_name, bot=is_bot, + ) + + def test_returns_none_when_not_connected(self): + adapter = self._make_adapter() + assert adapter.get_voice_channel_info(111) is None + + def test_returns_none_when_vc_disconnected(self): + adapter = self._make_adapter() + vc = MagicMock() + vc.is_connected.return_value = False + adapter._voice_clients[111] = vc + assert adapter.get_voice_channel_info(111) is None + + def test_returns_info_with_members(self): + adapter = self._make_adapter() + vc = MagicMock() + vc.is_connected.return_value = True + bot_member = self._make_member(99999, "HermesBot", is_bot=True) + user_a = self._make_member(1001, "Alice") + user_b = self._make_member(1002, "Bob") + vc.channel.name = "general-voice" + vc.channel.members = [bot_member, user_a, user_b] + adapter._voice_clients[111] = vc + + info = adapter.get_voice_channel_info(111) + assert info is not None + assert info["channel_name"] == "general-voice" + assert info["member_count"] == 2 # bot excluded + names = [m["display_name"] for m in info["members"]] + assert "Alice" in names + assert "Bob" in names + assert "HermesBot" not in names + + def test_speaking_detection(self): + adapter = self._make_adapter() + vc = MagicMock() + vc.is_connected.return_value = True + user_a = self._make_member(1001, "Alice") + user_b = self._make_member(1002, "Bob") + vc.channel.name = "voice" + vc.channel.members = [user_a, user_b] + adapter._voice_clients[111] = vc + + # Set up a mock receiver with Alice speaking + import time as _time + receiver = MagicMock() + receiver._lock = threading.Lock() + receiver._last_packet_time = {100: _time.monotonic()} # ssrc 100 is active + receiver._ssrc_to_user = {100: 1001} # ssrc 100 -> Alice + adapter._voice_receivers[111] = receiver + + info = adapter.get_voice_channel_info(111) + alice = [m for m in info["members"] if m["display_name"] == "Alice"][0] + bob = [m for m in info["members"] if m["display_name"] == "Bob"][0] + assert alice["is_speaking"] is True + assert bob["is_speaking"] is False + assert info["speaking_count"] == 1 + + def test_context_string_format(self): + adapter = self._make_adapter() + vc = MagicMock() + vc.is_connected.return_value = True + user_a = self._make_member(1001, "Alice") + vc.channel.name = "chat-room" + vc.channel.members = [user_a] + adapter._voice_clients[111] = vc + + ctx = adapter.get_voice_channel_context(111) + assert "#chat-room" in ctx + assert "1 participant" in ctx + assert "Alice" in ctx + + def test_context_empty_when_not_connected(self): + adapter = self._make_adapter() + assert adapter.get_voice_channel_context(111) == "" + + +# --------------------------------------------------------------------------- +# Bugfix: disconnect() must clean up voice state +# --------------------------------------------------------------------------- + + +class TestDisconnectVoiceCleanup: + """Bug: disconnect() left voice dicts populated after closing client.""" + + @pytest.mark.asyncio + async def test_disconnect_clears_voice_state(self): + from unittest.mock import AsyncMock + + adapter = MagicMock() + adapter._voice_clients = {111: MagicMock(), 222: MagicMock()} + adapter._voice_receivers = {111: MagicMock(), 222: MagicMock()} + adapter._voice_listen_tasks = {111: MagicMock(), 222: MagicMock()} + adapter._voice_timeout_tasks = {111: MagicMock(), 222: MagicMock()} + adapter._voice_text_channels = {111: 999, 222: 888} + + async def mock_leave(guild_id): + adapter._voice_receivers.pop(guild_id, None) + adapter._voice_listen_tasks.pop(guild_id, None) + adapter._voice_clients.pop(guild_id, None) + adapter._voice_timeout_tasks.pop(guild_id, None) + adapter._voice_text_channels.pop(guild_id, None) + + for gid in list(adapter._voice_clients.keys()): + await mock_leave(gid) + + assert len(adapter._voice_clients) == 0 + assert len(adapter._voice_receivers) == 0 + assert len(adapter._voice_listen_tasks) == 0 + assert len(adapter._voice_timeout_tasks) == 0 diff --git a/tests/hermes_cli/test_commands.py b/tests/hermes_cli/test_commands.py index 9aa7220806..218059434a 100644 --- a/tests/hermes_cli/test_commands.py +++ b/tests/hermes_cli/test_commands.py @@ -12,7 +12,7 @@ EXPECTED_COMMANDS = { "/personality", "/clear", "/history", "/new", "/reset", "/retry", "/undo", "/save", "/config", "/cron", "/skills", "/platforms", "/verbose", "/reasoning", "/compress", "/title", "/usage", "/insights", "/paste", - "/reload-mcp", "/rollback", "/background", "/skin", "/quit", + "/reload-mcp", "/rollback", "/background", "/skin", "/voice", "/quit", } diff --git a/tests/test_cli_init.py b/tests/test_cli_init.py index 1afb7c912d..5ebd301ed8 100644 --- a/tests/test_cli_init.py +++ b/tests/test_cli_init.py @@ -95,6 +95,17 @@ class TestVerboseAndToolProgress: assert cli.tool_progress_mode in ("off", "new", "all", "verbose") +class TestSingleQueryState: + def test_voice_and_interrupt_state_initialized_before_run(self): + """Single-query mode calls chat() without going through run().""" + cli = _make_cli() + assert cli._voice_tts is False + assert cli._voice_mode is False + assert cli._voice_tts_done.is_set() + assert hasattr(cli, "_interrupt_queue") + assert hasattr(cli, "_pending_input") + + class TestHistoryDisplay: def test_history_numbers_only_visible_messages_and_summarizes_tools(self, capsys): cli = _make_cli() diff --git a/tests/test_cli_skin_integration.py b/tests/test_cli_skin_integration.py index ef4ddb38df..61a177cad4 100644 --- a/tests/test_cli_skin_integration.py +++ b/tests/test_cli_skin_integration.py @@ -14,6 +14,9 @@ def _make_cli_stub(): cli._clarify_freetext = False cli._command_running = False cli._agent_running = False + cli._voice_recording = False + cli._voice_processing = False + cli._voice_mode = False cli._command_spinner_frame = lambda: "⟳" cli._tui_style_base = { "prompt": "#fff", diff --git a/tests/test_run_agent.py b/tests/test_run_agent.py index 15a0d5fba3..59c4a052ac 100644 --- a/tests/test_run_agent.py +++ b/tests/test_run_agent.py @@ -2083,3 +2083,367 @@ class TestAnthropicBaseUrlPassthrough: # No base_url provided, should be default empty string or None passed_url = call_args[0][1] assert not passed_url or passed_url is None + + +# =================================================================== +# _streaming_api_call tests +# =================================================================== + +def _make_chunk(content=None, tool_calls=None, finish_reason=None, model="test/model"): + """Build a SimpleNamespace mimicking an OpenAI streaming chunk.""" + delta = SimpleNamespace(content=content, tool_calls=tool_calls) + choice = SimpleNamespace(delta=delta, finish_reason=finish_reason) + return SimpleNamespace(model=model, choices=[choice]) + + +def _make_tc_delta(index=0, tc_id=None, name=None, arguments=None): + """Build a SimpleNamespace mimicking a streaming tool_call delta.""" + func = SimpleNamespace(name=name, arguments=arguments) + return SimpleNamespace(index=index, id=tc_id, function=func) + + +class TestStreamingApiCall: + """Tests for _streaming_api_call — voice TTS streaming pipeline.""" + + def test_content_assembly(self, agent): + chunks = [ + _make_chunk(content="Hel"), + _make_chunk(content="lo "), + _make_chunk(content="World"), + _make_chunk(finish_reason="stop"), + ] + agent.client.chat.completions.create.return_value = iter(chunks) + callback = MagicMock() + + resp = agent._streaming_api_call({"messages": []}, callback) + + assert resp.choices[0].message.content == "Hello World" + assert resp.choices[0].finish_reason == "stop" + assert callback.call_count == 3 + callback.assert_any_call("Hel") + callback.assert_any_call("lo ") + callback.assert_any_call("World") + + def test_tool_call_accumulation(self, agent): + chunks = [ + _make_chunk(tool_calls=[_make_tc_delta(0, "call_1", "web_", '{"q":')]), + _make_chunk(tool_calls=[_make_tc_delta(0, None, "search", '"test"}')]), + _make_chunk(finish_reason="tool_calls"), + ] + agent.client.chat.completions.create.return_value = iter(chunks) + + resp = agent._streaming_api_call({"messages": []}, MagicMock()) + + tc = resp.choices[0].message.tool_calls + assert len(tc) == 1 + assert tc[0].function.name == "web_search" + assert tc[0].function.arguments == '{"q":"test"}' + assert tc[0].id == "call_1" + + def test_multiple_tool_calls(self, agent): + chunks = [ + _make_chunk(tool_calls=[_make_tc_delta(0, "call_a", "search", '{}')]), + _make_chunk(tool_calls=[_make_tc_delta(1, "call_b", "read", '{}')]), + _make_chunk(finish_reason="tool_calls"), + ] + agent.client.chat.completions.create.return_value = iter(chunks) + + resp = agent._streaming_api_call({"messages": []}, MagicMock()) + + tc = resp.choices[0].message.tool_calls + assert len(tc) == 2 + assert tc[0].function.name == "search" + assert tc[1].function.name == "read" + + def test_content_and_tool_calls_together(self, agent): + chunks = [ + _make_chunk(content="I'll search"), + _make_chunk(tool_calls=[_make_tc_delta(0, "call_1", "search", '{}')]), + _make_chunk(finish_reason="tool_calls"), + ] + agent.client.chat.completions.create.return_value = iter(chunks) + + resp = agent._streaming_api_call({"messages": []}, MagicMock()) + + assert resp.choices[0].message.content == "I'll search" + assert len(resp.choices[0].message.tool_calls) == 1 + + def test_empty_content_returns_none(self, agent): + chunks = [_make_chunk(finish_reason="stop")] + agent.client.chat.completions.create.return_value = iter(chunks) + + resp = agent._streaming_api_call({"messages": []}, MagicMock()) + + assert resp.choices[0].message.content is None + assert resp.choices[0].message.tool_calls is None + + def test_callback_exception_swallowed(self, agent): + chunks = [ + _make_chunk(content="Hello"), + _make_chunk(content=" World"), + _make_chunk(finish_reason="stop"), + ] + agent.client.chat.completions.create.return_value = iter(chunks) + callback = MagicMock(side_effect=ValueError("boom")) + + resp = agent._streaming_api_call({"messages": []}, callback) + + assert resp.choices[0].message.content == "Hello World" + + def test_model_name_captured(self, agent): + chunks = [ + _make_chunk(content="Hi", model="gpt-4o"), + _make_chunk(finish_reason="stop", model="gpt-4o"), + ] + agent.client.chat.completions.create.return_value = iter(chunks) + + resp = agent._streaming_api_call({"messages": []}, MagicMock()) + + assert resp.model == "gpt-4o" + + def test_stream_kwarg_injected(self, agent): + chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")] + agent.client.chat.completions.create.return_value = iter(chunks) + + agent._streaming_api_call({"messages": [], "model": "test"}, MagicMock()) + + call_kwargs = agent.client.chat.completions.create.call_args + assert call_kwargs[1].get("stream") is True or call_kwargs.kwargs.get("stream") is True + + def test_api_exception_propagated(self, agent): + agent.client.chat.completions.create.side_effect = ConnectionError("fail") + + with pytest.raises(ConnectionError, match="fail"): + agent._streaming_api_call({"messages": []}, MagicMock()) + + def test_response_has_uuid_id(self, agent): + chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")] + agent.client.chat.completions.create.return_value = iter(chunks) + + resp = agent._streaming_api_call({"messages": []}, MagicMock()) + + assert resp.id.startswith("stream-") + assert len(resp.id) > len("stream-") + + def test_empty_choices_chunk_skipped(self, agent): + empty_chunk = SimpleNamespace(model="gpt-4", choices=[]) + chunks = [ + empty_chunk, + _make_chunk(content="Hello", model="gpt-4"), + _make_chunk(finish_reason="stop", model="gpt-4"), + ] + agent.client.chat.completions.create.return_value = iter(chunks) + + resp = agent._streaming_api_call({"messages": []}, MagicMock()) + + assert resp.choices[0].message.content == "Hello" + assert resp.model == "gpt-4" + + +# =================================================================== +# Interrupt _vprint force=True verification +# =================================================================== + + +class TestInterruptVprintForceTrue: + """All interrupt _vprint calls must use force=True so they are always visible.""" + + def test_all_interrupt_vprint_have_force_true(self): + """Scan source for _vprint calls containing 'Interrupt' — each must have force=True.""" + import inspect + source = inspect.getsource(AIAgent) + lines = source.split("\n") + violations = [] + for i, line in enumerate(lines, 1): + stripped = line.strip() + if "_vprint(" in stripped and "Interrupt" in stripped: + if "force=True" not in stripped: + violations.append(f"line {i}: {stripped}") + assert not violations, ( + f"Interrupt _vprint calls missing force=True:\n" + + "\n".join(violations) + ) + + +# =================================================================== +# Anthropic interrupt handler in _interruptible_api_call +# =================================================================== + + +class TestAnthropicInterruptHandler: + """_interruptible_api_call must handle Anthropic mode when interrupted.""" + + def test_interruptible_has_anthropic_branch(self): + """The interrupt handler must check api_mode == 'anthropic_messages'.""" + import inspect + source = inspect.getsource(AIAgent._interruptible_api_call) + assert "anthropic_messages" in source, \ + "_interruptible_api_call must handle Anthropic interrupt (api_mode check)" + + def test_interruptible_rebuilds_anthropic_client(self): + """After interrupting, the Anthropic client should be rebuilt.""" + import inspect + source = inspect.getsource(AIAgent._interruptible_api_call) + assert "build_anthropic_client" in source, \ + "_interruptible_api_call must rebuild Anthropic client after interrupt" + + def test_streaming_has_anthropic_branch(self): + """_streaming_api_call must also handle Anthropic interrupt.""" + import inspect + source = inspect.getsource(AIAgent._streaming_api_call) + assert "anthropic_messages" in source, \ + "_streaming_api_call must handle Anthropic interrupt" + + +# --------------------------------------------------------------------------- +# Bugfix: stream_callback forwarding for non-streaming providers +# --------------------------------------------------------------------------- + + +class TestStreamCallbackNonStreamingProvider: + """When api_mode != chat_completions, stream_callback must still receive + the response content so TTS works (batch delivery).""" + + def test_callback_receives_chat_completions_response(self, agent): + """For chat_completions-shaped responses, callback gets content.""" + agent.api_mode = "anthropic_messages" + mock_response = SimpleNamespace( + choices=[SimpleNamespace( + message=SimpleNamespace(content="Hello", tool_calls=None, reasoning_content=None), + finish_reason="stop", index=0, + )], + usage=None, model="test", id="test-id", + ) + agent._interruptible_api_call = MagicMock(return_value=mock_response) + + received = [] + cb = lambda delta: received.append(delta) + agent._stream_callback = cb + + _cb = getattr(agent, "_stream_callback", None) + response = agent._interruptible_api_call({}) + if _cb is not None and response: + try: + if agent.api_mode == "anthropic_messages": + text_parts = [ + block.text for block in getattr(response, "content", []) + if getattr(block, "type", None) == "text" and getattr(block, "text", None) + ] + content = " ".join(text_parts) if text_parts else None + else: + content = response.choices[0].message.content + if content: + _cb(content) + except Exception: + pass + + # Anthropic format not matched above; fallback via except + # Test the actual code path by checking chat_completions branch + received2 = [] + agent.api_mode = "some_other_mode" + agent._stream_callback = lambda d: received2.append(d) + _cb2 = agent._stream_callback + if _cb2 is not None and mock_response: + try: + content = mock_response.choices[0].message.content + if content: + _cb2(content) + except Exception: + pass + assert received2 == ["Hello"] + + def test_callback_receives_anthropic_content(self, agent): + """For Anthropic responses, text blocks are extracted and forwarded.""" + agent.api_mode = "anthropic_messages" + mock_response = SimpleNamespace( + content=[SimpleNamespace(type="text", text="Hello from Claude")], + stop_reason="end_turn", + ) + + received = [] + cb = lambda d: received.append(d) + agent._stream_callback = cb + _cb = agent._stream_callback + + if _cb is not None and mock_response: + try: + if agent.api_mode == "anthropic_messages": + text_parts = [ + block.text for block in getattr(mock_response, "content", []) + if getattr(block, "type", None) == "text" and getattr(block, "text", None) + ] + content = " ".join(text_parts) if text_parts else None + else: + content = mock_response.choices[0].message.content + if content: + _cb(content) + except Exception: + pass + + assert received == ["Hello from Claude"] + + +# --------------------------------------------------------------------------- +# Bugfix: API-only user message prefixes must not persist +# --------------------------------------------------------------------------- + + +class TestPersistUserMessageOverride: + """Synthetic API-only user prefixes should never leak into transcripts.""" + + def test_persist_session_rewrites_current_turn_user_message(self, agent): + agent._session_db = MagicMock() + agent.session_id = "session-123" + agent._last_flushed_db_idx = 0 + agent._persist_user_message_idx = 0 + agent._persist_user_message_override = "Hello there" + messages = [ + { + "role": "user", + "content": ( + "[Voice input — respond concisely and conversationally, " + "2-3 sentences max. No code blocks or markdown.] Hello there" + ), + }, + {"role": "assistant", "content": "Hi!"}, + ] + + with patch.object(agent, "_save_session_log") as mock_save: + agent._persist_session(messages, []) + + assert messages[0]["content"] == "Hello there" + saved_messages = mock_save.call_args.args[0] + assert saved_messages[0]["content"] == "Hello there" + first_db_write = agent._session_db.append_message.call_args_list[0].kwargs + assert first_db_write["content"] == "Hello there" + + +# --------------------------------------------------------------------------- +# Bugfix: _vprint force=True on error messages during TTS +# --------------------------------------------------------------------------- + + +class TestVprintForceOnErrors: + """Error/warning messages must be visible during streaming TTS.""" + + def test_forced_message_shown_during_tts(self, agent): + agent._stream_callback = lambda x: None + printed = [] + with patch("builtins.print", side_effect=lambda *a, **kw: printed.append(a)): + agent._vprint("error msg", force=True) + assert len(printed) == 1 + + def test_non_forced_suppressed_during_tts(self, agent): + agent._stream_callback = lambda x: None + printed = [] + with patch("builtins.print", side_effect=lambda *a, **kw: printed.append(a)): + agent._vprint("debug info") + assert len(printed) == 0 + + def test_all_shown_without_tts(self, agent): + agent._stream_callback = None + printed = [] + with patch("builtins.print", side_effect=lambda *a, **kw: printed.append(a)): + agent._vprint("debug") + agent._vprint("error", force=True) + assert len(printed) == 2 diff --git a/tests/tools/test_transcription.py b/tests/tools/test_transcription.py index e6cceb0835..fe3b24a8d3 100644 --- a/tests/tools/test_transcription.py +++ b/tests/tools/test_transcription.py @@ -28,6 +28,7 @@ class TestGetProvider: def test_local_fallback_to_openai(self, monkeypatch): monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test") + monkeypatch.delenv("GROQ_API_KEY", raising=False) with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ patch("tools.transcription_tools._HAS_OPENAI", True): from tools.transcription_tools import _get_provider @@ -124,7 +125,7 @@ class TestTranscribeLocal: mock_model.transcribe.return_value = ([mock_segment], mock_info) with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \ - patch("tools.transcription_tools.WhisperModel", return_value=mock_model), \ + patch("faster_whisper.WhisperModel", return_value=mock_model), \ patch("tools.transcription_tools._local_model", None): from tools.transcription_tools import _transcribe_local result = _transcribe_local(str(audio_file), "base") @@ -163,7 +164,7 @@ class TestTranscribeOpenAI: mock_client.audio.transcriptions.create.return_value = "Hello from OpenAI" with patch("tools.transcription_tools._HAS_OPENAI", True), \ - patch("tools.transcription_tools.OpenAI", return_value=mock_client): + patch("openai.OpenAI", return_value=mock_client): from tools.transcription_tools import _transcribe_openai result = _transcribe_openai(str(audio_file), "whisper-1") diff --git a/tests/tools/test_transcription_tools.py b/tests/tools/test_transcription_tools.py new file mode 100644 index 0000000000..2f5b7cfbee --- /dev/null +++ b/tests/tools/test_transcription_tools.py @@ -0,0 +1,716 @@ +"""Tests for tools.transcription_tools — three-provider STT pipeline. + +Covers the full provider matrix (local, groq, openai), fallback chains, +model auto-correction, config loading, validation edge cases, and +end-to-end dispatch. All external dependencies are mocked. +""" + +import os +import struct +import wave +from unittest.mock import MagicMock, patch + +import pytest + + +# ============================================================================ +# Fixtures +# ============================================================================ + +@pytest.fixture +def sample_wav(tmp_path): + """Create a minimal valid WAV file (1 second of silence at 16kHz).""" + wav_path = tmp_path / "test.wav" + n_frames = 16000 + silence = struct.pack(f"<{n_frames}h", *([0] * n_frames)) + + with wave.open(str(wav_path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(16000) + wf.writeframes(silence) + + return str(wav_path) + + +@pytest.fixture +def sample_ogg(tmp_path): + """Create a fake OGG file for validation tests.""" + ogg_path = tmp_path / "test.ogg" + ogg_path.write_bytes(b"fake audio data") + return str(ogg_path) + + +@pytest.fixture(autouse=True) +def clean_env(monkeypatch): + """Ensure no real API keys leak into tests.""" + monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False) + monkeypatch.delenv("GROQ_API_KEY", raising=False) + + +# ============================================================================ +# _get_provider — full permutation matrix +# ============================================================================ + +class TestGetProviderGroq: + """Groq-specific provider selection tests.""" + + def test_groq_when_key_set(self, monkeypatch): + monkeypatch.setenv("GROQ_API_KEY", "gsk-test") + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("tools.transcription_tools._HAS_FASTER_WHISPER", False): + from tools.transcription_tools import _get_provider + assert _get_provider({"provider": "groq"}) == "groq" + + def test_groq_fallback_to_local(self, monkeypatch): + monkeypatch.delenv("GROQ_API_KEY", raising=False) + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True): + from tools.transcription_tools import _get_provider + assert _get_provider({"provider": "groq"}) == "local" + + def test_groq_fallback_to_openai(self, monkeypatch): + monkeypatch.delenv("GROQ_API_KEY", raising=False) + monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test") + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._HAS_OPENAI", True): + from tools.transcription_tools import _get_provider + assert _get_provider({"provider": "groq"}) == "openai" + + def test_groq_nothing_available(self, monkeypatch): + monkeypatch.delenv("GROQ_API_KEY", raising=False) + monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False) + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._HAS_OPENAI", False): + from tools.transcription_tools import _get_provider + assert _get_provider({"provider": "groq"}) == "none" + + +class TestGetProviderFallbackPriority: + """Cross-provider fallback priority tests.""" + + def test_local_fallback_prefers_groq_over_openai(self, monkeypatch): + """When local unavailable, groq (free) is preferred over openai (paid).""" + monkeypatch.setenv("GROQ_API_KEY", "gsk-test") + monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test") + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._HAS_OPENAI", True): + from tools.transcription_tools import _get_provider + assert _get_provider({"provider": "local"}) == "groq" + + def test_local_fallback_to_groq_only(self, monkeypatch): + """When only groq key available, falls back to groq.""" + monkeypatch.setenv("GROQ_API_KEY", "gsk-test") + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._HAS_OPENAI", True): + from tools.transcription_tools import _get_provider + assert _get_provider({"provider": "local"}) == "groq" + + def test_openai_fallback_to_groq(self, monkeypatch): + """When openai key missing but groq available, falls back to groq.""" + monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False) + monkeypatch.setenv("GROQ_API_KEY", "gsk-test") + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._HAS_OPENAI", True): + from tools.transcription_tools import _get_provider + assert _get_provider({"provider": "openai"}) == "groq" + + def test_openai_nothing_available(self, monkeypatch): + """When no openai key and no local, returns none.""" + monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False) + monkeypatch.delenv("GROQ_API_KEY", raising=False) + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._HAS_OPENAI", True): + from tools.transcription_tools import _get_provider + assert _get_provider({"provider": "openai"}) == "none" + + def test_unknown_provider_passed_through(self): + from tools.transcription_tools import _get_provider + assert _get_provider({"provider": "custom-endpoint"}) == "custom-endpoint" + + def test_empty_config_defaults_to_local(self): + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True): + from tools.transcription_tools import _get_provider + assert _get_provider({}) == "local" + + +# ============================================================================ +# _transcribe_groq +# ============================================================================ + +class TestTranscribeGroq: + def test_no_key(self, monkeypatch): + monkeypatch.delenv("GROQ_API_KEY", raising=False) + from tools.transcription_tools import _transcribe_groq + result = _transcribe_groq("/tmp/test.ogg", "whisper-large-v3-turbo") + assert result["success"] is False + assert "GROQ_API_KEY" in result["error"] + + def test_openai_package_not_installed(self, monkeypatch): + monkeypatch.setenv("GROQ_API_KEY", "gsk-test") + with patch("tools.transcription_tools._HAS_OPENAI", False): + from tools.transcription_tools import _transcribe_groq + result = _transcribe_groq("/tmp/test.ogg", "whisper-large-v3-turbo") + assert result["success"] is False + assert "openai package" in result["error"] + + def test_successful_transcription(self, monkeypatch, sample_wav): + monkeypatch.setenv("GROQ_API_KEY", "gsk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.return_value = "hello world" + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client): + from tools.transcription_tools import _transcribe_groq + result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo") + + assert result["success"] is True + assert result["transcript"] == "hello world" + assert result["provider"] == "groq" + + def test_whitespace_stripped(self, monkeypatch, sample_wav): + monkeypatch.setenv("GROQ_API_KEY", "gsk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.return_value = " hello world \n" + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client): + from tools.transcription_tools import _transcribe_groq + result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo") + + assert result["transcript"] == "hello world" + + def test_uses_groq_base_url(self, monkeypatch, sample_wav): + monkeypatch.setenv("GROQ_API_KEY", "gsk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.return_value = "test" + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client) as mock_openai_cls: + from tools.transcription_tools import _transcribe_groq, GROQ_BASE_URL + _transcribe_groq(sample_wav, "whisper-large-v3-turbo") + + call_kwargs = mock_openai_cls.call_args + assert call_kwargs.kwargs["base_url"] == GROQ_BASE_URL + + def test_api_error_returns_failure(self, monkeypatch, sample_wav): + monkeypatch.setenv("GROQ_API_KEY", "gsk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.side_effect = Exception("API error") + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client): + from tools.transcription_tools import _transcribe_groq + result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo") + + assert result["success"] is False + assert "API error" in result["error"] + + def test_permission_error(self, monkeypatch, sample_wav): + monkeypatch.setenv("GROQ_API_KEY", "gsk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.side_effect = PermissionError("denied") + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client): + from tools.transcription_tools import _transcribe_groq + result = _transcribe_groq(sample_wav, "whisper-large-v3-turbo") + + assert result["success"] is False + assert "Permission denied" in result["error"] + + +# ============================================================================ +# _transcribe_openai — additional tests +# ============================================================================ + +class TestTranscribeOpenAIExtended: + def test_openai_package_not_installed(self, monkeypatch): + monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test") + with patch("tools.transcription_tools._HAS_OPENAI", False): + from tools.transcription_tools import _transcribe_openai + result = _transcribe_openai("/tmp/test.ogg", "whisper-1") + assert result["success"] is False + assert "openai package" in result["error"] + + def test_uses_openai_base_url(self, monkeypatch, sample_wav): + monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.return_value = "test" + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client) as mock_openai_cls: + from tools.transcription_tools import _transcribe_openai, OPENAI_BASE_URL + _transcribe_openai(sample_wav, "whisper-1") + + call_kwargs = mock_openai_cls.call_args + assert call_kwargs.kwargs["base_url"] == OPENAI_BASE_URL + + def test_whitespace_stripped(self, monkeypatch, sample_wav): + monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.return_value = " hello \n" + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client): + from tools.transcription_tools import _transcribe_openai + result = _transcribe_openai(sample_wav, "whisper-1") + + assert result["transcript"] == "hello" + + def test_permission_error(self, monkeypatch, sample_wav): + monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.side_effect = PermissionError("denied") + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client): + from tools.transcription_tools import _transcribe_openai + result = _transcribe_openai(sample_wav, "whisper-1") + + assert result["success"] is False + assert "Permission denied" in result["error"] + + +# ============================================================================ +# _transcribe_local — additional tests +# ============================================================================ + +class TestTranscribeLocalExtended: + def test_model_reuse_on_second_call(self, tmp_path): + """Second call with same model should NOT reload the model.""" + audio = tmp_path / "test.ogg" + audio.write_bytes(b"fake") + + mock_segment = MagicMock() + mock_segment.text = "hi" + mock_info = MagicMock() + mock_info.language = "en" + mock_info.duration = 1.0 + + mock_model = MagicMock() + mock_model.transcribe.return_value = ([mock_segment], mock_info) + mock_whisper_cls = MagicMock(return_value=mock_model) + + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \ + patch("faster_whisper.WhisperModel", mock_whisper_cls), \ + patch("tools.transcription_tools._local_model", None), \ + patch("tools.transcription_tools._local_model_name", None): + from tools.transcription_tools import _transcribe_local + _transcribe_local(str(audio), "base") + _transcribe_local(str(audio), "base") + + # WhisperModel should be created only once + assert mock_whisper_cls.call_count == 1 + + def test_model_reloaded_on_change(self, tmp_path): + """Switching model name should reload the model.""" + audio = tmp_path / "test.ogg" + audio.write_bytes(b"fake") + + mock_segment = MagicMock() + mock_segment.text = "hi" + mock_info = MagicMock() + mock_info.language = "en" + mock_info.duration = 1.0 + + mock_model = MagicMock() + mock_model.transcribe.return_value = ([mock_segment], mock_info) + mock_whisper_cls = MagicMock(return_value=mock_model) + + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \ + patch("faster_whisper.WhisperModel", mock_whisper_cls), \ + patch("tools.transcription_tools._local_model", None), \ + patch("tools.transcription_tools._local_model_name", None): + from tools.transcription_tools import _transcribe_local + _transcribe_local(str(audio), "base") + _transcribe_local(str(audio), "small") + + assert mock_whisper_cls.call_count == 2 + + def test_exception_returns_failure(self, tmp_path): + audio = tmp_path / "test.ogg" + audio.write_bytes(b"fake") + + mock_whisper_cls = MagicMock(side_effect=RuntimeError("CUDA out of memory")) + + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \ + patch("faster_whisper.WhisperModel", mock_whisper_cls), \ + patch("tools.transcription_tools._local_model", None): + from tools.transcription_tools import _transcribe_local + result = _transcribe_local(str(audio), "large-v3") + + assert result["success"] is False + assert "CUDA out of memory" in result["error"] + + def test_multiple_segments_joined(self, tmp_path): + audio = tmp_path / "test.ogg" + audio.write_bytes(b"fake") + + seg1 = MagicMock() + seg1.text = "Hello" + seg2 = MagicMock() + seg2.text = " world" + mock_info = MagicMock() + mock_info.language = "en" + mock_info.duration = 3.0 + + mock_model = MagicMock() + mock_model.transcribe.return_value = ([seg1, seg2], mock_info) + + with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \ + patch("faster_whisper.WhisperModel", return_value=mock_model), \ + patch("tools.transcription_tools._local_model", None): + from tools.transcription_tools import _transcribe_local + result = _transcribe_local(str(audio), "base") + + assert result["success"] is True + assert result["transcript"] == "Hello world" + + +# ============================================================================ +# Model auto-correction +# ============================================================================ + +class TestModelAutoCorrection: + def test_groq_corrects_openai_model(self, monkeypatch, sample_wav): + monkeypatch.setenv("GROQ_API_KEY", "gsk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.return_value = "hello world" + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client): + from tools.transcription_tools import _transcribe_groq, DEFAULT_GROQ_STT_MODEL + _transcribe_groq(sample_wav, "whisper-1") + + call_kwargs = mock_client.audio.transcriptions.create.call_args + assert call_kwargs.kwargs["model"] == DEFAULT_GROQ_STT_MODEL + + def test_groq_corrects_gpt4o_transcribe(self, monkeypatch, sample_wav): + monkeypatch.setenv("GROQ_API_KEY", "gsk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.return_value = "test" + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client): + from tools.transcription_tools import _transcribe_groq, DEFAULT_GROQ_STT_MODEL + _transcribe_groq(sample_wav, "gpt-4o-transcribe") + + call_kwargs = mock_client.audio.transcriptions.create.call_args + assert call_kwargs.kwargs["model"] == DEFAULT_GROQ_STT_MODEL + + def test_openai_corrects_groq_model(self, monkeypatch, sample_wav): + monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.return_value = "hello world" + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client): + from tools.transcription_tools import _transcribe_openai, DEFAULT_STT_MODEL + _transcribe_openai(sample_wav, "whisper-large-v3-turbo") + + call_kwargs = mock_client.audio.transcriptions.create.call_args + assert call_kwargs.kwargs["model"] == DEFAULT_STT_MODEL + + def test_openai_corrects_distil_whisper(self, monkeypatch, sample_wav): + monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.return_value = "test" + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client): + from tools.transcription_tools import _transcribe_openai, DEFAULT_STT_MODEL + _transcribe_openai(sample_wav, "distil-whisper-large-v3-en") + + call_kwargs = mock_client.audio.transcriptions.create.call_args + assert call_kwargs.kwargs["model"] == DEFAULT_STT_MODEL + + def test_compatible_groq_model_not_overridden(self, monkeypatch, sample_wav): + monkeypatch.setenv("GROQ_API_KEY", "gsk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.return_value = "test" + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client): + from tools.transcription_tools import _transcribe_groq + _transcribe_groq(sample_wav, "whisper-large-v3") + + call_kwargs = mock_client.audio.transcriptions.create.call_args + assert call_kwargs.kwargs["model"] == "whisper-large-v3" + + def test_compatible_openai_model_not_overridden(self, monkeypatch, sample_wav): + monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.return_value = "test" + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client): + from tools.transcription_tools import _transcribe_openai + _transcribe_openai(sample_wav, "gpt-4o-mini-transcribe") + + call_kwargs = mock_client.audio.transcriptions.create.call_args + assert call_kwargs.kwargs["model"] == "gpt-4o-mini-transcribe" + + def test_unknown_model_passes_through_groq(self, monkeypatch, sample_wav): + """A model not in either known set should not be overridden.""" + monkeypatch.setenv("GROQ_API_KEY", "gsk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.return_value = "test" + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client): + from tools.transcription_tools import _transcribe_groq + _transcribe_groq(sample_wav, "my-custom-model") + + call_kwargs = mock_client.audio.transcriptions.create.call_args + assert call_kwargs.kwargs["model"] == "my-custom-model" + + def test_unknown_model_passes_through_openai(self, monkeypatch, sample_wav): + monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test") + + mock_client = MagicMock() + mock_client.audio.transcriptions.create.return_value = "test" + + with patch("tools.transcription_tools._HAS_OPENAI", True), \ + patch("openai.OpenAI", return_value=mock_client): + from tools.transcription_tools import _transcribe_openai + _transcribe_openai(sample_wav, "my-custom-model") + + call_kwargs = mock_client.audio.transcriptions.create.call_args + assert call_kwargs.kwargs["model"] == "my-custom-model" + + +# ============================================================================ +# _load_stt_config +# ============================================================================ + +class TestLoadSttConfig: + def test_returns_dict_when_import_fails(self): + with patch("tools.transcription_tools._load_stt_config") as mock_load: + mock_load.return_value = {} + from tools.transcription_tools import _load_stt_config + assert _load_stt_config() == {} + + def test_real_load_returns_dict(self): + """_load_stt_config should always return a dict, even on import error.""" + with patch.dict("sys.modules", {"hermes_cli": None, "hermes_cli.config": None}): + from tools.transcription_tools import _load_stt_config + result = _load_stt_config() + assert isinstance(result, dict) + + +# ============================================================================ +# _validate_audio_file — edge cases +# ============================================================================ + +class TestValidateAudioFileEdgeCases: + def test_directory_is_not_a_file(self, tmp_path): + from tools.transcription_tools import _validate_audio_file + # tmp_path itself is a directory with an .ogg-ish name? No. + # Create a directory with a valid audio extension + d = tmp_path / "audio.ogg" + d.mkdir() + result = _validate_audio_file(str(d)) + assert result is not None + assert "not a file" in result["error"] + + def test_stat_oserror(self, tmp_path): + f = tmp_path / "test.ogg" + f.write_bytes(b"data") + from tools.transcription_tools import _validate_audio_file + real_stat = f.stat() + call_count = 0 + + def stat_side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + # First calls are from exists() and is_file(), let them pass + if call_count <= 2: + return real_stat + raise OSError("disk error") + + with patch("pathlib.Path.stat", side_effect=stat_side_effect): + result = _validate_audio_file(str(f)) + assert result is not None + assert "Failed to access" in result["error"] + + def test_all_supported_formats_accepted(self, tmp_path): + from tools.transcription_tools import _validate_audio_file, SUPPORTED_FORMATS + for fmt in SUPPORTED_FORMATS: + f = tmp_path / f"test{fmt}" + f.write_bytes(b"data") + assert _validate_audio_file(str(f)) is None, f"Format {fmt} should be accepted" + + def test_case_insensitive_extension(self, tmp_path): + from tools.transcription_tools import _validate_audio_file + f = tmp_path / "test.MP3" + f.write_bytes(b"data") + assert _validate_audio_file(str(f)) is None + + +# ============================================================================ +# transcribe_audio — end-to-end dispatch +# ============================================================================ + +class TestTranscribeAudioDispatch: + def test_dispatches_to_groq(self, sample_ogg): + with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "groq"}), \ + patch("tools.transcription_tools._get_provider", return_value="groq"), \ + patch("tools.transcription_tools._transcribe_groq", + return_value={"success": True, "transcript": "hi", "provider": "groq"}) as mock_groq: + from tools.transcription_tools import transcribe_audio + result = transcribe_audio(sample_ogg) + + assert result["success"] is True + assert result["provider"] == "groq" + mock_groq.assert_called_once() + + def test_dispatches_to_local(self, sample_ogg): + with patch("tools.transcription_tools._load_stt_config", return_value={}), \ + patch("tools.transcription_tools._get_provider", return_value="local"), \ + patch("tools.transcription_tools._transcribe_local", + return_value={"success": True, "transcript": "hi"}) as mock_local: + from tools.transcription_tools import transcribe_audio + result = transcribe_audio(sample_ogg) + + assert result["success"] is True + mock_local.assert_called_once() + + def test_dispatches_to_openai(self, sample_ogg): + with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openai"}), \ + patch("tools.transcription_tools._get_provider", return_value="openai"), \ + patch("tools.transcription_tools._transcribe_openai", + return_value={"success": True, "transcript": "hi", "provider": "openai"}) as mock_openai: + from tools.transcription_tools import transcribe_audio + result = transcribe_audio(sample_ogg) + + assert result["success"] is True + mock_openai.assert_called_once() + + def test_no_provider_returns_error(self, sample_ogg): + with patch("tools.transcription_tools._load_stt_config", return_value={}), \ + patch("tools.transcription_tools._get_provider", return_value="none"): + from tools.transcription_tools import transcribe_audio + result = transcribe_audio(sample_ogg) + + assert result["success"] is False + assert "No STT provider" in result["error"] + assert "faster-whisper" in result["error"] + assert "GROQ_API_KEY" in result["error"] + + def test_invalid_file_short_circuits(self): + from tools.transcription_tools import transcribe_audio + result = transcribe_audio("/nonexistent/audio.wav") + assert result["success"] is False + assert "not found" in result["error"] + + def test_model_override_passed_to_groq(self, sample_ogg): + with patch("tools.transcription_tools._load_stt_config", return_value={}), \ + patch("tools.transcription_tools._get_provider", return_value="groq"), \ + patch("tools.transcription_tools._transcribe_groq", + return_value={"success": True, "transcript": "hi"}) as mock_groq: + from tools.transcription_tools import transcribe_audio + transcribe_audio(sample_ogg, model="whisper-large-v3") + + _, kwargs = mock_groq.call_args + assert kwargs.get("model_name") or mock_groq.call_args[0][1] == "whisper-large-v3" + + def test_model_override_passed_to_local(self, sample_ogg): + with patch("tools.transcription_tools._load_stt_config", return_value={}), \ + patch("tools.transcription_tools._get_provider", return_value="local"), \ + patch("tools.transcription_tools._transcribe_local", + return_value={"success": True, "transcript": "hi"}) as mock_local: + from tools.transcription_tools import transcribe_audio + transcribe_audio(sample_ogg, model="large-v3") + + assert mock_local.call_args[0][1] == "large-v3" + + def test_default_model_used_when_none(self, sample_ogg): + with patch("tools.transcription_tools._load_stt_config", return_value={}), \ + patch("tools.transcription_tools._get_provider", return_value="groq"), \ + patch("tools.transcription_tools._transcribe_groq", + return_value={"success": True, "transcript": "hi"}) as mock_groq: + from tools.transcription_tools import transcribe_audio, DEFAULT_GROQ_STT_MODEL + transcribe_audio(sample_ogg, model=None) + + assert mock_groq.call_args[0][1] == DEFAULT_GROQ_STT_MODEL + + def test_config_local_model_used(self, sample_ogg): + config = {"local": {"model": "small"}} + with patch("tools.transcription_tools._load_stt_config", return_value=config), \ + patch("tools.transcription_tools._get_provider", return_value="local"), \ + patch("tools.transcription_tools._transcribe_local", + return_value={"success": True, "transcript": "hi"}) as mock_local: + from tools.transcription_tools import transcribe_audio + transcribe_audio(sample_ogg, model=None) + + assert mock_local.call_args[0][1] == "small" + + def test_config_openai_model_used(self, sample_ogg): + config = {"openai": {"model": "gpt-4o-transcribe"}} + with patch("tools.transcription_tools._load_stt_config", return_value=config), \ + patch("tools.transcription_tools._get_provider", return_value="openai"), \ + patch("tools.transcription_tools._transcribe_openai", + return_value={"success": True, "transcript": "hi"}) as mock_openai: + from tools.transcription_tools import transcribe_audio + transcribe_audio(sample_ogg, model=None) + + assert mock_openai.call_args[0][1] == "gpt-4o-transcribe" + + +# ============================================================================ +# get_stt_model_from_config +# ============================================================================ + +class TestGetSttModelFromConfig: + def test_returns_model_from_config(self, tmp_path, monkeypatch): + cfg = tmp_path / "config.yaml" + cfg.write_text("stt:\n model: whisper-large-v3\n") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.transcription_tools import get_stt_model_from_config + assert get_stt_model_from_config() == "whisper-large-v3" + + def test_returns_none_when_no_stt_section(self, tmp_path, monkeypatch): + cfg = tmp_path / "config.yaml" + cfg.write_text("tts:\n provider: edge\n") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.transcription_tools import get_stt_model_from_config + assert get_stt_model_from_config() is None + + def test_returns_none_when_no_config_file(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.transcription_tools import get_stt_model_from_config + assert get_stt_model_from_config() is None + + def test_returns_none_on_invalid_yaml(self, tmp_path, monkeypatch): + cfg = tmp_path / "config.yaml" + cfg.write_text(": : :\n bad yaml [[[") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.transcription_tools import get_stt_model_from_config + assert get_stt_model_from_config() is None + + def test_returns_none_when_model_key_missing(self, tmp_path, monkeypatch): + cfg = tmp_path / "config.yaml" + cfg.write_text("stt:\n enabled: true\n") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.transcription_tools import get_stt_model_from_config + assert get_stt_model_from_config() is None diff --git a/tests/tools/test_voice_cli_integration.py b/tests/tools/test_voice_cli_integration.py new file mode 100644 index 0000000000..39fa026ce6 --- /dev/null +++ b/tests/tools/test_voice_cli_integration.py @@ -0,0 +1,1233 @@ +"""Tests for CLI voice mode integration -- command parsing, markdown stripping, +state management, streaming TTS activation, voice message prefix, _vprint.""" + +import ast +import os +import queue +import threading +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + + +def _make_voice_cli(**overrides): + """Create a minimal HermesCLI with only voice-related attrs initialized. + + Uses ``__new__()`` to bypass ``__init__`` so no config/env/API setup is + needed. Only the voice state attributes (from __init__ lines 3749-3758) + are populated. + """ + from cli import HermesCLI + + cli = HermesCLI.__new__(HermesCLI) + cli._voice_lock = threading.Lock() + cli._voice_mode = False + cli._voice_tts = False + cli._voice_recorder = None + cli._voice_recording = False + cli._voice_processing = False + cli._voice_continuous = False + cli._voice_tts_done = threading.Event() + cli._voice_tts_done.set() + cli._pending_input = queue.Queue() + cli._app = None + cli.console = SimpleNamespace(width=80) + for k, v in overrides.items(): + setattr(cli, k, v) + return cli + + +# ============================================================================ +# Markdown stripping — import real function from tts_tool +# ============================================================================ + +from tools.tts_tool import _strip_markdown_for_tts + + +class TestMarkdownStripping: + def test_strips_bold(self): + assert _strip_markdown_for_tts("This is **bold** text") == "This is bold text" + + def test_strips_italic(self): + assert _strip_markdown_for_tts("This is *italic* text") == "This is italic text" + + def test_strips_inline_code(self): + assert _strip_markdown_for_tts("Run `pip install foo`") == "Run pip install foo" + + def test_strips_fenced_code_blocks(self): + text = "Here is code:\n```python\nprint('hello')\n```\nDone." + result = _strip_markdown_for_tts(text) + assert "print" not in result + assert "Done." in result + + def test_strips_headers(self): + assert _strip_markdown_for_tts("## Summary\nSome text") == "Summary\nSome text" + + def test_strips_list_markers(self): + text = "- item one\n- item two\n* item three" + result = _strip_markdown_for_tts(text) + assert "item one" in result + assert "- " not in result + assert "* " not in result + + def test_strips_urls(self): + text = "Visit https://example.com for details" + result = _strip_markdown_for_tts(text) + assert "https://" not in result + assert "Visit" in result + + def test_strips_markdown_links(self): + text = "See [the docs](https://example.com/docs) for info" + result = _strip_markdown_for_tts(text) + assert "the docs" in result + assert "https://" not in result + assert "[" not in result + + def test_strips_horizontal_rules(self): + text = "Part one\n---\nPart two" + result = _strip_markdown_for_tts(text) + assert "---" not in result + assert "Part one" in result + assert "Part two" in result + + def test_empty_after_stripping_returns_empty(self): + text = "```python\nprint('hello')\n```" + result = _strip_markdown_for_tts(text) + assert result == "" + + def test_long_text_not_truncated(self): + """_strip_markdown_for_tts does NOT truncate — that's the caller's job.""" + text = "a" * 5000 + result = _strip_markdown_for_tts(text) + assert len(result) == 5000 + + def test_complex_response(self): + text = ( + "## Answer\n\n" + "Here's how to do it:\n\n" + "```python\ndef hello():\n print('hi')\n```\n\n" + "Run it with `python main.py`. " + "See [docs](https://example.com) for more.\n\n" + "- Step one\n- Step two\n\n" + "---\n\n" + "**Good luck!**" + ) + result = _strip_markdown_for_tts(text) + assert "```" not in result + assert "https://" not in result + assert "**" not in result + assert "---" not in result + assert "Answer" in result + assert "Good luck!" in result + assert "docs" in result + + +# ============================================================================ +# Voice command parsing +# ============================================================================ + +class TestVoiceCommandParsing: + """Test _handle_voice_command logic without full CLI setup.""" + + def test_parse_subcommands(self): + """Verify subcommand extraction from /voice commands.""" + test_cases = [ + ("/voice on", "on"), + ("/voice off", "off"), + ("/voice tts", "tts"), + ("/voice status", "status"), + ("/voice", ""), + ("/voice ON ", "on"), + ] + for command, expected in test_cases: + parts = command.strip().split(maxsplit=1) + subcommand = parts[1].lower().strip() if len(parts) > 1 else "" + assert subcommand == expected, f"Failed for {command!r}: got {subcommand!r}" + + +# ============================================================================ +# Voice state thread safety +# ============================================================================ + +class TestVoiceStateLock: + def test_lock_protects_state(self): + """Verify that concurrent state changes don't corrupt state.""" + lock = threading.Lock() + state = {"recording": False, "count": 0} + + def toggle_many(n): + for _ in range(n): + with lock: + state["recording"] = not state["recording"] + state["count"] += 1 + + threads = [threading.Thread(target=toggle_many, args=(1000,)) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert state["count"] == 4000 + + +# ============================================================================ +# Streaming TTS lazy import activation (Bug A fix) +# ============================================================================ + +class TestStreamingTTSActivation: + """Verify streaming TTS uses lazy imports to check availability.""" + + def test_activates_when_elevenlabs_and_sounddevice_available(self): + """use_streaming_tts should be True when provider is elevenlabs + and both lazy imports succeed.""" + use_streaming_tts = False + try: + from tools.tts_tool import ( + _load_tts_config as _load_tts_cfg, + _get_provider as _get_prov, + _import_elevenlabs, + _import_sounddevice, + ) + assert callable(_import_elevenlabs) + assert callable(_import_sounddevice) + except ImportError: + pytest.skip("tools.tts_tool not available") + + with patch("tools.tts_tool._load_tts_config") as mock_cfg, \ + patch("tools.tts_tool._get_provider", return_value="elevenlabs"), \ + patch("tools.tts_tool._import_elevenlabs") as mock_el, \ + patch("tools.tts_tool._import_sounddevice") as mock_sd: + mock_cfg.return_value = {"provider": "elevenlabs"} + mock_el.return_value = MagicMock() + mock_sd.return_value = MagicMock() + + from tools.tts_tool import ( + _load_tts_config as load_cfg, + _get_provider as get_prov, + _import_elevenlabs as import_el, + _import_sounddevice as import_sd, + ) + cfg = load_cfg() + if get_prov(cfg) == "elevenlabs": + import_el() + import_sd() + use_streaming_tts = True + + assert use_streaming_tts is True + + def test_does_not_activate_when_elevenlabs_missing(self): + """use_streaming_tts stays False when elevenlabs import fails.""" + use_streaming_tts = False + with patch("tools.tts_tool._load_tts_config", return_value={"provider": "elevenlabs"}), \ + patch("tools.tts_tool._get_provider", return_value="elevenlabs"), \ + patch("tools.tts_tool._import_elevenlabs", side_effect=ImportError("no elevenlabs")): + try: + from tools.tts_tool import ( + _load_tts_config as load_cfg, + _get_provider as get_prov, + _import_elevenlabs as import_el, + _import_sounddevice as import_sd, + ) + cfg = load_cfg() + if get_prov(cfg) == "elevenlabs": + import_el() + import_sd() + use_streaming_tts = True + except (ImportError, OSError): + pass + + assert use_streaming_tts is False + + def test_does_not_activate_when_sounddevice_missing(self): + """use_streaming_tts stays False when sounddevice import fails.""" + use_streaming_tts = False + with patch("tools.tts_tool._load_tts_config", return_value={"provider": "elevenlabs"}), \ + patch("tools.tts_tool._get_provider", return_value="elevenlabs"), \ + patch("tools.tts_tool._import_elevenlabs", return_value=MagicMock()), \ + patch("tools.tts_tool._import_sounddevice", side_effect=OSError("no PortAudio")): + try: + from tools.tts_tool import ( + _load_tts_config as load_cfg, + _get_provider as get_prov, + _import_elevenlabs as import_el, + _import_sounddevice as import_sd, + ) + cfg = load_cfg() + if get_prov(cfg) == "elevenlabs": + import_el() + import_sd() + use_streaming_tts = True + except (ImportError, OSError): + pass + + assert use_streaming_tts is False + + def test_does_not_activate_for_non_elevenlabs_provider(self): + """use_streaming_tts stays False when provider is not elevenlabs.""" + use_streaming_tts = False + with patch("tools.tts_tool._load_tts_config", return_value={"provider": "edge"}), \ + patch("tools.tts_tool._get_provider", return_value="edge"): + try: + from tools.tts_tool import ( + _load_tts_config as load_cfg, + _get_provider as get_prov, + _import_elevenlabs as import_el, + _import_sounddevice as import_sd, + ) + cfg = load_cfg() + if get_prov(cfg) == "elevenlabs": + import_el() + import_sd() + use_streaming_tts = True + except (ImportError, OSError): + pass + + assert use_streaming_tts is False + + def test_stale_boolean_imports_no_longer_exist(self): + """Confirm _HAS_ELEVENLABS and _HAS_AUDIO are not in tts_tool module.""" + import tools.tts_tool as tts_mod + assert not hasattr(tts_mod, "_HAS_ELEVENLABS"), \ + "_HAS_ELEVENLABS should not exist -- lazy imports replaced it" + assert not hasattr(tts_mod, "_HAS_AUDIO"), \ + "_HAS_AUDIO should not exist -- lazy imports replaced it" + + +# ============================================================================ +# Voice mode user message prefix (Bug B fix) +# ============================================================================ + +class TestVoiceMessagePrefix: + """Voice mode should inject instruction via user message prefix, + not by modifying the system prompt (which breaks prompt cache).""" + + def test_prefix_added_when_voice_mode_active(self): + """When voice mode is active and message is str, agent_message + should have the voice instruction prefix.""" + voice_mode = True + message = "What's the weather like?" + + agent_message = message + if voice_mode and isinstance(message, str): + agent_message = ( + "[Voice input — respond concisely and conversationally, " + "2-3 sentences max. No code blocks or markdown.] " + + message + ) + + assert agent_message.startswith("[Voice input") + assert "What's the weather like?" in agent_message + + def test_no_prefix_when_voice_mode_inactive(self): + """When voice mode is off, message passes through unchanged.""" + voice_mode = False + message = "What's the weather like?" + + agent_message = message + if voice_mode and isinstance(message, str): + agent_message = ( + "[Voice input — respond concisely and conversationally, " + "2-3 sentences max. No code blocks or markdown.] " + + message + ) + + assert agent_message == message + + def test_no_prefix_for_multimodal_content(self): + """When message is a list (multimodal), no prefix is added.""" + voice_mode = True + message = [{"type": "text", "text": "describe this"}, {"type": "image_url"}] + + agent_message = message + if voice_mode and isinstance(message, str): + agent_message = ( + "[Voice input — respond concisely and conversationally, " + "2-3 sentences max. No code blocks or markdown.] " + + message + ) + + assert agent_message is message + + def test_history_stays_clean(self): + """conversation_history should contain the original message, + not the prefixed version.""" + voice_mode = True + message = "Hello there" + conversation_history = [] + + conversation_history.append({"role": "user", "content": message}) + + agent_message = message + if voice_mode and isinstance(message, str): + agent_message = ( + "[Voice input — respond concisely and conversationally, " + "2-3 sentences max. No code blocks or markdown.] " + + message + ) + + assert conversation_history[-1]["content"] == "Hello there" + assert agent_message.startswith("[Voice input") + assert agent_message != conversation_history[-1]["content"] + + def test_enable_voice_mode_does_not_modify_system_prompt(self): + """_enable_voice_mode should NOT modify self.system_prompt or + agent.ephemeral_system_prompt -- the system prompt must stay + stable to preserve prompt cache.""" + cli = SimpleNamespace( + _voice_mode=False, + _voice_tts=False, + _voice_lock=threading.Lock(), + system_prompt="You are helpful", + agent=SimpleNamespace(ephemeral_system_prompt="You are helpful"), + ) + + original_system = cli.system_prompt + original_ephemeral = cli.agent.ephemeral_system_prompt + + cli._voice_mode = True + + assert cli.system_prompt == original_system + assert cli.agent.ephemeral_system_prompt == original_ephemeral + + +# ============================================================================ +# _vprint force parameter (Minor fix) +# ============================================================================ + +class TestVprintForceParameter: + """_vprint should suppress output during streaming TTS unless force=True.""" + + def _make_agent_with_stream(self, stream_active: bool): + """Create a minimal agent-like object with _vprint.""" + agent = SimpleNamespace( + _stream_callback=MagicMock() if stream_active else None, + ) + + def _vprint(*args, force=False, **kwargs): + if not force and getattr(agent, "_stream_callback", None) is not None: + return + print(*args, **kwargs) + + agent._vprint = _vprint + return agent + + def test_suppressed_during_streaming(self, capsys): + """Normal _vprint output is suppressed when streaming TTS is active.""" + agent = self._make_agent_with_stream(stream_active=True) + agent._vprint("should be hidden") + captured = capsys.readouterr() + assert captured.out == "" + + def test_shown_when_not_streaming(self, capsys): + """Normal _vprint output is shown when streaming is not active.""" + agent = self._make_agent_with_stream(stream_active=False) + agent._vprint("should be shown") + captured = capsys.readouterr() + assert "should be shown" in captured.out + + def test_force_shown_during_streaming(self, capsys): + """force=True bypasses the streaming suppression.""" + agent = self._make_agent_with_stream(stream_active=True) + agent._vprint("critical error!", force=True) + captured = capsys.readouterr() + assert "critical error!" in captured.out + + def test_force_shown_when_not_streaming(self, capsys): + """force=True works normally when not streaming (no regression).""" + agent = self._make_agent_with_stream(stream_active=False) + agent._vprint("normal message", force=True) + captured = capsys.readouterr() + assert "normal message" in captured.out + + def test_error_messages_use_force_in_run_agent(self): + """Verify that critical error _vprint calls in run_agent.py + include force=True.""" + with open("run_agent.py", "r") as f: + source = f.read() + + tree = ast.parse(source) + + forced_error_count = 0 + unforced_error_count = 0 + + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + func = node.func + if not (isinstance(func, ast.Attribute) and func.attr == "_vprint"): + continue + has_fatal = False + for arg in node.args: + if isinstance(arg, ast.JoinedStr): + for val in arg.values: + if isinstance(val, ast.Constant) and isinstance(val.value, str): + if "\u274c" in val.value: + has_fatal = True + break + + if not has_fatal: + continue + + has_force = any( + kw.arg == "force" + and isinstance(kw.value, ast.Constant) + and kw.value.value is True + for kw in node.keywords + ) + + if has_force: + forced_error_count += 1 + else: + unforced_error_count += 1 + + assert forced_error_count > 0, \ + "Expected at least one _vprint with force=True for error messages" + assert unforced_error_count == 0, \ + f"Found {unforced_error_count} critical error _vprint calls without force=True" + + +# ============================================================================ +# Bug fix regression tests +# ============================================================================ + +class TestEdgeTTSLazyImport: + """Bug #3: _generate_edge_tts must use lazy import, not bare module name.""" + + def test_generate_edge_tts_calls_lazy_import(self): + """AST check: _generate_edge_tts must call _import_edge_tts(), not + reference bare 'edge_tts' module name.""" + import ast as _ast + + with open("tools/tts_tool.py") as f: + tree = _ast.parse(f.read()) + + for node in _ast.walk(tree): + if isinstance(node, _ast.AsyncFunctionDef) and node.name == "_generate_edge_tts": + # Collect all Name references (bare identifiers) + bare_refs = [ + n.id for n in _ast.walk(node) + if isinstance(n, _ast.Name) and n.id == "edge_tts" + ] + assert bare_refs == [], ( + f"_generate_edge_tts uses bare 'edge_tts' name — " + f"should use _import_edge_tts() lazy helper" + ) + + # Must have a call to _import_edge_tts + lazy_calls = [ + n for n in _ast.walk(node) + if isinstance(n, _ast.Call) + and isinstance(n.func, _ast.Name) + and n.func.id == "_import_edge_tts" + ] + assert len(lazy_calls) >= 1, ( + "_generate_edge_tts must call _import_edge_tts()" + ) + break + else: + pytest.fail("_generate_edge_tts not found in tts_tool.py") + + +class TestStreamingTTSOutputStreamCleanup: + """Bug #7: output_stream must be closed in finally block.""" + + def test_output_stream_closed_in_finally(self): + """AST check: stream_tts_to_speaker's finally block must close + output_stream even on exception.""" + import ast as _ast + + with open("tools/tts_tool.py") as f: + tree = _ast.parse(f.read()) + + for node in _ast.walk(tree): + if isinstance(node, _ast.FunctionDef) and node.name == "stream_tts_to_speaker": + # Find the outermost try that has a finally with tts_done_event.set() + for child in _ast.walk(node): + if isinstance(child, _ast.Try) and child.finalbody: + finally_text = "\n".join( + _ast.dump(n) for n in child.finalbody + ) + if "tts_done_event" in finally_text: + assert "output_stream" in finally_text, ( + "finally block must close output_stream" + ) + return + pytest.fail("No finally block with tts_done_event found") + + +class TestCtrlCResetsContinuousMode: + """Bug #4: Ctrl+C cancel must reset _voice_continuous.""" + + def test_ctrl_c_handler_resets_voice_continuous(self): + """Source check: Ctrl+C voice cancel block must set + _voice_continuous = False.""" + with open("cli.py") as f: + source = f.read() + + # Find the Ctrl+C handler's voice cancel block + lines = source.split("\n") + in_cancel_block = False + found_continuous_reset = False + for i, line in enumerate(lines): + if "Cancel active voice recording" in line: + in_cancel_block = True + if in_cancel_block: + if "_voice_continuous = False" in line: + found_continuous_reset = True + break + # Block ends at next comment section or return + if "return" in line and in_cancel_block: + break + + assert found_continuous_reset, ( + "Ctrl+C voice cancel block must set _voice_continuous = False" + ) + + +class TestDisableVoiceModeStopsTTS: + """Bug #5: _disable_voice_mode must stop active TTS playback.""" + + def test_disable_voice_mode_calls_stop_playback(self): + """Source check: _disable_voice_mode must call stop_playback().""" + import inspect + from cli import HermesCLI + + source = inspect.getsource(HermesCLI._disable_voice_mode) + assert "stop_playback" in source, ( + "_disable_voice_mode must call stop_playback()" + ) + assert "_voice_tts_done.set()" in source, ( + "_disable_voice_mode must set _voice_tts_done" + ) + + +class TestVoiceStatusUsesConfigKey: + """Bug #8: _show_voice_status must read record key from config.""" + + def test_show_voice_status_not_hardcoded(self): + """Source check: _show_voice_status must not hardcode Ctrl+B.""" + with open("cli.py") as f: + source = f.read() + + lines = source.split("\n") + in_method = False + for line in lines: + if "def _show_voice_status" in line: + in_method = True + elif in_method and line.strip().startswith("def "): + break + elif in_method: + assert 'Record key: Ctrl+B"' not in line, ( + "_show_voice_status hardcodes 'Ctrl+B' — " + "should read from config" + ) + + def test_show_voice_status_reads_config(self): + """Source check: _show_voice_status must use load_config().""" + with open("cli.py") as f: + source = f.read() + + lines = source.split("\n") + in_method = False + method_lines = [] + for line in lines: + if "def _show_voice_status" in line: + in_method = True + elif in_method and line.strip().startswith("def "): + break + elif in_method: + method_lines.append(line) + + method_body = "\n".join(method_lines) + assert "load_config" in method_body or "record_key" in method_body, ( + "_show_voice_status should read record_key from config" + ) + + +class TestChatTTSCleanupOnException: + """Bug #2: chat() must clean up streaming TTS resources on exception.""" + + def test_chat_has_finally_for_tts_cleanup(self): + """AST check: chat() method must have a finally block that cleans up + text_queue, stop_event, and tts_thread.""" + import ast as _ast + + with open("cli.py") as f: + tree = _ast.parse(f.read()) + + for node in _ast.walk(tree): + if isinstance(node, _ast.FunctionDef) and node.name == "chat": + # Find Try nodes with finally blocks + for child in _ast.walk(node): + if isinstance(child, _ast.Try) and child.finalbody: + finally_text = "\n".join( + _ast.dump(n) for n in child.finalbody + ) + if "text_queue" in finally_text: + assert "stop_event" in finally_text, ( + "finally must also handle stop_event" + ) + assert "tts_thread" in finally_text, ( + "finally must also handle tts_thread" + ) + return + pytest.fail( + "chat() must have a finally block cleaning up " + "text_queue/stop_event/tts_thread" + ) + + +class TestBrowserToolSignalHandlerRemoved: + """browser_tool.py must NOT register SIGINT/SIGTERM handlers that call + sys.exit() — this conflicts with prompt_toolkit's event loop and causes + the process to become unkillable during voice mode.""" + + def test_no_signal_handler_registration(self): + """Source check: browser_tool.py must not call signal.signal() + for SIGINT or SIGTERM.""" + with open("tools/browser_tool.py") as f: + source = f.read() + + lines = source.split("\n") + for i, line in enumerate(lines, 1): + stripped = line.strip() + # Skip comments + if stripped.startswith("#"): + continue + assert "signal.signal(signal.SIGINT" not in stripped, ( + f"browser_tool.py:{i} registers SIGINT handler — " + f"use atexit instead to avoid prompt_toolkit conflicts" + ) + assert "signal.signal(signal.SIGTERM" not in stripped, ( + f"browser_tool.py:{i} registers SIGTERM handler — " + f"use atexit instead to avoid prompt_toolkit conflicts" + ) + + +class TestKeyHandlerNeverBlocks: + """The Ctrl+B key handler runs in prompt_toolkit's event-loop thread. + Any blocking call freezes the entire UI. Verify that: + 1. _voice_start_recording is NOT called directly (must be in daemon thread) + 2. _voice_processing guard prevents starting while stop/transcribe runs + 3. _voice_processing is set atomically with _voice_recording in stop_and_transcribe + """ + + def test_start_recording_not_called_directly_in_handler(self): + """AST check: handle_voice_record must NOT call _voice_start_recording() + directly — it must wrap it in a Thread to avoid blocking the UI.""" + import ast as _ast + + with open("cli.py") as f: + tree = _ast.parse(f.read()) + + for node in _ast.walk(tree): + if isinstance(node, _ast.FunctionDef) and node.name == "handle_voice_record": + # Collect all direct calls to _voice_start_recording in this function. + # They should ONLY appear inside a nested def (the _start_recording wrapper). + for child in _ast.iter_child_nodes(node): + # Direct statements in the handler body (not nested defs) + if isinstance(child, _ast.Expr) and isinstance(child.value, _ast.Call): + call_src = _ast.dump(child.value) + assert "_voice_start_recording" not in call_src, ( + "handle_voice_record calls _voice_start_recording directly " + "— must dispatch to a daemon thread" + ) + break + + def test_processing_guard_in_start_path(self): + """Source check: key handler must check _voice_processing before + starting a new recording.""" + with open("cli.py") as f: + source = f.read() + + lines = source.split("\n") + in_handler = False + in_else = False + found_guard = False + for line in lines: + if "def handle_voice_record" in line: + in_handler = True + elif in_handler and line.strip().startswith("def ") and "_start_recording" not in line: + break + elif in_handler and "else:" in line: + in_else = True + elif in_else and "_voice_processing" in line: + found_guard = True + break + + assert found_guard, ( + "Key handler START path must guard against _voice_processing " + "to prevent blocking on AudioRecorder._lock" + ) + + def test_processing_set_atomically_with_recording_false(self): + """Source check: _voice_stop_and_transcribe must set _voice_processing = True + in the same lock block where it sets _voice_recording = False.""" + with open("cli.py") as f: + source = f.read() + + lines = source.split("\n") + in_method = False + in_first_lock = False + found_recording_false = False + found_processing_true = False + for line in lines: + if "def _voice_stop_and_transcribe" in line: + in_method = True + elif in_method and "with self._voice_lock:" in line and not in_first_lock: + in_first_lock = True + elif in_first_lock: + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + if "_voice_recording = False" in stripped: + found_recording_false = True + if "_voice_processing = True" in stripped: + found_processing_true = True + # End of with block (dedent) + if stripped and not line.startswith(" ") and not line.startswith("\t\t\t"): + break + + assert found_recording_false and found_processing_true, ( + "_voice_stop_and_transcribe must set _voice_processing = True " + "atomically (same lock block) with _voice_recording = False" + ) + + +# ============================================================================ +# Real behavior tests — CLI voice methods via _make_voice_cli() +# ============================================================================ + +class TestHandleVoiceCommandReal: + """Tests _handle_voice_command routing with real CLI instance.""" + + def _cli(self): + cli = _make_voice_cli() + cli._enable_voice_mode = MagicMock() + cli._disable_voice_mode = MagicMock() + cli._toggle_voice_tts = MagicMock() + cli._show_voice_status = MagicMock() + return cli + + @patch("cli._cprint") + def test_on_calls_enable(self, _cp): + cli = self._cli() + cli._handle_voice_command("/voice on") + cli._enable_voice_mode.assert_called_once() + + @patch("cli._cprint") + def test_off_calls_disable(self, _cp): + cli = self._cli() + cli._handle_voice_command("/voice off") + cli._disable_voice_mode.assert_called_once() + + @patch("cli._cprint") + def test_tts_calls_toggle(self, _cp): + cli = self._cli() + cli._handle_voice_command("/voice tts") + cli._toggle_voice_tts.assert_called_once() + + @patch("cli._cprint") + def test_status_calls_show(self, _cp): + cli = self._cli() + cli._handle_voice_command("/voice status") + cli._show_voice_status.assert_called_once() + + @patch("cli._cprint") + def test_toggle_off_when_enabled(self, _cp): + cli = self._cli() + cli._voice_mode = True + cli._handle_voice_command("/voice") + cli._disable_voice_mode.assert_called_once() + + @patch("cli._cprint") + def test_toggle_on_when_disabled(self, _cp): + cli = self._cli() + cli._voice_mode = False + cli._handle_voice_command("/voice") + cli._enable_voice_mode.assert_called_once() + + @patch("cli._cprint") + def test_unknown_subcommand(self, mock_cp): + cli = self._cli() + cli._handle_voice_command("/voice foobar") + cli._enable_voice_mode.assert_not_called() + cli._disable_voice_mode.assert_not_called() + # Should print usage via _cprint + assert any("Unknown" in str(c) or "unknown" in str(c) + for c in mock_cp.call_args_list) + + +class TestEnableVoiceModeReal: + """Tests _enable_voice_mode with real CLI instance.""" + + @patch("cli._cprint") + @patch("hermes_cli.config.load_config", return_value={"voice": {}}) + @patch("tools.voice_mode.check_voice_requirements", + return_value={"available": True, "details": "OK"}) + @patch("tools.voice_mode.detect_audio_environment", + return_value={"available": True, "warnings": []}) + def test_success_sets_voice_mode(self, _env, _req, _cfg, _cp): + cli = _make_voice_cli() + cli._enable_voice_mode() + assert cli._voice_mode is True + + @patch("cli._cprint") + def test_already_enabled_noop(self, _cp): + cli = _make_voice_cli(_voice_mode=True) + cli._enable_voice_mode() + assert cli._voice_mode is True + + @patch("cli._cprint") + @patch("tools.voice_mode.detect_audio_environment", + return_value={"available": False, "warnings": ["SSH session"]}) + def test_env_check_fails(self, _env, _cp): + cli = _make_voice_cli() + cli._enable_voice_mode() + assert cli._voice_mode is False + + @patch("cli._cprint") + @patch("tools.voice_mode.check_voice_requirements", + return_value={"available": False, "details": "Missing", + "missing_packages": ["sounddevice"]}) + @patch("tools.voice_mode.detect_audio_environment", + return_value={"available": True, "warnings": []}) + def test_requirements_fail(self, _env, _req, _cp): + cli = _make_voice_cli() + cli._enable_voice_mode() + assert cli._voice_mode is False + + @patch("cli._cprint") + @patch("hermes_cli.config.load_config", return_value={"voice": {"auto_tts": True}}) + @patch("tools.voice_mode.check_voice_requirements", + return_value={"available": True, "details": "OK"}) + @patch("tools.voice_mode.detect_audio_environment", + return_value={"available": True, "warnings": []}) + def test_auto_tts_from_config(self, _env, _req, _cfg, _cp): + cli = _make_voice_cli() + cli._enable_voice_mode() + assert cli._voice_tts is True + + @patch("cli._cprint") + @patch("hermes_cli.config.load_config", return_value={"voice": {}}) + @patch("tools.voice_mode.check_voice_requirements", + return_value={"available": True, "details": "OK"}) + @patch("tools.voice_mode.detect_audio_environment", + return_value={"available": True, "warnings": []}) + def test_no_auto_tts_default(self, _env, _req, _cfg, _cp): + cli = _make_voice_cli() + cli._enable_voice_mode() + assert cli._voice_tts is False + + @patch("cli._cprint") + @patch("hermes_cli.config.load_config", side_effect=Exception("broken config")) + @patch("tools.voice_mode.check_voice_requirements", + return_value={"available": True, "details": "OK"}) + @patch("tools.voice_mode.detect_audio_environment", + return_value={"available": True, "warnings": []}) + def test_config_exception_still_enables(self, _env, _req, _cfg, _cp): + cli = _make_voice_cli() + cli._enable_voice_mode() + assert cli._voice_mode is True + + +class TestDisableVoiceModeReal: + """Tests _disable_voice_mode with real CLI instance.""" + + @patch("cli._cprint") + @patch("tools.voice_mode.stop_playback") + def test_all_flags_reset(self, _sp, _cp): + cli = _make_voice_cli(_voice_mode=True, _voice_tts=True, + _voice_continuous=True) + cli._disable_voice_mode() + assert cli._voice_mode is False + assert cli._voice_tts is False + assert cli._voice_continuous is False + + @patch("cli._cprint") + @patch("tools.voice_mode.stop_playback") + def test_active_recording_cancelled(self, _sp, _cp): + recorder = MagicMock() + cli = _make_voice_cli(_voice_recording=True, _voice_recorder=recorder) + cli._disable_voice_mode() + recorder.cancel.assert_called_once() + assert cli._voice_recording is False + + @patch("cli._cprint") + @patch("tools.voice_mode.stop_playback") + def test_stop_playback_called(self, mock_sp, _cp): + cli = _make_voice_cli() + cli._disable_voice_mode() + mock_sp.assert_called_once() + + @patch("cli._cprint") + @patch("tools.voice_mode.stop_playback") + def test_tts_done_event_set(self, _sp, _cp): + cli = _make_voice_cli() + cli._voice_tts_done.clear() + cli._disable_voice_mode() + assert cli._voice_tts_done.is_set() + + @patch("cli._cprint") + @patch("tools.voice_mode.stop_playback") + def test_no_recorder_no_crash(self, _sp, _cp): + cli = _make_voice_cli(_voice_recording=True, _voice_recorder=None) + cli._disable_voice_mode() + assert cli._voice_mode is False + + @patch("cli._cprint") + @patch("tools.voice_mode.stop_playback", side_effect=RuntimeError("boom")) + def test_stop_playback_exception_swallowed(self, _sp, _cp): + cli = _make_voice_cli(_voice_mode=True) + cli._disable_voice_mode() + assert cli._voice_mode is False + + +class TestVoiceSpeakResponseReal: + """Tests _voice_speak_response with real CLI instance.""" + + @patch("cli._cprint") + def test_early_return_when_tts_off(self, _cp): + cli = _make_voice_cli(_voice_tts=False) + with patch("tools.tts_tool.text_to_speech_tool") as mock_tts: + cli._voice_speak_response("Hello") + mock_tts.assert_not_called() + + @patch("cli._cprint") + @patch("cli.os.unlink") + @patch("cli.os.path.getsize", return_value=1000) + @patch("cli.os.path.isfile", return_value=True) + @patch("cli.os.makedirs") + @patch("tools.voice_mode.play_audio_file") + @patch("tools.tts_tool.text_to_speech_tool", return_value='{"success": true}') + def test_markdown_stripped(self, mock_tts, _play, _mkd, _isf, _gsz, _unl, _cp): + cli = _make_voice_cli(_voice_tts=True) + cli._voice_speak_response("## Title\n**bold** and `code`") + call_text = mock_tts.call_args.kwargs["text"] + assert "##" not in call_text + assert "**" not in call_text + assert "`" not in call_text + + @patch("cli._cprint") + @patch("cli.os.makedirs") + @patch("tools.tts_tool.text_to_speech_tool", return_value='{"success": true}') + def test_code_blocks_removed(self, mock_tts, _mkd, _cp): + cli = _make_voice_cli(_voice_tts=True) + cli._voice_speak_response("```python\nprint('hi')\n```\nSome text") + call_text = mock_tts.call_args.kwargs["text"] + assert "print" not in call_text + assert "```" not in call_text + assert "Some text" in call_text + + @patch("cli._cprint") + @patch("cli.os.makedirs") + def test_empty_after_strip_returns_early(self, _mkd, _cp): + cli = _make_voice_cli(_voice_tts=True) + with patch("tools.tts_tool.text_to_speech_tool") as mock_tts: + cli._voice_speak_response("```python\nprint('hi')\n```") + mock_tts.assert_not_called() + + @patch("cli._cprint") + @patch("cli.os.makedirs") + @patch("tools.tts_tool.text_to_speech_tool", return_value='{"success": true}') + def test_long_text_truncated(self, mock_tts, _mkd, _cp): + cli = _make_voice_cli(_voice_tts=True) + cli._voice_speak_response("A" * 5000) + call_text = mock_tts.call_args.kwargs["text"] + assert len(call_text) <= 4000 + + @patch("cli._cprint") + @patch("cli.os.makedirs") + @patch("tools.tts_tool.text_to_speech_tool", side_effect=RuntimeError("tts fail")) + def test_exception_sets_done_event(self, _tts, _mkd, _cp): + cli = _make_voice_cli(_voice_tts=True) + cli._voice_tts_done.clear() + cli._voice_speak_response("Hello") + assert cli._voice_tts_done.is_set() + + @patch("cli._cprint") + @patch("cli.os.unlink") + @patch("cli.os.path.getsize", return_value=1000) + @patch("cli.os.path.isfile", return_value=True) + @patch("cli.os.makedirs") + @patch("tools.voice_mode.play_audio_file") + @patch("tools.tts_tool.text_to_speech_tool", return_value='{"success": true}') + def test_play_audio_called(self, _tts, mock_play, _mkd, _isf, _gsz, _unl, _cp): + cli = _make_voice_cli(_voice_tts=True) + cli._voice_speak_response("Hello world") + mock_play.assert_called_once() + + +class TestVoiceStopAndTranscribeReal: + """Tests _voice_stop_and_transcribe with real CLI instance.""" + + @patch("cli._cprint") + def test_guard_not_recording(self, _cp): + cli = _make_voice_cli(_voice_recording=False) + with patch("tools.voice_mode.transcribe_recording") as mock_tr: + cli._voice_stop_and_transcribe() + mock_tr.assert_not_called() + + @patch("cli._cprint") + def test_no_recorder_returns_early(self, _cp): + cli = _make_voice_cli(_voice_recording=True, _voice_recorder=None) + with patch("tools.voice_mode.transcribe_recording") as mock_tr: + cli._voice_stop_and_transcribe() + mock_tr.assert_not_called() + assert cli._voice_recording is False + + @patch("cli._cprint") + @patch("tools.voice_mode.play_beep") + def test_no_speech_detected(self, _beep, _cp): + recorder = MagicMock() + recorder.stop.return_value = None + cli = _make_voice_cli(_voice_recording=True, _voice_recorder=recorder) + cli._voice_stop_and_transcribe() + assert cli._pending_input.empty() + + @patch("cli._cprint") + @patch("cli.os.unlink") + @patch("cli.os.path.isfile", return_value=True) + @patch("hermes_cli.config.load_config", return_value={"stt": {}}) + @patch("tools.voice_mode.transcribe_recording", + return_value={"success": True, "transcript": "hello world"}) + @patch("tools.voice_mode.play_beep") + def test_successful_transcription_queues_input( + self, _beep, _tr, _cfg, _isf, _unl, _cp + ): + recorder = MagicMock() + recorder.stop.return_value = "/tmp/test.wav" + cli = _make_voice_cli(_voice_recording=True, _voice_recorder=recorder) + cli._voice_stop_and_transcribe() + assert cli._pending_input.get_nowait() == "hello world" + + @patch("cli._cprint") + @patch("cli.os.unlink") + @patch("cli.os.path.isfile", return_value=True) + @patch("hermes_cli.config.load_config", return_value={"stt": {}}) + @patch("tools.voice_mode.transcribe_recording", + return_value={"success": True, "transcript": ""}) + @patch("tools.voice_mode.play_beep") + def test_empty_transcript_not_queued(self, _beep, _tr, _cfg, _isf, _unl, _cp): + recorder = MagicMock() + recorder.stop.return_value = "/tmp/test.wav" + cli = _make_voice_cli(_voice_recording=True, _voice_recorder=recorder) + cli._voice_stop_and_transcribe() + assert cli._pending_input.empty() + + @patch("cli._cprint") + @patch("cli.os.unlink") + @patch("cli.os.path.isfile", return_value=True) + @patch("hermes_cli.config.load_config", return_value={"stt": {}}) + @patch("tools.voice_mode.transcribe_recording", + return_value={"success": False, "error": "API timeout"}) + @patch("tools.voice_mode.play_beep") + def test_transcription_failure(self, _beep, _tr, _cfg, _isf, _unl, _cp): + recorder = MagicMock() + recorder.stop.return_value = "/tmp/test.wav" + cli = _make_voice_cli(_voice_recording=True, _voice_recorder=recorder) + cli._voice_stop_and_transcribe() + assert cli._pending_input.empty() + + @patch("cli._cprint") + @patch("cli.os.unlink") + @patch("cli.os.path.isfile", return_value=True) + @patch("hermes_cli.config.load_config", return_value={"stt": {}}) + @patch("tools.voice_mode.transcribe_recording", + side_effect=ConnectionError("network")) + @patch("tools.voice_mode.play_beep") + def test_exception_caught(self, _beep, _tr, _cfg, _isf, _unl, _cp): + recorder = MagicMock() + recorder.stop.return_value = "/tmp/test.wav" + cli = _make_voice_cli(_voice_recording=True, _voice_recorder=recorder) + cli._voice_stop_and_transcribe() # Should not raise + + @patch("cli._cprint") + @patch("tools.voice_mode.play_beep") + def test_processing_flag_cleared(self, _beep, _cp): + recorder = MagicMock() + recorder.stop.return_value = None + cli = _make_voice_cli(_voice_recording=True, _voice_recorder=recorder) + cli._voice_stop_and_transcribe() + assert cli._voice_processing is False + + @patch("cli._cprint") + @patch("tools.voice_mode.play_beep") + def test_continuous_restarts_on_no_speech(self, _beep, _cp): + recorder = MagicMock() + recorder.stop.return_value = None + cli = _make_voice_cli(_voice_recording=True, _voice_recorder=recorder, + _voice_continuous=True) + cli._voice_start_recording = MagicMock() + cli._voice_stop_and_transcribe() + cli._voice_start_recording.assert_called_once() + + @patch("cli._cprint") + @patch("cli.os.unlink") + @patch("cli.os.path.isfile", return_value=True) + @patch("hermes_cli.config.load_config", return_value={"stt": {}}) + @patch("tools.voice_mode.transcribe_recording", + return_value={"success": True, "transcript": "hello"}) + @patch("tools.voice_mode.play_beep") + def test_continuous_no_restart_on_success( + self, _beep, _tr, _cfg, _isf, _unl, _cp + ): + recorder = MagicMock() + recorder.stop.return_value = "/tmp/test.wav" + cli = _make_voice_cli(_voice_recording=True, _voice_recorder=recorder, + _voice_continuous=True) + cli._voice_start_recording = MagicMock() + cli._voice_stop_and_transcribe() + cli._voice_start_recording.assert_not_called() + + @patch("cli._cprint") + @patch("cli.os.unlink") + @patch("cli.os.path.isfile", return_value=True) + @patch("hermes_cli.config.load_config", return_value={"stt": {"model": "whisper-large-v3"}}) + @patch("tools.voice_mode.transcribe_recording", + return_value={"success": True, "transcript": "hi"}) + @patch("tools.voice_mode.play_beep") + def test_stt_model_from_config(self, _beep, mock_tr, _cfg, _isf, _unl, _cp): + recorder = MagicMock() + recorder.stop.return_value = "/tmp/test.wav" + cli = _make_voice_cli(_voice_recording=True, _voice_recorder=recorder) + cli._voice_stop_and_transcribe() + mock_tr.assert_called_once_with("/tmp/test.wav", model="whisper-large-v3") + + +# --------------------------------------------------------------------------- +# Bugfix: _refresh_level must read _voice_recording under lock +# --------------------------------------------------------------------------- + + +class TestRefreshLevelLock: + """Bug: _refresh_level thread read _voice_recording without lock.""" + + def test_refresh_stops_when_recording_false(self): + import threading, time + + lock = threading.Lock() + recording = True + iterations = 0 + + def refresh_level(): + nonlocal iterations + while True: + with lock: + still = recording + if not still: + break + iterations += 1 + time.sleep(0.01) + + t = threading.Thread(target=refresh_level, daemon=True) + t.start() + + time.sleep(0.05) + with lock: + recording = False + + t.join(timeout=1) + assert not t.is_alive(), "Refresh thread did not stop" + assert iterations > 0, "Refresh thread never ran" diff --git a/tests/tools/test_voice_mode.py b/tests/tools/test_voice_mode.py new file mode 100644 index 0000000000..013ed66353 --- /dev/null +++ b/tests/tools/test_voice_mode.py @@ -0,0 +1,938 @@ +"""Tests for tools.voice_mode -- all mocked, no real microphone or API calls.""" + +import os +import struct +import time +import wave +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + + +# ============================================================================ +# Fixtures +# ============================================================================ + +@pytest.fixture +def sample_wav(tmp_path): + """Create a minimal valid WAV file (1 second of silence at 16kHz).""" + wav_path = tmp_path / "test.wav" + n_frames = 16000 # 1 second at 16kHz + silence = struct.pack(f"<{n_frames}h", *([0] * n_frames)) + + with wave.open(str(wav_path), "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(16000) + wf.writeframes(silence) + + return str(wav_path) + + +@pytest.fixture +def temp_voice_dir(tmp_path, monkeypatch): + """Redirect _TEMP_DIR to a temporary path.""" + voice_dir = tmp_path / "hermes_voice" + voice_dir.mkdir() + monkeypatch.setattr("tools.voice_mode._TEMP_DIR", str(voice_dir)) + return voice_dir + + +@pytest.fixture +def mock_sd(monkeypatch): + """Mock _import_audio to return (mock_sd, real_np) so lazy imports work.""" + mock = MagicMock() + try: + import numpy as real_np + except ImportError: + real_np = MagicMock() + + def _fake_import_audio(): + return mock, real_np + + monkeypatch.setattr("tools.voice_mode._import_audio", _fake_import_audio) + monkeypatch.setattr("tools.voice_mode._audio_available", lambda: True) + return mock + + +# ============================================================================ +# check_voice_requirements +# ============================================================================ + +class TestCheckVoiceRequirements: + def test_all_requirements_met(self, monkeypatch): + monkeypatch.setattr("tools.voice_mode._audio_available", lambda: True) + monkeypatch.setattr("tools.voice_mode.detect_audio_environment", + lambda: {"available": True, "warnings": []}) + monkeypatch.setattr("tools.transcription_tools._get_provider", lambda cfg: "openai") + + from tools.voice_mode import check_voice_requirements + + result = check_voice_requirements() + assert result["available"] is True + assert result["audio_available"] is True + assert result["stt_available"] is True + assert result["missing_packages"] == [] + + def test_missing_audio_packages(self, monkeypatch): + monkeypatch.setattr("tools.voice_mode._audio_available", lambda: False) + monkeypatch.setattr("tools.voice_mode.detect_audio_environment", + lambda: {"available": False, "warnings": ["Audio libraries not installed"]}) + monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test-key") + + from tools.voice_mode import check_voice_requirements + + result = check_voice_requirements() + assert result["available"] is False + assert result["audio_available"] is False + assert "sounddevice" in result["missing_packages"] + assert "numpy" in result["missing_packages"] + + def test_missing_stt_provider(self, monkeypatch): + monkeypatch.setattr("tools.voice_mode._audio_available", lambda: True) + monkeypatch.setattr("tools.voice_mode.detect_audio_environment", + lambda: {"available": True, "warnings": []}) + monkeypatch.setattr("tools.transcription_tools._get_provider", lambda cfg: "none") + + from tools.voice_mode import check_voice_requirements + + result = check_voice_requirements() + assert result["available"] is False + assert result["stt_available"] is False + assert "STT provider: MISSING" in result["details"] + + +# ============================================================================ +# AudioRecorder +# ============================================================================ + +class TestAudioRecorderStart: + def test_start_raises_without_audio(self, monkeypatch): + def _fail_import(): + raise ImportError("no sounddevice") + monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import) + + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + with pytest.raises(RuntimeError, match="sounddevice and numpy"): + recorder.start() + + def test_start_creates_and_starts_stream(self, mock_sd): + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + recorder.start() + + assert recorder.is_recording is True + mock_sd.InputStream.assert_called_once() + mock_stream.start.assert_called_once() + + def test_double_start_is_noop(self, mock_sd): + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + recorder.start() + recorder.start() # second call should be noop + + assert mock_sd.InputStream.call_count == 1 + + +class TestAudioRecorderStop: + def test_stop_returns_none_when_not_recording(self): + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + assert recorder.stop() is None + + def test_stop_writes_wav_file(self, mock_sd, temp_voice_dir): + np = pytest.importorskip("numpy") + + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder, SAMPLE_RATE + + recorder = AudioRecorder() + recorder.start() + + # Simulate captured audio frames (1 second of loud audio above RMS threshold) + frame = np.full((SAMPLE_RATE, 1), 1000, dtype="int16") + recorder._frames = [frame] + recorder._peak_rms = 1000 # Peak RMS above threshold + + wav_path = recorder.stop() + + assert wav_path is not None + assert os.path.isfile(wav_path) + assert wav_path.endswith(".wav") + assert recorder.is_recording is False + + # Verify it is a valid WAV + with wave.open(wav_path, "rb") as wf: + assert wf.getnchannels() == 1 + assert wf.getsampwidth() == 2 + assert wf.getframerate() == SAMPLE_RATE + + def test_stop_returns_none_for_very_short_recording(self, mock_sd, temp_voice_dir): + np = pytest.importorskip("numpy") + + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + recorder.start() + + # Very short recording (100 samples = ~6ms at 16kHz) + frame = np.zeros((100, 1), dtype="int16") + recorder._frames = [frame] + + wav_path = recorder.stop() + assert wav_path is None + + def test_stop_returns_none_for_silent_recording(self, mock_sd, temp_voice_dir): + np = pytest.importorskip("numpy") + + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder, SAMPLE_RATE + + recorder = AudioRecorder() + recorder.start() + + # 1 second of near-silence (RMS well below threshold) + frame = np.full((SAMPLE_RATE, 1), 10, dtype="int16") + recorder._frames = [frame] + recorder._peak_rms = 10 # Peak RMS also below threshold + + wav_path = recorder.stop() + assert wav_path is None + + +class TestAudioRecorderCancel: + def test_cancel_discards_frames(self, mock_sd): + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + recorder.start() + recorder._frames = [MagicMock()] # simulate captured data + + recorder.cancel() + + assert recorder.is_recording is False + assert recorder._frames == [] + # Stream is kept alive (persistent) — cancel() does NOT close it. + mock_stream.stop.assert_not_called() + mock_stream.close.assert_not_called() + + def test_cancel_when_not_recording_is_safe(self): + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + recorder.cancel() # should not raise + assert recorder.is_recording is False + + +class TestAudioRecorderProperties: + def test_elapsed_seconds_when_not_recording(self): + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + assert recorder.elapsed_seconds == 0.0 + + def test_elapsed_seconds_when_recording(self, mock_sd): + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + recorder.start() + + # Force start time to 1 second ago + recorder._start_time = time.monotonic() - 1.0 + elapsed = recorder.elapsed_seconds + assert 0.9 < elapsed < 2.0 + + recorder.cancel() + + +# ============================================================================ +# transcribe_recording +# ============================================================================ + +class TestTranscribeRecording: + def test_delegates_to_transcribe_audio(self): + mock_transcribe = MagicMock(return_value={ + "success": True, + "transcript": "hello world", + }) + + with patch("tools.transcription_tools.transcribe_audio", mock_transcribe): + from tools.voice_mode import transcribe_recording + result = transcribe_recording("/tmp/test.wav", model="whisper-1") + + assert result["success"] is True + assert result["transcript"] == "hello world" + mock_transcribe.assert_called_once_with("/tmp/test.wav", model="whisper-1") + + def test_filters_whisper_hallucination(self): + mock_transcribe = MagicMock(return_value={ + "success": True, + "transcript": "Thank you.", + }) + + with patch("tools.transcription_tools.transcribe_audio", mock_transcribe): + from tools.voice_mode import transcribe_recording + result = transcribe_recording("/tmp/test.wav") + + assert result["success"] is True + assert result["transcript"] == "" + assert result["filtered"] is True + + def test_does_not_filter_real_speech(self): + mock_transcribe = MagicMock(return_value={ + "success": True, + "transcript": "Thank you for helping me with this code.", + }) + + with patch("tools.transcription_tools.transcribe_audio", mock_transcribe): + from tools.voice_mode import transcribe_recording + result = transcribe_recording("/tmp/test.wav") + + assert result["transcript"] == "Thank you for helping me with this code." + assert "filtered" not in result + + +class TestWhisperHallucinationFilter: + def test_known_hallucinations(self): + from tools.voice_mode import is_whisper_hallucination + + assert is_whisper_hallucination("Thank you.") is True + assert is_whisper_hallucination("thank you") is True + assert is_whisper_hallucination("Thanks for watching.") is True + assert is_whisper_hallucination("Bye.") is True + assert is_whisper_hallucination(" Thank you. ") is True # with whitespace + assert is_whisper_hallucination("you") is True + + def test_real_speech_not_filtered(self): + from tools.voice_mode import is_whisper_hallucination + + assert is_whisper_hallucination("Hello, how are you?") is False + assert is_whisper_hallucination("Thank you for your help with the project.") is False + assert is_whisper_hallucination("Can you explain this code?") is False + + +# ============================================================================ +# play_audio_file +# ============================================================================ + +class TestPlayAudioFile: + def test_play_wav_via_sounddevice(self, monkeypatch, sample_wav): + np = pytest.importorskip("numpy") + + mock_sd_obj = MagicMock() + # Simulate stream completing immediately (get_stream().active = False) + mock_stream = MagicMock() + mock_stream.active = False + mock_sd_obj.get_stream.return_value = mock_stream + + def _fake_import(): + return mock_sd_obj, np + + monkeypatch.setattr("tools.voice_mode._import_audio", _fake_import) + + from tools.voice_mode import play_audio_file + + result = play_audio_file(sample_wav) + + assert result is True + mock_sd_obj.play.assert_called_once() + mock_sd_obj.stop.assert_called_once() + + def test_returns_false_when_no_player(self, monkeypatch, sample_wav): + def _fail_import(): + raise ImportError("no sounddevice") + monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import) + monkeypatch.setattr("shutil.which", lambda _: None) + + from tools.voice_mode import play_audio_file + + result = play_audio_file(sample_wav) + assert result is False + + def test_returns_false_for_missing_file(self): + from tools.voice_mode import play_audio_file + + result = play_audio_file("/nonexistent/file.wav") + assert result is False + + +# ============================================================================ +# cleanup_temp_recordings +# ============================================================================ + +class TestCleanupTempRecordings: + def test_old_files_deleted(self, temp_voice_dir): + # Create an "old" file + old_file = temp_voice_dir / "recording_20240101_000000.wav" + old_file.write_bytes(b"\x00" * 100) + # Set mtime to 2 hours ago + old_mtime = time.time() - 7200 + os.utime(str(old_file), (old_mtime, old_mtime)) + + from tools.voice_mode import cleanup_temp_recordings + + deleted = cleanup_temp_recordings(max_age_seconds=3600) + assert deleted == 1 + assert not old_file.exists() + + def test_recent_files_preserved(self, temp_voice_dir): + # Create a "recent" file + recent_file = temp_voice_dir / "recording_20260303_120000.wav" + recent_file.write_bytes(b"\x00" * 100) + + from tools.voice_mode import cleanup_temp_recordings + + deleted = cleanup_temp_recordings(max_age_seconds=3600) + assert deleted == 0 + assert recent_file.exists() + + def test_nonexistent_dir_returns_zero(self, monkeypatch): + monkeypatch.setattr("tools.voice_mode._TEMP_DIR", "/nonexistent/dir") + + from tools.voice_mode import cleanup_temp_recordings + + assert cleanup_temp_recordings() == 0 + + def test_non_recording_files_ignored(self, temp_voice_dir): + # Create a file that doesn't match the pattern + other_file = temp_voice_dir / "other_file.txt" + other_file.write_bytes(b"\x00" * 100) + old_mtime = time.time() - 7200 + os.utime(str(other_file), (old_mtime, old_mtime)) + + from tools.voice_mode import cleanup_temp_recordings + + deleted = cleanup_temp_recordings(max_age_seconds=3600) + assert deleted == 0 + assert other_file.exists() + + +# ============================================================================ +# play_beep +# ============================================================================ + +class TestPlayBeep: + def test_beep_calls_sounddevice_play(self, mock_sd): + np = pytest.importorskip("numpy") + + from tools.voice_mode import play_beep + + # play_beep uses polling (get_stream) + sd.stop() instead of sd.wait() + mock_stream = MagicMock() + mock_stream.active = False + mock_sd.get_stream.return_value = mock_stream + + play_beep(frequency=880, duration=0.1, count=1) + + mock_sd.play.assert_called_once() + mock_sd.stop.assert_called() + # Verify audio data is int16 numpy array + audio_arg = mock_sd.play.call_args[0][0] + assert audio_arg.dtype == np.int16 + assert len(audio_arg) > 0 + + def test_beep_double_produces_longer_audio(self, mock_sd): + np = pytest.importorskip("numpy") + + from tools.voice_mode import play_beep + + play_beep(frequency=660, duration=0.1, count=2) + + audio_arg = mock_sd.play.call_args[0][0] + single_beep_samples = int(16000 * 0.1) + # Double beep should be longer than a single beep + assert len(audio_arg) > single_beep_samples + + def test_beep_noop_without_audio(self, monkeypatch): + def _fail_import(): + raise ImportError("no sounddevice") + monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import) + + from tools.voice_mode import play_beep + + # Should not raise + play_beep() + + def test_beep_handles_playback_error(self, mock_sd): + mock_sd.play.side_effect = Exception("device error") + + from tools.voice_mode import play_beep + + # Should not raise + play_beep() + + +# ============================================================================ +# Silence detection +# ============================================================================ + +class TestSilenceDetection: + def test_silence_callback_fires_after_speech_then_silence(self, mock_sd): + np = pytest.importorskip("numpy") + import threading + + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder, SAMPLE_RATE + + recorder = AudioRecorder() + # Use very short durations for testing + recorder._silence_duration = 0.05 + recorder._min_speech_duration = 0.05 + + fired = threading.Event() + + def on_silence(): + fired.set() + + recorder.start(on_silence_stop=on_silence) + + # Get the callback function from InputStream constructor + callback = mock_sd.InputStream.call_args.kwargs.get("callback") + if callback is None: + callback = mock_sd.InputStream.call_args[1]["callback"] + + # Simulate sustained speech (multiple loud chunks to exceed min_speech_duration) + loud_frame = np.full((1600, 1), 5000, dtype="int16") + callback(loud_frame, 1600, None, None) + time.sleep(0.06) + callback(loud_frame, 1600, None, None) + assert recorder._has_spoken is True + + # Simulate silence + silent_frame = np.zeros((1600, 1), dtype="int16") + callback(silent_frame, 1600, None, None) + + # Wait a bit past the silence duration, then send another silent frame + time.sleep(0.06) + callback(silent_frame, 1600, None, None) + + # The callback should have been fired + assert fired.wait(timeout=1.0) is True + + recorder.cancel() + + def test_silence_without_speech_does_not_fire(self, mock_sd): + np = pytest.importorskip("numpy") + import threading + + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + recorder._silence_duration = 0.02 + + fired = threading.Event() + recorder.start(on_silence_stop=lambda: fired.set()) + + callback = mock_sd.InputStream.call_args.kwargs.get("callback") + if callback is None: + callback = mock_sd.InputStream.call_args[1]["callback"] + + # Only silence -- no speech detected, so callback should NOT fire + silent_frame = np.zeros((1600, 1), dtype="int16") + for _ in range(5): + callback(silent_frame, 1600, None, None) + time.sleep(0.01) + + assert fired.wait(timeout=0.2) is False + + recorder.cancel() + + def test_micro_pause_tolerance_during_speech(self, mock_sd): + """Brief dips below threshold during speech should NOT reset speech tracking.""" + np = pytest.importorskip("numpy") + import threading + + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + recorder._silence_duration = 0.05 + recorder._min_speech_duration = 0.15 + recorder._max_dip_tolerance = 0.1 + + fired = threading.Event() + recorder.start(on_silence_stop=lambda: fired.set()) + + callback = mock_sd.InputStream.call_args.kwargs.get("callback") + if callback is None: + callback = mock_sd.InputStream.call_args[1]["callback"] + + loud_frame = np.full((1600, 1), 5000, dtype="int16") + quiet_frame = np.full((1600, 1), 50, dtype="int16") + + # Speech chunk 1 + callback(loud_frame, 1600, None, None) + time.sleep(0.05) + # Brief micro-pause (dip < max_dip_tolerance) + callback(quiet_frame, 1600, None, None) + time.sleep(0.05) + # Speech resumes -- speech_start should NOT have been reset + callback(loud_frame, 1600, None, None) + assert recorder._speech_start > 0, "Speech start should be preserved across brief dips" + time.sleep(0.06) + # Another speech chunk to exceed min_speech_duration + callback(loud_frame, 1600, None, None) + assert recorder._has_spoken is True, "Speech should be confirmed after tolerating micro-pause" + + recorder.cancel() + + def test_no_callback_means_no_silence_detection(self, mock_sd): + np = pytest.importorskip("numpy") + + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + recorder.start() # no on_silence_stop + + callback = mock_sd.InputStream.call_args.kwargs.get("callback") + if callback is None: + callback = mock_sd.InputStream.call_args[1]["callback"] + + # Even with speech then silence, nothing should happen + loud_frame = np.full((1600, 1), 5000, dtype="int16") + silent_frame = np.zeros((1600, 1), dtype="int16") + callback(loud_frame, 1600, None, None) + callback(silent_frame, 1600, None, None) + + # No crash, no callback + assert recorder._on_silence_stop is None + recorder.cancel() + + +# ============================================================================ +# Playback interrupt +# ============================================================================ + +class TestPlaybackInterrupt: + """Verify that TTS playback can be interrupted.""" + + def test_stop_playback_terminates_process(self): + from tools.voice_mode import stop_playback, _playback_lock + import tools.voice_mode as vm + + mock_proc = MagicMock() + mock_proc.poll.return_value = None # process is running + + with _playback_lock: + vm._active_playback = mock_proc + + stop_playback() + + mock_proc.terminate.assert_called_once() + + with _playback_lock: + assert vm._active_playback is None + + def test_stop_playback_noop_when_nothing_playing(self): + import tools.voice_mode as vm + + with vm._playback_lock: + vm._active_playback = None + + vm.stop_playback() + + def test_play_audio_file_sets_active_playback(self, monkeypatch, sample_wav): + import tools.voice_mode as vm + + def _fail_import(): + raise ImportError("no sounddevice") + monkeypatch.setattr("tools.voice_mode._import_audio", _fail_import) + + mock_proc = MagicMock() + mock_proc.wait.return_value = 0 + + mock_popen = MagicMock(return_value=mock_proc) + monkeypatch.setattr("subprocess.Popen", mock_popen) + monkeypatch.setattr("shutil.which", lambda cmd: "/usr/bin/" + cmd) + + vm.play_audio_file(sample_wav) + + assert mock_popen.called + with vm._playback_lock: + assert vm._active_playback is None + + +# ============================================================================ +# Continuous mode flow +# ============================================================================ + +class TestContinuousModeFlow: + """Verify continuous mode: auto-restart after transcription or silence.""" + + def test_continuous_restart_on_no_speech(self, mock_sd, temp_voice_dir): + np = pytest.importorskip("numpy") + + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + + # First recording: only silence -> stop returns None + recorder.start() + callback = mock_sd.InputStream.call_args.kwargs.get("callback") + if callback is None: + callback = mock_sd.InputStream.call_args[1]["callback"] + + for _ in range(10): + silence = np.full((1600, 1), 10, dtype="int16") + callback(silence, 1600, None, None) + + wav_path = recorder.stop() + assert wav_path is None + + # Simulate continuous mode restart + recorder.start() + assert recorder.is_recording is True + + callback = mock_sd.InputStream.call_args.kwargs.get("callback") + if callback is None: + callback = mock_sd.InputStream.call_args[1]["callback"] + + for _ in range(10): + speech = np.full((1600, 1), 5000, dtype="int16") + callback(speech, 1600, None, None) + + wav_path = recorder.stop() + assert wav_path is not None + + recorder.cancel() + + def test_recorder_reusable_after_stop(self, mock_sd, temp_voice_dir): + np = pytest.importorskip("numpy") + + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + results = [] + + for i in range(3): + recorder.start() + callback = mock_sd.InputStream.call_args.kwargs.get("callback") + if callback is None: + callback = mock_sd.InputStream.call_args[1]["callback"] + loud = np.full((1600, 1), 5000, dtype="int16") + for _ in range(10): + callback(loud, 1600, None, None) + wav_path = recorder.stop() + results.append(wav_path) + + assert all(r is not None for r in results) + assert os.path.isfile(results[-1]) + + +# ============================================================================ +# Audio level indicator +# ============================================================================ + +class TestAudioLevelIndicator: + """Verify current_rms property updates in real-time for UI feedback.""" + + def test_rms_updates_with_audio_chunks(self, mock_sd): + np = pytest.importorskip("numpy") + + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + recorder.start() + callback = mock_sd.InputStream.call_args.kwargs.get("callback") + if callback is None: + callback = mock_sd.InputStream.call_args[1]["callback"] + + assert recorder.current_rms == 0 + + loud = np.full((1600, 1), 5000, dtype="int16") + callback(loud, 1600, None, None) + assert recorder.current_rms == 5000 + + quiet = np.full((1600, 1), 100, dtype="int16") + callback(quiet, 1600, None, None) + assert recorder.current_rms == 100 + + recorder.cancel() + + def test_peak_rms_tracks_maximum(self, mock_sd): + np = pytest.importorskip("numpy") + + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder + + recorder = AudioRecorder() + recorder.start() + callback = mock_sd.InputStream.call_args.kwargs.get("callback") + if callback is None: + callback = mock_sd.InputStream.call_args[1]["callback"] + + frames = [ + np.full((1600, 1), 100, dtype="int16"), + np.full((1600, 1), 8000, dtype="int16"), + np.full((1600, 1), 500, dtype="int16"), + np.full((1600, 1), 3000, dtype="int16"), + ] + for frame in frames: + callback(frame, 1600, None, None) + + assert recorder._peak_rms == 8000 + assert recorder.current_rms == 3000 + + recorder.cancel() + + +# ============================================================================ +# Configurable silence parameters +# ============================================================================ + +class TestConfigurableSilenceParams: + """Verify that silence detection params can be configured.""" + + def test_custom_threshold_and_duration(self, mock_sd): + np = pytest.importorskip("numpy") + + mock_stream = MagicMock() + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder + import threading + + recorder = AudioRecorder() + recorder._silence_threshold = 5000 + recorder._silence_duration = 0.05 + recorder._min_speech_duration = 0.05 + + fired = threading.Event() + recorder.start(on_silence_stop=lambda: fired.set()) + callback = mock_sd.InputStream.call_args.kwargs.get("callback") + if callback is None: + callback = mock_sd.InputStream.call_args[1]["callback"] + + # Audio at RMS 1000 -- below custom threshold (5000) + moderate = np.full((1600, 1), 1000, dtype="int16") + for _ in range(5): + callback(moderate, 1600, None, None) + time.sleep(0.02) + + assert recorder._has_spoken is False + assert fired.wait(timeout=0.2) is False + + # Now send really loud audio (above 5000 threshold) + very_loud = np.full((1600, 1), 8000, dtype="int16") + callback(very_loud, 1600, None, None) + time.sleep(0.06) + callback(very_loud, 1600, None, None) + assert recorder._has_spoken is True + + recorder.cancel() + + +# ============================================================================ +# Bugfix regression tests +# ============================================================================ + + +class TestSubprocessTimeoutKill: + """Bug: proc.wait(timeout) raised TimeoutExpired but process was not killed.""" + + def test_timeout_kills_process(self): + import subprocess, os + proc = subprocess.Popen(["sleep", "600"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + pid = proc.pid + assert proc.poll() is None + + try: + proc.wait(timeout=0.1) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait() + + assert proc.poll() is not None + assert proc.returncode is not None + + +class TestStreamLeakOnStartFailure: + """Bug: stream.start() failure left stream unclosed.""" + + def test_stream_closed_on_start_failure(self, mock_sd): + mock_stream = MagicMock() + mock_stream.start.side_effect = OSError("Audio device busy") + mock_sd.InputStream.return_value = mock_stream + + from tools.voice_mode import AudioRecorder + recorder = AudioRecorder() + + with pytest.raises(RuntimeError, match="Failed to open audio input stream"): + recorder._ensure_stream() + + mock_stream.close.assert_called_once() + + +class TestSilenceCallbackLock: + """Bug: _on_silence_stop was read/written without lock in audio callback.""" + + def test_fire_block_acquires_lock(self): + import inspect + from tools.voice_mode import AudioRecorder + + source = inspect.getsource(AudioRecorder._ensure_stream) + # Verify lock is used before reading _on_silence_stop in fire block + assert "with self._lock:" in source + assert "cb = self._on_silence_stop" in source + lock_pos = source.index("with self._lock:") + cb_pos = source.index("cb = self._on_silence_stop") + assert lock_pos < cb_pos + + def test_cancel_clears_callback_under_lock(self, mock_sd): + from tools.voice_mode import AudioRecorder + recorder = AudioRecorder() + mock_sd.InputStream.return_value = MagicMock() + + cb = lambda: None + recorder.start(on_silence_stop=cb) + assert recorder._on_silence_stop is cb + + recorder.cancel() + with recorder._lock: + assert recorder._on_silence_stop is None diff --git a/tools/browser_tool.py b/tools/browser_tool.py index 15f4961897..b3516c4f24 100644 --- a/tools/browser_tool.py +++ b/tools/browser_tool.py @@ -224,24 +224,14 @@ def _emergency_cleanup_all_sessions(): logger.error("Emergency cleanup error: %s", e) -def _signal_handler(signum, frame): - """Handle interrupt signals to cleanup sessions before exit.""" - logger.warning("Received signal %s, cleaning up...", signum) - _emergency_cleanup_all_sessions() - sys.exit(128 + signum) - - -# Register cleanup handlers +# Register cleanup via atexit only. Previous versions installed SIGINT/SIGTERM +# handlers that called sys.exit(), but this conflicts with prompt_toolkit's +# async event loop — a SystemExit raised inside a key-binding callback +# corrupts the coroutine state and makes the process unkillable. atexit +# handlers run on any normal exit (including sys.exit), so browser sessions +# are still cleaned up without hijacking signals. atexit.register(_emergency_cleanup_all_sessions) -# Only register signal handlers in main process (not in multiprocessing workers) -try: - if os.getpid() == os.getpgrp(): # Main process check - signal.signal(signal.SIGINT, _signal_handler) - signal.signal(signal.SIGTERM, _signal_handler) -except (OSError, AttributeError): - pass # Signal handling not available (e.g., Windows or worker process) - # ============================================================================= # Inactivity Cleanup Functions diff --git a/tools/transcription_tools.py b/tools/transcription_tools.py index 96b7a95e2d..a20ba41341 100644 --- a/tools/transcription_tools.py +++ b/tools/transcription_tools.py @@ -2,11 +2,12 @@ """ Transcription Tools Module -Provides speech-to-text transcription with two providers: +Provides speech-to-text transcription with three providers: - **local** (default, free) — faster-whisper running locally, no API key needed. Auto-downloads the model (~150 MB for ``base``) on first use. - - **openai** — OpenAI Whisper API, requires ``VOICE_TOOLS_OPENAI_KEY``. + - **groq** (free tier) — Groq Whisper API, requires ``GROQ_API_KEY``. + - **openai** (paid) — OpenAI Whisper API, requires ``VOICE_TOOLS_OPENAI_KEY``. Used by the messaging gateway to automatically transcribe voice messages sent by users on Telegram, Discord, WhatsApp, Slack, and Signal. @@ -33,18 +34,9 @@ logger = logging.getLogger(__name__) # Optional imports — graceful degradation # --------------------------------------------------------------------------- -try: - from faster_whisper import WhisperModel - _HAS_FASTER_WHISPER = True -except ImportError: - _HAS_FASTER_WHISPER = False - WhisperModel = None # type: ignore[assignment,misc] - -try: - from openai import OpenAI, APIError, APIConnectionError, APITimeoutError - _HAS_OPENAI = True -except ImportError: - _HAS_OPENAI = False +import importlib.util as _ilu +_HAS_FASTER_WHISPER = _ilu.find_spec("faster_whisper") is not None +_HAS_OPENAI = _ilu.find_spec("openai") is not None # --------------------------------------------------------------------------- # Constants @@ -52,13 +44,21 @@ except ImportError: DEFAULT_PROVIDER = "local" DEFAULT_LOCAL_MODEL = "base" -DEFAULT_OPENAI_MODEL = "whisper-1" +DEFAULT_STT_MODEL = os.getenv("STT_OPENAI_MODEL", "whisper-1") +DEFAULT_GROQ_STT_MODEL = os.getenv("STT_GROQ_MODEL", "whisper-large-v3-turbo") + +GROQ_BASE_URL = os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1") +OPENAI_BASE_URL = os.getenv("STT_OPENAI_BASE_URL", "https://api.openai.com/v1") SUPPORTED_FORMATS = {".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm", ".ogg"} MAX_FILE_SIZE = 25 * 1024 * 1024 # 25 MB +# Known model sets for auto-correction +OPENAI_MODELS = {"whisper-1", "gpt-4o-mini-transcribe", "gpt-4o-transcribe"} +GROQ_MODELS = {"whisper-large-v3", "whisper-large-v3-turbo", "distil-whisper-large-v3-en"} + # Singleton for the local model — loaded once, reused across calls -_local_model: Optional["WhisperModel"] = None +_local_model: Optional[object] = None _local_model_name: Optional[str] = None # --------------------------------------------------------------------------- @@ -66,6 +66,24 @@ _local_model_name: Optional[str] = None # --------------------------------------------------------------------------- +def get_stt_model_from_config() -> Optional[str]: + """Read the STT model name from ~/.hermes/config.yaml. + + Returns the value of ``stt.model`` if present, otherwise ``None``. + Silently returns ``None`` on any error (missing file, bad YAML, etc.). + """ + try: + import yaml + cfg_path = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "config.yaml" + if cfg_path.exists(): + with open(cfg_path) as f: + data = yaml.safe_load(f) or {} + return data.get("stt", {}).get("model") + except Exception: + pass + return None + + def _load_stt_config() -> dict: """Load the ``stt`` section from user config, falling back to defaults.""" try: @@ -80,7 +98,7 @@ def _get_provider(stt_config: dict) -> str: Priority: 1. Explicit config value (``stt.provider``) - 2. Auto-detect: local if faster-whisper available, else openai if key set + 2. Auto-detect: local > groq (free) > openai (paid) 3. Disabled (returns "none") """ provider = stt_config.get("provider", DEFAULT_PROVIDER) @@ -88,19 +106,37 @@ def _get_provider(stt_config: dict) -> str: if provider == "local": if _HAS_FASTER_WHISPER: return "local" - # Local requested but not available — fall back to openai if possible + # Local requested but not available — fall back to groq, then openai + if _HAS_OPENAI and os.getenv("GROQ_API_KEY"): + logger.info("faster-whisper not installed, falling back to Groq Whisper API") + return "groq" if _HAS_OPENAI and os.getenv("VOICE_TOOLS_OPENAI_KEY"): logger.info("faster-whisper not installed, falling back to OpenAI Whisper API") return "openai" return "none" + if provider == "groq": + if _HAS_OPENAI and os.getenv("GROQ_API_KEY"): + return "groq" + # Groq requested but no key — fall back + if _HAS_FASTER_WHISPER: + logger.info("GROQ_API_KEY not set, falling back to local faster-whisper") + return "local" + if _HAS_OPENAI and os.getenv("VOICE_TOOLS_OPENAI_KEY"): + logger.info("GROQ_API_KEY not set, falling back to OpenAI Whisper API") + return "openai" + return "none" + if provider == "openai": if _HAS_OPENAI and os.getenv("VOICE_TOOLS_OPENAI_KEY"): return "openai" - # OpenAI requested but no key — fall back to local if possible + # OpenAI requested but no key — fall back if _HAS_FASTER_WHISPER: logger.info("VOICE_TOOLS_OPENAI_KEY not set, falling back to local faster-whisper") return "local" + if _HAS_OPENAI and os.getenv("GROQ_API_KEY"): + logger.info("VOICE_TOOLS_OPENAI_KEY not set, falling back to Groq Whisper API") + return "groq" return "none" return provider # Unknown — let it fail downstream @@ -150,6 +186,7 @@ def _transcribe_local(file_path: str, model_name: str) -> Dict[str, Any]: return {"success": False, "transcript": "", "error": "faster-whisper not installed"} try: + from faster_whisper import WhisperModel # Lazy-load the model (downloads on first use, ~150 MB for 'base') if _local_model is None or _local_model_name != model_name: logger.info("Loading faster-whisper model '%s' (first load downloads the model)...", model_name) @@ -164,12 +201,60 @@ def _transcribe_local(file_path: str, model_name: str) -> Dict[str, Any]: Path(file_path).name, model_name, info.language, info.duration, ) - return {"success": True, "transcript": transcript} + return {"success": True, "transcript": transcript, "provider": "local"} except Exception as e: logger.error("Local transcription failed: %s", e, exc_info=True) return {"success": False, "transcript": "", "error": f"Local transcription failed: {e}"} +# --------------------------------------------------------------------------- +# Provider: groq (Whisper API — free tier) +# --------------------------------------------------------------------------- + + +def _transcribe_groq(file_path: str, model_name: str) -> Dict[str, Any]: + """Transcribe using Groq Whisper API (free tier available).""" + api_key = os.getenv("GROQ_API_KEY") + if not api_key: + return {"success": False, "transcript": "", "error": "GROQ_API_KEY not set"} + + if not _HAS_OPENAI: + return {"success": False, "transcript": "", "error": "openai package not installed"} + + # Auto-correct model if caller passed an OpenAI-only model + if model_name in OPENAI_MODELS: + logger.info("Model %s not available on Groq, using %s", model_name, DEFAULT_GROQ_STT_MODEL) + model_name = DEFAULT_GROQ_STT_MODEL + + try: + from openai import OpenAI, APIError, APIConnectionError, APITimeoutError + client = OpenAI(api_key=api_key, base_url=GROQ_BASE_URL, timeout=30, max_retries=0) + + with open(file_path, "rb") as audio_file: + transcription = client.audio.transcriptions.create( + model=model_name, + file=audio_file, + response_format="text", + ) + + transcript_text = str(transcription).strip() + logger.info("Transcribed %s via Groq API (%s, %d chars)", + Path(file_path).name, model_name, len(transcript_text)) + + return {"success": True, "transcript": transcript_text, "provider": "groq"} + + except PermissionError: + return {"success": False, "transcript": "", "error": f"Permission denied: {file_path}"} + except APIConnectionError as e: + return {"success": False, "transcript": "", "error": f"Connection error: {e}"} + except APITimeoutError as e: + return {"success": False, "transcript": "", "error": f"Request timeout: {e}"} + except APIError as e: + return {"success": False, "transcript": "", "error": f"API error: {e}"} + except Exception as e: + logger.error("Groq transcription failed: %s", e, exc_info=True) + return {"success": False, "transcript": "", "error": f"Transcription failed: {e}"} + # --------------------------------------------------------------------------- # Provider: openai (Whisper API) # --------------------------------------------------------------------------- @@ -184,8 +269,14 @@ def _transcribe_openai(file_path: str, model_name: str) -> Dict[str, Any]: if not _HAS_OPENAI: return {"success": False, "transcript": "", "error": "openai package not installed"} + # Auto-correct model if caller passed a Groq-only model + if model_name in GROQ_MODELS: + logger.info("Model %s not available on OpenAI, using %s", model_name, DEFAULT_STT_MODEL) + model_name = DEFAULT_STT_MODEL + try: - client = OpenAI(api_key=api_key, base_url="https://api.openai.com/v1") + from openai import OpenAI, APIError, APIConnectionError, APITimeoutError + client = OpenAI(api_key=api_key, base_url=OPENAI_BASE_URL, timeout=30, max_retries=0) with open(file_path, "rb") as audio_file: transcription = client.audio.transcriptions.create( @@ -198,7 +289,7 @@ def _transcribe_openai(file_path: str, model_name: str) -> Dict[str, Any]: logger.info("Transcribed %s via OpenAI API (%s, %d chars)", Path(file_path).name, model_name, len(transcript_text)) - return {"success": True, "transcript": transcript_text} + return {"success": True, "transcript": transcript_text, "provider": "openai"} except PermissionError: return {"success": False, "transcript": "", "error": f"Permission denied: {file_path}"} @@ -223,7 +314,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A Provider priority: 1. User config (``stt.provider`` in config.yaml) - 2. Auto-detect: local faster-whisper if available, else OpenAI API + 2. Auto-detect: local faster-whisper (free) > Groq (free tier) > OpenAI (paid) Args: file_path: Absolute path to the audio file to transcribe. @@ -234,6 +325,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A - "success" (bool): Whether transcription succeeded - "transcript" (str): The transcribed text (empty on failure) - "error" (str, optional): Error message if success is False + - "provider" (str, optional): Which provider was used """ # Validate input error = _validate_audio_file(file_path) @@ -249,9 +341,13 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A model_name = model or local_cfg.get("model", DEFAULT_LOCAL_MODEL) return _transcribe_local(file_path, model_name) + if provider == "groq": + model_name = model or DEFAULT_GROQ_STT_MODEL + return _transcribe_groq(file_path, model_name) + if provider == "openai": openai_cfg = stt_config.get("openai", {}) - model_name = model or openai_cfg.get("model", DEFAULT_OPENAI_MODEL) + model_name = model or openai_cfg.get("model", DEFAULT_STT_MODEL) return _transcribe_openai(file_path, model_name) # No provider available @@ -260,6 +356,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A "transcript": "", "error": ( "No STT provider available. Install faster-whisper for free local " - "transcription, or set VOICE_TOOLS_OPENAI_KEY for the OpenAI Whisper API." + "transcription, set GROQ_API_KEY for free Groq Whisper, " + "or set VOICE_TOOLS_OPENAI_KEY for the OpenAI Whisper API." ), } diff --git a/tools/tts_tool.py b/tools/tts_tool.py index 3544b20fd8..286bb14b4e 100644 --- a/tools/tts_tool.py +++ b/tools/tts_tool.py @@ -25,35 +25,41 @@ import datetime import json import logging import os +import queue +import re import shutil import subprocess import tempfile +import threading from pathlib import Path -from typing import Dict, Any, Optional +from typing import Callable, Dict, Any, Optional logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- -# Optional imports -- providers degrade gracefully if not installed +# Lazy imports -- providers are imported only when actually used to avoid +# crashing in headless environments (SSH, Docker, WSL, no PortAudio). # --------------------------------------------------------------------------- -try: + +def _import_edge_tts(): + """Lazy import edge_tts. Returns the module or raises ImportError.""" import edge_tts - _HAS_EDGE_TTS = True -except ImportError: - _HAS_EDGE_TTS = False + return edge_tts -try: +def _import_elevenlabs(): + """Lazy import ElevenLabs client. Returns the class or raises ImportError.""" from elevenlabs.client import ElevenLabs - _HAS_ELEVENLABS = True -except ImportError: - _HAS_ELEVENLABS = False + return ElevenLabs -# openai is a core dependency, but guard anyway -try: +def _import_openai_client(): + """Lazy import OpenAI client. Returns the class or raises ImportError.""" from openai import OpenAI as OpenAIClient - _HAS_OPENAI = True -except ImportError: - _HAS_OPENAI = False + return OpenAIClient + +def _import_sounddevice(): + """Lazy import sounddevice. Returns the module or raises ImportError/OSError.""" + import sounddevice as sd + return sd # =========================================================================== @@ -63,6 +69,7 @@ DEFAULT_PROVIDER = "edge" DEFAULT_EDGE_VOICE = "en-US-AriaNeural" DEFAULT_ELEVENLABS_VOICE_ID = "pNInz6obpgDQGcFmaJgB" # Adam DEFAULT_ELEVENLABS_MODEL_ID = "eleven_multilingual_v2" +DEFAULT_ELEVENLABS_STREAMING_MODEL_ID = "eleven_flash_v2_5" DEFAULT_OPENAI_MODEL = "gpt-4o-mini-tts" DEFAULT_OPENAI_VOICE = "alloy" DEFAULT_OUTPUT_DIR = str(Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "audio_cache") @@ -154,10 +161,11 @@ async def _generate_edge_tts(text: str, output_path: str, tts_config: Dict[str, Returns: Path to the saved audio file. """ + _edge_tts = _import_edge_tts() edge_config = tts_config.get("edge", {}) voice = edge_config.get("voice", DEFAULT_EDGE_VOICE) - communicate = edge_tts.Communicate(text, voice) + communicate = _edge_tts.Communicate(text, voice) await communicate.save(output_path) return output_path @@ -191,6 +199,7 @@ def _generate_elevenlabs(text: str, output_path: str, tts_config: Dict[str, Any] else: output_format = "mp3_44100_128" + ElevenLabs = _import_elevenlabs() client = ElevenLabs(api_key=api_key) audio_generator = client.text_to_speech.convert( text=text, @@ -236,6 +245,7 @@ def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any] else: response_format = "mp3" + OpenAIClient = _import_openai_client() client = OpenAIClient(api_key=api_key, base_url="https://api.openai.com/v1") response = client.audio.speech.create( model=model, @@ -311,7 +321,9 @@ def text_to_speech_tool( try: # Generate audio with the configured provider if provider == "elevenlabs": - if not _HAS_ELEVENLABS: + try: + _import_elevenlabs() + except ImportError: return json.dumps({ "success": False, "error": "ElevenLabs provider selected but 'elevenlabs' package not installed. Run: pip install elevenlabs" @@ -320,7 +332,9 @@ def text_to_speech_tool( _generate_elevenlabs(text, file_str, tts_config) elif provider == "openai": - if not _HAS_OPENAI: + try: + _import_openai_client() + except ImportError: return json.dumps({ "success": False, "error": "OpenAI provider selected but 'openai' package not installed." @@ -330,7 +344,9 @@ def text_to_speech_tool( else: # Default: Edge TTS (free) - if not _HAS_EDGE_TTS: + try: + _import_edge_tts() + except ImportError: return json.dumps({ "success": False, "error": "Edge TTS not available. Run: pip install edge-tts" @@ -411,15 +427,262 @@ def check_tts_requirements() -> bool: Returns: bool: True if at least one provider can work. """ - if _HAS_EDGE_TTS: - return True - if _HAS_ELEVENLABS and os.getenv("ELEVENLABS_API_KEY"): - return True - if _HAS_OPENAI and os.getenv("VOICE_TOOLS_OPENAI_KEY"): + try: + _import_edge_tts() return True + except ImportError: + pass + try: + _import_elevenlabs() + if os.getenv("ELEVENLABS_API_KEY"): + return True + except ImportError: + pass + try: + _import_openai_client() + if os.getenv("VOICE_TOOLS_OPENAI_KEY"): + return True + except ImportError: + pass return False +# =========================================================================== +# Streaming TTS: sentence-by-sentence pipeline for ElevenLabs +# =========================================================================== +# Sentence boundary pattern: punctuation followed by space or newline +_SENTENCE_BOUNDARY_RE = re.compile(r'(?<=[.!?])(?:\s|\n)|(?:\n\n)') + +# Markdown stripping patterns (same as cli.py _voice_speak_response) +_MD_CODE_BLOCK = re.compile(r'```[\s\S]*?```') +_MD_LINK = re.compile(r'\[([^\]]+)\]\([^)]+\)') +_MD_URL = re.compile(r'https?://\S+') +_MD_BOLD = re.compile(r'\*\*(.+?)\*\*') +_MD_ITALIC = re.compile(r'\*(.+?)\*') +_MD_INLINE_CODE = re.compile(r'`(.+?)`') +_MD_HEADER = re.compile(r'^#+\s*', flags=re.MULTILINE) +_MD_LIST_ITEM = re.compile(r'^\s*[-*]\s+', flags=re.MULTILINE) +_MD_HR = re.compile(r'---+') +_MD_EXCESS_NL = re.compile(r'\n{3,}') + + +def _strip_markdown_for_tts(text: str) -> str: + """Remove markdown formatting that shouldn't be spoken aloud.""" + text = _MD_CODE_BLOCK.sub(' ', text) + text = _MD_LINK.sub(r'\1', text) + text = _MD_URL.sub('', text) + text = _MD_BOLD.sub(r'\1', text) + text = _MD_ITALIC.sub(r'\1', text) + text = _MD_INLINE_CODE.sub(r'\1', text) + text = _MD_HEADER.sub('', text) + text = _MD_LIST_ITEM.sub('', text) + text = _MD_HR.sub('', text) + text = _MD_EXCESS_NL.sub('\n\n', text) + return text.strip() + + +def stream_tts_to_speaker( + text_queue: queue.Queue, + stop_event: threading.Event, + tts_done_event: threading.Event, + display_callback: Optional[Callable[[str], None]] = None, +): + """Consume text deltas from *text_queue*, buffer them into sentences, + and stream each sentence through ElevenLabs TTS to the speaker in + real-time. + + Protocol: + * The producer puts ``str`` deltas onto *text_queue*. + * A ``None`` sentinel signals end-of-text (flush remaining buffer). + * *stop_event* can be set to abort early (e.g. user interrupt). + * *tts_done_event* is **set** in the ``finally`` block so callers + waiting on it (continuous voice mode) know playback is finished. + """ + tts_done_event.clear() + + try: + # --- TTS client setup (optional -- display_callback works without it) --- + client = None + output_stream = None + voice_id = DEFAULT_ELEVENLABS_VOICE_ID + model_id = DEFAULT_ELEVENLABS_STREAMING_MODEL_ID + + tts_config = _load_tts_config() + el_config = tts_config.get("elevenlabs", {}) + voice_id = el_config.get("voice_id", voice_id) + model_id = el_config.get("streaming_model_id", + el_config.get("model_id", model_id)) + + api_key = os.getenv("ELEVENLABS_API_KEY", "") + if not api_key: + logger.warning("ELEVENLABS_API_KEY not set; streaming TTS audio disabled") + else: + try: + ElevenLabs = _import_elevenlabs() + client = ElevenLabs(api_key=api_key) + except ImportError: + logger.warning("elevenlabs package not installed; streaming TTS disabled") + + # Open a single sounddevice output stream for the lifetime of + # this function. ElevenLabs pcm_24000 produces signed 16-bit + # little-endian mono PCM at 24 kHz. + if client is not None: + try: + sd = _import_sounddevice() + import numpy as _np + output_stream = sd.OutputStream( + samplerate=24000, channels=1, dtype="int16", + ) + output_stream.start() + except (ImportError, OSError) as exc: + logger.debug("sounddevice not available: %s", exc) + output_stream = None + except Exception as exc: + logger.warning("sounddevice OutputStream failed: %s", exc) + output_stream = None + + sentence_buf = "" + min_sentence_len = 20 + long_flush_len = 100 + queue_timeout = 0.5 + _spoken_sentences: list[str] = [] # track spoken sentences to skip duplicates + # Regex to strip complete ... blocks from buffer + _think_block_re = re.compile(r'].*?', flags=re.DOTALL) + + def _speak_sentence(sentence: str): + """Display sentence and optionally generate + play audio.""" + if stop_event.is_set(): + return + cleaned = _strip_markdown_for_tts(sentence).strip() + if not cleaned: + return + # Skip duplicate/near-duplicate sentences (LLM repetition) + cleaned_lower = cleaned.lower().rstrip(".!,") + for prev in _spoken_sentences: + if prev.lower().rstrip(".!,") == cleaned_lower: + return + _spoken_sentences.append(cleaned) + # Display raw sentence on screen before TTS processing + if display_callback is not None: + display_callback(sentence) + # Skip audio generation if no TTS client available + if client is None: + return + # Truncate very long sentences + if len(cleaned) > MAX_TEXT_LENGTH: + cleaned = cleaned[:MAX_TEXT_LENGTH] + try: + audio_iter = client.text_to_speech.convert( + text=cleaned, + voice_id=voice_id, + model_id=model_id, + output_format="pcm_24000", + ) + if output_stream is not None: + for chunk in audio_iter: + if stop_event.is_set(): + break + import numpy as _np + audio_array = _np.frombuffer(chunk, dtype=_np.int16) + output_stream.write(audio_array.reshape(-1, 1)) + else: + # Fallback: write chunks to temp file and play via system player + _play_via_tempfile(audio_iter, stop_event) + except Exception as exc: + logger.warning("Streaming TTS sentence failed: %s", exc) + + def _play_via_tempfile(audio_iter, stop_evt): + """Write PCM chunks to a temp WAV file and play it.""" + tmp_path = None + try: + import wave + tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + tmp_path = tmp.name + with wave.open(tmp, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) # 16-bit + wf.setframerate(24000) + for chunk in audio_iter: + if stop_evt.is_set(): + break + wf.writeframes(chunk) + from tools.voice_mode import play_audio_file + play_audio_file(tmp_path) + except Exception as exc: + logger.warning("Temp-file TTS fallback failed: %s", exc) + finally: + if tmp_path: + try: + os.unlink(tmp_path) + except OSError: + pass + + while not stop_event.is_set(): + # Read next delta from queue + try: + delta = text_queue.get(timeout=queue_timeout) + except queue.Empty: + # Timeout: if we have accumulated a long buffer, flush it + if len(sentence_buf) > long_flush_len: + _speak_sentence(sentence_buf) + sentence_buf = "" + continue + + if delta is None: + # End-of-text sentinel: strip any remaining think blocks, flush + sentence_buf = _think_block_re.sub('', sentence_buf) + if sentence_buf.strip(): + _speak_sentence(sentence_buf) + break + + sentence_buf += delta + + # --- Think block filtering --- + # Strip complete ... blocks from buffer. + # Works correctly even when tags span multiple deltas. + sentence_buf = _think_block_re.sub('', sentence_buf) + + # If an incomplete ' not in sentence_buf: + continue + + # Check for sentence boundaries + while True: + m = _SENTENCE_BOUNDARY_RE.search(sentence_buf) + if m is None: + break + end_pos = m.end() + sentence = sentence_buf[:end_pos] + sentence_buf = sentence_buf[end_pos:] + # Merge short fragments into the next sentence + if len(sentence.strip()) < min_sentence_len: + sentence_buf = sentence + sentence_buf + break + _speak_sentence(sentence) + + # Drain any remaining items from the queue + while True: + try: + text_queue.get_nowait() + except queue.Empty: + break + + # output_stream is closed in the finally block below + + except Exception as exc: + logger.warning("Streaming TTS pipeline error: %s", exc) + finally: + # Always close the audio output stream to avoid locking the device + if output_stream is not None: + try: + output_stream.stop() + output_stream.close() + except Exception: + pass + tts_done_event.set() + + # =========================================================================== # Main -- quick diagnostics # =========================================================================== @@ -427,12 +690,19 @@ if __name__ == "__main__": print("🔊 Text-to-Speech Tool Module") print("=" * 50) + def _check(importer, label): + try: + importer() + return True + except ImportError: + return False + print(f"\nProvider availability:") - print(f" Edge TTS: {'✅ installed' if _HAS_EDGE_TTS else '❌ not installed (pip install edge-tts)'}") - print(f" ElevenLabs: {'✅ installed' if _HAS_ELEVENLABS else '❌ not installed (pip install elevenlabs)'}") - print(f" API Key: {'✅ set' if os.getenv('ELEVENLABS_API_KEY') else '❌ not set'}") - print(f" OpenAI: {'✅ installed' if _HAS_OPENAI else '❌ not installed'}") - print(f" API Key: {'✅ set' if os.getenv('VOICE_TOOLS_OPENAI_KEY') else '❌ not set (VOICE_TOOLS_OPENAI_KEY)'}") + print(f" Edge TTS: {'installed' if _check(_import_edge_tts, 'edge') else 'not installed (pip install edge-tts)'}") + print(f" ElevenLabs: {'installed' if _check(_import_elevenlabs, 'el') else 'not installed (pip install elevenlabs)'}") + print(f" API Key: {'set' if os.getenv('ELEVENLABS_API_KEY') else 'not set'}") + print(f" OpenAI: {'installed' if _check(_import_openai_client, 'oai') else 'not installed'}") + print(f" API Key: {'set' if os.getenv('VOICE_TOOLS_OPENAI_KEY') else 'not set (VOICE_TOOLS_OPENAI_KEY)'}") print(f" ffmpeg: {'✅ found' if _has_ffmpeg() else '❌ not found (needed for Telegram Opus)'}") print(f"\n Output dir: {DEFAULT_OUTPUT_DIR}") diff --git a/tools/voice_mode.py b/tools/voice_mode.py new file mode 100644 index 0000000000..a2c70ac1b0 --- /dev/null +++ b/tools/voice_mode.py @@ -0,0 +1,783 @@ +"""Voice Mode -- Push-to-talk audio recording and playback for the CLI. + +Provides audio capture via sounddevice, WAV encoding via stdlib wave, +STT dispatch via tools.transcription_tools, and TTS playback via +sounddevice or system audio players. + +Dependencies (optional): + pip install sounddevice numpy + or: pip install hermes-agent[voice] +""" + +import logging +import os +import platform +import re +import shutil +import subprocess +import tempfile +import threading +import time +import wave +from pathlib import Path +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Lazy audio imports -- never imported at module level to avoid crashing +# in headless environments (SSH, Docker, WSL, no PortAudio). +# --------------------------------------------------------------------------- + +def _import_audio(): + """Lazy-import sounddevice and numpy. Returns (sd, np). + + Raises ImportError or OSError if the libraries are not available + (e.g. PortAudio missing on headless servers). + """ + import sounddevice as sd + import numpy as np + return sd, np + + +def _audio_available() -> bool: + """Return True if audio libraries can be imported.""" + try: + _import_audio() + return True + except (ImportError, OSError): + return False + + +def detect_audio_environment() -> dict: + """Detect if the current environment supports audio I/O. + + Returns dict with 'available' (bool) and 'warnings' (list of strings). + """ + warnings = [] + + # SSH detection + if any(os.environ.get(v) for v in ('SSH_CLIENT', 'SSH_TTY', 'SSH_CONNECTION')): + warnings.append("Running over SSH -- no audio devices available") + + # Docker detection + if os.path.exists('/.dockerenv'): + warnings.append("Running inside Docker container -- no audio devices") + + # WSL detection + try: + with open('/proc/version', 'r') as f: + if 'microsoft' in f.read().lower(): + warnings.append("Running in WSL -- audio requires PulseAudio bridge to Windows") + except (FileNotFoundError, PermissionError, OSError): + pass + + # Check audio libraries + try: + sd, _ = _import_audio() + try: + devices = sd.query_devices() + if not devices: + warnings.append("No audio input/output devices detected") + except Exception: + warnings.append("Audio subsystem error (PortAudio cannot query devices)") + except (ImportError, OSError): + warnings.append("Audio libraries not installed (pip install sounddevice numpy)") + + return { + "available": len(warnings) == 0, + "warnings": warnings, + } + +# --------------------------------------------------------------------------- +# Recording parameters +# --------------------------------------------------------------------------- +SAMPLE_RATE = 16000 # Whisper native rate +CHANNELS = 1 # Mono +DTYPE = "int16" # 16-bit PCM +SAMPLE_WIDTH = 2 # bytes per sample (int16) +MAX_RECORDING_SECONDS = 120 # Safety cap + +# Silence detection defaults +SILENCE_RMS_THRESHOLD = 200 # RMS below this = silence (int16 range 0-32767) +SILENCE_DURATION_SECONDS = 3.0 # Seconds of continuous silence before auto-stop + +# Temp directory for voice recordings +_TEMP_DIR = os.path.join(tempfile.gettempdir(), "hermes_voice") + + +# ============================================================================ +# Audio cues (beep tones) +# ============================================================================ +def play_beep(frequency: int = 880, duration: float = 0.12, count: int = 1) -> None: + """Play a short beep tone using numpy + sounddevice. + + Args: + frequency: Tone frequency in Hz (default 880 = A5). + duration: Duration of each beep in seconds. + count: Number of beeps to play (with short gap between). + """ + try: + sd, np = _import_audio() + except (ImportError, OSError): + return + try: + gap = 0.06 # seconds between beeps + samples_per_beep = int(SAMPLE_RATE * duration) + samples_per_gap = int(SAMPLE_RATE * gap) + + parts = [] + for i in range(count): + t = np.linspace(0, duration, samples_per_beep, endpoint=False) + # Apply fade in/out to avoid click artifacts + tone = np.sin(2 * np.pi * frequency * t) + fade_len = min(int(SAMPLE_RATE * 0.01), samples_per_beep // 4) + tone[:fade_len] *= np.linspace(0, 1, fade_len) + tone[-fade_len:] *= np.linspace(1, 0, fade_len) + parts.append((tone * 0.3 * 32767).astype(np.int16)) + if i < count - 1: + parts.append(np.zeros(samples_per_gap, dtype=np.int16)) + + audio = np.concatenate(parts) + sd.play(audio, samplerate=SAMPLE_RATE) + # sd.wait() calls Event.wait() without timeout — hangs forever if the + # audio device stalls. Poll with a 2s ceiling and force-stop. + deadline = time.monotonic() + 2.0 + while sd.get_stream() and sd.get_stream().active and time.monotonic() < deadline: + time.sleep(0.01) + sd.stop() + except Exception as e: + logger.debug("Beep playback failed: %s", e) + + +# ============================================================================ +# AudioRecorder +# ============================================================================ +class AudioRecorder: + """Thread-safe audio recorder using sounddevice.InputStream. + + Usage:: + + recorder = AudioRecorder() + recorder.start(on_silence_stop=my_callback) + # ... user speaks ... + wav_path = recorder.stop() # returns path to WAV file + # or + recorder.cancel() # discard without saving + + If ``on_silence_stop`` is provided, recording automatically stops when + the user is silent for ``silence_duration`` seconds and calls the callback. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._stream: Any = None + self._frames: List[Any] = [] + self._recording = False + self._start_time: float = 0.0 + # Silence detection state + self._has_spoken = False + self._speech_start: float = 0.0 # When speech attempt began + self._dip_start: float = 0.0 # When current below-threshold dip began + self._min_speech_duration: float = 0.3 # Seconds of speech needed to confirm + self._max_dip_tolerance: float = 0.3 # Max dip duration before resetting speech + self._silence_start: float = 0.0 + self._resume_start: float = 0.0 # Tracks sustained speech after silence starts + self._resume_dip_start: float = 0.0 # Dip tolerance tracker for resume detection + self._on_silence_stop = None + self._silence_threshold: int = SILENCE_RMS_THRESHOLD + self._silence_duration: float = SILENCE_DURATION_SECONDS + self._max_wait: float = 15.0 # Max seconds to wait for speech before auto-stop + # Peak RMS seen during recording (for speech presence check in stop()) + self._peak_rms: int = 0 + # Live audio level (read by UI for visual feedback) + self._current_rms: int = 0 + + # -- public properties --------------------------------------------------- + + @property + def is_recording(self) -> bool: + return self._recording + + @property + def elapsed_seconds(self) -> float: + if not self._recording: + return 0.0 + return time.monotonic() - self._start_time + + @property + def current_rms(self) -> int: + """Current audio input RMS level (0-32767). Updated each audio chunk.""" + return self._current_rms + + # -- public methods ------------------------------------------------------ + + def _ensure_stream(self) -> None: + """Create the audio InputStream once and keep it alive. + + The stream stays open for the lifetime of the recorder. Between + recordings the callback simply discards audio chunks (``_recording`` + is ``False``). This avoids the CoreAudio bug where closing and + re-opening an ``InputStream`` hangs indefinitely on macOS. + """ + if self._stream is not None: + return # already alive + + sd, np = _import_audio() + + def _callback(indata, frames, time_info, status): # noqa: ARG001 + if status: + logger.debug("sounddevice status: %s", status) + # When not recording the stream is idle — discard audio. + if not self._recording: + return + self._frames.append(indata.copy()) + + # Compute RMS for level display and silence detection + rms = int(np.sqrt(np.mean(indata.astype(np.float64) ** 2))) + self._current_rms = rms + if rms > self._peak_rms: + self._peak_rms = rms + + # Silence detection + if self._on_silence_stop is not None: + now = time.monotonic() + elapsed = now - self._start_time + + if rms > self._silence_threshold: + # Audio is above threshold -- this is speech (or noise). + self._dip_start = 0.0 # Reset dip tracker + if self._speech_start == 0.0: + self._speech_start = now + elif not self._has_spoken and now - self._speech_start >= self._min_speech_duration: + self._has_spoken = True + logger.debug("Speech confirmed (%.2fs above threshold)", + now - self._speech_start) + # After speech is confirmed, only reset silence timer if + # speech is sustained (>0.3s above threshold). Brief + # spikes from ambient noise should NOT reset the timer. + if not self._has_spoken: + self._silence_start = 0.0 + else: + # Track resumed speech with dip tolerance. + # Brief dips below threshold are normal during speech, + # so we mirror the initial speech detection pattern: + # start tracking, tolerate short dips, confirm after 0.3s. + self._resume_dip_start = 0.0 # Above threshold — no dip + if self._resume_start == 0.0: + self._resume_start = now + elif now - self._resume_start >= self._min_speech_duration: + self._silence_start = 0.0 + self._resume_start = 0.0 + elif self._has_spoken: + # Below threshold after speech confirmed. + # Use dip tolerance before resetting resume tracker — + # natural speech has brief dips below threshold. + if self._resume_start > 0: + if self._resume_dip_start == 0.0: + self._resume_dip_start = now + elif now - self._resume_dip_start >= self._max_dip_tolerance: + # Sustained dip — user actually stopped speaking + self._resume_start = 0.0 + self._resume_dip_start = 0.0 + elif self._speech_start > 0: + # We were in a speech attempt but RMS dipped. + # Tolerate brief dips (micro-pauses between syllables). + if self._dip_start == 0.0: + self._dip_start = now + elif now - self._dip_start >= self._max_dip_tolerance: + # Dip lasted too long -- genuine silence, reset + logger.debug("Speech attempt reset (dip lasted %.2fs)", + now - self._dip_start) + self._speech_start = 0.0 + self._dip_start = 0.0 + + # Fire silence callback when: + # 1. User spoke then went silent for silence_duration, OR + # 2. No speech detected at all for max_wait seconds + should_fire = False + if self._has_spoken and rms <= self._silence_threshold: + # User was speaking and now is silent + if self._silence_start == 0.0: + self._silence_start = now + elif now - self._silence_start >= self._silence_duration: + logger.info("Silence detected (%.1fs), auto-stopping", + self._silence_duration) + should_fire = True + elif not self._has_spoken and elapsed >= self._max_wait: + logger.info("No speech within %.0fs, auto-stopping", + self._max_wait) + should_fire = True + + if should_fire: + with self._lock: + cb = self._on_silence_stop + self._on_silence_stop = None # fire only once + if cb: + def _safe_cb(): + try: + cb() + except Exception as e: + logger.error("Silence callback failed: %s", e, exc_info=True) + threading.Thread(target=_safe_cb, daemon=True).start() + + # Create stream — may block on CoreAudio (first call only). + stream = None + try: + stream = sd.InputStream( + samplerate=SAMPLE_RATE, + channels=CHANNELS, + dtype=DTYPE, + callback=_callback, + ) + stream.start() + except Exception as e: + if stream is not None: + try: + stream.close() + except Exception: + pass + raise RuntimeError( + f"Failed to open audio input stream: {e}. " + "Check that a microphone is connected and accessible." + ) from e + self._stream = stream + + def start(self, on_silence_stop=None) -> None: + """Start capturing audio from the default input device. + + The underlying InputStream is created once and kept alive across + recordings. Subsequent calls simply reset detection state and + toggle frame collection via ``_recording``. + + Args: + on_silence_stop: Optional callback invoked (in a daemon thread) when + silence is detected after speech. The callback receives no arguments. + Use this to auto-stop recording and trigger transcription. + + Raises ``RuntimeError`` if sounddevice/numpy are not installed + or if a recording is already in progress. + """ + try: + _import_audio() + except (ImportError, OSError) as e: + raise RuntimeError( + "Voice mode requires sounddevice and numpy.\n" + "Install with: pip install sounddevice numpy\n" + "Or: pip install hermes-agent[voice]" + ) from e + + with self._lock: + if self._recording: + return # already recording + + self._frames = [] + self._start_time = time.monotonic() + self._has_spoken = False + self._speech_start = 0.0 + self._dip_start = 0.0 + self._silence_start = 0.0 + self._resume_start = 0.0 + self._resume_dip_start = 0.0 + self._peak_rms = 0 + self._current_rms = 0 + self._on_silence_stop = on_silence_stop + + # Ensure the persistent stream is alive (no-op after first call). + self._ensure_stream() + + with self._lock: + self._recording = True + logger.info("Voice recording started (rate=%d, channels=%d)", SAMPLE_RATE, CHANNELS) + + def _close_stream_with_timeout(self, timeout: float = 3.0) -> None: + """Close the audio stream with a timeout to prevent CoreAudio hangs.""" + if self._stream is None: + return + + stream = self._stream + self._stream = None + + def _do_close(): + try: + stream.stop() + stream.close() + except Exception: + pass + + t = threading.Thread(target=_do_close, daemon=True) + t.start() + # Poll in short intervals so Ctrl+C is not blocked + deadline = __import__("time").monotonic() + timeout + while t.is_alive() and __import__("time").monotonic() < deadline: + t.join(timeout=0.1) + if t.is_alive(): + logger.warning("Audio stream close timed out after %.1fs — forcing ahead", timeout) + + def stop(self) -> Optional[str]: + """Stop recording and write captured audio to a WAV file. + + The underlying stream is kept alive for reuse — only frame + collection is stopped. + + Returns: + Path to the WAV file, or ``None`` if no audio was captured. + """ + with self._lock: + if not self._recording: + return None + + self._recording = False + self._current_rms = 0 + # Stream stays alive — no close needed. + + if not self._frames: + return None + + # Concatenate frames and write WAV + _, np = _import_audio() + audio_data = np.concatenate(self._frames, axis=0) + self._frames = [] + + elapsed = time.monotonic() - self._start_time + logger.info("Voice recording stopped (%.1fs, %d samples)", elapsed, len(audio_data)) + + # Skip very short recordings (< 0.3s of audio) + min_samples = int(SAMPLE_RATE * 0.3) + if len(audio_data) < min_samples: + logger.debug("Recording too short (%d samples), discarding", len(audio_data)) + return None + + # Skip silent recordings using peak RMS (not overall average, which + # gets diluted by silence at the end of the recording). + if self._peak_rms < SILENCE_RMS_THRESHOLD: + logger.info("Recording too quiet (peak RMS=%d < %d), discarding", + self._peak_rms, SILENCE_RMS_THRESHOLD) + return None + + return self._write_wav(audio_data) + + def cancel(self) -> None: + """Stop recording and discard all captured audio. + + The underlying stream is kept alive for reuse. + """ + with self._lock: + self._recording = False + self._frames = [] + self._on_silence_stop = None + self._current_rms = 0 + logger.info("Voice recording cancelled") + + def shutdown(self) -> None: + """Release the audio stream. Call when voice mode is disabled.""" + with self._lock: + self._recording = False + self._frames = [] + self._on_silence_stop = None + # Close stream OUTSIDE the lock to avoid deadlock with audio callback + self._close_stream_with_timeout() + logger.info("AudioRecorder shut down") + + # -- private helpers ----------------------------------------------------- + + @staticmethod + def _write_wav(audio_data) -> str: + """Write numpy int16 audio data to a WAV file. + + Returns the file path. + """ + os.makedirs(_TEMP_DIR, exist_ok=True) + timestamp = time.strftime("%Y%m%d_%H%M%S") + wav_path = os.path.join(_TEMP_DIR, f"recording_{timestamp}.wav") + + with wave.open(wav_path, "wb") as wf: + wf.setnchannels(CHANNELS) + wf.setsampwidth(SAMPLE_WIDTH) + wf.setframerate(SAMPLE_RATE) + wf.writeframes(audio_data.tobytes()) + + file_size = os.path.getsize(wav_path) + logger.info("WAV written: %s (%d bytes)", wav_path, file_size) + return wav_path + + +# ============================================================================ +# Whisper hallucination filter +# ============================================================================ +# Whisper commonly hallucinates these phrases on silent/near-silent audio. +WHISPER_HALLUCINATIONS = { + "thank you.", + "thank you", + "thanks for watching.", + "thanks for watching", + "subscribe to my channel.", + "subscribe to my channel", + "like and subscribe.", + "like and subscribe", + "please subscribe.", + "please subscribe", + "thank you for watching.", + "thank you for watching", + "bye.", + "bye", + "you", + "the end.", + "the end", + # Non-English hallucinations (common on silence) + "продолжение следует", + "продолжение следует...", + "sous-titres", + "sous-titres réalisés par la communauté d'amara.org", + "sottotitoli creati dalla comunità amara.org", + "untertitel von stephanie geiges", + "amara.org", + "www.mooji.org", + "ご視聴ありがとうございました", +} + +# Regex patterns for repetitive hallucinations (e.g. "Thank you. Thank you. Thank you.") +_HALLUCINATION_REPEAT_RE = re.compile( + r'^(?:thank you|thanks|bye|you|ok|okay|the end|\.|\s|,|!)+$', + flags=re.IGNORECASE, +) + + +def is_whisper_hallucination(transcript: str) -> bool: + """Check if a transcript is a known Whisper hallucination on silence.""" + cleaned = transcript.strip().lower() + if not cleaned: + return True + # Exact match against known phrases + if cleaned.rstrip('.!') in WHISPER_HALLUCINATIONS or cleaned in WHISPER_HALLUCINATIONS: + return True + # Repetitive patterns (e.g. "Thank you. Thank you. Thank you. you") + if _HALLUCINATION_REPEAT_RE.match(cleaned): + return True + return False + + +# ============================================================================ +# STT dispatch +# ============================================================================ +def transcribe_recording(wav_path: str, model: Optional[str] = None) -> Dict[str, Any]: + """Transcribe a WAV recording using the existing Whisper pipeline. + + Delegates to ``tools.transcription_tools.transcribe_audio()``. + Filters out known Whisper hallucinations on silent audio. + + Args: + wav_path: Path to the WAV file. + model: Whisper model name (default: from config or ``whisper-1``). + + Returns: + Dict with ``success``, ``transcript``, and optionally ``error``. + """ + from tools.transcription_tools import transcribe_audio + + result = transcribe_audio(wav_path, model=model) + + # Filter out Whisper hallucinations (common on silent/near-silent audio) + if result.get("success") and is_whisper_hallucination(result.get("transcript", "")): + logger.info("Filtered Whisper hallucination: %r", result["transcript"]) + return {"success": True, "transcript": "", "filtered": True} + + return result + + +# ============================================================================ +# Audio playback (interruptable) +# ============================================================================ + +# Global reference to the active playback process so it can be interrupted. +_active_playback: Optional[subprocess.Popen] = None +_playback_lock = threading.Lock() + + +def stop_playback() -> None: + """Interrupt the currently playing audio (if any).""" + global _active_playback + with _playback_lock: + proc = _active_playback + _active_playback = None + if proc and proc.poll() is None: + try: + proc.terminate() + logger.info("Audio playback interrupted") + except Exception: + pass + # Also stop sounddevice playback if active + try: + sd, _ = _import_audio() + sd.stop() + except Exception: + pass + + +def play_audio_file(file_path: str) -> bool: + """Play an audio file through the default output device. + + Strategy: + 1. WAV files via ``sounddevice.play()`` when available. + 2. System commands: ``afplay`` (macOS), ``ffplay`` (cross-platform), + ``aplay`` (Linux ALSA). + + Playback can be interrupted by calling ``stop_playback()``. + + Returns: + ``True`` if playback succeeded, ``False`` otherwise. + """ + global _active_playback + + if not os.path.isfile(file_path): + logger.warning("Audio file not found: %s", file_path) + return False + + # Try sounddevice for WAV files + if file_path.endswith(".wav"): + try: + sd, np = _import_audio() + with wave.open(file_path, "rb") as wf: + frames = wf.readframes(wf.getnframes()) + audio_data = np.frombuffer(frames, dtype=np.int16) + sample_rate = wf.getframerate() + + sd.play(audio_data, samplerate=sample_rate) + # sd.wait() calls Event.wait() without timeout — hangs forever if + # the audio device stalls. Poll with a ceiling and force-stop. + duration_secs = len(audio_data) / sample_rate + deadline = time.monotonic() + duration_secs + 2.0 + while sd.get_stream() and sd.get_stream().active and time.monotonic() < deadline: + time.sleep(0.01) + sd.stop() + return True + except (ImportError, OSError): + pass # audio libs not available, fall through to system players + except Exception as e: + logger.debug("sounddevice playback failed: %s", e) + + # Fall back to system audio players (using Popen for interruptability) + system = platform.system() + players = [] + + if system == "Darwin": + players.append(["afplay", file_path]) + players.append(["ffplay", "-nodisp", "-autoexit", "-loglevel", "quiet", file_path]) + if system == "Linux": + players.append(["aplay", "-q", file_path]) + + for cmd in players: + exe = shutil.which(cmd[0]) + if exe: + try: + proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + with _playback_lock: + _active_playback = proc + proc.wait(timeout=300) + with _playback_lock: + _active_playback = None + return True + except subprocess.TimeoutExpired: + logger.warning("System player %s timed out, killing process", cmd[0]) + proc.kill() + proc.wait() + with _playback_lock: + _active_playback = None + except Exception as e: + logger.debug("System player %s failed: %s", cmd[0], e) + with _playback_lock: + _active_playback = None + + logger.warning("No audio player available for %s", file_path) + return False + + +# ============================================================================ +# Requirements check +# ============================================================================ +def check_voice_requirements() -> Dict[str, Any]: + """Check if all voice mode requirements are met. + + Returns: + Dict with ``available``, ``audio_available``, ``stt_available``, + ``missing_packages``, and ``details``. + """ + # Determine STT provider availability + from tools.transcription_tools import _get_provider, _load_stt_config, _HAS_FASTER_WHISPER + stt_config = _load_stt_config() + stt_provider = _get_provider(stt_config) + stt_available = stt_provider != "none" + + missing: List[str] = [] + has_audio = _audio_available() + + if not has_audio: + missing.extend(["sounddevice", "numpy"]) + + # Environment detection + env_check = detect_audio_environment() + + available = has_audio and stt_available and env_check["available"] + details_parts = [] + + if has_audio: + details_parts.append("Audio capture: OK") + else: + details_parts.append("Audio capture: MISSING (pip install sounddevice numpy)") + + if stt_provider == "local": + details_parts.append("STT provider: OK (local faster-whisper)") + elif stt_provider == "groq": + details_parts.append("STT provider: OK (Groq)") + elif stt_provider == "openai": + details_parts.append("STT provider: OK (OpenAI)") + else: + details_parts.append( + "STT provider: MISSING (pip install faster-whisper, " + "or set GROQ_API_KEY / VOICE_TOOLS_OPENAI_KEY)" + ) + + for warning in env_check["warnings"]: + details_parts.append(f"Environment: {warning}") + + return { + "available": available, + "audio_available": has_audio, + "stt_available": stt_available, + "missing_packages": missing, + "details": "\n".join(details_parts), + "environment": env_check, + } + + +# ============================================================================ +# Temp file cleanup +# ============================================================================ +def cleanup_temp_recordings(max_age_seconds: int = 3600) -> int: + """Remove old temporary voice recording files. + + Args: + max_age_seconds: Delete files older than this (default: 1 hour). + + Returns: + Number of files deleted. + """ + if not os.path.isdir(_TEMP_DIR): + return 0 + + deleted = 0 + now = time.time() + + for entry in os.scandir(_TEMP_DIR): + if entry.is_file() and entry.name.startswith("recording_") and entry.name.endswith(".wav"): + try: + age = now - entry.stat().st_mtime + if age > max_age_seconds: + os.unlink(entry.path) + deleted += 1 + except OSError: + pass + + if deleted: + logger.debug("Cleaned up %d old voice recordings", deleted) + return deleted diff --git a/website/docs/user-guide/features/voice-mode.md b/website/docs/user-guide/features/voice-mode.md new file mode 100644 index 0000000000..ce151643a0 --- /dev/null +++ b/website/docs/user-guide/features/voice-mode.md @@ -0,0 +1,487 @@ +--- +sidebar_position: 10 +title: "Voice Mode" +description: "Real-time voice conversations with Hermes Agent — CLI, Telegram, Discord (DMs, text channels, and voice channels)" +--- + +# Voice Mode + +Hermes Agent supports full voice interaction across CLI and messaging platforms. Talk to the agent using your microphone, hear spoken replies, and have live voice conversations in Discord voice channels. + +## Prerequisites + +Before using voice features, make sure you have: + +1. **Hermes Agent installed** — `pip install hermes-agent` (see [Getting Started](../../getting-started.md)) +2. **An LLM provider configured** — set `OPENAI_API_KEY`, `OPENAI_BASE_URL`, and `LLM_MODEL` in `~/.hermes/.env` +3. **A working base setup** — run `hermes` to verify the agent responds to text before enabling voice + +:::tip +The `~/.hermes/` directory and default `config.yaml` are created automatically the first time you run `hermes`. You only need to create `~/.hermes/.env` manually for API keys. +::: + +## Overview + +| Feature | Platform | Description | +|---------|----------|-------------| +| **Interactive Voice** | CLI | Press Ctrl+B to record, agent auto-detects silence and responds | +| **Auto Voice Reply** | Telegram, Discord | Agent sends spoken audio alongside text responses | +| **Voice Channel** | Discord | Bot joins VC, listens to users speaking, speaks replies back | + +## Requirements + +### Python Packages + +```bash +# CLI voice mode (microphone + audio playback) +pip install hermes-agent[voice] + +# Discord + Telegram messaging (includes discord.py[voice] for VC support) +pip install hermes-agent[messaging] + +# Premium TTS (ElevenLabs) +pip install hermes-agent[tts-premium] + +# Everything at once +pip install hermes-agent[all] +``` + +| Extra | Packages | Required For | +|-------|----------|-------------| +| `voice` | `sounddevice`, `numpy` | CLI voice mode | +| `messaging` | `discord.py[voice]`, `python-telegram-bot`, `aiohttp` | Discord & Telegram bots | +| `tts-premium` | `elevenlabs` | ElevenLabs TTS provider | + +:::info +`discord.py[voice]` installs **PyNaCl** (for voice encryption) and **opus bindings** automatically. This is required for Discord voice channel support. +::: + +### System Dependencies + +```bash +# macOS +brew install portaudio ffmpeg opus + +# Ubuntu/Debian +sudo apt install portaudio19-dev ffmpeg libopus0 +``` + +| Dependency | Purpose | Required For | +|-----------|---------|-------------| +| **PortAudio** | Microphone input and audio playback | CLI voice mode | +| **ffmpeg** | Audio format conversion (MP3 → Opus, PCM → WAV) | All platforms | +| **Opus** | Discord voice codec | Discord voice channels | + +### API Keys + +Add to `~/.hermes/.env`: + +```bash +# Speech-to-Text — local provider needs NO key at all +# pip install faster-whisper # Free, runs locally, recommended +GROQ_API_KEY=your-key # Groq Whisper — fast, free tier (cloud) +VOICE_TOOLS_OPENAI_KEY=your-key # OpenAI Whisper — paid (cloud) + +# Text-to-Speech (optional — Edge TTS works without any key) +ELEVENLABS_API_KEY=your-key # ElevenLabs — premium quality +``` + +:::tip +If `faster-whisper` is installed, voice mode works with **zero API keys** for STT. The model (~150 MB for `base`) downloads automatically on first use. +::: + +--- + +## CLI Voice Mode + +### Quick Start + +Start the CLI and enable voice mode: + +```bash +hermes # Start the interactive CLI +``` + +Then use these commands inside the CLI: + +``` +/voice Toggle voice mode on/off +/voice on Enable voice mode +/voice off Disable voice mode +/voice tts Toggle TTS output +/voice status Show current state +``` + +### How It Works + +1. Start the CLI with `hermes` and enable voice mode with `/voice on` +2. **Press Ctrl+B** — a beep plays (880Hz), recording starts +3. **Speak** — a live audio level bar shows your input: `● [▁▂▃▅▇▇▅▂] ❯` +4. **Stop speaking** — after 3 seconds of silence, recording auto-stops +5. **Two beeps** play (660Hz) confirming the recording ended +6. Audio is transcribed via Whisper and sent to the agent +7. If TTS is enabled, the agent's reply is spoken aloud +8. Recording **automatically restarts** — speak again without pressing any key + +This loop continues until you press **Ctrl+B** during recording (exits continuous mode) or 3 consecutive recordings detect no speech. + +:::tip +The record key is configurable via `voice.record_key` in `~/.hermes/config.yaml` (default: `ctrl+b`). +::: + +### Silence Detection + +Two-stage algorithm detects when you've finished speaking: + +1. **Speech confirmation** — waits for audio above the RMS threshold (200) for at least 0.3s, tolerating brief dips between syllables +2. **End detection** — once speech is confirmed, triggers after 3.0 seconds of continuous silence + +If no speech is detected at all for 15 seconds, recording stops automatically. + +Both `silence_threshold` and `silence_duration` are configurable in `config.yaml`. + +### Streaming TTS + +When TTS is enabled, the agent speaks its reply **sentence-by-sentence** as it generates text — you don't wait for the full response: + +1. Buffers text deltas into complete sentences (min 20 chars) +2. Strips markdown formatting and `` blocks +3. Generates and plays audio per sentence in real-time + +### Hallucination Filter + +Whisper sometimes generates phantom text from silence or background noise ("Thank you for watching", "Subscribe", etc.). The agent filters these out using a set of 26 known hallucination phrases across multiple languages, plus a regex pattern that catches repetitive variations. + +--- + +## Gateway Voice Reply (Telegram & Discord) + +If you haven't set up your messaging bots yet, see the platform-specific guides: +- [Telegram Setup Guide](../messaging/telegram.md) +- [Discord Setup Guide](../messaging/discord.md) + +Start the gateway to connect to your messaging platforms: + +```bash +hermes gateway # Start the gateway (connects to configured platforms) +hermes gateway setup # Interactive setup wizard for first-time configuration +``` + +### Discord: Channels vs DMs + +The bot supports two interaction modes on Discord: + +| Mode | How to Talk | Mention Required | Setup | +|------|------------|-----------------|-------| +| **Direct Message (DM)** | Open the bot's profile → "Message" | No | Works immediately | +| **Server Channel** | Type in a text channel where the bot is present | Yes (`@botname`) | Bot must be invited to the server | + +**DM (recommended for personal use):** Just open a DM with the bot and type — no @mention needed. Voice replies and all commands work the same as in channels. + +**Server channels:** The bot only responds when you @mention it (e.g. `@hermesbyt4 hello`). Make sure you select the **bot user** from the mention popup, not the role with the same name. + +:::tip +To disable the mention requirement in server channels, add to `~/.hermes/.env`: +```bash +DISCORD_REQUIRE_MENTION=false +``` +Or set specific channels as free-response (no mention needed): +```bash +DISCORD_FREE_RESPONSE_CHANNELS=123456789,987654321 +``` +::: + +### Commands + +These work in both Telegram and Discord (DMs and text channels): + +``` +/voice Toggle voice mode on/off +/voice on Voice replies only when you send a voice message +/voice tts Voice replies for ALL messages +/voice off Disable voice replies +/voice status Show current setting +``` + +### Modes + +| Mode | Command | Behavior | +|------|---------|----------| +| `off` | `/voice off` | Text only (default) | +| `voice_only` | `/voice on` | Speaks reply only when you send a voice message | +| `all` | `/voice tts` | Speaks reply to every message | + +Voice mode setting is persisted across gateway restarts. + +### Platform Delivery + +| Platform | Format | Notes | +|----------|--------|-------| +| **Telegram** | Voice bubble (Opus/OGG) | Plays inline in chat. ffmpeg converts MP3 → Opus if needed | +| **Discord** | Native voice bubble (Opus/OGG) | Plays inline like a user voice message. Falls back to file attachment if voice bubble API fails | + +--- + +## Discord Voice Channels + +The most immersive voice feature: the bot joins a Discord voice channel, listens to users speaking, transcribes their speech, processes through the agent, and speaks the reply back in the voice channel. + +### Setup + +#### 1. Discord Bot Permissions + +If you already have a Discord bot set up for text (see [Discord Setup Guide](../messaging/discord.md)), you need to add voice permissions. + +Go to the [Discord Developer Portal](https://discord.com/developers/applications) → your application → **Installation** → **Default Install Settings** → **Guild Install**: + +**Add these permissions to the existing text permissions:** + +| Permission | Purpose | Required | +|-----------|---------|----------| +| **Connect** | Join voice channels | Yes | +| **Speak** | Play TTS audio in voice channels | Yes | +| **Use Voice Activity** | Detect when users are speaking | Recommended | + +**Updated Permissions Integer:** + +| Level | Integer | What's Included | +|-------|---------|----------------| +| Text only | `274878286912` | View Channels, Send Messages, Read History, Embeds, Attachments, Threads, Reactions | +| Text + Voice | `274881432640` | All above + Connect, Speak | + +**Re-invite the bot** with the updated permissions URL: + +``` +https://discord.com/oauth2/authorize?client_id=YOUR_APP_ID&scope=bot+applications.commands&permissions=274881432640 +``` + +Replace `YOUR_APP_ID` with your Application ID from the Developer Portal. + +:::warning +Re-inviting the bot to a server it's already in will update its permissions without removing it. You won't lose any data or configuration. +::: + +#### 2. Privileged Gateway Intents + +In the [Developer Portal](https://discord.com/developers/applications) → your application → **Bot** → **Privileged Gateway Intents**, enable all three: + +| Intent | Purpose | +|--------|---------| +| **Presence Intent** | Detect user online/offline status | +| **Server Members Intent** | Map voice SSRC identifiers to Discord user IDs | +| **Message Content Intent** | Read text message content in channels | + +All three are required for full voice channel functionality. **Server Members Intent** is especially critical — without it, the bot cannot identify who is speaking in the voice channel. + +#### 3. Opus Codec + +The Opus codec library must be installed on the machine running the gateway: + +```bash +# macOS (Homebrew) +brew install opus + +# Ubuntu/Debian +sudo apt install libopus0 +``` + +The bot auto-loads the codec from: +- **macOS:** `/opt/homebrew/lib/libopus.dylib` +- **Linux:** `libopus.so.0` + +#### 4. Environment Variables + +```bash +# ~/.hermes/.env + +# Discord bot (already configured for text) +DISCORD_BOT_TOKEN=your-bot-token +DISCORD_ALLOWED_USERS=your-user-id + +# STT — local provider needs no key (pip install faster-whisper) +# GROQ_API_KEY=your-key # Alternative: cloud-based, fast, free tier + +# TTS — optional, Edge TTS (free) is the default +# ELEVENLABS_API_KEY=your-key # Premium quality +``` + +### Start the Gateway + +```bash +hermes gateway # Start with existing configuration +``` + +The bot should come online in Discord within a few seconds. + +### Commands + +Use these in the Discord text channel where the bot is present: + +``` +/voice join Bot joins your current voice channel +/voice channel Alias for /voice join +/voice leave Bot disconnects from voice channel +/voice status Show voice mode and connected channel +``` + +:::info +You must be in a voice channel before running `/voice join`. The bot joins the same VC you're in. +::: + +### How It Works + +When the bot joins a voice channel, it: + +1. **Listens** to each user's audio stream independently +2. **Detects silence** — 1.5s of silence after at least 0.5s of speech triggers processing +3. **Transcribes** the audio via Whisper STT (local, Groq, or OpenAI) +4. **Processes** through the full agent pipeline (session, tools, memory) +5. **Speaks** the reply back in the voice channel via TTS + +### Text Channel Integration + +When the bot is in a voice channel: + +- Transcripts appear in the text channel: `[Voice] @user: what you said` +- Agent responses are sent as text in the channel AND spoken in the VC +- The text channel is the one where `/voice join` was issued + +### Echo Prevention + +The bot automatically pauses its audio listener while playing TTS replies, preventing it from hearing and re-processing its own output. + +### Access Control + +Only users listed in `DISCORD_ALLOWED_USERS` can interact via voice. Other users' audio is silently ignored. + +```bash +# ~/.hermes/.env +DISCORD_ALLOWED_USERS=284102345871466496 +``` + +--- + +## Configuration Reference + +### config.yaml + +```yaml +# Voice recording (CLI) +voice: + record_key: "ctrl+b" # Key to start/stop recording + max_recording_seconds: 120 # Maximum recording length + auto_tts: false # Auto-enable TTS when voice mode starts + silence_threshold: 200 # RMS level (0-32767) below which counts as silence + silence_duration: 3.0 # Seconds of silence before auto-stop + +# Speech-to-Text +stt: + provider: "local" # "local" (free) | "groq" | "openai" + local: + model: "base" # tiny, base, small, medium, large-v3 + # model: "whisper-1" # Legacy: used when provider is not set + +# Text-to-Speech +tts: + provider: "edge" # "edge" (free) | "elevenlabs" | "openai" + edge: + voice: "en-US-AriaNeural" # 322 voices, 74 languages + elevenlabs: + voice_id: "pNInz6obpgDQGcFmaJgB" # Adam + model_id: "eleven_multilingual_v2" + openai: + model: "gpt-4o-mini-tts" + voice: "alloy" # alloy, echo, fable, onyx, nova, shimmer +``` + +### Environment Variables + +```bash +# Speech-to-Text providers (local needs no key) +# pip install faster-whisper # Free local STT — no API key needed +GROQ_API_KEY=... # Groq Whisper (fast, free tier) +VOICE_TOOLS_OPENAI_KEY=... # OpenAI Whisper (paid) + +# STT advanced overrides (optional) +STT_GROQ_MODEL=whisper-large-v3-turbo # Override default Groq STT model +STT_OPENAI_MODEL=whisper-1 # Override default OpenAI STT model +GROQ_BASE_URL=https://api.groq.com/openai/v1 # Custom Groq endpoint +STT_OPENAI_BASE_URL=https://api.openai.com/v1 # Custom OpenAI STT endpoint + +# Text-to-Speech providers (Edge TTS needs no key) +ELEVENLABS_API_KEY=... # ElevenLabs (premium quality) +# OpenAI TTS uses VOICE_TOOLS_OPENAI_KEY + +# Discord voice channel +DISCORD_BOT_TOKEN=... +DISCORD_ALLOWED_USERS=... +``` + +### STT Provider Comparison + +| Provider | Model | Speed | Quality | Cost | API Key | +|----------|-------|-------|---------|------|---------| +| **Local** | `base` | Fast (depends on CPU/GPU) | Good | Free | No | +| **Local** | `small` | Medium | Better | Free | No | +| **Local** | `large-v3` | Slow | Best | Free | No | +| **Groq** | `whisper-large-v3-turbo` | Very fast (~0.5s) | Good | Free tier | Yes | +| **Groq** | `whisper-large-v3` | Fast (~1s) | Better | Free tier | Yes | +| **OpenAI** | `whisper-1` | Fast (~1s) | Good | Paid | Yes | +| **OpenAI** | `gpt-4o-transcribe` | Medium (~2s) | Best | Paid | Yes | + +Provider priority (automatic fallback): **local** > **groq** > **openai** + +### TTS Provider Comparison + +| Provider | Quality | Cost | Latency | Key Required | +|----------|---------|------|---------|-------------| +| **Edge TTS** | Good | Free | ~1s | No | +| **ElevenLabs** | Excellent | Paid | ~2s | Yes | +| **OpenAI TTS** | Good | Paid | ~1.5s | Yes | + +--- + +## Troubleshooting + +### "No audio device found" (CLI) + +PortAudio is not installed: + +```bash +brew install portaudio # macOS +sudo apt install portaudio19-dev # Ubuntu +``` + +### Bot doesn't respond in Discord server channels + +The bot requires an @mention by default in server channels. Make sure you: + +1. Type `@` and select the **bot user** (with the #discriminator), not the **role** with the same name +2. Or use DMs instead — no mention needed +3. Or set `DISCORD_REQUIRE_MENTION=false` in `~/.hermes/.env` + +### Bot joins VC but doesn't hear me + +- Check your Discord user ID is in `DISCORD_ALLOWED_USERS` +- Make sure you're not muted in Discord +- The bot needs a SPEAKING event from Discord before it can map your audio — start speaking within a few seconds of joining + +### Bot hears me but doesn't respond + +- Verify STT is available: install `faster-whisper` (no key needed) or set `GROQ_API_KEY` / `VOICE_TOOLS_OPENAI_KEY` +- Check the LLM model is configured and accessible +- Review gateway logs: `tail -f ~/.hermes/logs/gateway.log` + +### Bot responds in text but not in voice channel + +- TTS provider may be failing — check API key and quota +- Edge TTS (free, no key) is the default fallback +- Check logs for TTS errors + +### Whisper returns garbage text + +The hallucination filter catches most cases automatically. If you're still getting phantom transcripts: + +- Use a quieter environment +- Adjust `silence_threshold` in config (higher = less sensitive) +- Try a different STT model diff --git a/website/docs/user-guide/messaging/discord.md b/website/docs/user-guide/messaging/discord.md index 38fb9598a9..0fc7f8cbc5 100644 --- a/website/docs/user-guide/messaging/discord.md +++ b/website/docs/user-guide/messaging/discord.md @@ -210,8 +210,8 @@ Replace the ID with the actual channel ID (right-click → Copy Channel ID with Hermes Agent supports Discord voice messages: -- **Incoming voice messages** are automatically transcribed using Whisper (requires `VOICE_TOOLS_OPENAI_KEY` to be set in your environment). -- **Text-to-speech**: When TTS is enabled, the bot can send spoken responses as MP3 file attachments. +- **Incoming voice messages** are automatically transcribed using Whisper (requires `GROQ_API_KEY` or `VOICE_TOOLS_OPENAI_KEY` to be set in your environment). +- **Text-to-speech**: Use `/voice tts` to have the bot send spoken audio responses alongside text replies. ## Troubleshooting diff --git a/website/docs/user-guide/messaging/index.md b/website/docs/user-guide/messaging/index.md index 2aa2605e6f..debc841b8b 100644 --- a/website/docs/user-guide/messaging/index.md +++ b/website/docs/user-guide/messaging/index.md @@ -1,12 +1,12 @@ --- sidebar_position: 1 title: "Messaging Gateway" -description: "Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, Email, or Home Assistant — architecture and setup overview" +description: "Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, Email, Home Assistant, or your browser — architecture and setup overview" --- # Messaging Gateway -Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, Email, or Home Assistant. The gateway is a single background process that connects to all your configured platforms, handles sessions, runs cron jobs, and delivers voice messages. +Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, Email, Home Assistant, or your browser. The gateway is a single background process that connects to all your configured platforms, handles sessions, runs cron jobs, and delivers voice messages. ## Architecture @@ -15,24 +15,24 @@ Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, Email, or Home │ Hermes Gateway │ ├───────────────────────────────────────────────────────────────────────────────┤ │ │ -│ ┌──────────┐ ┌─────────┐ ┌──────────┐ ┌───────┐ ┌───────┐ ┌───────┐ ┌────┐│ -│ │ Telegram │ │ Discord │ │ WhatsApp │ │ Slack │ │Signal │ │ Email │ │ HA ││ -│ │ Adapter │ │ Adapter │ │ Adapter │ │Adapter│ │Adapter│ │Adapter│ │Adpt││ -│ └────┬─────┘ └────┬────┘ └────┬─────┘ └──┬────┘ └──┬────┘ └──┬────┘ └─┬──┘│ -│ │ │ │ │ │ │ │ │ -│ └─────────────┴───────────┴───────────┴─────────┴─────────┴────────┘ │ -│ │ │ -│ ┌────────▼────────┐ │ -│ │ Session Store │ │ -│ │ (per-chat) │ │ -│ └────────┬────────┘ │ -│ │ │ -│ ┌────────▼────────┐ │ -│ │ AIAgent │ │ -│ │ (run_agent) │ │ -│ └─────────────────┘ │ -│ │ -└───────────────────────────────────────────────────────────────────────────────┘ +│ ┌──────────┐ ┌─────────┐ ┌──────────┐ ┌───────┐ ┌───────┐ ┌───────┐ ┌────┐ │ +│ │ Telegram │ │ Discord │ │ WhatsApp │ │ Slack │ │Signal │ │ Email │ │ HA │ │ +│ │ Adapter │ │ Adapter │ │ Adapter │ │Adapter│ │Adapter│ │Adapter│ │Adpt│ │ +│ └────┬─────┘ └────┬────┘ └────┬─────┘ └──┬────┘ └──┬────┘ └──┬────┘ └─┬──┘ │ +│ │ │ │ │ │ │ │ │ +│ └─────────────┴───────────┴───────────┴─────────┴─────────┴────────┘ │ +│ │ │ +│ ┌────────▼────────┐ │ +│ │ Session Store │ │ +│ │ (per-chat) │ │ +│ └────────┬────────┘ │ +│ │ │ +│ ┌────────▼────────┐ │ +│ │ AIAgent │ │ +│ │ (run_agent) │ │ +│ └─────────────────┘ │ +│ │ +└───────────────────────────────────────────────────────────────────────────────────────┘ ``` Each platform adapter receives messages, routes them through a per-chat session store, and dispatches them to the AIAgent for processing. The gateway also runs the cron scheduler, ticking every 60 seconds to execute any due jobs.