feat(stt): add stt.providers.<name> command-provider registry

Mirror of the TTS command-provider registry (PR #17843) for STT. Lets any
shell-driven ASR engine — Doubao ASR, NVIDIA Parakeet, whisper.cpp builds,
SenseVoice, curl pipelines — become an STT backend with zero Python.
Complements the legacy HERMES_LOCAL_STT_COMMAND escape hatch (preserved
untouched via the built-in local_command path) and the
register_transcription_provider() Python plugin hook also shipped in this
PR.

Resolution order (mirrors TTS exactly):

  1. Built-in (local, local_command, groq, openai, mistral, xai)
     → native handler. Always wins.
  2. stt.providers.<name>: type: command  → command-provider runner.
  3. Plugin-registered TranscriptionProvider → plugin dispatch.
  4. No match → 'No STT provider available'.

Files
-----
- tools/transcription_tools.py: BUILTIN_STT_PROVIDERS frozenset retained;
  added _resolve_command_stt_provider_config, _transcribe_command_stt,
  and local helpers for template rendering, shell-quote context, and
  process-tree termination. Helpers are documented as mirrors of their
  tts_tool.py counterparts (kept local to avoid cross-tool private
  import). Wire-in is one insertion point in transcribe_audio() after
  the xai elif and before the plugin dispatcher. Plugin dispatcher
  additionally defensively short-circuits when a same-name command
  config exists (command-wins-over-plugin invariant).

- tests/tools/test_transcription_command_providers.py: 50 new tests
  covering resolution (builtin precedence, type/command gating,
  case-insensitive lookup, legacy stt.<name> back-compat), helpers
  (timeout fallback, format validation, iter, has-any), template
  rendering (shell-quote contexts, doubled-brace preservation),
  end-to-end via _transcribe_command_stt (output_path read, stdout
  fallback, timeout, nonzero exit envelope, model override,
  language precedence), and dispatcher integration via the real
  transcribe_audio() including command-wins-over-plugin and
  builtin-shadow-rejection.

- tests/plugins/transcription/check_parity_vs_main.py: extended from
  10 to 13 scenarios. New cases: command-provider-installed,
  command-vs-plugin-same-name (verifies command wins precedence),
  explicit-openai-with-command-shadow (verifies built-in wins).
  Adds command_provider dispatch_kind detection via transcript prefix
  (CMD: vs PLUGIN:) so command-provider scenarios can be distinguished
  from plugin scenarios even when sharing a provider name.

- website/docs/user-guide/features/tts.md: new 'STT custom command
  providers' section symmetric to the TTS section — example config,
  placeholder grammar table (input_path / output_path / output_dir /
  format / language / model), transcript-read-back semantics (file
  first, then stdout fallback), optional keys table, behavior notes,
  security note. Updated 'Python plugin providers (STT)' to include
  the new 'When to pick which (STT)' decision table and updated
  resolution-order section (now 4 layers instead of 3).

Verification
------------
189/189 STT targeted tests + 50/50 new command-provider tests pass.
Combined sweep: tests/tools/ 5576/5576, tests/agent/ + tests/hermes_cli/
8623/8623 — zero regressions across 14,199 tests.

Parity harness: 13 scenarios, 9 OK + 4 expected diffs
(no_provider_error → plugin, plugin_unavailable, command_provider × 2).

E2E live-verified in an isolated HERMES_HOME with a real .wav file:

  command:                    → dispatched to stt.providers.my-fake-cli
  plugin:                     → dispatched to registered TranscriptionProvider
  command-wins-over-plugin:   → command provider beats same-name plugin
  builtin-wins-over-command:  → built-in OpenAI handler fires;
                                stt.providers.openai: type: command
                                does NOT hijack it.
This commit is contained in:
teknium1 2026-05-24 23:22:50 -07:00 committed by Teknium
parent 2cd952e110
commit d3ffbc6409
4 changed files with 1323 additions and 14 deletions

View file

@ -233,6 +233,503 @@ BUILTIN_STT_PROVIDERS = frozenset({
})
# ---------------------------------------------------------------------------
# Command-provider registry (``stt.providers.<name>: type: command``)
# ---------------------------------------------------------------------------
#
# Mirrors the TTS command-provider registry shipped in PR #17843 — same
# placeholder grammar, same shell-quote-aware rendering, same process-tree
# termination on timeout. Lets any whisper CLI / ASR CLI / curl pipeline
# become an STT backend with zero Python.
#
# Resolution order:
# 1. Built-in (``local``, ``local_command``, ``groq``, ``openai``,
# ``mistral``, ``xai``) → native handler. **Always wins.**
# 2. ``stt.providers.<name>: type: command`` → command-provider runner.
# 3. Plugin-registered TranscriptionProvider → plugin dispatch.
# 4. No match → "No STT provider available".
#
# The single-env-var ``HERMES_LOCAL_STT_COMMAND`` escape hatch is preserved
# untouched via the built-in ``local_command`` path. Use the command-provider
# registry when you want MULTIPLE shell-driven STT engines, or you want a
# named provider you can pick via ``stt.provider`` in config.yaml.
DEFAULT_COMMAND_STT_TIMEOUT_SECONDS = 300
DEFAULT_COMMAND_STT_LANGUAGE = "en"
DEFAULT_COMMAND_STT_OUTPUT_FORMAT = "txt"
COMMAND_STT_OUTPUT_FORMATS = frozenset({"txt", "json", "srt", "vtt"})
def _get_stt_section(stt_config: Dict[str, Any], name: str) -> Dict[str, Any]:
"""Return an stt sub-section if it's a dict, else an empty dict."""
if not isinstance(stt_config, dict):
return {}
section = stt_config.get(name)
return section if isinstance(section, dict) else {}
def _get_named_stt_provider_config(
stt_config: Dict[str, Any],
name: str,
) -> Dict[str, Any]:
"""Return the config dict for a user-declared STT command provider.
Looks up ``stt.providers.<name>`` first (the canonical location), and
falls back to ``stt.<name>`` so users who followed the built-in layout
still work. Returns an empty dict when the provider is not declared.
Built-in names are NOT special-cased here the caller short-circuits
them before this is consulted, AND ``_is_command_stt_provider_config``
requires an explicit ``command:`` value, so a built-in section like
``stt.openai`` (which has ``model``/``language`` but no ``command``)
can't accidentally be treated as a command provider.
"""
providers = _get_stt_section(stt_config, "providers")
section = providers.get(name) if isinstance(providers, dict) else None
if isinstance(section, dict):
return section
# Back-compat: allow ``stt.<name>`` for user-declared providers too,
# but only when the name is not a built-in (so a user's ``stt.openai``
# block still means the OpenAI provider, not a custom command).
if name.lower() not in BUILTIN_STT_PROVIDERS:
legacy = _get_stt_section(stt_config, name)
if legacy:
return legacy
return {}
def _is_command_stt_provider_config(config: Dict[str, Any]) -> bool:
"""Return True when *config* declares a command-type STT provider."""
if not isinstance(config, dict):
return False
ptype = str(config.get("type") or "").strip().lower()
if ptype and ptype != "command":
return False
command = config.get("command")
return isinstance(command, str) and bool(command.strip())
def _resolve_command_stt_provider_config(
provider: str,
stt_config: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
"""Return the provider config if *provider* resolves to a command type.
Built-in provider names are rejected (they have native handlers).
Returns None when the name is a built-in, ``"none"``, unknown, or not
a command type.
"""
if not provider:
return None
key = provider.lower().strip()
if key in BUILTIN_STT_PROVIDERS or key == "none":
return None
config = _get_named_stt_provider_config(stt_config, key)
if _is_command_stt_provider_config(config):
return config
return None
def _iter_command_stt_providers(stt_config: Dict[str, Any]):
"""Yield (name, config) pairs for every declared command-type STT provider."""
if not isinstance(stt_config, dict):
return
providers = _get_stt_section(stt_config, "providers")
for name, cfg in (providers or {}).items():
if isinstance(name, str) and name.lower() not in BUILTIN_STT_PROVIDERS:
if _is_command_stt_provider_config(cfg):
yield name, cfg
def _has_any_command_stt_provider(stt_config: Optional[Dict[str, Any]] = None) -> bool:
"""Return True when any command-type STT provider is configured."""
if stt_config is None:
stt_config = _load_stt_config()
for _name, _cfg in _iter_command_stt_providers(stt_config):
return True
return False
def _get_command_stt_timeout(config: Dict[str, Any]) -> float:
"""Return timeout in seconds, falling back when invalid."""
raw = config.get("timeout", config.get("timeout_seconds", DEFAULT_COMMAND_STT_TIMEOUT_SECONDS))
try:
value = float(raw)
except (TypeError, ValueError):
return float(DEFAULT_COMMAND_STT_TIMEOUT_SECONDS)
if value <= 0:
return float(DEFAULT_COMMAND_STT_TIMEOUT_SECONDS)
return value
def _get_command_stt_output_format(config: Dict[str, Any]) -> str:
"""Return the validated output format (txt/json/srt/vtt)."""
raw = (
config.get("format")
or config.get("output_format")
or DEFAULT_COMMAND_STT_OUTPUT_FORMAT
)
fmt = str(raw).lower().strip().lstrip(".")
return fmt if fmt in COMMAND_STT_OUTPUT_FORMATS else DEFAULT_COMMAND_STT_OUTPUT_FORMAT
def _shell_quote_context_stt(command_template: str, position: int) -> Optional[str]:
"""Return the shell quote character active right before *position*.
Mirrors ``tools.tts_tool._shell_quote_context`` kept local to avoid
cross-module import of a private helper. Returns ``"'"`` / ``'"'`` when
inside a quoted region, ``None`` for bare context.
"""
quote: Optional[str] = None
escaped = False
i = 0
while i < position:
char = command_template[i]
if quote == "'":
if char == "'":
quote = None
elif quote == '"':
if escaped:
escaped = False
elif char == "\\":
escaped = True
elif char == '"':
quote = None
elif char == "'":
quote = "'"
elif char == '"':
quote = '"'
elif char == "\\":
i += 1
i += 1
return quote
def _quote_command_stt_placeholder(value: str, quote_context: Optional[str]) -> str:
"""Quote a placeholder value for its position in a shell command template.
Mirrors ``tools.tts_tool._quote_command_tts_placeholder``.
"""
if quote_context == "'":
return value.replace("'", r"'\''")
if quote_context == '"':
return (
value
.replace("\\", "\\\\")
.replace('"', r'\"')
.replace("$", r"\$")
.replace("`", r"\`")
)
if os.name == "nt":
return subprocess.list2cmdline([value])
return shlex.quote(value)
def _render_command_stt_template(
command_template: str,
placeholders: Dict[str, str],
) -> str:
"""Replace supported placeholders while preserving ``{{`` / ``}}``.
Mirrors ``tools.tts_tool._render_command_tts_template``. Placeholders
are shell-quote-aware: ``{voice}`` inside single quotes gets
single-quote-safe escaping, inside double quotes gets ``$``/`` ` ``/`` " ``
escaping, outside quotes gets ``shlex.quote``. Doubled braces ``{{`` and
``}}`` are preserved as literal ``{`` / ``}`` for users who want to
embed JSON snippets in their command.
"""
import re
names = "|".join(re.escape(name) for name in placeholders)
pattern = re.compile(
rf"(?<!\$)(?:\{{\{{(?P<double>{names})\}}\}}|\{{(?P<single>{names})\}})"
)
replacements: list[tuple[str, str]] = []
def replace_match(match: "re.Match[str]") -> str:
name = match.group("double") or match.group("single")
token = f"__HERMES_STT_PLACEHOLDER_{len(replacements)}__"
replacements.append((
token,
_quote_command_stt_placeholder(
placeholders[name],
_shell_quote_context_stt(command_template, match.start()),
),
))
return token
rendered = pattern.sub(replace_match, command_template)
rendered = rendered.replace("{{", "{").replace("}}", "}")
for token, value in replacements:
rendered = rendered.replace(token, value)
return rendered
def _terminate_command_stt_process_tree(proc: subprocess.Popen) -> None:
"""Best-effort termination of a shell process and all of its children.
Mirrors ``tools.tts_tool._terminate_command_tts_process_tree``.
"""
if proc.poll() is not None:
return
if os.name == "nt":
try:
subprocess.run(
["taskkill", "/F", "/T", "/PID", str(proc.pid)],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
timeout=5,
)
except Exception:
proc.kill()
return
try:
import psutil # type: ignore
except ImportError:
# psutil is optional — fall back to single-process terminate/kill
proc.terminate()
try:
proc.wait(timeout=2)
except subprocess.TimeoutExpired:
proc.kill()
return
try:
parent = psutil.Process(proc.pid)
for child in parent.children(recursive=True):
try:
child.terminate()
except psutil.NoSuchProcess:
pass
parent.terminate()
except psutil.NoSuchProcess:
return
except Exception:
proc.terminate()
try:
proc.wait(timeout=2)
return
except subprocess.TimeoutExpired:
pass
try:
parent = psutil.Process(proc.pid)
for child in parent.children(recursive=True):
try:
child.kill()
except psutil.NoSuchProcess:
pass
parent.kill()
except psutil.NoSuchProcess:
return
except Exception:
proc.kill()
def _run_command_stt(command: str, timeout: float) -> subprocess.CompletedProcess:
"""Run a command-provider shell command with process-tree timeout cleanup.
Mirrors ``tools.tts_tool._run_command_tts``.
"""
popen_kwargs: Dict[str, Any] = {
"shell": True,
"stdout": subprocess.PIPE,
"stderr": subprocess.PIPE,
"text": True,
}
if os.name == "nt":
popen_kwargs["creationflags"] = getattr(subprocess, "CREATE_NEW_PROCESS_GROUP", 0)
else:
popen_kwargs["start_new_session"] = True
proc = subprocess.Popen(command, **popen_kwargs)
try:
stdout, stderr = proc.communicate(timeout=timeout)
except subprocess.TimeoutExpired as exc:
_terminate_command_stt_process_tree(proc)
try:
stdout, stderr = proc.communicate(timeout=1)
except Exception:
stdout = getattr(exc, "output", None)
stderr = getattr(exc, "stderr", None)
raise subprocess.TimeoutExpired(
command,
timeout,
output=stdout,
stderr=stderr,
) from exc
if proc.returncode:
raise subprocess.CalledProcessError(
proc.returncode,
command,
output=stdout,
stderr=stderr,
)
return subprocess.CompletedProcess(command, proc.returncode, stdout, stderr)
def _read_command_stt_output(output_path: Path, stdout: str, fmt: str) -> str:
"""Return the transcript text from a command-provider invocation.
Resolution:
1. If ``output_path`` exists and is non-empty read it (raw text).
2. Else if ``stdout`` is non-empty use stdout (lets users write
curl-style one-liners that emit transcript to stdout instead of
writing a file).
3. Else raise RuntimeError (no usable output produced).
For JSON format, we still return the raw bytes extracting a
``text`` field is out of scope; users either configure ``format: txt``
or post-process JSON downstream. (Same trade-off as TTS: the runner
doesn't try to be clever about output shape.)
"""
if output_path.exists():
try:
content = output_path.read_text(encoding="utf-8").strip()
except UnicodeDecodeError:
content = output_path.read_bytes().decode("utf-8", errors="replace").strip()
if content:
return content
if stdout and stdout.strip():
return stdout.strip()
raise RuntimeError(
f"Command STT provider wrote no output file at {output_path} "
f"and produced no stdout"
)
def _transcribe_command_stt(
file_path: str,
provider_name: str,
config: Dict[str, Any],
stt_config: Dict[str, Any],
model_override: Optional[str] = None,
) -> Dict[str, Any]:
"""Transcribe via a user-declared ``stt.providers.<name>: type: command``.
Placeholder grammar:
| Placeholder | Substituted with |
|-------------------|-----------------------------------------------------------|
| ``{input_path}`` | absolute path to the audio file (original location) |
| ``{output_path}`` | absolute path the provider should write its transcript to |
| ``{output_dir}`` | parent dir of ``{output_path}`` |
| ``{format}`` | configured output format (``txt`` / ``json`` / ``srt`` / ``vtt``) |
| ``{language}`` | configured language code (default ``en``) |
| ``{model}`` | configured model id (empty when not set) |
All placeholders are shell-quote-aware (see ``_render_command_stt_template``).
Doubled braces ``{{`` and ``}}`` are preserved as literal braces.
Returns the standard transcribe-response envelope (``success``,
``transcript``, ``provider``, ``error``).
"""
command_template = str(config.get("command") or "").strip()
if not command_template:
return {
"success": False,
"transcript": "",
"provider": provider_name,
"error": f"stt.providers.{provider_name}.command is not configured",
}
audio = Path(file_path).expanduser()
if not audio.exists():
return {
"success": False,
"transcript": "",
"provider": provider_name,
"error": f"Audio file not found: {file_path}",
}
timeout = _get_command_stt_timeout(config)
output_format = _get_command_stt_output_format(config)
language = (
config.get("language")
or stt_config.get("language")
or DEFAULT_COMMAND_STT_LANGUAGE
)
model = model_override or config.get("model") or ""
try:
with tempfile.TemporaryDirectory(prefix=f"hermes-cmd-stt-{provider_name}-") as tmpdir:
output_path = Path(tmpdir) / f"transcript.{output_format}"
placeholders = {
"input_path": str(audio.resolve()),
"output_path": str(output_path),
"output_dir": str(output_path.parent),
"format": output_format,
"language": str(language),
"model": str(model),
}
command = _render_command_stt_template(command_template, placeholders)
logger.info(
"Transcribing %s via command STT provider '%s'...",
audio.name, provider_name,
)
try:
result = _run_command_stt(command, timeout)
except subprocess.TimeoutExpired:
return {
"success": False,
"transcript": "",
"provider": provider_name,
"error": (
f"STT command provider '{provider_name}' timed out after "
f"{timeout:g}s"
),
}
except subprocess.CalledProcessError as exc:
detail_parts = []
if exc.stderr:
detail_parts.append(f"stderr: {exc.stderr.strip()}")
if exc.stdout:
detail_parts.append(f"stdout: {exc.stdout.strip()}")
detail = "; ".join(detail_parts) or "no command output"
return {
"success": False,
"transcript": "",
"provider": provider_name,
"error": (
f"STT command provider '{provider_name}' exited with code "
f"{exc.returncode}: {detail}"
),
}
try:
transcript_text = _read_command_stt_output(
output_path, result.stdout or "", output_format,
)
except RuntimeError as exc:
return {
"success": False,
"transcript": "",
"provider": provider_name,
"error": str(exc),
}
except OSError as exc:
return {
"success": False,
"transcript": "",
"provider": provider_name,
"error": f"STT command provider '{provider_name}' failed: {exc}",
}
logger.info(
"Transcribed %s via command STT provider '%s' (%d chars)",
audio.name, provider_name, len(transcript_text),
)
return {
"success": True,
"transcript": transcript_text,
"provider": provider_name,
}
def _get_provider(stt_config: dict) -> str:
"""Determine which STT provider to use.
@ -352,6 +849,7 @@ def _get_provider(stt_config: dict) -> str:
def _dispatch_to_plugin_provider(
file_path: str,
provider: str,
stt_config: Optional[Dict[str, Any]] = None,
*,
model: Optional[str] = None,
language: Optional[str] = None,
@ -370,12 +868,17 @@ def _dispatch_to_plugin_provider(
function defensively rejects those names so a plugin can't be
silently dispatched under a built-in name even if it somehow
slipped past the registry's built-in shadow guard.
2. Plugin dispatch fires only when ``provider`` matches a
2. Same-name command-type provider declared under
``stt.providers.<name>: type: command`` wins over a plugin. The
caller short-circuits to the command runner before reaching us,
but we re-verify here so a refactor of the caller can't silently
break the invariant (matches TTS PR #17843 precedence rule).
3. Plugin dispatch fires only when ``provider`` matches a
registered :class:`TranscriptionProvider` whose ``name`` equals
the configured value. Unknown names with no plugin registered
return None (caller surfaces the legacy "No STT provider"
message).
3. Availability gating: when the matched plugin reports
4. Availability gating: when the matched plugin reports
``is_available() == False`` (missing API key, missing optional
SDK, etc.) this returns an error envelope identifying the
plugin as unavailable **not** ``None`` because the user
@ -392,6 +895,13 @@ def _dispatch_to_plugin_provider(
key = provider.lower().strip()
if key in BUILTIN_STT_PROVIDERS or key == "none":
return None
# Defense in depth: command-provider check should already have
# short-circuited the caller. If a same-name command config exists,
# bail so the command path wins.
if stt_config is not None and _is_command_stt_provider_config(
_get_named_stt_provider_config(stt_config, key)
):
return None
try:
from agent.transcription_registry import get_provider
from hermes_cli.plugins import _ensure_plugins_discovered
@ -1058,9 +1568,26 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
model_name = model or "grok-stt"
return _transcribe_xai(file_path, model_name)
# User-declared command-type provider
# (``stt.providers.<name>: type: command``). Fires after the built-in
# elif chain — built-in names short-circuit upstream so a user's
# ``stt.providers.openai.command`` can't override the real OpenAI
# handler — and BEFORE the plugin dispatcher, because config is more
# local than a plugin install (same precedence rule as TTS PR #17843).
command_provider_config = _resolve_command_stt_provider_config(provider, stt_config)
if command_provider_config is not None:
return _transcribe_command_stt(
file_path,
provider,
command_provider_config,
stt_config,
model_override=model,
)
# Plugin-registered STT backend (e.g. OpenRouter, SenseAudio,
# Gemini-STT). Fires only when ``provider`` is neither a built-in
# nor ``"none"``. The dispatcher enforces built-ins-always-win
# nor ``"none"`` AND there is no same-name command provider. The
# dispatcher enforces built-ins-always-win + command-wins-over-plugin
# defensively. Returns None when no plugin is registered for the
# configured name, falling through to the legacy "No STT provider"
# error message below.
@ -1076,6 +1603,7 @@ def transcribe_audio(file_path: str, model: Optional[str] = None) -> Dict[str, A
plugin_result = _dispatch_to_plugin_provider(
file_path,
provider,
stt_config,
model=plugin_model,
language=plugin_language,
)