feat(tts): add xAI TTS speed and optimize_streaming_latency config knobs

The xAI TTS REST endpoint (POST /v1/tts) accepts 'speed' (0.7-1.5)
and 'optimize_streaming_latency' (0/1/2) parameters, but the Hermes
built-in xAI provider was reading neither from config nor sending
either in the request body. Add them as tts.xai.speed and
tts.xai.optimize_streaming_latency config knobs (with global
tts.speed / tts.optimize_streaming_latency fallbacks).

- speed: float, clamped to 0.7-1.5. 1.0 (the API default) is omitted
  from the request body to preserve the existing minimal-payload
  contract.
- optimize_streaming_latency: int, clamped to 0-2. 0 (best quality,
  the API default) is omitted from the request body.

Resolver order: tts.xai.<knob> overrides the global tts.<knob>.
This commit is contained in:
Carlos Diosdado 2026-06-18 18:58:59 -06:00 committed by Teknium
parent 8b7c89bff2
commit e00b965406
2 changed files with 248 additions and 0 deletions

View file

@ -324,3 +324,207 @@ def test_generate_xai_tts_leaves_text_plain_by_default(tmp_path, monkeypatch):
)
assert captured["json"]["text"] == "Bonjour Monsieur Talbot. Ceci est un test."
def test_generate_xai_tts_omits_speed_and_latency_by_default(tmp_path, monkeypatch):
"""No speed / optimize_streaming_latency in the request body unless
the user explicitly sets them. Keeps the existing minimal-payload
contract for default configs.
"""
captured = {}
fake_response = Mock()
fake_response.content = b"mp3"
fake_response.raise_for_status.return_value = None
def fake_post(url, headers, json, timeout):
captured["json"] = json
return fake_response
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
monkeypatch.setattr("requests.post", fake_post)
_generate_xai_tts(
"Hello world.",
str(tmp_path / "out.mp3"),
{"xai": {"voice_id": "ara", "language": "en"}},
)
assert "speed" not in captured["json"]
assert "optimize_streaming_latency" not in captured["json"]
def test_generate_xai_tts_sends_speed_when_set(tmp_path, monkeypatch):
"""tts.xai.speed flows into the POST body."""
captured = {}
fake_response = Mock()
fake_response.content = b"mp3"
fake_response.raise_for_status.return_value = None
def fake_post(url, headers, json, timeout):
captured["json"] = json
return fake_response
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
monkeypatch.setattr("requests.post", fake_post)
_generate_xai_tts(
"Hello world.",
str(tmp_path / "out.mp3"),
{"xai": {"voice_id": "ara", "language": "en", "speed": 1.5}},
)
assert captured["json"]["speed"] == 1.5
def test_generate_xai_tts_speed_clamped_to_valid_range(tmp_path, monkeypatch):
"""speed values outside xAI's 0.7..1.5 band are clamped, not sent raw."""
captured = {}
fake_response = Mock()
fake_response.content = b"mp3"
fake_response.raise_for_status.return_value = None
def fake_post(url, headers, json, timeout):
captured["json"] = json
return fake_response
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
monkeypatch.setattr("requests.post", fake_post)
# Below 0.7 -> 0.7
_generate_xai_tts(
"Hello.",
str(tmp_path / "out.mp3"),
{"xai": {"voice_id": "eve", "language": "en", "speed": 0.1}},
)
assert captured["json"]["speed"] == 0.7
# Above 1.5 -> 1.5
_generate_xai_tts(
"Hello.",
str(tmp_path / "out.mp3"),
{"xai": {"voice_id": "eve", "language": "en", "speed": 3.0}},
)
assert captured["json"]["speed"] == 1.5
def test_generate_xai_tts_omits_speed_when_exactly_default(tmp_path, monkeypatch):
"""speed == 1.0 is the API default; the field stays out of the payload."""
captured = {}
fake_response = Mock()
fake_response.content = b"mp3"
fake_response.raise_for_status.return_value = None
def fake_post(url, headers, json, timeout):
captured["json"] = json
return fake_response
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
monkeypatch.setattr("requests.post", fake_post)
_generate_xai_tts(
"Hello.",
str(tmp_path / "out.mp3"),
{"xai": {"voice_id": "eve", "language": "en", "speed": 1.0}},
)
assert "speed" not in captured["json"]
def test_generate_xai_tts_sends_optimize_streaming_latency_when_set(tmp_path, monkeypatch):
"""tts.xai.optimize_streaming_latency flows into the POST body."""
captured = {}
fake_response = Mock()
fake_response.content = b"mp3"
fake_response.raise_for_status.return_value = None
def fake_post(url, headers, json, timeout):
captured["json"] = json
return fake_response
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
monkeypatch.setattr("requests.post", fake_post)
_generate_xai_tts(
"Hello world.",
str(tmp_path / "out.mp3"),
{"xai": {"voice_id": "ara", "language": "en", "optimize_streaming_latency": 2}},
)
assert captured["json"]["optimize_streaming_latency"] == 2
def test_generate_xai_tts_optimize_streaming_latency_omitted_at_default(tmp_path, monkeypatch):
"""optimize_streaming_latency == 0 is the API default; field is not sent."""
captured = {}
fake_response = Mock()
fake_response.content = b"mp3"
fake_response.raise_for_status.return_value = None
def fake_post(url, headers, json, timeout):
captured["json"] = json
return fake_response
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
monkeypatch.setattr("requests.post", fake_post)
_generate_xai_tts(
"Hello world.",
str(tmp_path / "out.mp3"),
{"xai": {"voice_id": "ara", "language": "en", "optimize_streaming_latency": 0}},
)
assert "optimize_streaming_latency" not in captured["json"]
def test_generate_xai_tts_global_speed_used_as_fallback(tmp_path, monkeypatch):
"""Global tts.speed is the fallback when tts.xai.speed is unset."""
captured = {}
fake_response = Mock()
fake_response.content = b"mp3"
fake_response.raise_for_status.return_value = None
def fake_post(url, headers, json, timeout):
captured["json"] = json
return fake_response
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
monkeypatch.setattr("requests.post", fake_post)
_generate_xai_tts(
"Hello.",
str(tmp_path / "out.mp3"),
{"speed": 0.8, "xai": {"voice_id": "ara", "language": "en"}},
)
assert captured["json"]["speed"] == 0.8
def test_generate_xai_tts_provider_speed_overrides_global(tmp_path, monkeypatch):
"""tts.xai.speed wins over the global tts.speed fallback."""
captured = {}
fake_response = Mock()
fake_response.content = b"mp3"
fake_response.raise_for_status.return_value = None
def fake_post(url, headers, json, timeout):
captured["json"] = json
return fake_response
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
monkeypatch.setattr("requests.post", fake_post)
_generate_xai_tts(
"Hello.",
str(tmp_path / "out.mp3"),
{"speed": 1.5, "xai": {"voice_id": "ara", "language": "en", "speed": 0.7}},
)
assert captured["json"]["speed"] == 0.7

View file

@ -187,6 +187,13 @@ DEFAULT_XAI_SAMPLE_RATE = 24000
DEFAULT_XAI_BIT_RATE = 128000
DEFAULT_XAI_AUTO_SPEECH_TAGS = False
DEFAULT_XAI_BASE_URL = "https://api.x.ai/v1"
# xAI TTS `speed` accepts 0.7..1.5; 1.0 is the API default (omitted => default).
DEFAULT_XAI_SPEED_MIN = 0.7
DEFAULT_XAI_SPEED_MAX = 1.5
DEFAULT_XAI_SPEED_DEFAULT = 1.0
# xAI TTS `optimize_streaming_latency` accepts 0, 1, or 2; 0 (best quality) is
# the API default (omitted => default). Values >0 trade quality for time-to-first-audio.
DEFAULT_XAI_OPTIMIZE_STREAMING_LATENCY_DEFAULT = 0
DEFAULT_GEMINI_TTS_MODEL = "gemini-2.5-flash-preview-tts"
DEFAULT_GEMINI_TTS_VOICE = "Kore"
DEFAULT_GEMINI_TTS_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
@ -1184,6 +1191,31 @@ def _generate_xai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -
xai_config.get("auto_speech_tags", xai_config.get("speech_tags")),
DEFAULT_XAI_AUTO_SPEECH_TAGS,
)
# ``tts.xai.speed`` overrides global ``tts.speed``; the xAI TTS API
# accepts 0.7..1.5 (1.0 = normal). Out-of-range values are clamped so a
# misconfigured agent can't 400 the request — the API would reject
# anything outside the band.
speed = xai_config.get("speed", tts_config.get("speed"))
if speed is not None and speed != "":
try:
speed = float(speed)
except (TypeError, ValueError):
speed = None
if speed is not None:
speed = max(DEFAULT_XAI_SPEED_MIN, min(DEFAULT_XAI_SPEED_MAX, speed))
# ``tts.xai.optimize_streaming_latency`` is 0, 1, or 2 (xAI-specific;
# trades chunk-boundary quality for time-to-first-audio).
optimize_streaming_latency = xai_config.get(
"optimize_streaming_latency",
tts_config.get("optimize_streaming_latency"),
)
if optimize_streaming_latency is not None and optimize_streaming_latency != "":
try:
optimize_streaming_latency = int(optimize_streaming_latency)
except (TypeError, ValueError):
optimize_streaming_latency = None
if optimize_streaming_latency is not None:
optimize_streaming_latency = max(0, min(2, optimize_streaming_latency))
if auto_speech_tags:
text = _apply_xai_auto_speech_tags(text)
base_url = str(
@ -1212,6 +1244,18 @@ def _generate_xai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -
if codec == "mp3" and bit_rate:
output_format["bit_rate"] = bit_rate
payload["output_format"] = output_format
# Only attach `speed` when the caller asked for something other than the
# API default (1.0). Keeps the existing minimal-payload contract for
# users who never touch the knob.
if speed is not None and speed != DEFAULT_XAI_SPEED_DEFAULT:
payload["speed"] = speed
# Only attach `optimize_streaming_latency` when the caller explicitly
# opts in to a non-default value (anything other than 0).
if (
optimize_streaming_latency is not None
and optimize_streaming_latency != DEFAULT_XAI_OPTIMIZE_STREAMING_LATENCY_DEFAULT
):
payload["optimize_streaming_latency"] = optimize_streaming_latency
response = requests.post(
f"{base_url}/tts",