mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(stt): map cloud-only model names to valid local size for faster-whisper (#2544)
Cherry-picked from PR #2545 by @Mibayy. The setup wizard could leave stt.model: "whisper-1" in config.yaml. When using the local faster-whisper provider, this crashed with "Invalid model size 'whisper-1'". Voice messages were silently ignored. _normalize_local_model() now detects cloud-only names (whisper-1, gpt-4o-transcribe, etc.) and maps them to the default local model with a warning. Valid local sizes (tiny, base, small, medium, large-v3) pass through unchanged. - Renamed _normalize_local_command_model -> _normalize_local_model (backward-compat wrapper preserved) - 6 new tests including integration test - Added lowercase AUTHOR_MAP alias for @Mibayy Closes #2544
This commit is contained in:
parent
0613f10def
commit
3273f301b7
3 changed files with 88 additions and 2 deletions
|
|
@ -82,6 +82,7 @@ AUTHOR_MAP = {
|
|||
"462836+jplew@users.noreply.github.com": "jplew",
|
||||
"nish3451@users.noreply.github.com": "nish3451",
|
||||
"Mibayy@users.noreply.github.com": "Mibayy",
|
||||
"mibayy@users.noreply.github.com": "Mibayy",
|
||||
"135070653+sgaofen@users.noreply.github.com": "sgaofen",
|
||||
"nocoo@users.noreply.github.com": "nocoo",
|
||||
"30841158+n-WN@users.noreply.github.com": "n-WN",
|
||||
|
|
|
|||
|
|
@ -245,3 +245,67 @@ class TestTranscribeAudio:
|
|||
result = transcribe_audio("/nonexistent/file.ogg")
|
||||
assert result["success"] is False
|
||||
assert "not found" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model name normalisation for local providers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNormalizeLocalModel:
|
||||
"""_normalize_local_model() maps cloud-only names to the local default."""
|
||||
|
||||
def test_openai_model_name_maps_to_default(self):
|
||||
from tools.transcription_tools import _normalize_local_model, DEFAULT_LOCAL_MODEL
|
||||
assert _normalize_local_model("whisper-1") == DEFAULT_LOCAL_MODEL
|
||||
|
||||
def test_groq_model_name_maps_to_default(self):
|
||||
from tools.transcription_tools import _normalize_local_model, DEFAULT_LOCAL_MODEL
|
||||
assert _normalize_local_model("whisper-large-v3-turbo") == DEFAULT_LOCAL_MODEL
|
||||
|
||||
def test_valid_local_model_preserved(self):
|
||||
from tools.transcription_tools import _normalize_local_model
|
||||
for size in ("tiny", "base", "small", "medium", "large-v3"):
|
||||
assert _normalize_local_model(size) == size
|
||||
|
||||
def test_none_maps_to_default(self):
|
||||
from tools.transcription_tools import _normalize_local_model, DEFAULT_LOCAL_MODEL
|
||||
assert _normalize_local_model(None) == DEFAULT_LOCAL_MODEL
|
||||
|
||||
def test_warning_emitted_for_cloud_model(self, caplog):
|
||||
import logging
|
||||
from tools.transcription_tools import _normalize_local_model
|
||||
with caplog.at_level(logging.WARNING, logger="tools.transcription_tools"):
|
||||
_normalize_local_model("whisper-1")
|
||||
assert any("whisper-1" in r.message for r in caplog.records)
|
||||
|
||||
def test_local_transcribe_normalises_model(self):
|
||||
"""transcribe_audio with local provider must not pass 'whisper-1' to WhisperModel."""
|
||||
import tempfile, os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".ogg", delete=False) as f:
|
||||
f.write(b"x")
|
||||
audio_file = f.name
|
||||
try:
|
||||
mock_model = MagicMock()
|
||||
mock_model.transcribe.return_value = (iter([]), MagicMock(language="en", duration=1.0))
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", True), \
|
||||
patch("tools.transcription_tools._load_stt_config", return_value={
|
||||
"enabled": True,
|
||||
"provider": "local",
|
||||
"local": {"model": "whisper-1"},
|
||||
}), \
|
||||
patch("tools.transcription_tools._local_model", None), \
|
||||
patch("tools.transcription_tools._local_model_name", None), \
|
||||
patch("faster_whisper.WhisperModel", return_value=mock_model) as mock_cls:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(audio_file)
|
||||
# WhisperModel must NOT have been called with "whisper-1"
|
||||
call_args = mock_cls.call_args
|
||||
assert call_args is not None
|
||||
assert call_args[0][0] != "whisper-1", (
|
||||
"WhisperModel was called with the cloud-only name 'whisper-1'"
|
||||
)
|
||||
finally:
|
||||
os.unlink(audio_file)
|
||||
|
|
|
|||
|
|
@ -154,12 +154,31 @@ def _has_local_command() -> bool:
|
|||
return _get_local_command_template() is not None
|
||||
|
||||
|
||||
def _normalize_local_command_model(model_name: Optional[str]) -> str:
|
||||
def _normalize_local_model(model_name: Optional[str]) -> str:
|
||||
"""Return a valid faster-whisper model size, mapping cloud-only names to the default.
|
||||
|
||||
Cloud providers like OpenAI use names such as ``whisper-1`` which are not
|
||||
valid for faster-whisper (which expects ``tiny``, ``base``, ``small``,
|
||||
``medium``, or ``large-v*``). When such a name is detected we fall back to
|
||||
the default local model and emit a warning so the user knows what happened.
|
||||
"""
|
||||
if not model_name or model_name in OPENAI_MODELS or model_name in GROQ_MODELS:
|
||||
if model_name and (model_name in OPENAI_MODELS or model_name in GROQ_MODELS):
|
||||
logger.warning(
|
||||
"STT model '%s' is a cloud-only name and cannot be used with the local "
|
||||
"provider. Falling back to '%s'. Set stt.local.model to a valid "
|
||||
"faster-whisper size (tiny, base, small, medium, large-v3).",
|
||||
model_name,
|
||||
DEFAULT_LOCAL_MODEL,
|
||||
)
|
||||
return DEFAULT_LOCAL_MODEL
|
||||
return model_name
|
||||
|
||||
|
||||
def _normalize_local_command_model(model_name: Optional[str]) -> str:
|
||||
return _normalize_local_model(model_name)
|
||||
|
||||
|
||||
def _get_provider(stt_config: dict) -> str:
|
||||
"""Determine which STT provider to use.
|
||||
|
||||
|
|
@ -596,7 +615,9 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
|||
|
||||
if provider == "local":
|
||||
local_cfg = stt_config.get("local", {})
|
||||
model_name = model or local_cfg.get("model", DEFAULT_LOCAL_MODEL)
|
||||
model_name = _normalize_local_model(
|
||||
model or local_cfg.get("model", DEFAULT_LOCAL_MODEL)
|
||||
)
|
||||
return _transcribe_local(file_path, model_name)
|
||||
|
||||
if provider == "local_command":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue