Add opt-in xAI TTS speech tag pauses

This commit is contained in:
Julien Talbot 2026-04-28 12:11:37 +04:00 committed by Teknium
parent 5af4b73f87
commit ca192cfb77
2 changed files with 161 additions and 0 deletions

View file

@ -0,0 +1,81 @@
"""Tests for xAI TTS speech-tag handling."""
from unittest.mock import Mock
from tools.tts_tool import _apply_xai_auto_speech_tags, _generate_xai_tts
def test_apply_xai_auto_speech_tags_adds_light_pause_after_first_sentence():
text = "Bonjour Monsieur Talbot. Ceci est un test de réponse vocale."
assert _apply_xai_auto_speech_tags(text) == (
"Bonjour Monsieur Talbot. [pause] Ceci est un test de réponse vocale."
)
def test_apply_xai_auto_speech_tags_preserves_explicit_tags():
text = "Bonjour. [pause] <whisper>Déjà balisé.</whisper>"
assert _apply_xai_auto_speech_tags(text) == text
def test_apply_xai_auto_speech_tags_preserves_all_documented_xai_tags():
text = "Bonjour Monsieur Talbot. [sigh] <slow>Je parle lentement.</slow> <emphasis>Important.</emphasis>"
assert _apply_xai_auto_speech_tags(text) == text
def test_generate_xai_tts_sends_auto_speech_tags_when_enabled(tmp_path, monkeypatch):
captured = {}
class FakeResponse:
content = b"mp3"
def raise_for_status(self):
pass
def fake_post(url, headers, json, timeout):
captured["url"] = url
captured["headers"] = headers
captured["json"] = json
captured["timeout"] = timeout
return FakeResponse()
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
monkeypatch.setattr("requests.post", fake_post)
out = tmp_path / "out.mp3"
_generate_xai_tts(
"Bonjour Monsieur Talbot. Ceci est un test.",
str(out),
{"xai": {"voice_id": "ara", "language": "fr", "auto_speech_tags": True}},
)
assert out.read_bytes() == b"mp3"
assert captured["url"] == "https://api.x.ai/v1/tts"
assert captured["json"]["voice_id"] == "ara"
assert captured["json"]["language"] == "fr"
assert captured["json"]["text"] == "Bonjour Monsieur Talbot. [pause] Ceci est un test."
def test_generate_xai_tts_leaves_text_plain_by_default(tmp_path, monkeypatch):
captured = {}
fake_response = Mock()
fake_response.content = b"mp3"
fake_response.raise_for_status.return_value = None
def fake_post(url, headers, json, timeout):
captured["json"] = json
return fake_response
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
monkeypatch.setattr("requests.post", fake_post)
_generate_xai_tts(
"Bonjour Monsieur Talbot. Ceci est un test.",
str(tmp_path / "out.mp3"),
{"xai": {"voice_id": "ara", "language": "fr"}},
)
assert captured["json"]["text"] == "Bonjour Monsieur Talbot. Ceci est un test."

View file

@ -167,6 +167,7 @@ DEFAULT_XAI_VOICE_ID = "eve"
DEFAULT_XAI_LANGUAGE = "en"
DEFAULT_XAI_SAMPLE_RATE = 24000
DEFAULT_XAI_BIT_RATE = 128000
DEFAULT_XAI_AUTO_SPEECH_TAGS = False
DEFAULT_XAI_BASE_URL = "https://api.x.ai/v1"
DEFAULT_GEMINI_TTS_MODEL = "gemini-2.5-flash-preview-tts"
DEFAULT_GEMINI_TTS_VOICE = "Kore"
@ -892,6 +893,79 @@ def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any]
# ===========================================================================
# Provider: xAI TTS
# ===========================================================================
_XAI_INLINE_SPEECH_TAGS = (
"pause",
"long-pause",
"hum-tune",
"laugh",
"chuckle",
"giggle",
"cry",
"tsk",
"tongue-click",
"lip-smack",
"breath",
"inhale",
"exhale",
"sigh",
)
_XAI_WRAPPING_SPEECH_TAGS = (
"soft",
"whisper",
"loud",
"build-intensity",
"decrease-intensity",
"higher-pitch",
"lower-pitch",
"slow",
"fast",
"sing-song",
"singing",
"laugh-speak",
"emphasis",
)
_XAI_SPEECH_TAG_RE = re.compile(
r"(\[(?:" + "|".join(_XAI_INLINE_SPEECH_TAGS) + r")\]|</?(?:" + "|".join(_XAI_WRAPPING_SPEECH_TAGS) + r")>)",
flags=re.IGNORECASE,
)
_XAI_FIRST_SENTENCE_RE = re.compile(r"^(.{12,120}?[.!?…])\s+(?=\S)", flags=re.DOTALL)
def _xai_bool_config(value: Any, default: bool = False) -> bool:
"""Coerce common YAML/env bool spellings without treating random strings as true."""
if isinstance(value, bool):
return value
if value is None:
return default
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on", "enabled"}:
return True
if normalized in {"0", "false", "no", "off", "disabled"}:
return False
return default
def _apply_xai_auto_speech_tags(text: str) -> str:
"""Add light xAI speech tags for more natural voice-mode replies.
The transform is intentionally conservative: it only inserts pauses. It
never fabricates laughter or whispering, and it leaves explicit user/model
speech tags untouched.
"""
clean = text.strip()
if not clean or _XAI_SPEECH_TAG_RE.search(clean):
return text
clean = re.sub(r"\n\s*\n+", " [pause] ", clean)
clean = re.sub(r"\s*\n\s*", " ", clean)
clean = _XAI_FIRST_SENTENCE_RE.sub(r"\1 [pause] ", clean, count=1)
clean = re.sub(r"\s{2,}", " ", clean).strip()
return clean
def _generate_xai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
"""
Generate audio using xAI TTS.
@ -913,6 +987,12 @@ def _generate_xai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -
language = str(xai_config.get("language", DEFAULT_XAI_LANGUAGE)).strip() or DEFAULT_XAI_LANGUAGE
sample_rate = int(xai_config.get("sample_rate", DEFAULT_XAI_SAMPLE_RATE))
bit_rate = int(xai_config.get("bit_rate", DEFAULT_XAI_BIT_RATE))
auto_speech_tags = _xai_bool_config(
xai_config.get("auto_speech_tags", xai_config.get("speech_tags")),
DEFAULT_XAI_AUTO_SPEECH_TAGS,
)
if auto_speech_tags:
text = _apply_xai_auto_speech_tags(text)
base_url = str(
xai_config.get("base_url")
or creds.get("base_url")