From 160bb565b4ec05b89c57808f2b8d425b39591475 Mon Sep 17 00:00:00 2001 From: Cdddo Date: Thu, 18 Jun 2026 20:51:37 -0600 Subject: [PATCH] feat(tts): expose speaker_id on built-in Piper provider The built-in Piper provider (tts.provider: piper, Python piper-tts package) already constructs piper.SynthesisConfig for the advanced tuning knobs, but did not forward speaker_id from the user config. This wires tts.piper.speaker_id through to SynthesisConfig.speaker_id so multi-speaker ONNX models (e.g. libritts_r) can be addressed via config without dropping to the command-provider path. Changes: - Add speaker_id to the has_advanced tuple so setting it triggers SynthesisConfig construction (same gating as the other knobs). - Pass speaker_id=speaker_id to SynthesisConfig. Defaults to 0 (Piper's own default; single-speaker models ignore the field). - Tolerant parse: bad input (non-int strings, lists, dicts) is dropped to 0 instead of raising. Booleans are rejected outright (True/False would silently coerce to 1/0 and hide a config mistake). Mirrors the same shape as the command-provider's _resolve_command_tts_optional_number helper. speaker_id is applied per-call via syn_config.speaker_id, so the PiperVoice cache key is intentionally left as just (model, cuda) -- the same loaded model serves all speakers. Tests cover the config knob, the tolerant parse, and the no-reload invariant. sentence_silence is intentionally not added here: the Python piper-tts SynthesisConfig does not expose that field (CLI-only). --- tests/tools/test_tts_piper.py | 93 ++++++++++++++++++++++++++++++++++- tools/tts_tool.py | 22 ++++++++- 2 files changed, 113 insertions(+), 2 deletions(-) diff --git a/tests/tools/test_tts_piper.py b/tests/tools/test_tts_piper.py index c30b26dc9b9..78567adf9bb 100644 --- a/tests/tools/test_tts_piper.py +++ b/tests/tools/test_tts_piper.py @@ -8,6 +8,7 @@ without requiring the ``piper-tts`` package to actually be installed import json import sys +import types from pathlib import Path from unittest.mock import MagicMock, patch @@ -219,7 +220,7 @@ class TestGeneratePiperTts: # The SynthesisConfig import happens inline inside _generate_piper_tts # via ``from piper import SynthesisConfig``. Inject a fake piper - # module so that import resolves. + # module so that that import resolves. monkeypatch.setitem(sys.modules, "piper", FakePiperModule) config = { @@ -239,6 +240,96 @@ class TestGeneratePiperTts: assert kwargs["length_scale"] == 2.0 assert kwargs["volume"] == 0.8 + def test_speaker_id_passed_through_to_synconfig(self, tmp_path, monkeypatch): + """speaker_id flows from config to SynthesisConfig when set.""" + model = self._prepare_voice_files(tmp_path) + monkeypatch.setattr(tts_tool, "_import_piper", lambda: _StubPiperVoice) + + fake_syn_cls = MagicMock() + monkeypatch.setitem(sys.modules, "piper", types.SimpleNamespace(SynthesisConfig=fake_syn_cls)) + + config = {"piper": {"voice": str(model), "speaker_id": 2}} + tts_tool._generate_piper_tts("hi", str(tmp_path / "out.wav"), config) + + fake_syn_cls.assert_called_once() + assert fake_syn_cls.call_args.kwargs["speaker_id"] == 2 + + def test_speaker_id_alone_triggers_synconfig(self, tmp_path, monkeypatch): + """Setting ONLY speaker_id (no other advanced knobs) still constructs SynthesisConfig. + + Regression guard: has_advanced must include speaker_id, otherwise + this knob gets silently dropped on the simplest configuration. + """ + model = self._prepare_voice_files(tmp_path) + monkeypatch.setattr(tts_tool, "_import_piper", lambda: _StubPiperVoice) + + fake_syn_cls = MagicMock() + monkeypatch.setitem(sys.modules, "piper", types.SimpleNamespace(SynthesisConfig=fake_syn_cls)) + + config = {"piper": {"voice": str(model), "speaker_id": 1}} + tts_tool._generate_piper_tts("hi", str(tmp_path / "out.wav"), config) + + fake_syn_cls.assert_called_once() + + def test_speaker_id_default_zero_when_unset(self, tmp_path, monkeypatch): + """No speaker_id in config → SynthesisConfig.speaker_id == 0 (Piper's default).""" + model = self._prepare_voice_files(tmp_path) + monkeypatch.setattr(tts_tool, "_import_piper", lambda: _StubPiperVoice) + + fake_syn_cls = MagicMock() + monkeypatch.setitem(sys.modules, "piper", types.SimpleNamespace(SynthesisConfig=fake_syn_cls)) + + config = {"piper": {"voice": str(model), "length_scale": 1.5}} + tts_tool._generate_piper_tts("hi", str(tmp_path / "out.wav"), config) + + assert fake_syn_cls.call_args.kwargs["speaker_id"] == 0 + + def test_speaker_id_bool_rejected_to_zero(self, tmp_path, monkeypatch): + """True/False would coerce to 1/0 and hide a config mistake — reject outright.""" + model = self._prepare_voice_files(tmp_path) + monkeypatch.setattr(tts_tool, "_import_piper", lambda: _StubPiperVoice) + + fake_syn_cls = MagicMock() + monkeypatch.setitem(sys.modules, "piper", types.SimpleNamespace(SynthesisConfig=fake_syn_cls)) + + for bad in (True, False): + fake_syn_cls.reset_mock() + config = {"piper": {"voice": str(model), "speaker_id": bad}} + tts_tool._generate_piper_tts("hi", str(tmp_path / f"out-{bad}.wav"), config) + assert fake_syn_cls.call_args.kwargs["speaker_id"] == 0 + + def test_speaker_id_non_int_dropped_to_zero(self, tmp_path, monkeypatch): + """Unparseable config (string, list, dict) drops to 0 instead of raising.""" + model = self._prepare_voice_files(tmp_path) + monkeypatch.setattr(tts_tool, "_import_piper", lambda: _StubPiperVoice) + + fake_syn_cls = MagicMock() + monkeypatch.setitem(sys.modules, "piper", types.SimpleNamespace(SynthesisConfig=fake_syn_cls)) + + for bad in ("two", [1, 2], {"k": 1}, None): + fake_syn_cls.reset_mock() + config = {"piper": {"voice": str(model), "speaker_id": bad}} + tts_tool._generate_piper_tts("hi", str(tmp_path / f"out-{type(bad).__name__}.wav"), config) + assert fake_syn_cls.call_args.kwargs["speaker_id"] == 0 + + def test_speaker_id_does_not_invalidate_voice_cache(self, tmp_path, monkeypatch): + """Switching speaker_id between calls must NOT trigger a model reload. + + PiperVoice is bound to a model, not a speaker — speaker is applied + per-call via syn_config.speaker_id. The voice cache should serve the + same PiperVoice instance for the same (model, cuda) regardless of + how many distinct speaker_ids the user cycles through. + """ + model = self._prepare_voice_files(tmp_path) + monkeypatch.setattr(tts_tool, "_import_piper", lambda: _StubPiperVoice) + + for speaker in (0, 1, 2, 3): + config = {"piper": {"voice": str(model), "speaker_id": speaker}} + tts_tool._generate_piper_tts("hi", str(tmp_path / f"out-{speaker}.wav"), config) + + # Only one PiperVoice.load() call across four calls with different speakers. + assert _StubPiperVoice.loaded == [str(model)] + # --------------------------------------------------------------------------- # text_to_speech_tool end-to-end (provider == "piper") diff --git a/tools/tts_tool.py b/tools/tts_tool.py index c6e7c22de0f..02fe4e5bda5 100644 --- a/tools/tts_tool.py +++ b/tools/tts_tool.py @@ -1889,6 +1889,18 @@ def _generate_piper_tts(text: str, output_path: str, tts_config: Dict[str, Any]) model_path = _resolve_piper_voice_path(voice_name, download_dir) + # Tolerant speaker_id parse: drop bad input (non-int strings, lists, dicts) + # to 0 (Piper's own default). Booleans are rejected outright — True/False + # would silently coerce to 1/0 and hide a config mistake. + _raw_speaker = piper_config.get("speaker_id", 0) + if isinstance(_raw_speaker, bool) or not isinstance(_raw_speaker, int): + speaker_id = 0 + else: + speaker_id = _raw_speaker + + # speaker_id is applied per-call via syn_config.speaker_id — the same + # PiperVoice instance serves all speakers, so it stays out of the cache + # key. Multi-speaker workflows share one model load. cache_key = f"{model_path}::cuda={use_cuda}" global _piper_voice_cache if cache_key not in _piper_voice_cache: @@ -1903,7 +1915,14 @@ def _generate_piper_tts(text: str, output_path: str, tts_config: Dict[str, Any]) syn_config = None has_advanced = any( k in piper_config - for k in ("length_scale", "noise_scale", "noise_w_scale", "volume", "normalize_audio") + for k in ( + "length_scale", + "noise_scale", + "noise_w_scale", + "volume", + "normalize_audio", + "speaker_id", + ) ) if has_advanced: try: @@ -1914,6 +1933,7 @@ def _generate_piper_tts(text: str, output_path: str, tts_config: Dict[str, Any]) noise_w_scale=float(piper_config.get("noise_w_scale", 0.8)), volume=float(piper_config.get("volume", 1.0)), normalize_audio=bool(piper_config.get("normalize_audio", True)), + speaker_id=speaker_id, ) except ImportError: logger.warning(