mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Merge branch 'main' of github.com:NousResearch/hermes-agent into feat/ink-refactor
This commit is contained in:
commit
f81dba0da2
128 changed files with 8357 additions and 842 deletions
|
|
@ -873,12 +873,37 @@ def _get_session_info(task_id: Optional[str] = None) -> Dict[str, str]:
|
|||
if provider is None:
|
||||
session_info = _create_local_session(task_id)
|
||||
else:
|
||||
session_info = provider.create_session(task_id)
|
||||
if session_info.get("cdp_url"):
|
||||
# Some cloud providers (including Browser-Use v3) return an HTTP
|
||||
# CDP discovery URL instead of a raw websocket endpoint.
|
||||
session_info = dict(session_info)
|
||||
session_info["cdp_url"] = _resolve_cdp_override(str(session_info["cdp_url"]))
|
||||
try:
|
||||
session_info = provider.create_session(task_id)
|
||||
# Validate cloud provider returned a usable session
|
||||
if not session_info or not isinstance(session_info, dict):
|
||||
raise ValueError(f"Cloud provider returned invalid session: {session_info!r}")
|
||||
if session_info.get("cdp_url"):
|
||||
# Some cloud providers (including Browser-Use v3) return an HTTP
|
||||
# CDP discovery URL instead of a raw websocket endpoint.
|
||||
session_info = dict(session_info)
|
||||
session_info["cdp_url"] = _resolve_cdp_override(str(session_info["cdp_url"]))
|
||||
except Exception as e:
|
||||
provider_name = type(provider).__name__
|
||||
logger.warning(
|
||||
"Cloud provider %s failed (%s); attempting fallback to local "
|
||||
"Chromium for task %s",
|
||||
provider_name, e, task_id,
|
||||
exc_info=True,
|
||||
)
|
||||
try:
|
||||
session_info = _create_local_session(task_id)
|
||||
except Exception as local_error:
|
||||
raise RuntimeError(
|
||||
f"Cloud provider {provider_name} failed ({e}) and local "
|
||||
f"fallback also failed ({local_error})"
|
||||
) from e
|
||||
# Mark session as degraded for observability
|
||||
if isinstance(session_info, dict):
|
||||
session_info = dict(session_info)
|
||||
session_info["fallback_from_cloud"] = True
|
||||
session_info["fallback_reason"] = str(e)
|
||||
session_info["fallback_provider"] = provider_name
|
||||
|
||||
with _cleanup_lock:
|
||||
# Double-check: another thread may have created a session while we
|
||||
|
|
|
|||
|
|
@ -988,7 +988,8 @@ def execute_code(
|
|||
# (terminal.env_passthrough) are passed through.
|
||||
_SAFE_ENV_PREFIXES = ("PATH", "HOME", "USER", "LANG", "LC_", "TERM",
|
||||
"TMPDIR", "TMP", "TEMP", "SHELL", "LOGNAME",
|
||||
"XDG_", "PYTHONPATH", "VIRTUAL_ENV", "CONDA")
|
||||
"XDG_", "PYTHONPATH", "VIRTUAL_ENV", "CONDA",
|
||||
"HERMES_")
|
||||
_SECRET_SUBSTRINGS = ("KEY", "TOKEN", "SECRET", "PASSWORD", "CREDENTIAL",
|
||||
"PASSWD", "AUTH")
|
||||
try:
|
||||
|
|
@ -1015,10 +1016,13 @@ def execute_code(
|
|||
_existing_pp = child_env.get("PYTHONPATH", "")
|
||||
child_env["PYTHONPATH"] = _hermes_root + (os.pathsep + _existing_pp if _existing_pp else "")
|
||||
# Inject user's configured timezone so datetime.now() in sandboxed
|
||||
# code reflects the correct wall-clock time.
|
||||
# code reflects the correct wall-clock time. Only TZ is set —
|
||||
# HERMES_TIMEZONE is an internal Hermes setting and must not leak
|
||||
# into child processes.
|
||||
_tz_name = os.getenv("HERMES_TIMEZONE", "").strip()
|
||||
if _tz_name:
|
||||
child_env["TZ"] = _tz_name
|
||||
child_env.pop("HERMES_TIMEZONE", None)
|
||||
|
||||
# Per-profile HOME isolation: redirect system tool configs into
|
||||
# {HERMES_HOME}/home/ when that directory exists.
|
||||
|
|
|
|||
|
|
@ -807,21 +807,61 @@ def delegate_task(
|
|||
)
|
||||
futures[future] = i
|
||||
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
entry = future.result()
|
||||
except Exception as exc:
|
||||
idx = futures[future]
|
||||
entry = {
|
||||
"task_index": idx,
|
||||
"status": "error",
|
||||
"summary": None,
|
||||
"error": str(exc),
|
||||
"api_calls": 0,
|
||||
"duration_seconds": 0,
|
||||
}
|
||||
results.append(entry)
|
||||
completed_count += 1
|
||||
# Poll futures with interrupt checking. as_completed() blocks
|
||||
# until ALL futures finish — if a child agent gets stuck,
|
||||
# the parent blocks forever even after interrupt propagation.
|
||||
# Instead, use wait() with a short timeout so we can bail
|
||||
# when the parent is interrupted.
|
||||
pending = set(futures.keys())
|
||||
while pending:
|
||||
if getattr(parent_agent, "_interrupt_requested", False) is True:
|
||||
# Parent interrupted — collect whatever finished and
|
||||
# abandon the rest. Children already received the
|
||||
# interrupt signal; we just can't wait forever.
|
||||
for f in pending:
|
||||
idx = futures[f]
|
||||
if f.done():
|
||||
try:
|
||||
entry = f.result()
|
||||
except Exception as exc:
|
||||
entry = {
|
||||
"task_index": idx,
|
||||
"status": "error",
|
||||
"summary": None,
|
||||
"error": str(exc),
|
||||
"api_calls": 0,
|
||||
"duration_seconds": 0,
|
||||
}
|
||||
else:
|
||||
entry = {
|
||||
"task_index": idx,
|
||||
"status": "interrupted",
|
||||
"summary": None,
|
||||
"error": "Parent agent interrupted — child did not finish in time",
|
||||
"api_calls": 0,
|
||||
"duration_seconds": 0,
|
||||
}
|
||||
results.append(entry)
|
||||
completed_count += 1
|
||||
break
|
||||
|
||||
from concurrent.futures import wait as _cf_wait, FIRST_COMPLETED
|
||||
done, pending = _cf_wait(pending, timeout=0.5, return_when=FIRST_COMPLETED)
|
||||
for future in done:
|
||||
try:
|
||||
entry = future.result()
|
||||
except Exception as exc:
|
||||
idx = futures[future]
|
||||
entry = {
|
||||
"task_index": idx,
|
||||
"status": "error",
|
||||
"summary": None,
|
||||
"error": str(exc),
|
||||
"api_calls": 0,
|
||||
"duration_seconds": 0,
|
||||
}
|
||||
results.append(entry)
|
||||
completed_count += 1
|
||||
|
||||
# Print per-task completion line above the spinner
|
||||
idx = entry["task_index"]
|
||||
|
|
|
|||
|
|
@ -1166,6 +1166,14 @@ class MCPServerTask:
|
|||
|
||||
_servers: Dict[str, MCPServerTask] = {}
|
||||
|
||||
# Circuit breaker: consecutive error counts per server. After
|
||||
# _CIRCUIT_BREAKER_THRESHOLD consecutive failures, the handler returns
|
||||
# a "server unreachable" message that tells the model to stop retrying,
|
||||
# preventing the 90-iteration burn loop described in #10447.
|
||||
# Reset to 0 on any successful call.
|
||||
_server_error_counts: Dict[str, int] = {}
|
||||
_CIRCUIT_BREAKER_THRESHOLD = 3
|
||||
|
||||
# Dedicated event loop running in a background daemon thread.
|
||||
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
_mcp_thread: Optional[threading.Thread] = None
|
||||
|
|
@ -1356,9 +1364,23 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
|||
"""
|
||||
|
||||
def _handler(args: dict, **kwargs) -> str:
|
||||
# Circuit breaker: if this server has failed too many times
|
||||
# consecutively, short-circuit with a clear message so the model
|
||||
# stops retrying and uses alternative approaches (#10447).
|
||||
if _server_error_counts.get(server_name, 0) >= _CIRCUIT_BREAKER_THRESHOLD:
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"MCP server '{server_name}' is unreachable after "
|
||||
f"{_CIRCUIT_BREAKER_THRESHOLD} consecutive failures. "
|
||||
f"Do NOT retry this tool — use alternative approaches "
|
||||
f"or ask the user to check the MCP server."
|
||||
)
|
||||
}, ensure_ascii=False)
|
||||
|
||||
with _lock:
|
||||
server = _servers.get(server_name)
|
||||
if not server or not server.session:
|
||||
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
|
||||
return json.dumps({
|
||||
"error": f"MCP server '{server_name}' is not connected"
|
||||
}, ensure_ascii=False)
|
||||
|
|
@ -1399,10 +1421,21 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
|||
return json.dumps({"result": text_result}, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
result = _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||
# Check if the MCP tool itself returned an error
|
||||
try:
|
||||
parsed = json.loads(result)
|
||||
if "error" in parsed:
|
||||
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
|
||||
else:
|
||||
_server_error_counts[server_name] = 0 # success — reset
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
_server_error_counts[server_name] = 0 # non-JSON = success
|
||||
return result
|
||||
except InterruptedError:
|
||||
return _interrupted_call_result()
|
||||
except Exception as exc:
|
||||
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
|
||||
logger.error(
|
||||
"MCP tool %s/%s call failed: %s",
|
||||
server_name, tool_name, exc,
|
||||
|
|
|
|||
|
|
@ -345,7 +345,7 @@ class ProcessRegistry:
|
|||
pty_env = _sanitize_subprocess_env(os.environ, env_vars)
|
||||
pty_env["PYTHONUNBUFFERED"] = "1"
|
||||
pty_proc = _PtyProcessCls.spawn(
|
||||
[user_shell, "-lic", command],
|
||||
[user_shell, "-lic", f"set +m; {command}"],
|
||||
cwd=session.cwd,
|
||||
env=pty_env,
|
||||
dimensions=(30, 120),
|
||||
|
|
@ -386,7 +386,7 @@ class ProcessRegistry:
|
|||
bg_env = _sanitize_subprocess_env(os.environ, env_vars)
|
||||
bg_env["PYTHONUNBUFFERED"] = "1"
|
||||
proc = subprocess.Popen(
|
||||
[user_shell, "-lic", command],
|
||||
[user_shell, "-lic", f"set +m; {command}"],
|
||||
text=True,
|
||||
cwd=session.cwd,
|
||||
env=bg_env,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Sends a message to a user or channel on any connected messaging platform
|
|||
human-friendly channel names to IDs. Works in both CLI and gateway contexts.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
|
@ -48,6 +49,49 @@ def _error(message: str) -> dict:
|
|||
return {"error": _sanitize_error_text(message)}
|
||||
|
||||
|
||||
def _telegram_retry_delay(exc: Exception, attempt: int) -> float | None:
|
||||
retry_after = getattr(exc, "retry_after", None)
|
||||
if retry_after is not None:
|
||||
try:
|
||||
return max(float(retry_after), 0.0)
|
||||
except (TypeError, ValueError):
|
||||
return 1.0
|
||||
|
||||
text = str(exc).lower()
|
||||
if "timed out" in text or "timeout" in text:
|
||||
return None
|
||||
if (
|
||||
"bad gateway" in text
|
||||
or "502" in text
|
||||
or "too many requests" in text
|
||||
or "429" in text
|
||||
or "service unavailable" in text
|
||||
or "503" in text
|
||||
or "gateway timeout" in text
|
||||
or "504" in text
|
||||
):
|
||||
return float(2 ** attempt)
|
||||
return None
|
||||
|
||||
|
||||
async def _send_telegram_message_with_retry(bot, *, attempts: int = 3, **kwargs):
|
||||
for attempt in range(attempts):
|
||||
try:
|
||||
return await bot.send_message(**kwargs)
|
||||
except Exception as exc:
|
||||
delay = _telegram_retry_delay(exc, attempt)
|
||||
if delay is None or attempt >= attempts - 1:
|
||||
raise
|
||||
logger.warning(
|
||||
"Transient Telegram send failure (attempt %d/%d), retrying in %.1fs: %s",
|
||||
attempt + 1,
|
||||
attempts,
|
||||
delay,
|
||||
_sanitize_error_text(exc),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
|
||||
SEND_MESSAGE_SCHEMA = {
|
||||
"name": "send_message",
|
||||
"description": (
|
||||
|
|
@ -327,10 +371,16 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None,
|
|||
"""
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import BasePlatformAdapter, utf16_len
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
from gateway.platforms.slack import SlackAdapter
|
||||
|
||||
# Telegram adapter import is optional (requires python-telegram-bot)
|
||||
try:
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
_telegram_available = True
|
||||
except ImportError:
|
||||
_telegram_available = False
|
||||
|
||||
# Feishu adapter import is optional (requires lark-oapi)
|
||||
try:
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
|
@ -349,7 +399,7 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None,
|
|||
|
||||
# Platform message length limits (from adapter class attributes)
|
||||
_MAX_LENGTHS = {
|
||||
Platform.TELEGRAM: TelegramAdapter.MAX_MESSAGE_LENGTH,
|
||||
Platform.TELEGRAM: TelegramAdapter.MAX_MESSAGE_LENGTH if _telegram_available else 4096,
|
||||
Platform.DISCORD: DiscordAdapter.MAX_MESSAGE_LENGTH,
|
||||
Platform.SLACK: SlackAdapter.MAX_MESSAGE_LENGTH,
|
||||
}
|
||||
|
|
@ -369,6 +419,7 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None,
|
|||
# --- Telegram: special handling for media attachments ---
|
||||
if platform == Platform.TELEGRAM:
|
||||
last_result = None
|
||||
disable_link_previews = bool(getattr(pconfig, "extra", {}) and pconfig.extra.get("disable_link_previews"))
|
||||
for i, chunk in enumerate(chunks):
|
||||
is_last = (i == len(chunks) - 1)
|
||||
result = await _send_telegram(
|
||||
|
|
@ -377,6 +428,7 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None,
|
|||
chunk,
|
||||
media_files=media_files if is_last else [],
|
||||
thread_id=thread_id,
|
||||
disable_link_previews=disable_link_previews,
|
||||
)
|
||||
if isinstance(result, dict) and result.get("error"):
|
||||
return result
|
||||
|
|
@ -404,11 +456,28 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None,
|
|||
last_result = result
|
||||
return last_result
|
||||
|
||||
# --- Matrix: use the native adapter helper when media is present ---
|
||||
if platform == Platform.MATRIX and media_files:
|
||||
last_result = None
|
||||
for i, chunk in enumerate(chunks):
|
||||
is_last = (i == len(chunks) - 1)
|
||||
result = await _send_matrix_via_adapter(
|
||||
pconfig,
|
||||
chat_id,
|
||||
chunk,
|
||||
media_files=media_files if is_last else [],
|
||||
thread_id=thread_id,
|
||||
)
|
||||
if isinstance(result, dict) and result.get("error"):
|
||||
return result
|
||||
last_result = result
|
||||
return last_result
|
||||
|
||||
# --- Non-Telegram/Discord platforms ---
|
||||
if media_files and not message.strip():
|
||||
return {
|
||||
"error": (
|
||||
f"send_message MEDIA delivery is currently only supported for telegram, discord, and weixin; "
|
||||
f"send_message MEDIA delivery is currently only supported for telegram, discord, matrix, and weixin; "
|
||||
f"target {platform.value} had only media attachments"
|
||||
)
|
||||
}
|
||||
|
|
@ -416,7 +485,7 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None,
|
|||
if media_files:
|
||||
warning = (
|
||||
f"MEDIA attachments were omitted for {platform.value}; "
|
||||
"native send_message media delivery is currently only supported for telegram, discord, and weixin"
|
||||
"native send_message media delivery is currently only supported for telegram, discord, matrix, and weixin"
|
||||
)
|
||||
|
||||
last_result = None
|
||||
|
|
@ -461,7 +530,7 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None,
|
|||
return last_result
|
||||
|
||||
|
||||
async def _send_telegram(token, chat_id, message, media_files=None, thread_id=None):
|
||||
async def _send_telegram(token, chat_id, message, media_files=None, thread_id=None, disable_link_previews=False):
|
||||
"""Send via Telegram Bot API (one-shot, no polling needed).
|
||||
|
||||
Applies markdown→MarkdownV2 formatting (same as the gateway adapter)
|
||||
|
|
@ -497,13 +566,16 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No
|
|||
thread_kwargs = {}
|
||||
if thread_id is not None:
|
||||
thread_kwargs["message_thread_id"] = int(thread_id)
|
||||
if disable_link_previews:
|
||||
thread_kwargs["disable_web_page_preview"] = True
|
||||
|
||||
last_msg = None
|
||||
warnings = []
|
||||
|
||||
if formatted.strip():
|
||||
try:
|
||||
last_msg = await bot.send_message(
|
||||
last_msg = await _send_telegram_message_with_retry(
|
||||
bot,
|
||||
chat_id=int_chat_id, text=formatted,
|
||||
parse_mode=send_parse_mode, **thread_kwargs
|
||||
)
|
||||
|
|
@ -523,7 +595,8 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No
|
|||
plain = message
|
||||
else:
|
||||
plain = message
|
||||
last_msg = await bot.send_message(
|
||||
last_msg = await _send_telegram_message_with_retry(
|
||||
bot,
|
||||
chat_id=int_chat_id, text=plain,
|
||||
parse_mode=None, **thread_kwargs
|
||||
)
|
||||
|
|
@ -907,6 +980,66 @@ async def _send_matrix(token, extra, chat_id, message):
|
|||
return _error(f"Matrix send failed: {e}")
|
||||
|
||||
|
||||
async def _send_matrix_via_adapter(pconfig, chat_id, message, media_files=None, thread_id=None):
|
||||
"""Send via the Matrix adapter so native Matrix media uploads are preserved."""
|
||||
try:
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
except ImportError:
|
||||
return {"error": "Matrix dependencies not installed. Run: pip install 'mautrix[encryption]'"}
|
||||
|
||||
media_files = media_files or []
|
||||
|
||||
try:
|
||||
adapter = MatrixAdapter(pconfig)
|
||||
connected = await adapter.connect()
|
||||
if not connected:
|
||||
return _error("Matrix connect failed")
|
||||
|
||||
metadata = {"thread_id": thread_id} if thread_id else None
|
||||
last_result = None
|
||||
|
||||
if message.strip():
|
||||
last_result = await adapter.send(chat_id, message, metadata=metadata)
|
||||
if not last_result.success:
|
||||
return _error(f"Matrix send failed: {last_result.error}")
|
||||
|
||||
for media_path, is_voice in media_files:
|
||||
if not os.path.exists(media_path):
|
||||
return _error(f"Media file not found: {media_path}")
|
||||
|
||||
ext = os.path.splitext(media_path)[1].lower()
|
||||
if ext in _IMAGE_EXTS:
|
||||
last_result = await adapter.send_image_file(chat_id, media_path, metadata=metadata)
|
||||
elif ext in _VIDEO_EXTS:
|
||||
last_result = await adapter.send_video(chat_id, media_path, metadata=metadata)
|
||||
elif ext in _VOICE_EXTS and is_voice:
|
||||
last_result = await adapter.send_voice(chat_id, media_path, metadata=metadata)
|
||||
elif ext in _AUDIO_EXTS:
|
||||
last_result = await adapter.send_voice(chat_id, media_path, metadata=metadata)
|
||||
else:
|
||||
last_result = await adapter.send_document(chat_id, media_path, metadata=metadata)
|
||||
|
||||
if not last_result.success:
|
||||
return _error(f"Matrix media send failed: {last_result.error}")
|
||||
|
||||
if last_result is None:
|
||||
return {"error": "No deliverable text or media remained after processing MEDIA tags"}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"platform": "matrix",
|
||||
"chat_id": chat_id,
|
||||
"message_id": last_result.message_id,
|
||||
}
|
||||
except Exception as e:
|
||||
return _error(f"Matrix send failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
await adapter.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _send_homeassistant(token, extra, chat_id, message):
|
||||
"""Send via Home Assistant notify service."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1263,6 +1263,7 @@ def skill_view(name: str, file_path: str = None, task_id: str = None) -> str:
|
|||
"related_skills": related_skills,
|
||||
"content": content,
|
||||
"path": rel_path,
|
||||
"skill_dir": str(skill_dir) if skill_dir else None,
|
||||
"linked_files": linked_files if linked_files else None,
|
||||
"usage_hint": "To view linked files, call skill_view(name, file_path) where file_path is e.g. 'references/api.md' or 'assets/config.yaml'"
|
||||
if linked_files
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ from hermes_constants import display_hermes_home
|
|||
logger = logging.getLogger(__name__)
|
||||
from tools.managed_tool_gateway import resolve_managed_tool_gateway
|
||||
from tools.tool_backend_helpers import managed_nous_tools_enabled, resolve_openai_audio_api_key
|
||||
from tools.xai_http import hermes_xai_user_agent
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lazy imports -- providers are imported only when actually used to avoid
|
||||
|
|
@ -93,6 +94,11 @@ DEFAULT_MINIMAX_VOICE_ID = "English_Graceful_Lady"
|
|||
DEFAULT_MINIMAX_BASE_URL = "https://api.minimax.io/v1/t2a_v2"
|
||||
DEFAULT_MISTRAL_TTS_MODEL = "voxtral-mini-tts-2603"
|
||||
DEFAULT_MISTRAL_TTS_VOICE_ID = "c69964a6-ab8b-4f8a-9465-ec0925096ec8" # Paul - Neutral
|
||||
DEFAULT_XAI_VOICE_ID = "eve"
|
||||
DEFAULT_XAI_LANGUAGE = "en"
|
||||
DEFAULT_XAI_SAMPLE_RATE = 24000
|
||||
DEFAULT_XAI_BIT_RATE = 128000
|
||||
DEFAULT_XAI_BASE_URL = "https://api.x.ai/v1"
|
||||
|
||||
def _get_default_output_dir() -> str:
|
||||
from hermes_constants import get_hermes_dir
|
||||
|
|
@ -299,6 +305,71 @@ def _generate_openai_tts(text: str, output_path: str, tts_config: Dict[str, Any]
|
|||
close()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Provider: xAI TTS
|
||||
# ===========================================================================
|
||||
def _generate_xai_tts(text: str, output_path: str, tts_config: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generate audio using xAI TTS.
|
||||
|
||||
xAI exposes a dedicated /v1/tts endpoint instead of the OpenAI audio.speech
|
||||
API shape, so this is implemented as a separate backend.
|
||||
"""
|
||||
import requests
|
||||
|
||||
api_key = os.getenv("XAI_API_KEY", "").strip()
|
||||
if not api_key:
|
||||
raise ValueError("XAI_API_KEY not set. Get one at https://console.x.ai/")
|
||||
|
||||
xai_config = tts_config.get("xai", {})
|
||||
voice_id = str(xai_config.get("voice_id", DEFAULT_XAI_VOICE_ID)).strip() or DEFAULT_XAI_VOICE_ID
|
||||
language = str(xai_config.get("language", DEFAULT_XAI_LANGUAGE)).strip() or DEFAULT_XAI_LANGUAGE
|
||||
sample_rate = int(xai_config.get("sample_rate", DEFAULT_XAI_SAMPLE_RATE))
|
||||
bit_rate = int(xai_config.get("bit_rate", DEFAULT_XAI_BIT_RATE))
|
||||
base_url = str(
|
||||
xai_config.get("base_url")
|
||||
or os.getenv("XAI_BASE_URL")
|
||||
or DEFAULT_XAI_BASE_URL
|
||||
).strip().rstrip("/")
|
||||
|
||||
# Match the documented minimal POST /v1/tts shape by default. Only send
|
||||
# output_format when Hermes actually needs a non-default format/override.
|
||||
codec = "wav" if output_path.endswith(".wav") else "mp3"
|
||||
payload: Dict[str, Any] = {
|
||||
"text": text,
|
||||
"voice_id": voice_id,
|
||||
"language": language,
|
||||
}
|
||||
if (
|
||||
codec != "mp3"
|
||||
or sample_rate != DEFAULT_XAI_SAMPLE_RATE
|
||||
or (codec == "mp3" and bit_rate != DEFAULT_XAI_BIT_RATE)
|
||||
):
|
||||
output_format: Dict[str, Any] = {"codec": codec}
|
||||
if sample_rate:
|
||||
output_format["sample_rate"] = sample_rate
|
||||
if codec == "mp3" and bit_rate:
|
||||
output_format["bit_rate"] = bit_rate
|
||||
payload["output_format"] = output_format
|
||||
|
||||
response = requests.post(
|
||||
f"{base_url}/tts",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": hermes_xai_user_agent(),
|
||||
},
|
||||
json=payload,
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Provider: MiniMax TTS
|
||||
# ===========================================================================
|
||||
|
|
@ -600,6 +671,10 @@ def text_to_speech_tool(
|
|||
logger.info("Generating speech with MiniMax TTS...")
|
||||
_generate_minimax_tts(text, file_str, tts_config)
|
||||
|
||||
elif provider == "xai":
|
||||
logger.info("Generating speech with xAI TTS...")
|
||||
_generate_xai_tts(text, file_str, tts_config)
|
||||
|
||||
elif provider == "mistral":
|
||||
try:
|
||||
_import_mistral_client()
|
||||
|
|
@ -661,7 +736,7 @@ def text_to_speech_tool(
|
|||
# Try Opus conversion for Telegram compatibility
|
||||
# Edge TTS outputs MP3, NeuTTS outputs WAV — both need ffmpeg conversion
|
||||
voice_compatible = False
|
||||
if provider in ("edge", "neutts", "minimax") and not file_str.endswith(".ogg"):
|
||||
if provider in ("edge", "neutts", "minimax", "xai") and not file_str.endswith(".ogg"):
|
||||
opus_path = _convert_to_opus(file_str)
|
||||
if opus_path:
|
||||
file_str = opus_path
|
||||
|
|
@ -734,6 +809,8 @@ def check_tts_requirements() -> bool:
|
|||
pass
|
||||
if os.getenv("MINIMAX_API_KEY"):
|
||||
return True
|
||||
if os.getenv("XAI_API_KEY"):
|
||||
return True
|
||||
try:
|
||||
_import_mistral_client()
|
||||
if os.getenv("MISTRAL_API_KEY"):
|
||||
|
|
|
|||
12
tools/xai_http.py
Normal file
12
tools/xai_http.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
"""Shared helpers for direct xAI HTTP integrations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def hermes_xai_user_agent() -> str:
|
||||
"""Return a stable Hermes-specific User-Agent for xAI HTTP calls."""
|
||||
try:
|
||||
from hermes_cli import __version__
|
||||
except Exception:
|
||||
__version__ = "unknown"
|
||||
return f"Hermes-Agent/{__version__}"
|
||||
Loading…
Add table
Add a link
Reference in a new issue