diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 1fda14229..5cdbf9eec 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -880,18 +880,8 @@ class DiscordAdapter(BasePlatformAdapter): try: await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path) - from tools.transcription_tools import transcribe_audio - # Read STT model from config.yaml - stt_model = None - try: - import yaml as _y - from pathlib import Path as _P - _cfg = _P(os.getenv("HERMES_HOME", _P.home() / ".hermes")) / "config.yaml" - if _cfg.exists(): - with open(_cfg) as _f: - stt_model = (_y.safe_load(_f) or {}).get("stt", {}).get("model") - except Exception: - pass + from tools.transcription_tools import transcribe_audio, get_stt_model_from_config + stt_model = get_stt_model_from_config() result = await asyncio.to_thread(transcribe_audio, wav_path, model=stt_model) if not result.get("success"): diff --git a/gateway/run.py b/gateway/run.py index 157ed9d99..6cddd491b 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -3323,20 +3323,10 @@ class GatewayRunner: Returns: The enriched message string with transcriptions prepended. """ - from tools.transcription_tools import transcribe_audio + from tools.transcription_tools import transcribe_audio, get_stt_model_from_config import asyncio - # Read STT model from config.yaml (same key the CLI uses) - stt_model = None - try: - import yaml as _y - _cfg = _hermes_home / "config.yaml" - if _cfg.exists(): - with open(_cfg) as _f: - _data = _y.safe_load(_f) or {} - stt_model = _data.get("stt", {}).get("model") - except Exception: - pass + stt_model = get_stt_model_from_config() enriched_parts = [] for path in audio_paths: diff --git a/tests/tools/test_transcription_tools.py b/tests/tools/test_transcription_tools.py index 6750f28d3..177d5d85e 100644 --- a/tests/tools/test_transcription_tools.py +++ b/tests/tools/test_transcription_tools.py @@ -1,4 +1,4 @@ -"""Tests for tools.transcription_tools -- provider resolution and model correction.""" +"""Tests for tools.transcription_tools -- provider resolution, model correction, config helper.""" import os import struct @@ -197,3 +197,47 @@ class TestTranscribeAudioSuccess: result = transcribe_audio(sample_wav) assert result["transcript"] == "hello world" + + +# ============================================================================ +# get_stt_model_from_config +# ============================================================================ + +class TestGetSttModelFromConfig: + def test_returns_model_from_config(self, tmp_path, monkeypatch): + cfg = tmp_path / "config.yaml" + cfg.write_text("stt:\n model: whisper-large-v3\n") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.transcription_tools import get_stt_model_from_config + assert get_stt_model_from_config() == "whisper-large-v3" + + def test_returns_none_when_no_stt_section(self, tmp_path, monkeypatch): + cfg = tmp_path / "config.yaml" + cfg.write_text("tts:\n provider: edge\n") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.transcription_tools import get_stt_model_from_config + assert get_stt_model_from_config() is None + + def test_returns_none_when_no_config_file(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.transcription_tools import get_stt_model_from_config + assert get_stt_model_from_config() is None + + def test_returns_none_on_invalid_yaml(self, tmp_path, monkeypatch): + cfg = tmp_path / "config.yaml" + cfg.write_text(": : :\n bad yaml [[[") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.transcription_tools import get_stt_model_from_config + assert get_stt_model_from_config() is None + + def test_returns_none_when_model_key_missing(self, tmp_path, monkeypatch): + cfg = tmp_path / "config.yaml" + cfg.write_text("stt:\n enabled: true\n") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.transcription_tools import get_stt_model_from_config + assert get_stt_model_from_config() is None diff --git a/tools/transcription_tools.py b/tools/transcription_tools.py index c962f77c3..e67acffb5 100644 --- a/tools/transcription_tools.py +++ b/tools/transcription_tools.py @@ -41,6 +41,24 @@ GROQ_BASE_URL = os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1") OPENAI_BASE_URL = os.getenv("STT_OPENAI_BASE_URL", "https://api.openai.com/v1") +def get_stt_model_from_config() -> Optional[str]: + """Read the STT model name from ~/.hermes/config.yaml. + + Returns the value of ``stt.model`` if present, otherwise ``None``. + Silently returns ``None`` on any error (missing file, bad YAML, etc.). + """ + try: + import yaml + cfg_path = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "config.yaml" + if cfg_path.exists(): + with open(cfg_path) as f: + data = yaml.safe_load(f) or {} + return data.get("stt", {}).get("model") + except Exception: + pass + return None + + def _resolve_stt_provider() -> Tuple[Optional[str], Optional[str], str]: """Resolve which STT provider to use based on available API keys.