diff --git a/agent/transcription_provider.py b/agent/transcription_provider.py new file mode 100644 index 00000000000..2586b8cc43a --- /dev/null +++ b/agent/transcription_provider.py @@ -0,0 +1,193 @@ +""" +Transcription Provider ABC +========================== + +Defines the pluggable-backend interface for speech-to-text. Providers +register instances via +:meth:`PluginContext.register_transcription_provider`; the active one +(selected via ``stt.provider`` in ``config.yaml``) services every +:func:`tools.transcription_tools.transcribe_audio` call **when the +configured name is neither a built-in (``local``, ``local_command``, +``groq``, ``openai``, ``mistral``, ``xai``) nor disabled**. + +Two coexisting STT extension surfaces — in resolution order: + +1. **Built-in providers** (``BUILTIN_STT_PROVIDERS`` in + :mod:`tools.transcription_tools`) — native Python implementations + for the 6 backends shipped today (faster-whisper, local_command, + Groq, OpenAI, Mistral, xAI). **Always win** — plugins cannot + shadow them. The single-env-var shell escape hatch + ``HERMES_LOCAL_STT_COMMAND`` is preserved via the built-in + ``local_command`` path. +2. **Plugin-registered providers** (this ABC). For new STT backends — + OpenRouter, SenseAudio, Gemini-STT, custom proprietary engines — + that need a Python implementation without modifying + ``tools/transcription_tools.py``. + +Built-ins-always-win is enforced at registration time +(:func:`agent.transcription_registry.register_provider` rejects names +in ``BUILTIN_STT_PROVIDERS`` with a warning) AND at dispatch time +(:func:`tools.transcription_tools._dispatch_to_plugin_provider` +re-checks defensively). + +Providers live in ``/plugins/transcription//`` (built-in +plugins, none shipped today) or +``~/.hermes/plugins/transcription//`` (user-installed). + +Response contract +----------------- +:meth:`TranscriptionProvider.transcribe` returns a dict with keys:: + + success bool + transcript str transcribed text (empty when success=False) + provider str provider name (for diagnostics) + error str only when success=False +""" + +from __future__ import annotations + +import abc +import logging +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# ABC +# --------------------------------------------------------------------------- + + +class TranscriptionProvider(abc.ABC): + """Abstract base class for a speech-to-text backend. + + Subclasses must implement :attr:`name` and :meth:`transcribe`. + Everything else has sane defaults — override only what your provider + needs. + """ + + @property + @abc.abstractmethod + def name(self) -> str: + """Stable short identifier used in ``stt.provider`` config. + + Lowercase, no spaces. Examples: ``openrouter``, ``sensaudio``, + ``gemini``, ``deepgram``. Names that collide with a built-in STT + provider (``local``, ``local_command``, ``groq``, ``openai``, + ``mistral``, ``xai``) are rejected at registration time. + """ + + @property + def display_name(self) -> str: + """Human-readable label shown in ``hermes tools``. + + Defaults to ``name.title()``. + """ + return self.name.title() + + def is_available(self) -> bool: + """Return True when this provider can service calls. + + Typically checks for a required API key + that the SDK is + importable. Default: True (providers with no external + dependencies are always available). + + Must NOT raise — used by the picker and ``hermes setup`` for + availability displays and should fail gracefully. + """ + return True + + def list_models(self) -> List[Dict[str, Any]]: + """Return model catalog entries. + + Each entry:: + + { + "id": "whisper-large-v3-turbo", # required + "display": "Whisper Large v3 Turbo", # optional + "languages": ["en", "es", "fr"], # optional + "max_audio_seconds": 1500, # optional + } + + Default: empty list (provider has a single fixed model or + doesn't expose model selection). + """ + return [] + + def default_model(self) -> Optional[str]: + """Return the default model id, or None if not applicable.""" + models = self.list_models() + if models: + return models[0].get("id") + return None + + def get_setup_schema(self) -> Dict[str, Any]: + """Return provider metadata for the ``hermes tools`` picker. + + Used by ``tools_config.py`` to inject this provider as a row in + the Speech-to-Text provider list. Shape:: + + { + "name": "OpenRouter STT", # picker label + "badge": "paid", # optional short tag + "tag": "Whisper via OpenRouter API", # optional subtitle + "env_vars": [ # keys to prompt for + {"key": "OPENROUTER_API_KEY", + "prompt": "OpenRouter API key", + "url": "https://openrouter.ai/keys"}, + ], + } + + Default: minimal entry derived from ``display_name`` with no + env vars. Override to expose API key prompts and custom badges. + """ + return { + "name": self.display_name, + "badge": "", + "tag": "", + "env_vars": [], + } + + @abc.abstractmethod + def transcribe( + self, + file_path: str, + *, + model: Optional[str] = None, + language: Optional[str] = None, + **extra: Any, + ) -> Dict[str, Any]: + """Transcribe the audio file at ``file_path``. + + Returns a dict with the standard envelope:: + + { + "success": True, + "transcript": "the transcribed text", + "provider": "", + } + + or on failure:: + + { + "success": False, + "transcript": "", + "error": "human-readable error message", + "provider": "", + } + + Implementations should NOT raise — convert exceptions to the + error envelope so the dispatcher can deliver a consistent shape + to the gateway/CLI caller. + + Args: + file_path: Absolute path to the audio file. The dispatcher + has already validated existence + size before calling. + model: Model identifier from :meth:`list_models`, or None + to use :meth:`default_model`. + language: Optional BCP-47 language hint (e.g. ``"en"``, + ``"ja"``) — providers without language hints should + ignore this argument. + **extra: Forward-compat parameters future schema versions + may expose. Implementations should ignore unknown keys. + """ diff --git a/agent/transcription_registry.py b/agent/transcription_registry.py new file mode 100644 index 00000000000..d84f93b19e4 --- /dev/null +++ b/agent/transcription_registry.py @@ -0,0 +1,122 @@ +""" +Transcription Provider Registry +================================ + +Central map of registered STT providers. Populated by plugins at +import-time via :meth:`PluginContext.register_transcription_provider`; +consumed by :mod:`tools.transcription_tools` to dispatch +:func:`transcribe_audio` calls to the active plugin backend **when** +the configured ``stt.provider`` name is not a built-in. + +Built-ins-always-win +-------------------- +Plugin names that collide with a built-in STT provider (``local``, +``local_command``, ``groq``, ``openai``, ``mistral``, ``xai``) are +rejected at registration with a warning. This invariant is also +re-checked at dispatch time in +:func:`tools.transcription_tools._dispatch_to_plugin_provider`. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Dict, List, Optional + +from agent.transcription_provider import TranscriptionProvider + +logger = logging.getLogger(__name__) + + +# Names reserved for native built-in STT handlers. Plugins cannot +# register a name in this set — the registration call is rejected with +# a warning. **Kept in sync with ``BUILTIN_STT_PROVIDERS`` in +# :mod:`tools.transcription_tools`** — a regression test in +# ``tests/agent/test_transcription_registry.py::TestBuiltinSync`` +# fails if the two lists drift. Importing from +# ``tools.transcription_tools`` directly would create a circular +# dependency (``tools.transcription_tools`` imports +# ``agent.transcription_registry`` for dispatch). +_BUILTIN_NAMES = frozenset({ + "local", + "local_command", + "groq", + "openai", + "mistral", + "xai", +}) + + +_providers: Dict[str, TranscriptionProvider] = {} +_lock = threading.Lock() + + +def register_provider(provider: TranscriptionProvider) -> None: + """Register a transcription provider. + + Rejects: + + - Non-:class:`TranscriptionProvider` instances (raises :class:`TypeError`). + - Empty/whitespace ``.name`` (raises :class:`ValueError`). + - Names colliding with a built-in (logs a warning, silently + ignores — built-ins-always-win invariant). + + Re-registration (same ``name``) overwrites the previous entry and + logs a debug message — makes hot-reload scenarios (tests, dev + loops) behave predictably. + """ + if not isinstance(provider, TranscriptionProvider): + raise TypeError( + f"register_provider() expects a TranscriptionProvider instance, " + f"got {type(provider).__name__}" + ) + name = provider.name + if not isinstance(name, str) or not name.strip(): + raise ValueError("Transcription provider .name must be a non-empty string") + key = name.strip().lower() + if key in _BUILTIN_NAMES: + logger.warning( + "Transcription provider '%s' shadows a built-in name; registration " + "ignored. Built-in STT providers (%s) always win — pick a different " + "name.", + key, ", ".join(sorted(_BUILTIN_NAMES)), + ) + return + with _lock: + existing = _providers.get(key) + _providers[key] = provider + if existing is not None: + logger.debug( + "Transcription provider '%s' re-registered (was %r)", + key, type(existing).__name__, + ) + else: + logger.debug( + "Registered transcription provider '%s' (%s)", + key, type(provider).__name__, + ) + + +def list_providers() -> List[TranscriptionProvider]: + """Return all registered providers, sorted by name.""" + with _lock: + items = list(_providers.values()) + return sorted(items, key=lambda p: p.name) + + +def get_provider(name: str) -> Optional[TranscriptionProvider]: + """Return the provider registered under *name*, or None. + + Name matching is case-insensitive and whitespace-tolerant — mirrors + how ``tools.transcription_tools._get_provider`` normalizes the + configured ``stt.provider`` value. + """ + if not isinstance(name, str): + return None + return _providers.get(name.strip().lower()) + + +def _reset_for_tests() -> None: + """Clear the registry. **Test-only.**""" + with _lock: + _providers.clear() diff --git a/hermes_cli/plugins.py b/hermes_cli/plugins.py index 2218172aa58..bd6367a44c8 100644 --- a/hermes_cli/plugins.py +++ b/hermes_cli/plugins.py @@ -678,6 +678,50 @@ class PluginContext: self.manifest.name, provider.name, ) + # -- transcription (STT) provider registration --------------------------- + + def register_transcription_provider(self, provider) -> None: + """Register a speech-to-text backend. + + ``provider`` must be an instance of + :class:`agent.transcription_provider.TranscriptionProvider`. + The ``provider.name`` attribute is what ``stt.provider`` in + ``config.yaml`` matches against when routing + :func:`tools.transcription_tools.transcribe_audio` calls — + **but only when**: + + 1. ``provider.name`` is NOT a built-in STT provider name + (``local``, ``local_command``, ``groq``, ``openai``, + ``mistral``, ``xai``). Built-ins always win — the registry + rejects shadowing names with a warning. + 2. There is NO ``stt.providers.: type: command`` entry + with the same name. Command-providers win on name + collision because config is more local than plugin install + — same precedence rule as TTS. + + Coexists with the in-tree dispatcher and the STT + command-provider registry rather than replacing them. The 6 + built-in STT backends keep their native implementations in + ``tools/transcription_tools.py``; this hook is for *new* Python + engines (OpenRouter, SenseAudio, Gemini-STT, custom proprietary + backends). + """ + from agent.transcription_provider import TranscriptionProvider + from agent.transcription_registry import register_provider as _register_stt_provider + + if not isinstance(provider, TranscriptionProvider): + logger.warning( + "Plugin '%s' tried to register a transcription provider that " + "does not inherit from TranscriptionProvider. Ignoring.", + self.manifest.name, + ) + return + _register_stt_provider(provider) + logger.info( + "Plugin '%s' registered transcription provider: %s", + self.manifest.name, provider.name, + ) + # -- platform adapter registration --------------------------------------- def register_platform( diff --git a/tests/agent/test_transcription_registry.py b/tests/agent/test_transcription_registry.py new file mode 100644 index 00000000000..9c3b93f0d2c --- /dev/null +++ b/tests/agent/test_transcription_registry.py @@ -0,0 +1,243 @@ +"""Tests for agent/transcription_registry.py and agent/transcription_provider.py. + +Covers: +- Registration happy path +- Registration rejection: non-TranscriptionProvider type +- Registration rejection: empty/whitespace name +- Built-in name shadowing: warning + silent ignore (no exception) +- Re-registration: overwrites + logs at debug +- Case + whitespace insensitivity on lookup +- ABC contract: default implementations work +- ABC contract: transcribe() must be implemented +- Sync invariant: registry built-ins match tools/transcription_tools.py +""" + +from __future__ import annotations + +import logging +from typing import Any, Optional + +import pytest + +from agent import transcription_registry +from agent.transcription_provider import TranscriptionProvider + + +class _FakeProvider(TranscriptionProvider): + def __init__( + self, + name: str = "fake", + display: Optional[str] = None, + available: bool = True, + transcribe_impl: Optional[Any] = None, + ): + self._name = name + self._display = display + self._available = available + self._transcribe_impl = transcribe_impl + + @property + def name(self) -> str: + return self._name + + @property + def display_name(self) -> str: + return self._display if self._display is not None else super().display_name + + def is_available(self) -> bool: + return self._available + + def transcribe(self, file_path: str, **kw): + if self._transcribe_impl is not None: + return self._transcribe_impl(file_path, **kw) + return {"success": True, "transcript": f"fake({file_path})", "provider": self._name} + + +@pytest.fixture(autouse=True) +def _reset_registry(): + transcription_registry._reset_for_tests() + yield + transcription_registry._reset_for_tests() + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + + +class TestRegistration: + def test_happy_path(self): + p = _FakeProvider(name="openrouter") + transcription_registry.register_provider(p) + assert transcription_registry.get_provider("openrouter") is p + assert [r.name for r in transcription_registry.list_providers()] == ["openrouter"] + + def test_rejects_non_provider_type(self): + with pytest.raises(TypeError, match="expects a TranscriptionProvider instance"): + transcription_registry.register_provider("not a provider") # type: ignore[arg-type] + assert transcription_registry.list_providers() == [] + + def test_rejects_empty_name(self): + p = _FakeProvider(name="") + with pytest.raises(ValueError, match="non-empty string"): + transcription_registry.register_provider(p) + assert transcription_registry.list_providers() == [] + + def test_rejects_whitespace_name(self): + p = _FakeProvider(name=" ") + with pytest.raises(ValueError, match="non-empty string"): + transcription_registry.register_provider(p) + assert transcription_registry.list_providers() == [] + + @pytest.mark.parametrize( + "builtin", + ["local", "local_command", "groq", "openai", "mistral", "xai"], + ) + def test_rejects_builtin_shadow_with_warning(self, builtin, caplog): + p = _FakeProvider(name=builtin) + with caplog.at_level(logging.WARNING, logger="agent.transcription_registry"): + transcription_registry.register_provider(p) + assert "shadows a built-in name" in caplog.text + assert builtin in caplog.text + assert transcription_registry.get_provider(builtin) is None + assert transcription_registry.list_providers() == [] + + def test_builtin_shadow_case_insensitive(self, caplog): + for variant in ("OPENAI", "OpenAi", " openai ", "oPeNaI"): + transcription_registry._reset_for_tests() + with caplog.at_level(logging.WARNING, logger="agent.transcription_registry"): + transcription_registry.register_provider(_FakeProvider(name=variant)) + assert transcription_registry.list_providers() == [], ( + f"variant {variant!r} should have been rejected as a built-in shadow" + ) + + def test_reregistration_overwrites(self, caplog): + p1 = _FakeProvider(name="openrouter") + p2 = _FakeProvider(name="openrouter") + transcription_registry.register_provider(p1) + with caplog.at_level(logging.DEBUG, logger="agent.transcription_registry"): + transcription_registry.register_provider(p2) + assert transcription_registry.get_provider("openrouter") is p2 + assert "re-registered" in caplog.text + + +# --------------------------------------------------------------------------- +# Lookup +# --------------------------------------------------------------------------- + + +class TestLookup: + def test_get_provider_missing_returns_none(self): + assert transcription_registry.get_provider("nonexistent") is None + + def test_get_provider_non_string_returns_none(self): + assert transcription_registry.get_provider(None) is None # type: ignore[arg-type] + assert transcription_registry.get_provider(123) is None # type: ignore[arg-type] + + def test_get_provider_case_insensitive(self): + p = _FakeProvider(name="openrouter") + transcription_registry.register_provider(p) + assert transcription_registry.get_provider("OPENROUTER") is p + assert transcription_registry.get_provider("OpenRouter") is p + + def test_get_provider_whitespace_tolerant(self): + p = _FakeProvider(name="openrouter") + transcription_registry.register_provider(p) + assert transcription_registry.get_provider(" openrouter ") is p + + def test_list_providers_sorted(self): + transcription_registry.register_provider(_FakeProvider(name="zylo")) + transcription_registry.register_provider(_FakeProvider(name="alpha")) + transcription_registry.register_provider(_FakeProvider(name="middle")) + names = [p.name for p in transcription_registry.list_providers()] + assert names == ["alpha", "middle", "zylo"] + + +# --------------------------------------------------------------------------- +# ABC contract +# --------------------------------------------------------------------------- + + +class TestABCContract: + def test_must_implement_transcribe(self): + class Incomplete(TranscriptionProvider): + @property + def name(self) -> str: + return "incomplete" + # transcribe NOT implemented + + with pytest.raises(TypeError, match="abstract"): + Incomplete() # type: ignore[abstract] + + def test_must_implement_name(self): + class Incomplete(TranscriptionProvider): + def transcribe(self, file_path, **kw): + return {"success": True, "transcript": "", "provider": "incomplete"} + # name NOT implemented + + with pytest.raises(TypeError, match="abstract"): + Incomplete() # type: ignore[abstract] + + def test_display_name_defaults_to_title(self): + p = _FakeProvider(name="openrouter") + assert p.display_name == "Openrouter" + + def test_display_name_override_respected(self): + p = _FakeProvider(name="openrouter", display="OpenRouter STT") + assert p.display_name == "OpenRouter STT" + + def test_is_available_default_true(self): + p = _FakeProvider(name="openrouter") + assert p.is_available() is True + + def test_list_models_default_empty(self): + p = _FakeProvider(name="openrouter") + assert p.list_models() == [] + + def test_default_model_none_when_no_models(self): + p = _FakeProvider(name="openrouter") + assert p.default_model() is None + + def test_default_model_first_listed(self): + class WithModels(_FakeProvider): + def list_models(self): + return [{"id": "whisper-large-v3-turbo"}, {"id": "whisper-large-v3"}] + + p = WithModels(name="openrouter") + assert p.default_model() == "whisper-large-v3-turbo" + + def test_get_setup_schema_default_minimal(self): + p = _FakeProvider(name="openrouter") + schema = p.get_setup_schema() + assert schema["name"] == "Openrouter" + assert schema["env_vars"] == [] + + +# --------------------------------------------------------------------------- +# Sync invariant: registry built-ins vs dispatcher built-ins +# --------------------------------------------------------------------------- + + +class TestBuiltinSync: + """``_BUILTIN_NAMES`` in agent/transcription_registry.py is duplicated + from ``BUILTIN_STT_PROVIDERS`` in tools/transcription_tools.py + (importing directly would create a circular dependency). This test + fails loudly if the two lists drift — a new built-in added to + transcription_tools.py MUST also be added to + transcription_registry.py's ``_BUILTIN_NAMES`` or the registry will + accept a name the dispatcher will silently route to the wrong + handler. + """ + + def test_registry_builtins_match_dispatcher_builtins(self): + from tools.transcription_tools import BUILTIN_STT_PROVIDERS + + assert transcription_registry._BUILTIN_NAMES == BUILTIN_STT_PROVIDERS, ( + "agent.transcription_registry._BUILTIN_NAMES and " + "tools.transcription_tools.BUILTIN_STT_PROVIDERS have drifted!\n" + f" Registry only: {sorted(transcription_registry._BUILTIN_NAMES - BUILTIN_STT_PROVIDERS)}\n" + f" Dispatcher only: {sorted(BUILTIN_STT_PROVIDERS - transcription_registry._BUILTIN_NAMES)}\n" + "Add the missing names to whichever list is incomplete. " + "These two lists exist as a circular-import workaround and " + "MUST be kept in sync manually." + ) diff --git a/tests/hermes_cli/test_plugins_transcription_registration.py b/tests/hermes_cli/test_plugins_transcription_registration.py new file mode 100644 index 00000000000..5f6ab4a2f78 --- /dev/null +++ b/tests/hermes_cli/test_plugins_transcription_registration.py @@ -0,0 +1,148 @@ +"""Tests for PluginContext.register_transcription_provider(). + +Exercises the plugin context hook end-to-end: drops a fake plugin into +``$HERMES_HOME/plugins/``, runs ``PluginManager().discover_and_load()``, +and asserts the registration result. + +Mirrors the shape of ``test_plugins_tts_registration.py`` (companion +TTS hook from issue #30398). +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any, Dict + +import yaml + + +def _write_plugin( + root: Path, + name: str, + *, + manifest_extra: Dict[str, Any] | None = None, + register_body: str = "pass", +) -> Path: + plugin_dir = root / name + plugin_dir.mkdir(parents=True, exist_ok=True) + manifest = { + "name": name, + "version": "0.1.0", + "description": f"Test plugin {name}", + } + if manifest_extra: + manifest.update(manifest_extra) + (plugin_dir / "plugin.yaml").write_text(yaml.dump(manifest)) + (plugin_dir / "__init__.py").write_text( + f"def register(ctx):\n {register_body}\n" + ) + return plugin_dir + + +def _enable(hermes_home: Path, name: str) -> None: + cfg_path = hermes_home / "config.yaml" + cfg: dict = {} + if cfg_path.exists(): + try: + cfg = yaml.safe_load(cfg_path.read_text()) or {} + except Exception: + cfg = {} + plugins_cfg = cfg.setdefault("plugins", {}) + enabled = plugins_cfg.setdefault("enabled", []) + if isinstance(enabled, list) and name not in enabled: + enabled.append(name) + cfg_path.write_text(yaml.safe_dump(cfg)) + + +class TestRegisterTranscriptionProvider: + def test_accepts_valid_provider(self): + from hermes_cli.plugins import PluginManager + + from agent import transcription_registry + transcription_registry._reset_for_tests() + + hermes_home = Path(os.environ["HERMES_HOME"]) + _write_plugin( + hermes_home / "plugins", + "my-stt-plugin", + register_body=( + "from agent.transcription_provider import TranscriptionProvider\n" + " class P(TranscriptionProvider):\n" + " @property\n" + " def name(self): return 'fake-stt'\n" + " def transcribe(self, file_path, **kw):\n" + " return {'success': True, 'transcript': 'hi', 'provider': 'fake-stt'}\n" + " ctx.register_transcription_provider(P())" + ), + ) + _enable(hermes_home, "my-stt-plugin") + + mgr = PluginManager() + mgr.discover_and_load() + + assert mgr._plugins["my-stt-plugin"].enabled is True, ( + f"Plugin failed to load: {mgr._plugins['my-stt-plugin'].error}" + ) + assert transcription_registry.get_provider("fake-stt") is not None + + transcription_registry._reset_for_tests() + + def test_rejects_non_provider(self, caplog): + from hermes_cli.plugins import PluginManager + + from agent import transcription_registry + transcription_registry._reset_for_tests() + + hermes_home = Path(os.environ["HERMES_HOME"]) + _write_plugin( + hermes_home / "plugins", + "bad-stt-plugin", + register_body="ctx.register_transcription_provider('not a provider')", + ) + _enable(hermes_home, "bad-stt-plugin") + + with caplog.at_level("WARNING"): + mgr = PluginManager() + mgr.discover_and_load() + + assert mgr._plugins["bad-stt-plugin"].enabled is True + assert transcription_registry.get_provider("not a provider") is None + assert transcription_registry.list_providers() == [] + assert "does not inherit from TranscriptionProvider" in caplog.text + + transcription_registry._reset_for_tests() + + def test_rejects_builtin_shadow(self, caplog): + from hermes_cli.plugins import PluginManager + + from agent import transcription_registry + transcription_registry._reset_for_tests() + + hermes_home = Path(os.environ["HERMES_HOME"]) + _write_plugin( + hermes_home / "plugins", + "shadow-stt-plugin", + register_body=( + "from agent.transcription_provider import TranscriptionProvider\n" + " class P(TranscriptionProvider):\n" + " @property\n" + " def name(self): return 'openai'\n" + " def transcribe(self, file_path, **kw):\n" + " return {'success': True, 'transcript': 'hi'}\n" + " ctx.register_transcription_provider(P())" + ), + ) + _enable(hermes_home, "shadow-stt-plugin") + + with caplog.at_level("WARNING"): + mgr = PluginManager() + mgr.discover_and_load() + + # Plugin still loaded normally — built-in shadowing is a warning, + # not an exception. The registry rejects the entry though. + assert mgr._plugins["shadow-stt-plugin"].enabled is True + assert transcription_registry.get_provider("openai") is None + assert "shadows a built-in name" in caplog.text + + transcription_registry._reset_for_tests() diff --git a/tests/plugins/transcription/__init__.py b/tests/plugins/transcription/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/plugins/transcription/check_parity_vs_main.py b/tests/plugins/transcription/check_parity_vs_main.py new file mode 100644 index 00000000000..2a0ac85dc8d --- /dev/null +++ b/tests/plugins/transcription/check_parity_vs_main.py @@ -0,0 +1,344 @@ +"""Behavior-parity check for the STT plugin hook (follow-up to #30398). + +Spawns one subprocess per (version, scenario) cell — pinned to either +``origin/main`` (no plugin hook; ``stt.provider: openrouter`` falls +through to the "No STT provider available" error path) or this PR's +worktree (plugin hook present; same config routes through the plugin +registry when a plugin is registered). + +Each subprocess clears all STT-related env vars + writes a +``config.yaml``, then asks the dispatcher how it would route a +``transcribe_audio`` call. The emitted shape tuple is:: + + {dispatch_kind, provider_name, success} + +Where ``dispatch_kind`` ∈ +``{"builtin_local", "builtin_groq", "builtin_openai", ..., +"plugin", "plugin_unavailable", "no_provider_error", "stt_disabled"}``. + +Acceptable diffs: +- ``no_provider_error → plugin`` for the ``plugin-installed`` scenario. +- ``no_provider_error → plugin_unavailable`` for the + ``plugin-installed-unavailable`` scenario (PR returns the cleaner + unavailability envelope instead of the generic auto-detect error). + +Run from the PR worktree:: + + python tests/plugins/transcription/check_parity_vs_main.py +""" + +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[3] + + +def _resolve_main_dir() -> Path: + candidate = REPO_ROOT.parent.parent + if (candidate / "tools" / "transcription_tools.py").exists() and candidate != REPO_ROOT: + return candidate + sibling = REPO_ROOT.parent / "hermes-agent-main" + if (sibling / "tools" / "transcription_tools.py").exists(): + return sibling + return REPO_ROOT + + +MAIN_DIR = _resolve_main_dir() +PR_DIR = REPO_ROOT +assert (PR_DIR / "tools" / "transcription_tools.py").exists(), ( + f"PR_DIR={PR_DIR} doesn't look like a hermes-agent checkout" +) + + +SUBPROCESS_SCRIPT = r""" +import json, os, sys, tempfile +sys.path.insert(0, sys.argv[1]) + +# Isolated HERMES_HOME so the config write is hermetic. +home = tempfile.mkdtemp() +os.environ["HERMES_HOME"] = home + +# Clear STT-related env so dispatch decisions are config-driven. +for k in ( + "GROQ_API_KEY", "OPENAI_API_KEY", "VOICE_TOOLS_OPENAI_KEY", + "MISTRAL_API_KEY", "XAI_API_KEY", + "HERMES_LOCAL_STT_COMMAND", +): + os.environ.pop(k, None) + +scenario_env = json.loads(sys.argv[2]) +os.environ.update(scenario_env) + +config_yaml = sys.argv[3] +plugin_register = sys.argv[4] # "yes" to register a fake plugin + +config_path = os.path.join(home, "config.yaml") +with open(config_path, "w") as f: + f.write(config_yaml) + +# Fresh import — must not have anything cached from prior runs. +for name in list(sys.modules): + if (name.startswith("tools.") + or name.startswith("agent.") + or name.startswith("plugins.") + or name.startswith("hermes_cli.")): + sys.modules.pop(name, None) + +# Try importing transcription_registry — only exists on PR side. +have_plugin_hook = False +try: + from agent import transcription_registry + from agent.transcription_provider import TranscriptionProvider + have_plugin_hook = True + + if plugin_register == "yes": + class _FakeProvider(TranscriptionProvider): + @property + def name(self): return "openrouter" + def transcribe(self, file_path, **kw): + return {"success": True, "transcript": "plugin transcript", "provider": "openrouter"} + + transcription_registry._reset_for_tests() + transcription_registry.register_provider(_FakeProvider()) + elif plugin_register == "unavailable": + class _UnavailablePlugin(TranscriptionProvider): + @property + def name(self): return "openrouter" + def is_available(self): return False + def transcribe(self, file_path, **kw): + return {"success": True, "transcript": "should not run"} + + transcription_registry._reset_for_tests() + transcription_registry.register_provider(_UnavailablePlugin()) +except ImportError: + pass + +import tools.transcription_tools as tt + +# Use a real (but empty) audio file so _validate_audio_file passes. +audio_path = os.path.join(home, "audio.ogg") +with open(audio_path, "wb") as f: + # Minimal-ish OGG-shaped bytes so the size check passes. + f.write(b"OggS" + b"\x00" * 1024) + +# Patch _transcribe_* so the test doesn't actually try cloud APIs. +# We're testing dispatch, not the underlying transcription. +def _stub(file_path, model_name=None): + return {"success": True, "transcript": "stub from " + sys._getframe().f_code.co_name.replace("_stub_", ""), + "provider": sys._getframe().f_code.co_name.replace("_stub_", "")} + +# Stub each built-in to a marker so we can identify the branch. +class _Stub: + def __init__(self, name): + self.name = name + def __call__(self, file_path, model_name=None): + return {"success": True, "transcript": "stub", "provider": self.name} + +tt._transcribe_local = _Stub("local") +tt._transcribe_local_command = _Stub("local_command") +tt._transcribe_groq = _Stub("groq") +tt._transcribe_openai = _Stub("openai") +tt._transcribe_mistral = _Stub("mistral") +tt._transcribe_xai = _Stub("xai") + +# Force _get_provider to honor the explicit config since we don't have +# real creds. The provider-resolution gates check _HAS_OPENAI / +# _HAS_FASTER_WHISPER which we can't easily set, so we just patch +# _get_provider to return whatever the config says. +stt_cfg = tt._load_stt_config() +explicit = stt_cfg.get("provider") +if explicit: + # Bypass the gating for test purposes — _get_provider would + # otherwise return "none" when the dependency isn't installed. + original_get = tt._get_provider + def _patched(cfg): + if not tt.is_stt_enabled(cfg): + return "none" + return cfg.get("provider", "none") + tt._get_provider = _patched + +try: + result = tt.transcribe_audio(audio_path) +except Exception as exc: + shape = {"dispatch_kind": "exception", "provider_name": None, "success": False, + "error_text": repr(exc)} + print(json.dumps(shape)) + sys.exit(0) + +dispatch_kind = "unknown" +provider_name = result.get("provider") if isinstance(result, dict) else None +success = result.get("success", False) if isinstance(result, dict) else False +error_text = result.get("error", "") if isinstance(result, dict) else "" + +if not success and "STT is disabled" in error_text: + dispatch_kind = "stt_disabled" +elif not success and "is not available" in error_text: + dispatch_kind = "plugin_unavailable" +elif not success and "No STT provider" in error_text: + dispatch_kind = "no_provider_error" +elif provider_name in ("local", "local_command", "groq", "openai", "mistral", "xai"): + dispatch_kind = "builtin_" + provider_name +elif success and provider_name and provider_name not in ("local", "local_command", "groq", "openai", "mistral", "xai"): + dispatch_kind = "plugin" +else: + dispatch_kind = "other" + +shape = { + "dispatch_kind": dispatch_kind, + "provider_name": provider_name, + "success": success, +} +print(json.dumps(shape)) +""" + + +SCENARIOS: list[tuple[str, str, dict[str, str], str]] = [ + # (label, config.yaml body, scenario_env, plugin_register) + ("stt-disabled", "stt:\n enabled: false\n", {}, "no"), + ("explicit-groq", "stt:\n provider: groq\n", {}, "no"), + ("explicit-openai", "stt:\n provider: openai\n", {}, "no"), + ("explicit-local", "stt:\n provider: local\n", {}, "no"), + ("explicit-xai", "stt:\n provider: xai\n", {}, "no"), + # Mistral is quarantined → _get_provider returns "none" today, hence no_provider_error. + ("explicit-mistral-quarantine", "stt:\n provider: mistral\n", {}, "no"), + # Unknown name + no plugin → both: no_provider_error + ("unknown-no-plugin", "stt:\n provider: openrouter\n", {}, "no"), + # Unknown name + plugin installed → main: no_provider_error, PR: plugin + ("plugin-installed", "stt:\n provider: openrouter\n", {}, "yes"), + # Unknown name + plugin reports unavailable → main: no_provider_error, + # PR: plugin_unavailable (cleaner envelope, names the plugin) + ("plugin-installed-unavailable", "stt:\n provider: openrouter\n", {}, "unavailable"), + # Built-in name + plugin tries to shadow → both: built-in + ("explicit-openai-with-plugin-registered", "stt:\n provider: openai\n", {}, "yes"), +] + + +def _run_scenario(repo_path: Path, label: str, config_yaml: str, env: dict, plugin_register: str) -> dict: + venv_python = repo_path / ".venv" / "bin" / "python" + if not venv_python.exists(): + venv_python = MAIN_DIR / ".venv" / "bin" / "python" + if not venv_python.exists(): + venv_python = MAIN_DIR / "venv" / "bin" / "python" + if not venv_python.exists(): + venv_python = Path("python3") + + out = subprocess.run( + [ + str(venv_python), + "-c", + SUBPROCESS_SCRIPT, + str(repo_path), + json.dumps(env), + config_yaml, + plugin_register, + ], + capture_output=True, + text=True, + timeout=60, + ) + if out.returncode != 0: + return { + "error": "subprocess failed", + "stdout": out.stdout[-500:], + "stderr": out.stderr[-500:], + } + try: + return json.loads(out.stdout.strip().splitlines()[-1]) + except Exception as exc: + return {"error": f"could not parse output: {exc}", "stdout": out.stdout} + + +def _reduce(shape: dict) -> dict: + return { + "dispatch_kind": shape.get("dispatch_kind"), + "success": shape.get("success"), + } + + +def main() -> int: + print(f"main: {MAIN_DIR}") + print(f"pr: {PR_DIR}") + print() + + if MAIN_DIR == PR_DIR: + print( + "WARN: MAIN_DIR == PR_DIR — diffs will be trivially identical.\n" + " Set up a sibling 'hermes-agent-main' checkout pinned to " + "origin/main to get real parity coverage." + ) + print() + + failures: list[str] = [] + errors: list[str] = [] + intentional_diffs: list[tuple[str, dict, dict]] = [] + for label, config_yaml, env, plugin_register in SCENARIOS: + main_shape = _run_scenario(MAIN_DIR, label, config_yaml, env, plugin_register) + pr_shape = _run_scenario(PR_DIR, label, config_yaml, env, plugin_register) + + if "error" in main_shape or "error" in pr_shape: + print(f" [ERR ] {label}: subprocess failed") + print(f" main: {main_shape}") + print(f" pr: {pr_shape}") + errors.append(label) + continue + + main_reduced = _reduce(main_shape) + pr_reduced = _reduce(pr_shape) + + if main_reduced == pr_reduced: + print(f" [OK] {label}: {main_reduced}") + continue + + # On main, "plugin-installed" returns no_provider_error (no + # plugin hook); on PR, plugin dispatches. Same shape for + # "plugin-installed-unavailable" but PR returns the cleaner + # plugin_unavailable envelope. Both diffs are expected. + no_provider_to_plugin = ( + main_reduced.get("dispatch_kind") == "no_provider_error" + and pr_reduced.get("dispatch_kind") == "plugin" + and label == "plugin-installed" + ) + no_provider_to_unavailable = ( + main_reduced.get("dispatch_kind") == "no_provider_error" + and pr_reduced.get("dispatch_kind") == "plugin_unavailable" + and label == "plugin-installed-unavailable" + ) + if no_provider_to_plugin: + print(f" [DIFF] {label}: no_provider_error → plugin — expected") + intentional_diffs.append((label, main_reduced, pr_reduced)) + elif no_provider_to_unavailable: + print(f" [DIFF] {label}: no_provider_error → plugin_unavailable — expected") + intentional_diffs.append((label, main_reduced, pr_reduced)) + else: + print(f" [FAIL] {label}") + print(f" main: {main_reduced}") + print(f" pr: {pr_reduced}") + failures.append(label) + + print() + if errors: + print(f"SUBPROCESS ERRORS in {len(errors)} scenario(s):") + for e in errors: + print(f" - {e}") + if failures: + print(f"BEHAVIOUR REGRESSION in {len(failures)} scenario(s):") + for f in failures: + print(f" - {f}") + if intentional_diffs: + print( + f"INTENTIONAL DIFFS ({len(intentional_diffs)}): " + f"no_provider_error → plugin dispatch when a plugin is registered." + ) + if failures or errors: + return 1 + print(f"PARITY OK across {len(SCENARIOS)} scenarios.") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/tools/test_transcription_plugin_dispatch.py b/tests/tools/test_transcription_plugin_dispatch.py new file mode 100644 index 00000000000..83424676952 --- /dev/null +++ b/tests/tools/test_transcription_plugin_dispatch.py @@ -0,0 +1,462 @@ +"""Tests for STT plugin dispatch in tools/transcription_tools.py. + +Covers the resolution invariants of the new plugin dispatcher (follow-up +to #30398 — STT pluggability): + +1. Built-in provider names short-circuit — plugins NEVER win over a + built-in. Even if a plugin somehow ended up in the registry with a + built-in name (which the registry blocks), the dispatcher re-checks + defensively. +2. Unknown name with no plugin → returns None (caller surfaces the + legacy "No STT provider available" error). +3. Unknown name with plugin registered → dispatches, returns result. +4. Plugin exceptions are caught and converted to the standard error + envelope. +5. Plugin returning non-dict → caught with error envelope. +6. Plugin result has ``provider`` field stamped if missing. +""" + +from __future__ import annotations + +import pytest + +from agent import transcription_registry +from agent.transcription_provider import TranscriptionProvider +from tools import transcription_tools + + +class _FakeProvider(TranscriptionProvider): + def __init__( + self, + name: str, + result: dict | None = None, + raise_exc: BaseException | None = None, + available: bool = True, + available_raises: BaseException | None = None, + ): + self._name = name + self._result = result + self._raise_exc = raise_exc + self._available = available + self._available_raises = available_raises + self.last_call: dict | None = None + + @property + def name(self) -> str: + return self._name + + def is_available(self) -> bool: + if self._available_raises is not None: + raise self._available_raises + return self._available + + def transcribe(self, file_path: str, **kw): + self.last_call = {"file_path": file_path, "kwargs": dict(kw)} + if self._raise_exc is not None: + raise self._raise_exc + if self._result is not None: + return self._result + return {"success": True, "transcript": "fake transcript", "provider": self._name} + + +@pytest.fixture(autouse=True) +def _reset_registry(): + transcription_registry._reset_for_tests() + yield + transcription_registry._reset_for_tests() + + +# --------------------------------------------------------------------------- +# Built-in always wins +# --------------------------------------------------------------------------- + + +class TestBuiltinAlwaysWins: + """Built-in STT provider names short-circuit the dispatcher. + + Even with a plugin registered (which the registry would reject — + but the dispatcher is defensive), built-in names return None so + the caller's elif chain handles them natively. + """ + + @pytest.mark.parametrize( + "builtin", + ["local", "local_command", "groq", "openai", "mistral", "xai"], + ) + def test_dispatcher_short_circuits_builtin(self, builtin): + result = transcription_tools._dispatch_to_plugin_provider( + "/tmp/audio.mp3", builtin, + ) + assert result is None, ( + f"Built-in {builtin!r} must short-circuit plugin dispatch." + ) + + def test_dispatcher_short_circuits_none(self): + """The ``none`` sentinel from _get_provider() means no provider + available — must not reach plugin registry.""" + result = transcription_tools._dispatch_to_plugin_provider( + "/tmp/audio.mp3", "none", + ) + assert result is None + + def test_dispatcher_short_circuits_empty(self): + assert transcription_tools._dispatch_to_plugin_provider( + "/tmp/audio.mp3", "", + ) is None + + def test_dispatcher_short_circuits_builtin_case_insensitive(self): + for variant in ("OPENAI", "OpenAI", " openai ", "oPeNaI"): + assert ( + transcription_tools._dispatch_to_plugin_provider( + "/tmp/audio.mp3", variant, + ) is None + ) + + +# --------------------------------------------------------------------------- +# Unknown names +# --------------------------------------------------------------------------- + + +class TestPluginDispatch: + def test_registered_plugin_called(self): + provider = _FakeProvider(name="openrouter") + transcription_registry.register_provider(provider) + + result = transcription_tools._dispatch_to_plugin_provider( + "/tmp/audio.mp3", "openrouter", + ) + assert result is not None + assert result["success"] is True + assert result["transcript"] == "fake transcript" + assert result["provider"] == "openrouter" + assert provider.last_call is not None + assert provider.last_call["file_path"] == "/tmp/audio.mp3" + + def test_unregistered_name_returns_none(self): + """Unknown name + no plugin → return None so the caller surfaces + the legacy 'No STT provider available' error.""" + result = transcription_tools._dispatch_to_plugin_provider( + "/tmp/audio.mp3", "unknown-stt", + ) + assert result is None + + def test_model_kwarg_forwarded(self): + provider = _FakeProvider(name="openrouter") + transcription_registry.register_provider(provider) + + transcription_tools._dispatch_to_plugin_provider( + "/tmp/audio.mp3", "openrouter", model="whisper-large-v3", + ) + assert provider.last_call["kwargs"]["model"] == "whisper-large-v3" + + def test_language_kwarg_forwarded(self): + provider = _FakeProvider(name="openrouter") + transcription_registry.register_provider(provider) + + transcription_tools._dispatch_to_plugin_provider( + "/tmp/audio.mp3", "openrouter", language="en", + ) + assert provider.last_call["kwargs"]["language"] == "en" + + def test_provider_exception_converted_to_error_envelope(self): + provider = _FakeProvider(name="openrouter", raise_exc=RuntimeError("network down")) + transcription_registry.register_provider(provider) + + result = transcription_tools._dispatch_to_plugin_provider( + "/tmp/audio.mp3", "openrouter", + ) + assert result is not None + assert result["success"] is False + assert "network down" in result["error"] + assert result["transcript"] == "" + assert result["provider"] == "openrouter" + + def test_provider_non_dict_result_converted_to_error(self): + provider = _FakeProvider(name="openrouter", result="weird string") # type: ignore[arg-type] + transcription_registry.register_provider(provider) + + result = transcription_tools._dispatch_to_plugin_provider( + "/tmp/audio.mp3", "openrouter", + ) + assert result is not None + assert result["success"] is False + assert "non-dict" in result["error"] + assert result["provider"] == "openrouter" + + def test_provider_field_stamped_if_missing(self): + """If a plugin forgets to set ``provider`` in its result, the + dispatcher stamps it from the registered name.""" + provider = _FakeProvider( + name="openrouter", + result={"success": True, "transcript": "hi"}, # no provider key + ) + transcription_registry.register_provider(provider) + + result = transcription_tools._dispatch_to_plugin_provider( + "/tmp/audio.mp3", "openrouter", + ) + assert result is not None + assert result["provider"] == "openrouter" + + +# --------------------------------------------------------------------------- +# End-to-end via transcribe_audio +# --------------------------------------------------------------------------- + + +class TestTranscribeAudioE2E: + """transcribe_audio() routes plugin dispatch correctly when the + configured name is unknown to the built-in branches. + + Note: we mock _validate_audio_file and _get_provider so the real + file-validation and provider-resolution don't fire — we're testing + the plugin-dispatch wiring, not those helpers. + """ + + def test_unknown_name_with_plugin_dispatches(self): + from unittest.mock import patch + provider = _FakeProvider(name="openrouter") + transcription_registry.register_provider(provider) + + with patch("tools.transcription_tools._validate_audio_file", return_value=None), \ + patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openrouter"}), \ + patch("tools.transcription_tools.is_stt_enabled", return_value=True), \ + patch("tools.transcription_tools._get_provider", return_value="openrouter"): + result = transcription_tools.transcribe_audio("/tmp/audio.mp3") + + assert result["success"] is True + assert result["transcript"] == "fake transcript" + assert result["provider"] == "openrouter" + + def test_unknown_name_without_plugin_falls_to_legacy_error(self): + """When no plugin is registered for the unknown name, the + dispatcher returns None and transcribe_audio falls through to + the legacy 'No STT provider available' error message.""" + from unittest.mock import patch + + with patch("tools.transcription_tools._validate_audio_file", return_value=None), \ + patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openrouter"}), \ + patch("tools.transcription_tools.is_stt_enabled", return_value=True), \ + patch("tools.transcription_tools._get_provider", return_value="openrouter"): + result = transcription_tools.transcribe_audio("/tmp/audio.mp3") + + assert result["success"] is False + assert "No STT provider" in result["error"] + + def test_builtin_name_does_not_consult_plugin_registry(self): + """Even if a plugin's name collides with a built-in (which the + registry blocks, but defense in depth matters), transcribe_audio + with provider='groq' goes through the legacy elif chain, never + the plugin dispatcher.""" + from unittest.mock import patch + # Register a plugin that WOULD respond to 'openrouter' — but + # we're asking for 'groq', so it shouldn't be called. + provider = _FakeProvider(name="openrouter") + transcription_registry.register_provider(provider) + + with patch("tools.transcription_tools._validate_audio_file", return_value=None), \ + patch("tools.transcription_tools._load_stt_config", return_value={"provider": "groq"}), \ + patch("tools.transcription_tools._get_provider", return_value="groq"), \ + patch("tools.transcription_tools._transcribe_groq", + return_value={"success": True, "transcript": "from groq", "provider": "groq"}) as mock_groq: + result = transcription_tools.transcribe_audio("/tmp/audio.mp3") + + assert result["provider"] == "groq" + assert result["transcript"] == "from groq" + mock_groq.assert_called_once() + # Plugin was never called + assert provider.last_call is None + + +# --------------------------------------------------------------------------- +# Availability gating (codex review feedback on PR #30493) +# --------------------------------------------------------------------------- + + +class TestAvailabilityGate: + """When the configured plugin reports ``is_available() == False``, + the dispatcher MUST short-circuit with a clear unavailability + envelope instead of routing the call into a plugin that'll crash. + + The user explicitly set ``stt.provider: `` so falling + through to the generic "No STT provider available" message would + be misleading — surface the plugin's own unavailability instead. + """ + + def test_unavailable_plugin_returns_envelope_not_none(self): + provider = _FakeProvider(name="openrouter", available=False) + transcription_registry.register_provider(provider) + + result = transcription_tools._dispatch_to_plugin_provider( + "/tmp/audio.mp3", "openrouter", + ) + assert result is not None, ( + "Unavailable plugin must return an envelope, not None — " + "otherwise we fall through to the generic auto-detect error " + "even though the user explicitly opted into this plugin." + ) + assert result["success"] is False + assert result["provider"] == "openrouter" + assert "not available" in result["error"] + # Plugin's transcribe MUST NOT have been called + assert provider.last_call is None + + def test_available_plugin_dispatches_normally(self): + provider = _FakeProvider(name="openrouter", available=True) + transcription_registry.register_provider(provider) + + result = transcription_tools._dispatch_to_plugin_provider( + "/tmp/audio.mp3", "openrouter", + ) + assert result["success"] is True + assert provider.last_call is not None + + def test_is_available_raising_treated_as_unavailable(self): + """Per the ABC contract ``is_available()`` MUST NOT raise; we + defend anyway so a buggy plugin can't break dispatch.""" + provider = _FakeProvider( + name="openrouter", + available_raises=RuntimeError("creds check exploded"), + ) + transcription_registry.register_provider(provider) + + result = transcription_tools._dispatch_to_plugin_provider( + "/tmp/audio.mp3", "openrouter", + ) + assert result is not None + assert result["success"] is False + assert result["provider"] == "openrouter" + assert "not available" in result["error"] + assert provider.last_call is None + + def test_unavailable_plugin_at_transcribe_audio_level(self): + """End-to-end: ``stt.provider: openrouter`` + plugin reports + unavailable → ``transcribe_audio`` returns the unavailability + envelope, NOT the generic "No STT provider available" message. + """ + from unittest.mock import patch + provider = _FakeProvider(name="openrouter", available=False) + transcription_registry.register_provider(provider) + + with patch("tools.transcription_tools._validate_audio_file", return_value=None), \ + patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openrouter"}), \ + patch("tools.transcription_tools.is_stt_enabled", return_value=True), \ + patch("tools.transcription_tools._get_provider", return_value="openrouter"): + result = transcription_tools.transcribe_audio("/tmp/audio.mp3") + + assert result["success"] is False + # Must surface the plugin's unavailability — NOT the generic + # "No STT provider available" auto-detect-failure message. + assert "not available" in result["error"] + assert "No STT provider available" not in result["error"] + assert result["provider"] == "openrouter" + + +# --------------------------------------------------------------------------- +# Language forwarding from config (codex review feedback on PR #30493) +# --------------------------------------------------------------------------- + + +class TestLanguageForwardingFromConfig: + """``transcribe_audio`` must forward ``stt..language`` + from config to the plugin (mirrors how built-ins read + ``stt.local.language``). + """ + + def test_language_read_from_provider_namespaced_config(self): + """``stt.openrouter.language: ja`` reaches the plugin's + transcribe() call as language='ja'.""" + from unittest.mock import patch + provider = _FakeProvider(name="openrouter") + transcription_registry.register_provider(provider) + + stt_config = { + "provider": "openrouter", + "openrouter": {"language": "ja"}, + } + with patch("tools.transcription_tools._validate_audio_file", return_value=None), \ + patch("tools.transcription_tools._load_stt_config", return_value=stt_config), \ + patch("tools.transcription_tools.is_stt_enabled", return_value=True), \ + patch("tools.transcription_tools._get_provider", return_value="openrouter"): + transcription_tools.transcribe_audio("/tmp/audio.mp3") + + assert provider.last_call is not None + assert provider.last_call["kwargs"]["language"] == "ja" + + def test_model_from_provider_namespaced_config(self): + """``stt.openrouter.model: whisper-large-v3`` reaches the + plugin as model='whisper-large-v3' when caller doesn't + override.""" + from unittest.mock import patch + provider = _FakeProvider(name="openrouter") + transcription_registry.register_provider(provider) + + stt_config = { + "provider": "openrouter", + "openrouter": {"model": "whisper-large-v3"}, + } + with patch("tools.transcription_tools._validate_audio_file", return_value=None), \ + patch("tools.transcription_tools._load_stt_config", return_value=stt_config), \ + patch("tools.transcription_tools.is_stt_enabled", return_value=True), \ + patch("tools.transcription_tools._get_provider", return_value="openrouter"): + transcription_tools.transcribe_audio("/tmp/audio.mp3") + + assert provider.last_call["kwargs"]["model"] == "whisper-large-v3" + + def test_caller_model_overrides_config_model(self): + """An explicit ``model`` arg to transcribe_audio wins over + ``stt..model`` in config.""" + from unittest.mock import patch + provider = _FakeProvider(name="openrouter") + transcription_registry.register_provider(provider) + + stt_config = { + "provider": "openrouter", + "openrouter": {"model": "config-model"}, + } + with patch("tools.transcription_tools._validate_audio_file", return_value=None), \ + patch("tools.transcription_tools._load_stt_config", return_value=stt_config), \ + patch("tools.transcription_tools.is_stt_enabled", return_value=True), \ + patch("tools.transcription_tools._get_provider", return_value="openrouter"): + transcription_tools.transcribe_audio( + "/tmp/audio.mp3", model="explicit-arg-model", + ) + + assert provider.last_call["kwargs"]["model"] == "explicit-arg-model" + + def test_missing_provider_namespace_passes_none(self): + """No ``stt.`` subsection → language is None, + model falls back to caller arg or None. No crash.""" + from unittest.mock import patch + provider = _FakeProvider(name="openrouter") + transcription_registry.register_provider(provider) + + with patch("tools.transcription_tools._validate_audio_file", return_value=None), \ + patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openrouter"}), \ + patch("tools.transcription_tools.is_stt_enabled", return_value=True), \ + patch("tools.transcription_tools._get_provider", return_value="openrouter"): + transcription_tools.transcribe_audio("/tmp/audio.mp3") + + assert provider.last_call["kwargs"]["language"] is None + assert provider.last_call["kwargs"]["model"] is None + + def test_non_dict_provider_namespace_does_not_crash(self): + """If someone accidentally writes ``stt.openrouter: "foo"`` (a + string instead of a dict), we should not crash — treat as + empty config.""" + from unittest.mock import patch + provider = _FakeProvider(name="openrouter") + transcription_registry.register_provider(provider) + + stt_config = {"provider": "openrouter", "openrouter": "garbage"} + with patch("tools.transcription_tools._validate_audio_file", return_value=None), \ + patch("tools.transcription_tools._load_stt_config", return_value=stt_config), \ + patch("tools.transcription_tools.is_stt_enabled", return_value=True), \ + patch("tools.transcription_tools._get_provider", return_value="openrouter"): + result = transcription_tools.transcribe_audio("/tmp/audio.mp3") + + # Should still dispatch successfully (config is just ignored) + assert result["success"] is True + assert provider.last_call["kwargs"]["language"] is None + assert provider.last_call["kwargs"]["model"] is None diff --git a/tools/transcription_tools.py b/tools/transcription_tools.py index a9af32023f3..a9c59ea9bfb 100644 --- a/tools/transcription_tools.py +++ b/tools/transcription_tools.py @@ -217,6 +217,22 @@ def _try_lazy_install_stt() -> bool: return False +# Names of the 6 STT providers with native handlers in this module. +# Kept in sync with ``agent.transcription_registry._BUILTIN_NAMES`` — +# a regression test fails if they drift. The plugin hook from +# issue #30398-style follow-up rejects plugins registering under any +# of these names; the dispatcher in ``transcribe_audio`` short-circuits +# them defensively as well. +BUILTIN_STT_PROVIDERS = frozenset({ + "local", + "local_command", + "groq", + "openai", + "mistral", + "xai", +}) + + def _get_provider(stt_config: dict) -> str: """Determine which STT provider to use. @@ -327,6 +343,142 @@ def _get_provider(stt_config: dict) -> str: pass return "none" + +# --------------------------------------------------------------------------- +# Plugin provider dispatch (issue follow-up to #30398 — STT pluggability) +# --------------------------------------------------------------------------- + + +def _dispatch_to_plugin_provider( + file_path: str, + provider: str, + *, + model: Optional[str] = None, + language: Optional[str] = None, +) -> Optional[Dict[str, Any]]: + """Route the call to a plugin-registered transcription provider, or + return None. + + Returns the transcribe-response dict on dispatch, or ``None`` to + fall through to the legacy "No STT provider available" error path. + + Resolution invariants enforced here: + + 1. Built-in provider names short-circuit — never reach the plugin + registry. The caller (``transcribe_audio``) handles ``local``, + ``groq``, ``openai``, etc. via its existing elif chain; this + function defensively rejects those names so a plugin can't be + silently dispatched under a built-in name even if it somehow + slipped past the registry's built-in shadow guard. + 2. Plugin dispatch fires only when ``provider`` matches a + registered :class:`TranscriptionProvider` whose ``name`` equals + the configured value. Unknown names with no plugin registered + return None (caller surfaces the legacy "No STT provider" + message). + 3. Availability gating: when the matched plugin reports + ``is_available() == False`` (missing API key, missing optional + SDK, etc.) this returns an error envelope identifying the + plugin as unavailable — **not** ``None`` — because the user + explicitly opted into this plugin via ``stt.provider`` and the + generic fallthrough message would be misleading. + + Provider exceptions are caught and converted into the standard + error envelope (matches the legacy built-in error shapes — the + gateway/CLI caller already expects ``{success: False, error: + "...", transcript: ""}`` on failure). + """ + if not provider: + return None + key = provider.lower().strip() + if key in BUILTIN_STT_PROVIDERS or key == "none": + return None + try: + from agent.transcription_registry import get_provider + from hermes_cli.plugins import _ensure_plugins_discovered + + _ensure_plugins_discovered() + plugin_provider = get_provider(key) + if plugin_provider is None: + # Long-lived sessions may have discovered plugins before a + # bundled backend was patched in or before config changed. + # Retry once with a forced refresh before surfacing fall- + # through. Mirrors the image_gen / browser dispatcher + # recovery pattern. + _ensure_plugins_discovered(force=True) + plugin_provider = get_provider(key) + except Exception as exc: # noqa: BLE001 — discovery failure is non-fatal + logger.debug("STT plugin dispatch skipped (discovery failed): %s", exc) + return None + if plugin_provider is None: + return None + + # Availability gate: when a plugin reports it's not configured + # (missing API key, missing optional SDK, etc.) surface a clean + # error envelope **instead of** falling through to the generic + # "No STT provider" message. The user explicitly set + # ``stt.provider: `` in config — surfacing the plugin's + # own availability failure is more actionable than the generic + # auto-detect-failure error, and avoids routing the call into a + # plugin that's about to crash messily. + # + # ``is_available()`` MUST NOT raise per the ABC contract; defend + # anyway so a buggy plugin can't break dispatch for everyone. + try: + available = plugin_provider.is_available() + except Exception as exc: # noqa: BLE001 + logger.warning( + "STT plugin provider '%s' is_available() raised: %s — " + "treating as unavailable", key, exc, exc_info=True, + ) + available = False + if not available: + logger.info( + "STT plugin provider '%s' reports not available; returning " + "unavailability envelope.", key, + ) + return { + "success": False, + "transcript": "", + "error": ( + f"STT plugin '{key}' is not available — check that its " + "required credentials / dependencies are configured." + ), + "provider": key, + } + + logger.info("Transcribing with plugin STT provider '%s'...", key) + try: + result = plugin_provider.transcribe( + file_path, + model=model, + language=language, + ) + except Exception as exc: # noqa: BLE001 + logger.warning( + "STT plugin provider '%s' raised: %s", key, exc, exc_info=True, + ) + return { + "success": False, + "transcript": "", + "error": f"STT plugin '{key}' raised: {exc}", + "provider": key, + } + + # Defensive: plugins should return a dict matching the contract. If + # they don't, surface a clear error envelope rather than leaking a + # weird object back to the gateway. + if not isinstance(result, dict): + return { + "success": False, + "transcript": "", + "error": f"STT plugin '{key}' returned a non-dict result", + "provider": key, + } + # Stamp provider if the plugin forgot to. + result.setdefault("provider", key) + return result + + # --------------------------------------------------------------------------- # Shared validation # --------------------------------------------------------------------------- @@ -906,6 +1058,30 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A model_name = model or "grok-stt" return _transcribe_xai(file_path, model_name) + # Plugin-registered STT backend (e.g. OpenRouter, SenseAudio, + # Gemini-STT). Fires only when ``provider`` is neither a built-in + # nor ``"none"``. The dispatcher enforces built-ins-always-win + # defensively. Returns None when no plugin is registered for the + # configured name, falling through to the legacy "No STT provider" + # error message below. + # + # Plugin-scoped config namespace mirrors the built-in pattern + # (``stt.openai.model``, ``stt.mistral.model``): plugins read their + # per-provider config under ``stt.`` and the dispatcher + # forwards ``language`` from there. Top-level ``model`` argument + # overrides any config-set model. + plugin_cfg = stt_config.get(provider, {}) if isinstance(stt_config.get(provider), dict) else {} + plugin_language = plugin_cfg.get("language") + plugin_model = model or plugin_cfg.get("model") + plugin_result = _dispatch_to_plugin_provider( + file_path, + provider, + model=plugin_model, + language=plugin_language, + ) + if plugin_result is not None: + return plugin_result + # No provider available return { "success": False, diff --git a/website/docs/user-guide/features/plugins.md b/website/docs/user-guide/features/plugins.md index 0aade649cf7..781fa5e8f06 100644 --- a/website/docs/user-guide/features/plugins.md +++ b/website/docs/user-guide/features/plugins.md @@ -235,7 +235,7 @@ The table above shows the four plugin categories, but within "General plugins" t | An **image-generation backend** (DALL·E, SDXL, …) | Backend plugin — `ctx.register_image_gen_provider()` | [Image Generation Provider Plugins](/developer-guide/image-gen-provider-plugin) | | A **video-generation backend** (Veo, Kling, Pixverse, Grok-Imagine, Runway, …) | Backend plugin — `ctx.register_video_gen_provider()` | [Video Generation Provider Plugins](/developer-guide/video-gen-provider-plugin) | | A **TTS backend** (any CLI — Piper, VoxCPM, Kokoro, xtts, voice-cloning scripts, …) | Config-driven (recommended) — declare under `tts.providers.` with `type: command` in `config.yaml`. OR Python backend plugin — `ctx.register_tts_provider()` for Python-SDK / streaming engines that need more than a shell template. | [TTS Setup](/user-guide/features/tts#custom-command-providers) · [Python plugin guide](/user-guide/features/tts#python-plugin-providers) | -| An **STT backend** (custom whisper binary, local ASR CLI) | Config-driven — set `HERMES_LOCAL_STT_COMMAND` env var to a shell template | [Voice Message Transcription (STT)](/user-guide/features/tts#voice-message-transcription-stt) | +| An **STT backend** (any CLI — whisper.cpp, custom whisper binary, local ASR CLI) | Config-driven (recommended) — declare under `stt.providers.` with `type: command` in `config.yaml`, or set `HERMES_LOCAL_STT_COMMAND` for the legacy single-command escape hatch. OR Python backend plugin — `ctx.register_transcription_provider()` for Python-SDK engines (OpenRouter, SenseAudio, Gemini-STT, etc.). | [STT Setup](/user-guide/features/tts#stt-custom-command-providers) · [Python plugin guide](/user-guide/features/tts#python-plugin-providers-stt) | | **External tools via MCP** (filesystem, GitHub, Linear, Notion, any MCP server) | Config-driven — declare `mcp_servers.` with `command:` / `url:` in `config.yaml`. Hermes auto-discovers the server's tools and registers them alongside built-ins. | [MCP](/user-guide/features/mcp) | | **Additional skill sources** (custom GitHub repos, private skill indexes) | CLI — `hermes skills tap add ` | [Skills Hub](/user-guide/features/skills#skills-hub) · [Publishing a custom tap](/user-guide/features/skills#publishing-a-custom-skill-tap) | | **Gateway event hooks** (fire on `gateway:startup`, `session:start`, `agent:end`, `command:*`) | Drop `HOOK.yaml` + `handler.py` into `~/.hermes/hooks//` | [Event Hooks](/user-guide/features/hooks#gateway-event-hooks) | diff --git a/website/docs/user-guide/features/tts.md b/website/docs/user-guide/features/tts.md index 588a5119e51..35d7ade0b59 100644 --- a/website/docs/user-guide/features/tts.md +++ b/website/docs/user-guide/features/tts.md @@ -454,3 +454,101 @@ If your configured provider isn't available, Hermes automatically falls back: - **OpenAI key not set** → Falls back to local transcription, then Groq - **Mistral key/SDK not set** → Skipped in auto-detect; falls through to next available provider - **Nothing available** → Voice messages pass through with an accurate note to the user + +### Python plugin providers (STT) + +For STT engines that aren't built-in (OpenRouter, SenseAudio, Gemini-STT, Deepgram, custom proprietary backends), register a Python plugin via `ctx.register_transcription_provider()`. The plugin **coexists with** the 6 built-in providers (`local`, `local_command`, `groq`, `openai`, `mistral`, `xai`) — those keep their native implementations and always win on name collision. + +#### Resolution order + +1. **`stt.provider` is a built-in name** → built-in dispatch. **Always wins.** +2. **`stt.provider` matches a plugin-registered `TranscriptionProvider`** → plugin dispatch: + - if the plugin's `is_available()` returns `False` (missing creds or SDK), the call surfaces an unavailability error envelope identifying the plugin — **not** the generic "No STT provider available" message. + - otherwise the plugin's `transcribe()` is called with `model` (from the public `model=` arg, falling back to `stt..model`) and `language` (from `stt..language`). +3. **No match** → "No STT provider available" error. + +#### Per-provider config namespace + +Plugins read their per-provider configuration from `stt.` in `config.yaml`, mirroring how built-ins read `stt.openai.model` / `stt.mistral.model`: + +```yaml +stt: + provider: my-stt + my-stt: + model: whisper-large-v3 + language: ja # forwarded as language= to transcribe() + # any other plugin-specific keys go here; read them via your + # own config.yaml access in __init__/is_available/transcribe +``` + +The dispatcher forwards `model` and `language` from this section; everything else, the plugin can read itself. + +#### Minimal plugin + +Drop this in `~/.hermes/plugins/my-stt/`: + +`plugin.yaml`: +```yaml +name: my-stt +version: 0.1.0 +description: "My custom Python STT backend" +``` + +`__init__.py`: +```python +from agent.transcription_provider import TranscriptionProvider + + +class MySTTProvider(TranscriptionProvider): + @property + def name(self) -> str: + return "my-stt" # what stt.provider matches against + + @property + def display_name(self) -> str: + return "My Custom STT" + + def is_available(self) -> bool: + # Return False when credentials/deps are missing — picker skips + # this row but the dispatcher still routes here on explicit config. + import os + return bool(os.environ.get("MY_STT_API_KEY")) + + def transcribe(self, file_path, *, model=None, language=None, **extra): + # Return the standard transcribe envelope: + # {"success": bool, "transcript": str, "provider": str, "error": str} + # Do NOT raise — convert exceptions to the error envelope so the + # gateway/CLI caller sees a consistent shape on failure. + try: + import my_stt_sdk + client = my_stt_sdk.Client() + text = client.transcribe(open(file_path, "rb")) + return { + "success": True, + "transcript": text, + "provider": "my-stt", + } + except Exception as exc: + return { + "success": False, + "transcript": "", + "error": f"my-stt failed: {exc}", + "provider": "my-stt", + } + + +def register(ctx): + ctx.register_transcription_provider(MySTTProvider()) +``` + +Enable it (`hermes plugins enable my-stt`), set `stt.provider: my-stt` in `config.yaml`, and voice-message transcription will route through your plugin. + +#### Optional hooks + +Override these on your provider class for richer integration: + +- `list_models()` → list of `{id, display, languages, max_audio_seconds}` dicts. +- `default_model()` → string returned when the user doesn't override the model. +- `get_setup_schema()` → return `{name, badge, tag, env_vars: [{key, prompt, url}]}` to power picker rows in `hermes tools` / `hermes setup` (the picker category for STT is not yet shipped — this metadata is available to plugins for forward compatibility). + +See `agent/transcription_provider.py` for the full ABC including docstrings.