mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-29 06:31:32 +00:00
feat(stt): add register_transcription_provider() plugin hook
Add an opt-in Python plugin surface for speech-to-text backends,
mirroring the TTS hook pattern. New backends (OpenRouter, SenseAudio,
Gemini-STT, custom proprietary engines) can be implemented as plugins
without modifying tools/transcription_tools.py.
Built-ins always win
--------------------
The 6 built-in STT providers (local/faster-whisper, local_command,
groq, openai, mistral, xai) keep their native handlers. Plugins
attempting to register under a built-in name are rejected at
registration time with a warning and re-checked defensively at
dispatch.
Resolution order
----------------
1. stt.provider matches a built-in → built-in dispatch (unchanged)
2. stt.provider matches a registered plugin →
a. if plugin.is_available() returns False → unavailability envelope
identifying the plugin (not the generic "No STT provider"
message — the user explicitly opted into this plugin)
b. otherwise plugin.transcribe() with model + language forwarded
from stt.<provider>.{model,language} config
3. No match → legacy "No STT provider available" error (unchanged)
Per-provider config namespace
-----------------------------
Plugins read their config from stt.<provider> in config.yaml, mirroring
how built-ins read stt.openai.model / stt.mistral.model. The dispatcher
forwards `model` and `language` from this section. Caller's explicit
`model=` argument overrides the config-set model.
Files
-----
- agent/transcription_provider.py: TranscriptionProvider ABC
- agent/transcription_registry.py: register/get/list providers,
built-in shadow guard, _reset_for_tests
- hermes_cli/plugins.py: register_transcription_provider() on
PluginContext
- tools/transcription_tools.py: BUILTIN_STT_PROVIDERS frozenset,
_dispatch_to_plugin_provider() with availability gate, wire-in
after xai branch and before "No STT provider" error
- tests/agent/test_transcription_registry.py: 27 tests
- tests/hermes_cli/test_plugins_transcription_registration.py: 3 tests
- tests/tools/test_transcription_plugin_dispatch.py: 28 tests
(covering built-in short-circuit, plugin dispatch, exception
envelope, non-dict guard, availability gate, language forwarding)
- tests/plugins/transcription/check_parity_vs_main.py: 10-scenario
subprocess-pinned parity harness vs origin/main
- website/docs/user-guide/features/{tts,plugins}.md: docs
Behavior parity
---------------
10 scenarios, 8 OK + 2 expected DIFFs:
no_provider_error → plugin (plugin-installed scenario)
no_provider_error → plugin_unavailable (plugin-installed-unavailable
scenario; PR returns cleaner envelope)
Zero behavior change for users not opting into a plugin.
Issue follow-up to #30398.
This commit is contained in:
parent
2e0ac31a72
commit
2cd952e110
11 changed files with 1831 additions and 1 deletions
193
agent/transcription_provider.py
Normal file
193
agent/transcription_provider.py
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
"""
|
||||
Transcription Provider ABC
|
||||
==========================
|
||||
|
||||
Defines the pluggable-backend interface for speech-to-text. Providers
|
||||
register instances via
|
||||
:meth:`PluginContext.register_transcription_provider`; the active one
|
||||
(selected via ``stt.provider`` in ``config.yaml``) services every
|
||||
:func:`tools.transcription_tools.transcribe_audio` call **when the
|
||||
configured name is neither a built-in (``local``, ``local_command``,
|
||||
``groq``, ``openai``, ``mistral``, ``xai``) nor disabled**.
|
||||
|
||||
Two coexisting STT extension surfaces — in resolution order:
|
||||
|
||||
1. **Built-in providers** (``BUILTIN_STT_PROVIDERS`` in
|
||||
:mod:`tools.transcription_tools`) — native Python implementations
|
||||
for the 6 backends shipped today (faster-whisper, local_command,
|
||||
Groq, OpenAI, Mistral, xAI). **Always win** — plugins cannot
|
||||
shadow them. The single-env-var shell escape hatch
|
||||
``HERMES_LOCAL_STT_COMMAND`` is preserved via the built-in
|
||||
``local_command`` path.
|
||||
2. **Plugin-registered providers** (this ABC). For new STT backends —
|
||||
OpenRouter, SenseAudio, Gemini-STT, custom proprietary engines —
|
||||
that need a Python implementation without modifying
|
||||
``tools/transcription_tools.py``.
|
||||
|
||||
Built-ins-always-win is enforced at registration time
|
||||
(:func:`agent.transcription_registry.register_provider` rejects names
|
||||
in ``BUILTIN_STT_PROVIDERS`` with a warning) AND at dispatch time
|
||||
(:func:`tools.transcription_tools._dispatch_to_plugin_provider`
|
||||
re-checks defensively).
|
||||
|
||||
Providers live in ``<repo>/plugins/transcription/<name>/`` (built-in
|
||||
plugins, none shipped today) or
|
||||
``~/.hermes/plugins/transcription/<name>/`` (user-installed).
|
||||
|
||||
Response contract
|
||||
-----------------
|
||||
:meth:`TranscriptionProvider.transcribe` returns a dict with keys::
|
||||
|
||||
success bool
|
||||
transcript str transcribed text (empty when success=False)
|
||||
provider str provider name (for diagnostics)
|
||||
error str only when success=False
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ABC
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TranscriptionProvider(abc.ABC):
|
||||
"""Abstract base class for a speech-to-text backend.
|
||||
|
||||
Subclasses must implement :attr:`name` and :meth:`transcribe`.
|
||||
Everything else has sane defaults — override only what your provider
|
||||
needs.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Stable short identifier used in ``stt.provider`` config.
|
||||
|
||||
Lowercase, no spaces. Examples: ``openrouter``, ``sensaudio``,
|
||||
``gemini``, ``deepgram``. Names that collide with a built-in STT
|
||||
provider (``local``, ``local_command``, ``groq``, ``openai``,
|
||||
``mistral``, ``xai``) are rejected at registration time.
|
||||
"""
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Human-readable label shown in ``hermes tools``.
|
||||
|
||||
Defaults to ``name.title()``.
|
||||
"""
|
||||
return self.name.title()
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Return True when this provider can service calls.
|
||||
|
||||
Typically checks for a required API key + that the SDK is
|
||||
importable. Default: True (providers with no external
|
||||
dependencies are always available).
|
||||
|
||||
Must NOT raise — used by the picker and ``hermes setup`` for
|
||||
availability displays and should fail gracefully.
|
||||
"""
|
||||
return True
|
||||
|
||||
def list_models(self) -> List[Dict[str, Any]]:
|
||||
"""Return model catalog entries.
|
||||
|
||||
Each entry::
|
||||
|
||||
{
|
||||
"id": "whisper-large-v3-turbo", # required
|
||||
"display": "Whisper Large v3 Turbo", # optional
|
||||
"languages": ["en", "es", "fr"], # optional
|
||||
"max_audio_seconds": 1500, # optional
|
||||
}
|
||||
|
||||
Default: empty list (provider has a single fixed model or
|
||||
doesn't expose model selection).
|
||||
"""
|
||||
return []
|
||||
|
||||
def default_model(self) -> Optional[str]:
|
||||
"""Return the default model id, or None if not applicable."""
|
||||
models = self.list_models()
|
||||
if models:
|
||||
return models[0].get("id")
|
||||
return None
|
||||
|
||||
def get_setup_schema(self) -> Dict[str, Any]:
|
||||
"""Return provider metadata for the ``hermes tools`` picker.
|
||||
|
||||
Used by ``tools_config.py`` to inject this provider as a row in
|
||||
the Speech-to-Text provider list. Shape::
|
||||
|
||||
{
|
||||
"name": "OpenRouter STT", # picker label
|
||||
"badge": "paid", # optional short tag
|
||||
"tag": "Whisper via OpenRouter API", # optional subtitle
|
||||
"env_vars": [ # keys to prompt for
|
||||
{"key": "OPENROUTER_API_KEY",
|
||||
"prompt": "OpenRouter API key",
|
||||
"url": "https://openrouter.ai/keys"},
|
||||
],
|
||||
}
|
||||
|
||||
Default: minimal entry derived from ``display_name`` with no
|
||||
env vars. Override to expose API key prompts and custom badges.
|
||||
"""
|
||||
return {
|
||||
"name": self.display_name,
|
||||
"badge": "",
|
||||
"tag": "",
|
||||
"env_vars": [],
|
||||
}
|
||||
|
||||
@abc.abstractmethod
|
||||
def transcribe(
|
||||
self,
|
||||
file_path: str,
|
||||
*,
|
||||
model: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
**extra: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Transcribe the audio file at ``file_path``.
|
||||
|
||||
Returns a dict with the standard envelope::
|
||||
|
||||
{
|
||||
"success": True,
|
||||
"transcript": "the transcribed text",
|
||||
"provider": "<this provider's name>",
|
||||
}
|
||||
|
||||
or on failure::
|
||||
|
||||
{
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": "human-readable error message",
|
||||
"provider": "<this provider's name>",
|
||||
}
|
||||
|
||||
Implementations should NOT raise — convert exceptions to the
|
||||
error envelope so the dispatcher can deliver a consistent shape
|
||||
to the gateway/CLI caller.
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the audio file. The dispatcher
|
||||
has already validated existence + size before calling.
|
||||
model: Model identifier from :meth:`list_models`, or None
|
||||
to use :meth:`default_model`.
|
||||
language: Optional BCP-47 language hint (e.g. ``"en"``,
|
||||
``"ja"``) — providers without language hints should
|
||||
ignore this argument.
|
||||
**extra: Forward-compat parameters future schema versions
|
||||
may expose. Implementations should ignore unknown keys.
|
||||
"""
|
||||
122
agent/transcription_registry.py
Normal file
122
agent/transcription_registry.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""
|
||||
Transcription Provider Registry
|
||||
================================
|
||||
|
||||
Central map of registered STT providers. Populated by plugins at
|
||||
import-time via :meth:`PluginContext.register_transcription_provider`;
|
||||
consumed by :mod:`tools.transcription_tools` to dispatch
|
||||
:func:`transcribe_audio` calls to the active plugin backend **when**
|
||||
the configured ``stt.provider`` name is not a built-in.
|
||||
|
||||
Built-ins-always-win
|
||||
--------------------
|
||||
Plugin names that collide with a built-in STT provider (``local``,
|
||||
``local_command``, ``groq``, ``openai``, ``mistral``, ``xai``) are
|
||||
rejected at registration with a warning. This invariant is also
|
||||
re-checked at dispatch time in
|
||||
:func:`tools.transcription_tools._dispatch_to_plugin_provider`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from agent.transcription_provider import TranscriptionProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Names reserved for native built-in STT handlers. Plugins cannot
|
||||
# register a name in this set — the registration call is rejected with
|
||||
# a warning. **Kept in sync with ``BUILTIN_STT_PROVIDERS`` in
|
||||
# :mod:`tools.transcription_tools`** — a regression test in
|
||||
# ``tests/agent/test_transcription_registry.py::TestBuiltinSync``
|
||||
# fails if the two lists drift. Importing from
|
||||
# ``tools.transcription_tools`` directly would create a circular
|
||||
# dependency (``tools.transcription_tools`` imports
|
||||
# ``agent.transcription_registry`` for dispatch).
|
||||
_BUILTIN_NAMES = frozenset({
|
||||
"local",
|
||||
"local_command",
|
||||
"groq",
|
||||
"openai",
|
||||
"mistral",
|
||||
"xai",
|
||||
})
|
||||
|
||||
|
||||
_providers: Dict[str, TranscriptionProvider] = {}
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
def register_provider(provider: TranscriptionProvider) -> None:
|
||||
"""Register a transcription provider.
|
||||
|
||||
Rejects:
|
||||
|
||||
- Non-:class:`TranscriptionProvider` instances (raises :class:`TypeError`).
|
||||
- Empty/whitespace ``.name`` (raises :class:`ValueError`).
|
||||
- Names colliding with a built-in (logs a warning, silently
|
||||
ignores — built-ins-always-win invariant).
|
||||
|
||||
Re-registration (same ``name``) overwrites the previous entry and
|
||||
logs a debug message — makes hot-reload scenarios (tests, dev
|
||||
loops) behave predictably.
|
||||
"""
|
||||
if not isinstance(provider, TranscriptionProvider):
|
||||
raise TypeError(
|
||||
f"register_provider() expects a TranscriptionProvider instance, "
|
||||
f"got {type(provider).__name__}"
|
||||
)
|
||||
name = provider.name
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
raise ValueError("Transcription provider .name must be a non-empty string")
|
||||
key = name.strip().lower()
|
||||
if key in _BUILTIN_NAMES:
|
||||
logger.warning(
|
||||
"Transcription provider '%s' shadows a built-in name; registration "
|
||||
"ignored. Built-in STT providers (%s) always win — pick a different "
|
||||
"name.",
|
||||
key, ", ".join(sorted(_BUILTIN_NAMES)),
|
||||
)
|
||||
return
|
||||
with _lock:
|
||||
existing = _providers.get(key)
|
||||
_providers[key] = provider
|
||||
if existing is not None:
|
||||
logger.debug(
|
||||
"Transcription provider '%s' re-registered (was %r)",
|
||||
key, type(existing).__name__,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Registered transcription provider '%s' (%s)",
|
||||
key, type(provider).__name__,
|
||||
)
|
||||
|
||||
|
||||
def list_providers() -> List[TranscriptionProvider]:
|
||||
"""Return all registered providers, sorted by name."""
|
||||
with _lock:
|
||||
items = list(_providers.values())
|
||||
return sorted(items, key=lambda p: p.name)
|
||||
|
||||
|
||||
def get_provider(name: str) -> Optional[TranscriptionProvider]:
|
||||
"""Return the provider registered under *name*, or None.
|
||||
|
||||
Name matching is case-insensitive and whitespace-tolerant — mirrors
|
||||
how ``tools.transcription_tools._get_provider`` normalizes the
|
||||
configured ``stt.provider`` value.
|
||||
"""
|
||||
if not isinstance(name, str):
|
||||
return None
|
||||
return _providers.get(name.strip().lower())
|
||||
|
||||
|
||||
def _reset_for_tests() -> None:
|
||||
"""Clear the registry. **Test-only.**"""
|
||||
with _lock:
|
||||
_providers.clear()
|
||||
|
|
@ -678,6 +678,50 @@ class PluginContext:
|
|||
self.manifest.name, provider.name,
|
||||
)
|
||||
|
||||
# -- transcription (STT) provider registration ---------------------------
|
||||
|
||||
def register_transcription_provider(self, provider) -> None:
|
||||
"""Register a speech-to-text backend.
|
||||
|
||||
``provider`` must be an instance of
|
||||
:class:`agent.transcription_provider.TranscriptionProvider`.
|
||||
The ``provider.name`` attribute is what ``stt.provider`` in
|
||||
``config.yaml`` matches against when routing
|
||||
:func:`tools.transcription_tools.transcribe_audio` calls —
|
||||
**but only when**:
|
||||
|
||||
1. ``provider.name`` is NOT a built-in STT provider name
|
||||
(``local``, ``local_command``, ``groq``, ``openai``,
|
||||
``mistral``, ``xai``). Built-ins always win — the registry
|
||||
rejects shadowing names with a warning.
|
||||
2. There is NO ``stt.providers.<name>: type: command`` entry
|
||||
with the same name. Command-providers win on name
|
||||
collision because config is more local than plugin install
|
||||
— same precedence rule as TTS.
|
||||
|
||||
Coexists with the in-tree dispatcher and the STT
|
||||
command-provider registry rather than replacing them. The 6
|
||||
built-in STT backends keep their native implementations in
|
||||
``tools/transcription_tools.py``; this hook is for *new* Python
|
||||
engines (OpenRouter, SenseAudio, Gemini-STT, custom proprietary
|
||||
backends).
|
||||
"""
|
||||
from agent.transcription_provider import TranscriptionProvider
|
||||
from agent.transcription_registry import register_provider as _register_stt_provider
|
||||
|
||||
if not isinstance(provider, TranscriptionProvider):
|
||||
logger.warning(
|
||||
"Plugin '%s' tried to register a transcription provider that "
|
||||
"does not inherit from TranscriptionProvider. Ignoring.",
|
||||
self.manifest.name,
|
||||
)
|
||||
return
|
||||
_register_stt_provider(provider)
|
||||
logger.info(
|
||||
"Plugin '%s' registered transcription provider: %s",
|
||||
self.manifest.name, provider.name,
|
||||
)
|
||||
|
||||
# -- platform adapter registration ---------------------------------------
|
||||
|
||||
def register_platform(
|
||||
|
|
|
|||
243
tests/agent/test_transcription_registry.py
Normal file
243
tests/agent/test_transcription_registry.py
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
"""Tests for agent/transcription_registry.py and agent/transcription_provider.py.
|
||||
|
||||
Covers:
|
||||
- Registration happy path
|
||||
- Registration rejection: non-TranscriptionProvider type
|
||||
- Registration rejection: empty/whitespace name
|
||||
- Built-in name shadowing: warning + silent ignore (no exception)
|
||||
- Re-registration: overwrites + logs at debug
|
||||
- Case + whitespace insensitivity on lookup
|
||||
- ABC contract: default implementations work
|
||||
- ABC contract: transcribe() must be implemented
|
||||
- Sync invariant: registry built-ins match tools/transcription_tools.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from agent import transcription_registry
|
||||
from agent.transcription_provider import TranscriptionProvider
|
||||
|
||||
|
||||
class _FakeProvider(TranscriptionProvider):
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "fake",
|
||||
display: Optional[str] = None,
|
||||
available: bool = True,
|
||||
transcribe_impl: Optional[Any] = None,
|
||||
):
|
||||
self._name = name
|
||||
self._display = display
|
||||
self._available = available
|
||||
self._transcribe_impl = transcribe_impl
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._display if self._display is not None else super().display_name
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return self._available
|
||||
|
||||
def transcribe(self, file_path: str, **kw):
|
||||
if self._transcribe_impl is not None:
|
||||
return self._transcribe_impl(file_path, **kw)
|
||||
return {"success": True, "transcript": f"fake({file_path})", "provider": self._name}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
transcription_registry._reset_for_tests()
|
||||
yield
|
||||
transcription_registry._reset_for_tests()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistration:
|
||||
def test_happy_path(self):
|
||||
p = _FakeProvider(name="openrouter")
|
||||
transcription_registry.register_provider(p)
|
||||
assert transcription_registry.get_provider("openrouter") is p
|
||||
assert [r.name for r in transcription_registry.list_providers()] == ["openrouter"]
|
||||
|
||||
def test_rejects_non_provider_type(self):
|
||||
with pytest.raises(TypeError, match="expects a TranscriptionProvider instance"):
|
||||
transcription_registry.register_provider("not a provider") # type: ignore[arg-type]
|
||||
assert transcription_registry.list_providers() == []
|
||||
|
||||
def test_rejects_empty_name(self):
|
||||
p = _FakeProvider(name="")
|
||||
with pytest.raises(ValueError, match="non-empty string"):
|
||||
transcription_registry.register_provider(p)
|
||||
assert transcription_registry.list_providers() == []
|
||||
|
||||
def test_rejects_whitespace_name(self):
|
||||
p = _FakeProvider(name=" ")
|
||||
with pytest.raises(ValueError, match="non-empty string"):
|
||||
transcription_registry.register_provider(p)
|
||||
assert transcription_registry.list_providers() == []
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"builtin",
|
||||
["local", "local_command", "groq", "openai", "mistral", "xai"],
|
||||
)
|
||||
def test_rejects_builtin_shadow_with_warning(self, builtin, caplog):
|
||||
p = _FakeProvider(name=builtin)
|
||||
with caplog.at_level(logging.WARNING, logger="agent.transcription_registry"):
|
||||
transcription_registry.register_provider(p)
|
||||
assert "shadows a built-in name" in caplog.text
|
||||
assert builtin in caplog.text
|
||||
assert transcription_registry.get_provider(builtin) is None
|
||||
assert transcription_registry.list_providers() == []
|
||||
|
||||
def test_builtin_shadow_case_insensitive(self, caplog):
|
||||
for variant in ("OPENAI", "OpenAi", " openai ", "oPeNaI"):
|
||||
transcription_registry._reset_for_tests()
|
||||
with caplog.at_level(logging.WARNING, logger="agent.transcription_registry"):
|
||||
transcription_registry.register_provider(_FakeProvider(name=variant))
|
||||
assert transcription_registry.list_providers() == [], (
|
||||
f"variant {variant!r} should have been rejected as a built-in shadow"
|
||||
)
|
||||
|
||||
def test_reregistration_overwrites(self, caplog):
|
||||
p1 = _FakeProvider(name="openrouter")
|
||||
p2 = _FakeProvider(name="openrouter")
|
||||
transcription_registry.register_provider(p1)
|
||||
with caplog.at_level(logging.DEBUG, logger="agent.transcription_registry"):
|
||||
transcription_registry.register_provider(p2)
|
||||
assert transcription_registry.get_provider("openrouter") is p2
|
||||
assert "re-registered" in caplog.text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLookup:
|
||||
def test_get_provider_missing_returns_none(self):
|
||||
assert transcription_registry.get_provider("nonexistent") is None
|
||||
|
||||
def test_get_provider_non_string_returns_none(self):
|
||||
assert transcription_registry.get_provider(None) is None # type: ignore[arg-type]
|
||||
assert transcription_registry.get_provider(123) is None # type: ignore[arg-type]
|
||||
|
||||
def test_get_provider_case_insensitive(self):
|
||||
p = _FakeProvider(name="openrouter")
|
||||
transcription_registry.register_provider(p)
|
||||
assert transcription_registry.get_provider("OPENROUTER") is p
|
||||
assert transcription_registry.get_provider("OpenRouter") is p
|
||||
|
||||
def test_get_provider_whitespace_tolerant(self):
|
||||
p = _FakeProvider(name="openrouter")
|
||||
transcription_registry.register_provider(p)
|
||||
assert transcription_registry.get_provider(" openrouter ") is p
|
||||
|
||||
def test_list_providers_sorted(self):
|
||||
transcription_registry.register_provider(_FakeProvider(name="zylo"))
|
||||
transcription_registry.register_provider(_FakeProvider(name="alpha"))
|
||||
transcription_registry.register_provider(_FakeProvider(name="middle"))
|
||||
names = [p.name for p in transcription_registry.list_providers()]
|
||||
assert names == ["alpha", "middle", "zylo"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ABC contract
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestABCContract:
|
||||
def test_must_implement_transcribe(self):
|
||||
class Incomplete(TranscriptionProvider):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "incomplete"
|
||||
# transcribe NOT implemented
|
||||
|
||||
with pytest.raises(TypeError, match="abstract"):
|
||||
Incomplete() # type: ignore[abstract]
|
||||
|
||||
def test_must_implement_name(self):
|
||||
class Incomplete(TranscriptionProvider):
|
||||
def transcribe(self, file_path, **kw):
|
||||
return {"success": True, "transcript": "", "provider": "incomplete"}
|
||||
# name NOT implemented
|
||||
|
||||
with pytest.raises(TypeError, match="abstract"):
|
||||
Incomplete() # type: ignore[abstract]
|
||||
|
||||
def test_display_name_defaults_to_title(self):
|
||||
p = _FakeProvider(name="openrouter")
|
||||
assert p.display_name == "Openrouter"
|
||||
|
||||
def test_display_name_override_respected(self):
|
||||
p = _FakeProvider(name="openrouter", display="OpenRouter STT")
|
||||
assert p.display_name == "OpenRouter STT"
|
||||
|
||||
def test_is_available_default_true(self):
|
||||
p = _FakeProvider(name="openrouter")
|
||||
assert p.is_available() is True
|
||||
|
||||
def test_list_models_default_empty(self):
|
||||
p = _FakeProvider(name="openrouter")
|
||||
assert p.list_models() == []
|
||||
|
||||
def test_default_model_none_when_no_models(self):
|
||||
p = _FakeProvider(name="openrouter")
|
||||
assert p.default_model() is None
|
||||
|
||||
def test_default_model_first_listed(self):
|
||||
class WithModels(_FakeProvider):
|
||||
def list_models(self):
|
||||
return [{"id": "whisper-large-v3-turbo"}, {"id": "whisper-large-v3"}]
|
||||
|
||||
p = WithModels(name="openrouter")
|
||||
assert p.default_model() == "whisper-large-v3-turbo"
|
||||
|
||||
def test_get_setup_schema_default_minimal(self):
|
||||
p = _FakeProvider(name="openrouter")
|
||||
schema = p.get_setup_schema()
|
||||
assert schema["name"] == "Openrouter"
|
||||
assert schema["env_vars"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync invariant: registry built-ins vs dispatcher built-ins
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuiltinSync:
|
||||
"""``_BUILTIN_NAMES`` in agent/transcription_registry.py is duplicated
|
||||
from ``BUILTIN_STT_PROVIDERS`` in tools/transcription_tools.py
|
||||
(importing directly would create a circular dependency). This test
|
||||
fails loudly if the two lists drift — a new built-in added to
|
||||
transcription_tools.py MUST also be added to
|
||||
transcription_registry.py's ``_BUILTIN_NAMES`` or the registry will
|
||||
accept a name the dispatcher will silently route to the wrong
|
||||
handler.
|
||||
"""
|
||||
|
||||
def test_registry_builtins_match_dispatcher_builtins(self):
|
||||
from tools.transcription_tools import BUILTIN_STT_PROVIDERS
|
||||
|
||||
assert transcription_registry._BUILTIN_NAMES == BUILTIN_STT_PROVIDERS, (
|
||||
"agent.transcription_registry._BUILTIN_NAMES and "
|
||||
"tools.transcription_tools.BUILTIN_STT_PROVIDERS have drifted!\n"
|
||||
f" Registry only: {sorted(transcription_registry._BUILTIN_NAMES - BUILTIN_STT_PROVIDERS)}\n"
|
||||
f" Dispatcher only: {sorted(BUILTIN_STT_PROVIDERS - transcription_registry._BUILTIN_NAMES)}\n"
|
||||
"Add the missing names to whichever list is incomplete. "
|
||||
"These two lists exist as a circular-import workaround and "
|
||||
"MUST be kept in sync manually."
|
||||
)
|
||||
148
tests/hermes_cli/test_plugins_transcription_registration.py
Normal file
148
tests/hermes_cli/test_plugins_transcription_registration.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""Tests for PluginContext.register_transcription_provider().
|
||||
|
||||
Exercises the plugin context hook end-to-end: drops a fake plugin into
|
||||
``$HERMES_HOME/plugins/``, runs ``PluginManager().discover_and_load()``,
|
||||
and asserts the registration result.
|
||||
|
||||
Mirrors the shape of ``test_plugins_tts_registration.py`` (companion
|
||||
TTS hook from issue #30398).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def _write_plugin(
|
||||
root: Path,
|
||||
name: str,
|
||||
*,
|
||||
manifest_extra: Dict[str, Any] | None = None,
|
||||
register_body: str = "pass",
|
||||
) -> Path:
|
||||
plugin_dir = root / name
|
||||
plugin_dir.mkdir(parents=True, exist_ok=True)
|
||||
manifest = {
|
||||
"name": name,
|
||||
"version": "0.1.0",
|
||||
"description": f"Test plugin {name}",
|
||||
}
|
||||
if manifest_extra:
|
||||
manifest.update(manifest_extra)
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump(manifest))
|
||||
(plugin_dir / "__init__.py").write_text(
|
||||
f"def register(ctx):\n {register_body}\n"
|
||||
)
|
||||
return plugin_dir
|
||||
|
||||
|
||||
def _enable(hermes_home: Path, name: str) -> None:
|
||||
cfg_path = hermes_home / "config.yaml"
|
||||
cfg: dict = {}
|
||||
if cfg_path.exists():
|
||||
try:
|
||||
cfg = yaml.safe_load(cfg_path.read_text()) or {}
|
||||
except Exception:
|
||||
cfg = {}
|
||||
plugins_cfg = cfg.setdefault("plugins", {})
|
||||
enabled = plugins_cfg.setdefault("enabled", [])
|
||||
if isinstance(enabled, list) and name not in enabled:
|
||||
enabled.append(name)
|
||||
cfg_path.write_text(yaml.safe_dump(cfg))
|
||||
|
||||
|
||||
class TestRegisterTranscriptionProvider:
|
||||
def test_accepts_valid_provider(self):
|
||||
from hermes_cli.plugins import PluginManager
|
||||
|
||||
from agent import transcription_registry
|
||||
transcription_registry._reset_for_tests()
|
||||
|
||||
hermes_home = Path(os.environ["HERMES_HOME"])
|
||||
_write_plugin(
|
||||
hermes_home / "plugins",
|
||||
"my-stt-plugin",
|
||||
register_body=(
|
||||
"from agent.transcription_provider import TranscriptionProvider\n"
|
||||
" class P(TranscriptionProvider):\n"
|
||||
" @property\n"
|
||||
" def name(self): return 'fake-stt'\n"
|
||||
" def transcribe(self, file_path, **kw):\n"
|
||||
" return {'success': True, 'transcript': 'hi', 'provider': 'fake-stt'}\n"
|
||||
" ctx.register_transcription_provider(P())"
|
||||
),
|
||||
)
|
||||
_enable(hermes_home, "my-stt-plugin")
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert mgr._plugins["my-stt-plugin"].enabled is True, (
|
||||
f"Plugin failed to load: {mgr._plugins['my-stt-plugin'].error}"
|
||||
)
|
||||
assert transcription_registry.get_provider("fake-stt") is not None
|
||||
|
||||
transcription_registry._reset_for_tests()
|
||||
|
||||
def test_rejects_non_provider(self, caplog):
|
||||
from hermes_cli.plugins import PluginManager
|
||||
|
||||
from agent import transcription_registry
|
||||
transcription_registry._reset_for_tests()
|
||||
|
||||
hermes_home = Path(os.environ["HERMES_HOME"])
|
||||
_write_plugin(
|
||||
hermes_home / "plugins",
|
||||
"bad-stt-plugin",
|
||||
register_body="ctx.register_transcription_provider('not a provider')",
|
||||
)
|
||||
_enable(hermes_home, "bad-stt-plugin")
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert mgr._plugins["bad-stt-plugin"].enabled is True
|
||||
assert transcription_registry.get_provider("not a provider") is None
|
||||
assert transcription_registry.list_providers() == []
|
||||
assert "does not inherit from TranscriptionProvider" in caplog.text
|
||||
|
||||
transcription_registry._reset_for_tests()
|
||||
|
||||
def test_rejects_builtin_shadow(self, caplog):
|
||||
from hermes_cli.plugins import PluginManager
|
||||
|
||||
from agent import transcription_registry
|
||||
transcription_registry._reset_for_tests()
|
||||
|
||||
hermes_home = Path(os.environ["HERMES_HOME"])
|
||||
_write_plugin(
|
||||
hermes_home / "plugins",
|
||||
"shadow-stt-plugin",
|
||||
register_body=(
|
||||
"from agent.transcription_provider import TranscriptionProvider\n"
|
||||
" class P(TranscriptionProvider):\n"
|
||||
" @property\n"
|
||||
" def name(self): return 'openai'\n"
|
||||
" def transcribe(self, file_path, **kw):\n"
|
||||
" return {'success': True, 'transcript': 'hi'}\n"
|
||||
" ctx.register_transcription_provider(P())"
|
||||
),
|
||||
)
|
||||
_enable(hermes_home, "shadow-stt-plugin")
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
# Plugin still loaded normally — built-in shadowing is a warning,
|
||||
# not an exception. The registry rejects the entry though.
|
||||
assert mgr._plugins["shadow-stt-plugin"].enabled is True
|
||||
assert transcription_registry.get_provider("openai") is None
|
||||
assert "shadows a built-in name" in caplog.text
|
||||
|
||||
transcription_registry._reset_for_tests()
|
||||
0
tests/plugins/transcription/__init__.py
Normal file
0
tests/plugins/transcription/__init__.py
Normal file
344
tests/plugins/transcription/check_parity_vs_main.py
Normal file
344
tests/plugins/transcription/check_parity_vs_main.py
Normal file
|
|
@ -0,0 +1,344 @@
|
|||
"""Behavior-parity check for the STT plugin hook (follow-up to #30398).
|
||||
|
||||
Spawns one subprocess per (version, scenario) cell — pinned to either
|
||||
``origin/main`` (no plugin hook; ``stt.provider: openrouter`` falls
|
||||
through to the "No STT provider available" error path) or this PR's
|
||||
worktree (plugin hook present; same config routes through the plugin
|
||||
registry when a plugin is registered).
|
||||
|
||||
Each subprocess clears all STT-related env vars + writes a
|
||||
``config.yaml``, then asks the dispatcher how it would route a
|
||||
``transcribe_audio`` call. The emitted shape tuple is::
|
||||
|
||||
{dispatch_kind, provider_name, success}
|
||||
|
||||
Where ``dispatch_kind`` ∈
|
||||
``{"builtin_local", "builtin_groq", "builtin_openai", ...,
|
||||
"plugin", "plugin_unavailable", "no_provider_error", "stt_disabled"}``.
|
||||
|
||||
Acceptable diffs:
|
||||
- ``no_provider_error → plugin`` for the ``plugin-installed`` scenario.
|
||||
- ``no_provider_error → plugin_unavailable`` for the
|
||||
``plugin-installed-unavailable`` scenario (PR returns the cleaner
|
||||
unavailability envelope instead of the generic auto-detect error).
|
||||
|
||||
Run from the PR worktree::
|
||||
|
||||
python tests/plugins/transcription/check_parity_vs_main.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[3]
|
||||
|
||||
|
||||
def _resolve_main_dir() -> Path:
|
||||
candidate = REPO_ROOT.parent.parent
|
||||
if (candidate / "tools" / "transcription_tools.py").exists() and candidate != REPO_ROOT:
|
||||
return candidate
|
||||
sibling = REPO_ROOT.parent / "hermes-agent-main"
|
||||
if (sibling / "tools" / "transcription_tools.py").exists():
|
||||
return sibling
|
||||
return REPO_ROOT
|
||||
|
||||
|
||||
MAIN_DIR = _resolve_main_dir()
|
||||
PR_DIR = REPO_ROOT
|
||||
assert (PR_DIR / "tools" / "transcription_tools.py").exists(), (
|
||||
f"PR_DIR={PR_DIR} doesn't look like a hermes-agent checkout"
|
||||
)
|
||||
|
||||
|
||||
SUBPROCESS_SCRIPT = r"""
|
||||
import json, os, sys, tempfile
|
||||
sys.path.insert(0, sys.argv[1])
|
||||
|
||||
# Isolated HERMES_HOME so the config write is hermetic.
|
||||
home = tempfile.mkdtemp()
|
||||
os.environ["HERMES_HOME"] = home
|
||||
|
||||
# Clear STT-related env so dispatch decisions are config-driven.
|
||||
for k in (
|
||||
"GROQ_API_KEY", "OPENAI_API_KEY", "VOICE_TOOLS_OPENAI_KEY",
|
||||
"MISTRAL_API_KEY", "XAI_API_KEY",
|
||||
"HERMES_LOCAL_STT_COMMAND",
|
||||
):
|
||||
os.environ.pop(k, None)
|
||||
|
||||
scenario_env = json.loads(sys.argv[2])
|
||||
os.environ.update(scenario_env)
|
||||
|
||||
config_yaml = sys.argv[3]
|
||||
plugin_register = sys.argv[4] # "yes" to register a fake plugin
|
||||
|
||||
config_path = os.path.join(home, "config.yaml")
|
||||
with open(config_path, "w") as f:
|
||||
f.write(config_yaml)
|
||||
|
||||
# Fresh import — must not have anything cached from prior runs.
|
||||
for name in list(sys.modules):
|
||||
if (name.startswith("tools.")
|
||||
or name.startswith("agent.")
|
||||
or name.startswith("plugins.")
|
||||
or name.startswith("hermes_cli.")):
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
# Try importing transcription_registry — only exists on PR side.
|
||||
have_plugin_hook = False
|
||||
try:
|
||||
from agent import transcription_registry
|
||||
from agent.transcription_provider import TranscriptionProvider
|
||||
have_plugin_hook = True
|
||||
|
||||
if plugin_register == "yes":
|
||||
class _FakeProvider(TranscriptionProvider):
|
||||
@property
|
||||
def name(self): return "openrouter"
|
||||
def transcribe(self, file_path, **kw):
|
||||
return {"success": True, "transcript": "plugin transcript", "provider": "openrouter"}
|
||||
|
||||
transcription_registry._reset_for_tests()
|
||||
transcription_registry.register_provider(_FakeProvider())
|
||||
elif plugin_register == "unavailable":
|
||||
class _UnavailablePlugin(TranscriptionProvider):
|
||||
@property
|
||||
def name(self): return "openrouter"
|
||||
def is_available(self): return False
|
||||
def transcribe(self, file_path, **kw):
|
||||
return {"success": True, "transcript": "should not run"}
|
||||
|
||||
transcription_registry._reset_for_tests()
|
||||
transcription_registry.register_provider(_UnavailablePlugin())
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import tools.transcription_tools as tt
|
||||
|
||||
# Use a real (but empty) audio file so _validate_audio_file passes.
|
||||
audio_path = os.path.join(home, "audio.ogg")
|
||||
with open(audio_path, "wb") as f:
|
||||
# Minimal-ish OGG-shaped bytes so the size check passes.
|
||||
f.write(b"OggS" + b"\x00" * 1024)
|
||||
|
||||
# Patch _transcribe_* so the test doesn't actually try cloud APIs.
|
||||
# We're testing dispatch, not the underlying transcription.
|
||||
def _stub(file_path, model_name=None):
|
||||
return {"success": True, "transcript": "stub from " + sys._getframe().f_code.co_name.replace("_stub_", ""),
|
||||
"provider": sys._getframe().f_code.co_name.replace("_stub_", "")}
|
||||
|
||||
# Stub each built-in to a marker so we can identify the branch.
|
||||
class _Stub:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def __call__(self, file_path, model_name=None):
|
||||
return {"success": True, "transcript": "stub", "provider": self.name}
|
||||
|
||||
tt._transcribe_local = _Stub("local")
|
||||
tt._transcribe_local_command = _Stub("local_command")
|
||||
tt._transcribe_groq = _Stub("groq")
|
||||
tt._transcribe_openai = _Stub("openai")
|
||||
tt._transcribe_mistral = _Stub("mistral")
|
||||
tt._transcribe_xai = _Stub("xai")
|
||||
|
||||
# Force _get_provider to honor the explicit config since we don't have
|
||||
# real creds. The provider-resolution gates check _HAS_OPENAI /
|
||||
# _HAS_FASTER_WHISPER which we can't easily set, so we just patch
|
||||
# _get_provider to return whatever the config says.
|
||||
stt_cfg = tt._load_stt_config()
|
||||
explicit = stt_cfg.get("provider")
|
||||
if explicit:
|
||||
# Bypass the gating for test purposes — _get_provider would
|
||||
# otherwise return "none" when the dependency isn't installed.
|
||||
original_get = tt._get_provider
|
||||
def _patched(cfg):
|
||||
if not tt.is_stt_enabled(cfg):
|
||||
return "none"
|
||||
return cfg.get("provider", "none")
|
||||
tt._get_provider = _patched
|
||||
|
||||
try:
|
||||
result = tt.transcribe_audio(audio_path)
|
||||
except Exception as exc:
|
||||
shape = {"dispatch_kind": "exception", "provider_name": None, "success": False,
|
||||
"error_text": repr(exc)}
|
||||
print(json.dumps(shape))
|
||||
sys.exit(0)
|
||||
|
||||
dispatch_kind = "unknown"
|
||||
provider_name = result.get("provider") if isinstance(result, dict) else None
|
||||
success = result.get("success", False) if isinstance(result, dict) else False
|
||||
error_text = result.get("error", "") if isinstance(result, dict) else ""
|
||||
|
||||
if not success and "STT is disabled" in error_text:
|
||||
dispatch_kind = "stt_disabled"
|
||||
elif not success and "is not available" in error_text:
|
||||
dispatch_kind = "plugin_unavailable"
|
||||
elif not success and "No STT provider" in error_text:
|
||||
dispatch_kind = "no_provider_error"
|
||||
elif provider_name in ("local", "local_command", "groq", "openai", "mistral", "xai"):
|
||||
dispatch_kind = "builtin_" + provider_name
|
||||
elif success and provider_name and provider_name not in ("local", "local_command", "groq", "openai", "mistral", "xai"):
|
||||
dispatch_kind = "plugin"
|
||||
else:
|
||||
dispatch_kind = "other"
|
||||
|
||||
shape = {
|
||||
"dispatch_kind": dispatch_kind,
|
||||
"provider_name": provider_name,
|
||||
"success": success,
|
||||
}
|
||||
print(json.dumps(shape))
|
||||
"""
|
||||
|
||||
|
||||
SCENARIOS: list[tuple[str, str, dict[str, str], str]] = [
|
||||
# (label, config.yaml body, scenario_env, plugin_register)
|
||||
("stt-disabled", "stt:\n enabled: false\n", {}, "no"),
|
||||
("explicit-groq", "stt:\n provider: groq\n", {}, "no"),
|
||||
("explicit-openai", "stt:\n provider: openai\n", {}, "no"),
|
||||
("explicit-local", "stt:\n provider: local\n", {}, "no"),
|
||||
("explicit-xai", "stt:\n provider: xai\n", {}, "no"),
|
||||
# Mistral is quarantined → _get_provider returns "none" today, hence no_provider_error.
|
||||
("explicit-mistral-quarantine", "stt:\n provider: mistral\n", {}, "no"),
|
||||
# Unknown name + no plugin → both: no_provider_error
|
||||
("unknown-no-plugin", "stt:\n provider: openrouter\n", {}, "no"),
|
||||
# Unknown name + plugin installed → main: no_provider_error, PR: plugin
|
||||
("plugin-installed", "stt:\n provider: openrouter\n", {}, "yes"),
|
||||
# Unknown name + plugin reports unavailable → main: no_provider_error,
|
||||
# PR: plugin_unavailable (cleaner envelope, names the plugin)
|
||||
("plugin-installed-unavailable", "stt:\n provider: openrouter\n", {}, "unavailable"),
|
||||
# Built-in name + plugin tries to shadow → both: built-in
|
||||
("explicit-openai-with-plugin-registered", "stt:\n provider: openai\n", {}, "yes"),
|
||||
]
|
||||
|
||||
|
||||
def _run_scenario(repo_path: Path, label: str, config_yaml: str, env: dict, plugin_register: str) -> dict:
|
||||
venv_python = repo_path / ".venv" / "bin" / "python"
|
||||
if not venv_python.exists():
|
||||
venv_python = MAIN_DIR / ".venv" / "bin" / "python"
|
||||
if not venv_python.exists():
|
||||
venv_python = MAIN_DIR / "venv" / "bin" / "python"
|
||||
if not venv_python.exists():
|
||||
venv_python = Path("python3")
|
||||
|
||||
out = subprocess.run(
|
||||
[
|
||||
str(venv_python),
|
||||
"-c",
|
||||
SUBPROCESS_SCRIPT,
|
||||
str(repo_path),
|
||||
json.dumps(env),
|
||||
config_yaml,
|
||||
plugin_register,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
if out.returncode != 0:
|
||||
return {
|
||||
"error": "subprocess failed",
|
||||
"stdout": out.stdout[-500:],
|
||||
"stderr": out.stderr[-500:],
|
||||
}
|
||||
try:
|
||||
return json.loads(out.stdout.strip().splitlines()[-1])
|
||||
except Exception as exc:
|
||||
return {"error": f"could not parse output: {exc}", "stdout": out.stdout}
|
||||
|
||||
|
||||
def _reduce(shape: dict) -> dict:
|
||||
return {
|
||||
"dispatch_kind": shape.get("dispatch_kind"),
|
||||
"success": shape.get("success"),
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
print(f"main: {MAIN_DIR}")
|
||||
print(f"pr: {PR_DIR}")
|
||||
print()
|
||||
|
||||
if MAIN_DIR == PR_DIR:
|
||||
print(
|
||||
"WARN: MAIN_DIR == PR_DIR — diffs will be trivially identical.\n"
|
||||
" Set up a sibling 'hermes-agent-main' checkout pinned to "
|
||||
"origin/main to get real parity coverage."
|
||||
)
|
||||
print()
|
||||
|
||||
failures: list[str] = []
|
||||
errors: list[str] = []
|
||||
intentional_diffs: list[tuple[str, dict, dict]] = []
|
||||
for label, config_yaml, env, plugin_register in SCENARIOS:
|
||||
main_shape = _run_scenario(MAIN_DIR, label, config_yaml, env, plugin_register)
|
||||
pr_shape = _run_scenario(PR_DIR, label, config_yaml, env, plugin_register)
|
||||
|
||||
if "error" in main_shape or "error" in pr_shape:
|
||||
print(f" [ERR ] {label}: subprocess failed")
|
||||
print(f" main: {main_shape}")
|
||||
print(f" pr: {pr_shape}")
|
||||
errors.append(label)
|
||||
continue
|
||||
|
||||
main_reduced = _reduce(main_shape)
|
||||
pr_reduced = _reduce(pr_shape)
|
||||
|
||||
if main_reduced == pr_reduced:
|
||||
print(f" [OK] {label}: {main_reduced}")
|
||||
continue
|
||||
|
||||
# On main, "plugin-installed" returns no_provider_error (no
|
||||
# plugin hook); on PR, plugin dispatches. Same shape for
|
||||
# "plugin-installed-unavailable" but PR returns the cleaner
|
||||
# plugin_unavailable envelope. Both diffs are expected.
|
||||
no_provider_to_plugin = (
|
||||
main_reduced.get("dispatch_kind") == "no_provider_error"
|
||||
and pr_reduced.get("dispatch_kind") == "plugin"
|
||||
and label == "plugin-installed"
|
||||
)
|
||||
no_provider_to_unavailable = (
|
||||
main_reduced.get("dispatch_kind") == "no_provider_error"
|
||||
and pr_reduced.get("dispatch_kind") == "plugin_unavailable"
|
||||
and label == "plugin-installed-unavailable"
|
||||
)
|
||||
if no_provider_to_plugin:
|
||||
print(f" [DIFF] {label}: no_provider_error → plugin — expected")
|
||||
intentional_diffs.append((label, main_reduced, pr_reduced))
|
||||
elif no_provider_to_unavailable:
|
||||
print(f" [DIFF] {label}: no_provider_error → plugin_unavailable — expected")
|
||||
intentional_diffs.append((label, main_reduced, pr_reduced))
|
||||
else:
|
||||
print(f" [FAIL] {label}")
|
||||
print(f" main: {main_reduced}")
|
||||
print(f" pr: {pr_reduced}")
|
||||
failures.append(label)
|
||||
|
||||
print()
|
||||
if errors:
|
||||
print(f"SUBPROCESS ERRORS in {len(errors)} scenario(s):")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
if failures:
|
||||
print(f"BEHAVIOUR REGRESSION in {len(failures)} scenario(s):")
|
||||
for f in failures:
|
||||
print(f" - {f}")
|
||||
if intentional_diffs:
|
||||
print(
|
||||
f"INTENTIONAL DIFFS ({len(intentional_diffs)}): "
|
||||
f"no_provider_error → plugin dispatch when a plugin is registered."
|
||||
)
|
||||
if failures or errors:
|
||||
return 1
|
||||
print(f"PARITY OK across {len(SCENARIOS)} scenarios.")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
462
tests/tools/test_transcription_plugin_dispatch.py
Normal file
462
tests/tools/test_transcription_plugin_dispatch.py
Normal file
|
|
@ -0,0 +1,462 @@
|
|||
"""Tests for STT plugin dispatch in tools/transcription_tools.py.
|
||||
|
||||
Covers the resolution invariants of the new plugin dispatcher (follow-up
|
||||
to #30398 — STT pluggability):
|
||||
|
||||
1. Built-in provider names short-circuit — plugins NEVER win over a
|
||||
built-in. Even if a plugin somehow ended up in the registry with a
|
||||
built-in name (which the registry blocks), the dispatcher re-checks
|
||||
defensively.
|
||||
2. Unknown name with no plugin → returns None (caller surfaces the
|
||||
legacy "No STT provider available" error).
|
||||
3. Unknown name with plugin registered → dispatches, returns result.
|
||||
4. Plugin exceptions are caught and converted to the standard error
|
||||
envelope.
|
||||
5. Plugin returning non-dict → caught with error envelope.
|
||||
6. Plugin result has ``provider`` field stamped if missing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agent import transcription_registry
|
||||
from agent.transcription_provider import TranscriptionProvider
|
||||
from tools import transcription_tools
|
||||
|
||||
|
||||
class _FakeProvider(TranscriptionProvider):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
result: dict | None = None,
|
||||
raise_exc: BaseException | None = None,
|
||||
available: bool = True,
|
||||
available_raises: BaseException | None = None,
|
||||
):
|
||||
self._name = name
|
||||
self._result = result
|
||||
self._raise_exc = raise_exc
|
||||
self._available = available
|
||||
self._available_raises = available_raises
|
||||
self.last_call: dict | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def is_available(self) -> bool:
|
||||
if self._available_raises is not None:
|
||||
raise self._available_raises
|
||||
return self._available
|
||||
|
||||
def transcribe(self, file_path: str, **kw):
|
||||
self.last_call = {"file_path": file_path, "kwargs": dict(kw)}
|
||||
if self._raise_exc is not None:
|
||||
raise self._raise_exc
|
||||
if self._result is not None:
|
||||
return self._result
|
||||
return {"success": True, "transcript": "fake transcript", "provider": self._name}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
transcription_registry._reset_for_tests()
|
||||
yield
|
||||
transcription_registry._reset_for_tests()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Built-in always wins
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuiltinAlwaysWins:
|
||||
"""Built-in STT provider names short-circuit the dispatcher.
|
||||
|
||||
Even with a plugin registered (which the registry would reject —
|
||||
but the dispatcher is defensive), built-in names return None so
|
||||
the caller's elif chain handles them natively.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"builtin",
|
||||
["local", "local_command", "groq", "openai", "mistral", "xai"],
|
||||
)
|
||||
def test_dispatcher_short_circuits_builtin(self, builtin):
|
||||
result = transcription_tools._dispatch_to_plugin_provider(
|
||||
"/tmp/audio.mp3", builtin,
|
||||
)
|
||||
assert result is None, (
|
||||
f"Built-in {builtin!r} must short-circuit plugin dispatch."
|
||||
)
|
||||
|
||||
def test_dispatcher_short_circuits_none(self):
|
||||
"""The ``none`` sentinel from _get_provider() means no provider
|
||||
available — must not reach plugin registry."""
|
||||
result = transcription_tools._dispatch_to_plugin_provider(
|
||||
"/tmp/audio.mp3", "none",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_dispatcher_short_circuits_empty(self):
|
||||
assert transcription_tools._dispatch_to_plugin_provider(
|
||||
"/tmp/audio.mp3", "",
|
||||
) is None
|
||||
|
||||
def test_dispatcher_short_circuits_builtin_case_insensitive(self):
|
||||
for variant in ("OPENAI", "OpenAI", " openai ", "oPeNaI"):
|
||||
assert (
|
||||
transcription_tools._dispatch_to_plugin_provider(
|
||||
"/tmp/audio.mp3", variant,
|
||||
) is None
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unknown names
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPluginDispatch:
|
||||
def test_registered_plugin_called(self):
|
||||
provider = _FakeProvider(name="openrouter")
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
result = transcription_tools._dispatch_to_plugin_provider(
|
||||
"/tmp/audio.mp3", "openrouter",
|
||||
)
|
||||
assert result is not None
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "fake transcript"
|
||||
assert result["provider"] == "openrouter"
|
||||
assert provider.last_call is not None
|
||||
assert provider.last_call["file_path"] == "/tmp/audio.mp3"
|
||||
|
||||
def test_unregistered_name_returns_none(self):
|
||||
"""Unknown name + no plugin → return None so the caller surfaces
|
||||
the legacy 'No STT provider available' error."""
|
||||
result = transcription_tools._dispatch_to_plugin_provider(
|
||||
"/tmp/audio.mp3", "unknown-stt",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_model_kwarg_forwarded(self):
|
||||
provider = _FakeProvider(name="openrouter")
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
transcription_tools._dispatch_to_plugin_provider(
|
||||
"/tmp/audio.mp3", "openrouter", model="whisper-large-v3",
|
||||
)
|
||||
assert provider.last_call["kwargs"]["model"] == "whisper-large-v3"
|
||||
|
||||
def test_language_kwarg_forwarded(self):
|
||||
provider = _FakeProvider(name="openrouter")
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
transcription_tools._dispatch_to_plugin_provider(
|
||||
"/tmp/audio.mp3", "openrouter", language="en",
|
||||
)
|
||||
assert provider.last_call["kwargs"]["language"] == "en"
|
||||
|
||||
def test_provider_exception_converted_to_error_envelope(self):
|
||||
provider = _FakeProvider(name="openrouter", raise_exc=RuntimeError("network down"))
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
result = transcription_tools._dispatch_to_plugin_provider(
|
||||
"/tmp/audio.mp3", "openrouter",
|
||||
)
|
||||
assert result is not None
|
||||
assert result["success"] is False
|
||||
assert "network down" in result["error"]
|
||||
assert result["transcript"] == ""
|
||||
assert result["provider"] == "openrouter"
|
||||
|
||||
def test_provider_non_dict_result_converted_to_error(self):
|
||||
provider = _FakeProvider(name="openrouter", result="weird string") # type: ignore[arg-type]
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
result = transcription_tools._dispatch_to_plugin_provider(
|
||||
"/tmp/audio.mp3", "openrouter",
|
||||
)
|
||||
assert result is not None
|
||||
assert result["success"] is False
|
||||
assert "non-dict" in result["error"]
|
||||
assert result["provider"] == "openrouter"
|
||||
|
||||
def test_provider_field_stamped_if_missing(self):
|
||||
"""If a plugin forgets to set ``provider`` in its result, the
|
||||
dispatcher stamps it from the registered name."""
|
||||
provider = _FakeProvider(
|
||||
name="openrouter",
|
||||
result={"success": True, "transcript": "hi"}, # no provider key
|
||||
)
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
result = transcription_tools._dispatch_to_plugin_provider(
|
||||
"/tmp/audio.mp3", "openrouter",
|
||||
)
|
||||
assert result is not None
|
||||
assert result["provider"] == "openrouter"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-end via transcribe_audio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTranscribeAudioE2E:
|
||||
"""transcribe_audio() routes plugin dispatch correctly when the
|
||||
configured name is unknown to the built-in branches.
|
||||
|
||||
Note: we mock _validate_audio_file and _get_provider so the real
|
||||
file-validation and provider-resolution don't fire — we're testing
|
||||
the plugin-dispatch wiring, not those helpers.
|
||||
"""
|
||||
|
||||
def test_unknown_name_with_plugin_dispatches(self):
|
||||
from unittest.mock import patch
|
||||
provider = _FakeProvider(name="openrouter")
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
with patch("tools.transcription_tools._validate_audio_file", return_value=None), \
|
||||
patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openrouter"}), \
|
||||
patch("tools.transcription_tools.is_stt_enabled", return_value=True), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openrouter"):
|
||||
result = transcription_tools.transcribe_audio("/tmp/audio.mp3")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "fake transcript"
|
||||
assert result["provider"] == "openrouter"
|
||||
|
||||
def test_unknown_name_without_plugin_falls_to_legacy_error(self):
|
||||
"""When no plugin is registered for the unknown name, the
|
||||
dispatcher returns None and transcribe_audio falls through to
|
||||
the legacy 'No STT provider available' error message."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch("tools.transcription_tools._validate_audio_file", return_value=None), \
|
||||
patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openrouter"}), \
|
||||
patch("tools.transcription_tools.is_stt_enabled", return_value=True), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openrouter"):
|
||||
result = transcription_tools.transcribe_audio("/tmp/audio.mp3")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "No STT provider" in result["error"]
|
||||
|
||||
def test_builtin_name_does_not_consult_plugin_registry(self):
|
||||
"""Even if a plugin's name collides with a built-in (which the
|
||||
registry blocks, but defense in depth matters), transcribe_audio
|
||||
with provider='groq' goes through the legacy elif chain, never
|
||||
the plugin dispatcher."""
|
||||
from unittest.mock import patch
|
||||
# Register a plugin that WOULD respond to 'openrouter' — but
|
||||
# we're asking for 'groq', so it shouldn't be called.
|
||||
provider = _FakeProvider(name="openrouter")
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
with patch("tools.transcription_tools._validate_audio_file", return_value=None), \
|
||||
patch("tools.transcription_tools._load_stt_config", return_value={"provider": "groq"}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="groq"), \
|
||||
patch("tools.transcription_tools._transcribe_groq",
|
||||
return_value={"success": True, "transcript": "from groq", "provider": "groq"}) as mock_groq:
|
||||
result = transcription_tools.transcribe_audio("/tmp/audio.mp3")
|
||||
|
||||
assert result["provider"] == "groq"
|
||||
assert result["transcript"] == "from groq"
|
||||
mock_groq.assert_called_once()
|
||||
# Plugin was never called
|
||||
assert provider.last_call is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Availability gating (codex review feedback on PR #30493)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAvailabilityGate:
|
||||
"""When the configured plugin reports ``is_available() == False``,
|
||||
the dispatcher MUST short-circuit with a clear unavailability
|
||||
envelope instead of routing the call into a plugin that'll crash.
|
||||
|
||||
The user explicitly set ``stt.provider: <plugin>`` so falling
|
||||
through to the generic "No STT provider available" message would
|
||||
be misleading — surface the plugin's own unavailability instead.
|
||||
"""
|
||||
|
||||
def test_unavailable_plugin_returns_envelope_not_none(self):
|
||||
provider = _FakeProvider(name="openrouter", available=False)
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
result = transcription_tools._dispatch_to_plugin_provider(
|
||||
"/tmp/audio.mp3", "openrouter",
|
||||
)
|
||||
assert result is not None, (
|
||||
"Unavailable plugin must return an envelope, not None — "
|
||||
"otherwise we fall through to the generic auto-detect error "
|
||||
"even though the user explicitly opted into this plugin."
|
||||
)
|
||||
assert result["success"] is False
|
||||
assert result["provider"] == "openrouter"
|
||||
assert "not available" in result["error"]
|
||||
# Plugin's transcribe MUST NOT have been called
|
||||
assert provider.last_call is None
|
||||
|
||||
def test_available_plugin_dispatches_normally(self):
|
||||
provider = _FakeProvider(name="openrouter", available=True)
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
result = transcription_tools._dispatch_to_plugin_provider(
|
||||
"/tmp/audio.mp3", "openrouter",
|
||||
)
|
||||
assert result["success"] is True
|
||||
assert provider.last_call is not None
|
||||
|
||||
def test_is_available_raising_treated_as_unavailable(self):
|
||||
"""Per the ABC contract ``is_available()`` MUST NOT raise; we
|
||||
defend anyway so a buggy plugin can't break dispatch."""
|
||||
provider = _FakeProvider(
|
||||
name="openrouter",
|
||||
available_raises=RuntimeError("creds check exploded"),
|
||||
)
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
result = transcription_tools._dispatch_to_plugin_provider(
|
||||
"/tmp/audio.mp3", "openrouter",
|
||||
)
|
||||
assert result is not None
|
||||
assert result["success"] is False
|
||||
assert result["provider"] == "openrouter"
|
||||
assert "not available" in result["error"]
|
||||
assert provider.last_call is None
|
||||
|
||||
def test_unavailable_plugin_at_transcribe_audio_level(self):
|
||||
"""End-to-end: ``stt.provider: openrouter`` + plugin reports
|
||||
unavailable → ``transcribe_audio`` returns the unavailability
|
||||
envelope, NOT the generic "No STT provider available" message.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
provider = _FakeProvider(name="openrouter", available=False)
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
with patch("tools.transcription_tools._validate_audio_file", return_value=None), \
|
||||
patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openrouter"}), \
|
||||
patch("tools.transcription_tools.is_stt_enabled", return_value=True), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openrouter"):
|
||||
result = transcription_tools.transcribe_audio("/tmp/audio.mp3")
|
||||
|
||||
assert result["success"] is False
|
||||
# Must surface the plugin's unavailability — NOT the generic
|
||||
# "No STT provider available" auto-detect-failure message.
|
||||
assert "not available" in result["error"]
|
||||
assert "No STT provider available" not in result["error"]
|
||||
assert result["provider"] == "openrouter"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Language forwarding from config (codex review feedback on PR #30493)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLanguageForwardingFromConfig:
|
||||
"""``transcribe_audio`` must forward ``stt.<provider>.language``
|
||||
from config to the plugin (mirrors how built-ins read
|
||||
``stt.local.language``).
|
||||
"""
|
||||
|
||||
def test_language_read_from_provider_namespaced_config(self):
|
||||
"""``stt.openrouter.language: ja`` reaches the plugin's
|
||||
transcribe() call as language='ja'."""
|
||||
from unittest.mock import patch
|
||||
provider = _FakeProvider(name="openrouter")
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
stt_config = {
|
||||
"provider": "openrouter",
|
||||
"openrouter": {"language": "ja"},
|
||||
}
|
||||
with patch("tools.transcription_tools._validate_audio_file", return_value=None), \
|
||||
patch("tools.transcription_tools._load_stt_config", return_value=stt_config), \
|
||||
patch("tools.transcription_tools.is_stt_enabled", return_value=True), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openrouter"):
|
||||
transcription_tools.transcribe_audio("/tmp/audio.mp3")
|
||||
|
||||
assert provider.last_call is not None
|
||||
assert provider.last_call["kwargs"]["language"] == "ja"
|
||||
|
||||
def test_model_from_provider_namespaced_config(self):
|
||||
"""``stt.openrouter.model: whisper-large-v3`` reaches the
|
||||
plugin as model='whisper-large-v3' when caller doesn't
|
||||
override."""
|
||||
from unittest.mock import patch
|
||||
provider = _FakeProvider(name="openrouter")
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
stt_config = {
|
||||
"provider": "openrouter",
|
||||
"openrouter": {"model": "whisper-large-v3"},
|
||||
}
|
||||
with patch("tools.transcription_tools._validate_audio_file", return_value=None), \
|
||||
patch("tools.transcription_tools._load_stt_config", return_value=stt_config), \
|
||||
patch("tools.transcription_tools.is_stt_enabled", return_value=True), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openrouter"):
|
||||
transcription_tools.transcribe_audio("/tmp/audio.mp3")
|
||||
|
||||
assert provider.last_call["kwargs"]["model"] == "whisper-large-v3"
|
||||
|
||||
def test_caller_model_overrides_config_model(self):
|
||||
"""An explicit ``model`` arg to transcribe_audio wins over
|
||||
``stt.<provider>.model`` in config."""
|
||||
from unittest.mock import patch
|
||||
provider = _FakeProvider(name="openrouter")
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
stt_config = {
|
||||
"provider": "openrouter",
|
||||
"openrouter": {"model": "config-model"},
|
||||
}
|
||||
with patch("tools.transcription_tools._validate_audio_file", return_value=None), \
|
||||
patch("tools.transcription_tools._load_stt_config", return_value=stt_config), \
|
||||
patch("tools.transcription_tools.is_stt_enabled", return_value=True), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openrouter"):
|
||||
transcription_tools.transcribe_audio(
|
||||
"/tmp/audio.mp3", model="explicit-arg-model",
|
||||
)
|
||||
|
||||
assert provider.last_call["kwargs"]["model"] == "explicit-arg-model"
|
||||
|
||||
def test_missing_provider_namespace_passes_none(self):
|
||||
"""No ``stt.<provider>`` subsection → language is None,
|
||||
model falls back to caller arg or None. No crash."""
|
||||
from unittest.mock import patch
|
||||
provider = _FakeProvider(name="openrouter")
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
with patch("tools.transcription_tools._validate_audio_file", return_value=None), \
|
||||
patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openrouter"}), \
|
||||
patch("tools.transcription_tools.is_stt_enabled", return_value=True), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openrouter"):
|
||||
transcription_tools.transcribe_audio("/tmp/audio.mp3")
|
||||
|
||||
assert provider.last_call["kwargs"]["language"] is None
|
||||
assert provider.last_call["kwargs"]["model"] is None
|
||||
|
||||
def test_non_dict_provider_namespace_does_not_crash(self):
|
||||
"""If someone accidentally writes ``stt.openrouter: "foo"`` (a
|
||||
string instead of a dict), we should not crash — treat as
|
||||
empty config."""
|
||||
from unittest.mock import patch
|
||||
provider = _FakeProvider(name="openrouter")
|
||||
transcription_registry.register_provider(provider)
|
||||
|
||||
stt_config = {"provider": "openrouter", "openrouter": "garbage"}
|
||||
with patch("tools.transcription_tools._validate_audio_file", return_value=None), \
|
||||
patch("tools.transcription_tools._load_stt_config", return_value=stt_config), \
|
||||
patch("tools.transcription_tools.is_stt_enabled", return_value=True), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="openrouter"):
|
||||
result = transcription_tools.transcribe_audio("/tmp/audio.mp3")
|
||||
|
||||
# Should still dispatch successfully (config is just ignored)
|
||||
assert result["success"] is True
|
||||
assert provider.last_call["kwargs"]["language"] is None
|
||||
assert provider.last_call["kwargs"]["model"] is None
|
||||
|
|
@ -217,6 +217,22 @@ def _try_lazy_install_stt() -> bool:
|
|||
return False
|
||||
|
||||
|
||||
# Names of the 6 STT providers with native handlers in this module.
|
||||
# Kept in sync with ``agent.transcription_registry._BUILTIN_NAMES`` —
|
||||
# a regression test fails if they drift. The plugin hook from
|
||||
# issue #30398-style follow-up rejects plugins registering under any
|
||||
# of these names; the dispatcher in ``transcribe_audio`` short-circuits
|
||||
# them defensively as well.
|
||||
BUILTIN_STT_PROVIDERS = frozenset({
|
||||
"local",
|
||||
"local_command",
|
||||
"groq",
|
||||
"openai",
|
||||
"mistral",
|
||||
"xai",
|
||||
})
|
||||
|
||||
|
||||
def _get_provider(stt_config: dict) -> str:
|
||||
"""Determine which STT provider to use.
|
||||
|
||||
|
|
@ -327,6 +343,142 @@ def _get_provider(stt_config: dict) -> str:
|
|||
pass
|
||||
return "none"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Plugin provider dispatch (issue follow-up to #30398 — STT pluggability)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _dispatch_to_plugin_provider(
|
||||
file_path: str,
|
||||
provider: str,
|
||||
*,
|
||||
model: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Route the call to a plugin-registered transcription provider, or
|
||||
return None.
|
||||
|
||||
Returns the transcribe-response dict on dispatch, or ``None`` to
|
||||
fall through to the legacy "No STT provider available" error path.
|
||||
|
||||
Resolution invariants enforced here:
|
||||
|
||||
1. Built-in provider names short-circuit — never reach the plugin
|
||||
registry. The caller (``transcribe_audio``) handles ``local``,
|
||||
``groq``, ``openai``, etc. via its existing elif chain; this
|
||||
function defensively rejects those names so a plugin can't be
|
||||
silently dispatched under a built-in name even if it somehow
|
||||
slipped past the registry's built-in shadow guard.
|
||||
2. Plugin dispatch fires only when ``provider`` matches a
|
||||
registered :class:`TranscriptionProvider` whose ``name`` equals
|
||||
the configured value. Unknown names with no plugin registered
|
||||
return None (caller surfaces the legacy "No STT provider"
|
||||
message).
|
||||
3. Availability gating: when the matched plugin reports
|
||||
``is_available() == False`` (missing API key, missing optional
|
||||
SDK, etc.) this returns an error envelope identifying the
|
||||
plugin as unavailable — **not** ``None`` — because the user
|
||||
explicitly opted into this plugin via ``stt.provider`` and the
|
||||
generic fallthrough message would be misleading.
|
||||
|
||||
Provider exceptions are caught and converted into the standard
|
||||
error envelope (matches the legacy built-in error shapes — the
|
||||
gateway/CLI caller already expects ``{success: False, error:
|
||||
"...", transcript: ""}`` on failure).
|
||||
"""
|
||||
if not provider:
|
||||
return None
|
||||
key = provider.lower().strip()
|
||||
if key in BUILTIN_STT_PROVIDERS or key == "none":
|
||||
return None
|
||||
try:
|
||||
from agent.transcription_registry import get_provider
|
||||
from hermes_cli.plugins import _ensure_plugins_discovered
|
||||
|
||||
_ensure_plugins_discovered()
|
||||
plugin_provider = get_provider(key)
|
||||
if plugin_provider is None:
|
||||
# Long-lived sessions may have discovered plugins before a
|
||||
# bundled backend was patched in or before config changed.
|
||||
# Retry once with a forced refresh before surfacing fall-
|
||||
# through. Mirrors the image_gen / browser dispatcher
|
||||
# recovery pattern.
|
||||
_ensure_plugins_discovered(force=True)
|
||||
plugin_provider = get_provider(key)
|
||||
except Exception as exc: # noqa: BLE001 — discovery failure is non-fatal
|
||||
logger.debug("STT plugin dispatch skipped (discovery failed): %s", exc)
|
||||
return None
|
||||
if plugin_provider is None:
|
||||
return None
|
||||
|
||||
# Availability gate: when a plugin reports it's not configured
|
||||
# (missing API key, missing optional SDK, etc.) surface a clean
|
||||
# error envelope **instead of** falling through to the generic
|
||||
# "No STT provider" message. The user explicitly set
|
||||
# ``stt.provider: <plugin>`` in config — surfacing the plugin's
|
||||
# own availability failure is more actionable than the generic
|
||||
# auto-detect-failure error, and avoids routing the call into a
|
||||
# plugin that's about to crash messily.
|
||||
#
|
||||
# ``is_available()`` MUST NOT raise per the ABC contract; defend
|
||||
# anyway so a buggy plugin can't break dispatch for everyone.
|
||||
try:
|
||||
available = plugin_provider.is_available()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"STT plugin provider '%s' is_available() raised: %s — "
|
||||
"treating as unavailable", key, exc, exc_info=True,
|
||||
)
|
||||
available = False
|
||||
if not available:
|
||||
logger.info(
|
||||
"STT plugin provider '%s' reports not available; returning "
|
||||
"unavailability envelope.", key,
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": (
|
||||
f"STT plugin '{key}' is not available — check that its "
|
||||
"required credentials / dependencies are configured."
|
||||
),
|
||||
"provider": key,
|
||||
}
|
||||
|
||||
logger.info("Transcribing with plugin STT provider '%s'...", key)
|
||||
try:
|
||||
result = plugin_provider.transcribe(
|
||||
file_path,
|
||||
model=model,
|
||||
language=language,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"STT plugin provider '%s' raised: %s", key, exc, exc_info=True,
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"STT plugin '{key}' raised: {exc}",
|
||||
"provider": key,
|
||||
}
|
||||
|
||||
# Defensive: plugins should return a dict matching the contract. If
|
||||
# they don't, surface a clear error envelope rather than leaking a
|
||||
# weird object back to the gateway.
|
||||
if not isinstance(result, dict):
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"STT plugin '{key}' returned a non-dict result",
|
||||
"provider": key,
|
||||
}
|
||||
# Stamp provider if the plugin forgot to.
|
||||
result.setdefault("provider", key)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -906,6 +1058,30 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
|
|||
model_name = model or "grok-stt"
|
||||
return _transcribe_xai(file_path, model_name)
|
||||
|
||||
# Plugin-registered STT backend (e.g. OpenRouter, SenseAudio,
|
||||
# Gemini-STT). Fires only when ``provider`` is neither a built-in
|
||||
# nor ``"none"``. The dispatcher enforces built-ins-always-win
|
||||
# defensively. Returns None when no plugin is registered for the
|
||||
# configured name, falling through to the legacy "No STT provider"
|
||||
# error message below.
|
||||
#
|
||||
# Plugin-scoped config namespace mirrors the built-in pattern
|
||||
# (``stt.openai.model``, ``stt.mistral.model``): plugins read their
|
||||
# per-provider config under ``stt.<provider>`` and the dispatcher
|
||||
# forwards ``language`` from there. Top-level ``model`` argument
|
||||
# overrides any config-set model.
|
||||
plugin_cfg = stt_config.get(provider, {}) if isinstance(stt_config.get(provider), dict) else {}
|
||||
plugin_language = plugin_cfg.get("language")
|
||||
plugin_model = model or plugin_cfg.get("model")
|
||||
plugin_result = _dispatch_to_plugin_provider(
|
||||
file_path,
|
||||
provider,
|
||||
model=plugin_model,
|
||||
language=plugin_language,
|
||||
)
|
||||
if plugin_result is not None:
|
||||
return plugin_result
|
||||
|
||||
# No provider available
|
||||
return {
|
||||
"success": False,
|
||||
|
|
|
|||
|
|
@ -235,7 +235,7 @@ The table above shows the four plugin categories, but within "General plugins" t
|
|||
| An **image-generation backend** (DALL·E, SDXL, …) | Backend plugin — `ctx.register_image_gen_provider()` | [Image Generation Provider Plugins](/developer-guide/image-gen-provider-plugin) |
|
||||
| A **video-generation backend** (Veo, Kling, Pixverse, Grok-Imagine, Runway, …) | Backend plugin — `ctx.register_video_gen_provider()` | [Video Generation Provider Plugins](/developer-guide/video-gen-provider-plugin) |
|
||||
| A **TTS backend** (any CLI — Piper, VoxCPM, Kokoro, xtts, voice-cloning scripts, …) | Config-driven (recommended) — declare under `tts.providers.<name>` with `type: command` in `config.yaml`. OR Python backend plugin — `ctx.register_tts_provider()` for Python-SDK / streaming engines that need more than a shell template. | [TTS Setup](/user-guide/features/tts#custom-command-providers) · [Python plugin guide](/user-guide/features/tts#python-plugin-providers) |
|
||||
| An **STT backend** (custom whisper binary, local ASR CLI) | Config-driven — set `HERMES_LOCAL_STT_COMMAND` env var to a shell template | [Voice Message Transcription (STT)](/user-guide/features/tts#voice-message-transcription-stt) |
|
||||
| An **STT backend** (any CLI — whisper.cpp, custom whisper binary, local ASR CLI) | Config-driven (recommended) — declare under `stt.providers.<name>` with `type: command` in `config.yaml`, or set `HERMES_LOCAL_STT_COMMAND` for the legacy single-command escape hatch. OR Python backend plugin — `ctx.register_transcription_provider()` for Python-SDK engines (OpenRouter, SenseAudio, Gemini-STT, etc.). | [STT Setup](/user-guide/features/tts#stt-custom-command-providers) · [Python plugin guide](/user-guide/features/tts#python-plugin-providers-stt) |
|
||||
| **External tools via MCP** (filesystem, GitHub, Linear, Notion, any MCP server) | Config-driven — declare `mcp_servers.<name>` with `command:` / `url:` in `config.yaml`. Hermes auto-discovers the server's tools and registers them alongside built-ins. | [MCP](/user-guide/features/mcp) |
|
||||
| **Additional skill sources** (custom GitHub repos, private skill indexes) | CLI — `hermes skills tap add <repo>` | [Skills Hub](/user-guide/features/skills#skills-hub) · [Publishing a custom tap](/user-guide/features/skills#publishing-a-custom-skill-tap) |
|
||||
| **Gateway event hooks** (fire on `gateway:startup`, `session:start`, `agent:end`, `command:*`) | Drop `HOOK.yaml` + `handler.py` into `~/.hermes/hooks/<name>/` | [Event Hooks](/user-guide/features/hooks#gateway-event-hooks) |
|
||||
|
|
|
|||
|
|
@ -454,3 +454,101 @@ If your configured provider isn't available, Hermes automatically falls back:
|
|||
- **OpenAI key not set** → Falls back to local transcription, then Groq
|
||||
- **Mistral key/SDK not set** → Skipped in auto-detect; falls through to next available provider
|
||||
- **Nothing available** → Voice messages pass through with an accurate note to the user
|
||||
|
||||
### Python plugin providers (STT)
|
||||
|
||||
For STT engines that aren't built-in (OpenRouter, SenseAudio, Gemini-STT, Deepgram, custom proprietary backends), register a Python plugin via `ctx.register_transcription_provider()`. The plugin **coexists with** the 6 built-in providers (`local`, `local_command`, `groq`, `openai`, `mistral`, `xai`) — those keep their native implementations and always win on name collision.
|
||||
|
||||
#### Resolution order
|
||||
|
||||
1. **`stt.provider` is a built-in name** → built-in dispatch. **Always wins.**
|
||||
2. **`stt.provider` matches a plugin-registered `TranscriptionProvider`** → plugin dispatch:
|
||||
- if the plugin's `is_available()` returns `False` (missing creds or SDK), the call surfaces an unavailability error envelope identifying the plugin — **not** the generic "No STT provider available" message.
|
||||
- otherwise the plugin's `transcribe()` is called with `model` (from the public `model=` arg, falling back to `stt.<provider>.model`) and `language` (from `stt.<provider>.language`).
|
||||
3. **No match** → "No STT provider available" error.
|
||||
|
||||
#### Per-provider config namespace
|
||||
|
||||
Plugins read their per-provider configuration from `stt.<provider>` in `config.yaml`, mirroring how built-ins read `stt.openai.model` / `stt.mistral.model`:
|
||||
|
||||
```yaml
|
||||
stt:
|
||||
provider: my-stt
|
||||
my-stt:
|
||||
model: whisper-large-v3
|
||||
language: ja # forwarded as language= to transcribe()
|
||||
# any other plugin-specific keys go here; read them via your
|
||||
# own config.yaml access in __init__/is_available/transcribe
|
||||
```
|
||||
|
||||
The dispatcher forwards `model` and `language` from this section; everything else, the plugin can read itself.
|
||||
|
||||
#### Minimal plugin
|
||||
|
||||
Drop this in `~/.hermes/plugins/my-stt/`:
|
||||
|
||||
`plugin.yaml`:
|
||||
```yaml
|
||||
name: my-stt
|
||||
version: 0.1.0
|
||||
description: "My custom Python STT backend"
|
||||
```
|
||||
|
||||
`__init__.py`:
|
||||
```python
|
||||
from agent.transcription_provider import TranscriptionProvider
|
||||
|
||||
|
||||
class MySTTProvider(TranscriptionProvider):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "my-stt" # what stt.provider matches against
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return "My Custom STT"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
# Return False when credentials/deps are missing — picker skips
|
||||
# this row but the dispatcher still routes here on explicit config.
|
||||
import os
|
||||
return bool(os.environ.get("MY_STT_API_KEY"))
|
||||
|
||||
def transcribe(self, file_path, *, model=None, language=None, **extra):
|
||||
# Return the standard transcribe envelope:
|
||||
# {"success": bool, "transcript": str, "provider": str, "error": str}
|
||||
# Do NOT raise — convert exceptions to the error envelope so the
|
||||
# gateway/CLI caller sees a consistent shape on failure.
|
||||
try:
|
||||
import my_stt_sdk
|
||||
client = my_stt_sdk.Client()
|
||||
text = client.transcribe(open(file_path, "rb"))
|
||||
return {
|
||||
"success": True,
|
||||
"transcript": text,
|
||||
"provider": "my-stt",
|
||||
}
|
||||
except Exception as exc:
|
||||
return {
|
||||
"success": False,
|
||||
"transcript": "",
|
||||
"error": f"my-stt failed: {exc}",
|
||||
"provider": "my-stt",
|
||||
}
|
||||
|
||||
|
||||
def register(ctx):
|
||||
ctx.register_transcription_provider(MySTTProvider())
|
||||
```
|
||||
|
||||
Enable it (`hermes plugins enable my-stt`), set `stt.provider: my-stt` in `config.yaml`, and voice-message transcription will route through your plugin.
|
||||
|
||||
#### Optional hooks
|
||||
|
||||
Override these on your provider class for richer integration:
|
||||
|
||||
- `list_models()` → list of `{id, display, languages, max_audio_seconds}` dicts.
|
||||
- `default_model()` → string returned when the user doesn't override the model.
|
||||
- `get_setup_schema()` → return `{name, badge, tag, env_vars: [{key, prompt, url}]}` to power picker rows in `hermes tools` / `hermes setup` (the picker category for STT is not yet shipped — this metadata is available to plugins for forward compatibility).
|
||||
|
||||
See `agent/transcription_provider.py` for the full ABC including docstrings.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue