diff --git a/cli-config.yaml.example b/cli-config.yaml.example index 346e6e851..a0a2d7d8a 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -684,7 +684,11 @@ platform_toolsets: stt: enabled: true # provider: "local" # auto-detected if omitted - model: "whisper-1" # whisper-1 (cheapest) | gpt-4o-mini-transcribe | gpt-4o-transcribe + local: + model: "base" # tiny | base | small | medium | large-v3 | turbo + # language: "" # auto-detect; set to "en", "es", "fr", etc. to force + openai: + model: "whisper-1" # whisper-1 | gpt-4o-mini-transcribe | gpt-4o-transcribe # mistral: # model: "voxtral-mini-latest" # voxtral-mini-latest | voxtral-mini-2602 diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index a51f94095..34a51e721 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -1260,9 +1260,8 @@ class DiscordAdapter(BasePlatformAdapter): try: await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path) - from tools.transcription_tools import transcribe_audio, get_stt_model_from_config - stt_model = get_stt_model_from_config() - result = await asyncio.to_thread(transcribe_audio, wav_path, model=stt_model) + from tools.transcription_tools import transcribe_audio + result = await asyncio.to_thread(transcribe_audio, wav_path) if not result.get("success"): return diff --git a/gateway/run.py b/gateway/run.py index 9aae8217d..9e9bb8fce 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -6099,16 +6099,14 @@ class GatewayRunner: return f"{disabled_note}\n\n{user_text}" return disabled_note - from tools.transcription_tools import transcribe_audio, get_stt_model_from_config + from tools.transcription_tools import transcribe_audio import asyncio - stt_model = get_stt_model_from_config() - enriched_parts = [] for path in audio_paths: try: logger.debug("Transcribing user voice: %s", path) - result = await asyncio.to_thread(transcribe_audio, path, model=stt_model) + result = await asyncio.to_thread(transcribe_audio, path) if result["success"]: transcript = result["transcript"] enriched_parts.append( diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 93aa1cc0c..4944e4293 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -612,7 +612,7 @@ DEFAULT_CONFIG = { }, # Config schema version - bump this when adding new required fields - "_config_version": 13, + "_config_version": 14, } # ============================================================================= @@ -1767,6 +1767,56 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A except Exception: pass + # ── Version 13 → 14: migrate legacy flat stt.model to provider section ── + # Old configs (and cli-config.yaml.example) had a flat `stt.model` key + # that was provider-agnostic. When the provider was "local" this caused + # OpenAI model names (e.g. "whisper-1") to be fed to faster-whisper, + # crashing with "Invalid model size". Move the value into the correct + # provider-specific section and remove the flat key. + if current_ver < 14: + # Read raw config (no defaults merged) to check what the user actually + # wrote, then apply changes to the merged config for saving. + raw = read_raw_config() + raw_stt = raw.get("stt", {}) + if isinstance(raw_stt, dict) and "model" in raw_stt: + legacy_model = raw_stt["model"] + provider = raw_stt.get("provider", "local") + config = load_config() + stt = config.get("stt", {}) + # Remove the legacy flat key + stt.pop("model", None) + # Place it in the appropriate provider section only if the + # user didn't already set a model there + if provider in ("local", "local_command"): + # Don't migrate an OpenAI model name into the local section + _local_models = { + "tiny.en", "tiny", "base.en", "base", "small.en", "small", + "medium.en", "medium", "large-v1", "large-v2", "large-v3", + "large", "distil-large-v2", "distil-medium.en", + "distil-small.en", "distil-large-v3", "distil-large-v3.5", + "large-v3-turbo", "turbo", + } + if legacy_model in _local_models: + # Check raw config — only set if user didn't already + # have a nested local.model + raw_local = raw_stt.get("local", {}) + if not isinstance(raw_local, dict) or "model" not in raw_local: + local_cfg = stt.setdefault("local", {}) + local_cfg["model"] = legacy_model + # else: drop it — it was an OpenAI model name, local section + # already defaults to "base" via DEFAULT_CONFIG + else: + # Cloud provider — put it in that provider's section only + # if user didn't already set a nested model + raw_provider = raw_stt.get(provider, {}) + if not isinstance(raw_provider, dict) or "model" not in raw_provider: + provider_cfg = stt.setdefault(provider, {}) + provider_cfg["model"] = legacy_model + config["stt"] = stt + save_config(config) + if not quiet: + print(f" ✓ Migrated legacy stt.model to provider-specific config") + if current_ver < latest_ver and not quiet: print(f"Config version: {current_ver} → {latest_ver}") diff --git a/tests/gateway/test_stt_config.py b/tests/gateway/test_stt_config.py index 436afd7c1..a49e40215 100644 --- a/tests/gateway/test_stt_config.py +++ b/tests/gateway/test_stt_config.py @@ -40,9 +40,6 @@ async def test_enrich_message_with_transcription_skips_when_stt_disabled(): with patch( "tools.transcription_tools.transcribe_audio", side_effect=AssertionError("transcribe_audio should not be called when STT is disabled"), - ), patch( - "tools.transcription_tools.get_stt_model_from_config", - return_value=None, ): result = await runner._enrich_message_with_transcription( "caption", @@ -63,9 +60,6 @@ async def test_enrich_message_with_transcription_avoids_bogus_no_provider_messag with patch( "tools.transcription_tools.transcribe_audio", return_value={"success": False, "error": "VOICE_TOOLS_OPENAI_KEY not set"}, - ), patch( - "tools.transcription_tools.get_stt_model_from_config", - return_value=None, ): result = await runner._enrich_message_with_transcription( "caption", diff --git a/tests/tools/test_transcription_tools.py b/tests/tools/test_transcription_tools.py index f781c32bd..88a33298e 100644 --- a/tests/tools/test_transcription_tools.py +++ b/tests/tools/test_transcription_tools.py @@ -822,27 +822,54 @@ class TestTranscribeAudioDispatch: # ============================================================================ class TestGetSttModelFromConfig: - def test_returns_model_from_config(self, tmp_path, monkeypatch): + """get_stt_model_from_config is provider-aware: it reads the model from the + correct provider-specific section (stt.local.model, stt.openai.model, etc.) + and only honours the legacy flat stt.model key for cloud providers.""" + + def test_returns_local_model_from_nested_config(self, tmp_path, monkeypatch): cfg = tmp_path / "config.yaml" - cfg.write_text("stt:\n model: whisper-large-v3\n") + cfg.write_text("stt:\n provider: local\n local:\n model: large-v3\n") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.transcription_tools import get_stt_model_from_config + assert get_stt_model_from_config() == "large-v3" + + def test_returns_openai_model_from_nested_config(self, tmp_path, monkeypatch): + cfg = tmp_path / "config.yaml" + cfg.write_text("stt:\n provider: openai\n openai:\n model: gpt-4o-transcribe\n") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.transcription_tools import get_stt_model_from_config + assert get_stt_model_from_config() == "gpt-4o-transcribe" + + def test_legacy_flat_key_ignored_for_local_provider(self, tmp_path, monkeypatch): + """Legacy stt.model should NOT be used when provider is local, to prevent + OpenAI model names (whisper-1) from being fed to faster-whisper.""" + cfg = tmp_path / "config.yaml" + cfg.write_text("stt:\n provider: local\n model: whisper-1\n") + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.transcription_tools import get_stt_model_from_config + result = get_stt_model_from_config() + assert result != "whisper-1", "Legacy stt.model should be ignored for local provider" + + def test_legacy_flat_key_honoured_for_cloud_provider(self, tmp_path, monkeypatch): + """Legacy stt.model should still work for cloud providers that don't + have a section in DEFAULT_CONFIG (e.g. groq).""" + cfg = tmp_path / "config.yaml" + cfg.write_text("stt:\n provider: groq\n model: whisper-large-v3\n") monkeypatch.setenv("HERMES_HOME", str(tmp_path)) from tools.transcription_tools import get_stt_model_from_config assert get_stt_model_from_config() == "whisper-large-v3" - def test_returns_none_when_no_stt_section(self, tmp_path, monkeypatch): - cfg = tmp_path / "config.yaml" - cfg.write_text("tts:\n provider: edge\n") + def test_defaults_to_local_model_when_no_config_file(self, tmp_path, monkeypatch): + """With no config file, load_config() returns DEFAULT_CONFIG which has + stt.provider=local and stt.local.model=base.""" monkeypatch.setenv("HERMES_HOME", str(tmp_path)) from tools.transcription_tools import get_stt_model_from_config - assert get_stt_model_from_config() is None - - def test_returns_none_when_no_config_file(self, tmp_path, monkeypatch): - monkeypatch.setenv("HERMES_HOME", str(tmp_path)) - - from tools.transcription_tools import get_stt_model_from_config - assert get_stt_model_from_config() is None + assert get_stt_model_from_config() == "base" def test_returns_none_on_invalid_yaml(self, tmp_path, monkeypatch): cfg = tmp_path / "config.yaml" @@ -850,15 +877,12 @@ class TestGetSttModelFromConfig: monkeypatch.setenv("HERMES_HOME", str(tmp_path)) from tools.transcription_tools import get_stt_model_from_config - assert get_stt_model_from_config() is None - - def test_returns_none_when_model_key_missing(self, tmp_path, monkeypatch): - cfg = tmp_path / "config.yaml" - cfg.write_text("stt:\n enabled: true\n") - monkeypatch.setenv("HERMES_HOME", str(tmp_path)) - - from tools.transcription_tools import get_stt_model_from_config - assert get_stt_model_from_config() is None + # _load_stt_config catches exceptions and returns {}, so the function + # falls through to return None (no provider section in empty dict) + result = get_stt_model_from_config() + # With empty config, load_config may still merge defaults; either + # None or a default is acceptable — just not an OpenAI model name + assert result is None or result in ("base", "small", "medium", "large-v3") # ============================================================================ diff --git a/tools/transcription_tools.py b/tools/transcription_tools.py index d4f9145c2..3d3473a39 100644 --- a/tools/transcription_tools.py +++ b/tools/transcription_tools.py @@ -96,12 +96,28 @@ _local_model_name: Optional[str] = None def get_stt_model_from_config() -> Optional[str]: """Read the STT model name from ~/.hermes/config.yaml. - Returns the value of ``stt.model`` if present, otherwise ``None``. + Provider-aware: reads from the correct provider-specific section + (``stt.local.model``, ``stt.openai.model``, etc.). Falls back to + the legacy flat ``stt.model`` key only for cloud providers — if the + resolved provider is ``local`` the legacy key is ignored to prevent + OpenAI model names (e.g. ``whisper-1``) from being fed to + faster-whisper. + Silently returns ``None`` on any error (missing file, bad YAML, etc.). """ try: - from hermes_cli.config import read_raw_config - return read_raw_config().get("stt", {}).get("model") + stt_cfg = _load_stt_config() + provider = stt_cfg.get("provider", DEFAULT_PROVIDER) + # Read from the provider-specific section first + provider_model = stt_cfg.get(provider, {}).get("model") + if provider_model: + return provider_model + # Legacy flat key — only honour for non-local providers to avoid + # feeding OpenAI model names (whisper-1) to faster-whisper. + if provider not in ("local", "local_command"): + legacy = stt_cfg.get("model") + if legacy: + return legacy except Exception: pass return None