diff --git a/README.md b/README.md index ab158fc2bd..622910b3a9 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ **The self-improving AI agent built by [Nous Research](https://nousresearch.com).** It's the only agent with a built-in learning loop — it creates skills from experience, improves them during use, nudges itself to persist knowledge, searches its own past conversations, and builds a deepening model of who you are across sessions. Run it on a $5 VPS, a GPU cluster, or serverless infrastructure that costs nearly nothing when idle. It's not tied to your laptop — talk to it from Telegram while it works on a cloud VM. -Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [Xiaomi MiMo](https://platform.xiaomimimo.com), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), [Hugging Face](https://huggingface.co), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in. +Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [NVIDIA NIM](https://build.nvidia.com) (Nemotron), [Xiaomi MiMo](https://platform.xiaomimimo.com), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), [Hugging Face](https://huggingface.co), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in. diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index 4f17461662..4860b16acd 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -94,6 +94,17 @@ def _normalize_aux_provider(provider: Optional[str]) -> str: return "custom" return _PROVIDER_ALIASES.get(normalized, normalized) + +_FIXED_TEMPERATURE_MODELS: Dict[str, float] = { + "kimi-for-coding": 0.6, +} + + +def _fixed_temperature_for_model(model: Optional[str]) -> Optional[float]: + """Return a required temperature override for models with strict contracts.""" + normalized = (model or "").strip().lower() + return _FIXED_TEMPERATURE_MODELS.get(normalized) + # Default auxiliary models for direct API-key providers (cheap/fast for side tasks) _API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = { "gemini": "gemini-3-flash-preview", @@ -2293,6 +2304,10 @@ def _build_call_kwargs( "timeout": timeout, } + fixed_temperature = _fixed_temperature_for_model(model) + if fixed_temperature is not None: + temperature = fixed_temperature + # Opus 4.7+ rejects any non-default temperature/top_p/top_k — silently # drop here so auxiliary callers that hardcode temperature (e.g. 0.3 on # flush_memories, 0 on structured-JSON extraction) don't 400 the moment diff --git a/agent/gemini_cloudcode_adapter.py b/agent/gemini_cloudcode_adapter.py index 36ba288eb4..ed687bffd6 100644 --- a/agent/gemini_cloudcode_adapter.py +++ b/agent/gemini_cloudcode_adapter.py @@ -747,18 +747,149 @@ class GeminiCloudCodeClient: def _gemini_http_error(response: httpx.Response) -> CodeAssistError: + """Translate an httpx response into a CodeAssistError with rich metadata. + + Parses Google's error envelope (``{"error": {"code", "message", "status", + "details": [...]}}``) so the agent's error classifier can reason about + the failure — ``status_code`` enables the rate_limit / auth classification + paths, and ``response`` lets the main loop honor ``Retry-After`` just + like it does for OpenAI SDK exceptions. + + Also lifts a few recognizable Google conditions into human-readable + messages so the user sees something better than a 500-char JSON dump: + + MODEL_CAPACITY_EXHAUSTED → "Gemini model capacity exhausted for + . This is a Google-side throttle..." + RESOURCE_EXHAUSTED w/o reason → quota-style message + 404 → "Model not found at cloudcode-pa..." + """ status = response.status_code + + # Parse the body once, surviving any weird encodings. + body_text = "" + body_json: Dict[str, Any] = {} try: - body = response.text[:500] + body_text = response.text except Exception: - body = "" - # Let run_agent's retry logic see auth errors as rotatable via `api_key` + body_text = "" + if body_text: + try: + parsed = json.loads(body_text) + if isinstance(parsed, dict): + body_json = parsed + except (ValueError, TypeError): + body_json = {} + + # Dig into Google's error envelope. Shape is: + # {"error": {"code": 429, "message": "...", "status": "RESOURCE_EXHAUSTED", + # "details": [{"@type": ".../ErrorInfo", "reason": "MODEL_CAPACITY_EXHAUSTED", + # "metadata": {...}}, + # {"@type": ".../RetryInfo", "retryDelay": "30s"}]}} + err_obj = body_json.get("error") if isinstance(body_json, dict) else None + if not isinstance(err_obj, dict): + err_obj = {} + err_status = str(err_obj.get("status") or "").strip() + err_message = str(err_obj.get("message") or "").strip() + err_details_list = err_obj.get("details") if isinstance(err_obj.get("details"), list) else [] + + # Extract google.rpc.ErrorInfo reason + metadata. There may be more + # than one ErrorInfo (rare), so we pick the first one with a reason. + error_reason = "" + error_metadata: Dict[str, Any] = {} + retry_delay_seconds: Optional[float] = None + for detail in err_details_list: + if not isinstance(detail, dict): + continue + type_url = str(detail.get("@type") or "") + if not error_reason and type_url.endswith("/google.rpc.ErrorInfo"): + reason = detail.get("reason") + if isinstance(reason, str) and reason: + error_reason = reason + md = detail.get("metadata") + if isinstance(md, dict): + error_metadata = md + elif retry_delay_seconds is None and type_url.endswith("/google.rpc.RetryInfo"): + # retryDelay is a google.protobuf.Duration string like "30s" or "1.5s". + delay_raw = detail.get("retryDelay") + if isinstance(delay_raw, str) and delay_raw.endswith("s"): + try: + retry_delay_seconds = float(delay_raw[:-1]) + except ValueError: + pass + elif isinstance(delay_raw, (int, float)): + retry_delay_seconds = float(delay_raw) + + # Fall back to the Retry-After header if the body didn't include RetryInfo. + if retry_delay_seconds is None: + try: + header_val = response.headers.get("Retry-After") or response.headers.get("retry-after") + except Exception: + header_val = None + if header_val: + try: + retry_delay_seconds = float(header_val) + except (TypeError, ValueError): + retry_delay_seconds = None + + # Classify the error code. ``code_assist_rate_limited`` stays the default + # for 429s; a more specific reason tag helps downstream callers (e.g. tests, + # logs) without changing the rate_limit classification path. code = f"code_assist_http_{status}" if status == 401: code = "code_assist_unauthorized" elif status == 429: code = "code_assist_rate_limited" + if error_reason == "MODEL_CAPACITY_EXHAUSTED": + code = "code_assist_capacity_exhausted" + + # Build a human-readable message. Keep the status + a raw-body tail for + # debugging, but lead with a friendlier summary when we recognize the + # Google signal. + model_hint = "" + if isinstance(error_metadata, dict): + model_hint = str(error_metadata.get("model") or error_metadata.get("modelId") or "").strip() + + if status == 429 and error_reason == "MODEL_CAPACITY_EXHAUSTED": + target = model_hint or "this Gemini model" + message = ( + f"Gemini capacity exhausted for {target} (Google-side throttle, " + f"not a Hermes issue). Try a different Gemini model or set a " + f"fallback_providers entry to a non-Gemini provider." + ) + if retry_delay_seconds is not None: + message += f" Google suggests retrying in {retry_delay_seconds:g}s." + elif status == 429 and err_status == "RESOURCE_EXHAUSTED": + message = ( + f"Gemini quota exhausted ({err_message or 'RESOURCE_EXHAUSTED'}). " + f"Check /gquota for remaining daily requests." + ) + if retry_delay_seconds is not None: + message += f" Retry suggested in {retry_delay_seconds:g}s." + elif status == 404: + # Google returns 404 when a model has been retired or renamed. + target = model_hint or (err_message or "model") + message = ( + f"Code Assist 404: {target} is not available at " + f"cloudcode-pa.googleapis.com. It may have been renamed or " + f"retired. Check hermes_cli/models.py for the current list." + ) + elif err_message: + # Generic fallback with the parsed message. + message = f"Code Assist HTTP {status} ({err_status or 'error'}): {err_message}" + else: + # Last-ditch fallback — raw body snippet. + message = f"Code Assist returned HTTP {status}: {body_text[:500]}" + return CodeAssistError( - f"Code Assist returned HTTP {status}: {body}", + message, code=code, + status_code=status, + response=response, + retry_after=retry_delay_seconds, + details={ + "status": err_status, + "reason": error_reason, + "metadata": error_metadata, + "message": err_message, + }, ) diff --git a/agent/google_code_assist.py b/agent/google_code_assist.py index 1acf3ea135..eba09b8f46 100644 --- a/agent/google_code_assist.py +++ b/agent/google_code_assist.py @@ -68,9 +68,45 @@ _ONBOARDING_POLL_INTERVAL_SECONDS = 5.0 class CodeAssistError(RuntimeError): - def __init__(self, message: str, *, code: str = "code_assist_error") -> None: + """Exception raised by the Code Assist (``cloudcode-pa``) integration. + + Carries HTTP status / response / retry-after metadata so the agent's + ``error_classifier._extract_status_code`` and the main loop's Retry-After + handling (which walks ``error.response.headers``) pick up the right + signals. Without these, 429s from the OAuth path look like opaque + ``RuntimeError`` and skip the rate-limit path. + """ + + def __init__( + self, + message: str, + *, + code: str = "code_assist_error", + status_code: Optional[int] = None, + response: Any = None, + retry_after: Optional[float] = None, + details: Optional[Dict[str, Any]] = None, + ) -> None: super().__init__(message) self.code = code + # ``status_code`` is picked up by ``agent.error_classifier._extract_status_code`` + # so a 429 from Code Assist classifies as FailoverReason.rate_limit and + # triggers the main loop's fallback_providers chain the same way SDK + # errors do. + self.status_code = status_code + # ``response`` is the underlying ``httpx.Response`` (or a shim with a + # ``.headers`` mapping and ``.json()`` method). The main loop reads + # ``error.response.headers["Retry-After"]`` to honor Google's retry + # hints when the backend throttles us. + self.response = response + # Parsed ``Retry-After`` seconds (kept separately for convenience — + # Google returns retry hints in both the header and the error body's + # ``google.rpc.RetryInfo`` details, and we pick whichever we found). + self.retry_after = retry_after + # Parsed structured error details from the Google error envelope + # (e.g. ``{"reason": "MODEL_CAPACITY_EXHAUSTED", "status": "RESOURCE_EXHAUSTED"}``). + # Useful for logging and for tests that want to assert on specifics. + self.details = details or {} class ProjectIdRequiredError(CodeAssistError): diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 089fd132ac..81bac6c92f 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -38,6 +38,7 @@ _PROVIDER_PREFIXES: frozenset[str] = frozenset({ "mimo", "xiaomi-mimo", "arcee-ai", "arceeai", "xai", "x-ai", "x.ai", "grok", + "nvidia", "nim", "nvidia-nim", "nemotron", "qwen-portal", }) @@ -124,7 +125,6 @@ DEFAULT_CONTEXT_LENGTHS = { "gemini": 1048576, # Gemma (open models served via AI Studio) "gemma-4-31b": 256000, - "gemma-4-26b": 256000, "gemma-3": 131072, "gemma": 8192, # fallback for older gemma models # DeepSeek @@ -158,6 +158,8 @@ DEFAULT_CONTEXT_LENGTHS = { "grok": 131072, # catch-all (grok-beta, unknown grok-*) # Kimi "kimi": 262144, + # Nemotron — NVIDIA's open-weights series (128K context across all sizes) + "nemotron": 131072, # Arcee "trinity": 262144, # OpenRouter @@ -240,6 +242,7 @@ _URL_TO_PROVIDER: Dict[str, str] = { "api.fireworks.ai": "fireworks", "opencode.ai": "opencode-go", "api.x.ai": "xai", + "integrate.api.nvidia.com": "nvidia", "api.xiaomimimo.com": "xiaomi", "xiaomimimo.com": "xiaomi", "ollama.com": "ollama-cloud", diff --git a/cli-config.yaml.example b/cli-config.yaml.example index 8c0484abd0..20b54b7887 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -24,6 +24,7 @@ model: # "minimax" - MiniMax global (requires: MINIMAX_API_KEY) # "minimax-cn" - MiniMax China (requires: MINIMAX_CN_API_KEY) # "huggingface" - Hugging Face Inference (requires: HF_TOKEN) + # "nvidia" - NVIDIA NIM / build.nvidia.com (requires: NVIDIA_API_KEY) # "xiaomi" - Xiaomi MiMo (requires: XIAOMI_API_KEY) # "arcee" - Arcee AI Trinity models (requires: ARCEEAI_API_KEY) # "ollama-cloud" - Ollama Cloud (requires: OLLAMA_API_KEY — https://ollama.com/settings) diff --git a/cron/scheduler.py b/cron/scheduler.py index 28c9057137..db5991c6f0 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -65,7 +65,15 @@ _HOME_TARGET_ENV_VARS = { "wecom": "WECOM_HOME_CHANNEL", "weixin": "WEIXIN_HOME_CHANNEL", "bluebubbles": "BLUEBUBBLES_HOME_CHANNEL", - "qqbot": "QQ_HOME_CHANNEL", + "qqbot": "QQBOT_HOME_CHANNEL", +} + +# Legacy env var names kept for back-compat. Each entry is the current +# primary env var → the previous name. _get_home_target_chat_id falls +# back to the legacy name if the primary is unset, so users who set the +# old name before the rename keep working until they migrate. +_LEGACY_HOME_TARGET_ENV_VARS = { + "QQBOT_HOME_CHANNEL": "QQ_HOME_CHANNEL", } from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run @@ -100,7 +108,12 @@ def _get_home_target_chat_id(platform_name: str) -> str: env_var = _HOME_TARGET_ENV_VARS.get(platform_name.lower()) if not env_var: return "" - return os.getenv(env_var, "") + value = os.getenv(env_var, "") + if not value: + legacy = _LEGACY_HOME_TARGET_ENV_VARS.get(env_var) + if legacy: + value = os.getenv(legacy, "") + return value def _resolve_single_delivery_target(job: dict, deliver_value: str) -> Optional[dict]: diff --git a/gateway/config.py b/gateway/config.py index 1258e08990..2d74073234 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -258,6 +258,13 @@ class GatewayConfig: # Streaming configuration streaming: StreamingConfig = field(default_factory=StreamingConfig) + # Session store pruning: drop SessionEntry records older than this many + # days from the in-memory dict and sessions.json. Keeps the store from + # growing unbounded in gateways serving many chats/threads/users over + # months. Pruning is invisible to users — if they resume, they get a + # fresh session exactly as if the reset policy had fired. 0 = disabled. + session_store_max_age_days: int = 90 + def get_connected_platforms(self) -> List[Platform]: """Return list of platforms that are enabled and configured.""" connected = [] @@ -365,6 +372,7 @@ class GatewayConfig: "thread_sessions_per_user": self.thread_sessions_per_user, "unauthorized_dm_behavior": self.unauthorized_dm_behavior, "streaming": self.streaming.to_dict(), + "session_store_max_age_days": self.session_store_max_age_days, } @classmethod @@ -412,6 +420,13 @@ class GatewayConfig: "pair", ) + try: + session_store_max_age_days = int(data.get("session_store_max_age_days", 90)) + if session_store_max_age_days < 0: + session_store_max_age_days = 0 + except (TypeError, ValueError): + session_store_max_age_days = 90 + return cls( platforms=platforms, default_reset_policy=default_policy, @@ -426,6 +441,7 @@ class GatewayConfig: thread_sessions_per_user=_coerce_bool(thread_sessions_per_user, False), unauthorized_dm_behavior=unauthorized_dm_behavior, streaming=StreamingConfig.from_dict(data.get("streaming", {})), + session_store_max_age_days=session_store_max_age_days, ) def get_unauthorized_dm_behavior(self, platform: Optional[Platform] = None) -> str: @@ -1213,12 +1229,24 @@ def _apply_env_overrides(config: GatewayConfig) -> None: qq_group_allowed = os.getenv("QQ_GROUP_ALLOWED_USERS", "").strip() if qq_group_allowed: extra["group_allow_from"] = qq_group_allowed - qq_home = os.getenv("QQ_HOME_CHANNEL", "").strip() + qq_home = os.getenv("QQBOT_HOME_CHANNEL", "").strip() + qq_home_name_env = "QQBOT_HOME_CHANNEL_NAME" + if not qq_home: + # Back-compat: accept the pre-rename name and log a one-time warning. + legacy_home = os.getenv("QQ_HOME_CHANNEL", "").strip() + if legacy_home: + qq_home = legacy_home + qq_home_name_env = "QQ_HOME_CHANNEL_NAME" + import logging + logging.getLogger(__name__).warning( + "QQ_HOME_CHANNEL is deprecated; rename to QQBOT_HOME_CHANNEL " + "in your .env for consistency with the platform key." + ) if qq_home: config.platforms[Platform.QQBOT].home_channel = HomeChannel( platform=Platform.QQBOT, chat_id=qq_home, - name=os.getenv("QQ_HOME_CHANNEL_NAME", "Home"), + name=os.getenv("QQBOT_HOME_CHANNEL_NAME") or os.getenv(qq_home_name_env, "Home"), ) # Session settings diff --git a/gateway/platforms/qqbot/__init__.py b/gateway/platforms/qqbot/__init__.py new file mode 100644 index 0000000000..7119dd979e --- /dev/null +++ b/gateway/platforms/qqbot/__init__.py @@ -0,0 +1,57 @@ +""" +QQBot platform package. + +Re-exports the main adapter symbols from ``adapter.py`` (the original +``qqbot.py``) so that **all existing import paths remain unchanged**:: + + from gateway.platforms.qqbot import QQAdapter # works + from gateway.platforms.qqbot import check_qq_requirements # works + +New modules: + - ``constants`` — shared constants (API URLs, timeouts, message types) + - ``utils`` — User-Agent builder, config helpers + - ``crypto`` — AES-256-GCM key generation and decryption + - ``onboard`` — QR-code scan-to-configure flow +""" + +# -- Adapter (original qqbot.py) ------------------------------------------ +from .adapter import ( # noqa: F401 + QQAdapter, + QQCloseError, + check_qq_requirements, + _coerce_list, + _ssrf_redirect_guard, +) + +# -- Onboard (QR-code scan-to-configure) ----------------------------------- +from .onboard import ( # noqa: F401 + BindStatus, + create_bind_task, + poll_bind_result, + build_connect_url, +) +from .crypto import decrypt_secret, generate_bind_key # noqa: F401 + +# -- Utils ----------------------------------------------------------------- +from .utils import build_user_agent, get_api_headers, coerce_list # noqa: F401 + +__all__ = [ + # adapter + "QQAdapter", + "QQCloseError", + "check_qq_requirements", + "_coerce_list", + "_ssrf_redirect_guard", + # onboard + "BindStatus", + "create_bind_task", + "poll_bind_result", + "build_connect_url", + # crypto + "decrypt_secret", + "generate_bind_key", + # utils + "build_user_agent", + "get_api_headers", + "coerce_list", +] diff --git a/gateway/platforms/qqbot.py b/gateway/platforms/qqbot/adapter.py similarity index 72% rename from gateway/platforms/qqbot.py rename to gateway/platforms/qqbot/adapter.py index 32252be12b..ced7442711 100644 --- a/gateway/platforms/qqbot.py +++ b/gateway/platforms/qqbot/adapter.py @@ -46,6 +46,7 @@ from urllib.parse import urlparse try: import aiohttp + AIOHTTP_AVAILABLE = True except ImportError: AIOHTTP_AVAILABLE = False @@ -53,6 +54,7 @@ except ImportError: try: import httpx + HTTPX_AVAILABLE = True except ImportError: HTTPX_AVAILABLE = False @@ -83,39 +85,40 @@ class QQCloseError(Exception): self.code = int(code) if code else None self.reason = str(reason) if reason else "" super().__init__(f"WebSocket closed (code={self.code}, reason={self.reason})") + + # --------------------------------------------------------------------------- -# Constants +# Constants — imported from the shared constants module. # --------------------------------------------------------------------------- -API_BASE = "https://api.sgroup.qq.com" -TOKEN_URL = "https://bots.qq.com/app/getAppAccessToken" -GATEWAY_URL_PATH = "/gateway" - -DEFAULT_API_TIMEOUT = 30.0 -FILE_UPLOAD_TIMEOUT = 120.0 -CONNECT_TIMEOUT_SECONDS = 20.0 - -RECONNECT_BACKOFF = [2, 5, 10, 30, 60] -MAX_RECONNECT_ATTEMPTS = 100 -RATE_LIMIT_DELAY = 60 # seconds -QUICK_DISCONNECT_THRESHOLD = 5.0 # seconds -MAX_QUICK_DISCONNECT_COUNT = 3 - -MAX_MESSAGE_LENGTH = 4000 -DEDUP_WINDOW_SECONDS = 300 -DEDUP_MAX_SIZE = 1000 - -# QQ Bot message types -MSG_TYPE_TEXT = 0 -MSG_TYPE_MARKDOWN = 2 -MSG_TYPE_MEDIA = 7 -MSG_TYPE_INPUT_NOTIFY = 6 - -# QQ Bot file media types -MEDIA_TYPE_IMAGE = 1 -MEDIA_TYPE_VIDEO = 2 -MEDIA_TYPE_VOICE = 3 -MEDIA_TYPE_FILE = 4 +from gateway.platforms.qqbot.constants import ( + API_BASE, + TOKEN_URL, + GATEWAY_URL_PATH, + DEFAULT_API_TIMEOUT, + FILE_UPLOAD_TIMEOUT, + CONNECT_TIMEOUT_SECONDS, + RECONNECT_BACKOFF, + MAX_RECONNECT_ATTEMPTS, + RATE_LIMIT_DELAY, + QUICK_DISCONNECT_THRESHOLD, + MAX_QUICK_DISCONNECT_COUNT, + MAX_MESSAGE_LENGTH, + DEDUP_WINDOW_SECONDS, + DEDUP_MAX_SIZE, + MSG_TYPE_TEXT, + MSG_TYPE_MARKDOWN, + MSG_TYPE_MEDIA, + MSG_TYPE_INPUT_NOTIFY, + MEDIA_TYPE_IMAGE, + MEDIA_TYPE_VIDEO, + MEDIA_TYPE_VOICE, + MEDIA_TYPE_FILE, +) +from gateway.platforms.qqbot.utils import ( + coerce_list as _coerce_list_impl, + build_user_agent, +) def check_qq_requirements() -> bool: @@ -125,24 +128,30 @@ def check_qq_requirements() -> bool: def _coerce_list(value: Any) -> List[str]: """Coerce config values into a trimmed string list.""" - if value is None: - return [] - if isinstance(value, str): - return [item.strip() for item in value.split(",") if item.strip()] - if isinstance(value, (list, tuple, set)): - return [str(item).strip() for item in value if str(item).strip()] - return [str(value).strip()] if str(value).strip() else [] + return _coerce_list_impl(value) # --------------------------------------------------------------------------- # QQAdapter # --------------------------------------------------------------------------- + class QQAdapter(BasePlatformAdapter): """QQ Bot adapter backed by the official QQ Bot WebSocket Gateway + REST API.""" # QQ Bot API does not support editing sent messages. SUPPORTS_MESSAGE_EDITING = False + MAX_MESSAGE_LENGTH = MAX_MESSAGE_LENGTH + _TYPING_INPUT_SECONDS = 60 # input_notify duration reported to QQ + _TYPING_DEBOUNCE_SECONDS = 50 # refresh before it expires + + @property + def _log_tag(self) -> str: + """Log prefix including app_id for multi-instance disambiguation.""" + app_id = getattr(self, "_app_id", None) + if app_id: + return f"QQBot:{app_id}" + return "QQBot" def _fail_pending(self, reason: str) -> None: """Fail all pending response futures.""" @@ -151,21 +160,25 @@ class QQAdapter(BasePlatformAdapter): fut.set_exception(RuntimeError(reason)) self._pending_responses.clear() - MAX_MESSAGE_LENGTH = MAX_MESSAGE_LENGTH - def __init__(self, config: PlatformConfig): super().__init__(config, Platform.QQBOT) extra = config.extra or {} self._app_id = str(extra.get("app_id") or os.getenv("QQ_APP_ID", "")).strip() - self._client_secret = str(extra.get("client_secret") or os.getenv("QQ_CLIENT_SECRET", "")).strip() + self._client_secret = str( + extra.get("client_secret") or os.getenv("QQ_CLIENT_SECRET", "") + ).strip() self._markdown_support = bool(extra.get("markdown_support", True)) # Auth/ACL policies self._dm_policy = str(extra.get("dm_policy", "open")).strip().lower() - self._allow_from = _coerce_list(extra.get("allow_from") or extra.get("allowFrom")) + self._allow_from = _coerce_list( + extra.get("allow_from") or extra.get("allowFrom") + ) self._group_policy = str(extra.get("group_policy", "open")).strip().lower() - self._group_allow_from = _coerce_list(extra.get("group_allow_from") or extra.get("groupAllowFrom")) + self._group_allow_from = _coerce_list( + extra.get("group_allow_from") or extra.get("groupAllowFrom") + ) # Connection state self._session: Optional[aiohttp.ClientSession] = None @@ -182,6 +195,11 @@ class QQAdapter(BasePlatformAdapter): self._pending_responses: Dict[str, asyncio.Future] = {} self._seen_messages: Dict[str, float] = {} + # Last inbound message ID per chat — used by send_typing + self._last_msg_id: Dict[str, str] = {} + # Typing debounce: chat_id → last send_typing timestamp + self._typing_sent_at: Dict[str, float] = {} + # Token cache self._access_token: Optional[str] = None self._token_expires_at: float = 0.0 @@ -207,23 +225,21 @@ class QQAdapter(BasePlatformAdapter): if not AIOHTTP_AVAILABLE: message = "QQ startup failed: aiohttp not installed" self._set_fatal_error("qq_missing_dependency", message, retryable=True) - logger.warning("[%s] %s. Run: pip install aiohttp", self.name, message) + logger.warning("[%s] %s. Run: pip install aiohttp", self._log_tag, message) return False if not HTTPX_AVAILABLE: message = "QQ startup failed: httpx not installed" self._set_fatal_error("qq_missing_dependency", message, retryable=True) - logger.warning("[%s] %s. Run: pip install httpx", self.name, message) + logger.warning("[%s] %s. Run: pip install httpx", self._log_tag, message) return False if not self._app_id or not self._client_secret: message = "QQ startup failed: QQ_APP_ID and QQ_CLIENT_SECRET are required" self._set_fatal_error("qq_missing_credentials", message, retryable=True) - logger.warning("[%s] %s", self.name, message) + logger.warning("[%s] %s", self._log_tag, message) return False # Prevent duplicate connections with the same credentials - if not self._acquire_platform_lock( - "qqbot-appid", self._app_id, "QQBot app ID" - ): + if not self._acquire_platform_lock("qqbot-appid", self._app_id, "QQBot app ID"): return False try: @@ -238,7 +254,7 @@ class QQAdapter(BasePlatformAdapter): # 2. Get WebSocket gateway URL gateway_url = await self._get_gateway_url() - logger.info("[%s] Gateway URL: %s", self.name, gateway_url) + logger.info("[%s] Gateway URL: %s", self._log_tag, gateway_url) # 3. Open WebSocket await self._open_ws(gateway_url) @@ -247,12 +263,12 @@ class QQAdapter(BasePlatformAdapter): self._listen_task = asyncio.create_task(self._listen_loop()) self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) self._mark_connected() - logger.info("[%s] Connected", self.name) + logger.info("[%s] Connected", self._log_tag) return True except Exception as exc: message = f"QQ startup failed: {exc}" self._set_fatal_error("qq_connect_error", message, retryable=True) - logger.error("[%s] %s", self.name, message, exc_info=True) + logger.error("[%s] %s", self._log_tag, message, exc_info=True) await self._cleanup() self._release_platform_lock() return False @@ -280,7 +296,7 @@ class QQAdapter(BasePlatformAdapter): await self._cleanup() self._release_platform_lock() - logger.info("[%s] Disconnected", self.name) + logger.info("[%s] Disconnected", self._log_tag) async def _cleanup(self) -> None: """Close WebSocket, HTTP session, and client.""" @@ -329,12 +345,16 @@ class QQAdapter(BasePlatformAdapter): token = data.get("access_token") if not token: - raise RuntimeError(f"QQ Bot token response missing access_token: {data}") + raise RuntimeError( + f"QQ Bot token response missing access_token: {data}" + ) expires_in = int(data.get("expires_in", 7200)) self._access_token = token self._token_expires_at = time.time() + expires_in - logger.info("[%s] Access token refreshed, expires in %ds", self.name, expires_in) + logger.info( + "[%s] Access token refreshed, expires in %ds", self._log_tag, expires_in + ) return self._access_token async def _get_gateway_url(self) -> str: @@ -343,7 +363,10 @@ class QQAdapter(BasePlatformAdapter): try: resp = await self._http_client.get( f"{API_BASE}{GATEWAY_URL_PATH}", - headers={"Authorization": f"QQBot {token}"}, + headers={ + "Authorization": f"QQBot {token}", + "User-Agent": build_user_agent(), + }, timeout=DEFAULT_API_TIMEOUT, ) resp.raise_for_status() @@ -373,9 +396,12 @@ class QQAdapter(BasePlatformAdapter): self._session = aiohttp.ClientSession() self._ws = await self._session.ws_connect( gateway_url, + headers={ + "User-Agent": build_user_agent(), + }, timeout=CONNECT_TIMEOUT_SECONDS, ) - logger.info("[%s] WebSocket connected to %s", self.name, gateway_url) + logger.info("[%s] WebSocket connected to %s", self._log_tag, gateway_url) async def _listen_loop(self) -> None: """Read WebSocket events and reconnect on errors. @@ -404,23 +430,34 @@ class QQAdapter(BasePlatformAdapter): return code = exc.code - logger.warning("[%s] WebSocket closed: code=%s reason=%s", - self.name, code, exc.reason) + logger.warning( + "[%s] WebSocket closed: code=%s reason=%s", + self._log_tag, + code, + exc.reason, + ) # Quick disconnect detection (permission issues, misconfiguration) duration = time.monotonic() - connect_time if duration < QUICK_DISCONNECT_THRESHOLD and connect_time > 0: quick_disconnect_count += 1 - logger.info("[%s] Quick disconnect (%.1fs), count: %d", - self.name, duration, quick_disconnect_count) + logger.info( + "[%s] Quick disconnect (%.1fs), count: %d", + self._log_tag, + duration, + quick_disconnect_count, + ) if quick_disconnect_count >= MAX_QUICK_DISCONNECT_COUNT: logger.error( "[%s] Too many quick disconnects. " "Check: 1) AppID/Secret correct 2) Bot permissions on QQ Open Platform", - self.name, + self._log_tag, + ) + self._set_fatal_error( + "qq_quick_disconnect", + "Too many quick disconnects — check bot permissions", + retryable=True, ) - self._set_fatal_error("qq_quick_disconnect", - "Too many quick disconnects — check bot permissions", retryable=True) return else: quick_disconnect_count = 0 @@ -431,13 +468,21 @@ class QQAdapter(BasePlatformAdapter): # Stop reconnecting for fatal codes if code in (4914, 4915): desc = "offline/sandbox-only" if code == 4914 else "banned" - logger.error("[%s] Bot is %s. Check QQ Open Platform.", self.name, desc) - self._set_fatal_error(f"qq_{desc}", f"Bot is {desc}", retryable=False) + logger.error( + "[%s] Bot is %s. Check QQ Open Platform.", self._log_tag, desc + ) + self._set_fatal_error( + f"qq_{desc}", f"Bot is {desc}", retryable=False + ) return # Rate limited if code == 4008: - logger.info("[%s] Rate limited (4008), waiting %ds", self.name, RATE_LIMIT_DELAY) + logger.info( + "[%s] Rate limited (4008), waiting %ds", + self._log_tag, + RATE_LIMIT_DELAY, + ) if backoff_idx >= MAX_RECONNECT_ATTEMPTS: return await asyncio.sleep(RATE_LIMIT_DELAY) @@ -450,14 +495,38 @@ class QQAdapter(BasePlatformAdapter): # Token invalid → clear cached token so _ensure_token() refreshes if code == 4004: - logger.info("[%s] Invalid token (4004), will refresh and reconnect", self.name) + logger.info( + "[%s] Invalid token (4004), will refresh and reconnect", + self._log_tag, + ) self._access_token = None self._token_expires_at = 0.0 # Session invalid → clear session, will re-identify on next Hello - if code in (4006, 4007, 4009, 4900, 4901, 4902, 4903, 4904, 4905, - 4906, 4907, 4908, 4909, 4910, 4911, 4912, 4913): - logger.info("[%s] Session error (%d), clearing session for re-identify", self.name, code) + if code in ( + 4006, + 4007, + 4009, + 4900, + 4901, + 4902, + 4903, + 4904, + 4905, + 4906, + 4907, + 4908, + 4909, + 4910, + 4911, + 4912, + 4913, + ): + logger.info( + "[%s] Session error (%d), clearing session for re-identify", + self._log_tag, + code, + ) self._session_id = None self._last_seq = None @@ -470,12 +539,12 @@ class QQAdapter(BasePlatformAdapter): except Exception as exc: if not self._running: return - logger.warning("[%s] WebSocket error: %s", self.name, exc) + logger.warning("[%s] WebSocket error: %s", self._log_tag, exc) self._mark_disconnected() self._fail_pending("Connection interrupted") if backoff_idx >= MAX_RECONNECT_ATTEMPTS: - logger.error("[%s] Max reconnect attempts reached", self.name) + logger.error("[%s] Max reconnect attempts reached", self._log_tag) return if await self._reconnect(backoff_idx): @@ -487,7 +556,12 @@ class QQAdapter(BasePlatformAdapter): async def _reconnect(self, backoff_idx: int) -> bool: """Attempt to reconnect the WebSocket. Returns True on success.""" delay = RECONNECT_BACKOFF[min(backoff_idx, len(RECONNECT_BACKOFF) - 1)] - logger.info("[%s] Reconnecting in %ds (attempt %d)...", self.name, delay, backoff_idx + 1) + logger.info( + "[%s] Reconnecting in %ds (attempt %d)...", + self._log_tag, + delay, + backoff_idx + 1, + ) await asyncio.sleep(delay) self._heartbeat_interval = 30.0 # reset until Hello @@ -496,10 +570,10 @@ class QQAdapter(BasePlatformAdapter): gateway_url = await self._get_gateway_url() await self._open_ws(gateway_url) self._mark_connected() - logger.info("[%s] Reconnected", self.name) + logger.info("[%s] Reconnected", self._log_tag) return True except Exception as exc: - logger.warning("[%s] Reconnect failed: %s", self.name, exc) + logger.warning("[%s] Reconnect failed: %s", self._log_tag, exc) return False async def _read_events(self) -> None: @@ -536,7 +610,7 @@ class QQAdapter(BasePlatformAdapter): # d should be the latest sequence number received, or null await self._ws.send_json({"op": 1, "d": self._last_seq}) except Exception as exc: - logger.debug("[%s] Heartbeat failed: %s", self.name, exc) + logger.debug("[%s] Heartbeat failed: %s", self._log_tag, exc) except asyncio.CancelledError: pass @@ -554,7 +628,11 @@ class QQAdapter(BasePlatformAdapter): "op": 2, "d": { "token": f"QQBot {token}", - "intents": (1 << 25) | (1 << 30) | (1 << 12), # C2C_GROUP_AT_MESSAGES + PUBLIC_GUILD_MESSAGES + DIRECT_MESSAGE + "intents": (1 << 25) + | (1 << 30) + | ( + 1 << 12 + ), # C2C_GROUP_AT_MESSAGES + PUBLIC_GUILD_MESSAGES + DIRECT_MESSAGE "shard": [0, 1], "properties": { "$os": "macOS", @@ -566,11 +644,13 @@ class QQAdapter(BasePlatformAdapter): try: if self._ws and not self._ws.closed: await self._ws.send_json(identify_payload) - logger.info("[%s] Identify sent", self.name) + logger.info("[%s] Identify sent", self._log_tag) else: - logger.warning("[%s] Cannot send Identify: WebSocket not connected", self.name) + logger.warning( + "[%s] Cannot send Identify: WebSocket not connected", self._log_tag + ) except Exception as exc: - logger.error("[%s] Failed to send Identify: %s", self.name, exc) + logger.error("[%s] Failed to send Identify: %s", self._log_tag, exc) async def _send_resume(self) -> None: """Send op 6 Resume to re-authenticate after a reconnection. @@ -589,12 +669,18 @@ class QQAdapter(BasePlatformAdapter): try: if self._ws and not self._ws.closed: await self._ws.send_json(resume_payload) - logger.info("[%s] Resume sent (session_id=%s, seq=%s)", - self.name, self._session_id, self._last_seq) + logger.info( + "[%s] Resume sent (session_id=%s, seq=%s)", + self._log_tag, + self._session_id, + self._last_seq, + ) else: - logger.warning("[%s] Cannot send Resume: WebSocket not connected", self.name) + logger.warning( + "[%s] Cannot send Resume: WebSocket not connected", self._log_tag + ) except Exception as exc: - logger.error("[%s] Failed to send Resume: %s", self.name, exc) + logger.error("[%s] Failed to send Resume: %s", self._log_tag, exc) # If resume fails, clear session and fall back to identify on next Hello self._session_id = None self._last_seq = None @@ -627,8 +713,12 @@ class QQAdapter(BasePlatformAdapter): interval_ms = d_data.get("heartbeat_interval", 30000) # Send heartbeats at 80% of the server interval to stay safe self._heartbeat_interval = interval_ms / 1000.0 * 0.8 - logger.debug("[%s] Hello received, heartbeat_interval=%dms (sending every %.1fs)", - self.name, interval_ms, self._heartbeat_interval) + logger.debug( + "[%s] Hello received, heartbeat_interval=%dms (sending every %.1fs)", + self._log_tag, + interval_ms, + self._heartbeat_interval, + ) # Authenticate: send Resume if we have a session, else Identify. # Use _create_task which is safe when no event loop is running (tests). if self._session_id and self._last_seq is not None: @@ -642,26 +732,30 @@ class QQAdapter(BasePlatformAdapter): if t == "READY": self._handle_ready(d) elif t == "RESUMED": - logger.info("[%s] Session resumed", self.name) - elif t in ("C2C_MESSAGE_CREATE", "GROUP_AT_MESSAGE_CREATE", - "DIRECT_MESSAGE_CREATE", "GUILD_MESSAGE_CREATE", - "GUILD_AT_MESSAGE_CREATE"): + logger.info("[%s] Session resumed", self._log_tag) + elif t in ( + "C2C_MESSAGE_CREATE", + "GROUP_AT_MESSAGE_CREATE", + "DIRECT_MESSAGE_CREATE", + "GUILD_MESSAGE_CREATE", + "GUILD_AT_MESSAGE_CREATE", + ): asyncio.create_task(self._on_message(t, d)) else: - logger.debug("[%s] Unhandled dispatch: %s", self.name, t) + logger.debug("[%s] Unhandled dispatch: %s", self._log_tag, t) return # op 11 = Heartbeat ACK if op == 11: return - logger.debug("[%s] Unknown op: %s", self.name, op) + logger.debug("[%s] Unknown op: %s", self._log_tag, op) def _handle_ready(self, d: Any) -> None: """Handle the READY event — store session_id for resume.""" if isinstance(d, dict): self._session_id = d.get("session_id") - logger.info("[%s] Ready, session_id=%s", self.name, self._session_id) + logger.info("[%s] Ready, session_id=%s", self._log_tag, self._session_id) # ------------------------------------------------------------------ # JSON helpers @@ -672,7 +766,7 @@ class QQAdapter(BasePlatformAdapter): try: payload = json.loads(raw) except Exception: - logger.debug("[%s] Failed to parse JSON: %r", "QQBot", raw) + logger.warning("[QQBot] Failed to parse JSON: %r", raw) return None return payload if isinstance(payload, dict) else None @@ -687,6 +781,12 @@ class QQAdapter(BasePlatformAdapter): # Inbound message handling # ------------------------------------------------------------------ + async def handle_message(self, event: MessageEvent) -> None: + """Cache the last message ID per chat, then delegate to base.""" + if event.message_id and event.source.chat_id: + self._last_msg_id[event.source.chat_id] = event.message_id + await super().handle_message(event) + async def _on_message(self, event_type: str, d: Any) -> None: """Process an inbound QQ Bot message event.""" if not isinstance(d, dict): @@ -695,7 +795,9 @@ class QQAdapter(BasePlatformAdapter): # Extract common fields msg_id = str(d.get("id", "")) if not msg_id or self._is_duplicate(msg_id): - logger.debug("[%s] Duplicate or missing message id: %s", self.name, msg_id) + logger.debug( + "[%s] Duplicate or missing message id: %s", self._log_tag, msg_id + ) return timestamp = str(d.get("timestamp", "")) @@ -713,7 +815,12 @@ class QQAdapter(BasePlatformAdapter): await self._handle_dm_message(d, msg_id, content, author, timestamp) async def _handle_c2c_message( - self, d: Dict[str, Any], msg_id: str, content: str, author: Dict[str, Any], timestamp: str + self, + d: Dict[str, Any], + msg_id: str, + content: str, + author: Dict[str, Any], + timestamp: str, ) -> None: """Handle a C2C (private) message event.""" user_openid = str(author.get("user_openid", "")) @@ -724,17 +831,28 @@ class QQAdapter(BasePlatformAdapter): text = content attachments_raw = d.get("attachments") - logger.info("[QQ] C2C message: id=%s content=%r attachments=%s", - msg_id, content[:50] if content else "", - f"{len(attachments_raw) if isinstance(attachments_raw, list) else 0} items" - if attachments_raw else "None") + logger.info( + "[%s] C2C message: id=%s content=%r attachments=%s", + self._log_tag, + msg_id, + content[:50] if content else "", + ( + f"{len(attachments_raw) if isinstance(attachments_raw, list) else 0} items" + if attachments_raw + else "None" + ), + ) if attachments_raw and isinstance(attachments_raw, list): for _i, _att in enumerate(attachments_raw): if isinstance(_att, dict): - logger.info("[QQ] attachment[%d]: content_type=%s url=%s filename=%s", - _i, _att.get("content_type", ""), - str(_att.get("url", ""))[:80], - _att.get("filename", "")) + logger.info( + "[%s] attachment[%d]: content_type=%s url=%s filename=%s", + self._log_tag, + _i, + _att.get("content_type", ""), + str(_att.get("url", ""))[:80], + _att.get("filename", ""), + ) # Process all attachments uniformly (images, voice, files) att_result = await self._process_attachments(attachments_raw) @@ -746,13 +864,23 @@ class QQAdapter(BasePlatformAdapter): # Append voice transcripts to the text body if voice_transcripts: voice_block = "\n".join(voice_transcripts) - text = (text + "\n\n" + voice_block).strip() if text.strip() else voice_block + text = ( + (text + "\n\n" + voice_block).strip() if text.strip() else voice_block + ) # Append non-media attachment info if attachment_info: - text = (text + "\n\n" + attachment_info).strip() if text.strip() else attachment_info + text = ( + (text + "\n\n" + attachment_info).strip() + if text.strip() + else attachment_info + ) - logger.info("[QQ] After processing: images=%d, voice=%d", - len(image_urls), len(voice_transcripts)) + logger.info( + "[%s] After processing: images=%d, voice=%d", + self._log_tag, + len(image_urls), + len(voice_transcripts), + ) if not text.strip() and not image_urls: return @@ -775,13 +903,20 @@ class QQAdapter(BasePlatformAdapter): await self.handle_message(event) async def _handle_group_message( - self, d: Dict[str, Any], msg_id: str, content: str, author: Dict[str, Any], timestamp: str + self, + d: Dict[str, Any], + msg_id: str, + content: str, + author: Dict[str, Any], + timestamp: str, ) -> None: """Handle a group @-message event.""" group_openid = str(d.get("group_openid", "")) if not group_openid: return - if not self._is_group_allowed(group_openid, str(author.get("member_openid", ""))): + if not self._is_group_allowed( + group_openid, str(author.get("member_openid", "")) + ): return # Strip the @bot mention prefix from content @@ -795,9 +930,15 @@ class QQAdapter(BasePlatformAdapter): # Append voice transcripts if voice_transcripts: voice_block = "\n".join(voice_transcripts) - text = (text + "\n\n" + voice_block).strip() if text.strip() else voice_block + text = ( + (text + "\n\n" + voice_block).strip() if text.strip() else voice_block + ) if attachment_info: - text = (text + "\n\n" + attachment_info).strip() if text.strip() else attachment_info + text = ( + (text + "\n\n" + attachment_info).strip() + if text.strip() + else attachment_info + ) if not text.strip() and not image_urls: return @@ -820,7 +961,12 @@ class QQAdapter(BasePlatformAdapter): await self.handle_message(event) async def _handle_guild_message( - self, d: Dict[str, Any], msg_id: str, content: str, author: Dict[str, Any], timestamp: str + self, + d: Dict[str, Any], + msg_id: str, + content: str, + author: Dict[str, Any], + timestamp: str, ) -> None: """Handle a guild/channel message event.""" channel_id = str(d.get("channel_id", "")) @@ -839,9 +985,15 @@ class QQAdapter(BasePlatformAdapter): if voice_transcripts: voice_block = "\n".join(voice_transcripts) - text = (text + "\n\n" + voice_block).strip() if text.strip() else voice_block + text = ( + (text + "\n\n" + voice_block).strip() if text.strip() else voice_block + ) if attachment_info: - text = (text + "\n\n" + attachment_info).strip() if text.strip() else attachment_info + text = ( + (text + "\n\n" + attachment_info).strip() + if text.strip() + else attachment_info + ) if not text.strip() and not image_urls: return @@ -865,7 +1017,12 @@ class QQAdapter(BasePlatformAdapter): await self.handle_message(event) async def _handle_dm_message( - self, d: Dict[str, Any], msg_id: str, content: str, author: Dict[str, Any], timestamp: str + self, + d: Dict[str, Any], + msg_id: str, + content: str, + author: Dict[str, Any], + timestamp: str, ) -> None: """Handle a guild DM message event.""" guild_id = str(d.get("guild_id", "")) @@ -881,9 +1038,15 @@ class QQAdapter(BasePlatformAdapter): if voice_transcripts: voice_block = "\n".join(voice_transcripts) - text = (text + "\n\n" + voice_block).strip() if text.strip() else voice_block + text = ( + (text + "\n\n" + voice_block).strip() if text.strip() else voice_block + ) if attachment_info: - text = (text + "\n\n" + attachment_info).strip() if text.strip() else attachment_info + text = ( + (text + "\n\n" + attachment_info).strip() + if text.strip() + else attachment_info + ) if not text.strip() and not image_urls: return @@ -909,7 +1072,6 @@ class QQAdapter(BasePlatformAdapter): # Attachment processing # ------------------------------------------------------------------ - @staticmethod def _detect_message_type(media_urls: list, media_types: list): """Determine MessageType from attachment content types.""" @@ -926,11 +1088,16 @@ class QQAdapter(BasePlatformAdapter): return MessageType.PHOTO # Unknown content type with an attachment — don't assume PHOTO # to prevent non-image files from being sent to vision analysis. - logger.debug("[QQ] Unknown media content_type '%s', defaulting to TEXT", first_type) + logger.debug( + "[%s] Unknown media content_type '%s', defaulting to TEXT", + self._log_tag, + first_type, + ) return MessageType.TEXT async def _process_attachments( - self, attachments: Any, + self, + attachments: Any, ) -> Dict[str, Any]: """Process inbound attachments (all message types). @@ -944,8 +1111,12 @@ class QQAdapter(BasePlatformAdapter): - attachment_info: str — text description of non-image, non-voice attachments """ if not isinstance(attachments, list): - return {"image_urls": [], "image_media_types": [], - "voice_transcripts": [], "attachment_info": ""} + return { + "image_urls": [], + "image_media_types": [], + "voice_transcripts": [], + "attachment_info": "", + } image_urls: List[str] = [] image_media_types: List[str] = [] @@ -967,30 +1138,39 @@ class QQAdapter(BasePlatformAdapter): url = "" continue - logger.debug("[QQ] Processing attachment: content_type=%s, url=%s, filename=%s", - ct, url[:80], filename) + logger.debug( + "[%s] Processing attachment: content_type=%s, url=%s, filename=%s", + self._log_tag, + ct, + url[:80], + filename, + ) if self._is_voice_content_type(ct, filename): # Voice: use QQ's asr_refer_text first, then voice_wav_url, then STT. asr_refer = ( str(att.get("asr_refer_text", "")).strip() - if isinstance(att.get("asr_refer_text"), str) else "" + if isinstance(att.get("asr_refer_text"), str) + else "" ) voice_wav_url = ( str(att.get("voice_wav_url", "")).strip() - if isinstance(att.get("voice_wav_url"), str) else "" + if isinstance(att.get("voice_wav_url"), str) + else "" ) transcript = await self._stt_voice_attachment( - url, ct, filename, + url, + ct, + filename, asr_refer_text=asr_refer or None, voice_wav_url=voice_wav_url or None, ) if transcript: voice_transcripts.append(f"[Voice] {transcript}") - logger.info("[QQ] Voice transcript: %s", transcript) + logger.debug("[%s] Voice transcript: %s", self._log_tag, transcript) else: - logger.warning("[QQ] Voice STT failed for %s", url[:60]) + logger.warning("[%s] Voice STT failed for %s", self._log_tag, url[:60]) voice_transcripts.append("[Voice] [语音识别失败]") elif ct.startswith("image/"): # Image: download and cache locally. @@ -1000,9 +1180,13 @@ class QQAdapter(BasePlatformAdapter): image_urls.append(cached_path) image_media_types.append(ct or "image/jpeg") elif cached_path: - logger.warning("[QQ] Cached image path does not exist: %s", cached_path) + logger.warning( + "[%s] Cached image path does not exist: %s", + self._log_tag, + cached_path, + ) except Exception as exc: - logger.debug("[QQ] Failed to cache image: %s", exc) + logger.debug("[%s] Failed to cache image: %s", self._log_tag, exc) else: # Other attachments (video, file, etc.): record as text. try: @@ -1010,7 +1194,7 @@ class QQAdapter(BasePlatformAdapter): if cached_path: other_attachments.append(f"[Attachment: {filename or ct}]") except Exception as exc: - logger.debug("[QQ] Failed to cache attachment: %s", exc) + logger.debug("[%s] Failed to cache attachment: %s", self._log_tag, exc) attachment_info = "\n".join(other_attachments) if other_attachments else "" return { @@ -1023,6 +1207,7 @@ class QQAdapter(BasePlatformAdapter): async def _download_and_cache(self, url: str, content_type: str) -> Optional[str]: """Download a URL and cache it locally.""" from tools.url_safety import is_safe_url + if not is_safe_url(url): raise ValueError(f"Blocked unsafe URL: {url[:80]}") @@ -1031,12 +1216,16 @@ class QQAdapter(BasePlatformAdapter): try: resp = await self._http_client.get( - url, timeout=30.0, headers=self._qq_media_headers(), + url, + timeout=30.0, + headers=self._qq_media_headers(), ) resp.raise_for_status() data = resp.content except Exception as exc: - logger.debug("[%s] Download failed for %s: %s", self.name, url[:80], exc) + logger.debug( + "[%s] Download failed for %s: %s", self._log_tag, url[:80], exc + ) return None if content_type.startswith("image/"): @@ -1057,7 +1246,17 @@ class QQAdapter(BasePlatformAdapter): fn = filename.strip().lower() if ct == "voice" or ct.startswith("audio/"): return True - _VOICE_EXTENSIONS = (".silk", ".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac", ".speex", ".flac") + _VOICE_EXTENSIONS = ( + ".silk", + ".amr", + ".mp3", + ".wav", + ".ogg", + ".m4a", + ".aac", + ".speex", + ".flac", + ) if any(fn.endswith(ext) for ext in _VOICE_EXTENSIONS): return True return False @@ -1074,13 +1273,13 @@ class QQAdapter(BasePlatformAdapter): return {} async def _stt_voice_attachment( - self, - url: str, - content_type: str, - filename: str, - *, - asr_refer_text: Optional[str] = None, - voice_wav_url: Optional[str] = None, + self, + url: str, + content_type: str, + filename: str, + *, + asr_refer_text: Optional[str] = None, + voice_wav_url: Optional[str] = None, ) -> Optional[str]: """Download a voice attachment, convert to wav, and transcribe. @@ -1093,7 +1292,9 @@ class QQAdapter(BasePlatformAdapter): """ # 1. Use QQ's built-in ASR text if available if asr_refer_text: - logger.info("[QQ] STT: using QQ asr_refer_text: %r", asr_refer_text[:100]) + logger.debug( + "[%s] STT: using QQ asr_refer_text: %r", self._log_tag, asr_refer_text[:100] + ) return asr_refer_text # Determine which URL to download (prefer voice_wav_url — already WAV) @@ -1104,7 +1305,7 @@ class QQAdapter(BasePlatformAdapter): voice_wav_url = f"https:{voice_wav_url}" download_url = voice_wav_url is_pre_wav = True - logger.info("[QQ] STT: using voice_wav_url (pre-converted WAV)") + logger.debug("[%s] STT: using voice_wav_url (pre-converted WAV)", self._log_tag) from tools.url_safety import is_safe_url if not is_safe_url(download_url): @@ -1114,40 +1315,65 @@ class QQAdapter(BasePlatformAdapter): try: # 2. Download audio (QQ CDN requires Authorization header) if not self._http_client: - logger.warning("[QQ] STT: no HTTP client") + logger.warning("[%s] STT: no HTTP client", self._log_tag) return None download_headers = self._qq_media_headers() - logger.info("[QQ] STT: downloading voice from %s (pre_wav=%s, headers=%s)", - download_url[:80], is_pre_wav, bool(download_headers)) + logger.debug( + "[%s] STT: downloading voice from %s (pre_wav=%s, headers=%s)", + self._log_tag, + download_url[:80], + is_pre_wav, + bool(download_headers), + ) resp = await self._http_client.get( - download_url, timeout=30.0, headers=download_headers, follow_redirects=True, + download_url, + timeout=30.0, + headers=download_headers, + follow_redirects=True, ) resp.raise_for_status() audio_data = resp.content - logger.info("[QQ] STT: downloaded %d bytes, content_type=%s", - len(audio_data), resp.headers.get("content-type", "unknown")) + logger.debug( + "[%s] STT: downloaded %d bytes, content_type=%s", + self._log_tag, + len(audio_data), + resp.headers.get("content-type", "unknown"), + ) if len(audio_data) < 10: - logger.warning("[QQ] STT: downloaded data too small (%d bytes), skipping", len(audio_data)) + logger.warning( + "[%s] STT: downloaded data too small (%d bytes), skipping", + self._log_tag, + len(audio_data), + ) return None # 3. Convert to wav (skip if we already have a pre-converted WAV) if is_pre_wav: import tempfile + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp.write(audio_data) wav_path = tmp.name - logger.info("[QQ] STT: using pre-converted WAV directly (%d bytes)", len(audio_data)) + logger.debug( + "[%s] STT: using pre-converted WAV directly (%d bytes)", + self._log_tag, + len(audio_data), + ) else: - logger.info("[QQ] STT: converting to wav, filename=%r", filename) + logger.debug( + "[%s] STT: converting to wav, filename=%r", self._log_tag, filename + ) wav_path = await self._convert_audio_to_wav_file(audio_data, filename) if not wav_path or not Path(wav_path).exists(): - logger.warning("[QQ] STT: ffmpeg conversion produced no output") + logger.warning( + "[%s] STT: ffmpeg conversion produced no output", self._log_tag + ) return None # 4. Call STT API - logger.info("[QQ] STT: calling ASR on %s", wav_path) + logger.debug("[%s] STT: calling ASR on %s", self._log_tag, wav_path) transcript = await self._call_stt(wav_path) # 5. Cleanup temp file @@ -1157,15 +1383,22 @@ class QQAdapter(BasePlatformAdapter): pass if transcript: - logger.info("[QQ] STT success: %r", transcript[:100]) + logger.debug("[%s] STT success: %r", self._log_tag, transcript[:100]) else: - logger.warning("[QQ] STT: ASR returned empty transcript") + logger.warning("[%s] STT: ASR returned empty transcript", self._log_tag) return transcript except (httpx.HTTPStatusError, httpx.TransportError, IOError) as exc: - logger.warning("[QQ] STT failed for voice attachment: %s: %s", type(exc).__name__, exc) + logger.warning( + "[%s] STT failed for voice attachment: %s: %s", + self._log_tag, + type(exc).__name__, + exc, + ) return None - async def _convert_audio_to_wav_file(self, audio_data: bytes, filename: str) -> Optional[str]: + async def _convert_audio_to_wav_file( + self, audio_data: bytes, filename: str + ) -> Optional[str]: """Convert audio bytes to a temp .wav file using pilk (SILK) or ffmpeg. QQ voice messages are typically SILK format which ffmpeg cannot decode. @@ -1175,9 +1408,18 @@ class QQAdapter(BasePlatformAdapter): """ import tempfile - ext = Path(filename).suffix.lower() if Path(filename).suffix else self._guess_ext_from_data(audio_data) - logger.info("[QQ] STT: audio_data size=%d, ext=%r, first_20_bytes=%r", - len(audio_data), ext, audio_data[:20]) + ext = ( + Path(filename).suffix.lower() + if Path(filename).suffix + else self._guess_ext_from_data(audio_data) + ) + logger.info( + "[%s] STT: audio_data size=%d, ext=%r, first_20_bytes=%r", + self._log_tag, + len(audio_data), + ext, + audio_data[:20], + ) with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_src: tmp_src.write(audio_data) @@ -1229,8 +1471,7 @@ class QQAdapter(BasePlatformAdapter): """Check if bytes look like a SILK audio file.""" return data[:4] == b"#!SILK" or data[:2] == b"\x02!" or data[:9] == b"#!SILK_V3" - @staticmethod - async def _convert_silk_to_wav(src_path: str, wav_path: str) -> Optional[str]: + async def _convert_silk_to_wav(self, src_path: str, wav_path: str) -> Optional[str]: """Convert audio file to WAV using the pilk library. Tries the file as-is first, then as .silk if the extension differs. @@ -1239,31 +1480,43 @@ class QQAdapter(BasePlatformAdapter): try: import pilk except ImportError: - logger.warning("[QQ] pilk not installed — cannot decode SILK audio. Run: pip install pilk") + logger.warning( + "[%s] pilk not installed — cannot decode SILK audio. Run: pip install pilk", + self._log_tag, + ) return None # Try converting the file as-is try: pilk.silk_to_wav(src_path, wav_path, rate=16000) if Path(wav_path).exists() and Path(wav_path).stat().st_size > 44: - logger.info("[QQ] pilk converted %s to wav (%d bytes)", - Path(src_path).name, Path(wav_path).stat().st_size) + logger.debug( + "[%s] pilk converted %s to wav (%d bytes)", + self._log_tag, + Path(src_path).name, + Path(wav_path).stat().st_size, + ) return wav_path except Exception as exc: - logger.debug("[QQ] pilk direct conversion failed: %s", exc) + logger.debug("[%s] pilk direct conversion failed: %s", self._log_tag, exc) # Try renaming to .silk and converting (pilk checks the extension) silk_path = src_path.rsplit(".", 1)[0] + ".silk" try: import shutil + shutil.copy2(src_path, silk_path) pilk.silk_to_wav(silk_path, wav_path, rate=16000) if Path(wav_path).exists() and Path(wav_path).stat().st_size > 44: - logger.info("[QQ] pilk converted %s (as .silk) to wav (%d bytes)", - Path(src_path).name, Path(wav_path).stat().st_size) + logger.debug( + "[%s] pilk converted %s (as .silk) to wav (%d bytes)", + self._log_tag, + Path(src_path).name, + Path(wav_path).stat().st_size, + ) return wav_path except Exception as exc: - logger.debug("[QQ] pilk .silk conversion failed: %s", exc) + logger.debug("[%s] pilk .silk conversion failed: %s", self._log_tag, exc) finally: try: os.unlink(silk_path) @@ -1272,8 +1525,7 @@ class QQAdapter(BasePlatformAdapter): return None - @staticmethod - async def _convert_raw_to_wav(audio_data: bytes, wav_path: str) -> Optional[str]: + async def _convert_raw_to_wav(self, audio_data: bytes, wav_path: str) -> Optional[str]: """Last resort: try writing audio data as raw PCM 16-bit mono 16kHz WAV. This will produce garbage if the data isn't raw PCM, but at least @@ -1281,6 +1533,7 @@ class QQAdapter(BasePlatformAdapter): """ try: import wave + with wave.open(wav_path, "w") as wf: wf.setnchannels(1) wf.setsampwidth(2) @@ -1288,33 +1541,52 @@ class QQAdapter(BasePlatformAdapter): wf.writeframes(audio_data) return wav_path except Exception as exc: - logger.debug("[QQ] raw PCM fallback failed: %s", exc) + logger.debug("[%s] raw PCM fallback failed: %s", self._log_tag, exc) return None - @staticmethod - async def _convert_ffmpeg_to_wav(src_path: str, wav_path: str) -> Optional[str]: + async def _convert_ffmpeg_to_wav(self, src_path: str, wav_path: str) -> Optional[str]: """Convert audio file to WAV using ffmpeg.""" try: proc = await asyncio.create_subprocess_exec( - "ffmpeg", "-y", "-i", src_path, "-ar", "16000", "-ac", "1", wav_path, + "ffmpeg", + "-y", + "-i", + src_path, + "-ar", + "16000", + "-ac", + "1", + wav_path, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.PIPE, ) await asyncio.wait_for(proc.wait(), timeout=30) if proc.returncode != 0: stderr = await proc.stderr.read() if proc.stderr else b"" - logger.warning("[QQ] ffmpeg failed for %s: %s", - Path(src_path).name, stderr[:200].decode(errors="replace")) + logger.warning( + "[%s] ffmpeg failed for %s: %s", + self._log_tag, + Path(src_path).name, + stderr[:200].decode(errors="replace"), + ) return None except (asyncio.TimeoutError, FileNotFoundError) as exc: - logger.warning("[QQ] ffmpeg conversion error: %s", exc) + logger.warning("[%s] ffmpeg conversion error: %s", self._log_tag, exc) return None if not Path(wav_path).exists() or Path(wav_path).stat().st_size <= 44: - logger.warning("[QQ] ffmpeg produced no/small output for %s", Path(src_path).name) + logger.warning( + "[%s] ffmpeg produced no/small output for %s", + self._log_tag, + Path(src_path).name, + ) return None - logger.info("[QQ] ffmpeg converted %s to wav (%d bytes)", - Path(src_path).name, Path(wav_path).stat().st_size) + logger.debug( + "[%s] ffmpeg converted %s to wav (%d bytes)", + self._log_tag, + Path(src_path).name, + Path(wav_path).stat().st_size, + ) return wav_path def _resolve_stt_config(self) -> Optional[Dict[str, str]]: @@ -1353,7 +1625,8 @@ class QQAdapter(BasePlatformAdapter): return { "base_url": base_url, "api_key": api_key, - "model": model or ("glm-asr" if provider in ("zai", "glm") else "whisper-1"), + "model": model + or ("glm-asr" if provider in ("zai", "glm") else "whisper-1"), } # 2. QQ-specific env vars (set by `hermes setup gateway` / `hermes gateway`) @@ -1381,7 +1654,10 @@ class QQAdapter(BasePlatformAdapter): """ stt_cfg = self._resolve_stt_config() if not stt_cfg: - logger.warning("[QQ] STT not configured (no stt config or QQ_STT_API_KEY)") + logger.warning( + "[%s] STT not configured (no stt config or QQ_STT_API_KEY)", + self._log_tag, + ) return None base_url = stt_cfg["base_url"] @@ -1411,17 +1687,37 @@ class QQAdapter(BasePlatformAdapter): return text.strip() return None except (httpx.HTTPStatusError, IOError) as exc: - logger.warning("[QQ] STT API call failed (model=%s, base=%s): %s", - model, base_url[:50], exc) + logger.warning( + "[%s] STT API call failed (model=%s, base=%s): %s", + self._log_tag, + model, + base_url[:50], + exc, + ) return None - async def _convert_audio_to_wav(self, audio_data: bytes, source_url: str) -> Optional[str]: + async def _convert_audio_to_wav( + self, audio_data: bytes, source_url: str + ) -> Optional[str]: """Convert audio bytes to .wav using pilk (SILK) or ffmpeg, caching the result.""" import tempfile # Determine source format from magic bytes or URL - ext = Path(urlparse(source_url).path).suffix.lower() if urlparse(source_url).path else "" - if not ext or ext not in (".silk", ".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac", ".flac"): + ext = ( + Path(urlparse(source_url).path).suffix.lower() + if urlparse(source_url).path + else "" + ) + if not ext or ext not in ( + ".silk", + ".amr", + ".mp3", + ".wav", + ".ogg", + ".m4a", + ".aac", + ".flac", + ): ext = self._guess_ext_from_data(audio_data) with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp_src: @@ -1437,8 +1733,12 @@ class QQAdapter(BasePlatformAdapter): result = await self._convert_ffmpeg_to_wav(src_path, wav_path) if not result: - logger.warning("[%s] audio conversion failed for %s (format=%s)", - self.name, source_url[:60], ext) + logger.warning( + "[%s] audio conversion failed for %s (format=%s)", + self._log_tag, + source_url[:60], + ext, + ) return cache_document_from_bytes(audio_data, f"qq_voice{ext}") except Exception: return cache_document_from_bytes(audio_data, f"qq_voice{ext}") @@ -1454,7 +1754,7 @@ class QQAdapter(BasePlatformAdapter): os.unlink(wav_path) return cache_document_from_bytes(wav_data, "qq_voice.wav") except Exception as exc: - logger.debug("[%s] Failed to read converted wav: %s", self.name, exc) + logger.debug("[%s] Failed to read converted wav: %s", self._log_tag, exc) return None # ------------------------------------------------------------------ @@ -1462,11 +1762,11 @@ class QQAdapter(BasePlatformAdapter): # ------------------------------------------------------------------ async def _api_request( - self, - method: str, - path: str, - body: Optional[Dict[str, Any]] = None, - timeout: float = DEFAULT_API_TIMEOUT, + self, + method: str, + path: str, + body: Optional[Dict[str, Any]] = None, + timeout: float = DEFAULT_API_TIMEOUT, ) -> Dict[str, Any]: """Make an authenticated REST API request to QQ Bot API.""" if not self._http_client: @@ -1476,6 +1776,7 @@ class QQAdapter(BasePlatformAdapter): headers = { "Authorization": f"QQBot {token}", "Content-Type": "application/json", + "User-Agent": build_user_agent(), } try: @@ -1497,17 +1798,21 @@ class QQAdapter(BasePlatformAdapter): raise RuntimeError(f"QQ Bot API timeout [{path}]: {exc}") from exc async def _upload_media( - self, - target_type: str, - target_id: str, - file_type: int, - url: Optional[str] = None, - file_data: Optional[str] = None, - srv_send_msg: bool = False, - file_name: Optional[str] = None, + self, + target_type: str, + target_id: str, + file_type: int, + url: Optional[str] = None, + file_data: Optional[str] = None, + srv_send_msg: bool = False, + file_name: Optional[str] = None, ) -> Dict[str, Any]: """Upload media and return file_info.""" - path = f"/v2/users/{target_id}/files" if target_type == "c2c" else f"/v2/groups/{target_id}/files" + path = ( + f"/v2/users/{target_id}/files" + if target_type == "c2c" + else f"/v2/groups/{target_id}/files" + ) body: Dict[str, Any] = { "file_type": file_type, @@ -1524,11 +1829,16 @@ class QQAdapter(BasePlatformAdapter): last_exc = None for attempt in range(3): try: - return await self._api_request("POST", path, body, timeout=FILE_UPLOAD_TIMEOUT) + return await self._api_request( + "POST", path, body, timeout=FILE_UPLOAD_TIMEOUT + ) except RuntimeError as exc: last_exc = exc err_msg = str(exc) - if any(kw in err_msg for kw in ("400", "401", "Invalid", "timeout", "Timeout")): + if any( + kw in err_msg + for kw in ("400", "401", "Invalid", "timeout", "Timeout") + ): raise if attempt < 2: await asyncio.sleep(1.5 * (attempt + 1)) @@ -1551,23 +1861,23 @@ class QQAdapter(BasePlatformAdapter): Returns True if reconnected, False if still disconnected. """ logger.info("[%s] Not connected — waiting for reconnection (up to %.0fs)", - self.name, self._RECONNECT_WAIT_SECONDS) + self._log_tag, self._RECONNECT_WAIT_SECONDS) waited = 0.0 while waited < self._RECONNECT_WAIT_SECONDS: await asyncio.sleep(self._RECONNECT_POLL_INTERVAL) waited += self._RECONNECT_POLL_INTERVAL if self.is_connected: - logger.info("[%s] Reconnected after %.1fs", self.name, waited) + logger.info("[%s] Reconnected after %.1fs", self._log_tag, waited) return True - logger.warning("[%s] Still not connected after %.0fs", self.name, self._RECONNECT_WAIT_SECONDS) + logger.warning("[%s] Still not connected after %.0fs", self._log_tag, self._RECONNECT_WAIT_SECONDS) return False async def send( - self, - chat_id: str, - content: str, - reply_to: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: """Send a text or markdown message to a QQ user or group. @@ -1596,7 +1906,10 @@ class QQAdapter(BasePlatformAdapter): return last_result async def _send_chunk( - self, chat_id: str, content: str, reply_to: Optional[str] = None, + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, ) -> SendResult: """Send a single chunk with retry + exponential backoff.""" last_exc: Optional[Exception] = None @@ -1611,28 +1924,39 @@ class QQAdapter(BasePlatformAdapter): elif chat_type == "guild": return await self._send_guild_text(chat_id, content, reply_to) else: - return SendResult(success=False, error=f"Unknown chat type for {chat_id}") + return SendResult( + success=False, error=f"Unknown chat type for {chat_id}" + ) except Exception as exc: last_exc = exc err = str(exc).lower() # Permanent errors — don't retry - if any(k in err for k in ("invalid", "forbidden", "not found", "bad request")): + if any( + k in err + for k in ("invalid", "forbidden", "not found", "bad request") + ): break # Transient — back off and retry if attempt < 2: delay = 1.0 * (2 ** attempt) - logger.warning("[%s] send retry %d/3 after %.1fs: %s", - self.name, attempt + 1, delay, exc) + logger.warning( + "[%s] send retry %d/3 after %.1fs: %s", + self._log_tag, + attempt + 1, + delay, + exc, + ) await asyncio.sleep(delay) error_msg = str(last_exc) if last_exc else "Unknown error" - logger.error("[%s] Send failed: %s", self.name, error_msg) - retryable = not any(k in error_msg.lower() - for k in ("invalid", "forbidden", "not found")) + logger.error("[%s] Send failed: %s", self._log_tag, error_msg) + retryable = not any( + k in error_msg.lower() for k in ("invalid", "forbidden", "not found") + ) return SendResult(success=False, error=error_msg, retryable=retryable) async def _send_c2c_text( - self, openid: str, content: str, reply_to: Optional[str] = None + self, openid: str, content: str, reply_to: Optional[str] = None ) -> SendResult: """Send text to a C2C user via REST API.""" msg_seq = self._next_msg_seq(reply_to or openid) @@ -1645,7 +1969,7 @@ class QQAdapter(BasePlatformAdapter): return SendResult(success=True, message_id=msg_id, raw_response=data) async def _send_group_text( - self, group_openid: str, content: str, reply_to: Optional[str] = None + self, group_openid: str, content: str, reply_to: Optional[str] = None ) -> SendResult: """Send text to a group via REST API.""" msg_seq = self._next_msg_seq(reply_to or group_openid) @@ -1653,15 +1977,17 @@ class QQAdapter(BasePlatformAdapter): if reply_to: body["msg_id"] = reply_to - data = await self._api_request("POST", f"/v2/groups/{group_openid}/messages", body) + data = await self._api_request( + "POST", f"/v2/groups/{group_openid}/messages", body + ) msg_id = str(data.get("id", uuid.uuid4().hex[:12])) return SendResult(success=True, message_id=msg_id, raw_response=data) async def _send_guild_text( - self, channel_id: str, content: str, reply_to: Optional[str] = None + self, channel_id: str, content: str, reply_to: Optional[str] = None ) -> SendResult: """Send text to a guild channel via REST API.""" - body: Dict[str, Any] = {"content": content[:self.MAX_MESSAGE_LENGTH]} + body: Dict[str, Any] = {"content": content[: self.MAX_MESSAGE_LENGTH]} if reply_to: body["msg_id"] = reply_to @@ -1669,19 +1995,21 @@ class QQAdapter(BasePlatformAdapter): msg_id = str(data.get("id", uuid.uuid4().hex[:12])) return SendResult(success=True, message_id=msg_id, raw_response=data) - def _build_text_body(self, content: str, reply_to: Optional[str] = None) -> Dict[str, Any]: + def _build_text_body( + self, content: str, reply_to: Optional[str] = None + ) -> Dict[str, Any]: """Build the message body for C2C/group text sending.""" msg_seq = self._next_msg_seq(reply_to or "default") if self._markdown_support: body: Dict[str, Any] = { - "markdown": {"content": content[:self.MAX_MESSAGE_LENGTH]}, + "markdown": {"content": content[: self.MAX_MESSAGE_LENGTH]}, "msg_type": MSG_TYPE_MARKDOWN, "msg_seq": msg_seq, } else: body = { - "content": content[:self.MAX_MESSAGE_LENGTH], + "content": content[: self.MAX_MESSAGE_LENGTH], "msg_type": MSG_TYPE_TEXT, "msg_seq": msg_seq, } @@ -1698,84 +2026,103 @@ class QQAdapter(BasePlatformAdapter): # ------------------------------------------------------------------ async def send_image( - self, - chat_id: str, - image_url: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, + self, + chat_id: str, + image_url: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: """Send an image natively via QQ Bot API upload.""" del metadata - result = await self._send_media(chat_id, image_url, MEDIA_TYPE_IMAGE, "image", caption, reply_to) + result = await self._send_media( + chat_id, image_url, MEDIA_TYPE_IMAGE, "image", caption, reply_to + ) if result.success or not self._is_url(image_url): return result # Fallback to text URL - logger.warning("[%s] Image send failed, falling back to text: %s", self.name, result.error) + logger.warning( + "[%s] Image send failed, falling back to text: %s", + self._log_tag, + result.error, + ) fallback = f"{caption}\n{image_url}" if caption else image_url return await self.send(chat_id=chat_id, content=fallback, reply_to=reply_to) async def send_image_file( - self, - chat_id: str, - image_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, - **kwargs, + self, + chat_id: str, + image_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, ) -> SendResult: """Send a local image file natively.""" del kwargs - return await self._send_media(chat_id, image_path, MEDIA_TYPE_IMAGE, "image", caption, reply_to) + return await self._send_media( + chat_id, image_path, MEDIA_TYPE_IMAGE, "image", caption, reply_to + ) async def send_voice( - self, - chat_id: str, - audio_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, - **kwargs, + self, + chat_id: str, + audio_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, ) -> SendResult: """Send a voice message natively.""" del kwargs - return await self._send_media(chat_id, audio_path, MEDIA_TYPE_VOICE, "voice", caption, reply_to) + return await self._send_media( + chat_id, audio_path, MEDIA_TYPE_VOICE, "voice", caption, reply_to + ) async def send_video( - self, - chat_id: str, - video_path: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, - **kwargs, + self, + chat_id: str, + video_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, ) -> SendResult: """Send a video natively.""" del kwargs - return await self._send_media(chat_id, video_path, MEDIA_TYPE_VIDEO, "video", caption, reply_to) + return await self._send_media( + chat_id, video_path, MEDIA_TYPE_VIDEO, "video", caption, reply_to + ) async def send_document( - self, - chat_id: str, - file_path: str, - caption: Optional[str] = None, - file_name: Optional[str] = None, - reply_to: Optional[str] = None, - **kwargs, + self, + chat_id: str, + file_path: str, + caption: Optional[str] = None, + file_name: Optional[str] = None, + reply_to: Optional[str] = None, + **kwargs, ) -> SendResult: """Send a file/document natively.""" del kwargs - return await self._send_media(chat_id, file_path, MEDIA_TYPE_FILE, "file", caption, reply_to, - file_name=file_name) + return await self._send_media( + chat_id, + file_path, + MEDIA_TYPE_FILE, + "file", + caption, + reply_to, + file_name=file_name, + ) async def _send_media( - self, - chat_id: str, - media_source: str, - file_type: int, - kind: str, - caption: Optional[str] = None, - reply_to: Optional[str] = None, - file_name: Optional[str] = None, + self, + chat_id: str, + media_source: str, + file_type: int, + kind: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + file_name: Optional[str] = None, ) -> SendResult: """Upload media and send as a native message.""" if not self.is_connected: @@ -1784,20 +2131,30 @@ class QQAdapter(BasePlatformAdapter): try: # Resolve media source - data, content_type, resolved_name = await self._load_media(media_source, file_name) + data, content_type, resolved_name = await self._load_media( + media_source, file_name + ) # Route chat_type = self._guess_chat_type(chat_id) - target_path = f"/v2/users/{chat_id}/files" if chat_type == "c2c" else f"/v2/groups/{chat_id}/files" + target_path = ( + f"/v2/users/{chat_id}/files" + if chat_type == "c2c" + else f"/v2/groups/{chat_id}/files" + ) if chat_type == "guild": # Guild channels don't support native media upload in the same way # Send as URL fallback - return SendResult(success=False, error="Guild media send not supported via this path") + return SendResult( + success=False, error="Guild media send not supported via this path" + ) # Upload upload = await self._upload_media( - chat_type, chat_id, file_type, + chat_type, + chat_id, + file_type, file_data=data if not self._is_url(media_source) else None, url=media_source if self._is_url(media_source) else None, srv_send_msg=False, @@ -1806,7 +2163,9 @@ class QQAdapter(BasePlatformAdapter): file_info = upload.get("file_info") if not file_info: - return SendResult(success=False, error=f"Upload returned no file_info: {upload}") + return SendResult( + success=False, error=f"Upload returned no file_info: {upload}" + ) # Send media message msg_seq = self._next_msg_seq(chat_id) @@ -1816,13 +2175,17 @@ class QQAdapter(BasePlatformAdapter): "msg_seq": msg_seq, } if caption: - body["content"] = caption[:self.MAX_MESSAGE_LENGTH] + body["content"] = caption[: self.MAX_MESSAGE_LENGTH] if reply_to: body["msg_id"] = reply_to send_data = await self._api_request( "POST", - f"/v2/users/{chat_id}/messages" if chat_type == "c2c" else f"/v2/groups/{chat_id}/messages", + ( + f"/v2/users/{chat_id}/messages" + if chat_type == "c2c" + else f"/v2/groups/{chat_id}/messages" + ), body, ) return SendResult( @@ -1831,11 +2194,11 @@ class QQAdapter(BasePlatformAdapter): raw_response=send_data, ) except Exception as exc: - logger.error("[%s] Media send failed: %s", self.name, exc) + logger.error("[%s] Media send failed: %s", self._log_tag, exc) return SendResult(success=False, error=str(exc)) async def _load_media( - self, source: str, file_name: Optional[str] = None + self, source: str, file_name: Optional[str] = None ) -> Tuple[str, str, str]: """Load media from URL or local path. Returns (base64_or_url, content_type, filename).""" source = str(source).strip() @@ -1866,7 +2229,9 @@ class QQAdapter(BasePlatformAdapter): raw = local_path.read_bytes() resolved_name = file_name or local_path.name - content_type = mimetypes.guess_type(str(local_path))[0] or "application/octet-stream" + content_type = ( + mimetypes.guess_type(str(local_path))[0] or "application/octet-stream" + ) b64 = base64.b64encode(raw).decode("ascii") return b64, content_type, resolved_name @@ -1875,27 +2240,44 @@ class QQAdapter(BasePlatformAdapter): # ------------------------------------------------------------------ async def send_typing(self, chat_id: str, metadata=None) -> None: - """Send an input notify to a C2C user (only supported for C2C).""" - del metadata + """Send an input notify to a C2C user (only supported for C2C). + Debounced to one request per ~50s (the API sets a 60s indicator). + The QQ API requires the originating message ID — retrieved from + ``_last_msg_id`` which is populated by ``_on_message``. + """ if not self.is_connected: return - # Only C2C supports input notify chat_type = self._guess_chat_type(chat_id) if chat_type != "c2c": return + msg_id = self._last_msg_id.get(chat_id) + if not msg_id: + return + + # Debounce — skip if we sent recently + now = time.time() + last_sent = self._typing_sent_at.get(chat_id, 0.0) + if now - last_sent < self._TYPING_DEBOUNCE_SECONDS: + return + try: msg_seq = self._next_msg_seq(chat_id) body = { "msg_type": MSG_TYPE_INPUT_NOTIFY, - "input_notify": {"input_type": 1, "input_second": 60}, + "msg_id": msg_id, + "input_notify": { + "input_type": 1, + "input_second": self._TYPING_INPUT_SECONDS, + }, "msg_seq": msg_seq, } await self._api_request("POST", f"/v2/users/{chat_id}/messages", body) + self._typing_sent_at[chat_id] = now except Exception as exc: - logger.debug("[%s] send_typing failed: %s", self.name, exc) + logger.debug("[%s] send_typing failed: %s", self._log_tag, exc) # ------------------------------------------------------------------ # Format @@ -1942,7 +2324,8 @@ class QQAdapter(BasePlatformAdapter): """Strip the @bot mention prefix from group message content.""" # QQ group @-messages may have the bot's QQ/ID as prefix import re - stripped = re.sub(r'^@\S+\s*', '', content.strip()) + + stripped = re.sub(r"^@\S+\s*", "", content.strip()) return stripped def _is_dm_allowed(self, user_id: str) -> bool: diff --git a/gateway/platforms/qqbot/constants.py b/gateway/platforms/qqbot/constants.py new file mode 100644 index 0000000000..ddae3c133e --- /dev/null +++ b/gateway/platforms/qqbot/constants.py @@ -0,0 +1,74 @@ +"""QQBot package-level constants shared across adapter, onboard, and other modules.""" + +from __future__ import annotations + +import os + +# --------------------------------------------------------------------------- +# QQBot adapter version — bump on functional changes to the adapter package. +# --------------------------------------------------------------------------- + +QQBOT_VERSION = "1.1.0" + +# --------------------------------------------------------------------------- +# API endpoints +# --------------------------------------------------------------------------- + +# The portal domain is configurable via QQ_API_HOST for corporate proxies +# or test environments. Default: q.qq.com (production). +PORTAL_HOST = os.getenv("QQ_PORTAL_HOST", "q.qq.com") + +API_BASE = "https://api.sgroup.qq.com" +TOKEN_URL = "https://bots.qq.com/app/getAppAccessToken" +GATEWAY_URL_PATH = "/gateway" + +# QR-code onboard endpoints (on the portal host) +ONBOARD_CREATE_PATH = "/lite/create_bind_task" +ONBOARD_POLL_PATH = "/lite/poll_bind_result" +QR_URL_TEMPLATE = ( + "https://q.qq.com/qqbot/openclaw/connect.html" + "?task_id={task_id}&_wv=2&source=hermes" +) + +# --------------------------------------------------------------------------- +# Timeouts & retry +# --------------------------------------------------------------------------- + +DEFAULT_API_TIMEOUT = 30.0 +FILE_UPLOAD_TIMEOUT = 120.0 +CONNECT_TIMEOUT_SECONDS = 20.0 + +RECONNECT_BACKOFF = [2, 5, 10, 30, 60] +MAX_RECONNECT_ATTEMPTS = 100 +RATE_LIMIT_DELAY = 60 # seconds +QUICK_DISCONNECT_THRESHOLD = 5.0 # seconds +MAX_QUICK_DISCONNECT_COUNT = 3 + +ONBOARD_POLL_INTERVAL = 2.0 # seconds between poll_bind_result calls +ONBOARD_API_TIMEOUT = 10.0 + +# --------------------------------------------------------------------------- +# Message limits +# --------------------------------------------------------------------------- + +MAX_MESSAGE_LENGTH = 4000 +DEDUP_WINDOW_SECONDS = 300 +DEDUP_MAX_SIZE = 1000 + +# --------------------------------------------------------------------------- +# QQ Bot message types +# --------------------------------------------------------------------------- + +MSG_TYPE_TEXT = 0 +MSG_TYPE_MARKDOWN = 2 +MSG_TYPE_MEDIA = 7 +MSG_TYPE_INPUT_NOTIFY = 6 + +# --------------------------------------------------------------------------- +# QQ Bot file media types +# --------------------------------------------------------------------------- + +MEDIA_TYPE_IMAGE = 1 +MEDIA_TYPE_VIDEO = 2 +MEDIA_TYPE_VOICE = 3 +MEDIA_TYPE_FILE = 4 diff --git a/gateway/platforms/qqbot/crypto.py b/gateway/platforms/qqbot/crypto.py new file mode 100644 index 0000000000..426bd29de5 --- /dev/null +++ b/gateway/platforms/qqbot/crypto.py @@ -0,0 +1,45 @@ +"""AES-256-GCM utilities for QQBot scan-to-configure credential decryption.""" + +from __future__ import annotations + +import base64 +import os + + +def generate_bind_key() -> str: + """Generate a 256-bit random AES key and return it as base64. + + The key is passed to ``create_bind_task`` so the server can encrypt + the bot's *client_secret* before returning it. Only this CLI holds + the key, ensuring the secret never travels in plaintext. + """ + return base64.b64encode(os.urandom(32)).decode() + + +def decrypt_secret(encrypted_base64: str, key_base64: str) -> str: + """Decrypt a base64-encoded AES-256-GCM ciphertext. + + Ciphertext layout (after base64-decoding):: + + IV (12 bytes) ‖ ciphertext (N bytes) ‖ AuthTag (16 bytes) + + Args: + encrypted_base64: The ``bot_encrypt_secret`` value from + ``poll_bind_result``. + key_base64: The base64 AES key generated by + :func:`generate_bind_key`. + + Returns: + The decrypted *client_secret* as a UTF-8 string. + """ + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + + key = base64.b64decode(key_base64) + raw = base64.b64decode(encrypted_base64) + + iv = raw[:12] + ciphertext_with_tag = raw[12:] # AESGCM expects ciphertext + tag concatenated + + aesgcm = AESGCM(key) + plaintext = aesgcm.decrypt(iv, ciphertext_with_tag, None) + return plaintext.decode("utf-8") diff --git a/gateway/platforms/qqbot/onboard.py b/gateway/platforms/qqbot/onboard.py new file mode 100644 index 0000000000..65750b3f10 --- /dev/null +++ b/gateway/platforms/qqbot/onboard.py @@ -0,0 +1,124 @@ +""" +QQBot scan-to-configure (QR code onboard) module. + +Calls the ``q.qq.com`` ``create_bind_task`` / ``poll_bind_result`` APIs to +generate a QR-code URL and poll for scan completion. On success the caller +receives the bot's *app_id*, *client_secret* (decrypted locally), and the +scanner's *user_openid* — enough to fully configure the QQBot gateway. + +Reference: https://bot.q.qq.com/wiki/develop/api-v2/ +""" + +from __future__ import annotations + +import logging +from enum import IntEnum +from typing import Tuple +from urllib.parse import quote + +from .constants import ( + ONBOARD_API_TIMEOUT, + ONBOARD_CREATE_PATH, + ONBOARD_POLL_PATH, + PORTAL_HOST, + QR_URL_TEMPLATE, +) +from .crypto import generate_bind_key +from .utils import get_api_headers + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Bind status +# --------------------------------------------------------------------------- + + +class BindStatus(IntEnum): + """Status codes returned by ``poll_bind_result``.""" + + NONE = 0 + PENDING = 1 + COMPLETED = 2 + EXPIRED = 3 + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def create_bind_task( + timeout: float = ONBOARD_API_TIMEOUT, +) -> Tuple[str, str]: + """Create a bind task and return *(task_id, aes_key_base64)*. + + The AES key is generated locally and sent to the server so it can + encrypt the bot credentials before returning them. + + Raises: + RuntimeError: If the API returns a non-zero ``retcode``. + """ + import httpx + + url = f"https://{PORTAL_HOST}{ONBOARD_CREATE_PATH}" + key = generate_bind_key() + + async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client: + resp = await client.post(url, json={"key": key}, headers=get_api_headers()) + resp.raise_for_status() + data = resp.json() + + if data.get("retcode") != 0: + raise RuntimeError(data.get("msg", "create_bind_task failed")) + + task_id = data.get("data", {}).get("task_id") + if not task_id: + raise RuntimeError("create_bind_task: missing task_id in response") + + logger.debug("create_bind_task ok: task_id=%s", task_id) + return task_id, key + + +async def poll_bind_result( + task_id: str, + timeout: float = ONBOARD_API_TIMEOUT, +) -> Tuple[BindStatus, str, str, str]: + """Poll the bind result for *task_id*. + + Returns: + A 4-tuple of ``(status, bot_appid, bot_encrypt_secret, user_openid)``. + + * ``bot_encrypt_secret`` is AES-256-GCM encrypted — decrypt it with + :func:`~gateway.platforms.qqbot.crypto.decrypt_secret` using the + key from :func:`create_bind_task`. + * ``user_openid`` is the OpenID of the person who scanned the code + (available when ``status == COMPLETED``). + + Raises: + RuntimeError: If the API returns a non-zero ``retcode``. + """ + import httpx + + url = f"https://{PORTAL_HOST}{ONBOARD_POLL_PATH}" + + async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client: + resp = await client.post(url, json={"task_id": task_id}, headers=get_api_headers()) + resp.raise_for_status() + data = resp.json() + + if data.get("retcode") != 0: + raise RuntimeError(data.get("msg", "poll_bind_result failed")) + + d = data.get("data", {}) + return ( + BindStatus(d.get("status", 0)), + str(d.get("bot_appid", "")), + d.get("bot_encrypt_secret", ""), + d.get("user_openid", ""), + ) + + +def build_connect_url(task_id: str) -> str: + """Build the QR-code target URL for a given *task_id*.""" + return QR_URL_TEMPLATE.format(task_id=quote(task_id)) diff --git a/gateway/platforms/qqbot/utils.py b/gateway/platforms/qqbot/utils.py new file mode 100644 index 0000000000..873e58d2a5 --- /dev/null +++ b/gateway/platforms/qqbot/utils.py @@ -0,0 +1,71 @@ +"""QQBot shared utilities — User-Agent, HTTP helpers, config coercion.""" + +from __future__ import annotations + +import platform +import sys +from typing import Any, Dict, List + +from .constants import QQBOT_VERSION + + +# --------------------------------------------------------------------------- +# User-Agent +# --------------------------------------------------------------------------- + +def _get_hermes_version() -> str: + """Return the hermes-agent package version, or 'dev' if unavailable.""" + try: + from importlib.metadata import version + return version("hermes-agent") + except Exception: + return "dev" + + +def build_user_agent() -> str: + """Build a descriptive User-Agent string. + + Format:: + + QQBotAdapter/ (Python/; ; Hermes/) + + Example:: + + QQBotAdapter/1.0.0 (Python/3.11.15; darwin; Hermes/0.9.0) + """ + py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + os_name = platform.system().lower() + hermes_version = _get_hermes_version() + return f"QQBotAdapter/{QQBOT_VERSION} (Python/{py_version}; {os_name}; Hermes/{hermes_version})" + + +def get_api_headers() -> Dict[str, str]: + """Return standard HTTP headers for QQBot API requests. + + Includes ``Content-Type``, ``Accept``, and a dynamic ``User-Agent``. + ``q.qq.com`` requires ``Accept: application/json`` — without it, + the server returns a JavaScript anti-bot challenge page. + """ + return { + "Content-Type": "application/json", + "Accept": "application/json", + "User-Agent": build_user_agent(), + } + + +# --------------------------------------------------------------------------- +# Config helpers +# --------------------------------------------------------------------------- + +def coerce_list(value: Any) -> List[str]: + """Coerce config values into a trimmed string list. + + Accepts comma-separated strings, lists, tuples, sets, or single values. + """ + if value is None: + return [] + if isinstance(value, str): + return [item.strip() for item in value.split(",") if item.strip()] + if isinstance(value, (list, tuple, set)): + return [str(item).strip() for item in value if str(item).strip()] + return [str(value).strip()] if str(value).strip() else [] diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 2f4ec93294..5b1fef1337 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -118,6 +118,84 @@ def _strip_mdv2(text: str) -> str: return cleaned +# --------------------------------------------------------------------------- +# Markdown table → code block conversion +# --------------------------------------------------------------------------- +# Telegram's MarkdownV2 has no table syntax — '|' is just an escaped literal, +# so pipe tables render as noisy backslash-pipe text with no alignment. +# Wrapping the table in a fenced code block makes Telegram render it as +# monospace preformatted text with columns intact. + +# Matches a GFM table delimiter row: optional outer pipes, cells containing +# only dashes (with optional leading/trailing colons for alignment) separated +# by '|'. Requires at least one internal '|' so lone '---' horizontal rules +# are NOT matched. +_TABLE_SEPARATOR_RE = re.compile( + r'^\s*\|?\s*:?-+:?\s*(?:\|\s*:?-+:?\s*){1,}\|?\s*$' +) + + +def _is_table_row(line: str) -> bool: + """Return True if *line* could plausibly be a table data row.""" + stripped = line.strip() + return bool(stripped) and '|' in stripped + + +def _wrap_markdown_tables(text: str) -> str: + """Wrap GFM-style pipe tables in ``` fences so Telegram renders them. + + Detected by a row containing '|' immediately followed by a delimiter + row matching :data:`_TABLE_SEPARATOR_RE`. Subsequent pipe-containing + non-blank lines are consumed as the table body and included in the + wrapped block. Tables inside existing fenced code blocks are left + alone. + """ + if '|' not in text or '-' not in text: + return text + + lines = text.split('\n') + out: list[str] = [] + in_fence = False + i = 0 + while i < len(lines): + line = lines[i] + stripped = line.lstrip() + + # Track existing fenced code blocks — never touch content inside. + if stripped.startswith('```'): + in_fence = not in_fence + out.append(line) + i += 1 + continue + if in_fence: + out.append(line) + i += 1 + continue + + # Look for a header row (contains '|') immediately followed by a + # delimiter row. + if ( + '|' in line + and i + 1 < len(lines) + and _TABLE_SEPARATOR_RE.match(lines[i + 1]) + ): + table_block = [line, lines[i + 1]] + j = i + 2 + while j < len(lines) and _is_table_row(lines[j]): + table_block.append(lines[j]) + j += 1 + out.append('```') + out.extend(table_block) + out.append('```') + i = j + continue + + out.append(line) + i += 1 + + return '\n'.join(out) + + class TelegramAdapter(BasePlatformAdapter): """ Telegram bot adapter. @@ -1916,6 +1994,12 @@ class TelegramAdapter(BasePlatformAdapter): text = content + # 0) Pre-wrap GFM-style pipe tables in ``` fences. Telegram can't + # render tables natively, but fenced code blocks render as + # monospace preformatted text with columns intact. The wrapped + # tables then flow through step (1) below as protected regions. + text = _wrap_markdown_tables(text) + # 1) Protect fenced code blocks (``` ... ```) # Per MarkdownV2 spec, \ and ` inside pre/code must be escaped. def _protect_fenced(m): diff --git a/gateway/run.py b/gateway/run.py index 170c6f87de..ea747321f9 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -2178,6 +2178,30 @@ class GatewayRunner: ) except Exception as _e: logger.debug("Idle agent sweep failed: %s", _e) + + # Periodically prune stale SessionStore entries. The + # in-memory dict (and sessions.json) would otherwise grow + # unbounded in gateways serving many rotating chats / + # threads / users over long time windows. Pruning is + # invisible to users — a resumed session just gets a + # fresh session_id, exactly as if the reset policy fired. + _last_prune_ts = getattr(self, "_last_session_store_prune_ts", 0.0) + _prune_interval = 3600.0 # once per hour + if time.time() - _last_prune_ts > _prune_interval: + try: + _max_age = int( + getattr(self.config, "session_store_max_age_days", 0) or 0 + ) + if _max_age > 0: + _pruned = self.session_store.prune_old_entries(_max_age) + if _pruned: + logger.info( + "SessionStore prune: dropped %d stale entries", + _pruned, + ) + except Exception as _e: + logger.debug("SessionStore prune failed: %s", _e) + self._last_session_store_prune_ts = time.time() except Exception as e: logger.debug("Session expiry watcher error: %s", e) # Sleep in small increments so we can stop quickly @@ -2384,6 +2408,7 @@ class GatewayRunner: self.adapters.clear() self._running_agents.clear() + self._running_agents_ts.clear() self._pending_messages.clear() self._pending_approvals.clear() if hasattr(self, '_busy_ack_ts'): @@ -2408,6 +2433,20 @@ class GatewayRunner: except Exception: pass + # Close SQLite session DBs so the WAL write lock is released. + # Without this, --replace and similar restart flows leave the + # old gateway's connection holding the WAL lock until Python + # actually exits — causing 'database is locked' errors when + # the new gateway tries to open the same file. + for _db_holder in (self, getattr(self, "session_store", None)): + _db = getattr(_db_holder, "_db", None) if _db_holder else None + if _db is None or not hasattr(_db, "close"): + continue + try: + _db.close() + except Exception as _e: + logger.debug("SessionDB close error: %s", _e) + from gateway.status import remove_pid_file remove_pid_file() @@ -2906,9 +2945,7 @@ class GatewayRunner: _quick_key[:30], _stale_age, _stale_idle, _raw_stale_timeout, _stale_detail, ) - del self._running_agents[_quick_key] - self._running_agents_ts.pop(_quick_key, None) - self._busy_ack_ts.pop(_quick_key, None) + self._release_running_agent_state(_quick_key) if _quick_key in self._running_agents: if event.get_command() == "status": @@ -2936,8 +2973,7 @@ class GatewayRunner: if adapter and hasattr(adapter, 'get_pending_message'): adapter.get_pending_message(_quick_key) # consume and discard self._pending_messages.pop(_quick_key, None) - if _quick_key in self._running_agents: - del self._running_agents[_quick_key] + self._release_running_agent_state(_quick_key) logger.info("STOP for session %s — agent interrupted, session lock released", _quick_key[:20]) return "⚡ Stopped. You can continue this session." @@ -2959,8 +2995,7 @@ class GatewayRunner: self._pending_messages.pop(_quick_key, None) # Clean up the running agent entry so the reset handler # doesn't think an agent is still active. - if _quick_key in self._running_agents: - del self._running_agents[_quick_key] + self._release_running_agent_state(_quick_key) return await self._handle_reset_command(event) # /queue — queue without interrupting @@ -3041,8 +3076,7 @@ class GatewayRunner: # Agent is being set up but not ready yet. if event.get_command() == "stop": # Force-clean the sentinel so the session is unlocked. - if _quick_key in self._running_agents: - del self._running_agents[_quick_key] + self._release_running_agent_state(_quick_key) logger.info("HARD STOP (pending) for session %s — sentinel cleared", _quick_key[:20]) return "⚡ Force-stopped. The agent was still starting — session unlocked." # Queue the message so it will be picked up after the @@ -3361,8 +3395,13 @@ class GatewayRunner: # (exception, command fallthrough, etc.) the sentinel must # not linger or the session would be permanently locked out. if self._running_agents.get(_quick_key) is _AGENT_PENDING_SENTINEL: - del self._running_agents[_quick_key] - self._running_agents_ts.pop(_quick_key, None) + self._release_running_agent_state(_quick_key) + else: + # Agent path already cleaned _running_agents; make sure + # the paired metadata dicts are gone too. + self._running_agents_ts.pop(_quick_key, None) + if hasattr(self, "_busy_ack_ts"): + self._busy_ack_ts.pop(_quick_key, None) async def _prepare_inbound_message_text( self, @@ -4668,16 +4707,14 @@ class GatewayRunner: agent = self._running_agents.get(session_key) if agent is _AGENT_PENDING_SENTINEL: # Force-clean the sentinel so the session is unlocked. - if session_key in self._running_agents: - del self._running_agents[session_key] + self._release_running_agent_state(session_key) logger.info("STOP (pending) for session %s — sentinel cleared", session_key[:20]) return "⚡ Stopped. The agent hadn't started yet — you can continue this session." if agent: agent.interrupt("Stop requested") # Force-clean the session lock so a truly hung agent doesn't # keep it locked forever. - if session_key in self._running_agents: - del self._running_agents[session_key] + self._release_running_agent_state(session_key) return "⚡ Stopped. You can continue this session." else: return "No active task to stop." @@ -6593,8 +6630,7 @@ class GatewayRunner: logger.debug("Memory flush on resume failed: %s", e) # Clear any running agent for this session key - if session_key in self._running_agents: - del self._running_agents[session_key] + self._release_running_agent_state(session_key) # Switch the session entry to point at the old session new_entry = self.session_store.switch_session(session_key, target_id) @@ -8010,6 +8046,30 @@ class GatewayRunner: override = self._session_model_overrides.get(session_key) return override is not None and override.get("model") == agent_model + def _release_running_agent_state(self, session_key: str) -> None: + """Pop ALL per-running-agent state entries for ``session_key``. + + Replaces ad-hoc ``del self._running_agents[key]`` calls scattered + across the gateway. Those sites had drifted: some popped only + ``_running_agents``; some also ``_running_agents_ts``; only one + path also cleared ``_busy_ack_ts``. Each missed entry was a + small, persistent leak — a (str_key → float) tuple per session + per gateway lifetime. + + Use this at every site that ends a running turn, regardless of + cause (normal completion, /stop, /reset, /resume, sentinel + cleanup, stale-eviction). Per-session state that PERSISTS + across turns (``_session_model_overrides``, ``_voice_mode``, + ``_pending_approvals``, ``_update_prompt_pending``) is NOT + touched here — those have their own lifecycles. + """ + if not session_key: + return + self._running_agents.pop(session_key, None) + self._running_agents_ts.pop(session_key, None) + if hasattr(self, "_busy_ack_ts"): + self._busy_ack_ts.pop(session_key, None) + def _evict_cached_agent(self, session_key: str) -> None: """Remove a cached agent for a session (called on /new, /model, etc).""" _lock = getattr(self, "_agent_cache_lock", None) @@ -9845,10 +9905,8 @@ class GatewayRunner: # Clean up tracking tracking_task.cancel() - if session_key and session_key in self._running_agents: - del self._running_agents[session_key] if session_key: - self._running_agents_ts.pop(session_key, None) + self._release_running_agent_state(session_key) if self._draining: self._update_runtime_status("draining") diff --git a/gateway/session.py b/gateway/session.py index f057d1cfc0..4cb623128c 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -802,6 +802,57 @@ class SessionStore: return True return False + def prune_old_entries(self, max_age_days: int) -> int: + """Drop SessionEntry records older than max_age_days. + + Pruning is based on ``updated_at`` (last activity), not ``created_at``. + A session that's been active within the window is kept regardless of + how old it is. Entries marked ``suspended`` are kept — the user + explicitly paused them for later resume. Entries held by an active + process (via has_active_processes_fn) are also kept so long-running + background work isn't orphaned. + + Pruning is functionally identical to a natural reset-policy expiry: + the transcript in SQLite stays, but the session_key → session_id + mapping is dropped and the user starts a fresh session on return. + + ``max_age_days <= 0`` disables pruning; returns 0 immediately. + Returns the number of entries removed. + """ + if max_age_days is None or max_age_days <= 0: + return 0 + from datetime import timedelta + + cutoff = _now() - timedelta(days=max_age_days) + removed_keys: list[str] = [] + + with self._lock: + self._ensure_loaded_locked() + for key, entry in list(self._entries.items()): + if entry.suspended: + continue + # Never prune sessions with an active background process + # attached — the user may still be waiting on output. + if self._has_active_processes_fn is not None: + try: + if self._has_active_processes_fn(entry.session_id): + continue + except Exception: + pass + if entry.updated_at < cutoff: + removed_keys.append(key) + for key in removed_keys: + self._entries.pop(key, None) + if removed_keys: + self._save() + + if removed_keys: + logger.info( + "SessionStore pruned %d entries older than %d days", + len(removed_keys), max_age_days, + ) + return len(removed_keys) + def suspend_recently_active(self, max_age_seconds: int = 120) -> int: """Mark recently-active sessions as suspended. diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index e79a6dca6d..421836c23c 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -233,6 +233,14 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = { api_key_env_vars=("XAI_API_KEY",), base_url_env_var="XAI_BASE_URL", ), + "nvidia": ProviderConfig( + id="nvidia", + name="NVIDIA NIM", + auth_type="api_key", + inference_base_url="https://integrate.api.nvidia.com/v1", + api_key_env_vars=("NVIDIA_API_KEY",), + base_url_env_var="NVIDIA_BASE_URL", + ), "ai-gateway": ProviderConfig( id="ai-gateway", name="Vercel AI Gateway", diff --git a/hermes_cli/config.py b/hermes_cli/config.py index c7df033701..1670156b27 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -44,7 +44,8 @@ _EXTRA_ENV_KEYS = frozenset({ "WEIXIN_HOME_CHANNEL", "WEIXIN_HOME_CHANNEL_NAME", "WEIXIN_DM_POLICY", "WEIXIN_GROUP_POLICY", "WEIXIN_ALLOWED_USERS", "WEIXIN_GROUP_ALLOWED_USERS", "WEIXIN_ALLOW_ALL_USERS", "BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_PASSWORD", - "QQ_APP_ID", "QQ_CLIENT_SECRET", "QQ_HOME_CHANNEL", "QQ_HOME_CHANNEL_NAME", + "QQ_APP_ID", "QQ_CLIENT_SECRET", "QQBOT_HOME_CHANNEL", "QQBOT_HOME_CHANNEL_NAME", + "QQ_HOME_CHANNEL", "QQ_HOME_CHANNEL_NAME", # legacy aliases (pre-rename, still read for back-compat) "QQ_ALLOWED_USERS", "QQ_GROUP_ALLOWED_USERS", "QQ_ALLOW_ALL_USERS", "QQ_MARKDOWN_SUPPORT", "QQ_STT_API_KEY", "QQ_STT_BASE_URL", "QQ_STT_MODEL", "TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT", @@ -861,6 +862,22 @@ OPTIONAL_ENV_VARS = { "category": "provider", "advanced": True, }, + "NVIDIA_API_KEY": { + "description": "NVIDIA NIM API key (build.nvidia.com or local NIM endpoint)", + "prompt": "NVIDIA NIM API key", + "url": "https://build.nvidia.com/", + "password": True, + "category": "provider", + "advanced": True, + }, + "NVIDIA_BASE_URL": { + "description": "NVIDIA NIM base URL override (e.g. http://localhost:8000/v1 for local NIM)", + "prompt": "NVIDIA NIM base URL (leave empty for default)", + "url": None, + "password": False, + "category": "provider", + "advanced": True, + }, "GLM_API_KEY": { "description": "Z.AI / GLM API key (also recognized as ZAI_API_KEY / Z_AI_API_KEY)", "prompt": "Z.AI / GLM API key", @@ -1518,12 +1535,12 @@ OPTIONAL_ENV_VARS = { "prompt": "Allow All QQ Users", "category": "messaging", }, - "QQ_HOME_CHANNEL": { + "QQBOT_HOME_CHANNEL": { "description": "Default QQ channel/group for cron delivery and notifications", "prompt": "QQ Home Channel", "category": "messaging", }, - "QQ_HOME_CHANNEL_NAME": { + "QQBOT_HOME_CHANNEL_NAME": { "description": "Display name for the QQ home channel", "prompt": "QQ Home Channel Name", "category": "messaging", diff --git a/hermes_cli/doctor.py b/hermes_cli/doctor.py index d044ddf4cf..28c4af1fa8 100644 --- a/hermes_cli/doctor.py +++ b/hermes_cli/doctor.py @@ -825,6 +825,7 @@ def run_doctor(args): ("Arcee AI", ("ARCEEAI_API_KEY",), "https://api.arcee.ai/api/v1/models", "ARCEE_BASE_URL", True), ("DeepSeek", ("DEEPSEEK_API_KEY",), "https://api.deepseek.com/v1/models", "DEEPSEEK_BASE_URL", True), ("Hugging Face", ("HF_TOKEN",), "https://router.huggingface.co/v1/models", "HF_BASE_URL", True), + ("NVIDIA NIM", ("NVIDIA_API_KEY",), "https://integrate.api.nvidia.com/v1/models", "NVIDIA_BASE_URL", True), ("Alibaba/DashScope", ("DASHSCOPE_API_KEY",), "https://dashscope-intl.aliyuncs.com/compatible-mode/v1/models", "DASHSCOPE_BASE_URL", True), # MiniMax: the /anthropic endpoint doesn't support /models, but the /v1 endpoint does. ("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL", True), diff --git a/hermes_cli/dump.py b/hermes_cli/dump.py index a520790857..ae8ecc6419 100644 --- a/hermes_cli/dump.py +++ b/hermes_cli/dump.py @@ -296,6 +296,7 @@ def run_dump(args): ("DEEPSEEK_API_KEY", "deepseek"), ("DASHSCOPE_API_KEY", "dashscope"), ("HF_TOKEN", "huggingface"), + ("NVIDIA_API_KEY", "nvidia"), ("AI_GATEWAY_API_KEY", "ai_gateway"), ("OPENCODE_ZEN_API_KEY", "opencode_zen"), ("OPENCODE_GO_API_KEY", "opencode_go"), diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index 585bbe4460..f5ebcf031c 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -1998,7 +1998,7 @@ _PLATFORMS = [ {"name": "QQ_ALLOWED_USERS", "prompt": "Allowed user OpenIDs (comma-separated, leave empty for open access)", "password": False, "is_allowlist": True, "help": "Optional — restrict DM access to specific user OpenIDs."}, - {"name": "QQ_HOME_CHANNEL", "prompt": "Home channel (user/group OpenID for cron delivery, or empty)", "password": False, + {"name": "QQBOT_HOME_CHANNEL", "prompt": "Home channel (user/group OpenID for cron delivery, or empty)", "password": False, "help": "OpenID to deliver cron results and notifications to."}, ], }, @@ -2625,6 +2625,215 @@ def _setup_feishu(): print_info(f" Bot: {bot_name}") +def _setup_qqbot(): + """Interactive setup for QQ Bot — scan-to-configure or manual credentials.""" + print() + print(color(" ─── 🐧 QQ Bot Setup ───", Colors.CYAN)) + + existing_app_id = get_env_value("QQ_APP_ID") + existing_secret = get_env_value("QQ_CLIENT_SECRET") + if existing_app_id and existing_secret: + print() + print_success("QQ Bot is already configured.") + if not prompt_yes_no(" Reconfigure QQ Bot?", False): + return + + # ── Choose setup method ── + print() + method_choices = [ + "Scan QR code to add bot automatically (recommended)", + "Enter existing App ID and App Secret manually", + ] + method_idx = prompt_choice(" How would you like to set up QQ Bot?", method_choices, 0) + + credentials = None + used_qr = False + + if method_idx == 0: + # ── QR scan-to-configure ── + try: + credentials = _qqbot_qr_flow() + except KeyboardInterrupt: + print() + print_warning(" QQ Bot setup cancelled.") + return + if credentials: + used_qr = True + if not credentials: + print_info(" QR setup did not complete. Continuing with manual input.") + + # ── Manual credential input ── + if not credentials: + print() + print_info(" Go to https://q.qq.com to register a QQ Bot application.") + print_info(" Note your App ID and App Secret from the application page.") + print() + app_id = prompt(" App ID", password=False) + if not app_id: + print_warning(" Skipped — QQ Bot won't work without an App ID.") + return + app_secret = prompt(" App Secret", password=True) + if not app_secret: + print_warning(" Skipped — QQ Bot won't work without an App Secret.") + return + credentials = {"app_id": app_id.strip(), "client_secret": app_secret.strip(), "user_openid": ""} + + # ── Save core credentials ── + save_env_value("QQ_APP_ID", credentials["app_id"]) + save_env_value("QQ_CLIENT_SECRET", credentials["client_secret"]) + + user_openid = credentials.get("user_openid", "") + + # ── DM security policy ── + print() + access_choices = [ + "Use DM pairing approval (recommended)", + "Allow all direct messages", + "Only allow listed user OpenIDs", + ] + access_idx = prompt_choice(" How should direct messages be authorized?", access_choices, 0) + if access_idx == 0: + save_env_value("QQ_ALLOW_ALL_USERS", "false") + if user_openid: + print() + if prompt_yes_no(f" Add yourself ({user_openid}) to the allow list?", True): + save_env_value("QQ_ALLOWED_USERS", user_openid) + print_success(f" Allow list set to {user_openid}") + else: + save_env_value("QQ_ALLOWED_USERS", "") + else: + save_env_value("QQ_ALLOWED_USERS", "") + print_success(" DM pairing enabled.") + print_info(" Unknown users can request access; approve with `hermes pairing approve`.") + elif access_idx == 1: + save_env_value("QQ_ALLOW_ALL_USERS", "true") + save_env_value("QQ_ALLOWED_USERS", "") + print_warning(" Open DM access enabled for QQ Bot.") + else: + default_allow = user_openid or "" + allowlist = prompt(" Allowed user OpenIDs (comma-separated)", default_allow, password=False).replace(" ", "") + save_env_value("QQ_ALLOW_ALL_USERS", "false") + save_env_value("QQ_ALLOWED_USERS", allowlist) + print_success(" Allowlist saved.") + + # ── Home channel ── + if user_openid: + print() + if prompt_yes_no(f" Use your QQ user ID ({user_openid}) as the home channel?", True): + save_env_value("QQBOT_HOME_CHANNEL", user_openid) + print_success(f" Home channel set to {user_openid}") + else: + print() + home_channel = prompt(" Home channel OpenID (for cron/notifications, or empty)", password=False) + if home_channel: + save_env_value("QQBOT_HOME_CHANNEL", home_channel.strip()) + print_success(f" Home channel set to {home_channel.strip()}") + + print() + print_success("🐧 QQ Bot configured!") + print_info(f" App ID: {credentials['app_id']}") + + +def _qqbot_render_qr(url: str) -> bool: + """Try to render a QR code in the terminal. Returns True if successful.""" + try: + import qrcode as _qr + qr = _qr.QRCode(border=1,error_correction=_qr.constants.ERROR_CORRECT_L) + qr.add_data(url) + qr.make(fit=True) + qr.print_ascii(invert=True) + return True + except Exception: + return False + + +def _qqbot_qr_flow(): + """Run the QR-code scan-to-configure flow. + + Returns a dict with app_id, client_secret, user_openid on success, + or None on failure/cancel. + """ + try: + from gateway.platforms.qqbot import ( + create_bind_task, poll_bind_result, build_connect_url, + decrypt_secret, BindStatus, + ) + from gateway.platforms.qqbot.constants import ONBOARD_POLL_INTERVAL + except Exception as exc: + print_error(f" QQBot onboard import failed: {exc}") + return None + + import asyncio + import time + + MAX_REFRESHES = 3 + refresh_count = 0 + + while refresh_count <= MAX_REFRESHES: + loop = asyncio.new_event_loop() + + # ── Create bind task ── + try: + task_id, aes_key = loop.run_until_complete(create_bind_task()) + except Exception as e: + print_warning(f" Failed to create bind task: {e}") + loop.close() + return None + + url = build_connect_url(task_id) + + # ── Display QR code + URL ── + print() + if _qqbot_render_qr(url): + print(f" Scan the QR code above, or open this URL directly:\n {url}") + else: + print(f" Open this URL in QQ on your phone:\n {url}") + print_info(" Tip: pip install qrcode to show a scannable QR code here") + + # ── Poll loop (silent — keep QR visible at bottom) ── + try: + while True: + try: + status, app_id, encrypted_secret, user_openid = loop.run_until_complete( + poll_bind_result(task_id) + ) + except Exception: + time.sleep(ONBOARD_POLL_INTERVAL) + continue + + if status == BindStatus.COMPLETED: + client_secret = decrypt_secret(encrypted_secret, aes_key) + print() + print_success(f" QR scan complete! (App ID: {app_id})") + if user_openid: + print_info(f" Scanner's OpenID: {user_openid}") + return { + "app_id": app_id, + "client_secret": client_secret, + "user_openid": user_openid, + } + + if status == BindStatus.EXPIRED: + refresh_count += 1 + if refresh_count > MAX_REFRESHES: + print() + print_warning(f" QR code expired {MAX_REFRESHES} times — giving up.") + return None + print() + print_warning(f" QR code expired, refreshing... ({refresh_count}/{MAX_REFRESHES})") + loop.close() + break # outer while creates a new task + + time.sleep(ONBOARD_POLL_INTERVAL) + except KeyboardInterrupt: + loop.close() + raise + finally: + loop.close() + + return None + + def _setup_signal(): """Interactive setup for Signal messenger.""" import shutil @@ -2806,6 +3015,8 @@ def gateway_setup(): _setup_dingtalk() elif platform["key"] == "feishu": _setup_feishu() + elif platform["key"] == "qqbot": + _setup_qqbot() else: _setup_standard_platform(platform) diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 195fd53c4f..2fb27dd2da 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -1532,6 +1532,7 @@ def select_provider_and_model(args=None): "huggingface", "xiaomi", "arcee", + "nvidia", "ollama-cloud", ): _model_flow_api_key_provider(config, selected_provider, current_model) @@ -5875,6 +5876,7 @@ For more help on a command: "kilocode", "xiaomi", "arcee", + "nvidia", ], default=None, help="Inference provider (default: auto)", diff --git a/hermes_cli/models.py b/hermes_cli/models.py index fe2a0c433a..cbbeef62d4 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -135,7 +135,6 @@ _PROVIDER_MODELS: dict[str, list[str]] = { "gemini-2.5-flash-lite", # Gemma open models (also served via AI Studio) "gemma-4-31b-it", - "gemma-4-26b-it", ], "google-gemini-cli": [ "gemini-2.5-pro", @@ -155,6 +154,20 @@ _PROVIDER_MODELS: dict[str, list[str]] = { "grok-4.20-reasoning", "grok-4-1-fast-reasoning", ], + "nvidia": [ + # NVIDIA flagship reasoning models + "nvidia/nemotron-3-super-120b-a12b", + "nvidia/nemotron-3-nano-30b-a3b", + "nvidia/llama-3.3-nemotron-super-49b-v1.5", + # Third-party agentic models hosted on build.nvidia.com + # (map to OpenRouter defaults — users get familiar picks on NIM) + "qwen/qwen3.5-397b-a17b", + "deepseek-ai/deepseek-v3.2", + "moonshotai/kimi-k2.5", + "minimaxai/minimax-m2.5", + "z-ai/glm5", + "openai/gpt-oss-120b", + ], "kimi-coding": [ "kimi-k2.5", "kimi-for-coding", @@ -536,6 +549,7 @@ CANONICAL_PROVIDERS: list[ProviderEntry] = [ ProviderEntry("anthropic", "Anthropic", "Anthropic (Claude models — API key or Claude Code)"), ProviderEntry("openai-codex", "OpenAI Codex", "OpenAI Codex"), ProviderEntry("xiaomi", "Xiaomi MiMo", "Xiaomi MiMo (MiMo-V2 models — pro, omni, flash)"), + ProviderEntry("nvidia", "NVIDIA NIM", "NVIDIA NIM (Nemotron models — build.nvidia.com or local NIM)"), ProviderEntry("qwen-oauth", "Qwen OAuth (Portal)", "Qwen OAuth (reuses local Qwen CLI login)"), ProviderEntry("copilot", "GitHub Copilot", "GitHub Copilot (uses GITHUB_TOKEN or gh auth token)"), ProviderEntry("copilot-acp", "GitHub Copilot ACP", "GitHub Copilot ACP (spawns `copilot --acp --stdio`)"), @@ -618,6 +632,10 @@ _PROVIDER_ALIASES = { "grok": "xai", "x-ai": "xai", "x.ai": "xai", + "nim": "nvidia", + "nvidia-nim": "nvidia", + "build-nvidia": "nvidia", + "nemotron": "nvidia", "ollama": "custom", # bare "ollama" = local; use "ollama-cloud" for cloud "ollama_cloud": "ollama-cloud", } diff --git a/hermes_cli/providers.py b/hermes_cli/providers.py index b2dda20be5..a71055cfe4 100644 --- a/hermes_cli/providers.py +++ b/hermes_cli/providers.py @@ -137,6 +137,11 @@ HERMES_OVERLAYS: Dict[str, HermesOverlay] = { base_url_override="https://api.x.ai/v1", base_url_env_var="XAI_BASE_URL", ), + "nvidia": HermesOverlay( + transport="openai_chat", + base_url_override="https://integrate.api.nvidia.com/v1", + base_url_env_var="NVIDIA_BASE_URL", + ), "xiaomi": HermesOverlay( transport="openai_chat", base_url_env_var="XIAOMI_BASE_URL", @@ -191,6 +196,12 @@ ALIASES: Dict[str, str] = { "x.ai": "xai", "grok": "xai", + # nvidia + "nim": "nvidia", + "nvidia-nim": "nvidia", + "build-nvidia": "nvidia", + "nemotron": "nvidia", + # kimi-for-coding (models.dev ID) "kimi": "kimi-for-coding", "kimi-coding": "kimi-for-coding", diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index b5efb52a88..95c9cae77e 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -91,7 +91,7 @@ _DEFAULT_PROVIDER_MODELS = { "gemini": [ "gemini-3.1-pro-preview", "gemini-3-flash-preview", "gemini-3.1-flash-lite-preview", "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite", - "gemma-4-31b-it", "gemma-4-26b-it", + "gemma-4-31b-it", ], "zai": ["glm-5.1", "glm-5", "glm-4.7", "glm-4.5", "glm-4.5-flash"], "kimi-coding": ["kimi-k2.5", "kimi-k2-thinking", "kimi-k2-turbo-preview"], @@ -2005,52 +2005,6 @@ def _setup_wecom_callback(): _gw_setup() -def _setup_qqbot(): - """Configure QQ Bot gateway.""" - print_header("QQ Bot") - existing = get_env_value("QQ_APP_ID") - if existing: - print_info("QQ Bot: already configured") - if not prompt_yes_no("Reconfigure QQ Bot?", False): - return - - print_info("Connects Hermes to QQ via the Official QQ Bot API (v2).") - print_info(" Requires a QQ Bot application at q.qq.com") - print_info(" Reference: https://bot.q.qq.com/wiki/develop/api-v2/") - print() - - app_id = prompt("QQ Bot App ID") - if not app_id: - print_warning("App ID is required — skipping QQ Bot setup") - return - save_env_value("QQ_APP_ID", app_id.strip()) - - client_secret = prompt("QQ Bot App Secret", password=True) - if not client_secret: - print_warning("App Secret is required — skipping QQ Bot setup") - return - save_env_value("QQ_CLIENT_SECRET", client_secret) - print_success("QQ Bot credentials saved") - - print() - print_info("🔒 Security: Restrict who can DM your bot") - print_info(" Use QQ user OpenIDs (found in event payloads)") - print() - allowed_users = prompt("Allowed user OpenIDs (comma-separated, leave empty for open access)") - if allowed_users: - save_env_value("QQ_ALLOWED_USERS", allowed_users.replace(" ", "")) - print_success("QQ Bot allowlist configured") - else: - print_info("⚠️ No allowlist set — anyone can DM the bot!") - - print() - print_info("📬 Home Channel: OpenID for cron job delivery and notifications.") - home_channel = prompt("Home channel OpenID (leave empty to set later)") - if home_channel: - save_env_value("QQ_HOME_CHANNEL", home_channel) - - print() - print_success("QQ Bot configured!") def _setup_bluebubbles(): @@ -2119,12 +2073,9 @@ def _setup_bluebubbles(): def _setup_qqbot(): - """Configure QQ Bot (Official API v2) via standard platform setup.""" - from hermes_cli.gateway import _PLATFORMS - qq_platform = next((p for p in _PLATFORMS if p["key"] == "qqbot"), None) - if qq_platform: - from hermes_cli.gateway import _setup_standard_platform - _setup_standard_platform(qq_platform) + """Configure QQ Bot (Official API v2) via gateway setup.""" + from hermes_cli.gateway import _setup_qqbot as _gateway_setup_qqbot + _gateway_setup_qqbot() def _setup_webhooks(): @@ -2264,7 +2215,9 @@ def setup_gateway(config: dict): missing_home.append("Slack") if get_env_value("BLUEBUBBLES_SERVER_URL") and not get_env_value("BLUEBUBBLES_HOME_CHANNEL"): missing_home.append("BlueBubbles") - if get_env_value("QQ_APP_ID") and not get_env_value("QQ_HOME_CHANNEL"): + if get_env_value("QQ_APP_ID") and not ( + get_env_value("QQBOT_HOME_CHANNEL") or get_env_value("QQ_HOME_CHANNEL") + ): missing_home.append("QQBot") if missing_home: diff --git a/hermes_cli/status.py b/hermes_cli/status.py index 2e34ae9c36..bc3290d56e 100644 --- a/hermes_cli/status.py +++ b/hermes_cli/status.py @@ -317,7 +317,7 @@ def show_status(args): "WeCom Callback": ("WECOM_CALLBACK_CORP_ID", None), "Weixin": ("WEIXIN_ACCOUNT_ID", "WEIXIN_HOME_CHANNEL"), "BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"), - "QQBot": ("QQ_APP_ID", "QQ_HOME_CHANNEL"), + "QQBot": ("QQ_APP_ID", "QQBOT_HOME_CHANNEL"), } for name, (token_var, home_var) in platforms.items(): @@ -327,6 +327,9 @@ def show_status(args): home_channel = "" if home_var: home_channel = os.getenv(home_var, "") + # Back-compat: QQBot home channel was renamed from QQ_HOME_CHANNEL to QQBOT_HOME_CHANNEL + if not home_channel and home_var == "QQBOT_HOME_CHANNEL": + home_channel = os.getenv("QQ_HOME_CHANNEL", "") status = "configured" if has_token else "not configured" if home_channel: diff --git a/run_agent.py b/run_agent.py index 64572001b3..bb8cfa459d 100644 --- a/run_agent.py +++ b/run_agent.py @@ -7208,14 +7208,22 @@ class AIAgent: # Use auxiliary client for the flush call when available -- # it's cheaper and avoids Codex Responses API incompatibility. - from agent.auxiliary_client import call_llm as _call_llm + from agent.auxiliary_client import ( + call_llm as _call_llm, + _fixed_temperature_for_model, + ) _aux_available = True + # Use the fixed-temperature override (e.g. kimi-for-coding → 0.6) if + # the model has a strict contract; otherwise the historical 0.3 default. + _flush_temperature = _fixed_temperature_for_model(self.model) + if _flush_temperature is None: + _flush_temperature = 0.3 try: response = _call_llm( task="flush_memories", messages=api_messages, tools=[memory_tool_def], - temperature=0.3, + temperature=_flush_temperature, max_tokens=5120, # timeout resolved from auxiliary.flush_memories.timeout config ) @@ -7227,7 +7235,7 @@ class AIAgent: # No auxiliary client -- use the Codex Responses path directly codex_kwargs = self._build_api_kwargs(api_messages) codex_kwargs["tools"] = self._responses_tools([memory_tool_def]) - codex_kwargs["temperature"] = 0.3 + codex_kwargs["temperature"] = _flush_temperature if "max_output_tokens" in codex_kwargs: codex_kwargs["max_output_tokens"] = 5120 response = self._run_codex_stream(codex_kwargs) @@ -7246,7 +7254,7 @@ class AIAgent: "model": self.model, "messages": api_messages, "tools": [memory_tool_def], - "temperature": 0.3, + "temperature": _flush_temperature, **self._max_tokens_param(5120), } from agent.auxiliary_client import _get_task_timeout diff --git a/scripts/release.py b/scripts/release.py index 028f75ba64..c6d906436b 100755 --- a/scripts/release.py +++ b/scripts/release.py @@ -256,6 +256,8 @@ AUTHOR_MAP = { "anthhub@163.com": "anthhub", "shenuu@gmail.com": "shenuu", "xiayh17@gmail.com": "xiayh0107", + "asurla@nvidia.com": "anniesurla", + "limkuan24@gmail.com": "WideLee", } diff --git a/tests/agent/test_auxiliary_client.py b/tests/agent/test_auxiliary_client.py index 5d79f96dea..1778855ddd 100644 --- a/tests/agent/test_auxiliary_client.py +++ b/tests/agent/test_auxiliary_client.py @@ -696,6 +696,95 @@ class TestIsConnectionError: assert _is_connection_error(err) is False +class TestKimiForCodingTemperature: + """kimi-for-coding now requires temperature=0.6 exactly.""" + + def test_build_call_kwargs_forces_fixed_temperature(self): + from agent.auxiliary_client import _build_call_kwargs + + kwargs = _build_call_kwargs( + provider="kimi-coding", + model="kimi-for-coding", + messages=[{"role": "user", "content": "hello"}], + temperature=0.3, + ) + + assert kwargs["temperature"] == 0.6 + + def test_build_call_kwargs_injects_temperature_when_missing(self): + from agent.auxiliary_client import _build_call_kwargs + + kwargs = _build_call_kwargs( + provider="kimi-coding", + model="kimi-for-coding", + messages=[{"role": "user", "content": "hello"}], + temperature=None, + ) + + assert kwargs["temperature"] == 0.6 + + def test_auto_routed_kimi_for_coding_sync_call_uses_fixed_temperature(self): + client = MagicMock() + client.base_url = "https://api.kimi.com/coding/v1" + response = MagicMock() + client.chat.completions.create.return_value = response + + with patch( + "agent.auxiliary_client._get_cached_client", + return_value=(client, "kimi-for-coding"), + ), patch( + "agent.auxiliary_client._resolve_task_provider_model", + return_value=("auto", "kimi-for-coding", None, None, None), + ): + result = call_llm( + task="session_search", + messages=[{"role": "user", "content": "hello"}], + temperature=0.1, + ) + + assert result is response + kwargs = client.chat.completions.create.call_args.kwargs + assert kwargs["model"] == "kimi-for-coding" + assert kwargs["temperature"] == 0.6 + + @pytest.mark.asyncio + async def test_auto_routed_kimi_for_coding_async_call_uses_fixed_temperature(self): + client = MagicMock() + client.base_url = "https://api.kimi.com/coding/v1" + response = MagicMock() + client.chat.completions.create = AsyncMock(return_value=response) + + with patch( + "agent.auxiliary_client._get_cached_client", + return_value=(client, "kimi-for-coding"), + ), patch( + "agent.auxiliary_client._resolve_task_provider_model", + return_value=("auto", "kimi-for-coding", None, None, None), + ): + result = await async_call_llm( + task="session_search", + messages=[{"role": "user", "content": "hello"}], + temperature=0.1, + ) + + assert result is response + kwargs = client.chat.completions.create.call_args.kwargs + assert kwargs["model"] == "kimi-for-coding" + assert kwargs["temperature"] == 0.6 + + def test_non_kimi_model_still_preserves_temperature(self): + from agent.auxiliary_client import _build_call_kwargs + + kwargs = _build_call_kwargs( + provider="kimi-coding", + model="kimi-k2.5", + messages=[{"role": "user", "content": "hello"}], + temperature=0.3, + ) + + assert kwargs["temperature"] == 0.3 + + # --------------------------------------------------------------------------- # async_call_llm payment / connection fallback (#7512 bug 2) # --------------------------------------------------------------------------- diff --git a/tests/agent/test_gemini_cloudcode.py b/tests/agent/test_gemini_cloudcode.py index cf5e80f08a..c9d2b87df8 100644 --- a/tests/agent/test_gemini_cloudcode.py +++ b/tests/agent/test_gemini_cloudcode.py @@ -826,6 +826,160 @@ class TestGeminiCloudCodeClient: finally: client.close() + +class TestGeminiHttpErrorParsing: + """Regression coverage for _gemini_http_error Google-envelope parsing. + + These are the paths that users actually hit during Google-side throttling + (April 2026: gemini-2.5-pro MODEL_CAPACITY_EXHAUSTED, gemma-4-26b-it + returning 404). The error needs to carry status_code + response so the + main loop's error_classifier and Retry-After logic work. + """ + + @staticmethod + def _fake_response(status: int, body: dict | str = "", headers=None): + """Minimal httpx.Response stand-in (duck-typed for _gemini_http_error).""" + class _FakeResponse: + def __init__(self): + self.status_code = status + if isinstance(body, dict): + self.text = json.dumps(body) + else: + self.text = body + self.headers = headers or {} + return _FakeResponse() + + def test_model_capacity_exhausted_produces_friendly_message(self): + from agent.gemini_cloudcode_adapter import _gemini_http_error + + body = { + "error": { + "code": 429, + "message": "Resource has been exhausted (e.g. check quota).", + "status": "RESOURCE_EXHAUSTED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "MODEL_CAPACITY_EXHAUSTED", + "domain": "googleapis.com", + "metadata": {"model": "gemini-2.5-pro"}, + }, + { + "@type": "type.googleapis.com/google.rpc.RetryInfo", + "retryDelay": "30s", + }, + ], + } + } + err = _gemini_http_error(self._fake_response(429, body)) + assert err.status_code == 429 + assert err.code == "code_assist_capacity_exhausted" + assert err.retry_after == 30.0 + assert err.details["reason"] == "MODEL_CAPACITY_EXHAUSTED" + # Message must be user-friendly, not a raw JSON dump. + message = str(err) + assert "gemini-2.5-pro" in message + assert "capacity exhausted" in message.lower() + assert "30s" in message + # response attr is preserved for run_agent's Retry-After header path. + assert err.response is not None + + def test_resource_exhausted_without_reason(self): + from agent.gemini_cloudcode_adapter import _gemini_http_error + + body = { + "error": { + "code": 429, + "message": "Quota exceeded for requests per minute.", + "status": "RESOURCE_EXHAUSTED", + } + } + err = _gemini_http_error(self._fake_response(429, body)) + assert err.status_code == 429 + assert err.code == "code_assist_rate_limited" + message = str(err) + assert "quota" in message.lower() + + def test_404_model_not_found_produces_model_retired_message(self): + from agent.gemini_cloudcode_adapter import _gemini_http_error + + body = { + "error": { + "code": 404, + "message": "models/gemma-4-26b-it is not found for API version v1internal", + "status": "NOT_FOUND", + } + } + err = _gemini_http_error(self._fake_response(404, body)) + assert err.status_code == 404 + message = str(err) + assert "not available" in message.lower() or "retired" in message.lower() + # Error message should reference the actual model text from Google. + assert "gemma-4-26b-it" in message + + def test_unauthorized_preserves_status_code(self): + from agent.gemini_cloudcode_adapter import _gemini_http_error + + err = _gemini_http_error(self._fake_response( + 401, {"error": {"code": 401, "message": "Invalid token", "status": "UNAUTHENTICATED"}}, + )) + assert err.status_code == 401 + assert err.code == "code_assist_unauthorized" + + def test_retry_after_header_fallback(self): + """If the body has no RetryInfo detail, fall back to Retry-After header.""" + from agent.gemini_cloudcode_adapter import _gemini_http_error + + resp = self._fake_response( + 429, + {"error": {"code": 429, "message": "Rate limited", "status": "RESOURCE_EXHAUSTED"}}, + headers={"Retry-After": "45"}, + ) + err = _gemini_http_error(resp) + assert err.retry_after == 45.0 + + def test_malformed_body_still_produces_structured_error(self): + """Non-JSON body must not swallow status_code — we still want the classifier path.""" + from agent.gemini_cloudcode_adapter import _gemini_http_error + + err = _gemini_http_error(self._fake_response(500, "internal error")) + assert err.status_code == 500 + # Raw body snippet must still be there for debugging. + assert "500" in str(err) + + def test_status_code_flows_through_error_classifier(self): + """End-to-end: CodeAssistError from a 429 must classify as rate_limit. + + This is the whole point of adding status_code to CodeAssistError — + _extract_status_code must see it and FailoverReason.rate_limit must + fire, so the main loop triggers fallback_providers. + """ + from agent.gemini_cloudcode_adapter import _gemini_http_error + from agent.error_classifier import classify_api_error, FailoverReason + + body = { + "error": { + "code": 429, + "message": "Resource has been exhausted", + "status": "RESOURCE_EXHAUSTED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "MODEL_CAPACITY_EXHAUSTED", + "metadata": {"model": "gemini-2.5-pro"}, + } + ], + } + } + err = _gemini_http_error(self._fake_response(429, body)) + + classified = classify_api_error( + err, provider="google-gemini-cli", model="gemini-2.5-pro", + ) + assert classified.status_code == 429 + assert classified.reason == FailoverReason.rate_limit + + # ============================================================================= # Provider registration # ============================================================================= diff --git a/tests/conftest.py b/tests/conftest.py index 27950118e1..c5b367266e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -229,6 +229,15 @@ def _hermetic_environment(tmp_path, monkeypatch): monkeypatch.setenv("LC_ALL", "C.UTF-8") monkeypatch.setenv("PYTHONHASHSEED", "0") + # 4b. Disable AWS IMDS lookups. Without this, any test that ends up + # calling has_aws_credentials() / resolve_aws_auth_env_var() + # (e.g. provider auto-detect, status command, cron run_job) burns + # ~2s waiting for the metadata service at 169.254.169.254 to time + # out. Tests don't run on EC2 — IMDS is always unreachable here. + monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true") + monkeypatch.setenv("AWS_METADATA_SERVICE_TIMEOUT", "1") + monkeypatch.setenv("AWS_METADATA_SERVICE_NUM_ATTEMPTS", "1") + # 5. Reset plugin singleton so tests don't leak plugins from # ~/.hermes/plugins/ (which, per step 3, is now empty — but the # singleton might still be cached from a previous test). diff --git a/tests/gateway/test_qqbot.py b/tests/gateway/test_qqbot.py index 18b1b59b75..a5aeb62516 100644 --- a/tests/gateway/test_qqbot.py +++ b/tests/gateway/test_qqbot.py @@ -179,7 +179,7 @@ class TestVoiceAttachmentSSRFProtection: from gateway.platforms.qqbot import QQAdapter, _ssrf_redirect_guard client = mock.AsyncMock() - with mock.patch("gateway.platforms.qqbot.httpx.AsyncClient", return_value=client) as async_client_cls: + with mock.patch("gateway.platforms.qqbot.adapter.httpx.AsyncClient", return_value=client) as async_client_cls: adapter = QQAdapter(_make_config(app_id="a", client_secret="b")) adapter._ensure_token = mock.AsyncMock(side_effect=RuntimeError("stop after client creation")) diff --git a/tests/gateway/test_session_state_cleanup.py b/tests/gateway/test_session_state_cleanup.py new file mode 100644 index 0000000000..3c708736c3 --- /dev/null +++ b/tests/gateway/test_session_state_cleanup.py @@ -0,0 +1,231 @@ +"""Regression tests for _release_running_agent_state and SessionDB shutdown. + +Before this change, running-agent state lived in three dicts that drifted +out of sync: + + self._running_agents — AIAgent instance per session key + self._running_agents_ts — start timestamp per session key + self._busy_ack_ts — last busy-ack timestamp per session key + +Six cleanup sites did ``del self._running_agents[key]`` without touching +the other two; one site only popped ``_running_agents`` and +``_running_agents_ts``; and only the stale-eviction site cleaned all +three. Each missed entry was a small persistent leak. + +Also: SessionDB connections were never closed on gateway shutdown, +leaving WAL locks in place until Python actually exited. +""" + +import threading +from unittest.mock import MagicMock + +import pytest + + +def _make_runner(): + """Bare GatewayRunner wired with just the state the helper touches.""" + from gateway.run import GatewayRunner + + runner = GatewayRunner.__new__(GatewayRunner) + runner._running_agents = {} + runner._running_agents_ts = {} + runner._busy_ack_ts = {} + return runner + + +class TestReleaseRunningAgentStateUnit: + def test_pops_all_three_dicts(self): + runner = _make_runner() + runner._running_agents["k"] = MagicMock() + runner._running_agents_ts["k"] = 123.0 + runner._busy_ack_ts["k"] = 456.0 + + runner._release_running_agent_state("k") + + assert "k" not in runner._running_agents + assert "k" not in runner._running_agents_ts + assert "k" not in runner._busy_ack_ts + + def test_idempotent_on_missing_key(self): + """Calling twice (or on an absent key) must not raise.""" + runner = _make_runner() + runner._release_running_agent_state("missing") + runner._release_running_agent_state("missing") # still fine + + def test_noop_on_empty_session_key(self): + """Empty string / None key is treated as a no-op.""" + runner = _make_runner() + runner._running_agents[""] = "guard" + runner._release_running_agent_state("") + # Empty key not processed — guard value survives. + assert runner._running_agents[""] == "guard" + + def test_preserves_other_sessions(self): + runner = _make_runner() + for k in ("a", "b", "c"): + runner._running_agents[k] = MagicMock() + runner._running_agents_ts[k] = 1.0 + runner._busy_ack_ts[k] = 1.0 + + runner._release_running_agent_state("b") + + assert set(runner._running_agents.keys()) == {"a", "c"} + assert set(runner._running_agents_ts.keys()) == {"a", "c"} + assert set(runner._busy_ack_ts.keys()) == {"a", "c"} + + def test_handles_missing_busy_ack_attribute(self): + """Backward-compatible with older runners lacking _busy_ack_ts.""" + runner = _make_runner() + del runner._busy_ack_ts # simulate older version + runner._running_agents["k"] = MagicMock() + runner._running_agents_ts["k"] = 1.0 + + runner._release_running_agent_state("k") # should not raise + + assert "k" not in runner._running_agents + assert "k" not in runner._running_agents_ts + + def test_concurrent_release_is_safe(self): + """Multiple threads releasing different keys concurrently.""" + runner = _make_runner() + for i in range(50): + k = f"s{i}" + runner._running_agents[k] = MagicMock() + runner._running_agents_ts[k] = float(i) + runner._busy_ack_ts[k] = float(i) + + def worker(keys): + for k in keys: + runner._release_running_agent_state(k) + + threads = [ + threading.Thread(target=worker, args=([f"s{i}" for i in range(start, 50, 5)],)) + for start in range(5) + ] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + assert not t.is_alive() + + assert runner._running_agents == {} + assert runner._running_agents_ts == {} + assert runner._busy_ack_ts == {} + + +class TestNoMoreBareDeleteSites: + """Regression: all bare `del self._running_agents[key]` sites were + converted to use the helper. If a future contributor reverts one, + this test flags it. Docstrings / comments mentioning the old + pattern are allowed. + """ + + def test_no_bare_del_of_running_agents_in_gateway_run(self): + from pathlib import Path + import re + + gateway_run = (Path(__file__).parent.parent.parent / "gateway" / "run.py").read_text() + # Match `del self._running_agents[...]` that is NOT inside a + # triple-quoted docstring. We scan non-docstring lines only. + lines = gateway_run.splitlines() + + in_docstring = False + docstring_delim = None + offenders = [] + for idx, line in enumerate(lines, start=1): + stripped = line.strip() + if not in_docstring: + if stripped.startswith('"""') or stripped.startswith("'''"): + delim = stripped[:3] + # single-line docstring? + if stripped.count(delim) >= 2: + continue + in_docstring = True + docstring_delim = delim + continue + if re.search(r"\bdel\s+self\._running_agents\[", line): + offenders.append((idx, line.rstrip())) + else: + if docstring_delim and docstring_delim in stripped: + in_docstring = False + docstring_delim = None + + assert offenders == [], ( + "Found bare `del self._running_agents[...]` sites in gateway/run.py. " + "Use self._release_running_agent_state(session_key) instead so " + "_running_agents_ts and _busy_ack_ts are popped in lockstep.\n" + + "\n".join(f" line {n}: {l}" for n, l in offenders) + ) + + +class TestSessionDbCloseOnShutdown: + """_stop_impl should call .close() on both self._session_db and + self.session_store._db to release SQLite WAL locks before the new + gateway (during --replace restart) tries to open the same file. + """ + + def test_stop_impl_closes_both_session_dbs(self): + """Run the exact shutdown block that closes SessionDBs and verify + .close() was called on both holders.""" + from gateway.run import GatewayRunner + + runner = GatewayRunner.__new__(GatewayRunner) + + runner_db = MagicMock() + store_db = MagicMock() + + runner._db = runner_db + runner.session_store = MagicMock() + runner.session_store._db = store_db + + # Replicate the exact production loop from _stop_impl. + for _db_holder in (runner, getattr(runner, "session_store", None)): + _db = getattr(_db_holder, "_db", None) if _db_holder else None + if _db is None or not hasattr(_db, "close"): + continue + _db.close() + + runner_db.close.assert_called_once() + store_db.close.assert_called_once() + + def test_shutdown_tolerates_missing_session_store(self): + """Gateway without a session_store attribute must not crash on shutdown.""" + from gateway.run import GatewayRunner + + runner = GatewayRunner.__new__(GatewayRunner) + runner._db = MagicMock() + # Deliberately no session_store attribute. + + for _db_holder in (runner, getattr(runner, "session_store", None)): + _db = getattr(_db_holder, "_db", None) if _db_holder else None + if _db is None or not hasattr(_db, "close"): + continue + _db.close() + + runner._db.close.assert_called_once() + + def test_shutdown_tolerates_close_raising(self): + """A close() that raises must not prevent subsequent cleanup.""" + from gateway.run import GatewayRunner + + runner = GatewayRunner.__new__(GatewayRunner) + flaky_db = MagicMock() + flaky_db.close.side_effect = RuntimeError("simulated lock error") + healthy_db = MagicMock() + + runner._db = flaky_db + runner.session_store = MagicMock() + runner.session_store._db = healthy_db + + # Same pattern as production: try/except around each close(). + for _db_holder in (runner, getattr(runner, "session_store", None)): + _db = getattr(_db_holder, "_db", None) if _db_holder else None + if _db is None or not hasattr(_db, "close"): + continue + try: + _db.close() + except Exception: + pass + + flaky_db.close.assert_called_once() + healthy_db.close.assert_called_once() diff --git a/tests/gateway/test_session_store_prune.py b/tests/gateway/test_session_store_prune.py new file mode 100644 index 0000000000..9b1dca2971 --- /dev/null +++ b/tests/gateway/test_session_store_prune.py @@ -0,0 +1,270 @@ +"""Tests for SessionStore.prune_old_entries and the gateway watcher that calls it. + +The SessionStore in-memory dict (and its backing sessions.json) grew +unbounded — every unique (platform, chat_id, thread_id, user_id) tuple +ever seen was kept forever, regardless of how stale it became. These +tests pin the prune behaviour: + + * Entries older than max_age_days (by updated_at) are removed + * Entries marked ``suspended`` are preserved (user-paused) + * Entries with an active process attached are preserved + * max_age_days <= 0 disables pruning entirely + * sessions.json is rewritten with the post-prune dict + * The ``updated_at`` field — not ``created_at`` — drives the decision + (so a long-running-but-still-active session isn't pruned) +""" + +import json +import threading +from datetime import datetime, timedelta +from unittest.mock import patch + +import pytest + +from gateway.config import GatewayConfig, Platform, SessionResetPolicy +from gateway.session import SessionEntry, SessionStore + + +def _make_store(tmp_path, max_age_days: int = 90, has_active_processes_fn=None): + """Build a SessionStore bypassing SQLite/disk-load side effects.""" + config = GatewayConfig( + default_reset_policy=SessionResetPolicy(mode="none"), + session_store_max_age_days=max_age_days, + ) + with patch("gateway.session.SessionStore._ensure_loaded"): + store = SessionStore( + sessions_dir=tmp_path, + config=config, + has_active_processes_fn=has_active_processes_fn, + ) + store._db = None + store._loaded = True + return store + + +def _entry(key: str, age_days: float, *, suspended: bool = False, + session_id: str | None = None) -> SessionEntry: + now = datetime.now() + return SessionEntry( + session_key=key, + session_id=session_id or f"sid_{key}", + created_at=now - timedelta(days=age_days + 30), # arbitrary older + updated_at=now - timedelta(days=age_days), + platform=Platform.TELEGRAM, + chat_type="dm", + suspended=suspended, + ) + + +class TestPruneBasics: + def test_prune_removes_entries_past_max_age(self, tmp_path): + store = _make_store(tmp_path) + store._entries["old"] = _entry("old", age_days=100) + store._entries["fresh"] = _entry("fresh", age_days=5) + + removed = store.prune_old_entries(max_age_days=90) + + assert removed == 1 + assert "old" not in store._entries + assert "fresh" in store._entries + + def test_prune_uses_updated_at_not_created_at(self, tmp_path): + """A session created long ago but updated recently must be kept.""" + store = _make_store(tmp_path) + now = datetime.now() + entry = SessionEntry( + session_key="long-lived", + session_id="sid", + created_at=now - timedelta(days=365), # ancient + updated_at=now - timedelta(days=3), # but just chatted + platform=Platform.TELEGRAM, + chat_type="dm", + ) + store._entries["long-lived"] = entry + + removed = store.prune_old_entries(max_age_days=30) + + assert removed == 0 + assert "long-lived" in store._entries + + def test_prune_disabled_when_max_age_is_zero(self, tmp_path): + store = _make_store(tmp_path, max_age_days=0) + for i in range(5): + store._entries[f"s{i}"] = _entry(f"s{i}", age_days=365) + + assert store.prune_old_entries(0) == 0 + assert len(store._entries) == 5 + + def test_prune_disabled_when_max_age_is_negative(self, tmp_path): + store = _make_store(tmp_path) + store._entries["s"] = _entry("s", age_days=365) + + assert store.prune_old_entries(-1) == 0 + assert "s" in store._entries + + def test_prune_skips_suspended_entries(self, tmp_path): + """/stop-suspended sessions must be kept for later resume.""" + store = _make_store(tmp_path) + store._entries["suspended"] = _entry( + "suspended", age_days=1000, suspended=True + ) + store._entries["idle"] = _entry("idle", age_days=1000) + + removed = store.prune_old_entries(max_age_days=90) + + assert removed == 1 + assert "suspended" in store._entries + assert "idle" not in store._entries + + def test_prune_skips_entries_with_active_processes(self, tmp_path): + """Sessions with active bg processes aren't pruned even if old.""" + active_session_ids = {"sid_active"} + + def _has_active(session_id: str) -> bool: + return session_id in active_session_ids + + store = _make_store(tmp_path, has_active_processes_fn=_has_active) + store._entries["active"] = _entry( + "active", age_days=1000, session_id="sid_active" + ) + store._entries["idle"] = _entry( + "idle", age_days=1000, session_id="sid_idle" + ) + + removed = store.prune_old_entries(max_age_days=90) + + assert removed == 1 + assert "active" in store._entries + assert "idle" not in store._entries + + def test_prune_does_not_write_disk_when_no_removals(self, tmp_path): + """If nothing is evictable, _save() should NOT be called.""" + store = _make_store(tmp_path) + store._entries["fresh1"] = _entry("fresh1", age_days=1) + store._entries["fresh2"] = _entry("fresh2", age_days=2) + + save_calls = [] + store._save = lambda: save_calls.append(1) + + assert store.prune_old_entries(max_age_days=90) == 0 + assert save_calls == [] + + def test_prune_writes_disk_after_removal(self, tmp_path): + store = _make_store(tmp_path) + store._entries["stale"] = _entry("stale", age_days=500) + store._entries["fresh"] = _entry("fresh", age_days=1) + + save_calls = [] + store._save = lambda: save_calls.append(1) + + store.prune_old_entries(max_age_days=90) + assert save_calls == [1] + + def test_prune_is_thread_safe(self, tmp_path): + """Prune acquires _lock internally; concurrent update_session is safe.""" + store = _make_store(tmp_path) + for i in range(20): + age = 1000 if i % 2 == 0 else 1 + store._entries[f"s{i}"] = _entry(f"s{i}", age_days=age) + + results = [] + + def _pruner(): + results.append(store.prune_old_entries(max_age_days=90)) + + def _reader(): + # Mimic a concurrent update_session reader iterating under lock. + with store._lock: + list(store._entries.keys()) + + threads = [threading.Thread(target=_pruner)] + threads += [threading.Thread(target=_reader) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + assert not t.is_alive() + + # Exactly one pruner ran; removed exactly the 10 stale entries. + assert results == [10] + assert len(store._entries) == 10 + for i in range(20): + if i % 2 == 1: # fresh + assert f"s{i}" in store._entries + + +class TestPrunePersistsToDisk: + def test_prune_rewrites_sessions_json(self, tmp_path): + """After prune, sessions.json on disk reflects the new dict.""" + config = GatewayConfig( + default_reset_policy=SessionResetPolicy(mode="none"), + session_store_max_age_days=90, + ) + store = SessionStore(sessions_dir=tmp_path, config=config) + store._db = None + # Force-populate without calling get_or_create to avoid DB side-effects + store._entries["stale"] = _entry("stale", age_days=500) + store._entries["fresh"] = _entry("fresh", age_days=1) + store._loaded = True + store._save() + + # Verify pre-prune state on disk. + saved_pre = json.loads((tmp_path / "sessions.json").read_text()) + assert set(saved_pre.keys()) == {"stale", "fresh"} + + # Prune and check disk. + store.prune_old_entries(max_age_days=90) + saved_post = json.loads((tmp_path / "sessions.json").read_text()) + assert set(saved_post.keys()) == {"fresh"} + + +class TestGatewayConfigSerialization: + def test_session_store_max_age_days_defaults_to_90(self): + cfg = GatewayConfig() + assert cfg.session_store_max_age_days == 90 + + def test_session_store_max_age_days_roundtrips(self): + cfg = GatewayConfig(session_store_max_age_days=30) + restored = GatewayConfig.from_dict(cfg.to_dict()) + assert restored.session_store_max_age_days == 30 + + def test_session_store_max_age_days_missing_defaults_90(self): + """Loading an old config (pre-this-field) falls back to default.""" + restored = GatewayConfig.from_dict({}) + assert restored.session_store_max_age_days == 90 + + def test_session_store_max_age_days_negative_coerced_to_zero(self): + """A negative value (accidental or hostile) becomes 0 (disabled).""" + restored = GatewayConfig.from_dict({"session_store_max_age_days": -5}) + assert restored.session_store_max_age_days == 0 + + def test_session_store_max_age_days_bad_type_falls_back(self): + """Non-int values fall back to the default, not a crash.""" + restored = GatewayConfig.from_dict({"session_store_max_age_days": "nope"}) + assert restored.session_store_max_age_days == 90 + + +class TestGatewayWatcherCallsPrune: + """The session_expiry_watcher should call prune_old_entries once per hour.""" + + def test_prune_gate_fires_on_first_tick(self): + """First watcher tick has _last_prune_ts=0, so the gate opens.""" + import time as _t + + last_ts = 0.0 + prune_interval = 3600.0 + now = _t.time() + + # Mirror the production gate check in _session_expiry_watcher. + should_prune = (now - last_ts) > prune_interval + assert should_prune is True + + def test_prune_gate_suppresses_within_interval(self): + import time as _t + + last_ts = _t.time() - 600 # 10 minutes ago + prune_interval = 3600.0 + now = _t.time() + + should_prune = (now - last_ts) > prune_interval + assert should_prune is False diff --git a/tests/gateway/test_telegram_format.py b/tests/gateway/test_telegram_format.py index 1bd889b7c8..ce7e02a474 100644 --- a/tests/gateway/test_telegram_format.py +++ b/tests/gateway/test_telegram_format.py @@ -34,7 +34,12 @@ def _ensure_telegram_mock(): _ensure_telegram_mock() -from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2 # noqa: E402 +from gateway.platforms.telegram import ( # noqa: E402 + TelegramAdapter, + _escape_mdv2, + _strip_mdv2, + _wrap_markdown_tables, +) # --------------------------------------------------------------------------- @@ -535,6 +540,152 @@ class TestStripMdv2: assert _strip_mdv2("||hidden text||") == "hidden text" +# ========================================================================= +# Markdown table auto-wrap +# ========================================================================= + + +class TestWrapMarkdownTables: + """_wrap_markdown_tables wraps GFM pipe tables in ``` fences so + Telegram renders them as monospace preformatted text instead of the + noisy backslash-pipe mess MarkdownV2 produces.""" + + def test_basic_table_wrapped(self): + text = ( + "Scores:\n\n" + "| Player | Score |\n" + "|--------|-------|\n" + "| Alice | 150 |\n" + "| Bob | 120 |\n" + "\nEnd." + ) + out = _wrap_markdown_tables(text) + # Table is now wrapped in a fence + assert "```\n| Player | Score |" in out + assert "| Bob | 120 |\n```" in out + # Surrounding prose is preserved + assert out.startswith("Scores:") + assert out.endswith("End.") + + def test_bare_pipe_table_wrapped(self): + """Tables without outer pipes (GFM allows this) are still detected.""" + text = "head1 | head2\n--- | ---\na | b\nc | d" + out = _wrap_markdown_tables(text) + assert out.startswith("```\n") + assert out.rstrip().endswith("```") + assert "head1 | head2" in out + + def test_alignment_separators(self): + """Separator rows with :--- / ---: / :---: alignment markers match.""" + text = ( + "| Name | Age | City |\n" + "|:-----|----:|:----:|\n" + "| Ada | 30 | NYC |" + ) + out = _wrap_markdown_tables(text) + assert out.count("```") == 2 + + def test_two_consecutive_tables_wrapped_separately(self): + text = ( + "| A | B |\n" + "|---|---|\n" + "| 1 | 2 |\n" + "\n" + "| X | Y |\n" + "|---|---|\n" + "| 9 | 8 |" + ) + out = _wrap_markdown_tables(text) + # Four fences total — one opening + closing per table + assert out.count("```") == 4 + + def test_plain_text_with_pipes_not_wrapped(self): + """A bare pipe in prose must NOT trigger wrapping.""" + text = "Use the | pipe operator to chain commands." + assert _wrap_markdown_tables(text) == text + + def test_horizontal_rule_not_wrapped(self): + """A lone '---' horizontal rule must not be mistaken for a separator.""" + text = "Section A\n\n---\n\nSection B" + assert _wrap_markdown_tables(text) == text + + def test_existing_code_block_with_pipes_left_alone(self): + """A table already inside a fenced code block must not be re-wrapped.""" + text = ( + "```\n" + "| a | b |\n" + "|---|---|\n" + "| 1 | 2 |\n" + "```" + ) + assert _wrap_markdown_tables(text) == text + + def test_no_pipe_character_short_circuits(self): + text = "Plain **bold** text with no table." + assert _wrap_markdown_tables(text) == text + + def test_no_dash_short_circuits(self): + text = "a | b\nc | d" # has pipes but no '-' separator row + assert _wrap_markdown_tables(text) == text + + def test_single_column_separator_not_matched(self): + """Single-column tables (rare) are not detected — we require at + least one internal pipe in the separator row to avoid false + positives on formatting rules.""" + text = "| a |\n| - |\n| b |" + assert _wrap_markdown_tables(text) == text + + +class TestFormatMessageTables: + """End-to-end: a pipe table passes through format_message with its + pipes and dashes left alone inside the fence, not mangled by MarkdownV2 + escaping.""" + + def test_table_rendered_as_code_block(self, adapter): + text = ( + "Data:\n\n" + "| Col1 | Col2 |\n" + "|------|------|\n" + "| A | B |\n" + ) + out = adapter.format_message(text) + # Pipes inside the fenced block are NOT escaped + assert "```\n| Col1 | Col2 |" in out + assert "\\|" not in out.split("```")[1] + # Dashes in separator not escaped inside fence + assert "\\-" not in out.split("```")[1] + + def test_text_after_table_still_formatted(self, adapter): + text = ( + "| A | B |\n" + "|---|---|\n" + "| 1 | 2 |\n" + "\n" + "Nice **work** team!" + ) + out = adapter.format_message(text) + # MarkdownV2 bold conversion still happens outside the table + assert "*work*" in out + # Exclamation outside fence is escaped + assert "\\!" in out + + def test_multiple_tables_in_single_message(self, adapter): + text = ( + "First:\n" + "| A | B |\n" + "|---|---|\n" + "| 1 | 2 |\n" + "\n" + "Second:\n" + "| X | Y |\n" + "|---|---|\n" + "| 9 | 8 |\n" + ) + out = adapter.format_message(text) + # Two separate fenced blocks in the output + assert out.count("```") == 4 + + @pytest.mark.asyncio async def test_send_escapes_chunk_indicator_for_markdownv2(adapter): adapter.MAX_MESSAGE_LENGTH = 80 diff --git a/tests/hermes_cli/test_api_key_providers.py b/tests/hermes_cli/test_api_key_providers.py index 97deab89e4..c56edc4bb2 100644 --- a/tests/hermes_cli/test_api_key_providers.py +++ b/tests/hermes_cli/test_api_key_providers.py @@ -33,6 +33,7 @@ class TestProviderRegistry: ("huggingface", "Hugging Face", "api_key"), ("zai", "Z.AI / GLM", "api_key"), ("xai", "xAI", "api_key"), + ("nvidia", "NVIDIA NIM", "api_key"), ("kimi-coding", "Kimi / Moonshot", "api_key"), ("minimax", "MiniMax", "api_key"), ("minimax-cn", "MiniMax (China)", "api_key"), @@ -57,6 +58,12 @@ class TestProviderRegistry: assert pconfig.base_url_env_var == "XAI_BASE_URL" assert pconfig.inference_base_url == "https://api.x.ai/v1" + def test_nvidia_env_vars(self): + pconfig = PROVIDER_REGISTRY["nvidia"] + assert pconfig.api_key_env_vars == ("NVIDIA_API_KEY",) + assert pconfig.base_url_env_var == "NVIDIA_BASE_URL" + assert pconfig.inference_base_url == "https://integrate.api.nvidia.com/v1" + def test_copilot_env_vars(self): pconfig = PROVIDER_REGISTRY["copilot"] assert pconfig.api_key_env_vars == ("COPILOT_GITHUB_TOKEN", "GH_TOKEN", "GITHUB_TOKEN") diff --git a/tests/hermes_cli/test_gemini_provider.py b/tests/hermes_cli/test_gemini_provider.py index b448ca513f..089a5cf98d 100644 --- a/tests/hermes_cli/test_gemini_provider.py +++ b/tests/hermes_cli/test_gemini_provider.py @@ -178,10 +178,6 @@ class TestGeminiContextLength: ctx = get_model_context_length("gemma-4-31b-it", provider="gemini") assert ctx == 256000 - def test_gemma_4_26b_context(self): - ctx = get_model_context_length("gemma-4-26b-it", provider="gemini") - assert ctx == 256000 - def test_gemini_3_context(self): ctx = get_model_context_length("gemini-3.1-pro-preview", provider="gemini") assert ctx == 1048576 diff --git a/tests/hermes_cli/test_update_gateway_restart.py b/tests/hermes_cli/test_update_gateway_restart.py index f3f2a0444a..6e10d56222 100644 --- a/tests/hermes_cli/test_update_gateway_restart.py +++ b/tests/hermes_cli/test_update_gateway_restart.py @@ -13,9 +13,29 @@ from unittest.mock import patch, MagicMock import pytest import hermes_cli.gateway as gateway_cli +import hermes_cli.main as cli_main from hermes_cli.main import cmd_update +# --------------------------------------------------------------------------- +# Skip the real-time sleeps inside cmd_update's restart-verification path +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _no_restart_verify_sleep(monkeypatch): + """hermes_cli/main.py uses time.sleep(3) after systemctl restart to + verify the service survived. Tests mock subprocess.run — nothing + actually restarts — so the 3s wait is dead time. + + main.py does ``import time as _time`` at both module level (line 167) + and inside functions (lines 3281, 4384, 4401). Patching the global + ``time.sleep`` affects only the duration of this test. + """ + import time as _real_time + monkeypatch.setattr(_real_time, "sleep", lambda *_a, **_k: None) + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- diff --git a/tests/plugins/test_retaindb_plugin.py b/tests/plugins/test_retaindb_plugin.py index 9ad801769b..5d517bce77 100644 --- a/tests/plugins/test_retaindb_plugin.py +++ b/tests/plugins/test_retaindb_plugin.py @@ -31,6 +31,31 @@ def _isolate_env(tmp_path, monkeypatch): monkeypatch.delenv("RETAINDB_PROJECT", raising=False) +@pytest.fixture(autouse=True) +def _cap_retaindb_sleeps(monkeypatch): + """Cap production-code sleeps so background-thread tests run fast. + + The retaindb ``_WriteQueue._flush_row`` does ``time.sleep(2)`` after + errors. Across multiple tests that trigger the retry path, that adds + up. Cap the module's bound ``time.sleep`` to 0.05s — tests don't care + about the exact retry delay, only that it happens. The test file's + own ``time.sleep`` stays real since it uses a different reference. + """ + try: + from plugins.memory import retaindb as _retaindb + except ImportError: + return + + real_sleep = _retaindb.time.sleep + + def _capped_sleep(seconds): + return real_sleep(min(float(seconds), 0.05)) + + import types as _types + fake_time = _types.SimpleNamespace(sleep=_capped_sleep, time=_retaindb.time.time) + monkeypatch.setattr(_retaindb, "time", fake_time) + + # We need the repo root on sys.path so the plugin can import agent.memory_provider import sys _repo_root = str(Path(__file__).resolve().parents[2]) @@ -130,16 +155,18 @@ class TestWriteQueue: def test_enqueue_creates_row(self, tmp_path): q, client, db_path = self._make_queue(tmp_path) q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}]) - # Give the writer thread a moment to process - time.sleep(1) + # shutdown() blocks until the writer thread drains the queue — no need + # to pre-sleep (the old 1s sleep was a just-in-case wait, but shutdown + # does the right thing). q.shutdown() # If ingest succeeded, the row should be deleted client.ingest_session.assert_called_once() def test_enqueue_persists_to_sqlite(self, tmp_path): client = MagicMock() - # Make ingest hang so the row stays in SQLite - client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(5)) + # Make ingest slow so the row is still in SQLite when we peek. + # 0.5s is plenty — the test just needs the flush to still be in-flight. + client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(0.5)) db_path = tmp_path / "test_queue.db" q = _WriteQueue(client, db_path) q.enqueue("user1", "sess1", [{"role": "user", "content": "test"}]) @@ -154,8 +181,7 @@ class TestWriteQueue: def test_flush_deletes_row_on_success(self, tmp_path): q, client, db_path = self._make_queue(tmp_path) q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}]) - time.sleep(1) - q.shutdown() + q.shutdown() # blocks until drain # Row should be gone conn = sqlite3.connect(str(db_path)) rows = conn.execute("SELECT COUNT(*) FROM pending").fetchone()[0] @@ -168,14 +194,20 @@ class TestWriteQueue: db_path = tmp_path / "test_queue.db" q = _WriteQueue(client, db_path) q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}]) - time.sleep(3) # Allow retry + sleep(2) in _flush_row + # Poll for the error to be recorded (max 2s), instead of a fixed 3s wait. + deadline = time.time() + 2.0 + last_error = None + while time.time() < deadline: + conn = sqlite3.connect(str(db_path)) + row = conn.execute("SELECT last_error FROM pending").fetchone() + conn.close() + if row and row[0]: + last_error = row[0] + break + time.sleep(0.05) q.shutdown() - # Row should still exist with error recorded - conn = sqlite3.connect(str(db_path)) - row = conn.execute("SELECT last_error FROM pending").fetchone() - conn.close() - assert row is not None - assert "API down" in row[0] + assert last_error is not None + assert "API down" in last_error def test_thread_local_connection_reuse(self, tmp_path): q, _, _ = self._make_queue(tmp_path) @@ -193,14 +225,27 @@ class TestWriteQueue: client1.ingest_session = MagicMock(side_effect=RuntimeError("fail")) q1 = _WriteQueue(client1, db_path) q1.enqueue("user1", "sess1", [{"role": "user", "content": "lost turn"}]) - time.sleep(3) + # Wait until the error is recorded (poll with short interval). + deadline = time.time() + 2.0 + while time.time() < deadline: + conn = sqlite3.connect(str(db_path)) + row = conn.execute("SELECT last_error FROM pending").fetchone() + conn.close() + if row and row[0]: + break + time.sleep(0.05) q1.shutdown() # Now create a new queue — it should replay the pending rows client2 = MagicMock() client2.ingest_session = MagicMock(return_value={"status": "ok"}) q2 = _WriteQueue(client2, db_path) - time.sleep(2) + # Poll for the replay to happen. + deadline = time.time() + 2.0 + while time.time() < deadline: + if client2.ingest_session.called: + break + time.sleep(0.05) q2.shutdown() # The replayed row should have been ingested via client2 diff --git a/tests/run_agent/conftest.py b/tests/run_agent/conftest.py new file mode 100644 index 0000000000..9b431869bf --- /dev/null +++ b/tests/run_agent/conftest.py @@ -0,0 +1,34 @@ +"""Fast-path fixtures shared across tests/run_agent/. + +Many tests in this directory exercise the retry/backoff paths in the +agent loop. Production code uses ``jittered_backoff(base_delay=5.0)`` +with a ``while time.time() < sleep_end`` loop — a single retry test +spends 5+ seconds of real wall-clock time on backoff waits. + +Mocking ``jittered_backoff`` to return 0.0 collapses the while-loop +to a no-op (``time.time() < time.time() + 0`` is false immediately), +which handles the most common case without touching ``time.sleep``. + +We deliberately DO NOT mock ``time.sleep`` here — some tests +(test_interrupt_propagation, test_primary_runtime_restore, etc.) use +the real ``time.sleep`` for threading coordination or assert that it +was called with specific values. Tests that want to additionally +fast-path direct ``time.sleep(N)`` calls in production code should +monkeypatch ``run_agent.time.sleep`` locally (see +``test_anthropic_error_handling.py`` for the pattern). +""" + +from __future__ import annotations + +import pytest + + +@pytest.fixture(autouse=True) +def _fast_retry_backoff(monkeypatch): + """Short-circuit retry backoff for all tests in this directory.""" + try: + import run_agent + except ImportError: + return + + monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0) diff --git a/tests/run_agent/test_413_compression.py b/tests/run_agent/test_413_compression.py index e8835c6412..8bd357d3d2 100644 --- a/tests/run_agent/test_413_compression.py +++ b/tests/run_agent/test_413_compression.py @@ -19,6 +19,24 @@ import pytest from agent.context_compressor import SUMMARY_PREFIX from run_agent import AIAgent +import run_agent + + +# --------------------------------------------------------------------------- +# Fast backoff for compression retry tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _no_compression_sleep(monkeypatch): + """Short-circuit the 2s time.sleep between compression retries. + + Production code has ``time.sleep(2)`` in multiple places after a 413/context + compression, for rate-limit smoothing. Tests assert behavior, not timing. + """ + import time as _time + monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None) + monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0) # --------------------------------------------------------------------------- diff --git a/tests/run_agent/test_anthropic_error_handling.py b/tests/run_agent/test_anthropic_error_handling.py index 00055928e0..cdf3372544 100644 --- a/tests/run_agent/test_anthropic_error_handling.py +++ b/tests/run_agent/test_anthropic_error_handling.py @@ -27,6 +27,39 @@ from gateway.config import Platform from gateway.session import SessionSource +# --------------------------------------------------------------------------- +# Fast backoff for tests that exercise the retry loop +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _no_backoff_wait(monkeypatch): + """Short-circuit retry backoff so tests don't block on real wall-clock waits. + + The production code uses jittered_backoff() with a 5s base delay plus a + tight time.sleep(0.2) loop. Without this patch, each 429/500/529 retry + test burns ~10s of real time on CI — across six tests that's ~60s for + behavior we're not asserting against timing. + + Tests assert retry counts and final results, never wait durations. + """ + import asyncio as _asyncio + import time as _time + + monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0) + monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None) + + # Also fast-path asyncio.sleep — the gateway's _run_agent path has + # several await asyncio.sleep(...) calls that add real wall-clock time. + _real_asyncio_sleep = _asyncio.sleep + + async def _fast_sleep(delay=0, *args, **kwargs): + # Yield to the event loop but skip the actual delay. + await _real_asyncio_sleep(0) + + monkeypatch.setattr(_asyncio, "sleep", _fast_sleep) + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- diff --git a/tests/run_agent/test_exit_cleanup_interrupt.py b/tests/run_agent/test_exit_cleanup_interrupt.py index 6a5d7b363a..1e5d8431c3 100644 --- a/tests/run_agent/test_exit_cleanup_interrupt.py +++ b/tests/run_agent/test_exit_cleanup_interrupt.py @@ -13,6 +13,24 @@ from unittest.mock import MagicMock, patch, call import pytest +@pytest.fixture(autouse=True) +def _mock_runtime_provider(monkeypatch): + """run_job calls resolve_runtime_provider which can try real network + auto-detection (~4s of socket timeouts in hermetic CI). Mock it out + since these tests don't care about provider resolution — the agent + is mocked too.""" + import hermes_cli.runtime_provider as rp + def _fake_resolve(*args, **kwargs): + return { + "provider": "openrouter", + "api_key": "test-key", + "base_url": "https://openrouter.ai/api/v1", + "model": "test/model", + "api_mode": "chat_completions", + } + monkeypatch.setattr(rp, "resolve_runtime_provider", _fake_resolve) + + class TestCronJobCleanup: """cron/scheduler.py — end_session + close in the finally block.""" diff --git a/tests/run_agent/test_fallback_model.py b/tests/run_agent/test_fallback_model.py index 6491bd686d..d2aec022ef 100644 --- a/tests/run_agent/test_fallback_model.py +++ b/tests/run_agent/test_fallback_model.py @@ -11,6 +11,16 @@ from unittest.mock import MagicMock, patch import pytest from run_agent import AIAgent +import run_agent + + +@pytest.fixture(autouse=True) +def _no_fallback_wait(monkeypatch): + """Short-circuit time.sleep in fallback/recovery paths so tests don't + block on the ``min(3 + retry_count, 8)`` wait before a primary retry.""" + import time as _time + monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None) + monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0) def _make_tool_defs(*names: str) -> list: diff --git a/tests/run_agent/test_run_agent_codex_responses.py b/tests/run_agent/test_run_agent_codex_responses.py index 4ff00018d2..81213aaf67 100644 --- a/tests/run_agent/test_run_agent_codex_responses.py +++ b/tests/run_agent/test_run_agent_codex_responses.py @@ -12,6 +12,15 @@ sys.modules.setdefault("fal_client", types.SimpleNamespace()) import run_agent +@pytest.fixture(autouse=True) +def _no_codex_backoff(monkeypatch): + """Short-circuit retry backoff so Codex retry tests don't block on real + wall-clock waits (5s jittered_backoff base delay + tight time.sleep loop).""" + import time as _time + monkeypatch.setattr(run_agent, "jittered_backoff", lambda *a, **k: 0.0) + monkeypatch.setattr(_time, "sleep", lambda *_a, **_k: None) + + def _patch_agent_bootstrap(monkeypatch): monkeypatch.setattr( run_agent, diff --git a/tests/test_timezone.py b/tests/test_timezone.py index 1af60cbfa2..ffb831617d 100644 --- a/tests/test_timezone.py +++ b/tests/test_timezone.py @@ -159,18 +159,34 @@ class TestCodeExecutionTZ: return _json.dumps({"error": f"unexpected tool call: {function_name}"}) def test_tz_injected_when_configured(self): - """When HERMES_TIMEZONE is set, child process sees TZ env var.""" + """When HERMES_TIMEZONE is set, child process sees TZ env var. + + Verified alongside leak-prevention + empty-TZ handling in one + subprocess call so we don't pay 3x the subprocess startup cost + (each execute_code spawns a real Python subprocess ~3s). + """ import json as _json os.environ["HERMES_TIMEZONE"] = "Asia/Kolkata" + # One subprocess, three things checked: + # 1) TZ is injected as "Asia/Kolkata" + # 2) HERMES_TIMEZONE itself does NOT leak into the child env + probe = ( + 'import os; ' + 'print("TZ=" + os.environ.get("TZ", "NOT_SET")); ' + 'print("HERMES_TIMEZONE=" + os.environ.get("HERMES_TIMEZONE", "NOT_SET"))' + ) with patch("model_tools.handle_function_call", side_effect=self._mock_handle): result = _json.loads(self._execute_code( - code='import os; print(os.environ.get("TZ", "NOT_SET"))', - task_id="tz-test", + code=probe, + task_id="tz-combined-test", enabled_tools=[], )) assert result["status"] == "success" - assert "Asia/Kolkata" in result["output"] + assert "TZ=Asia/Kolkata" in result["output"] + assert "HERMES_TIMEZONE=NOT_SET" in result["output"], ( + "HERMES_TIMEZONE should not leak into child env (only TZ)" + ) def test_tz_not_injected_when_empty(self): """When HERMES_TIMEZONE is not set, child process has no TZ.""" @@ -186,20 +202,6 @@ class TestCodeExecutionTZ: assert result["status"] == "success" assert "NOT_SET" in result["output"] - def test_hermes_timezone_not_leaked_to_child(self): - """HERMES_TIMEZONE itself must NOT appear in child env (only TZ).""" - import json as _json - os.environ["HERMES_TIMEZONE"] = "Asia/Kolkata" - - with patch("model_tools.handle_function_call", side_effect=self._mock_handle): - result = _json.loads(self._execute_code( - code='import os; print(os.environ.get("HERMES_TIMEZONE", "NOT_SET"))', - task_id="tz-leak-test", - enabled_tools=[], - )) - assert result["status"] == "success" - assert "NOT_SET" in result["output"] - # ========================================================================= # Cron timezone-aware scheduling diff --git a/tests/tools/test_local_env_blocklist.py b/tests/tools/test_local_env_blocklist.py index b196cea781..0377d59b36 100644 --- a/tests/tools/test_local_env_blocklist.py +++ b/tests/tools/test_local_env_blocklist.py @@ -86,6 +86,7 @@ class TestProviderEnvBlocklist: "MINIMAX_API_KEY": "mm-key", "MINIMAX_CN_API_KEY": "mmcn-key", "DEEPSEEK_API_KEY": "deepseek-key", + "NVIDIA_API_KEY": "nvidia-key", } result_env = _run_with_env(extra_os_env=registry_vars) diff --git a/website/docs/getting-started/quickstart.md b/website/docs/getting-started/quickstart.md index 1f721586c9..428d23b7ce 100644 --- a/website/docs/getting-started/quickstart.md +++ b/website/docs/getting-started/quickstart.md @@ -61,6 +61,7 @@ hermes setup # Or configure everything at once | **OpenCode Zen** | Pay-as-you-go access to curated models | Set `OPENCODE_ZEN_API_KEY` | | **OpenCode Go** | $10/month subscription for open models | Set `OPENCODE_GO_API_KEY` | | **DeepSeek** | Direct DeepSeek API access | Set `DEEPSEEK_API_KEY` | +| **NVIDIA NIM** | Nemotron models via build.nvidia.com or local NIM | Set `NVIDIA_API_KEY` (optional: `NVIDIA_BASE_URL`) | | **GitHub Copilot** | GitHub Copilot subscription (GPT-5.x, Claude, Gemini, etc.) | OAuth via `hermes model`, or `COPILOT_GITHUB_TOKEN` / `GH_TOKEN` | | **GitHub Copilot ACP** | Copilot ACP agent backend (spawns local `copilot` CLI) | `hermes model` (requires `copilot` CLI + `copilot login`) | | **Vercel AI Gateway** | Vercel AI Gateway routing | Set `AI_GATEWAY_API_KEY` | diff --git a/website/docs/integrations/providers.md b/website/docs/integrations/providers.md index e3d0ad8284..750ad671cd 100644 --- a/website/docs/integrations/providers.md +++ b/website/docs/integrations/providers.md @@ -295,6 +295,30 @@ When using xAI as a provider (any base URL containing `x.ai`), Hermes automatica No configuration is needed — caching activates automatically when an xAI endpoint is detected and a session ID is available. This reduces latency and cost for multi-turn conversations. +### NVIDIA NIM + +Nemotron and other open source models via [build.nvidia.com](https://build.nvidia.com) (free API key) or a local NIM endpoint. + +```bash +# Cloud (build.nvidia.com) +hermes chat --provider nvidia --model nvidia/nemotron-3-super-120b-a12b +# Requires: NVIDIA_API_KEY in ~/.hermes/.env + +# Local NIM endpoint — override base URL +NVIDIA_BASE_URL=http://localhost:8000/v1 hermes chat --provider nvidia --model nvidia/nemotron-3-super-120b-a12b +``` + +Or set it permanently in `config.yaml`: +```yaml +model: + provider: "nvidia" + default: "nvidia/nemotron-3-super-120b-a12b" +``` + +:::tip Local NIM +For on-prem deployments (DGX Spark, local GPU), set `NVIDIA_BASE_URL=http://localhost:8000/v1`. NIM exposes the same OpenAI-compatible chat completions API as build.nvidia.com, so switching between cloud and local is a one-line env-var change. +::: + ### Hugging Face Inference Providers [Hugging Face Inference Providers](https://huggingface.co/docs/inference-providers) routes to 20+ open models through a unified OpenAI-compatible endpoint (`router.huggingface.co/v1`). Requests are automatically routed to the fastest available backend (Groq, Together, SambaNova, etc.) with automatic failover. diff --git a/website/docs/reference/environment-variables.md b/website/docs/reference/environment-variables.md index 6aa8197dbb..ead884ba7b 100644 --- a/website/docs/reference/environment-variables.md +++ b/website/docs/reference/environment-variables.md @@ -290,7 +290,7 @@ For cloud sandbox backends, persistence is filesystem-oriented. `TERMINAL_LIFETI | `QQ_ALLOWED_USERS` | Comma-separated QQ user openIDs allowed to message the bot | | `QQ_GROUP_ALLOWED_USERS` | Comma-separated QQ group IDs for group @-message access | | `QQ_ALLOW_ALL_USERS` | Allow all users (`true`/`false`, overrides `QQ_ALLOWED_USERS`) | -| `QQ_HOME_CHANNEL` | QQ user/group openID for cron delivery and notifications | +| `QQBOT_HOME_CHANNEL` | QQ user/group openID for cron delivery and notifications | | `MATTERMOST_URL` | Mattermost server URL (e.g. `https://mm.example.com`) | | `MATTERMOST_TOKEN` | Bot token or personal access token for Mattermost | | `MATTERMOST_ALLOWED_USERS` | Comma-separated Mattermost user IDs allowed to message the bot | diff --git a/website/docs/user-guide/features/fallback-providers.md b/website/docs/user-guide/features/fallback-providers.md index 1e2b2a8035..12fde185d4 100644 --- a/website/docs/user-guide/features/fallback-providers.md +++ b/website/docs/user-guide/features/fallback-providers.md @@ -47,6 +47,7 @@ Both `provider` and `model` are **required**. If either is missing, the fallback | MiniMax | `minimax` | `MINIMAX_API_KEY` | | MiniMax (China) | `minimax-cn` | `MINIMAX_CN_API_KEY` | | DeepSeek | `deepseek` | `DEEPSEEK_API_KEY` | +| NVIDIA NIM | `nvidia` | `NVIDIA_API_KEY` (optional: `NVIDIA_BASE_URL`) | | OpenCode Zen | `opencode-zen` | `OPENCODE_ZEN_API_KEY` | | OpenCode Go | `opencode-go` | `OPENCODE_GO_API_KEY` | | Kilo Code | `kilocode` | `KILOCODE_API_KEY` | diff --git a/website/docs/user-guide/messaging/qqbot.md b/website/docs/user-guide/messaging/qqbot.md index 686fd862e8..d9da90d586 100644 --- a/website/docs/user-guide/messaging/qqbot.md +++ b/website/docs/user-guide/messaging/qqbot.md @@ -48,8 +48,8 @@ QQ_CLIENT_SECRET=your-app-secret |---|---|---| | `QQ_APP_ID` | QQ Bot App ID (required) | — | | `QQ_CLIENT_SECRET` | QQ Bot App Secret (required) | — | -| `QQ_HOME_CHANNEL` | OpenID for cron/notification delivery | — | -| `QQ_HOME_CHANNEL_NAME` | Display name for home channel | `Home` | +| `QQBOT_HOME_CHANNEL` | OpenID for cron/notification delivery | — | +| `QQBOT_HOME_CHANNEL_NAME` | Display name for home channel | `Home` | | `QQ_ALLOWED_USERS` | Comma-separated user OpenIDs for DM access | open (all users) | | `QQ_ALLOW_ALL_USERS` | Set to `true` to allow all DMs | `false` | | `QQ_MARKDOWN_SUPPORT` | Enable QQ markdown (msg_type 2) | `true` | @@ -113,7 +113,7 @@ This usually means: - Verify the bot's **intents** are enabled at q.qq.com - Check `QQ_ALLOWED_USERS` if DM access is restricted - For group messages, ensure the bot is **@mentioned** (group policy may require allowlisting) -- Check `QQ_HOME_CHANNEL` for cron/notification delivery +- Check `QQBOT_HOME_CHANNEL` for cron/notification delivery ### Connection errors
A real terminal interfaceFull TUI with multiline editing, slash-command autocomplete, conversation history, interrupt-and-redirect, and streaming tool output.