mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
refactor: extract get_stt_model_from_config helper to eliminate DRY violation
Duplicated YAML config parsing for stt.model existed in gateway/run.py and gateway/platforms/discord.py. Moved to a single helper in transcription_tools.py and added 5 tests covering all edge cases.
This commit is contained in:
parent
3260413cc7
commit
2c84979d77
4 changed files with 67 additions and 25 deletions
|
|
@ -880,18 +880,8 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
try:
|
||||
await asyncio.to_thread(VoiceReceiver.pcm_to_wav, pcm_data, wav_path)
|
||||
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
# Read STT model from config.yaml
|
||||
stt_model = None
|
||||
try:
|
||||
import yaml as _y
|
||||
from pathlib import Path as _P
|
||||
_cfg = _P(os.getenv("HERMES_HOME", _P.home() / ".hermes")) / "config.yaml"
|
||||
if _cfg.exists():
|
||||
with open(_cfg) as _f:
|
||||
stt_model = (_y.safe_load(_f) or {}).get("stt", {}).get("model")
|
||||
except Exception:
|
||||
pass
|
||||
from tools.transcription_tools import transcribe_audio, get_stt_model_from_config
|
||||
stt_model = get_stt_model_from_config()
|
||||
result = await asyncio.to_thread(transcribe_audio, wav_path, model=stt_model)
|
||||
|
||||
if not result.get("success"):
|
||||
|
|
|
|||
|
|
@ -3323,20 +3323,10 @@ class GatewayRunner:
|
|||
Returns:
|
||||
The enriched message string with transcriptions prepended.
|
||||
"""
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
from tools.transcription_tools import transcribe_audio, get_stt_model_from_config
|
||||
import asyncio
|
||||
|
||||
# Read STT model from config.yaml (same key the CLI uses)
|
||||
stt_model = None
|
||||
try:
|
||||
import yaml as _y
|
||||
_cfg = _hermes_home / "config.yaml"
|
||||
if _cfg.exists():
|
||||
with open(_cfg) as _f:
|
||||
_data = _y.safe_load(_f) or {}
|
||||
stt_model = _data.get("stt", {}).get("model")
|
||||
except Exception:
|
||||
pass
|
||||
stt_model = get_stt_model_from_config()
|
||||
|
||||
enriched_parts = []
|
||||
for path in audio_paths:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Tests for tools.transcription_tools -- provider resolution and model correction."""
|
||||
"""Tests for tools.transcription_tools -- provider resolution, model correction, config helper."""
|
||||
|
||||
import os
|
||||
import struct
|
||||
|
|
@ -197,3 +197,47 @@ class TestTranscribeAudioSuccess:
|
|||
result = transcribe_audio(sample_wav)
|
||||
|
||||
assert result["transcript"] == "hello world"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# get_stt_model_from_config
|
||||
# ============================================================================
|
||||
|
||||
class TestGetSttModelFromConfig:
|
||||
def test_returns_model_from_config(self, tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("stt:\n model: whisper-large-v3\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() == "whisper-large-v3"
|
||||
|
||||
def test_returns_none_when_no_stt_section(self, tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("tts:\n provider: edge\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() is None
|
||||
|
||||
def test_returns_none_when_no_config_file(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() is None
|
||||
|
||||
def test_returns_none_on_invalid_yaml(self, tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text(": : :\n bad yaml [[[")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() is None
|
||||
|
||||
def test_returns_none_when_model_key_missing(self, tmp_path, monkeypatch):
|
||||
cfg = tmp_path / "config.yaml"
|
||||
cfg.write_text("stt:\n enabled: true\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() is None
|
||||
|
|
|
|||
|
|
@ -41,6 +41,24 @@ GROQ_BASE_URL = os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1")
|
|||
OPENAI_BASE_URL = os.getenv("STT_OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
|
||||
|
||||
def get_stt_model_from_config() -> Optional[str]:
|
||||
"""Read the STT model name from ~/.hermes/config.yaml.
|
||||
|
||||
Returns the value of ``stt.model`` if present, otherwise ``None``.
|
||||
Silently returns ``None`` on any error (missing file, bad YAML, etc.).
|
||||
"""
|
||||
try:
|
||||
import yaml
|
||||
cfg_path = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "config.yaml"
|
||||
if cfg_path.exists():
|
||||
with open(cfg_path) as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
return data.get("stt", {}).get("model")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_stt_provider() -> Tuple[Optional[str], Optional[str], str]:
|
||||
"""Resolve which STT provider to use based on available API keys.
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue