From e00b96540633e15a8972558033e96dade70804cc Mon Sep 17 00:00:00 2001 From: Carlos Diosdado Date: Thu, 18 Jun 2026 18:58:59 -0600 Subject: [PATCH] feat(tts): add xAI TTS speed and optimize_streaming_latency config knobs The xAI TTS REST endpoint (POST /v1/tts) accepts 'speed' (0.7-1.5) and 'optimize_streaming_latency' (0/1/2) parameters, but the Hermes built-in xAI provider was reading neither from config nor sending either in the request body. Add them as tts.xai.speed and tts.xai.optimize_streaming_latency config knobs (with global tts.speed / tts.optimize_streaming_latency fallbacks). - speed: float, clamped to 0.7-1.5. 1.0 (the API default) is omitted from the request body to preserve the existing minimal-payload contract. - optimize_streaming_latency: int, clamped to 0-2. 0 (best quality, the API default) is omitted from the request body. Resolver order: tts.xai. overrides the global tts.. --- tests/tools/test_tts_xai_speech_tags.py | 204 ++++++++++++++++++++++++ tools/tts_tool.py | 44 +++++ 2 files changed, 248 insertions(+) diff --git a/tests/tools/test_tts_xai_speech_tags.py b/tests/tools/test_tts_xai_speech_tags.py index d54fe7a5c92..4343a387f7a 100644 --- a/tests/tools/test_tts_xai_speech_tags.py +++ b/tests/tools/test_tts_xai_speech_tags.py @@ -324,3 +324,207 @@ def test_generate_xai_tts_leaves_text_plain_by_default(tmp_path, monkeypatch): ) assert captured["json"]["text"] == "Bonjour Monsieur Talbot. Ceci est un test." + + +def test_generate_xai_tts_omits_speed_and_latency_by_default(tmp_path, monkeypatch): + """No speed / optimize_streaming_latency in the request body unless + the user explicitly sets them. Keeps the existing minimal-payload + contract for default configs. + """ + captured = {} + + fake_response = Mock() + fake_response.content = b"mp3" + fake_response.raise_for_status.return_value = None + + def fake_post(url, headers, json, timeout): + captured["json"] = json + return fake_response + + monkeypatch.setenv("XAI_API_KEY", "test-xai-key") + monkeypatch.setattr("requests.post", fake_post) + + _generate_xai_tts( + "Hello world.", + str(tmp_path / "out.mp3"), + {"xai": {"voice_id": "ara", "language": "en"}}, + ) + + assert "speed" not in captured["json"] + assert "optimize_streaming_latency" not in captured["json"] + + +def test_generate_xai_tts_sends_speed_when_set(tmp_path, monkeypatch): + """tts.xai.speed flows into the POST body.""" + captured = {} + + fake_response = Mock() + fake_response.content = b"mp3" + fake_response.raise_for_status.return_value = None + + def fake_post(url, headers, json, timeout): + captured["json"] = json + return fake_response + + monkeypatch.setenv("XAI_API_KEY", "test-xai-key") + monkeypatch.setattr("requests.post", fake_post) + + _generate_xai_tts( + "Hello world.", + str(tmp_path / "out.mp3"), + {"xai": {"voice_id": "ara", "language": "en", "speed": 1.5}}, + ) + + assert captured["json"]["speed"] == 1.5 + + +def test_generate_xai_tts_speed_clamped_to_valid_range(tmp_path, monkeypatch): + """speed values outside xAI's 0.7..1.5 band are clamped, not sent raw.""" + captured = {} + + fake_response = Mock() + fake_response.content = b"mp3" + fake_response.raise_for_status.return_value = None + + def fake_post(url, headers, json, timeout): + captured["json"] = json + return fake_response + + monkeypatch.setenv("XAI_API_KEY", "test-xai-key") + monkeypatch.setattr("requests.post", fake_post) + + # Below 0.7 -> 0.7 + _generate_xai_tts( + "Hello.", + str(tmp_path / "out.mp3"), + {"xai": {"voice_id": "eve", "language": "en", "speed": 0.1}}, + ) + assert captured["json"]["speed"] == 0.7 + + # Above 1.5 -> 1.5 + _generate_xai_tts( + "Hello.", + str(tmp_path / "out.mp3"), + {"xai": {"voice_id": "eve", "language": "en", "speed": 3.0}}, + ) + assert captured["json"]["speed"] == 1.5 + + +def test_generate_xai_tts_omits_speed_when_exactly_default(tmp_path, monkeypatch): + """speed == 1.0 is the API default; the field stays out of the payload.""" + captured = {} + + fake_response = Mock() + fake_response.content = b"mp3" + fake_response.raise_for_status.return_value = None + + def fake_post(url, headers, json, timeout): + captured["json"] = json + return fake_response + + monkeypatch.setenv("XAI_API_KEY", "test-xai-key") + monkeypatch.setattr("requests.post", fake_post) + + _generate_xai_tts( + "Hello.", + str(tmp_path / "out.mp3"), + {"xai": {"voice_id": "eve", "language": "en", "speed": 1.0}}, + ) + + assert "speed" not in captured["json"] + + +def test_generate_xai_tts_sends_optimize_streaming_latency_when_set(tmp_path, monkeypatch): + """tts.xai.optimize_streaming_latency flows into the POST body.""" + captured = {} + + fake_response = Mock() + fake_response.content = b"mp3" + fake_response.raise_for_status.return_value = None + + def fake_post(url, headers, json, timeout): + captured["json"] = json + return fake_response + + monkeypatch.setenv("XAI_API_KEY", "test-xai-key") + monkeypatch.setattr("requests.post", fake_post) + + _generate_xai_tts( + "Hello world.", + str(tmp_path / "out.mp3"), + {"xai": {"voice_id": "ara", "language": "en", "optimize_streaming_latency": 2}}, + ) + + assert captured["json"]["optimize_streaming_latency"] == 2 + + +def test_generate_xai_tts_optimize_streaming_latency_omitted_at_default(tmp_path, monkeypatch): + """optimize_streaming_latency == 0 is the API default; field is not sent.""" + captured = {} + + fake_response = Mock() + fake_response.content = b"mp3" + fake_response.raise_for_status.return_value = None + + def fake_post(url, headers, json, timeout): + captured["json"] = json + return fake_response + + monkeypatch.setenv("XAI_API_KEY", "test-xai-key") + monkeypatch.setattr("requests.post", fake_post) + + _generate_xai_tts( + "Hello world.", + str(tmp_path / "out.mp3"), + {"xai": {"voice_id": "ara", "language": "en", "optimize_streaming_latency": 0}}, + ) + + assert "optimize_streaming_latency" not in captured["json"] + + +def test_generate_xai_tts_global_speed_used_as_fallback(tmp_path, monkeypatch): + """Global tts.speed is the fallback when tts.xai.speed is unset.""" + captured = {} + + fake_response = Mock() + fake_response.content = b"mp3" + fake_response.raise_for_status.return_value = None + + def fake_post(url, headers, json, timeout): + captured["json"] = json + return fake_response + + monkeypatch.setenv("XAI_API_KEY", "test-xai-key") + monkeypatch.setattr("requests.post", fake_post) + + _generate_xai_tts( + "Hello.", + str(tmp_path / "out.mp3"), + {"speed": 0.8, "xai": {"voice_id": "ara", "language": "en"}}, + ) + + assert captured["json"]["speed"] == 0.8 + + +def test_generate_xai_tts_provider_speed_overrides_global(tmp_path, monkeypatch): + """tts.xai.speed wins over the global tts.speed fallback.""" + captured = {} + + fake_response = Mock() + fake_response.content = b"mp3" + fake_response.raise_for_status.return_value = None + + def fake_post(url, headers, json, timeout): + captured["json"] = json + return fake_response + + monkeypatch.setenv("XAI_API_KEY", "test-xai-key") + monkeypatch.setattr("requests.post", fake_post) + + _generate_xai_tts( + "Hello.", + str(tmp_path / "out.mp3"), + {"speed": 1.5, "xai": {"voice_id": "ara", "language": "en", "speed": 0.7}}, + ) + + assert captured["json"]["speed"] == 0.7 diff --git a/tools/tts_tool.py b/tools/tts_tool.py index 808d21e85e3..d803086983e 100644 --- a/tools/tts_tool.py +++ b/tools/tts_tool.py @@ -187,6 +187,13 @@ DEFAULT_XAI_SAMPLE_RATE = 24000 DEFAULT_XAI_BIT_RATE = 128000 DEFAULT_XAI_AUTO_SPEECH_TAGS = False DEFAULT_XAI_BASE_URL = "https://api.x.ai/v1" +# xAI TTS `speed` accepts 0.7..1.5; 1.0 is the API default (omitted => default). +DEFAULT_XAI_SPEED_MIN = 0.7 +DEFAULT_XAI_SPEED_MAX = 1.5 +DEFAULT_XAI_SPEED_DEFAULT = 1.0 +# xAI TTS `optimize_streaming_latency` accepts 0, 1, or 2; 0 (best quality) is +# the API default (omitted => default). Values >0 trade quality for time-to-first-audio. +DEFAULT_XAI_OPTIMIZE_STREAMING_LATENCY_DEFAULT = 0 DEFAULT_GEMINI_TTS_MODEL = "gemini-2.5-flash-preview-tts" DEFAULT_GEMINI_TTS_VOICE = "Kore" DEFAULT_GEMINI_TTS_BASE_URL = "https://generativelanguage.googleapis.com/v1beta" @@ -1184,6 +1191,31 @@ def _generate_xai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) - xai_config.get("auto_speech_tags", xai_config.get("speech_tags")), DEFAULT_XAI_AUTO_SPEECH_TAGS, ) + # ``tts.xai.speed`` overrides global ``tts.speed``; the xAI TTS API + # accepts 0.7..1.5 (1.0 = normal). Out-of-range values are clamped so a + # misconfigured agent can't 400 the request — the API would reject + # anything outside the band. + speed = xai_config.get("speed", tts_config.get("speed")) + if speed is not None and speed != "": + try: + speed = float(speed) + except (TypeError, ValueError): + speed = None + if speed is not None: + speed = max(DEFAULT_XAI_SPEED_MIN, min(DEFAULT_XAI_SPEED_MAX, speed)) + # ``tts.xai.optimize_streaming_latency`` is 0, 1, or 2 (xAI-specific; + # trades chunk-boundary quality for time-to-first-audio). + optimize_streaming_latency = xai_config.get( + "optimize_streaming_latency", + tts_config.get("optimize_streaming_latency"), + ) + if optimize_streaming_latency is not None and optimize_streaming_latency != "": + try: + optimize_streaming_latency = int(optimize_streaming_latency) + except (TypeError, ValueError): + optimize_streaming_latency = None + if optimize_streaming_latency is not None: + optimize_streaming_latency = max(0, min(2, optimize_streaming_latency)) if auto_speech_tags: text = _apply_xai_auto_speech_tags(text) base_url = str( @@ -1212,6 +1244,18 @@ def _generate_xai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) - if codec == "mp3" and bit_rate: output_format["bit_rate"] = bit_rate payload["output_format"] = output_format + # Only attach `speed` when the caller asked for something other than the + # API default (1.0). Keeps the existing minimal-payload contract for + # users who never touch the knob. + if speed is not None and speed != DEFAULT_XAI_SPEED_DEFAULT: + payload["speed"] = speed + # Only attach `optimize_streaming_latency` when the caller explicitly + # opts in to a non-default value (anything other than 0). + if ( + optimize_streaming_latency is not None + and optimize_streaming_latency != DEFAULT_XAI_OPTIMIZE_STREAMING_LATENCY_DEFAULT + ): + payload["optimize_streaming_latency"] = optimize_streaming_latency response = requests.post( f"{base_url}/tts",