mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-21 10:22:18 +00:00
feat(tts): add xAI TTS speed and optimize_streaming_latency config knobs
The xAI TTS REST endpoint (POST /v1/tts) accepts 'speed' (0.7-1.5) and 'optimize_streaming_latency' (0/1/2) parameters, but the Hermes built-in xAI provider was reading neither from config nor sending either in the request body. Add them as tts.xai.speed and tts.xai.optimize_streaming_latency config knobs (with global tts.speed / tts.optimize_streaming_latency fallbacks). - speed: float, clamped to 0.7-1.5. 1.0 (the API default) is omitted from the request body to preserve the existing minimal-payload contract. - optimize_streaming_latency: int, clamped to 0-2. 0 (best quality, the API default) is omitted from the request body. Resolver order: tts.xai.<knob> overrides the global tts.<knob>.
This commit is contained in:
parent
8b7c89bff2
commit
e00b965406
2 changed files with 248 additions and 0 deletions
|
|
@ -324,3 +324,207 @@ def test_generate_xai_tts_leaves_text_plain_by_default(tmp_path, monkeypatch):
|
|||
)
|
||||
|
||||
assert captured["json"]["text"] == "Bonjour Monsieur Talbot. Ceci est un test."
|
||||
|
||||
|
||||
def test_generate_xai_tts_omits_speed_and_latency_by_default(tmp_path, monkeypatch):
|
||||
"""No speed / optimize_streaming_latency in the request body unless
|
||||
the user explicitly sets them. Keeps the existing minimal-payload
|
||||
contract for default configs.
|
||||
"""
|
||||
captured = {}
|
||||
|
||||
fake_response = Mock()
|
||||
fake_response.content = b"mp3"
|
||||
fake_response.raise_for_status.return_value = None
|
||||
|
||||
def fake_post(url, headers, json, timeout):
|
||||
captured["json"] = json
|
||||
return fake_response
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
|
||||
monkeypatch.setattr("requests.post", fake_post)
|
||||
|
||||
_generate_xai_tts(
|
||||
"Hello world.",
|
||||
str(tmp_path / "out.mp3"),
|
||||
{"xai": {"voice_id": "ara", "language": "en"}},
|
||||
)
|
||||
|
||||
assert "speed" not in captured["json"]
|
||||
assert "optimize_streaming_latency" not in captured["json"]
|
||||
|
||||
|
||||
def test_generate_xai_tts_sends_speed_when_set(tmp_path, monkeypatch):
|
||||
"""tts.xai.speed flows into the POST body."""
|
||||
captured = {}
|
||||
|
||||
fake_response = Mock()
|
||||
fake_response.content = b"mp3"
|
||||
fake_response.raise_for_status.return_value = None
|
||||
|
||||
def fake_post(url, headers, json, timeout):
|
||||
captured["json"] = json
|
||||
return fake_response
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
|
||||
monkeypatch.setattr("requests.post", fake_post)
|
||||
|
||||
_generate_xai_tts(
|
||||
"Hello world.",
|
||||
str(tmp_path / "out.mp3"),
|
||||
{"xai": {"voice_id": "ara", "language": "en", "speed": 1.5}},
|
||||
)
|
||||
|
||||
assert captured["json"]["speed"] == 1.5
|
||||
|
||||
|
||||
def test_generate_xai_tts_speed_clamped_to_valid_range(tmp_path, monkeypatch):
|
||||
"""speed values outside xAI's 0.7..1.5 band are clamped, not sent raw."""
|
||||
captured = {}
|
||||
|
||||
fake_response = Mock()
|
||||
fake_response.content = b"mp3"
|
||||
fake_response.raise_for_status.return_value = None
|
||||
|
||||
def fake_post(url, headers, json, timeout):
|
||||
captured["json"] = json
|
||||
return fake_response
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
|
||||
monkeypatch.setattr("requests.post", fake_post)
|
||||
|
||||
# Below 0.7 -> 0.7
|
||||
_generate_xai_tts(
|
||||
"Hello.",
|
||||
str(tmp_path / "out.mp3"),
|
||||
{"xai": {"voice_id": "eve", "language": "en", "speed": 0.1}},
|
||||
)
|
||||
assert captured["json"]["speed"] == 0.7
|
||||
|
||||
# Above 1.5 -> 1.5
|
||||
_generate_xai_tts(
|
||||
"Hello.",
|
||||
str(tmp_path / "out.mp3"),
|
||||
{"xai": {"voice_id": "eve", "language": "en", "speed": 3.0}},
|
||||
)
|
||||
assert captured["json"]["speed"] == 1.5
|
||||
|
||||
|
||||
def test_generate_xai_tts_omits_speed_when_exactly_default(tmp_path, monkeypatch):
|
||||
"""speed == 1.0 is the API default; the field stays out of the payload."""
|
||||
captured = {}
|
||||
|
||||
fake_response = Mock()
|
||||
fake_response.content = b"mp3"
|
||||
fake_response.raise_for_status.return_value = None
|
||||
|
||||
def fake_post(url, headers, json, timeout):
|
||||
captured["json"] = json
|
||||
return fake_response
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
|
||||
monkeypatch.setattr("requests.post", fake_post)
|
||||
|
||||
_generate_xai_tts(
|
||||
"Hello.",
|
||||
str(tmp_path / "out.mp3"),
|
||||
{"xai": {"voice_id": "eve", "language": "en", "speed": 1.0}},
|
||||
)
|
||||
|
||||
assert "speed" not in captured["json"]
|
||||
|
||||
|
||||
def test_generate_xai_tts_sends_optimize_streaming_latency_when_set(tmp_path, monkeypatch):
|
||||
"""tts.xai.optimize_streaming_latency flows into the POST body."""
|
||||
captured = {}
|
||||
|
||||
fake_response = Mock()
|
||||
fake_response.content = b"mp3"
|
||||
fake_response.raise_for_status.return_value = None
|
||||
|
||||
def fake_post(url, headers, json, timeout):
|
||||
captured["json"] = json
|
||||
return fake_response
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
|
||||
monkeypatch.setattr("requests.post", fake_post)
|
||||
|
||||
_generate_xai_tts(
|
||||
"Hello world.",
|
||||
str(tmp_path / "out.mp3"),
|
||||
{"xai": {"voice_id": "ara", "language": "en", "optimize_streaming_latency": 2}},
|
||||
)
|
||||
|
||||
assert captured["json"]["optimize_streaming_latency"] == 2
|
||||
|
||||
|
||||
def test_generate_xai_tts_optimize_streaming_latency_omitted_at_default(tmp_path, monkeypatch):
|
||||
"""optimize_streaming_latency == 0 is the API default; field is not sent."""
|
||||
captured = {}
|
||||
|
||||
fake_response = Mock()
|
||||
fake_response.content = b"mp3"
|
||||
fake_response.raise_for_status.return_value = None
|
||||
|
||||
def fake_post(url, headers, json, timeout):
|
||||
captured["json"] = json
|
||||
return fake_response
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
|
||||
monkeypatch.setattr("requests.post", fake_post)
|
||||
|
||||
_generate_xai_tts(
|
||||
"Hello world.",
|
||||
str(tmp_path / "out.mp3"),
|
||||
{"xai": {"voice_id": "ara", "language": "en", "optimize_streaming_latency": 0}},
|
||||
)
|
||||
|
||||
assert "optimize_streaming_latency" not in captured["json"]
|
||||
|
||||
|
||||
def test_generate_xai_tts_global_speed_used_as_fallback(tmp_path, monkeypatch):
|
||||
"""Global tts.speed is the fallback when tts.xai.speed is unset."""
|
||||
captured = {}
|
||||
|
||||
fake_response = Mock()
|
||||
fake_response.content = b"mp3"
|
||||
fake_response.raise_for_status.return_value = None
|
||||
|
||||
def fake_post(url, headers, json, timeout):
|
||||
captured["json"] = json
|
||||
return fake_response
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
|
||||
monkeypatch.setattr("requests.post", fake_post)
|
||||
|
||||
_generate_xai_tts(
|
||||
"Hello.",
|
||||
str(tmp_path / "out.mp3"),
|
||||
{"speed": 0.8, "xai": {"voice_id": "ara", "language": "en"}},
|
||||
)
|
||||
|
||||
assert captured["json"]["speed"] == 0.8
|
||||
|
||||
|
||||
def test_generate_xai_tts_provider_speed_overrides_global(tmp_path, monkeypatch):
|
||||
"""tts.xai.speed wins over the global tts.speed fallback."""
|
||||
captured = {}
|
||||
|
||||
fake_response = Mock()
|
||||
fake_response.content = b"mp3"
|
||||
fake_response.raise_for_status.return_value = None
|
||||
|
||||
def fake_post(url, headers, json, timeout):
|
||||
captured["json"] = json
|
||||
return fake_response
|
||||
|
||||
monkeypatch.setenv("XAI_API_KEY", "test-xai-key")
|
||||
monkeypatch.setattr("requests.post", fake_post)
|
||||
|
||||
_generate_xai_tts(
|
||||
"Hello.",
|
||||
str(tmp_path / "out.mp3"),
|
||||
{"speed": 1.5, "xai": {"voice_id": "ara", "language": "en", "speed": 0.7}},
|
||||
)
|
||||
|
||||
assert captured["json"]["speed"] == 0.7
|
||||
|
|
|
|||
|
|
@ -187,6 +187,13 @@ DEFAULT_XAI_SAMPLE_RATE = 24000
|
|||
DEFAULT_XAI_BIT_RATE = 128000
|
||||
DEFAULT_XAI_AUTO_SPEECH_TAGS = False
|
||||
DEFAULT_XAI_BASE_URL = "https://api.x.ai/v1"
|
||||
# xAI TTS `speed` accepts 0.7..1.5; 1.0 is the API default (omitted => default).
|
||||
DEFAULT_XAI_SPEED_MIN = 0.7
|
||||
DEFAULT_XAI_SPEED_MAX = 1.5
|
||||
DEFAULT_XAI_SPEED_DEFAULT = 1.0
|
||||
# xAI TTS `optimize_streaming_latency` accepts 0, 1, or 2; 0 (best quality) is
|
||||
# the API default (omitted => default). Values >0 trade quality for time-to-first-audio.
|
||||
DEFAULT_XAI_OPTIMIZE_STREAMING_LATENCY_DEFAULT = 0
|
||||
DEFAULT_GEMINI_TTS_MODEL = "gemini-2.5-flash-preview-tts"
|
||||
DEFAULT_GEMINI_TTS_VOICE = "Kore"
|
||||
DEFAULT_GEMINI_TTS_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
|
@ -1184,6 +1191,31 @@ def _generate_xai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -
|
|||
xai_config.get("auto_speech_tags", xai_config.get("speech_tags")),
|
||||
DEFAULT_XAI_AUTO_SPEECH_TAGS,
|
||||
)
|
||||
# ``tts.xai.speed`` overrides global ``tts.speed``; the xAI TTS API
|
||||
# accepts 0.7..1.5 (1.0 = normal). Out-of-range values are clamped so a
|
||||
# misconfigured agent can't 400 the request — the API would reject
|
||||
# anything outside the band.
|
||||
speed = xai_config.get("speed", tts_config.get("speed"))
|
||||
if speed is not None and speed != "":
|
||||
try:
|
||||
speed = float(speed)
|
||||
except (TypeError, ValueError):
|
||||
speed = None
|
||||
if speed is not None:
|
||||
speed = max(DEFAULT_XAI_SPEED_MIN, min(DEFAULT_XAI_SPEED_MAX, speed))
|
||||
# ``tts.xai.optimize_streaming_latency`` is 0, 1, or 2 (xAI-specific;
|
||||
# trades chunk-boundary quality for time-to-first-audio).
|
||||
optimize_streaming_latency = xai_config.get(
|
||||
"optimize_streaming_latency",
|
||||
tts_config.get("optimize_streaming_latency"),
|
||||
)
|
||||
if optimize_streaming_latency is not None and optimize_streaming_latency != "":
|
||||
try:
|
||||
optimize_streaming_latency = int(optimize_streaming_latency)
|
||||
except (TypeError, ValueError):
|
||||
optimize_streaming_latency = None
|
||||
if optimize_streaming_latency is not None:
|
||||
optimize_streaming_latency = max(0, min(2, optimize_streaming_latency))
|
||||
if auto_speech_tags:
|
||||
text = _apply_xai_auto_speech_tags(text)
|
||||
base_url = str(
|
||||
|
|
@ -1212,6 +1244,18 @@ def _generate_xai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -
|
|||
if codec == "mp3" and bit_rate:
|
||||
output_format["bit_rate"] = bit_rate
|
||||
payload["output_format"] = output_format
|
||||
# Only attach `speed` when the caller asked for something other than the
|
||||
# API default (1.0). Keeps the existing minimal-payload contract for
|
||||
# users who never touch the knob.
|
||||
if speed is not None and speed != DEFAULT_XAI_SPEED_DEFAULT:
|
||||
payload["speed"] = speed
|
||||
# Only attach `optimize_streaming_latency` when the caller explicitly
|
||||
# opts in to a non-default value (anything other than 0).
|
||||
if (
|
||||
optimize_streaming_latency is not None
|
||||
and optimize_streaming_latency != DEFAULT_XAI_OPTIMIZE_STREAMING_LATENCY_DEFAULT
|
||||
):
|
||||
payload["optimize_streaming_latency"] = optimize_streaming_latency
|
||||
|
||||
response = requests.post(
|
||||
f"{base_url}/tts",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue