mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-07 08:02:23 +00:00
Merge branch 'main' into fix/show-reasoning-per-platform
This commit is contained in:
commit
ed7b42f889
401 changed files with 66696 additions and 1966 deletions
4
.envrc
4
.envrc
|
|
@ -1 +1,5 @@
|
|||
watch_file pyproject.toml uv.lock
|
||||
watch_file ui-tui/package-lock.json ui-tui/package.json
|
||||
watch_file flake.nix flake.lock nix/devShell.nix nix/tui.nix nix/package.nix nix/python.nix
|
||||
|
||||
use flake
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -60,5 +60,6 @@ mini-swe-agent/
|
|||
|
||||
# Nix
|
||||
.direnv/
|
||||
.nix-stamps/
|
||||
result
|
||||
website/static/api/skills-index.json
|
||||
|
|
|
|||
66
AGENTS.md
66
AGENTS.md
|
|
@ -56,6 +56,19 @@ hermes-agent/
|
|||
│ ├── run.py # Main loop, slash commands, message dispatch
|
||||
│ ├── session.py # SessionStore — conversation persistence
|
||||
│ └── platforms/ # Adapters: telegram, discord, slack, whatsapp, homeassistant, signal, qqbot
|
||||
├── ui-tui/ # Ink (React) terminal UI — `hermes --tui`
|
||||
│ ├── src/entry.tsx # TTY gate + render()
|
||||
│ ├── src/app.tsx # Main state machine and UI
|
||||
│ ├── src/gatewayClient.ts # Child process + JSON-RPC bridge
|
||||
│ ├── src/app/ # Decomposed app logic (event handler, slash handler, stores, hooks)
|
||||
│ ├── src/components/ # Ink components (branding, markdown, prompts, pickers, etc.)
|
||||
│ ├── src/hooks/ # useCompletion, useInputHistory, useQueue, useVirtualHistory
|
||||
│ └── src/lib/ # Pure helpers (history, osc52, text, rpc, messages)
|
||||
├── tui_gateway/ # Python JSON-RPC backend for the TUI
|
||||
│ ├── entry.py # stdio entrypoint
|
||||
│ ├── server.py # RPC handlers and session logic
|
||||
│ ├── render.py # Optional rich/ANSI bridge
|
||||
│ └── slash_worker.py # Persistent HermesCLI subprocess for slash commands
|
||||
├── acp_adapter/ # ACP server (VS Code / Zed / JetBrains integration)
|
||||
├── cron/ # Scheduler (jobs.py, scheduler.py)
|
||||
├── environments/ # RL training environments (Atropos)
|
||||
|
|
@ -179,6 +192,59 @@ if canonical == "mycommand":
|
|||
|
||||
---
|
||||
|
||||
## TUI Architecture (ui-tui + tui_gateway)
|
||||
|
||||
The TUI is a full replacement for the classic (prompt_toolkit) CLI, activated via `hermes --tui` or `HERMES_TUI=1`.
|
||||
|
||||
### Process Model
|
||||
|
||||
```
|
||||
hermes --tui
|
||||
└─ Node (Ink) ──stdio JSON-RPC── Python (tui_gateway)
|
||||
│ └─ AIAgent + tools + sessions
|
||||
└─ renders transcript, composer, prompts, activity
|
||||
```
|
||||
|
||||
TypeScript owns the screen. Python owns sessions, tools, model calls, and slash command logic.
|
||||
|
||||
### Transport
|
||||
|
||||
Newline-delimited JSON-RPC over stdio. Requests from Ink, events from Python. See `tui_gateway/server.py` for the full method/event catalog.
|
||||
|
||||
### Key Surfaces
|
||||
|
||||
| Surface | Ink component | Gateway method |
|
||||
|---------|---------------|----------------|
|
||||
| Chat streaming | `app.tsx` + `messageLine.tsx` | `prompt.submit` → `message.delta/complete` |
|
||||
| Tool activity | `thinking.tsx` | `tool.start/progress/complete` |
|
||||
| Approvals | `prompts.tsx` | `approval.respond` ← `approval.request` |
|
||||
| Clarify/sudo/secret | `prompts.tsx`, `maskedPrompt.tsx` | `clarify/sudo/secret.respond` |
|
||||
| Session picker | `sessionPicker.tsx` | `session.list/resume` |
|
||||
| Slash commands | Local handler + fallthrough | `slash.exec` → `_SlashWorker`, `command.dispatch` |
|
||||
| Completions | `useCompletion` hook | `complete.slash`, `complete.path` |
|
||||
| Theming | `theme.ts` + `branding.tsx` | `gateway.ready` with skin data |
|
||||
|
||||
### Slash Command Flow
|
||||
|
||||
1. Built-in client commands (`/help`, `/quit`, `/clear`, `/resume`, `/copy`, `/paste`, etc.) handled locally in `app.tsx`
|
||||
2. Everything else → `slash.exec` (runs in persistent `_SlashWorker` subprocess) → `command.dispatch` fallback
|
||||
|
||||
### Dev Commands
|
||||
|
||||
```bash
|
||||
cd ui-tui
|
||||
npm install # first time
|
||||
npm run dev # watch mode (rebuilds hermes-ink + tsx --watch)
|
||||
npm start # production
|
||||
npm run build # full build (hermes-ink + tsc)
|
||||
npm run type-check # typecheck only (tsc --noEmit)
|
||||
npm run lint # eslint
|
||||
npm run fmt # prettier
|
||||
npm test # vitest
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Adding New Tools
|
||||
|
||||
Requires changes in **2 files**:
|
||||
|
|
|
|||
11
README.md
11
README.md
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
**The self-improving AI agent built by [Nous Research](https://nousresearch.com).** It's the only agent with a built-in learning loop — it creates skills from experience, improves them during use, nudges itself to persist knowledge, searches its own past conversations, and builds a deepening model of who you are across sessions. Run it on a $5 VPS, a GPU cluster, or serverless infrastructure that costs nearly nothing when idle. It's not tied to your laptop — talk to it from Telegram while it works on a cloud VM.
|
||||
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [Xiaomi MiMo](https://platform.xiaomimimo.com), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), [Hugging Face](https://huggingface.co), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [NVIDIA NIM](https://build.nvidia.com) (Nemotron), [Xiaomi MiMo](https://platform.xiaomimimo.com), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), [Hugging Face](https://huggingface.co), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
|
||||
<table>
|
||||
<tr><td><b>A real terminal interface</b></td><td>Full TUI with multiline editing, slash-command autocomplete, conversation history, interrupt-and-redirect, and streaming tool output.</td></tr>
|
||||
|
|
@ -141,11 +141,18 @@ See `hermes claw migrate --help` for all options, or use the `openclaw-migration
|
|||
|
||||
We welcome contributions! See the [Contributing Guide](https://hermes-agent.nousresearch.com/docs/developer-guide/contributing) for development setup, code style, and PR process.
|
||||
|
||||
Quick start for contributors:
|
||||
Quick start for contributors — clone and go with `setup-hermes.sh`:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
./setup-hermes.sh # installs uv, creates venv, installs .[all], symlinks ~/.local/bin/hermes
|
||||
./hermes # auto-detects the venv, no need to `source` first
|
||||
```
|
||||
|
||||
Manual path (equivalent to the above):
|
||||
|
||||
```bash
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
uv venv venv --python 3.11
|
||||
source venv/bin/activate
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ def make_tool_progress_cb(
|
|||
session_id: str,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
tool_call_ids: Dict[str, Deque[str]],
|
||||
tool_call_meta: Dict[str, Dict[str, Any]],
|
||||
) -> Callable:
|
||||
"""Create a ``tool_progress_callback`` for AIAgent.
|
||||
|
||||
|
|
@ -84,6 +85,16 @@ def make_tool_progress_cb(
|
|||
tool_call_ids[name] = queue
|
||||
queue.append(tc_id)
|
||||
|
||||
snapshot = None
|
||||
if name in {"write_file", "patch", "skill_manage"}:
|
||||
try:
|
||||
from agent.display import capture_local_edit_snapshot
|
||||
|
||||
snapshot = capture_local_edit_snapshot(name, args)
|
||||
except Exception:
|
||||
logger.debug("Failed to capture ACP edit snapshot for %s", name, exc_info=True)
|
||||
tool_call_meta[tc_id] = {"args": args, "snapshot": snapshot}
|
||||
|
||||
update = build_tool_start(tc_id, name, args)
|
||||
_send_update(conn, session_id, loop, update)
|
||||
|
||||
|
|
@ -119,6 +130,7 @@ def make_step_cb(
|
|||
session_id: str,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
tool_call_ids: Dict[str, Deque[str]],
|
||||
tool_call_meta: Dict[str, Dict[str, Any]],
|
||||
) -> Callable:
|
||||
"""Create a ``step_callback`` for AIAgent.
|
||||
|
||||
|
|
@ -132,10 +144,12 @@ def make_step_cb(
|
|||
for tool_info in prev_tools:
|
||||
tool_name = None
|
||||
result = None
|
||||
function_args = None
|
||||
|
||||
if isinstance(tool_info, dict):
|
||||
tool_name = tool_info.get("name") or tool_info.get("function_name")
|
||||
result = tool_info.get("result") or tool_info.get("output")
|
||||
function_args = tool_info.get("arguments") or tool_info.get("args")
|
||||
elif isinstance(tool_info, str):
|
||||
tool_name = tool_info
|
||||
|
||||
|
|
@ -145,8 +159,13 @@ def make_step_cb(
|
|||
tool_call_ids[tool_name] = queue
|
||||
if tool_name and queue:
|
||||
tc_id = queue.popleft()
|
||||
meta = tool_call_meta.pop(tc_id, {})
|
||||
update = build_tool_complete(
|
||||
tc_id, tool_name, result=str(result) if result is not None else None
|
||||
tc_id,
|
||||
tool_name,
|
||||
result=str(result) if result is not None else None,
|
||||
function_args=function_args or meta.get("args"),
|
||||
snapshot=meta.get("snapshot"),
|
||||
)
|
||||
_send_update(conn, session_id, loop, update)
|
||||
if not queue:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from acp.schema import (
|
|||
McpServerHttp,
|
||||
McpServerSse,
|
||||
McpServerStdio,
|
||||
ModelInfo,
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
|
|
@ -36,6 +37,7 @@ from acp.schema import (
|
|||
SessionCapabilities,
|
||||
SessionForkCapabilities,
|
||||
SessionListCapabilities,
|
||||
SessionModelState,
|
||||
SessionResumeCapabilities,
|
||||
SessionInfo,
|
||||
TextContentBlock,
|
||||
|
|
@ -147,6 +149,98 @@ class HermesACPAgent(acp.Agent):
|
|||
self._conn = conn
|
||||
logger.info("ACP client connected")
|
||||
|
||||
@staticmethod
|
||||
def _encode_model_choice(provider: str | None, model: str | None) -> str:
|
||||
"""Encode a model selection so ACP clients can keep provider context."""
|
||||
raw_model = str(model or "").strip()
|
||||
if not raw_model:
|
||||
return ""
|
||||
raw_provider = str(provider or "").strip().lower()
|
||||
if not raw_provider:
|
||||
return raw_model
|
||||
return f"{raw_provider}:{raw_model}"
|
||||
|
||||
def _build_model_state(self, state: SessionState) -> SessionModelState | None:
|
||||
"""Return the ACP model selector payload for editors like Zed."""
|
||||
model = str(state.model or getattr(state.agent, "model", "") or "").strip()
|
||||
provider = getattr(state.agent, "provider", None) or detect_provider() or "openrouter"
|
||||
|
||||
try:
|
||||
from hermes_cli.models import curated_models_for_provider, normalize_provider, provider_label
|
||||
|
||||
normalized_provider = normalize_provider(provider)
|
||||
provider_name = provider_label(normalized_provider)
|
||||
available_models: list[ModelInfo] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
for model_id, description in curated_models_for_provider(normalized_provider):
|
||||
rendered_model = str(model_id or "").strip()
|
||||
if not rendered_model:
|
||||
continue
|
||||
choice_id = self._encode_model_choice(normalized_provider, rendered_model)
|
||||
if choice_id in seen_ids:
|
||||
continue
|
||||
desc_parts = [f"Provider: {provider_name}"]
|
||||
if description:
|
||||
desc_parts.append(str(description).strip())
|
||||
if rendered_model == model:
|
||||
desc_parts.append("current")
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
model_id=choice_id,
|
||||
name=rendered_model,
|
||||
description=" • ".join(part for part in desc_parts if part),
|
||||
)
|
||||
)
|
||||
seen_ids.add(choice_id)
|
||||
|
||||
current_model_id = self._encode_model_choice(normalized_provider, model)
|
||||
if current_model_id and current_model_id not in seen_ids:
|
||||
available_models.insert(
|
||||
0,
|
||||
ModelInfo(
|
||||
model_id=current_model_id,
|
||||
name=model,
|
||||
description=f"Provider: {provider_name} • current",
|
||||
),
|
||||
)
|
||||
|
||||
if available_models:
|
||||
return SessionModelState(
|
||||
available_models=available_models,
|
||||
current_model_id=current_model_id or available_models[0].model_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Could not build ACP model state", exc_info=True)
|
||||
|
||||
if not model:
|
||||
return None
|
||||
|
||||
fallback_choice = self._encode_model_choice(provider, model)
|
||||
return SessionModelState(
|
||||
available_models=[ModelInfo(model_id=fallback_choice, name=model)],
|
||||
current_model_id=fallback_choice,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_model_selection(raw_model: str, current_provider: str) -> tuple[str, str]:
|
||||
"""Resolve ``provider:model`` input into the provider and normalized model id."""
|
||||
target_provider = current_provider
|
||||
new_model = raw_model.strip()
|
||||
|
||||
try:
|
||||
from hermes_cli.models import detect_provider_for_model, parse_model_input
|
||||
|
||||
target_provider, new_model = parse_model_input(new_model, current_provider)
|
||||
if target_provider == current_provider:
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
except Exception:
|
||||
logger.debug("Provider detection failed, using model as-is", exc_info=True)
|
||||
|
||||
return target_provider, new_model
|
||||
|
||||
async def _register_session_mcp_servers(
|
||||
self,
|
||||
state: SessionState,
|
||||
|
|
@ -273,7 +367,10 @@ class HermesACPAgent(acp.Agent):
|
|||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("New session %s (cwd=%s)", state.session_id, cwd)
|
||||
self._schedule_available_commands_update(state.session_id)
|
||||
return NewSessionResponse(session_id=state.session_id)
|
||||
return NewSessionResponse(
|
||||
session_id=state.session_id,
|
||||
models=self._build_model_state(state),
|
||||
)
|
||||
|
||||
async def load_session(
|
||||
self,
|
||||
|
|
@ -289,7 +386,7 @@ class HermesACPAgent(acp.Agent):
|
|||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Loaded session %s", session_id)
|
||||
self._schedule_available_commands_update(session_id)
|
||||
return LoadSessionResponse()
|
||||
return LoadSessionResponse(models=self._build_model_state(state))
|
||||
|
||||
async def resume_session(
|
||||
self,
|
||||
|
|
@ -305,7 +402,7 @@ class HermesACPAgent(acp.Agent):
|
|||
await self._register_session_mcp_servers(state, mcp_servers)
|
||||
logger.info("Resumed session %s", state.session_id)
|
||||
self._schedule_available_commands_update(state.session_id)
|
||||
return ResumeSessionResponse()
|
||||
return ResumeSessionResponse(models=self._build_model_state(state))
|
||||
|
||||
async def cancel(self, session_id: str, **kwargs: Any) -> None:
|
||||
state = self.session_manager.get_session(session_id)
|
||||
|
|
@ -340,11 +437,20 @@ class HermesACPAgent(acp.Agent):
|
|||
cwd: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ListSessionsResponse:
|
||||
infos = self.session_manager.list_sessions()
|
||||
sessions = [
|
||||
SessionInfo(session_id=s["session_id"], cwd=s["cwd"])
|
||||
for s in infos
|
||||
]
|
||||
infos = self.session_manager.list_sessions(cwd=cwd)
|
||||
sessions = []
|
||||
for s in infos:
|
||||
updated_at = s.get("updated_at")
|
||||
if updated_at is not None and not isinstance(updated_at, str):
|
||||
updated_at = str(updated_at)
|
||||
sessions.append(
|
||||
SessionInfo(
|
||||
session_id=s["session_id"],
|
||||
cwd=s["cwd"],
|
||||
title=s.get("title"),
|
||||
updated_at=updated_at,
|
||||
)
|
||||
)
|
||||
return ListSessionsResponse(sessions=sessions)
|
||||
|
||||
# ---- Prompt (core) ------------------------------------------------------
|
||||
|
|
@ -389,12 +495,13 @@ class HermesACPAgent(acp.Agent):
|
|||
state.cancel_event.clear()
|
||||
|
||||
tool_call_ids: dict[str, Deque[str]] = defaultdict(deque)
|
||||
tool_call_meta: dict[str, dict[str, Any]] = {}
|
||||
previous_approval_cb = None
|
||||
|
||||
if conn:
|
||||
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids)
|
||||
tool_progress_cb = make_tool_progress_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
|
||||
thinking_cb = make_thinking_cb(conn, session_id, loop)
|
||||
step_cb = make_step_cb(conn, session_id, loop, tool_call_ids)
|
||||
step_cb = make_step_cb(conn, session_id, loop, tool_call_ids, tool_call_meta)
|
||||
message_cb = make_message_cb(conn, session_id, loop)
|
||||
approval_cb = make_approval_callback(conn.request_permission, loop, session_id)
|
||||
else:
|
||||
|
|
@ -449,6 +556,19 @@ class HermesACPAgent(acp.Agent):
|
|||
self.session_manager.save_session(session_id)
|
||||
|
||||
final_response = result.get("final_response", "")
|
||||
if final_response:
|
||||
try:
|
||||
from agent.title_generator import maybe_auto_title
|
||||
|
||||
maybe_auto_title(
|
||||
self.session_manager._get_db(),
|
||||
session_id,
|
||||
user_text,
|
||||
final_response,
|
||||
state.history,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to auto-title ACP session %s", session_id, exc_info=True)
|
||||
if final_response and conn:
|
||||
update = acp.update_agent_message_text(final_response)
|
||||
await conn.session_update(session_id, update)
|
||||
|
|
@ -556,27 +676,15 @@ class HermesACPAgent(acp.Agent):
|
|||
provider = getattr(state.agent, "provider", None) or "auto"
|
||||
return f"Current model: {model}\nProvider: {provider}"
|
||||
|
||||
new_model = args.strip()
|
||||
target_provider = None
|
||||
current_provider = getattr(state.agent, "provider", None) or "openrouter"
|
||||
|
||||
# Auto-detect provider for the requested model
|
||||
try:
|
||||
from hermes_cli.models import parse_model_input, detect_provider_for_model
|
||||
target_provider, new_model = parse_model_input(new_model, current_provider)
|
||||
if target_provider == current_provider:
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
except Exception:
|
||||
logger.debug("Provider detection failed, using model as-is", exc_info=True)
|
||||
target_provider, new_model = self._resolve_model_selection(args, current_provider)
|
||||
|
||||
state.model = new_model
|
||||
state.agent = self.session_manager._make_agent(
|
||||
session_id=state.session_id,
|
||||
cwd=state.cwd,
|
||||
model=new_model,
|
||||
requested_provider=target_provider or current_provider,
|
||||
requested_provider=target_provider,
|
||||
)
|
||||
self.session_manager.save_session(state.session_id)
|
||||
provider_label = getattr(state.agent, "provider", None) or target_provider or current_provider
|
||||
|
|
@ -678,20 +786,30 @@ class HermesACPAgent(acp.Agent):
|
|||
"""Switch the model for a session (called by ACP protocol)."""
|
||||
state = self.session_manager.get_session(session_id)
|
||||
if state:
|
||||
state.model = model_id
|
||||
current_provider = getattr(state.agent, "provider", None)
|
||||
current_base_url = getattr(state.agent, "base_url", None)
|
||||
current_api_mode = getattr(state.agent, "api_mode", None)
|
||||
requested_provider, resolved_model = self._resolve_model_selection(
|
||||
model_id,
|
||||
current_provider or "openrouter",
|
||||
)
|
||||
state.model = resolved_model
|
||||
provider_changed = bool(current_provider and requested_provider != current_provider)
|
||||
current_base_url = None if provider_changed else getattr(state.agent, "base_url", None)
|
||||
current_api_mode = None if provider_changed else getattr(state.agent, "api_mode", None)
|
||||
state.agent = self.session_manager._make_agent(
|
||||
session_id=session_id,
|
||||
cwd=state.cwd,
|
||||
model=model_id,
|
||||
requested_provider=current_provider,
|
||||
model=resolved_model,
|
||||
requested_provider=requested_provider,
|
||||
base_url=current_base_url,
|
||||
api_mode=current_api_mode,
|
||||
)
|
||||
self.session_manager.save_session(session_id)
|
||||
logger.info("Session %s: model switched to %s", session_id, model_id)
|
||||
logger.info(
|
||||
"Session %s: model switched to %s via provider %s",
|
||||
session_id,
|
||||
resolved_model,
|
||||
requested_provider,
|
||||
)
|
||||
return SetSessionModelResponse()
|
||||
logger.warning("Session %s: model switch requested for missing session", session_id)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -13,8 +13,12 @@ from hermes_constants import get_hermes_home
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
|
@ -22,6 +26,64 @@ from typing import Any, Dict, List, Optional
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_cwd_for_compare(cwd: str | None) -> str:
|
||||
raw = str(cwd or ".").strip()
|
||||
if not raw:
|
||||
raw = "."
|
||||
expanded = os.path.expanduser(raw)
|
||||
|
||||
# Normalize Windows drive paths into the equivalent WSL mount form so
|
||||
# ACP history filters match the same workspace across Windows and WSL.
|
||||
match = re.match(r"^([A-Za-z]):[\\/](.*)$", expanded)
|
||||
if match:
|
||||
drive = match.group(1).lower()
|
||||
tail = match.group(2).replace("\\", "/")
|
||||
expanded = f"/mnt/{drive}/{tail}"
|
||||
elif re.match(r"^/mnt/[A-Za-z]/", expanded):
|
||||
expanded = f"/mnt/{expanded[5].lower()}/{expanded[7:]}"
|
||||
|
||||
return os.path.normpath(expanded)
|
||||
|
||||
|
||||
def _build_session_title(title: Any, preview: Any, cwd: str | None) -> str:
|
||||
explicit = str(title or "").strip()
|
||||
if explicit:
|
||||
return explicit
|
||||
preview_text = str(preview or "").strip()
|
||||
if preview_text:
|
||||
return preview_text
|
||||
leaf = os.path.basename(str(cwd or "").rstrip("/\\"))
|
||||
return leaf or "New thread"
|
||||
|
||||
|
||||
def _format_updated_at(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value
|
||||
try:
|
||||
return datetime.fromtimestamp(float(value), tz=timezone.utc).isoformat()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _updated_at_sort_key(value: Any) -> float:
|
||||
if value is None:
|
||||
return float("-inf")
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
raw = str(value).strip()
|
||||
if not raw:
|
||||
return float("-inf")
|
||||
try:
|
||||
return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp()
|
||||
except Exception:
|
||||
try:
|
||||
return float(raw)
|
||||
except Exception:
|
||||
return float("-inf")
|
||||
|
||||
|
||||
def _acp_stderr_print(*args, **kwargs) -> None:
|
||||
"""Best-effort human-readable output sink for ACP stdio sessions.
|
||||
|
||||
|
|
@ -162,47 +224,78 @@ class SessionManager:
|
|||
logger.info("Forked ACP session %s -> %s", session_id, new_id)
|
||||
return state
|
||||
|
||||
def list_sessions(self) -> List[Dict[str, Any]]:
|
||||
def list_sessions(self, cwd: str | None = None) -> List[Dict[str, Any]]:
|
||||
"""Return lightweight info dicts for all sessions (memory + database)."""
|
||||
normalized_cwd = _normalize_cwd_for_compare(cwd) if cwd else None
|
||||
db = self._get_db()
|
||||
persisted_rows: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if db is not None:
|
||||
try:
|
||||
for row in db.list_sessions_rich(source="acp", limit=1000):
|
||||
persisted_rows[str(row["id"])] = dict(row)
|
||||
except Exception:
|
||||
logger.debug("Failed to load ACP sessions from DB", exc_info=True)
|
||||
|
||||
# Collect in-memory sessions first.
|
||||
with self._lock:
|
||||
seen_ids = set(self._sessions.keys())
|
||||
results = [
|
||||
{
|
||||
"session_id": s.session_id,
|
||||
"cwd": s.cwd,
|
||||
"model": s.model,
|
||||
"history_len": len(s.history),
|
||||
}
|
||||
for s in self._sessions.values()
|
||||
]
|
||||
results = []
|
||||
for s in self._sessions.values():
|
||||
history_len = len(s.history)
|
||||
if history_len <= 0:
|
||||
continue
|
||||
if normalized_cwd and _normalize_cwd_for_compare(s.cwd) != normalized_cwd:
|
||||
continue
|
||||
persisted = persisted_rows.get(s.session_id, {})
|
||||
preview = next(
|
||||
(
|
||||
str(msg.get("content") or "").strip()
|
||||
for msg in s.history
|
||||
if msg.get("role") == "user" and str(msg.get("content") or "").strip()
|
||||
),
|
||||
persisted.get("preview") or "",
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"session_id": s.session_id,
|
||||
"cwd": s.cwd,
|
||||
"model": s.model,
|
||||
"history_len": history_len,
|
||||
"title": _build_session_title(persisted.get("title"), preview, s.cwd),
|
||||
"updated_at": _format_updated_at(
|
||||
persisted.get("last_active") or persisted.get("started_at") or time.time()
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Merge any persisted sessions not currently in memory.
|
||||
db = self._get_db()
|
||||
if db is not None:
|
||||
try:
|
||||
rows = db.search_sessions(source="acp", limit=1000)
|
||||
for row in rows:
|
||||
sid = row["id"]
|
||||
if sid in seen_ids:
|
||||
continue
|
||||
# Extract cwd from model_config JSON.
|
||||
cwd = "."
|
||||
mc = row.get("model_config")
|
||||
if mc:
|
||||
try:
|
||||
cwd = json.loads(mc).get("cwd", ".")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
results.append({
|
||||
"session_id": sid,
|
||||
"cwd": cwd,
|
||||
"model": row.get("model") or "",
|
||||
"history_len": row.get("message_count") or 0,
|
||||
})
|
||||
except Exception:
|
||||
logger.debug("Failed to list ACP sessions from DB", exc_info=True)
|
||||
for sid, row in persisted_rows.items():
|
||||
if sid in seen_ids:
|
||||
continue
|
||||
message_count = int(row.get("message_count") or 0)
|
||||
if message_count <= 0:
|
||||
continue
|
||||
# Extract cwd from model_config JSON.
|
||||
session_cwd = "."
|
||||
mc = row.get("model_config")
|
||||
if mc:
|
||||
try:
|
||||
session_cwd = json.loads(mc).get("cwd", ".")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
if normalized_cwd and _normalize_cwd_for_compare(session_cwd) != normalized_cwd:
|
||||
continue
|
||||
results.append({
|
||||
"session_id": sid,
|
||||
"cwd": session_cwd,
|
||||
"model": row.get("model") or "",
|
||||
"history_len": message_count,
|
||||
"title": _build_session_title(row.get("title"), row.get("preview"), session_cwd),
|
||||
"updated_at": _format_updated_at(row.get("last_active") or row.get("started_at")),
|
||||
})
|
||||
|
||||
results.sort(key=lambda item: _updated_at_sort_key(item.get("updated_at")), reverse=True)
|
||||
return results
|
||||
|
||||
def update_cwd(self, session_id: str, cwd: str) -> Optional[SessionState]:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
|
@ -96,6 +97,170 @@ def build_tool_title(tool_name: str, args: Dict[str, Any]) -> str:
|
|||
return tool_name
|
||||
|
||||
|
||||
def _build_patch_mode_content(patch_text: str) -> List[Any]:
|
||||
"""Parse V4A patch mode input into ACP diff blocks when possible."""
|
||||
if not patch_text:
|
||||
return [acp.tool_content(acp.text_block(""))]
|
||||
|
||||
try:
|
||||
from tools.patch_parser import OperationType, parse_v4a_patch
|
||||
|
||||
operations, error = parse_v4a_patch(patch_text)
|
||||
if error or not operations:
|
||||
return [acp.tool_content(acp.text_block(patch_text))]
|
||||
|
||||
content: List[Any] = []
|
||||
for op in operations:
|
||||
if op.operation == OperationType.UPDATE:
|
||||
old_chunks: list[str] = []
|
||||
new_chunks: list[str] = []
|
||||
for hunk in op.hunks:
|
||||
old_lines = [line.content for line in hunk.lines if line.prefix in (" ", "-")]
|
||||
new_lines = [line.content for line in hunk.lines if line.prefix in (" ", "+")]
|
||||
if old_lines or new_lines:
|
||||
old_chunks.append("\n".join(old_lines))
|
||||
new_chunks.append("\n".join(new_lines))
|
||||
|
||||
old_text = "\n...\n".join(chunk for chunk in old_chunks if chunk)
|
||||
new_text = "\n...\n".join(chunk for chunk in new_chunks if chunk)
|
||||
if old_text or new_text:
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=op.file_path,
|
||||
old_text=old_text or None,
|
||||
new_text=new_text or "",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if op.operation == OperationType.ADD:
|
||||
added_lines = [line.content for hunk in op.hunks for line in hunk.lines if line.prefix == "+"]
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=op.file_path,
|
||||
new_text="\n".join(added_lines),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if op.operation == OperationType.DELETE:
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=op.file_path,
|
||||
old_text=f"Delete file: {op.file_path}",
|
||||
new_text="",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if op.operation == OperationType.MOVE:
|
||||
content.append(
|
||||
acp.tool_content(acp.text_block(f"Move file: {op.file_path} -> {op.new_path}"))
|
||||
)
|
||||
|
||||
return content or [acp.tool_content(acp.text_block(patch_text))]
|
||||
except Exception:
|
||||
return [acp.tool_content(acp.text_block(patch_text))]
|
||||
|
||||
|
||||
def _strip_diff_prefix(path: str) -> str:
|
||||
raw = str(path or "").strip()
|
||||
if raw.startswith(("a/", "b/")):
|
||||
return raw[2:]
|
||||
return raw
|
||||
|
||||
|
||||
def _parse_unified_diff_content(diff_text: str) -> List[Any]:
|
||||
"""Convert unified diff text into ACP diff content blocks."""
|
||||
if not diff_text:
|
||||
return []
|
||||
|
||||
content: List[Any] = []
|
||||
current_old_path: Optional[str] = None
|
||||
current_new_path: Optional[str] = None
|
||||
old_lines: list[str] = []
|
||||
new_lines: list[str] = []
|
||||
|
||||
def _flush() -> None:
|
||||
nonlocal current_old_path, current_new_path, old_lines, new_lines
|
||||
if current_old_path is None and current_new_path is None:
|
||||
return
|
||||
path = current_new_path if current_new_path and current_new_path != "/dev/null" else current_old_path
|
||||
if not path or path == "/dev/null":
|
||||
current_old_path = None
|
||||
current_new_path = None
|
||||
old_lines = []
|
||||
new_lines = []
|
||||
return
|
||||
content.append(
|
||||
acp.tool_diff_content(
|
||||
path=_strip_diff_prefix(path),
|
||||
old_text="\n".join(old_lines) if old_lines else None,
|
||||
new_text="\n".join(new_lines),
|
||||
)
|
||||
)
|
||||
current_old_path = None
|
||||
current_new_path = None
|
||||
old_lines = []
|
||||
new_lines = []
|
||||
|
||||
for line in diff_text.splitlines():
|
||||
if line.startswith("--- "):
|
||||
_flush()
|
||||
current_old_path = line[4:].strip()
|
||||
continue
|
||||
if line.startswith("+++ "):
|
||||
current_new_path = line[4:].strip()
|
||||
continue
|
||||
if line.startswith("@@"):
|
||||
continue
|
||||
if current_old_path is None and current_new_path is None:
|
||||
continue
|
||||
if line.startswith("+"):
|
||||
new_lines.append(line[1:])
|
||||
elif line.startswith("-"):
|
||||
old_lines.append(line[1:])
|
||||
elif line.startswith(" "):
|
||||
shared = line[1:]
|
||||
old_lines.append(shared)
|
||||
new_lines.append(shared)
|
||||
|
||||
_flush()
|
||||
return content
|
||||
|
||||
|
||||
def _build_tool_complete_content(
|
||||
tool_name: str,
|
||||
result: Optional[str],
|
||||
*,
|
||||
function_args: Optional[Dict[str, Any]] = None,
|
||||
snapshot: Any = None,
|
||||
) -> List[Any]:
|
||||
"""Build structured ACP completion content, falling back to plain text."""
|
||||
display_result = result or ""
|
||||
if len(display_result) > 5000:
|
||||
display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)"
|
||||
|
||||
if tool_name in {"write_file", "patch", "skill_manage"}:
|
||||
try:
|
||||
from agent.display import extract_edit_diff
|
||||
|
||||
diff_text = extract_edit_diff(
|
||||
tool_name,
|
||||
result,
|
||||
function_args=function_args,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
if isinstance(diff_text, str) and diff_text.strip():
|
||||
diff_content = _parse_unified_diff_content(diff_text)
|
||||
if diff_content:
|
||||
return diff_content
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return [acp.tool_content(acp.text_block(display_result))]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build ACP content objects for tool-call events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -119,9 +284,8 @@ def build_tool_start(
|
|||
new = arguments.get("new_string", "")
|
||||
content = [acp.tool_diff_content(path=path, new_text=new, old_text=old)]
|
||||
else:
|
||||
# Patch mode — show the patch content as text
|
||||
patch_text = arguments.get("patch", "")
|
||||
content = [acp.tool_content(acp.text_block(patch_text))]
|
||||
content = _build_patch_mode_content(patch_text)
|
||||
return acp.start_tool_call(
|
||||
tool_call_id, title, kind=kind, content=content, locations=locations,
|
||||
raw_input=arguments,
|
||||
|
|
@ -178,16 +342,17 @@ def build_tool_complete(
|
|||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: Optional[str] = None,
|
||||
function_args: Optional[Dict[str, Any]] = None,
|
||||
snapshot: Any = None,
|
||||
) -> ToolCallProgress:
|
||||
"""Create a ToolCallUpdate (progress) event for a completed tool call."""
|
||||
kind = get_tool_kind(tool_name)
|
||||
|
||||
# Truncate very large results for the UI
|
||||
display_result = result or ""
|
||||
if len(display_result) > 5000:
|
||||
display_result = display_result[:4900] + f"\n... ({len(result)} chars total, truncated)"
|
||||
|
||||
content = [acp.tool_content(acp.text_block(display_result))]
|
||||
content = _build_tool_complete_content(
|
||||
tool_name,
|
||||
result,
|
||||
function_args=function_args,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
return acp.update_tool_call(
|
||||
tool_call_id,
|
||||
kind=kind,
|
||||
|
|
|
|||
|
|
@ -94,6 +94,17 @@ def _normalize_aux_provider(provider: Optional[str]) -> str:
|
|||
return "custom"
|
||||
return _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
|
||||
_FIXED_TEMPERATURE_MODELS: Dict[str, float] = {
|
||||
"kimi-for-coding": 0.6,
|
||||
}
|
||||
|
||||
|
||||
def _fixed_temperature_for_model(model: Optional[str]) -> Optional[float]:
|
||||
"""Return a required temperature override for models with strict contracts."""
|
||||
normalized = (model or "").strip().lower()
|
||||
return _FIXED_TEMPERATURE_MODELS.get(normalized)
|
||||
|
||||
# Default auxiliary models for direct API-key providers (cheap/fast for side tasks)
|
||||
_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
|
||||
"gemini": "gemini-3-flash-preview",
|
||||
|
|
@ -1064,8 +1075,6 @@ _AUTO_PROVIDER_LABELS = {
|
|||
"_resolve_api_key_provider": "api-key",
|
||||
}
|
||||
|
||||
_AGGREGATOR_PROVIDERS = frozenset({"openrouter", "nous"})
|
||||
|
||||
_MAIN_RUNTIME_FIELDS = ("provider", "model", "base_url", "api_key", "api_mode")
|
||||
|
||||
|
||||
|
|
@ -1196,11 +1205,15 @@ def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Option
|
|||
"""Full auto-detection chain.
|
||||
|
||||
Priority:
|
||||
1. If the user's main provider is NOT an aggregator (OpenRouter / Nous),
|
||||
use their main provider + main model directly. This ensures users on
|
||||
Alibaba, DeepSeek, ZAI, etc. get auxiliary tasks handled by the same
|
||||
provider they already have credentials for — no OpenRouter key needed.
|
||||
2. OpenRouter → Nous → custom → Codex → API-key providers (original chain).
|
||||
1. User's main provider + main model, regardless of provider type.
|
||||
This means auxiliary tasks (compression, vision, web extraction,
|
||||
session search, etc.) use the same model the user configured for
|
||||
chat. Users on OpenRouter/Nous get their chosen chat model; users
|
||||
on DeepSeek/ZAI/Alibaba get theirs; etc. Running aux tasks on the
|
||||
user's picked model keeps behavior predictable — no surprise
|
||||
switches to a cheap fallback model for side tasks.
|
||||
2. OpenRouter → Nous → custom → Codex → API-key providers (fallback
|
||||
chain, only used when the main provider has no working client).
|
||||
"""
|
||||
global auxiliary_is_nous, _stale_base_url_warned
|
||||
auxiliary_is_nous = False # Reset — _try_nous() will set True if it wins
|
||||
|
|
@ -1230,11 +1243,16 @@ def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Option
|
|||
)
|
||||
_stale_base_url_warned = True
|
||||
|
||||
# ── Step 1: non-aggregator main provider → use main model directly ──
|
||||
# ── Step 1: main provider + main model → use them directly ──
|
||||
#
|
||||
# This is the primary aux backend for every user. "auto" means
|
||||
# "use my main chat model for side tasks as well" — including users
|
||||
# on aggregators (OpenRouter, Nous) who previously got routed to a
|
||||
# cheap provider-side default. Explicit per-task overrides set via
|
||||
# config.yaml (auxiliary.<task>.provider) still win over this.
|
||||
main_provider = runtime_provider or _read_main_provider()
|
||||
main_model = runtime_model or _read_main_model()
|
||||
if (main_provider and main_model
|
||||
and main_provider not in _AGGREGATOR_PROVIDERS
|
||||
and main_provider not in ("auto", "")):
|
||||
resolved_provider = main_provider
|
||||
explicit_base_url = None
|
||||
|
|
@ -1817,34 +1835,31 @@ def resolve_vision_provider_client(
|
|||
|
||||
if requested == "auto":
|
||||
# Vision auto-detection order:
|
||||
# 1. Active provider + model (user's main chat config)
|
||||
# 2. OpenRouter (known vision-capable default model)
|
||||
# 3. Nous Portal (known vision-capable default model)
|
||||
# 1. User's main provider + main model (including aggregators).
|
||||
# _PROVIDER_VISION_MODELS provides per-provider vision model
|
||||
# overrides when the provider has a dedicated multimodal model
|
||||
# that differs from the chat model (e.g. xiaomi → mimo-v2-omni,
|
||||
# zai → glm-5v-turbo).
|
||||
# 2. OpenRouter (vision-capable aggregator fallback)
|
||||
# 3. Nous Portal (vision-capable aggregator fallback)
|
||||
# 4. Stop
|
||||
main_provider = _read_main_provider()
|
||||
main_model = _read_main_model()
|
||||
if main_provider and main_provider not in ("auto", ""):
|
||||
if main_provider in _VISION_AUTO_PROVIDER_ORDER:
|
||||
# Known strict backend — use its defaults.
|
||||
sync_client, default_model = _resolve_strict_vision_backend(main_provider)
|
||||
if sync_client is not None:
|
||||
return _finalize(main_provider, sync_client, default_model)
|
||||
else:
|
||||
# Exotic provider (DeepSeek, Alibaba, Xiaomi, named custom, etc.)
|
||||
# Use provider-specific vision model if available, otherwise main model.
|
||||
vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model)
|
||||
rpc_client, rpc_model = resolve_provider_client(
|
||||
main_provider, vision_model,
|
||||
api_mode=resolved_api_mode)
|
||||
if rpc_client is not None:
|
||||
logger.info(
|
||||
"Vision auto-detect: using active provider %s (%s)",
|
||||
main_provider, rpc_model or vision_model,
|
||||
)
|
||||
return _finalize(
|
||||
main_provider, rpc_client, rpc_model or vision_model)
|
||||
vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model)
|
||||
rpc_client, rpc_model = resolve_provider_client(
|
||||
main_provider, vision_model,
|
||||
api_mode=resolved_api_mode)
|
||||
if rpc_client is not None:
|
||||
logger.info(
|
||||
"Vision auto-detect: using main provider %s (%s)",
|
||||
main_provider, rpc_model or vision_model,
|
||||
)
|
||||
return _finalize(
|
||||
main_provider, rpc_client, rpc_model or vision_model)
|
||||
|
||||
# Fall back through aggregators.
|
||||
# Fall back through aggregators (uses their dedicated vision model,
|
||||
# not the user's main model) when main provider has no client.
|
||||
for candidate in _VISION_AUTO_PROVIDER_ORDER:
|
||||
if candidate == main_provider:
|
||||
continue # already tried above
|
||||
|
|
@ -2293,6 +2308,10 @@ def _build_call_kwargs(
|
|||
"timeout": timeout,
|
||||
}
|
||||
|
||||
fixed_temperature = _fixed_temperature_for_model(model)
|
||||
if fixed_temperature is not None:
|
||||
temperature = fixed_temperature
|
||||
|
||||
# Opus 4.7+ rejects any non-default temperature/top_p/top_k — silently
|
||||
# drop here so auxiliary callers that hardcode temperature (e.g. 0.3 on
|
||||
# flush_memories, 0 on structured-JSON extraction) don't 400 the moment
|
||||
|
|
|
|||
|
|
@ -1130,6 +1130,14 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
|||
state = _load_provider_state(auth_store, "nous")
|
||||
if state:
|
||||
active_sources.add("device_code")
|
||||
# Prefer a user-supplied label embedded in the singleton state
|
||||
# (set by persist_nous_credentials(label=...) when the user ran
|
||||
# `hermes auth add nous --label <name>`). Fall back to the
|
||||
# auto-derived token fingerprint for logins that didn't supply one.
|
||||
custom_label = str(state.get("label") or "").strip()
|
||||
seeded_label = custom_label or label_from_token(
|
||||
state.get("access_token", ""), "device_code"
|
||||
)
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
provider,
|
||||
|
|
@ -1148,7 +1156,7 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
|
|||
"agent_key": state.get("agent_key"),
|
||||
"agent_key_expires_at": state.get("agent_key_expires_at"),
|
||||
"tls": state.get("tls") if isinstance(state.get("tls"), dict) else None,
|
||||
"label": label_from_token(state.get("access_token", ""), "device_code"),
|
||||
"label": seeded_label,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -747,18 +747,149 @@ class GeminiCloudCodeClient:
|
|||
|
||||
|
||||
def _gemini_http_error(response: httpx.Response) -> CodeAssistError:
|
||||
"""Translate an httpx response into a CodeAssistError with rich metadata.
|
||||
|
||||
Parses Google's error envelope (``{"error": {"code", "message", "status",
|
||||
"details": [...]}}``) so the agent's error classifier can reason about
|
||||
the failure — ``status_code`` enables the rate_limit / auth classification
|
||||
paths, and ``response`` lets the main loop honor ``Retry-After`` just
|
||||
like it does for OpenAI SDK exceptions.
|
||||
|
||||
Also lifts a few recognizable Google conditions into human-readable
|
||||
messages so the user sees something better than a 500-char JSON dump:
|
||||
|
||||
MODEL_CAPACITY_EXHAUSTED → "Gemini model capacity exhausted for
|
||||
<model>. This is a Google-side throttle..."
|
||||
RESOURCE_EXHAUSTED w/o reason → quota-style message
|
||||
404 → "Model <name> not found at cloudcode-pa..."
|
||||
"""
|
||||
status = response.status_code
|
||||
|
||||
# Parse the body once, surviving any weird encodings.
|
||||
body_text = ""
|
||||
body_json: Dict[str, Any] = {}
|
||||
try:
|
||||
body = response.text[:500]
|
||||
body_text = response.text
|
||||
except Exception:
|
||||
body = ""
|
||||
# Let run_agent's retry logic see auth errors as rotatable via `api_key`
|
||||
body_text = ""
|
||||
if body_text:
|
||||
try:
|
||||
parsed = json.loads(body_text)
|
||||
if isinstance(parsed, dict):
|
||||
body_json = parsed
|
||||
except (ValueError, TypeError):
|
||||
body_json = {}
|
||||
|
||||
# Dig into Google's error envelope. Shape is:
|
||||
# {"error": {"code": 429, "message": "...", "status": "RESOURCE_EXHAUSTED",
|
||||
# "details": [{"@type": ".../ErrorInfo", "reason": "MODEL_CAPACITY_EXHAUSTED",
|
||||
# "metadata": {...}},
|
||||
# {"@type": ".../RetryInfo", "retryDelay": "30s"}]}}
|
||||
err_obj = body_json.get("error") if isinstance(body_json, dict) else None
|
||||
if not isinstance(err_obj, dict):
|
||||
err_obj = {}
|
||||
err_status = str(err_obj.get("status") or "").strip()
|
||||
err_message = str(err_obj.get("message") or "").strip()
|
||||
err_details_list = err_obj.get("details") if isinstance(err_obj.get("details"), list) else []
|
||||
|
||||
# Extract google.rpc.ErrorInfo reason + metadata. There may be more
|
||||
# than one ErrorInfo (rare), so we pick the first one with a reason.
|
||||
error_reason = ""
|
||||
error_metadata: Dict[str, Any] = {}
|
||||
retry_delay_seconds: Optional[float] = None
|
||||
for detail in err_details_list:
|
||||
if not isinstance(detail, dict):
|
||||
continue
|
||||
type_url = str(detail.get("@type") or "")
|
||||
if not error_reason and type_url.endswith("/google.rpc.ErrorInfo"):
|
||||
reason = detail.get("reason")
|
||||
if isinstance(reason, str) and reason:
|
||||
error_reason = reason
|
||||
md = detail.get("metadata")
|
||||
if isinstance(md, dict):
|
||||
error_metadata = md
|
||||
elif retry_delay_seconds is None and type_url.endswith("/google.rpc.RetryInfo"):
|
||||
# retryDelay is a google.protobuf.Duration string like "30s" or "1.5s".
|
||||
delay_raw = detail.get("retryDelay")
|
||||
if isinstance(delay_raw, str) and delay_raw.endswith("s"):
|
||||
try:
|
||||
retry_delay_seconds = float(delay_raw[:-1])
|
||||
except ValueError:
|
||||
pass
|
||||
elif isinstance(delay_raw, (int, float)):
|
||||
retry_delay_seconds = float(delay_raw)
|
||||
|
||||
# Fall back to the Retry-After header if the body didn't include RetryInfo.
|
||||
if retry_delay_seconds is None:
|
||||
try:
|
||||
header_val = response.headers.get("Retry-After") or response.headers.get("retry-after")
|
||||
except Exception:
|
||||
header_val = None
|
||||
if header_val:
|
||||
try:
|
||||
retry_delay_seconds = float(header_val)
|
||||
except (TypeError, ValueError):
|
||||
retry_delay_seconds = None
|
||||
|
||||
# Classify the error code. ``code_assist_rate_limited`` stays the default
|
||||
# for 429s; a more specific reason tag helps downstream callers (e.g. tests,
|
||||
# logs) without changing the rate_limit classification path.
|
||||
code = f"code_assist_http_{status}"
|
||||
if status == 401:
|
||||
code = "code_assist_unauthorized"
|
||||
elif status == 429:
|
||||
code = "code_assist_rate_limited"
|
||||
if error_reason == "MODEL_CAPACITY_EXHAUSTED":
|
||||
code = "code_assist_capacity_exhausted"
|
||||
|
||||
# Build a human-readable message. Keep the status + a raw-body tail for
|
||||
# debugging, but lead with a friendlier summary when we recognize the
|
||||
# Google signal.
|
||||
model_hint = ""
|
||||
if isinstance(error_metadata, dict):
|
||||
model_hint = str(error_metadata.get("model") or error_metadata.get("modelId") or "").strip()
|
||||
|
||||
if status == 429 and error_reason == "MODEL_CAPACITY_EXHAUSTED":
|
||||
target = model_hint or "this Gemini model"
|
||||
message = (
|
||||
f"Gemini capacity exhausted for {target} (Google-side throttle, "
|
||||
f"not a Hermes issue). Try a different Gemini model or set a "
|
||||
f"fallback_providers entry to a non-Gemini provider."
|
||||
)
|
||||
if retry_delay_seconds is not None:
|
||||
message += f" Google suggests retrying in {retry_delay_seconds:g}s."
|
||||
elif status == 429 and err_status == "RESOURCE_EXHAUSTED":
|
||||
message = (
|
||||
f"Gemini quota exhausted ({err_message or 'RESOURCE_EXHAUSTED'}). "
|
||||
f"Check /gquota for remaining daily requests."
|
||||
)
|
||||
if retry_delay_seconds is not None:
|
||||
message += f" Retry suggested in {retry_delay_seconds:g}s."
|
||||
elif status == 404:
|
||||
# Google returns 404 when a model has been retired or renamed.
|
||||
target = model_hint or (err_message or "model")
|
||||
message = (
|
||||
f"Code Assist 404: {target} is not available at "
|
||||
f"cloudcode-pa.googleapis.com. It may have been renamed or "
|
||||
f"retired. Check hermes_cli/models.py for the current list."
|
||||
)
|
||||
elif err_message:
|
||||
# Generic fallback with the parsed message.
|
||||
message = f"Code Assist HTTP {status} ({err_status or 'error'}): {err_message}"
|
||||
else:
|
||||
# Last-ditch fallback — raw body snippet.
|
||||
message = f"Code Assist returned HTTP {status}: {body_text[:500]}"
|
||||
|
||||
return CodeAssistError(
|
||||
f"Code Assist returned HTTP {status}: {body}",
|
||||
message,
|
||||
code=code,
|
||||
status_code=status,
|
||||
response=response,
|
||||
retry_after=retry_delay_seconds,
|
||||
details={
|
||||
"status": err_status,
|
||||
"reason": error_reason,
|
||||
"metadata": error_metadata,
|
||||
"message": err_message,
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -68,9 +68,45 @@ _ONBOARDING_POLL_INTERVAL_SECONDS = 5.0
|
|||
|
||||
|
||||
class CodeAssistError(RuntimeError):
|
||||
def __init__(self, message: str, *, code: str = "code_assist_error") -> None:
|
||||
"""Exception raised by the Code Assist (``cloudcode-pa``) integration.
|
||||
|
||||
Carries HTTP status / response / retry-after metadata so the agent's
|
||||
``error_classifier._extract_status_code`` and the main loop's Retry-After
|
||||
handling (which walks ``error.response.headers``) pick up the right
|
||||
signals. Without these, 429s from the OAuth path look like opaque
|
||||
``RuntimeError`` and skip the rate-limit path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
code: str = "code_assist_error",
|
||||
status_code: Optional[int] = None,
|
||||
response: Any = None,
|
||||
retry_after: Optional[float] = None,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
# ``status_code`` is picked up by ``agent.error_classifier._extract_status_code``
|
||||
# so a 429 from Code Assist classifies as FailoverReason.rate_limit and
|
||||
# triggers the main loop's fallback_providers chain the same way SDK
|
||||
# errors do.
|
||||
self.status_code = status_code
|
||||
# ``response`` is the underlying ``httpx.Response`` (or a shim with a
|
||||
# ``.headers`` mapping and ``.json()`` method). The main loop reads
|
||||
# ``error.response.headers["Retry-After"]`` to honor Google's retry
|
||||
# hints when the backend throttles us.
|
||||
self.response = response
|
||||
# Parsed ``Retry-After`` seconds (kept separately for convenience —
|
||||
# Google returns retry hints in both the header and the error body's
|
||||
# ``google.rpc.RetryInfo`` details, and we pick whichever we found).
|
||||
self.retry_after = retry_after
|
||||
# Parsed structured error details from the Google error envelope
|
||||
# (e.g. ``{"reason": "MODEL_CAPACITY_EXHAUSTED", "status": "RESOURCE_EXHAUSTED"}``).
|
||||
# Useful for logging and for tests that want to assert on specifics.
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class ProjectIdRequiredError(CodeAssistError):
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ _PROVIDER_PREFIXES: frozenset[str] = frozenset({
|
|||
"mimo", "xiaomi-mimo",
|
||||
"arcee-ai", "arceeai",
|
||||
"xai", "x-ai", "x.ai", "grok",
|
||||
"nvidia", "nim", "nvidia-nim", "nemotron",
|
||||
"qwen-portal",
|
||||
})
|
||||
|
||||
|
|
@ -124,7 +125,6 @@ DEFAULT_CONTEXT_LENGTHS = {
|
|||
"gemini": 1048576,
|
||||
# Gemma (open models served via AI Studio)
|
||||
"gemma-4-31b": 256000,
|
||||
"gemma-4-26b": 256000,
|
||||
"gemma-3": 131072,
|
||||
"gemma": 8192, # fallback for older gemma models
|
||||
# DeepSeek
|
||||
|
|
@ -158,6 +158,8 @@ DEFAULT_CONTEXT_LENGTHS = {
|
|||
"grok": 131072, # catch-all (grok-beta, unknown grok-*)
|
||||
# Kimi
|
||||
"kimi": 262144,
|
||||
# Nemotron — NVIDIA's open-weights series (128K context across all sizes)
|
||||
"nemotron": 131072,
|
||||
# Arcee
|
||||
"trinity": 262144,
|
||||
# OpenRouter
|
||||
|
|
@ -240,6 +242,7 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
|||
"api.fireworks.ai": "fireworks",
|
||||
"opencode.ai": "opencode-go",
|
||||
"api.x.ai": "xai",
|
||||
"integrate.api.nvidia.com": "nvidia",
|
||||
"api.xiaomimimo.com": "xiaomi",
|
||||
"xiaomimimo.com": "xiaomi",
|
||||
"ollama.com": "ollama-cloud",
|
||||
|
|
|
|||
|
|
@ -654,7 +654,7 @@ def build_skills_system_prompt(
|
|||
):
|
||||
continue
|
||||
skills_by_category.setdefault(category, []).append(
|
||||
(skill_name, entry.get("description", ""))
|
||||
(frontmatter_name, entry.get("description", ""))
|
||||
)
|
||||
category_descriptions = {
|
||||
str(k): str(v)
|
||||
|
|
@ -679,7 +679,7 @@ def build_skills_system_prompt(
|
|||
):
|
||||
continue
|
||||
skills_by_category.setdefault(entry["category"], []).append(
|
||||
(skill_name, entry["description"])
|
||||
(entry["frontmatter_name"], entry["description"])
|
||||
)
|
||||
|
||||
# Read category-level DESCRIPTION.md files
|
||||
|
|
@ -722,9 +722,10 @@ def build_skills_system_prompt(
|
|||
continue
|
||||
entry = _build_snapshot_entry(skill_file, ext_dir, frontmatter, desc)
|
||||
skill_name = entry["skill_name"]
|
||||
if skill_name in seen_skill_names:
|
||||
frontmatter_name = entry["frontmatter_name"]
|
||||
if frontmatter_name in seen_skill_names:
|
||||
continue
|
||||
if entry["frontmatter_name"] in disabled or skill_name in disabled:
|
||||
if frontmatter_name in disabled or skill_name in disabled:
|
||||
continue
|
||||
if not _skill_should_show(
|
||||
extract_skill_conditions(frontmatter),
|
||||
|
|
@ -732,9 +733,9 @@ def build_skills_system_prompt(
|
|||
available_toolsets,
|
||||
):
|
||||
continue
|
||||
seen_skill_names.add(skill_name)
|
||||
seen_skill_names.add(frontmatter_name)
|
||||
skills_by_category.setdefault(entry["category"], []).append(
|
||||
(skill_name, entry["description"])
|
||||
(frontmatter_name, entry["description"])
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Error reading external skill %s: %s", skill_file, e)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ model:
|
|||
# "minimax" - MiniMax global (requires: MINIMAX_API_KEY)
|
||||
# "minimax-cn" - MiniMax China (requires: MINIMAX_CN_API_KEY)
|
||||
# "huggingface" - Hugging Face Inference (requires: HF_TOKEN)
|
||||
# "nvidia" - NVIDIA NIM / build.nvidia.com (requires: NVIDIA_API_KEY)
|
||||
# "xiaomi" - Xiaomi MiMo (requires: XIAOMI_API_KEY)
|
||||
# "arcee" - Arcee AI Trinity models (requires: ARCEEAI_API_KEY)
|
||||
# "ollama-cloud" - Ollama Cloud (requires: OLLAMA_API_KEY — https://ollama.com/settings)
|
||||
|
|
|
|||
142
cli.py
142
cli.py
|
|
@ -18,6 +18,8 @@ import os
|
|||
import shutil
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
import base64
|
||||
import atexit
|
||||
import tempfile
|
||||
import time
|
||||
|
|
@ -78,6 +80,42 @@ _project_env = Path(__file__).parent / '.env'
|
|||
load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
|
||||
|
||||
|
||||
_REASONING_TAGS = (
|
||||
"REASONING_SCRATCHPAD",
|
||||
"think",
|
||||
"reasoning",
|
||||
"THINKING",
|
||||
"thinking",
|
||||
)
|
||||
|
||||
|
||||
def _strip_reasoning_tags(text: str) -> str:
|
||||
cleaned = text
|
||||
for tag in _REASONING_TAGS:
|
||||
cleaned = re.sub(rf"<{tag}>.*?</{tag}>\s*", "", cleaned, flags=re.DOTALL)
|
||||
cleaned = re.sub(rf"<{tag}>.*$", "", cleaned, flags=re.DOTALL)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def _assistant_content_as_text(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
str(part.get("text", ""))
|
||||
for part in content
|
||||
if isinstance(part, dict) and part.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _assistant_copy_text(content: Any) -> str:
|
||||
return _strip_reasoning_tags(_assistant_content_as_text(content))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration Loading
|
||||
# =============================================================================
|
||||
|
|
@ -1172,6 +1210,10 @@ def _resolve_attachment_path(raw_path: str) -> Path | None:
|
|||
return None
|
||||
|
||||
expanded = os.path.expandvars(os.path.expanduser(token))
|
||||
if os.name != "nt":
|
||||
normalized = expanded.replace("\\", "/")
|
||||
if len(normalized) >= 3 and normalized[1] == ":" and normalized[2] == "/" and normalized[0].isalpha():
|
||||
expanded = f"/mnt/{normalized[0].lower()}/{normalized[3:]}"
|
||||
path = Path(expanded)
|
||||
if not path.is_absolute():
|
||||
base_dir = Path(os.getenv("TERMINAL_CWD", os.getcwd()))
|
||||
|
|
@ -1254,10 +1296,12 @@ def _detect_file_drop(user_input: str) -> "dict | None":
|
|||
or stripped.startswith("~")
|
||||
or stripped.startswith("./")
|
||||
or stripped.startswith("../")
|
||||
or (len(stripped) >= 3 and stripped[1] == ":" and stripped[2] in ("\\", "/") and stripped[0].isalpha())
|
||||
or stripped.startswith('"/')
|
||||
or stripped.startswith('"~')
|
||||
or stripped.startswith("'/")
|
||||
or stripped.startswith("'~")
|
||||
or (len(stripped) >= 4 and stripped[0] in ("'", '"') and stripped[2] == ":" and stripped[3] in ("\\", "/") and stripped[1].isalpha())
|
||||
)
|
||||
if not starts_like_path:
|
||||
return None
|
||||
|
|
@ -3125,21 +3169,6 @@ class HermesCLI:
|
|||
MAX_ASST_LEN = 200 # truncate assistant text
|
||||
MAX_ASST_LINES = 3 # max lines of assistant text
|
||||
|
||||
def _strip_reasoning(text: str) -> str:
|
||||
"""Remove <REASONING_SCRATCHPAD>...</REASONING_SCRATCHPAD> blocks
|
||||
from displayed text (reasoning model internal thoughts)."""
|
||||
import re
|
||||
cleaned = re.sub(
|
||||
r"<REASONING_SCRATCHPAD>.*?</REASONING_SCRATCHPAD>\s*",
|
||||
"", text, flags=re.DOTALL,
|
||||
)
|
||||
# Also strip unclosed reasoning tags at the end
|
||||
cleaned = re.sub(
|
||||
r"<REASONING_SCRATCHPAD>.*$",
|
||||
"", cleaned, flags=re.DOTALL,
|
||||
)
|
||||
return cleaned.strip()
|
||||
|
||||
# Collect displayable entries (skip system, tool-result messages)
|
||||
entries = [] # list of (role, display_text)
|
||||
_last_asst_idx = None # index of last assistant entry
|
||||
|
|
@ -3171,7 +3200,7 @@ class HermesCLI:
|
|||
|
||||
elif role == "assistant":
|
||||
text = "" if content is None else str(content)
|
||||
text = _strip_reasoning(text)
|
||||
text = _strip_reasoning_tags(text)
|
||||
parts = []
|
||||
full_parts = [] # un-truncated version
|
||||
if text:
|
||||
|
|
@ -3510,6 +3539,26 @@ class HermesCLI:
|
|||
killed = process_registry.kill_all()
|
||||
print(f" ✅ Stopped {killed} process(es).")
|
||||
|
||||
def _handle_agents_command(self):
|
||||
"""Handle /agents — show background processes and agent status."""
|
||||
from tools.process_registry import format_uptime_short, process_registry
|
||||
|
||||
processes = process_registry.list_sessions()
|
||||
running = [p for p in processes if p.get("status") == "running"]
|
||||
finished = [p for p in processes if p.get("status") != "running"]
|
||||
|
||||
_cprint(f" Running processes: {len(running)}")
|
||||
for p in running:
|
||||
cmd = p.get("command", "")[:80]
|
||||
up = format_uptime_short(p.get("uptime_seconds", 0))
|
||||
_cprint(f" {p.get('session_id', '?')} · {up} · {cmd}")
|
||||
|
||||
if finished:
|
||||
_cprint(f" Recently finished: {len(finished)}")
|
||||
|
||||
agent_running = getattr(self, "_agent_running", False)
|
||||
_cprint(f" Agent: {'running' if agent_running else 'idle'}")
|
||||
|
||||
def _handle_paste_command(self):
|
||||
"""Handle /paste — explicitly check clipboard for an image.
|
||||
|
||||
|
|
@ -3535,6 +3584,61 @@ class HermesCLI:
|
|||
else:
|
||||
_cprint(f" {_DIM}(._.) No image found in clipboard{_RST}")
|
||||
|
||||
def _write_osc52_clipboard(self, text: str) -> None:
|
||||
"""Copy *text* to terminal clipboard via OSC 52."""
|
||||
payload = base64.b64encode(text.encode("utf-8")).decode("ascii")
|
||||
seq = f"\x1b]52;c;{payload}\x07"
|
||||
out = getattr(self, "_app", None)
|
||||
output = getattr(out, "output", None) if out else None
|
||||
if output and hasattr(output, "write_raw"):
|
||||
output.write_raw(seq)
|
||||
output.flush()
|
||||
return
|
||||
if output and hasattr(output, "write"):
|
||||
output.write(seq)
|
||||
output.flush()
|
||||
return
|
||||
sys.stdout.write(seq)
|
||||
sys.stdout.flush()
|
||||
|
||||
def _handle_copy_command(self, cmd_original: str) -> None:
|
||||
"""Handle /copy [number] — copy assistant output to clipboard."""
|
||||
parts = cmd_original.split(maxsplit=1)
|
||||
arg = parts[1].strip() if len(parts) > 1 else ""
|
||||
|
||||
assistant = [m for m in self.conversation_history if m.get("role") == "assistant"]
|
||||
if not assistant:
|
||||
_cprint(" Nothing to copy yet.")
|
||||
return
|
||||
|
||||
if arg:
|
||||
try:
|
||||
idx = int(arg) - 1
|
||||
except ValueError:
|
||||
_cprint(" Usage: /copy [number]")
|
||||
return
|
||||
if idx < 0 or idx >= len(assistant):
|
||||
_cprint(f" Invalid response number. Use 1-{len(assistant)}.")
|
||||
return
|
||||
else:
|
||||
idx = len(assistant) - 1
|
||||
while idx >= 0 and not _assistant_copy_text(assistant[idx].get("content")):
|
||||
idx -= 1
|
||||
if idx < 0:
|
||||
_cprint(" Nothing to copy in assistant responses yet.")
|
||||
return
|
||||
|
||||
text = _assistant_copy_text(assistant[idx].get("content"))
|
||||
if not text:
|
||||
_cprint(" Nothing to copy in that assistant response.")
|
||||
return
|
||||
|
||||
try:
|
||||
self._write_osc52_clipboard(text)
|
||||
_cprint(f" Copied assistant response #{idx + 1} to clipboard")
|
||||
except Exception as e:
|
||||
_cprint(f" Clipboard copy failed: {e}")
|
||||
|
||||
def _handle_image_command(self, cmd_original: str):
|
||||
"""Handle /image <path> — attach a local image file for the next prompt."""
|
||||
raw_args = (cmd_original.split(None, 1)[1].strip() if " " in cmd_original else "")
|
||||
|
|
@ -3671,7 +3775,7 @@ class HermesCLI:
|
|||
skin = get_active_skin()
|
||||
separator_color = skin.get_color("banner_dim", "#B8860B")
|
||||
accent_color = skin.get_color("ui_accent", "#FFBF00")
|
||||
label_color = skin.get_color("ui_label", "#4dd0e1")
|
||||
label_color = skin.get_color("ui_label", "#DAA520")
|
||||
except Exception:
|
||||
separator_color, accent_color, label_color = "#B8860B", "#FFBF00", "cyan"
|
||||
toolsets_info = ""
|
||||
|
|
@ -5553,6 +5657,8 @@ class HermesCLI:
|
|||
self._show_usage()
|
||||
elif canonical == "insights":
|
||||
self._show_insights(cmd_original)
|
||||
elif canonical == "copy":
|
||||
self._handle_copy_command(cmd_original)
|
||||
elif canonical == "debug":
|
||||
self._handle_debug_command()
|
||||
elif canonical == "paste":
|
||||
|
|
@ -5596,6 +5702,8 @@ class HermesCLI:
|
|||
self._handle_snapshot_command(cmd_original)
|
||||
elif canonical == "stop":
|
||||
self._handle_stop_command()
|
||||
elif canonical == "agents":
|
||||
self._handle_agents_command()
|
||||
elif canonical == "background":
|
||||
self._handle_background_command(cmd_original)
|
||||
elif canonical == "btw":
|
||||
|
|
|
|||
|
|
@ -65,7 +65,15 @@ _HOME_TARGET_ENV_VARS = {
|
|||
"wecom": "WECOM_HOME_CHANNEL",
|
||||
"weixin": "WEIXIN_HOME_CHANNEL",
|
||||
"bluebubbles": "BLUEBUBBLES_HOME_CHANNEL",
|
||||
"qqbot": "QQ_HOME_CHANNEL",
|
||||
"qqbot": "QQBOT_HOME_CHANNEL",
|
||||
}
|
||||
|
||||
# Legacy env var names kept for back-compat. Each entry is the current
|
||||
# primary env var → the previous name. _get_home_target_chat_id falls
|
||||
# back to the legacy name if the primary is unset, so users who set the
|
||||
# old name before the rename keep working until they migrate.
|
||||
_LEGACY_HOME_TARGET_ENV_VARS = {
|
||||
"QQBOT_HOME_CHANNEL": "QQ_HOME_CHANNEL",
|
||||
}
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run
|
||||
|
|
@ -100,7 +108,12 @@ def _get_home_target_chat_id(platform_name: str) -> str:
|
|||
env_var = _HOME_TARGET_ENV_VARS.get(platform_name.lower())
|
||||
if not env_var:
|
||||
return ""
|
||||
return os.getenv(env_var, "")
|
||||
value = os.getenv(env_var, "")
|
||||
if not value:
|
||||
legacy = _LEGACY_HOME_TARGET_ENV_VARS.get(env_var)
|
||||
if legacy:
|
||||
value = os.getenv(legacy, "")
|
||||
return value
|
||||
|
||||
|
||||
def _resolve_single_delivery_target(job: dict, deliver_value: str) -> Optional[dict]:
|
||||
|
|
|
|||
108
docs/plans/2026-04-01-ink-gateway-tui-migration-plan.md
Normal file
108
docs/plans/2026-04-01-ink-gateway-tui-migration-plan.md
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
# Ink Gateway TUI Migration — Post-mortem
|
||||
|
||||
Planned: 2026-04-01 · Delivered: 2026-04 · Status: shipped, classic (prompt_toolkit) CLI still present
|
||||
|
||||
## What Shipped
|
||||
|
||||
Three layers, same repo, Python runtime unchanged.
|
||||
|
||||
```
|
||||
ui-tui (Node/TS) ──stdio JSON-RPC──▶ tui_gateway (Py) ──▶ AIAgent (run_agent.py)
|
||||
```
|
||||
|
||||
### Backend — `tui_gateway/`
|
||||
|
||||
```
|
||||
tui_gateway/
|
||||
├── entry.py # subprocess entrypoint, stdio read/write loop
|
||||
├── server.py # everything: sessions dict, @method handlers, _emit
|
||||
├── render.py # stream renderer, diff rendering, message rendering
|
||||
├── slash_worker.py # subprocess that runs hermes_cli slash commands
|
||||
└── __init__.py
|
||||
```
|
||||
|
||||
`server.py` owns the full runtime-control surface: session store (`_sessions: dict[str, dict]`), method registry (`@method("…")` decorator), event emitter (`_emit`), agent lifecycle (`_make_agent`, `_init_session`, `_wire_callbacks`), approval/sudo/clarify round-trips, and JSON-RPC dispatch.
|
||||
|
||||
Protocol methods (`@method(...)` in `server.py`):
|
||||
|
||||
- session: `session.{create, resume, list, close, interrupt, usage, history, compress, branch, title, save, undo}`
|
||||
- prompt: `prompt.{submit, background, btw}`
|
||||
- tools: `tools.{list, show, configure}`
|
||||
- slash: `slash.exec`, `command.{dispatch, resolve}`, `commands.catalog`, `complete.{path, slash}`
|
||||
- approvals: `approval.respond`, `sudo.respond`, `clarify.respond`, `secret.respond`
|
||||
- config/state: `config.{get, set, show}`, `model.options`, `reload.mcp`
|
||||
- ops: `shell.exec`, `cli.exec`, `terminal.resize`, `input.detect_drop`, `clipboard.paste`, `paste.collapse`, `image.attach`, `process.stop`
|
||||
- misc: `agents.list`, `skills.manage`, `plugins.list`, `cron.manage`, `insights.get`, `rollback.{list, diff, restore}`, `browser.manage`
|
||||
|
||||
Protocol events (`_emit(…)` → handled in `ui-tui/src/app/createGatewayEventHandler.ts`):
|
||||
|
||||
- lifecycle: `gateway.{ready, stderr}`, `session.info`, `skin.changed`
|
||||
- stream: `message.{start, delta, complete}`, `thinking.delta`, `reasoning.{delta, available}`, `status.update`
|
||||
- tools: `tool.{start, progress, complete, generating}`, `subagent.{start, thinking, tool, progress, complete}`
|
||||
- interactive: `approval.request`, `sudo.request`, `clarify.request`, `secret.request`
|
||||
- async: `background.complete`, `btw.complete`, `error`
|
||||
|
||||
### Frontend — `ui-tui/src/`
|
||||
|
||||
```
|
||||
src/
|
||||
├── entry.tsx # node bootstrap: bootBanner → spawn python → dynamic-import Ink → render(<App/>)
|
||||
├── app.tsx # <GatewayProvider> wraps <AppLayout>
|
||||
├── bootBanner.ts # raw-ANSI banner to stdout in ~2ms, pre-React
|
||||
├── gatewayClient.ts # JSON-RPC client over child_process stdio
|
||||
├── gatewayTypes.ts # typed RPC responses + GatewayEvent union
|
||||
├── theme.ts # DEFAULT_THEME + fromSkin
|
||||
│
|
||||
├── app/ # hooks + stores — the orchestration layer
|
||||
│ ├── uiStore.ts # nanostore: sid, info, busy, usage, theme, status…
|
||||
│ ├── turnStore.ts # nanostore: per-turn activity / reasoning / tools
|
||||
│ ├── turnController.ts # imperative singleton for stream-time operations
|
||||
│ ├── overlayStore.ts # nanostore: modal/overlay state
|
||||
│ ├── useMainApp.ts # top-level composition hook
|
||||
│ ├── useSessionLifecycle.ts # session.create/resume/close/reset
|
||||
│ ├── useSubmission.ts # shell/slash/prompt dispatch + interpolation
|
||||
│ ├── useConfigSync.ts # config.get + mtime poll
|
||||
│ ├── useComposerState.ts # input buffer, paste snippets, editor mode
|
||||
│ ├── useInputHandlers.ts # key bindings
|
||||
│ ├── createGatewayEventHandler.ts # event-stream dispatcher
|
||||
│ ├── createSlashHandler.ts # slash command router (registry + python fallback)
|
||||
│ └── slash/commands/ # core.ts, ops.ts, session.ts — TS-owned slash commands
|
||||
│
|
||||
├── components/ # AppLayout, AppChrome, AppOverlays, MessageLine, Thinking, Markdown, pickers, prompts, Banner, SessionPanel
|
||||
├── config/ # env, limits, timing constants
|
||||
├── content/ # charms, faces, fortunes, hotkeys, placeholders, verbs
|
||||
├── domain/ # details, messages, paths, roles, slash, usage, viewport
|
||||
├── protocol/ # interpolation, paste regex
|
||||
├── hooks/ # useCompletion, useInputHistory, useQueue, useVirtualHistory
|
||||
└── lib/ # history, messages, osc52, rpc, text
|
||||
```
|
||||
|
||||
### CLI entry points — `hermes_cli/main.py`
|
||||
|
||||
- `hermes --tui` → `node dist/entry.js` (auto-builds when `.ts`/`.tsx` newer than `dist/entry.js`)
|
||||
- `hermes --tui --dev` → `tsx src/entry.tsx` (skip build)
|
||||
- `HERMES_TUI_DIR=…` → external prebuilt dist (nix, distro packaging)
|
||||
|
||||
## Diverged From Original Plan
|
||||
|
||||
| Plan | Reality | Why |
|
||||
|---|---|---|
|
||||
| `tui_gateway/{controller,session_state,events,protocol}.py` | all collapsed into `server.py` | no second consumer ever emerged, keeping one file cheaper than four |
|
||||
| `ui-tui/src/main.tsx` | split into `entry.tsx` (bootstrap) + `app.tsx` (shell) | boot banner + early python spawn wanted a pre-React moment |
|
||||
| `ui-tui/src/state/store.ts` | three nanostores (`uiStore`, `turnStore`, `overlayStore`) | separate lifetimes: ui persists, turn resets per reply, overlay is modal |
|
||||
| `approval.requested` / `sudo.requested` / `clarify.requested` | `*.request` (no `-ed`) | cosmetic |
|
||||
| `session.cancel` | dropped | `session.interrupt` covers it |
|
||||
| `HERMES_EXPERIMENTAL_TUI=1`, `display.experimental_tui: true`, `/tui on/off/status` | none shipped | `--tui` went from opt-in to first-class without an experimental phase |
|
||||
|
||||
## Post-migration Additions (not in original plan)
|
||||
|
||||
- **Async `session.create`** — returns sid in ~1ms, agent builds on a background thread, `session.info` broadcasts when ready; `_wait_agent()` gates every agent-touching handler via `_sess`
|
||||
- **`bootBanner`** — raw-ANSI logo painted to stdout at T≈2ms, before Ink loads; `<AlternateScreen>` wipes it seamlessly when React mounts
|
||||
- **Selection uniform bg** — `theme.color.selectionBg` wired via `useSelection().setSelectionBgColor`; replaces SGR-inverse per-cell swap that fragmented over amber/gold fg
|
||||
- **Slash command registry** — TS-owned commands in `app/slash/commands/{core,ops,session}.ts`, everything else falls through to `slash.exec` (python worker)
|
||||
- **Turn store + controller split** — imperative singleton (`turnController`) holds refs/timers, nanostore (`turnStore`) holds render-visible state
|
||||
|
||||
## What's Still Open
|
||||
|
||||
- **Classic CLI not deleted.** `cli.py` still has ~80 `prompt_toolkit` references; classic REPL is still the default when `--tui` is absent. The original plan's "Cut 4 · prompt_toolkit removal later" hasn't happened.
|
||||
- **No config-file opt-in.** `HERMES_EXPERIMENTAL_TUI` and `display.experimental_tui` were never built; only the CLI flag exists. Fine for now — if we want "default to TUI", a single line in `main.py` flips it.
|
||||
|
|
@ -6,6 +6,11 @@
|
|||
# All fields are optional — missing values inherit from the default skin.
|
||||
# Activate with: /skin <name> or display.skin: <name> in config.yaml
|
||||
#
|
||||
# Keys are marked:
|
||||
# (both) — applies to both the classic CLI and the TUI
|
||||
# (classic) — classic CLI only (see hermes --tui in user-guide/tui.md)
|
||||
# (tui) — TUI only
|
||||
#
|
||||
# See hermes_cli/skin_engine.py for the full schema reference.
|
||||
# ============================================================================
|
||||
|
||||
|
|
@ -14,43 +19,47 @@ name: example
|
|||
description: An example custom skin — copy and modify this template
|
||||
|
||||
# ── Colors ──────────────────────────────────────────────────────────────────
|
||||
# Hex color values for Rich markup. These control the CLI's visual palette.
|
||||
# Hex color values. These control the visual palette.
|
||||
colors:
|
||||
# Banner panel (the startup welcome box)
|
||||
# Banner panel (the startup welcome box) — (both)
|
||||
banner_border: "#CD7F32" # Panel border
|
||||
banner_title: "#FFD700" # Panel title text
|
||||
banner_accent: "#FFBF00" # Section headers (Available Tools, Skills, etc.)
|
||||
banner_dim: "#B8860B" # Dim/muted text (separators, model info)
|
||||
banner_text: "#FFF8DC" # Body text (tool names, skill names)
|
||||
|
||||
# UI elements
|
||||
ui_accent: "#FFBF00" # General accent color
|
||||
# UI elements — (both)
|
||||
ui_accent: "#FFBF00" # General accent (falls back to banner_accent)
|
||||
ui_label: "#4dd0e1" # Labels
|
||||
ui_ok: "#4caf50" # Success indicators
|
||||
ui_error: "#ef5350" # Error indicators
|
||||
ui_warn: "#ffa726" # Warning indicators
|
||||
|
||||
# Input area
|
||||
prompt: "#FFF8DC" # Prompt text color
|
||||
input_rule: "#CD7F32" # Horizontal rule around input
|
||||
prompt: "#FFF8DC" # Prompt text / `❯` glyph color (both)
|
||||
input_rule: "#CD7F32" # Horizontal rule above input (classic)
|
||||
|
||||
# Response box
|
||||
response_border: "#FFD700" # Response box border (ANSI color)
|
||||
# Response box — (classic)
|
||||
response_border: "#FFD700" # Response box border
|
||||
|
||||
# Session display
|
||||
session_label: "#DAA520" # Session label
|
||||
session_border: "#8B8682" # Session ID dim color
|
||||
# Session display — (both)
|
||||
session_label: "#DAA520" # "Session: " label
|
||||
session_border: "#8B8682" # Session ID text
|
||||
|
||||
# TUI surfaces
|
||||
status_bar_bg: "#1a1a2e" # Status / usage bar background
|
||||
voice_status_bg: "#1a1a2e" # Voice-mode badge background
|
||||
completion_menu_bg: "#1a1a2e" # Completion list background
|
||||
completion_menu_current_bg: "#333355" # Active completion row background
|
||||
completion_menu_meta_bg: "#1a1a2e" # Completion meta column background
|
||||
completion_menu_meta_current_bg: "#333355" # Active completion meta background
|
||||
# TUI / CLI surfaces — (classic: status bar, voice badge, completion meta)
|
||||
status_bar_bg: "#1a1a2e" # Status / usage bar background (classic)
|
||||
voice_status_bg: "#1a1a2e" # Voice-mode badge background (classic)
|
||||
completion_menu_bg: "#1a1a2e" # Completion list background (both)
|
||||
completion_menu_current_bg: "#333355" # Active completion row background (both)
|
||||
completion_menu_meta_bg: "#1a1a2e" # Completion meta column bg (classic)
|
||||
completion_menu_meta_current_bg: "#333355" # Active meta bg (classic)
|
||||
|
||||
# Drag-to-select background — (tui)
|
||||
selection_bg: "#3a3a55" # Uniform selection highlight in the TUI
|
||||
|
||||
# ── Spinner ─────────────────────────────────────────────────────────────────
|
||||
# Customize the animated spinner shown during API calls and tool execution.
|
||||
# (classic) — the TUI uses its own animated indicators; spinner config here
|
||||
# is only read by the classic prompt_toolkit CLI.
|
||||
spinner:
|
||||
# Faces shown while waiting for the API response
|
||||
waiting_faces:
|
||||
|
|
@ -78,17 +87,17 @@ spinner:
|
|||
# - ["⟪▲", "▲⟫"]
|
||||
|
||||
# ── Branding ────────────────────────────────────────────────────────────────
|
||||
# Text strings used throughout the CLI interface.
|
||||
# Text strings used throughout the interface.
|
||||
branding:
|
||||
agent_name: "Hermes Agent" # Banner title, about display
|
||||
welcome: "Welcome! Type your message or /help for commands."
|
||||
goodbye: "Goodbye! ⚕" # Exit message
|
||||
response_label: " ⚕ Hermes " # Response box header label
|
||||
prompt_symbol: "❯ " # Input prompt symbol
|
||||
help_header: "(^_^)? Available Commands" # /help header text
|
||||
agent_name: "Hermes Agent" # (both) Banner title, about display
|
||||
welcome: "Welcome! Type your message or /help for commands." # (both)
|
||||
goodbye: "Goodbye! ⚕" # (both) Exit message
|
||||
response_label: " ⚕ Hermes " # (classic) Response box header label
|
||||
prompt_symbol: "❯ " # (both) Input prompt glyph
|
||||
help_header: "(^_^)? Available Commands" # (both) /help overlay title
|
||||
|
||||
# ── Tool Output ─────────────────────────────────────────────────────────────
|
||||
# Character used as the prefix for tool output lines.
|
||||
# Character used as the prefix for tool output lines. (both)
|
||||
# Default is "┊" (thin dotted vertical line). Some alternatives:
|
||||
# "╎" (light triple dash vertical)
|
||||
# "▏" (left one-eighth block)
|
||||
|
|
|
|||
21
flake.lock
generated
21
flake.lock
generated
|
|
@ -36,6 +36,26 @@
|
|||
"type": "github"
|
||||
}
|
||||
},
|
||||
"npm-lockfile-fix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1775903712,
|
||||
"narHash": "sha256-2GV79U6iVH4gKAPWYrxUReB0S41ty/Y3dBLquU8AlaA=",
|
||||
"owner": "jeslie0",
|
||||
"repo": "npm-lockfile-fix",
|
||||
"rev": "c6093acb0c0548e0f9b8b3d82918823721930fe8",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "jeslie0",
|
||||
"repo": "npm-lockfile-fix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-build-systems": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
|
|
@ -124,6 +144,7 @@
|
|||
"inputs": {
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"npm-lockfile-fix": "npm-lockfile-fix",
|
||||
"pyproject-build-systems": "pyproject-build-systems",
|
||||
"pyproject-nix": "pyproject-nix_2",
|
||||
"uv2nix": "uv2nix_2"
|
||||
|
|
|
|||
13
flake.nix
13
flake.nix
|
|
@ -19,11 +19,20 @@
|
|||
url = "github:pyproject-nix/build-system-pkgs";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
npm-lockfile-fix = {
|
||||
url = "github:jeslie0/npm-lockfile-fix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
};
|
||||
|
||||
outputs = inputs:
|
||||
outputs =
|
||||
inputs:
|
||||
inputs.flake-parts.lib.mkFlake { inherit inputs; } {
|
||||
systems = [ "x86_64-linux" "aarch64-linux" "aarch64-darwin" ];
|
||||
systems = [
|
||||
"x86_64-linux"
|
||||
"aarch64-linux"
|
||||
"aarch64-darwin"
|
||||
];
|
||||
|
||||
imports = [
|
||||
./nix/packages.nix
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
|||
|
||||
|
||||
def _build_discord(adapter) -> List[Dict[str, str]]:
|
||||
"""Enumerate all text channels the Discord bot can see."""
|
||||
"""Enumerate all text channels and forum channels the Discord bot can see."""
|
||||
channels = []
|
||||
client = getattr(adapter, "_client", None)
|
||||
if not client:
|
||||
|
|
@ -119,6 +119,15 @@ def _build_discord(adapter) -> List[Dict[str, str]]:
|
|||
"guild": guild.name,
|
||||
"type": "channel",
|
||||
})
|
||||
# Forum channels (type 15) — creating a message auto-spawns a thread post.
|
||||
forums = getattr(guild, "forum_channels", None) or []
|
||||
for ch in forums:
|
||||
channels.append({
|
||||
"id": str(ch.id),
|
||||
"name": ch.name,
|
||||
"guild": guild.name,
|
||||
"type": "forum",
|
||||
})
|
||||
# Also include DM-capable users we've interacted with is not
|
||||
# feasible via guild enumeration; those come from sessions.
|
||||
|
||||
|
|
@ -191,6 +200,15 @@ def load_directory() -> Dict[str, Any]:
|
|||
return {"updated_at": None, "platforms": {}}
|
||||
|
||||
|
||||
def lookup_channel_type(platform_name: str, chat_id: str) -> Optional[str]:
|
||||
"""Return the channel ``type`` string (e.g. ``"channel"``, ``"forum"``) for *chat_id*, or *None* if unknown."""
|
||||
directory = load_directory()
|
||||
for ch in directory.get("platforms", {}).get(platform_name, []):
|
||||
if ch.get("id") == chat_id:
|
||||
return ch.get("type")
|
||||
return None
|
||||
|
||||
|
||||
def resolve_channel_name(platform_name: str, name: str) -> Optional[str]:
|
||||
"""
|
||||
Resolve a human-friendly channel name to a numeric ID.
|
||||
|
|
|
|||
|
|
@ -258,6 +258,13 @@ class GatewayConfig:
|
|||
# Streaming configuration
|
||||
streaming: StreamingConfig = field(default_factory=StreamingConfig)
|
||||
|
||||
# Session store pruning: drop SessionEntry records older than this many
|
||||
# days from the in-memory dict and sessions.json. Keeps the store from
|
||||
# growing unbounded in gateways serving many chats/threads/users over
|
||||
# months. Pruning is invisible to users — if they resume, they get a
|
||||
# fresh session exactly as if the reset policy had fired. 0 = disabled.
|
||||
session_store_max_age_days: int = 90
|
||||
|
||||
def get_connected_platforms(self) -> List[Platform]:
|
||||
"""Return list of platforms that are enabled and configured."""
|
||||
connected = []
|
||||
|
|
@ -365,6 +372,7 @@ class GatewayConfig:
|
|||
"thread_sessions_per_user": self.thread_sessions_per_user,
|
||||
"unauthorized_dm_behavior": self.unauthorized_dm_behavior,
|
||||
"streaming": self.streaming.to_dict(),
|
||||
"session_store_max_age_days": self.session_store_max_age_days,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -412,6 +420,13 @@ class GatewayConfig:
|
|||
"pair",
|
||||
)
|
||||
|
||||
try:
|
||||
session_store_max_age_days = int(data.get("session_store_max_age_days", 90))
|
||||
if session_store_max_age_days < 0:
|
||||
session_store_max_age_days = 0
|
||||
except (TypeError, ValueError):
|
||||
session_store_max_age_days = 90
|
||||
|
||||
return cls(
|
||||
platforms=platforms,
|
||||
default_reset_policy=default_policy,
|
||||
|
|
@ -426,6 +441,7 @@ class GatewayConfig:
|
|||
thread_sessions_per_user=_coerce_bool(thread_sessions_per_user, False),
|
||||
unauthorized_dm_behavior=unauthorized_dm_behavior,
|
||||
streaming=StreamingConfig.from_dict(data.get("streaming", {})),
|
||||
session_store_max_age_days=session_store_max_age_days,
|
||||
)
|
||||
|
||||
def get_unauthorized_dm_behavior(self, platform: Optional[Platform] = None) -> str:
|
||||
|
|
@ -1213,12 +1229,24 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
|||
qq_group_allowed = os.getenv("QQ_GROUP_ALLOWED_USERS", "").strip()
|
||||
if qq_group_allowed:
|
||||
extra["group_allow_from"] = qq_group_allowed
|
||||
qq_home = os.getenv("QQ_HOME_CHANNEL", "").strip()
|
||||
qq_home = os.getenv("QQBOT_HOME_CHANNEL", "").strip()
|
||||
qq_home_name_env = "QQBOT_HOME_CHANNEL_NAME"
|
||||
if not qq_home:
|
||||
# Back-compat: accept the pre-rename name and log a one-time warning.
|
||||
legacy_home = os.getenv("QQ_HOME_CHANNEL", "").strip()
|
||||
if legacy_home:
|
||||
qq_home = legacy_home
|
||||
qq_home_name_env = "QQ_HOME_CHANNEL_NAME"
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
"QQ_HOME_CHANNEL is deprecated; rename to QQBOT_HOME_CHANNEL "
|
||||
"in your .env for consistency with the platform key."
|
||||
)
|
||||
if qq_home:
|
||||
config.platforms[Platform.QQBOT].home_channel = HomeChannel(
|
||||
platform=Platform.QQBOT,
|
||||
chat_id=qq_home,
|
||||
name=os.getenv("QQ_HOME_CHANNEL_NAME", "Home"),
|
||||
name=os.getenv("QQBOT_HOME_CHANNEL_NAME") or os.getenv(qq_home_name_env, "Home"),
|
||||
)
|
||||
|
||||
# Session settings
|
||||
|
|
|
|||
|
|
@ -1045,16 +1045,40 @@ class BasePlatformAdapter(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
# Default: the adapter treats ``finalize=True`` on edit_message as a
|
||||
# no-op and is happy to have the stream consumer skip redundant final
|
||||
# edits. Subclasses that *require* an explicit finalize call to close
|
||||
# out the message lifecycle (e.g. rich card / AI assistant surfaces
|
||||
# such as DingTalk AI Cards) override this to True (class attribute or
|
||||
# property) so the stream consumer knows not to short-circuit.
|
||||
REQUIRES_EDIT_FINALIZE: bool = False
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
content: str,
|
||||
*,
|
||||
finalize: bool = False,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Edit a previously sent message. Optional — platforms that don't
|
||||
support editing return success=False and callers fall back to
|
||||
sending a new message.
|
||||
|
||||
``finalize`` signals that this is the last edit in a streaming
|
||||
sequence. Most platforms (Telegram, Slack, Discord, Matrix,
|
||||
etc.) treat it as a no-op because their edit APIs have no notion
|
||||
of message lifecycle state — an edit is an edit. Platforms that
|
||||
render streaming updates with a distinct "in progress" state and
|
||||
require explicit closure (e.g. rich card / AI assistant surfaces
|
||||
such as DingTalk AI Cards) use it to finalize the message and
|
||||
transition the UI out of the streaming indicator — those should
|
||||
also set ``REQUIRES_EDIT_FINALIZE = True`` so callers route a
|
||||
final edit through even when content is unchanged. Callers
|
||||
should set ``finalize=True`` on the final edit of a streamed
|
||||
response (typically when ``got_done`` fires in the stream
|
||||
consumer) and leave it ``False`` on intermediate edits.
|
||||
"""
|
||||
return SendResult(success=False, error="Not supported")
|
||||
|
||||
|
|
@ -1579,7 +1603,9 @@ class BasePlatformAdapter(ABC):
|
|||
# session lifecycle and its cleanup races with the running task
|
||||
# (see PR #4926).
|
||||
cmd = event.get_command()
|
||||
if cmd in ("approve", "deny", "status", "stop", "new", "reset", "background", "restart", "queue", "q"):
|
||||
from hermes_cli.commands import should_bypass_active_session
|
||||
|
||||
if should_bypass_active_session(cmd):
|
||||
logger.debug(
|
||||
"[%s] Command '/%s' bypassing active-session guard for %s",
|
||||
self.name, cmd, session_key,
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -857,6 +857,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
When metadata contains a thread_id, the message is sent to that
|
||||
thread instead of the parent channel identified by chat_id.
|
||||
|
||||
Forum channels (type 15) reject direct messages — a thread post is
|
||||
created automatically.
|
||||
"""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
|
@ -882,6 +885,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
# Forum channels reject channel.send() — create a thread post instead.
|
||||
if self._is_forum_parent(channel):
|
||||
return await self._send_to_forum(channel, content)
|
||||
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
|
@ -945,6 +952,120 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
logger.error("[%s] Failed to send Discord message: %s", self.name, e, exc_info=True)
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def _send_to_forum(self, forum_channel: Any, content: str) -> SendResult:
|
||||
"""Create a thread post in a forum channel with the message as starter content.
|
||||
|
||||
Forum channels (type 15) don't support direct messages. Instead we
|
||||
POST to /channels/{forum_id}/threads with a thread name derived from
|
||||
the first line of the message. Any follow-up chunk failures are
|
||||
reported in ``raw_response['warnings']`` so the caller can surface
|
||||
partial-send issues.
|
||||
"""
|
||||
from tools.send_message_tool import _derive_forum_thread_name
|
||||
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
thread_name = _derive_forum_thread_name(content)
|
||||
|
||||
starter_content = chunks[0] if chunks else thread_name
|
||||
|
||||
try:
|
||||
thread = await forum_channel.create_thread(
|
||||
name=thread_name,
|
||||
content=starter_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("[%s] Failed to create forum thread in %s: %s", self.name, forum_channel.id, e)
|
||||
return SendResult(success=False, error=f"Forum thread creation failed: {e}")
|
||||
|
||||
thread_channel = thread if hasattr(thread, "send") else getattr(thread, "thread", None)
|
||||
thread_id = str(getattr(thread_channel, "id", getattr(thread, "id", "")))
|
||||
starter_msg = getattr(thread, "message", None)
|
||||
message_id = str(getattr(starter_msg, "id", thread_id)) if starter_msg else thread_id
|
||||
|
||||
# Send remaining chunks into the newly created thread. Track any
|
||||
# per-chunk failures so the caller sees partial-send outcomes.
|
||||
message_ids = [message_id]
|
||||
warnings: list[str] = []
|
||||
for chunk in chunks[1:]:
|
||||
try:
|
||||
msg = await thread_channel.send(content=chunk)
|
||||
message_ids.append(str(msg.id))
|
||||
except Exception as e:
|
||||
warning = f"Failed to send follow-up chunk to forum thread {thread_id}: {e}"
|
||||
logger.warning("[%s] %s", self.name, warning)
|
||||
warnings.append(warning)
|
||||
|
||||
raw_response: Dict[str, Any] = {"message_ids": message_ids, "thread_id": thread_id}
|
||||
if warnings:
|
||||
raw_response["warnings"] = warnings
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0],
|
||||
raw_response=raw_response,
|
||||
)
|
||||
|
||||
async def _forum_post_file(
|
||||
self,
|
||||
forum_channel: Any,
|
||||
*,
|
||||
thread_name: Optional[str] = None,
|
||||
content: str = "",
|
||||
file: Any = None,
|
||||
files: Optional[list] = None,
|
||||
) -> SendResult:
|
||||
"""Create a forum thread whose starter message carries file attachments.
|
||||
|
||||
Used by the send_voice / send_image_file / send_document paths when
|
||||
the target channel is a forum (type 15). ``create_thread`` on a
|
||||
ForumChannel accepts the same file/files/content kwargs as
|
||||
``channel.send``, creating the thread and starter message atomically.
|
||||
"""
|
||||
from tools.send_message_tool import _derive_forum_thread_name
|
||||
|
||||
if not thread_name:
|
||||
# Prefer the text content, fall back to the first attached
|
||||
# filename, fall back to the generic default.
|
||||
hint = content or ""
|
||||
if not hint.strip():
|
||||
if file is not None:
|
||||
hint = getattr(file, "filename", "") or ""
|
||||
elif files:
|
||||
hint = getattr(files[0], "filename", "") or ""
|
||||
thread_name = _derive_forum_thread_name(hint) if hint.strip() else "New Post"
|
||||
|
||||
kwargs: Dict[str, Any] = {"name": thread_name}
|
||||
if content:
|
||||
kwargs["content"] = content
|
||||
if file is not None:
|
||||
kwargs["file"] = file
|
||||
if files:
|
||||
kwargs["files"] = files
|
||||
|
||||
try:
|
||||
thread = await forum_channel.create_thread(**kwargs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"[%s] Failed to create forum thread with file in %s: %s",
|
||||
self.name,
|
||||
getattr(forum_channel, "id", "?"),
|
||||
e,
|
||||
)
|
||||
return SendResult(success=False, error=f"Forum thread creation failed: {e}")
|
||||
|
||||
thread_channel = thread if hasattr(thread, "send") else getattr(thread, "thread", None)
|
||||
thread_id = str(getattr(thread_channel, "id", getattr(thread, "id", "")))
|
||||
starter_msg = getattr(thread, "message", None)
|
||||
message_id = str(getattr(starter_msg, "id", thread_id)) if starter_msg else thread_id
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_id,
|
||||
raw_response={"thread_id": thread_id},
|
||||
)
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
|
@ -975,7 +1096,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local file as a Discord attachment."""
|
||||
"""Send a local file as a Discord attachment.
|
||||
|
||||
Forum channels (type 15) get a new thread whose starter message
|
||||
carries the file — they reject direct POST /messages.
|
||||
"""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
|
@ -988,6 +1113,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
filename = file_name or os.path.basename(file_path)
|
||||
with open(file_path, "rb") as fh:
|
||||
file = discord.File(fh, filename=filename)
|
||||
if self._is_forum_parent(channel):
|
||||
return await self._forum_post_file(
|
||||
channel,
|
||||
content=(caption or "").strip(),
|
||||
file=file,
|
||||
)
|
||||
msg = await channel.send(content=caption if caption else None, file=file)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
|
|
@ -1036,6 +1167,18 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
with open(audio_path, "rb") as f:
|
||||
file_data = f.read()
|
||||
|
||||
# Forum channels (type 15) reject direct POST /messages — the
|
||||
# native voice flag path also targets /messages so it would fail
|
||||
# too. Create a thread post with the audio as the starter
|
||||
# attachment instead.
|
||||
if self._is_forum_parent(channel):
|
||||
forum_file = discord.File(io.BytesIO(file_data), filename=filename)
|
||||
return await self._forum_post_file(
|
||||
channel,
|
||||
content=(caption or "").strip(),
|
||||
file=forum_file,
|
||||
)
|
||||
|
||||
# Try sending as a native voice message via raw API (flags=8192).
|
||||
try:
|
||||
import base64
|
||||
|
|
@ -1488,6 +1631,13 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
import io
|
||||
file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}")
|
||||
|
||||
if self._is_forum_parent(channel):
|
||||
return await self._forum_post_file(
|
||||
channel,
|
||||
content=(caption or "").strip(),
|
||||
file=file,
|
||||
)
|
||||
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
|
|
@ -1550,6 +1700,13 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
import io
|
||||
file = discord.File(io.BytesIO(animation_data), filename="animation.gif")
|
||||
|
||||
if self._is_forum_parent(channel):
|
||||
return await self._forum_post_file(
|
||||
channel,
|
||||
content=(caption or "").strip(),
|
||||
file=file,
|
||||
)
|
||||
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
|
|
|
|||
|
|
@ -1228,6 +1228,10 @@ class FeishuAdapter(BasePlatformAdapter):
|
|||
.register_p2_im_chat_member_bot_deleted_v1(self._on_bot_removed_from_chat)
|
||||
.register_p2_im_chat_access_event_bot_p2p_chat_entered_v1(self._on_p2p_chat_entered)
|
||||
.register_p2_im_message_recalled_v1(self._on_message_recalled)
|
||||
.register_p2_customized_event(
|
||||
"drive.notice.comment_add_v1",
|
||||
self._on_drive_comment_event,
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
|
|
@ -1965,6 +1969,25 @@ class FeishuAdapter(BasePlatformAdapter):
|
|||
def _on_message_recalled(self, data: Any) -> None:
|
||||
logger.debug("[Feishu] Message recalled by user")
|
||||
|
||||
def _on_drive_comment_event(self, data: Any) -> None:
|
||||
"""Handle drive document comment notification (drive.notice.comment_add_v1).
|
||||
|
||||
Delegates to :mod:`gateway.platforms.feishu_comment` for parsing,
|
||||
logging, and reaction. Scheduling follows the same
|
||||
``run_coroutine_threadsafe`` pattern used by ``_on_message_event``.
|
||||
"""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
|
||||
loop = self._loop
|
||||
if not self._loop_accepts_callbacks(loop):
|
||||
logger.warning("[Feishu] Dropping drive comment event before adapter loop is ready")
|
||||
return
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
handle_drive_comment_event(self._client, data, self_open_id=self._bot_open_id),
|
||||
loop,
|
||||
)
|
||||
future.add_done_callback(self._log_background_failure)
|
||||
|
||||
def _on_reaction_event(self, event_type: str, data: Any) -> None:
|
||||
"""Route user reactions on bot messages as synthetic text events."""
|
||||
event = getattr(data, "event", None)
|
||||
|
|
@ -2590,6 +2613,8 @@ class FeishuAdapter(BasePlatformAdapter):
|
|||
self._on_reaction_event(event_type, data)
|
||||
elif event_type == "card.action.trigger":
|
||||
self._on_card_action_trigger(data)
|
||||
elif event_type == "drive.notice.comment_add_v1":
|
||||
self._on_drive_comment_event(data)
|
||||
else:
|
||||
logger.debug("[Feishu] Ignoring webhook event type: %s", event_type or "unknown")
|
||||
return web.json_response({"code": 0, "msg": "ok"})
|
||||
|
|
|
|||
1383
gateway/platforms/feishu_comment.py
Normal file
1383
gateway/platforms/feishu_comment.py
Normal file
File diff suppressed because it is too large
Load diff
429
gateway/platforms/feishu_comment_rules.py
Normal file
429
gateway/platforms/feishu_comment_rules.py
Normal file
|
|
@ -0,0 +1,429 @@
|
|||
"""
|
||||
Feishu document comment access-control rules.
|
||||
|
||||
3-tier rule resolution: exact doc > wildcard "*" > top-level > code defaults.
|
||||
Each field (enabled/policy/allow_from) falls back independently.
|
||||
Config: ~/.hermes/feishu_comment_rules.json (mtime-cached, hot-reload).
|
||||
Pairing store: ~/.hermes/feishu_comment_pairing.json.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Paths
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Uses the canonical ``get_hermes_home()`` helper (HERMES_HOME-aware and
|
||||
# profile-safe). Resolved at import time; this module is lazy-imported by
|
||||
# the Feishu comment event handler, which runs long after profile overrides
|
||||
# have been applied, so freezing paths here is safe.
|
||||
|
||||
RULES_FILE = get_hermes_home() / "feishu_comment_rules.json"
|
||||
PAIRING_FILE = get_hermes_home() / "feishu_comment_pairing.json"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_VALID_POLICIES = ("allowlist", "pairing")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommentDocumentRule:
|
||||
"""Per-document rule. ``None`` means 'inherit from lower tier'."""
|
||||
enabled: Optional[bool] = None
|
||||
policy: Optional[str] = None
|
||||
allow_from: Optional[frozenset] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommentsConfig:
|
||||
"""Top-level comment access config."""
|
||||
enabled: bool = True
|
||||
policy: str = "pairing"
|
||||
allow_from: frozenset = field(default_factory=frozenset)
|
||||
documents: Dict[str, CommentDocumentRule] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedCommentRule:
|
||||
"""Fully resolved rule after field-by-field fallback."""
|
||||
enabled: bool
|
||||
policy: str
|
||||
allow_from: frozenset
|
||||
match_source: str # e.g. "exact:docx:xxx" | "wildcard" | "top" | "default"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mtime-cached file loading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _MtimeCache:
|
||||
"""Generic mtime-based file cache. ``stat()`` per access, re-read only on change."""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
self._path = path
|
||||
self._mtime: float = 0.0
|
||||
self._data: Optional[dict] = None
|
||||
|
||||
def load(self) -> dict:
|
||||
try:
|
||||
st = self._path.stat()
|
||||
mtime = st.st_mtime
|
||||
except FileNotFoundError:
|
||||
self._mtime = 0.0
|
||||
self._data = {}
|
||||
return {}
|
||||
|
||||
if mtime == self._mtime and self._data is not None:
|
||||
return self._data
|
||||
|
||||
try:
|
||||
with open(self._path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, dict):
|
||||
data = {}
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("[Feishu-Rules] Failed to read %s, using empty config", self._path)
|
||||
data = {}
|
||||
|
||||
self._mtime = mtime
|
||||
self._data = data
|
||||
return data
|
||||
|
||||
|
||||
_rules_cache = _MtimeCache(RULES_FILE)
|
||||
_pairing_cache = _MtimeCache(PAIRING_FILE)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_frozenset(raw: Any) -> Optional[frozenset]:
|
||||
"""Parse a list of strings into a frozenset; return None if key absent."""
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, (list, tuple)):
|
||||
return frozenset(str(u).strip() for u in raw if str(u).strip())
|
||||
return None
|
||||
|
||||
|
||||
def _parse_document_rule(raw: dict) -> CommentDocumentRule:
|
||||
enabled = raw.get("enabled")
|
||||
if enabled is not None:
|
||||
enabled = bool(enabled)
|
||||
policy = raw.get("policy")
|
||||
if policy is not None:
|
||||
policy = str(policy).strip().lower()
|
||||
if policy not in _VALID_POLICIES:
|
||||
policy = None
|
||||
allow_from = _parse_frozenset(raw.get("allow_from"))
|
||||
return CommentDocumentRule(enabled=enabled, policy=policy, allow_from=allow_from)
|
||||
|
||||
|
||||
def load_config() -> CommentsConfig:
|
||||
"""Load comment rules from disk (mtime-cached)."""
|
||||
raw = _rules_cache.load()
|
||||
if not raw:
|
||||
return CommentsConfig()
|
||||
|
||||
documents: Dict[str, CommentDocumentRule] = {}
|
||||
raw_docs = raw.get("documents", {})
|
||||
if isinstance(raw_docs, dict):
|
||||
for key, rule_raw in raw_docs.items():
|
||||
if isinstance(rule_raw, dict):
|
||||
documents[str(key)] = _parse_document_rule(rule_raw)
|
||||
|
||||
policy = str(raw.get("policy", "pairing")).strip().lower()
|
||||
if policy not in _VALID_POLICIES:
|
||||
policy = "pairing"
|
||||
|
||||
return CommentsConfig(
|
||||
enabled=raw.get("enabled", True),
|
||||
policy=policy,
|
||||
allow_from=_parse_frozenset(raw.get("allow_from")) or frozenset(),
|
||||
documents=documents,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rule resolution (§8.4 field-by-field fallback)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def has_wiki_keys(cfg: CommentsConfig) -> bool:
|
||||
"""Check if any document rule key starts with 'wiki:'."""
|
||||
return any(k.startswith("wiki:") for k in cfg.documents)
|
||||
|
||||
|
||||
def resolve_rule(
|
||||
cfg: CommentsConfig,
|
||||
file_type: str,
|
||||
file_token: str,
|
||||
wiki_token: str = "",
|
||||
) -> ResolvedCommentRule:
|
||||
"""Resolve effective rule: exact doc → wiki key → wildcard → top-level → defaults."""
|
||||
exact_key = f"{file_type}:{file_token}"
|
||||
|
||||
exact = cfg.documents.get(exact_key)
|
||||
exact_src = f"exact:{exact_key}"
|
||||
if exact is None and wiki_token:
|
||||
wiki_key = f"wiki:{wiki_token}"
|
||||
exact = cfg.documents.get(wiki_key)
|
||||
exact_src = f"exact:{wiki_key}"
|
||||
|
||||
wildcard = cfg.documents.get("*")
|
||||
|
||||
layers = []
|
||||
if exact is not None:
|
||||
layers.append((exact, exact_src))
|
||||
if wildcard is not None:
|
||||
layers.append((wildcard, "wildcard"))
|
||||
|
||||
def _pick(field_name: str):
|
||||
for layer, source in layers:
|
||||
val = getattr(layer, field_name)
|
||||
if val is not None:
|
||||
return val, source
|
||||
return getattr(cfg, field_name), "top"
|
||||
|
||||
enabled, en_src = _pick("enabled")
|
||||
policy, pol_src = _pick("policy")
|
||||
allow_from, _ = _pick("allow_from")
|
||||
|
||||
# match_source = highest-priority tier that contributed any field
|
||||
priority_order = {"exact": 0, "wildcard": 1, "top": 2}
|
||||
best_src = min(
|
||||
[en_src, pol_src],
|
||||
key=lambda s: priority_order.get(s.split(":")[0], 3),
|
||||
)
|
||||
|
||||
return ResolvedCommentRule(
|
||||
enabled=enabled,
|
||||
policy=policy,
|
||||
allow_from=allow_from,
|
||||
match_source=best_src,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pairing store
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_pairing_approved() -> set:
|
||||
"""Return set of approved user open_ids (mtime-cached)."""
|
||||
data = _pairing_cache.load()
|
||||
approved = data.get("approved", {})
|
||||
if isinstance(approved, dict):
|
||||
return set(approved.keys())
|
||||
if isinstance(approved, list):
|
||||
return set(str(u) for u in approved if u)
|
||||
return set()
|
||||
|
||||
|
||||
def _save_pairing(data: dict) -> None:
|
||||
PAIRING_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = PAIRING_FILE.with_suffix(".tmp")
|
||||
with open(tmp, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
tmp.replace(PAIRING_FILE)
|
||||
# Invalidate cache so next load picks up change
|
||||
_pairing_cache._mtime = 0.0
|
||||
_pairing_cache._data = None
|
||||
|
||||
|
||||
def pairing_add(user_open_id: str) -> bool:
|
||||
"""Add a user to the pairing-approved list. Returns True if newly added."""
|
||||
data = _pairing_cache.load()
|
||||
approved = data.get("approved", {})
|
||||
if not isinstance(approved, dict):
|
||||
approved = {}
|
||||
if user_open_id in approved:
|
||||
return False
|
||||
approved[user_open_id] = {"approved_at": time.time()}
|
||||
data["approved"] = approved
|
||||
_save_pairing(data)
|
||||
return True
|
||||
|
||||
|
||||
def pairing_remove(user_open_id: str) -> bool:
|
||||
"""Remove a user from the pairing-approved list. Returns True if removed."""
|
||||
data = _pairing_cache.load()
|
||||
approved = data.get("approved", {})
|
||||
if not isinstance(approved, dict):
|
||||
return False
|
||||
if user_open_id not in approved:
|
||||
return False
|
||||
del approved[user_open_id]
|
||||
data["approved"] = approved
|
||||
_save_pairing(data)
|
||||
return True
|
||||
|
||||
|
||||
def pairing_list() -> Dict[str, Any]:
|
||||
"""Return the approved dict {user_open_id: {approved_at: ...}}."""
|
||||
data = _pairing_cache.load()
|
||||
approved = data.get("approved", {})
|
||||
return dict(approved) if isinstance(approved, dict) else {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Access check (public API for feishu_comment.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def is_user_allowed(rule: ResolvedCommentRule, user_open_id: str) -> bool:
|
||||
"""Check if user passes the resolved rule's policy gate."""
|
||||
if user_open_id in rule.allow_from:
|
||||
return True
|
||||
if rule.policy == "pairing":
|
||||
return user_open_id in _load_pairing_approved()
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _print_status() -> None:
|
||||
cfg = load_config()
|
||||
print(f"Rules file: {RULES_FILE}")
|
||||
print(f" exists: {RULES_FILE.exists()}")
|
||||
print(f"Pairing file: {PAIRING_FILE}")
|
||||
print(f" exists: {PAIRING_FILE.exists()}")
|
||||
print()
|
||||
print(f"Top-level:")
|
||||
print(f" enabled: {cfg.enabled}")
|
||||
print(f" policy: {cfg.policy}")
|
||||
print(f" allow_from: {sorted(cfg.allow_from) if cfg.allow_from else '[]'}")
|
||||
print()
|
||||
if cfg.documents:
|
||||
print(f"Document rules ({len(cfg.documents)}):")
|
||||
for key, rule in sorted(cfg.documents.items()):
|
||||
parts = []
|
||||
if rule.enabled is not None:
|
||||
parts.append(f"enabled={rule.enabled}")
|
||||
if rule.policy is not None:
|
||||
parts.append(f"policy={rule.policy}")
|
||||
if rule.allow_from is not None:
|
||||
parts.append(f"allow_from={sorted(rule.allow_from)}")
|
||||
print(f" [{key}] {', '.join(parts) if parts else '(empty — inherits all)'}")
|
||||
else:
|
||||
print("Document rules: (none)")
|
||||
print()
|
||||
approved = pairing_list()
|
||||
print(f"Pairing approved ({len(approved)}):")
|
||||
for uid, meta in sorted(approved.items()):
|
||||
ts = meta.get("approved_at", 0)
|
||||
print(f" {uid} (approved_at={ts})")
|
||||
|
||||
|
||||
def _do_check(doc_key: str, user_open_id: str) -> None:
|
||||
cfg = load_config()
|
||||
parts = doc_key.split(":", 1)
|
||||
if len(parts) != 2:
|
||||
print(f"Error: doc_key must be 'fileType:fileToken', got '{doc_key}'")
|
||||
return
|
||||
file_type, file_token = parts
|
||||
rule = resolve_rule(cfg, file_type, file_token)
|
||||
allowed = is_user_allowed(rule, user_open_id)
|
||||
print(f"Document: {doc_key}")
|
||||
print(f"User: {user_open_id}")
|
||||
print(f"Resolved rule:")
|
||||
print(f" enabled: {rule.enabled}")
|
||||
print(f" policy: {rule.policy}")
|
||||
print(f" allow_from: {sorted(rule.allow_from) if rule.allow_from else '[]'}")
|
||||
print(f" match_source: {rule.match_source}")
|
||||
print(f"Result: {'ALLOWED' if allowed else 'DENIED'}")
|
||||
|
||||
|
||||
def _main() -> int:
|
||||
import sys
|
||||
|
||||
try:
|
||||
from hermes_cli.env_loader import load_hermes_dotenv
|
||||
load_hermes_dotenv()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
usage = (
|
||||
"Usage: python -m gateway.platforms.feishu_comment_rules <command> [args]\n"
|
||||
"\n"
|
||||
"Commands:\n"
|
||||
" status Show rules config and pairing state\n"
|
||||
" check <fileType:token> <user> Simulate access check\n"
|
||||
" pairing add <user_open_id> Add user to pairing-approved list\n"
|
||||
" pairing remove <user_open_id> Remove user from pairing-approved list\n"
|
||||
" pairing list List pairing-approved users\n"
|
||||
"\n"
|
||||
f"Rules config file: {RULES_FILE}\n"
|
||||
" Edit this JSON file directly to configure policies and document rules.\n"
|
||||
" Changes take effect on the next comment event (no restart needed).\n"
|
||||
)
|
||||
|
||||
args = sys.argv[1:]
|
||||
if not args:
|
||||
print(usage)
|
||||
return 1
|
||||
|
||||
cmd = args[0]
|
||||
|
||||
if cmd == "status":
|
||||
_print_status()
|
||||
|
||||
elif cmd == "check":
|
||||
if len(args) < 3:
|
||||
print("Usage: check <fileType:fileToken> <user_open_id>")
|
||||
return 1
|
||||
_do_check(args[1], args[2])
|
||||
|
||||
elif cmd == "pairing":
|
||||
if len(args) < 2:
|
||||
print("Usage: pairing <add|remove|list> [args]")
|
||||
return 1
|
||||
sub = args[1]
|
||||
if sub == "add":
|
||||
if len(args) < 3:
|
||||
print("Usage: pairing add <user_open_id>")
|
||||
return 1
|
||||
if pairing_add(args[2]):
|
||||
print(f"Added: {args[2]}")
|
||||
else:
|
||||
print(f"Already approved: {args[2]}")
|
||||
elif sub == "remove":
|
||||
if len(args) < 3:
|
||||
print("Usage: pairing remove <user_open_id>")
|
||||
return 1
|
||||
if pairing_remove(args[2]):
|
||||
print(f"Removed: {args[2]}")
|
||||
else:
|
||||
print(f"Not in approved list: {args[2]}")
|
||||
elif sub == "list":
|
||||
approved = pairing_list()
|
||||
if not approved:
|
||||
print("(no approved users)")
|
||||
for uid, meta in sorted(approved.items()):
|
||||
print(f" {uid} approved_at={meta.get('approved_at', '?')}")
|
||||
else:
|
||||
print(f"Unknown pairing subcommand: {sub}")
|
||||
return 1
|
||||
else:
|
||||
print(f"Unknown command: {cmd}\n")
|
||||
print(usage)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(_main())
|
||||
57
gateway/platforms/qqbot/__init__.py
Normal file
57
gateway/platforms/qqbot/__init__.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
"""
|
||||
QQBot platform package.
|
||||
|
||||
Re-exports the main adapter symbols from ``adapter.py`` (the original
|
||||
``qqbot.py``) so that **all existing import paths remain unchanged**::
|
||||
|
||||
from gateway.platforms.qqbot import QQAdapter # works
|
||||
from gateway.platforms.qqbot import check_qq_requirements # works
|
||||
|
||||
New modules:
|
||||
- ``constants`` — shared constants (API URLs, timeouts, message types)
|
||||
- ``utils`` — User-Agent builder, config helpers
|
||||
- ``crypto`` — AES-256-GCM key generation and decryption
|
||||
- ``onboard`` — QR-code scan-to-configure flow
|
||||
"""
|
||||
|
||||
# -- Adapter (original qqbot.py) ------------------------------------------
|
||||
from .adapter import ( # noqa: F401
|
||||
QQAdapter,
|
||||
QQCloseError,
|
||||
check_qq_requirements,
|
||||
_coerce_list,
|
||||
_ssrf_redirect_guard,
|
||||
)
|
||||
|
||||
# -- Onboard (QR-code scan-to-configure) -----------------------------------
|
||||
from .onboard import ( # noqa: F401
|
||||
BindStatus,
|
||||
create_bind_task,
|
||||
poll_bind_result,
|
||||
build_connect_url,
|
||||
)
|
||||
from .crypto import decrypt_secret, generate_bind_key # noqa: F401
|
||||
|
||||
# -- Utils -----------------------------------------------------------------
|
||||
from .utils import build_user_agent, get_api_headers, coerce_list # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
# adapter
|
||||
"QQAdapter",
|
||||
"QQCloseError",
|
||||
"check_qq_requirements",
|
||||
"_coerce_list",
|
||||
"_ssrf_redirect_guard",
|
||||
# onboard
|
||||
"BindStatus",
|
||||
"create_bind_task",
|
||||
"poll_bind_result",
|
||||
"build_connect_url",
|
||||
# crypto
|
||||
"decrypt_secret",
|
||||
"generate_bind_key",
|
||||
# utils
|
||||
"build_user_agent",
|
||||
"get_api_headers",
|
||||
"coerce_list",
|
||||
]
|
||||
File diff suppressed because it is too large
Load diff
74
gateway/platforms/qqbot/constants.py
Normal file
74
gateway/platforms/qqbot/constants.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
"""QQBot package-level constants shared across adapter, onboard, and other modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QQBot adapter version — bump on functional changes to the adapter package.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
QQBOT_VERSION = "1.1.0"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# The portal domain is configurable via QQ_API_HOST for corporate proxies
|
||||
# or test environments. Default: q.qq.com (production).
|
||||
PORTAL_HOST = os.getenv("QQ_PORTAL_HOST", "q.qq.com")
|
||||
|
||||
API_BASE = "https://api.sgroup.qq.com"
|
||||
TOKEN_URL = "https://bots.qq.com/app/getAppAccessToken"
|
||||
GATEWAY_URL_PATH = "/gateway"
|
||||
|
||||
# QR-code onboard endpoints (on the portal host)
|
||||
ONBOARD_CREATE_PATH = "/lite/create_bind_task"
|
||||
ONBOARD_POLL_PATH = "/lite/poll_bind_result"
|
||||
QR_URL_TEMPLATE = (
|
||||
"https://q.qq.com/qqbot/openclaw/connect.html"
|
||||
"?task_id={task_id}&_wv=2&source=hermes"
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Timeouts & retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_API_TIMEOUT = 30.0
|
||||
FILE_UPLOAD_TIMEOUT = 120.0
|
||||
CONNECT_TIMEOUT_SECONDS = 20.0
|
||||
|
||||
RECONNECT_BACKOFF = [2, 5, 10, 30, 60]
|
||||
MAX_RECONNECT_ATTEMPTS = 100
|
||||
RATE_LIMIT_DELAY = 60 # seconds
|
||||
QUICK_DISCONNECT_THRESHOLD = 5.0 # seconds
|
||||
MAX_QUICK_DISCONNECT_COUNT = 3
|
||||
|
||||
ONBOARD_POLL_INTERVAL = 2.0 # seconds between poll_bind_result calls
|
||||
ONBOARD_API_TIMEOUT = 10.0
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message limits
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MAX_MESSAGE_LENGTH = 4000
|
||||
DEDUP_WINDOW_SECONDS = 300
|
||||
DEDUP_MAX_SIZE = 1000
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QQ Bot message types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MSG_TYPE_TEXT = 0
|
||||
MSG_TYPE_MARKDOWN = 2
|
||||
MSG_TYPE_MEDIA = 7
|
||||
MSG_TYPE_INPUT_NOTIFY = 6
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QQ Bot file media types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MEDIA_TYPE_IMAGE = 1
|
||||
MEDIA_TYPE_VIDEO = 2
|
||||
MEDIA_TYPE_VOICE = 3
|
||||
MEDIA_TYPE_FILE = 4
|
||||
45
gateway/platforms/qqbot/crypto.py
Normal file
45
gateway/platforms/qqbot/crypto.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
"""AES-256-GCM utilities for QQBot scan-to-configure credential decryption."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
|
||||
|
||||
def generate_bind_key() -> str:
|
||||
"""Generate a 256-bit random AES key and return it as base64.
|
||||
|
||||
The key is passed to ``create_bind_task`` so the server can encrypt
|
||||
the bot's *client_secret* before returning it. Only this CLI holds
|
||||
the key, ensuring the secret never travels in plaintext.
|
||||
"""
|
||||
return base64.b64encode(os.urandom(32)).decode()
|
||||
|
||||
|
||||
def decrypt_secret(encrypted_base64: str, key_base64: str) -> str:
|
||||
"""Decrypt a base64-encoded AES-256-GCM ciphertext.
|
||||
|
||||
Ciphertext layout (after base64-decoding)::
|
||||
|
||||
IV (12 bytes) ‖ ciphertext (N bytes) ‖ AuthTag (16 bytes)
|
||||
|
||||
Args:
|
||||
encrypted_base64: The ``bot_encrypt_secret`` value from
|
||||
``poll_bind_result``.
|
||||
key_base64: The base64 AES key generated by
|
||||
:func:`generate_bind_key`.
|
||||
|
||||
Returns:
|
||||
The decrypted *client_secret* as a UTF-8 string.
|
||||
"""
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
key = base64.b64decode(key_base64)
|
||||
raw = base64.b64decode(encrypted_base64)
|
||||
|
||||
iv = raw[:12]
|
||||
ciphertext_with_tag = raw[12:] # AESGCM expects ciphertext + tag concatenated
|
||||
|
||||
aesgcm = AESGCM(key)
|
||||
plaintext = aesgcm.decrypt(iv, ciphertext_with_tag, None)
|
||||
return plaintext.decode("utf-8")
|
||||
124
gateway/platforms/qqbot/onboard.py
Normal file
124
gateway/platforms/qqbot/onboard.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""
|
||||
QQBot scan-to-configure (QR code onboard) module.
|
||||
|
||||
Calls the ``q.qq.com`` ``create_bind_task`` / ``poll_bind_result`` APIs to
|
||||
generate a QR-code URL and poll for scan completion. On success the caller
|
||||
receives the bot's *app_id*, *client_secret* (decrypted locally), and the
|
||||
scanner's *user_openid* — enough to fully configure the QQBot gateway.
|
||||
|
||||
Reference: https://bot.q.qq.com/wiki/develop/api-v2/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from enum import IntEnum
|
||||
from typing import Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
from .constants import (
|
||||
ONBOARD_API_TIMEOUT,
|
||||
ONBOARD_CREATE_PATH,
|
||||
ONBOARD_POLL_PATH,
|
||||
PORTAL_HOST,
|
||||
QR_URL_TEMPLATE,
|
||||
)
|
||||
from .crypto import generate_bind_key
|
||||
from .utils import get_api_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bind status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BindStatus(IntEnum):
|
||||
"""Status codes returned by ``poll_bind_result``."""
|
||||
|
||||
NONE = 0
|
||||
PENDING = 1
|
||||
COMPLETED = 2
|
||||
EXPIRED = 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def create_bind_task(
|
||||
timeout: float = ONBOARD_API_TIMEOUT,
|
||||
) -> Tuple[str, str]:
|
||||
"""Create a bind task and return *(task_id, aes_key_base64)*.
|
||||
|
||||
The AES key is generated locally and sent to the server so it can
|
||||
encrypt the bot credentials before returning them.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the API returns a non-zero ``retcode``.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
url = f"https://{PORTAL_HOST}{ONBOARD_CREATE_PATH}"
|
||||
key = generate_bind_key()
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, json={"key": key}, headers=get_api_headers())
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if data.get("retcode") != 0:
|
||||
raise RuntimeError(data.get("msg", "create_bind_task failed"))
|
||||
|
||||
task_id = data.get("data", {}).get("task_id")
|
||||
if not task_id:
|
||||
raise RuntimeError("create_bind_task: missing task_id in response")
|
||||
|
||||
logger.debug("create_bind_task ok: task_id=%s", task_id)
|
||||
return task_id, key
|
||||
|
||||
|
||||
async def poll_bind_result(
|
||||
task_id: str,
|
||||
timeout: float = ONBOARD_API_TIMEOUT,
|
||||
) -> Tuple[BindStatus, str, str, str]:
|
||||
"""Poll the bind result for *task_id*.
|
||||
|
||||
Returns:
|
||||
A 4-tuple of ``(status, bot_appid, bot_encrypt_secret, user_openid)``.
|
||||
|
||||
* ``bot_encrypt_secret`` is AES-256-GCM encrypted — decrypt it with
|
||||
:func:`~gateway.platforms.qqbot.crypto.decrypt_secret` using the
|
||||
key from :func:`create_bind_task`.
|
||||
* ``user_openid`` is the OpenID of the person who scanned the code
|
||||
(available when ``status == COMPLETED``).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the API returns a non-zero ``retcode``.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
url = f"https://{PORTAL_HOST}{ONBOARD_POLL_PATH}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True) as client:
|
||||
resp = await client.post(url, json={"task_id": task_id}, headers=get_api_headers())
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if data.get("retcode") != 0:
|
||||
raise RuntimeError(data.get("msg", "poll_bind_result failed"))
|
||||
|
||||
d = data.get("data", {})
|
||||
return (
|
||||
BindStatus(d.get("status", 0)),
|
||||
str(d.get("bot_appid", "")),
|
||||
d.get("bot_encrypt_secret", ""),
|
||||
d.get("user_openid", ""),
|
||||
)
|
||||
|
||||
|
||||
def build_connect_url(task_id: str) -> str:
|
||||
"""Build the QR-code target URL for a given *task_id*."""
|
||||
return QR_URL_TEMPLATE.format(task_id=quote(task_id))
|
||||
71
gateway/platforms/qqbot/utils.py
Normal file
71
gateway/platforms/qqbot/utils.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""QQBot shared utilities — User-Agent, HTTP helpers, config coercion."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import platform
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .constants import QQBOT_VERSION
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-Agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_hermes_version() -> str:
|
||||
"""Return the hermes-agent package version, or 'dev' if unavailable."""
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
return version("hermes-agent")
|
||||
except Exception:
|
||||
return "dev"
|
||||
|
||||
|
||||
def build_user_agent() -> str:
|
||||
"""Build a descriptive User-Agent string.
|
||||
|
||||
Format::
|
||||
|
||||
QQBotAdapter/<qqbot_version> (Python/<py_version>; <os>; Hermes/<hermes_version>)
|
||||
|
||||
Example::
|
||||
|
||||
QQBotAdapter/1.0.0 (Python/3.11.15; darwin; Hermes/0.9.0)
|
||||
"""
|
||||
py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
|
||||
os_name = platform.system().lower()
|
||||
hermes_version = _get_hermes_version()
|
||||
return f"QQBotAdapter/{QQBOT_VERSION} (Python/{py_version}; {os_name}; Hermes/{hermes_version})"
|
||||
|
||||
|
||||
def get_api_headers() -> Dict[str, str]:
|
||||
"""Return standard HTTP headers for QQBot API requests.
|
||||
|
||||
Includes ``Content-Type``, ``Accept``, and a dynamic ``User-Agent``.
|
||||
``q.qq.com`` requires ``Accept: application/json`` — without it,
|
||||
the server returns a JavaScript anti-bot challenge page.
|
||||
"""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": build_user_agent(),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def coerce_list(value: Any) -> List[str]:
|
||||
"""Coerce config values into a trimmed string list.
|
||||
|
||||
Accepts comma-separated strings, lists, tuples, sets, or single values.
|
||||
"""
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
return [item.strip() for item in value.split(",") if item.strip()]
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [str(item).strip() for item in value if str(item).strip()]
|
||||
return [str(value).strip()] if str(value).strip() else []
|
||||
|
|
@ -118,6 +118,84 @@ def _strip_mdv2(text: str) -> str:
|
|||
return cleaned
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Markdown table → code block conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
# Telegram's MarkdownV2 has no table syntax — '|' is just an escaped literal,
|
||||
# so pipe tables render as noisy backslash-pipe text with no alignment.
|
||||
# Wrapping the table in a fenced code block makes Telegram render it as
|
||||
# monospace preformatted text with columns intact.
|
||||
|
||||
# Matches a GFM table delimiter row: optional outer pipes, cells containing
|
||||
# only dashes (with optional leading/trailing colons for alignment) separated
|
||||
# by '|'. Requires at least one internal '|' so lone '---' horizontal rules
|
||||
# are NOT matched.
|
||||
_TABLE_SEPARATOR_RE = re.compile(
|
||||
r'^\s*\|?\s*:?-+:?\s*(?:\|\s*:?-+:?\s*){1,}\|?\s*$'
|
||||
)
|
||||
|
||||
|
||||
def _is_table_row(line: str) -> bool:
|
||||
"""Return True if *line* could plausibly be a table data row."""
|
||||
stripped = line.strip()
|
||||
return bool(stripped) and '|' in stripped
|
||||
|
||||
|
||||
def _wrap_markdown_tables(text: str) -> str:
|
||||
"""Wrap GFM-style pipe tables in ``` fences so Telegram renders them.
|
||||
|
||||
Detected by a row containing '|' immediately followed by a delimiter
|
||||
row matching :data:`_TABLE_SEPARATOR_RE`. Subsequent pipe-containing
|
||||
non-blank lines are consumed as the table body and included in the
|
||||
wrapped block. Tables inside existing fenced code blocks are left
|
||||
alone.
|
||||
"""
|
||||
if '|' not in text or '-' not in text:
|
||||
return text
|
||||
|
||||
lines = text.split('\n')
|
||||
out: list[str] = []
|
||||
in_fence = False
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.lstrip()
|
||||
|
||||
# Track existing fenced code blocks — never touch content inside.
|
||||
if stripped.startswith('```'):
|
||||
in_fence = not in_fence
|
||||
out.append(line)
|
||||
i += 1
|
||||
continue
|
||||
if in_fence:
|
||||
out.append(line)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Look for a header row (contains '|') immediately followed by a
|
||||
# delimiter row.
|
||||
if (
|
||||
'|' in line
|
||||
and i + 1 < len(lines)
|
||||
and _TABLE_SEPARATOR_RE.match(lines[i + 1])
|
||||
):
|
||||
table_block = [line, lines[i + 1]]
|
||||
j = i + 2
|
||||
while j < len(lines) and _is_table_row(lines[j]):
|
||||
table_block.append(lines[j])
|
||||
j += 1
|
||||
out.append('```')
|
||||
out.extend(table_block)
|
||||
out.append('```')
|
||||
i = j
|
||||
continue
|
||||
|
||||
out.append(line)
|
||||
i += 1
|
||||
|
||||
return '\n'.join(out)
|
||||
|
||||
|
||||
class TelegramAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Telegram bot adapter.
|
||||
|
|
@ -1916,6 +1994,12 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
|
||||
text = content
|
||||
|
||||
# 0) Pre-wrap GFM-style pipe tables in ``` fences. Telegram can't
|
||||
# render tables natively, but fenced code blocks render as
|
||||
# monospace preformatted text with columns intact. The wrapped
|
||||
# tables then flow through step (1) below as protected regions.
|
||||
text = _wrap_markdown_tables(text)
|
||||
|
||||
# 1) Protect fenced code blocks (``` ... ```)
|
||||
# Per MarkdownV2 spec, \ and ` inside pre/code must be escaped.
|
||||
def _protect_fenced(m):
|
||||
|
|
|
|||
|
|
@ -180,6 +180,8 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_WECOM_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
|
||||
self._pending_text_batches: Dict[str, MessageEvent] = {}
|
||||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._device_id = uuid.uuid4().hex
|
||||
self._last_chat_req_ids: Dict[str, str] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection lifecycle
|
||||
|
|
@ -277,7 +279,11 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
{
|
||||
"cmd": APP_CMD_SUBSCRIBE,
|
||||
"headers": {"req_id": req_id},
|
||||
"body": {"bot_id": self._bot_id, "secret": self._secret},
|
||||
"body": {
|
||||
"bot_id": self._bot_id,
|
||||
"secret": self._secret,
|
||||
"device_id": self._device_id,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -496,6 +502,11 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
logger.debug("[%s] DM sender %s blocked by policy", self.name, sender_id)
|
||||
return
|
||||
|
||||
# Cache the inbound req_id after policy checks so proactive sends to
|
||||
# this chat can fall back to APP_CMD_RESPONSE (required for groups —
|
||||
# WeCom AI Bots cannot initiate APP_CMD_SEND in group chats).
|
||||
self._remember_chat_req_id(chat_id, self._payload_req_id(payload))
|
||||
|
||||
text, reply_text = self._extract_text(body)
|
||||
media_urls, media_types = await self._extract_media(body)
|
||||
message_type = self._derive_message_type(body, text, media_types)
|
||||
|
|
@ -847,6 +858,23 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
while len(self._reply_req_ids) > DEDUP_MAX_SIZE:
|
||||
self._reply_req_ids.pop(next(iter(self._reply_req_ids)))
|
||||
|
||||
def _remember_chat_req_id(self, chat_id: str, req_id: str) -> None:
|
||||
"""Cache the most recent inbound req_id per chat.
|
||||
|
||||
Used as a fallback reply target when we need to send into a group
|
||||
without an explicit ``reply_to`` — WeCom AI Bots are blocked from
|
||||
APP_CMD_SEND in groups and must use APP_CMD_RESPONSE bound to some
|
||||
prior req_id. Bounded like _reply_req_ids so long-running gateways
|
||||
don't leak memory across many chats.
|
||||
"""
|
||||
normalized_chat_id = str(chat_id or "").strip()
|
||||
normalized_req_id = str(req_id or "").strip()
|
||||
if not normalized_chat_id or not normalized_req_id:
|
||||
return
|
||||
self._last_chat_req_ids[normalized_chat_id] = normalized_req_id
|
||||
while len(self._last_chat_req_ids) > DEDUP_MAX_SIZE:
|
||||
self._last_chat_req_ids.pop(next(iter(self._last_chat_req_ids)))
|
||||
|
||||
def _reply_req_id_for_message(self, reply_to: Optional[str]) -> Optional[str]:
|
||||
normalized = str(reply_to or "").strip()
|
||||
if not normalized or normalized.startswith("quote:"):
|
||||
|
|
@ -1163,19 +1191,15 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
self._raise_for_wecom_error(response, "send media message")
|
||||
return response
|
||||
|
||||
async def _send_reply_stream(self, reply_req_id: str, content: str) -> Dict[str, Any]:
|
||||
async def _send_reply_markdown(self, reply_req_id: str, content: str) -> Dict[str, Any]:
|
||||
response = await self._send_reply_request(
|
||||
reply_req_id,
|
||||
{
|
||||
"msgtype": "stream",
|
||||
"stream": {
|
||||
"id": self._new_req_id("stream"),
|
||||
"finish": True,
|
||||
"content": content[:self.MAX_MESSAGE_LENGTH],
|
||||
},
|
||||
"msgtype": "markdown",
|
||||
"markdown": {"content": content[:self.MAX_MESSAGE_LENGTH]},
|
||||
},
|
||||
)
|
||||
self._raise_for_wecom_error(response, "send reply stream")
|
||||
self._raise_for_wecom_error(response, "send reply markdown")
|
||||
return response
|
||||
|
||||
async def _send_reply_media_message(
|
||||
|
|
@ -1235,6 +1259,9 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
return SendResult(success=False, error=prepared["reject_reason"])
|
||||
|
||||
reply_req_id = self._reply_req_id_for_message(reply_to)
|
||||
if not reply_req_id and chat_id in self._last_chat_req_ids:
|
||||
reply_req_id = self._last_chat_req_ids[chat_id]
|
||||
|
||||
try:
|
||||
upload_result = await self._upload_media_bytes(
|
||||
prepared["data"],
|
||||
|
|
@ -1302,8 +1329,12 @@ class WeComAdapter(BasePlatformAdapter):
|
|||
|
||||
try:
|
||||
reply_req_id = self._reply_req_id_for_message(reply_to)
|
||||
|
||||
if not reply_req_id and chat_id in self._last_chat_req_ids:
|
||||
reply_req_id = self._last_chat_req_ids[chat_id]
|
||||
|
||||
if reply_req_id:
|
||||
response = await self._send_reply_stream(reply_req_id, content)
|
||||
response = await self._send_reply_markdown(reply_req_id, content)
|
||||
else:
|
||||
response = await self._send_request(
|
||||
APP_CMD_SEND,
|
||||
|
|
|
|||
261
gateway/run.py
261
gateway/run.py
|
|
@ -2186,6 +2186,30 @@ class GatewayRunner:
|
|||
)
|
||||
except Exception as _e:
|
||||
logger.debug("Idle agent sweep failed: %s", _e)
|
||||
|
||||
# Periodically prune stale SessionStore entries. The
|
||||
# in-memory dict (and sessions.json) would otherwise grow
|
||||
# unbounded in gateways serving many rotating chats /
|
||||
# threads / users over long time windows. Pruning is
|
||||
# invisible to users — a resumed session just gets a
|
||||
# fresh session_id, exactly as if the reset policy fired.
|
||||
_last_prune_ts = getattr(self, "_last_session_store_prune_ts", 0.0)
|
||||
_prune_interval = 3600.0 # once per hour
|
||||
if time.time() - _last_prune_ts > _prune_interval:
|
||||
try:
|
||||
_max_age = int(
|
||||
getattr(self.config, "session_store_max_age_days", 0) or 0
|
||||
)
|
||||
if _max_age > 0:
|
||||
_pruned = self.session_store.prune_old_entries(_max_age)
|
||||
if _pruned:
|
||||
logger.info(
|
||||
"SessionStore prune: dropped %d stale entries",
|
||||
_pruned,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.debug("SessionStore prune failed: %s", _e)
|
||||
self._last_session_store_prune_ts = time.time()
|
||||
except Exception as e:
|
||||
logger.debug("Session expiry watcher error: %s", e)
|
||||
# Sleep in small increments so we can stop quickly
|
||||
|
|
@ -2392,6 +2416,7 @@ class GatewayRunner:
|
|||
|
||||
self.adapters.clear()
|
||||
self._running_agents.clear()
|
||||
self._running_agents_ts.clear()
|
||||
self._pending_messages.clear()
|
||||
self._pending_approvals.clear()
|
||||
if hasattr(self, '_busy_ack_ts'):
|
||||
|
|
@ -2416,6 +2441,20 @@ class GatewayRunner:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
# Close SQLite session DBs so the WAL write lock is released.
|
||||
# Without this, --replace and similar restart flows leave the
|
||||
# old gateway's connection holding the WAL lock until Python
|
||||
# actually exits — causing 'database is locked' errors when
|
||||
# the new gateway tries to open the same file.
|
||||
for _db_holder in (self, getattr(self, "session_store", None)):
|
||||
_db = getattr(_db_holder, "_db", None) if _db_holder else None
|
||||
if _db is None or not hasattr(_db, "close"):
|
||||
continue
|
||||
try:
|
||||
_db.close()
|
||||
except Exception as _e:
|
||||
logger.debug("SessionDB close error: %s", _e)
|
||||
|
||||
from gateway.status import remove_pid_file
|
||||
remove_pid_file()
|
||||
|
||||
|
|
@ -2914,16 +2953,17 @@ class GatewayRunner:
|
|||
_quick_key[:30], _stale_age, _stale_idle,
|
||||
_raw_stale_timeout, _stale_detail,
|
||||
)
|
||||
del self._running_agents[_quick_key]
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
self._busy_ack_ts.pop(_quick_key, None)
|
||||
self._release_running_agent_state(_quick_key)
|
||||
|
||||
if _quick_key in self._running_agents:
|
||||
if event.get_command() == "status":
|
||||
return await self._handle_status_command(event)
|
||||
|
||||
# Resolve the command once for all early-intercept checks below.
|
||||
from hermes_cli.commands import resolve_command as _resolve_cmd_inner
|
||||
from hermes_cli.commands import (
|
||||
resolve_command as _resolve_cmd_inner,
|
||||
should_bypass_active_session as _should_bypass_active_inner,
|
||||
)
|
||||
_evt_cmd = event.get_command()
|
||||
_cmd_def_inner = _resolve_cmd_inner(_evt_cmd) if _evt_cmd else None
|
||||
|
||||
|
|
@ -2944,8 +2984,7 @@ class GatewayRunner:
|
|||
if adapter and hasattr(adapter, 'get_pending_message'):
|
||||
adapter.get_pending_message(_quick_key) # consume and discard
|
||||
self._pending_messages.pop(_quick_key, None)
|
||||
if _quick_key in self._running_agents:
|
||||
del self._running_agents[_quick_key]
|
||||
self._release_running_agent_state(_quick_key)
|
||||
logger.info("STOP for session %s — agent interrupted, session lock released", _quick_key[:20])
|
||||
return "⚡ Stopped. You can continue this session."
|
||||
|
||||
|
|
@ -2967,8 +3006,7 @@ class GatewayRunner:
|
|||
self._pending_messages.pop(_quick_key, None)
|
||||
# Clean up the running agent entry so the reset handler
|
||||
# doesn't think an agent is still active.
|
||||
if _quick_key in self._running_agents:
|
||||
del self._running_agents[_quick_key]
|
||||
self._release_running_agent_state(_quick_key)
|
||||
return await self._handle_reset_command(event)
|
||||
|
||||
# /queue <prompt> — queue without interrupting
|
||||
|
|
@ -3002,11 +3040,29 @@ class GatewayRunner:
|
|||
return await self._handle_approve_command(event)
|
||||
return await self._handle_deny_command(event)
|
||||
|
||||
# /agents (/tasks alias) should be query-only and never interrupt.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "agents":
|
||||
return await self._handle_agents_command(event)
|
||||
|
||||
# /background must bypass the running-agent guard — it starts a
|
||||
# parallel task and must never interrupt the active conversation.
|
||||
if _cmd_def_inner and _cmd_def_inner.name == "background":
|
||||
return await self._handle_background_command(event)
|
||||
|
||||
# Gateway-handled info/control commands must never fall through to
|
||||
# the interrupt path. If they are queued as pending text, the
|
||||
# slash-command safety net discards them before the user sees any
|
||||
# response.
|
||||
if _cmd_def_inner and _should_bypass_active_inner(_cmd_def_inner.name):
|
||||
if _cmd_def_inner.name == "help":
|
||||
return await self._handle_help_command(event)
|
||||
if _cmd_def_inner.name == "commands":
|
||||
return await self._handle_commands_command(event)
|
||||
if _cmd_def_inner.name == "profile":
|
||||
return await self._handle_profile_command(event)
|
||||
if _cmd_def_inner.name == "update":
|
||||
return await self._handle_update_command(event)
|
||||
|
||||
if event.message_type == MessageType.PHOTO:
|
||||
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
||||
adapter = self.adapters.get(source.platform)
|
||||
|
|
@ -3045,8 +3101,7 @@ class GatewayRunner:
|
|||
# Agent is being set up but not ready yet.
|
||||
if event.get_command() == "stop":
|
||||
# Force-clean the sentinel so the session is unlocked.
|
||||
if _quick_key in self._running_agents:
|
||||
del self._running_agents[_quick_key]
|
||||
self._release_running_agent_state(_quick_key)
|
||||
logger.info("HARD STOP (pending) for session %s — sentinel cleared", _quick_key[:20])
|
||||
return "⚡ Force-stopped. The agent was still starting — session unlocked."
|
||||
# Queue the message so it will be picked up after the
|
||||
|
|
@ -3110,6 +3165,9 @@ class GatewayRunner:
|
|||
if canonical == "status":
|
||||
return await self._handle_status_command(event)
|
||||
|
||||
if canonical == "agents":
|
||||
return await self._handle_agents_command(event)
|
||||
|
||||
if canonical == "restart":
|
||||
return await self._handle_restart_command(event)
|
||||
|
||||
|
|
@ -3362,8 +3420,13 @@ class GatewayRunner:
|
|||
# (exception, command fallthrough, etc.) the sentinel must
|
||||
# not linger or the session would be permanently locked out.
|
||||
if self._running_agents.get(_quick_key) is _AGENT_PENDING_SENTINEL:
|
||||
del self._running_agents[_quick_key]
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
self._release_running_agent_state(_quick_key)
|
||||
else:
|
||||
# Agent path already cleaned _running_agents; make sure
|
||||
# the paired metadata dicts are gone too.
|
||||
self._running_agents_ts.pop(_quick_key, None)
|
||||
if hasattr(self, "_busy_ack_ts"):
|
||||
self._busy_ack_ts.pop(_quick_key, None)
|
||||
|
||||
async def _prepare_inbound_message_text(
|
||||
self,
|
||||
|
|
@ -4560,6 +4623,96 @@ class GatewayRunner:
|
|||
])
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_agents_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /agents command - list active agents and running tasks."""
|
||||
from tools.process_registry import format_uptime_short, process_registry
|
||||
|
||||
now = time.time()
|
||||
current_session_key = self._session_key_for_source(event.source)
|
||||
|
||||
running_agents: dict = getattr(self, "_running_agents", {}) or {}
|
||||
running_started: dict = getattr(self, "_running_agents_ts", {}) or {}
|
||||
|
||||
agent_rows: list[dict] = []
|
||||
for session_key, agent in running_agents.items():
|
||||
started = float(running_started.get(session_key, now))
|
||||
elapsed = max(0, int(now - started))
|
||||
is_pending = agent is _AGENT_PENDING_SENTINEL
|
||||
agent_rows.append(
|
||||
{
|
||||
"session_key": session_key,
|
||||
"elapsed": elapsed,
|
||||
"state": "starting" if is_pending else "running",
|
||||
"session_id": "" if is_pending else str(getattr(agent, "session_id", "") or ""),
|
||||
"model": "" if is_pending else str(getattr(agent, "model", "") or ""),
|
||||
}
|
||||
)
|
||||
|
||||
agent_rows.sort(key=lambda row: row["elapsed"], reverse=True)
|
||||
|
||||
running_processes: list[dict] = []
|
||||
try:
|
||||
running_processes = [
|
||||
p for p in process_registry.list_sessions()
|
||||
if p.get("status") == "running"
|
||||
]
|
||||
except Exception:
|
||||
running_processes = []
|
||||
|
||||
background_tasks = [
|
||||
t for t in (getattr(self, "_background_tasks", set()) or set())
|
||||
if hasattr(t, "done") and not t.done()
|
||||
]
|
||||
|
||||
lines = [
|
||||
"🤖 **Active Agents & Tasks**",
|
||||
"",
|
||||
f"**Active agents:** {len(agent_rows)}",
|
||||
]
|
||||
|
||||
if agent_rows:
|
||||
for idx, row in enumerate(agent_rows[:12], 1):
|
||||
current = " · this chat" if row["session_key"] == current_session_key else ""
|
||||
sid = f" · `{row['session_id']}`" if row["session_id"] else ""
|
||||
model = f" · `{row['model']}`" if row["model"] else ""
|
||||
lines.append(
|
||||
f"{idx}. `{row['session_key']}` · {row['state']} · "
|
||||
f"{format_uptime_short(row['elapsed'])}{sid}{model}{current}"
|
||||
)
|
||||
if len(agent_rows) > 12:
|
||||
lines.append(f"... and {len(agent_rows) - 12} more")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
f"**Running background processes:** {len(running_processes)}",
|
||||
]
|
||||
)
|
||||
if running_processes:
|
||||
for proc in running_processes[:12]:
|
||||
cmd = " ".join(str(proc.get("command", "")).split())
|
||||
if len(cmd) > 90:
|
||||
cmd = cmd[:87] + "..."
|
||||
lines.append(
|
||||
f"- `{proc.get('session_id', '?')}` · "
|
||||
f"{format_uptime_short(int(proc.get('uptime_seconds', 0)))} · `{cmd}`"
|
||||
)
|
||||
if len(running_processes) > 12:
|
||||
lines.append(f"... and {len(running_processes) - 12} more")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
f"**Gateway async jobs:** {len(background_tasks)}",
|
||||
]
|
||||
)
|
||||
|
||||
if not agent_rows and not running_processes and not background_tasks:
|
||||
lines.append("")
|
||||
lines.append("No active agents or running tasks.")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_stop_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /stop command - interrupt a running agent.
|
||||
|
|
@ -4579,16 +4732,14 @@ class GatewayRunner:
|
|||
agent = self._running_agents.get(session_key)
|
||||
if agent is _AGENT_PENDING_SENTINEL:
|
||||
# Force-clean the sentinel so the session is unlocked.
|
||||
if session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
self._release_running_agent_state(session_key)
|
||||
logger.info("STOP (pending) for session %s — sentinel cleared", session_key[:20])
|
||||
return "⚡ Stopped. The agent hadn't started yet — you can continue this session."
|
||||
if agent:
|
||||
agent.interrupt("Stop requested")
|
||||
# Force-clean the session lock so a truly hung agent doesn't
|
||||
# keep it locked forever.
|
||||
if session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
self._release_running_agent_state(session_key)
|
||||
return "⚡ Stopped. You can continue this session."
|
||||
else:
|
||||
return "No active task to stop."
|
||||
|
|
@ -6505,8 +6656,7 @@ class GatewayRunner:
|
|||
logger.debug("Memory flush on resume failed: %s", e)
|
||||
|
||||
# Clear any running agent for this session key
|
||||
if session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
self._release_running_agent_state(session_key)
|
||||
|
||||
# Switch the session entry to point at the old session
|
||||
new_entry = self.session_store.switch_session(session_key, target_id)
|
||||
|
|
@ -7922,6 +8072,30 @@ class GatewayRunner:
|
|||
override = self._session_model_overrides.get(session_key)
|
||||
return override is not None and override.get("model") == agent_model
|
||||
|
||||
def _release_running_agent_state(self, session_key: str) -> None:
|
||||
"""Pop ALL per-running-agent state entries for ``session_key``.
|
||||
|
||||
Replaces ad-hoc ``del self._running_agents[key]`` calls scattered
|
||||
across the gateway. Those sites had drifted: some popped only
|
||||
``_running_agents``; some also ``_running_agents_ts``; only one
|
||||
path also cleared ``_busy_ack_ts``. Each missed entry was a
|
||||
small, persistent leak — a (str_key → float) tuple per session
|
||||
per gateway lifetime.
|
||||
|
||||
Use this at every site that ends a running turn, regardless of
|
||||
cause (normal completion, /stop, /reset, /resume, sentinel
|
||||
cleanup, stale-eviction). Per-session state that PERSISTS
|
||||
across turns (``_session_model_overrides``, ``_voice_mode``,
|
||||
``_pending_approvals``, ``_update_prompt_pending``) is NOT
|
||||
touched here — those have their own lifecycles.
|
||||
"""
|
||||
if not session_key:
|
||||
return
|
||||
self._running_agents.pop(session_key, None)
|
||||
self._running_agents_ts.pop(session_key, None)
|
||||
if hasattr(self, "_busy_ack_ts"):
|
||||
self._busy_ack_ts.pop(session_key, None)
|
||||
|
||||
def _evict_cached_agent(self, session_key: str) -> None:
|
||||
"""Remove a cached agent for a session (called on /new, /model, etc)."""
|
||||
_lock = getattr(self, "_agent_cache_lock", None)
|
||||
|
|
@ -9757,10 +9931,8 @@ class GatewayRunner:
|
|||
|
||||
# Clean up tracking
|
||||
tracking_task.cancel()
|
||||
if session_key and session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
if session_key:
|
||||
self._running_agents_ts.pop(session_key, None)
|
||||
self._release_running_agent_state(session_key)
|
||||
if self._draining:
|
||||
self._update_runtime_status("draining")
|
||||
|
||||
|
|
@ -9889,6 +10061,16 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
|||
"Replacing existing gateway instance (PID %d) with --replace.",
|
||||
existing_pid,
|
||||
)
|
||||
# Record a takeover marker so the target's shutdown handler
|
||||
# recognises its SIGTERM as a planned takeover and exits 0
|
||||
# (rather than exit 1, which would trigger systemd's
|
||||
# Restart=on-failure and start a flap loop against us).
|
||||
# Best-effort — proceed even if the write fails.
|
||||
try:
|
||||
from gateway.status import write_takeover_marker
|
||||
write_takeover_marker(existing_pid)
|
||||
except Exception as e:
|
||||
logger.debug("Could not write takeover marker: %s", e)
|
||||
try:
|
||||
terminate_pid(existing_pid, force=False)
|
||||
except ProcessLookupError:
|
||||
|
|
@ -9898,6 +10080,13 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
|||
"Permission denied killing PID %d. Cannot replace.",
|
||||
existing_pid,
|
||||
)
|
||||
# Marker is scoped to a specific target; clean it up on
|
||||
# give-up so it doesn't grief an unrelated future shutdown.
|
||||
try:
|
||||
from gateway.status import clear_takeover_marker
|
||||
clear_takeover_marker()
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
# Wait up to 10 seconds for the old process to exit
|
||||
for _ in range(20):
|
||||
|
|
@ -9918,6 +10107,13 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
|||
except (ProcessLookupError, PermissionError, OSError):
|
||||
pass
|
||||
remove_pid_file()
|
||||
# Clean up any takeover marker the old process didn't consume
|
||||
# (e.g. SIGKILL'd before its shutdown handler could read it).
|
||||
try:
|
||||
from gateway.status import clear_takeover_marker
|
||||
clear_takeover_marker()
|
||||
except Exception:
|
||||
pass
|
||||
# Also release all scoped locks left by the old process.
|
||||
# Stopped (Ctrl+Z) processes don't release locks on exit,
|
||||
# leaving stale lock files that block the new gateway from starting.
|
||||
|
|
@ -9985,8 +10181,27 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
|||
# Set up signal handlers
|
||||
def shutdown_signal_handler():
|
||||
nonlocal _signal_initiated_shutdown
|
||||
_signal_initiated_shutdown = True
|
||||
logger.info("Received SIGTERM/SIGINT — initiating shutdown")
|
||||
# Planned --replace takeover check: when a sibling gateway is
|
||||
# taking over via --replace, it wrote a marker naming this PID
|
||||
# before sending SIGTERM. If present, treat the signal as a
|
||||
# planned shutdown and exit 0 so systemd's Restart=on-failure
|
||||
# doesn't revive us (which would flap-fight the replacer when
|
||||
# both services are enabled, e.g. hermes.service + hermes-
|
||||
# gateway.service from pre-rename installs).
|
||||
planned_takeover = False
|
||||
try:
|
||||
from gateway.status import consume_takeover_marker_for_self
|
||||
planned_takeover = consume_takeover_marker_for_self()
|
||||
except Exception as e:
|
||||
logger.debug("Takeover marker check failed: %s", e)
|
||||
|
||||
if planned_takeover:
|
||||
logger.info(
|
||||
"Received SIGTERM as a planned --replace takeover — exiting cleanly"
|
||||
)
|
||||
else:
|
||||
_signal_initiated_shutdown = True
|
||||
logger.info("Received SIGTERM/SIGINT — initiating shutdown")
|
||||
# Diagnostic: log all hermes-related processes so we can identify
|
||||
# what triggered the signal (hermes update, hermes gateway restart,
|
||||
# a stale detached subprocess, etc.).
|
||||
|
|
|
|||
|
|
@ -802,6 +802,57 @@ class SessionStore:
|
|||
return True
|
||||
return False
|
||||
|
||||
def prune_old_entries(self, max_age_days: int) -> int:
|
||||
"""Drop SessionEntry records older than max_age_days.
|
||||
|
||||
Pruning is based on ``updated_at`` (last activity), not ``created_at``.
|
||||
A session that's been active within the window is kept regardless of
|
||||
how old it is. Entries marked ``suspended`` are kept — the user
|
||||
explicitly paused them for later resume. Entries held by an active
|
||||
process (via has_active_processes_fn) are also kept so long-running
|
||||
background work isn't orphaned.
|
||||
|
||||
Pruning is functionally identical to a natural reset-policy expiry:
|
||||
the transcript in SQLite stays, but the session_key → session_id
|
||||
mapping is dropped and the user starts a fresh session on return.
|
||||
|
||||
``max_age_days <= 0`` disables pruning; returns 0 immediately.
|
||||
Returns the number of entries removed.
|
||||
"""
|
||||
if max_age_days is None or max_age_days <= 0:
|
||||
return 0
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = _now() - timedelta(days=max_age_days)
|
||||
removed_keys: list[str] = []
|
||||
|
||||
with self._lock:
|
||||
self._ensure_loaded_locked()
|
||||
for key, entry in list(self._entries.items()):
|
||||
if entry.suspended:
|
||||
continue
|
||||
# Never prune sessions with an active background process
|
||||
# attached — the user may still be waiting on output.
|
||||
if self._has_active_processes_fn is not None:
|
||||
try:
|
||||
if self._has_active_processes_fn(entry.session_id):
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
if entry.updated_at < cutoff:
|
||||
removed_keys.append(key)
|
||||
for key in removed_keys:
|
||||
self._entries.pop(key, None)
|
||||
if removed_keys:
|
||||
self._save()
|
||||
|
||||
if removed_keys:
|
||||
logger.info(
|
||||
"SessionStore pruned %d entries older than %d days",
|
||||
len(removed_keys), max_age_days,
|
||||
)
|
||||
return len(removed_keys)
|
||||
|
||||
def suspend_recently_active(self, max_age_seconds: int = 120) -> int:
|
||||
"""Mark recently-active sessions as suspended.
|
||||
|
||||
|
|
|
|||
|
|
@ -188,8 +188,8 @@ def _write_json_file(path: Path, payload: dict[str, Any]) -> None:
|
|||
path.write_text(json.dumps(payload))
|
||||
|
||||
|
||||
def _read_pid_record() -> Optional[dict]:
|
||||
pid_path = _get_pid_path()
|
||||
def _read_pid_record(pid_path: Optional[Path] = None) -> Optional[dict]:
|
||||
pid_path = pid_path or _get_pid_path()
|
||||
if not pid_path.exists():
|
||||
return None
|
||||
|
||||
|
|
@ -212,6 +212,18 @@ def _read_pid_record() -> Optional[dict]:
|
|||
return None
|
||||
|
||||
|
||||
def _cleanup_invalid_pid_path(pid_path: Path, *, cleanup_stale: bool) -> None:
|
||||
if not cleanup_stale:
|
||||
return
|
||||
try:
|
||||
if pid_path == _get_pid_path():
|
||||
remove_pid_file()
|
||||
else:
|
||||
pid_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def write_pid_file() -> None:
|
||||
"""Write the current process PID and metadata to the gateway PID file."""
|
||||
_write_json_file(_get_pid_path(), _build_pid_record())
|
||||
|
|
@ -413,43 +425,179 @@ def release_all_scoped_locks() -> int:
|
|||
return removed
|
||||
|
||||
|
||||
def get_running_pid() -> Optional[int]:
|
||||
# ── --replace takeover marker ─────────────────────────────────────────
|
||||
#
|
||||
# When a new gateway starts with ``--replace``, it SIGTERMs the existing
|
||||
# gateway so it can take over the bot token. PR #5646 made SIGTERM exit
|
||||
# the gateway with code 1 so ``Restart=on-failure`` can revive it after
|
||||
# unexpected kills — but that also means a --replace takeover target
|
||||
# exits 1, which tricks systemd into reviving it 30 seconds later,
|
||||
# starting a flap loop against the replacer when both services are
|
||||
# enabled in the user's systemd (e.g. ``hermes.service`` + ``hermes-
|
||||
# gateway.service``).
|
||||
#
|
||||
# The takeover marker breaks the loop: the replacer writes a short-lived
|
||||
# file naming the target PID + start_time BEFORE sending SIGTERM.
|
||||
# The target's shutdown handler reads the marker and, if it names
|
||||
# this process, treats the SIGTERM as a planned takeover and exits 0.
|
||||
# The marker is unlinked after the target has consumed it, so a stale
|
||||
# marker left by a crashed replacer can grief at most one future
|
||||
# shutdown on the same PID — and only within _TAKEOVER_MARKER_TTL_S.
|
||||
|
||||
_TAKEOVER_MARKER_FILENAME = ".gateway-takeover.json"
|
||||
_TAKEOVER_MARKER_TTL_S = 60 # Marker older than this is treated as stale
|
||||
|
||||
|
||||
def _get_takeover_marker_path() -> Path:
|
||||
"""Return the path to the --replace takeover marker file."""
|
||||
home = get_hermes_home()
|
||||
return home / _TAKEOVER_MARKER_FILENAME
|
||||
|
||||
|
||||
def write_takeover_marker(target_pid: int) -> bool:
|
||||
"""Record that ``target_pid`` is being replaced by the current process.
|
||||
|
||||
Captures the target's ``start_time`` so that PID reuse after the
|
||||
target exits cannot later match the marker. Also records the
|
||||
replacer's PID and a UTC timestamp for TTL-based staleness checks.
|
||||
|
||||
Returns True on successful write, False on any failure. The caller
|
||||
should proceed with the SIGTERM even if the write fails (the marker
|
||||
is a best-effort signal, not a correctness requirement).
|
||||
"""
|
||||
try:
|
||||
target_start_time = _get_process_start_time(target_pid)
|
||||
record = {
|
||||
"target_pid": target_pid,
|
||||
"target_start_time": target_start_time,
|
||||
"replacer_pid": os.getpid(),
|
||||
"written_at": _utc_now_iso(),
|
||||
}
|
||||
_write_json_file(_get_takeover_marker_path(), record)
|
||||
return True
|
||||
except (OSError, PermissionError):
|
||||
return False
|
||||
|
||||
|
||||
def consume_takeover_marker_for_self() -> bool:
|
||||
"""Check & unlink the takeover marker if it names the current process.
|
||||
|
||||
Returns True only when a valid (non-stale) marker names this PID +
|
||||
start_time. A returning True indicates the current SIGTERM is a
|
||||
planned --replace takeover; the caller should exit 0 instead of
|
||||
signalling ``_signal_initiated_shutdown``.
|
||||
|
||||
Always unlinks the marker on match (and on detected staleness) so
|
||||
subsequent unrelated signals don't re-trigger.
|
||||
"""
|
||||
path = _get_takeover_marker_path()
|
||||
record = _read_json_file(path)
|
||||
if not record:
|
||||
return False
|
||||
|
||||
# Any malformed or stale marker → drop it and return False
|
||||
try:
|
||||
target_pid = int(record["target_pid"])
|
||||
target_start_time = record.get("target_start_time")
|
||||
written_at = record.get("written_at") or ""
|
||||
except (KeyError, TypeError, ValueError):
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
|
||||
# TTL guard: a stale marker older than _TAKEOVER_MARKER_TTL_S is ignored.
|
||||
stale = False
|
||||
try:
|
||||
written_dt = datetime.fromisoformat(written_at)
|
||||
age = (datetime.now(timezone.utc) - written_dt).total_seconds()
|
||||
if age > _TAKEOVER_MARKER_TTL_S:
|
||||
stale = True
|
||||
except (TypeError, ValueError):
|
||||
stale = True # Unparseable timestamp — treat as stale
|
||||
|
||||
if stale:
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
return False
|
||||
|
||||
# Does the marker name THIS process?
|
||||
our_pid = os.getpid()
|
||||
our_start_time = _get_process_start_time(our_pid)
|
||||
matches = (
|
||||
target_pid == our_pid
|
||||
and target_start_time is not None
|
||||
and our_start_time is not None
|
||||
and target_start_time == our_start_time
|
||||
)
|
||||
|
||||
# Consume the marker whether it matched or not — a marker that doesn't
|
||||
# match our identity is stale-for-us anyway.
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def clear_takeover_marker() -> None:
|
||||
"""Remove the takeover marker unconditionally. Safe to call repeatedly."""
|
||||
try:
|
||||
_get_takeover_marker_path().unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def get_running_pid(
|
||||
pid_path: Optional[Path] = None,
|
||||
*,
|
||||
cleanup_stale: bool = True,
|
||||
) -> Optional[int]:
|
||||
"""Return the PID of a running gateway instance, or ``None``.
|
||||
|
||||
Checks the PID file and verifies the process is actually alive.
|
||||
Cleans up stale PID files automatically.
|
||||
"""
|
||||
record = _read_pid_record()
|
||||
resolved_pid_path = pid_path or _get_pid_path()
|
||||
record = _read_pid_record(resolved_pid_path)
|
||||
if not record:
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
try:
|
||||
pid = int(record["pid"])
|
||||
except (KeyError, TypeError, ValueError):
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
try:
|
||||
os.kill(pid, 0) # signal 0 = existence check, no actual signal sent
|
||||
except (ProcessLookupError, PermissionError):
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
recorded_start = record.get("start_time")
|
||||
current_start = _get_process_start_time(pid)
|
||||
if recorded_start is not None and current_start is not None and current_start != recorded_start:
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
if not _looks_like_gateway_process(pid):
|
||||
if not _record_looks_like_gateway(record):
|
||||
remove_pid_file()
|
||||
_cleanup_invalid_pid_path(resolved_pid_path, cleanup_stale=cleanup_stale)
|
||||
return None
|
||||
|
||||
return pid
|
||||
|
||||
|
||||
def is_gateway_running() -> bool:
|
||||
def is_gateway_running(
|
||||
pid_path: Optional[Path] = None,
|
||||
*,
|
||||
cleanup_stale: bool = True,
|
||||
) -> bool:
|
||||
"""Check if the gateway daemon is currently running."""
|
||||
return get_running_pid() is not None
|
||||
return get_running_pid(pid_path, cleanup_stale=cleanup_stale) is not None
|
||||
|
|
|
|||
|
|
@ -100,6 +100,14 @@ class GatewayStreamConsumer:
|
|||
self._flood_strikes = 0 # Consecutive flood-control edit failures
|
||||
self._current_edit_interval = self.cfg.edit_interval # Adaptive backoff
|
||||
self._final_response_sent = False
|
||||
# Cache adapter lifecycle capability: only platforms that need an
|
||||
# explicit finalize call (e.g. DingTalk AI Cards) force us to make
|
||||
# a redundant final edit. Everyone else keeps the fast path.
|
||||
# Use ``is True`` (not ``bool(...)``) so MagicMock attribute access
|
||||
# in tests doesn't incorrectly enable this path.
|
||||
self._adapter_requires_finalize: bool = (
|
||||
getattr(adapter, "REQUIRES_EDIT_FINALIZE", False) is True
|
||||
)
|
||||
|
||||
# Think-block filter state (mirrors CLI's _stream_delta tag suppression)
|
||||
self._in_think_block = False
|
||||
|
|
@ -361,7 +369,16 @@ class GatewayStreamConsumer:
|
|||
if not got_done and not got_segment_break and commentary_text is None:
|
||||
display_text += self.cfg.cursor
|
||||
|
||||
current_update_visible = await self._send_or_edit(display_text)
|
||||
# Segment break: finalize the current message so platforms
|
||||
# that need explicit closure (e.g. DingTalk AI Cards) don't
|
||||
# leave the previous segment stuck in a loading state when
|
||||
# the next segment (tool progress, next chunk) creates a
|
||||
# new message below it. got_done has its own finalize
|
||||
# path below so we don't finalize here for it.
|
||||
current_update_visible = await self._send_or_edit(
|
||||
display_text,
|
||||
finalize=got_segment_break,
|
||||
)
|
||||
self._last_edit_time = time.monotonic()
|
||||
|
||||
if got_done:
|
||||
|
|
@ -372,10 +389,22 @@ class GatewayStreamConsumer:
|
|||
if self._accumulated:
|
||||
if self._fallback_final_send:
|
||||
await self._send_fallback_final(self._accumulated)
|
||||
elif current_update_visible:
|
||||
elif (
|
||||
current_update_visible
|
||||
and not self._adapter_requires_finalize
|
||||
):
|
||||
# Mid-stream edit above already delivered the
|
||||
# final accumulated content. Skip the redundant
|
||||
# final edit — but only for adapters that don't
|
||||
# need an explicit finalize signal.
|
||||
self._final_response_sent = True
|
||||
elif self._message_id:
|
||||
self._final_response_sent = await self._send_or_edit(self._accumulated)
|
||||
# Either the mid-stream edit didn't run (no
|
||||
# visible update this tick) OR the adapter needs
|
||||
# explicit finalize=True to close the stream.
|
||||
self._final_response_sent = await self._send_or_edit(
|
||||
self._accumulated, finalize=True,
|
||||
)
|
||||
elif not self._already_sent:
|
||||
self._final_response_sent = await self._send_or_edit(self._accumulated)
|
||||
return
|
||||
|
|
@ -633,12 +662,15 @@ class GatewayStreamConsumer:
|
|||
logger.error("Commentary send error: %s", e)
|
||||
return False
|
||||
|
||||
async def _send_or_edit(self, text: str) -> bool:
|
||||
async def _send_or_edit(self, text: str, *, finalize: bool = False) -> bool:
|
||||
"""Send or edit the streaming message.
|
||||
|
||||
Returns True if the text was successfully delivered (sent or edited),
|
||||
False otherwise. Callers like the overflow split loop use this to
|
||||
decide whether to advance past the delivered chunk.
|
||||
|
||||
``finalize`` is True when this is the last edit in a streaming
|
||||
sequence.
|
||||
"""
|
||||
# Strip MEDIA: directives so they don't appear as visible text.
|
||||
# Media files are delivered as native attachments after the stream
|
||||
|
|
@ -672,14 +704,22 @@ class GatewayStreamConsumer:
|
|||
try:
|
||||
if self._message_id is not None:
|
||||
if self._edit_supported:
|
||||
# Skip if text is identical to what we last sent
|
||||
if text == self._last_sent_text:
|
||||
# Skip if text is identical to what we last sent.
|
||||
# Exception: adapters that require an explicit finalize
|
||||
# call (REQUIRES_EDIT_FINALIZE) must still receive the
|
||||
# finalize=True edit even when content is unchanged, so
|
||||
# their streaming UI can transition out of the in-
|
||||
# progress state. Everyone else short-circuits.
|
||||
if text == self._last_sent_text and not (
|
||||
finalize and self._adapter_requires_finalize
|
||||
):
|
||||
return True
|
||||
# Edit existing message
|
||||
result = await self.adapter.edit_message(
|
||||
chat_id=self.chat_id,
|
||||
message_id=self._message_id,
|
||||
content=text,
|
||||
finalize=finalize,
|
||||
)
|
||||
if result.success:
|
||||
self._already_sent = True
|
||||
|
|
|
|||
|
|
@ -233,6 +233,14 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
|||
api_key_env_vars=("XAI_API_KEY",),
|
||||
base_url_env_var="XAI_BASE_URL",
|
||||
),
|
||||
"nvidia": ProviderConfig(
|
||||
id="nvidia",
|
||||
name="NVIDIA NIM",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://integrate.api.nvidia.com/v1",
|
||||
api_key_env_vars=("NVIDIA_API_KEY",),
|
||||
base_url_env_var="NVIDIA_BASE_URL",
|
||||
),
|
||||
"ai-gateway": ProviderConfig(
|
||||
id="ai-gateway",
|
||||
name="Vercel AI Gateway",
|
||||
|
|
@ -2151,6 +2159,62 @@ def refresh_nous_oauth_from_state(
|
|||
)
|
||||
|
||||
|
||||
NOUS_DEVICE_CODE_SOURCE = "device_code"
|
||||
|
||||
|
||||
def persist_nous_credentials(
|
||||
creds: Dict[str, Any],
|
||||
*,
|
||||
label: Optional[str] = None,
|
||||
):
|
||||
"""Persist minted Nous OAuth credentials as the singleton provider state
|
||||
and ensure the credential pool is in sync.
|
||||
|
||||
Nous credentials are read at runtime from two independent locations:
|
||||
|
||||
- ``providers.nous``: singleton state read by
|
||||
``resolve_nous_runtime_credentials()`` during 401 recovery and by
|
||||
``_seed_from_singletons()`` during pool load.
|
||||
- ``credential_pool.nous``: used by the runtime ``pool.select()`` path.
|
||||
|
||||
Historically ``hermes auth add nous`` wrote a ``manual:device_code`` pool
|
||||
entry only, skipping ``providers.nous``. When the 24h agent_key TTL
|
||||
expired, the recovery path read the empty singleton state and raised
|
||||
``AuthError`` silently (``logger.debug`` at INFO level).
|
||||
|
||||
This helper writes ``providers.nous`` then calls ``load_pool("nous")`` so
|
||||
``_seed_from_singletons`` materialises the canonical ``device_code`` pool
|
||||
entry from the singleton. Re-running login upserts the same entry in
|
||||
place; the pool never accumulates duplicate device_code rows.
|
||||
|
||||
``label`` is an optional user-chosen display name (from
|
||||
``hermes auth add nous --label <name>``). It gets embedded in the
|
||||
singleton state so that ``_seed_from_singletons`` uses it as the pool
|
||||
entry's label on every subsequent ``load_pool("nous")`` instead of the
|
||||
auto-derived token fingerprint. When ``None``, the auto-derived label
|
||||
via ``label_from_token`` is used (unchanged default behaviour).
|
||||
|
||||
Returns the upserted :class:`PooledCredential` entry (or ``None`` if
|
||||
seeding somehow produced no match — shouldn't happen).
|
||||
"""
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
state = dict(creds)
|
||||
if label and str(label).strip():
|
||||
state["label"] = str(label).strip()
|
||||
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
_save_provider_state(auth_store, "nous", state)
|
||||
_save_auth_store(auth_store)
|
||||
|
||||
pool = load_pool("nous")
|
||||
return next(
|
||||
(e for e in pool.entries() if e.source == NOUS_DEVICE_CODE_SOURCE),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def resolve_nous_runtime_credentials(
|
||||
*,
|
||||
min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
|
||||
|
|
|
|||
|
|
@ -217,19 +217,15 @@ def auth_add_command(args) -> None:
|
|||
ca_bundle=getattr(args, "ca_bundle", None),
|
||||
min_key_ttl_seconds=max(60, int(getattr(args, "min_key_ttl_seconds", 5 * 60))),
|
||||
)
|
||||
label = (getattr(args, "label", None) or "").strip() or label_from_token(
|
||||
creds.get("access_token", ""),
|
||||
_oauth_default_label(provider, len(pool.entries()) + 1),
|
||||
# Honor `--label <name>` so nous matches other providers' UX. The
|
||||
# helper embeds this into providers.nous so that label_from_token
|
||||
# doesn't overwrite it on every subsequent load_pool("nous").
|
||||
custom_label = (getattr(args, "label", None) or "").strip() or None
|
||||
entry = auth_mod.persist_nous_credentials(creds, label=custom_label)
|
||||
shown_label = entry.label if entry is not None else label_from_token(
|
||||
creds.get("access_token", ""), _oauth_default_label(provider, 1),
|
||||
)
|
||||
entry = PooledCredential.from_dict(provider, {
|
||||
**creds,
|
||||
"label": label,
|
||||
"auth_type": AUTH_TYPE_OAUTH,
|
||||
"source": f"{SOURCE_MANUAL}:device_code",
|
||||
"base_url": creds.get("inference_base_url"),
|
||||
})
|
||||
pool.add_entry(entry)
|
||||
print(f'Added {provider} OAuth credential #{len(pool.entries())}: "{entry.label}"')
|
||||
print(f'Saved {provider} OAuth device-code credentials: "{shown_label}"')
|
||||
return
|
||||
|
||||
if provider == "openai-codex":
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ CLI tools that ship with the platform (or are commonly installed).
|
|||
|
||||
Platform support:
|
||||
macOS — osascript (always available), pngpaste (if installed)
|
||||
Windows — PowerShell via .NET System.Windows.Forms.Clipboard
|
||||
WSL2 — powershell.exe via .NET System.Windows.Forms.Clipboard
|
||||
Windows — PowerShell via WinForms, Get-Clipboard, file-drop fallback
|
||||
WSL2 — powershell.exe via WinForms, Get-Clipboard, file-drop fallback
|
||||
Linux — wl-paste (Wayland), xclip (X11)
|
||||
"""
|
||||
|
||||
|
|
@ -46,10 +46,11 @@ def has_clipboard_image() -> bool:
|
|||
return _macos_has_image()
|
||||
if sys.platform == "win32":
|
||||
return _windows_has_image()
|
||||
if _is_wsl():
|
||||
return _wsl_has_image()
|
||||
if os.environ.get("WAYLAND_DISPLAY"):
|
||||
return _wayland_has_image()
|
||||
# Match _linux_save fallthrough order: WSL → Wayland → X11
|
||||
if _is_wsl() and _wsl_has_image():
|
||||
return True
|
||||
if os.environ.get("WAYLAND_DISPLAY") and _wayland_has_image():
|
||||
return True
|
||||
return _xclip_has_image()
|
||||
|
||||
|
||||
|
|
@ -135,6 +136,114 @@ _PS_EXTRACT_IMAGE = (
|
|||
"[System.Convert]::ToBase64String($ms.ToArray())"
|
||||
)
|
||||
|
||||
_PS_CHECK_IMAGE_GET_CLIPBOARD = (
|
||||
"try { "
|
||||
"$img = Get-Clipboard -Format Image -ErrorAction Stop;"
|
||||
"if ($null -ne $img) { 'True' } else { 'False' }"
|
||||
"} catch { 'False' }"
|
||||
)
|
||||
|
||||
_PS_EXTRACT_IMAGE_GET_CLIPBOARD = (
|
||||
"try { "
|
||||
"Add-Type -AssemblyName System.Drawing;"
|
||||
"Add-Type -AssemblyName PresentationCore;"
|
||||
"Add-Type -AssemblyName WindowsBase;"
|
||||
"$img = Get-Clipboard -Format Image -ErrorAction Stop;"
|
||||
"if ($null -eq $img) { exit 1 }"
|
||||
"$ms = New-Object System.IO.MemoryStream;"
|
||||
"if ($img -is [System.Drawing.Image]) {"
|
||||
"$img.Save($ms, [System.Drawing.Imaging.ImageFormat]::Png)"
|
||||
"} elseif ($img -is [System.Windows.Media.Imaging.BitmapSource]) {"
|
||||
"$enc = New-Object System.Windows.Media.Imaging.PngBitmapEncoder;"
|
||||
"$enc.Frames.Add([System.Windows.Media.Imaging.BitmapFrame]::Create($img));"
|
||||
"$enc.Save($ms)"
|
||||
"} else { exit 2 }"
|
||||
"[System.Convert]::ToBase64String($ms.ToArray())"
|
||||
"} catch { exit 1 }"
|
||||
)
|
||||
|
||||
_FILEDROP_IMAGE_EXTS = "'.png','.jpg','.jpeg','.gif','.webp','.bmp','.tiff','.tif'"
|
||||
|
||||
_PS_CHECK_FILEDROP_IMAGE = (
|
||||
"try { "
|
||||
"$files = Get-Clipboard -Format FileDropList -ErrorAction Stop;"
|
||||
f"$exts = @({_FILEDROP_IMAGE_EXTS});"
|
||||
"$hit = $files | Where-Object { $exts -contains ([System.IO.Path]::GetExtension($_).ToLowerInvariant()) } | Select-Object -First 1;"
|
||||
"if ($null -ne $hit) { 'True' } else { 'False' }"
|
||||
"} catch { 'False' }"
|
||||
)
|
||||
|
||||
_PS_EXTRACT_FILEDROP_IMAGE = (
|
||||
"try { "
|
||||
"$files = Get-Clipboard -Format FileDropList -ErrorAction Stop;"
|
||||
f"$exts = @({_FILEDROP_IMAGE_EXTS});"
|
||||
"$hit = $files | Where-Object { $exts -contains ([System.IO.Path]::GetExtension($_).ToLowerInvariant()) } | Select-Object -First 1;"
|
||||
"if ($null -eq $hit) { exit 1 }"
|
||||
"[System.Convert]::ToBase64String([System.IO.File]::ReadAllBytes($hit))"
|
||||
"} catch { exit 1 }"
|
||||
)
|
||||
|
||||
_POWERSHELL_HAS_IMAGE_SCRIPTS = (
|
||||
_PS_CHECK_IMAGE,
|
||||
_PS_CHECK_IMAGE_GET_CLIPBOARD,
|
||||
_PS_CHECK_FILEDROP_IMAGE,
|
||||
)
|
||||
|
||||
_POWERSHELL_EXTRACT_IMAGE_SCRIPTS = (
|
||||
_PS_EXTRACT_IMAGE,
|
||||
_PS_EXTRACT_IMAGE_GET_CLIPBOARD,
|
||||
_PS_EXTRACT_FILEDROP_IMAGE,
|
||||
)
|
||||
|
||||
|
||||
def _run_powershell(exe: str, script: str, timeout: int) -> subprocess.CompletedProcess:
|
||||
return subprocess.run(
|
||||
[exe, "-NoProfile", "-NonInteractive", "-Command", script],
|
||||
capture_output=True, text=True, timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
def _write_base64_image(dest: Path, b64_data: str) -> bool:
|
||||
image_bytes = base64.b64decode(b64_data, validate=True)
|
||||
dest.write_bytes(image_bytes)
|
||||
return dest.exists() and dest.stat().st_size > 0
|
||||
|
||||
|
||||
def _powershell_has_image(exe: str, *, timeout: int, label: str) -> bool:
|
||||
for script in _POWERSHELL_HAS_IMAGE_SCRIPTS:
|
||||
try:
|
||||
r = _run_powershell(exe, script, timeout=timeout)
|
||||
if r.returncode == 0 and "True" in r.stdout:
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
logger.debug("%s not found — clipboard unavailable", exe)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("%s clipboard image check failed: %s", label, e)
|
||||
return False
|
||||
|
||||
|
||||
def _powershell_save_image(exe: str, dest: Path, *, timeout: int, label: str) -> bool:
|
||||
for script in _POWERSHELL_EXTRACT_IMAGE_SCRIPTS:
|
||||
try:
|
||||
r = _run_powershell(exe, script, timeout=timeout)
|
||||
if r.returncode != 0:
|
||||
continue
|
||||
|
||||
b64_data = r.stdout.strip()
|
||||
if not b64_data:
|
||||
continue
|
||||
|
||||
if _write_base64_image(dest, b64_data):
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
logger.debug("%s not found — clipboard unavailable", exe)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("%s clipboard image extraction failed: %s", label, e)
|
||||
dest.unlink(missing_ok=True)
|
||||
return False
|
||||
|
||||
|
||||
# ── Native Windows ────────────────────────────────────────────────────────
|
||||
|
||||
|
|
@ -175,15 +284,7 @@ def _windows_has_image() -> bool:
|
|||
ps = _get_ps_exe()
|
||||
if ps is None:
|
||||
return False
|
||||
try:
|
||||
r = subprocess.run(
|
||||
[ps, "-NoProfile", "-NonInteractive", "-Command", _PS_CHECK_IMAGE],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
return r.returncode == 0 and "True" in r.stdout
|
||||
except Exception as e:
|
||||
logger.debug("Windows clipboard image check failed: %s", e)
|
||||
return False
|
||||
return _powershell_has_image(ps, timeout=5, label="Windows")
|
||||
|
||||
|
||||
def _windows_save(dest: Path) -> bool:
|
||||
|
|
@ -192,26 +293,7 @@ def _windows_save(dest: Path) -> bool:
|
|||
if ps is None:
|
||||
logger.debug("No PowerShell found — Windows clipboard image paste unavailable")
|
||||
return False
|
||||
try:
|
||||
r = subprocess.run(
|
||||
[ps, "-NoProfile", "-NonInteractive", "-Command", _PS_EXTRACT_IMAGE],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
if r.returncode != 0:
|
||||
return False
|
||||
|
||||
b64_data = r.stdout.strip()
|
||||
if not b64_data:
|
||||
return False
|
||||
|
||||
png_bytes = base64.b64decode(b64_data)
|
||||
dest.write_bytes(png_bytes)
|
||||
return dest.exists() and dest.stat().st_size > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Windows clipboard image extraction failed: %s", e)
|
||||
dest.unlink(missing_ok=True)
|
||||
return False
|
||||
return _powershell_save_image(ps, dest, timeout=15, label="Windows")
|
||||
|
||||
|
||||
# ── Linux ────────────────────────────────────────────────────────────────
|
||||
|
|
@ -235,45 +317,12 @@ def _linux_save(dest: Path) -> bool:
|
|||
|
||||
def _wsl_has_image() -> bool:
|
||||
"""Check if Windows clipboard has an image (via powershell.exe)."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
|
||||
_PS_CHECK_IMAGE],
|
||||
capture_output=True, text=True, timeout=8,
|
||||
)
|
||||
return r.returncode == 0 and "True" in r.stdout
|
||||
except FileNotFoundError:
|
||||
logger.debug("powershell.exe not found — WSL clipboard unavailable")
|
||||
except Exception as e:
|
||||
logger.debug("WSL clipboard check failed: %s", e)
|
||||
return False
|
||||
return _powershell_has_image("powershell.exe", timeout=8, label="WSL")
|
||||
|
||||
|
||||
def _wsl_save(dest: Path) -> bool:
|
||||
"""Extract clipboard image via powershell.exe → base64 → decode to PNG."""
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
|
||||
_PS_EXTRACT_IMAGE],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
if r.returncode != 0:
|
||||
return False
|
||||
|
||||
b64_data = r.stdout.strip()
|
||||
if not b64_data:
|
||||
return False
|
||||
|
||||
png_bytes = base64.b64decode(b64_data)
|
||||
dest.write_bytes(png_bytes)
|
||||
return dest.exists() and dest.stat().st_size > 0
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.debug("powershell.exe not found — WSL clipboard unavailable")
|
||||
except Exception as e:
|
||||
logger.debug("WSL clipboard extraction failed: %s", e)
|
||||
dest.unlink(missing_ok=True)
|
||||
return False
|
||||
return _powershell_save_image("powershell.exe", dest, timeout=15, label="WSL")
|
||||
|
||||
|
||||
# ── Wayland (wl-paste) ──────────────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -87,6 +87,8 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
|||
aliases=("bg",), args_hint="<prompt>"),
|
||||
CommandDef("btw", "Ephemeral side question using session context (no tools, not persisted)", "Session",
|
||||
args_hint="<question>"),
|
||||
CommandDef("agents", "Show active agents and running tasks", "Session",
|
||||
aliases=("tasks",)),
|
||||
CommandDef("queue", "Queue a prompt for the next turn (doesn't interrupt)", "Session",
|
||||
aliases=("q",), args_hint="<prompt>"),
|
||||
CommandDef("status", "Show session info", "Session"),
|
||||
|
|
@ -99,7 +101,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
|||
# Configuration
|
||||
CommandDef("config", "Show current configuration", "Configuration",
|
||||
cli_only=True),
|
||||
CommandDef("model", "Switch model for this session", "Configuration", args_hint="[model] [--global]"),
|
||||
CommandDef("model", "Switch model for this session", "Configuration", args_hint="[model] [--provider name] [--global]"),
|
||||
CommandDef("provider", "Show available providers and current provider",
|
||||
"Configuration"),
|
||||
CommandDef("gquota", "Show Google Gemini Code Assist quota usage", "Info"),
|
||||
|
|
@ -120,7 +122,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
|||
args_hint="[normal|fast|status]",
|
||||
subcommands=("normal", "fast", "status", "on", "off")),
|
||||
CommandDef("skin", "Show or change the display skin/theme", "Configuration",
|
||||
cli_only=True, args_hint="[name]"),
|
||||
args_hint="[name]"),
|
||||
CommandDef("voice", "Toggle voice mode", "Configuration",
|
||||
args_hint="[on|off|tts|status]", subcommands=("on", "off", "tts", "status")),
|
||||
|
||||
|
|
@ -155,7 +157,9 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
|||
args_hint="[days]"),
|
||||
CommandDef("platforms", "Show gateway/messaging platform status", "Info",
|
||||
cli_only=True, aliases=("gateway",)),
|
||||
CommandDef("paste", "Check clipboard for an image and attach it", "Info",
|
||||
CommandDef("copy", "Copy the last assistant response to clipboard", "Info",
|
||||
cli_only=True, args_hint="[number]"),
|
||||
CommandDef("paste", "Attach clipboard image from your clipboard", "Info",
|
||||
cli_only=True),
|
||||
CommandDef("image", "Attach a local image file for your next prompt", "Info",
|
||||
cli_only=True, args_hint="<path>"),
|
||||
|
|
@ -254,6 +258,35 @@ GATEWAY_KNOWN_COMMANDS: frozenset[str] = frozenset(
|
|||
)
|
||||
|
||||
|
||||
# Commands that must never be queued behind an active gateway session.
|
||||
# These are explicit control/info commands handled by the gateway itself;
|
||||
# if they get queued as pending text, the safety net in gateway.run will
|
||||
# discard them before they ever reach the user.
|
||||
ACTIVE_SESSION_BYPASS_COMMANDS: frozenset[str] = frozenset(
|
||||
{
|
||||
"agents",
|
||||
"approve",
|
||||
"background",
|
||||
"commands",
|
||||
"deny",
|
||||
"help",
|
||||
"new",
|
||||
"profile",
|
||||
"queue",
|
||||
"restart",
|
||||
"status",
|
||||
"stop",
|
||||
"update",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def should_bypass_active_session(command_name: str | None) -> bool:
|
||||
"""Return True when a slash command must bypass active-session queuing."""
|
||||
cmd = resolve_command(command_name) if command_name else None
|
||||
return bool(cmd and cmd.name in ACTIVE_SESSION_BYPASS_COMMANDS)
|
||||
|
||||
|
||||
def _resolve_config_gates() -> set[str]:
|
||||
"""Return canonical names of commands whose ``gateway_config_gate`` is truthy.
|
||||
|
||||
|
|
@ -1044,6 +1077,51 @@ class SlashCommandCompleter(Completer):
|
|||
display_meta=f"{fp} {meta}" if meta else fp,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _skin_completions(sub_text: str, sub_lower: str):
|
||||
"""Yield completions for /skin from available skins."""
|
||||
try:
|
||||
from hermes_cli.skin_engine import list_skins
|
||||
for s in list_skins():
|
||||
name = s["name"]
|
||||
if name.startswith(sub_lower) and name != sub_lower:
|
||||
yield Completion(
|
||||
name,
|
||||
start_position=-len(sub_text),
|
||||
display=name,
|
||||
display_meta=s.get("description", "") or s.get("source", ""),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _personality_completions(sub_text: str, sub_lower: str):
|
||||
"""Yield completions for /personality from configured personalities."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
personalities = load_config().get("agent", {}).get("personalities", {})
|
||||
if "none".startswith(sub_lower) and "none" != sub_lower:
|
||||
yield Completion(
|
||||
"none",
|
||||
start_position=-len(sub_text),
|
||||
display="none",
|
||||
display_meta="clear personality overlay",
|
||||
)
|
||||
for name, prompt in personalities.items():
|
||||
if name.startswith(sub_lower) and name != sub_lower:
|
||||
if isinstance(prompt, dict):
|
||||
meta = prompt.get("description") or prompt.get("system_prompt", "")[:50]
|
||||
else:
|
||||
meta = str(prompt)[:50]
|
||||
yield Completion(
|
||||
name,
|
||||
start_position=-len(sub_text),
|
||||
display=name,
|
||||
display_meta=meta,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _model_completions(self, sub_text: str, sub_lower: str):
|
||||
"""Yield completions for /model from config aliases + built-in aliases."""
|
||||
seen = set()
|
||||
|
|
@ -1098,10 +1176,17 @@ class SlashCommandCompleter(Completer):
|
|||
sub_text = parts[1] if len(parts) > 1 else ""
|
||||
sub_lower = sub_text.lower()
|
||||
|
||||
# Dynamic model alias completions for /model
|
||||
if " " not in sub_text and base_cmd == "/model":
|
||||
yield from self._model_completions(sub_text, sub_lower)
|
||||
return
|
||||
# Dynamic completions for commands with runtime lists
|
||||
if " " not in sub_text:
|
||||
if base_cmd == "/model":
|
||||
yield from self._model_completions(sub_text, sub_lower)
|
||||
return
|
||||
if base_cmd == "/skin":
|
||||
yield from self._skin_completions(sub_text, sub_lower)
|
||||
return
|
||||
if base_cmd == "/personality":
|
||||
yield from self._personality_completions(sub_text, sub_lower)
|
||||
return
|
||||
|
||||
# Static subcommand completions
|
||||
if " " not in sub_text and base_cmd in SUBCOMMANDS and self._command_allowed(base_cmd):
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ This module provides:
|
|||
- hermes config wizard - Re-run setup wizard
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
|
|
@ -26,6 +27,7 @@ from typing import Dict, Any, Optional, List, Tuple
|
|||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
_ENV_VAR_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
_LAST_EXPANDED_CONFIG_BY_PATH: Dict[str, Any] = {}
|
||||
# Env var names written to .env that aren't in OPTIONAL_ENV_VARS
|
||||
# (managed by setup/provider flows directly).
|
||||
_EXTRA_ENV_KEYS = frozenset({
|
||||
|
|
@ -44,7 +46,8 @@ _EXTRA_ENV_KEYS = frozenset({
|
|||
"WEIXIN_HOME_CHANNEL", "WEIXIN_HOME_CHANNEL_NAME", "WEIXIN_DM_POLICY", "WEIXIN_GROUP_POLICY",
|
||||
"WEIXIN_ALLOWED_USERS", "WEIXIN_GROUP_ALLOWED_USERS", "WEIXIN_ALLOW_ALL_USERS",
|
||||
"BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_PASSWORD",
|
||||
"QQ_APP_ID", "QQ_CLIENT_SECRET", "QQ_HOME_CHANNEL", "QQ_HOME_CHANNEL_NAME",
|
||||
"QQ_APP_ID", "QQ_CLIENT_SECRET", "QQBOT_HOME_CHANNEL", "QQBOT_HOME_CHANNEL_NAME",
|
||||
"QQ_HOME_CHANNEL", "QQ_HOME_CHANNEL_NAME", # legacy aliases (pre-rename, still read for back-compat)
|
||||
"QQ_ALLOWED_USERS", "QQ_GROUP_ALLOWED_USERS", "QQ_ALLOW_ALL_USERS", "QQ_MARKDOWN_SUPPORT",
|
||||
"QQ_STT_API_KEY", "QQ_STT_BASE_URL", "QQ_STT_MODEL",
|
||||
"TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT",
|
||||
|
|
@ -417,6 +420,7 @@ DEFAULT_CONFIG = {
|
|||
"command_timeout": 30, # Timeout for browser commands in seconds (screenshot, navigate, etc.)
|
||||
"record_sessions": False, # Auto-record browser sessions as WebM videos
|
||||
"allow_private_urls": False, # Allow navigating to private/internal IPs (localhost, 192.168.x.x, etc.)
|
||||
"cdp_url": "", # Optional persistent CDP endpoint for attaching to an existing Chromium/Chrome
|
||||
"camofox": {
|
||||
# When true, Hermes sends a stable profile-scoped userId to Camofox
|
||||
# so the server maps it to a persistent Firefox profile automatically.
|
||||
|
|
@ -537,6 +541,13 @@ DEFAULT_CONFIG = {
|
|||
"api_key": "",
|
||||
"timeout": 30,
|
||||
},
|
||||
"title_generation": {
|
||||
"provider": "auto",
|
||||
"model": "",
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
"timeout": 30,
|
||||
},
|
||||
},
|
||||
|
||||
"display": {
|
||||
|
|
@ -861,6 +872,22 @@ OPTIONAL_ENV_VARS = {
|
|||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"NVIDIA_API_KEY": {
|
||||
"description": "NVIDIA NIM API key (build.nvidia.com or local NIM endpoint)",
|
||||
"prompt": "NVIDIA NIM API key",
|
||||
"url": "https://build.nvidia.com/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"NVIDIA_BASE_URL": {
|
||||
"description": "NVIDIA NIM base URL override (e.g. http://localhost:8000/v1 for local NIM)",
|
||||
"prompt": "NVIDIA NIM base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"GLM_API_KEY": {
|
||||
"description": "Z.AI / GLM API key (also recognized as ZAI_API_KEY / Z_AI_API_KEY)",
|
||||
"prompt": "Z.AI / GLM API key",
|
||||
|
|
@ -1518,12 +1545,12 @@ OPTIONAL_ENV_VARS = {
|
|||
"prompt": "Allow All QQ Users",
|
||||
"category": "messaging",
|
||||
},
|
||||
"QQ_HOME_CHANNEL": {
|
||||
"QQBOT_HOME_CHANNEL": {
|
||||
"description": "Default QQ channel/group for cron delivery and notifications",
|
||||
"prompt": "QQ Home Channel",
|
||||
"category": "messaging",
|
||||
},
|
||||
"QQ_HOME_CHANNEL_NAME": {
|
||||
"QQBOT_HOME_CHANNEL_NAME": {
|
||||
"description": "Display name for the QQ home channel",
|
||||
"prompt": "QQ Home Channel Name",
|
||||
"category": "messaging",
|
||||
|
|
@ -2610,6 +2637,85 @@ def _expand_env_vars(obj):
|
|||
return obj
|
||||
|
||||
|
||||
def _items_by_unique_name(items):
|
||||
"""Return a name-indexed dict only when all items have unique string names."""
|
||||
if not isinstance(items, list):
|
||||
return None
|
||||
indexed = {}
|
||||
for item in items:
|
||||
if not isinstance(item, dict) or not isinstance(item.get("name"), str):
|
||||
return None
|
||||
name = item["name"]
|
||||
if name in indexed:
|
||||
return None
|
||||
indexed[name] = item
|
||||
return indexed
|
||||
|
||||
|
||||
def _preserve_env_ref_templates(current, raw, loaded_expanded=None):
|
||||
"""Restore raw ``${VAR}`` templates when a value is otherwise unchanged.
|
||||
|
||||
``load_config()`` expands env refs for runtime use. When a caller later
|
||||
persists that config after modifying some unrelated setting, keep the
|
||||
original on-disk template instead of writing the expanded plaintext
|
||||
secret back to ``config.yaml``.
|
||||
|
||||
Prefer preserving the raw template when ``current`` still matches either
|
||||
the value previously returned by ``load_config()`` for this config path or
|
||||
the current environment expansion of ``raw``. This handles env-var
|
||||
rotation between load and save while still treating mixed literal/template
|
||||
string edits as caller-owned once their rendered value diverges.
|
||||
"""
|
||||
if isinstance(current, str) and isinstance(raw, str) and re.search(r"\${[^}]+}", raw):
|
||||
if current == raw:
|
||||
return raw
|
||||
if isinstance(loaded_expanded, str) and current == loaded_expanded:
|
||||
return raw
|
||||
if _expand_env_vars(raw) == current:
|
||||
return raw
|
||||
return current
|
||||
|
||||
if isinstance(current, dict) and isinstance(raw, dict):
|
||||
return {
|
||||
key: _preserve_env_ref_templates(
|
||||
value,
|
||||
raw.get(key),
|
||||
loaded_expanded.get(key) if isinstance(loaded_expanded, dict) else None,
|
||||
)
|
||||
for key, value in current.items()
|
||||
}
|
||||
|
||||
if isinstance(current, list) and isinstance(raw, list):
|
||||
# Prefer matching named config objects (e.g. custom_providers) by name
|
||||
# so harmless reordering doesn't drop the original template. If names
|
||||
# are duplicated, fall back to positional matching instead of silently
|
||||
# shadowing one entry.
|
||||
current_by_name = _items_by_unique_name(current)
|
||||
raw_by_name = _items_by_unique_name(raw)
|
||||
loaded_by_name = _items_by_unique_name(loaded_expanded)
|
||||
if current_by_name is not None and raw_by_name is not None:
|
||||
return [
|
||||
_preserve_env_ref_templates(
|
||||
item,
|
||||
raw_by_name.get(item.get("name")),
|
||||
loaded_by_name.get(item.get("name")) if loaded_by_name is not None else None,
|
||||
)
|
||||
for item in current
|
||||
]
|
||||
return [
|
||||
_preserve_env_ref_templates(
|
||||
item,
|
||||
raw[index] if index < len(raw) else None,
|
||||
loaded_expanded[index]
|
||||
if isinstance(loaded_expanded, list) and index < len(loaded_expanded)
|
||||
else None,
|
||||
)
|
||||
for index, item in enumerate(current)
|
||||
]
|
||||
|
||||
return current
|
||||
|
||||
|
||||
def _normalize_root_model_keys(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Move stale root-level provider/base_url into model section.
|
||||
|
||||
|
|
@ -2677,7 +2783,6 @@ def read_raw_config() -> Dict[str, Any]:
|
|||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
"""Load configuration from ~/.hermes/config.yaml."""
|
||||
import copy
|
||||
ensure_hermes_home()
|
||||
config_path = get_config_path()
|
||||
|
||||
|
|
@ -2698,8 +2803,11 @@ def load_config() -> Dict[str, Any]:
|
|||
config = _deep_merge(config, user_config)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load config: {e}")
|
||||
|
||||
return _expand_env_vars(_normalize_root_model_keys(_normalize_max_turns_config(config)))
|
||||
|
||||
normalized = _normalize_root_model_keys(_normalize_max_turns_config(config))
|
||||
expanded = _expand_env_vars(normalized)
|
||||
_LAST_EXPANDED_CONFIG_BY_PATH[str(config_path)] = copy.deepcopy(expanded)
|
||||
return expanded
|
||||
|
||||
|
||||
_SECURITY_COMMENT = """
|
||||
|
|
@ -2808,7 +2916,15 @@ def save_config(config: Dict[str, Any]):
|
|||
|
||||
ensure_hermes_home()
|
||||
config_path = get_config_path()
|
||||
normalized = _normalize_root_model_keys(_normalize_max_turns_config(config))
|
||||
current_normalized = _normalize_root_model_keys(_normalize_max_turns_config(config))
|
||||
normalized = current_normalized
|
||||
raw_existing = _normalize_root_model_keys(_normalize_max_turns_config(read_raw_config()))
|
||||
if raw_existing:
|
||||
normalized = _preserve_env_ref_templates(
|
||||
normalized,
|
||||
raw_existing,
|
||||
_LAST_EXPANDED_CONFIG_BY_PATH.get(str(config_path)),
|
||||
)
|
||||
|
||||
# Build optional commented-out sections for features that are off by
|
||||
# default or only relevant when explicitly configured.
|
||||
|
|
@ -2826,6 +2942,7 @@ def save_config(config: Dict[str, Any]):
|
|||
extra_content="".join(parts) if parts else None,
|
||||
)
|
||||
_secure_file(config_path)
|
||||
_LAST_EXPANDED_CONFIG_BY_PATH[str(config_path)] = copy.deepcopy(current_normalized)
|
||||
|
||||
|
||||
def load_env() -> Dict[str, str]:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,10 @@ Currently supports:
|
|||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
|
@ -31,6 +34,119 @@ _MAX_LOG_BYTES = 512_000
|
|||
_AUTO_DELETE_SECONDS = 21600
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pending-deletion tracking (replaces the old fork-and-sleep subprocess).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _pending_file() -> Path:
|
||||
"""Path to ``~/.hermes/pastes/pending.json``.
|
||||
|
||||
Each entry: ``{"url": "...", "expire_at": <unix_ts>}``. Scheduled
|
||||
DELETEs used to be handled by spawning a detached Python process per
|
||||
paste that slept for 6 hours; those accumulated forever if the user
|
||||
ran ``hermes debug share`` repeatedly. We now persist the schedule
|
||||
to disk and sweep expired entries on the next debug invocation.
|
||||
"""
|
||||
return get_hermes_home() / "pastes" / "pending.json"
|
||||
|
||||
|
||||
def _load_pending() -> list[dict]:
|
||||
path = _pending_file()
|
||||
if not path.exists():
|
||||
return []
|
||||
try:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
if isinstance(data, list):
|
||||
# Filter to well-formed entries only
|
||||
return [
|
||||
e for e in data
|
||||
if isinstance(e, dict) and "url" in e and "expire_at" in e
|
||||
]
|
||||
except (OSError, ValueError, json.JSONDecodeError):
|
||||
pass
|
||||
return []
|
||||
|
||||
|
||||
def _save_pending(entries: list[dict]) -> None:
|
||||
path = _pending_file()
|
||||
try:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(".json.tmp")
|
||||
tmp.write_text(json.dumps(entries, indent=2), encoding="utf-8")
|
||||
os.replace(tmp, path)
|
||||
except OSError:
|
||||
# Non-fatal — worst case the user has to run ``hermes debug delete``
|
||||
# manually.
|
||||
pass
|
||||
|
||||
|
||||
def _record_pending(urls: list[str], delay_seconds: int = _AUTO_DELETE_SECONDS) -> None:
|
||||
"""Record *urls* for deletion at ``now + delay_seconds``.
|
||||
|
||||
Only paste.rs URLs are recorded (dpaste.com auto-expires). Entries
|
||||
are merged into any existing pending.json.
|
||||
"""
|
||||
paste_rs_urls = [u for u in urls if _extract_paste_id(u)]
|
||||
if not paste_rs_urls:
|
||||
return
|
||||
|
||||
entries = _load_pending()
|
||||
# Dedupe by URL: keep the later expire_at if same URL appears twice
|
||||
by_url: dict[str, float] = {e["url"]: float(e["expire_at"]) for e in entries}
|
||||
expire_at = time.time() + delay_seconds
|
||||
for u in paste_rs_urls:
|
||||
by_url[u] = max(expire_at, by_url.get(u, 0.0))
|
||||
merged = [{"url": u, "expire_at": ts} for u, ts in by_url.items()]
|
||||
_save_pending(merged)
|
||||
|
||||
|
||||
def _sweep_expired_pastes(now: Optional[float] = None) -> tuple[int, int]:
|
||||
"""Synchronously DELETE any pending pastes whose ``expire_at`` has passed.
|
||||
|
||||
Returns ``(deleted, remaining)``. Best-effort: failed deletes stay in
|
||||
the pending file and will be retried on the next sweep. Silent —
|
||||
intended to be called from every ``hermes debug`` invocation with
|
||||
minimal noise.
|
||||
"""
|
||||
entries = _load_pending()
|
||||
if not entries:
|
||||
return (0, 0)
|
||||
|
||||
current = time.time() if now is None else now
|
||||
deleted = 0
|
||||
remaining: list[dict] = []
|
||||
|
||||
for entry in entries:
|
||||
try:
|
||||
expire_at = float(entry.get("expire_at", 0))
|
||||
except (TypeError, ValueError):
|
||||
continue # drop malformed entries
|
||||
if expire_at > current:
|
||||
remaining.append(entry)
|
||||
continue
|
||||
|
||||
url = entry.get("url", "")
|
||||
try:
|
||||
if delete_paste(url):
|
||||
deleted += 1
|
||||
continue
|
||||
except Exception:
|
||||
# Network hiccup, 404 (already gone), etc. — drop the entry
|
||||
# after a grace period; don't retry forever.
|
||||
pass
|
||||
|
||||
# Retain failed deletes for up to 24h past expiration, then give up.
|
||||
if expire_at + 86400 > current:
|
||||
remaining.append(entry)
|
||||
else:
|
||||
deleted += 1 # count as reaped (paste.rs will GC eventually)
|
||||
|
||||
if deleted:
|
||||
_save_pending(remaining)
|
||||
|
||||
return (deleted, len(remaining))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Privacy / delete helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -90,37 +206,19 @@ def delete_paste(url: str) -> bool:
|
|||
|
||||
|
||||
def _schedule_auto_delete(urls: list[str], delay_seconds: int = _AUTO_DELETE_SECONDS):
|
||||
"""Spawn a detached process to delete paste.rs pastes after *delay_seconds*.
|
||||
"""Record *urls* for deletion ``delay_seconds`` from now.
|
||||
|
||||
The child process is fully detached (``start_new_session=True``) so it
|
||||
survives the parent exiting (important for CLI mode). Only paste.rs
|
||||
URLs are attempted — dpaste.com pastes auto-expire on their own.
|
||||
Previously this spawned a detached Python subprocess per call that slept
|
||||
for 6 hours and then issued DELETE requests. Those subprocesses leaked —
|
||||
every ``hermes debug share`` invocation added ~20 MB of resident Python
|
||||
interpreters that never exited until the sleep completed.
|
||||
|
||||
The replacement is stateless: we append to ``~/.hermes/pastes/pending.json``
|
||||
and rely on opportunistic sweeps (``_sweep_expired_pastes``) called from
|
||||
every ``hermes debug`` invocation. If the user never runs ``hermes debug``
|
||||
again, paste.rs's own retention policy handles cleanup.
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
paste_rs_urls = [u for u in urls if _extract_paste_id(u)]
|
||||
if not paste_rs_urls:
|
||||
return
|
||||
|
||||
# Build a tiny inline Python script. No imports beyond stdlib.
|
||||
url_list = ", ".join(f'"{u}"' for u in paste_rs_urls)
|
||||
script = (
|
||||
"import time, urllib.request; "
|
||||
f"time.sleep({delay_seconds}); "
|
||||
f"[urllib.request.urlopen(urllib.request.Request(u, method='DELETE', "
|
||||
f"headers={{'User-Agent': 'hermes-agent/auto-delete'}}), timeout=15) "
|
||||
f"for u in [{url_list}]]"
|
||||
)
|
||||
|
||||
try:
|
||||
subprocess.Popen(
|
||||
[sys.executable, "-c", script],
|
||||
start_new_session=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
except Exception:
|
||||
pass # Best-effort; manual delete still available.
|
||||
_record_pending(urls, delay_seconds=delay_seconds)
|
||||
|
||||
|
||||
def _delete_hint(url: str) -> str:
|
||||
|
|
@ -455,6 +553,16 @@ def run_debug_delete(args):
|
|||
|
||||
def run_debug(args):
|
||||
"""Route debug subcommands."""
|
||||
# Opportunistic sweep of expired pastes on every ``hermes debug`` call.
|
||||
# Replaces the old per-paste sleeping subprocess that used to leak as
|
||||
# one orphaned Python interpreter per scheduled deletion. Silent and
|
||||
# best-effort — any failure is swallowed so ``hermes debug`` stays
|
||||
# reliable even when offline.
|
||||
try:
|
||||
_sweep_expired_pastes()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
subcmd = getattr(args, "debug_command", None)
|
||||
if subcmd == "share":
|
||||
run_debug_share(args)
|
||||
|
|
|
|||
|
|
@ -825,6 +825,7 @@ def run_doctor(args):
|
|||
("Arcee AI", ("ARCEEAI_API_KEY",), "https://api.arcee.ai/api/v1/models", "ARCEE_BASE_URL", True),
|
||||
("DeepSeek", ("DEEPSEEK_API_KEY",), "https://api.deepseek.com/v1/models", "DEEPSEEK_BASE_URL", True),
|
||||
("Hugging Face", ("HF_TOKEN",), "https://router.huggingface.co/v1/models", "HF_BASE_URL", True),
|
||||
("NVIDIA NIM", ("NVIDIA_API_KEY",), "https://integrate.api.nvidia.com/v1/models", "NVIDIA_BASE_URL", True),
|
||||
("Alibaba/DashScope", ("DASHSCOPE_API_KEY",), "https://dashscope-intl.aliyuncs.com/compatible-mode/v1/models", "DASHSCOPE_BASE_URL", True),
|
||||
# MiniMax: the /anthropic endpoint doesn't support /models, but the /v1 endpoint does.
|
||||
("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL", True),
|
||||
|
|
|
|||
|
|
@ -43,41 +43,20 @@ def _redact(value: str) -> str:
|
|||
|
||||
def _gateway_status() -> str:
|
||||
"""Return a short gateway status string."""
|
||||
if sys.platform.startswith("linux"):
|
||||
from hermes_constants import is_container
|
||||
if is_container():
|
||||
try:
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
pids = find_gateway_pids()
|
||||
if pids:
|
||||
return f"running (docker, pid {pids[0]})"
|
||||
return "stopped (docker)"
|
||||
except Exception:
|
||||
return "stopped (docker)"
|
||||
try:
|
||||
from hermes_cli.gateway import get_service_name
|
||||
svc = get_service_name()
|
||||
except Exception:
|
||||
svc = "hermes-gateway"
|
||||
try:
|
||||
r = subprocess.run(
|
||||
["systemctl", "--user", "is-active", svc],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
return "running (systemd)" if r.stdout.strip() == "active" else "stopped"
|
||||
except Exception:
|
||||
return "unknown"
|
||||
elif sys.platform == "darwin":
|
||||
try:
|
||||
from hermes_cli.gateway import get_launchd_label
|
||||
r = subprocess.run(
|
||||
["launchctl", "list", get_launchd_label()],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
return "loaded (launchd)" if r.returncode == 0 else "not loaded"
|
||||
except Exception:
|
||||
return "unknown"
|
||||
return "N/A"
|
||||
try:
|
||||
from hermes_cli.gateway import get_gateway_runtime_snapshot
|
||||
|
||||
snapshot = get_gateway_runtime_snapshot()
|
||||
if snapshot.running:
|
||||
mode = snapshot.manager
|
||||
if snapshot.has_process_service_mismatch:
|
||||
mode = "manual"
|
||||
return f"running ({mode}, pid {snapshot.gateway_pids[0]})"
|
||||
if snapshot.service_installed and not snapshot.service_running:
|
||||
return f"stopped ({snapshot.manager})"
|
||||
return f"stopped ({snapshot.manager})"
|
||||
except Exception:
|
||||
return "unknown" if sys.platform.startswith(("linux", "darwin")) else "N/A"
|
||||
|
||||
|
||||
def _count_skills(hermes_home: Path) -> int:
|
||||
|
|
@ -296,6 +275,7 @@ def run_dump(args):
|
|||
("DEEPSEEK_API_KEY", "deepseek"),
|
||||
("DASHSCOPE_API_KEY", "dashscope"),
|
||||
("HF_TOKEN", "huggingface"),
|
||||
("NVIDIA_API_KEY", "nvidia"),
|
||||
("AI_GATEWAY_API_KEY", "ai_gateway"),
|
||||
("OPENCODE_ZEN_API_KEY", "opencode_zen"),
|
||||
("OPENCODE_GO_API_KEY", "opencode_go"),
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import shutil
|
|||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
|
@ -41,6 +42,23 @@ from hermes_cli.colors import Colors, color
|
|||
# Process Management (for manual gateway runs)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GatewayRuntimeSnapshot:
|
||||
manager: str
|
||||
service_installed: bool = False
|
||||
service_running: bool = False
|
||||
gateway_pids: tuple[int, ...] = ()
|
||||
service_scope: str | None = None
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
return self.service_running or bool(self.gateway_pids)
|
||||
|
||||
@property
|
||||
def has_process_service_mismatch(self) -> bool:
|
||||
return self.service_installed and self.running and not self.service_running
|
||||
|
||||
def _get_service_pids() -> set:
|
||||
"""Return PIDs currently managed by systemd or launchd gateway services.
|
||||
|
||||
|
|
@ -157,20 +175,22 @@ def _request_gateway_self_restart(pid: int) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = False) -> list:
|
||||
"""Find PIDs of running gateway processes.
|
||||
def _append_unique_pid(pids: list[int], pid: int | None, exclude_pids: set[int]) -> None:
|
||||
if pid is None or pid <= 0:
|
||||
return
|
||||
if pid == os.getpid() or pid in exclude_pids or pid in pids:
|
||||
return
|
||||
pids.append(pid)
|
||||
|
||||
Args:
|
||||
exclude_pids: PIDs to exclude from the result (e.g. service-managed
|
||||
PIDs that should not be killed during a stale-process sweep).
|
||||
all_profiles: When ``True``, return gateway PIDs across **all**
|
||||
profiles (the pre-7923 global behaviour). ``hermes update``
|
||||
needs this because a code update affects every profile.
|
||||
When ``False`` (default), only PIDs belonging to the current
|
||||
Hermes profile are returned.
|
||||
|
||||
def _scan_gateway_pids(exclude_pids: set[int], all_profiles: bool = False) -> list[int]:
|
||||
"""Best-effort process-table scan for gateway PIDs.
|
||||
|
||||
This supplements the profile-scoped PID file so status views can still spot
|
||||
a live gateway when the PID file is stale/missing, and ``--all`` sweeps can
|
||||
discover gateways outside the current profile.
|
||||
"""
|
||||
_exclude = exclude_pids or set()
|
||||
pids = [pid for pid in _get_service_pids() if pid not in _exclude]
|
||||
pids: list[int] = []
|
||||
patterns = [
|
||||
"hermes_cli.main gateway",
|
||||
"hermes_cli.main --profile",
|
||||
|
|
@ -203,20 +223,24 @@ def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = Fals
|
|||
if is_windows():
|
||||
result = subprocess.run(
|
||||
["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"],
|
||||
capture_output=True, text=True, timeout=10
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return []
|
||||
current_cmd = ""
|
||||
for line in result.stdout.split('\n'):
|
||||
for line in result.stdout.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("CommandLine="):
|
||||
current_cmd = line[len("CommandLine="):]
|
||||
elif line.startswith("ProcessId="):
|
||||
pid_str = line[len("ProcessId="):]
|
||||
if any(p in current_cmd for p in patterns) and (all_profiles or _matches_current_profile(current_cmd)):
|
||||
if any(p in current_cmd for p in patterns) and (
|
||||
all_profiles or _matches_current_profile(current_cmd)
|
||||
):
|
||||
try:
|
||||
pid = int(pid_str)
|
||||
if pid != os.getpid() and pid not in pids and pid not in _exclude:
|
||||
pids.append(pid)
|
||||
_append_unique_pid(pids, int(pid_str), exclude_pids)
|
||||
except ValueError:
|
||||
pass
|
||||
current_cmd = ""
|
||||
|
|
@ -227,9 +251,11 @@ def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = Fals
|
|||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
for line in result.stdout.split('\n'):
|
||||
if result.returncode != 0:
|
||||
return []
|
||||
for line in result.stdout.split("\n"):
|
||||
stripped = line.strip()
|
||||
if not stripped or 'grep' in stripped:
|
||||
if not stripped or "grep" in stripped:
|
||||
continue
|
||||
|
||||
pid = None
|
||||
|
|
@ -251,16 +277,137 @@ def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = Fals
|
|||
|
||||
if pid is None:
|
||||
continue
|
||||
if pid == os.getpid() or pid in pids or pid in _exclude:
|
||||
continue
|
||||
if any(pattern in command for pattern in patterns) and (all_profiles or _matches_current_profile(command)):
|
||||
pids.append(pid)
|
||||
if any(pattern in command for pattern in patterns) and (
|
||||
all_profiles or _matches_current_profile(command)
|
||||
):
|
||||
_append_unique_pid(pids, pid, exclude_pids)
|
||||
except (OSError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
return []
|
||||
|
||||
return pids
|
||||
|
||||
|
||||
def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = False) -> list:
|
||||
"""Find PIDs of running gateway processes.
|
||||
|
||||
Args:
|
||||
exclude_pids: PIDs to exclude from the result (e.g. service-managed
|
||||
PIDs that should not be killed during a stale-process sweep).
|
||||
all_profiles: When ``True``, return gateway PIDs across **all**
|
||||
profiles (the pre-7923 global behaviour). ``hermes update``
|
||||
needs this because a code update affects every profile.
|
||||
When ``False`` (default), only PIDs belonging to the current
|
||||
Hermes profile are returned.
|
||||
"""
|
||||
_exclude = set(exclude_pids or set())
|
||||
pids: list[int] = []
|
||||
if not all_profiles:
|
||||
try:
|
||||
from gateway.status import get_running_pid
|
||||
|
||||
_append_unique_pid(pids, get_running_pid(), _exclude)
|
||||
except Exception:
|
||||
pass
|
||||
for pid in _get_service_pids():
|
||||
_append_unique_pid(pids, pid, _exclude)
|
||||
for pid in _scan_gateway_pids(_exclude, all_profiles=all_profiles):
|
||||
_append_unique_pid(pids, pid, _exclude)
|
||||
return pids
|
||||
|
||||
|
||||
def _probe_systemd_service_running(system: bool = False) -> tuple[bool, bool]:
|
||||
selected_system = _select_systemd_scope(system)
|
||||
unit_exists = get_systemd_unit_path(system=selected_system).exists()
|
||||
if not unit_exists:
|
||||
return selected_system, False
|
||||
try:
|
||||
result = _run_systemctl(
|
||||
["is-active", get_service_name()],
|
||||
system=selected_system,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
except (RuntimeError, subprocess.TimeoutExpired):
|
||||
return selected_system, False
|
||||
return selected_system, result.stdout.strip() == "active"
|
||||
|
||||
|
||||
def _probe_launchd_service_running() -> bool:
|
||||
if not get_launchd_plist_path().exists():
|
||||
return False
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", get_launchd_label()],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
return False
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def get_gateway_runtime_snapshot(system: bool = False) -> GatewayRuntimeSnapshot:
|
||||
"""Return a unified view of gateway liveness for the current profile."""
|
||||
gateway_pids = tuple(find_gateway_pids())
|
||||
if is_termux():
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager="Termux / manual process",
|
||||
gateway_pids=gateway_pids,
|
||||
)
|
||||
|
||||
from hermes_constants import is_container
|
||||
|
||||
if is_linux() and is_container():
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager="docker (foreground)",
|
||||
gateway_pids=gateway_pids,
|
||||
)
|
||||
|
||||
if supports_systemd_services():
|
||||
selected_system, service_running = _probe_systemd_service_running(system=system)
|
||||
scope_label = _service_scope_label(selected_system)
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager=f"systemd ({scope_label})",
|
||||
service_installed=get_systemd_unit_path(system=selected_system).exists(),
|
||||
service_running=service_running,
|
||||
gateway_pids=gateway_pids,
|
||||
service_scope=scope_label,
|
||||
)
|
||||
|
||||
if is_macos():
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager="launchd",
|
||||
service_installed=get_launchd_plist_path().exists(),
|
||||
service_running=_probe_launchd_service_running(),
|
||||
gateway_pids=gateway_pids,
|
||||
service_scope="launchd",
|
||||
)
|
||||
|
||||
return GatewayRuntimeSnapshot(
|
||||
manager="manual process",
|
||||
gateway_pids=gateway_pids,
|
||||
)
|
||||
|
||||
|
||||
def _format_gateway_pids(pids: tuple[int, ...] | list[int], *, limit: int | None = 3) -> str:
|
||||
rendered = [str(pid) for pid in pids[:limit] if pid > 0] if limit is not None else [str(pid) for pid in pids if pid > 0]
|
||||
if limit is not None and len(pids) > limit:
|
||||
rendered.append("...")
|
||||
return ", ".join(rendered)
|
||||
|
||||
|
||||
def _print_gateway_process_mismatch(snapshot: GatewayRuntimeSnapshot) -> None:
|
||||
if not snapshot.has_process_service_mismatch:
|
||||
return
|
||||
print()
|
||||
print("⚠ Gateway process is running for this profile, but the service is not active")
|
||||
print(f" PID(s): {_format_gateway_pids(snapshot.gateway_pids, limit=None)}")
|
||||
print(" This is usually a manual foreground/tmux/nohup run, so `hermes gateway`")
|
||||
print(" can refuse to start another copy until this process stops.")
|
||||
|
||||
|
||||
def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None,
|
||||
all_profiles: bool = False) -> int:
|
||||
"""Kill any running gateway processes. Returns count killed.
|
||||
|
|
@ -340,25 +487,44 @@ def _wsl_systemd_operational() -> bool:
|
|||
WSL2 with ``systemd=true`` in wsl.conf has working systemd.
|
||||
WSL2 without it (or WSL1) does not — systemctl commands fail.
|
||||
"""
|
||||
return _systemd_operational(system=True)
|
||||
|
||||
|
||||
def _systemd_operational(system: bool = False) -> bool:
|
||||
"""Return True when the requested systemd scope is usable."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["systemctl", "is-system-running"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
result = _run_systemctl(
|
||||
["is-system-running"],
|
||||
system=system,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
# "running", "degraded", "starting" all mean systemd is PID 1
|
||||
status = result.stdout.strip().lower()
|
||||
return status in ("running", "degraded", "starting", "initializing")
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired, OSError):
|
||||
except (RuntimeError, subprocess.TimeoutExpired, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def _container_systemd_operational() -> bool:
|
||||
"""Return True when a container exposes working user or system systemd."""
|
||||
if _systemd_operational(system=False):
|
||||
return True
|
||||
if _systemd_operational(system=True):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def supports_systemd_services() -> bool:
|
||||
if not is_linux() or is_termux() or is_container():
|
||||
if not is_linux() or is_termux():
|
||||
return False
|
||||
if shutil.which("systemctl") is None:
|
||||
return False
|
||||
if is_wsl():
|
||||
return _wsl_systemd_operational()
|
||||
if is_container():
|
||||
return _container_systemd_operational()
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -521,6 +687,195 @@ def has_conflicting_systemd_units() -> bool:
|
|||
return len(get_installed_systemd_scopes()) > 1
|
||||
|
||||
|
||||
# Legacy service names from older Hermes installs that predate the
|
||||
# hermes-gateway rename. Kept as an explicit allowlist (NOT a glob) so
|
||||
# profile units (hermes-gateway-*.service) and unrelated third-party
|
||||
# "hermes" units are never matched.
|
||||
_LEGACY_SERVICE_NAMES: tuple[str, ...] = ("hermes.service",)
|
||||
|
||||
# ExecStart content markers that identify a unit as running our gateway.
|
||||
# A legacy unit is only flagged when its file contains one of these.
|
||||
_LEGACY_UNIT_EXECSTART_MARKERS: tuple[str, ...] = (
|
||||
"hermes_cli.main gateway",
|
||||
"hermes_cli/main.py gateway",
|
||||
"gateway/run.py",
|
||||
" hermes gateway ",
|
||||
"/hermes gateway ",
|
||||
)
|
||||
|
||||
|
||||
def _legacy_unit_search_paths() -> list[tuple[bool, Path]]:
|
||||
"""Return ``[(is_system, base_dir), ...]`` — directories to scan for legacy units.
|
||||
|
||||
Factored out so tests can monkeypatch the search roots without touching
|
||||
real filesystem paths.
|
||||
"""
|
||||
return [
|
||||
(False, Path.home() / ".config" / "systemd" / "user"),
|
||||
(True, Path("/etc/systemd/system")),
|
||||
]
|
||||
|
||||
|
||||
def _find_legacy_hermes_units() -> list[tuple[str, Path, bool]]:
|
||||
"""Return ``[(unit_name, unit_path, is_system)]`` for legacy Hermes gateway units.
|
||||
|
||||
Detects unit files installed by older Hermes versions that used a
|
||||
different service name (e.g. ``hermes.service`` before the rename to
|
||||
``hermes-gateway.service``). When both a legacy unit and the current
|
||||
``hermes-gateway.service`` are active, they fight over the same bot
|
||||
token — the PR #5646 signal-recovery change turns this into a 30-second
|
||||
SIGTERM flap loop.
|
||||
|
||||
Safety guards:
|
||||
|
||||
* Explicit allowlist of legacy names (no globbing). Profile units such
|
||||
as ``hermes-gateway-coder.service`` and unrelated third-party
|
||||
``hermes-*`` services are never matched.
|
||||
* ExecStart content check — only flag units that invoke our gateway
|
||||
entrypoint. A user-created ``hermes.service`` running an unrelated
|
||||
binary is left untouched.
|
||||
* Results are returned purely for caller inspection; this function
|
||||
never mutates or removes anything.
|
||||
"""
|
||||
results: list[tuple[str, Path, bool]] = []
|
||||
for is_system, base in _legacy_unit_search_paths():
|
||||
for name in _LEGACY_SERVICE_NAMES:
|
||||
unit_path = base / name
|
||||
try:
|
||||
if not unit_path.exists():
|
||||
continue
|
||||
text = unit_path.read_text(encoding="utf-8", errors="ignore")
|
||||
except (OSError, PermissionError):
|
||||
continue
|
||||
if not any(marker in text for marker in _LEGACY_UNIT_EXECSTART_MARKERS):
|
||||
# Not our gateway — leave alone
|
||||
continue
|
||||
results.append((name, unit_path, is_system))
|
||||
return results
|
||||
|
||||
|
||||
def has_legacy_hermes_units() -> bool:
|
||||
"""Return True when any legacy Hermes gateway unit files exist."""
|
||||
return bool(_find_legacy_hermes_units())
|
||||
|
||||
|
||||
def print_legacy_unit_warning() -> None:
|
||||
"""Warn about legacy Hermes gateway unit files if any are installed.
|
||||
|
||||
Idempotent: prints nothing when no legacy units are detected. Safe to
|
||||
call from any status/install/setup path.
|
||||
"""
|
||||
legacy = _find_legacy_hermes_units()
|
||||
if not legacy:
|
||||
return
|
||||
print_warning("Legacy Hermes gateway unit(s) detected from an older install:")
|
||||
for name, path, is_system in legacy:
|
||||
scope = "system" if is_system else "user"
|
||||
print_info(f" {path} ({scope} scope)")
|
||||
print_info(" These run alongside the current hermes-gateway service and")
|
||||
print_info(" cause SIGTERM flap loops — both try to use the same bot token.")
|
||||
print_info(" Remove them with:")
|
||||
print_info(" hermes gateway migrate-legacy")
|
||||
|
||||
|
||||
def remove_legacy_hermes_units(
|
||||
interactive: bool = True,
|
||||
dry_run: bool = False,
|
||||
) -> tuple[int, list[Path]]:
|
||||
"""Stop, disable, and remove legacy Hermes gateway unit files.
|
||||
|
||||
Iterates over whatever ``_find_legacy_hermes_units()`` returns — which is
|
||||
an explicit allowlist of legacy names (not a glob). Profile units and
|
||||
unrelated third-party services are never touched.
|
||||
|
||||
Args:
|
||||
interactive: When True, prompt before removing. When False, remove
|
||||
without asking (used when another prompt has already confirmed,
|
||||
e.g. from the install flow).
|
||||
dry_run: When True, list what would be removed and return.
|
||||
|
||||
Returns:
|
||||
``(removed_count, remaining_paths)`` — remaining includes units we
|
||||
couldn't remove (typically system-scope when not running as root).
|
||||
"""
|
||||
legacy = _find_legacy_hermes_units()
|
||||
if not legacy:
|
||||
print("No legacy Hermes gateway units found.")
|
||||
return 0, []
|
||||
|
||||
user_units = [(n, p) for n, p, is_sys in legacy if not is_sys]
|
||||
system_units = [(n, p) for n, p, is_sys in legacy if is_sys]
|
||||
|
||||
print()
|
||||
print("Legacy Hermes gateway unit(s) found:")
|
||||
for name, path, is_system in legacy:
|
||||
scope = "system" if is_system else "user"
|
||||
print(f" {path} ({scope} scope)")
|
||||
print()
|
||||
|
||||
if dry_run:
|
||||
print("(dry-run — nothing removed)")
|
||||
return 0, [p for _, p, _ in legacy]
|
||||
|
||||
if interactive and not prompt_yes_no("Remove these legacy units?", True):
|
||||
print("Skipped. Run again with: hermes gateway migrate-legacy")
|
||||
return 0, [p for _, p, _ in legacy]
|
||||
|
||||
removed = 0
|
||||
remaining: list[Path] = []
|
||||
|
||||
# User-scope removal
|
||||
for name, path in user_units:
|
||||
try:
|
||||
_run_systemctl(["stop", name], system=False, check=False, timeout=90)
|
||||
_run_systemctl(["disable", name], system=False, check=False, timeout=30)
|
||||
path.unlink(missing_ok=True)
|
||||
print(f" ✓ Removed {path}")
|
||||
removed += 1
|
||||
except (OSError, RuntimeError) as e:
|
||||
print(f" ⚠ Could not remove {path}: {e}")
|
||||
remaining.append(path)
|
||||
|
||||
if user_units:
|
||||
try:
|
||||
_run_systemctl(["daemon-reload"], system=False, check=False, timeout=30)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# System-scope removal (needs root)
|
||||
if system_units:
|
||||
if os.geteuid() != 0:
|
||||
print()
|
||||
print_warning("System-scope legacy units require root to remove.")
|
||||
print_info(" Re-run with: sudo hermes gateway migrate-legacy")
|
||||
for _, path in system_units:
|
||||
remaining.append(path)
|
||||
else:
|
||||
for name, path in system_units:
|
||||
try:
|
||||
_run_systemctl(["stop", name], system=True, check=False, timeout=90)
|
||||
_run_systemctl(["disable", name], system=True, check=False, timeout=30)
|
||||
path.unlink(missing_ok=True)
|
||||
print(f" ✓ Removed {path}")
|
||||
removed += 1
|
||||
except (OSError, RuntimeError) as e:
|
||||
print(f" ⚠ Could not remove {path}: {e}")
|
||||
remaining.append(path)
|
||||
|
||||
try:
|
||||
_run_systemctl(["daemon-reload"], system=True, check=False, timeout=30)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
print()
|
||||
if remaining:
|
||||
print_warning(f"{len(remaining)} legacy unit(s) still present — see messages above.")
|
||||
else:
|
||||
print_success(f"Removed {removed} legacy unit(s).")
|
||||
|
||||
return removed, remaining
|
||||
|
||||
|
||||
def print_systemd_scope_conflict_warning() -> None:
|
||||
scopes = get_installed_systemd_scopes()
|
||||
if len(scopes) < 2:
|
||||
|
|
@ -1054,6 +1409,19 @@ def systemd_install(force: bool = False, system: bool = False, run_as_user: str
|
|||
if system:
|
||||
_require_root_for_system_service("install")
|
||||
|
||||
# Offer to remove legacy units (hermes.service from pre-rename installs)
|
||||
# before installing the new hermes-gateway.service. If both remain, they
|
||||
# flap-fight for the Telegram bot token on every gateway startup.
|
||||
# Only removes units matching _LEGACY_SERVICE_NAMES + our ExecStart
|
||||
# signature — profile units are never touched.
|
||||
if has_legacy_hermes_units():
|
||||
print()
|
||||
print_legacy_unit_warning()
|
||||
print()
|
||||
if prompt_yes_no("Remove the legacy unit(s) before installing?", True):
|
||||
remove_legacy_hermes_units(interactive=False)
|
||||
print()
|
||||
|
||||
unit_path = get_systemd_unit_path(system=system)
|
||||
scope_flag = " --system" if system else ""
|
||||
|
||||
|
|
@ -1092,6 +1460,7 @@ def systemd_install(force: bool = False, system: bool = False, run_as_user: str
|
|||
_ensure_linger_enabled()
|
||||
|
||||
print_systemd_scope_conflict_warning()
|
||||
print_legacy_unit_warning()
|
||||
|
||||
|
||||
def systemd_uninstall(system: bool = False):
|
||||
|
|
@ -1215,6 +1584,10 @@ def systemd_status(deep: bool = False, system: bool = False):
|
|||
print_systemd_scope_conflict_warning()
|
||||
print()
|
||||
|
||||
if has_legacy_hermes_units():
|
||||
print_legacy_unit_warning()
|
||||
print()
|
||||
|
||||
if not systemd_unit_is_current(system=system):
|
||||
print("⚠ Installed gateway service definition is outdated")
|
||||
print(f" Run: {'sudo ' if system else ''}hermes gateway restart{scope_flag} # auto-refreshes the unit")
|
||||
|
|
@ -1998,7 +2371,7 @@ _PLATFORMS = [
|
|||
{"name": "QQ_ALLOWED_USERS", "prompt": "Allowed user OpenIDs (comma-separated, leave empty for open access)", "password": False,
|
||||
"is_allowlist": True,
|
||||
"help": "Optional — restrict DM access to specific user OpenIDs."},
|
||||
{"name": "QQ_HOME_CHANNEL", "prompt": "Home channel (user/group OpenID for cron delivery, or empty)", "password": False,
|
||||
{"name": "QQBOT_HOME_CHANNEL", "prompt": "Home channel (user/group OpenID for cron delivery, or empty)", "password": False,
|
||||
"help": "OpenID to deliver cron results and notifications to."},
|
||||
],
|
||||
},
|
||||
|
|
@ -2625,6 +2998,215 @@ def _setup_feishu():
|
|||
print_info(f" Bot: {bot_name}")
|
||||
|
||||
|
||||
def _setup_qqbot():
|
||||
"""Interactive setup for QQ Bot — scan-to-configure or manual credentials."""
|
||||
print()
|
||||
print(color(" ─── 🐧 QQ Bot Setup ───", Colors.CYAN))
|
||||
|
||||
existing_app_id = get_env_value("QQ_APP_ID")
|
||||
existing_secret = get_env_value("QQ_CLIENT_SECRET")
|
||||
if existing_app_id and existing_secret:
|
||||
print()
|
||||
print_success("QQ Bot is already configured.")
|
||||
if not prompt_yes_no(" Reconfigure QQ Bot?", False):
|
||||
return
|
||||
|
||||
# ── Choose setup method ──
|
||||
print()
|
||||
method_choices = [
|
||||
"Scan QR code to add bot automatically (recommended)",
|
||||
"Enter existing App ID and App Secret manually",
|
||||
]
|
||||
method_idx = prompt_choice(" How would you like to set up QQ Bot?", method_choices, 0)
|
||||
|
||||
credentials = None
|
||||
used_qr = False
|
||||
|
||||
if method_idx == 0:
|
||||
# ── QR scan-to-configure ──
|
||||
try:
|
||||
credentials = _qqbot_qr_flow()
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
print_warning(" QQ Bot setup cancelled.")
|
||||
return
|
||||
if credentials:
|
||||
used_qr = True
|
||||
if not credentials:
|
||||
print_info(" QR setup did not complete. Continuing with manual input.")
|
||||
|
||||
# ── Manual credential input ──
|
||||
if not credentials:
|
||||
print()
|
||||
print_info(" Go to https://q.qq.com to register a QQ Bot application.")
|
||||
print_info(" Note your App ID and App Secret from the application page.")
|
||||
print()
|
||||
app_id = prompt(" App ID", password=False)
|
||||
if not app_id:
|
||||
print_warning(" Skipped — QQ Bot won't work without an App ID.")
|
||||
return
|
||||
app_secret = prompt(" App Secret", password=True)
|
||||
if not app_secret:
|
||||
print_warning(" Skipped — QQ Bot won't work without an App Secret.")
|
||||
return
|
||||
credentials = {"app_id": app_id.strip(), "client_secret": app_secret.strip(), "user_openid": ""}
|
||||
|
||||
# ── Save core credentials ──
|
||||
save_env_value("QQ_APP_ID", credentials["app_id"])
|
||||
save_env_value("QQ_CLIENT_SECRET", credentials["client_secret"])
|
||||
|
||||
user_openid = credentials.get("user_openid", "")
|
||||
|
||||
# ── DM security policy ──
|
||||
print()
|
||||
access_choices = [
|
||||
"Use DM pairing approval (recommended)",
|
||||
"Allow all direct messages",
|
||||
"Only allow listed user OpenIDs",
|
||||
]
|
||||
access_idx = prompt_choice(" How should direct messages be authorized?", access_choices, 0)
|
||||
if access_idx == 0:
|
||||
save_env_value("QQ_ALLOW_ALL_USERS", "false")
|
||||
if user_openid:
|
||||
print()
|
||||
if prompt_yes_no(f" Add yourself ({user_openid}) to the allow list?", True):
|
||||
save_env_value("QQ_ALLOWED_USERS", user_openid)
|
||||
print_success(f" Allow list set to {user_openid}")
|
||||
else:
|
||||
save_env_value("QQ_ALLOWED_USERS", "")
|
||||
else:
|
||||
save_env_value("QQ_ALLOWED_USERS", "")
|
||||
print_success(" DM pairing enabled.")
|
||||
print_info(" Unknown users can request access; approve with `hermes pairing approve`.")
|
||||
elif access_idx == 1:
|
||||
save_env_value("QQ_ALLOW_ALL_USERS", "true")
|
||||
save_env_value("QQ_ALLOWED_USERS", "")
|
||||
print_warning(" Open DM access enabled for QQ Bot.")
|
||||
else:
|
||||
default_allow = user_openid or ""
|
||||
allowlist = prompt(" Allowed user OpenIDs (comma-separated)", default_allow, password=False).replace(" ", "")
|
||||
save_env_value("QQ_ALLOW_ALL_USERS", "false")
|
||||
save_env_value("QQ_ALLOWED_USERS", allowlist)
|
||||
print_success(" Allowlist saved.")
|
||||
|
||||
# ── Home channel ──
|
||||
if user_openid:
|
||||
print()
|
||||
if prompt_yes_no(f" Use your QQ user ID ({user_openid}) as the home channel?", True):
|
||||
save_env_value("QQBOT_HOME_CHANNEL", user_openid)
|
||||
print_success(f" Home channel set to {user_openid}")
|
||||
else:
|
||||
print()
|
||||
home_channel = prompt(" Home channel OpenID (for cron/notifications, or empty)", password=False)
|
||||
if home_channel:
|
||||
save_env_value("QQBOT_HOME_CHANNEL", home_channel.strip())
|
||||
print_success(f" Home channel set to {home_channel.strip()}")
|
||||
|
||||
print()
|
||||
print_success("🐧 QQ Bot configured!")
|
||||
print_info(f" App ID: {credentials['app_id']}")
|
||||
|
||||
|
||||
def _qqbot_render_qr(url: str) -> bool:
|
||||
"""Try to render a QR code in the terminal. Returns True if successful."""
|
||||
try:
|
||||
import qrcode as _qr
|
||||
qr = _qr.QRCode(border=1,error_correction=_qr.constants.ERROR_CORRECT_L)
|
||||
qr.add_data(url)
|
||||
qr.make(fit=True)
|
||||
qr.print_ascii(invert=True)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _qqbot_qr_flow():
|
||||
"""Run the QR-code scan-to-configure flow.
|
||||
|
||||
Returns a dict with app_id, client_secret, user_openid on success,
|
||||
or None on failure/cancel.
|
||||
"""
|
||||
try:
|
||||
from gateway.platforms.qqbot import (
|
||||
create_bind_task, poll_bind_result, build_connect_url,
|
||||
decrypt_secret, BindStatus,
|
||||
)
|
||||
from gateway.platforms.qqbot.constants import ONBOARD_POLL_INTERVAL
|
||||
except Exception as exc:
|
||||
print_error(f" QQBot onboard import failed: {exc}")
|
||||
return None
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
MAX_REFRESHES = 3
|
||||
refresh_count = 0
|
||||
|
||||
while refresh_count <= MAX_REFRESHES:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
# ── Create bind task ──
|
||||
try:
|
||||
task_id, aes_key = loop.run_until_complete(create_bind_task())
|
||||
except Exception as e:
|
||||
print_warning(f" Failed to create bind task: {e}")
|
||||
loop.close()
|
||||
return None
|
||||
|
||||
url = build_connect_url(task_id)
|
||||
|
||||
# ── Display QR code + URL ──
|
||||
print()
|
||||
if _qqbot_render_qr(url):
|
||||
print(f" Scan the QR code above, or open this URL directly:\n {url}")
|
||||
else:
|
||||
print(f" Open this URL in QQ on your phone:\n {url}")
|
||||
print_info(" Tip: pip install qrcode to show a scannable QR code here")
|
||||
|
||||
# ── Poll loop (silent — keep QR visible at bottom) ──
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
status, app_id, encrypted_secret, user_openid = loop.run_until_complete(
|
||||
poll_bind_result(task_id)
|
||||
)
|
||||
except Exception:
|
||||
time.sleep(ONBOARD_POLL_INTERVAL)
|
||||
continue
|
||||
|
||||
if status == BindStatus.COMPLETED:
|
||||
client_secret = decrypt_secret(encrypted_secret, aes_key)
|
||||
print()
|
||||
print_success(f" QR scan complete! (App ID: {app_id})")
|
||||
if user_openid:
|
||||
print_info(f" Scanner's OpenID: {user_openid}")
|
||||
return {
|
||||
"app_id": app_id,
|
||||
"client_secret": client_secret,
|
||||
"user_openid": user_openid,
|
||||
}
|
||||
|
||||
if status == BindStatus.EXPIRED:
|
||||
refresh_count += 1
|
||||
if refresh_count > MAX_REFRESHES:
|
||||
print()
|
||||
print_warning(f" QR code expired {MAX_REFRESHES} times — giving up.")
|
||||
return None
|
||||
print()
|
||||
print_warning(f" QR code expired, refreshing... ({refresh_count}/{MAX_REFRESHES})")
|
||||
loop.close()
|
||||
break # outer while creates a new task
|
||||
|
||||
time.sleep(ONBOARD_POLL_INTERVAL)
|
||||
except KeyboardInterrupt:
|
||||
loop.close()
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _setup_signal():
|
||||
"""Interactive setup for Signal messenger."""
|
||||
import shutil
|
||||
|
|
@ -2762,6 +3344,10 @@ def gateway_setup():
|
|||
print_systemd_scope_conflict_warning()
|
||||
print()
|
||||
|
||||
if supports_systemd_services() and has_legacy_hermes_units():
|
||||
print_legacy_unit_warning()
|
||||
print()
|
||||
|
||||
if service_installed and service_running:
|
||||
print_success("Gateway service is installed and running.")
|
||||
elif service_installed:
|
||||
|
|
@ -2806,6 +3392,8 @@ def gateway_setup():
|
|||
_setup_dingtalk()
|
||||
elif platform["key"] == "feishu":
|
||||
_setup_feishu()
|
||||
elif platform["key"] == "qqbot":
|
||||
_setup_qqbot()
|
||||
else:
|
||||
_setup_standard_platform(platform)
|
||||
|
||||
|
|
@ -3165,15 +3753,18 @@ def gateway_command(args):
|
|||
elif subcmd == "status":
|
||||
deep = getattr(args, 'deep', False)
|
||||
system = getattr(args, 'system', False)
|
||||
snapshot = get_gateway_runtime_snapshot(system=system)
|
||||
|
||||
# Check for service first
|
||||
if supports_systemd_services() and (get_systemd_unit_path(system=False).exists() or get_systemd_unit_path(system=True).exists()):
|
||||
systemd_status(deep, system=system)
|
||||
_print_gateway_process_mismatch(snapshot)
|
||||
elif is_macos() and get_launchd_plist_path().exists():
|
||||
launchd_status(deep)
|
||||
_print_gateway_process_mismatch(snapshot)
|
||||
else:
|
||||
# Check for manually running processes
|
||||
pids = find_gateway_pids()
|
||||
pids = list(snapshot.gateway_pids)
|
||||
if pids:
|
||||
print(f"✓ Gateway is running (PID: {', '.join(map(str, pids))})")
|
||||
print(" (Running manually, not as a system service)")
|
||||
|
|
@ -3214,3 +3805,14 @@ def gateway_command(args):
|
|||
else:
|
||||
print(" hermes gateway install # Install as user service")
|
||||
print(" sudo hermes gateway install --system # Install as boot-time system service")
|
||||
|
||||
elif subcmd == "migrate-legacy":
|
||||
# Stop, disable, and remove legacy Hermes gateway unit files from
|
||||
# pre-rename installs (e.g. hermes.service). Profile units and
|
||||
# unrelated third-party services are never touched.
|
||||
dry_run = getattr(args, 'dry_run', False)
|
||||
yes = getattr(args, 'yes', False)
|
||||
if not supports_systemd_services() and not is_macos():
|
||||
print("Legacy unit migration only applies to systemd-based Linux hosts.")
|
||||
return
|
||||
remove_legacy_hermes_units(interactive=not yes, dry_run=dry_run)
|
||||
|
|
|
|||
3066
hermes_cli/main.py
3066
hermes_cli/main.py
File diff suppressed because it is too large
Load diff
|
|
@ -692,12 +692,12 @@ def switch_model(
|
|||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
validation = {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": None,
|
||||
"message": f"Could not validate `{new_model}`: {e}",
|
||||
}
|
||||
|
||||
if not validation.get("accepted"):
|
||||
|
|
|
|||
|
|
@ -26,7 +26,8 @@ COPILOT_REASONING_EFFORTS_O_SERIES = ["low", "medium", "high"]
|
|||
# Fallback OpenRouter snapshot used when the live catalog is unavailable.
|
||||
# (model_id, display description shown in menus)
|
||||
OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("anthropic/claude-opus-4.7", "recommended"),
|
||||
("moonshotai/kimi-k2.5", "recommended"),
|
||||
("anthropic/claude-opus-4.7", ""),
|
||||
("anthropic/claude-opus-4.6", ""),
|
||||
("anthropic/claude-sonnet-4.6", ""),
|
||||
("qwen/qwen3.6-plus", ""),
|
||||
|
|
@ -49,7 +50,6 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [
|
|||
("z-ai/glm-5.1", ""),
|
||||
("z-ai/glm-5v-turbo", ""),
|
||||
("z-ai/glm-5-turbo", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("x-ai/grok-4.20", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b:free", "free"),
|
||||
|
|
@ -75,6 +75,7 @@ def _codex_curated_models() -> list[str]:
|
|||
|
||||
_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"nous": [
|
||||
"moonshotai/kimi-k2.5",
|
||||
"xiaomi/mimo-v2-pro",
|
||||
"anthropic/claude-opus-4.7",
|
||||
"anthropic/claude-opus-4.6",
|
||||
|
|
@ -96,7 +97,6 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
|||
"z-ai/glm-5.1",
|
||||
"z-ai/glm-5v-turbo",
|
||||
"z-ai/glm-5-turbo",
|
||||
"moonshotai/kimi-k2.5",
|
||||
"x-ai/grok-4.20-beta",
|
||||
"nvidia/nemotron-3-super-120b-a12b",
|
||||
"nvidia/nemotron-3-super-120b-a12b:free",
|
||||
|
|
@ -135,7 +135,6 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
|||
"gemini-2.5-flash-lite",
|
||||
# Gemma open models (also served via AI Studio)
|
||||
"gemma-4-31b-it",
|
||||
"gemma-4-26b-it",
|
||||
],
|
||||
"google-gemini-cli": [
|
||||
"gemini-2.5-pro",
|
||||
|
|
@ -155,9 +154,23 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
|||
"grok-4.20-reasoning",
|
||||
"grok-4-1-fast-reasoning",
|
||||
],
|
||||
"nvidia": [
|
||||
# NVIDIA flagship reasoning models
|
||||
"nvidia/nemotron-3-super-120b-a12b",
|
||||
"nvidia/nemotron-3-nano-30b-a3b",
|
||||
"nvidia/llama-3.3-nemotron-super-49b-v1.5",
|
||||
# Third-party agentic models hosted on build.nvidia.com
|
||||
# (map to OpenRouter defaults — users get familiar picks on NIM)
|
||||
"qwen/qwen3.5-397b-a17b",
|
||||
"deepseek-ai/deepseek-v3.2",
|
||||
"moonshotai/kimi-k2.5",
|
||||
"minimaxai/minimax-m2.5",
|
||||
"z-ai/glm5",
|
||||
"openai/gpt-oss-120b",
|
||||
],
|
||||
"kimi-coding": [
|
||||
"kimi-for-coding",
|
||||
"kimi-k2.5",
|
||||
"kimi-for-coding",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-thinking-turbo",
|
||||
"kimi-k2-turbo-preview",
|
||||
|
|
@ -212,6 +225,7 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
|||
"trinity-mini",
|
||||
],
|
||||
"opencode-zen": [
|
||||
"kimi-k2.5",
|
||||
"gpt-5.4-pro",
|
||||
"gpt-5.4",
|
||||
"gpt-5.3-codex",
|
||||
|
|
@ -243,16 +257,15 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
|||
"glm-5",
|
||||
"glm-4.7",
|
||||
"glm-4.6",
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2",
|
||||
"qwen3-coder",
|
||||
"big-pickle",
|
||||
],
|
||||
"opencode-go": [
|
||||
"kimi-k2.5",
|
||||
"glm-5.1",
|
||||
"glm-5",
|
||||
"kimi-k2.5",
|
||||
"mimo-v2-pro",
|
||||
"mimo-v2-omni",
|
||||
"minimax-m2.7",
|
||||
|
|
@ -285,21 +298,21 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
|||
# to https://dashscope-intl.aliyuncs.com/compatible-mode/v1 (OpenAI-compat)
|
||||
# or https://dashscope-intl.aliyuncs.com/apps/anthropic (Anthropic-compat).
|
||||
"alibaba": [
|
||||
"kimi-k2.5",
|
||||
"qwen3.5-plus",
|
||||
"qwen3-coder-plus",
|
||||
"qwen3-coder-next",
|
||||
# Third-party models available on coding-intl
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"kimi-k2.5",
|
||||
"MiniMax-M2.5",
|
||||
],
|
||||
# Curated HF model list — only agentic models that map to OpenRouter defaults.
|
||||
"huggingface": [
|
||||
"moonshotai/Kimi-K2.5",
|
||||
"Qwen/Qwen3.5-397B-A17B",
|
||||
"Qwen/Qwen3.5-35B-A3B",
|
||||
"deepseek-ai/DeepSeek-V3.2",
|
||||
"moonshotai/Kimi-K2.5",
|
||||
"MiniMaxAI/MiniMax-M2.5",
|
||||
"zai-org/GLM-5",
|
||||
"XiaomiMiMo/MiMo-V2-Flash",
|
||||
|
|
@ -536,6 +549,7 @@ CANONICAL_PROVIDERS: list[ProviderEntry] = [
|
|||
ProviderEntry("anthropic", "Anthropic", "Anthropic (Claude models — API key or Claude Code)"),
|
||||
ProviderEntry("openai-codex", "OpenAI Codex", "OpenAI Codex"),
|
||||
ProviderEntry("xiaomi", "Xiaomi MiMo", "Xiaomi MiMo (MiMo-V2 models — pro, omni, flash)"),
|
||||
ProviderEntry("nvidia", "NVIDIA NIM", "NVIDIA NIM (Nemotron models — build.nvidia.com or local NIM)"),
|
||||
ProviderEntry("qwen-oauth", "Qwen OAuth (Portal)", "Qwen OAuth (reuses local Qwen CLI login)"),
|
||||
ProviderEntry("copilot", "GitHub Copilot", "GitHub Copilot (uses GITHUB_TOKEN or gh auth token)"),
|
||||
ProviderEntry("copilot-acp", "GitHub Copilot ACP", "GitHub Copilot ACP (spawns `copilot --acp --stdio`)"),
|
||||
|
|
@ -618,6 +632,10 @@ _PROVIDER_ALIASES = {
|
|||
"grok": "xai",
|
||||
"x-ai": "xai",
|
||||
"x.ai": "xai",
|
||||
"nim": "nvidia",
|
||||
"nvidia-nim": "nvidia",
|
||||
"build-nvidia": "nvidia",
|
||||
"nemotron": "nvidia",
|
||||
"ollama": "custom", # bare "ollama" = local; use "ollama-cloud" for cloud
|
||||
"ollama_cloud": "ollama-cloud",
|
||||
}
|
||||
|
|
@ -2032,8 +2050,8 @@ def validate_requested_model(
|
|||
)
|
||||
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": message,
|
||||
}
|
||||
|
|
@ -2046,8 +2064,8 @@ def validate_requested_model(
|
|||
message += f"\n If this server expects `/v1`, try base URL: `{probe.get('suggested_base_url')}`"
|
||||
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": message,
|
||||
}
|
||||
|
|
@ -2081,12 +2099,11 @@ def validate_requested_model(
|
|||
if suggestions:
|
||||
suggestion_text = "\n Similar models: " + ", ".join(f"`{s}`" for s in suggestions)
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Note: `{requested}` was not found in the OpenAI Codex model listing. "
|
||||
f"It may still work if your account has access to it."
|
||||
f"Model `{requested}` was not found in the OpenAI Codex model listing."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
|
|
@ -2125,16 +2142,15 @@ def validate_requested_model(
|
|||
if suggestions:
|
||||
suggestion_text = "\n Similar models: " + ", ".join(f"`{s}`" for s in suggestions)
|
||||
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Note: `{requested}` was not found in this provider's model listing. "
|
||||
f"It may still work if your plan supports it."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
return {
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Model `{requested}` was not found in this provider's model listing."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
|
||||
# api_models is None — couldn't reach API. Accept and persist,
|
||||
# but warn so typos don't silently break things.
|
||||
|
|
@ -2176,8 +2192,8 @@ def validate_requested_model(
|
|||
|
||||
provider_label = _PROVIDER_LABELS.get(normalized, normalized)
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Could not reach the {provider_label} API to validate `{requested}`. "
|
||||
|
|
|
|||
|
|
@ -300,19 +300,10 @@ def _read_config_model(profile_dir: Path) -> tuple:
|
|||
|
||||
def _check_gateway_running(profile_dir: Path) -> bool:
|
||||
"""Check if a gateway is running for a given profile directory."""
|
||||
pid_file = profile_dir / "gateway.pid"
|
||||
if not pid_file.exists():
|
||||
return False
|
||||
try:
|
||||
raw = pid_file.read_text().strip()
|
||||
if not raw:
|
||||
return False
|
||||
data = json.loads(raw) if raw.startswith("{") else {"pid": int(raw)}
|
||||
pid = int(data["pid"])
|
||||
os.kill(pid, 0) # existence check
|
||||
return True
|
||||
except (json.JSONDecodeError, KeyError, ValueError, TypeError,
|
||||
ProcessLookupError, PermissionError, OSError):
|
||||
from gateway.status import get_running_pid
|
||||
return get_running_pid(profile_dir / "gateway.pid", cleanup_stale=False) is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -137,6 +137,11 @@ HERMES_OVERLAYS: Dict[str, HermesOverlay] = {
|
|||
base_url_override="https://api.x.ai/v1",
|
||||
base_url_env_var="XAI_BASE_URL",
|
||||
),
|
||||
"nvidia": HermesOverlay(
|
||||
transport="openai_chat",
|
||||
base_url_override="https://integrate.api.nvidia.com/v1",
|
||||
base_url_env_var="NVIDIA_BASE_URL",
|
||||
),
|
||||
"xiaomi": HermesOverlay(
|
||||
transport="openai_chat",
|
||||
base_url_env_var="XIAOMI_BASE_URL",
|
||||
|
|
@ -191,6 +196,12 @@ ALIASES: Dict[str, str] = {
|
|||
"x.ai": "xai",
|
||||
"grok": "xai",
|
||||
|
||||
# nvidia
|
||||
"nim": "nvidia",
|
||||
"nvidia-nim": "nvidia",
|
||||
"build-nvidia": "nvidia",
|
||||
"nemotron": "nvidia",
|
||||
|
||||
# kimi-for-coding (models.dev ID)
|
||||
"kimi": "kimi-for-coding",
|
||||
"kimi-coding": "kimi-for-coding",
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ _DEFAULT_PROVIDER_MODELS = {
|
|||
"gemini": [
|
||||
"gemini-3.1-pro-preview", "gemini-3-flash-preview", "gemini-3.1-flash-lite-preview",
|
||||
"gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite",
|
||||
"gemma-4-31b-it", "gemma-4-26b-it",
|
||||
"gemma-4-31b-it",
|
||||
],
|
||||
"zai": ["glm-5.1", "glm-5", "glm-4.7", "glm-4.5", "glm-4.5-flash"],
|
||||
"kimi-coding": ["kimi-k2.5", "kimi-k2-thinking", "kimi-k2-turbo-preview"],
|
||||
|
|
@ -2005,52 +2005,6 @@ def _setup_wecom_callback():
|
|||
_gw_setup()
|
||||
|
||||
|
||||
def _setup_qqbot():
|
||||
"""Configure QQ Bot gateway."""
|
||||
print_header("QQ Bot")
|
||||
existing = get_env_value("QQ_APP_ID")
|
||||
if existing:
|
||||
print_info("QQ Bot: already configured")
|
||||
if not prompt_yes_no("Reconfigure QQ Bot?", False):
|
||||
return
|
||||
|
||||
print_info("Connects Hermes to QQ via the Official QQ Bot API (v2).")
|
||||
print_info(" Requires a QQ Bot application at q.qq.com")
|
||||
print_info(" Reference: https://bot.q.qq.com/wiki/develop/api-v2/")
|
||||
print()
|
||||
|
||||
app_id = prompt("QQ Bot App ID")
|
||||
if not app_id:
|
||||
print_warning("App ID is required — skipping QQ Bot setup")
|
||||
return
|
||||
save_env_value("QQ_APP_ID", app_id.strip())
|
||||
|
||||
client_secret = prompt("QQ Bot App Secret", password=True)
|
||||
if not client_secret:
|
||||
print_warning("App Secret is required — skipping QQ Bot setup")
|
||||
return
|
||||
save_env_value("QQ_CLIENT_SECRET", client_secret)
|
||||
print_success("QQ Bot credentials saved")
|
||||
|
||||
print()
|
||||
print_info("🔒 Security: Restrict who can DM your bot")
|
||||
print_info(" Use QQ user OpenIDs (found in event payloads)")
|
||||
print()
|
||||
allowed_users = prompt("Allowed user OpenIDs (comma-separated, leave empty for open access)")
|
||||
if allowed_users:
|
||||
save_env_value("QQ_ALLOWED_USERS", allowed_users.replace(" ", ""))
|
||||
print_success("QQ Bot allowlist configured")
|
||||
else:
|
||||
print_info("⚠️ No allowlist set — anyone can DM the bot!")
|
||||
|
||||
print()
|
||||
print_info("📬 Home Channel: OpenID for cron job delivery and notifications.")
|
||||
home_channel = prompt("Home channel OpenID (leave empty to set later)")
|
||||
if home_channel:
|
||||
save_env_value("QQ_HOME_CHANNEL", home_channel)
|
||||
|
||||
print()
|
||||
print_success("QQ Bot configured!")
|
||||
|
||||
|
||||
def _setup_bluebubbles():
|
||||
|
|
@ -2119,12 +2073,9 @@ def _setup_bluebubbles():
|
|||
|
||||
|
||||
def _setup_qqbot():
|
||||
"""Configure QQ Bot (Official API v2) via standard platform setup."""
|
||||
from hermes_cli.gateway import _PLATFORMS
|
||||
qq_platform = next((p for p in _PLATFORMS if p["key"] == "qqbot"), None)
|
||||
if qq_platform:
|
||||
from hermes_cli.gateway import _setup_standard_platform
|
||||
_setup_standard_platform(qq_platform)
|
||||
"""Configure QQ Bot (Official API v2) via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_qqbot as _gateway_setup_qqbot
|
||||
_gateway_setup_qqbot()
|
||||
|
||||
|
||||
def _setup_webhooks():
|
||||
|
|
@ -2264,7 +2215,9 @@ def setup_gateway(config: dict):
|
|||
missing_home.append("Slack")
|
||||
if get_env_value("BLUEBUBBLES_SERVER_URL") and not get_env_value("BLUEBUBBLES_HOME_CHANNEL"):
|
||||
missing_home.append("BlueBubbles")
|
||||
if get_env_value("QQ_APP_ID") and not get_env_value("QQ_HOME_CHANNEL"):
|
||||
if get_env_value("QQ_APP_ID") and not (
|
||||
get_env_value("QQBOT_HOME_CHANNEL") or get_env_value("QQ_HOME_CHANNEL")
|
||||
):
|
||||
missing_home.append("QQBot")
|
||||
|
||||
if missing_home:
|
||||
|
|
@ -2289,8 +2242,10 @@ def setup_gateway(config: dict):
|
|||
_is_service_running,
|
||||
supports_systemd_services,
|
||||
has_conflicting_systemd_units,
|
||||
has_legacy_hermes_units,
|
||||
install_linux_gateway_from_setup,
|
||||
print_systemd_scope_conflict_warning,
|
||||
print_legacy_unit_warning,
|
||||
systemd_start,
|
||||
systemd_restart,
|
||||
launchd_install,
|
||||
|
|
@ -2308,6 +2263,10 @@ def setup_gateway(config: dict):
|
|||
print_systemd_scope_conflict_warning()
|
||||
print()
|
||||
|
||||
if supports_systemd and has_legacy_hermes_units():
|
||||
print_legacy_unit_warning()
|
||||
print()
|
||||
|
||||
if service_running:
|
||||
if prompt_yes_no(" Restart the gateway to pick up changes?", True):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -515,6 +515,90 @@ def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
|
|||
c.print()
|
||||
|
||||
|
||||
def browse_skills(page: int = 1, page_size: int = 20, source: str = "all") -> dict:
|
||||
"""Paginated hub browse for programmatic callers (e.g. TUI gateway).
|
||||
|
||||
Returns ``{"items": [...], "page": int, "total_pages": int, "total": int}``.
|
||||
"""
|
||||
from tools.skills_hub import GitHubAuth, create_source_router
|
||||
|
||||
page_size = max(1, min(page_size, 100))
|
||||
_TRUST_RANK = {"builtin": 3, "trusted": 2, "community": 1}
|
||||
_PER_SOURCE_LIMIT = {"official": 100, "skills-sh": 100, "well-known": 25, "github": 100, "clawhub": 50,
|
||||
"claude-marketplace": 50, "lobehub": 50}
|
||||
auth = GitHubAuth()
|
||||
sources = create_source_router(auth)
|
||||
all_results: list = []
|
||||
for src in sources:
|
||||
sid = src.source_id()
|
||||
if source != "all" and sid != source and sid != "official":
|
||||
continue
|
||||
try:
|
||||
limit = _PER_SOURCE_LIMIT.get(sid, 50)
|
||||
all_results.extend(src.search("", limit=limit))
|
||||
except Exception:
|
||||
continue
|
||||
if not all_results:
|
||||
return {"items": [], "page": 1, "total_pages": 1, "total": 0}
|
||||
seen: dict = {}
|
||||
for r in all_results:
|
||||
rank = _TRUST_RANK.get(r.trust_level, 0)
|
||||
if r.name not in seen or rank > _TRUST_RANK.get(seen[r.name].trust_level, 0):
|
||||
seen[r.name] = r
|
||||
deduped = list(seen.values())
|
||||
deduped.sort(key=lambda r: (-_TRUST_RANK.get(r.trust_level, 0), r.source != "official", r.name.lower()))
|
||||
total = len(deduped)
|
||||
total_pages = max(1, (total + page_size - 1) // page_size)
|
||||
page = max(1, min(page, total_pages))
|
||||
start = (page - 1) * page_size
|
||||
page_items = deduped[start : min(start + page_size, total)]
|
||||
return {
|
||||
"items": [{"name": r.name, "description": r.description, "source": r.source,
|
||||
"trust": r.trust_level} for r in page_items],
|
||||
"page": page,
|
||||
"total_pages": total_pages,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
|
||||
def inspect_skill(identifier: str) -> Optional[dict]:
|
||||
"""Skill metadata (+ SKILL.md preview) for programmatic callers."""
|
||||
from tools.skills_hub import GitHubAuth, create_source_router
|
||||
|
||||
class _Q:
|
||||
def print(self, *a, **k):
|
||||
pass
|
||||
|
||||
c = _Q()
|
||||
auth = GitHubAuth()
|
||||
sources = create_source_router(auth)
|
||||
ident = identifier
|
||||
if "/" not in ident:
|
||||
ident = _resolve_short_name(ident, sources, c)
|
||||
if not ident:
|
||||
return None
|
||||
meta, bundle, _ = _resolve_source_meta_and_bundle(ident, sources)
|
||||
if not meta:
|
||||
return None
|
||||
out: dict = {
|
||||
"name": meta.name,
|
||||
"description": meta.description,
|
||||
"source": meta.source,
|
||||
"identifier": meta.identifier,
|
||||
"tags": list(meta.tags) if meta.tags else [],
|
||||
}
|
||||
if bundle and "SKILL.md" in bundle.files:
|
||||
content = bundle.files["SKILL.md"]
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
lines = content.split("\n")
|
||||
preview = "\n".join(lines[:50])
|
||||
if len(lines) > 50:
|
||||
preview += f"\n\n... ({len(lines) - 50} more lines)"
|
||||
out["skill_md_preview"] = preview
|
||||
return out
|
||||
|
||||
|
||||
def do_list(source_filter: str = "all", console: Optional[Console] = None) -> None:
|
||||
"""List installed skills, distinguishing hub, builtin, and local skills."""
|
||||
from tools.skills_hub import HubLockFile, ensure_hub_dirs
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ All fields are optional. Missing values inherit from the ``default`` skin.
|
|||
banner_dim: "#B8860B" # Dim/muted text (separators, labels)
|
||||
banner_text: "#FFF8DC" # Body text (tool names, skill names)
|
||||
ui_accent: "#FFBF00" # General UI accent
|
||||
ui_label: "#4dd0e1" # UI labels
|
||||
ui_label: "#DAA520" # UI labels (warm gold; teal clashed w/ default banner gold)
|
||||
ui_ok: "#4caf50" # Success indicators
|
||||
ui_error: "#ef5350" # Error indicators
|
||||
ui_warn: "#ffa726" # Warning indicators
|
||||
|
|
@ -163,7 +163,7 @@ _BUILTIN_SKINS: Dict[str, Dict[str, Any]] = {
|
|||
"banner_dim": "#B8860B",
|
||||
"banner_text": "#FFF8DC",
|
||||
"ui_accent": "#FFBF00",
|
||||
"ui_label": "#4dd0e1",
|
||||
"ui_label": "#DAA520",
|
||||
"ui_ok": "#4caf50",
|
||||
"ui_error": "#ef5350",
|
||||
"ui_warn": "#ffa726",
|
||||
|
|
|
|||
|
|
@ -317,7 +317,7 @@ def show_status(args):
|
|||
"WeCom Callback": ("WECOM_CALLBACK_CORP_ID", None),
|
||||
"Weixin": ("WEIXIN_ACCOUNT_ID", "WEIXIN_HOME_CHANNEL"),
|
||||
"BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"),
|
||||
"QQBot": ("QQ_APP_ID", "QQ_HOME_CHANNEL"),
|
||||
"QQBot": ("QQ_APP_ID", "QQBOT_HOME_CHANNEL"),
|
||||
}
|
||||
|
||||
for name, (token_var, home_var) in platforms.items():
|
||||
|
|
@ -327,6 +327,9 @@ def show_status(args):
|
|||
home_channel = ""
|
||||
if home_var:
|
||||
home_channel = os.getenv(home_var, "")
|
||||
# Back-compat: QQBot home channel was renamed from QQ_HOME_CHANNEL to QQBOT_HOME_CHANNEL
|
||||
if not home_channel and home_var == "QQBOT_HOME_CHANNEL":
|
||||
home_channel = os.getenv("QQ_HOME_CHANNEL", "")
|
||||
|
||||
status = "configured" if has_token else "not configured"
|
||||
if home_channel:
|
||||
|
|
@ -339,73 +342,36 @@ def show_status(args):
|
|||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Gateway Service", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
if _is_termux():
|
||||
try:
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
gateway_pids = find_gateway_pids()
|
||||
except Exception:
|
||||
gateway_pids = []
|
||||
is_running = bool(gateway_pids)
|
||||
|
||||
try:
|
||||
from hermes_cli.gateway import get_gateway_runtime_snapshot, _format_gateway_pids
|
||||
|
||||
snapshot = get_gateway_runtime_snapshot()
|
||||
is_running = snapshot.running
|
||||
print(f" Status: {check_mark(is_running)} {'running' if is_running else 'stopped'}")
|
||||
print(" Manager: Termux / manual process")
|
||||
if gateway_pids:
|
||||
rendered = ", ".join(str(pid) for pid in gateway_pids[:3])
|
||||
if len(gateway_pids) > 3:
|
||||
rendered += ", ..."
|
||||
print(f" PID(s): {rendered}")
|
||||
else:
|
||||
print(f" Manager: {snapshot.manager}")
|
||||
if snapshot.gateway_pids:
|
||||
print(f" PID(s): {_format_gateway_pids(snapshot.gateway_pids)}")
|
||||
if snapshot.has_process_service_mismatch:
|
||||
print(" Service: installed but not managing the current running gateway")
|
||||
elif _is_termux() and not snapshot.gateway_pids:
|
||||
print(" Start with: hermes gateway")
|
||||
print(" Note: Android may stop background jobs when Termux is suspended")
|
||||
|
||||
elif sys.platform.startswith('linux'):
|
||||
from hermes_constants import is_container
|
||||
if is_container():
|
||||
# Docker/Podman: no systemd — check for running gateway processes
|
||||
try:
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
gateway_pids = find_gateway_pids()
|
||||
is_active = len(gateway_pids) > 0
|
||||
except Exception:
|
||||
is_active = False
|
||||
print(f" Status: {check_mark(is_active)} {'running' if is_active else 'stopped'}")
|
||||
print(" Manager: docker (foreground)")
|
||||
elif snapshot.service_installed and not snapshot.service_running:
|
||||
print(" Service: installed but stopped")
|
||||
except Exception:
|
||||
if _is_termux():
|
||||
print(f" Status: {color('unknown', Colors.DIM)}")
|
||||
print(" Manager: Termux / manual process")
|
||||
elif sys.platform.startswith('linux'):
|
||||
print(f" Status: {color('unknown', Colors.DIM)}")
|
||||
print(" Manager: systemd/manual")
|
||||
elif sys.platform == 'darwin':
|
||||
print(f" Status: {color('unknown', Colors.DIM)}")
|
||||
print(" Manager: launchd")
|
||||
else:
|
||||
try:
|
||||
from hermes_cli.gateway import get_service_name
|
||||
_gw_svc = get_service_name()
|
||||
except Exception:
|
||||
_gw_svc = "hermes-gateway"
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", _gw_svc],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
is_active = result.stdout.strip() == "active"
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
is_active = False
|
||||
print(f" Status: {check_mark(is_active)} {'running' if is_active else 'stopped'}")
|
||||
print(" Manager: systemd (user)")
|
||||
|
||||
elif sys.platform == 'darwin':
|
||||
from hermes_cli.gateway import get_launchd_label
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", get_launchd_label()],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
is_loaded = result.returncode == 0
|
||||
except subprocess.TimeoutExpired:
|
||||
is_loaded = False
|
||||
print(f" Status: {check_mark(is_loaded)} {'loaded' if is_loaded else 'not loaded'}")
|
||||
print(" Manager: launchd")
|
||||
else:
|
||||
print(f" Status: {color('N/A', Colors.DIM)}")
|
||||
print(" Manager: (not supported on this platform)")
|
||||
print(f" Status: {color('N/A', Colors.DIM)}")
|
||||
print(" Manager: (not supported on this platform)")
|
||||
|
||||
# =========================================================================
|
||||
# Cron Jobs
|
||||
|
|
|
|||
|
|
@ -1444,38 +1444,8 @@ def _nous_poller(session_id: str) -> None:
|
|||
auth_state, min_key_ttl_seconds=300, timeout_seconds=15.0,
|
||||
force_refresh=False, force_mint=True,
|
||||
)
|
||||
# Save into credential pool same as auth_commands.py does
|
||||
from agent.credential_pool import (
|
||||
PooledCredential,
|
||||
load_pool,
|
||||
AUTH_TYPE_OAUTH,
|
||||
SOURCE_MANUAL,
|
||||
)
|
||||
pool = load_pool("nous")
|
||||
entry = PooledCredential.from_dict("nous", {
|
||||
**full_state,
|
||||
"label": "dashboard device_code",
|
||||
"auth_type": AUTH_TYPE_OAUTH,
|
||||
"source": f"{SOURCE_MANUAL}:dashboard_device_code",
|
||||
"base_url": full_state.get("inference_base_url"),
|
||||
})
|
||||
pool.add_entry(entry)
|
||||
# Also persist to auth store so get_nous_auth_status() sees it
|
||||
# (matches what _login_nous in auth.py does for the CLI flow).
|
||||
try:
|
||||
from hermes_cli.auth import (
|
||||
_load_auth_store, _save_provider_state, _save_auth_store,
|
||||
_auth_store_lock,
|
||||
)
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
_save_provider_state(auth_store, "nous", full_state)
|
||||
_save_auth_store(auth_store)
|
||||
except Exception as store_exc:
|
||||
_log.warning(
|
||||
"oauth/device: credential pool saved but auth store write failed "
|
||||
"(session=%s): %s", session_id, store_exc,
|
||||
)
|
||||
from hermes_cli.auth import persist_nous_credentials
|
||||
persist_nous_credentials(full_state)
|
||||
with _oauth_sessions_lock:
|
||||
sess["status"] = "approved"
|
||||
_log.info("oauth/device: nous login completed (session=%s)", session_id)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@ def get_hermes_home() -> Path:
|
|||
Reads HERMES_HOME env var, falls back to ~/.hermes.
|
||||
This is the single source of truth — all other copies should import this.
|
||||
"""
|
||||
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
val = os.environ.get("HERMES_HOME", "").strip()
|
||||
return Path(val) if val else Path.home() / ".hermes"
|
||||
|
||||
|
||||
def get_default_hermes_root() -> Path:
|
||||
|
|
|
|||
|
|
@ -43,6 +43,15 @@ from dotenv import load_dotenv
|
|||
load_dotenv()
|
||||
|
||||
|
||||
def _effective_temperature_for_model(model: str) -> Optional[float]:
|
||||
"""Return a fixed temperature for models with strict sampling contracts."""
|
||||
try:
|
||||
from agent.auxiliary_client import _fixed_temperature_for_model
|
||||
except Exception:
|
||||
return None
|
||||
return _fixed_temperature_for_model(model)
|
||||
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
|
@ -442,12 +451,17 @@ Complete the user's task step by step."""
|
|||
|
||||
# Make API call
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
tools=self.tools,
|
||||
timeout=300.0
|
||||
)
|
||||
api_kwargs = {
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
"tools": self.tools,
|
||||
"timeout": 300.0,
|
||||
}
|
||||
fixed_temperature = _effective_temperature_for_model(self.model)
|
||||
if fixed_temperature is not None:
|
||||
api_kwargs["temperature"] = fixed_temperature
|
||||
|
||||
response = self.client.chat.completions.create(**api_kwargs)
|
||||
except Exception as e:
|
||||
self.logger.error(f"API call failed: {e}")
|
||||
break
|
||||
|
|
|
|||
|
|
@ -103,6 +103,28 @@ json.dump(sorted(leaf_paths(DEFAULT_CONFIG)), sys.stdout, indent=2)
|
|||
echo "ok" > $out/result
|
||||
'';
|
||||
|
||||
# Verify bundled TUI is present and compiled
|
||||
bundled-tui = pkgs.runCommand "hermes-bundled-tui" { } ''
|
||||
set -e
|
||||
echo "=== Checking bundled TUI ==="
|
||||
test -d ${hermes-agent}/ui-tui || (echo "FAIL: ui-tui directory missing"; exit 1)
|
||||
echo "PASS: ui-tui directory exists"
|
||||
|
||||
test -f ${hermes-agent}/ui-tui/dist/entry.js || (echo "FAIL: compiled entry.js missing"; exit 1)
|
||||
echo "PASS: compiled entry.js present"
|
||||
|
||||
test -d ${hermes-agent}/ui-tui/node_modules || (echo "FAIL: node_modules missing"; exit 1)
|
||||
echo "PASS: node_modules present"
|
||||
|
||||
grep -q "HERMES_TUI_DIR" ${hermes-agent}/bin/hermes || \
|
||||
(echo "FAIL: HERMES_TUI_DIR not in wrapper"; exit 1)
|
||||
echo "PASS: HERMES_TUI_DIR set in wrapper"
|
||||
|
||||
echo "=== All bundled TUI checks passed ==="
|
||||
mkdir -p $out
|
||||
echo "ok" > $out/result
|
||||
'';
|
||||
|
||||
# Verify HERMES_MANAGED guard works on all mutation commands
|
||||
managed-guard = pkgs.runCommand "hermes-managed-guard" { } ''
|
||||
set -e
|
||||
|
|
|
|||
|
|
@ -1,49 +1,26 @@
|
|||
# nix/devShell.nix — Fast dev shell with stamp-file optimization
|
||||
# nix/devShell.nix — Dev shell that delegates setup to each package
|
||||
#
|
||||
# Each package in inputsFrom exposes passthru.devShellHook — a bash snippet
|
||||
# with stamp-checked setup logic. This file collects and runs them all.
|
||||
{ inputs, ... }: {
|
||||
perSystem = { pkgs, ... }:
|
||||
perSystem = { pkgs, system, ... }:
|
||||
let
|
||||
python = pkgs.python311;
|
||||
hermes-agent = inputs.self.packages.${system}.default;
|
||||
hermes-tui = inputs.self.packages.${system}.tui;
|
||||
packages = [ hermes-agent hermes-tui ];
|
||||
in {
|
||||
devShells.default = pkgs.mkShell {
|
||||
inputsFrom = packages;
|
||||
packages = with pkgs; [
|
||||
python uv nodejs_20 ripgrep git openssh ffmpeg
|
||||
python311 uv nodejs_22 ripgrep git openssh ffmpeg
|
||||
];
|
||||
|
||||
shellHook = ''
|
||||
shellHook = let
|
||||
hooks = map (p: p.passthru.devShellHook or "") packages;
|
||||
combined = pkgs.lib.concatStringsSep "\n" (builtins.filter (h: h != "") hooks);
|
||||
in ''
|
||||
echo "Hermes Agent dev shell"
|
||||
|
||||
# Composite stamp: changes when nix python or uv change
|
||||
STAMP_VALUE="${python}:${pkgs.uv}"
|
||||
STAMP_FILE=".venv/.nix-stamp"
|
||||
|
||||
# Create venv if missing
|
||||
if [ ! -d .venv ]; then
|
||||
echo "Creating Python 3.11 venv..."
|
||||
uv venv .venv --python ${python}/bin/python3
|
||||
fi
|
||||
|
||||
source .venv/bin/activate
|
||||
|
||||
# Only install if stamp is stale or missing
|
||||
if [ ! -f "$STAMP_FILE" ] || [ "$(cat "$STAMP_FILE")" != "$STAMP_VALUE" ]; then
|
||||
echo "Installing Python dependencies..."
|
||||
uv pip install -e ".[all]"
|
||||
if [ -d mini-swe-agent ]; then
|
||||
uv pip install -e ./mini-swe-agent 2>/dev/null || true
|
||||
fi
|
||||
if [ -d tinker-atropos ]; then
|
||||
uv pip install -e ./tinker-atropos 2>/dev/null || true
|
||||
fi
|
||||
|
||||
# Install npm deps
|
||||
if [ -f package.json ] && [ ! -d node_modules ]; then
|
||||
echo "Installing npm dependencies..."
|
||||
npm install
|
||||
fi
|
||||
|
||||
echo "$STAMP_VALUE" > "$STAMP_FILE"
|
||||
fi
|
||||
|
||||
${combined}
|
||||
echo "Ready. Run 'hermes' to start."
|
||||
'';
|
||||
};
|
||||
|
|
|
|||
112
nix/packages.nix
112
nix/packages.nix
|
|
@ -1,54 +1,108 @@
|
|||
# nix/packages.nix — Hermes Agent package built with uv2nix
|
||||
{ inputs, ... }: {
|
||||
perSystem = { pkgs, system, ... }:
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ pkgs, inputs', ... }:
|
||||
let
|
||||
hermesVenv = pkgs.callPackage ./python.nix {
|
||||
inherit (inputs) uv2nix pyproject-nix pyproject-build-systems;
|
||||
};
|
||||
|
||||
hermesTui = pkgs.callPackage ./tui.nix {
|
||||
npm-lockfile-fix = inputs'.npm-lockfile-fix.packages.default;
|
||||
};
|
||||
|
||||
# Import bundled skills, excluding runtime caches
|
||||
bundledSkills = pkgs.lib.cleanSourceWith {
|
||||
src = ../skills;
|
||||
filter = path: _type:
|
||||
!(pkgs.lib.hasInfix "/index-cache/" path);
|
||||
filter = path: _type: !(pkgs.lib.hasInfix "/index-cache/" path);
|
||||
};
|
||||
|
||||
runtimeDeps = with pkgs; [
|
||||
nodejs_20 ripgrep git openssh ffmpeg tirith
|
||||
nodejs_22
|
||||
ripgrep
|
||||
git
|
||||
openssh
|
||||
ffmpeg
|
||||
tirith
|
||||
];
|
||||
|
||||
runtimePath = pkgs.lib.makeBinPath runtimeDeps;
|
||||
in {
|
||||
packages.default = pkgs.stdenv.mkDerivation {
|
||||
pname = "hermes-agent";
|
||||
version = (builtins.fromTOML (builtins.readFile ../pyproject.toml)).project.version;
|
||||
|
||||
dontUnpack = true;
|
||||
dontBuild = true;
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
# Lockfile hashes for dev shell stamps
|
||||
pyprojectHash = builtins.hashString "sha256" (builtins.readFile ../pyproject.toml);
|
||||
uvLockHash =
|
||||
if builtins.pathExists ../uv.lock then
|
||||
builtins.hashString "sha256" (builtins.readFile ../uv.lock)
|
||||
else
|
||||
"none";
|
||||
in
|
||||
{
|
||||
packages = {
|
||||
default = pkgs.stdenv.mkDerivation {
|
||||
pname = "hermes-agent";
|
||||
version = (fromTOML (builtins.readFile ../pyproject.toml)).project.version;
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
dontUnpack = true;
|
||||
dontBuild = true;
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
|
||||
mkdir -p $out/share/hermes-agent $out/bin
|
||||
cp -r ${bundledSkills} $out/share/hermes-agent/skills
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
|
||||
${pkgs.lib.concatMapStringsSep "\n" (name: ''
|
||||
makeWrapper ${hermesVenv}/bin/${name} $out/bin/${name} \
|
||||
--suffix PATH : "${runtimePath}" \
|
||||
--set HERMES_BUNDLED_SKILLS $out/share/hermes-agent/skills
|
||||
'') [ "hermes" "hermes-agent" "hermes-acp" ]}
|
||||
mkdir -p $out/share/hermes-agent $out/bin
|
||||
cp -r ${bundledSkills} $out/share/hermes-agent/skills
|
||||
|
||||
runHook postInstall
|
||||
'';
|
||||
# copy pre-built TUI (same layout as dev: ui-tui/dist/ + node_modules/)
|
||||
mkdir -p $out/ui-tui
|
||||
cp -r ${hermesTui}/lib/hermes-tui/* $out/ui-tui/
|
||||
|
||||
meta = with pkgs.lib; {
|
||||
description = "AI agent with advanced tool-calling capabilities";
|
||||
homepage = "https://github.com/NousResearch/hermes-agent";
|
||||
mainProgram = "hermes";
|
||||
license = licenses.mit;
|
||||
platforms = platforms.unix;
|
||||
${pkgs.lib.concatMapStringsSep "\n"
|
||||
(name: ''
|
||||
makeWrapper ${hermesVenv}/bin/${name} $out/bin/${name} \
|
||||
--suffix PATH : "${runtimePath}" \
|
||||
--set HERMES_BUNDLED_SKILLS $out/share/hermes-agent/skills \
|
||||
--set HERMES_TUI_DIR $out/ui-tui \
|
||||
--set HERMES_PYTHON ${hermesVenv}/bin/python3
|
||||
'')
|
||||
[
|
||||
"hermes"
|
||||
"hermes-agent"
|
||||
"hermes-acp"
|
||||
]
|
||||
}
|
||||
|
||||
runHook postInstall
|
||||
'';
|
||||
|
||||
passthru.devShellHook = ''
|
||||
STAMP=".nix-stamps/hermes-agent"
|
||||
STAMP_VALUE="${pyprojectHash}:${uvLockHash}"
|
||||
if [ ! -f "$STAMP" ] || [ "$(cat "$STAMP")" != "$STAMP_VALUE" ]; then
|
||||
echo "hermes-agent: installing Python dependencies..."
|
||||
uv venv .venv --python ${pkgs.python311}/bin/python3 2>/dev/null || true
|
||||
source .venv/bin/activate
|
||||
uv pip install -e ".[all]"
|
||||
[ -d mini-swe-agent ] && uv pip install -e ./mini-swe-agent 2>/dev/null || true
|
||||
[ -d tinker-atropos ] && uv pip install -e ./tinker-atropos 2>/dev/null || true
|
||||
mkdir -p .nix-stamps
|
||||
echo "$STAMP_VALUE" > "$STAMP"
|
||||
else
|
||||
source .venv/bin/activate
|
||||
export HERMES_PYTHON=${hermesVenv}/bin/python3
|
||||
fi
|
||||
'';
|
||||
|
||||
meta = with pkgs.lib; {
|
||||
description = "AI agent with advanced tool-calling capabilities";
|
||||
homepage = "https://github.com/NousResearch/hermes-agent";
|
||||
mainProgram = "hermes";
|
||||
license = licenses.mit;
|
||||
platforms = platforms.unix;
|
||||
};
|
||||
};
|
||||
|
||||
tui = hermesTui;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
|
|
|||
82
nix/tui.nix
Normal file
82
nix/tui.nix
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
# nix/tui.nix — Hermes TUI (Ink/React) compiled with tsc and bundled
|
||||
{ pkgs, npm-lockfile-fix, ... }:
|
||||
let
|
||||
src = ../ui-tui;
|
||||
npmDeps = pkgs.fetchNpmDeps {
|
||||
inherit src;
|
||||
hash = "sha256-zsUPmbC6oMUO10EhS3ptvDjwlfpCSEmrkjyeORw7fac=";
|
||||
};
|
||||
|
||||
packageJson = builtins.fromJSON (builtins.readFile (src + "/package.json"));
|
||||
version = packageJson.version;
|
||||
|
||||
npmLockHash = builtins.hashString "sha256" (builtins.readFile ../ui-tui/package-lock.json);
|
||||
in
|
||||
pkgs.buildNpmPackage {
|
||||
pname = "hermes-tui";
|
||||
inherit src npmDeps version;
|
||||
|
||||
doCheck = false;
|
||||
|
||||
postPatch = ''
|
||||
# fetchNpmDeps strips the trailing newline; match it so the diff passes
|
||||
sed -i -z 's/\n$//' package-lock.json
|
||||
'';
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
|
||||
mkdir -p $out/lib/hermes-tui
|
||||
|
||||
cp -r dist $out/lib/hermes-tui/dist
|
||||
|
||||
# runtime node_modules
|
||||
cp -r node_modules $out/lib/hermes-tui/node_modules
|
||||
|
||||
# @hermes/ink is a file: dependency, we need to copy it in fr
|
||||
rm -f $out/lib/hermes-tui/node_modules/@hermes/ink
|
||||
cp -r packages/hermes-ink $out/lib/hermes-tui/node_modules/@hermes/ink
|
||||
|
||||
# package.json needed for "type": "module" resolution
|
||||
cp package.json $out/lib/hermes-tui/
|
||||
|
||||
runHook postInstall
|
||||
'';
|
||||
|
||||
nativeBuildInputs = [
|
||||
(pkgs.writeShellScriptBin "update_tui_lockfile" ''
|
||||
set -euox pipefail
|
||||
|
||||
# get root of repo
|
||||
REPO_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
# cd into ui-tui and reinstall
|
||||
cd "$REPO_ROOT/ui-tui"
|
||||
rm -rf node_modules/
|
||||
npm cache clean --force
|
||||
CI=true npm install # ci env var to suppress annoying unicode install banner lag
|
||||
${pkgs.lib.getExe npm-lockfile-fix} ./package-lock.json
|
||||
|
||||
NIX_FILE="$REPO_ROOT/nix/tui.nix"
|
||||
# compute the new hash
|
||||
sed -i "s/hash = \"[^\"]*\";/hash = \"\";/" $NIX_FILE
|
||||
NIX_OUTPUT=$(nix build .#tui 2>&1 || true)
|
||||
NEW_HASH=$(echo "$NIX_OUTPUT" | grep 'got:' | awk '{print $2}')
|
||||
echo got new hash $NEW_HASH
|
||||
sed -i "s|hash = \"[^\"]*\";|hash = \"$NEW_HASH\";|" $NIX_FILE
|
||||
nix build .#tui
|
||||
echo "Updated npm hash in $NIX_FILE to $NEW_HASH"
|
||||
'')
|
||||
];
|
||||
|
||||
passthru.devShellHook = ''
|
||||
STAMP=".nix-stamps/hermes-tui"
|
||||
STAMP_VALUE="${npmLockHash}"
|
||||
if [ ! -f "$STAMP" ] || [ "$(cat "$STAMP")" != "$STAMP_VALUE" ]; then
|
||||
echo "hermes-tui: installing npm dependencies..."
|
||||
cd ui-tui && CI=true npm install --silent --no-fund --no-audit 2>/dev/null && cd ..
|
||||
mkdir -p .nix-stamps
|
||||
echo "$STAMP_VALUE" > "$STAMP"
|
||||
fi
|
||||
'';
|
||||
}
|
||||
|
|
@ -76,8 +76,8 @@ termux = [
|
|||
"hermes-agent[honcho]",
|
||||
"hermes-agent[acp]",
|
||||
]
|
||||
dingtalk = ["dingtalk-stream>=0.1.0,<1"]
|
||||
feishu = ["lark-oapi>=1.5.3,<2"]
|
||||
dingtalk = ["dingtalk-stream>=0.20,<1", "alibabacloud-dingtalk>=2.0.0", "qrcode>=7.0,<8"]
|
||||
feishu = ["lark-oapi>=1.5.3,<2", "qrcode>=7.0,<8"]
|
||||
web = ["fastapi>=0.104.0,<1", "uvicorn[standard]>=0.24.0,<1"]
|
||||
rl = [
|
||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git@c20c85256e5a45ad31edf8b7276e9c5ee1995a30",
|
||||
|
|
@ -126,7 +126,7 @@ py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajector
|
|||
hermes_cli = ["web_dist/**/*"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["agent", "tools", "tools.*", "hermes_cli", "gateway", "gateway.*", "cron", "acp_adapter", "plugins", "plugins.*"]
|
||||
include = ["agent", "tools", "tools.*", "hermes_cli", "gateway", "gateway.*", "tui_gateway", "tui_gateway.*", "cron", "acp_adapter", "plugins", "plugins.*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
|
|
|
|||
162
run_agent.py
162
run_agent.py
|
|
@ -353,12 +353,50 @@ def _sanitize_surrogates(text: str) -> str:
|
|||
return text
|
||||
|
||||
|
||||
def _sanitize_structure_surrogates(payload: Any) -> bool:
|
||||
"""Replace surrogate code points in nested dict/list payloads in-place.
|
||||
|
||||
Mirror of ``_sanitize_structure_non_ascii`` but for surrogate recovery.
|
||||
Used to scrub nested structured fields (e.g. ``reasoning_details`` — an
|
||||
array of dicts with ``summary``/``text`` strings) that flat per-field
|
||||
checks don't reach. Returns True if any surrogates were replaced.
|
||||
"""
|
||||
found = False
|
||||
|
||||
def _walk(node):
|
||||
nonlocal found
|
||||
if isinstance(node, dict):
|
||||
for key, value in node.items():
|
||||
if isinstance(value, str):
|
||||
if _SURROGATE_RE.search(value):
|
||||
node[key] = _SURROGATE_RE.sub('\ufffd', value)
|
||||
found = True
|
||||
elif isinstance(value, (dict, list)):
|
||||
_walk(value)
|
||||
elif isinstance(node, list):
|
||||
for idx, value in enumerate(node):
|
||||
if isinstance(value, str):
|
||||
if _SURROGATE_RE.search(value):
|
||||
node[idx] = _SURROGATE_RE.sub('\ufffd', value)
|
||||
found = True
|
||||
elif isinstance(value, (dict, list)):
|
||||
_walk(value)
|
||||
|
||||
_walk(payload)
|
||||
return found
|
||||
|
||||
|
||||
def _sanitize_messages_surrogates(messages: list) -> bool:
|
||||
"""Sanitize surrogate characters from all string content in a messages list.
|
||||
|
||||
Walks message dicts in-place. Returns True if any surrogates were found
|
||||
and replaced, False otherwise. Covers content/text, name, and tool call
|
||||
metadata/arguments so retries don't fail on a non-content field.
|
||||
and replaced, False otherwise. Covers content/text, name, tool call
|
||||
metadata/arguments, AND any additional string or nested structured fields
|
||||
(``reasoning``, ``reasoning_content``, ``reasoning_details``, etc.) so
|
||||
retries don't fail on a non-content field. Byte-level reasoning models
|
||||
(xiaomi/mimo, kimi, glm) can emit lone surrogates in reasoning output
|
||||
that flow through to ``api_messages["reasoning_content"]`` on the next
|
||||
turn and crash json.dumps inside the OpenAI SDK.
|
||||
"""
|
||||
found = False
|
||||
for msg in messages:
|
||||
|
|
@ -398,6 +436,21 @@ def _sanitize_messages_surrogates(messages: list) -> bool:
|
|||
if isinstance(fn_args, str) and _SURROGATE_RE.search(fn_args):
|
||||
fn["arguments"] = _SURROGATE_RE.sub('\ufffd', fn_args)
|
||||
found = True
|
||||
# Walk any additional string / nested fields (reasoning,
|
||||
# reasoning_content, reasoning_details, etc.) — surrogates from
|
||||
# byte-level reasoning models (xiaomi/mimo, kimi, glm) can lurk
|
||||
# in these fields and aren't covered by the per-field checks above.
|
||||
# Matches _sanitize_messages_non_ascii's coverage (PR #10537).
|
||||
for key, value in msg.items():
|
||||
if key in {"content", "name", "tool_calls", "role"}:
|
||||
continue
|
||||
if isinstance(value, str):
|
||||
if _SURROGATE_RE.search(value):
|
||||
msg[key] = _SURROGATE_RE.sub('\ufffd', value)
|
||||
found = True
|
||||
elif isinstance(value, (dict, list)):
|
||||
if _sanitize_structure_surrogates(value):
|
||||
found = True
|
||||
return found
|
||||
|
||||
|
||||
|
|
@ -5841,6 +5894,7 @@ class AIAgent:
|
|||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._emit_status("🔄 Reconnected — resuming…")
|
||||
continue
|
||||
self._emit_status(
|
||||
"❌ Connection to provider failed after "
|
||||
|
|
@ -6744,6 +6798,14 @@ class AIAgent:
|
|||
"messages": sanitized_messages,
|
||||
"timeout": float(os.getenv("HERMES_API_TIMEOUT", 1800.0)),
|
||||
}
|
||||
try:
|
||||
from agent.auxiliary_client import _fixed_temperature_for_model
|
||||
except Exception:
|
||||
_fixed_temperature_for_model = None
|
||||
if _fixed_temperature_for_model is not None:
|
||||
fixed_temperature = _fixed_temperature_for_model(self.model)
|
||||
if fixed_temperature is not None:
|
||||
api_kwargs["temperature"] = fixed_temperature
|
||||
if self._is_qwen_portal():
|
||||
api_kwargs["metadata"] = {
|
||||
"sessionId": self.session_id or "hermes",
|
||||
|
|
@ -6949,7 +7011,7 @@ class AIAgent:
|
|||
# (gateway, batch, quiet) still get reasoning.
|
||||
# Any reasoning that wasn't shown during streaming is caught by the
|
||||
# CLI post-response display fallback (cli.py _reasoning_shown_this_turn).
|
||||
if not self.stream_delta_callback:
|
||||
if not self.stream_delta_callback and not self._stream_callback:
|
||||
try:
|
||||
self.reasoning_callback(reasoning_text)
|
||||
except Exception:
|
||||
|
|
@ -7154,14 +7216,22 @@ class AIAgent:
|
|||
|
||||
# Use auxiliary client for the flush call when available --
|
||||
# it's cheaper and avoids Codex Responses API incompatibility.
|
||||
from agent.auxiliary_client import call_llm as _call_llm
|
||||
from agent.auxiliary_client import (
|
||||
call_llm as _call_llm,
|
||||
_fixed_temperature_for_model,
|
||||
)
|
||||
_aux_available = True
|
||||
# Use the fixed-temperature override (e.g. kimi-for-coding → 0.6) if
|
||||
# the model has a strict contract; otherwise the historical 0.3 default.
|
||||
_flush_temperature = _fixed_temperature_for_model(self.model)
|
||||
if _flush_temperature is None:
|
||||
_flush_temperature = 0.3
|
||||
try:
|
||||
response = _call_llm(
|
||||
task="flush_memories",
|
||||
messages=api_messages,
|
||||
tools=[memory_tool_def],
|
||||
temperature=0.3,
|
||||
temperature=_flush_temperature,
|
||||
max_tokens=5120,
|
||||
# timeout resolved from auxiliary.flush_memories.timeout config
|
||||
)
|
||||
|
|
@ -7173,7 +7243,7 @@ class AIAgent:
|
|||
# No auxiliary client -- use the Codex Responses path directly
|
||||
codex_kwargs = self._build_api_kwargs(api_messages)
|
||||
codex_kwargs["tools"] = self._responses_tools([memory_tool_def])
|
||||
codex_kwargs["temperature"] = 0.3
|
||||
codex_kwargs["temperature"] = _flush_temperature
|
||||
if "max_output_tokens" in codex_kwargs:
|
||||
codex_kwargs["max_output_tokens"] = 5120
|
||||
response = self._run_codex_stream(codex_kwargs)
|
||||
|
|
@ -7192,7 +7262,7 @@ class AIAgent:
|
|||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
"tools": [memory_tool_def],
|
||||
"temperature": 0.3,
|
||||
"temperature": _flush_temperature,
|
||||
**self._max_tokens_param(5120),
|
||||
}
|
||||
from agent.auxiliary_client import _get_task_timeout
|
||||
|
|
@ -8165,6 +8235,15 @@ class AIAgent:
|
|||
api_messages.insert(sys_offset + idx, pfm.copy())
|
||||
|
||||
summary_extra_body = {}
|
||||
try:
|
||||
from agent.auxiliary_client import _fixed_temperature_for_model
|
||||
except Exception:
|
||||
_fixed_temperature_for_model = None
|
||||
_summary_temperature = (
|
||||
_fixed_temperature_for_model(self.model)
|
||||
if _fixed_temperature_for_model is not None
|
||||
else None
|
||||
)
|
||||
_is_nous = "nousresearch" in self._base_url_lower
|
||||
if self._supports_reasoning_extra_body():
|
||||
if self.reasoning_config is not None:
|
||||
|
|
@ -8188,6 +8267,8 @@ class AIAgent:
|
|||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
}
|
||||
if _summary_temperature is not None:
|
||||
summary_kwargs["temperature"] = _summary_temperature
|
||||
if self.max_tokens is not None:
|
||||
summary_kwargs.update(self._max_tokens_param(self.max_tokens))
|
||||
|
||||
|
|
@ -8253,6 +8334,8 @@ class AIAgent:
|
|||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
}
|
||||
if _summary_temperature is not None:
|
||||
summary_kwargs["temperature"] = _summary_temperature
|
||||
if self.max_tokens is not None:
|
||||
summary_kwargs.update(self._max_tokens_param(self.max_tokens))
|
||||
if summary_extra_body:
|
||||
|
|
@ -8688,6 +8771,7 @@ class AIAgent:
|
|||
{
|
||||
"name": tc["function"]["name"],
|
||||
"result": _results_by_id.get(tc.get("id")),
|
||||
"arguments": tc["function"].get("arguments"),
|
||||
}
|
||||
for tc in _m["tool_calls"]
|
||||
if isinstance(tc, dict)
|
||||
|
|
@ -9302,8 +9386,7 @@ class AIAgent:
|
|||
"and had none left for the actual response.\n\n"
|
||||
"To fix this:\n"
|
||||
"→ Lower reasoning effort: `/thinkon low` or `/thinkon minimal`\n"
|
||||
"→ Increase the output token limit: "
|
||||
"set `model.max_tokens` in config.yaml"
|
||||
"→ Or switch to a larger/non-reasoning model with `/model`"
|
||||
)
|
||||
self._cleanup_task_resources(effective_task_id)
|
||||
self._persist_session(messages, conversation_history)
|
||||
|
|
@ -9570,13 +9653,51 @@ class AIAgent:
|
|||
if isinstance(api_error, UnicodeEncodeError) and getattr(self, '_unicode_sanitization_passes', 0) < 2:
|
||||
_err_str = str(api_error).lower()
|
||||
_is_ascii_codec = "'ascii'" in _err_str or "ascii" in _err_str
|
||||
# Detect surrogate errors — utf-8 codec refusing to
|
||||
# encode U+D800..U+DFFF. The error text is:
|
||||
# "'utf-8' codec can't encode characters in position
|
||||
# N-M: surrogates not allowed"
|
||||
_is_surrogate_error = (
|
||||
"surrogate" in _err_str
|
||||
or ("'utf-8'" in _err_str and not _is_ascii_codec)
|
||||
)
|
||||
# Sanitize surrogates from both the canonical `messages`
|
||||
# list AND `api_messages` (the API-copy, which may carry
|
||||
# `reasoning_content`/`reasoning_details` transformed
|
||||
# from `reasoning` — fields the canonical list doesn't
|
||||
# have directly). Also clean `api_kwargs` if built and
|
||||
# `prefill_messages` if present. Mirrors the ASCII
|
||||
# codec recovery below.
|
||||
_surrogates_found = _sanitize_messages_surrogates(messages)
|
||||
if _surrogates_found:
|
||||
if isinstance(api_messages, list):
|
||||
if _sanitize_messages_surrogates(api_messages):
|
||||
_surrogates_found = True
|
||||
if isinstance(api_kwargs, dict):
|
||||
if _sanitize_structure_surrogates(api_kwargs):
|
||||
_surrogates_found = True
|
||||
if isinstance(getattr(self, "prefill_messages", None), list):
|
||||
if _sanitize_messages_surrogates(self.prefill_messages):
|
||||
_surrogates_found = True
|
||||
# Gate the retry on the error type, not on whether we
|
||||
# found anything — _force_ascii_payload / the extended
|
||||
# surrogate walker above cover all known paths, but a
|
||||
# new transformed field could still slip through. If
|
||||
# the error was a surrogate encode failure, always let
|
||||
# the retry run; the proactive sanitizer at line ~8781
|
||||
# runs again on the next iteration. Bounded by
|
||||
# _unicode_sanitization_passes < 2 (outer guard).
|
||||
if _surrogates_found or _is_surrogate_error:
|
||||
self._unicode_sanitization_passes += 1
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...",
|
||||
force=True,
|
||||
)
|
||||
if _surrogates_found:
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Stripped invalid surrogate characters from messages. Retrying...",
|
||||
force=True,
|
||||
)
|
||||
else:
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Surrogate encoding error — retrying after full-payload sanitization...",
|
||||
force=True,
|
||||
)
|
||||
continue
|
||||
if _is_ascii_codec:
|
||||
self._force_ascii_payload = True
|
||||
|
|
@ -10344,9 +10465,9 @@ class AIAgent:
|
|||
pass
|
||||
wait_time = _retry_after if _retry_after else jittered_backoff(retry_count, base_delay=2.0, max_delay=60.0)
|
||||
if is_rate_limited:
|
||||
self._emit_status(f"⏱️ Rate limit reached. Waiting {wait_time}s before retry (attempt {retry_count + 1}/{max_retries})...")
|
||||
self._emit_status(f"⏱️ Rate limited. Waiting {wait_time:.1f}s (attempt {retry_count + 1}/{max_retries})...")
|
||||
else:
|
||||
self._emit_status(f"⏳ Retrying in {wait_time}s (attempt {retry_count}/{max_retries})...")
|
||||
self._emit_status(f"⏳ Retrying in {wait_time:.1f}s (attempt {retry_count}/{max_retries})...")
|
||||
logger.warning(
|
||||
"Retrying API call in %ss (attempt %s/%s) %s error=%s",
|
||||
wait_time,
|
||||
|
|
@ -10762,7 +10883,14 @@ class AIAgent:
|
|||
elif self.quiet_mode:
|
||||
clean = self._strip_think_blocks(turn_content).strip()
|
||||
if clean:
|
||||
self._vprint(f" ┊ 💬 {clean}")
|
||||
relayed = False
|
||||
if (
|
||||
self.tool_progress_callback
|
||||
and getattr(self, "platform", "") == "tui"
|
||||
):
|
||||
relayed = True
|
||||
if not relayed:
|
||||
self._vprint(f" ┊ 💬 {clean}")
|
||||
|
||||
# Pop thinking-only prefill message(s) before appending
|
||||
# (tool-call path — same rationale as the final-response path).
|
||||
|
|
|
|||
|
|
@ -721,6 +721,20 @@ function Install-NodeDeps {
|
|||
}
|
||||
}
|
||||
|
||||
# Install TUI dependencies
|
||||
$tuiDir = "$InstallDir\ui-tui"
|
||||
if (Test-Path "$tuiDir\package.json") {
|
||||
Write-Info "Installing TUI dependencies..."
|
||||
Push-Location $tuiDir
|
||||
try {
|
||||
npm install --silent 2>&1 | Out-Null
|
||||
Write-Success "TUI dependencies installed"
|
||||
} catch {
|
||||
Write-Warn "TUI npm install failed (hermes --tui may not work)"
|
||||
}
|
||||
Pop-Location
|
||||
}
|
||||
|
||||
# Install WhatsApp bridge dependencies
|
||||
$bridgeDir = "$InstallDir\scripts\whatsapp-bridge"
|
||||
if (Test-Path "$bridgeDir\package.json") {
|
||||
|
|
|
|||
|
|
@ -1194,6 +1194,16 @@ install_node_deps() {
|
|||
log_success "Browser engine setup complete"
|
||||
fi
|
||||
|
||||
# Install TUI dependencies
|
||||
if [ -f "$INSTALL_DIR/ui-tui/package.json" ]; then
|
||||
log_info "Installing TUI dependencies..."
|
||||
cd "$INSTALL_DIR/ui-tui"
|
||||
npm install --silent 2>/dev/null || {
|
||||
log_warn "TUI npm install failed (hermes --tui may not work)"
|
||||
}
|
||||
log_success "TUI dependencies installed"
|
||||
fi
|
||||
|
||||
# Install WhatsApp bridge dependencies
|
||||
if [ -f "$INSTALL_DIR/scripts/whatsapp-bridge/package.json" ]; then
|
||||
log_info "Installing WhatsApp bridge dependencies..."
|
||||
|
|
|
|||
238
scripts/lib/node-bootstrap.sh
Normal file
238
scripts/lib/node-bootstrap.sh
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
#!/usr/bin/env bash
|
||||
# ============================================================================
|
||||
# scripts/lib/node-bootstrap.sh
|
||||
# ----------------------------------------------------------------------------
|
||||
# Sourceable helper: ensure Node.js >= MIN_VERSION is available for the TUI
|
||||
# (React + Ink), browser tools, and the WhatsApp bridge.
|
||||
#
|
||||
# Strategy (first hit wins — respects the user's existing tooling):
|
||||
# 1. modern `node` already on PATH
|
||||
# 2. ~/.hermes/node/ from a prior Hermes-managed install
|
||||
# 3. fnm, proto, nvm (in that order) if the user already uses a version manager
|
||||
# 4. Termux `pkg`, macOS Homebrew
|
||||
# 5. pinned nodejs.org tarball into ~/.hermes/node/ (always works, zero shell rc edits)
|
||||
#
|
||||
# Usage:
|
||||
# source scripts/lib/node-bootstrap.sh
|
||||
# ensure_node # returns 0 on success, non-zero on failure
|
||||
# if [ "$HERMES_NODE_AVAILABLE" = true ]; then ...; fi
|
||||
#
|
||||
# Env inputs (set before sourcing to override defaults):
|
||||
# HERMES_NODE_MIN_VERSION (default: 20) — accepted on PATH
|
||||
# HERMES_NODE_TARGET_MAJOR (default: 22) — installed when we install
|
||||
# HERMES_HOME (default: $HOME/.hermes)
|
||||
# ============================================================================
|
||||
|
||||
HERMES_NODE_MIN_VERSION="${HERMES_NODE_MIN_VERSION:-20}"
|
||||
HERMES_NODE_TARGET_MAJOR="${HERMES_NODE_TARGET_MAJOR:-22}"
|
||||
HERMES_HOME="${HERMES_HOME:-$HOME/.hermes}"
|
||||
HERMES_NODE_AVAILABLE=false
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging — prefer the host script's log_* helpers when present
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_log() { declare -F log_info >/dev/null 2>&1 && log_info "$*" || printf '→ %s\n' "$*" >&2; }
|
||||
_nb_ok() { declare -F log_success >/dev/null 2>&1 && log_success "$*" || printf '✓ %s\n' "$*" >&2; }
|
||||
_nb_warn() { declare -F log_warn >/dev/null 2>&1 && log_warn "$*" || printf '⚠ %s\n' "$*" >&2; }
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Platform + version helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_is_termux() {
|
||||
[ -n "${TERMUX_VERSION:-}" ] || [[ "${PREFIX:-}" == *"com.termux/files/usr"* ]]
|
||||
}
|
||||
|
||||
_nb_node_major() {
|
||||
local v
|
||||
v=$(node --version 2>/dev/null | sed 's/^v//' | cut -d. -f1)
|
||||
[[ "$v" =~ ^[0-9]+$ ]] && echo "$v" || echo 0
|
||||
}
|
||||
|
||||
_nb_have_modern_node() {
|
||||
command -v node >/dev/null 2>&1 || return 1
|
||||
[ "$(_nb_node_major)" -ge "$HERMES_NODE_MIN_VERSION" ]
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Version-manager paths — respect what the user already uses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_try_fnm() {
|
||||
command -v fnm >/dev/null 2>&1 || return 1
|
||||
_nb_log "fnm detected — installing Node $HERMES_NODE_TARGET_MAJOR..."
|
||||
eval "$(fnm env 2>/dev/null)" || true
|
||||
fnm install "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
fnm use "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) activated via fnm"
|
||||
return 0
|
||||
}
|
||||
|
||||
_nb_try_proto() {
|
||||
command -v proto >/dev/null 2>&1 || return 1
|
||||
_nb_log "proto detected — installing Node $HERMES_NODE_TARGET_MAJOR..."
|
||||
proto install node "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) activated via proto"
|
||||
return 0
|
||||
}
|
||||
|
||||
_nb_try_nvm() {
|
||||
local nvm_sh="${NVM_DIR:-$HOME/.nvm}/nvm.sh"
|
||||
[ -s "$nvm_sh" ] || return 1
|
||||
# shellcheck source=/dev/null
|
||||
\. "$nvm_sh" >/dev/null 2>&1 || return 1
|
||||
_nb_log "nvm detected — installing Node $HERMES_NODE_TARGET_MAJOR..."
|
||||
nvm install "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
nvm use "$HERMES_NODE_TARGET_MAJOR" >/dev/null 2>&1 || return 1
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) activated via nvm"
|
||||
return 0
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Platform package managers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_try_termux_pkg() {
|
||||
_nb_is_termux || return 1
|
||||
_nb_log "Installing Node.js via pkg..."
|
||||
pkg install -y nodejs >/dev/null 2>&1 || return 1
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) installed via pkg"
|
||||
return 0
|
||||
}
|
||||
|
||||
_nb_try_brew() {
|
||||
[ "$(uname -s)" = "Darwin" ] || return 1
|
||||
command -v brew >/dev/null 2>&1 || return 1
|
||||
_nb_log "Installing Node via Homebrew..."
|
||||
brew install "node@${HERMES_NODE_TARGET_MAJOR}" >/dev/null 2>&1 \
|
||||
|| brew install node >/dev/null 2>&1 \
|
||||
|| return 1
|
||||
brew link --overwrite --force "node@${HERMES_NODE_TARGET_MAJOR}" >/dev/null 2>&1 || true
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) installed via Homebrew"
|
||||
return 0
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bundled binary fallback — always works, no shell rc edits
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_nb_install_bundled_node() {
|
||||
local arch node_arch os_name node_os
|
||||
arch=$(uname -m)
|
||||
case "$arch" in
|
||||
x86_64) node_arch="x64" ;;
|
||||
aarch64|arm64) node_arch="arm64" ;;
|
||||
armv7l) node_arch="armv7l" ;;
|
||||
*)
|
||||
_nb_warn "Unsupported arch ($arch) — install Node.js manually: https://nodejs.org/"
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
|
||||
os_name=$(uname -s)
|
||||
case "$os_name" in
|
||||
Linux*) node_os="linux" ;;
|
||||
Darwin*) node_os="darwin" ;;
|
||||
*)
|
||||
_nb_warn "Unsupported OS ($os_name) — install Node.js manually: https://nodejs.org/"
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
|
||||
local index_url="https://nodejs.org/dist/latest-v${HERMES_NODE_TARGET_MAJOR}.x/"
|
||||
local tarball
|
||||
tarball=$(curl -fsSL "$index_url" \
|
||||
| grep -oE "node-v${HERMES_NODE_TARGET_MAJOR}\.[0-9]+\.[0-9]+-${node_os}-${node_arch}\.tar\.xz" \
|
||||
| head -1)
|
||||
if [ -z "$tarball" ]; then
|
||||
tarball=$(curl -fsSL "$index_url" \
|
||||
| grep -oE "node-v${HERMES_NODE_TARGET_MAJOR}\.[0-9]+\.[0-9]+-${node_os}-${node_arch}\.tar\.gz" \
|
||||
| head -1)
|
||||
fi
|
||||
if [ -z "$tarball" ]; then
|
||||
_nb_warn "Could not resolve Node $HERMES_NODE_TARGET_MAJOR binary for $node_os-$node_arch"
|
||||
return 1
|
||||
fi
|
||||
|
||||
local tmp
|
||||
tmp=$(mktemp -d)
|
||||
_nb_log "Downloading $tarball..."
|
||||
curl -fsSL "${index_url}${tarball}" -o "$tmp/$tarball" || {
|
||||
_nb_warn "Download failed"; rm -rf "$tmp"; return 1
|
||||
}
|
||||
|
||||
_nb_log "Extracting to $HERMES_HOME/node/..."
|
||||
if [[ "$tarball" == *.tar.xz ]]; then
|
||||
tar xf "$tmp/$tarball" -C "$tmp" || { rm -rf "$tmp"; return 1; }
|
||||
else
|
||||
tar xzf "$tmp/$tarball" -C "$tmp" || { rm -rf "$tmp"; return 1; }
|
||||
fi
|
||||
|
||||
local extracted
|
||||
extracted=$(find "$tmp" -maxdepth 1 -type d -name 'node-v*' 2>/dev/null | head -1)
|
||||
if [ ! -d "$extracted" ]; then
|
||||
_nb_warn "Extraction produced no node-v* directory"
|
||||
rm -rf "$tmp"
|
||||
return 1
|
||||
fi
|
||||
|
||||
mkdir -p "$HERMES_HOME"
|
||||
rm -rf "$HERMES_HOME/node"
|
||||
mv "$extracted" "$HERMES_HOME/node"
|
||||
rm -rf "$tmp"
|
||||
|
||||
mkdir -p "$HOME/.local/bin"
|
||||
ln -sf "$HERMES_HOME/node/bin/node" "$HOME/.local/bin/node"
|
||||
ln -sf "$HERMES_HOME/node/bin/npm" "$HOME/.local/bin/npm"
|
||||
ln -sf "$HERMES_HOME/node/bin/npx" "$HOME/.local/bin/npx"
|
||||
export PATH="$HERMES_HOME/node/bin:$PATH"
|
||||
|
||||
_nb_have_modern_node || return 1
|
||||
_nb_ok "Node $(node --version) installed to $HERMES_HOME/node/"
|
||||
return 0
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ensure_node() {
|
||||
HERMES_NODE_AVAILABLE=false
|
||||
|
||||
if _nb_have_modern_node; then
|
||||
_nb_ok "Node $(node --version) found"
|
||||
HERMES_NODE_AVAILABLE=true
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [ -x "$HERMES_HOME/node/bin/node" ]; then
|
||||
export PATH="$HERMES_HOME/node/bin:$PATH"
|
||||
if _nb_have_modern_node; then
|
||||
_nb_ok "Node $(node --version) found (Hermes-managed)"
|
||||
HERMES_NODE_AVAILABLE=true
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
|
||||
# Version managers first — respect the user's existing setup.
|
||||
_nb_try_fnm && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
_nb_try_proto && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
_nb_try_nvm && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
|
||||
# Platform package managers.
|
||||
_nb_try_termux_pkg && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
_nb_try_brew && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
|
||||
# Last resort: pinned nodejs.org tarball.
|
||||
_nb_install_bundled_node && { HERMES_NODE_AVAILABLE=true; return 0; }
|
||||
|
||||
_nb_warn "Node.js install failed — TUI and browser tools will be unavailable."
|
||||
_nb_warn "Install manually: https://nodejs.org/en/download/ (or: \`brew install node\`, \`fnm install $HERMES_NODE_TARGET_MAJOR\`, etc.)"
|
||||
return 1
|
||||
}
|
||||
|
|
@ -44,6 +44,7 @@ AUTHOR_MAP = {
|
|||
"teknium@nousresearch.com": "teknium1",
|
||||
"127238744+teknium1@users.noreply.github.com": "teknium1",
|
||||
# contributors (from noreply pattern)
|
||||
"snreynolds2506@gmail.com": "snreynolds",
|
||||
"35742124+0xbyt4@users.noreply.github.com": "0xbyt4",
|
||||
"82637225+kshitijk4poor@users.noreply.github.com": "kshitijk4poor",
|
||||
"kshitijk4poor@users.noreply.github.com": "kshitijk4poor",
|
||||
|
|
@ -75,6 +76,7 @@ AUTHOR_MAP = {
|
|||
"Asunfly@users.noreply.github.com": "Asunfly",
|
||||
# contributors (manual mapping from git names)
|
||||
"ahmedsherif95@gmail.com": "asheriif",
|
||||
"liujinkun@bytedance.com": "liujinkun2025",
|
||||
"dmayhem93@gmail.com": "dmahan93",
|
||||
"samherring99@gmail.com": "samherring99",
|
||||
"desaiaum08@gmail.com": "Aum08Desai",
|
||||
|
|
@ -95,6 +97,7 @@ AUTHOR_MAP = {
|
|||
"4317663+helix4u@users.noreply.github.com": "helix4u",
|
||||
"331214+counterposition@users.noreply.github.com": "counterposition",
|
||||
"blspear@gmail.com": "BrennerSpear",
|
||||
"akhater@gmail.com": "akhater",
|
||||
"239876380+handsdiff@users.noreply.github.com": "handsdiff",
|
||||
"gpickett00@gmail.com": "gpickett00",
|
||||
"mcosma@gmail.com": "wakamex",
|
||||
|
|
@ -103,6 +106,7 @@ AUTHOR_MAP = {
|
|||
"dangtc94@gmail.com": "dieutx",
|
||||
"jaisehgal11299@gmail.com": "jaisup",
|
||||
"percydikec@gmail.com": "PercyDikec",
|
||||
"noonou7@gmail.com": "HenkDz",
|
||||
"dean.kerr@gmail.com": "deankerr",
|
||||
"socrates1024@gmail.com": "socrates1024",
|
||||
"satelerd@gmail.com": "satelerd",
|
||||
|
|
@ -115,6 +119,7 @@ AUTHOR_MAP = {
|
|||
"vincentcharlebois@gmail.com": "vincentcharlebois",
|
||||
"aryan@synvoid.com": "aryansingh",
|
||||
"johnsonblake1@gmail.com": "blakejohnson",
|
||||
"hcn518@gmail.com": "pedh",
|
||||
"greer.guthrie@gmail.com": "g-guthrie",
|
||||
"kennyx102@gmail.com": "bobashopcashier",
|
||||
"shokatalishaikh95@gmail.com": "areu01or00",
|
||||
|
|
@ -255,6 +260,8 @@ AUTHOR_MAP = {
|
|||
"anthhub@163.com": "anthhub",
|
||||
"shenuu@gmail.com": "shenuu",
|
||||
"xiayh17@gmail.com": "xiayh0107",
|
||||
"asurla@nvidia.com": "anniesurla",
|
||||
"limkuan24@gmail.com": "WideLee",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -42,9 +42,10 @@ class TestToolProgressCallback:
|
|||
def test_emits_tool_call_start(self, mock_conn, event_loop_fixture):
|
||||
"""Tool progress should emit a ToolCallStart update."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
# Run callback in the event loop context
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
|
|
@ -66,9 +67,10 @@ class TestToolProgressCallback:
|
|||
def test_handles_string_args(self, mock_conn, event_loop_fixture):
|
||||
"""If args is a JSON string, it should be parsed."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -82,9 +84,10 @@ class TestToolProgressCallback:
|
|||
def test_handles_non_dict_args(self, mock_conn, event_loop_fixture):
|
||||
"""If args is not a dict, it should be wrapped."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -98,10 +101,11 @@ class TestToolProgressCallback:
|
|||
def test_duplicate_same_name_tool_calls_use_fifo_ids(self, mock_conn, event_loop_fixture):
|
||||
"""Multiple same-name tool calls should be tracked independently in order."""
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
progress_cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
step_cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -163,7 +167,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {"terminal": "tc-abc123"}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -181,7 +185,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
cb(1, [{"name": "unknown_tool", "result": "ok"}])
|
||||
|
|
@ -193,7 +197,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {"read_file": "tc-def456"}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts:
|
||||
future = MagicMock(spec=Future)
|
||||
|
|
@ -212,7 +216,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {"terminal": deque(["tc-xyz789"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
|
|
@ -224,7 +228,7 @@ class TestStepCallback:
|
|||
cb(1, [{"name": "terminal", "result": '{"output": "hello"}'}])
|
||||
|
||||
mock_btc.assert_called_once_with(
|
||||
"tc-xyz789", "terminal", result='{"output": "hello"}'
|
||||
"tc-xyz789", "terminal", result='{"output": "hello"}', function_args=None, snapshot=None
|
||||
)
|
||||
|
||||
def test_none_result_passed_through(self, mock_conn, event_loop_fixture):
|
||||
|
|
@ -234,7 +238,7 @@ class TestStepCallback:
|
|||
tool_call_ids = {"web_search": deque(["tc-aaa"])}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids)
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, {})
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
|
|
@ -244,7 +248,50 @@ class TestStepCallback:
|
|||
|
||||
cb(1, [{"name": "web_search", "result": None}])
|
||||
|
||||
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None)
|
||||
mock_btc.assert_called_once_with("tc-aaa", "web_search", result=None, function_args=None, snapshot=None)
|
||||
|
||||
def test_step_callback_passes_arguments_and_snapshot(self, mock_conn, event_loop_fixture):
|
||||
from collections import deque
|
||||
|
||||
tool_call_ids = {"write_file": deque(["tc-write"])}
|
||||
tool_call_meta = {"tc-write": {"args": {"path": "fallback.txt"}, "snapshot": "snap"}}
|
||||
loop = event_loop_fixture
|
||||
|
||||
cb = make_step_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
|
||||
with patch("acp_adapter.events.asyncio.run_coroutine_threadsafe") as mock_rcts, \
|
||||
patch("acp_adapter.events.build_tool_complete") as mock_btc:
|
||||
future = MagicMock(spec=Future)
|
||||
future.result.return_value = None
|
||||
mock_rcts.return_value = future
|
||||
|
||||
cb(1, [{"name": "write_file", "result": '{"bytes_written": 23}', "arguments": {"path": "diff-test.txt"}}])
|
||||
|
||||
mock_btc.assert_called_once_with(
|
||||
"tc-write",
|
||||
"write_file",
|
||||
result='{"bytes_written": 23}',
|
||||
function_args={"path": "diff-test.txt"},
|
||||
snapshot="snap",
|
||||
)
|
||||
|
||||
def test_tool_progress_captures_snapshot_metadata(self, mock_conn, event_loop_fixture):
|
||||
tool_call_ids = {}
|
||||
tool_call_meta = {}
|
||||
loop = event_loop_fixture
|
||||
|
||||
with patch("acp_adapter.events.make_tool_call_id", return_value="tc-meta"), \
|
||||
patch("acp_adapter.events._send_update") as mock_send, \
|
||||
patch("agent.display.capture_local_edit_snapshot", return_value="snapshot"):
|
||||
cb = make_tool_progress_cb(mock_conn, "session-1", loop, tool_call_ids, tool_call_meta)
|
||||
cb("tool.started", "write_file", None, {"path": "diff-test.txt", "content": "hello"})
|
||||
|
||||
assert list(tool_call_ids["write_file"]) == ["tc-meta"]
|
||||
assert tool_call_meta["tc-meta"] == {
|
||||
"args": {"path": "diff-test.txt", "content": "hello"},
|
||||
"snapshot": "snapshot",
|
||||
}
|
||||
mock_send.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from acp.schema import (
|
|||
|
||||
from acp_adapter.server import HermesACPAgent
|
||||
from acp_adapter.session import SessionManager
|
||||
from acp_adapter.tools import build_tool_start
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -181,6 +182,25 @@ class TestMcpRegistrationE2E:
|
|||
assert complete_event.raw_output is not None
|
||||
assert "hello" in str(complete_event.raw_output)
|
||||
|
||||
def test_patch_mode_tool_start_emits_diff_blocks_for_v4a_patch(self):
|
||||
update = build_tool_start(
|
||||
"tc-1",
|
||||
"patch",
|
||||
{
|
||||
"mode": "patch",
|
||||
"patch": "*** Begin Patch\n*** Update File: src/app.py\n@@\n-old line\n+new line\n*** Add File: src/new.py\n+hello\n*** End Patch",
|
||||
},
|
||||
)
|
||||
|
||||
assert len(update.content) == 2
|
||||
assert update.content[0].type == "diff"
|
||||
assert update.content[0].path == "src/app.py"
|
||||
assert update.content[0].old_text == "old line"
|
||||
assert update.content[0].new_text == "new line"
|
||||
assert update.content[1].type == "diff"
|
||||
assert update.content[1].path == "src/new.py"
|
||||
assert update.content[1].new_text == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_tool_results_paired_by_call_id(self, acp_agent, mock_manager):
|
||||
"""The ToolCallUpdate's toolCallId must match the ToolCallStart's."""
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ from acp.schema import (
|
|||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
SessionModelState,
|
||||
SetSessionConfigOptionResponse,
|
||||
SetSessionModelResponse,
|
||||
SetSessionModeResponse,
|
||||
SessionInfo,
|
||||
TextContentBlock,
|
||||
|
|
@ -127,6 +129,25 @@ class TestSessionOps:
|
|||
assert state is not None
|
||||
assert state.cwd == "/home/user/project"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_session_returns_model_state(self):
|
||||
manager = SessionManager(
|
||||
agent_factory=lambda: SimpleNamespace(model="gpt-5.4", provider="openai-codex")
|
||||
)
|
||||
acp_agent = HermesACPAgent(session_manager=manager)
|
||||
|
||||
with patch(
|
||||
"hermes_cli.models.curated_models_for_provider",
|
||||
return_value=[("gpt-5.4", "recommended"), ("gpt-5.4-mini", "")],
|
||||
):
|
||||
resp = await acp_agent.new_session(cwd="/tmp")
|
||||
|
||||
assert isinstance(resp.models, SessionModelState)
|
||||
assert resp.models.current_model_id == "openai-codex:gpt-5.4"
|
||||
assert resp.models.available_models[0].model_id == "openai-codex:gpt-5.4"
|
||||
assert resp.models.available_models[0].description is not None
|
||||
assert "Provider:" in resp.models.available_models[0].description
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_available_commands_include_help(self, agent):
|
||||
help_cmd = next(
|
||||
|
|
@ -204,6 +225,33 @@ class TestListAndFork:
|
|||
assert fork_resp.session_id
|
||||
assert fork_resp.session_id != new_resp.session_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_includes_title_and_updated_at(self, agent):
|
||||
with patch.object(
|
||||
agent.session_manager,
|
||||
"list_sessions",
|
||||
return_value=[
|
||||
{
|
||||
"session_id": "session-1",
|
||||
"cwd": "/tmp/project",
|
||||
"title": "Fix Zed session history",
|
||||
"updated_at": 123.0,
|
||||
}
|
||||
],
|
||||
):
|
||||
resp = await agent.list_sessions(cwd="/tmp/project")
|
||||
|
||||
assert isinstance(resp.sessions[0], SessionInfo)
|
||||
assert resp.sessions[0].title == "Fix Zed session history"
|
||||
assert resp.sessions[0].updated_at == "123.0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_passes_cwd_filter(self, agent):
|
||||
with patch.object(agent.session_manager, "list_sessions", return_value=[]) as mock_list:
|
||||
await agent.list_sessions(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||
|
||||
mock_list.assert_called_once_with(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session configuration / model routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -257,6 +305,53 @@ class TestSessionConfiguration:
|
|||
assert result == {}
|
||||
assert state.model == "gpt-5.4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_session_model_accepts_provider_prefixed_choice(self, tmp_path, monkeypatch):
|
||||
runtime_calls = []
|
||||
|
||||
def fake_resolve_runtime_provider(requested=None, **kwargs):
|
||||
runtime_calls.append(requested)
|
||||
provider = requested or "openrouter"
|
||||
return {
|
||||
"provider": provider,
|
||||
"api_mode": "anthropic_messages" if provider == "anthropic" else "chat_completions",
|
||||
"base_url": f"https://{provider}.example/v1",
|
||||
"api_key": f"{provider}-key",
|
||||
"command": None,
|
||||
"args": [],
|
||||
}
|
||||
|
||||
def fake_agent(**kwargs):
|
||||
return SimpleNamespace(
|
||||
model=kwargs.get("model"),
|
||||
provider=kwargs.get("provider"),
|
||||
base_url=kwargs.get("base_url"),
|
||||
api_mode=kwargs.get("api_mode"),
|
||||
)
|
||||
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: {
|
||||
"model": {"provider": "openrouter", "default": "openrouter/gpt-5"}
|
||||
})
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
fake_resolve_runtime_provider,
|
||||
)
|
||||
manager = SessionManager(db=SessionDB(tmp_path / "state.db"))
|
||||
|
||||
with patch("run_agent.AIAgent", side_effect=fake_agent):
|
||||
acp_agent = HermesACPAgent(session_manager=manager)
|
||||
state = manager.create_session(cwd="/tmp")
|
||||
result = await acp_agent.set_session_model(
|
||||
model_id="anthropic:claude-sonnet-4-6",
|
||||
session_id=state.session_id,
|
||||
)
|
||||
|
||||
assert isinstance(result, SetSessionModelResponse)
|
||||
assert state.model == "claude-sonnet-4-6"
|
||||
assert state.agent.provider == "anthropic"
|
||||
assert state.agent.base_url == "https://anthropic.example/v1"
|
||||
assert runtime_calls[-1] == "anthropic"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prompt
|
||||
|
|
@ -354,6 +449,31 @@ class TestPrompt:
|
|||
update = last_call[1].get("update") or last_call[0][1]
|
||||
assert update.session_update == "agent_message_chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_auto_titles_session(self, agent):
|
||||
new_resp = await agent.new_session(cwd=".")
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
state.agent.run_conversation = MagicMock(return_value={
|
||||
"final_response": "Here is the fix.",
|
||||
"messages": [
|
||||
{"role": "user", "content": "fix the broken ACP history"},
|
||||
{"role": "assistant", "content": "Here is the fix."},
|
||||
],
|
||||
})
|
||||
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
mock_conn.session_update = AsyncMock()
|
||||
agent._conn = mock_conn
|
||||
|
||||
with patch("agent.title_generator.maybe_auto_title") as mock_title:
|
||||
prompt = [TextContentBlock(type="text", text="fix the broken ACP history")]
|
||||
await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
mock_title.assert_called_once()
|
||||
assert mock_title.call_args.args[1] == new_resp.session_id
|
||||
assert mock_title.call_args.args[2] == "fix the broken ACP history"
|
||||
assert mock_title.call_args.args[3] == "Here is the fix."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_populates_usage_from_top_level_run_conversation_fields(self, agent):
|
||||
"""ACP should map top-level token fields into PromptResponse.usage."""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import contextlib
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
|
@ -100,15 +101,23 @@ class TestListAndCleanup:
|
|||
def test_list_sessions_returns_created(self, manager):
|
||||
s1 = manager.create_session(cwd="/a")
|
||||
s2 = manager.create_session(cwd="/b")
|
||||
s1.history.append({"role": "user", "content": "hello from a"})
|
||||
s2.history.append({"role": "user", "content": "hello from b"})
|
||||
listing = manager.list_sessions()
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert s1.session_id in ids
|
||||
assert s2.session_id in ids
|
||||
assert len(listing) == 2
|
||||
|
||||
def test_list_sessions_hides_empty_threads(self, manager):
|
||||
manager.create_session(cwd="/empty")
|
||||
assert manager.list_sessions() == []
|
||||
|
||||
def test_cleanup_clears_all(self, manager):
|
||||
manager.create_session()
|
||||
manager.create_session()
|
||||
s1 = manager.create_session()
|
||||
s2 = manager.create_session()
|
||||
s1.history.append({"role": "user", "content": "one"})
|
||||
s2.history.append({"role": "user", "content": "two"})
|
||||
assert len(manager.list_sessions()) == 2
|
||||
manager.cleanup()
|
||||
assert manager.list_sessions() == []
|
||||
|
|
@ -194,6 +203,8 @@ class TestPersistence:
|
|||
def test_list_sessions_includes_db_only(self, manager):
|
||||
"""Sessions only in DB (not in memory) appear in list_sessions."""
|
||||
state = manager.create_session(cwd="/db-only")
|
||||
state.history.append({"role": "user", "content": "database only thread"})
|
||||
manager.save_session(state.session_id)
|
||||
sid = state.session_id
|
||||
|
||||
# Drop from memory.
|
||||
|
|
@ -204,6 +215,53 @@ class TestPersistence:
|
|||
ids = {s["session_id"] for s in listing}
|
||||
assert sid in ids
|
||||
|
||||
def test_list_sessions_filters_by_cwd(self, manager):
|
||||
keep = manager.create_session(cwd="/keep")
|
||||
drop = manager.create_session(cwd="/drop")
|
||||
keep.history.append({"role": "user", "content": "keep me"})
|
||||
drop.history.append({"role": "user", "content": "drop me"})
|
||||
|
||||
listing = manager.list_sessions(cwd="/keep")
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert keep.session_id in ids
|
||||
assert drop.session_id not in ids
|
||||
|
||||
def test_list_sessions_matches_windows_and_wsl_paths(self, manager):
|
||||
state = manager.create_session(cwd="/mnt/e/Projects/AI/browser-link-3")
|
||||
state.history.append({"role": "user", "content": "same project from WSL"})
|
||||
|
||||
listing = manager.list_sessions(cwd=r"E:\Projects\AI\browser-link-3")
|
||||
ids = {s["session_id"] for s in listing}
|
||||
assert state.session_id in ids
|
||||
|
||||
def test_list_sessions_prefers_title_then_preview(self, manager):
|
||||
state = manager.create_session(cwd="/named")
|
||||
state.history.append({"role": "user", "content": "Investigate broken ACP history in Zed"})
|
||||
manager.save_session(state.session_id)
|
||||
db = manager._get_db()
|
||||
db.set_session_title(state.session_id, "Fix Zed ACP history")
|
||||
|
||||
listing = manager.list_sessions(cwd="/named")
|
||||
assert listing[0]["title"] == "Fix Zed ACP history"
|
||||
|
||||
db.set_session_title(state.session_id, "")
|
||||
listing = manager.list_sessions(cwd="/named")
|
||||
assert listing[0]["title"].startswith("Investigate broken ACP history")
|
||||
|
||||
def test_list_sessions_sorted_by_most_recent_activity(self, manager):
|
||||
older = manager.create_session(cwd="/ordered")
|
||||
older.history.append({"role": "user", "content": "older"})
|
||||
manager.save_session(older.session_id)
|
||||
time.sleep(0.02)
|
||||
newer = manager.create_session(cwd="/ordered")
|
||||
newer.history.append({"role": "user", "content": "newer"})
|
||||
manager.save_session(newer.session_id)
|
||||
|
||||
listing = manager.list_sessions(cwd="/ordered")
|
||||
assert [item["session_id"] for item in listing[:2]] == [newer.session_id, older.session_id]
|
||||
assert listing[0]["updated_at"]
|
||||
assert listing[1]["updated_at"]
|
||||
|
||||
def test_fork_restores_source_from_db(self, manager):
|
||||
"""Forking a session that is only in DB should work."""
|
||||
original = manager.create_session()
|
||||
|
|
|
|||
|
|
@ -215,6 +215,46 @@ class TestBuildToolComplete:
|
|||
assert len(display_text) < 6000
|
||||
assert "truncated" in display_text
|
||||
|
||||
def test_build_tool_complete_for_patch_uses_diff_blocks(self):
|
||||
"""Completed patch calls should keep structured diff content for Zed."""
|
||||
patch_result = (
|
||||
'{"success": true, "diff": "--- a/README.md\\n+++ b/README.md\\n@@ -1 +1,2 @@\\n old line\\n+new line\\n", '
|
||||
'"files_modified": ["README.md"]}'
|
||||
)
|
||||
result = build_tool_complete("tc-p1", "patch", patch_result)
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
assert len(result.content) == 1
|
||||
diff_item = result.content[0]
|
||||
assert isinstance(diff_item, FileEditToolCallContent)
|
||||
assert diff_item.path == "README.md"
|
||||
assert diff_item.old_text == "old line"
|
||||
assert diff_item.new_text == "old line\nnew line"
|
||||
|
||||
def test_build_tool_complete_for_patch_falls_back_to_text_when_no_diff(self):
|
||||
result = build_tool_complete("tc-p2", "patch", '{"success": true}')
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
assert isinstance(result.content[0], ContentToolCallContent)
|
||||
|
||||
def test_build_tool_complete_for_write_file_uses_snapshot_diff(self, tmp_path):
|
||||
target = tmp_path / "diff-test.txt"
|
||||
snapshot = type("Snapshot", (), {"paths": [target], "before": {str(target): None}})()
|
||||
target.write_text("hello from hermes\n", encoding="utf-8")
|
||||
|
||||
result = build_tool_complete(
|
||||
"tc-wf1",
|
||||
"write_file",
|
||||
'{"bytes_written": 18, "dirs_created": false}',
|
||||
function_args={"path": str(target), "content": "hello from hermes\n"},
|
||||
snapshot=snapshot,
|
||||
)
|
||||
assert isinstance(result, ToolCallProgress)
|
||||
assert len(result.content) == 1
|
||||
diff_item = result.content[0]
|
||||
assert isinstance(diff_item, FileEditToolCallContent)
|
||||
assert diff_item.path.endswith("diff-test.txt")
|
||||
assert diff_item.old_text is None
|
||||
assert diff_item.new_text == "hello from hermes"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_locations
|
||||
|
|
|
|||
|
|
@ -696,6 +696,95 @@ class TestIsConnectionError:
|
|||
assert _is_connection_error(err) is False
|
||||
|
||||
|
||||
class TestKimiForCodingTemperature:
|
||||
"""kimi-for-coding now requires temperature=0.6 exactly."""
|
||||
|
||||
def test_build_call_kwargs_forces_fixed_temperature(self):
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
provider="kimi-coding",
|
||||
model="kimi-for-coding",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
def test_build_call_kwargs_injects_temperature_when_missing(self):
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
provider="kimi-coding",
|
||||
model="kimi-for-coding",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=None,
|
||||
)
|
||||
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
def test_auto_routed_kimi_for_coding_sync_call_uses_fixed_temperature(self):
|
||||
client = MagicMock()
|
||||
client.base_url = "https://api.kimi.com/coding/v1"
|
||||
response = MagicMock()
|
||||
client.chat.completions.create.return_value = response
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._get_cached_client",
|
||||
return_value=(client, "kimi-for-coding"),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", "kimi-for-coding", None, None, None),
|
||||
):
|
||||
result = call_llm(
|
||||
task="session_search",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
assert result is response
|
||||
kwargs = client.chat.completions.create.call_args.kwargs
|
||||
assert kwargs["model"] == "kimi-for-coding"
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_routed_kimi_for_coding_async_call_uses_fixed_temperature(self):
|
||||
client = MagicMock()
|
||||
client.base_url = "https://api.kimi.com/coding/v1"
|
||||
response = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(return_value=response)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._get_cached_client",
|
||||
return_value=(client, "kimi-for-coding"),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", "kimi-for-coding", None, None, None),
|
||||
):
|
||||
result = await async_call_llm(
|
||||
task="session_search",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
assert result is response
|
||||
kwargs = client.chat.completions.create.call_args.kwargs
|
||||
assert kwargs["model"] == "kimi-for-coding"
|
||||
assert kwargs["temperature"] == 0.6
|
||||
|
||||
def test_non_kimi_model_still_preserves_temperature(self):
|
||||
from agent.auxiliary_client import _build_call_kwargs
|
||||
|
||||
kwargs = _build_call_kwargs(
|
||||
provider="kimi-coding",
|
||||
model="kimi-k2.5",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
assert kwargs["temperature"] == 0.3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# async_call_llm payment / connection fallback (#7512 bug 2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
311
tests/agent/test_auxiliary_main_first.py
Normal file
311
tests/agent/test_auxiliary_main_first.py
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
"""Regression tests for the ``auto`` → main-model-first policy.
|
||||
|
||||
Prior to this change, aggregator users (OpenRouter / Nous Portal) had aux
|
||||
tasks routed through a cheap provider-side default (Gemini Flash) while
|
||||
non-aggregator users got their main model. This made behavior inconsistent
|
||||
and surprising — users picked Claude but got Gemini Flash summaries.
|
||||
|
||||
The current policy: ``auto`` means "use my main chat model" for every user,
|
||||
regardless of provider type. Explicit per-task overrides in ``config.yaml``
|
||||
(``auxiliary.<task>.provider``) still win. The cheap fallback chain only
|
||||
runs when the main provider has no working client.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ── Text aux tasks — _resolve_auto ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestResolveAutoMainFirst:
|
||||
"""_resolve_auto() must prefer main provider + main model for every user."""
|
||||
|
||||
def test_openrouter_main_uses_main_model_for_aux(self, monkeypatch):
|
||||
"""OpenRouter main user → aux uses their picked OR model, not Gemini Flash."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-test-key")
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider",
|
||||
return_value="openrouter",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="anthropic/claude-sonnet-4.6",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve:
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "anthropic/claude-sonnet-4.6")
|
||||
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is mock_client
|
||||
assert model == "anthropic/claude-sonnet-4.6"
|
||||
# Verify it asked resolve_provider_client for the MAIN provider+model,
|
||||
# not a fallback-chain provider
|
||||
mock_resolve.assert_called_once()
|
||||
assert mock_resolve.call_args.args[0] == "openrouter"
|
||||
assert mock_resolve.call_args.args[1] == "anthropic/claude-sonnet-4.6"
|
||||
|
||||
def test_nous_main_uses_main_model_for_aux(self, monkeypatch):
|
||||
"""Nous Portal main user → aux uses their picked Nous model, not free-tier MiMo."""
|
||||
# No OPENROUTER_API_KEY → ensures if main failed we'd fall to chain
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="nous",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="anthropic/claude-opus-4.6",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve:
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "anthropic/claude-opus-4.6")
|
||||
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is mock_client
|
||||
assert model == "anthropic/claude-opus-4.6"
|
||||
assert mock_resolve.call_args.args[0] == "nous"
|
||||
|
||||
def test_non_aggregator_main_still_uses_main(self, monkeypatch):
|
||||
"""Non-aggregator main (DeepSeek) → unchanged behavior, main model used."""
|
||||
monkeypatch.setenv("DEEPSEEK_API_KEY", "ds-test")
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="deepseek",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="deepseek-chat",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve:
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "deepseek-chat")
|
||||
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is mock_client
|
||||
assert model == "deepseek-chat"
|
||||
assert mock_resolve.call_args.args[0] == "deepseek"
|
||||
|
||||
def test_main_unavailable_falls_through_to_chain(self, monkeypatch):
|
||||
"""Main provider with no working client → fall back to aux chain."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
chain_client = MagicMock()
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="anthropic",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="claude-opus",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(None, None), # main provider has no client
|
||||
), patch(
|
||||
"agent.auxiliary_client._try_openrouter",
|
||||
return_value=(chain_client, "google/gemini-3-flash-preview"),
|
||||
):
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is chain_client
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
|
||||
def test_no_main_config_uses_chain_directly(self):
|
||||
"""No main provider configured → skip step 1, use chain (no regression)."""
|
||||
chain_client = MagicMock()
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="",
|
||||
), patch(
|
||||
"agent.auxiliary_client._try_openrouter",
|
||||
return_value=(chain_client, "google/gemini-3-flash-preview"),
|
||||
):
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
client, model = _resolve_auto()
|
||||
|
||||
assert client is chain_client
|
||||
|
||||
def test_runtime_override_wins_over_config(self, monkeypatch):
|
||||
"""main_runtime kwarg overrides config-read main provider/model."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider",
|
||||
return_value="openrouter",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="config-model",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve:
|
||||
mock_resolve.return_value = (MagicMock(), "runtime-model")
|
||||
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
|
||||
_resolve_auto(main_runtime={
|
||||
"provider": "anthropic",
|
||||
"model": "runtime-model",
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
"api_mode": "",
|
||||
})
|
||||
|
||||
# Runtime override wins
|
||||
assert mock_resolve.call_args.args[0] == "anthropic"
|
||||
assert mock_resolve.call_args.args[1] == "runtime-model"
|
||||
|
||||
|
||||
# ── Vision — resolve_vision_provider_client ─────────────────────────────────
|
||||
|
||||
|
||||
class TestResolveVisionMainFirst:
|
||||
"""Vision auto-detection prefers main provider + main model first."""
|
||||
|
||||
def test_openrouter_main_vision_uses_main_model(self, monkeypatch):
|
||||
"""OpenRouter main with vision-capable model → aux vision uses main model."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="openrouter",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="anthropic/claude-sonnet-4.6",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve, patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "anthropic/claude-sonnet-4.6")
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "openrouter"
|
||||
assert client is mock_client
|
||||
assert model == "anthropic/claude-sonnet-4.6"
|
||||
# Verify it did NOT call the strict vision backend for OpenRouter
|
||||
# (which would have used a cheap gemini-flash-preview default)
|
||||
mock_resolve.assert_called_once()
|
||||
assert mock_resolve.call_args.args[0] == "openrouter"
|
||||
assert mock_resolve.call_args.args[1] == "anthropic/claude-sonnet-4.6"
|
||||
|
||||
def test_nous_main_vision_uses_main_model(self):
|
||||
"""Nous Portal main → aux vision uses main model, not free-tier MiMo-V2-Omni."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="nous",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="openai/gpt-5",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve, patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_resolve.return_value = (mock_client, "openai/gpt-5")
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "nous"
|
||||
assert model == "openai/gpt-5"
|
||||
|
||||
def test_exotic_provider_with_vision_override_preserved(self):
|
||||
"""xiaomi → mimo-v2-omni override still wins over main_model."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="xiaomi",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="mimo-v2-pro", # text model
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client"
|
||||
) as mock_resolve, patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
):
|
||||
mock_resolve.return_value = (MagicMock(), "mimo-v2-omni")
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "xiaomi"
|
||||
# Should use mimo-v2-omni (vision override), not mimo-v2-pro (text main)
|
||||
assert mock_resolve.call_args.args[1] == "mimo-v2-omni"
|
||||
|
||||
def test_main_unavailable_vision_falls_through_to_aggregators(self):
|
||||
"""Main provider fails → fall back to OpenRouter/Nous strict backends."""
|
||||
fallback_client = MagicMock()
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="deepseek",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model", return_value="deepseek-chat",
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(None, None),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_strict_vision_backend",
|
||||
return_value=(fallback_client, "google/gemini-3-flash-preview"),
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
):
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert client is fallback_client
|
||||
assert provider in ("openrouter", "nous")
|
||||
|
||||
def test_explicit_provider_override_still_wins(self):
|
||||
"""Explicit config override bypasses main-first policy."""
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider", return_value="openrouter",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="anthropic/claude-opus-4.6",
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("nous", None, None, None, None), # explicit override
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_strict_vision_backend"
|
||||
) as mock_strict:
|
||||
mock_strict.return_value = (MagicMock(), "nous-default-model")
|
||||
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
# Explicit "nous" override → uses strict backend, NOT main model path
|
||||
assert provider == "nous"
|
||||
mock_strict.assert_called_once_with("nous")
|
||||
|
||||
|
||||
# ── Constant cleanup ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_aggregator_providers_constant_removed():
|
||||
"""The dead _AGGREGATOR_PROVIDERS constant should no longer live in the module.
|
||||
|
||||
Removed when the main-first policy made the aggregator-skip guard obsolete.
|
||||
"""
|
||||
import agent.auxiliary_client as aux_mod
|
||||
|
||||
assert not hasattr(aux_mod, "_AGGREGATOR_PROVIDERS"), (
|
||||
"_AGGREGATOR_PROVIDERS was removed when _resolve_auto stopped "
|
||||
"treating aggregators specially. If you re-added it, the main-first "
|
||||
"policy may have regressed."
|
||||
)
|
||||
|
|
@ -826,6 +826,160 @@ class TestGeminiCloudCodeClient:
|
|||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
class TestGeminiHttpErrorParsing:
|
||||
"""Regression coverage for _gemini_http_error Google-envelope parsing.
|
||||
|
||||
These are the paths that users actually hit during Google-side throttling
|
||||
(April 2026: gemini-2.5-pro MODEL_CAPACITY_EXHAUSTED, gemma-4-26b-it
|
||||
returning 404). The error needs to carry status_code + response so the
|
||||
main loop's error_classifier and Retry-After logic work.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _fake_response(status: int, body: dict | str = "", headers=None):
|
||||
"""Minimal httpx.Response stand-in (duck-typed for _gemini_http_error)."""
|
||||
class _FakeResponse:
|
||||
def __init__(self):
|
||||
self.status_code = status
|
||||
if isinstance(body, dict):
|
||||
self.text = json.dumps(body)
|
||||
else:
|
||||
self.text = body
|
||||
self.headers = headers or {}
|
||||
return _FakeResponse()
|
||||
|
||||
def test_model_capacity_exhausted_produces_friendly_message(self):
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
body = {
|
||||
"error": {
|
||||
"code": 429,
|
||||
"message": "Resource has been exhausted (e.g. check quota).",
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
|
||||
"reason": "MODEL_CAPACITY_EXHAUSTED",
|
||||
"domain": "googleapis.com",
|
||||
"metadata": {"model": "gemini-2.5-pro"},
|
||||
},
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.RetryInfo",
|
||||
"retryDelay": "30s",
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
err = _gemini_http_error(self._fake_response(429, body))
|
||||
assert err.status_code == 429
|
||||
assert err.code == "code_assist_capacity_exhausted"
|
||||
assert err.retry_after == 30.0
|
||||
assert err.details["reason"] == "MODEL_CAPACITY_EXHAUSTED"
|
||||
# Message must be user-friendly, not a raw JSON dump.
|
||||
message = str(err)
|
||||
assert "gemini-2.5-pro" in message
|
||||
assert "capacity exhausted" in message.lower()
|
||||
assert "30s" in message
|
||||
# response attr is preserved for run_agent's Retry-After header path.
|
||||
assert err.response is not None
|
||||
|
||||
def test_resource_exhausted_without_reason(self):
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
body = {
|
||||
"error": {
|
||||
"code": 429,
|
||||
"message": "Quota exceeded for requests per minute.",
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
}
|
||||
}
|
||||
err = _gemini_http_error(self._fake_response(429, body))
|
||||
assert err.status_code == 429
|
||||
assert err.code == "code_assist_rate_limited"
|
||||
message = str(err)
|
||||
assert "quota" in message.lower()
|
||||
|
||||
def test_404_model_not_found_produces_model_retired_message(self):
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
body = {
|
||||
"error": {
|
||||
"code": 404,
|
||||
"message": "models/gemma-4-26b-it is not found for API version v1internal",
|
||||
"status": "NOT_FOUND",
|
||||
}
|
||||
}
|
||||
err = _gemini_http_error(self._fake_response(404, body))
|
||||
assert err.status_code == 404
|
||||
message = str(err)
|
||||
assert "not available" in message.lower() or "retired" in message.lower()
|
||||
# Error message should reference the actual model text from Google.
|
||||
assert "gemma-4-26b-it" in message
|
||||
|
||||
def test_unauthorized_preserves_status_code(self):
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
err = _gemini_http_error(self._fake_response(
|
||||
401, {"error": {"code": 401, "message": "Invalid token", "status": "UNAUTHENTICATED"}},
|
||||
))
|
||||
assert err.status_code == 401
|
||||
assert err.code == "code_assist_unauthorized"
|
||||
|
||||
def test_retry_after_header_fallback(self):
|
||||
"""If the body has no RetryInfo detail, fall back to Retry-After header."""
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
resp = self._fake_response(
|
||||
429,
|
||||
{"error": {"code": 429, "message": "Rate limited", "status": "RESOURCE_EXHAUSTED"}},
|
||||
headers={"Retry-After": "45"},
|
||||
)
|
||||
err = _gemini_http_error(resp)
|
||||
assert err.retry_after == 45.0
|
||||
|
||||
def test_malformed_body_still_produces_structured_error(self):
|
||||
"""Non-JSON body must not swallow status_code — we still want the classifier path."""
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
|
||||
err = _gemini_http_error(self._fake_response(500, "<html>internal error</html>"))
|
||||
assert err.status_code == 500
|
||||
# Raw body snippet must still be there for debugging.
|
||||
assert "500" in str(err)
|
||||
|
||||
def test_status_code_flows_through_error_classifier(self):
|
||||
"""End-to-end: CodeAssistError from a 429 must classify as rate_limit.
|
||||
|
||||
This is the whole point of adding status_code to CodeAssistError —
|
||||
_extract_status_code must see it and FailoverReason.rate_limit must
|
||||
fire, so the main loop triggers fallback_providers.
|
||||
"""
|
||||
from agent.gemini_cloudcode_adapter import _gemini_http_error
|
||||
from agent.error_classifier import classify_api_error, FailoverReason
|
||||
|
||||
body = {
|
||||
"error": {
|
||||
"code": 429,
|
||||
"message": "Resource has been exhausted",
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{
|
||||
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
|
||||
"reason": "MODEL_CAPACITY_EXHAUSTED",
|
||||
"metadata": {"model": "gemini-2.5-pro"},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
err = _gemini_http_error(self._fake_response(429, body))
|
||||
|
||||
classified = classify_api_error(
|
||||
err, provider="google-gemini-cli", model="gemini-2.5-pro",
|
||||
)
|
||||
assert classified.status_code == 429
|
||||
assert classified.reason == FailoverReason.rate_limit
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Provider registration
|
||||
# =============================================================================
|
||||
|
|
|
|||
71
tests/cli/test_cli_copy_command.py
Normal file
71
tests/cli/test_cli_copy_command.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""Tests for CLI /copy command."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
def _make_cli() -> HermesCLI:
|
||||
cli_obj = HermesCLI.__new__(HermesCLI)
|
||||
cli_obj.config = {}
|
||||
cli_obj.console = MagicMock()
|
||||
cli_obj.agent = None
|
||||
cli_obj.conversation_history = []
|
||||
cli_obj.session_id = "sess-copy-test"
|
||||
cli_obj._pending_input = MagicMock()
|
||||
cli_obj._app = None
|
||||
return cli_obj
|
||||
|
||||
|
||||
def test_copy_copies_latest_assistant_message():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj.conversation_history = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "first"},
|
||||
{"role": "assistant", "content": "latest"},
|
||||
]
|
||||
|
||||
with patch.object(cli_obj, "_write_osc52_clipboard") as mock_copy:
|
||||
result = cli_obj.process_command("/copy")
|
||||
|
||||
assert result is True
|
||||
mock_copy.assert_called_once_with("latest")
|
||||
|
||||
|
||||
def test_copy_with_index_uses_requested_assistant_message():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj.conversation_history = [
|
||||
{"role": "assistant", "content": "one"},
|
||||
{"role": "assistant", "content": "two"},
|
||||
]
|
||||
|
||||
with patch.object(cli_obj, "_write_osc52_clipboard") as mock_copy:
|
||||
cli_obj.process_command("/copy 1")
|
||||
|
||||
mock_copy.assert_called_once_with("one")
|
||||
|
||||
|
||||
def test_copy_strips_reasoning_blocks_before_copy():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj.conversation_history = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<REASONING_SCRATCHPAD>internal</REASONING_SCRATCHPAD>\nVisible answer",
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(cli_obj, "_write_osc52_clipboard") as mock_copy:
|
||||
cli_obj.process_command("/copy")
|
||||
|
||||
mock_copy.assert_called_once_with("Visible answer")
|
||||
|
||||
|
||||
def test_copy_invalid_index_does_not_copy():
|
||||
cli_obj = _make_cli()
|
||||
cli_obj.conversation_history = [{"role": "assistant", "content": "only"}]
|
||||
|
||||
with patch.object(cli_obj, "_write_osc52_clipboard") as mock_copy, patch("cli._cprint") as mock_print:
|
||||
cli_obj.process_command("/copy 99")
|
||||
|
||||
mock_copy.assert_not_called()
|
||||
assert any("Invalid response number" in str(call) for call in mock_print.call_args_list)
|
||||
|
|
@ -64,6 +64,24 @@ class TestSaveConfigValueAtomic:
|
|||
result = yaml.safe_load(config_env.read_text())
|
||||
assert result["display"]["skin"] == "ares"
|
||||
|
||||
def test_preserves_env_ref_templates_in_unrelated_fields(self, config_env):
|
||||
"""The /model --global persistence path must not inline env-backed secrets."""
|
||||
config_env.write_text(yaml.dump({
|
||||
"custom_providers": [{
|
||||
"name": "tuzi",
|
||||
"api_key": "${TU_ZI_API_KEY}",
|
||||
"model": "claude-opus-4-6",
|
||||
}],
|
||||
"model": {"default": "test-model", "provider": "openrouter"},
|
||||
}))
|
||||
|
||||
from cli import save_config_value
|
||||
save_config_value("model.default", "doubao-pro")
|
||||
|
||||
result = yaml.safe_load(config_env.read_text())
|
||||
assert result["model"]["default"] == "doubao-pro"
|
||||
assert result["custom_providers"][0]["api_key"] == "${TU_ZI_API_KEY}"
|
||||
|
||||
def test_file_not_truncated_on_error(self, config_env, monkeypatch):
|
||||
"""If atomic_yaml_write raises, the original file is untouched."""
|
||||
original_content = config_env.read_text()
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@
|
|||
|
||||
Surrogates (U+D800..U+DFFF) are invalid in UTF-8 and crash json.dumps()
|
||||
inside the OpenAI SDK. They can appear via clipboard paste from rich-text
|
||||
editors like Google Docs.
|
||||
editors like Google Docs, OR from byte-level reasoning models (xiaomi/mimo,
|
||||
kimi, glm) emitting lone halves in reasoning output.
|
||||
"""
|
||||
import json
|
||||
import pytest
|
||||
|
|
@ -11,6 +12,7 @@ from unittest.mock import MagicMock, patch
|
|||
from run_agent import (
|
||||
_sanitize_surrogates,
|
||||
_sanitize_messages_surrogates,
|
||||
_sanitize_structure_surrogates,
|
||||
_SURROGATE_RE,
|
||||
)
|
||||
|
||||
|
|
@ -109,6 +111,186 @@ class TestSanitizeMessagesSurrogates:
|
|||
assert "\ufffd" in msgs[0]["content"]
|
||||
|
||||
|
||||
class TestReasoningFieldSurrogates:
|
||||
"""Surrogates in reasoning fields (byte-level reasoning models).
|
||||
|
||||
xiaomi/mimo, kimi, glm and similar byte-level tokenizers can emit lone
|
||||
surrogates in reasoning output. These fields are carried through to the
|
||||
API as `reasoning_content` on assistant messages, and must be sanitized
|
||||
or json.dumps() crashes with 'utf-8' codec can't encode surrogates.
|
||||
"""
|
||||
|
||||
def test_reasoning_field_sanitized(self):
|
||||
msgs = [
|
||||
{"role": "assistant", "content": "ok", "reasoning": "thought \udce2 here"},
|
||||
]
|
||||
assert _sanitize_messages_surrogates(msgs) is True
|
||||
assert "\udce2" not in msgs[0]["reasoning"]
|
||||
assert "\ufffd" in msgs[0]["reasoning"]
|
||||
|
||||
def test_reasoning_content_field_sanitized(self):
|
||||
"""api_messages carry `reasoning_content` built from `reasoning`."""
|
||||
msgs = [
|
||||
{"role": "assistant", "content": "ok", "reasoning_content": "thought \udce2 here"},
|
||||
]
|
||||
assert _sanitize_messages_surrogates(msgs) is True
|
||||
assert "\udce2" not in msgs[0]["reasoning_content"]
|
||||
assert "\ufffd" in msgs[0]["reasoning_content"]
|
||||
|
||||
def test_reasoning_details_nested_sanitized(self):
|
||||
"""reasoning_details is a list of dicts with nested string fields."""
|
||||
msgs = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "ok",
|
||||
"reasoning_details": [
|
||||
{"type": "reasoning.summary", "summary": "summary \udce2 text"},
|
||||
{"type": "reasoning.text", "text": "chain \udc00 of thought"},
|
||||
],
|
||||
},
|
||||
]
|
||||
assert _sanitize_messages_surrogates(msgs) is True
|
||||
assert "\udce2" not in msgs[0]["reasoning_details"][0]["summary"]
|
||||
assert "\ufffd" in msgs[0]["reasoning_details"][0]["summary"]
|
||||
assert "\udc00" not in msgs[0]["reasoning_details"][1]["text"]
|
||||
assert "\ufffd" in msgs[0]["reasoning_details"][1]["text"]
|
||||
|
||||
def test_deeply_nested_reasoning_sanitized(self):
|
||||
"""Nested dicts / lists inside extra fields are recursed into."""
|
||||
msgs = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "ok",
|
||||
"reasoning_details": [
|
||||
{
|
||||
"type": "reasoning.encrypted",
|
||||
"content": {
|
||||
"encrypted_content": "opaque",
|
||||
"text_parts": ["part1", "part2 \udce2 part"],
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
assert _sanitize_messages_surrogates(msgs) is True
|
||||
assert (
|
||||
msgs[0]["reasoning_details"][0]["content"]["text_parts"][1]
|
||||
== "part2 \ufffd part"
|
||||
)
|
||||
|
||||
def test_reasoning_end_to_end_json_serialization(self):
|
||||
"""After sanitization, the full message dict must serialize clean."""
|
||||
msgs = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "answer",
|
||||
"reasoning_content": "reasoning with \udce2 surrogate",
|
||||
"reasoning_details": [
|
||||
{"summary": "nested \udcb0 surrogate"},
|
||||
],
|
||||
},
|
||||
]
|
||||
_sanitize_messages_surrogates(msgs)
|
||||
# Must round-trip through json + utf-8 encoding without error
|
||||
payload = json.dumps(msgs, ensure_ascii=False).encode("utf-8")
|
||||
assert b"\\" not in payload[:0] # sanity — just ensure we got bytes
|
||||
assert len(payload) > 0
|
||||
|
||||
def test_no_surrogates_returns_false(self):
|
||||
"""Clean reasoning fields don't trigger a modification."""
|
||||
msgs = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "ok",
|
||||
"reasoning": "clean thought",
|
||||
"reasoning_content": "also clean",
|
||||
"reasoning_details": [{"summary": "clean summary"}],
|
||||
},
|
||||
]
|
||||
assert _sanitize_messages_surrogates(msgs) is False
|
||||
|
||||
|
||||
class TestSanitizeStructureSurrogates:
|
||||
"""Test the _sanitize_structure_surrogates() helper for nested payloads."""
|
||||
|
||||
def test_empty_payload(self):
|
||||
assert _sanitize_structure_surrogates({}) is False
|
||||
assert _sanitize_structure_surrogates([]) is False
|
||||
|
||||
def test_flat_dict(self):
|
||||
payload = {"a": "clean", "b": "dirty \udce2 text"}
|
||||
assert _sanitize_structure_surrogates(payload) is True
|
||||
assert payload["a"] == "clean"
|
||||
assert "\ufffd" in payload["b"]
|
||||
|
||||
def test_flat_list(self):
|
||||
payload = ["clean", "dirty \udce2"]
|
||||
assert _sanitize_structure_surrogates(payload) is True
|
||||
assert payload[0] == "clean"
|
||||
assert "\ufffd" in payload[1]
|
||||
|
||||
def test_nested_dict_in_list(self):
|
||||
payload = [{"x": "dirty \udce2"}, {"x": "clean"}]
|
||||
assert _sanitize_structure_surrogates(payload) is True
|
||||
assert "\ufffd" in payload[0]["x"]
|
||||
assert payload[1]["x"] == "clean"
|
||||
|
||||
def test_deeply_nested(self):
|
||||
payload = {
|
||||
"level1": {
|
||||
"level2": [
|
||||
{"level3": "deep \udce2 surrogate"},
|
||||
],
|
||||
},
|
||||
}
|
||||
assert _sanitize_structure_surrogates(payload) is True
|
||||
assert "\ufffd" in payload["level1"]["level2"][0]["level3"]
|
||||
|
||||
def test_clean_payload_returns_false(self):
|
||||
payload = {"a": "clean", "b": [{"c": "also clean"}]}
|
||||
assert _sanitize_structure_surrogates(payload) is False
|
||||
|
||||
def test_non_string_values_ignored(self):
|
||||
payload = {"int": 42, "list": [1, 2, 3], "dict": {"none": None}, "bool": True}
|
||||
assert _sanitize_structure_surrogates(payload) is False
|
||||
# Non-string values survive unchanged
|
||||
assert payload["int"] == 42
|
||||
assert payload["list"] == [1, 2, 3]
|
||||
|
||||
|
||||
class TestApiMessagesSurrogateRecovery:
|
||||
"""Integration: verify the recovery block sanitizes api_messages.
|
||||
|
||||
The bug this guards against: a surrogate in `reasoning_content` on
|
||||
api_messages (transformed from `reasoning` during build) crashes the
|
||||
OpenAI SDK's json.dumps(), and the recovery block previously only
|
||||
sanitized the canonical `messages` list — not `api_messages` — so the
|
||||
next retry would send the same broken payload and fail 3 times.
|
||||
"""
|
||||
|
||||
def test_api_messages_reasoning_content_sanitized(self):
|
||||
"""The extended sanitizer catches reasoning_content in api_messages."""
|
||||
api_messages = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "response",
|
||||
"reasoning_content": "thought \udce2 trail",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"function": {"name": "tool", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "result", "tool_call_id": "call_1"},
|
||||
]
|
||||
assert _sanitize_messages_surrogates(api_messages) is True
|
||||
assert "\udce2" not in api_messages[1]["reasoning_content"]
|
||||
# Full payload must now serialize clean
|
||||
json.dumps(api_messages, ensure_ascii=False).encode("utf-8")
|
||||
|
||||
|
||||
class TestRunConversationSurrogateSanitization:
|
||||
"""Integration: verify run_conversation sanitizes user_message."""
|
||||
|
||||
|
|
|
|||
|
|
@ -184,6 +184,8 @@ _HERMES_BEHAVIORAL_VARS = frozenset({
|
|||
"HERMES_BACKGROUND_NOTIFICATIONS",
|
||||
"HERMES_EXEC_ASK",
|
||||
"HERMES_HOME_MODE",
|
||||
"BROWSER_CDP_URL",
|
||||
"CAMOFOX_URL",
|
||||
})
|
||||
|
||||
|
||||
|
|
@ -229,6 +231,15 @@ def _hermetic_environment(tmp_path, monkeypatch):
|
|||
monkeypatch.setenv("LC_ALL", "C.UTF-8")
|
||||
monkeypatch.setenv("PYTHONHASHSEED", "0")
|
||||
|
||||
# 4b. Disable AWS IMDS lookups. Without this, any test that ends up
|
||||
# calling has_aws_credentials() / resolve_aws_auth_env_var()
|
||||
# (e.g. provider auto-detect, status command, cron run_job) burns
|
||||
# ~2s waiting for the metadata service at 169.254.169.254 to time
|
||||
# out. Tests don't run on EC2 — IMDS is always unreachable here.
|
||||
monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true")
|
||||
monkeypatch.setenv("AWS_METADATA_SERVICE_TIMEOUT", "1")
|
||||
monkeypatch.setenv("AWS_METADATA_SERVICE_NUM_ATTEMPTS", "1")
|
||||
|
||||
# 5. Reset plugin singleton so tests don't leak plugins from
|
||||
# ~/.hermes/plugins/ (which, per step 3, is now empty — but the
|
||||
# singleton might still be cached from a previous test).
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from unittest.mock import patch
|
|||
|
||||
from gateway.channel_directory import (
|
||||
build_channel_directory,
|
||||
lookup_channel_type,
|
||||
resolve_channel_name,
|
||||
format_directory_for_display,
|
||||
load_directory,
|
||||
|
|
@ -285,3 +286,49 @@ class TestFormatDirectoryForDisplay:
|
|||
assert "Discord (Server1):" in result
|
||||
assert "Discord (Server2):" in result
|
||||
assert "discord:#general" in result
|
||||
|
||||
|
||||
class TestLookupChannelType:
|
||||
def _setup(self, tmp_path, platforms):
|
||||
cache_file = _write_directory(tmp_path, platforms)
|
||||
return patch("gateway.channel_directory.DIRECTORY_PATH", cache_file)
|
||||
|
||||
def test_forum_channel(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "100", "name": "ideas", "guild": "Server1", "type": "forum"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert lookup_channel_type("discord", "100") == "forum"
|
||||
|
||||
def test_regular_channel(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "200", "name": "general", "guild": "Server1", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert lookup_channel_type("discord", "200") == "channel"
|
||||
|
||||
def test_unknown_chat_id_returns_none(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "200", "name": "general", "guild": "Server1", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert lookup_channel_type("discord", "999") is None
|
||||
|
||||
def test_unknown_platform_returns_none(self, tmp_path):
|
||||
with self._setup(tmp_path, {}):
|
||||
assert lookup_channel_type("discord", "100") is None
|
||||
|
||||
def test_channel_without_type_key_returns_none(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "300", "name": "general", "guild": "Server1"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert lookup_channel_type("discord", "300") is None
|
||||
|
|
|
|||
|
|
@ -160,6 +160,30 @@ class TestCommandBypassActiveSession:
|
|||
assert sk not in adapter._pending_messages
|
||||
assert any("handled:status" in r for r in adapter.sent_responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agents_bypasses_guard(self):
|
||||
"""/agents must bypass so active-task queries don't interrupt runs."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/agents"))
|
||||
|
||||
assert sk not in adapter._pending_messages
|
||||
assert any("handled:agents" in r for r in adapter.sent_responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tasks_alias_bypasses_guard(self):
|
||||
"""/tasks alias must bypass active-session guard too."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/tasks"))
|
||||
|
||||
assert sk not in adapter._pending_messages
|
||||
assert any("handled:tasks" in r for r in adapter.sent_responses)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_bypasses_guard(self):
|
||||
"""/background must bypass so it spawns a parallel task, not an interrupt."""
|
||||
|
|
@ -176,6 +200,38 @@ class TestCommandBypassActiveSession:
|
|||
"/background response was not sent back to the user"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help_bypasses_guard(self):
|
||||
"""/help must bypass so it is not silently dropped as pending slash text."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/help"))
|
||||
|
||||
assert sk not in adapter._pending_messages, (
|
||||
"/help was queued as a pending message instead of being dispatched"
|
||||
)
|
||||
assert any("handled:help" in r for r in adapter.sent_responses), (
|
||||
"/help response was not sent back to the user"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_bypasses_guard(self):
|
||||
"""/update must bypass so it is not discarded by the pending-command safety net."""
|
||||
adapter = _make_adapter()
|
||||
sk = _session_key()
|
||||
adapter._active_sessions[sk] = asyncio.Event()
|
||||
|
||||
await adapter.handle_message(_make_event("/update"))
|
||||
|
||||
assert sk not in adapter._pending_messages, (
|
||||
"/update was queued as a pending message instead of being dispatched"
|
||||
)
|
||||
assert any("handled:update" in r for r in adapter.sent_responses), (
|
||||
"/update response was not sent back to the user"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_bypasses_guard(self):
|
||||
"""/queue must bypass so it can queue without interrupting."""
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ class TestSend:
|
|||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
adapter._http_client = mock_client
|
||||
adapter._session_webhooks["chat-123"] = "https://cached.example/webhook"
|
||||
adapter._session_webhooks["chat-123"] = ("https://cached.example/webhook", 9999999999999)
|
||||
|
||||
result = await adapter.send("chat-123", "Hello!")
|
||||
assert result.success is True
|
||||
|
|
@ -681,3 +681,290 @@ class TestIncomingHandlerProcess:
|
|||
processing_gate.set()
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text extraction — mention preservation + platform sanity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestExtractTextMentions:
|
||||
|
||||
def test_preserves_at_mentions_in_text(self):
|
||||
"""@mentions are routing signals (via isInAtList), not text to strip.
|
||||
|
||||
Stripping all @handles collateral-damages emails, SSH URLs, and
|
||||
literal references the user wrote.
|
||||
"""
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
cases = [
|
||||
("@bot hello", "@bot hello"),
|
||||
("contact alice@example.com", "contact alice@example.com"),
|
||||
("git@github.com:foo/bar.git", "git@github.com:foo/bar.git"),
|
||||
("what does @openai think", "what does @openai think"),
|
||||
("@机器人 转发给 @老王", "@机器人 转发给 @老王"),
|
||||
]
|
||||
for text, expected in cases:
|
||||
msg = MagicMock()
|
||||
msg.text = text
|
||||
msg.rich_text = None
|
||||
msg.rich_text_content = None
|
||||
assert DingTalkAdapter._extract_text(msg) == expected, (
|
||||
f"mangled: {text!r} -> {DingTalkAdapter._extract_text(msg)!r}"
|
||||
)
|
||||
|
||||
def test_dingtalk_in_platform_enum(self):
|
||||
assert Platform.DINGTALK.value == "dingtalk"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concurrency — chat-scoped message context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageContextIsolation:
|
||||
|
||||
def test_contexts_keyed_by_chat_id(self):
|
||||
"""Two concurrent chats must not clobber each other's context."""
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
adapter = DingTalkAdapter(PlatformConfig(enabled=True))
|
||||
|
||||
msg_a = MagicMock(conversation_id="chat-A", sender_staff_id="user-A")
|
||||
msg_b = MagicMock(conversation_id="chat-B", sender_staff_id="user-B")
|
||||
adapter._message_contexts["chat-A"] = msg_a
|
||||
adapter._message_contexts["chat-B"] = msg_b
|
||||
|
||||
assert adapter._message_contexts["chat-A"] is msg_a
|
||||
assert adapter._message_contexts["chat-B"] is msg_b
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Card lifecycle: finalize via metadata["streaming"]
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCardLifecycle:
|
||||
|
||||
@pytest.fixture
|
||||
def adapter_with_card(self):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
a = DingTalkAdapter(PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"card_template_id": "tmpl-1"},
|
||||
))
|
||||
a._card_sdk = MagicMock()
|
||||
a._card_sdk.create_card_with_options_async = AsyncMock()
|
||||
a._card_sdk.deliver_card_with_options_async = AsyncMock()
|
||||
a._card_sdk.streaming_update_with_options_async = AsyncMock()
|
||||
a._http_client = AsyncMock()
|
||||
a._get_access_token = AsyncMock(return_value="token")
|
||||
# Minimal message context
|
||||
msg = MagicMock(
|
||||
conversation_id="chat-1",
|
||||
conversation_type="1",
|
||||
sender_staff_id="staff-1",
|
||||
message_id="user-msg-1",
|
||||
)
|
||||
a._message_contexts["chat-1"] = msg
|
||||
a._session_webhooks["chat-1"] = (
|
||||
"https://api.dingtalk.com/x", 9999999999999,
|
||||
)
|
||||
return a
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_final_reply_finalizes_card(self, adapter_with_card):
|
||||
"""send(reply_to=...) creates a closed card (final response path)."""
|
||||
a = adapter_with_card
|
||||
result = await a.send("chat-1", "Hello", reply_to="user-msg-1")
|
||||
assert result.success
|
||||
call = a._card_sdk.streaming_update_with_options_async.call_args
|
||||
assert call[0][0].is_finalize is True
|
||||
# Not tracked as streaming — it's already closed.
|
||||
assert "chat-1" not in a._streaming_cards
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_intermediate_send_stays_streaming(self, adapter_with_card):
|
||||
"""send() without reply_to creates an OPEN card (tool progress /
|
||||
commentary / streaming first chunk). No flicker closed→streaming
|
||||
when edit_message follows."""
|
||||
a = adapter_with_card
|
||||
result = await a.send("chat-1", "💻 terminal: ls")
|
||||
assert result.success
|
||||
call = a._card_sdk.streaming_update_with_options_async.call_args
|
||||
assert call[0][0].is_finalize is False
|
||||
# Tracked for sibling cleanup.
|
||||
assert result.message_id in a._streaming_cards.get("chat-1", {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_done_fires_only_when_reply_to_is_set(self, adapter_with_card):
|
||||
"""reply_to distinguishes final response (base.py) from tool-progress
|
||||
sends (run.py). Done must only fire for the former."""
|
||||
a = adapter_with_card
|
||||
fired: list[str] = []
|
||||
a._fire_done_reaction = lambda cid: fired.append(cid)
|
||||
|
||||
# Tool-progress / commentary path: no reply_to — no Done.
|
||||
await a.send("chat-1", "tool line")
|
||||
assert fired == []
|
||||
|
||||
# Final response path: reply_to set — Done fires.
|
||||
await a.send("chat-1", "final", reply_to="user-msg-1")
|
||||
assert fired == ["chat-1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_message_finalize_fires_done(self, adapter_with_card):
|
||||
"""Stream consumer's final edit_message(finalize=True) fires Done."""
|
||||
a = adapter_with_card
|
||||
fired: list[str] = []
|
||||
a._fire_done_reaction = lambda cid: fired.append(cid)
|
||||
|
||||
await a.send("chat-1", "initial")
|
||||
# Reopen via edit_message(finalize=False) then close.
|
||||
await a.edit_message(
|
||||
chat_id="chat-1", message_id="track-X",
|
||||
content="streaming...", finalize=False,
|
||||
)
|
||||
await a.edit_message(
|
||||
chat_id="chat-1", message_id="track-X",
|
||||
content="final", finalize=True,
|
||||
)
|
||||
assert "chat-1" in fired
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_message_finalize_false_tracks_sibling(self, adapter_with_card):
|
||||
"""After edit_message(finalize=False), card is tracked as open."""
|
||||
a = adapter_with_card
|
||||
await a.edit_message(
|
||||
chat_id="chat-1", message_id="track-1",
|
||||
content="partial", finalize=False,
|
||||
)
|
||||
assert "chat-1" in a._streaming_cards
|
||||
assert a._streaming_cards["chat-1"].get("track-1") == "partial"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_send_auto_closes_sibling_streaming_cards(
|
||||
self, adapter_with_card,
|
||||
):
|
||||
"""Tool-progress card left open (send without reply_to + edits) must
|
||||
be auto-closed when the final-reply send arrives."""
|
||||
a = adapter_with_card
|
||||
# First tool: intermediate send — card stays open.
|
||||
r1 = await a.send("chat-1", "💻 tool1")
|
||||
# Second tool: edit_message(finalize=False) — keeps streaming.
|
||||
await a.edit_message(
|
||||
chat_id="chat-1", message_id=r1.message_id,
|
||||
content="💻 tool1\n💻 tool2", finalize=False,
|
||||
)
|
||||
assert r1.message_id in a._streaming_cards.get("chat-1", {})
|
||||
a._card_sdk.streaming_update_with_options_async.reset_mock()
|
||||
|
||||
# Final response send auto-closes the sibling.
|
||||
await a.send("chat-1", "final answer", reply_to="user-msg")
|
||||
|
||||
calls = a._card_sdk.streaming_update_with_options_async.call_args_list
|
||||
assert len(calls) >= 2
|
||||
# First call was the sibling close with last-seen tool-progress content.
|
||||
first_req = calls[0][0][0]
|
||||
assert first_req.out_track_id == r1.message_id
|
||||
assert first_req.is_finalize is True
|
||||
assert "tool1" in first_req.content
|
||||
# Streaming tracking is cleared after close.
|
||||
assert "chat-1" not in a._streaming_cards
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_message_requires_message_id(self, adapter_with_card):
|
||||
a = adapter_with_card
|
||||
result = await a.edit_message(
|
||||
chat_id="chat-1", message_id="", content="x", finalize=True,
|
||||
)
|
||||
assert result.success is False
|
||||
a._card_sdk.streaming_update_with_options_async.assert_not_called()
|
||||
|
||||
def test_fire_done_reaction_is_idempotent(self, adapter_with_card):
|
||||
a = adapter_with_card
|
||||
captured = []
|
||||
def _capture(coro):
|
||||
captured.append(coro)
|
||||
a._spawn_bg = _capture
|
||||
|
||||
a._fire_done_reaction("chat-1")
|
||||
a._fire_done_reaction("chat-1")
|
||||
assert len(captured) == 1
|
||||
captured[0].close()
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AI Card Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDingTalkAdapterAICards:
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
return PlatformConfig(
|
||||
enabled=True,
|
||||
extra={
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"card_template_id": "test_card_template",
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_stream_client(self):
|
||||
client = MagicMock()
|
||||
client.get_access_token = MagicMock(return_value="test_token")
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_client(self):
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
msg = MagicMock()
|
||||
msg.message_id = "test_msg_id"
|
||||
msg.conversation_id = "test_conv_id"
|
||||
msg.conversation_type = "1"
|
||||
msg.sender_id = "sender1"
|
||||
msg.sender_nick = "Test User"
|
||||
msg.sender_staff_id = "staff1"
|
||||
msg.text = MagicMock(content="Hello")
|
||||
msg.session_webhook = "https://api.dingtalk.com/robot/sendBySession?session=test"
|
||||
msg.session_webhook_expired_time = 999999999999
|
||||
msg.create_at = int(datetime.now(tz=timezone.utc).timestamp() * 1000)
|
||||
msg.at_users = []
|
||||
return msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_ai_card_if_configured(self, config, mock_stream_client, mock_http_client, mock_message):
|
||||
from gateway.platforms.dingtalk import DingTalkAdapter
|
||||
|
||||
adapter = DingTalkAdapter(config)
|
||||
adapter._stream_client = mock_stream_client
|
||||
adapter._http_client = mock_http_client
|
||||
adapter._message_contexts["test_conv_id"] = mock_message
|
||||
adapter._session_webhooks = {"test_conv_id": ("https://api.dingtalk.com/robot/sendBySession?session=test", 9999999999999)}
|
||||
adapter._card_template_id = "test_card_template"
|
||||
|
||||
# Mock the card SDK with proper async methods
|
||||
mock_card_sdk = MagicMock()
|
||||
mock_card_sdk.create_card_with_options_async = AsyncMock()
|
||||
mock_card_sdk.deliver_card_with_options_async = AsyncMock()
|
||||
mock_card_sdk.streaming_update_with_options_async = AsyncMock()
|
||||
adapter._card_sdk = mock_card_sdk
|
||||
|
||||
# Mock access token
|
||||
adapter._get_access_token = AsyncMock(return_value="test_token")
|
||||
|
||||
result = await adapter.send("test_conv_id", "Hello World")
|
||||
|
||||
mock_card_sdk.create_card_with_options_async.assert_called_once()
|
||||
mock_card_sdk.deliver_card_with_options_async.assert_called_once()
|
||||
mock_card_sdk.streaming_update_with_options_async.assert_called_once()
|
||||
assert result.success is True
|
||||
|
|
|
|||
|
|
@ -157,3 +157,232 @@ async def test_send_does_not_retry_on_unrelated_errors():
|
|||
# Only the first attempt happens — no reference-retry replay.
|
||||
assert channel.send.await_count == 1
|
||||
assert send_calls[0]["reference"] is reference_obj
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Forum channel tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import discord as _discord_mod # noqa: E402 — imported after _ensure_discord_mock
|
||||
|
||||
|
||||
class TestIsForumParent:
|
||||
def test_none_returns_false(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
assert adapter._is_forum_parent(None) is False
|
||||
|
||||
def test_forum_channel_class_instance(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
forum_cls = getattr(_discord_mod, "ForumChannel", None)
|
||||
if forum_cls is None:
|
||||
# Re-create a type for the mock
|
||||
forum_cls = type("ForumChannel", (), {})
|
||||
_discord_mod.ForumChannel = forum_cls
|
||||
ch = forum_cls()
|
||||
assert adapter._is_forum_parent(ch) is True
|
||||
|
||||
def test_type_value_15(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
ch = SimpleNamespace(type=15)
|
||||
assert adapter._is_forum_parent(ch) is True
|
||||
|
||||
def test_regular_channel_returns_false(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
ch = SimpleNamespace(type=0)
|
||||
assert adapter._is_forum_parent(ch) is False
|
||||
|
||||
def test_thread_returns_false(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
ch = SimpleNamespace(type=11) # public thread
|
||||
assert adapter._is_forum_parent(ch) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_forum_creates_thread_post():
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
# thread object has no 'send' so _send_to_forum uses thread.thread
|
||||
thread_ch = SimpleNamespace(id=555, send=AsyncMock(return_value=SimpleNamespace(id=600)))
|
||||
thread = SimpleNamespace(
|
||||
id=555,
|
||||
message=SimpleNamespace(id=500),
|
||||
thread=thread_ch,
|
||||
)
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.name = "ideas"
|
||||
forum_channel.create_thread = AsyncMock(return_value=thread)
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: forum_channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
result = await adapter.send("999", "Hello forum!")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "500"
|
||||
forum_channel.create_thread.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_forum_sends_remaining_chunks():
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
# Force a small max message length so the message splits
|
||||
adapter.MAX_MESSAGE_LENGTH = 20
|
||||
|
||||
chunk_msg_1 = SimpleNamespace(id=500)
|
||||
chunk_msg_2 = SimpleNamespace(id=501)
|
||||
thread_ch = SimpleNamespace(
|
||||
id=555,
|
||||
send=AsyncMock(return_value=chunk_msg_2),
|
||||
)
|
||||
# thread object has no 'send' so _send_to_forum uses thread.thread
|
||||
thread = SimpleNamespace(
|
||||
id=555,
|
||||
message=chunk_msg_1,
|
||||
thread=thread_ch,
|
||||
)
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.name = "ideas"
|
||||
forum_channel.create_thread = AsyncMock(return_value=thread)
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: forum_channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
result = await adapter.send("999", "A" * 50)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "500"
|
||||
# Should have sent at least one follow-up chunk
|
||||
assert thread_ch.send.await_count >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_forum_create_thread_failure():
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.name = "ideas"
|
||||
forum_channel.create_thread = AsyncMock(side_effect=Exception("rate limited"))
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: forum_channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
result = await adapter.send("999", "Hello forum!")
|
||||
|
||||
assert result.success is False
|
||||
assert "rate limited" in result.error
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Forum follow-up chunk failure reporting + media on forum paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_forum_follow_up_chunk_failures_collected_as_warnings():
|
||||
"""Partial-send chunk failures surface in raw_response['warnings']."""
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
adapter.MAX_MESSAGE_LENGTH = 20
|
||||
|
||||
chunk_msg_1 = SimpleNamespace(id=500)
|
||||
# Every follow-up chunk fails — we should collect a warning per failure
|
||||
thread_ch = SimpleNamespace(
|
||||
id=555,
|
||||
send=AsyncMock(side_effect=Exception("rate limited")),
|
||||
)
|
||||
thread = SimpleNamespace(id=555, message=chunk_msg_1, thread=thread_ch)
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.name = "ideas"
|
||||
forum_channel.create_thread = AsyncMock(return_value=thread)
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: forum_channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
# Long enough to produce multiple chunks
|
||||
result = await adapter.send("999", "A" * 60)
|
||||
|
||||
# Starter message (first chunk) was delivered via create_thread, so send is
|
||||
# successful overall — but follow-up chunks all failed and are reported.
|
||||
assert result.success is True
|
||||
assert result.message_id == "500"
|
||||
warnings = (result.raw_response or {}).get("warnings") or []
|
||||
assert len(warnings) >= 1
|
||||
assert all("rate limited" in w for w in warnings)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forum_post_file_creates_thread_with_attachment():
|
||||
"""_forum_post_file routes file-bearing sends to create_thread with file kwarg."""
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
thread_ch = SimpleNamespace(id=777, send=AsyncMock())
|
||||
thread = SimpleNamespace(id=777, message=SimpleNamespace(id=800), thread=thread_ch)
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.name = "ideas"
|
||||
forum_channel.create_thread = AsyncMock(return_value=thread)
|
||||
|
||||
# discord.File is a real class; build a MagicMock that looks like one
|
||||
fake_file = SimpleNamespace(filename="photo.png")
|
||||
|
||||
result = await adapter._forum_post_file(
|
||||
forum_channel,
|
||||
content="here is a photo",
|
||||
file=fake_file,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "800"
|
||||
forum_channel.create_thread.assert_awaited_once()
|
||||
call_kwargs = forum_channel.create_thread.await_args.kwargs
|
||||
assert call_kwargs["file"] is fake_file
|
||||
assert call_kwargs["content"] == "here is a photo"
|
||||
# Thread name derived from content's first line
|
||||
assert call_kwargs["name"] == "here is a photo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forum_post_file_uses_filename_when_no_content():
|
||||
"""Thread name falls back to file.filename when no content is provided."""
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
thread = SimpleNamespace(id=1, message=SimpleNamespace(id=2), thread=SimpleNamespace(id=1, send=AsyncMock()))
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 10
|
||||
forum_channel.name = "forum"
|
||||
forum_channel.create_thread = AsyncMock(return_value=thread)
|
||||
|
||||
fake_file = SimpleNamespace(filename="voice-message.ogg")
|
||||
result = await adapter._forum_post_file(forum_channel, content="", file=fake_file)
|
||||
|
||||
assert result.success is True
|
||||
call_kwargs = forum_channel.create_thread.await_args.kwargs
|
||||
# Content was empty → thread name derived from filename
|
||||
assert call_kwargs["name"] == "voice-message.ogg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forum_post_file_creation_failure():
|
||||
"""_forum_post_file returns a failed SendResult when create_thread raises."""
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.create_thread = AsyncMock(side_effect=Exception("missing perms"))
|
||||
|
||||
result = await adapter._forum_post_file(
|
||||
forum_channel,
|
||||
content="hi",
|
||||
file=SimpleNamespace(filename="x.png"),
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "missing perms" in (result.error or "")
|
||||
|
|
|
|||
|
|
@ -601,6 +601,10 @@ class TestAdapterBehavior(unittest.TestCase):
|
|||
calls.append("message_recalled")
|
||||
return self
|
||||
|
||||
def register_p2_customized_event(self, event_key, _handler):
|
||||
calls.append(f"customized:{event_key}")
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
calls.append("build")
|
||||
return "handler"
|
||||
|
|
@ -628,6 +632,7 @@ class TestAdapterBehavior(unittest.TestCase):
|
|||
"bot_deleted",
|
||||
"p2p_chat_entered",
|
||||
"message_recalled",
|
||||
"customized:drive.notice.comment_add_v1",
|
||||
"build",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
261
tests/gateway/test_feishu_comment.py
Normal file
261
tests/gateway/test_feishu_comment.py
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
"""Tests for feishu_comment — event filtering, access control integration, wiki reverse lookup."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from gateway.platforms.feishu_comment import (
|
||||
parse_drive_comment_event,
|
||||
_ALLOWED_NOTICE_TYPES,
|
||||
_sanitize_comment_text,
|
||||
)
|
||||
|
||||
|
||||
def _make_event(
|
||||
comment_id="c1",
|
||||
reply_id="r1",
|
||||
notice_type="add_reply",
|
||||
file_token="docx_token",
|
||||
file_type="docx",
|
||||
from_open_id="ou_user",
|
||||
to_open_id="ou_bot",
|
||||
is_mentioned=True,
|
||||
):
|
||||
"""Build a minimal drive comment event SimpleNamespace."""
|
||||
return SimpleNamespace(event={
|
||||
"event_id": "evt_1",
|
||||
"comment_id": comment_id,
|
||||
"reply_id": reply_id,
|
||||
"is_mentioned": is_mentioned,
|
||||
"timestamp": "1713200000",
|
||||
"notice_meta": {
|
||||
"file_token": file_token,
|
||||
"file_type": file_type,
|
||||
"notice_type": notice_type,
|
||||
"from_user_id": {"open_id": from_open_id},
|
||||
"to_user_id": {"open_id": to_open_id},
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
class TestParseEvent(unittest.TestCase):
|
||||
def test_parse_valid_event(self):
|
||||
evt = _make_event()
|
||||
parsed = parse_drive_comment_event(evt)
|
||||
self.assertIsNotNone(parsed)
|
||||
self.assertEqual(parsed["comment_id"], "c1")
|
||||
self.assertEqual(parsed["file_type"], "docx")
|
||||
self.assertEqual(parsed["from_open_id"], "ou_user")
|
||||
self.assertEqual(parsed["to_open_id"], "ou_bot")
|
||||
|
||||
def test_parse_missing_event_attr(self):
|
||||
self.assertIsNone(parse_drive_comment_event(object()))
|
||||
|
||||
def test_parse_none_event(self):
|
||||
self.assertIsNone(parse_drive_comment_event(SimpleNamespace()))
|
||||
|
||||
|
||||
class TestEventFiltering(unittest.TestCase):
|
||||
"""Test the filtering logic in handle_drive_comment_event."""
|
||||
|
||||
def _run(self, coro):
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
@patch("gateway.platforms.feishu_comment_rules.load_config")
|
||||
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
|
||||
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed")
|
||||
def test_self_reply_filtered(self, mock_allowed, mock_resolve, mock_load):
|
||||
"""Events where from_open_id == self_open_id should be dropped."""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
|
||||
evt = _make_event(from_open_id="ou_bot", to_open_id="ou_bot")
|
||||
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("gateway.platforms.feishu_comment_rules.load_config")
|
||||
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
|
||||
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed")
|
||||
def test_wrong_receiver_filtered(self, mock_allowed, mock_resolve, mock_load):
|
||||
"""Events where to_open_id != self_open_id should be dropped."""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
|
||||
evt = _make_event(to_open_id="ou_other_bot")
|
||||
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("gateway.platforms.feishu_comment_rules.load_config")
|
||||
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
|
||||
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed")
|
||||
def test_empty_to_open_id_filtered(self, mock_allowed, mock_resolve, mock_load):
|
||||
"""Events with empty to_open_id should be dropped."""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
|
||||
evt = _make_event(to_open_id="")
|
||||
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
|
||||
mock_load.assert_not_called()
|
||||
|
||||
@patch("gateway.platforms.feishu_comment_rules.load_config")
|
||||
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
|
||||
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed")
|
||||
def test_invalid_notice_type_filtered(self, mock_allowed, mock_resolve, mock_load):
|
||||
"""Events with unsupported notice_type should be dropped."""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
|
||||
evt = _make_event(notice_type="resolve_comment")
|
||||
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
|
||||
mock_load.assert_not_called()
|
||||
|
||||
def test_allowed_notice_types(self):
|
||||
self.assertIn("add_comment", _ALLOWED_NOTICE_TYPES)
|
||||
self.assertIn("add_reply", _ALLOWED_NOTICE_TYPES)
|
||||
self.assertNotIn("resolve_comment", _ALLOWED_NOTICE_TYPES)
|
||||
|
||||
|
||||
class TestAccessControlIntegration(unittest.TestCase):
|
||||
def _run(self, coro):
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
@patch("gateway.platforms.feishu_comment_rules.has_wiki_keys", return_value=False)
|
||||
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed", return_value=False)
|
||||
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
|
||||
@patch("gateway.platforms.feishu_comment_rules.load_config")
|
||||
def test_denied_user_no_side_effects(self, mock_load, mock_resolve, mock_allowed, mock_wiki_keys):
|
||||
"""Denied user should not trigger typing reaction or agent."""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
from gateway.platforms.feishu_comment_rules import ResolvedCommentRule
|
||||
|
||||
mock_resolve.return_value = ResolvedCommentRule(True, "allowlist", frozenset(), "top")
|
||||
mock_load.return_value = Mock()
|
||||
|
||||
client = Mock()
|
||||
evt = _make_event()
|
||||
self._run(handle_drive_comment_event(client, evt, self_open_id="ou_bot"))
|
||||
|
||||
# No API calls should be made for denied users
|
||||
client.request.assert_not_called()
|
||||
|
||||
@patch("gateway.platforms.feishu_comment_rules.has_wiki_keys", return_value=False)
|
||||
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed", return_value=False)
|
||||
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
|
||||
@patch("gateway.platforms.feishu_comment_rules.load_config")
|
||||
def test_disabled_comment_skipped(self, mock_load, mock_resolve, mock_allowed, mock_wiki_keys):
|
||||
"""Disabled comments should return immediately."""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
from gateway.platforms.feishu_comment_rules import ResolvedCommentRule
|
||||
|
||||
mock_resolve.return_value = ResolvedCommentRule(False, "allowlist", frozenset(), "top")
|
||||
mock_load.return_value = Mock()
|
||||
|
||||
evt = _make_event()
|
||||
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
|
||||
mock_allowed.assert_not_called()
|
||||
|
||||
|
||||
class TestSanitizeCommentText(unittest.TestCase):
|
||||
def test_angle_brackets_escaped(self):
|
||||
self.assertEqual(_sanitize_comment_text("List<String>"), "List<String>")
|
||||
|
||||
def test_ampersand_escaped_first(self):
|
||||
self.assertEqual(_sanitize_comment_text("a & b"), "a & b")
|
||||
|
||||
def test_ampersand_not_double_escaped(self):
|
||||
result = _sanitize_comment_text("a < b & c > d")
|
||||
self.assertEqual(result, "a < b & c > d")
|
||||
self.assertNotIn("&lt;", result)
|
||||
self.assertNotIn("&gt;", result)
|
||||
|
||||
def test_plain_text_unchanged(self):
|
||||
self.assertEqual(_sanitize_comment_text("hello world"), "hello world")
|
||||
|
||||
def test_empty_string(self):
|
||||
self.assertEqual(_sanitize_comment_text(""), "")
|
||||
|
||||
def test_code_snippet(self):
|
||||
text = 'if (a < b && c > 0) { return "ok"; }'
|
||||
result = _sanitize_comment_text(text)
|
||||
self.assertNotIn("<", result)
|
||||
self.assertNotIn(">", result)
|
||||
self.assertIn("<", result)
|
||||
self.assertIn(">", result)
|
||||
|
||||
|
||||
class TestWikiReverseLookup(unittest.TestCase):
|
||||
def _run(self, coro):
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
@patch("gateway.platforms.feishu_comment._exec_request")
|
||||
def test_reverse_lookup_success(self, mock_exec):
|
||||
from gateway.platforms.feishu_comment import _reverse_lookup_wiki_token
|
||||
|
||||
mock_exec.return_value = (0, "Success", {
|
||||
"node": {"node_token": "WIKI_TOKEN_123", "obj_token": "docx_abc"},
|
||||
})
|
||||
result = self._run(_reverse_lookup_wiki_token(Mock(), "docx", "docx_abc"))
|
||||
self.assertEqual(result, "WIKI_TOKEN_123")
|
||||
# Verify correct API params
|
||||
call_args = mock_exec.call_args
|
||||
queries = call_args[1].get("queries") or call_args[0][3]
|
||||
query_dict = dict(queries)
|
||||
self.assertEqual(query_dict["token"], "docx_abc")
|
||||
self.assertEqual(query_dict["obj_type"], "docx")
|
||||
|
||||
@patch("gateway.platforms.feishu_comment._exec_request")
|
||||
def test_reverse_lookup_not_wiki(self, mock_exec):
|
||||
from gateway.platforms.feishu_comment import _reverse_lookup_wiki_token
|
||||
|
||||
mock_exec.return_value = (131001, "not found", {})
|
||||
result = self._run(_reverse_lookup_wiki_token(Mock(), "docx", "docx_abc"))
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("gateway.platforms.feishu_comment._exec_request")
|
||||
def test_reverse_lookup_service_error(self, mock_exec):
|
||||
from gateway.platforms.feishu_comment import _reverse_lookup_wiki_token
|
||||
|
||||
mock_exec.return_value = (500, "internal error", {})
|
||||
result = self._run(_reverse_lookup_wiki_token(Mock(), "docx", "docx_abc"))
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("gateway.platforms.feishu_comment._reverse_lookup_wiki_token", new_callable=AsyncMock)
|
||||
@patch("gateway.platforms.feishu_comment_rules.has_wiki_keys", return_value=True)
|
||||
@patch("gateway.platforms.feishu_comment_rules.is_user_allowed", return_value=True)
|
||||
@patch("gateway.platforms.feishu_comment_rules.resolve_rule")
|
||||
@patch("gateway.platforms.feishu_comment_rules.load_config")
|
||||
@patch("gateway.platforms.feishu_comment.add_comment_reaction", new_callable=AsyncMock)
|
||||
@patch("gateway.platforms.feishu_comment.batch_query_comment", new_callable=AsyncMock)
|
||||
@patch("gateway.platforms.feishu_comment.query_document_meta", new_callable=AsyncMock)
|
||||
def test_wiki_lookup_triggered_when_no_exact_match(
|
||||
self, mock_meta, mock_batch, mock_reaction,
|
||||
mock_load, mock_resolve, mock_allowed, mock_wiki_keys, mock_lookup,
|
||||
):
|
||||
"""Wiki reverse lookup should fire when rule falls to wildcard/top and wiki keys exist."""
|
||||
from gateway.platforms.feishu_comment import handle_drive_comment_event
|
||||
from gateway.platforms.feishu_comment_rules import ResolvedCommentRule
|
||||
|
||||
# First resolve returns wildcard (no exact match), second returns exact wiki match
|
||||
mock_resolve.side_effect = [
|
||||
ResolvedCommentRule(True, "allowlist", frozenset(), "wildcard"),
|
||||
ResolvedCommentRule(True, "allowlist", frozenset(), "exact:wiki:WIKI123"),
|
||||
]
|
||||
mock_load.return_value = Mock()
|
||||
mock_lookup.return_value = "WIKI123"
|
||||
mock_meta.return_value = {"title": "Test", "url": ""}
|
||||
mock_batch.return_value = {"is_whole": False, "quote": ""}
|
||||
|
||||
evt = _make_event()
|
||||
# Will proceed past access control but fail later — that's OK, we just test the lookup
|
||||
try:
|
||||
self._run(handle_drive_comment_event(Mock(), evt, self_open_id="ou_bot"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
mock_lookup.assert_called_once_with(unittest.mock.ANY, "docx", "docx_token")
|
||||
self.assertEqual(mock_resolve.call_count, 2)
|
||||
# Second call should include wiki_token
|
||||
second_call_kwargs = mock_resolve.call_args_list[1]
|
||||
self.assertEqual(second_call_kwargs[1].get("wiki_token") or second_call_kwargs[0][3], "WIKI123")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
320
tests/gateway/test_feishu_comment_rules.py
Normal file
320
tests/gateway/test_feishu_comment_rules.py
Normal file
|
|
@ -0,0 +1,320 @@
|
|||
"""Tests for feishu_comment_rules — 3-tier access control rule engine."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from gateway.platforms.feishu_comment_rules import (
|
||||
CommentsConfig,
|
||||
CommentDocumentRule,
|
||||
ResolvedCommentRule,
|
||||
_MtimeCache,
|
||||
_parse_document_rule,
|
||||
has_wiki_keys,
|
||||
is_user_allowed,
|
||||
load_config,
|
||||
pairing_add,
|
||||
pairing_list,
|
||||
pairing_remove,
|
||||
resolve_rule,
|
||||
)
|
||||
|
||||
|
||||
class TestCommentDocumentRuleParsing(unittest.TestCase):
|
||||
def test_parse_full_rule(self):
|
||||
rule = _parse_document_rule({
|
||||
"enabled": False,
|
||||
"policy": "allowlist",
|
||||
"allow_from": ["ou_a", "ou_b"],
|
||||
})
|
||||
self.assertFalse(rule.enabled)
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertEqual(rule.allow_from, frozenset(["ou_a", "ou_b"]))
|
||||
|
||||
def test_parse_partial_rule(self):
|
||||
rule = _parse_document_rule({"policy": "allowlist"})
|
||||
self.assertIsNone(rule.enabled)
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertIsNone(rule.allow_from)
|
||||
|
||||
def test_parse_empty_rule(self):
|
||||
rule = _parse_document_rule({})
|
||||
self.assertIsNone(rule.enabled)
|
||||
self.assertIsNone(rule.policy)
|
||||
self.assertIsNone(rule.allow_from)
|
||||
|
||||
def test_invalid_policy_ignored(self):
|
||||
rule = _parse_document_rule({"policy": "invalid_value"})
|
||||
self.assertIsNone(rule.policy)
|
||||
|
||||
|
||||
class TestResolveRule(unittest.TestCase):
|
||||
def test_exact_match(self):
|
||||
cfg = CommentsConfig(
|
||||
policy="pairing",
|
||||
allow_from=frozenset(["ou_top"]),
|
||||
documents={
|
||||
"docx:abc": CommentDocumentRule(policy="allowlist"),
|
||||
},
|
||||
)
|
||||
rule = resolve_rule(cfg, "docx", "abc")
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertTrue(rule.match_source.startswith("exact:"))
|
||||
|
||||
def test_wildcard_match(self):
|
||||
cfg = CommentsConfig(
|
||||
policy="pairing",
|
||||
documents={
|
||||
"*": CommentDocumentRule(policy="allowlist"),
|
||||
},
|
||||
)
|
||||
rule = resolve_rule(cfg, "docx", "unknown")
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertEqual(rule.match_source, "wildcard")
|
||||
|
||||
def test_top_level_fallback(self):
|
||||
cfg = CommentsConfig(policy="pairing", allow_from=frozenset(["ou_top"]))
|
||||
rule = resolve_rule(cfg, "docx", "whatever")
|
||||
self.assertEqual(rule.policy, "pairing")
|
||||
self.assertEqual(rule.allow_from, frozenset(["ou_top"]))
|
||||
self.assertEqual(rule.match_source, "top")
|
||||
|
||||
def test_exact_overrides_wildcard(self):
|
||||
cfg = CommentsConfig(
|
||||
policy="pairing",
|
||||
documents={
|
||||
"*": CommentDocumentRule(policy="pairing"),
|
||||
"docx:abc": CommentDocumentRule(policy="allowlist"),
|
||||
},
|
||||
)
|
||||
rule = resolve_rule(cfg, "docx", "abc")
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertTrue(rule.match_source.startswith("exact:"))
|
||||
|
||||
def test_field_by_field_fallback(self):
|
||||
"""Exact sets policy, wildcard sets allow_from, enabled from top."""
|
||||
cfg = CommentsConfig(
|
||||
enabled=True,
|
||||
policy="pairing",
|
||||
allow_from=frozenset(["ou_top"]),
|
||||
documents={
|
||||
"*": CommentDocumentRule(allow_from=frozenset(["ou_wildcard"])),
|
||||
"docx:abc": CommentDocumentRule(policy="allowlist"),
|
||||
},
|
||||
)
|
||||
rule = resolve_rule(cfg, "docx", "abc")
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertEqual(rule.allow_from, frozenset(["ou_wildcard"]))
|
||||
self.assertTrue(rule.enabled)
|
||||
|
||||
def test_explicit_empty_allow_from_does_not_fall_through(self):
|
||||
"""allow_from=[] on exact should NOT inherit from wildcard or top."""
|
||||
cfg = CommentsConfig(
|
||||
allow_from=frozenset(["ou_top"]),
|
||||
documents={
|
||||
"*": CommentDocumentRule(allow_from=frozenset(["ou_wildcard"])),
|
||||
"docx:abc": CommentDocumentRule(
|
||||
policy="allowlist",
|
||||
allow_from=frozenset(),
|
||||
),
|
||||
},
|
||||
)
|
||||
rule = resolve_rule(cfg, "docx", "abc")
|
||||
self.assertEqual(rule.allow_from, frozenset())
|
||||
|
||||
def test_wiki_token_match(self):
|
||||
cfg = CommentsConfig(
|
||||
policy="pairing",
|
||||
documents={
|
||||
"wiki:WIKI123": CommentDocumentRule(policy="allowlist"),
|
||||
},
|
||||
)
|
||||
rule = resolve_rule(cfg, "docx", "obj_token", wiki_token="WIKI123")
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertTrue(rule.match_source.startswith("exact:wiki:"))
|
||||
|
||||
def test_exact_takes_priority_over_wiki(self):
|
||||
cfg = CommentsConfig(
|
||||
documents={
|
||||
"docx:abc": CommentDocumentRule(policy="allowlist"),
|
||||
"wiki:WIKI123": CommentDocumentRule(policy="pairing"),
|
||||
},
|
||||
)
|
||||
rule = resolve_rule(cfg, "docx", "abc", wiki_token="WIKI123")
|
||||
self.assertEqual(rule.policy, "allowlist")
|
||||
self.assertTrue(rule.match_source.startswith("exact:docx:"))
|
||||
|
||||
def test_default_config(self):
|
||||
cfg = CommentsConfig()
|
||||
rule = resolve_rule(cfg, "docx", "anything")
|
||||
self.assertTrue(rule.enabled)
|
||||
self.assertEqual(rule.policy, "pairing")
|
||||
self.assertEqual(rule.allow_from, frozenset())
|
||||
|
||||
|
||||
class TestHasWikiKeys(unittest.TestCase):
|
||||
def test_no_wiki_keys(self):
|
||||
cfg = CommentsConfig(documents={
|
||||
"docx:abc": CommentDocumentRule(policy="allowlist"),
|
||||
"*": CommentDocumentRule(policy="pairing"),
|
||||
})
|
||||
self.assertFalse(has_wiki_keys(cfg))
|
||||
|
||||
def test_has_wiki_keys(self):
|
||||
cfg = CommentsConfig(documents={
|
||||
"wiki:WIKI123": CommentDocumentRule(policy="allowlist"),
|
||||
})
|
||||
self.assertTrue(has_wiki_keys(cfg))
|
||||
|
||||
def test_empty_documents(self):
|
||||
cfg = CommentsConfig()
|
||||
self.assertFalse(has_wiki_keys(cfg))
|
||||
|
||||
|
||||
class TestIsUserAllowed(unittest.TestCase):
|
||||
def test_allowlist_allows_listed(self):
|
||||
rule = ResolvedCommentRule(True, "allowlist", frozenset(["ou_a"]), "top")
|
||||
self.assertTrue(is_user_allowed(rule, "ou_a"))
|
||||
|
||||
def test_allowlist_denies_unlisted(self):
|
||||
rule = ResolvedCommentRule(True, "allowlist", frozenset(["ou_a"]), "top")
|
||||
self.assertFalse(is_user_allowed(rule, "ou_b"))
|
||||
|
||||
def test_allowlist_empty_denies_all(self):
|
||||
rule = ResolvedCommentRule(True, "allowlist", frozenset(), "top")
|
||||
self.assertFalse(is_user_allowed(rule, "ou_anyone"))
|
||||
|
||||
def test_pairing_allows_in_allow_from(self):
|
||||
rule = ResolvedCommentRule(True, "pairing", frozenset(["ou_a"]), "top")
|
||||
self.assertTrue(is_user_allowed(rule, "ou_a"))
|
||||
|
||||
def test_pairing_checks_store(self):
|
||||
rule = ResolvedCommentRule(True, "pairing", frozenset(), "top")
|
||||
with patch(
|
||||
"gateway.platforms.feishu_comment_rules._load_pairing_approved",
|
||||
return_value={"ou_approved"},
|
||||
):
|
||||
self.assertTrue(is_user_allowed(rule, "ou_approved"))
|
||||
self.assertFalse(is_user_allowed(rule, "ou_unknown"))
|
||||
|
||||
|
||||
class TestMtimeCache(unittest.TestCase):
|
||||
def test_returns_empty_dict_for_missing_file(self):
|
||||
cache = _MtimeCache(Path("/nonexistent/path.json"))
|
||||
self.assertEqual(cache.load(), {})
|
||||
|
||||
def test_reads_file_and_caches(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump({"key": "value"}, f)
|
||||
f.flush()
|
||||
path = Path(f.name)
|
||||
try:
|
||||
cache = _MtimeCache(path)
|
||||
data = cache.load()
|
||||
self.assertEqual(data, {"key": "value"})
|
||||
# Second load should use cache (same mtime)
|
||||
data2 = cache.load()
|
||||
self.assertEqual(data2, {"key": "value"})
|
||||
finally:
|
||||
path.unlink()
|
||||
|
||||
def test_reloads_on_mtime_change(self):
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump({"v": 1}, f)
|
||||
f.flush()
|
||||
path = Path(f.name)
|
||||
try:
|
||||
cache = _MtimeCache(path)
|
||||
self.assertEqual(cache.load(), {"v": 1})
|
||||
# Modify file
|
||||
time.sleep(0.05)
|
||||
with open(path, "w") as f2:
|
||||
json.dump({"v": 2}, f2)
|
||||
# Force mtime change detection
|
||||
os.utime(path, (time.time() + 1, time.time() + 1))
|
||||
self.assertEqual(cache.load(), {"v": 2})
|
||||
finally:
|
||||
path.unlink()
|
||||
|
||||
|
||||
class TestLoadConfig(unittest.TestCase):
|
||||
def test_load_with_documents(self):
|
||||
raw = {
|
||||
"enabled": True,
|
||||
"policy": "allowlist",
|
||||
"allow_from": ["ou_a"],
|
||||
"documents": {
|
||||
"*": {"policy": "pairing"},
|
||||
"docx:abc": {"policy": "allowlist", "allow_from": ["ou_b"]},
|
||||
},
|
||||
}
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
||||
json.dump(raw, f)
|
||||
path = Path(f.name)
|
||||
try:
|
||||
with patch("gateway.platforms.feishu_comment_rules.RULES_FILE", path):
|
||||
with patch("gateway.platforms.feishu_comment_rules._rules_cache", _MtimeCache(path)):
|
||||
cfg = load_config()
|
||||
self.assertTrue(cfg.enabled)
|
||||
self.assertEqual(cfg.policy, "allowlist")
|
||||
self.assertEqual(cfg.allow_from, frozenset(["ou_a"]))
|
||||
self.assertIn("*", cfg.documents)
|
||||
self.assertIn("docx:abc", cfg.documents)
|
||||
self.assertEqual(cfg.documents["docx:abc"].policy, "allowlist")
|
||||
finally:
|
||||
path.unlink()
|
||||
|
||||
def test_load_missing_file_returns_defaults(self):
|
||||
with patch("gateway.platforms.feishu_comment_rules._rules_cache", _MtimeCache(Path("/nonexistent"))):
|
||||
cfg = load_config()
|
||||
self.assertTrue(cfg.enabled)
|
||||
self.assertEqual(cfg.policy, "pairing")
|
||||
self.assertEqual(cfg.allow_from, frozenset())
|
||||
self.assertEqual(cfg.documents, {})
|
||||
|
||||
|
||||
class TestPairingStore(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._tmpdir = tempfile.mkdtemp()
|
||||
self._pairing_file = Path(self._tmpdir) / "pairing.json"
|
||||
with open(self._pairing_file, "w") as f:
|
||||
json.dump({"approved": {}}, f)
|
||||
self._patcher_file = patch("gateway.platforms.feishu_comment_rules.PAIRING_FILE", self._pairing_file)
|
||||
self._patcher_cache = patch(
|
||||
"gateway.platforms.feishu_comment_rules._pairing_cache",
|
||||
_MtimeCache(self._pairing_file),
|
||||
)
|
||||
self._patcher_file.start()
|
||||
self._patcher_cache.start()
|
||||
|
||||
def tearDown(self):
|
||||
self._patcher_cache.stop()
|
||||
self._patcher_file.stop()
|
||||
if self._pairing_file.exists():
|
||||
self._pairing_file.unlink()
|
||||
os.rmdir(self._tmpdir)
|
||||
|
||||
def test_add_and_list(self):
|
||||
self.assertTrue(pairing_add("ou_new"))
|
||||
approved = pairing_list()
|
||||
self.assertIn("ou_new", approved)
|
||||
|
||||
def test_add_duplicate(self):
|
||||
pairing_add("ou_a")
|
||||
self.assertFalse(pairing_add("ou_a"))
|
||||
|
||||
def test_remove(self):
|
||||
pairing_add("ou_a")
|
||||
self.assertTrue(pairing_remove("ou_a"))
|
||||
self.assertNotIn("ou_a", pairing_list())
|
||||
|
||||
def test_remove_nonexistent(self):
|
||||
self.assertFalse(pairing_remove("ou_nobody"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -179,7 +179,7 @@ class TestVoiceAttachmentSSRFProtection:
|
|||
from gateway.platforms.qqbot import QQAdapter, _ssrf_redirect_guard
|
||||
|
||||
client = mock.AsyncMock()
|
||||
with mock.patch("gateway.platforms.qqbot.httpx.AsyncClient", return_value=client) as async_client_cls:
|
||||
with mock.patch("gateway.platforms.qqbot.adapter.httpx.AsyncClient", return_value=client) as async_client_cls:
|
||||
adapter = QQAdapter(_make_config(app_id="a", client_secret="b"))
|
||||
adapter._ensure_token = mock.AsyncMock(side_effect=RuntimeError("stop after client creation"))
|
||||
|
||||
|
|
|
|||
|
|
@ -202,3 +202,120 @@ async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_p
|
|||
|
||||
assert ok is True
|
||||
assert calls == [(42, False), (42, True)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_gateway_replace_writes_takeover_marker_before_sigterm(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
"""--replace must write a takeover marker BEFORE sending SIGTERM.
|
||||
|
||||
The marker lets the target's shutdown handler identify the signal as a
|
||||
planned takeover (→ exit 0) rather than an unexpected kill (→ exit 1).
|
||||
Without the marker, PR #5646's signal-recovery path would revive the
|
||||
target via systemd Restart=on-failure, starting a flap loop.
|
||||
"""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
# Record the ORDER of marker-write + terminate_pid calls
|
||||
events: list[str] = []
|
||||
marker_paths_seen: list = []
|
||||
|
||||
def record_write_marker(target_pid: int) -> bool:
|
||||
events.append(f"write_marker(target_pid={target_pid})")
|
||||
# Also check that the marker file actually exists after this call
|
||||
marker_paths_seen.append(
|
||||
(tmp_path / ".gateway-takeover.json").exists() is False # not yet
|
||||
)
|
||||
# Actually write the marker so we can verify cleanup later
|
||||
from gateway.status import _get_takeover_marker_path, _write_json_file, _get_process_start_time
|
||||
_write_json_file(_get_takeover_marker_path(), {
|
||||
"target_pid": target_pid,
|
||||
"target_start_time": 0,
|
||||
"replacer_pid": 100,
|
||||
"written_at": "2026-04-17T00:00:00+00:00",
|
||||
})
|
||||
return True
|
||||
|
||||
def record_terminate(pid, force=False):
|
||||
events.append(f"terminate_pid(pid={pid}, force={force})")
|
||||
|
||||
class _CleanExitRunner:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.should_exit_cleanly = True
|
||||
self.exit_reason = None
|
||||
self.adapters = {}
|
||||
|
||||
async def start(self):
|
||||
return True
|
||||
|
||||
async def stop(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 42)
|
||||
monkeypatch.setattr("gateway.status.remove_pid_file", lambda: None)
|
||||
monkeypatch.setattr("gateway.status.release_all_scoped_locks", lambda: 0)
|
||||
monkeypatch.setattr("gateway.status.write_takeover_marker", record_write_marker)
|
||||
monkeypatch.setattr("gateway.status.terminate_pid", record_terminate)
|
||||
monkeypatch.setattr("gateway.run.os.getpid", lambda: 100)
|
||||
# Simulate old process exiting on first check so we don't loop into force-kill
|
||||
monkeypatch.setattr(
|
||||
"gateway.run.os.kill",
|
||||
lambda pid, sig: (_ for _ in ()).throw(ProcessLookupError()),
|
||||
)
|
||||
monkeypatch.setattr("time.sleep", lambda _: None)
|
||||
monkeypatch.setattr("tools.skills_sync.sync_skills", lambda quiet=True: None)
|
||||
monkeypatch.setattr("hermes_logging.setup_logging", lambda hermes_home, mode: tmp_path)
|
||||
monkeypatch.setattr("hermes_logging._add_rotating_handler", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr("gateway.run.GatewayRunner", _CleanExitRunner)
|
||||
|
||||
from gateway.run import start_gateway
|
||||
|
||||
ok = await start_gateway(config=GatewayConfig(), replace=True, verbosity=None)
|
||||
|
||||
assert ok is True
|
||||
# Ordering: marker written BEFORE SIGTERM
|
||||
assert events[0] == "write_marker(target_pid=42)"
|
||||
assert any(e.startswith("terminate_pid(pid=42") for e in events[1:])
|
||||
# Marker file cleanup: replacer cleans it after loop completes
|
||||
assert not (tmp_path / ".gateway-takeover.json").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_gateway_replace_clears_marker_on_permission_denied(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
"""If we fail to kill the existing PID (permission denied), clean up the
|
||||
marker so it doesn't grief an unrelated future shutdown."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
def write_marker(target_pid: int) -> bool:
|
||||
from gateway.status import _get_takeover_marker_path, _write_json_file
|
||||
_write_json_file(_get_takeover_marker_path(), {
|
||||
"target_pid": target_pid,
|
||||
"target_start_time": 0,
|
||||
"replacer_pid": 100,
|
||||
"written_at": "2026-04-17T00:00:00+00:00",
|
||||
})
|
||||
return True
|
||||
|
||||
def raise_permission(pid, force=False):
|
||||
raise PermissionError("simulated EPERM")
|
||||
|
||||
monkeypatch.setattr("gateway.status.get_running_pid", lambda: 42)
|
||||
monkeypatch.setattr("gateway.status.write_takeover_marker", write_marker)
|
||||
monkeypatch.setattr("gateway.status.terminate_pid", raise_permission)
|
||||
monkeypatch.setattr("gateway.run.os.getpid", lambda: 100)
|
||||
monkeypatch.setattr("tools.skills_sync.sync_skills", lambda quiet=True: None)
|
||||
monkeypatch.setattr("hermes_logging.setup_logging", lambda hermes_home, mode: tmp_path)
|
||||
monkeypatch.setattr("hermes_logging._add_rotating_handler", lambda *args, **kwargs: None)
|
||||
|
||||
from gateway.run import start_gateway
|
||||
|
||||
# Should return False due to permission error
|
||||
ok = await start_gateway(config=GatewayConfig(), replace=True, verbosity=None)
|
||||
|
||||
assert ok is False
|
||||
# Marker must NOT be left behind
|
||||
assert not (tmp_path / ".gateway-takeover.json").exists()
|
||||
|
|
|
|||
|
|
@ -288,6 +288,38 @@ async def test_command_messages_do_not_leave_sentinel():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("command_text", "handler_attr", "handler_result"),
|
||||
[
|
||||
("/help", "_handle_help_command", "Help text"),
|
||||
("/commands", "_handle_commands_command", "Commands text"),
|
||||
("/update", "_handle_update_command", "Update text"),
|
||||
("/profile", "_handle_profile_command", "Profile text"),
|
||||
],
|
||||
)
|
||||
async def test_active_session_bypass_commands_dispatch_without_interrupt(
|
||||
command_text,
|
||||
handler_attr,
|
||||
handler_result,
|
||||
):
|
||||
"""Gateway-handled bypass commands must return directly while an agent runs."""
|
||||
runner = _make_runner()
|
||||
event = _make_event(text=command_text)
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
fake_agent = MagicMock()
|
||||
fake_agent.get_activity_summary.return_value = {"seconds_since_activity": 0}
|
||||
runner._running_agents[session_key] = fake_agent
|
||||
setattr(runner, handler_attr, AsyncMock(return_value=handler_result))
|
||||
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
assert result == handler_result
|
||||
fake_agent.interrupt.assert_not_called()
|
||||
assert session_key not in runner.adapters[Platform.TELEGRAM]._pending_messages
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Test 6: /stop during sentinel force-cleans and unlocks session
|
||||
# ------------------------------------------------------------------
|
||||
|
|
|
|||
231
tests/gateway/test_session_state_cleanup.py
Normal file
231
tests/gateway/test_session_state_cleanup.py
Normal file
|
|
@ -0,0 +1,231 @@
|
|||
"""Regression tests for _release_running_agent_state and SessionDB shutdown.
|
||||
|
||||
Before this change, running-agent state lived in three dicts that drifted
|
||||
out of sync:
|
||||
|
||||
self._running_agents — AIAgent instance per session key
|
||||
self._running_agents_ts — start timestamp per session key
|
||||
self._busy_ack_ts — last busy-ack timestamp per session key
|
||||
|
||||
Six cleanup sites did ``del self._running_agents[key]`` without touching
|
||||
the other two; one site only popped ``_running_agents`` and
|
||||
``_running_agents_ts``; and only the stale-eviction site cleaned all
|
||||
three. Each missed entry was a small persistent leak.
|
||||
|
||||
Also: SessionDB connections were never closed on gateway shutdown,
|
||||
leaving WAL locks in place until Python actually exited.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_runner():
|
||||
"""Bare GatewayRunner wired with just the state the helper touches."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._busy_ack_ts = {}
|
||||
return runner
|
||||
|
||||
|
||||
class TestReleaseRunningAgentStateUnit:
|
||||
def test_pops_all_three_dicts(self):
|
||||
runner = _make_runner()
|
||||
runner._running_agents["k"] = MagicMock()
|
||||
runner._running_agents_ts["k"] = 123.0
|
||||
runner._busy_ack_ts["k"] = 456.0
|
||||
|
||||
runner._release_running_agent_state("k")
|
||||
|
||||
assert "k" not in runner._running_agents
|
||||
assert "k" not in runner._running_agents_ts
|
||||
assert "k" not in runner._busy_ack_ts
|
||||
|
||||
def test_idempotent_on_missing_key(self):
|
||||
"""Calling twice (or on an absent key) must not raise."""
|
||||
runner = _make_runner()
|
||||
runner._release_running_agent_state("missing")
|
||||
runner._release_running_agent_state("missing") # still fine
|
||||
|
||||
def test_noop_on_empty_session_key(self):
|
||||
"""Empty string / None key is treated as a no-op."""
|
||||
runner = _make_runner()
|
||||
runner._running_agents[""] = "guard"
|
||||
runner._release_running_agent_state("")
|
||||
# Empty key not processed — guard value survives.
|
||||
assert runner._running_agents[""] == "guard"
|
||||
|
||||
def test_preserves_other_sessions(self):
|
||||
runner = _make_runner()
|
||||
for k in ("a", "b", "c"):
|
||||
runner._running_agents[k] = MagicMock()
|
||||
runner._running_agents_ts[k] = 1.0
|
||||
runner._busy_ack_ts[k] = 1.0
|
||||
|
||||
runner._release_running_agent_state("b")
|
||||
|
||||
assert set(runner._running_agents.keys()) == {"a", "c"}
|
||||
assert set(runner._running_agents_ts.keys()) == {"a", "c"}
|
||||
assert set(runner._busy_ack_ts.keys()) == {"a", "c"}
|
||||
|
||||
def test_handles_missing_busy_ack_attribute(self):
|
||||
"""Backward-compatible with older runners lacking _busy_ack_ts."""
|
||||
runner = _make_runner()
|
||||
del runner._busy_ack_ts # simulate older version
|
||||
runner._running_agents["k"] = MagicMock()
|
||||
runner._running_agents_ts["k"] = 1.0
|
||||
|
||||
runner._release_running_agent_state("k") # should not raise
|
||||
|
||||
assert "k" not in runner._running_agents
|
||||
assert "k" not in runner._running_agents_ts
|
||||
|
||||
def test_concurrent_release_is_safe(self):
|
||||
"""Multiple threads releasing different keys concurrently."""
|
||||
runner = _make_runner()
|
||||
for i in range(50):
|
||||
k = f"s{i}"
|
||||
runner._running_agents[k] = MagicMock()
|
||||
runner._running_agents_ts[k] = float(i)
|
||||
runner._busy_ack_ts[k] = float(i)
|
||||
|
||||
def worker(keys):
|
||||
for k in keys:
|
||||
runner._release_running_agent_state(k)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=worker, args=([f"s{i}" for i in range(start, 50, 5)],))
|
||||
for start in range(5)
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
assert not t.is_alive()
|
||||
|
||||
assert runner._running_agents == {}
|
||||
assert runner._running_agents_ts == {}
|
||||
assert runner._busy_ack_ts == {}
|
||||
|
||||
|
||||
class TestNoMoreBareDeleteSites:
|
||||
"""Regression: all bare `del self._running_agents[key]` sites were
|
||||
converted to use the helper. If a future contributor reverts one,
|
||||
this test flags it. Docstrings / comments mentioning the old
|
||||
pattern are allowed.
|
||||
"""
|
||||
|
||||
def test_no_bare_del_of_running_agents_in_gateway_run(self):
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
gateway_run = (Path(__file__).parent.parent.parent / "gateway" / "run.py").read_text()
|
||||
# Match `del self._running_agents[...]` that is NOT inside a
|
||||
# triple-quoted docstring. We scan non-docstring lines only.
|
||||
lines = gateway_run.splitlines()
|
||||
|
||||
in_docstring = False
|
||||
docstring_delim = None
|
||||
offenders = []
|
||||
for idx, line in enumerate(lines, start=1):
|
||||
stripped = line.strip()
|
||||
if not in_docstring:
|
||||
if stripped.startswith('"""') or stripped.startswith("'''"):
|
||||
delim = stripped[:3]
|
||||
# single-line docstring?
|
||||
if stripped.count(delim) >= 2:
|
||||
continue
|
||||
in_docstring = True
|
||||
docstring_delim = delim
|
||||
continue
|
||||
if re.search(r"\bdel\s+self\._running_agents\[", line):
|
||||
offenders.append((idx, line.rstrip()))
|
||||
else:
|
||||
if docstring_delim and docstring_delim in stripped:
|
||||
in_docstring = False
|
||||
docstring_delim = None
|
||||
|
||||
assert offenders == [], (
|
||||
"Found bare `del self._running_agents[...]` sites in gateway/run.py. "
|
||||
"Use self._release_running_agent_state(session_key) instead so "
|
||||
"_running_agents_ts and _busy_ack_ts are popped in lockstep.\n"
|
||||
+ "\n".join(f" line {n}: {l}" for n, l in offenders)
|
||||
)
|
||||
|
||||
|
||||
class TestSessionDbCloseOnShutdown:
|
||||
"""_stop_impl should call .close() on both self._session_db and
|
||||
self.session_store._db to release SQLite WAL locks before the new
|
||||
gateway (during --replace restart) tries to open the same file.
|
||||
"""
|
||||
|
||||
def test_stop_impl_closes_both_session_dbs(self):
|
||||
"""Run the exact shutdown block that closes SessionDBs and verify
|
||||
.close() was called on both holders."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
|
||||
runner_db = MagicMock()
|
||||
store_db = MagicMock()
|
||||
|
||||
runner._db = runner_db
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store._db = store_db
|
||||
|
||||
# Replicate the exact production loop from _stop_impl.
|
||||
for _db_holder in (runner, getattr(runner, "session_store", None)):
|
||||
_db = getattr(_db_holder, "_db", None) if _db_holder else None
|
||||
if _db is None or not hasattr(_db, "close"):
|
||||
continue
|
||||
_db.close()
|
||||
|
||||
runner_db.close.assert_called_once()
|
||||
store_db.close.assert_called_once()
|
||||
|
||||
def test_shutdown_tolerates_missing_session_store(self):
|
||||
"""Gateway without a session_store attribute must not crash on shutdown."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._db = MagicMock()
|
||||
# Deliberately no session_store attribute.
|
||||
|
||||
for _db_holder in (runner, getattr(runner, "session_store", None)):
|
||||
_db = getattr(_db_holder, "_db", None) if _db_holder else None
|
||||
if _db is None or not hasattr(_db, "close"):
|
||||
continue
|
||||
_db.close()
|
||||
|
||||
runner._db.close.assert_called_once()
|
||||
|
||||
def test_shutdown_tolerates_close_raising(self):
|
||||
"""A close() that raises must not prevent subsequent cleanup."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
flaky_db = MagicMock()
|
||||
flaky_db.close.side_effect = RuntimeError("simulated lock error")
|
||||
healthy_db = MagicMock()
|
||||
|
||||
runner._db = flaky_db
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store._db = healthy_db
|
||||
|
||||
# Same pattern as production: try/except around each close().
|
||||
for _db_holder in (runner, getattr(runner, "session_store", None)):
|
||||
_db = getattr(_db_holder, "_db", None) if _db_holder else None
|
||||
if _db is None or not hasattr(_db, "close"):
|
||||
continue
|
||||
try:
|
||||
_db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
flaky_db.close.assert_called_once()
|
||||
healthy_db.close.assert_called_once()
|
||||
270
tests/gateway/test_session_store_prune.py
Normal file
270
tests/gateway/test_session_store_prune.py
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
"""Tests for SessionStore.prune_old_entries and the gateway watcher that calls it.
|
||||
|
||||
The SessionStore in-memory dict (and its backing sessions.json) grew
|
||||
unbounded — every unique (platform, chat_id, thread_id, user_id) tuple
|
||||
ever seen was kept forever, regardless of how stale it became. These
|
||||
tests pin the prune behaviour:
|
||||
|
||||
* Entries older than max_age_days (by updated_at) are removed
|
||||
* Entries marked ``suspended`` are preserved (user-paused)
|
||||
* Entries with an active process attached are preserved
|
||||
* max_age_days <= 0 disables pruning entirely
|
||||
* sessions.json is rewritten with the post-prune dict
|
||||
* The ``updated_at`` field — not ``created_at`` — drives the decision
|
||||
(so a long-running-but-still-active session isn't pruned)
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, SessionResetPolicy
|
||||
from gateway.session import SessionEntry, SessionStore
|
||||
|
||||
|
||||
def _make_store(tmp_path, max_age_days: int = 90, has_active_processes_fn=None):
|
||||
"""Build a SessionStore bypassing SQLite/disk-load side effects."""
|
||||
config = GatewayConfig(
|
||||
default_reset_policy=SessionResetPolicy(mode="none"),
|
||||
session_store_max_age_days=max_age_days,
|
||||
)
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(
|
||||
sessions_dir=tmp_path,
|
||||
config=config,
|
||||
has_active_processes_fn=has_active_processes_fn,
|
||||
)
|
||||
store._db = None
|
||||
store._loaded = True
|
||||
return store
|
||||
|
||||
|
||||
def _entry(key: str, age_days: float, *, suspended: bool = False,
|
||||
session_id: str | None = None) -> SessionEntry:
|
||||
now = datetime.now()
|
||||
return SessionEntry(
|
||||
session_key=key,
|
||||
session_id=session_id or f"sid_{key}",
|
||||
created_at=now - timedelta(days=age_days + 30), # arbitrary older
|
||||
updated_at=now - timedelta(days=age_days),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
suspended=suspended,
|
||||
)
|
||||
|
||||
|
||||
class TestPruneBasics:
|
||||
def test_prune_removes_entries_past_max_age(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
store._entries["old"] = _entry("old", age_days=100)
|
||||
store._entries["fresh"] = _entry("fresh", age_days=5)
|
||||
|
||||
removed = store.prune_old_entries(max_age_days=90)
|
||||
|
||||
assert removed == 1
|
||||
assert "old" not in store._entries
|
||||
assert "fresh" in store._entries
|
||||
|
||||
def test_prune_uses_updated_at_not_created_at(self, tmp_path):
|
||||
"""A session created long ago but updated recently must be kept."""
|
||||
store = _make_store(tmp_path)
|
||||
now = datetime.now()
|
||||
entry = SessionEntry(
|
||||
session_key="long-lived",
|
||||
session_id="sid",
|
||||
created_at=now - timedelta(days=365), # ancient
|
||||
updated_at=now - timedelta(days=3), # but just chatted
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
store._entries["long-lived"] = entry
|
||||
|
||||
removed = store.prune_old_entries(max_age_days=30)
|
||||
|
||||
assert removed == 0
|
||||
assert "long-lived" in store._entries
|
||||
|
||||
def test_prune_disabled_when_max_age_is_zero(self, tmp_path):
|
||||
store = _make_store(tmp_path, max_age_days=0)
|
||||
for i in range(5):
|
||||
store._entries[f"s{i}"] = _entry(f"s{i}", age_days=365)
|
||||
|
||||
assert store.prune_old_entries(0) == 0
|
||||
assert len(store._entries) == 5
|
||||
|
||||
def test_prune_disabled_when_max_age_is_negative(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
store._entries["s"] = _entry("s", age_days=365)
|
||||
|
||||
assert store.prune_old_entries(-1) == 0
|
||||
assert "s" in store._entries
|
||||
|
||||
def test_prune_skips_suspended_entries(self, tmp_path):
|
||||
"""/stop-suspended sessions must be kept for later resume."""
|
||||
store = _make_store(tmp_path)
|
||||
store._entries["suspended"] = _entry(
|
||||
"suspended", age_days=1000, suspended=True
|
||||
)
|
||||
store._entries["idle"] = _entry("idle", age_days=1000)
|
||||
|
||||
removed = store.prune_old_entries(max_age_days=90)
|
||||
|
||||
assert removed == 1
|
||||
assert "suspended" in store._entries
|
||||
assert "idle" not in store._entries
|
||||
|
||||
def test_prune_skips_entries_with_active_processes(self, tmp_path):
|
||||
"""Sessions with active bg processes aren't pruned even if old."""
|
||||
active_session_ids = {"sid_active"}
|
||||
|
||||
def _has_active(session_id: str) -> bool:
|
||||
return session_id in active_session_ids
|
||||
|
||||
store = _make_store(tmp_path, has_active_processes_fn=_has_active)
|
||||
store._entries["active"] = _entry(
|
||||
"active", age_days=1000, session_id="sid_active"
|
||||
)
|
||||
store._entries["idle"] = _entry(
|
||||
"idle", age_days=1000, session_id="sid_idle"
|
||||
)
|
||||
|
||||
removed = store.prune_old_entries(max_age_days=90)
|
||||
|
||||
assert removed == 1
|
||||
assert "active" in store._entries
|
||||
assert "idle" not in store._entries
|
||||
|
||||
def test_prune_does_not_write_disk_when_no_removals(self, tmp_path):
|
||||
"""If nothing is evictable, _save() should NOT be called."""
|
||||
store = _make_store(tmp_path)
|
||||
store._entries["fresh1"] = _entry("fresh1", age_days=1)
|
||||
store._entries["fresh2"] = _entry("fresh2", age_days=2)
|
||||
|
||||
save_calls = []
|
||||
store._save = lambda: save_calls.append(1)
|
||||
|
||||
assert store.prune_old_entries(max_age_days=90) == 0
|
||||
assert save_calls == []
|
||||
|
||||
def test_prune_writes_disk_after_removal(self, tmp_path):
|
||||
store = _make_store(tmp_path)
|
||||
store._entries["stale"] = _entry("stale", age_days=500)
|
||||
store._entries["fresh"] = _entry("fresh", age_days=1)
|
||||
|
||||
save_calls = []
|
||||
store._save = lambda: save_calls.append(1)
|
||||
|
||||
store.prune_old_entries(max_age_days=90)
|
||||
assert save_calls == [1]
|
||||
|
||||
def test_prune_is_thread_safe(self, tmp_path):
|
||||
"""Prune acquires _lock internally; concurrent update_session is safe."""
|
||||
store = _make_store(tmp_path)
|
||||
for i in range(20):
|
||||
age = 1000 if i % 2 == 0 else 1
|
||||
store._entries[f"s{i}"] = _entry(f"s{i}", age_days=age)
|
||||
|
||||
results = []
|
||||
|
||||
def _pruner():
|
||||
results.append(store.prune_old_entries(max_age_days=90))
|
||||
|
||||
def _reader():
|
||||
# Mimic a concurrent update_session reader iterating under lock.
|
||||
with store._lock:
|
||||
list(store._entries.keys())
|
||||
|
||||
threads = [threading.Thread(target=_pruner)]
|
||||
threads += [threading.Thread(target=_reader) for _ in range(4)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
assert not t.is_alive()
|
||||
|
||||
# Exactly one pruner ran; removed exactly the 10 stale entries.
|
||||
assert results == [10]
|
||||
assert len(store._entries) == 10
|
||||
for i in range(20):
|
||||
if i % 2 == 1: # fresh
|
||||
assert f"s{i}" in store._entries
|
||||
|
||||
|
||||
class TestPrunePersistsToDisk:
|
||||
def test_prune_rewrites_sessions_json(self, tmp_path):
|
||||
"""After prune, sessions.json on disk reflects the new dict."""
|
||||
config = GatewayConfig(
|
||||
default_reset_policy=SessionResetPolicy(mode="none"),
|
||||
session_store_max_age_days=90,
|
||||
)
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._db = None
|
||||
# Force-populate without calling get_or_create to avoid DB side-effects
|
||||
store._entries["stale"] = _entry("stale", age_days=500)
|
||||
store._entries["fresh"] = _entry("fresh", age_days=1)
|
||||
store._loaded = True
|
||||
store._save()
|
||||
|
||||
# Verify pre-prune state on disk.
|
||||
saved_pre = json.loads((tmp_path / "sessions.json").read_text())
|
||||
assert set(saved_pre.keys()) == {"stale", "fresh"}
|
||||
|
||||
# Prune and check disk.
|
||||
store.prune_old_entries(max_age_days=90)
|
||||
saved_post = json.loads((tmp_path / "sessions.json").read_text())
|
||||
assert set(saved_post.keys()) == {"fresh"}
|
||||
|
||||
|
||||
class TestGatewayConfigSerialization:
|
||||
def test_session_store_max_age_days_defaults_to_90(self):
|
||||
cfg = GatewayConfig()
|
||||
assert cfg.session_store_max_age_days == 90
|
||||
|
||||
def test_session_store_max_age_days_roundtrips(self):
|
||||
cfg = GatewayConfig(session_store_max_age_days=30)
|
||||
restored = GatewayConfig.from_dict(cfg.to_dict())
|
||||
assert restored.session_store_max_age_days == 30
|
||||
|
||||
def test_session_store_max_age_days_missing_defaults_90(self):
|
||||
"""Loading an old config (pre-this-field) falls back to default."""
|
||||
restored = GatewayConfig.from_dict({})
|
||||
assert restored.session_store_max_age_days == 90
|
||||
|
||||
def test_session_store_max_age_days_negative_coerced_to_zero(self):
|
||||
"""A negative value (accidental or hostile) becomes 0 (disabled)."""
|
||||
restored = GatewayConfig.from_dict({"session_store_max_age_days": -5})
|
||||
assert restored.session_store_max_age_days == 0
|
||||
|
||||
def test_session_store_max_age_days_bad_type_falls_back(self):
|
||||
"""Non-int values fall back to the default, not a crash."""
|
||||
restored = GatewayConfig.from_dict({"session_store_max_age_days": "nope"})
|
||||
assert restored.session_store_max_age_days == 90
|
||||
|
||||
|
||||
class TestGatewayWatcherCallsPrune:
|
||||
"""The session_expiry_watcher should call prune_old_entries once per hour."""
|
||||
|
||||
def test_prune_gate_fires_on_first_tick(self):
|
||||
"""First watcher tick has _last_prune_ts=0, so the gate opens."""
|
||||
import time as _t
|
||||
|
||||
last_ts = 0.0
|
||||
prune_interval = 3600.0
|
||||
now = _t.time()
|
||||
|
||||
# Mirror the production gate check in _session_expiry_watcher.
|
||||
should_prune = (now - last_ts) > prune_interval
|
||||
assert should_prune is True
|
||||
|
||||
def test_prune_gate_suppresses_within_interval(self):
|
||||
import time as _t
|
||||
|
||||
last_ts = _t.time() - 600 # 10 minutes ago
|
||||
prune_interval = 3600.0
|
||||
now = _t.time()
|
||||
|
||||
should_prune = (now - last_ts) > prune_interval
|
||||
assert should_prune is False
|
||||
|
|
@ -63,6 +63,24 @@ class TestGatewayPidState:
|
|||
|
||||
assert status.get_running_pid() == os.getpid()
|
||||
|
||||
def test_get_running_pid_accepts_explicit_pid_path_without_cleanup(self, tmp_path, monkeypatch):
|
||||
other_home = tmp_path / "profile-home"
|
||||
other_home.mkdir()
|
||||
pid_path = other_home / "gateway.pid"
|
||||
pid_path.write_text(json.dumps({
|
||||
"pid": os.getpid(),
|
||||
"kind": "hermes-gateway",
|
||||
"argv": ["python", "-m", "hermes_cli.main", "gateway"],
|
||||
"start_time": 123,
|
||||
}))
|
||||
|
||||
monkeypatch.setattr(status.os, "kill", lambda pid, sig: None)
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 123)
|
||||
monkeypatch.setattr(status, "_read_process_cmdline", lambda pid: None)
|
||||
|
||||
assert status.get_running_pid(pid_path, cleanup_stale=False) == os.getpid()
|
||||
assert pid_path.exists()
|
||||
|
||||
|
||||
class TestGatewayRuntimeStatus:
|
||||
def test_write_runtime_status_overwrites_stale_pid_on_restart(self, tmp_path, monkeypatch):
|
||||
|
|
@ -246,3 +264,181 @@ class TestScopedLocks:
|
|||
|
||||
status.release_scoped_lock("telegram-bot-token", "secret")
|
||||
assert not lock_path.exists()
|
||||
|
||||
|
||||
class TestTakeoverMarker:
|
||||
"""Tests for the --replace takeover marker.
|
||||
|
||||
The marker breaks the post-#5646 flap loop between two gateway services
|
||||
fighting for the same bot token. The replacer writes a file naming the
|
||||
target PID + start_time; the target's shutdown handler sees it and exits
|
||||
0 instead of 1, so systemd's Restart=on-failure doesn't revive it.
|
||||
"""
|
||||
|
||||
def test_write_marker_records_target_identity(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 42)
|
||||
|
||||
ok = status.write_takeover_marker(target_pid=12345)
|
||||
|
||||
assert ok is True
|
||||
marker = tmp_path / ".gateway-takeover.json"
|
||||
assert marker.exists()
|
||||
payload = json.loads(marker.read_text())
|
||||
assert payload["target_pid"] == 12345
|
||||
assert payload["target_start_time"] == 42
|
||||
assert payload["replacer_pid"] == os.getpid()
|
||||
assert "written_at" in payload
|
||||
|
||||
def test_consume_returns_true_when_marker_names_self(self, tmp_path, monkeypatch):
|
||||
"""Primary happy path: planned takeover is recognised."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Mark THIS process as the target
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 100)
|
||||
ok = status.write_takeover_marker(target_pid=os.getpid())
|
||||
assert ok is True
|
||||
|
||||
# Call consume as if this process just got SIGTERMed
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
assert result is True
|
||||
# Marker must be unlinked after consumption
|
||||
assert not (tmp_path / ".gateway-takeover.json").exists()
|
||||
|
||||
def test_consume_returns_false_for_different_pid(self, tmp_path, monkeypatch):
|
||||
"""A marker naming a DIFFERENT process must not be consumed as ours."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 100)
|
||||
# Marker names a different PID
|
||||
other_pid = os.getpid() + 9999
|
||||
ok = status.write_takeover_marker(target_pid=other_pid)
|
||||
assert ok is True
|
||||
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
assert result is False
|
||||
# Marker IS unlinked even on non-match (the record has been consumed
|
||||
# and isn't relevant to us — leaving it around would grief a later
|
||||
# legitimate check).
|
||||
assert not (tmp_path / ".gateway-takeover.json").exists()
|
||||
|
||||
def test_consume_returns_false_on_start_time_mismatch(self, tmp_path, monkeypatch):
|
||||
"""PID reuse defence: old marker's start_time mismatches current process."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Marker says target started at time 100 with our PID
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 100)
|
||||
status.write_takeover_marker(target_pid=os.getpid())
|
||||
|
||||
# Now change the reported start_time to simulate PID reuse
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 9999)
|
||||
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_consume_returns_false_when_marker_missing(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_consume_returns_false_for_stale_marker(self, tmp_path, monkeypatch):
|
||||
"""A marker older than 60s must be ignored."""
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
marker_path = tmp_path / ".gateway-takeover.json"
|
||||
# Hand-craft a marker written 2 minutes ago
|
||||
stale_time = (datetime.now(timezone.utc) - timedelta(minutes=2)).isoformat()
|
||||
marker_path.write_text(json.dumps({
|
||||
"target_pid": os.getpid(),
|
||||
"target_start_time": 123,
|
||||
"replacer_pid": 99999,
|
||||
"written_at": stale_time,
|
||||
}))
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 123)
|
||||
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
assert result is False
|
||||
# Stale markers are unlinked so a later legit shutdown isn't griefed
|
||||
assert not marker_path.exists()
|
||||
|
||||
def test_consume_handles_malformed_marker_gracefully(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
marker_path = tmp_path / ".gateway-takeover.json"
|
||||
marker_path.write_text("not valid json{")
|
||||
|
||||
# Must not raise
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_consume_handles_marker_with_missing_fields(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
marker_path = tmp_path / ".gateway-takeover.json"
|
||||
marker_path.write_text(json.dumps({"only_replacer_pid": 99999}))
|
||||
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
assert result is False
|
||||
# Malformed marker should be cleaned up
|
||||
assert not marker_path.exists()
|
||||
|
||||
def test_clear_takeover_marker_is_idempotent(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
# Nothing to clear — must not raise
|
||||
status.clear_takeover_marker()
|
||||
|
||||
# Write then clear
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 100)
|
||||
status.write_takeover_marker(target_pid=12345)
|
||||
assert (tmp_path / ".gateway-takeover.json").exists()
|
||||
|
||||
status.clear_takeover_marker()
|
||||
assert not (tmp_path / ".gateway-takeover.json").exists()
|
||||
|
||||
# Clear again — still no error
|
||||
status.clear_takeover_marker()
|
||||
|
||||
def test_write_marker_returns_false_on_write_failure(self, tmp_path, monkeypatch):
|
||||
"""write_takeover_marker is best-effort; returns False but doesn't raise."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
def raise_oserror(*args, **kwargs):
|
||||
raise OSError("simulated write failure")
|
||||
|
||||
monkeypatch.setattr(status, "_write_json_file", raise_oserror)
|
||||
|
||||
ok = status.write_takeover_marker(target_pid=12345)
|
||||
|
||||
assert ok is False
|
||||
|
||||
def test_consume_ignores_marker_for_different_process_and_prevents_stale_grief(
|
||||
self, tmp_path, monkeypatch
|
||||
):
|
||||
"""Regression: a stale marker from a dead replacer naming a dead
|
||||
target must not accidentally cause an unrelated future gateway to
|
||||
exit 0 on legitimate SIGTERM.
|
||||
|
||||
The distinguishing check is ``target_pid == our_pid AND
|
||||
target_start_time == our_start_time``. Different PID always wins.
|
||||
"""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
marker_path = tmp_path / ".gateway-takeover.json"
|
||||
# Fresh marker (timestamp is recent) but names a totally different PID
|
||||
from datetime import datetime, timezone
|
||||
marker_path.write_text(json.dumps({
|
||||
"target_pid": os.getpid() + 10000,
|
||||
"target_start_time": 42,
|
||||
"replacer_pid": 99999,
|
||||
"written_at": datetime.now(timezone.utc).isoformat(),
|
||||
}))
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 42)
|
||||
|
||||
result = status.consume_takeover_marker_for_self()
|
||||
|
||||
# We are not the target — must NOT consume as planned
|
||||
assert result is False
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Tests for gateway /status behavior and token persistence."""
|
||||
|
||||
from datetime import datetime
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
|
@ -111,6 +112,75 @@ async def test_status_command_includes_session_title_when_present():
|
|||
assert "**Title:** My titled session" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agents_command_reports_active_agents_and_processes(monkeypatch):
|
||||
session_key = build_session_key(_make_source())
|
||||
session_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
total_tokens=0,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
running_agent = SimpleNamespace(
|
||||
session_id="sess-running",
|
||||
model="openrouter/test-model",
|
||||
interrupt=MagicMock(),
|
||||
get_activity_summary=lambda: {"seconds_since_activity": 0},
|
||||
)
|
||||
runner._running_agents[session_key] = running_agent
|
||||
runner._running_agents_ts = {session_key: time.time() - 8}
|
||||
runner._background_tasks = set()
|
||||
|
||||
class _FakeRegistry:
|
||||
def list_sessions(self):
|
||||
return [
|
||||
{
|
||||
"session_id": "proc-1",
|
||||
"status": "running",
|
||||
"uptime_seconds": 17,
|
||||
"command": "sleep 30",
|
||||
}
|
||||
]
|
||||
|
||||
monkeypatch.setattr("tools.process_registry.process_registry", _FakeRegistry())
|
||||
|
||||
result = await runner._handle_message(_make_event("/agents"))
|
||||
|
||||
assert "**Active agents:** 1" in result
|
||||
assert "**Running background processes:** 1" in result
|
||||
assert "proc-1" in result
|
||||
running_agent.interrupt.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tasks_alias_routes_to_agents_command(monkeypatch):
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
total_tokens=0,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._background_tasks = set()
|
||||
|
||||
class _FakeRegistry:
|
||||
def list_sessions(self):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr("tools.process_registry.process_registry", _FakeRegistry())
|
||||
|
||||
result = await runner._handle_message(_make_event("/tasks"))
|
||||
|
||||
assert "Active Agents & Tasks" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
|
|
|||
|
|
@ -88,6 +88,51 @@ class TestCleanForDisplay:
|
|||
# ── Integration: _send_or_edit strips MEDIA: ─────────────────────────────
|
||||
|
||||
|
||||
class TestFinalizeCapabilityGate:
|
||||
"""Verify REQUIRES_EDIT_FINALIZE gates the redundant final edit.
|
||||
|
||||
Platforms that don't need an explicit finalize signal (Telegram,
|
||||
Slack, Matrix, …) should skip the redundant final edit when the
|
||||
mid-stream edit already delivered the final content. Platforms that
|
||||
*do* need it (DingTalk AI Cards) must always receive a finalize=True
|
||||
edit at the end of the stream.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_identical_text_skip_respects_adapter_flag(self):
|
||||
"""_send_or_edit short-circuits identical-text only when the
|
||||
adapter doesn't require an explicit finalize signal."""
|
||||
# Adapter without finalize requirement — should skip identical edit.
|
||||
plain = MagicMock()
|
||||
plain.REQUIRES_EDIT_FINALIZE = False
|
||||
plain.send = AsyncMock(return_value=SimpleNamespace(
|
||||
success=True, message_id="m1",
|
||||
))
|
||||
plain.edit_message = AsyncMock()
|
||||
plain.MAX_MESSAGE_LENGTH = 4096
|
||||
c1 = GatewayStreamConsumer(plain, "chat_1")
|
||||
await c1._send_or_edit("hello") # first send
|
||||
await c1._send_or_edit("hello", finalize=True) # identical → skip
|
||||
plain.edit_message.assert_not_called()
|
||||
|
||||
# Adapter that requires finalize — must still fire the edit.
|
||||
picky = MagicMock()
|
||||
picky.REQUIRES_EDIT_FINALIZE = True
|
||||
picky.send = AsyncMock(return_value=SimpleNamespace(
|
||||
success=True, message_id="m1",
|
||||
))
|
||||
picky.edit_message = AsyncMock(return_value=SimpleNamespace(
|
||||
success=True, message_id="m1",
|
||||
))
|
||||
picky.MAX_MESSAGE_LENGTH = 4096
|
||||
c2 = GatewayStreamConsumer(picky, "chat_1")
|
||||
await c2._send_or_edit("hello")
|
||||
await c2._send_or_edit("hello", finalize=True)
|
||||
# Finalize edit must go through even on identical content.
|
||||
picky.edit_message.assert_called_once()
|
||||
assert picky.edit_message.call_args[1]["finalize"] is True
|
||||
|
||||
|
||||
class TestSendOrEditMediaStripping:
|
||||
"""Verify _send_or_edit strips MEDIA: before sending to the platform."""
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,12 @@ def _ensure_telegram_mock():
|
|||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2 # noqa: E402
|
||||
from gateway.platforms.telegram import ( # noqa: E402
|
||||
TelegramAdapter,
|
||||
_escape_mdv2,
|
||||
_strip_mdv2,
|
||||
_wrap_markdown_tables,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -535,6 +540,152 @@ class TestStripMdv2:
|
|||
assert _strip_mdv2("||hidden text||") == "hidden text"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Markdown table auto-wrap
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestWrapMarkdownTables:
|
||||
"""_wrap_markdown_tables wraps GFM pipe tables in ``` fences so
|
||||
Telegram renders them as monospace preformatted text instead of the
|
||||
noisy backslash-pipe mess MarkdownV2 produces."""
|
||||
|
||||
def test_basic_table_wrapped(self):
|
||||
text = (
|
||||
"Scores:\n\n"
|
||||
"| Player | Score |\n"
|
||||
"|--------|-------|\n"
|
||||
"| Alice | 150 |\n"
|
||||
"| Bob | 120 |\n"
|
||||
"\nEnd."
|
||||
)
|
||||
out = _wrap_markdown_tables(text)
|
||||
# Table is now wrapped in a fence
|
||||
assert "```\n| Player | Score |" in out
|
||||
assert "| Bob | 120 |\n```" in out
|
||||
# Surrounding prose is preserved
|
||||
assert out.startswith("Scores:")
|
||||
assert out.endswith("End.")
|
||||
|
||||
def test_bare_pipe_table_wrapped(self):
|
||||
"""Tables without outer pipes (GFM allows this) are still detected."""
|
||||
text = "head1 | head2\n--- | ---\na | b\nc | d"
|
||||
out = _wrap_markdown_tables(text)
|
||||
assert out.startswith("```\n")
|
||||
assert out.rstrip().endswith("```")
|
||||
assert "head1 | head2" in out
|
||||
|
||||
def test_alignment_separators(self):
|
||||
"""Separator rows with :--- / ---: / :---: alignment markers match."""
|
||||
text = (
|
||||
"| Name | Age | City |\n"
|
||||
"|:-----|----:|:----:|\n"
|
||||
"| Ada | 30 | NYC |"
|
||||
)
|
||||
out = _wrap_markdown_tables(text)
|
||||
assert out.count("```") == 2
|
||||
|
||||
def test_two_consecutive_tables_wrapped_separately(self):
|
||||
text = (
|
||||
"| A | B |\n"
|
||||
"|---|---|\n"
|
||||
"| 1 | 2 |\n"
|
||||
"\n"
|
||||
"| X | Y |\n"
|
||||
"|---|---|\n"
|
||||
"| 9 | 8 |"
|
||||
)
|
||||
out = _wrap_markdown_tables(text)
|
||||
# Four fences total — one opening + closing per table
|
||||
assert out.count("```") == 4
|
||||
|
||||
def test_plain_text_with_pipes_not_wrapped(self):
|
||||
"""A bare pipe in prose must NOT trigger wrapping."""
|
||||
text = "Use the | pipe operator to chain commands."
|
||||
assert _wrap_markdown_tables(text) == text
|
||||
|
||||
def test_horizontal_rule_not_wrapped(self):
|
||||
"""A lone '---' horizontal rule must not be mistaken for a separator."""
|
||||
text = "Section A\n\n---\n\nSection B"
|
||||
assert _wrap_markdown_tables(text) == text
|
||||
|
||||
def test_existing_code_block_with_pipes_left_alone(self):
|
||||
"""A table already inside a fenced code block must not be re-wrapped."""
|
||||
text = (
|
||||
"```\n"
|
||||
"| a | b |\n"
|
||||
"|---|---|\n"
|
||||
"| 1 | 2 |\n"
|
||||
"```"
|
||||
)
|
||||
assert _wrap_markdown_tables(text) == text
|
||||
|
||||
def test_no_pipe_character_short_circuits(self):
|
||||
text = "Plain **bold** text with no table."
|
||||
assert _wrap_markdown_tables(text) == text
|
||||
|
||||
def test_no_dash_short_circuits(self):
|
||||
text = "a | b\nc | d" # has pipes but no '-' separator row
|
||||
assert _wrap_markdown_tables(text) == text
|
||||
|
||||
def test_single_column_separator_not_matched(self):
|
||||
"""Single-column tables (rare) are not detected — we require at
|
||||
least one internal pipe in the separator row to avoid false
|
||||
positives on formatting rules."""
|
||||
text = "| a |\n| - |\n| b |"
|
||||
assert _wrap_markdown_tables(text) == text
|
||||
|
||||
|
||||
class TestFormatMessageTables:
|
||||
"""End-to-end: a pipe table passes through format_message with its
|
||||
pipes and dashes left alone inside the fence, not mangled by MarkdownV2
|
||||
escaping."""
|
||||
|
||||
def test_table_rendered_as_code_block(self, adapter):
|
||||
text = (
|
||||
"Data:\n\n"
|
||||
"| Col1 | Col2 |\n"
|
||||
"|------|------|\n"
|
||||
"| A | B |\n"
|
||||
)
|
||||
out = adapter.format_message(text)
|
||||
# Pipes inside the fenced block are NOT escaped
|
||||
assert "```\n| Col1 | Col2 |" in out
|
||||
assert "\\|" not in out.split("```")[1]
|
||||
# Dashes in separator not escaped inside fence
|
||||
assert "\\-" not in out.split("```")[1]
|
||||
|
||||
def test_text_after_table_still_formatted(self, adapter):
|
||||
text = (
|
||||
"| A | B |\n"
|
||||
"|---|---|\n"
|
||||
"| 1 | 2 |\n"
|
||||
"\n"
|
||||
"Nice **work** team!"
|
||||
)
|
||||
out = adapter.format_message(text)
|
||||
# MarkdownV2 bold conversion still happens outside the table
|
||||
assert "*work*" in out
|
||||
# Exclamation outside fence is escaped
|
||||
assert "\\!" in out
|
||||
|
||||
def test_multiple_tables_in_single_message(self, adapter):
|
||||
text = (
|
||||
"First:\n"
|
||||
"| A | B |\n"
|
||||
"|---|---|\n"
|
||||
"| 1 | 2 |\n"
|
||||
"\n"
|
||||
"Second:\n"
|
||||
"| X | Y |\n"
|
||||
"|---|---|\n"
|
||||
"| 9 | 8 |\n"
|
||||
)
|
||||
out = adapter.format_message(text)
|
||||
# Two separate fenced blocks in the output
|
||||
assert out.count("```") == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_escapes_chunk_indicator_for_markdownv2(adapter):
|
||||
adapter.MAX_MESSAGE_LENGTH = 80
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue