mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-06 07:51:53 +00:00
feat(tts): add register_tts_provider() plugin hook (closes #30398)
Adds a `TTSProvider(ABC)` + `register_tts_provider()` extension point to the plugin context API, **alongside** the existing config-driven `tts.providers.<name>: type: command` registry from PR #17843. This is additive — the command-provider surface stays as the primary way to add a TTS backend. The hook covers cases the shell-template grammar can't reasonably express: - Native Python SDKs without a CLI (Cartesia, Fish Audio, etc.) - Streaming synthesis (chunked Opus → voice-bubble delivery) - Voice metadata API for the `hermes tools` picker - OAuth-refreshing auth flows None of the 10 inline built-in providers (`edge`, `openai`, `elevenlabs`, `minimax`, `gemini`, `mistral`, `xai`, `piper`, `kittentts`, `neutts`) are migrated to plugins. They stay inline. The hook is for *new* engines that aren't built-in. ## Resolution order The dispatcher's resolution order is the load-bearing invariant: 1. `tts.provider` is a built-in name → built-in dispatch. **Always wins.** 2. `tts.provider` matches `tts.providers.<name>` with `command:` set → command-provider dispatch (PR #17843). 3. `tts.provider` matches a plugin-registered `TTSProvider` → plugin dispatch (new). 4. No match → falls through to Edge TTS default (legacy behavior). Built-ins-always-win is enforced at THREE layers: - Registry: `register_provider()` rejects shadowing names with a warning. - Dispatcher: `_dispatch_to_plugin_provider()` short-circuits built-in names defensively before consulting the registry. - Picker: `_plugin_tts_providers()` filters built-in shadows out of the `hermes tools` row list defensively. Command-providers-win-over-plugins is enforced at TWO layers: - The caller in `text_to_speech_tool` checks `_resolve_command_provider_config` first. - `_dispatch_to_plugin_provider` re-checks for a same-name command config defensively so a refactor of the caller can't silently break the invariant. ## New files - `agent/tts_provider.py` — `TTSProvider(ABC)` with `synthesize()` (required), `list_voices()`, `list_models()`, `get_setup_schema()`, `stream()`, `voice_compatible` (all optional with sane defaults). Mirrors `agent/image_gen_provider.py` shape. - `agent/tts_registry.py` — `register_provider`/`get_provider`/`list_providers` with `_BUILTIN_NAMES` reject-shadowing invariant. Mirrors `agent/image_gen_registry.py` shape. - `plugins/tts/...` directory ready for community plugins (none shipped). ## Modified files - `hermes_cli/plugins.py` — `register_tts_provider()` method on `PluginContext`. Matches the gating shape of `register_image_gen_provider()` / `register_browser_provider()`. - `tools/tts_tool.py` — `_dispatch_to_plugin_provider()` + `_plugin_provider_is_voice_compatible()` + walrus-elif wiring into the main dispatcher. Built-in elif chain untouched. - `hermes_cli/tools_config.py` — `_plugin_tts_providers()` injects plugin rows into the Text-to-Speech picker category alongside the 10 hardcoded built-in rows. ## Tests - `tests/agent/test_tts_registry.py` — 47 tests covering registration, lookup, ABC contract, helpers, AND a `TestBuiltinSync` regression test that fails if `agent.tts_registry._BUILTIN_NAMES` drifts from `tools.tts_tool.BUILTIN_TTS_PROVIDERS` (kept duplicated due to circular import constraints). - `tests/tools/test_tts_plugin_dispatch.py` — 35 tests covering built-in-always-wins, command-wins-over-plugin, plugin dispatch, exception passthrough, voice_compatible helper. - `tests/hermes_cli/test_tts_picker.py` — 10 tests covering the picker surface, builtin shadowing defense, integration with `_visible_providers`. - `tests/hermes_cli/test_plugins_tts_registration.py` — 3 end-to-end tests via `PluginManager.discover_and_load()`. - `tests/plugins/tts/check_parity_vs_main.py` — 9-scenario subprocess parity harness vs `origin/main`. The only intentional diff is `fallback_edge → plugin` for the `plugin-installed` scenario. ## Verification - 95/95 new tests pass. - 170/170 pre-existing TTS tests (test_tts_command_providers, test_tts_max_text_length, test_tts_speed, etc.) pass unchanged. - Parity harness against `origin/main`: 8 OK + 1 expected DIFF. - E2E smoke: a registered plugin's `synthesize()` is called via `text_to_speech_tool` with the standard JSON envelope returned. - Ruff clean on all touched files. ## Docs - `website/docs/user-guide/features/tts.md` — new "Python plugin providers" section with a decision table (command-provider vs plugin), minimal plugin example, and the optional-hook reference. - `website/docs/user-guide/features/plugins.md` — TTS row updated to mention both surfaces (command-provider primary, plugin for SDK/streaming). Closes #30398
This commit is contained in:
parent
782681f904
commit
00ec0b617c
13 changed files with 2037 additions and 1 deletions
312
tests/agent/test_tts_registry.py
Normal file
312
tests/agent/test_tts_registry.py
Normal file
|
|
@ -0,0 +1,312 @@
|
|||
"""Tests for agent/tts_registry.py and agent/tts_provider.py.
|
||||
|
||||
Covers:
|
||||
- Registration happy path
|
||||
- Registration rejection: non-TTSProvider type
|
||||
- Registration rejection: empty/whitespace name
|
||||
- Built-in name shadowing: warning + silent ignore (no exception)
|
||||
- Re-registration: overwrites + logs at debug
|
||||
- Case + whitespace insensitivity on lookup
|
||||
- ABC contract: default implementations work
|
||||
- ABC contract: synthesize() must be implemented
|
||||
- ABC contract: stream() raises NotImplementedError by default
|
||||
- resolve_output_format helper coerces invalid input
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from agent import tts_registry
|
||||
from agent.tts_provider import (
|
||||
DEFAULT_OUTPUT_FORMAT,
|
||||
VALID_OUTPUT_FORMATS,
|
||||
TTSProvider,
|
||||
resolve_output_format,
|
||||
)
|
||||
|
||||
|
||||
class _FakeProvider(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "fake",
|
||||
display: Optional[str] = None,
|
||||
voice_compat: bool = False,
|
||||
synthesize_impl: Optional[Any] = None,
|
||||
):
|
||||
self._name = name
|
||||
self._display = display
|
||||
self._voice_compat = voice_compat
|
||||
self._synthesize_impl = synthesize_impl
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._display if self._display is not None else super().display_name
|
||||
|
||||
@property
|
||||
def voice_compatible(self) -> bool:
|
||||
return self._voice_compat
|
||||
|
||||
def synthesize(self, text: str, output_path: str, **kw):
|
||||
if self._synthesize_impl is not None:
|
||||
return self._synthesize_impl(text, output_path, **kw)
|
||||
return output_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
tts_registry._reset_for_tests()
|
||||
yield
|
||||
tts_registry._reset_for_tests()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistration:
|
||||
def test_happy_path(self):
|
||||
p = _FakeProvider(name="cartesia")
|
||||
tts_registry.register_provider(p)
|
||||
assert tts_registry.get_provider("cartesia") is p
|
||||
assert [r.name for r in tts_registry.list_providers()] == ["cartesia"]
|
||||
|
||||
def test_rejects_non_provider_type(self):
|
||||
with pytest.raises(TypeError, match="expects a TTSProvider instance"):
|
||||
tts_registry.register_provider("not a provider") # type: ignore[arg-type]
|
||||
assert tts_registry.list_providers() == []
|
||||
|
||||
def test_rejects_empty_name(self):
|
||||
p = _FakeProvider(name="")
|
||||
with pytest.raises(ValueError, match="non-empty string"):
|
||||
tts_registry.register_provider(p)
|
||||
assert tts_registry.list_providers() == []
|
||||
|
||||
def test_rejects_whitespace_name(self):
|
||||
p = _FakeProvider(name=" ")
|
||||
with pytest.raises(ValueError, match="non-empty string"):
|
||||
tts_registry.register_provider(p)
|
||||
assert tts_registry.list_providers() == []
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"builtin",
|
||||
["edge", "openai", "elevenlabs", "minimax", "gemini",
|
||||
"mistral", "xai", "piper", "kittentts", "neutts"],
|
||||
)
|
||||
def test_rejects_builtin_shadow_with_warning(self, builtin, caplog):
|
||||
"""Built-in names always win — plugin registration is silently ignored
|
||||
but a warning is logged so the operator can see what happened.
|
||||
"""
|
||||
p = _FakeProvider(name=builtin)
|
||||
with caplog.at_level(logging.WARNING, logger="agent.tts_registry"):
|
||||
tts_registry.register_provider(p)
|
||||
assert "shadows a built-in name" in caplog.text
|
||||
assert builtin in caplog.text
|
||||
assert tts_registry.get_provider(builtin) is None
|
||||
assert tts_registry.list_providers() == []
|
||||
|
||||
def test_builtin_shadow_case_insensitive(self, caplog):
|
||||
"""``EDGE``/``Edge``/`` edge `` all collide with the ``edge`` built-in."""
|
||||
for variant in ("EDGE", "Edge", " edge ", "eDgE"):
|
||||
tts_registry._reset_for_tests()
|
||||
with caplog.at_level(logging.WARNING, logger="agent.tts_registry"):
|
||||
tts_registry.register_provider(_FakeProvider(name=variant))
|
||||
assert tts_registry.list_providers() == [], (
|
||||
f"variant {variant!r} should have been rejected as a built-in shadow"
|
||||
)
|
||||
|
||||
def test_reregistration_overwrites(self, caplog):
|
||||
p1 = _FakeProvider(name="cartesia")
|
||||
p2 = _FakeProvider(name="cartesia")
|
||||
tts_registry.register_provider(p1)
|
||||
with caplog.at_level(logging.DEBUG, logger="agent.tts_registry"):
|
||||
tts_registry.register_provider(p2)
|
||||
assert tts_registry.get_provider("cartesia") is p2
|
||||
assert "re-registered" in caplog.text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLookup:
|
||||
def test_get_provider_missing_returns_none(self):
|
||||
assert tts_registry.get_provider("nonexistent") is None
|
||||
|
||||
def test_get_provider_non_string_returns_none(self):
|
||||
assert tts_registry.get_provider(None) is None # type: ignore[arg-type]
|
||||
assert tts_registry.get_provider(123) is None # type: ignore[arg-type]
|
||||
|
||||
def test_get_provider_case_insensitive(self):
|
||||
p = _FakeProvider(name="cartesia")
|
||||
tts_registry.register_provider(p)
|
||||
assert tts_registry.get_provider("CARTESIA") is p
|
||||
assert tts_registry.get_provider("Cartesia") is p
|
||||
|
||||
def test_get_provider_whitespace_tolerant(self):
|
||||
p = _FakeProvider(name="cartesia")
|
||||
tts_registry.register_provider(p)
|
||||
assert tts_registry.get_provider(" cartesia ") is p
|
||||
|
||||
def test_list_providers_sorted(self):
|
||||
tts_registry.register_provider(_FakeProvider(name="zylo"))
|
||||
tts_registry.register_provider(_FakeProvider(name="alpha"))
|
||||
tts_registry.register_provider(_FakeProvider(name="middle"))
|
||||
names = [p.name for p in tts_registry.list_providers()]
|
||||
assert names == ["alpha", "middle", "zylo"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ABC contract
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestABCContract:
|
||||
def test_must_implement_synthesize(self):
|
||||
class Incomplete(TTSProvider):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "incomplete"
|
||||
# synthesize NOT implemented
|
||||
|
||||
with pytest.raises(TypeError, match="abstract"):
|
||||
Incomplete() # type: ignore[abstract]
|
||||
|
||||
def test_must_implement_name(self):
|
||||
class Incomplete(TTSProvider):
|
||||
def synthesize(self, text, output_path, **kw):
|
||||
return output_path
|
||||
# name NOT implemented
|
||||
|
||||
with pytest.raises(TypeError, match="abstract"):
|
||||
Incomplete() # type: ignore[abstract]
|
||||
|
||||
def test_display_name_defaults_to_title(self):
|
||||
p = _FakeProvider(name="cartesia")
|
||||
assert p.display_name == "Cartesia"
|
||||
|
||||
def test_display_name_override_respected(self):
|
||||
p = _FakeProvider(name="cartesia", display="Cartesia AI")
|
||||
assert p.display_name == "Cartesia AI"
|
||||
|
||||
def test_is_available_default_true(self):
|
||||
p = _FakeProvider(name="cartesia")
|
||||
assert p.is_available() is True
|
||||
|
||||
def test_list_voices_default_empty(self):
|
||||
p = _FakeProvider(name="cartesia")
|
||||
assert p.list_voices() == []
|
||||
|
||||
def test_list_models_default_empty(self):
|
||||
p = _FakeProvider(name="cartesia")
|
||||
assert p.list_models() == []
|
||||
|
||||
def test_default_model_none_when_no_models(self):
|
||||
p = _FakeProvider(name="cartesia")
|
||||
assert p.default_model() is None
|
||||
|
||||
def test_default_voice_none_when_no_voices(self):
|
||||
p = _FakeProvider(name="cartesia")
|
||||
assert p.default_voice() is None
|
||||
|
||||
def test_default_model_first_listed(self):
|
||||
class WithModels(_FakeProvider):
|
||||
def list_models(self):
|
||||
return [{"id": "sonic-2"}, {"id": "sonic-1"}]
|
||||
|
||||
p = WithModels(name="cartesia")
|
||||
assert p.default_model() == "sonic-2"
|
||||
|
||||
def test_default_voice_first_listed(self):
|
||||
class WithVoices(_FakeProvider):
|
||||
def list_voices(self):
|
||||
return [{"id": "voice-aria"}, {"id": "voice-jasper"}]
|
||||
|
||||
p = WithVoices(name="cartesia")
|
||||
assert p.default_voice() == "voice-aria"
|
||||
|
||||
def test_get_setup_schema_default_minimal(self):
|
||||
p = _FakeProvider(name="cartesia")
|
||||
schema = p.get_setup_schema()
|
||||
assert schema["name"] == "Cartesia"
|
||||
assert schema["env_vars"] == []
|
||||
|
||||
def test_stream_raises_not_implemented_by_default(self):
|
||||
p = _FakeProvider(name="cartesia")
|
||||
with pytest.raises(NotImplementedError, match="does not implement streaming"):
|
||||
next(p.stream("hello"))
|
||||
|
||||
def test_voice_compatible_default_false(self):
|
||||
p = _FakeProvider(name="cartesia")
|
||||
assert p.voice_compatible is False
|
||||
|
||||
def test_voice_compatible_override(self):
|
||||
p = _FakeProvider(name="cartesia", voice_compat=True)
|
||||
assert p.voice_compatible is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveOutputFormat:
|
||||
@pytest.mark.parametrize("valid", sorted(VALID_OUTPUT_FORMATS))
|
||||
def test_valid_passes_through(self, valid):
|
||||
assert resolve_output_format(valid) == valid
|
||||
|
||||
def test_uppercase_normalized(self):
|
||||
assert resolve_output_format("MP3") == "mp3"
|
||||
assert resolve_output_format("Opus") == "opus"
|
||||
|
||||
def test_whitespace_stripped(self):
|
||||
assert resolve_output_format(" wav ") == "wav"
|
||||
|
||||
def test_invalid_returns_default(self):
|
||||
assert resolve_output_format("aiff") == DEFAULT_OUTPUT_FORMAT
|
||||
assert resolve_output_format("") == DEFAULT_OUTPUT_FORMAT
|
||||
|
||||
def test_none_returns_default(self):
|
||||
assert resolve_output_format(None) == DEFAULT_OUTPUT_FORMAT
|
||||
|
||||
def test_non_string_returns_default(self):
|
||||
assert resolve_output_format(123) == DEFAULT_OUTPUT_FORMAT # type: ignore[arg-type]
|
||||
assert resolve_output_format([]) == DEFAULT_OUTPUT_FORMAT # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync invariant: registry's built-in list vs dispatcher's built-in list
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuiltinSync:
|
||||
"""``_BUILTIN_NAMES`` in agent/tts_registry.py is duplicated from
|
||||
``BUILTIN_TTS_PROVIDERS`` in tools/tts_tool.py (importing directly
|
||||
would create a circular dependency). This test fails loudly if the
|
||||
two lists drift — a new built-in added to tts_tool.py MUST also be
|
||||
added to tts_registry.py's _BUILTIN_NAMES or the registry will
|
||||
accept a name the dispatcher will silently route to the wrong
|
||||
handler.
|
||||
"""
|
||||
|
||||
def test_registry_builtins_match_dispatcher_builtins(self):
|
||||
from tools.tts_tool import BUILTIN_TTS_PROVIDERS
|
||||
|
||||
assert tts_registry._BUILTIN_NAMES == BUILTIN_TTS_PROVIDERS, (
|
||||
"agent.tts_registry._BUILTIN_NAMES and "
|
||||
"tools.tts_tool.BUILTIN_TTS_PROVIDERS have drifted!\n"
|
||||
f" Registry only: {sorted(tts_registry._BUILTIN_NAMES - BUILTIN_TTS_PROVIDERS)}\n"
|
||||
f" Dispatcher only: {sorted(BUILTIN_TTS_PROVIDERS - tts_registry._BUILTIN_NAMES)}\n"
|
||||
"Add the missing names to whichever list is incomplete. "
|
||||
"These two lists exist as a circular-import workaround and "
|
||||
"MUST be kept in sync manually."
|
||||
)
|
||||
156
tests/hermes_cli/test_plugins_tts_registration.py
Normal file
156
tests/hermes_cli/test_plugins_tts_registration.py
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
"""Tests for PluginContext.register_tts_provider() (issue #30398).
|
||||
|
||||
Exercises the plugin context hook end-to-end: drops a fake plugin into
|
||||
``$HERMES_HOME/plugins/``, runs ``PluginManager().discover_and_load()``,
|
||||
and asserts the registration result.
|
||||
|
||||
Mirrors the structure of
|
||||
``tests/hermes_cli/test_plugin_scanner_recursion.py::TestRegisterImageGenProvider``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def _write_plugin(
|
||||
root: Path,
|
||||
name: str,
|
||||
*,
|
||||
manifest_extra: Dict[str, Any] | None = None,
|
||||
register_body: str = "pass",
|
||||
) -> Path:
|
||||
plugin_dir = root / name
|
||||
plugin_dir.mkdir(parents=True, exist_ok=True)
|
||||
manifest = {
|
||||
"name": name,
|
||||
"version": "0.1.0",
|
||||
"description": f"Test plugin {name}",
|
||||
}
|
||||
if manifest_extra:
|
||||
manifest.update(manifest_extra)
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump(manifest))
|
||||
(plugin_dir / "__init__.py").write_text(
|
||||
f"def register(ctx):\n {register_body}\n"
|
||||
)
|
||||
return plugin_dir
|
||||
|
||||
|
||||
def _enable(hermes_home: Path, name: str) -> None:
|
||||
cfg_path = hermes_home / "config.yaml"
|
||||
cfg: dict = {}
|
||||
if cfg_path.exists():
|
||||
try:
|
||||
cfg = yaml.safe_load(cfg_path.read_text()) or {}
|
||||
except Exception:
|
||||
cfg = {}
|
||||
plugins_cfg = cfg.setdefault("plugins", {})
|
||||
enabled = plugins_cfg.setdefault("enabled", [])
|
||||
if isinstance(enabled, list) and name not in enabled:
|
||||
enabled.append(name)
|
||||
cfg_path.write_text(yaml.safe_dump(cfg))
|
||||
|
||||
|
||||
class TestRegisterTTSProvider:
|
||||
"""End-to-end: a fake plugin registers via the hook, ends up in the registry."""
|
||||
|
||||
def test_accepts_valid_provider(self):
|
||||
from hermes_cli.plugins import PluginManager
|
||||
|
||||
from agent import tts_registry
|
||||
tts_registry._reset_for_tests()
|
||||
|
||||
hermes_home = Path(os.environ["HERMES_HOME"])
|
||||
_write_plugin(
|
||||
hermes_home / "plugins",
|
||||
"my-tts-plugin",
|
||||
register_body=(
|
||||
"from agent.tts_provider import TTSProvider\n"
|
||||
" class P(TTSProvider):\n"
|
||||
" @property\n"
|
||||
" def name(self): return 'fake-tts'\n"
|
||||
" def synthesize(self, text, output_path, **kw):\n"
|
||||
" return output_path\n"
|
||||
" ctx.register_tts_provider(P())"
|
||||
),
|
||||
)
|
||||
_enable(hermes_home, "my-tts-plugin")
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert mgr._plugins["my-tts-plugin"].enabled is True, (
|
||||
f"Plugin failed to load: {mgr._plugins['my-tts-plugin'].error}"
|
||||
)
|
||||
assert tts_registry.get_provider("fake-tts") is not None
|
||||
|
||||
tts_registry._reset_for_tests()
|
||||
|
||||
def test_rejects_non_provider(self, caplog):
|
||||
"""A plugin that passes a non-TTSProvider gets a warning, no exception."""
|
||||
from hermes_cli.plugins import PluginManager
|
||||
|
||||
from agent import tts_registry
|
||||
tts_registry._reset_for_tests()
|
||||
|
||||
hermes_home = Path(os.environ["HERMES_HOME"])
|
||||
_write_plugin(
|
||||
hermes_home / "plugins",
|
||||
"bad-tts-plugin",
|
||||
register_body="ctx.register_tts_provider('not a provider')",
|
||||
)
|
||||
_enable(hermes_home, "bad-tts-plugin")
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
# Plugin loaded (register returned normally), but registry empty.
|
||||
assert mgr._plugins["bad-tts-plugin"].enabled is True
|
||||
assert tts_registry.get_provider("not a provider") is None
|
||||
assert tts_registry.list_providers() == []
|
||||
assert "does not inherit from TTSProvider" in caplog.text
|
||||
|
||||
tts_registry._reset_for_tests()
|
||||
|
||||
def test_rejects_builtin_shadow(self, caplog):
|
||||
"""A plugin trying to register a name colliding with a built-in is silently
|
||||
rejected by the underlying registry — both with a registry-level warning
|
||||
AND with the registry remaining empty (plugin still loads OK).
|
||||
"""
|
||||
from hermes_cli.plugins import PluginManager
|
||||
|
||||
from agent import tts_registry
|
||||
tts_registry._reset_for_tests()
|
||||
|
||||
hermes_home = Path(os.environ["HERMES_HOME"])
|
||||
_write_plugin(
|
||||
hermes_home / "plugins",
|
||||
"shadow-tts-plugin",
|
||||
register_body=(
|
||||
"from agent.tts_provider import TTSProvider\n"
|
||||
" class P(TTSProvider):\n"
|
||||
" @property\n"
|
||||
" def name(self): return 'edge'\n"
|
||||
" def synthesize(self, text, output_path, **kw):\n"
|
||||
" return output_path\n"
|
||||
" ctx.register_tts_provider(P())"
|
||||
),
|
||||
)
|
||||
_enable(hermes_home, "shadow-tts-plugin")
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
# Plugin still loaded normally — built-in shadowing is a warning,
|
||||
# not an exception. The registry rejects the entry though.
|
||||
assert mgr._plugins["shadow-tts-plugin"].enabled is True
|
||||
assert tts_registry.get_provider("edge") is None
|
||||
assert "shadows a built-in name" in caplog.text
|
||||
|
||||
tts_registry._reset_for_tests()
|
||||
187
tests/hermes_cli/test_tts_picker.py
Normal file
187
tests/hermes_cli/test_tts_picker.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
"""Tests for the TTS plugin picker surface in hermes_cli/tools_config.py (issue #30398).
|
||||
|
||||
Covers ``_plugin_tts_providers()`` and the ``_visible_providers()``
|
||||
integration that injects plugin rows into the Text-to-Speech category.
|
||||
|
||||
Mirrors the structure of existing image_gen / browser picker tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agent import tts_registry
|
||||
from agent.tts_provider import TTSProvider
|
||||
from hermes_cli import tools_config
|
||||
|
||||
|
||||
class _FakeTTSProvider(TTSProvider):
|
||||
def __init__(self, name: str, schema: dict | None = None):
|
||||
self._name = name
|
||||
self._schema = schema
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def synthesize(self, text, output_path, **kw):
|
||||
return output_path
|
||||
|
||||
def get_setup_schema(self):
|
||||
if self._schema is not None:
|
||||
return self._schema
|
||||
return super().get_setup_schema()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
tts_registry._reset_for_tests()
|
||||
yield
|
||||
tts_registry._reset_for_tests()
|
||||
|
||||
|
||||
class TestPluginTTSProviders:
|
||||
"""``_plugin_tts_providers()`` returns picker-row dicts."""
|
||||
|
||||
def test_empty_when_no_plugins(self):
|
||||
assert tools_config._plugin_tts_providers() == []
|
||||
|
||||
def test_returns_row_for_registered_plugin(self):
|
||||
tts_registry.register_provider(
|
||||
_FakeTTSProvider(
|
||||
name="cartesia",
|
||||
schema={
|
||||
"name": "Cartesia",
|
||||
"badge": "paid",
|
||||
"tag": "Ultra-low-latency streaming",
|
||||
"env_vars": [
|
||||
{"key": "CARTESIA_API_KEY", "prompt": "Cartesia API key",
|
||||
"url": "https://play.cartesia.ai/console"},
|
||||
],
|
||||
},
|
||||
)
|
||||
)
|
||||
rows = tools_config._plugin_tts_providers()
|
||||
assert len(rows) == 1
|
||||
row = rows[0]
|
||||
assert row["name"] == "Cartesia"
|
||||
assert row["badge"] == "paid"
|
||||
assert row["tag"] == "Ultra-low-latency streaming"
|
||||
assert row["env_vars"][0]["key"] == "CARTESIA_API_KEY"
|
||||
# Selecting this row writes ``tts.provider: cartesia`` — same
|
||||
# write path as a hardcoded row.
|
||||
assert row["tts_provider"] == "cartesia"
|
||||
assert row["tts_plugin_name"] == "cartesia"
|
||||
|
||||
def test_filters_builtin_shadow_defensively(self):
|
||||
"""Even if a plugin slipped past the registry's built-in check
|
||||
(e.g. via direct ``agent.tts_registry.register_provider`` rather
|
||||
than the ``ctx.register_tts_provider`` hook), the picker layer
|
||||
filters it out so the picker invariant holds."""
|
||||
# Use lower-level call to bypass the warning + skip in
|
||||
# register_provider (the registry's built-in guard).
|
||||
# Note: this is intentionally pathological — production code
|
||||
# paths go through the hook which catches this first.
|
||||
provider = _FakeTTSProvider(name="edge")
|
||||
tts_registry._providers["edge"] = provider # type: ignore[index]
|
||||
try:
|
||||
rows = tools_config._plugin_tts_providers()
|
||||
assert rows == [], (
|
||||
"Picker must filter built-in name shadows even when the "
|
||||
"registry has been bypassed."
|
||||
)
|
||||
finally:
|
||||
tts_registry._providers.pop("edge", None) # type: ignore[arg-type]
|
||||
|
||||
def test_skips_providers_with_no_name(self):
|
||||
"""Defense in depth: a provider with no .name attribute is skipped
|
||||
rather than crashing the picker."""
|
||||
|
||||
class _NoName:
|
||||
display_name = "Bogus"
|
||||
def get_setup_schema(self):
|
||||
return {"name": "Bogus"}
|
||||
|
||||
tts_registry._providers["bogus"] = _NoName() # type: ignore[assignment]
|
||||
try:
|
||||
rows = tools_config._plugin_tts_providers()
|
||||
# Provider has no .name so the picker filters it out
|
||||
assert all(r.get("tts_plugin_name") != "bogus" for r in rows)
|
||||
finally:
|
||||
tts_registry._providers.pop("bogus", None) # type: ignore[arg-type]
|
||||
|
||||
def test_skips_providers_whose_schema_raises(self):
|
||||
class _ExplodingSchema(_FakeTTSProvider):
|
||||
def get_setup_schema(self):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
tts_registry.register_provider(_ExplodingSchema(name="exploding"))
|
||||
tts_registry.register_provider(_FakeTTSProvider(name="working"))
|
||||
rows = tools_config._plugin_tts_providers()
|
||||
assert [r["tts_plugin_name"] for r in rows] == ["working"]
|
||||
|
||||
def test_minimal_schema_uses_display_name(self):
|
||||
"""A provider with no setup_schema override gets a row built from
|
||||
``display_name`` and ``name`` only."""
|
||||
tts_registry.register_provider(_FakeTTSProvider(name="minimal"))
|
||||
rows = tools_config._plugin_tts_providers()
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["name"] == "Minimal" # display_name default
|
||||
assert rows[0]["tts_provider"] == "minimal"
|
||||
assert rows[0]["env_vars"] == []
|
||||
|
||||
def test_post_setup_passthrough(self):
|
||||
tts_registry.register_provider(
|
||||
_FakeTTSProvider(
|
||||
name="my-tts",
|
||||
schema={
|
||||
"name": "My TTS",
|
||||
"post_setup": "my_post_install_hook",
|
||||
"env_vars": [],
|
||||
},
|
||||
)
|
||||
)
|
||||
rows = tools_config._plugin_tts_providers()
|
||||
assert rows[0].get("post_setup") == "my_post_install_hook"
|
||||
|
||||
|
||||
class TestVisibleProvidersInjectsTTSPlugins:
|
||||
"""``_visible_providers()`` injects plugin rows into the Text-to-Speech
|
||||
category alongside the hardcoded built-in rows."""
|
||||
|
||||
def test_tts_category_includes_plugin_rows(self):
|
||||
tts_registry.register_provider(_FakeTTSProvider(name="cartesia"))
|
||||
|
||||
tts_cat = tools_config.TOOL_CATEGORIES["tts"]
|
||||
visible = tools_config._visible_providers(tts_cat, config={})
|
||||
|
||||
names = [row.get("name") for row in visible]
|
||||
# Hardcoded rows (sample — check at least one is present)
|
||||
assert "Microsoft Edge TTS" in names
|
||||
# Plugin row injected at the end
|
||||
assert "Cartesia" in names
|
||||
|
||||
# Plugin row has tts_provider key for write-path compat
|
||||
plugin_rows = [r for r in visible if r.get("tts_plugin_name")]
|
||||
assert len(plugin_rows) == 1
|
||||
assert plugin_rows[0]["tts_provider"] == "cartesia"
|
||||
|
||||
def test_other_categories_unaffected_by_tts_plugins(self):
|
||||
"""Registering a TTS plugin must not leak into the Image Generation
|
||||
or Browser pickers."""
|
||||
tts_registry.register_provider(_FakeTTSProvider(name="cartesia"))
|
||||
|
||||
img_cat = tools_config.TOOL_CATEGORIES["image_gen"]
|
||||
visible = tools_config._visible_providers(img_cat, config={})
|
||||
names = [row.get("name") for row in visible]
|
||||
assert "Cartesia" not in names
|
||||
|
||||
def test_tts_category_without_plugins_only_hardcoded(self):
|
||||
"""No plugins → picker shows exactly the hardcoded rows."""
|
||||
tts_cat = tools_config.TOOL_CATEGORIES["tts"]
|
||||
visible = tools_config._visible_providers(tts_cat, config={})
|
||||
names = [row.get("name") for row in visible]
|
||||
# No row has the plugin marker
|
||||
assert all(not row.get("tts_plugin_name") for row in visible)
|
||||
# Hardcoded rows still present (sample one of the always-visible ones)
|
||||
assert "Microsoft Edge TTS" in names
|
||||
0
tests/plugins/tts/__init__.py
Normal file
0
tests/plugins/tts/__init__.py
Normal file
328
tests/plugins/tts/check_parity_vs_main.py
Normal file
328
tests/plugins/tts/check_parity_vs_main.py
Normal file
|
|
@ -0,0 +1,328 @@
|
|||
"""Behavior-parity check for the TTS plugin hook (issue #30398).
|
||||
|
||||
Spawns one subprocess per (version, scenario) cell — pinned to either
|
||||
``origin/main`` (no plugin hook; ``tts.provider: cartesia`` falls
|
||||
through to the Edge TTS default branch) or this PR's worktree (plugin
|
||||
hook present; same config routes through the plugin registry when a
|
||||
plugin is registered).
|
||||
|
||||
Each subprocess clears all TTS-related env vars + writes a
|
||||
``config.yaml``, then resolves how the dispatcher would route a
|
||||
``text_to_speech`` call. The emitted shape tuple is::
|
||||
|
||||
{dispatch_kind, provider_name, voice_compat}
|
||||
|
||||
Where ``dispatch_kind`` ∈
|
||||
``{"builtin_edge", "builtin_openai", "builtin_elevenlabs", ...,
|
||||
"command", "plugin", "fallback_edge", "error"}``:
|
||||
|
||||
* ``builtin_<name>`` — config selects a built-in handler that exists
|
||||
on both main and PR (no diff expected)
|
||||
* ``command`` — config selects a ``tts.providers.<name>: type: command``
|
||||
entry (PR #17843; no diff expected)
|
||||
* ``plugin`` — config selects a plugin-registered provider (PR only)
|
||||
* ``fallback_edge`` — config selects an unknown name with no matching
|
||||
plugin or command entry → Edge TTS default fallback
|
||||
* ``error`` — explicit fatal error (e.g. mistral quarantine)
|
||||
|
||||
The parent process diffs the reduced shape per scenario. The only
|
||||
acceptable diff is ``fallback_edge → plugin`` for the
|
||||
``unknown-name-with-plugin-installed`` scenario — everything else is
|
||||
a regression.
|
||||
|
||||
Run from the PR worktree (it auto-resolves ``MAIN_DIR`` from the parent
|
||||
of the worktree directory, or falls back to a sibling
|
||||
``hermes-agent-main`` checkout)::
|
||||
|
||||
python tests/plugins/tts/check_parity_vs_main.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[3]
|
||||
|
||||
|
||||
def _resolve_main_dir() -> Path:
|
||||
candidate = REPO_ROOT.parent.parent
|
||||
if (candidate / "tools" / "tts_tool.py").exists() and candidate != REPO_ROOT:
|
||||
return candidate
|
||||
sibling = REPO_ROOT.parent / "hermes-agent-main"
|
||||
if (sibling / "tools" / "tts_tool.py").exists():
|
||||
return sibling
|
||||
return REPO_ROOT
|
||||
|
||||
|
||||
MAIN_DIR = _resolve_main_dir()
|
||||
PR_DIR = REPO_ROOT
|
||||
assert (PR_DIR / "tools" / "tts_tool.py").exists(), (
|
||||
f"PR_DIR={PR_DIR} doesn't look like a hermes-agent checkout"
|
||||
)
|
||||
|
||||
|
||||
# The subprocess script — runs INSIDE either the main checkout or PR
|
||||
# checkout, so the import paths resolve to the version of the code
|
||||
# under test. We never call the real ``text_to_speech_tool`` because
|
||||
# that would require audio synthesis; instead we ask the resolution
|
||||
# layer what it WOULD do.
|
||||
SUBPROCESS_SCRIPT = r"""
|
||||
import json, os, sys, tempfile
|
||||
sys.path.insert(0, sys.argv[1])
|
||||
|
||||
# Isolated HERMES_HOME so the config write is hermetic.
|
||||
home = tempfile.mkdtemp()
|
||||
os.environ["HERMES_HOME"] = home
|
||||
|
||||
# Clear TTS-related env so dispatch decisions are config-driven.
|
||||
for k in (
|
||||
"ELEVENLABS_API_KEY", "OPENAI_API_KEY", "VOICE_TOOLS_OPENAI_KEY",
|
||||
"MINIMAX_API_KEY", "XAI_API_KEY", "GEMINI_API_KEY",
|
||||
):
|
||||
os.environ.pop(k, None)
|
||||
|
||||
scenario_env = json.loads(sys.argv[2])
|
||||
os.environ.update(scenario_env)
|
||||
|
||||
config_yaml = sys.argv[3]
|
||||
plugin_register = sys.argv[4] # "yes" to register a fake plugin
|
||||
|
||||
config_path = os.path.join(home, "config.yaml")
|
||||
with open(config_path, "w") as f:
|
||||
f.write(config_yaml)
|
||||
|
||||
# Fresh import — must not have anything cached from prior runs.
|
||||
for name in list(sys.modules):
|
||||
if (name.startswith("tools.")
|
||||
or name.startswith("agent.")
|
||||
or name.startswith("plugins.")
|
||||
or name.startswith("hermes_cli.")):
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
# Try importing tts_registry — only exists on PR side.
|
||||
have_plugin_hook = False
|
||||
try:
|
||||
from agent import tts_registry
|
||||
from agent.tts_provider import TTSProvider
|
||||
have_plugin_hook = True
|
||||
|
||||
if plugin_register == "yes":
|
||||
class _FakeProvider(TTSProvider):
|
||||
@property
|
||||
def name(self): return "cartesia"
|
||||
def synthesize(self, text, output_path, **kw):
|
||||
return output_path
|
||||
|
||||
tts_registry._reset_for_tests()
|
||||
tts_registry.register_provider(_FakeProvider())
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import tools.tts_tool as tts_tool
|
||||
|
||||
# Read the config the same way text_to_speech_tool() does.
|
||||
tts_config = tts_tool._load_tts_config()
|
||||
provider = tts_tool._get_provider(tts_config)
|
||||
|
||||
dispatch_kind = None
|
||||
provider_name = provider
|
||||
voice_compat = False
|
||||
error_text = None
|
||||
|
||||
try:
|
||||
# Mistral is the one branch that returns a fatal error.
|
||||
if provider == "mistral":
|
||||
dispatch_kind = "error"
|
||||
error_text = "mistral quarantine"
|
||||
elif tts_tool._resolve_command_provider_config(provider, tts_config) is not None:
|
||||
dispatch_kind = "command"
|
||||
elif have_plugin_hook and provider not in tts_tool.BUILTIN_TTS_PROVIDERS:
|
||||
# On PR side: check plugin dispatch.
|
||||
plugin_path = tts_tool._dispatch_to_plugin_provider(
|
||||
"test", os.path.join(home, "out.mp3"), provider, tts_config,
|
||||
)
|
||||
if plugin_path is not None:
|
||||
dispatch_kind = "plugin"
|
||||
voice_compat = tts_tool._plugin_provider_is_voice_compatible(provider)
|
||||
else:
|
||||
# Falls through to Edge TTS default on the PR side too.
|
||||
dispatch_kind = "fallback_edge"
|
||||
elif provider in tts_tool.BUILTIN_TTS_PROVIDERS:
|
||||
dispatch_kind = "builtin_" + provider
|
||||
else:
|
||||
# On main side: unknown names fall through to Edge default.
|
||||
dispatch_kind = "fallback_edge"
|
||||
except Exception as exc:
|
||||
dispatch_kind = "exception"
|
||||
error_text = repr(exc)
|
||||
|
||||
shape = {
|
||||
"dispatch_kind": dispatch_kind,
|
||||
"provider_name": provider_name,
|
||||
"voice_compat": bool(voice_compat),
|
||||
"error_present": error_text is not None,
|
||||
}
|
||||
print(json.dumps(shape))
|
||||
"""
|
||||
|
||||
|
||||
SCENARIOS: list[tuple[str, str, dict[str, str], str]] = [
|
||||
# (label, config.yaml body, scenario_env, plugin_register)
|
||||
|
||||
# Scenario 1: unset tts.provider → both: Edge default
|
||||
("unset-defaults-to-edge", "", {}, "no"),
|
||||
|
||||
# Scenario 2: built-in name → both: that built-in
|
||||
("explicit-edge", "tts:\n provider: edge\n", {}, "no"),
|
||||
("explicit-openai", "tts:\n provider: openai\n", {}, "no"),
|
||||
("explicit-elevenlabs", "tts:\n provider: elevenlabs\n", {}, "no"),
|
||||
|
||||
# Scenario 3: command-type provider → both: command dispatch
|
||||
(
|
||||
"command-provider",
|
||||
"tts:\n provider: my-piper\n providers:\n my-piper:\n type: command\n command: 'piper -m model.onnx -f {output_path} < {input_path}'\n",
|
||||
{},
|
||||
"no",
|
||||
),
|
||||
|
||||
# Scenario 4: unknown name with NO plugin installed → both: fallback to Edge
|
||||
("unknown-no-plugin", "tts:\n provider: cartesia\n", {}, "no"),
|
||||
|
||||
# Scenario 5: unknown name WITH plugin installed
|
||||
# main: fallback_edge (no plugin hook exists)
|
||||
# PR: plugin (cartesia)
|
||||
# This is the ONLY acceptable diff in the harness.
|
||||
("plugin-installed", "tts:\n provider: cartesia\n", {}, "yes"),
|
||||
|
||||
# Scenario 6: built-in name + plugin tries to shadow → both: built-in
|
||||
# The plugin registers under name "cartesia", not "edge", so this is
|
||||
# effectively the same as scenario 2 — but we exercise the with-plugin
|
||||
# path to ensure the built-in branch still takes priority.
|
||||
("explicit-edge-with-plugin-registered", "tts:\n provider: edge\n", {}, "yes"),
|
||||
|
||||
# Scenario 7: mistral quarantine — both surface the explicit error
|
||||
("mistral-quarantine", "tts:\n provider: mistral\n", {}, "no"),
|
||||
]
|
||||
|
||||
|
||||
def _run_scenario(repo_path: Path, label: str, config_yaml: str, env: dict, plugin_register: str) -> dict:
|
||||
venv_python = repo_path / ".venv" / "bin" / "python"
|
||||
if not venv_python.exists():
|
||||
venv_python = MAIN_DIR / ".venv" / "bin" / "python"
|
||||
if not venv_python.exists():
|
||||
venv_python = MAIN_DIR / "venv" / "bin" / "python"
|
||||
if not venv_python.exists():
|
||||
venv_python = Path("python3")
|
||||
|
||||
out = subprocess.run(
|
||||
[
|
||||
str(venv_python),
|
||||
"-c",
|
||||
SUBPROCESS_SCRIPT,
|
||||
str(repo_path),
|
||||
json.dumps(env),
|
||||
config_yaml,
|
||||
plugin_register,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
if out.returncode != 0:
|
||||
return {
|
||||
"error": "subprocess failed",
|
||||
"stdout": out.stdout[-500:],
|
||||
"stderr": out.stderr[-500:],
|
||||
}
|
||||
try:
|
||||
return json.loads(out.stdout.strip().splitlines()[-1])
|
||||
except Exception as exc:
|
||||
return {"error": f"could not parse output: {exc}", "stdout": out.stdout}
|
||||
|
||||
|
||||
def _reduce(shape: dict) -> dict:
|
||||
"""Reduce to the parts that matter for user-visible parity."""
|
||||
return {
|
||||
"dispatch_kind": shape.get("dispatch_kind"),
|
||||
"provider_name": shape.get("provider_name"),
|
||||
"error_present": shape.get("error_present"),
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
print(f"main: {MAIN_DIR}")
|
||||
print(f"pr: {PR_DIR}")
|
||||
print()
|
||||
|
||||
if MAIN_DIR == PR_DIR:
|
||||
print(
|
||||
"WARN: MAIN_DIR == PR_DIR — diffs will be trivially identical.\n"
|
||||
" Set up a sibling 'hermes-agent-main' checkout pinned to "
|
||||
"origin/main to get real parity coverage."
|
||||
)
|
||||
print()
|
||||
|
||||
failures: list[str] = []
|
||||
errors: list[str] = []
|
||||
intentional_diffs: list[tuple[str, dict, dict]] = []
|
||||
for label, config_yaml, env, plugin_register in SCENARIOS:
|
||||
main_shape = _run_scenario(MAIN_DIR, label, config_yaml, env, plugin_register)
|
||||
pr_shape = _run_scenario(PR_DIR, label, config_yaml, env, plugin_register)
|
||||
|
||||
if "error" in main_shape or "error" in pr_shape:
|
||||
print(f" [ERR ] {label}: subprocess failed")
|
||||
print(f" main: {main_shape}")
|
||||
print(f" pr: {pr_shape}")
|
||||
errors.append(label)
|
||||
continue
|
||||
|
||||
main_reduced = _reduce(main_shape)
|
||||
pr_reduced = _reduce(pr_shape)
|
||||
|
||||
if main_reduced == pr_reduced:
|
||||
print(f" [OK] {label}: {main_reduced}")
|
||||
continue
|
||||
|
||||
# On main, "plugin-installed" scenario returns fallback_edge
|
||||
# (no plugin hook); on PR, it routes to the plugin. That's the
|
||||
# only acceptable diff.
|
||||
fallback_to_plugin = (
|
||||
main_reduced.get("dispatch_kind") == "fallback_edge"
|
||||
and pr_reduced.get("dispatch_kind") == "plugin"
|
||||
and label == "plugin-installed"
|
||||
)
|
||||
if fallback_to_plugin:
|
||||
print(f" [DIFF] {label}: fallback_edge → plugin — expected")
|
||||
intentional_diffs.append((label, main_reduced, pr_reduced))
|
||||
else:
|
||||
print(f" [FAIL] {label}")
|
||||
print(f" main: {main_reduced}")
|
||||
print(f" pr: {pr_reduced}")
|
||||
failures.append(label)
|
||||
|
||||
print()
|
||||
if errors:
|
||||
print(f"SUBPROCESS ERRORS in {len(errors)} scenario(s):")
|
||||
for e in errors:
|
||||
print(f" - {e}")
|
||||
if failures:
|
||||
print(f"BEHAVIOUR REGRESSION in {len(failures)} scenario(s):")
|
||||
for f in failures:
|
||||
print(f" - {f}")
|
||||
if intentional_diffs:
|
||||
print(
|
||||
f"INTENTIONAL DIFFS ({len(intentional_diffs)}): "
|
||||
f"fallback_edge → plugin dispatch when a plugin is registered."
|
||||
)
|
||||
if failures or errors:
|
||||
return 1
|
||||
print(f"PARITY OK across {len(SCENARIOS)} scenarios.")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
323
tests/tools/test_tts_plugin_dispatch.py
Normal file
323
tests/tools/test_tts_plugin_dispatch.py
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
"""Tests for TTS plugin dispatch in tools/tts_tool.py (issue #30398).
|
||||
|
||||
Covers the three core invariants of the plugin dispatcher:
|
||||
|
||||
1. Built-in provider names short-circuit — plugins NEVER win over a
|
||||
built-in. Even if a plugin somehow ended up in the registry with a
|
||||
built-in name (which the registry already blocks), the dispatcher
|
||||
re-checks defensively.
|
||||
2. Command-type providers declared under ``tts.providers.<name>: type:
|
||||
command`` (PR #17843) win over a plugin with the same name. Config
|
||||
is more local than plugin install.
|
||||
3. Plugin dispatch fires only when the configured provider is neither
|
||||
a built-in nor a command-type entry, AND a plugin is registered
|
||||
under that name. Unknown names fall through.
|
||||
|
||||
Also exercises:
|
||||
- Plugin exceptions surface to the outer error envelope (don't crash)
|
||||
- Plugin returning a different path is honored
|
||||
- voice_compatible: True triggers ffmpeg opus conversion path
|
||||
- voice_compatible: False keeps the file as-is
|
||||
|
||||
The dispatcher is exercised in isolation — we don't actually call
|
||||
``text_to_speech_tool`` because that would require real audio file
|
||||
writes. Each test directly calls
|
||||
``tools.tts_tool._dispatch_to_plugin_provider`` / the predicate
|
||||
helpers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from agent import tts_registry
|
||||
from agent.tts_provider import TTSProvider
|
||||
from tools import tts_tool
|
||||
|
||||
|
||||
class _FakeTTSProvider(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
voice_compat: bool = False,
|
||||
raise_exc: Optional[BaseException] = None,
|
||||
return_path: Optional[str] = None,
|
||||
):
|
||||
self._name = name
|
||||
self._voice_compat = voice_compat
|
||||
self._raise_exc = raise_exc
|
||||
self._return_path = return_path
|
||||
# Recorded for assertions
|
||||
self.last_call: Optional[dict] = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def voice_compatible(self) -> bool:
|
||||
return self._voice_compat
|
||||
|
||||
def synthesize(self, text, output_path, **kw):
|
||||
self.last_call = {
|
||||
"text": text,
|
||||
"output_path": output_path,
|
||||
"kwargs": dict(kw),
|
||||
}
|
||||
if self._raise_exc is not None:
|
||||
raise self._raise_exc
|
||||
return self._return_path if self._return_path is not None else output_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_registry():
|
||||
tts_registry._reset_for_tests()
|
||||
yield
|
||||
tts_registry._reset_for_tests()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Resolution invariants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuiltinAlwaysWins:
|
||||
"""Built-in TTS provider names short-circuit the dispatcher.
|
||||
|
||||
Even with a plugin registered (which the registry would reject —
|
||||
but the dispatcher is defensive), built-in names return None so
|
||||
the caller's elif chain handles them natively.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"builtin",
|
||||
["edge", "openai", "elevenlabs", "minimax", "gemini",
|
||||
"mistral", "xai", "piper", "kittentts", "neutts"],
|
||||
)
|
||||
def test_dispatcher_short_circuits_builtin(self, builtin):
|
||||
result = tts_tool._dispatch_to_plugin_provider(
|
||||
text="hello",
|
||||
output_path="/tmp/out.mp3",
|
||||
provider=builtin,
|
||||
tts_config={},
|
||||
)
|
||||
assert result is None, (
|
||||
f"Built-in {builtin!r} must short-circuit plugin dispatch. "
|
||||
"If this test fails, the dispatcher would silently let a "
|
||||
"plugin with a built-in name shadow the native handler — "
|
||||
"violating the precedence rule from PR #17843."
|
||||
)
|
||||
|
||||
def test_dispatcher_short_circuits_builtin_case_insensitive(self):
|
||||
for variant in ("EDGE", "Edge", " edge ", "eDgE"):
|
||||
assert (
|
||||
tts_tool._dispatch_to_plugin_provider(
|
||||
text="hello", output_path="/tmp/x.mp3",
|
||||
provider=variant, tts_config={},
|
||||
) is None
|
||||
)
|
||||
|
||||
|
||||
class TestCommandProviderWins:
|
||||
"""A same-name ``tts.providers.<name>: type: command`` config beats a plugin.
|
||||
|
||||
Locality: a user's command-provider config is more specific than
|
||||
whichever plugin happens to be installed.
|
||||
"""
|
||||
|
||||
def test_command_config_beats_plugin(self):
|
||||
tts_registry.register_provider(_FakeTTSProvider(name="my-tts"))
|
||||
|
||||
result = tts_tool._dispatch_to_plugin_provider(
|
||||
text="hello",
|
||||
output_path="/tmp/out.mp3",
|
||||
provider="my-tts",
|
||||
tts_config={
|
||||
"providers": {
|
||||
"my-tts": {
|
||||
"type": "command",
|
||||
"command": "echo 'hi' > {output_path}",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
# Plugin path returns None → caller falls back to command
|
||||
# provider dispatch (handled by the outer text_to_speech_tool
|
||||
# via _resolve_command_provider_config).
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestPluginDispatch:
|
||||
"""Happy path: configured name matches a registered plugin, dispatcher fires."""
|
||||
|
||||
def test_registered_plugin_called(self):
|
||||
provider = _FakeTTSProvider(name="cartesia")
|
||||
tts_registry.register_provider(provider)
|
||||
|
||||
result = tts_tool._dispatch_to_plugin_provider(
|
||||
text="hello world",
|
||||
output_path="/tmp/out.mp3",
|
||||
provider="cartesia",
|
||||
tts_config={},
|
||||
)
|
||||
assert result == "/tmp/out.mp3"
|
||||
assert provider.last_call is not None
|
||||
assert provider.last_call["text"] == "hello world"
|
||||
assert provider.last_call["output_path"] == "/tmp/out.mp3"
|
||||
|
||||
def test_unregistered_name_returns_none(self):
|
||||
result = tts_tool._dispatch_to_plugin_provider(
|
||||
text="hello",
|
||||
output_path="/tmp/out.mp3",
|
||||
provider="unknown-tts",
|
||||
tts_config={},
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_voice_model_speed_format_forwarded(self):
|
||||
provider = _FakeTTSProvider(name="cartesia")
|
||||
tts_registry.register_provider(provider)
|
||||
|
||||
result = tts_tool._dispatch_to_plugin_provider(
|
||||
text="hello",
|
||||
output_path="/tmp/out.opus",
|
||||
provider="cartesia",
|
||||
tts_config={
|
||||
"voice": "voice-aria",
|
||||
"model": "sonic-2",
|
||||
"speed": 1.2,
|
||||
"output_format": "opus",
|
||||
},
|
||||
)
|
||||
assert result == "/tmp/out.opus"
|
||||
kwargs = provider.last_call["kwargs"]
|
||||
assert kwargs["voice"] == "voice-aria"
|
||||
assert kwargs["model"] == "sonic-2"
|
||||
assert kwargs["speed"] == 1.2
|
||||
assert kwargs["format"] == "opus"
|
||||
|
||||
def test_empty_string_voice_passed_as_none(self):
|
||||
"""Empty-string config values are normalized to None so providers can
|
||||
fall back to their own defaults (matches the ABC contract)."""
|
||||
provider = _FakeTTSProvider(name="cartesia")
|
||||
tts_registry.register_provider(provider)
|
||||
|
||||
tts_tool._dispatch_to_plugin_provider(
|
||||
text="hello",
|
||||
output_path="/tmp/out.mp3",
|
||||
provider="cartesia",
|
||||
tts_config={"voice": "", "model": ""},
|
||||
)
|
||||
kwargs = provider.last_call["kwargs"]
|
||||
assert kwargs["voice"] is None
|
||||
assert kwargs["model"] is None
|
||||
|
||||
def test_provider_returning_different_path_honored(self):
|
||||
"""If a provider rewrites the output path (e.g. format-driven extension
|
||||
change), the dispatcher returns the new path."""
|
||||
provider = _FakeTTSProvider(name="cartesia", return_path="/tmp/rewritten.opus")
|
||||
tts_registry.register_provider(provider)
|
||||
|
||||
result = tts_tool._dispatch_to_plugin_provider(
|
||||
text="hi",
|
||||
output_path="/tmp/out.mp3",
|
||||
provider="cartesia",
|
||||
tts_config={},
|
||||
)
|
||||
assert result == "/tmp/rewritten.opus"
|
||||
|
||||
def test_provider_returning_none_falls_back_to_output_path(self):
|
||||
"""Defensive: a provider returning None means the dispatcher should
|
||||
report the caller-supplied output_path (matches the ABC contract — the
|
||||
provider is supposed to write to output_path)."""
|
||||
provider = _FakeTTSProvider(name="cartesia", return_path=None)
|
||||
# Override the default-output-path behavior to return None explicitly
|
||||
provider._return_path = None
|
||||
|
||||
class _ReturnsNone(_FakeTTSProvider):
|
||||
def synthesize(self, text, output_path, **kw):
|
||||
return None # type: ignore[return-value]
|
||||
|
||||
provider2 = _ReturnsNone(name="weird")
|
||||
tts_registry.register_provider(provider2)
|
||||
|
||||
result = tts_tool._dispatch_to_plugin_provider(
|
||||
text="hi",
|
||||
output_path="/tmp/out.mp3",
|
||||
provider="weird",
|
||||
tts_config={},
|
||||
)
|
||||
assert result == "/tmp/out.mp3"
|
||||
|
||||
def test_provider_exception_bubbles_up(self):
|
||||
"""Plugin exceptions are NOT swallowed by the dispatcher — they bubble
|
||||
up so the outer ``text_to_speech_tool`` try/except converts them to
|
||||
the standard error envelope. Matches command-provider failure
|
||||
behavior."""
|
||||
provider = _FakeTTSProvider(
|
||||
name="cartesia",
|
||||
raise_exc=RuntimeError("network down"),
|
||||
)
|
||||
tts_registry.register_provider(provider)
|
||||
|
||||
with pytest.raises(RuntimeError, match="network down"):
|
||||
tts_tool._dispatch_to_plugin_provider(
|
||||
text="hi",
|
||||
output_path="/tmp/out.mp3",
|
||||
provider="cartesia",
|
||||
tts_config={},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# voice_compatible flag
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVoiceCompatibleHelper:
|
||||
def test_voice_compatible_true(self):
|
||||
tts_registry.register_provider(
|
||||
_FakeTTSProvider(name="cartesia", voice_compat=True)
|
||||
)
|
||||
assert tts_tool._plugin_provider_is_voice_compatible("cartesia") is True
|
||||
|
||||
def test_voice_compatible_false_by_default(self):
|
||||
tts_registry.register_provider(_FakeTTSProvider(name="cartesia"))
|
||||
assert tts_tool._plugin_provider_is_voice_compatible("cartesia") is False
|
||||
|
||||
def test_unregistered_provider_returns_false(self):
|
||||
assert tts_tool._plugin_provider_is_voice_compatible("unknown") is False
|
||||
|
||||
def test_empty_provider_name_returns_false(self):
|
||||
assert tts_tool._plugin_provider_is_voice_compatible("") is False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"builtin",
|
||||
["edge", "openai", "elevenlabs", "minimax", "gemini",
|
||||
"mistral", "xai", "piper", "kittentts", "neutts"],
|
||||
)
|
||||
def test_builtin_names_return_false(self, builtin):
|
||||
"""voice_compatible helper short-circuits built-ins so they go
|
||||
through the legacy code path that handles their format quirks."""
|
||||
assert tts_tool._plugin_provider_is_voice_compatible(builtin) is False
|
||||
|
||||
def test_voice_compatible_case_insensitive(self):
|
||||
tts_registry.register_provider(
|
||||
_FakeTTSProvider(name="cartesia", voice_compat=True)
|
||||
)
|
||||
assert tts_tool._plugin_provider_is_voice_compatible("CARTESIA") is True
|
||||
assert tts_tool._plugin_provider_is_voice_compatible(" cartesia ") is True
|
||||
|
||||
def test_provider_property_exception_returns_false(self):
|
||||
"""A buggy ``voice_compatible`` property raising must not crash the
|
||||
TTS pipeline."""
|
||||
|
||||
class _ExplodingProvider(_FakeTTSProvider):
|
||||
@property
|
||||
def voice_compatible(self) -> bool:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
tts_registry.register_provider(_ExplodingProvider(name="cartesia"))
|
||||
assert tts_tool._plugin_provider_is_voice_compatible("cartesia") is False
|
||||
Loading…
Add table
Add a link
Reference in a new issue