diff --git a/AGENTS.md b/AGENTS.md index 8045c3d213..8f227968e3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -351,8 +351,9 @@ Cache-breaking forces dramatically higher costs. The ONLY time we alter context ### Background Process Notifications (Gateway) -When `terminal(background=true, check_interval=...)` is used, the gateway runs a watcher that -pushes status updates to the user's chat. Control verbosity with `display.background_process_notifications` +When `terminal(background=true, notify_on_complete=true)` is used, the gateway runs a watcher that +detects process completion and triggers a new agent turn. Control verbosity of background process +messages with `display.background_process_notifications` in config.yaml (or `HERMES_BACKGROUND_NOTIFICATIONS` env var): - `all` — running-output updates + final message (default) diff --git a/agent/context_compressor.py b/agent/context_compressor.py index 069a5b65e1..5aa95dc01b 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional from agent.auxiliary_client import call_llm from agent.context_engine import ContextEngine from agent.model_metadata import ( + MINIMUM_CONTEXT_LENGTH, get_model_context_length, estimate_messages_tokens_rough, ) @@ -87,7 +88,10 @@ class ContextCompressor(ContextEngine): self.api_key = api_key self.provider = provider self.context_length = context_length - self.threshold_tokens = int(context_length * self.threshold_percent) + self.threshold_tokens = max( + int(context_length * self.threshold_percent), + MINIMUM_CONTEXT_LENGTH, + ) def __init__( self, @@ -118,7 +122,14 @@ class ContextCompressor(ContextEngine): config_context_length=config_context_length, provider=provider, ) - self.threshold_tokens = int(self.context_length * threshold_percent) + # Floor: never compress below MINIMUM_CONTEXT_LENGTH tokens even if + # the percentage would suggest a lower value. This prevents premature + # compression on large-context models at 50% while keeping the % sane + # for models right at the minimum. + self.threshold_tokens = max( + int(self.context_length * threshold_percent), + MINIMUM_CONTEXT_LENGTH, + ) self.compression_count = 0 # Derive token budgets: ratio is relative to the threshold, not total context diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 2ef6830e58..03f70b3fe4 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -85,6 +85,11 @@ CONTEXT_PROBE_TIERS = [ # Default context length when no detection method succeeds. DEFAULT_FALLBACK_CONTEXT = CONTEXT_PROBE_TIERS[0] +# Minimum context length required to run Hermes Agent. Models with fewer +# tokens cannot maintain enough working memory for tool-calling workflows. +# Sessions, model switches, and cron jobs should reject models below this. +MINIMUM_CONTEXT_LENGTH = 64_000 + # Thin fallback defaults — only broad model family patterns. # These fire only when provider is unknown AND models.dev/OpenRouter/Anthropic # all miss. Replaced the previous 80+ entry dict. @@ -179,6 +184,12 @@ _MAX_COMPLETION_KEYS = ( # Local server hostnames / address patterns _LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0") +# Docker / Podman / Lima DNS names that resolve to the host machine +_CONTAINER_LOCAL_SUFFIXES = ( + ".docker.internal", + ".containers.internal", + ".lima.internal", +) def _normalize_base_url(base_url: str) -> str: @@ -254,6 +265,9 @@ def is_local_endpoint(base_url: str) -> bool: return False if host in _LOCAL_HOSTS: return True + # Docker / Podman / Lima internal DNS names (e.g. host.docker.internal) + if any(host.endswith(suffix) for suffix in _CONTAINER_LOCAL_SUFFIXES): + return True # RFC-1918 private ranges and link-local import ipaddress try: @@ -1031,16 +1045,21 @@ def get_model_context_length( def estimate_tokens_rough(text: str) -> int: - """Rough token estimate (~4 chars/token) for pre-flight checks.""" + """Rough token estimate (~4 chars/token) for pre-flight checks. + + Uses ceiling division so short texts (1-3 chars) never estimate as + 0 tokens, which would cause the compressor and pre-flight checks to + systematically undercount when many short tool results are present. + """ if not text: return 0 - return len(text) // 4 + return (len(text) + 3) // 4 def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int: """Rough token estimate for a message list (pre-flight only).""" total_chars = sum(len(str(msg)) for msg in messages) - return total_chars // 4 + return (total_chars + 3) // 4 def estimate_request_tokens_rough( @@ -1063,4 +1082,4 @@ def estimate_request_tokens_rough( total_chars += sum(len(str(msg)) for msg in messages) if tools: total_chars += len(str(tools)) - return total_chars // 4 + return (total_chars + 3) // 4 diff --git a/cli-config.yaml.example b/cli-config.yaml.example index 12e2b39995..c9e6645bba 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -774,6 +774,11 @@ display: # Toggle at runtime with /verbose in the CLI tool_progress: all + # Gateway-only natural mid-turn assistant updates. + # When true, completed assistant status messages are sent as separate chat + # messages. This is independent of tool_progress and gateway streaming. + interim_assistant_messages: true + # What Enter does when Hermes is already busy in the CLI. # interrupt: Interrupt the current run and redirect Hermes (default) # queue: Queue your message for the next turn @@ -782,7 +787,7 @@ display: # Background process notifications (gateway/messaging only). # Controls how chatty the process watcher is when you use - # terminal(background=true, check_interval=...) from Telegram/Discord/etc. + # terminal(background=true, notify_on_complete=true) from Telegram/Discord/etc. # off: No watcher messages at all # result: Only the final completion message # error: Only the final message when exit code != 0 diff --git a/cli.py b/cli.py index ff80a49b8f..26a2233880 100644 --- a/cli.py +++ b/cli.py @@ -1355,6 +1355,19 @@ class ChatConsole: for line in output.rstrip("\n").split("\n"): _cprint(line) + @contextmanager + def status(self, *_args, **_kwargs): + """Provide a no-op Rich-compatible status context. + + Some slash command helpers use ``console.status(...)`` when running in + the standalone CLI. Interactive chat routes those helpers through + ``ChatConsole()``, which historically only implemented ``print()``. + Returning a silent context manager keeps slash commands compatible + without duplicating the higher-level busy indicator already shown by + ``HermesCLI._busy_command()``. + """ + yield self + # ASCII Art - HERMES-AGENT logo (full width, single line - requires ~95 char terminal) HERMES_AGENT_LOGO = """[bold #FFD700]██╗ ██╗███████╗██████╗ ███╗ ███╗███████╗███████╗ █████╗ ██████╗ ███████╗███╗ ██╗████████╗[/] [bold #FFD700]██║ ██║██╔════╝██╔══██╗████╗ ████║██╔════╝██╔════╝ ██╔══██╗██╔════╝ ██╔════╝████╗ ██║╚══██╔══╝[/] @@ -1804,6 +1817,7 @@ class HermesCLI: self._approval_state = None self._approval_deadline = 0 self._approval_lock = threading.Lock() + self._model_picker_state = None self._secret_state = None self._secret_deadline = 0 self._spinner_text: str = "" # thinking spinner text for TUI @@ -2046,7 +2060,7 @@ class HermesCLI: return f"⚕ {self.model if getattr(self, 'model', None) else 'Hermes'}" def _get_status_bar_fragments(self): - if not self._status_bar_visible: + if not self._status_bar_visible or getattr(self, '_model_picker_state', None): return [] try: snapshot = self._get_status_bar_snapshot() @@ -4278,6 +4292,265 @@ class HermesCLI: remaining = len(self.conversation_history) print(f" {remaining} message(s) remaining in history.") + def _run_curses_picker(self, title: str, items: list[str], default_index: int = 0) -> int | None: + """Run curses_single_select via run_in_terminal so prompt_toolkit handles terminal ownership cleanly.""" + import threading + from hermes_cli.curses_ui import curses_single_select + + result = [None] + + def _pick(): + result[0] = curses_single_select(title, items, default_index=default_index) + + # run_in_terminal requires an asyncio event loop — only exists in the + # main prompt_toolkit thread. If we're in a background thread (e.g. + # process_loop), fall back to direct curses call. + in_main_thread = threading.current_thread() is threading.main_thread() + + if self._app and in_main_thread: + from prompt_toolkit.application import run_in_terminal + was_visible = self._status_bar_visible + self._status_bar_visible = False + self._app.invalidate() + try: + run_in_terminal(_pick) + finally: + self._status_bar_visible = was_visible + self._app.invalidate() + else: + _pick() + + return result[0] + + def _prompt_text_input(self, prompt_text: str) -> str | None: + """Prompt for free-text input safely inside or outside prompt_toolkit.""" + result = [None] + + def _ask(): + try: + result[0] = input(prompt_text).strip() or None + except (KeyboardInterrupt, EOFError): + pass + + if self._app: + from prompt_toolkit.application import run_in_terminal + was_visible = self._status_bar_visible + self._status_bar_visible = False + self._app.invalidate() + try: + run_in_terminal(_ask) + finally: + self._status_bar_visible = was_visible + self._app.invalidate() + else: + _ask() + return result[0] + + def _interactive_provider_selection( + self, providers: list, current_model: str, current_provider: str + ) -> str | None: + """Show provider picker, return slug or None on cancel.""" + choices = [] + for p in providers: + count = p.get("total_models", len(p.get("models", []))) + label = f"{p['name']} ({count} model{'s' if count != 1 else ''})" + if p.get("is_current"): + label += " ← current" + choices.append(label) + + default_idx = next( + (i for i, p in enumerate(providers) if p.get("is_current")), 0 + ) + + idx = self._run_curses_picker( + f"Select a provider (current: {current_model} on {current_provider}):", + choices, + default_index=default_idx, + ) + if idx is None: + return None + return providers[idx]["slug"] + + def _interactive_model_selection( + self, model_list: list, provider_data: dict + ) -> str | None: + """Show model picker for a given provider, return model_id or None on cancel.""" + pname = provider_data.get("name", provider_data.get("slug", "")) + total = provider_data.get("total_models", len(model_list)) + + if not model_list: + _cprint(f"\n No models listed for {pname}.") + return self._prompt_text_input(" Enter model name manually (or Enter to cancel): ") + + choices = list(model_list) + ["Enter custom model name"] + idx = self._run_curses_picker( + f"Select model from {pname} ({len(model_list)} of {total}):", + choices, + ) + if idx is None: + return None + if idx < len(model_list): + return model_list[idx] + return self._prompt_text_input(" Enter model name: ") + + def _open_model_picker(self, providers: list, current_model: str, current_provider: str, user_provs=None, custom_provs=None) -> None: + """Open prompt_toolkit-native /model picker modal.""" + self._capture_modal_input_snapshot() + default_idx = next((i for i, p in enumerate(providers) if p.get("is_current")), 0) + self._model_picker_state = { + "stage": "provider", + "providers": providers, + "selected": default_idx, + "current_model": current_model, + "current_provider": current_provider, + "user_provs": user_provs, + "custom_provs": custom_provs, + } + self._invalidate(min_interval=0.0) + + def _close_model_picker(self) -> None: + self._model_picker_state = None + self._restore_modal_input_snapshot() + self._invalidate(min_interval=0.0) + + def _apply_model_switch_result(self, result, persist_global: bool) -> None: + if not result.success: + _cprint(f" ✗ {result.error_message}") + return + + old_model = self.model + self.model = result.new_model + self.provider = result.target_provider + self.requested_provider = result.target_provider + if result.api_key: + self.api_key = result.api_key + self._explicit_api_key = result.api_key + if result.base_url: + self.base_url = result.base_url + self._explicit_base_url = result.base_url + if result.api_mode: + self.api_mode = result.api_mode + + if self.agent is not None: + try: + self.agent.switch_model( + new_model=result.new_model, + new_provider=result.target_provider, + api_key=result.api_key, + base_url=result.base_url, + api_mode=result.api_mode, + ) + except Exception as exc: + _cprint(f" ⚠ Agent swap failed ({exc}); change applied to next session.") + + self._pending_model_switch_note = ( + f"[Note: model was just switched from {old_model} to {result.new_model} " + f"via {result.provider_label or result.target_provider}. " + f"Adjust your self-identification accordingly.]" + ) + + provider_label = result.provider_label or result.target_provider + _cprint(f" ✓ Model switched: {result.new_model}") + _cprint(f" Provider: {provider_label}") + + mi = result.model_info + if mi: + if mi.context_window: + _cprint(f" Context: {mi.context_window:,} tokens") + if mi.max_output: + _cprint(f" Max output: {mi.max_output:,} tokens") + if mi.has_cost_data(): + _cprint(f" Cost: {mi.format_cost()}") + _cprint(f" Capabilities: {mi.format_capabilities()}") + else: + try: + from agent.model_metadata import get_model_context_length + ctx = get_model_context_length( + result.new_model, + base_url=result.base_url or self.base_url, + api_key=result.api_key or self.api_key, + provider=result.target_provider, + ) + _cprint(f" Context: {ctx:,} tokens") + except Exception: + pass + + cache_enabled = ( + ("openrouter" in (result.base_url or "").lower() and "claude" in result.new_model.lower()) + or result.api_mode == "anthropic_messages" + ) + if cache_enabled: + _cprint(" Prompt caching: enabled") + if result.warning_message: + _cprint(f" ⚠ {result.warning_message}") + if persist_global: + save_config_value("model.default", result.new_model) + if result.provider_changed: + save_config_value("model.provider", result.target_provider) + _cprint(" Saved to config.yaml (--global)") + else: + _cprint(" (session only — add --global to persist)") + + def _handle_model_picker_selection(self, persist_global: bool = False) -> None: + state = self._model_picker_state + if not state: + return + selected = state.get("selected", 0) + stage = state.get("stage") + if stage == "provider": + providers = state.get("providers") or [] + if selected >= len(providers): + self._close_model_picker() + return + provider_data = providers[selected] + model_list = [] + try: + from hermes_cli.models import provider_model_ids + live = provider_model_ids(provider_data["slug"]) + if live: + model_list = live + except Exception: + pass + if not model_list: + model_list = provider_data.get("models", []) + state["stage"] = "model" + state["provider_data"] = provider_data + state["model_list"] = model_list + state["selected"] = 0 + self._invalidate(min_interval=0.0) + return + if stage == "model": + provider_data = state.get("provider_data") or {} + model_list = state.get("model_list") or [] + back_idx = len(model_list) + cancel_idx = len(model_list) + 1 + if selected == back_idx: + state["stage"] = "provider" + state["selected"] = next((i for i, p in enumerate(state.get("providers") or []) if p.get("slug") == provider_data.get("slug")), 0) + self._invalidate(min_interval=0.0) + return + if selected >= cancel_idx: + self._close_model_picker() + return + if selected < len(model_list): + from hermes_cli.model_switch import switch_model + chosen_model = model_list[selected] + result = switch_model( + raw_input=chosen_model, + current_provider=self.provider or "", + current_model=self.model or "", + current_base_url=self.base_url or "", + current_api_key=self.api_key or "", + is_global=persist_global, + explicit_provider=provider_data.get("slug"), + user_providers=state.get("user_provs"), + custom_providers=state.get("custom_provs"), + ) + self._close_model_picker() + self._apply_model_switch_result(result, persist_global) + return + self._close_model_picker() + def _handle_model_switch(self, cmd_original: str): """Handle /model command — switch model for this session. @@ -4300,56 +4573,46 @@ class HermesCLI: user_provs = None custom_provs = None - try: - from hermes_cli.config import load_config - cfg = load_config() - user_provs = cfg.get("providers") - custom_provs = cfg.get("custom_providers") - except Exception: - pass - # No args at all: show available providers + models + # No args at all: open prompt_toolkit-native picker modal if not model_input and not explicit_provider: model_display = self.model or "unknown" provider_display = get_label(self.provider) if self.provider else "unknown" - _cprint(f" Current: {model_display} on {provider_display}") - _cprint("") - # Show authenticated providers with top models + user_provs = None + custom_provs = None + try: + from hermes_cli.config import load_config + cfg = load_config() + user_provs = cfg.get("providers") + custom_provs = cfg.get("custom_providers") + except Exception: + pass + try: providers = list_authenticated_providers( current_provider=self.provider or "", user_providers=user_provs, custom_providers=custom_provs, - max_models=6, + max_models=50, ) - if providers: - for p in providers: - tag = " (current)" if p["is_current"] else "" - _cprint(f" {p['name']} [--provider {p['slug']}]{tag}:") - if p["models"]: - model_strs = ", ".join(p["models"]) - extra = f" (+{p['total_models'] - len(p['models'])} more)" if p["total_models"] > len(p["models"]) else "" - _cprint(f" {model_strs}{extra}") - elif p.get("api_url"): - _cprint(f" {p['api_url']} (use /model --provider {p['slug']})") - else: - _cprint(f" (no models listed)") - _cprint("") - else: - _cprint(" No authenticated providers found.") - _cprint("") except Exception: - pass + providers = [] - # Aliases - from hermes_cli.model_switch import MODEL_ALIASES - alias_list = ", ".join(sorted(MODEL_ALIASES.keys())) - _cprint(f" Aliases: {alias_list}") - _cprint("") - _cprint(" /model switch model") - _cprint(" /model --provider switch provider") - _cprint(" /model --global persist to config") + if not providers: + _cprint(" No authenticated providers found.") + _cprint("") + _cprint(" /model switch model") + _cprint(" /model --provider switch provider") + return + + self._open_model_picker( + providers, + model_display, + provider_display, + user_provs=user_provs, + custom_provs=custom_provs, + ) return # Perform the switch @@ -4457,6 +4720,18 @@ class HermesCLI: else: _cprint(" (session only — add --global to persist)") + def _should_handle_model_command_inline(self, text: str, has_images: bool = False) -> bool: + """Return True when /model should be handled immediately on the UI thread.""" + if not text or has_images or not _looks_like_slash_command(text): + return False + try: + from hermes_cli.commands import resolve_command + base = text.split(None, 1)[0].lower().lstrip('/') + cmd = resolve_command(base) + return bool(cmd and cmd.name == "model") + except Exception: + return False + def _show_model_and_providers(self): """Show current model + provider and list all authenticated providers. @@ -7679,7 +7954,8 @@ class HermesCLI: secret_widget, approval_widget, clarify_widget, - spinner_widget, + model_picker_widget=None, + spinner_widget=None, spacer, status_bar, input_rule_top, @@ -7696,21 +7972,24 @@ class HermesCLI: ordering. """ return [ - Window(height=0), - sudo_widget, - secret_widget, - approval_widget, - clarify_widget, - spinner_widget, - spacer, - *self._get_extra_tui_widgets(), - status_bar, - input_rule_top, - image_bar, - input_area, - input_rule_bot, - voice_status_bar, - completions_menu, + item for item in [ + Window(height=0), + sudo_widget, + secret_widget, + approval_widget, + clarify_widget, + model_picker_widget, + spinner_widget, + spacer, + *self._get_extra_tui_widgets(), + status_bar, + input_rule_top, + image_bar, + input_area, + input_rule_bot, + voice_status_bar, + completions_menu, + ] if item is not None ] def run(self): @@ -7871,6 +8150,12 @@ class HermesCLI: event.app.invalidate() return + # --- /model picker modal --- + if self._model_picker_state: + self._handle_model_picker_selection() + event.app.invalidate() + return + # --- Clarify freetext mode: user typed their own answer --- if self._clarify_freetext and self._clarify_state: text = event.app.current_buffer.text.strip() @@ -7901,6 +8186,16 @@ class HermesCLI: text = event.app.current_buffer.text.strip() has_images = bool(self._attached_images) if text or has_images: + # Handle /model directly on the UI thread so interactive pickers + # can safely use prompt_toolkit terminal handoff helpers. + if self._should_handle_model_command_inline(text, has_images=has_images): + if not self.process_command(text): + self._should_exit = True + if event.app.is_running: + event.app.exit() + event.app.current_buffer.reset(append_to_history=True) + return + # Snapshot and clear attached images images = list(self._attached_images) self._attached_images.clear() @@ -8004,12 +8299,31 @@ class HermesCLI: self._approval_state["selected"] = min(max_idx, self._approval_state["selected"] + 1) event.app.invalidate() + # --- /model picker: arrow-key navigation --- + @kb.add('up', filter=Condition(lambda: bool(self._model_picker_state))) + def model_picker_up(event): + if self._model_picker_state: + self._model_picker_state["selected"] = max(0, self._model_picker_state.get("selected", 0) - 1) + event.app.invalidate() + + @kb.add('down', filter=Condition(lambda: bool(self._model_picker_state))) + def model_picker_down(event): + state = self._model_picker_state + if not state: + return + if state.get("stage") == "provider": + max_idx = len(state.get("providers") or []) + else: + max_idx = len(state.get("model_list") or []) + 1 + state["selected"] = min(max_idx, state.get("selected", 0) + 1) + event.app.invalidate() + # --- History navigation: up/down browse history in normal input mode --- # The TextArea is multiline, so by default up/down only move the cursor. # Buffer.auto_up/auto_down handle both: cursor movement when multi-line, # history browsing when on the first/last line (or single-line input). _normal_input = Condition( - lambda: not self._clarify_state and not self._approval_state and not self._sudo_state and not self._secret_state + lambda: not self._clarify_state and not self._approval_state and not self._sudo_state and not self._secret_state and not self._model_picker_state ) @kb.add('up', filter=_normal_input) @@ -8075,6 +8389,13 @@ class HermesCLI: event.app.invalidate() return + # Cancel /model picker + if self._model_picker_state: + self._close_model_picker() + event.app.current_buffer.reset() + event.app.invalidate() + return + # Cancel clarify prompt if self._clarify_state: self._clarify_state["response_queue"].put( @@ -8127,7 +8448,7 @@ class HermesCLI: agent_name = get_active_skin().get_branding("agent_name", "Hermes Agent") msg = f"\n{agent_name} has been suspended. Run `fg` to bring {agent_name} back." def _suspend(): - os.write(1, msg.encode("utf-8", errors="replace")) + os.write(1, msg.encode()) os.kill(0, _sig.SIGTSTP) run_in_terminal(_suspend) @@ -8692,6 +9013,60 @@ class HermesCLI: filter=Condition(lambda: cli_ref._approval_state is not None), ) + # --- /model picker: display widget --- + def _get_model_picker_display(): + state = cli_ref._model_picker_state + if not state: + return [] + stage = state.get("stage", "provider") + if stage == "provider": + title = "⚙ Model Picker — Select Provider" + choices = [] + for p in state.get("providers") or []: + count = p.get("total_models", len(p.get("models", []))) + label = f"{p['name']} ({count} model{'s' if count != 1 else ''})" + if p.get("is_current"): + label += " ← current" + choices.append(label) + choices.append("Cancel") + hint = f"Current: {state.get('current_model', 'unknown')} on {state.get('current_provider', 'unknown')}" + else: + provider_data = state.get("provider_data") or {} + model_list = state.get("model_list") or [] + title = f"⚙ Model Picker — {provider_data.get('name', provider_data.get('slug', 'Provider'))}" + choices = list(model_list) + ["← Back", "Cancel"] + if model_list: + hint = f"Select a model ({len(model_list)} available)" + else: + hint = "No models listed for this provider. Use Back or Cancel." + + box_width = _panel_box_width(title, [hint] + choices, min_width=46, max_width=84) + inner_text_width = max(8, box_width - 6) + lines = [] + lines.append(('class:clarify-border', '╭─ ')) + lines.append(('class:clarify-title', title)) + lines.append(('class:clarify-border', ' ' + ('─' * max(0, box_width - len(title) - 3)) + '╮\n')) + _append_blank_panel_line(lines, 'class:clarify-border', box_width) + _append_panel_line(lines, 'class:clarify-border', 'class:clarify-hint', hint, box_width) + _append_blank_panel_line(lines, 'class:clarify-border', box_width) + selected = state.get("selected", 0) + for idx, choice in enumerate(choices): + style = 'class:clarify-selected' if idx == selected else 'class:clarify-choice' + prefix = '❯ ' if idx == selected else ' ' + for wrapped in _wrap_panel_text(prefix + choice, inner_text_width, subsequent_indent=' '): + _append_panel_line(lines, 'class:clarify-border', style, wrapped, box_width) + _append_blank_panel_line(lines, 'class:clarify-border', box_width) + lines.append(('class:clarify-border', '╰' + ('─' * box_width) + '╯\n')) + return lines + + model_picker_widget = ConditionalContainer( + Window( + FormattedTextControl(_get_model_picker_display), + wrap_lines=True, + ), + filter=Condition(lambda: cli_ref._model_picker_state is not None), + ) + # Horizontal rules above and below the input. # On narrow/mobile terminals we keep the top separator for structure but # hide the bottom one to recover a full row for conversation content. @@ -8767,6 +9142,7 @@ class HermesCLI: secret_widget=secret_widget, approval_widget=approval_widget, clarify_widget=clarify_widget, + model_picker_widget=model_picker_widget, spinner_widget=spinner_widget, spacer=spacer, status_bar=status_bar, diff --git a/cron/scheduler.py b/cron/scheduler.py index 72c95302e2..eb7cb42c53 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -44,7 +44,7 @@ logger = logging.getLogger(__name__) _KNOWN_DELIVERY_PLATFORMS = frozenset({ "telegram", "discord", "slack", "whatsapp", "signal", "matrix", "mattermost", "homeassistant", "dingtalk", "feishu", - "wecom", "weixin", "sms", "email", "webhook", "bluebubbles", + "wecom", "wecom_callback", "weixin", "sms", "email", "webhook", "bluebubbles", }) from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run @@ -708,6 +708,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: provider_sort=pr.get("sort"), disabled_toolsets=["cronjob", "messaging", "clarify"], quiet_mode=True, + skip_context_files=True, # Don't inject SOUL.md/AGENTS.md from scheduler cwd skip_memory=True, # Cron system prompts would corrupt user representations platform="cron", session_id=_cron_session_id, diff --git a/docs/specs/container-cli-review-fixes.md b/docs/specs/container-cli-review-fixes.md new file mode 100644 index 0000000000..0eb9070dbf --- /dev/null +++ b/docs/specs/container-cli-review-fixes.md @@ -0,0 +1,329 @@ +# Container-Aware CLI Review Fixes Spec + +**PR:** NousResearch/hermes-agent#7543 +**Review:** cursor[bot] bugbot review (4094049442) + two prior rounds +**Date:** 2026-04-12 +**Branch:** `feat/container-aware-cli-clean` + +## Review Issues Summary + +Six issues were raised across three bugbot review rounds. Three were fixed in intermediate commits (38277a6a, 726cf90f). This spec addresses remaining design concerns surfaced by those reviews and simplifies the implementation based on interview decisions. + +| # | Issue | Severity | Status | +|---|-------|----------|--------| +| 1 | `os.execvp` retry loop unreachable | Medium | Fixed in 79e8cd12 (switched to subprocess.run) | +| 2 | Redundant `shutil.which("sudo")` | Medium | Fixed in 38277a6a (reuses `sudo` var) | +| 3 | Missing `chown -h` on symlink update | Low | Fixed in 38277a6a | +| 4 | Container routing after `parse_args()` | High | Fixed in 726cf90f | +| 5 | Hardcoded `/home/${user}` | Medium | Fixed in 726cf90f | +| 6 | Group membership not gated on `container.enable` | Low | Fixed in 726cf90f | + +The mechanical fixes are in place but the overall design needs revision. The retry loop, error swallowing, and process model have deeper issues than what the bugbot flagged. + +--- + +## Spec: Revised `_exec_in_container` + +### Design Principles + +1. **Let it crash.** No silent fallbacks. If `.container-mode` exists but something goes wrong, the error propagates naturally (Python traceback). The only case where container routing is skipped is when `.container-mode` doesn't exist or `HERMES_DEV=1`. +2. **No retries.** Probe once for sudo, exec once. If it fails, docker/podman's stderr reaches the user verbatim. +3. **Completely transparent.** No error wrapping, no prefixes, no spinners. Docker's output goes straight through. +4. **`os.execvp` on the happy path.** Replace the Python process entirely so there's no idle parent during interactive sessions. Note: `execvp` never returns on success (process is replaced) and raises `OSError` on failure (it does not return a value). The container process's exit code becomes the process exit code by definition — no explicit propagation needed. +5. **One human-readable exception to "let it crash".** `subprocess.TimeoutExpired` from the sudo probe gets a specific catch with a readable message, since a raw traceback for "your Docker daemon is slow" is confusing. All other exceptions propagate naturally. + +### Execution Flow + +``` +1. get_container_exec_info() + - HERMES_DEV=1 → return None (skip routing) + - Inside container → return None (skip routing) + - .container-mode doesn't exist → return None (skip routing) + - .container-mode exists → parse and return dict + - .container-mode exists but malformed/unreadable → LET IT CRASH (no try/except) + +2. _exec_in_container(container_info, sys.argv[1:]) + a. shutil.which(backend) → if None, print "{backend} not found on PATH" and sys.exit(1) + b. Sudo probe: subprocess.run([runtime, "inspect", "--format", "ok", container_name], timeout=15) + - If succeeds → needs_sudo = False + - If fails → try subprocess.run([sudo, "-n", runtime, "inspect", ...], timeout=15) + - If succeeds → needs_sudo = True + - If fails → print error with sudoers hint (including why -n is required) and sys.exit(1) + - If TimeoutExpired → catch specifically, print human-readable message about slow daemon + c. Build exec_cmd: [sudo? + runtime, "exec", tty_flags, "-u", exec_user, env_flags, container, hermes_bin, *cli_args] + d. os.execvp(exec_cmd[0], exec_cmd) + - On success: process is replaced — Python is gone, container exit code IS the process exit code + - On OSError: let it crash (natural traceback) +``` + +### Changes to `hermes_cli/main.py` + +#### `_exec_in_container` — rewrite + +Remove: +- The entire retry loop (`max_retries`, `for attempt in range(...)`) +- Spinner logic (`"Waiting for container..."`, dots) +- Exit code classification (125/126/127 handling) +- `subprocess.run` for the exec call (keep it only for the sudo probe) +- Special TTY vs non-TTY retry counts +- The `time` import (no longer needed) + +Change: +- Use `os.execvp(exec_cmd[0], exec_cmd)` as the final call +- Keep the `subprocess` import only for the sudo probe +- Keep TTY detection for the `-it` vs `-i` flag +- Keep env var forwarding (TERM, COLORTERM, LANG, LC_ALL) +- Keep the sudo probe as-is (it's the one "smart" part) +- Bump probe `timeout` from 5s to 15s — cold podman on a loaded machine needs headroom +- Catch `subprocess.TimeoutExpired` specifically on both probe calls — print a readable message about the daemon being unresponsive instead of a raw traceback +- Expand the sudoers hint error message to explain *why* `-n` (non-interactive) is required: a password prompt would hang the CLI or break piped commands + +The function becomes roughly: + +```python +def _exec_in_container(container_info: dict, cli_args: list): + """Replace the current process with a command inside the managed container. + + Probes whether sudo is needed (rootful containers), then os.execvp + into the container. If exec fails, the OS error propagates naturally. + """ + import shutil + import subprocess + + backend = container_info["backend"] + container_name = container_info["container_name"] + exec_user = container_info["exec_user"] + hermes_bin = container_info["hermes_bin"] + + runtime = shutil.which(backend) + if not runtime: + print(f"Error: {backend} not found on PATH. Cannot route to container.", + file=sys.stderr) + sys.exit(1) + + # Probe whether we need sudo to see the rootful container. + # Timeout is 15s — cold podman on a loaded machine can take a while. + # TimeoutExpired is caught specifically for a human-readable message; + # all other exceptions propagate naturally. + needs_sudo = False + sudo = None + try: + probe = subprocess.run( + [runtime, "inspect", "--format", "ok", container_name], + capture_output=True, text=True, timeout=15, + ) + except subprocess.TimeoutExpired: + print( + f"Error: timed out waiting for {backend} to respond.\n" + f"The {backend} daemon may be unresponsive or starting up.", + file=sys.stderr, + ) + sys.exit(1) + + if probe.returncode != 0: + sudo = shutil.which("sudo") + if sudo: + try: + probe2 = subprocess.run( + [sudo, "-n", runtime, "inspect", "--format", "ok", container_name], + capture_output=True, text=True, timeout=15, + ) + except subprocess.TimeoutExpired: + print( + f"Error: timed out waiting for sudo {backend} to respond.", + file=sys.stderr, + ) + sys.exit(1) + + if probe2.returncode == 0: + needs_sudo = True + else: + print( + f"Error: container '{container_name}' not found via {backend}.\n" + f"\n" + f"The NixOS service runs the container as root. Your user cannot\n" + f"see it because {backend} uses per-user namespaces.\n" + f"\n" + f"Fix: grant passwordless sudo for {backend}. The -n (non-interactive)\n" + f"flag is required because the CLI calls sudo non-interactively —\n" + f"a password prompt would hang or break piped commands:\n" + f"\n" + f' security.sudo.extraRules = [{{\n' + f' users = [ "{os.getenv("USER", "your-user")}" ];\n' + f' commands = [{{ command = "{runtime}"; options = [ "NOPASSWD" ]; }}];\n' + f' }}];\n' + f"\n" + f"Or run: sudo hermes {' '.join(cli_args)}", + file=sys.stderr, + ) + sys.exit(1) + else: + print( + f"Error: container '{container_name}' not found via {backend}.\n" + f"The container may be running under root. Try: sudo hermes {' '.join(cli_args)}", + file=sys.stderr, + ) + sys.exit(1) + + is_tty = sys.stdin.isatty() + tty_flags = ["-it"] if is_tty else ["-i"] + + env_flags = [] + for var in ("TERM", "COLORTERM", "LANG", "LC_ALL"): + val = os.environ.get(var) + if val: + env_flags.extend(["-e", f"{var}={val}"]) + + cmd_prefix = [sudo, "-n", runtime] if needs_sudo else [runtime] + exec_cmd = ( + cmd_prefix + ["exec"] + + tty_flags + + ["-u", exec_user] + + env_flags + + [container_name, hermes_bin] + + cli_args + ) + + # execvp replaces this process entirely — it never returns on success. + # On failure it raises OSError, which propagates naturally. + os.execvp(exec_cmd[0], exec_cmd) +``` + +#### Container routing call site in `main()` — remove try/except + +Current: +```python +try: + from hermes_cli.config import get_container_exec_info + container_info = get_container_exec_info() + if container_info: + _exec_in_container(container_info, sys.argv[1:]) + sys.exit(1) # exec failed if we reach here +except SystemExit: + raise +except Exception: + pass # Container routing unavailable, proceed locally +``` + +Revised: +```python +from hermes_cli.config import get_container_exec_info +container_info = get_container_exec_info() +if container_info: + _exec_in_container(container_info, sys.argv[1:]) + # Unreachable: os.execvp never returns on success (process is replaced) + # and raises OSError on failure (which propagates as a traceback). + # This line exists only as a defensive assertion. + sys.exit(1) +``` + +No try/except. If `.container-mode` doesn't exist, `get_container_exec_info()` returns `None` and we skip routing. If it exists but is broken, the exception propagates with a natural traceback. + +Note: `sys.exit(1)` after `_exec_in_container` is dead code in all paths — `os.execvp` either replaces the process or raises. It's kept as a belt-and-suspenders assertion with a comment marking it unreachable, not as actual error handling. + +### Changes to `hermes_cli/config.py` + +#### `get_container_exec_info` — remove inner try/except + +Current code catches `(OSError, IOError)` and returns `None`. This silently hides permission errors, corrupt files, etc. + +Change: Remove the try/except around file reading. Keep the early returns for `HERMES_DEV=1` and `_is_inside_container()`. The `FileNotFoundError` from `open()` when `.container-mode` doesn't exist should still return `None` (this is the "container mode not enabled" case). All other exceptions propagate. + +```python +def get_container_exec_info() -> Optional[dict]: + if os.environ.get("HERMES_DEV") == "1": + return None + if _is_inside_container(): + return None + + container_mode_file = get_hermes_home() / ".container-mode" + + try: + with open(container_mode_file, "r") as f: + # ... parse key=value lines ... + except FileNotFoundError: + return None + # All other exceptions (PermissionError, malformed data, etc.) propagate + + return { ... } +``` + +--- + +## Spec: NixOS Module Changes + +### Symlink creation — simplify to two branches + +Current: 4 branches (symlink exists, directory exists, other file, doesn't exist). + +Revised: 2 branches. + +```bash +if [ -d "${symlinkPath}" ] && [ ! -L "${symlinkPath}" ]; then + # Real directory — back it up, then create symlink + _backup="${symlinkPath}.bak.$(date +%s)" + echo "hermes-agent: backing up existing ${symlinkPath} to $_backup" + mv "${symlinkPath}" "$_backup" +fi +# For everything else (symlink, doesn't exist, etc.) — just force-create +ln -sfn "${target}" "${symlinkPath}" +chown -h ${user}:${cfg.group} "${symlinkPath}" +``` + +`ln -sfn` handles: existing symlink (replaces), doesn't exist (creates), and after the `mv` above (creates). The only case that needs special handling is a real directory, because `ln -sfn` cannot atomically replace a directory. + +Note: there is a theoretical race between the `[ -d ... ]` check and the `mv` (something could create/remove the directory in between). In practice this is a NixOS activation script running as root during `nixos-rebuild switch` — no other process should be touching `~/.hermes` at that moment. Not worth adding locking for. + +### Sudoers — document, don't auto-configure + +Do NOT add `security.sudo.extraRules` to the module. Document the sudoers requirement in the module's description/comments and in the error message the CLI prints when sudo probe fails. + +### Group membership gating — keep as-is + +The fix in 726cf90f (`cfg.container.enable && cfg.container.hostUsers != []`) is correct. Leftover group membership when container mode is disabled is harmless. No cleanup needed. + +--- + +## Spec: Test Rewrite + +The existing test file (`tests/hermes_cli/test_container_aware_cli.py`) has 16 tests. With the simplified exec model, several are obsolete. + +### Tests to keep (update as needed) + +- `test_is_inside_container_dockerenv` — unchanged +- `test_is_inside_container_containerenv` — unchanged +- `test_is_inside_container_cgroup_docker` — unchanged +- `test_is_inside_container_false_on_host` — unchanged +- `test_get_container_exec_info_returns_metadata` — unchanged +- `test_get_container_exec_info_none_inside_container` — unchanged +- `test_get_container_exec_info_none_without_file` — unchanged +- `test_get_container_exec_info_skipped_when_hermes_dev` — unchanged +- `test_get_container_exec_info_not_skipped_when_hermes_dev_zero` — unchanged +- `test_get_container_exec_info_defaults` — unchanged +- `test_get_container_exec_info_docker_backend` — unchanged + +### Tests to add + +- `test_get_container_exec_info_crashes_on_permission_error` — verify that `PermissionError` propagates (no silent `None` return) +- `test_exec_in_container_calls_execvp` — verify `os.execvp` is called with correct args (runtime, tty flags, user, env, container, binary, cli args) +- `test_exec_in_container_sudo_probe_sets_prefix` — verify that when first probe fails and sudo probe succeeds, `os.execvp` is called with `sudo -n` prefix +- `test_exec_in_container_no_runtime_hard_fails` — keep existing, verify `sys.exit(1)` when `shutil.which` returns None +- `test_exec_in_container_non_tty_uses_i_only` — update to check `os.execvp` args instead of `subprocess.run` args +- `test_exec_in_container_probe_timeout_prints_message` — verify that `subprocess.TimeoutExpired` from the probe produces a human-readable error and `sys.exit(1)`, not a raw traceback +- `test_exec_in_container_container_not_running_no_sudo` — verify the path where runtime exists (`shutil.which` returns a path) but probe returns non-zero and no sudo is available. Should print the "container may be running under root" error. This is distinct from `no_runtime_hard_fails` which covers `shutil.which` returning None. + +### Tests to delete + +- `test_exec_in_container_tty_retries_on_container_failure` — retry loop removed +- `test_exec_in_container_non_tty_retries_silently_exits_126` — retry loop removed +- `test_exec_in_container_propagates_hermes_exit_code` — no subprocess.run to check exit codes; execvp replaces the process. Note: exit code propagation still works correctly — when `os.execvp` succeeds, the container's process *becomes* this process, so its exit code is the process exit code by OS semantics. No application code needed, no test needed. A comment in the function docstring documents this intent for future readers. + +--- + +## Out of Scope + +- Auto-configuring sudoers rules in the NixOS module +- Any changes to `get_container_exec_info` parsing logic beyond the try/except narrowing +- Changes to `.container-mode` file format +- Changes to the `HERMES_DEV=1` bypass +- Changes to container detection logic (`_is_inside_container`) diff --git a/gateway/config.py b/gateway/config.py index ff363fe260..fdc8cc1b18 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -69,6 +69,7 @@ class Platform(Enum): WEBHOOK = "webhook" FEISHU = "feishu" WECOM = "wecom" + WECOM_CALLBACK = "wecom_callback" WEIXIN = "weixin" BLUEBUBBLES = "bluebubbles" @@ -319,9 +320,14 @@ class GatewayConfig: # Feishu uses extra dict for app credentials elif platform == Platform.FEISHU and config.extra.get("app_id"): connected.append(platform) - # WeCom uses extra dict for bot credentials + # WeCom bot mode uses extra dict for bot credentials elif platform == Platform.WECOM and config.extra.get("bot_id"): connected.append(platform) + # WeCom callback mode uses corp_id or apps list + elif platform == Platform.WECOM_CALLBACK and ( + config.extra.get("corp_id") or config.extra.get("apps") + ): + connected.append(platform) # BlueBubbles uses extra dict for local server config elif platform == Platform.BLUEBUBBLES and config.extra.get("server_url") and config.extra.get("password"): connected.append(platform) @@ -1035,6 +1041,23 @@ def _apply_env_overrides(config: GatewayConfig) -> None: name=os.getenv("WECOM_HOME_CHANNEL_NAME", "Home"), ) + # WeCom callback mode (self-built apps) + wecom_callback_corp_id = os.getenv("WECOM_CALLBACK_CORP_ID") + wecom_callback_corp_secret = os.getenv("WECOM_CALLBACK_CORP_SECRET") + if wecom_callback_corp_id and wecom_callback_corp_secret: + if Platform.WECOM_CALLBACK not in config.platforms: + config.platforms[Platform.WECOM_CALLBACK] = PlatformConfig() + config.platforms[Platform.WECOM_CALLBACK].enabled = True + config.platforms[Platform.WECOM_CALLBACK].extra.update({ + "corp_id": wecom_callback_corp_id, + "corp_secret": wecom_callback_corp_secret, + "agent_id": os.getenv("WECOM_CALLBACK_AGENT_ID", ""), + "token": os.getenv("WECOM_CALLBACK_TOKEN", ""), + "encoding_aes_key": os.getenv("WECOM_CALLBACK_ENCODING_AES_KEY", ""), + "host": os.getenv("WECOM_CALLBACK_HOST", "0.0.0.0"), + "port": int(os.getenv("WECOM_CALLBACK_PORT", "8645")), + }) + # Weixin (personal WeChat via iLink Bot API) weixin_token = os.getenv("WEIXIN_TOKEN") weixin_account_id = os.getenv("WEIXIN_ACCOUNT_ID") diff --git a/gateway/display_config.py b/gateway/display_config.py new file mode 100644 index 0000000000..e148be9103 --- /dev/null +++ b/gateway/display_config.py @@ -0,0 +1,206 @@ +"""Per-platform display/verbosity configuration resolver. + +Provides ``resolve_display_setting()`` — the single entry-point for reading +display settings with platform-specific overrides and sensible defaults. + +Resolution order (first non-None wins): + 1. ``display.platforms..`` — explicit per-platform user override + 2. ``display.`` — global user setting + 3. ``_PLATFORM_DEFAULTS[][]`` — built-in sensible default + 4. ``_GLOBAL_DEFAULTS[]`` — built-in global default + +Backward compatibility: ``display.tool_progress_overrides`` is still read as a +fallback for ``tool_progress`` when no ``display.platforms`` entry exists. A +config migration (version bump) automatically moves the old format into the new +``display.platforms`` structure. +""" + +from __future__ import annotations + +from typing import Any + +# --------------------------------------------------------------------------- +# Overrideable display settings and their global defaults +# --------------------------------------------------------------------------- +# These are the settings that can be configured per-platform. +# Other display settings (compact, personality, skin, etc.) are CLI-only +# and don't participate in per-platform resolution. + +_GLOBAL_DEFAULTS: dict[str, Any] = { + "tool_progress": "all", + "show_reasoning": False, + "tool_preview_length": 0, + "streaming": None, # None = follow top-level streaming config +} + +# --------------------------------------------------------------------------- +# Sensible per-platform defaults — tiered by platform capability +# --------------------------------------------------------------------------- +# Tier 1 (high): Supports message editing, typically personal/team use +# Tier 2 (medium): Supports editing but often workspace/customer-facing +# Tier 3 (low): No edit support — each progress msg is permanent +# Tier 4 (minimal): Batch/non-interactive delivery + +_TIER_HIGH = { + "tool_progress": "all", + "show_reasoning": False, + "tool_preview_length": 40, + "streaming": None, # follow global +} + +_TIER_MEDIUM = { + "tool_progress": "new", + "show_reasoning": False, + "tool_preview_length": 40, + "streaming": None, +} + +_TIER_LOW = { + "tool_progress": "off", + "show_reasoning": False, + "tool_preview_length": 40, + "streaming": False, +} + +_TIER_MINIMAL = { + "tool_progress": "off", + "show_reasoning": False, + "tool_preview_length": 0, + "streaming": False, +} + +_PLATFORM_DEFAULTS: dict[str, dict[str, Any]] = { + # Tier 1 — full edit support, personal/team use + "telegram": _TIER_HIGH, + "discord": _TIER_HIGH, + + # Tier 2 — edit support, often customer/workspace channels + "slack": _TIER_MEDIUM, + "mattermost": _TIER_MEDIUM, + "matrix": _TIER_MEDIUM, + "feishu": _TIER_MEDIUM, + + # Tier 3 — no edit support, progress messages are permanent + "signal": _TIER_LOW, + "whatsapp": _TIER_LOW, + "bluebubbles": _TIER_LOW, + "weixin": _TIER_LOW, + "wecom": _TIER_LOW, + "wecom_callback": _TIER_LOW, + "dingtalk": _TIER_LOW, + + # Tier 4 — batch or non-interactive delivery + "email": _TIER_MINIMAL, + "sms": _TIER_MINIMAL, + "webhook": _TIER_MINIMAL, + "homeassistant": _TIER_MINIMAL, + "api_server": {**_TIER_HIGH, "tool_preview_length": 0}, +} + +# Canonical set of per-platform overrideable keys (for validation). +OVERRIDEABLE_KEYS = frozenset(_GLOBAL_DEFAULTS.keys()) + + +def resolve_display_setting( + user_config: dict, + platform_key: str, + setting: str, + fallback: Any = None, +) -> Any: + """Resolve a display setting with per-platform override support. + + Parameters + ---------- + user_config : dict + The full parsed config.yaml dict. + platform_key : str + Platform config key (e.g. ``"telegram"``, ``"slack"``). Use + ``_platform_config_key(source.platform)`` from gateway/run.py. + setting : str + Display setting name (e.g. ``"tool_progress"``, ``"show_reasoning"``). + fallback : Any + Fallback value when the setting isn't found anywhere. + + Returns + ------- + The resolved value, or *fallback* if nothing is configured. + """ + display_cfg = user_config.get("display") or {} + + # 1. Explicit per-platform override (display.platforms..) + platforms = display_cfg.get("platforms") or {} + plat_overrides = platforms.get(platform_key) + if isinstance(plat_overrides, dict): + val = plat_overrides.get(setting) + if val is not None: + return _normalise(setting, val) + + # 1b. Backward compat: display.tool_progress_overrides. + if setting == "tool_progress": + legacy = display_cfg.get("tool_progress_overrides") + if isinstance(legacy, dict): + val = legacy.get(platform_key) + if val is not None: + return _normalise(setting, val) + + # 2. Global user setting (display.) + val = display_cfg.get(setting) + if val is not None: + return _normalise(setting, val) + + # 3. Built-in platform default + plat_defaults = _PLATFORM_DEFAULTS.get(platform_key) + if plat_defaults: + val = plat_defaults.get(setting) + if val is not None: + return val + + # 4. Built-in global default + val = _GLOBAL_DEFAULTS.get(setting) + if val is not None: + return val + + return fallback + + +def get_platform_defaults(platform_key: str) -> dict[str, Any]: + """Return the built-in default display settings for a platform. + + Falls back to ``_GLOBAL_DEFAULTS`` for unknown platforms. + """ + return dict(_PLATFORM_DEFAULTS.get(platform_key, _GLOBAL_DEFAULTS)) + + +def get_effective_display(user_config: dict, platform_key: str) -> dict[str, Any]: + """Return the fully-resolved display settings for a platform. + + Useful for status commands that want to show all effective settings. + """ + return { + key: resolve_display_setting(user_config, platform_key, key) + for key in OVERRIDEABLE_KEYS + } + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _normalise(setting: str, value: Any) -> Any: + """Normalise YAML quirks (bare ``off`` → False in YAML 1.1).""" + if setting == "tool_progress": + if value is False: + return "off" + if value is True: + return "all" + return str(value).lower() + if setting in ("show_reasoning", "streaming"): + if isinstance(value, str): + return value.lower() in ("true", "1", "yes", "on") + return bool(value) + if setting == "tool_preview_length": + try: + return int(value) + except (TypeError, ValueError): + return 0 + return value diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index b1d07e5d65..43a9338d78 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -456,6 +456,7 @@ class DiscordAdapter(BasePlatformAdapter): # show the standard typing gateway event for bots) self._typing_tasks: Dict[str, asyncio.Task] = {} self._bot_task: Optional[asyncio.Task] = None + self._post_connect_task: Optional[asyncio.Task] = None # Dedup cache: prevents duplicate bot responses when Discord # RESUME replays events after reconnects. self._dedup = MessageDeduplicator() @@ -545,15 +546,14 @@ class DiscordAdapter(BasePlatformAdapter): # Resolve any usernames in the allowed list to numeric IDs await adapter_self._resolve_allowed_usernames() - - # Sync slash commands with Discord - try: - synced = await adapter_self._client.tree.sync() - logger.info("[%s] Synced %d slash command(s)", adapter_self.name, len(synced)) - except Exception as e: # pragma: no cover - defensive logging - logger.warning("[%s] Slash command sync failed: %s", adapter_self.name, e, exc_info=True) adapter_self._ready_event.set() + if adapter_self._post_connect_task and not adapter_self._post_connect_task.done(): + adapter_self._post_connect_task.cancel() + adapter_self._post_connect_task = asyncio.create_task( + adapter_self._run_post_connect_initialization() + ) + @self._client.event async def on_message(message: DiscordMessage): # Dedup: Discord RESUME replays events after reconnects (#4777) @@ -686,14 +686,36 @@ class DiscordAdapter(BasePlatformAdapter): except Exception as e: # pragma: no cover - defensive logging logger.warning("[%s] Error during disconnect: %s", self.name, e, exc_info=True) + if self._post_connect_task and not self._post_connect_task.done(): + self._post_connect_task.cancel() + try: + await self._post_connect_task + except asyncio.CancelledError: + pass + self._running = False self._client = None self._ready_event.clear() + self._post_connect_task = None self._release_platform_lock() logger.info("[%s] Disconnected", self.name) + async def _run_post_connect_initialization(self) -> None: + """Finish non-critical startup work after Discord is connected.""" + if not self._client: + return + try: + synced = await asyncio.wait_for(self._client.tree.sync(), timeout=30) + logger.info("[%s] Synced %d slash command(s)", self.name, len(synced)) + except asyncio.TimeoutError: + logger.warning("[%s] Slash command sync timed out after 30s", self.name) + except asyncio.CancelledError: + raise + except Exception as e: # pragma: no cover - defensive logging + logger.warning("[%s] Slash command sync failed: %s", self.name, e, exc_info=True) + async def _add_reaction(self, message: Any, emoji: str) -> bool: """Add an emoji reaction to a Discord message.""" if not message or not hasattr(message, "add_reaction"): diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index 349f962d2e..75d7e9c9f6 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -104,7 +104,7 @@ MAX_MESSAGE_LENGTH = 4000 # Uses get_hermes_home() so each profile gets its own Matrix store. from hermes_constants import get_hermes_dir as _get_hermes_dir _STORE_DIR = _get_hermes_dir("platforms/matrix/store", "matrix/store") -_CRYPTO_PICKLE_PATH = _STORE_DIR / "crypto_store.pickle" +_CRYPTO_DB_PATH = _STORE_DIR / "crypto.db" # Grace period: ignore messages older than this many seconds before startup. _STARTUP_GRACE_SECONDS = 5 @@ -165,6 +165,33 @@ def check_matrix_requirements() -> bool: return True +class _CryptoStateStore: + """Adapter that satisfies the mautrix crypto StateStore interface. + + OlmMachine requires a StateStore with ``is_encrypted``, + ``get_encryption_info``, and ``find_shared_rooms``. The basic + ``MemoryStateStore`` from ``mautrix.client`` doesn't implement these, + so we provide simple implementations that consult the client's room + state. + """ + + def __init__(self, client_state_store: Any, joined_rooms: set): + self._ss = client_state_store + self._joined_rooms = joined_rooms + + async def is_encrypted(self, room_id: str) -> bool: + return (await self.get_encryption_info(room_id)) is not None + + async def get_encryption_info(self, room_id: str): + if hasattr(self._ss, "get_encryption_info"): + return await self._ss.get_encryption_info(room_id) + return None + + async def find_shared_rooms(self, user_id: str) -> list: + # Return all joined rooms — simple but correct for a single-user bot. + return list(self._joined_rooms) + + class MatrixAdapter(BasePlatformAdapter): """Gateway adapter for Matrix (any homeserver).""" @@ -199,6 +226,7 @@ class MatrixAdapter(BasePlatformAdapter): ) self._client: Any = None # mautrix.client.Client + self._crypto_db: Any = None # mautrix.util.async_db.Database self._sync_task: Optional[asyncio.Task] = None self._closing = False self._startup_ts: float = 0.0 @@ -252,6 +280,92 @@ class MatrixAdapter(BasePlatformAdapter): self._processed_events_set.add(event_id) return False + # ------------------------------------------------------------------ + # E2EE helpers + # ------------------------------------------------------------------ + + async def _verify_device_keys_on_server(self, client: Any, olm: Any) -> bool: + """Verify our device keys are on the homeserver after loading crypto state. + + Returns True if keys are valid or were successfully re-uploaded. + Returns False if verification fails (caller should refuse E2EE). + """ + try: + resp = await client.query_keys({client.mxid: [client.device_id]}) + except Exception as exc: + logger.error( + "Matrix: cannot verify device keys on server: %s — refusing E2EE", exc, + ) + return False + + # query_keys returns typed objects (QueryKeysResponse, DeviceKeys + # with KeyID keys). Normalise to plain strings for comparison. + device_keys_map = getattr(resp, "device_keys", {}) or {} + our_user_devices = device_keys_map.get(str(client.mxid)) or {} + our_keys = our_user_devices.get(str(client.device_id)) + + if not our_keys: + logger.warning("Matrix: device keys missing from server — re-uploading") + olm.account.shared = False + try: + await olm.share_keys() + except Exception as exc: + logger.error("Matrix: failed to re-upload device keys: %s", exc) + return False + return True + + # DeviceKeys.keys is a dict[KeyID, str]. Iterate to find the + # ed25519 key rather than constructing a KeyID for lookup. + server_ed25519 = None + keys_dict = getattr(our_keys, "keys", {}) or {} + for key_id, key_value in keys_dict.items(): + if str(key_id).startswith("ed25519:"): + server_ed25519 = str(key_value) + break + local_ed25519 = olm.account.identity_keys.get("ed25519") + + if server_ed25519 != local_ed25519: + if olm.account.shared: + # Restored account from DB but server has different keys — corrupted state. + logger.error( + "Matrix: server has different identity keys for device %s — " + "local crypto state is stale. Delete %s and restart.", + client.device_id, + _CRYPTO_DB_PATH, + ) + return False + + # Fresh account (never uploaded). Server has stale keys from a + # previous installation. Try to delete the old device and re-upload. + logger.warning( + "Matrix: server has stale keys for device %s — attempting re-upload", + client.device_id, + ) + try: + await client.api.request( + client.api.Method.DELETE + if hasattr(client.api, "Method") + else "DELETE", + f"/_matrix/client/v3/devices/{client.device_id}", + ) + logger.info("Matrix: deleted stale device %s from server", client.device_id) + except Exception: + # Device deletion often requires UIA or may simply not be + # permitted — that's fine, share_keys will try to overwrite. + pass + try: + await olm.share_keys() + except Exception as exc: + logger.error( + "Matrix: cannot upload device keys for %s: %s. " + "Try generating a new access token to get a fresh device.", + client.device_id, + exc, + ) + return False + + return True + # ------------------------------------------------------------------ # Required overrides # ------------------------------------------------------------------ @@ -350,54 +464,54 @@ class MatrixAdapter(BasePlatformAdapter): return False try: from mautrix.crypto import OlmMachine - from mautrix.crypto.store import MemoryCryptoStore + from mautrix.crypto.store.asyncpg import PgCryptoStore + from mautrix.util.async_db import Database + + _STORE_DIR.mkdir(parents=True, exist_ok=True) + + # Remove legacy pickle file from pre-SQLite era. + legacy_pickle = _STORE_DIR / "crypto_store.pickle" + if legacy_pickle.exists(): + logger.info("Matrix: removing legacy crypto_store.pickle (migrated to SQLite)") + legacy_pickle.unlink() + + # Open SQLite-backed crypto store. + crypto_db = Database.create( + f"sqlite:///{_CRYPTO_DB_PATH}", + upgrade_table=PgCryptoStore.upgrade_table, + ) + await crypto_db.start() + self._crypto_db = crypto_db - # account_id and pickle_key are required by mautrix ≥0.21. - # Use the Matrix user ID as account_id for stable identity. - # pickle_key secures in-memory serialisation; derive from - # the same user_id:device_id pair used for the on-disk HMAC. _acct_id = self._user_id or "hermes" - _pickle_key = f"{_acct_id}:{self._device_id}" - crypto_store = MemoryCryptoStore( + _pickle_key = f"{_acct_id}:{self._device_id or 'default'}" + crypto_store = PgCryptoStore( account_id=_acct_id, pickle_key=_pickle_key, + db=crypto_db, ) + await crypto_store.open() - # Restore persisted crypto state from a previous run. - # Uses HMAC to verify integrity before unpickling. - pickle_path = _CRYPTO_PICKLE_PATH - if pickle_path.exists(): - try: - import hashlib, hmac, pickle - raw = pickle_path.read_bytes() - # Format: 32-byte HMAC-SHA256 signature + pickle data. - if len(raw) > 32: - sig, payload = raw[:32], raw[32:] - # Key is derived from the device_id + user_id (stable per install). - hmac_key = f"{self._user_id}:{self._device_id}".encode() - expected = hmac.new(hmac_key, payload, hashlib.sha256).digest() - if hmac.compare_digest(sig, expected): - saved = pickle.loads(payload) # noqa: S301 - if isinstance(saved, MemoryCryptoStore): - crypto_store = saved - logger.info("Matrix: restored E2EE crypto store from %s", pickle_path) - else: - logger.warning("Matrix: crypto store HMAC mismatch — ignoring stale/tampered file") - except Exception as exc: - logger.warning("Matrix: could not restore crypto store: %s", exc) + crypto_state = _CryptoStateStore(state_store, self._joined_rooms) + olm = OlmMachine(client, crypto_store, crypto_state) - olm = OlmMachine(client, crypto_store, state_store) - - # Set trust policy: accept unverified devices so senders - # share Megolm session keys with us automatically. + # Accept unverified devices so senders share Megolm + # session keys with us automatically. olm.share_keys_min_trust = TrustState.UNVERIFIED olm.send_keys_min_trust = TrustState.UNVERIFIED await olm.load() + + # Verify our device keys are still on the homeserver. + if not await self._verify_device_keys_on_server(client, olm): + await crypto_db.stop() + await api.session.close() + return False + client.crypto = olm logger.info( "Matrix: E2EE enabled (store: %s%s)", - str(_STORE_DIR), + str(_CRYPTO_DB_PATH), f", device_id={client.device_id}" if client.device_id else "", ) except Exception as exc: @@ -438,6 +552,15 @@ class MatrixAdapter(BasePlatformAdapter): ) # Build DM room cache from m.direct account data. await self._refresh_dm_cache() + + # Dispatch events from the initial sync so the OlmMachine + # receives to-device key shares queued while we were offline. + try: + tasks = client.handle_sync(sync_data) + if tasks: + await asyncio.gather(*tasks) + except Exception as exc: + logger.warning("Matrix: initial sync event dispatch error: %s", exc) else: logger.warning("Matrix: initial sync returned unexpected type %s", type(sync_data).__name__) except Exception as exc: @@ -466,21 +589,12 @@ class MatrixAdapter(BasePlatformAdapter): except (asyncio.CancelledError, Exception): pass - # Persist E2EE crypto store before closing so the next restart - # can decrypt events using sessions from this run. - if self._client and self._encryption and getattr(self._client, "crypto", None): + # Close the SQLite crypto store database. + if hasattr(self, "_crypto_db") and self._crypto_db: try: - import hashlib, hmac, pickle - crypto_store = self._client.crypto.crypto_store - _STORE_DIR.mkdir(parents=True, exist_ok=True) - pickle_path = _CRYPTO_PICKLE_PATH - payload = pickle.dumps(crypto_store) - hmac_key = f"{self._user_id}:{self._device_id}".encode() - sig = hmac.new(hmac_key, payload, hashlib.sha256).digest() - pickle_path.write_bytes(sig + payload) - logger.info("Matrix: persisted E2EE crypto store to %s", pickle_path) + await self._crypto_db.stop() except Exception as exc: - logger.debug("Matrix: could not persist crypto store on disconnect: %s", exc) + logger.debug("Matrix: could not close crypto DB on disconnect: %s", exc) if self._client: try: @@ -853,13 +967,6 @@ class MatrixAdapter(BasePlatformAdapter): except Exception as exc: logger.warning("Matrix: sync event dispatch error: %s", exc) - # Share keys periodically if E2EE is enabled. - if self._encryption and getattr(client, "crypto", None): - try: - await client.crypto.share_keys() - except Exception as exc: - logger.warning("Matrix: E2EE key share failed: %s", exc) - # Retry any buffered undecrypted events. if self._pending_megolm: await self._retry_pending_decryptions() diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 884ef9c45b..2653296026 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -299,9 +299,11 @@ class TelegramAdapter(BasePlatformAdapter): # Exhausted retries — fatal message = ( - "Another Telegram bot poller is already using this token. " + "Another process is already polling this Telegram bot token " + "(possibly OpenClaw or another Hermes instance). " "Hermes stopped Telegram polling after %d retries. " - "Make sure only one gateway instance is running for this bot token." + "Only one poller can run per token — stop the other process " + "and restart with 'hermes start'." % MAX_CONFLICT_RETRIES ) logger.error("[%s] %s Original error: %s", self.name, message, error) diff --git a/gateway/platforms/webhook.py b/gateway/platforms/webhook.py index bb874f8f59..dfe7a70f3f 100644 --- a/gateway/platforms/webhook.py +++ b/gateway/platforms/webhook.py @@ -201,6 +201,7 @@ class WebhookAdapter(BasePlatformAdapter): "dingtalk", "feishu", "wecom", + "wecom_callback", "weixin", "bluebubbles", ): diff --git a/gateway/platforms/wecom_callback.py b/gateway/platforms/wecom_callback.py new file mode 100644 index 0000000000..4bb67d5cfa --- /dev/null +++ b/gateway/platforms/wecom_callback.py @@ -0,0 +1,387 @@ +"""WeCom callback-mode adapter for self-built enterprise applications. + +Unlike the bot/websocket adapter in ``wecom.py``, this handles the standard +WeCom callback flow: WeCom POSTs encrypted XML to an HTTP endpoint, the +adapter decrypts it, queues the message for the agent, and immediately +acknowledges. The agent's reply is delivered later via the proactive +``message/send`` API using an access-token. + +Supports multiple self-built apps under one gateway instance, scoped by +``corp_id:user_id`` to avoid cross-corp collisions. +""" + +from __future__ import annotations + +import asyncio +import logging +import socket as _socket +import time +from typing import Any, Dict, List, Optional +from xml.etree import ElementTree as ET + +try: + from aiohttp import web + + AIOHTTP_AVAILABLE = True +except ImportError: + web = None # type: ignore[assignment] + AIOHTTP_AVAILABLE = False + +try: + import httpx + + HTTPX_AVAILABLE = True +except ImportError: + httpx = None # type: ignore[assignment] + HTTPX_AVAILABLE = False + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult +from gateway.platforms.wecom_crypto import WXBizMsgCrypt, WeComCryptoError + +logger = logging.getLogger(__name__) + +DEFAULT_HOST = "0.0.0.0" +DEFAULT_PORT = 8645 +DEFAULT_PATH = "/wecom/callback" +ACCESS_TOKEN_TTL_SECONDS = 7200 +MESSAGE_DEDUP_TTL_SECONDS = 300 + + +def check_wecom_callback_requirements() -> bool: + return AIOHTTP_AVAILABLE and HTTPX_AVAILABLE + + +class WecomCallbackAdapter(BasePlatformAdapter): + def __init__(self, config: PlatformConfig): + super().__init__(config, Platform.WECOM_CALLBACK) + extra = config.extra or {} + self._host = str(extra.get("host") or DEFAULT_HOST) + self._port = int(extra.get("port") or DEFAULT_PORT) + self._path = str(extra.get("path") or DEFAULT_PATH) + self._apps: List[Dict[str, Any]] = self._normalize_apps(extra) + self._runner: Optional[web.AppRunner] = None + self._site: Optional[web.TCPSite] = None + self._app: Optional[web.Application] = None + self._http_client: Optional[httpx.AsyncClient] = None + self._message_queue: asyncio.Queue[MessageEvent] = asyncio.Queue() + self._poll_task: Optional[asyncio.Task] = None + self._seen_messages: Dict[str, float] = {} + self._user_app_map: Dict[str, str] = {} + self._access_tokens: Dict[str, Dict[str, Any]] = {} + + # ------------------------------------------------------------------ + # App normalisation + # ------------------------------------------------------------------ + + @staticmethod + def _user_app_key(corp_id: str, user_id: str) -> str: + return f"{corp_id}:{user_id}" if corp_id else user_id + + @staticmethod + def _normalize_apps(extra: Dict[str, Any]) -> List[Dict[str, Any]]: + apps = extra.get("apps") + if isinstance(apps, list) and apps: + return [dict(app) for app in apps if isinstance(app, dict)] + if extra.get("corp_id"): + return [ + { + "name": extra.get("name") or "default", + "corp_id": extra.get("corp_id", ""), + "corp_secret": extra.get("corp_secret", ""), + "agent_id": str(extra.get("agent_id", "")), + "token": extra.get("token", ""), + "encoding_aes_key": extra.get("encoding_aes_key", ""), + } + ] + return [] + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def connect(self) -> bool: + if not self._apps: + logger.warning("[WecomCallback] No callback apps configured") + return False + if not check_wecom_callback_requirements(): + logger.warning("[WecomCallback] aiohttp/httpx not installed") + return False + + # Quick port-in-use check. + try: + with _socket.socket(_socket.AF_INET, _socket.SOCK_STREAM) as sock: + sock.settimeout(1) + sock.connect(("127.0.0.1", self._port)) + logger.error("[WecomCallback] Port %d already in use", self._port) + return False + except (ConnectionRefusedError, OSError): + pass + + try: + self._http_client = httpx.AsyncClient(timeout=20.0) + self._app = web.Application() + self._app.router.add_get("/health", self._handle_health) + self._app.router.add_get(self._path, self._handle_verify) + self._app.router.add_post(self._path, self._handle_callback) + self._runner = web.AppRunner(self._app) + await self._runner.setup() + self._site = web.TCPSite(self._runner, self._host, self._port) + await self._site.start() + self._poll_task = asyncio.create_task(self._poll_loop()) + self._mark_connected() + logger.info( + "[WecomCallback] HTTP server listening on %s:%s%s", + self._host, self._port, self._path, + ) + for app in self._apps: + try: + await self._refresh_access_token(app) + except Exception as exc: + logger.warning( + "[WecomCallback] Initial token refresh failed for app '%s': %s", + app.get("name", "default"), exc, + ) + return True + except Exception: + await self._cleanup() + logger.exception("[WecomCallback] Failed to start") + return False + + async def disconnect(self) -> None: + self._running = False + if self._poll_task: + self._poll_task.cancel() + try: + await self._poll_task + except asyncio.CancelledError: + pass + self._poll_task = None + await self._cleanup() + self._mark_disconnected() + logger.info("[WecomCallback] Disconnected") + + async def _cleanup(self) -> None: + self._site = None + if self._runner: + await self._runner.cleanup() + self._runner = None + self._app = None + if self._http_client: + await self._http_client.aclose() + self._http_client = None + + # ------------------------------------------------------------------ + # Outbound: proactive send via access-token API + # ------------------------------------------------------------------ + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + app = self._resolve_app_for_chat(chat_id) + touser = chat_id.split(":", 1)[1] if ":" in chat_id else chat_id + try: + token = await self._get_access_token(app) + payload = { + "touser": touser, + "msgtype": "text", + "agentid": int(str(app.get("agent_id") or 0)), + "text": {"content": content[:2048]}, + "safe": 0, + } + resp = await self._http_client.post( + f"https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={token}", + json=payload, + ) + data = resp.json() + if data.get("errcode") != 0: + return SendResult(success=False, error=str(data)) + return SendResult( + success=True, + message_id=str(data.get("msgid", "")), + raw_response=data, + ) + except Exception as exc: + return SendResult(success=False, error=str(exc)) + + def _resolve_app_for_chat(self, chat_id: str) -> Dict[str, Any]: + """Pick the app associated with *chat_id*, falling back sensibly.""" + app_name = self._user_app_map.get(chat_id) + if not app_name and ":" not in chat_id: + # Legacy bare user_id — try to find a unique match. + matching = [k for k in self._user_app_map if k.endswith(f":{chat_id}")] + if len(matching) == 1: + app_name = self._user_app_map.get(matching[0]) + app = self._get_app_by_name(app_name) if app_name else None + return app or self._apps[0] + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + return {"name": chat_id, "type": "dm"} + + # ------------------------------------------------------------------ + # Inbound: HTTP callback handlers + # ------------------------------------------------------------------ + + async def _handle_health(self, request: web.Request) -> web.Response: + return web.json_response({"status": "ok", "platform": "wecom_callback"}) + + async def _handle_verify(self, request: web.Request) -> web.Response: + """GET endpoint — WeCom URL verification handshake.""" + msg_signature = request.query.get("msg_signature", "") + timestamp = request.query.get("timestamp", "") + nonce = request.query.get("nonce", "") + echostr = request.query.get("echostr", "") + for app in self._apps: + try: + crypt = self._crypt_for_app(app) + plain = crypt.verify_url(msg_signature, timestamp, nonce, echostr) + return web.Response(text=plain, content_type="text/plain") + except Exception: + continue + return web.Response(status=403, text="signature verification failed") + + async def _handle_callback(self, request: web.Request) -> web.Response: + """POST endpoint — receive an encrypted message callback.""" + msg_signature = request.query.get("msg_signature", "") + timestamp = request.query.get("timestamp", "") + nonce = request.query.get("nonce", "") + body = await request.text() + + for app in self._apps: + try: + decrypted = self._decrypt_request( + app, body, msg_signature, timestamp, nonce, + ) + event = self._build_event(app, decrypted) + if event is not None: + # Record which app this user belongs to. + if event.source and event.source.user_id: + map_key = self._user_app_key( + str(app.get("corp_id") or ""), event.source.user_id, + ) + self._user_app_map[map_key] = app["name"] + await self._message_queue.put(event) + # Immediately acknowledge — the agent's reply will arrive + # later via the proactive message/send API. + return web.Response(text="success", content_type="text/plain") + except WeComCryptoError: + continue + except Exception: + logger.exception("[WecomCallback] Error handling message") + break + return web.Response(status=400, text="invalid callback payload") + + async def _poll_loop(self) -> None: + """Drain the message queue and dispatch to the gateway runner.""" + while True: + event = await self._message_queue.get() + try: + task = asyncio.create_task(self.handle_message(event)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + except Exception: + logger.exception("[WecomCallback] Failed to enqueue event") + + # ------------------------------------------------------------------ + # XML / crypto helpers + # ------------------------------------------------------------------ + + def _decrypt_request( + self, app: Dict[str, Any], body: str, + msg_signature: str, timestamp: str, nonce: str, + ) -> str: + root = ET.fromstring(body) + encrypt = root.findtext("Encrypt", default="") + crypt = self._crypt_for_app(app) + return crypt.decrypt(msg_signature, timestamp, nonce, encrypt).decode("utf-8") + + def _build_event(self, app: Dict[str, Any], xml_text: str) -> Optional[MessageEvent]: + root = ET.fromstring(xml_text) + msg_type = (root.findtext("MsgType") or "").lower() + # Silently acknowledge lifecycle events. + if msg_type == "event": + event_name = (root.findtext("Event") or "").lower() + if event_name in {"enter_agent", "subscribe"}: + return None + if msg_type not in {"text", "event"}: + return None + + user_id = root.findtext("FromUserName", default="") + corp_id = root.findtext("ToUserName", default=app.get("corp_id", "")) + scoped_chat_id = self._user_app_key(corp_id, user_id) + content = root.findtext("Content", default="").strip() + if not content and msg_type == "event": + content = "/start" + msg_id = ( + root.findtext("MsgId") + or f"{user_id}:{root.findtext('CreateTime', default='0')}" + ) + source = self.build_source( + chat_id=scoped_chat_id, + chat_name=user_id, + chat_type="dm", + user_id=user_id, + user_name=user_id, + ) + return MessageEvent( + text=content, + message_type=MessageType.TEXT, + source=source, + raw_message=xml_text, + message_id=msg_id, + ) + + def _crypt_for_app(self, app: Dict[str, Any]) -> WXBizMsgCrypt: + return WXBizMsgCrypt( + token=str(app.get("token") or ""), + encoding_aes_key=str(app.get("encoding_aes_key") or ""), + receive_id=str(app.get("corp_id") or ""), + ) + + def _get_app_by_name(self, name: Optional[str]) -> Optional[Dict[str, Any]]: + if not name: + return None + for app in self._apps: + if app.get("name") == name: + return app + return None + + # ------------------------------------------------------------------ + # Access-token management + # ------------------------------------------------------------------ + + async def _get_access_token(self, app: Dict[str, Any]) -> str: + cached = self._access_tokens.get(app["name"]) + now = time.time() + if cached and cached.get("expires_at", 0) > now + 60: + return cached["token"] + return await self._refresh_access_token(app) + + async def _refresh_access_token(self, app: Dict[str, Any]) -> str: + resp = await self._http_client.get( + "https://qyapi.weixin.qq.com/cgi-bin/gettoken", + params={ + "corpid": app.get("corp_id"), + "corpsecret": app.get("corp_secret"), + }, + ) + data = resp.json() + if data.get("errcode") != 0: + raise RuntimeError(f"WeCom token refresh failed: {data}") + token = data["access_token"] + expires_in = int(data.get("expires_in", ACCESS_TOKEN_TTL_SECONDS)) + self._access_tokens[app["name"]] = { + "token": token, + "expires_at": time.time() + expires_in, + } + logger.info( + "[WecomCallback] Token refreshed for app '%s' (corp=%s), expires in %ss", + app.get("name", "default"), + app.get("corp_id", ""), + expires_in, + ) + return token diff --git a/gateway/platforms/wecom_crypto.py b/gateway/platforms/wecom_crypto.py new file mode 100644 index 0000000000..f984ca80c3 --- /dev/null +++ b/gateway/platforms/wecom_crypto.py @@ -0,0 +1,142 @@ +"""WeCom BizMsgCrypt-compatible AES-CBC encryption for callback mode. + +Implements the same wire format as Tencent's official ``WXBizMsgCrypt`` +SDK so that WeCom can verify, encrypt, and decrypt callback payloads. +""" + +from __future__ import annotations + +import base64 +import hashlib +import os +import secrets +import socket +import struct +from typing import Optional +from xml.etree import ElementTree as ET + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + + +class WeComCryptoError(Exception): + pass + + +class SignatureError(WeComCryptoError): + pass + + +class DecryptError(WeComCryptoError): + pass + + +class EncryptError(WeComCryptoError): + pass + + +class PKCS7Encoder: + block_size = 32 + + @classmethod + def encode(cls, text: bytes) -> bytes: + amount_to_pad = cls.block_size - (len(text) % cls.block_size) + if amount_to_pad == 0: + amount_to_pad = cls.block_size + pad = bytes([amount_to_pad]) * amount_to_pad + return text + pad + + @classmethod + def decode(cls, decrypted: bytes) -> bytes: + if not decrypted: + raise DecryptError("empty decrypted payload") + pad = decrypted[-1] + if pad < 1 or pad > cls.block_size: + raise DecryptError("invalid PKCS7 padding") + if decrypted[-pad:] != bytes([pad]) * pad: + raise DecryptError("malformed PKCS7 padding") + return decrypted[:-pad] + + +def _sha1_signature(token: str, timestamp: str, nonce: str, encrypt: str) -> str: + parts = sorted([token, timestamp, nonce, encrypt]) + return hashlib.sha1("".join(parts).encode("utf-8")).hexdigest() + + +class WXBizMsgCrypt: + """Minimal WeCom callback crypto helper compatible with BizMsgCrypt semantics.""" + + def __init__(self, token: str, encoding_aes_key: str, receive_id: str): + if not token: + raise ValueError("token is required") + if not encoding_aes_key: + raise ValueError("encoding_aes_key is required") + if len(encoding_aes_key) != 43: + raise ValueError("encoding_aes_key must be 43 chars") + if not receive_id: + raise ValueError("receive_id is required") + + self.token = token + self.receive_id = receive_id + self.key = base64.b64decode(encoding_aes_key + "=") + self.iv = self.key[:16] + + def verify_url(self, msg_signature: str, timestamp: str, nonce: str, echostr: str) -> str: + plain = self.decrypt(msg_signature, timestamp, nonce, echostr) + return plain.decode("utf-8") + + def decrypt(self, msg_signature: str, timestamp: str, nonce: str, encrypt: str) -> bytes: + expected = _sha1_signature(self.token, timestamp, nonce, encrypt) + if expected != msg_signature: + raise SignatureError("signature mismatch") + try: + cipher_text = base64.b64decode(encrypt) + except Exception as exc: + raise DecryptError(f"invalid base64 payload: {exc}") from exc + try: + cipher = Cipher(algorithms.AES(self.key), modes.CBC(self.iv), backend=default_backend()) + decryptor = cipher.decryptor() + padded = decryptor.update(cipher_text) + decryptor.finalize() + plain = PKCS7Encoder.decode(padded) + content = plain[16:] # skip 16-byte random prefix + xml_length = socket.ntohl(struct.unpack("I", content[:4])[0]) + xml_content = content[4:4 + xml_length] + receive_id = content[4 + xml_length:].decode("utf-8") + except WeComCryptoError: + raise + except Exception as exc: + raise DecryptError(f"decrypt failed: {exc}") from exc + + if receive_id != self.receive_id: + raise DecryptError("receive_id mismatch") + return xml_content + + def encrypt(self, plaintext: str, nonce: Optional[str] = None, timestamp: Optional[str] = None) -> str: + nonce = nonce or self._random_nonce() + timestamp = timestamp or str(int(__import__("time").time())) + encrypt = self._encrypt_bytes(plaintext.encode("utf-8")) + signature = _sha1_signature(self.token, timestamp, nonce, encrypt) + root = ET.Element("xml") + ET.SubElement(root, "Encrypt").text = encrypt + ET.SubElement(root, "MsgSignature").text = signature + ET.SubElement(root, "TimeStamp").text = timestamp + ET.SubElement(root, "Nonce").text = nonce + return ET.tostring(root, encoding="unicode") + + def _encrypt_bytes(self, raw: bytes) -> str: + try: + random_prefix = os.urandom(16) + msg_len = struct.pack("I", socket.htonl(len(raw))) + payload = random_prefix + msg_len + raw + self.receive_id.encode("utf-8") + padded = PKCS7Encoder.encode(payload) + cipher = Cipher(algorithms.AES(self.key), modes.CBC(self.iv), backend=default_backend()) + encryptor = cipher.encryptor() + encrypted = encryptor.update(padded) + encryptor.finalize() + return base64.b64encode(encrypted).decode("utf-8") + except Exception as exc: + raise EncryptError(f"encrypt failed: {exc}") from exc + + @staticmethod + def _random_nonce(length: int = 10) -> str: + alphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + return "".join(secrets.choice(alphabet) for _ in range(length)) diff --git a/gateway/run.py b/gateway/run.py index d71b9a0128..9d6f149f13 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -76,7 +76,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) # Resolve Hermes home directory (respects HERMES_HOME override) from hermes_constants import get_hermes_home -from utils import atomic_yaml_write +from utils import atomic_yaml_write, is_truthy_value _hermes_home = get_hermes_home() # Load environment variables from ~/.hermes/.env first. @@ -352,19 +352,14 @@ def _build_media_placeholder(event) -> str: return "\n".join(parts) -def _dequeue_pending_text(adapter, session_key: str) -> str | None: - """Consume and return the text of a pending queued message. +def _dequeue_pending_event(adapter, session_key: str) -> MessageEvent | None: + """Consume and return the full pending event for a session. - Preserves media context for captionless photo/document events by - building a placeholder so the message isn't silently dropped. + Queued follow-ups must preserve their media metadata so they can re-enter + the normal image/STT/document preprocessing path instead of being reduced + to a placeholder string. """ - event = adapter.get_pending_message(session_key) - if not event: - return None - text = event.text - if not text and getattr(event, "media_urls", None): - text = _build_media_placeholder(event) - return text + return adapter.get_pending_message(session_key) def _check_unavailable_skill(command_name: str) -> str | None: @@ -930,6 +925,12 @@ class GatewayRunner: adapter.fatal_error_code or "unknown", adapter.fatal_error_message or "unknown error", ) + self._update_platform_runtime_status( + adapter.platform.value, + platform_state="retrying" if adapter.fatal_error_retryable else "fatal", + error_code=adapter.fatal_error_code, + error_message=adapter.fatal_error_message, + ) existing = self.adapters.get(adapter.platform) if existing is adapter: @@ -1007,6 +1008,25 @@ class GatewayRunner: ) except Exception: pass + + def _update_platform_runtime_status( + self, + platform: str, + *, + platform_state: Optional[str] = None, + error_code: Optional[str] = None, + error_message: Optional[str] = None, + ) -> None: + try: + from gateway.status import write_runtime_status + write_runtime_status( + platform=platform, + platform_state=platform_state, + error_code=error_code, + error_message=error_message, + ) + except Exception: + pass @staticmethod def _load_prefill_messages() -> List[Dict[str, Any]]: @@ -1440,6 +1460,7 @@ class GatewayRunner: "MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS", "FEISHU_ALLOWED_USERS", "WECOM_ALLOWED_USERS", + "WECOM_CALLBACK_ALLOWED_USERS", "WEIXIN_ALLOWED_USERS", "BLUEBUBBLES_ALLOWED_USERS", "GATEWAY_ALLOWED_USERS") @@ -1453,6 +1474,7 @@ class GatewayRunner: "MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS", "FEISHU_ALLOW_ALL_USERS", "WECOM_ALLOW_ALL_USERS", + "WECOM_CALLBACK_ALLOW_ALL_USERS", "WEIXIN_ALLOW_ALL_USERS", "BLUEBUBBLES_ALLOW_ALL_USERS") ) @@ -1520,16 +1542,34 @@ class GatewayRunner: # Try to connect logger.info("Connecting to %s...", platform.value) + self._update_platform_runtime_status( + platform.value, + platform_state="connecting", + error_code=None, + error_message=None, + ) try: success = await adapter.connect() if success: self.adapters[platform] = adapter self._sync_voice_mode_state_to_adapter(adapter) connected_count += 1 + self._update_platform_runtime_status( + platform.value, + platform_state="connected", + error_code=None, + error_message=None, + ) logger.info("✓ %s connected", platform.value) else: logger.warning("✗ %s failed to connect", platform.value) if adapter.has_fatal_error: + self._update_platform_runtime_status( + platform.value, + platform_state="retrying" if adapter.fatal_error_retryable else "fatal", + error_code=adapter.fatal_error_code, + error_message=adapter.fatal_error_message, + ) target = ( startup_retryable_errors if adapter.fatal_error_retryable @@ -1546,6 +1586,12 @@ class GatewayRunner: "next_retry": time.monotonic() + 30, } else: + self._update_platform_runtime_status( + platform.value, + platform_state="retrying", + error_code=None, + error_message="failed to connect", + ) startup_retryable_errors.append( f"{platform.value}: failed to connect" ) @@ -1557,6 +1603,12 @@ class GatewayRunner: } except Exception as e: logger.error("✗ %s error: %s", platform.value, e) + self._update_platform_runtime_status( + platform.value, + platform_state="retrying", + error_code=None, + error_message=str(e), + ) startup_retryable_errors.append(f"{platform.value}: {e}") # Unexpected exceptions are typically transient — queue for retry self._failed_platforms[platform] = { @@ -1835,6 +1887,12 @@ class GatewayRunner: self._sync_voice_mode_state_to_adapter(adapter) self.delivery_router.adapters = self.adapters del self._failed_platforms[platform] + self._update_platform_runtime_status( + platform.value, + platform_state="connected", + error_code=None, + error_message=None, + ) logger.info("✓ %s reconnected successfully", platform.value) # Rebuild channel directory with the new adapter @@ -1846,12 +1904,24 @@ class GatewayRunner: else: # Check if the failure is non-retryable if adapter.has_fatal_error and not adapter.fatal_error_retryable: + self._update_platform_runtime_status( + platform.value, + platform_state="fatal", + error_code=adapter.fatal_error_code, + error_message=adapter.fatal_error_message, + ) logger.warning( "Reconnect %s: non-retryable error (%s), removing from retry queue", platform.value, adapter.fatal_error_message, ) del self._failed_platforms[platform] else: + self._update_platform_runtime_status( + platform.value, + platform_state="retrying", + error_code=adapter.fatal_error_code, + error_message=adapter.fatal_error_message or "failed to reconnect", + ) backoff = min(30 * (2 ** (attempt - 1)), _BACKOFF_CAP) info["attempts"] = attempt info["next_retry"] = time.monotonic() + backoff @@ -1860,6 +1930,12 @@ class GatewayRunner: platform.value, backoff, ) except Exception as e: + self._update_platform_runtime_status( + platform.value, + platform_state="retrying", + error_code=None, + error_message=str(e), + ) backoff = min(30 * (2 ** (attempt - 1)), _BACKOFF_CAP) info["attempts"] = attempt info["next_retry"] = time.monotonic() + backoff @@ -2081,6 +2157,16 @@ class GatewayRunner: return None return FeishuAdapter(config) + elif platform == Platform.WECOM_CALLBACK: + from gateway.platforms.wecom_callback import ( + WecomCallbackAdapter, + check_wecom_callback_requirements, + ) + if not check_wecom_callback_requirements(): + logger.warning("WeComCallback: aiohttp/httpx not installed") + return None + return WecomCallbackAdapter(config) + elif platform == Platform.WECOM: from gateway.platforms.wecom import WeComAdapter, check_wecom_requirements if not check_wecom_requirements(): @@ -2170,6 +2256,7 @@ class GatewayRunner: Platform.DINGTALK: "DINGTALK_ALLOWED_USERS", Platform.FEISHU: "FEISHU_ALLOWED_USERS", Platform.WECOM: "WECOM_ALLOWED_USERS", + Platform.WECOM_CALLBACK: "WECOM_CALLBACK_ALLOWED_USERS", Platform.WEIXIN: "WEIXIN_ALLOWED_USERS", Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOWED_USERS", } @@ -2186,6 +2273,7 @@ class GatewayRunner: Platform.DINGTALK: "DINGTALK_ALLOW_ALL_USERS", Platform.FEISHU: "FEISHU_ALLOW_ALL_USERS", Platform.WECOM: "WECOM_ALLOW_ALL_USERS", + Platform.WECOM_CALLBACK: "WECOM_CALLBACK_ALLOW_ALL_USERS", Platform.WEIXIN: "WEIXIN_ALLOW_ALL_USERS", Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOW_ALL_USERS", } @@ -2821,6 +2909,162 @@ class GatewayRunner: del self._running_agents[_quick_key] self._running_agents_ts.pop(_quick_key, None) + async def _prepare_inbound_message_text( + self, + *, + event: MessageEvent, + source: SessionSource, + history: List[Dict[str, Any]], + ) -> Optional[str]: + """Prepare inbound event text for the agent. + + Keep the normal inbound path and the queued follow-up path on the same + preprocessing pipeline so sender attribution, image enrichment, STT, + document notes, reply context, and @ references all behave the same. + """ + history = history or [] + message_text = event.text or "" + + _is_shared_thread = ( + source.chat_type != "dm" + and source.thread_id + and not getattr(self.config, "thread_sessions_per_user", False) + ) + if _is_shared_thread and source.user_name: + message_text = f"[{source.user_name}] {message_text}" + + if event.media_urls: + image_paths = [] + audio_paths = [] + for i, path in enumerate(event.media_urls): + mtype = event.media_types[i] if i < len(event.media_types) else "" + if mtype.startswith("image/") or event.message_type == MessageType.PHOTO: + image_paths.append(path) + if mtype.startswith("audio/") or event.message_type in (MessageType.VOICE, MessageType.AUDIO): + audio_paths.append(path) + + if image_paths: + message_text = await self._enrich_message_with_vision( + message_text, + image_paths, + ) + + if audio_paths: + message_text = await self._enrich_message_with_transcription( + message_text, + audio_paths, + ) + _stt_fail_markers = ( + "No STT provider", + "STT is disabled", + "can't listen", + "VOICE_TOOLS_OPENAI_KEY", + ) + if any(marker in message_text for marker in _stt_fail_markers): + _stt_adapter = self.adapters.get(source.platform) + _stt_meta = {"thread_id": source.thread_id} if source.thread_id else None + if _stt_adapter: + try: + _stt_msg = ( + "🎤 I received your voice message but can't transcribe it — " + "no speech-to-text provider is configured.\n\n" + "To enable voice: install faster-whisper " + "(`pip install faster-whisper` in the Hermes venv) " + "and set `stt.enabled: true` in config.yaml, " + "then /restart the gateway." + ) + if self._has_setup_skill(): + _stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`" + await _stt_adapter.send( + source.chat_id, + _stt_msg, + metadata=_stt_meta, + ) + except Exception: + pass + + if event.media_urls and event.message_type == MessageType.DOCUMENT: + import mimetypes as _mimetypes + + _TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"} + for i, path in enumerate(event.media_urls): + mtype = event.media_types[i] if i < len(event.media_types) else "" + if mtype in ("", "application/octet-stream"): + import os as _os2 + + _ext = _os2.path.splitext(path)[1].lower() + if _ext in _TEXT_EXTENSIONS: + mtype = "text/plain" + else: + guessed, _ = _mimetypes.guess_type(path) + if guessed: + mtype = guessed + if not mtype.startswith(("application/", "text/")): + continue + + import os as _os + import re as _re + + basename = _os.path.basename(path) + parts = basename.split("_", 2) + display_name = parts[2] if len(parts) >= 3 else basename + display_name = _re.sub(r'[^\w.\- ]', '_', display_name) + + if mtype.startswith("text/"): + context_note = ( + f"[The user sent a text document: '{display_name}'. " + f"Its content has been included below. " + f"The file is also saved at: {path}]" + ) + else: + context_note = ( + f"[The user sent a document: '{display_name}'. " + f"The file is saved at: {path}. " + f"Ask the user what they'd like you to do with it.]" + ) + message_text = f"{context_note}\n\n{message_text}" + + if getattr(event, "reply_to_text", None) and event.reply_to_message_id: + reply_snippet = event.reply_to_text[:500] + found_in_history = any( + reply_snippet[:200] in (msg.get("content") or "") + for msg in history + if msg.get("role") in ("assistant", "user", "tool") + ) + if not found_in_history: + message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}' + + if "@" in message_text: + try: + from agent.context_references import preprocess_context_references_async + from agent.model_metadata import get_model_context_length + + _msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~")) + _msg_ctx_len = get_model_context_length( + self._model, + base_url=self._base_url or "", + ) + _ctx_result = await preprocess_context_references_async( + message_text, + cwd=_msg_cwd, + context_length=_msg_ctx_len, + allowed_root=_msg_cwd, + ) + if _ctx_result.blocked: + _adapter = self.adapters.get(source.platform) + if _adapter: + await _adapter.send( + source.chat_id, + "\n".join(_ctx_result.warnings) or "Context injection refused.", + ) + return None + if _ctx_result.expanded: + message_text = _ctx_result.message + except Exception as exc: + logger.debug("@ context reference expansion failed: %s", exc) + + return message_text + async def _handle_message_with_agent(self, event, source, _quick_key: str): """Inner handler that runs under the _running_agents sentinel guard.""" _msg_start_time = time.time() @@ -3261,149 +3505,13 @@ class GatewayRunner: # attachments (documents, audio, etc.) are not sent to the vision # tool even when they appear in the same message. # ----------------------------------------------------------------- - message_text = event.text or "" - - # ----------------------------------------------------------------- - # Sender attribution for shared thread sessions. - # - # When multiple users share a single thread session (the default for - # threads), prefix each message with [sender name] so the agent can - # tell participants apart. Skip for DMs (single-user by nature) and - # when per-user thread isolation is explicitly enabled. - # ----------------------------------------------------------------- - _is_shared_thread = ( - source.chat_type != "dm" - and source.thread_id - and not getattr(self.config, "thread_sessions_per_user", False) + message_text = await self._prepare_inbound_message_text( + event=event, + source=source, + history=history, ) - if _is_shared_thread and source.user_name: - message_text = f"[{source.user_name}] {message_text}" - - if event.media_urls: - image_paths = [] - for i, path in enumerate(event.media_urls): - # Check media_types if available; otherwise infer from message type - mtype = event.media_types[i] if i < len(event.media_types) else "" - is_image = ( - mtype.startswith("image/") - or event.message_type == MessageType.PHOTO - ) - if is_image: - image_paths.append(path) - if image_paths: - message_text = await self._enrich_message_with_vision( - message_text, image_paths - ) - - # ----------------------------------------------------------------- - # Auto-transcribe voice/audio messages sent by the user - # ----------------------------------------------------------------- - if event.media_urls: - audio_paths = [] - for i, path in enumerate(event.media_urls): - mtype = event.media_types[i] if i < len(event.media_types) else "" - is_audio = ( - mtype.startswith("audio/") - or event.message_type in (MessageType.VOICE, MessageType.AUDIO) - ) - if is_audio: - audio_paths.append(path) - if audio_paths: - message_text = await self._enrich_message_with_transcription( - message_text, audio_paths - ) - # If STT failed, send a direct message to the user so they - # know voice isn't configured — don't rely on the agent to - # relay the error clearly. - _stt_fail_markers = ( - "No STT provider", - "STT is disabled", - "can't listen", - "VOICE_TOOLS_OPENAI_KEY", - ) - if any(m in message_text for m in _stt_fail_markers): - _stt_adapter = self.adapters.get(source.platform) - _stt_meta = {"thread_id": source.thread_id} if source.thread_id else None - if _stt_adapter: - try: - _stt_msg = ( - "🎤 I received your voice message but can't transcribe it — " - "no speech-to-text provider is configured.\n\n" - "To enable voice: install faster-whisper " - "(`pip install faster-whisper` in the Hermes venv) " - "and set `stt.enabled: true` in config.yaml, " - "then /restart the gateway." - ) - # Point to setup skill if it's installed - if self._has_setup_skill(): - _stt_msg += "\n\nFor full setup instructions, type: `/skill hermes-agent-setup`" - await _stt_adapter.send( - source.chat_id, _stt_msg, - metadata=_stt_meta, - ) - except Exception: - pass - - # ----------------------------------------------------------------- - # Enrich document messages with context notes for the agent - # ----------------------------------------------------------------- - if event.media_urls and event.message_type == MessageType.DOCUMENT: - import mimetypes as _mimetypes - _TEXT_EXTENSIONS = {".txt", ".md", ".csv", ".log", ".json", ".xml", ".yaml", ".yml", ".toml", ".ini", ".cfg"} - for i, path in enumerate(event.media_urls): - mtype = event.media_types[i] if i < len(event.media_types) else "" - # Fall back to extension-based detection when MIME type is unreliable. - if mtype in ("", "application/octet-stream"): - import os as _os2 - _ext = _os2.path.splitext(path)[1].lower() - if _ext in _TEXT_EXTENSIONS: - mtype = "text/plain" - else: - guessed, _ = _mimetypes.guess_type(path) - if guessed: - mtype = guessed - if not mtype.startswith(("application/", "text/")): - continue - # Extract display filename by stripping the doc_{uuid12}_ prefix - import os as _os - basename = _os.path.basename(path) - # Format: doc_<12hex>_ - parts = basename.split("_", 2) - display_name = parts[2] if len(parts) >= 3 else basename - # Sanitize to prevent prompt injection via filenames - import re as _re - display_name = _re.sub(r'[^\w.\- ]', '_', display_name) - - if mtype.startswith("text/"): - context_note = ( - f"[The user sent a text document: '{display_name}'. " - f"Its content has been included below. " - f"The file is also saved at: {path}]" - ) - else: - context_note = ( - f"[The user sent a document: '{display_name}'. " - f"The file is saved at: {path}. " - f"Ask the user what they'd like you to do with it.]" - ) - message_text = f"{context_note}\n\n{message_text}" - - # ----------------------------------------------------------------- - # Inject reply context when user replies to a message not in history. - # Telegram (and other platforms) let users reply to specific messages, - # but if the quoted message is from a previous session, cron delivery, - # or background task, the agent has no context about what's being - # referenced. Prepend the quoted text so the agent understands. (#1594) - # ----------------------------------------------------------------- - if getattr(event, 'reply_to_text', None) and event.reply_to_message_id: - reply_snippet = event.reply_to_text[:500] - found_in_history = any( - reply_snippet[:200] in (msg.get("content") or "") - for msg in history - if msg.get("role") in ("assistant", "user", "tool") - ) - if not found_in_history: - message_text = f'[Replying to: "{reply_snippet}"]\n\n{message_text}' + if message_text is None: + return try: # Emit agent:start hook @@ -3415,30 +3523,6 @@ class GatewayRunner: } await self.hooks.emit("agent:start", hook_ctx) - # Expand @ context references (@file:, @folder:, @diff, etc.) - if "@" in message_text: - try: - from agent.context_references import preprocess_context_references_async - from agent.model_metadata import get_model_context_length - _msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~")) - _msg_ctx_len = get_model_context_length( - self._model, base_url=self._base_url or "") - _ctx_result = await preprocess_context_references_async( - message_text, cwd=_msg_cwd, - context_length=_msg_ctx_len, allowed_root=_msg_cwd) - if _ctx_result.blocked: - _adapter = self.adapters.get(source.platform) - if _adapter: - await _adapter.send( - source.chat_id, - "\n".join(_ctx_result.warnings) or "Context injection refused.", - ) - return - if _ctx_result.expanded: - message_text = _ctx_result.message - except Exception as exc: - logger.debug("@ context reference expansion failed: %s", exc) - # Run the agent agent_result = await self._run_agent( message=message_text, @@ -3502,8 +3586,18 @@ class GatewayRunner: if agent_result.get("session_id") and agent_result["session_id"] != session_entry.session_id: session_entry.session_id = agent_result["session_id"] - # Prepend reasoning/thinking if display is enabled - if getattr(self, "_show_reasoning", False) and response: + # Prepend reasoning/thinking if display is enabled (per-platform) + try: + from gateway.display_config import resolve_display_setting as _rds + _show_reasoning_effective = _rds( + _load_gateway_config(), + _platform_config_key(source.platform), + "show_reasoning", + getattr(self, "_show_reasoning", False), + ) + except Exception: + _show_reasoning_effective = getattr(self, "_show_reasoning", False) + if _show_reasoning_effective and response: last_reasoning = agent_result.get("last_reasoning") if last_reasoning: # Collapse long reasoning to keep messages readable @@ -5489,16 +5583,20 @@ class GatewayRunner: "_Usage:_ `/reasoning `" ) - # Display toggle + # Display toggle (per-platform) + platform_key = _platform_config_key(event.source.platform) if args in ("show", "on"): self._show_reasoning = True - _save_config_key("display.show_reasoning", True) - return "🧠 ✓ Reasoning display: **ON**\nModel thinking will be shown before each response." + _save_config_key(f"display.platforms.{platform_key}.show_reasoning", True) + return ( + "🧠 ✓ Reasoning display: **ON**\n" + f"Model thinking will be shown before each response on **{platform_key}**." + ) if args in ("hide", "off"): self._show_reasoning = False - _save_config_key("display.show_reasoning", False) - return "🧠 ✓ Reasoning display: **OFF**" + _save_config_key(f"display.platforms.{platform_key}.show_reasoning", False) + return f"🧠 ✓ Reasoning display: **OFF** for **{platform_key}**" # Effort level change effort = args.strip() @@ -5601,11 +5699,14 @@ class GatewayRunner: Gated by ``display.tool_progress_command`` in config.yaml (default off). When enabled, cycles the tool progress mode through off → new → all → - verbose → off, same as the CLI. + verbose → off for the *current platform*. The setting is saved to + ``display.platforms..tool_progress`` so each channel can + have its own verbosity level independently. """ import yaml config_path = _hermes_home / "config.yaml" + platform_key = _platform_config_key(event.source.platform) # --- check config gate ------------------------------------------------ try: @@ -5624,7 +5725,7 @@ class GatewayRunner: "display:\n tool_progress_command: true\n```" ) - # --- cycle mode ------------------------------------------------------- + # --- cycle mode (per-platform) ---------------------------------------- cycle = ["off", "new", "all", "verbose"] descriptions = { "off": "⚙️ Tool progress: **OFF** — no tool activity shown.", @@ -5633,26 +5734,29 @@ class GatewayRunner: "verbose": "⚙️ Tool progress: **VERBOSE** — every tool call with full arguments.", } - raw_progress = user_config.get("display", {}).get("tool_progress", "all") - # YAML 1.1 parses bare "off" as boolean False — normalise back - if raw_progress is False: - current = "off" - elif raw_progress is True: - current = "all" - else: - current = str(raw_progress).lower() + # Read current effective mode for this platform via the resolver + from gateway.display_config import resolve_display_setting + current = resolve_display_setting(user_config, platform_key, "tool_progress", "all") if current not in cycle: current = "all" idx = (cycle.index(current) + 1) % len(cycle) new_mode = cycle[idx] - # Save to config.yaml + # Save to display.platforms..tool_progress try: if "display" not in user_config or not isinstance(user_config.get("display"), dict): user_config["display"] = {} - user_config["display"]["tool_progress"] = new_mode + display = user_config["display"] + if "platforms" not in display or not isinstance(display.get("platforms"), dict): + display["platforms"] = {} + if platform_key not in display["platforms"] or not isinstance(display["platforms"].get(platform_key), dict): + display["platforms"][platform_key] = {} + display["platforms"][platform_key]["tool_progress"] = new_mode atomic_yaml_write(config_path, user_config) - return f"{descriptions[new_mode]}\n_(saved to config — takes effect on next message)_" + return ( + f"{descriptions[new_mode]}\n" + f"_(saved for **{platform_key}** — takes effect on next message)_" + ) except Exception as e: logger.warning("Failed to save tool_progress mode: %s", e) return f"{descriptions[new_mode]}\n_(could not save to config: {e})_" @@ -6273,7 +6377,7 @@ class GatewayRunner: Platform.TELEGRAM, Platform.DISCORD, Platform.SLACK, Platform.WHATSAPP, Platform.SIGNAL, Platform.MATTERMOST, Platform.MATRIX, Platform.HOMEASSISTANT, Platform.EMAIL, Platform.SMS, Platform.DINGTALK, - Platform.FEISHU, Platform.WECOM, Platform.WEIXIN, Platform.BLUEBUBBLES, Platform.LOCAL, + Platform.FEISHU, Platform.WECOM, Platform.WECOM_CALLBACK, Platform.WEIXIN, Platform.BLUEBUBBLES, Platform.LOCAL, }) async def _handle_update_command(self, event: MessageEvent) -> str: @@ -6686,6 +6790,7 @@ class GatewayRunner: thread_id=str(context.source.thread_id) if context.source.thread_id else "", user_id=str(context.source.user_id) if context.source.user_id else "", user_name=str(context.source.user_name) if context.source.user_name else "", + session_key=context.session_key, ) def _clear_session_env(self, tokens: list) -> None: @@ -7127,31 +7232,27 @@ class GatewayRunner: from hermes_cli.tools_config import _get_platform_tools enabled_toolsets = sorted(_get_platform_tools(user_config, platform_key)) + display_config = user_config.get("display", {}) + if not isinstance(display_config, dict): + display_config = {} + + # Per-platform display settings — resolve via display_config module + # which checks display.platforms.. first, then + # display. global, then built-in platform defaults. + from gateway.display_config import resolve_display_setting + # Apply tool preview length config (0 = no limit) try: from agent.display import set_tool_preview_max_len - _tpl = user_config.get("display", {}).get("tool_preview_length", 0) + _tpl = resolve_display_setting(user_config, platform_key, "tool_preview_length", 0) set_tool_preview_max_len(int(_tpl) if _tpl else 0) except Exception: pass - # Tool progress mode from config.yaml: "all", "new", "verbose", "off" - # Falls back to env vars for backward compatibility. - # YAML 1.1 parses bare `off` as boolean False — normalise before - # the `or` chain so it doesn't silently fall through to "all". - # - # Per-platform overrides (display.tool_progress_overrides) take - # priority over the global setting — e.g. Signal users can set - # tool_progress to "off" while keeping Telegram on "all". - _display_cfg = user_config.get("display", {}) - _overrides = _display_cfg.get("tool_progress_overrides", {}) - _raw_tp = _overrides.get(platform_key) - if _raw_tp is None: - _raw_tp = _display_cfg.get("tool_progress") - if _raw_tp is False: - _raw_tp = "off" + # Tool progress mode — resolved per-platform with env var fallback + _resolved_tp = resolve_display_setting(user_config, platform_key, "tool_progress") progress_mode = ( - _raw_tp + _resolved_tp or os.getenv("HERMES_TOOL_PROGRESS_MODE") or "all" ) @@ -7159,6 +7260,16 @@ class GatewayRunner: # so each progress line would be sent as a separate message. from gateway.config import Platform tool_progress_enabled = progress_mode != "off" and source.platform != Platform.WEBHOOK + # Natural assistant status messages are intentionally independent from + # tool progress and token streaming. Users can keep tool_progress quiet + # in chat platforms while opting into concise mid-turn updates. + interim_assistant_messages_enabled = ( + source.platform != Platform.WEBHOOK + and is_truthy_value( + display_config.get("interim_assistant_messages"), + default=True, + ) + ) # Queue for progress messages (thread-safe) progress_queue = queue.Queue() if tool_progress_enabled else None @@ -7428,8 +7539,8 @@ class GatewayRunner: # `_resolve_turn_agent_config(message, …)`. nonlocal message - # Pass session_key to process registry via env var so background - # processes can be mapped back to this gateway session + # session_key is now set via contextvars in _set_session_env() + # (concurrency-safe). Keep os.environ as fallback for CLI/cron. os.environ["HERMES_SESSION_KEY"] = session_key or "" # Read from env var or use default (same as CLI) @@ -7471,7 +7582,7 @@ class GatewayRunner: reasoning_config = self._load_reasoning_config() self._reasoning_config = reasoning_config self._service_tier = self._load_service_tier() - # Set up streaming consumer if enabled + # Set up stream consumer for token streaming or interim commentary. _stream_consumer = None _stream_delta_cb = None _scfg = getattr(getattr(self, 'config', None), 'streaming', None) @@ -7479,7 +7590,22 @@ class GatewayRunner: from gateway.config import StreamingConfig _scfg = StreamingConfig() - if _scfg.enabled and _scfg.transport != "off": + # Per-platform streaming gate: display.platforms..streaming + # can disable streaming for specific platforms even when the global + # streaming config is enabled. + _plat_streaming = resolve_display_setting( + user_config, platform_key, "streaming" + ) + # None = no per-platform override → follow global config + _streaming_enabled = ( + _scfg.enabled and _scfg.transport != "off" + if _plat_streaming is None + else bool(_plat_streaming) + ) + _want_stream_deltas = _streaming_enabled + _want_interim_messages = interim_assistant_messages_enabled + _want_interim_consumer = _want_interim_messages + if _want_stream_deltas or _want_interim_consumer: try: from gateway.stream_consumer import GatewayStreamConsumer, StreamConsumerConfig _adapter = self.adapters.get(source.platform) @@ -7495,11 +7621,33 @@ class GatewayRunner: config=_consumer_cfg, metadata={"thread_id": _progress_thread_id} if _progress_thread_id else None, ) - _stream_delta_cb = _stream_consumer.on_delta + if _want_stream_deltas: + _stream_delta_cb = _stream_consumer.on_delta stream_consumer_holder[0] = _stream_consumer except Exception as _sc_err: logger.debug("Could not set up stream consumer: %s", _sc_err) + def _interim_assistant_cb(text: str, *, already_streamed: bool = False) -> None: + if _stream_consumer is not None: + if already_streamed: + _stream_consumer.on_segment_break() + else: + _stream_consumer.on_commentary(text) + return + if already_streamed or not _status_adapter or not str(text or "").strip(): + return + try: + asyncio.run_coroutine_threadsafe( + _status_adapter.send( + _status_chat_id, + text, + metadata=_status_thread_metadata, + ), + _loop_for_step, + ) + except Exception as _e: + logger.debug("interim_assistant_callback error: %s", _e) + turn_route = self._resolve_turn_agent_config(message, model, runtime_kwargs) # Check agent cache — reuse the AIAgent from the previous message @@ -7557,6 +7705,7 @@ class GatewayRunner: agent.tool_progress_callback = progress_callback if tool_progress_enabled else None agent.step_callback = _step_callback_sync if _hooks_ref.loaded_hooks else None agent.stream_delta_callback = _stream_delta_cb + agent.interim_assistant_callback = _interim_assistant_cb if _want_interim_messages else None agent.status_callback = _status_callback_sync agent.reasoning_config = reasoning_config agent.service_tier = self._service_tier @@ -7860,6 +8009,7 @@ class GatewayRunner: "output_tokens": _output_toks, "model": _resolved_model, "session_id": effective_session_id, + "response_previewed": result.get("response_previewed", False), } # Start progress message sender if enabled @@ -8111,17 +8261,16 @@ class GatewayRunner: # Get pending message from adapter. # Use session_key (not source.chat_id) to match adapter's storage keys. + pending_event = None pending = None if result and adapter and session_key: - if result.get("interrupted"): - pending = _dequeue_pending_text(adapter, session_key) - if not pending and result.get("interrupt_message"): - pending = result.get("interrupt_message") - else: - pending = _dequeue_pending_text(adapter, session_key) - if pending: - logger.debug("Processing queued message after agent completion: '%s...'", pending[:40]) - + pending_event = _dequeue_pending_event(adapter, session_key) + if result.get("interrupted") and not pending_event and result.get("interrupt_message"): + pending = result.get("interrupt_message") + elif pending_event: + pending = pending_event.text or _build_media_placeholder(pending_event) + logger.debug("Processing queued message after agent completion: '%s...'", pending[:40]) + # Safety net: if the pending text is a slash command (e.g. "/stop", # "/new"), discard it — commands should never be passed to the agent # as user input. The primary fix is in base.py (commands bypass the @@ -8139,27 +8288,29 @@ class GatewayRunner: "commands must not be passed as agent input", _pending_cmd_word, ) + pending_event = None pending = None except Exception: pass - if self._draining and pending: + if self._draining and (pending_event or pending): logger.info( "Discarding pending follow-up for session %s during gateway %s", session_key[:20] if session_key else "?", self._status_action_label(), ) + pending_event = None pending = None - if pending: + if pending_event or pending: logger.debug("Processing pending message: '%s...'", pending[:40]) - + # Clear the adapter's interrupt event so the next _run_agent call # doesn't immediately re-trigger the interrupt before the new agent # even makes its first API call (this was causing an infinite loop). if adapter and hasattr(adapter, '_active_sessions') and session_key and session_key in adapter._active_sessions: adapter._active_sessions[session_key].clear() - + # Cap recursion depth to prevent resource exhaustion when the # user sends multiple messages while the agent keeps failing. (#816) if _interrupt_depth >= self._MAX_INTERRUPT_DEPTH: @@ -8168,9 +8319,10 @@ class GatewayRunner: "queueing message instead of recursing.", _interrupt_depth, session_key, ) - # Queue the pending message for normal processing on next turn adapter = self.adapters.get(source.platform) - if adapter and hasattr(adapter, 'queue_message'): + if adapter and pending_event: + merge_pending_message_event(adapter._pending_messages, session_key, pending_event) + elif adapter and hasattr(adapter, 'queue_message'): adapter.queue_message(session_key, pending) return result_holder[0] or {"final_response": response, "messages": history} @@ -8180,28 +8332,66 @@ class GatewayRunner: # response before processing the queued follow-up. # Skip if streaming already delivered it. _sc = stream_consumer_holder[0] - _already_streamed = _sc and getattr(_sc, "already_sent", False) + if _sc and stream_task: + try: + await asyncio.wait_for(stream_task, timeout=5.0) + except (asyncio.TimeoutError, asyncio.CancelledError): + stream_task.cancel() + try: + await stream_task + except asyncio.CancelledError: + pass + except Exception as e: + logger.debug("Stream consumer wait before queued message failed: %s", e) + _response_previewed = bool(result.get("response_previewed")) + _already_streamed = bool( + _sc + and ( + getattr(_sc, "final_response_sent", False) + or ( + _response_previewed + and getattr(_sc, "already_sent", False) + ) + ) + ) first_response = result.get("final_response", "") if first_response and not _already_streamed: try: - await adapter.send(source.chat_id, first_response, - metadata=getattr(event, "metadata", None)) + await adapter.send( + source.chat_id, + first_response, + metadata=_status_thread_metadata, + ) except Exception as e: logger.warning("Failed to send first response before queued message: %s", e) # else: interrupted — discard the interrupted response ("Operation # interrupted." is just noise; the user already knows they sent a # new message). - # Process the pending message with updated history updated_history = result.get("messages", history) + next_source = source + next_message = pending + next_message_id = None + if pending_event is not None: + next_source = getattr(pending_event, "source", None) or source + next_message = await self._prepare_inbound_message_text( + event=pending_event, + source=next_source, + history=updated_history, + ) + if next_message is None: + return result + next_message_id = getattr(pending_event, "message_id", None) + return await self._run_agent( - message=pending, + message=next_message, context_prompt=context_prompt, history=updated_history, - source=source, + source=next_source, session_id=session_id, session_key=session_key, _interrupt_depth=_interrupt_depth + 1, + event_message_id=next_message_id, ) finally: # Stop progress sender, interrupt monitor, and notification task @@ -8244,8 +8434,15 @@ class GatewayRunner: # message is new content the user hasn't seen, and it must reach # them even if streaming had sent earlier partial output. _sc = stream_consumer_holder[0] - if _sc and _sc.already_sent and isinstance(response, dict): - if not response.get("failed"): + if _sc and isinstance(response, dict) and not response.get("failed"): + _response_previewed = bool(response.get("response_previewed")) + if ( + getattr(_sc, "final_response_sent", False) + or ( + _response_previewed + and getattr(_sc, "already_sent", False) + ) + ): response["already_sent"] = True return response @@ -8394,23 +8591,11 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = except Exception: pass - # Centralized logging — agent.log (INFO+) and errors.log (WARNING+). + # Centralized logging — agent.log (INFO+), errors.log (WARNING+), + # and gateway.log (INFO+, gateway-component records only). # Idempotent, so repeated calls from AIAgent.__init__ won't duplicate. from hermes_logging import setup_logging - log_dir = setup_logging(hermes_home=_hermes_home, mode="gateway") - - # Gateway-specific rotating log — captures all gateway-level messages - # (session management, platform adapters, slash commands, etc.). - from agent.redact import RedactingFormatter - from hermes_logging import _add_rotating_handler - _add_rotating_handler( - logging.getLogger(), - log_dir / 'gateway.log', - level=logging.INFO, - max_bytes=5 * 1024 * 1024, - backup_count=3, - formatter=RedactingFormatter('%(asctime)s %(levelname)s %(name)s: %(message)s'), - ) + setup_logging(hermes_home=_hermes_home, mode="gateway") # Optional stderr handler — level driven by -v/-q flags on the CLI. # verbosity=None (-q/--quiet): no stderr output diff --git a/gateway/session.py b/gateway/session.py index 3492796b08..323f7e51ff 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -817,9 +817,9 @@ class SessionStore: to avoid resetting long-idle sessions that are harmless to resume. Returns the number of sessions that were suspended. """ - import time as _time + from datetime import timedelta - cutoff = _time.time() - max_age_seconds + cutoff = _now() - timedelta(seconds=max_age_seconds) count = 0 with self._lock: self._ensure_loaded_locked() diff --git a/gateway/session_context.py b/gateway/session_context.py index 6d676dc1ec..b9fdcdfaf7 100644 --- a/gateway/session_context.py +++ b/gateway/session_context.py @@ -48,6 +48,7 @@ _SESSION_CHAT_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_NAME", def _SESSION_THREAD_ID: ContextVar[str] = ContextVar("HERMES_SESSION_THREAD_ID", default="") _SESSION_USER_ID: ContextVar[str] = ContextVar("HERMES_SESSION_USER_ID", default="") _SESSION_USER_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_USER_NAME", default="") +_SESSION_KEY: ContextVar[str] = ContextVar("HERMES_SESSION_KEY", default="") _VAR_MAP = { "HERMES_SESSION_PLATFORM": _SESSION_PLATFORM, @@ -56,6 +57,7 @@ _VAR_MAP = { "HERMES_SESSION_THREAD_ID": _SESSION_THREAD_ID, "HERMES_SESSION_USER_ID": _SESSION_USER_ID, "HERMES_SESSION_USER_NAME": _SESSION_USER_NAME, + "HERMES_SESSION_KEY": _SESSION_KEY, } @@ -66,6 +68,7 @@ def set_session_vars( thread_id: str = "", user_id: str = "", user_name: str = "", + session_key: str = "", ) -> list: """Set all session context variables and return reset tokens. @@ -82,6 +85,7 @@ def set_session_vars( _SESSION_THREAD_ID.set(thread_id), _SESSION_USER_ID.set(user_id), _SESSION_USER_NAME.set(user_name), + _SESSION_KEY.set(session_key), ] return tokens @@ -97,6 +101,7 @@ def clear_session_vars(tokens: list) -> None: _SESSION_THREAD_ID, _SESSION_USER_ID, _SESSION_USER_NAME, + _SESSION_KEY, ] for var, token in zip(vars_in_order, tokens): var.reset(token) diff --git a/gateway/status.py b/gateway/status.py index 5423461c2f..d7f357b363 100644 --- a/gateway/status.py +++ b/gateway/status.py @@ -26,6 +26,7 @@ _GATEWAY_KIND = "hermes-gateway" _RUNTIME_STATUS_FILE = "gateway_state.json" _LOCKS_DIRNAME = "gateway-locks" _IS_WINDOWS = sys.platform == "win32" +_UNSET = object() def _get_pid_path() -> Path: @@ -218,14 +219,14 @@ def write_pid_file() -> None: def write_runtime_status( *, - gateway_state: Optional[str] = None, - exit_reason: Optional[str] = None, - restart_requested: Optional[bool] = None, - active_agents: Optional[int] = None, - platform: Optional[str] = None, - platform_state: Optional[str] = None, - error_code: Optional[str] = None, - error_message: Optional[str] = None, + gateway_state: Any = _UNSET, + exit_reason: Any = _UNSET, + restart_requested: Any = _UNSET, + active_agents: Any = _UNSET, + platform: Any = _UNSET, + platform_state: Any = _UNSET, + error_code: Any = _UNSET, + error_message: Any = _UNSET, ) -> None: """Persist gateway runtime health information for diagnostics/status.""" path = _get_runtime_status_path() @@ -236,22 +237,22 @@ def write_runtime_status( payload["start_time"] = _get_process_start_time(os.getpid()) payload["updated_at"] = _utc_now_iso() - if gateway_state is not None: + if gateway_state is not _UNSET: payload["gateway_state"] = gateway_state - if exit_reason is not None: + if exit_reason is not _UNSET: payload["exit_reason"] = exit_reason - if restart_requested is not None: + if restart_requested is not _UNSET: payload["restart_requested"] = bool(restart_requested) - if active_agents is not None: + if active_agents is not _UNSET: payload["active_agents"] = max(0, int(active_agents)) - if platform is not None: + if platform is not _UNSET: platform_payload = payload["platforms"].get(platform, {}) - if platform_state is not None: + if platform_state is not _UNSET: platform_payload["state"] = platform_state - if error_code is not None: + if error_code is not _UNSET: platform_payload["error_code"] = error_code - if error_message is not None: + if error_message is not _UNSET: platform_payload["error_message"] = error_message platform_payload["updated_at"] = _utc_now_iso() payload["platforms"][platform] = platform_payload diff --git a/gateway/stream_consumer.py b/gateway/stream_consumer.py index de0a1453b9..486d179de9 100644 --- a/gateway/stream_consumer.py +++ b/gateway/stream_consumer.py @@ -32,6 +32,10 @@ _DONE = object() # new one so that subsequent text appears below tool progress messages. _NEW_SEGMENT = object() +# Queue marker for a completed assistant commentary message emitted between +# API/tool iterations (for example: "I'll inspect the repo first."). +_COMMENTARY = object() + @dataclass class StreamConsumerConfig: @@ -75,20 +79,43 @@ class GatewayStreamConsumer: self._accumulated = "" self._message_id: Optional[str] = None self._already_sent = False - self._edit_supported = True # Disabled on first edit failure (Signal/Email/HA) + self._edit_supported = True # Disabled when progressive edits are no longer usable self._last_edit_time = 0.0 self._last_sent_text = "" # Track last-sent text to skip redundant edits self._fallback_final_send = False self._fallback_prefix = "" self._flood_strikes = 0 # Consecutive flood-control edit failures self._current_edit_interval = self.cfg.edit_interval # Adaptive backoff + self._final_response_sent = False @property def already_sent(self) -> bool: - """True if at least one message was sent/edited — signals the base - adapter to skip re-sending the final response.""" + """True if at least one message was sent or edited during the run.""" return self._already_sent + @property + def final_response_sent(self) -> bool: + """True when the stream consumer delivered the final assistant reply.""" + return self._final_response_sent + + def on_segment_break(self) -> None: + """Finalize the current stream segment and start a fresh message.""" + self._queue.put(_NEW_SEGMENT) + + def on_commentary(self, text: str) -> None: + """Queue a completed interim assistant commentary message.""" + if text: + self._queue.put((_COMMENTARY, text)) + + def _reset_segment_state(self, *, preserve_no_edit: bool = False) -> None: + if preserve_no_edit and self._message_id == "__no_edit__": + return + self._message_id = None + self._accumulated = "" + self._last_sent_text = "" + self._fallback_final_send = False + self._fallback_prefix = "" + def on_delta(self, text: str) -> None: """Thread-safe callback — called from the agent's worker thread. @@ -99,7 +126,7 @@ class GatewayStreamConsumer: if text: self._queue.put(text) elif text is None: - self._queue.put(_NEW_SEGMENT) + self.on_segment_break() def finish(self) -> None: """Signal that the stream is complete.""" @@ -116,6 +143,7 @@ class GatewayStreamConsumer: # Drain all available items from the queue got_done = False got_segment_break = False + commentary_text = None while True: try: item = self._queue.get_nowait() @@ -125,6 +153,9 @@ class GatewayStreamConsumer: if item is _NEW_SEGMENT: got_segment_break = True break + if isinstance(item, tuple) and len(item) == 2 and item[0] is _COMMENTARY: + commentary_text = item[1] + break self._accumulated += item except queue.Empty: break @@ -135,11 +166,13 @@ class GatewayStreamConsumer: should_edit = ( got_done or got_segment_break + or commentary_text is not None or (elapsed >= self._current_edit_interval and self._accumulated) or len(self._accumulated) >= self.cfg.buffer_threshold ) + current_update_visible = False if should_edit and self._accumulated: # Split overflow: if accumulated text exceeds the platform # limit, split into properly sized chunks. @@ -161,6 +194,7 @@ class GatewayStreamConsumer: self._last_sent_text = "" self._last_edit_time = time.monotonic() if got_done: + self._final_response_sent = self._already_sent return if got_segment_break: self._message_id = None @@ -192,10 +226,10 @@ class GatewayStreamConsumer: self._last_sent_text = "" display_text = self._accumulated - if not got_done and not got_segment_break: + if not got_done and not got_segment_break and commentary_text is None: display_text += self.cfg.cursor - await self._send_or_edit(display_text) + current_update_visible = await self._send_or_edit(display_text) self._last_edit_time = time.monotonic() if got_done: @@ -206,12 +240,20 @@ class GatewayStreamConsumer: if self._accumulated: if self._fallback_final_send: await self._send_fallback_final(self._accumulated) + elif current_update_visible: + self._final_response_sent = True elif self._message_id: - await self._send_or_edit(self._accumulated) + self._final_response_sent = await self._send_or_edit(self._accumulated) elif not self._already_sent: - await self._send_or_edit(self._accumulated) + self._final_response_sent = await self._send_or_edit(self._accumulated) return + if commentary_text is not None: + self._reset_segment_state() + await self._send_commentary(commentary_text) + self._last_edit_time = time.monotonic() + self._reset_segment_state() + # Tool boundary: reset message state so the next text chunk # creates a fresh message below any tool-progress messages. # @@ -220,17 +262,14 @@ class GatewayStreamConsumer: # github_comment delivery). Resetting to None would re-enter # the "first send" path on every tool boundary and post one # platform message per tool call — that is what caused 155 - # comments under a single PR. Instead, keep all state so the - # full continuation is delivered once via _send_fallback_final. + # comments under a single PR. Instead, preserve the sentinel + # so the full continuation is delivered once via + # _send_fallback_final. # (When editing fails mid-stream due to flood control the id is # a real string like "msg_1", not "__no_edit__", so that case # still resets and creates a fresh segment as intended.) - if got_segment_break and self._message_id != "__no_edit__": - self._message_id = None - self._accumulated = "" - self._last_sent_text = "" - self._fallback_final_send = False - self._fallback_prefix = "" + if got_segment_break: + self._reset_segment_state(preserve_no_edit=True) await asyncio.sleep(0.05) # Small yield to not busy-loop @@ -339,6 +378,7 @@ class GatewayStreamConsumer: if not continuation.strip(): # Nothing new to send — the visible partial already matches final text. self._already_sent = True + self._final_response_sent = True return raw_limit = getattr(self.adapter, "MAX_MESSAGE_LENGTH", 4096) @@ -373,6 +413,7 @@ class GatewayStreamConsumer: # the base gateway final-send path so we don't resend the # full response and create another duplicate. self._already_sent = True + self._final_response_sent = True self._message_id = last_message_id self._last_sent_text = last_successful_chunk self._fallback_prefix = "" @@ -390,6 +431,7 @@ class GatewayStreamConsumer: self._message_id = last_message_id self._already_sent = True + self._final_response_sent = True self._last_sent_text = chunks[-1] self._fallback_prefix = "" @@ -420,6 +462,24 @@ class GatewayStreamConsumer: except Exception: pass # best-effort — don't let this block the fallback path + async def _send_commentary(self, text: str) -> bool: + """Send a completed interim assistant commentary message.""" + text = self._clean_for_display(text) + if not text.strip(): + return False + try: + result = await self.adapter.send( + chat_id=self.chat_id, + content=text, + metadata=self.metadata, + ) + if result.success: + self._already_sent = True + return True + except Exception as e: + logger.error("Commentary send error: %s", e) + return False + async def _send_or_edit(self, text: str) -> bool: """Send or edit the streaming message. @@ -501,23 +561,21 @@ class GatewayStreamConsumer: content=text, metadata=self.metadata, ) - if result.success and result.message_id: - self._message_id = result.message_id + if result.success: + if result.message_id: + self._message_id = result.message_id + else: + self._edit_supported = False self._already_sent = True self._last_sent_text = text + if not result.message_id: + self._fallback_prefix = self._visible_prefix() + self._fallback_final_send = True + # Sentinel prevents re-entering the first-send path on + # every delta/tool boundary when platforms accept a + # message but do not return an editable message id. + self._message_id = "__no_edit__" return True - elif result.success: - # Platform accepted the message but returned no message_id - # (e.g. Signal). Can't edit without an ID — switch to - # fallback mode: suppress intermediate deltas, send only - # the missing tail once the final response is ready. - self._already_sent = True - self._edit_supported = False - self._fallback_prefix = self._clean_for_display(text) - self._fallback_final_send = True - # Sentinel prevents re-entering this branch on every delta - self._message_id = "__no_edit__" - return True # platform accepted, just can't edit else: # Initial send failed — disable streaming for this session self._edit_supported = False diff --git a/hermes_cli/backup.py b/hermes_cli/backup.py new file mode 100644 index 0000000000..9aca0f8221 --- /dev/null +++ b/hermes_cli/backup.py @@ -0,0 +1,399 @@ +""" +Backup and import commands for hermes CLI. + +`hermes backup` creates a zip archive of the entire ~/.hermes/ directory +(excluding the hermes-agent repo and transient files). + +`hermes import` restores from a backup zip, overlaying onto the current +HERMES_HOME root. +""" + +import os +import sys +import time +import zipfile +from datetime import datetime +from pathlib import Path + +from hermes_constants import get_default_hermes_root, display_hermes_home + + +# --------------------------------------------------------------------------- +# Exclusion rules +# --------------------------------------------------------------------------- + +# Directory names to skip entirely (matched against each path component) +_EXCLUDED_DIRS = { + "hermes-agent", # the codebase repo — re-clone instead + "__pycache__", # bytecode caches — regenerated on import + ".git", # nested git dirs (profiles shouldn't have these, but safety) + "node_modules", # js deps if website/ somehow leaks in +} + +# File-name suffixes to skip +_EXCLUDED_SUFFIXES = ( + ".pyc", + ".pyo", +) + +# File names to skip (runtime state that's meaningless on another machine) +_EXCLUDED_NAMES = { + "gateway.pid", + "cron.pid", +} + + +def _should_exclude(rel_path: Path) -> bool: + """Return True if *rel_path* (relative to hermes root) should be skipped.""" + parts = rel_path.parts + + # Any path component matches an excluded dir name + for part in parts: + if part in _EXCLUDED_DIRS: + return True + + name = rel_path.name + + if name in _EXCLUDED_NAMES: + return True + + if name.endswith(_EXCLUDED_SUFFIXES): + return True + + return False + + +# --------------------------------------------------------------------------- +# Backup +# --------------------------------------------------------------------------- + +def _format_size(nbytes: int) -> str: + """Human-readable file size.""" + for unit in ("B", "KB", "MB", "GB"): + if nbytes < 1024: + return f"{nbytes:.1f} {unit}" if unit != "B" else f"{nbytes} {unit}" + nbytes /= 1024 + return f"{nbytes:.1f} TB" + + +def run_backup(args) -> None: + """Create a zip backup of the Hermes home directory.""" + hermes_root = get_default_hermes_root() + + if not hermes_root.is_dir(): + print(f"Error: Hermes home directory not found at {hermes_root}") + sys.exit(1) + + # Determine output path + if args.output: + out_path = Path(args.output).expanduser().resolve() + # If user gave a directory, put the zip inside it + if out_path.is_dir(): + stamp = datetime.now().strftime("%Y-%m-%d-%H%M%S") + out_path = out_path / f"hermes-backup-{stamp}.zip" + else: + stamp = datetime.now().strftime("%Y-%m-%d-%H%M%S") + out_path = Path.home() / f"hermes-backup-{stamp}.zip" + + # Ensure the suffix is .zip + if out_path.suffix.lower() != ".zip": + out_path = out_path.with_suffix(out_path.suffix + ".zip") + + # Ensure parent directory exists + out_path.parent.mkdir(parents=True, exist_ok=True) + + # Collect files + print(f"Scanning {display_hermes_home()} ...") + files_to_add: list[tuple[Path, Path]] = [] # (absolute, relative) + skipped_dirs = set() + + for dirpath, dirnames, filenames in os.walk(hermes_root, followlinks=False): + dp = Path(dirpath) + rel_dir = dp.relative_to(hermes_root) + + # Prune excluded directories in-place so os.walk doesn't descend + orig_dirnames = dirnames[:] + dirnames[:] = [ + d for d in dirnames + if d not in _EXCLUDED_DIRS + ] + for removed in set(orig_dirnames) - set(dirnames): + skipped_dirs.add(str(rel_dir / removed)) + + for fname in filenames: + fpath = dp / fname + rel = fpath.relative_to(hermes_root) + + if _should_exclude(rel): + continue + + # Skip the output zip itself if it happens to be inside hermes root + try: + if fpath.resolve() == out_path.resolve(): + continue + except (OSError, ValueError): + pass + + files_to_add.append((fpath, rel)) + + if not files_to_add: + print("No files to back up.") + return + + # Create the zip + file_count = len(files_to_add) + print(f"Backing up {file_count} files ...") + + total_bytes = 0 + errors = [] + t0 = time.monotonic() + + with zipfile.ZipFile(out_path, "w", zipfile.ZIP_DEFLATED, compresslevel=6) as zf: + for i, (abs_path, rel_path) in enumerate(files_to_add, 1): + try: + zf.write(abs_path, arcname=str(rel_path)) + total_bytes += abs_path.stat().st_size + except (PermissionError, OSError) as exc: + errors.append(f" {rel_path}: {exc}") + continue + + # Progress every 500 files + if i % 500 == 0: + print(f" {i}/{file_count} files ...") + + elapsed = time.monotonic() - t0 + zip_size = out_path.stat().st_size + + # Summary + print() + print(f"Backup complete: {out_path}") + print(f" Files: {file_count}") + print(f" Original: {_format_size(total_bytes)}") + print(f" Compressed: {_format_size(zip_size)}") + print(f" Time: {elapsed:.1f}s") + + if skipped_dirs: + print(f"\n Excluded directories:") + for d in sorted(skipped_dirs): + print(f" {d}/") + + if errors: + print(f"\n Warnings ({len(errors)} files skipped):") + for e in errors[:10]: + print(e) + if len(errors) > 10: + print(f" ... and {len(errors) - 10} more") + + print(f"\nRestore with: hermes import {out_path.name}") + + +# --------------------------------------------------------------------------- +# Import +# --------------------------------------------------------------------------- + +def _validate_backup_zip(zf: zipfile.ZipFile) -> tuple[bool, str]: + """Check that a zip looks like a Hermes backup. + + Returns (ok, reason). + """ + names = zf.namelist() + if not names: + return False, "zip archive is empty" + + # Look for telltale files that a hermes home would have + markers = {"config.yaml", ".env", "hermes_state.db", "memory_store.db"} + found = set() + for n in names: + # Could be at the root or one level deep (if someone zipped the directory) + basename = Path(n).name + if basename in markers: + found.add(basename) + + if not found: + return False, ( + "zip does not appear to be a Hermes backup " + "(no config.yaml, .env, or state databases found)" + ) + + return True, "" + + +def _detect_prefix(zf: zipfile.ZipFile) -> str: + """Detect if the zip has a common directory prefix wrapping all entries. + + Some tools zip as `.hermes/config.yaml` instead of `config.yaml`. + Returns the prefix to strip (empty string if none). + """ + names = [n for n in zf.namelist() if not n.endswith("/")] + if not names: + return "" + + # Find common prefix + parts_list = [Path(n).parts for n in names] + + # Check if all entries share a common first directory + first_parts = {p[0] for p in parts_list if len(p) > 1} + if len(first_parts) == 1: + prefix = first_parts.pop() + # Only strip if it looks like a hermes dir name + if prefix in (".hermes", "hermes"): + return prefix + "/" + + return "" + + +def run_import(args) -> None: + """Restore a Hermes backup from a zip file.""" + zip_path = Path(args.zipfile).expanduser().resolve() + + if not zip_path.is_file(): + print(f"Error: File not found: {zip_path}") + sys.exit(1) + + if not zipfile.is_zipfile(zip_path): + print(f"Error: Not a valid zip file: {zip_path}") + sys.exit(1) + + hermes_root = get_default_hermes_root() + + with zipfile.ZipFile(zip_path, "r") as zf: + # Validate + ok, reason = _validate_backup_zip(zf) + if not ok: + print(f"Error: {reason}") + sys.exit(1) + + prefix = _detect_prefix(zf) + members = [n for n in zf.namelist() if not n.endswith("/")] + file_count = len(members) + + print(f"Backup contains {file_count} files") + print(f"Target: {display_hermes_home()}") + + if prefix: + print(f"Detected archive prefix: {prefix!r} (will be stripped)") + + # Check for existing installation + has_config = (hermes_root / "config.yaml").exists() + has_env = (hermes_root / ".env").exists() + + if (has_config or has_env) and not args.force: + print() + print("Warning: Target directory already has Hermes configuration.") + print("Importing will overwrite existing files with backup contents.") + print() + try: + answer = input("Continue? [y/N] ").strip().lower() + except (EOFError, KeyboardInterrupt): + print("\nAborted.") + sys.exit(1) + if answer not in ("y", "yes"): + print("Aborted.") + return + + # Extract + print(f"\nImporting {file_count} files ...") + hermes_root.mkdir(parents=True, exist_ok=True) + + errors = [] + restored = 0 + t0 = time.monotonic() + + for member in members: + # Strip prefix if detected + if prefix and member.startswith(prefix): + rel = member[len(prefix):] + else: + rel = member + + if not rel: + continue + + target = hermes_root / rel + + # Security: reject absolute paths and traversals + try: + target.resolve().relative_to(hermes_root.resolve()) + except ValueError: + errors.append(f" {rel}: path traversal blocked") + continue + + try: + target.parent.mkdir(parents=True, exist_ok=True) + with zf.open(member) as src, open(target, "wb") as dst: + dst.write(src.read()) + restored += 1 + except (PermissionError, OSError) as exc: + errors.append(f" {rel}: {exc}") + + if restored % 500 == 0: + print(f" {restored}/{file_count} files ...") + + elapsed = time.monotonic() - t0 + + # Summary + print() + print(f"Import complete: {restored} files restored in {elapsed:.1f}s") + print(f" Target: {display_hermes_home()}") + + if errors: + print(f"\n Warnings ({len(errors)} files skipped):") + for e in errors[:10]: + print(e) + if len(errors) > 10: + print(f" ... and {len(errors) - 10} more") + + # Post-import: restore profile wrapper scripts + profiles_dir = hermes_root / "profiles" + restored_profiles = [] + if profiles_dir.is_dir(): + try: + from hermes_cli.profiles import ( + create_wrapper_script, check_alias_collision, + _is_wrapper_dir_in_path, _get_wrapper_dir, + ) + for entry in sorted(profiles_dir.iterdir()): + if not entry.is_dir(): + continue + profile_name = entry.name + # Only create wrappers for directories with config + if not (entry / "config.yaml").exists() and not (entry / ".env").exists(): + continue + collision = check_alias_collision(profile_name) + if collision: + print(f" Skipped alias '{profile_name}': {collision}") + restored_profiles.append((profile_name, False)) + else: + wrapper = create_wrapper_script(profile_name) + restored_profiles.append((profile_name, wrapper is not None)) + + if restored_profiles: + created = [n for n, ok in restored_profiles if ok] + skipped = [n for n, ok in restored_profiles if not ok] + if created: + print(f"\n Profile aliases restored: {', '.join(created)}") + if skipped: + print(f" Profile aliases skipped: {', '.join(skipped)}") + if not _is_wrapper_dir_in_path(): + print(f"\n Note: {_get_wrapper_dir()} is not in your PATH.") + print(' Add to your shell config (~/.bashrc or ~/.zshrc):') + print(' export PATH="$HOME/.local/bin:$PATH"') + except ImportError: + # hermes_cli.profiles might not be available (fresh install) + if any(profiles_dir.iterdir()): + print(f"\n Profiles detected but aliases could not be created.") + print(f" Run: hermes profile list (after installing hermes)") + + # Guidance + print() + if not (hermes_root / "hermes-agent").is_dir(): + print("Note: The hermes-agent codebase was not included in the backup.") + print(" If this is a fresh install, run: hermes update") + + if restored_profiles: + gw_profiles = [n for n, _ in restored_profiles] + print("\nTo re-enable gateway services for profiles:") + for pname in gw_profiles: + print(f" hermes -p {pname} gateway install") + + print("Done. Your Hermes configuration has been restored.") diff --git a/hermes_cli/claw.py b/hermes_cli/claw.py index 3ab6bf9a8d..d0bfd73d23 100644 --- a/hermes_cli/claw.py +++ b/hermes_cli/claw.py @@ -52,6 +52,41 @@ _OPENCLAW_SCRIPT_INSTALLED = ( # Known OpenClaw directory names (current + legacy) _OPENCLAW_DIR_NAMES = (".openclaw", ".clawdbot", ".moldbot") +def _warn_if_gateway_running(auto_yes: bool) -> None: + """Check if a Hermes gateway is running with connected platforms. + + Migrating bot tokens while the gateway is polling will cause conflicts + (e.g. Telegram 409 "terminated by other getUpdates request"). Warn the + user and let them decide whether to continue. + """ + from gateway.status import get_running_pid, read_runtime_status + + if not get_running_pid(): + return + + data = read_runtime_status() or {} + platforms = data.get("platforms") or {} + connected = [name for name, info in platforms.items() + if isinstance(info, dict) and info.get("state") == "connected"] + if not connected: + return + + print() + print_error( + "Hermes gateway is running with active connections: " + + ", ".join(connected) + ) + print_info( + "Migrating bot tokens while the gateway is active will cause " + "conflicts (Telegram, Discord, and Slack only allow one active " + "session per token)." + ) + print_info("Recommendation: stop the gateway first with 'hermes stop'.") + print() + if not auto_yes and not prompt_yes_no("Continue anyway?", default=False): + print_info("Migration cancelled. Stop the gateway and try again.") + sys.exit(0) + # State files commonly found in OpenClaw workspace directories that cause # confusion after migration (the agent discovers them and writes to them) _WORKSPACE_STATE_GLOBS = ( @@ -252,6 +287,10 @@ def _cmd_migrate(args): print_info(f"Workspace: {workspace_target}") print() + # Check if a gateway is running with connected platforms — migrating tokens + # while the gateway is active will cause conflicts (e.g. Telegram 409). + _warn_if_gateway_running(auto_yes) + # Ensure config.yaml exists before migration tries to read it config_path = get_config_path() if not config_path.exists(): diff --git a/hermes_cli/config.py b/hermes_cli/config.py index a8866fef56..93afb3f8cc 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -38,6 +38,9 @@ _EXTRA_ENV_KEYS = frozenset({ "DINGTALK_CLIENT_ID", "DINGTALK_CLIENT_SECRET", "FEISHU_APP_ID", "FEISHU_APP_SECRET", "FEISHU_ENCRYPT_KEY", "FEISHU_VERIFICATION_TOKEN", "WECOM_BOT_ID", "WECOM_SECRET", + "WECOM_CALLBACK_CORP_ID", "WECOM_CALLBACK_CORP_SECRET", "WECOM_CALLBACK_AGENT_ID", + "WECOM_CALLBACK_TOKEN", "WECOM_CALLBACK_ENCODING_AES_KEY", + "WECOM_CALLBACK_HOST", "WECOM_CALLBACK_PORT", "WEIXIN_ACCOUNT_ID", "WEIXIN_TOKEN", "WEIXIN_BASE_URL", "WEIXIN_CDN_BASE_URL", "WEIXIN_HOME_CHANNEL", "WEIXIN_HOME_CHANNEL_NAME", "WEIXIN_DM_POLICY", "WEIXIN_GROUP_POLICY", "WEIXIN_ALLOWED_USERS", "WEIXIN_GROUP_ALLOWED_USERS", "WEIXIN_ALLOW_ALL_USERS", @@ -142,6 +145,73 @@ def managed_error(action: str = "modify configuration"): print(format_managed_message(action), file=sys.stderr) +# ============================================================================= +# Container-aware CLI (NixOS container mode) +# ============================================================================= + +def _is_inside_container() -> bool: + """Detect if we're already running inside a Docker/Podman container.""" + # Standard Docker/Podman indicators + if os.path.exists("/.dockerenv"): + return True + # Podman uses /run/.containerenv + if os.path.exists("/run/.containerenv"): + return True + # Check cgroup for container runtime evidence (works for both Docker & Podman) + try: + with open("/proc/1/cgroup", "r") as f: + cgroup = f.read() + if "docker" in cgroup or "podman" in cgroup or "/lxc/" in cgroup: + return True + except OSError: + pass + return False + + +def get_container_exec_info() -> Optional[dict]: + """Read container mode metadata from HERMES_HOME/.container-mode. + + Returns a dict with keys: backend, container_name, exec_user, hermes_bin + or None if container mode is not active, we're already inside the + container, or HERMES_DEV=1 is set. + + The .container-mode file is written by the NixOS activation script when + container.enable = true. It tells the host CLI to exec into the container + instead of running locally. + """ + if os.environ.get("HERMES_DEV") == "1": + return None + + if _is_inside_container(): + return None + + container_mode_file = get_hermes_home() / ".container-mode" + + try: + info = {} + with open(container_mode_file, "r") as f: + for line in f: + line = line.strip() + if "=" in line and not line.startswith("#"): + key, _, value = line.partition("=") + info[key.strip()] = value.strip() + except FileNotFoundError: + return None + # All other exceptions (PermissionError, malformed data, etc.) propagate + + backend = info.get("backend", "docker") + container_name = info.get("container_name", "hermes-agent") + exec_user = info.get("exec_user", "hermes") + hermes_bin = info.get("hermes_bin", "/data/current-package/bin/hermes") + + return { + "backend": backend, + "container_name": container_name, + "exec_user": exec_user, + "hermes_bin": hermes_bin, + } + + # ============================================================================= # Config paths # ============================================================================= @@ -447,9 +517,11 @@ DEFAULT_CONFIG = { "inline_diffs": True, # Show inline diff previews for write actions (write_file, patch, skill_manage) "show_cost": False, # Show $ cost in the status bar (off by default) "skin": "default", + "interim_assistant_messages": True, # Gateway: show natural mid-turn assistant status messages "tool_progress_command": False, # Enable /verbose command in messaging gateway - "tool_progress_overrides": {}, # Per-platform overrides: {"signal": "off", "telegram": "all"} + "tool_progress_overrides": {}, # DEPRECATED — use display.platforms instead "tool_preview_length": 0, # Max chars for tool call previews (0 = no limit, show full paths/commands) + "platforms": {}, # Per-platform display overrides: {"telegram": {"tool_progress": "all"}, "slack": {"tool_progress": "off"}} }, # Privacy settings @@ -637,7 +709,7 @@ DEFAULT_CONFIG = { }, # Config schema version - bump this when adding new required fields - "_config_version": 14, + "_config_version": 16, } # ============================================================================= @@ -1901,6 +1973,44 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A if not quiet: print(f" ✓ Migrated legacy stt.model to provider-specific config") + # ── Version 14 → 15: add explicit gateway interim-message gate ── + if current_ver < 15: + config = read_raw_config() + display = config.get("display", {}) + if not isinstance(display, dict): + display = {} + if "interim_assistant_messages" not in display: + display["interim_assistant_messages"] = True + config["display"] = display + results["config_added"].append("display.interim_assistant_messages=true (default)") + save_config(config) + if not quiet: + print(" ✓ Added display.interim_assistant_messages=true") + + # ── Version 15 → 16: migrate tool_progress_overrides into display.platforms ── + if current_ver < 16: + config = read_raw_config() + display = config.get("display", {}) + if not isinstance(display, dict): + display = {} + old_overrides = display.get("tool_progress_overrides") + if isinstance(old_overrides, dict) and old_overrides: + platforms = display.get("platforms", {}) + if not isinstance(platforms, dict): + platforms = {} + for plat, mode in old_overrides.items(): + if plat not in platforms: + platforms[plat] = {} + if "tool_progress" not in platforms[plat]: + platforms[plat]["tool_progress"] = mode + display["platforms"] = platforms + config["display"] = display + save_config(config) + if not quiet: + migrated = ", ".join(f"{p}={m}" for p, m in old_overrides.items()) + print(f" ✓ Migrated tool_progress_overrides → display.platforms: {migrated}") + results["config_added"].append("display.platforms (migrated from tool_progress_overrides)") + if current_ver < latest_ver and not quiet: print(f"Config version: {current_ver} → {latest_ver}") diff --git a/hermes_cli/curses_ui.py b/hermes_cli/curses_ui.py index 9cebaf60f8..4880171fd4 100644 --- a/hermes_cli/curses_ui.py +++ b/hermes_cli/curses_ui.py @@ -287,6 +287,129 @@ def _radio_numbered_fallback( return cancel_returns +def curses_single_select( + title: str, + items: List[str], + default_index: int = 0, + *, + cancel_label: str = "Cancel", +) -> int | None: + """Curses single-select menu. Returns selected index or None on cancel. + + Works inside prompt_toolkit because curses.wrapper() restores the terminal + safely, unlike simple_term_menu which conflicts with /dev/tty. + """ + if not sys.stdin.isatty(): + return None + + try: + import curses + result_holder: list = [None] + + all_items = list(items) + [cancel_label] + cancel_idx = len(items) + + def _draw(stdscr): + curses.curs_set(0) + if curses.has_colors(): + curses.start_color() + curses.use_default_colors() + curses.init_pair(1, curses.COLOR_GREEN, -1) + curses.init_pair(2, curses.COLOR_YELLOW, -1) + cursor = min(default_index, len(all_items) - 1) + scroll_offset = 0 + + while True: + stdscr.clear() + max_y, max_x = stdscr.getmaxyx() + + try: + hattr = curses.A_BOLD + if curses.has_colors(): + hattr |= curses.color_pair(2) + stdscr.addnstr(0, 0, title, max_x - 1, hattr) + stdscr.addnstr( + 1, 0, + " ↑↓ navigate ENTER confirm ESC/q cancel", + max_x - 1, curses.A_DIM, + ) + except curses.error: + pass + + visible_rows = max_y - 3 + if cursor < scroll_offset: + scroll_offset = cursor + elif cursor >= scroll_offset + visible_rows: + scroll_offset = cursor - visible_rows + 1 + + for draw_i, i in enumerate( + range(scroll_offset, min(len(all_items), scroll_offset + visible_rows)) + ): + y = draw_i + 3 + if y >= max_y - 1: + break + arrow = "→" if i == cursor else " " + line = f" {arrow} {all_items[i]}" + attr = curses.A_NORMAL + if i == cursor: + attr = curses.A_BOLD + if curses.has_colors(): + attr |= curses.color_pair(1) + try: + stdscr.addnstr(y, 0, line, max_x - 1, attr) + except curses.error: + pass + + stdscr.refresh() + key = stdscr.getch() + + if key in (curses.KEY_UP, ord("k")): + cursor = (cursor - 1) % len(all_items) + elif key in (curses.KEY_DOWN, ord("j")): + cursor = (cursor + 1) % len(all_items) + elif key in (curses.KEY_ENTER, 10, 13): + result_holder[0] = cursor + return + elif key in (27, ord("q")): + result_holder[0] = None + return + + curses.wrapper(_draw) + flush_stdin() + if result_holder[0] is not None and result_holder[0] >= cancel_idx: + return None + return result_holder[0] + + except Exception: + all_items = list(items) + [cancel_label] + cancel_idx = len(items) + return _numbered_single_fallback(title, all_items, cancel_idx) + + +def _numbered_single_fallback( + title: str, + items: List[str], + cancel_idx: int, +) -> int | None: + """Text-based numbered fallback for single-select.""" + print(f"\n {title}\n") + for i, label in enumerate(items, 1): + print(f" {i}. {label}") + print() + try: + val = input(f" Choice [1-{len(items)}]: ").strip() + if not val: + return None + idx = int(val) - 1 + if 0 <= idx < len(items) and idx < cancel_idx: + return idx + if idx == cancel_idx: + return None + except (ValueError, KeyboardInterrupt, EOFError): + pass + return None + + def _numbered_fallback( title: str, items: List[str], diff --git a/hermes_cli/dump.py b/hermes_cli/dump.py index 00441c0ccb..caa6b7e8ca 100644 --- a/hermes_cli/dump.py +++ b/hermes_cli/dump.py @@ -119,6 +119,7 @@ def _configured_platforms() -> list[str]: "dingtalk": "DINGTALK_CLIENT_ID", "feishu": "FEISHU_APP_ID", "wecom": "WECOM_BOT_ID", + "wecom_callback": "WECOM_CALLBACK_CORP_ID", "weixin": "WEIXIN_ACCOUNT_ID", } return [name for name, env in checks.items() if os.getenv(env)] diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index 897e61df54..624a9f7f72 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -157,30 +157,54 @@ def _request_gateway_self_restart(pid: int) -> bool: return True -def find_gateway_pids(exclude_pids: set | None = None) -> list: +def find_gateway_pids(exclude_pids: set | None = None, all_profiles: bool = False) -> list: """Find PIDs of running gateway processes. Args: exclude_pids: PIDs to exclude from the result (e.g. service-managed PIDs that should not be killed during a stale-process sweep). + all_profiles: When ``True``, return gateway PIDs across **all** + profiles (the pre-7923 global behaviour). ``hermes update`` + needs this because a code update affects every profile. + When ``False`` (default), only PIDs belonging to the current + Hermes profile are returned. """ - pids = [] _exclude = exclude_pids or set() + pids = [pid for pid in _get_service_pids() if pid not in _exclude] patterns = [ "hermes_cli.main gateway", + "hermes_cli.main --profile", + "hermes_cli.main -p", "hermes_cli/main.py gateway", + "hermes_cli/main.py --profile", + "hermes_cli/main.py -p", "hermes gateway", "gateway/run.py", ] + current_home = str(get_hermes_home().resolve()) + current_profile_arg = _profile_arg(current_home) + current_profile_name = current_profile_arg.split()[-1] if current_profile_arg else "" + + def _matches_current_profile(command: str) -> bool: + if current_profile_name: + return ( + f"--profile {current_profile_name}" in command + or f"-p {current_profile_name}" in command + or f"HERMES_HOME={current_home}" in command + ) + + if "--profile " in command or " -p " in command: + return False + if "HERMES_HOME=" in command and f"HERMES_HOME={current_home}" not in command: + return False + return True try: if is_windows(): - # Windows: use wmic to search command lines result = subprocess.run( ["wmic", "process", "get", "ProcessId,CommandLine", "/FORMAT:LIST"], capture_output=True, text=True, timeout=10 ) - # Parse WMIC LIST output: blocks of "CommandLine=...\nProcessId=...\n" current_cmd = "" for line in result.stdout.split('\n'): line = line.strip() @@ -188,7 +212,7 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list: current_cmd = line[len("CommandLine="):] elif line.startswith("ProcessId="): pid_str = line[len("ProcessId="):] - if any(p in current_cmd for p in patterns): + if any(p in current_cmd for p in patterns) and (all_profiles or _matches_current_profile(current_cmd)): try: pid = int(pid_str) if pid != os.getpid() and pid not in pids and pid not in _exclude: @@ -198,41 +222,57 @@ def find_gateway_pids(exclude_pids: set | None = None) -> list: current_cmd = "" else: result = subprocess.run( - ["ps", "aux"], + ["ps", "eww", "-ax", "-o", "pid=,command="], capture_output=True, text=True, timeout=10, ) for line in result.stdout.split('\n'): - # Skip grep and current process - if 'grep' in line or str(os.getpid()) in line: + stripped = line.strip() + if not stripped or 'grep' in stripped: continue - for pattern in patterns: - if pattern in line: - parts = line.split() - if len(parts) > 1: - try: - pid = int(parts[1]) - if pid not in pids and pid not in _exclude: - pids.append(pid) - except ValueError: - continue - break - except Exception: + + pid = None + command = "" + + parts = stripped.split(None, 1) + if len(parts) == 2: + try: + pid = int(parts[0]) + command = parts[1] + except ValueError: + pid = None + + if pid is None: + aux_parts = stripped.split() + if len(aux_parts) > 10 and aux_parts[1].isdigit(): + pid = int(aux_parts[1]) + command = " ".join(aux_parts[10:]) + + if pid is None: + continue + if pid == os.getpid() or pid in pids or pid in _exclude: + continue + if any(pattern in command for pattern in patterns) and (all_profiles or _matches_current_profile(command)): + pids.append(pid) + except (OSError, subprocess.TimeoutExpired): pass return pids -def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None) -> int: +def kill_gateway_processes(force: bool = False, exclude_pids: set | None = None, + all_profiles: bool = False) -> int: """Kill any running gateway processes. Returns count killed. Args: force: Use the platform's force-kill mechanism instead of graceful terminate. exclude_pids: PIDs to skip (e.g. service-managed PIDs that were just restarted and should not be killed). + all_profiles: When ``True``, kill across all profiles. Passed + through to :func:`find_gateway_pids`. """ - pids = find_gateway_pids(exclude_pids=exclude_pids) + pids = find_gateway_pids(exclude_pids=exclude_pids, all_profiles=all_profiles) killed = 0 for pid in pids: @@ -633,6 +673,17 @@ def print_systemd_linger_guidance() -> None: print(" If you want the gateway user service to survive logout, run:") print(" sudo loginctl enable-linger $USER") +def _launchd_user_home() -> Path: + """Return the real macOS user home for launchd artifacts. + + Profile-mode Hermes often sets ``HOME`` to a profile-scoped directory, but + launchd user agents still live under the actual account home. + """ + import pwd + + return Path(pwd.getpwuid(os.getuid()).pw_dir) + + def get_launchd_plist_path() -> Path: """Return the launchd plist path, scoped per profile. @@ -641,7 +692,7 @@ def get_launchd_plist_path() -> Path: """ suffix = _profile_suffix() name = f"ai.hermes.gateway-{suffix}" if suffix else "ai.hermes.gateway" - return Path.home() / "Library" / "LaunchAgents" / f"{name}.plist" + return _launchd_user_home() / "Library" / "LaunchAgents" / f"{name}.plist" def _detect_venv_dir() -> Path | None: """Detect the active virtualenv directory. @@ -839,6 +890,25 @@ def _normalize_service_definition(text: str) -> str: return "\n".join(line.rstrip() for line in text.strip().splitlines()) +def _normalize_launchd_plist_for_comparison(text: str) -> str: + """Normalize launchd plist text for staleness checks. + + The generated plist intentionally captures a broad PATH assembled from the + invoking shell so user-installed tools remain reachable under launchd. + That makes raw text comparison unstable across shells, so ignore the PATH + payload when deciding whether the installed plist is stale. + """ + import re + + normalized = _normalize_service_definition(text) + return re.sub( + r'(PATH\s*)(.*?)()', + r'\1__HERMES_PATH__\3', + normalized, + flags=re.S, + ) + + def systemd_unit_is_current(system: bool = False) -> bool: unit_path = get_systemd_unit_path(system=system) if not unit_path.exists(): @@ -1220,7 +1290,7 @@ def launchd_plist_is_current() -> bool: installed = plist_path.read_text(encoding="utf-8") expected = generate_launchd_plist() - return _normalize_service_definition(installed) == _normalize_service_definition(expected) + return _normalize_launchd_plist_for_comparison(installed) == _normalize_launchd_plist_for_comparison(expected) def refresh_launchd_plist_if_needed() -> bool: @@ -1751,6 +1821,37 @@ _PLATFORMS = [ "help": "Chat ID for scheduled results and notifications."}, ], }, + { + "key": "wecom_callback", + "label": "WeCom Callback (Self-Built App)", + "emoji": "💬", + "token_var": "WECOM_CALLBACK_CORP_ID", + "setup_instructions": [ + "1. Go to WeCom Admin Console → Applications → Create Self-Built App", + "2. Note the Corp ID (top of admin console) and create a Corp Secret", + "3. Under Receive Messages, configure the callback URL to point to your server", + "4. Copy the Token and EncodingAESKey from the callback configuration", + "5. The adapter runs an HTTP server — ensure the port is reachable from WeCom", + "6. Restrict access with WECOM_CALLBACK_ALLOWED_USERS for production use", + ], + "vars": [ + {"name": "WECOM_CALLBACK_CORP_ID", "prompt": "Corp ID", "password": False, + "help": "Your WeCom enterprise Corp ID."}, + {"name": "WECOM_CALLBACK_CORP_SECRET", "prompt": "Corp Secret", "password": True, + "help": "The secret for your self-built application."}, + {"name": "WECOM_CALLBACK_AGENT_ID", "prompt": "Agent ID", "password": False, + "help": "The Agent ID of your self-built application."}, + {"name": "WECOM_CALLBACK_TOKEN", "prompt": "Callback Token", "password": True, + "help": "The Token from your WeCom callback configuration."}, + {"name": "WECOM_CALLBACK_ENCODING_AES_KEY", "prompt": "Encoding AES Key", "password": True, + "help": "The EncodingAESKey from your WeCom callback configuration."}, + {"name": "WECOM_CALLBACK_PORT", "prompt": "Callback server port (default: 8645)", "password": False, + "help": "Port for the HTTP callback server."}, + {"name": "WECOM_CALLBACK_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated, or empty)", "password": False, + "is_allowlist": True, + "help": "Restrict which WeCom users can interact with the app."}, + ], + }, { "key": "weixin", "label": "Weixin / WeChat", @@ -1981,6 +2082,36 @@ def _setup_whatsapp(): cmd_whatsapp(argparse.Namespace()) +def _setup_email(): + """Configure Email via the standard platform setup.""" + email_platform = next(p for p in _PLATFORMS if p["key"] == "email") + _setup_standard_platform(email_platform) + + +def _setup_sms(): + """Configure SMS (Twilio) via the standard platform setup.""" + sms_platform = next(p for p in _PLATFORMS if p["key"] == "sms") + _setup_standard_platform(sms_platform) + + +def _setup_dingtalk(): + """Configure DingTalk via the standard platform setup.""" + dingtalk_platform = next(p for p in _PLATFORMS if p["key"] == "dingtalk") + _setup_standard_platform(dingtalk_platform) + + +def _setup_feishu(): + """Configure Feishu / Lark via the standard platform setup.""" + feishu_platform = next(p for p in _PLATFORMS if p["key"] == "feishu") + _setup_standard_platform(feishu_platform) + + +def _setup_wecom(): + """Configure WeCom (Enterprise WeChat) via the standard platform setup.""" + wecom_platform = next(p for p in _PLATFORMS if p["key"] == "wecom") + _setup_standard_platform(wecom_platform) + + def _is_service_installed() -> bool: """Check if the gateway is installed as a system service.""" if supports_systemd_services(): @@ -2566,7 +2697,7 @@ def gateway_command(args): service_available = True except subprocess.CalledProcessError: pass - killed = kill_gateway_processes() + killed = kill_gateway_processes(all_profiles=True) total = killed + (1 if service_available else 0) if total: print(f"✓ Stopped {total} gateway process(es) across all profiles") diff --git a/hermes_cli/logs.py b/hermes_cli/logs.py index d598494089..9a829a4bdc 100644 --- a/hermes_cli/logs.py +++ b/hermes_cli/logs.py @@ -1,16 +1,18 @@ """``hermes logs`` — view and filter Hermes log files. -Supports tailing, following, session filtering, level filtering, and -relative time ranges. All log files live under ``~/.hermes/logs/``. +Supports tailing, following, session filtering, level filtering, +component filtering, and relative time ranges. All log files live +under ``~/.hermes/logs/``. Usage examples:: hermes logs # last 50 lines of agent.log hermes logs -f # follow agent.log in real time hermes logs errors # last 50 lines of errors.log - hermes logs gateway -n 100 # last 100 lines of gateway.log + hermes logs gateway -n 100 # last 100 lines of gateway.log hermes logs --level WARNING # only WARNING+ lines hermes logs --session abc123 # filter by session ID substring + hermes logs --component tools # only tool-related lines hermes logs --since 1h # lines from the last hour hermes logs --since 30m -f # follow, starting 30 min ago """ @@ -20,7 +22,7 @@ import sys import time from datetime import datetime, timedelta from pathlib import Path -from typing import Optional +from typing import Optional, Sequence from hermes_constants import get_hermes_home, display_hermes_home @@ -38,6 +40,15 @@ _TS_RE = re.compile(r"^(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2})") # Level extraction — matches " INFO ", " WARNING ", " ERROR ", " DEBUG ", " CRITICAL " _LEVEL_RE = re.compile(r"\s(DEBUG|INFO|WARNING|ERROR|CRITICAL)\s") +# Logger name extraction — after level and optional session tag, the next +# non-space token before ":" is the logger name. +# Matches: "INFO gateway.run:" or "INFO [sess_abc] tools.terminal_tool:" +_LOGGER_NAME_RE = re.compile( + r"\s(?:DEBUG|INFO|WARNING|ERROR|CRITICAL)" # level + r"(?:\s+\[.*?\])?" # optional session tag + r"\s+(\S+):" # logger name +) + # Level ordering for >= filtering _LEVEL_ORDER = {"DEBUG": 0, "INFO": 1, "WARNING": 2, "ERROR": 3, "CRITICAL": 4} @@ -79,12 +90,27 @@ def _extract_level(line: str) -> Optional[str]: return m.group(1) if m else None +def _extract_logger_name(line: str) -> Optional[str]: + """Extract the logger name from a log line.""" + m = _LOGGER_NAME_RE.search(line) + return m.group(1) if m else None + + +def _line_matches_component(line: str, prefixes: Sequence[str]) -> bool: + """Check if a log line's logger name starts with any of *prefixes*.""" + name = _extract_logger_name(line) + if name is None: + return False + return name.startswith(tuple(prefixes)) + + def _matches_filters( line: str, *, min_level: Optional[str] = None, session_filter: Optional[str] = None, since: Optional[datetime] = None, + component_prefixes: Optional[Sequence[str]] = None, ) -> bool: """Check if a log line passes all active filters.""" if since is not None: @@ -102,6 +128,10 @@ def _matches_filters( if session_filter not in line: return False + if component_prefixes is not None: + if not _line_matches_component(line, component_prefixes): + return False + return True @@ -113,6 +143,7 @@ def tail_log( level: Optional[str] = None, session: Optional[str] = None, since: Optional[str] = None, + component: Optional[str] = None, ) -> None: """Read and display log lines, optionally following in real time. @@ -130,6 +161,8 @@ def tail_log( Session ID substring to filter on. since Relative time string (e.g. ``"1h"``, ``"30m"``). + component + Component name to filter by (e.g. ``"gateway"``, ``"tools"``). """ filename = LOG_FILES.get(log_name) if filename is None: @@ -155,13 +188,29 @@ def tail_log( print(f"Invalid --level: {level!r}. Use DEBUG, INFO, WARNING, ERROR, or CRITICAL.") sys.exit(1) - has_filters = min_level is not None or session is not None or since_dt is not None + # Resolve component to logger name prefixes + component_prefixes = None + if component: + from hermes_logging import COMPONENT_PREFIXES + component_lower = component.lower() + if component_lower not in COMPONENT_PREFIXES: + available = ", ".join(sorted(COMPONENT_PREFIXES)) + print(f"Unknown component: {component!r}. Available: {available}") + sys.exit(1) + component_prefixes = COMPONENT_PREFIXES[component_lower] + + has_filters = ( + min_level is not None + or session is not None + or since_dt is not None + or component_prefixes is not None + ) # Read and display the tail try: lines = _read_tail(log_path, num_lines, has_filters=has_filters, min_level=min_level, session_filter=session, - since=since_dt) + since=since_dt, component_prefixes=component_prefixes) except PermissionError: print(f"Permission denied: {log_path}") sys.exit(1) @@ -172,6 +221,8 @@ def tail_log( filter_parts.append(f"level>={min_level}") if session: filter_parts.append(f"session={session}") + if component: + filter_parts.append(f"component={component}") if since: filter_parts.append(f"since={since}") filter_desc = f" [{', '.join(filter_parts)}]" if filter_parts else "" @@ -190,7 +241,7 @@ def tail_log( # Follow mode — poll for new content try: _follow_log(log_path, min_level=min_level, session_filter=session, - since=since_dt) + since=since_dt, component_prefixes=component_prefixes) except KeyboardInterrupt: print("\n--- stopped ---") @@ -203,6 +254,7 @@ def _read_tail( min_level: Optional[str] = None, session_filter: Optional[str] = None, since: Optional[datetime] = None, + component_prefixes: Optional[Sequence[str]] = None, ) -> list: """Read the last *num_lines* matching lines from a log file. @@ -215,7 +267,8 @@ def _read_tail( filtered = [ l for l in raw_lines if _matches_filters(l, min_level=min_level, - session_filter=session_filter, since=since) + session_filter=session_filter, since=since, + component_prefixes=component_prefixes) ] return filtered[-num_lines:] else: @@ -284,6 +337,7 @@ def _follow_log( min_level: Optional[str] = None, session_filter: Optional[str] = None, since: Optional[datetime] = None, + component_prefixes: Optional[Sequence[str]] = None, ) -> None: """Poll a log file for new content and print matching lines.""" with open(path, "r", encoding="utf-8", errors="replace") as f: @@ -293,7 +347,8 @@ def _follow_log( line = f.readline() if line: if _matches_filters(line, min_level=min_level, - session_filter=session_filter, since=since): + session_filter=session_filter, since=since, + component_prefixes=component_prefixes): print(line, end="") sys.stdout.flush() else: diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 4b7dd600b3..037c0a72fe 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -528,6 +528,113 @@ def _resolve_last_cli_session() -> Optional[str]: return None +def _probe_container(cmd: list, backend: str, via_sudo: bool = False): + """Run a container inspect probe, returning the CompletedProcess. + + Catches TimeoutExpired specifically for a human-readable message; + all other exceptions propagate naturally. + """ + try: + return subprocess.run(cmd, capture_output=True, text=True, timeout=15) + except subprocess.TimeoutExpired: + label = f"sudo {backend}" if via_sudo else backend + print( + f"Error: timed out waiting for {label} to respond.\n" + f"The {backend} daemon may be unresponsive or starting up.", + file=sys.stderr, + ) + sys.exit(1) + + +def _exec_in_container(container_info: dict, cli_args: list): + """Replace the current process with a command inside the managed container. + + Probes whether sudo is needed (rootful containers), then os.execvp + into the container. On success the Python process is replaced entirely + and the container's exit code becomes the process exit code (OS semantics). + On failure, OSError propagates naturally. + + Args: + container_info: dict with backend, container_name, exec_user, hermes_bin + cli_args: the original CLI arguments (everything after 'hermes') + """ + import shutil + + backend = container_info["backend"] + container_name = container_info["container_name"] + exec_user = container_info["exec_user"] + hermes_bin = container_info["hermes_bin"] + + runtime = shutil.which(backend) + if not runtime: + print(f"Error: {backend} not found on PATH. Cannot route to container.", + file=sys.stderr) + sys.exit(1) + + # Rootful containers (NixOS systemd service) are invisible to unprivileged + # users — Podman uses per-user namespaces, Docker needs group access. + # Probe whether the runtime can see the container; if not, try via sudo. + sudo_path = None + probe = _probe_container( + [runtime, "inspect", "--format", "ok", container_name], backend, + ) + if probe.returncode != 0: + sudo_path = shutil.which("sudo") + if sudo_path: + probe2 = _probe_container( + [sudo_path, "-n", runtime, "inspect", "--format", "ok", container_name], + backend, via_sudo=True, + ) + if probe2.returncode != 0: + print( + f"Error: container '{container_name}' not found via {backend}.\n" + f"\n" + f"The container is likely running as root. Your user cannot see it\n" + f"because {backend} uses per-user namespaces. Grant passwordless\n" + f"sudo for {backend} — the -n (non-interactive) flag is required\n" + f"because a password prompt would hang or break piped commands.\n" + f"\n" + f"On NixOS:\n" + f"\n" + f' security.sudo.extraRules = [{{\n' + f' users = [ "{os.getenv("USER", "your-user")}" ];\n' + f' commands = [{{ command = "{runtime}"; options = [ "NOPASSWD" ]; }}];\n' + f' }}];\n' + f"\n" + f"Or run: sudo hermes {' '.join(cli_args)}", + file=sys.stderr, + ) + sys.exit(1) + else: + print( + f"Error: container '{container_name}' not found via {backend}.\n" + f"The container may be running under root. Try: sudo hermes {' '.join(cli_args)}", + file=sys.stderr, + ) + sys.exit(1) + + is_tty = sys.stdin.isatty() + tty_flags = ["-it"] if is_tty else ["-i"] + + env_flags = [] + for var in ("TERM", "COLORTERM", "LANG", "LC_ALL"): + val = os.environ.get(var) + if val: + env_flags.extend(["-e", f"{var}={val}"]) + + cmd_prefix = [sudo_path, "-n", runtime] if sudo_path else [runtime] + exec_cmd = ( + cmd_prefix + ["exec"] + + tty_flags + + ["-u", exec_user] + + env_flags + + [container_name, hermes_bin] + + cli_args + ) + + os.execvp(exec_cmd[0], exec_cmd) + + def _resolve_session_by_name_or_id(name_or_id: str) -> Optional[str]: """Resolve a session name (title) or ID to a session ID. @@ -2711,6 +2818,18 @@ def cmd_config(args): config_command(args) +def cmd_backup(args): + """Back up Hermes home directory to a zip file.""" + from hermes_cli.backup import run_backup + run_backup(args) + + +def cmd_import(args): + """Restore a Hermes backup from a zip file.""" + from hermes_cli.backup import run_import + run_import(args) + + def cmd_version(args): """Show version.""" print(f"Hermes Agent v{__version__} ({__release_date__})") @@ -3876,7 +3995,7 @@ def cmd_update(args): # Exclude PIDs that belong to just-restarted services so we don't # immediately kill the process that systemd/launchd just spawned. service_pids = _get_service_pids() - manual_pids = find_gateway_pids(exclude_pids=service_pids) + manual_pids = find_gateway_pids(exclude_pids=service_pids, all_profiles=True) for pid in manual_pids: try: os.kill(pid, _signal.SIGTERM) @@ -4231,6 +4350,7 @@ def cmd_logs(args): level=getattr(args, "level", None), session=getattr(args, "session", None), since=getattr(args, "since", None), + component=getattr(args, "component", None), ) @@ -4796,7 +4916,43 @@ For more help on a command: help="Show redacted API key prefixes (first/last 4 chars) instead of just set/not set" ) dump_parser.set_defaults(func=cmd_dump) - + + # ========================================================================= + # backup command + # ========================================================================= + backup_parser = subparsers.add_parser( + "backup", + help="Back up Hermes home directory to a zip file", + description="Create a zip archive of your entire Hermes configuration, " + "skills, sessions, and data (excludes the hermes-agent codebase)" + ) + backup_parser.add_argument( + "-o", "--output", + help="Output path for the zip file (default: ~/hermes-backup-.zip)" + ) + backup_parser.set_defaults(func=cmd_backup) + + # ========================================================================= + # import command + # ========================================================================= + import_parser = subparsers.add_parser( + "import", + help="Restore a Hermes backup from a zip file", + description="Extract a previously created Hermes backup into your " + "Hermes home directory, restoring configuration, skills, " + "sessions, and data" + ) + import_parser.add_argument( + "zipfile", + help="Path to the backup zip file" + ) + import_parser.add_argument( + "--force", "-f", + action="store_true", + help="Overwrite existing files without confirmation" + ) + import_parser.set_defaults(func=cmd_import) + # ========================================================================= # config command # ========================================================================= @@ -5146,6 +5302,8 @@ For more help on a command: mcp_add_p.add_argument("--command", help="Stdio command (e.g. npx)") mcp_add_p.add_argument("--args", nargs="*", default=[], help="Arguments for stdio command") mcp_add_p.add_argument("--auth", choices=["oauth", "header"], help="Auth method") + mcp_add_p.add_argument("--preset", help="Known MCP preset name") + mcp_add_p.add_argument("--env", nargs="*", default=[], help="Environment variables for stdio servers (KEY=VALUE)") mcp_rm_p = mcp_sub.add_parser("remove", aliases=["rm"], help="Remove an MCP server") mcp_rm_p.add_argument("name", help="Server name to remove") @@ -5628,6 +5786,7 @@ Examples: hermes logs gateway -n 100 Show last 100 lines of gateway.log hermes logs --level WARNING Only show WARNING and above hermes logs --session abc123 Filter by session ID + hermes logs --component tools Only show tool-related lines hermes logs --since 1h Lines from the last hour hermes logs --since 30m -f Follow, starting from 30 min ago hermes logs list List available log files with sizes @@ -5657,6 +5816,10 @@ Examples: "--since", metavar="TIME", help="Show lines since TIME ago (e.g. 1h, 30m, 2d)", ) + logs_parser.add_argument( + "--component", metavar="NAME", + help="Filter by component: gateway, agent, tools, cli, cron", + ) logs_parser.set_defaults(func=cmd_logs) # ========================================================================= @@ -5665,9 +5828,22 @@ Examples: # Pre-process argv so unquoted multi-word session names after -c / -r # are merged into a single token before argparse sees them. # e.g. ``hermes -c Pokemon Agent Dev`` → ``hermes -c 'Pokemon Agent Dev'`` + # ── Container-aware routing ──────────────────────────────────────── + # When NixOS container mode is active, route ALL subcommands into + # the managed container. This MUST run before parse_args() so that + # --help, unrecognised flags, and every subcommand are forwarded + # transparently instead of being intercepted by argparse on the host. + from hermes_cli.config import get_container_exec_info + container_info = get_container_exec_info() + if container_info: + _exec_in_container(container_info, sys.argv[1:]) + # Unreachable: os.execvp never returns on success (process is replaced) + # and raises OSError on failure (which propagates as a traceback). + sys.exit(1) + _processed_argv = _coalesce_session_name_args(sys.argv[1:]) args = parser.parse_args(_processed_argv) - + # Handle --version flag if args.version: cmd_version(args) diff --git a/hermes_cli/mcp_config.py b/hermes_cli/mcp_config.py index cf2dde0892..b21234ce0a 100644 --- a/hermes_cli/mcp_config.py +++ b/hermes_cli/mcp_config.py @@ -9,7 +9,6 @@ configuration in ~/.hermes/config.yaml under the ``mcp_servers`` key. """ import asyncio -import getpass import logging import os import re @@ -28,6 +27,11 @@ from hermes_constants import display_hermes_home logger = logging.getLogger(__name__) +_ENV_VAR_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +_MCP_PRESETS: Dict[str, Dict[str, Any]] = {} + # ─── UI Helpers ─────────────────────────────────────────────────────────────── @@ -98,6 +102,59 @@ def _env_key_for_server(name: str) -> str: return f"MCP_{name.upper().replace('-', '_')}_API_KEY" +def _parse_env_assignments(raw_env: Optional[List[str]]) -> Dict[str, str]: + """Parse ``KEY=VALUE`` strings from CLI args into an env dict.""" + parsed: Dict[str, str] = {} + for item in raw_env or []: + text = str(item or "").strip() + if not text: + continue + if "=" not in text: + raise ValueError(f"Invalid --env value '{text}' (expected KEY=VALUE)") + key, value = text.split("=", 1) + key = key.strip() + if not key: + raise ValueError(f"Invalid --env value '{text}' (missing variable name)") + if not _ENV_VAR_NAME_RE.match(key): + raise ValueError(f"Invalid --env variable name '{key}'") + parsed[key] = value + return parsed + + +def _apply_mcp_preset( + name: str, + *, + preset_name: Optional[str], + url: Optional[str], + command: Optional[str], + cmd_args: List[str], + server_config: Dict[str, Any], +) -> tuple[Optional[str], Optional[str], List[str], bool]: + """Apply a known MCP preset when transport details were omitted.""" + if not preset_name: + return url, command, cmd_args, False + + preset = _MCP_PRESETS.get(preset_name) + if not preset: + raise ValueError(f"Unknown MCP preset: {preset_name}") + + if url or command: + return url, command, cmd_args, False + + url = preset.get("url") + command = preset.get("command") + cmd_args = list(preset.get("args") or []) + + if url: + server_config["url"] = url + if command: + server_config["command"] = command + if cmd_args: + server_config["args"] = cmd_args + + return url, command, cmd_args, True + + # ─── Discovery (temporary connect) ─────────────────────────────────────────── def _probe_single_server( @@ -166,13 +223,35 @@ def cmd_mcp_add(args): command = getattr(args, "command", None) cmd_args = getattr(args, "args", None) or [] auth_type = getattr(args, "auth", None) + preset_name = getattr(args, "preset", None) + raw_env = getattr(args, "env", None) + + server_config: Dict[str, Any] = {} + try: + explicit_env = _parse_env_assignments(raw_env) + url, command, cmd_args, _preset_applied = _apply_mcp_preset( + name, + preset_name=preset_name, + url=url, + command=command, + cmd_args=list(cmd_args), + server_config=server_config, + ) + except ValueError as exc: + _error(str(exc)) + return + + if url and explicit_env: + _error("--env is only supported for stdio MCP servers (--command or stdio presets)") + return # Validate transport if not url and not command: - _error("Must specify --url or --command ") + _error("Must specify --url , --command , or --preset ") _info("Examples:") _info(' hermes mcp add ink --url "https://mcp.ml.ink/mcp"') _info(' hermes mcp add github --command npx --args @modelcontextprotocol/server-github') + _info(' hermes mcp add myserver --preset mypreset') return # Check if server already exists @@ -183,13 +262,15 @@ def cmd_mcp_add(args): return # Build initial config - server_config: Dict[str, Any] = {} if url: server_config["url"] = url else: server_config["command"] = command if cmd_args: server_config["args"] = cmd_args + if explicit_env: + server_config["env"] = explicit_env + # ── Authentication ──────────────────────────────────────────────── @@ -627,6 +708,7 @@ def mcp_command(args): _info("hermes mcp serve Run as MCP server") _info("hermes mcp add --url Add an MCP server") _info("hermes mcp add --command Add a stdio server") + _info("hermes mcp add --preset Add from a known preset") _info("hermes mcp remove Remove a server") _info("hermes mcp list List servers") _info("hermes mcp test Test connection") diff --git a/hermes_cli/model_normalize.py b/hermes_cli/model_normalize.py index 8c0c30fbfa..68e8dc898e 100644 --- a/hermes_cli/model_normalize.py +++ b/hermes_cli/model_normalize.py @@ -74,13 +74,13 @@ _DOT_TO_HYPHEN_PROVIDERS: frozenset[str] = frozenset({ _STRIP_VENDOR_ONLY_PROVIDERS: frozenset[str] = frozenset({ "copilot", "copilot-acp", + "openai-codex", }) # Providers whose native naming is authoritative -- pass through unchanged. _AUTHORITATIVE_NATIVE_PROVIDERS: frozenset[str] = frozenset({ "gemini", "huggingface", - "openai-codex", }) # Direct providers that accept bare native names but should repair a matching @@ -360,7 +360,11 @@ def normalize_model_for_provider(model_input: str, target_provider: str) -> str: # --- Copilot: strip matching provider prefix, keep dots --- if provider in _STRIP_VENDOR_ONLY_PROVIDERS: - return _strip_matching_provider_prefix(name, provider) + stripped = _strip_matching_provider_prefix(name, provider) + if stripped == name and name.startswith("openai/"): + # openai-codex maps openai/gpt-5.4 -> gpt-5.4 + return name.split("/", 1)[1] + return stripped # --- DeepSeek: map to one of two canonical names --- if provider == "deepseek": diff --git a/hermes_cli/models.py b/hermes_cli/models.py index 17c1072dbe..ae4146415e 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -1809,6 +1809,35 @@ def validate_requested_model( "message": message, } + # OpenAI Codex has its own catalog path; /v1/models probing is not the right validation path. + if normalized == "openai-codex": + try: + codex_models = provider_model_ids("openai-codex") + except Exception: + codex_models = [] + if codex_models: + if requested_for_lookup in set(codex_models): + return { + "accepted": True, + "persist": True, + "recognized": True, + "message": None, + } + suggestions = get_close_matches(requested_for_lookup, codex_models, n=3, cutoff=0.5) + suggestion_text = "" + if suggestions: + suggestion_text = "\n Similar models: " + ", ".join(f"`{s}`" for s in suggestions) + return { + "accepted": True, + "persist": True, + "recognized": False, + "message": ( + f"Note: `{requested}` was not found in the OpenAI Codex model listing. " + f"It may still work if your account has access to it." + f"{suggestion_text}" + ), + } + # Probe the live API to check if the model actually exists api_models = fetch_api_models(api_key, base_url) diff --git a/hermes_cli/platforms.py b/hermes_cli/platforms.py index d84d974398..5ffd6fc210 100644 --- a/hermes_cli/platforms.py +++ b/hermes_cli/platforms.py @@ -33,6 +33,7 @@ PLATFORMS: OrderedDict[str, PlatformInfo] = OrderedDict([ ("dingtalk", PlatformInfo(label="💬 DingTalk", default_toolset="hermes-dingtalk")), ("feishu", PlatformInfo(label="🪽 Feishu", default_toolset="hermes-feishu")), ("wecom", PlatformInfo(label="💬 WeCom", default_toolset="hermes-wecom")), + ("wecom_callback", PlatformInfo(label="💬 WeCom Callback", default_toolset="hermes-wecom-callback")), ("weixin", PlatformInfo(label="💬 Weixin", default_toolset="hermes-weixin")), ("webhook", PlatformInfo(label="🔗 Webhook", default_toolset="hermes-webhook")), ("api_server", PlatformInfo(label="🌐 API Server", default_toolset="hermes-api-server")), diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index fb70d9081c..e12f7d1a76 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -1969,6 +1969,48 @@ def _setup_weixin(): _gateway_setup_weixin() +def _setup_signal(): + """Configure Signal via gateway setup.""" + from hermes_cli.gateway import _setup_signal as _gateway_setup_signal + _gateway_setup_signal() + + +def _setup_email(): + """Configure Email via gateway setup.""" + from hermes_cli.gateway import _setup_email as _gateway_setup_email + _gateway_setup_email() + + +def _setup_sms(): + """Configure SMS (Twilio) via gateway setup.""" + from hermes_cli.gateway import _setup_sms as _gateway_setup_sms + _gateway_setup_sms() + + +def _setup_dingtalk(): + """Configure DingTalk via gateway setup.""" + from hermes_cli.gateway import _setup_dingtalk as _gateway_setup_dingtalk + _gateway_setup_dingtalk() + + +def _setup_feishu(): + """Configure Feishu / Lark via gateway setup.""" + from hermes_cli.gateway import _setup_feishu as _gateway_setup_feishu + _gateway_setup_feishu() + + +def _setup_wecom(): + """Configure WeCom (Enterprise WeChat) via gateway setup.""" + from hermes_cli.gateway import _setup_wecom as _gateway_setup_wecom + _gateway_setup_wecom() + + +def _setup_wecom_callback(): + """Configure WeCom Callback (self-built app) via gateway setup.""" + from hermes_cli.gateway import _setup_wecom_callback as _gw_setup + _gw_setup() + + def _setup_bluebubbles(): """Configure BlueBubbles iMessage gateway.""" print_header("BlueBubbles (iMessage)") @@ -2085,9 +2127,16 @@ _GATEWAY_PLATFORMS = [ ("Telegram", "TELEGRAM_BOT_TOKEN", _setup_telegram), ("Discord", "DISCORD_BOT_TOKEN", _setup_discord), ("Slack", "SLACK_BOT_TOKEN", _setup_slack), + ("Signal", "SIGNAL_HTTP_URL", _setup_signal), + ("Email", "EMAIL_ADDRESS", _setup_email), + ("SMS (Twilio)", "TWILIO_ACCOUNT_SID", _setup_sms), ("Matrix", "MATRIX_ACCESS_TOKEN", _setup_matrix), ("Mattermost", "MATTERMOST_TOKEN", _setup_mattermost), ("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp), + ("DingTalk", "DINGTALK_CLIENT_ID", _setup_dingtalk), + ("Feishu / Lark", "FEISHU_APP_ID", _setup_feishu), + ("WeCom (Enterprise WeChat)", "WECOM_BOT_ID", _setup_wecom), + ("WeCom Callback (Self-Built App)", "WECOM_CALLBACK_CORP_ID", _setup_wecom_callback), ("Weixin (WeChat)", "WEIXIN_ACCOUNT_ID", _setup_weixin), ("BlueBubbles (iMessage)", "BLUEBUBBLES_SERVER_URL", _setup_bluebubbles), ("Webhooks (GitHub, GitLab, etc.)", "WEBHOOK_ENABLED", _setup_webhooks), @@ -2129,10 +2178,17 @@ def setup_gateway(config: dict): get_env_value("TELEGRAM_BOT_TOKEN") or get_env_value("DISCORD_BOT_TOKEN") or get_env_value("SLACK_BOT_TOKEN") + or get_env_value("SIGNAL_HTTP_URL") + or get_env_value("EMAIL_ADDRESS") + or get_env_value("TWILIO_ACCOUNT_SID") or get_env_value("MATTERMOST_TOKEN") or get_env_value("MATRIX_ACCESS_TOKEN") or get_env_value("MATRIX_PASSWORD") or get_env_value("WHATSAPP_ENABLED") + or get_env_value("DINGTALK_CLIENT_ID") + or get_env_value("FEISHU_APP_ID") + or get_env_value("WECOM_BOT_ID") + or get_env_value("WEIXIN_ACCOUNT_ID") or get_env_value("BLUEBUBBLES_SERVER_URL") or get_env_value("WEBHOOK_ENABLED") ) @@ -2321,12 +2377,30 @@ def _get_section_config_summary(config: dict, section_key: str) -> Optional[str] platforms.append("Discord") if get_env_value("SLACK_BOT_TOKEN"): platforms.append("Slack") - if get_env_value("WHATSAPP_PHONE_NUMBER_ID"): - platforms.append("WhatsApp") if get_env_value("SIGNAL_ACCOUNT"): platforms.append("Signal") + if get_env_value("EMAIL_ADDRESS"): + platforms.append("Email") + if get_env_value("TWILIO_ACCOUNT_SID"): + platforms.append("SMS") + if get_env_value("MATRIX_ACCESS_TOKEN") or get_env_value("MATRIX_PASSWORD"): + platforms.append("Matrix") + if get_env_value("MATTERMOST_TOKEN"): + platforms.append("Mattermost") + if get_env_value("WHATSAPP_PHONE_NUMBER_ID"): + platforms.append("WhatsApp") + if get_env_value("DINGTALK_CLIENT_ID"): + platforms.append("DingTalk") + if get_env_value("FEISHU_APP_ID"): + platforms.append("Feishu") + if get_env_value("WECOM_BOT_ID"): + platforms.append("WeCom") + if get_env_value("WEIXIN_ACCOUNT_ID"): + platforms.append("Weixin") if get_env_value("BLUEBUBBLES_SERVER_URL"): platforms.append("BlueBubbles") + if get_env_value("WEBHOOK_ENABLED"): + platforms.append("Webhooks") if platforms: return ", ".join(platforms) return None # No platforms configured — section must run diff --git a/hermes_cli/status.py b/hermes_cli/status.py index 2c50bc98ff..54dd923c4e 100644 --- a/hermes_cli/status.py +++ b/hermes_cli/status.py @@ -302,6 +302,7 @@ def show_status(args): "DingTalk": ("DINGTALK_CLIENT_ID", None), "Feishu": ("FEISHU_APP_ID", "FEISHU_HOME_CHANNEL"), "WeCom": ("WECOM_BOT_ID", "WECOM_HOME_CHANNEL"), + "WeCom Callback": ("WECOM_CALLBACK_CORP_ID", None), "Weixin": ("WEIXIN_ACCOUNT_ID", "WEIXIN_HOME_CHANNEL"), "BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"), } diff --git a/hermes_logging.py b/hermes_logging.py index b765e94640..f1c20e3fa2 100644 --- a/hermes_logging.py +++ b/hermes_logging.py @@ -7,16 +7,28 @@ gateway call early in their startup path. All log files live under Log files produced: agent.log — INFO+, all agent/tool/session activity (the main log) errors.log — WARNING+, errors and warnings only (quick triage) + gateway.log — INFO+, gateway-only events (created when mode="gateway") -Both files use ``RotatingFileHandler`` with ``RedactingFormatter`` so +All files use ``RotatingFileHandler`` with ``RedactingFormatter`` so secrets are never written to disk. + +Component separation: + gateway.log only receives records from ``gateway.*`` loggers — + platform adapters, session management, slash commands, delivery. + agent.log remains the catch-all (everything goes there). + +Session context: + Call ``set_session_context(session_id)`` at the start of a conversation + and ``clear_session_context()`` when done. All log lines emitted on + that thread will include ``[session_id]`` for filtering/correlation. """ import logging import os +import threading from logging.handlers import RotatingFileHandler from pathlib import Path -from typing import Optional +from typing import Optional, Sequence from hermes_constants import get_config_path, get_hermes_home @@ -25,9 +37,14 @@ from hermes_constants import get_config_path, get_hermes_home # unless ``force=True``. _logging_initialized = False -# Default log format — includes timestamp, level, logger name, and message. -_LOG_FORMAT = "%(asctime)s %(levelname)s %(name)s: %(message)s" -_LOG_FORMAT_VERBOSE = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +# Thread-local storage for per-conversation session context. +_session_context = threading.local() + +# Default log format — includes timestamp, level, optional session tag, +# logger name, and message. The ``%(session_tag)s`` field is guaranteed to +# exist on every LogRecord via _install_session_record_factory() below. +_LOG_FORMAT = "%(asctime)s %(levelname)s%(session_tag)s %(name)s: %(message)s" +_LOG_FORMAT_VERBOSE = "%(asctime)s - %(name)s - %(levelname)s%(session_tag)s - %(message)s" # Third-party loggers that are noisy at DEBUG/INFO level. _NOISY_LOGGERS = ( @@ -48,6 +65,99 @@ _NOISY_LOGGERS = ( ) +# --------------------------------------------------------------------------- +# Public session context API +# --------------------------------------------------------------------------- + +def set_session_context(session_id: str) -> None: + """Set the session ID for the current thread. + + All subsequent log records on this thread will include ``[session_id]`` + in the formatted output. Call at the start of ``run_conversation()``. + """ + _session_context.session_id = session_id + + +def clear_session_context() -> None: + """Clear the session ID for the current thread. + + Optional — ``set_session_context()`` overwrites the previous value, + so explicit clearing is only needed if the thread is reused for + non-conversation work after ``run_conversation()`` returns. + """ + _session_context.session_id = None + + +# --------------------------------------------------------------------------- +# Record factory — injects session_tag into every LogRecord at creation +# --------------------------------------------------------------------------- + +def _install_session_record_factory() -> None: + """Replace the global LogRecord factory with one that adds ``session_tag``. + + Unlike a ``logging.Filter`` on a handler or logger, the record factory + runs for EVERY record in the process — including records that propagate + from child loggers and records handled by third-party handlers. This + guarantees ``%(session_tag)s`` is always available in format strings, + eliminating the KeyError that would occur if a handler used our format + without having a ``_SessionFilter`` attached. + + Idempotent — checks for a marker attribute to avoid double-wrapping if + the module is reloaded. + """ + current_factory = logging.getLogRecordFactory() + if getattr(current_factory, "_hermes_session_injector", False): + return # already installed + + def _session_record_factory(*args, **kwargs): + record = current_factory(*args, **kwargs) + sid = getattr(_session_context, "session_id", None) + record.session_tag = f" [{sid}]" if sid else "" # type: ignore[attr-defined] + return record + + _session_record_factory._hermes_session_injector = True # type: ignore[attr-defined] + logging.setLogRecordFactory(_session_record_factory) + + +# Install immediately on import — session_tag is available on all records +# from this point forward, even before setup_logging() is called. +_install_session_record_factory() + + +# --------------------------------------------------------------------------- +# Filters +# --------------------------------------------------------------------------- + +class _ComponentFilter(logging.Filter): + """Only pass records whose logger name starts with one of *prefixes*. + + Used to route gateway-specific records to ``gateway.log`` while + keeping ``agent.log`` as the catch-all. + """ + + def __init__(self, prefixes: Sequence[str]) -> None: + super().__init__() + self._prefixes = tuple(prefixes) + + def filter(self, record: logging.LogRecord) -> bool: + return record.name.startswith(self._prefixes) + + +# Logger name prefixes that belong to each component. +# Used by _ComponentFilter and exposed for ``hermes logs --component``. +COMPONENT_PREFIXES = { + "gateway": ("gateway",), + "agent": ("agent", "run_agent", "model_tools", "batch_runner"), + "tools": ("tools",), + "cli": ("hermes_cli", "cli"), + "cron": ("cron",), +} + + +# --------------------------------------------------------------------------- +# Main setup +# --------------------------------------------------------------------------- + def setup_logging( *, hermes_home: Optional[Path] = None, @@ -78,8 +188,9 @@ def setup_logging( Number of rotated backup files to keep. Defaults to 3 or the value from config.yaml ``logging.backup_count``. mode - Hint for the caller context: ``"cli"``, ``"gateway"``, ``"cron"``. - Currently used only for log format tuning (gateway includes PID). + Caller context: ``"cli"``, ``"gateway"``, ``"cron"``. + When ``"gateway"``, an additional ``gateway.log`` file is created + that receives only gateway-component records. force Re-run setup even if it has already been called. @@ -130,6 +241,18 @@ def setup_logging( formatter=RedactingFormatter(_LOG_FORMAT), ) + # --- gateway.log (INFO+, gateway component only) ------------------------ + if mode == "gateway": + _add_rotating_handler( + root, + log_dir / "gateway.log", + level=logging.INFO, + max_bytes=5 * 1024 * 1024, + backup_count=3, + formatter=RedactingFormatter(_LOG_FORMAT), + log_filter=_ComponentFilter(COMPONENT_PREFIXES["gateway"]), + ) + # Ensure root logger level is low enough for the handlers to fire. if root.level == logging.NOTSET or root.level > level: root.setLevel(level) @@ -218,9 +341,16 @@ def _add_rotating_handler( max_bytes: int, backup_count: int, formatter: logging.Formatter, + log_filter: Optional[logging.Filter] = None, ) -> None: """Add a ``RotatingFileHandler`` to *logger*, skipping if one already exists for the same resolved file path (idempotent). + + Parameters + ---------- + log_filter + Optional filter to attach to the handler (e.g. ``_ComponentFilter`` + for gateway.log). """ resolved = path.resolve() for existing in logger.handlers: @@ -236,6 +366,8 @@ def _add_rotating_handler( ) handler.setLevel(level) handler.setFormatter(formatter) + if log_filter is not None: + handler.addFilter(log_filter) logger.addHandler(handler) diff --git a/nix/nixosModules.nix b/nix/nixosModules.nix index b1be031df2..75b3dca31b 100644 --- a/nix/nixosModules.nix +++ b/nix/nixosModules.nix @@ -499,6 +499,16 @@ default = "ubuntu:24.04"; description = "OCI container image. The container pulls this at runtime via Docker/Podman."; }; + + hostUsers = mkOption { + type = types.listOf types.str; + default = [ ]; + description = '' + Interactive users who get a ~/.hermes symlink to the service + stateDir. These users are automatically added to the hermes group. + ''; + example = [ "sidbin" ]; + }; }; }; @@ -557,6 +567,25 @@ environment.variables.HERMES_HOME = "${cfg.stateDir}/.hermes"; }) + # ── Host user group membership ───────────────────────────────────── + (lib.mkIf (cfg.container.enable && cfg.container.hostUsers != []) { + users.users = lib.genAttrs cfg.container.hostUsers (user: { + extraGroups = [ cfg.group ]; + }); + }) + + # ── Warnings ────────────────────────────────────────────────────── + (lib.mkIf (cfg.container.enable && !cfg.addToSystemPackages && cfg.container.hostUsers != []) { + warnings = [ + '' + services.hermes-agent: container.enable is true and container.hostUsers + is set, but addToSystemPackages is false. Without a host-installed hermes + binary, container routing will not work for interactive users. + Set addToSystemPackages = true or ensure hermes is on PATH. + '' + ]; + }) + # ── Directories ─────────────────────────────────────────────────── { systemd.tmpfiles.rules = [ @@ -611,6 +640,59 @@ chown ${cfg.user}:${cfg.group} ${cfg.stateDir}/.hermes/.managed chmod 0644 ${cfg.stateDir}/.hermes/.managed + # Container mode metadata — tells the host CLI to exec into the + # container instead of running locally. Removed when container mode + # is disabled so the host CLI falls back to native execution. + ${if cfg.container.enable then '' + cat > ${cfg.stateDir}/.hermes/.container-mode <<'HERMES_CONTAINER_MODE_EOF' +# Written by NixOS activation script. Do not edit manually. +backend=${cfg.container.backend} +container_name=${containerName} +exec_user=${cfg.user} +hermes_bin=${containerDataDir}/current-package/bin/hermes +HERMES_CONTAINER_MODE_EOF + chown ${cfg.user}:${cfg.group} ${cfg.stateDir}/.hermes/.container-mode + chmod 0644 ${cfg.stateDir}/.hermes/.container-mode + '' else '' + rm -f ${cfg.stateDir}/.hermes/.container-mode + + # Remove symlink bridge for hostUsers + ${lib.concatStringsSep "\n" (map (user: + let + userHome = config.users.users.${user}.home; + symlinkPath = "${userHome}/.hermes"; + in '' + if [ -L "${symlinkPath}" ] && [ "$(readlink "${symlinkPath}")" = "${cfg.stateDir}/.hermes" ]; then + rm -f "${symlinkPath}" + echo "hermes-agent: removed symlink ${symlinkPath}" + fi + '') cfg.container.hostUsers)} + ''} + + # ── Symlink bridge for interactive users ─────────────────────── + # Create ~/.hermes -> stateDir/.hermes for each hostUser so the + # host CLI shares state with the container service. + # Only runs when container mode is enabled. + ${lib.optionalString cfg.container.enable + (lib.concatStringsSep "\n" (map (user: + let + userHome = config.users.users.${user}.home; + symlinkPath = "${userHome}/.hermes"; + target = "${cfg.stateDir}/.hermes"; + in '' + if [ -d "${symlinkPath}" ] && [ ! -L "${symlinkPath}" ]; then + # Real directory — back it up, then create symlink. + # (ln -sfn cannot atomically replace a directory.) + _backup="${symlinkPath}.bak.$(date +%s)" + echo "hermes-agent: backing up existing ${symlinkPath} to $_backup" + mv "${symlinkPath}" "$_backup" + fi + # For everything else (existing symlink, doesn't exist, etc.) + # ln -sfn handles it: replaces symlinks, creates new ones. + ln -sfn "${target}" "${symlinkPath}" + chown -h ${user}:${cfg.group} "${symlinkPath}" + '') cfg.container.hostUsers))} + # Seed auth file if provided ${lib.optionalString (cfg.authFile != null) '' ${if cfg.authFileForceOverwrite then '' diff --git a/pyproject.toml b/pyproject.toml index 28a4a300a7..95a1dfddd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dev = ["debugpy>=1.8.0,<2", "pytest>=9.0.2,<10", "pytest-asyncio>=1.3.0,<2", "py messaging = ["python-telegram-bot[webhooks]>=22.6,<23", "discord.py[voice]>=2.7.1,<3", "aiohttp>=3.13.3,<4", "slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4"] cron = ["croniter>=6.0.0,<7"] slack = ["slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4"] -matrix = ["mautrix[encryption]>=0.20,<1", "Markdown>=3.6,<4"] +matrix = ["mautrix[encryption]>=0.20,<1", "Markdown>=3.6,<4", "aiosqlite>=0.20", "asyncpg>=0.29"] cli = ["simple-term-menu>=1.0,<2"] tts-premium = ["elevenlabs>=1.0,<2"] voice = [ diff --git a/run_agent.py b/run_agent.py index ba0a9f93d3..3956f89048 100644 --- a/run_agent.py +++ b/run_agent.py @@ -339,10 +339,7 @@ def _paths_overlap(left: Path, right: Path) -> bool: _SURROGATE_RE = re.compile(r'[\ud800-\udfff]') -_BUDGET_WARNING_RE = re.compile( - r"\[BUDGET(?:\s+WARNING)?:\s+Iteration\s+\d+/\d+\..*?\]", - re.DOTALL, -) + def _sanitize_surrogates(text: str) -> str: @@ -463,34 +460,7 @@ def _sanitize_messages_non_ascii(messages: list) -> bool: return found -def _strip_budget_warnings_from_history(messages: list) -> None: - """Remove budget pressure warnings from tool-result messages in-place. - Budget warnings are turn-scoped signals that must not leak into replayed - history. They live in tool-result ``content`` either as a JSON key - (``_budget_warning``) or appended plain text. - """ - for msg in messages: - if not isinstance(msg, dict) or msg.get("role") != "tool": - continue - content = msg.get("content") - if not isinstance(content, str) or "_budget_warning" not in content and "[BUDGET" not in content: - continue - - # Try JSON first (the common case: _budget_warning key in a dict) - try: - parsed = json.loads(content) - if isinstance(parsed, dict) and "_budget_warning" in parsed: - del parsed["_budget_warning"] - msg["content"] = json.dumps(parsed, ensure_ascii=False) - continue - except (json.JSONDecodeError, TypeError): - pass - - # Fallback: strip the text pattern from plain-text tool results - cleaned = _BUDGET_WARNING_RE.sub("", content).strip() - if cleaned != content: - msg["content"] = cleaned # ========================================================================= @@ -579,6 +549,7 @@ class AIAgent: clarify_callback: callable = None, step_callback: callable = None, stream_delta_callback: callable = None, + interim_assistant_callback: callable = None, tool_gen_callback: callable = None, status_callback: callable = None, max_tokens: int = None, @@ -728,6 +699,7 @@ class AIAgent: self.clarify_callback = clarify_callback self.step_callback = step_callback self.stream_delta_callback = stream_delta_callback + self.interim_assistant_callback = interim_assistant_callback self.status_callback = status_callback self.tool_gen_callback = tool_gen_callback @@ -775,12 +747,14 @@ class AIAgent: self._use_prompt_caching = (is_openrouter and is_claude) or is_native_anthropic self._cache_ttl = "5m" # Default 5-minute TTL (1.25x write cost) - # Iteration budget pressure: warn the LLM as it approaches max_iterations. - # Warnings are injected into the last tool result JSON (not as separate - # messages) so they don't break message structure or invalidate caching. - self._budget_caution_threshold = 0.7 # 70% — nudge to start wrapping up - self._budget_warning_threshold = 0.9 # 90% — urgent, respond now - self._budget_pressure_enabled = True + # Iteration budget: the LLM is only notified when it actually exhausts + # the iteration budget (api_call_count >= max_iterations). At that + # point we inject ONE message, allow one final API call, and if the + # model doesn't produce a text response, force a user-message asking + # it to summarise. No intermediate pressure warnings — they caused + # models to "give up" prematurely on complex tasks (#7915). + self._budget_exhausted_injected = False + self._budget_grace_call = False # Context pressure warnings: notify the USER (not the LLM) as context # fills up. Purely informational — displayed in CLI output and sent via @@ -831,6 +805,11 @@ class AIAgent: # Deferred paragraph break flag — set after tool iterations so a # single "\n\n" is prepended to the next real text delta. self._stream_needs_break = False + # Visible assistant text already delivered through live token callbacks + # during the current model response. Used to avoid re-sending the same + # commentary when the provider later returns it as a completed interim + # assistant message. + self._current_streamed_assistant_text = "" # Optional current-turn user-message override used when the API-facing # user message intentionally differs from the persisted transcript @@ -1331,6 +1310,19 @@ class AIAgent: ) self.compression_enabled = compression_enabled + # Reject models whose context window is below the minimum required + # for reliable tool-calling workflows (64K tokens). + from agent.model_metadata import MINIMUM_CONTEXT_LENGTH + _ctx = getattr(self.context_compressor, "context_length", 0) + if _ctx and _ctx < MINIMUM_CONTEXT_LENGTH: + raise ValueError( + f"Model {self.model} has a context window of {_ctx:,} tokens, " + f"which is below the minimum {MINIMUM_CONTEXT_LENGTH:,} required " + f"by Hermes Agent. Choose a model with at least " + f"{MINIMUM_CONTEXT_LENGTH // 1000}K context, or set " + f"model.context_length in config.yaml to override." + ) + # Inject context engine tool schemas (e.g. lcm_grep, lcm_describe, lcm_expand) self._context_engine_tool_names: set = set() if hasattr(self, "context_compressor") and self.context_compressor and self.tools is not None: @@ -3190,7 +3182,7 @@ class AIAgent: if platform_key in PLATFORM_HINTS: prompt_parts.append(PLATFORM_HINTS[platform_key]) - return "\n\n".join(prompt_parts) + return "\n\n".join(p.strip() for p in prompt_parts if p.strip()) # ========================================================================= # Pre/post-call guardrails (inspired by PR #1321 — @alireza78a) @@ -3446,6 +3438,7 @@ class AIAgent: def _chat_messages_to_responses_input(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Convert internal chat-style messages to Responses input items.""" items: List[Dict[str, Any]] = [] + seen_item_ids: set = set() for msg in messages: if not isinstance(msg, dict): @@ -3466,7 +3459,12 @@ class AIAgent: if isinstance(codex_reasoning, list): for ri in codex_reasoning: if isinstance(ri, dict) and ri.get("encrypted_content"): + item_id = ri.get("id") + if item_id and item_id in seen_item_ids: + continue items.append(ri) + if item_id: + seen_item_ids.add(item_id) has_codex_reasoning = True if content_text.strip(): @@ -3546,6 +3544,7 @@ class AIAgent: raise ValueError("Codex Responses input must be a list of input items.") normalized: List[Dict[str, Any]] = [] + seen_ids: set = set() for idx, item in enumerate(raw_items): if not isinstance(item, dict): raise ValueError(f"Codex Responses input[{idx}] must be an object.") @@ -3598,8 +3597,12 @@ class AIAgent: if item_type == "reasoning": encrypted = item.get("encrypted_content") if isinstance(encrypted, str) and encrypted: - reasoning_item = {"type": "reasoning", "encrypted_content": encrypted} item_id = item.get("id") + if isinstance(item_id, str) and item_id: + if item_id in seen_ids: + continue + seen_ids.add(item_id) + reasoning_item = {"type": "reasoning", "encrypted_content": encrypted} if isinstance(item_id, str) and item_id: reasoning_item["id"] = item_id summary = item.get("summary") @@ -4719,6 +4722,49 @@ class AIAgent: # ── Unified streaming API call ───────────────────────────────────────── + def _reset_stream_delivery_tracking(self) -> None: + """Reset tracking for text delivered during the current model response.""" + self._current_streamed_assistant_text = "" + + def _record_streamed_assistant_text(self, text: str) -> None: + """Accumulate visible assistant text emitted through stream callbacks.""" + if isinstance(text, str) and text: + self._current_streamed_assistant_text = ( + getattr(self, "_current_streamed_assistant_text", "") + text + ) + + @staticmethod + def _normalize_interim_visible_text(text: str) -> str: + if not isinstance(text, str): + return "" + return re.sub(r"\s+", " ", text).strip() + + def _interim_content_was_streamed(self, content: str) -> bool: + visible_content = self._normalize_interim_visible_text( + self._strip_think_blocks(content or "") + ) + if not visible_content: + return False + streamed = self._normalize_interim_visible_text( + self._strip_think_blocks(getattr(self, "_current_streamed_assistant_text", "") or "") + ) + return bool(streamed) and streamed == visible_content + + def _emit_interim_assistant_message(self, assistant_msg: Dict[str, Any]) -> None: + """Surface a real mid-turn assistant commentary message to the UI layer.""" + cb = getattr(self, "interim_assistant_callback", None) + if cb is None or not isinstance(assistant_msg, dict): + return + content = assistant_msg.get("content") + visible = self._strip_think_blocks(content or "").strip() + if not visible or visible == "(empty)": + return + already_streamed = self._interim_content_was_streamed(visible) + try: + cb(visible, already_streamed=already_streamed) + except Exception: + logger.debug("interim_assistant_callback error", exc_info=True) + def _fire_stream_delta(self, text: str) -> None: """Fire all registered stream delta callbacks (display + TTS).""" # If a tool iteration set the break flag, prepend a single paragraph @@ -4728,12 +4774,16 @@ class AIAgent: if getattr(self, "_stream_needs_break", False) and text and text.strip(): self._stream_needs_break = False text = "\n\n" + text - for cb in (self.stream_delta_callback, self._stream_callback): - if cb is not None: - try: - cb(text) - except Exception: - pass + callbacks = [cb for cb in (self.stream_delta_callback, self._stream_callback) if cb is not None] + delivered = False + for cb in callbacks: + try: + cb(text) + delivered = True + except Exception: + pass + if delivered: + self._record_streamed_assistant_text(text) def _fire_reasoning_delta(self, text: str) -> None: """Fire reasoning callback if registered.""" @@ -4917,6 +4967,7 @@ class AIAgent: if self.stream_delta_callback: try: self.stream_delta_callback(delta.content) + self._record_streamed_assistant_text(delta.content) except Exception: pass @@ -6908,24 +6959,6 @@ class AIAgent: turn_tool_msgs = messages[-num_tools:] enforce_turn_budget(turn_tool_msgs, env=get_active_env(effective_task_id)) - # ── Budget pressure injection ──────────────────────────────────── - budget_warning = self._get_budget_warning(api_call_count) - if budget_warning and messages and messages[-1].get("role") == "tool": - last_content = messages[-1]["content"] - try: - parsed = json.loads(last_content) - if isinstance(parsed, dict): - parsed["_budget_warning"] = budget_warning - messages[-1]["content"] = json.dumps(parsed, ensure_ascii=False) - else: - messages[-1]["content"] = last_content + f"\n\n{budget_warning}" - except (json.JSONDecodeError, TypeError): - messages[-1]["content"] = last_content + f"\n\n{budget_warning}" - if not self.quiet_mode: - remaining = self.max_iterations - api_call_count - tier = "⚠️ WARNING" if remaining <= self.max_iterations * 0.1 else "💡 CAUTION" - print(f"{self.log_prefix}{tier}: {remaining} iterations remaining") - def _execute_tool_calls_sequential(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None: """Execute tool calls sequentially (original behavior). Used for single calls or interactive tools.""" for i, tool_call in enumerate(assistant_message.tool_calls, 1): @@ -6974,6 +7007,15 @@ class AIAgent: self._current_tool = function_name self._touch_activity(f"executing tool: {function_name}") + # Set activity callback for long-running tool execution (terminal + # commands, etc.) so the gateway's inactivity monitor doesn't kill + # the agent while a command is running. + try: + from tools.environments.base import set_activity_callback + set_activity_callback(self._touch_activity) + except Exception: + pass + if self.tool_progress_callback: try: preview = _build_tool_preview(function_name, function_args) @@ -7263,50 +7305,7 @@ class AIAgent: if num_tools_seq > 0: enforce_turn_budget(messages[-num_tools_seq:], env=get_active_env(effective_task_id)) - # ── Budget pressure injection ───────────────────────────────── - # After all tool calls in this turn are processed, check if we're - # approaching max_iterations. If so, inject a warning into the LAST - # tool result's JSON so the LLM sees it naturally when reading results. - budget_warning = self._get_budget_warning(api_call_count) - if budget_warning and messages and messages[-1].get("role") == "tool": - last_content = messages[-1]["content"] - try: - parsed = json.loads(last_content) - if isinstance(parsed, dict): - parsed["_budget_warning"] = budget_warning - messages[-1]["content"] = json.dumps(parsed, ensure_ascii=False) - else: - messages[-1]["content"] = last_content + f"\n\n{budget_warning}" - except (json.JSONDecodeError, TypeError): - messages[-1]["content"] = last_content + f"\n\n{budget_warning}" - if not self.quiet_mode: - remaining = self.max_iterations - api_call_count - tier = "⚠️ WARNING" if remaining <= self.max_iterations * 0.1 else "💡 CAUTION" - print(f"{self.log_prefix}{tier}: {remaining} iterations remaining") - def _get_budget_warning(self, api_call_count: int) -> Optional[str]: - """Return a budget pressure string, or None if not yet needed. - - Two-tier system: - - Caution (70%): nudge to consolidate work - - Warning (90%): urgent, must respond now - """ - if not self._budget_pressure_enabled or self.max_iterations <= 0: - return None - progress = api_call_count / self.max_iterations - remaining = self.max_iterations - api_call_count - if progress >= self._budget_warning_threshold: - return ( - f"[BUDGET WARNING: Iteration {api_call_count}/{self.max_iterations}. " - f"Only {remaining} iteration(s) left. " - "Provide your final response NOW. No more tool calls unless absolutely critical.]" - ) - if progress >= self._budget_caution_threshold: - return ( - f"[BUDGET: Iteration {api_call_count}/{self.max_iterations}. " - f"{remaining} iterations left. Start consolidating your work.]" - ) - return None def _emit_context_pressure(self, compaction_progress: float, compressor) -> None: """Notify the user that context is approaching the compaction threshold. @@ -7530,6 +7529,11 @@ class AIAgent: # Installed once, transparent when streams are healthy, prevents crash on write. _install_safe_stdio() + # Tag all log records on this thread with the session ID so + # ``hermes logs --session `` can filter a single conversation. + from hermes_logging import set_session_context + set_session_context(self.session_id) + # If the previous turn activated fallback, restore the primary # runtime so this turn gets a fresh attempt with the preferred model. # No-op when _fallback_activated is False (gateway, first turn, etc.). @@ -7599,14 +7603,6 @@ class AIAgent: # Initialize conversation (copy to avoid mutating the caller's list) messages = list(conversation_history) if conversation_history else [] - # Strip budget pressure warnings from previous turns. These are - # turn-scoped signals injected by _get_budget_warning() into tool - # result content. If left in the replayed history, models (especially - # GPT-family) interpret them as still-active instructions and avoid - # making tool calls in ALL subsequent turns. - if messages: - _strip_budget_warnings_from_history(messages) - # Hydrate todo store from conversation history (gateway creates a fresh # AIAgent per message, so the in-memory store is empty -- we need to # recover the todo state from the most recent todo tool response in history) @@ -7823,7 +7819,7 @@ class AIAgent: except Exception: pass - while api_call_count < self.max_iterations and self.iteration_budget.remaining > 0: + while (api_call_count < self.max_iterations and self.iteration_budget.remaining > 0) or self._budget_grace_call: # Reset per-turn checkpoint dedup so each iteration can take one snapshot self._checkpoint_mgr.new_turn() @@ -7838,7 +7834,13 @@ class AIAgent: api_call_count += 1 self._api_call_count = api_call_count self._touch_activity(f"starting API call #{api_call_count}") - if not self.iteration_budget.consume(): + + # Grace call: the budget is exhausted but we gave the model one + # more chance. Consume the grace flag so the loop exits after + # this iteration regardless of outcome. + if self._budget_grace_call: + self._budget_grace_call = False + elif not self.iteration_budget.consume(): _turn_exit_reason = "budget_exhausted" if not self.quiet_mode: self._safe_print(f"\n⚠️ Iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} iterations used)") @@ -7965,9 +7967,39 @@ class AIAgent: # manual message manipulation are always caught. api_messages = self._sanitize_api_messages(api_messages) + # Normalize message whitespace and tool-call JSON for consistent + # prefix matching. Ensures bit-perfect prefixes across turns, + # which enables KV cache reuse on local inference servers + # (llama.cpp, vLLM, Ollama) and improves cache hit rates for + # cloud providers. Operates on api_messages (the API copy) so + # the original conversation history in `messages` is untouched. + for am in api_messages: + if isinstance(am.get("content"), str): + am["content"] = am["content"].strip() + for am in api_messages: + tcs = am.get("tool_calls") + if not tcs: + continue + new_tcs = [] + for tc in tcs: + if isinstance(tc, dict) and "function" in tc: + try: + args_obj = json.loads(tc["function"]["arguments"]) + tc = {**tc, "function": { + **tc["function"], + "arguments": json.dumps( + args_obj, separators=(",", ":"), + sort_keys=True, + ), + }} + except Exception: + pass + new_tcs.append(tc) + am["tool_calls"] = new_tcs + # Calculate approximate request size for logging total_chars = sum(len(str(msg)) for msg in api_messages) - approx_tokens = total_chars // 4 # Rough estimate: 4 chars per token + approx_tokens = estimate_messages_tokens_rough(api_messages) # Thinking spinner for quiet mode (animated during API call) thinking_spinner = None @@ -8016,6 +8048,7 @@ class AIAgent: while retry_count < max_retries: try: + self._reset_stream_delivery_tracking() api_kwargs = self._build_api_kwargs(api_messages) if self.api_mode == "codex_responses": api_kwargs = self._preflight_codex_api_kwargs(api_kwargs, allow_stream=False) @@ -8174,6 +8207,8 @@ class AIAgent: self._emit_status("⚠️ Empty/malformed response — switching to fallback...") if self._try_activate_fallback(): retry_count = 0 + compression_attempts = 0 + primary_recovery_attempted = False continue # Check for error field in response (some providers include this) @@ -8209,6 +8244,8 @@ class AIAgent: self._emit_status(f"⚠️ Max retries ({max_retries}) for invalid responses — trying fallback...") if self._try_activate_fallback(): retry_count = 0 + compression_attempts = 0 + primary_recovery_attempted = False continue self._emit_status(f"❌ Max retries ({max_retries}) exceeded for invalid responses. Giving up.") logging.error(f"{self.log_prefix}Invalid API response after {max_retries} retries.") @@ -8863,6 +8900,8 @@ class AIAgent: self._emit_status("⚠️ Rate limited — switching to fallback provider...") if self._try_activate_fallback(): retry_count = 0 + compression_attempts = 0 + primary_recovery_attempted = False continue is_payload_too_large = ( @@ -9076,6 +9115,8 @@ class AIAgent: self._emit_status(f"⚠️ Non-retryable error (HTTP {status_code}) — trying fallback...") if self._try_activate_fallback(): retry_count = 0 + compression_attempts = 0 + primary_recovery_attempted = False continue if api_kwargs is not None: self._dump_api_request_debug( @@ -9141,6 +9182,8 @@ class AIAgent: self._emit_status(f"⚠️ Max retries ({max_retries}) exhausted — trying fallback...") if self._try_activate_fallback(): retry_count = 0 + compression_attempts = 0 + primary_recovery_attempted = False continue _final_summary = self._summarize_api_error(api_error) if is_rate_limited: @@ -9364,8 +9407,6 @@ class AIAgent: # Check for incomplete (opened but never closed) # This means the model ran out of output tokens mid-reasoning — retry up to 2 times if has_incomplete_scratchpad(assistant_message.content or ""): - if not hasattr(self, '_incomplete_scratchpad_retries'): - self._incomplete_scratchpad_retries = 0 self._incomplete_scratchpad_retries += 1 self._vprint(f"{self.log_prefix}⚠️ Incomplete detected (opened but never closed)") @@ -9393,12 +9434,9 @@ class AIAgent: } # Reset incomplete scratchpad counter on clean response - if hasattr(self, '_incomplete_scratchpad_retries'): - self._incomplete_scratchpad_retries = 0 + self._incomplete_scratchpad_retries = 0 if self.api_mode == "codex_responses" and finish_reason == "incomplete": - if not hasattr(self, "_codex_incomplete_retries"): - self._codex_incomplete_retries = 0 self._codex_incomplete_retries += 1 interim_msg = self._build_assistant_message(assistant_message, finish_reason) @@ -9425,6 +9463,7 @@ class AIAgent: ) if not duplicate_interim: messages.append(interim_msg) + self._emit_interim_assistant_message(interim_msg) if self._codex_incomplete_retries < 3: if not self.quiet_mode: @@ -9469,8 +9508,6 @@ class AIAgent: ] if invalid_tool_calls: # Track retries for invalid tool calls - if not hasattr(self, '_invalid_tool_retries'): - self._invalid_tool_retries = 0 self._invalid_tool_retries += 1 # Return helpful error to model — model can self-correct next turn @@ -9506,8 +9543,7 @@ class AIAgent: }) continue # Reset retry counter on successful tool call validation - if hasattr(self, '_invalid_tool_retries'): - self._invalid_tool_retries = 0 + self._invalid_tool_retries = 0 # Validate tool call arguments are valid JSON # Handle empty strings as empty objects (common model quirk) @@ -9647,6 +9683,7 @@ class AIAgent: messages.pop() messages.append(assistant_msg) + self._emit_interim_assistant_message(assistant_msg) # Close any open streaming display (response box, reasoning # box) before tool execution begins. Intermediate turns may @@ -9909,8 +9946,7 @@ class AIAgent: break # Reset retry counter/signature on successful content - if hasattr(self, '_empty_content_retries'): - self._empty_content_retries = 0 + self._empty_content_retries = 0 self._thinking_prefill_retries = 0 if ( @@ -9926,6 +9962,7 @@ class AIAgent: codex_ack_continuations += 1 interim_msg = self._build_assistant_message(assistant_message, "incomplete") messages.append(interim_msg) + self._emit_interim_assistant_message(interim_msg) continue_msg = { "role": "user", @@ -9976,8 +10013,7 @@ class AIAgent: except (OSError, ValueError): logger.error(error_msg) - if self.verbose_logging: - logging.exception("Detailed error information:") + logger.debug("Outer loop error in API call #%d", api_call_count, exc_info=True) # If an assistant message with tool_calls was already appended, # the API expects a role="tool" result for every tool_call_id. @@ -10023,7 +10059,31 @@ class AIAgent: if final_response is None and ( api_call_count >= self.max_iterations or self.iteration_budget.remaining <= 0 - ): + ) and not self._budget_exhausted_injected: + # Budget exhausted but we haven't tried asking the model to + # summarise yet. Inject a user message and give it one grace + # API call to produce a text response. + self._budget_exhausted_injected = True + self._budget_grace_call = True + _grace_msg = ( + "Your tool budget ran out. Please give me the information " + "or actions you've completed so far." + ) + messages.append({"role": "user", "content": _grace_msg}) + self._emit_status( + f"⚠️ Iteration budget exhausted ({api_call_count}/{self.max_iterations}) " + "— asking model to summarise" + ) + if not self.quiet_mode: + self._safe_print( + f"\n⚠️ Iteration budget exhausted ({api_call_count}/{self.max_iterations}) " + "— requesting summary..." + ) + + if final_response is None and ( + api_call_count >= self.max_iterations + or self.iteration_budget.remaining <= 0 + ) and not self._budget_grace_call: _turn_exit_reason = f"max_iterations_reached({api_call_count}/{self.max_iterations})" if self.iteration_budget.remaining <= 0 and not self.quiet_mode: print(f"\n⚠️ Iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} iterations used)") diff --git a/tests/agent/test_context_compressor.py b/tests/agent/test_context_compressor.py index 88a23b44cf..f4cf19666f 100644 --- a/tests/agent/test_context_compressor.py +++ b/tests/agent/test_context_compressor.py @@ -576,11 +576,19 @@ class TestSummaryTargetRatio: assert c.summary_target_ratio == 0.80 def test_default_threshold_is_50_percent(self): - """Default compression threshold should be 50%.""" + """Default compression threshold should be 50%, with a 64K floor.""" with patch("agent.context_compressor.get_model_context_length", return_value=100_000): c = ContextCompressor(model="test", quiet_mode=True) assert c.threshold_percent == 0.50 - assert c.threshold_tokens == 50_000 + # 50% of 100K = 50K, but the floor is 64K + assert c.threshold_tokens == 64_000 + + def test_threshold_floor_does_not_apply_above_128k(self): + """On large-context models the 50% percentage is used directly.""" + with patch("agent.context_compressor.get_model_context_length", return_value=200_000): + c = ContextCompressor(model="test", quiet_mode=True) + # 50% of 200K = 100K, which is above the 64K floor + assert c.threshold_tokens == 100_000 def test_default_protect_last_n_is_20(self): """Default protect_last_n should be 20.""" diff --git a/tests/agent/test_local_stream_timeout.py b/tests/agent/test_local_stream_timeout.py index 929f2e3c84..8184dd2d49 100644 --- a/tests/agent/test_local_stream_timeout.py +++ b/tests/agent/test_local_stream_timeout.py @@ -22,6 +22,9 @@ class TestLocalStreamReadTimeout: "http://0.0.0.0:5000", "http://192.168.1.100:8000", "http://10.0.0.5:1234", + "http://host.docker.internal:11434", + "http://host.containers.internal:11434", + "http://host.lima.internal:11434", ]) def test_local_endpoint_bumps_read_timeout(self, base_url): """Local endpoint + default timeout -> bumps to base_timeout.""" @@ -68,3 +71,38 @@ class TestLocalStreamReadTimeout: if _stream_read_timeout == 120.0 and base_url and is_local_endpoint(base_url): _stream_read_timeout = _base_timeout assert _stream_read_timeout == 120.0 + + +class TestIsLocalEndpoint: + """Direct unit tests for is_local_endpoint.""" + + @pytest.mark.parametrize("url", [ + "http://localhost:11434", + "http://127.0.0.1:8080", + "http://0.0.0.0:5000", + "http://[::1]:11434", + "http://192.168.1.100:8000", + "http://10.0.0.5:1234", + "http://172.17.0.1:11434", + ]) + def test_classic_local_addresses(self, url): + assert is_local_endpoint(url) is True + + @pytest.mark.parametrize("url", [ + "http://host.docker.internal:11434", + "http://host.docker.internal:8080/v1", + "http://gateway.docker.internal:11434", + "http://host.containers.internal:11434", + "http://host.lima.internal:11434", + ]) + def test_container_dns_names(self, url): + assert is_local_endpoint(url) is True + + @pytest.mark.parametrize("url", [ + "https://api.openai.com", + "https://openrouter.ai/api", + "https://api.anthropic.com", + "https://evil.docker.internal.example.com", + ]) + def test_remote_endpoints(self, url): + assert is_local_endpoint(url) is False diff --git a/tests/agent/test_model_metadata.py b/tests/agent/test_model_metadata.py index 1eac37e20f..df680fb241 100644 --- a/tests/agent/test_model_metadata.py +++ b/tests/agent/test_model_metadata.py @@ -50,7 +50,8 @@ class TestEstimateTokensRough: assert estimate_tokens_rough("a" * 400) == 100 def test_short_text(self): - assert estimate_tokens_rough("hello") == 1 + # "hello" = 5 chars → ceil(5/4) = 2 + assert estimate_tokens_rough("hello") == 2 def test_proportional(self): short = estimate_tokens_rough("hello world") @@ -68,10 +69,11 @@ class TestEstimateMessagesTokensRough: assert estimate_messages_tokens_rough([]) == 0 def test_single_message_concrete_value(self): - """Verify against known str(msg) length.""" + """Verify against known str(msg) length (ceiling division).""" msg = {"role": "user", "content": "a" * 400} result = estimate_messages_tokens_rough([msg]) - expected = len(str(msg)) // 4 + n = len(str(msg)) + expected = (n + 3) // 4 assert result == expected def test_multiple_messages_additive(self): @@ -80,7 +82,8 @@ class TestEstimateMessagesTokensRough: {"role": "assistant", "content": "Hi there, how can I help?"}, ] result = estimate_messages_tokens_rough(msgs) - expected = sum(len(str(m)) for m in msgs) // 4 + n = sum(len(str(m)) for m in msgs) + expected = (n + 3) // 4 assert result == expected def test_tool_call_message(self): @@ -89,7 +92,7 @@ class TestEstimateMessagesTokensRough: "tool_calls": [{"id": "1", "function": {"name": "terminal", "arguments": "{}"}}]} result = estimate_messages_tokens_rough([msg]) assert result > 0 - assert result == len(str(msg)) // 4 + assert result == (len(str(msg)) + 3) // 4 def test_message_with_list_content(self): """Vision messages with multimodal content arrays.""" @@ -98,7 +101,7 @@ class TestEstimateMessagesTokensRough: {"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}} ]} result = estimate_messages_tokens_rough([msg]) - assert result == len(str(msg)) // 4 + assert result == (len(str(msg)) + 3) // 4 # ========================================================================= diff --git a/tests/agent/test_prompt_builder.py b/tests/agent/test_prompt_builder.py index 3b6a4c3ec1..1f2f6ada77 100644 --- a/tests/agent/test_prompt_builder.py +++ b/tests/agent/test_prompt_builder.py @@ -1009,65 +1009,4 @@ class TestOpenAIModelExecutionGuidance: # ========================================================================= -class TestStripBudgetWarningsFromHistory: - def test_strips_json_budget_warning_key(self): - import json - from run_agent import _strip_budget_warnings_from_history - messages = [ - {"role": "tool", "tool_call_id": "c1", "content": json.dumps({ - "output": "hello", - "exit_code": 0, - "_budget_warning": "[BUDGET: Iteration 55/60. 5 iterations left. Start consolidating your work.]", - })}, - ] - _strip_budget_warnings_from_history(messages) - parsed = json.loads(messages[0]["content"]) - assert "_budget_warning" not in parsed - assert parsed["output"] == "hello" - assert parsed["exit_code"] == 0 - - def test_strips_text_budget_warning(self): - from run_agent import _strip_budget_warnings_from_history - - messages = [ - {"role": "tool", "tool_call_id": "c1", - "content": "some result\n\n[BUDGET WARNING: Iteration 58/60. Only 2 iteration(s) left. Provide your final response NOW. No more tool calls unless absolutely critical.]"}, - ] - _strip_budget_warnings_from_history(messages) - assert messages[0]["content"] == "some result" - - def test_leaves_non_tool_messages_unchanged(self): - from run_agent import _strip_budget_warnings_from_history - - messages = [ - {"role": "assistant", "content": "[BUDGET WARNING: Iteration 58/60. Only 2 iteration(s) left. Provide your final response NOW. No more tool calls unless absolutely critical.]"}, - {"role": "user", "content": "hello"}, - ] - original_contents = [m["content"] for m in messages] - _strip_budget_warnings_from_history(messages) - assert [m["content"] for m in messages] == original_contents - - def test_handles_empty_and_missing_content(self): - from run_agent import _strip_budget_warnings_from_history - - messages = [ - {"role": "tool", "tool_call_id": "c1", "content": ""}, - {"role": "tool", "tool_call_id": "c2"}, - ] - _strip_budget_warnings_from_history(messages) - assert messages[0]["content"] == "" - - def test_strips_caution_variant(self): - import json - from run_agent import _strip_budget_warnings_from_history - - messages = [ - {"role": "tool", "tool_call_id": "c1", "content": json.dumps({ - "output": "ok", - "_budget_warning": "[BUDGET: Iteration 42/60. 18 iterations left. Start consolidating your work.]", - })}, - ] - _strip_budget_warnings_from_history(messages) - parsed = json.loads(messages[0]["content"]) - assert "_budget_warning" not in parsed diff --git a/tests/gateway/test_discord_connect.py b/tests/gateway/test_discord_connect.py index 9f094dd0dd..04490f2462 100644 --- a/tests/gateway/test_discord_connect.py +++ b/tests/gateway/test_discord_connect.py @@ -74,6 +74,26 @@ class FakeBot: return None +class SlowSyncTree(FakeTree): + def __init__(self): + super().__init__() + self.started = asyncio.Event() + self.allow_finish = asyncio.Event() + + async def _slow_sync(): + self.started.set() + await self.allow_finish.wait() + return [] + + self.sync = AsyncMock(side_effect=_slow_sync) + + +class SlowSyncBot(FakeBot): + def __init__(self, *, intents, proxy=None): + super().__init__(intents=intents, proxy=proxy) + self.tree = SlowSyncTree() + + @pytest.mark.asyncio @pytest.mark.parametrize( ("allowed_users", "expected_members_intent"), @@ -138,3 +158,36 @@ async def test_connect_releases_token_lock_on_timeout(monkeypatch): assert ok is False assert released == [("discord-bot-token", "test-token")] assert adapter._platform_lock_identity is None + + +@pytest.mark.asyncio +async def test_connect_does_not_wait_for_slash_sync(monkeypatch): + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="test-token")) + + monkeypatch.setattr("gateway.status.acquire_scoped_lock", lambda scope, identity, metadata=None: (True, None)) + monkeypatch.setattr("gateway.status.release_scoped_lock", lambda scope, identity: None) + + intents = SimpleNamespace(message_content=False, dm_messages=False, guild_messages=False, members=False, voice_states=False) + monkeypatch.setattr(discord_platform.Intents, "default", lambda: intents) + + created = {} + + def fake_bot_factory(*, command_prefix, intents, proxy=None): + bot = SlowSyncBot(intents=intents, proxy=proxy) + created["bot"] = bot + return bot + + monkeypatch.setattr(discord_platform.commands, "Bot", fake_bot_factory) + monkeypatch.setattr(adapter, "_resolve_allowed_usernames", AsyncMock()) + + ok = await asyncio.wait_for(adapter.connect(), timeout=1.0) + + assert ok is True + assert adapter._ready_event.is_set() + + await asyncio.wait_for(created["bot"].tree.started.wait(), timeout=1.0) + assert created["bot"].tree.sync.await_count == 1 + + created["bot"].tree.allow_finish.set() + await asyncio.sleep(0) + await adapter.disconnect() diff --git a/tests/gateway/test_display_config.py b/tests/gateway/test_display_config.py new file mode 100644 index 0000000000..4dd73ebd28 --- /dev/null +++ b/tests/gateway/test_display_config.py @@ -0,0 +1,355 @@ +"""Tests for gateway.display_config — per-platform display/verbosity resolver.""" +import pytest + + +# --------------------------------------------------------------------------- +# Resolver: resolution order +# --------------------------------------------------------------------------- + +class TestResolveDisplaySetting: + """resolve_display_setting() resolves with correct priority.""" + + def test_explicit_platform_override_wins(self): + """display.platforms.. takes top priority.""" + from gateway.display_config import resolve_display_setting + + config = { + "display": { + "tool_progress": "all", + "platforms": { + "telegram": {"tool_progress": "verbose"}, + }, + } + } + assert resolve_display_setting(config, "telegram", "tool_progress") == "verbose" + + def test_global_setting_when_no_platform_override(self): + """Falls back to display. when no platform override exists.""" + from gateway.display_config import resolve_display_setting + + config = { + "display": { + "tool_progress": "new", + "platforms": {}, + } + } + assert resolve_display_setting(config, "telegram", "tool_progress") == "new" + + def test_platform_default_when_no_user_config(self): + """Falls back to built-in platform default.""" + from gateway.display_config import resolve_display_setting + + # Empty config — should get built-in defaults + config = {} + # Telegram defaults to tier_high → "all" + assert resolve_display_setting(config, "telegram", "tool_progress") == "all" + # Email defaults to tier_minimal → "off" + assert resolve_display_setting(config, "email", "tool_progress") == "off" + + def test_global_default_for_unknown_platform(self): + """Unknown platforms get the global defaults.""" + from gateway.display_config import resolve_display_setting + + config = {} + # Unknown platform, no config → global default "all" + assert resolve_display_setting(config, "unknown_platform", "tool_progress") == "all" + + def test_fallback_parameter_used_last(self): + """Explicit fallback is used when nothing else matches.""" + from gateway.display_config import resolve_display_setting + + config = {} + # "nonexistent_key" isn't in any defaults + result = resolve_display_setting(config, "telegram", "nonexistent_key", "my_fallback") + assert result == "my_fallback" + + def test_platform_override_only_affects_that_platform(self): + """Other platforms are unaffected by a specific platform override.""" + from gateway.display_config import resolve_display_setting + + config = { + "display": { + "tool_progress": "all", + "platforms": { + "slack": {"tool_progress": "off"}, + }, + } + } + assert resolve_display_setting(config, "slack", "tool_progress") == "off" + assert resolve_display_setting(config, "telegram", "tool_progress") == "all" + + +# --------------------------------------------------------------------------- +# Backward compatibility: tool_progress_overrides +# --------------------------------------------------------------------------- + +class TestBackwardCompat: + """Legacy tool_progress_overrides is still respected as a fallback.""" + + def test_legacy_overrides_read(self): + """tool_progress_overrides is read when no platforms entry exists.""" + from gateway.display_config import resolve_display_setting + + config = { + "display": { + "tool_progress": "all", + "tool_progress_overrides": { + "signal": "off", + "telegram": "verbose", + }, + } + } + assert resolve_display_setting(config, "signal", "tool_progress") == "off" + assert resolve_display_setting(config, "telegram", "tool_progress") == "verbose" + + def test_new_platforms_takes_precedence_over_legacy(self): + """display.platforms beats tool_progress_overrides.""" + from gateway.display_config import resolve_display_setting + + config = { + "display": { + "tool_progress": "all", + "tool_progress_overrides": {"telegram": "verbose"}, + "platforms": {"telegram": {"tool_progress": "new"}}, + } + } + assert resolve_display_setting(config, "telegram", "tool_progress") == "new" + + def test_legacy_overrides_only_for_tool_progress(self): + """Legacy overrides don't affect other settings.""" + from gateway.display_config import resolve_display_setting + + config = { + "display": { + "tool_progress_overrides": {"telegram": "verbose"}, + } + } + # show_reasoning should NOT read from tool_progress_overrides + assert resolve_display_setting(config, "telegram", "show_reasoning") is False + + +# --------------------------------------------------------------------------- +# YAML normalisation +# --------------------------------------------------------------------------- + +class TestYAMLNormalisation: + """YAML 1.1 quirks (bare off → False, on → True) are handled.""" + + def test_tool_progress_false_normalised_to_off(self): + """YAML's bare `off` parses as False — normalised to 'off' string.""" + from gateway.display_config import resolve_display_setting + + config = {"display": {"tool_progress": False}} + assert resolve_display_setting(config, "telegram", "tool_progress") == "off" + + def test_tool_progress_true_normalised_to_all(self): + """YAML's bare `on` parses as True — normalised to 'all'.""" + from gateway.display_config import resolve_display_setting + + config = {"display": {"tool_progress": True}} + assert resolve_display_setting(config, "telegram", "tool_progress") == "all" + + def test_show_reasoning_string_true(self): + """String 'true' is normalised to bool True.""" + from gateway.display_config import resolve_display_setting + + config = {"display": {"platforms": {"telegram": {"show_reasoning": "true"}}}} + assert resolve_display_setting(config, "telegram", "show_reasoning") is True + + def test_tool_preview_length_string(self): + """String numbers are normalised to int.""" + from gateway.display_config import resolve_display_setting + + config = {"display": {"platforms": {"slack": {"tool_preview_length": "80"}}}} + assert resolve_display_setting(config, "slack", "tool_preview_length") == 80 + + def test_platform_override_false_tool_progress(self): + """Per-platform bare off → normalised.""" + from gateway.display_config import resolve_display_setting + + config = {"display": {"platforms": {"slack": {"tool_progress": False}}}} + assert resolve_display_setting(config, "slack", "tool_progress") == "off" + + +# --------------------------------------------------------------------------- +# Built-in platform defaults (tier system) +# --------------------------------------------------------------------------- + +class TestPlatformDefaults: + """Built-in defaults reflect platform capability tiers.""" + + def test_high_tier_platforms(self): + """Telegram and Discord default to 'all' tool progress.""" + from gateway.display_config import resolve_display_setting + + for plat in ("telegram", "discord"): + assert resolve_display_setting({}, plat, "tool_progress") == "all", plat + + def test_medium_tier_platforms(self): + """Slack, Mattermost, Matrix default to 'new' tool progress.""" + from gateway.display_config import resolve_display_setting + + for plat in ("slack", "mattermost", "matrix", "feishu"): + assert resolve_display_setting({}, plat, "tool_progress") == "new", plat + + def test_low_tier_platforms(self): + """Signal, WhatsApp, etc. default to 'off' tool progress.""" + from gateway.display_config import resolve_display_setting + + for plat in ("signal", "whatsapp", "bluebubbles", "weixin", "wecom", "dingtalk"): + assert resolve_display_setting({}, plat, "tool_progress") == "off", plat + + def test_minimal_tier_platforms(self): + """Email, SMS, webhook default to 'off' tool progress.""" + from gateway.display_config import resolve_display_setting + + for plat in ("email", "sms", "webhook", "homeassistant"): + assert resolve_display_setting({}, plat, "tool_progress") == "off", plat + + def test_low_tier_streaming_defaults_to_false(self): + """Low-tier platforms default streaming to False.""" + from gateway.display_config import resolve_display_setting + + assert resolve_display_setting({}, "signal", "streaming") is False + assert resolve_display_setting({}, "email", "streaming") is False + + def test_high_tier_streaming_defaults_to_none(self): + """High-tier platforms default streaming to None (follow global).""" + from gateway.display_config import resolve_display_setting + + assert resolve_display_setting({}, "telegram", "streaming") is None + + +# --------------------------------------------------------------------------- +# get_effective_display / get_platform_defaults +# --------------------------------------------------------------------------- + +class TestHelpers: + """Helper functions return correct composite results.""" + + def test_get_effective_display_merges_correctly(self): + from gateway.display_config import get_effective_display + + config = { + "display": { + "tool_progress": "new", + "show_reasoning": True, + "platforms": { + "telegram": {"tool_progress": "verbose"}, + }, + } + } + eff = get_effective_display(config, "telegram") + assert eff["tool_progress"] == "verbose" # platform override + assert eff["show_reasoning"] is True # global + assert "tool_preview_length" in eff # default filled in + + def test_get_platform_defaults_returns_dict(self): + from gateway.display_config import get_platform_defaults + + defaults = get_platform_defaults("telegram") + assert "tool_progress" in defaults + assert "show_reasoning" in defaults + # Returns a new dict (not the shared tier dict) + defaults["tool_progress"] = "changed" + assert get_platform_defaults("telegram")["tool_progress"] != "changed" + + +# --------------------------------------------------------------------------- +# Config migration: tool_progress_overrides → display.platforms +# --------------------------------------------------------------------------- + +class TestConfigMigration: + """Version 16 migration moves tool_progress_overrides into display.platforms.""" + + def test_migration_creates_platforms_entries(self, tmp_path, monkeypatch): + """Old overrides are migrated into display.platforms..tool_progress.""" + import yaml + + config_path = tmp_path / "config.yaml" + config = { + "_config_version": 15, + "display": { + "tool_progress_overrides": { + "signal": "off", + "telegram": "all", + }, + }, + } + config_path.write_text(yaml.dump(config)) + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + # Re-import to pick up the new HERMES_HOME + import importlib + import hermes_cli.config as cfg_mod + importlib.reload(cfg_mod) + + result = cfg_mod.migrate_config(interactive=False, quiet=True) + # Re-read config + updated = yaml.safe_load(config_path.read_text()) + platforms = updated.get("display", {}).get("platforms", {}) + assert platforms.get("signal", {}).get("tool_progress") == "off" + assert platforms.get("telegram", {}).get("tool_progress") == "all" + + def test_migration_preserves_existing_platforms_entries(self, tmp_path, monkeypatch): + """Existing display.platforms entries are NOT overwritten by migration.""" + import yaml + + config_path = tmp_path / "config.yaml" + config = { + "_config_version": 15, + "display": { + "tool_progress_overrides": {"telegram": "off"}, + "platforms": {"telegram": {"tool_progress": "verbose"}}, + }, + } + config_path.write_text(yaml.dump(config)) + + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + import importlib + import hermes_cli.config as cfg_mod + importlib.reload(cfg_mod) + + cfg_mod.migrate_config(interactive=False, quiet=True) + updated = yaml.safe_load(config_path.read_text()) + # Existing "verbose" should NOT be overwritten by legacy "off" + assert updated["display"]["platforms"]["telegram"]["tool_progress"] == "verbose" + + +# --------------------------------------------------------------------------- +# Streaming per-platform (None = follow global) +# --------------------------------------------------------------------------- + +class TestStreamingPerPlatform: + """Streaming per-platform override semantics.""" + + def test_none_means_follow_global(self): + """When streaming is None, the caller should use global config.""" + from gateway.display_config import resolve_display_setting + + config = {} + # Telegram has no streaming override in defaults → None + result = resolve_display_setting(config, "telegram", "streaming") + assert result is None # caller should check global StreamingConfig + + def test_explicit_false_disables(self): + """Explicit False disables streaming for that platform.""" + from gateway.display_config import resolve_display_setting + + config = { + "display": { + "platforms": {"telegram": {"streaming": False}}, + } + } + assert resolve_display_setting(config, "telegram", "streaming") is False + + def test_explicit_true_enables(self): + """Explicit True enables streaming for that platform.""" + from gateway.display_config import resolve_display_setting + + config = { + "display": { + "platforms": {"email": {"streaming": True}}, + } + } + assert resolve_display_setting(config, "email", "streaming") is True diff --git a/tests/gateway/test_matrix.py b/tests/gateway/test_matrix.py index 4bde50b638..d5db07c645 100644 --- a/tests/gateway/test_matrix.py +++ b/tests/gateway/test_matrix.py @@ -157,12 +157,44 @@ def _make_fake_mautrix(): mautrix_crypto_store = types.ModuleType("mautrix.crypto.store") class MemoryCryptoStore: - def __init__(self, account_id="", pickle_key=""): + def __init__(self, account_id="", pickle_key=""): # noqa: S301 self.account_id = account_id self.pickle_key = pickle_key mautrix_crypto_store.MemoryCryptoStore = MemoryCryptoStore + # --- mautrix.crypto.store.asyncpg --- + mautrix_crypto_store_asyncpg = types.ModuleType("mautrix.crypto.store.asyncpg") + + class PgCryptoStore: + upgrade_table = MagicMock() + + def __init__(self, account_id="", pickle_key="", db=None): # noqa: S301 + self.account_id = account_id + self.pickle_key = pickle_key + self.db = db + + async def open(self): + pass + + mautrix_crypto_store_asyncpg.PgCryptoStore = PgCryptoStore + + # --- mautrix.util --- + mautrix_util = types.ModuleType("mautrix.util") + + # --- mautrix.util.async_db --- + mautrix_util_async_db = types.ModuleType("mautrix.util.async_db") + + class Database: + @classmethod + def create(cls, url, upgrade_table=None): + db = MagicMock() + db.start = AsyncMock() + db.stop = AsyncMock() + return db + + mautrix_util_async_db.Database = Database + return { "mautrix": mautrix, "mautrix.api": mautrix_api, @@ -171,6 +203,9 @@ def _make_fake_mautrix(): "mautrix.client.state_store": mautrix_client_state_store, "mautrix.crypto": mautrix_crypto, "mautrix.crypto.store": mautrix_crypto_store, + "mautrix.crypto.store.asyncpg": mautrix_crypto_store_asyncpg, + "mautrix.util": mautrix_util, + "mautrix.util.async_db": mautrix_util_async_db, } @@ -740,6 +775,12 @@ class TestMatrixAccessTokenAuth: mock_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123")) mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}}) mock_client.add_event_handler = MagicMock() + mock_client.handle_sync = MagicMock(return_value=[]) + mock_client.query_keys = AsyncMock(return_value={ + "device_keys": {"@bot:example.org": {"DEV123": { + "keys": {"ed25519:DEV123": "fake_ed25519_key"}, + }}}, + }) mock_client.api = MagicMock() mock_client.api.token = "syt_test_access_token" mock_client.api.session = MagicMock() @@ -751,6 +792,8 @@ class TestMatrixAccessTokenAuth: mock_olm.share_keys = AsyncMock() mock_olm.share_keys_min_trust = None mock_olm.send_keys_min_trust = None + mock_olm.account = MagicMock() + mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"} # Patch Client constructor to return our mock fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client) @@ -924,6 +967,12 @@ class TestMatrixDeviceId: mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="WHOAMI_DEV")) mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}}) mock_client.add_event_handler = MagicMock() + mock_client.handle_sync = MagicMock(return_value=[]) + mock_client.query_keys = AsyncMock(return_value={ + "device_keys": {"@bot:example.org": {"MY_STABLE_DEVICE": { + "keys": {"ed25519:MY_STABLE_DEVICE": "fake_ed25519_key"}, + }}}, + }) mock_client.api = MagicMock() mock_client.api.token = "syt_test_access_token" mock_client.api.session = MagicMock() @@ -934,6 +983,8 @@ class TestMatrixDeviceId: mock_olm.share_keys = AsyncMock() mock_olm.share_keys_min_trust = None mock_olm.send_keys_min_trust = None + mock_olm.account = MagicMock() + mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"} fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client) fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(return_value=mock_olm) @@ -1030,8 +1081,8 @@ class TestMatrixDeviceIdConfig: class TestMatrixSyncLoop: @pytest.mark.asyncio - async def test_sync_loop_shares_keys_when_encryption_enabled(self): - """_sync_loop should call crypto.share_keys() after each sync.""" + async def test_sync_loop_dispatches_events_and_stores_token(self): + """_sync_loop should call handle_sync() and persist next_batch.""" adapter = _make_adapter() adapter._encryption = True adapter._closing = False @@ -1046,7 +1097,6 @@ class TestMatrixSyncLoop: return {"rooms": {"join": {"!room:example.org": {}}}, "next_batch": "s1234"} mock_crypto = MagicMock() - mock_crypto.share_keys = AsyncMock() mock_sync_store = MagicMock() mock_sync_store.get_next_batch = AsyncMock(return_value=None) @@ -1062,7 +1112,6 @@ class TestMatrixSyncLoop: await adapter._sync_loop() fake_client.sync.assert_awaited_once() - mock_crypto.share_keys.assert_awaited_once() fake_client.handle_sync.assert_called_once() mock_sync_store.put_next_batch.assert_awaited_once_with("s1234") @@ -1248,6 +1297,12 @@ class TestMatrixEncryptedEventHandler: mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="DEV123")) mock_client.sync = AsyncMock(return_value={"rooms": {"join": {"!room:server": {}}}}) mock_client.add_event_handler = MagicMock() + mock_client.handle_sync = MagicMock(return_value=[]) + mock_client.query_keys = AsyncMock(return_value={ + "device_keys": {"@bot:example.org": {"DEV123": { + "keys": {"ed25519:DEV123": "fake_ed25519_key"}, + }}}, + }) mock_client.api = MagicMock() mock_client.api.token = "syt_test_token" mock_client.api.session = MagicMock() @@ -1258,6 +1313,8 @@ class TestMatrixEncryptedEventHandler: mock_olm.share_keys = AsyncMock() mock_olm.share_keys_min_trust = None mock_olm.send_keys_min_trust = None + mock_olm.account = MagicMock() + mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"} fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client) fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(return_value=mock_olm) diff --git a/tests/gateway/test_queue_consumption.py b/tests/gateway/test_queue_consumption.py index 2a4dd4ff02..50effc139d 100644 --- a/tests/gateway/test_queue_consumption.py +++ b/tests/gateway/test_queue_consumption.py @@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from gateway.run import _dequeue_pending_event from gateway.platforms.base import ( BasePlatformAdapter, MessageEvent, @@ -79,6 +80,26 @@ class TestQueueMessageStorage: # Should be consumed (cleared) assert adapter.get_pending_message(session_key) is None + def test_dequeue_pending_event_preserves_voice_media_metadata(self): + adapter = _StubAdapter() + session_key = "telegram:user:voice" + event = MessageEvent( + text="", + message_type=MessageType.VOICE, + source=MagicMock(chat_id="123", platform=Platform.TELEGRAM), + message_id="voice-q1", + media_urls=["/tmp/voice.ogg"], + media_types=["audio/ogg"], + ) + adapter._pending_messages[session_key] = event + + retrieved = _dequeue_pending_event(adapter, session_key) + + assert retrieved is event + assert retrieved.media_urls == ["/tmp/voice.ogg"] + assert retrieved.media_types == ["audio/ogg"] + assert adapter.get_pending_message(session_key) is None + def test_queue_does_not_set_interrupt_event(self): """The whole point of /queue — no interrupt signal.""" adapter = _StubAdapter() diff --git a/tests/gateway/test_run_progress_topics.py b/tests/gateway/test_run_progress_topics.py index c28317d7e4..6b1d46567d 100644 --- a/tests/gateway/test_run_progress_topics.py +++ b/tests/gateway/test_run_progress_topics.py @@ -8,8 +8,8 @@ from types import SimpleNamespace import pytest -from gateway.config import Platform, PlatformConfig -from gateway.platforms.base import BasePlatformAdapter, SendResult +from gateway.config import Platform, PlatformConfig, StreamingConfig +from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult from gateway.session import SessionSource @@ -104,6 +104,11 @@ def _make_runner(adapter): runner._session_db = None runner._running_agents = {} runner.hooks = SimpleNamespace(loaded_hooks=False) + runner.config = SimpleNamespace( + thread_sessions_per_user=False, + group_sessions_per_user=False, + stt_enabled=False, + ) return runner @@ -118,6 +123,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa fake_run_agent = types.ModuleType("run_agent") fake_run_agent.AIAgent = FakeAgent monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) + import tools.terminal_tool # noqa: F401 - register terminal emoji for this fake-agent test adapter = ProgressCaptureAdapter() runner = _make_runner(adapter) @@ -144,7 +150,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa assert adapter.sent == [ { "chat_id": "-1001", - "content": '⚙️ terminal: "pwd"', + "content": '💻 terminal: "pwd"', "reply_to": None, "metadata": {"thread_id": "17585"}, } @@ -334,3 +340,238 @@ def test_all_mode_no_truncation_when_preview_fits(monkeypatch, tmp_path): content = adapter.sent[0]["content"] # With a 200-char cap, the 165-char command should NOT be truncated assert "..." not in content, f"Preview was truncated when it shouldn't be: {content}" + + +class CommentaryAgent: + def __init__(self, **kwargs): + self.tool_progress_callback = kwargs.get("tool_progress_callback") + self.interim_assistant_callback = kwargs.get("interim_assistant_callback") + self.stream_delta_callback = kwargs.get("stream_delta_callback") + self.tools = [] + + def run_conversation(self, message, conversation_history=None, task_id=None): + if self.interim_assistant_callback: + self.interim_assistant_callback("I'll inspect the repo first.", already_streamed=False) + time.sleep(0.1) + if self.stream_delta_callback: + self.stream_delta_callback("done") + return { + "final_response": "done", + "messages": [], + "api_calls": 1, + } + + +class PreviewedResponseAgent: + def __init__(self, **kwargs): + self.interim_assistant_callback = kwargs.get("interim_assistant_callback") + self.tools = [] + + def run_conversation(self, message, conversation_history=None, task_id=None): + if self.interim_assistant_callback: + self.interim_assistant_callback("You're welcome.", already_streamed=False) + return { + "final_response": "You're welcome.", + "response_previewed": True, + "messages": [], + "api_calls": 1, + } + + +class QueuedCommentaryAgent: + calls = 0 + + def __init__(self, **kwargs): + self.interim_assistant_callback = kwargs.get("interim_assistant_callback") + self.tools = [] + + def run_conversation(self, message, conversation_history=None, task_id=None): + type(self).calls += 1 + if type(self).calls == 1 and self.interim_assistant_callback: + self.interim_assistant_callback("I'll inspect the repo first.", already_streamed=False) + return { + "final_response": f"final response {type(self).calls}", + "messages": [], + "api_calls": 1, + } + + +async def _run_with_agent( + monkeypatch, + tmp_path, + agent_cls, + *, + session_id, + pending_text=None, + config_data=None, +): + if config_data: + import yaml + + (tmp_path / "config.yaml").write_text(yaml.dump(config_data), encoding="utf-8") + + fake_dotenv = types.ModuleType("dotenv") + fake_dotenv.load_dotenv = lambda *args, **kwargs: None + monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv) + + fake_run_agent = types.ModuleType("run_agent") + fake_run_agent.AIAgent = agent_cls + monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) + + adapter = ProgressCaptureAdapter() + runner = _make_runner(adapter) + gateway_run = importlib.import_module("gateway.run") + if config_data and "streaming" in config_data: + runner.config.streaming = StreamingConfig.from_dict(config_data["streaming"]) + monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) + monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"}) + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="-1001", + chat_type="group", + thread_id="17585", + ) + session_key = "agent:main:telegram:group:-1001:17585" + if pending_text is not None: + adapter._pending_messages[session_key] = MessageEvent( + text=pending_text, + message_type=MessageType.TEXT, + source=source, + message_id="queued-1", + ) + + result = await runner._run_agent( + message="hello", + context_prompt="", + history=[], + source=source, + session_id=session_id, + session_key=session_key, + ) + return adapter, result + + +@pytest.mark.asyncio +async def test_run_agent_surfaces_real_interim_commentary(monkeypatch, tmp_path): + adapter, result = await _run_with_agent( + monkeypatch, + tmp_path, + CommentaryAgent, + session_id="sess-commentary", + config_data={"display": {"interim_assistant_messages": True}}, + ) + + assert result.get("already_sent") is not True + assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent) + + +@pytest.mark.asyncio +async def test_run_agent_surfaces_interim_commentary_by_default(monkeypatch, tmp_path): + adapter, result = await _run_with_agent( + monkeypatch, + tmp_path, + CommentaryAgent, + session_id="sess-commentary-default-on", + ) + + assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent) + + +@pytest.mark.asyncio +async def test_run_agent_suppresses_interim_commentary_when_disabled(monkeypatch, tmp_path): + adapter, result = await _run_with_agent( + monkeypatch, + tmp_path, + CommentaryAgent, + session_id="sess-commentary-disabled", + config_data={"display": {"interim_assistant_messages": False}}, + ) + + assert result.get("already_sent") is not True + assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent) + + +@pytest.mark.asyncio +async def test_run_agent_tool_progress_does_not_control_interim_commentary(monkeypatch, tmp_path): + """tool_progress=all with interim_assistant_messages=false should not surface commentary.""" + adapter, result = await _run_with_agent( + monkeypatch, + tmp_path, + CommentaryAgent, + session_id="sess-commentary-tool-progress", + config_data={"display": {"tool_progress": "all", "interim_assistant_messages": False}}, + ) + + assert result.get("already_sent") is not True + assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent) + + +@pytest.mark.asyncio +async def test_run_agent_streaming_does_not_enable_completed_interim_commentary( + monkeypatch, tmp_path +): + """Streaming alone with interim_assistant_messages=false should not surface commentary.""" + adapter, result = await _run_with_agent( + monkeypatch, + tmp_path, + CommentaryAgent, + session_id="sess-commentary-streaming", + config_data={ + "display": {"tool_progress": "off", "interim_assistant_messages": False}, + "streaming": {"enabled": True}, + }, + ) + + assert result.get("already_sent") is True + assert not any(call["content"] == "I'll inspect the repo first." for call in adapter.sent) + + +@pytest.mark.asyncio +async def test_run_agent_interim_commentary_works_with_tool_progress_off(monkeypatch, tmp_path): + adapter, result = await _run_with_agent( + monkeypatch, + tmp_path, + CommentaryAgent, + session_id="sess-commentary-explicit-on", + config_data={ + "display": { + "tool_progress": "off", + "interim_assistant_messages": True, + }, + }, + ) + + assert result.get("already_sent") is not True + assert any(call["content"] == "I'll inspect the repo first." for call in adapter.sent) + + +@pytest.mark.asyncio +async def test_run_agent_previewed_final_marks_already_sent(monkeypatch, tmp_path): + adapter, result = await _run_with_agent( + monkeypatch, + tmp_path, + PreviewedResponseAgent, + session_id="sess-previewed", + config_data={"display": {"interim_assistant_messages": True}}, + ) + + assert result.get("already_sent") is True + assert [call["content"] for call in adapter.sent] == ["You're welcome."] + + +@pytest.mark.asyncio +async def test_run_agent_queued_message_does_not_treat_commentary_as_final(monkeypatch, tmp_path): + QueuedCommentaryAgent.calls = 0 + adapter, result = await _run_with_agent( + monkeypatch, + tmp_path, + QueuedCommentaryAgent, + session_id="sess-queued-commentary", + pending_text="queued follow-up", + config_data={"display": {"interim_assistant_messages": True}}, + ) + + sent_texts = [call["content"] for call in adapter.sent] + assert result["final_response"] == "final response 2" + assert "I'll inspect the repo first." in sent_texts + assert "final response 1" in sent_texts diff --git a/tests/gateway/test_runner_startup_failures.py b/tests/gateway/test_runner_startup_failures.py index 1be67b71bb..787cb0adad 100644 --- a/tests/gateway/test_runner_startup_failures.py +++ b/tests/gateway/test_runner_startup_failures.py @@ -1,4 +1,5 @@ import pytest +from unittest.mock import AsyncMock from gateway.config import GatewayConfig, Platform, PlatformConfig from gateway.platforms.base import BasePlatformAdapter @@ -45,6 +46,23 @@ class _DisabledAdapter(BasePlatformAdapter): return {"id": chat_id} +class _SuccessfulAdapter(BasePlatformAdapter): + def __init__(self): + super().__init__(PlatformConfig(enabled=True, token="***"), Platform.DISCORD) + + async def connect(self) -> bool: + return True + + async def disconnect(self) -> None: + self._mark_disconnected() + + async def send(self, chat_id, content, reply_to=None, metadata=None): + raise NotImplementedError + + async def get_chat_info(self, chat_id): + return {"id": chat_id} + + @pytest.mark.asyncio async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch, tmp_path): monkeypatch.setenv("HERMES_HOME", str(tmp_path)) @@ -65,7 +83,7 @@ async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch, state = read_runtime_status() assert state["gateway_state"] == "startup_failed" assert "temporary DNS resolution failure" in state["exit_reason"] - assert state["platforms"]["telegram"]["state"] == "fatal" + assert state["platforms"]["telegram"]["state"] == "retrying" assert state["platforms"]["telegram"]["error_code"] == "telegram_connect_error" @@ -89,6 +107,31 @@ async def test_runner_allows_cron_only_mode_when_no_platforms_are_enabled(monkey assert state["gateway_state"] == "running" +@pytest.mark.asyncio +async def test_runner_records_connected_platform_state_on_success(monkeypatch, tmp_path): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + config = GatewayConfig( + platforms={ + Platform.DISCORD: PlatformConfig(enabled=True, token="***") + }, + sessions_dir=tmp_path / "sessions", + ) + runner = GatewayRunner(config) + + monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _SuccessfulAdapter()) + monkeypatch.setattr(runner.hooks, "discover_and_load", lambda: None) + monkeypatch.setattr(runner.hooks, "emit", AsyncMock()) + + ok = await runner.start() + + assert ok is True + state = read_runtime_status() + assert state["gateway_state"] == "running" + assert state["platforms"]["discord"]["state"] == "connected" + assert state["platforms"]["discord"]["error_code"] is None + assert state["platforms"]["discord"]["error_message"] is None + + @pytest.mark.asyncio async def test_start_gateway_replace_force_uses_terminate_pid(monkeypatch, tmp_path): monkeypatch.setenv("HERMES_HOME", str(tmp_path)) diff --git a/tests/gateway/test_session_env.py b/tests/gateway/test_session_env.py index b75e267f11..9f556f8846 100644 --- a/tests/gateway/test_session_env.py +++ b/tests/gateway/test_session_env.py @@ -1,3 +1,4 @@ +import asyncio import os from gateway.config import Platform @@ -130,3 +131,99 @@ def test_set_session_env_handles_missing_optional_fields(): assert get_session_env("HERMES_SESSION_THREAD_ID") == "" runner._clear_session_env(tokens) + + +# --------------------------------------------------------------------------- +# SESSION_KEY contextvars tests +# --------------------------------------------------------------------------- + + +def test_session_key_set_via_contextvars(monkeypatch): + """set_session_vars should set HERMES_SESSION_KEY via contextvars.""" + monkeypatch.delenv("HERMES_SESSION_KEY", raising=False) + + tokens = set_session_vars( + platform="telegram", + chat_id="-1001", + session_key="tg:-1001:17585", + ) + assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585" + + clear_session_vars(tokens) + assert get_session_env("HERMES_SESSION_KEY") == "" + + +def test_session_key_falls_back_to_os_environ(monkeypatch): + """get_session_env for SESSION_KEY should fall back to os.environ.""" + monkeypatch.setenv("HERMES_SESSION_KEY", "env-session-123") + + # No contextvar set — should read from os.environ + assert get_session_env("HERMES_SESSION_KEY") == "env-session-123" + + # Set contextvar — should prefer it + tokens = set_session_vars(session_key="ctx-session-456") + assert get_session_env("HERMES_SESSION_KEY") == "ctx-session-456" + + # Restore — should fall back to os.environ + clear_session_vars(tokens) + assert get_session_env("HERMES_SESSION_KEY") == "env-session-123" + + +def test_set_session_env_includes_session_key(): + """_set_session_env should propagate session_key from SessionContext.""" + runner = object.__new__(GatewayRunner) + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="-1001", + chat_name="Group", + chat_type="group", + thread_id="17585", + ) + context = SessionContext( + source=source, + connected_platforms=[], + home_channels={}, + session_key="tg:-1001:17585", + ) + + tokens = runner._set_session_env(context) + assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585" + runner._clear_session_env(tokens) + assert get_session_env("HERMES_SESSION_KEY") == "" + + +def test_session_key_no_race_condition_with_contextvars(monkeypatch): + """Prove contextvars isolates SESSION_KEY across concurrent async tasks. + + Two tasks set different session keys. With contextvars each task + reads back its own value. With os.environ the second task would + overwrite the first (the old bug). + """ + monkeypatch.delenv("HERMES_SESSION_KEY", raising=False) + + results = {} + + async def handler(key: str, delay: float): + tokens = set_session_vars(session_key=key) + try: + await asyncio.sleep(delay) + read_back = get_session_env("HERMES_SESSION_KEY") + results[key] = read_back + finally: + clear_session_vars(tokens) + + async def run(): + task_a = asyncio.create_task(handler("session-A", 0.15)) + await asyncio.sleep(0.05) + task_b = asyncio.create_task(handler("session-B", 0.05)) + await asyncio.gather(task_a, task_b) + + asyncio.run(run()) + + # Both tasks must read back their own session key + assert results["session-A"] == "session-A", ( + f"Session A got '{results['session-A']}' instead of 'session-A' — race condition!" + ) + assert results["session-B"] == "session-B", ( + f"Session B got '{results['session-B']}' instead of 'session-B' — race condition!" + ) diff --git a/tests/gateway/test_status.py b/tests/gateway/test_status.py index 6792061f92..16d4bfc5e8 100644 --- a/tests/gateway/test_status.py +++ b/tests/gateway/test_status.py @@ -104,6 +104,34 @@ class TestGatewayRuntimeStatus: assert payload["platforms"]["telegram"]["error_code"] == "telegram_polling_conflict" assert payload["platforms"]["telegram"]["error_message"] == "another poller is active" + def test_write_runtime_status_explicit_none_clears_stale_fields(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + status.write_runtime_status( + gateway_state="startup_failed", + exit_reason="stale error", + platform="discord", + platform_state="fatal", + error_code="discord_timeout", + error_message="stale platform error", + ) + + status.write_runtime_status( + gateway_state="running", + exit_reason=None, + platform="discord", + platform_state="connected", + error_code=None, + error_message=None, + ) + + payload = status.read_runtime_status() + assert payload["gateway_state"] == "running" + assert payload["exit_reason"] is None + assert payload["platforms"]["discord"]["state"] == "connected" + assert payload["platforms"]["discord"]["error_code"] is None + assert payload["platforms"]["discord"]["error_message"] is None + class TestTerminatePid: def test_force_uses_taskkill_on_windows(self, monkeypatch): diff --git a/tests/gateway/test_stream_consumer.py b/tests/gateway/test_stream_consumer.py index 5cebb20eee..8f7fb6dd5d 100644 --- a/tests/gateway/test_stream_consumer.py +++ b/tests/gateway/test_stream_consumer.py @@ -505,3 +505,81 @@ class TestSegmentBreakOnToolBoundary: assert len(sent_texts) == 3 assert sent_texts[0].startswith(prefix) assert sum(len(t) for t in sent_texts[1:]) == len(tail) + + +class TestInterimCommentaryMessages: + @pytest.mark.asyncio + async def test_commentary_message_stays_separate_from_final_stream(self): + adapter = MagicMock() + adapter.send = AsyncMock(side_effect=[ + SimpleNamespace(success=True, message_id="msg_1"), + SimpleNamespace(success=True, message_id="msg_2"), + ]) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + adapter.MAX_MESSAGE_LENGTH = 4096 + + consumer = GatewayStreamConsumer( + adapter, + "chat_123", + StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5), + ) + + consumer.on_commentary("I'll inspect the repository first.") + consumer.on_delta("Done.") + consumer.finish() + + await consumer.run() + + sent_texts = [call[1]["content"] for call in adapter.send.call_args_list] + assert sent_texts == ["I'll inspect the repository first.", "Done."] + assert consumer.final_response_sent is True + + @pytest.mark.asyncio + async def test_failed_final_send_does_not_mark_final_response_sent(self): + adapter = MagicMock() + adapter.send = AsyncMock(return_value=SimpleNamespace(success=False, message_id=None)) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + adapter.MAX_MESSAGE_LENGTH = 4096 + + consumer = GatewayStreamConsumer( + adapter, + "chat_123", + StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5), + ) + + consumer.on_delta("Done.") + consumer.finish() + + await consumer.run() + + assert consumer.final_response_sent is False + assert consumer.already_sent is False + + @pytest.mark.asyncio + async def test_success_without_message_id_marks_visible_and_sends_only_tail(self): + adapter = MagicMock() + adapter.send = AsyncMock(side_effect=[ + SimpleNamespace(success=True, message_id=None), + SimpleNamespace(success=True, message_id=None), + ]) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + adapter.MAX_MESSAGE_LENGTH = 4096 + + consumer = GatewayStreamConsumer( + adapter, + "chat_123", + StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉"), + ) + + consumer.on_delta("Hello") + task = asyncio.create_task(consumer.run()) + await asyncio.sleep(0.08) + consumer.on_delta(" world") + await asyncio.sleep(0.08) + consumer.finish() + await task + + sent_texts = [call[1]["content"] for call in adapter.send.call_args_list] + assert sent_texts == ["Hello ▉", "world"] + assert consumer.already_sent is True + assert consumer.final_response_sent is True diff --git a/tests/gateway/test_stt_config.py b/tests/gateway/test_stt_config.py index a49e402151..23ba06af22 100644 --- a/tests/gateway/test_stt_config.py +++ b/tests/gateway/test_stt_config.py @@ -6,7 +6,9 @@ from unittest.mock import AsyncMock, patch import pytest import yaml -from gateway.config import GatewayConfig, load_gateway_config +from gateway.config import GatewayConfig, Platform, load_gateway_config +from gateway.platforms.base import MessageEvent, MessageType +from gateway.session import SessionSource def test_gateway_config_stt_disabled_from_dict_nested(): @@ -69,3 +71,46 @@ async def test_enrich_message_with_transcription_avoids_bogus_no_provider_messag assert "No STT provider is configured" not in result assert "trouble transcribing" in result assert "caption" in result + + +@pytest.mark.asyncio +async def test_prepare_inbound_message_text_transcribes_queued_voice_event(): + from gateway.run import GatewayRunner + + runner = GatewayRunner.__new__(GatewayRunner) + runner.config = GatewayConfig(stt_enabled=True) + runner.adapters = {} + runner._model = "test-model" + runner._base_url = "" + runner._has_setup_skill = lambda: False + + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123", + chat_type="dm", + ) + event = MessageEvent( + text="", + message_type=MessageType.VOICE, + source=source, + media_urls=["/tmp/queued-voice.ogg"], + media_types=["audio/ogg"], + ) + + with patch( + "tools.transcription_tools.transcribe_audio", + return_value={ + "success": True, + "transcript": "queued voice transcript", + "provider": "local_command", + }, + ): + result = await runner._prepare_inbound_message_text( + event=event, + source=source, + history=[], + ) + + assert result is not None + assert "queued voice transcript" in result + assert "voice message" in result.lower() diff --git a/tests/gateway/test_verbose_command.py b/tests/gateway/test_verbose_command.py index 857d0744e1..c34167b2e4 100644 --- a/tests/gateway/test_verbose_command.py +++ b/tests/gateway/test_verbose_command.py @@ -63,7 +63,7 @@ class TestVerboseCommand: @pytest.mark.asyncio async def test_enabled_cycles_mode(self, tmp_path, monkeypatch): - """When enabled, /verbose cycles tool_progress mode.""" + """When enabled, /verbose cycles tool_progress mode per-platform.""" hermes_home = tmp_path / "hermes" hermes_home.mkdir() config_path = hermes_home / "config.yaml" @@ -79,10 +79,11 @@ class TestVerboseCommand: # all -> verbose assert "VERBOSE" in result + assert "telegram" in result.lower() # per-platform feedback - # Verify config was saved + # Verify config was saved to display.platforms.telegram saved = yaml.safe_load(config_path.read_text(encoding="utf-8")) - assert saved["display"]["tool_progress"] == "verbose" + assert saved["display"]["platforms"]["telegram"]["tool_progress"] == "verbose" @pytest.mark.asyncio async def test_cycles_through_all_modes(self, tmp_path, monkeypatch): @@ -103,8 +104,9 @@ class TestVerboseCommand: for mode in expected: result = await runner._handle_verbose_command(_make_event()) saved = yaml.safe_load(config_path.read_text(encoding="utf-8")) - assert saved["display"]["tool_progress"] == mode, \ - f"Expected {mode}, got {saved['display']['tool_progress']}" + actual = saved["display"]["platforms"]["telegram"]["tool_progress"] + assert actual == mode, \ + f"Expected {mode}, got {actual}" @pytest.mark.asyncio async def test_defaults_to_all_when_no_tool_progress_set(self, tmp_path, monkeypatch): @@ -122,10 +124,45 @@ class TestVerboseCommand: runner = _make_runner() result = await runner._handle_verbose_command(_make_event()) - # default "all" -> verbose + # Telegram default is "all" (high tier) → cycles to verbose assert "VERBOSE" in result saved = yaml.safe_load(config_path.read_text(encoding="utf-8")) - assert saved["display"]["tool_progress"] == "verbose" + assert saved["display"]["platforms"]["telegram"]["tool_progress"] == "verbose" + + @pytest.mark.asyncio + async def test_per_platform_isolation(self, tmp_path, monkeypatch): + """Cycling /verbose on Telegram doesn't change Slack's setting. + + Without a global tool_progress, each platform uses its built-in + default: Telegram = 'all' (high tier), Slack = 'new' (medium tier). + """ + hermes_home = tmp_path / "hermes" + hermes_home.mkdir() + config_path = hermes_home / "config.yaml" + # No global tool_progress → built-in platform defaults apply + config_path.write_text( + "display:\n tool_progress_command: true\n", + encoding="utf-8", + ) + + monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home) + runner = _make_runner() + + # Cycle on Telegram + await runner._handle_verbose_command( + _make_event(platform=Platform.TELEGRAM) + ) + # Cycle on Slack + await runner._handle_verbose_command( + _make_event(platform=Platform.SLACK) + ) + + saved = yaml.safe_load(config_path.read_text(encoding="utf-8")) + platforms = saved["display"]["platforms"] + # Telegram: all -> verbose (high tier default = all) + assert platforms["telegram"]["tool_progress"] == "verbose" + # Slack: new -> all (medium tier default = new, cycle to all) + assert platforms["slack"]["tool_progress"] == "all" @pytest.mark.asyncio async def test_no_config_file_returns_disabled(self, tmp_path, monkeypatch): diff --git a/tests/gateway/test_wecom_callback.py b/tests/gateway/test_wecom_callback.py new file mode 100644 index 0000000000..88c084ae3e --- /dev/null +++ b/tests/gateway/test_wecom_callback.py @@ -0,0 +1,185 @@ +"""Tests for the WeCom callback-mode adapter.""" + +import asyncio +from xml.etree import ElementTree as ET + +import pytest + +from gateway.config import PlatformConfig +from gateway.platforms.wecom_callback import WecomCallbackAdapter +from gateway.platforms.wecom_crypto import WXBizMsgCrypt + + +def _app(name="test-app", corp_id="ww1234567890", agent_id="1000002"): + return { + "name": name, + "corp_id": corp_id, + "corp_secret": "test-secret", + "agent_id": agent_id, + "token": "test-callback-token", + "encoding_aes_key": "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG", + } + + +def _config(apps=None): + return PlatformConfig( + enabled=True, + extra={"mode": "callback", "host": "127.0.0.1", "port": 0, "apps": apps or [_app()]}, + ) + + +class TestWecomCrypto: + def test_roundtrip_encrypt_decrypt(self): + app = _app() + crypt = WXBizMsgCrypt(app["token"], app["encoding_aes_key"], app["corp_id"]) + encrypted_xml = crypt.encrypt( + "hello", nonce="nonce123", timestamp="123456", + ) + root = ET.fromstring(encrypted_xml) + decrypted = crypt.decrypt( + root.findtext("MsgSignature", default=""), + root.findtext("TimeStamp", default=""), + root.findtext("Nonce", default=""), + root.findtext("Encrypt", default=""), + ) + assert b"hello" in decrypted + + def test_signature_mismatch_raises(self): + app = _app() + crypt = WXBizMsgCrypt(app["token"], app["encoding_aes_key"], app["corp_id"]) + encrypted_xml = crypt.encrypt("", nonce="n", timestamp="1") + root = ET.fromstring(encrypted_xml) + from gateway.platforms.wecom_crypto import SignatureError + with pytest.raises(SignatureError): + crypt.decrypt("bad-sig", "1", "n", root.findtext("Encrypt", default="")) + + +class TestWecomCallbackEventConstruction: + def test_build_event_extracts_text_message(self): + adapter = WecomCallbackAdapter(_config()) + xml_text = """ + + ww1234567890 + zhangsan + 1710000000 + text + \u4f60\u597d + 123456789 + + """ + event = adapter._build_event(_app(), xml_text) + assert event is not None + assert event.source is not None + assert event.source.user_id == "zhangsan" + assert event.source.chat_id == "ww1234567890:zhangsan" + assert event.message_id == "123456789" + assert event.text == "\u4f60\u597d" + + def test_build_event_returns_none_for_subscribe(self): + adapter = WecomCallbackAdapter(_config()) + xml_text = """ + + ww1234567890 + zhangsan + 1710000000 + event + subscribe + + """ + event = adapter._build_event(_app(), xml_text) + assert event is None + + +class TestWecomCallbackRouting: + def test_user_app_key_scopes_across_corps(self): + adapter = WecomCallbackAdapter(_config()) + assert adapter._user_app_key("corpA", "alice") == "corpA:alice" + assert adapter._user_app_key("corpB", "alice") == "corpB:alice" + assert adapter._user_app_key("corpA", "alice") != adapter._user_app_key("corpB", "alice") + + @pytest.mark.asyncio + async def test_send_selects_correct_app_for_scoped_chat_id(self): + apps = [ + _app(name="corp-a", corp_id="corpA", agent_id="1001"), + _app(name="corp-b", corp_id="corpB", agent_id="2002"), + ] + adapter = WecomCallbackAdapter(_config(apps=apps)) + adapter._user_app_map["corpB:alice"] = "corp-b" + adapter._access_tokens["corp-b"] = {"token": "tok-b", "expires_at": 9999999999} + + calls = {} + + class FakeResponse: + def json(self): + return {"errcode": 0, "msgid": "ok1"} + + class FakeClient: + async def post(self, url, json): + calls["url"] = url + calls["json"] = json + return FakeResponse() + + adapter._http_client = FakeClient() + result = await adapter.send("corpB:alice", "hello") + + assert result.success is True + assert calls["json"]["touser"] == "alice" + assert calls["json"]["agentid"] == 2002 + assert "tok-b" in calls["url"] + + @pytest.mark.asyncio + async def test_send_falls_back_from_bare_user_id_when_unique(self): + apps = [_app(name="corp-a", corp_id="corpA", agent_id="1001")] + adapter = WecomCallbackAdapter(_config(apps=apps)) + adapter._user_app_map["corpA:alice"] = "corp-a" + adapter._access_tokens["corp-a"] = {"token": "tok-a", "expires_at": 9999999999} + + calls = {} + + class FakeResponse: + def json(self): + return {"errcode": 0, "msgid": "ok2"} + + class FakeClient: + async def post(self, url, json): + calls["url"] = url + calls["json"] = json + return FakeResponse() + + adapter._http_client = FakeClient() + result = await adapter.send("alice", "hello") + + assert result.success is True + assert calls["json"]["agentid"] == 1001 + + +class TestWecomCallbackPollLoop: + @pytest.mark.asyncio + async def test_poll_loop_dispatches_handle_message(self, monkeypatch): + adapter = WecomCallbackAdapter(_config()) + calls = [] + + async def fake_handle_message(event): + calls.append(event.text) + + monkeypatch.setattr(adapter, "handle_message", fake_handle_message) + event = adapter._build_event( + _app(), + """ + + ww1234567890 + lisi + 1710000000 + text + test + m2 + + """, + ) + task = asyncio.create_task(adapter._poll_loop()) + await adapter._message_queue.put(event) + await asyncio.sleep(0.05) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + assert calls == ["test"] diff --git a/tests/hermes_cli/test_backup.py b/tests/hermes_cli/test_backup.py new file mode 100644 index 0000000000..8ef3858962 --- /dev/null +++ b/tests/hermes_cli/test_backup.py @@ -0,0 +1,897 @@ +"""Tests for hermes backup and import commands.""" + +import os +import zipfile +from argparse import Namespace +from pathlib import Path +from unittest.mock import patch + +import pytest + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_hermes_tree(root: Path) -> None: + """Create a realistic ~/.hermes directory structure for testing.""" + (root / "config.yaml").write_text("model:\n provider: openrouter\n") + (root / ".env").write_text("OPENROUTER_API_KEY=sk-test-123\n") + (root / "memory_store.db").write_bytes(b"fake-sqlite") + (root / "hermes_state.db").write_bytes(b"fake-state") + + # Sessions + (root / "sessions").mkdir(exist_ok=True) + (root / "sessions" / "abc123.json").write_text("{}") + + # Skills + (root / "skills").mkdir(exist_ok=True) + (root / "skills" / "my-skill").mkdir() + (root / "skills" / "my-skill" / "SKILL.md").write_text("# My Skill\n") + + # Skins + (root / "skins").mkdir(exist_ok=True) + (root / "skins" / "cyber.yaml").write_text("name: cyber\n") + + # Cron + (root / "cron").mkdir(exist_ok=True) + (root / "cron" / "jobs.json").write_text("[]") + + # Memories + (root / "memories").mkdir(exist_ok=True) + (root / "memories" / "notes.json").write_text("{}") + + # Profiles + (root / "profiles").mkdir(exist_ok=True) + (root / "profiles" / "coder").mkdir() + (root / "profiles" / "coder" / "config.yaml").write_text("model:\n provider: anthropic\n") + (root / "profiles" / "coder" / ".env").write_text("ANTHROPIC_API_KEY=sk-ant-123\n") + + # hermes-agent repo (should be EXCLUDED) + (root / "hermes-agent").mkdir(exist_ok=True) + (root / "hermes-agent" / "run_agent.py").write_text("# big file\n") + (root / "hermes-agent" / ".git").mkdir() + (root / "hermes-agent" / ".git" / "HEAD").write_text("ref: refs/heads/main\n") + + # __pycache__ (should be EXCLUDED) + (root / "plugins").mkdir(exist_ok=True) + (root / "plugins" / "__pycache__").mkdir() + (root / "plugins" / "__pycache__" / "mod.cpython-312.pyc").write_bytes(b"\x00") + + # PID files (should be EXCLUDED) + (root / "gateway.pid").write_text("12345") + + # Logs (should be included) + (root / "logs").mkdir(exist_ok=True) + (root / "logs" / "agent.log").write_text("log line\n") + + +# --------------------------------------------------------------------------- +# _should_exclude tests +# --------------------------------------------------------------------------- + +class TestShouldExclude: + def test_excludes_hermes_agent(self): + from hermes_cli.backup import _should_exclude + assert _should_exclude(Path("hermes-agent/run_agent.py")) + assert _should_exclude(Path("hermes-agent/.git/HEAD")) + + def test_excludes_pycache(self): + from hermes_cli.backup import _should_exclude + assert _should_exclude(Path("plugins/__pycache__/mod.cpython-312.pyc")) + + def test_excludes_pyc_files(self): + from hermes_cli.backup import _should_exclude + assert _should_exclude(Path("some/module.pyc")) + + def test_excludes_pid_files(self): + from hermes_cli.backup import _should_exclude + assert _should_exclude(Path("gateway.pid")) + assert _should_exclude(Path("cron.pid")) + + def test_includes_config(self): + from hermes_cli.backup import _should_exclude + assert not _should_exclude(Path("config.yaml")) + + def test_includes_env(self): + from hermes_cli.backup import _should_exclude + assert not _should_exclude(Path(".env")) + + def test_includes_skills(self): + from hermes_cli.backup import _should_exclude + assert not _should_exclude(Path("skills/my-skill/SKILL.md")) + + def test_includes_profiles(self): + from hermes_cli.backup import _should_exclude + assert not _should_exclude(Path("profiles/coder/config.yaml")) + + def test_includes_sessions(self): + from hermes_cli.backup import _should_exclude + assert not _should_exclude(Path("sessions/abc.json")) + + def test_includes_logs(self): + from hermes_cli.backup import _should_exclude + assert not _should_exclude(Path("logs/agent.log")) + + +# --------------------------------------------------------------------------- +# Backup tests +# --------------------------------------------------------------------------- + +class TestBackup: + def test_creates_zip(self, tmp_path, monkeypatch): + """Backup creates a valid zip containing expected files.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + _make_hermes_tree(hermes_home) + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + # get_default_hermes_root needs this + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + out_zip = tmp_path / "backup.zip" + args = Namespace(output=str(out_zip)) + + from hermes_cli.backup import run_backup + run_backup(args) + + assert out_zip.exists() + with zipfile.ZipFile(out_zip, "r") as zf: + names = zf.namelist() + # Config should be present + assert "config.yaml" in names + assert ".env" in names + # Skills + assert "skills/my-skill/SKILL.md" in names + # Profiles + assert "profiles/coder/config.yaml" in names + assert "profiles/coder/.env" in names + # Sessions + assert "sessions/abc123.json" in names + # Logs + assert "logs/agent.log" in names + # Skins + assert "skins/cyber.yaml" in names + + def test_excludes_hermes_agent(self, tmp_path, monkeypatch): + """Backup does NOT include hermes-agent/ directory.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + _make_hermes_tree(hermes_home) + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + out_zip = tmp_path / "backup.zip" + args = Namespace(output=str(out_zip)) + + from hermes_cli.backup import run_backup + run_backup(args) + + with zipfile.ZipFile(out_zip, "r") as zf: + names = zf.namelist() + agent_files = [n for n in names if "hermes-agent" in n] + assert agent_files == [], f"hermes-agent files leaked into backup: {agent_files}" + + def test_excludes_pycache(self, tmp_path, monkeypatch): + """Backup does NOT include __pycache__ dirs.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + _make_hermes_tree(hermes_home) + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + out_zip = tmp_path / "backup.zip" + args = Namespace(output=str(out_zip)) + + from hermes_cli.backup import run_backup + run_backup(args) + + with zipfile.ZipFile(out_zip, "r") as zf: + names = zf.namelist() + pycache_files = [n for n in names if "__pycache__" in n] + assert pycache_files == [] + + def test_excludes_pid_files(self, tmp_path, monkeypatch): + """Backup does NOT include PID files.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + _make_hermes_tree(hermes_home) + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + out_zip = tmp_path / "backup.zip" + args = Namespace(output=str(out_zip)) + + from hermes_cli.backup import run_backup + run_backup(args) + + with zipfile.ZipFile(out_zip, "r") as zf: + names = zf.namelist() + pid_files = [n for n in names if n.endswith(".pid")] + assert pid_files == [] + + def test_default_output_path(self, tmp_path, monkeypatch): + """When no output path given, zip goes to ~/hermes-backup-*.zip.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text("model: test\n") + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + args = Namespace(output=None) + + from hermes_cli.backup import run_backup + run_backup(args) + + # Should exist in home dir + zips = list(tmp_path.glob("hermes-backup-*.zip")) + assert len(zips) == 1 + + +# --------------------------------------------------------------------------- +# Import tests +# --------------------------------------------------------------------------- + +class TestImport: + def _make_backup_zip(self, zip_path: Path, files: dict[str, str | bytes]) -> None: + """Create a test zip with given files.""" + with zipfile.ZipFile(zip_path, "w") as zf: + for name, content in files.items(): + if isinstance(content, bytes): + zf.writestr(name, content) + else: + zf.writestr(name, content) + + def test_restores_files(self, tmp_path, monkeypatch): + """Import extracts files into hermes home.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + zip_path = tmp_path / "backup.zip" + self._make_backup_zip(zip_path, { + "config.yaml": "model:\n provider: openrouter\n", + ".env": "OPENROUTER_API_KEY=sk-test\n", + "skills/my-skill/SKILL.md": "# My Skill\n", + "profiles/coder/config.yaml": "model:\n provider: anthropic\n", + }) + + args = Namespace(zipfile=str(zip_path), force=True) + + from hermes_cli.backup import run_import + run_import(args) + + assert (hermes_home / "config.yaml").read_text() == "model:\n provider: openrouter\n" + assert (hermes_home / ".env").read_text() == "OPENROUTER_API_KEY=sk-test\n" + assert (hermes_home / "skills" / "my-skill" / "SKILL.md").read_text() == "# My Skill\n" + assert (hermes_home / "profiles" / "coder" / "config.yaml").exists() + + def test_strips_hermes_prefix(self, tmp_path, monkeypatch): + """Import strips .hermes/ prefix if all entries share it.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + zip_path = tmp_path / "backup.zip" + self._make_backup_zip(zip_path, { + ".hermes/config.yaml": "model: test\n", + ".hermes/skills/a/SKILL.md": "# A\n", + }) + + args = Namespace(zipfile=str(zip_path), force=True) + + from hermes_cli.backup import run_import + run_import(args) + + assert (hermes_home / "config.yaml").read_text() == "model: test\n" + assert (hermes_home / "skills" / "a" / "SKILL.md").read_text() == "# A\n" + + def test_rejects_empty_zip(self, tmp_path, monkeypatch): + """Import rejects an empty zip.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + zip_path = tmp_path / "empty.zip" + with zipfile.ZipFile(zip_path, "w"): + pass # empty + + args = Namespace(zipfile=str(zip_path), force=True) + + from hermes_cli.backup import run_import + with pytest.raises(SystemExit): + run_import(args) + + def test_rejects_non_hermes_zip(self, tmp_path, monkeypatch): + """Import rejects a zip that doesn't look like a hermes backup.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + zip_path = tmp_path / "random.zip" + self._make_backup_zip(zip_path, { + "some/random/file.txt": "hello", + "another/thing.json": "{}", + }) + + args = Namespace(zipfile=str(zip_path), force=True) + + from hermes_cli.backup import run_import + with pytest.raises(SystemExit): + run_import(args) + + def test_blocks_path_traversal(self, tmp_path, monkeypatch): + """Import blocks zip entries with path traversal.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + zip_path = tmp_path / "evil.zip" + # Include a marker file so validation passes + self._make_backup_zip(zip_path, { + "config.yaml": "model: test\n", + "../../etc/passwd": "root:x:0:0\n", + }) + + args = Namespace(zipfile=str(zip_path), force=True) + + from hermes_cli.backup import run_import + run_import(args) + + # config.yaml should be restored + assert (hermes_home / "config.yaml").exists() + # traversal file should NOT exist outside hermes home + assert not (tmp_path / "etc" / "passwd").exists() + + def test_confirmation_prompt_abort(self, tmp_path, monkeypatch): + """Import aborts when user says no to confirmation.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + # Pre-existing config triggers the confirmation + (hermes_home / "config.yaml").write_text("existing: true\n") + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + zip_path = tmp_path / "backup.zip" + self._make_backup_zip(zip_path, { + "config.yaml": "model: restored\n", + }) + + args = Namespace(zipfile=str(zip_path), force=False) + + from hermes_cli.backup import run_import + with patch("builtins.input", return_value="n"): + run_import(args) + + # Original config should be unchanged + assert (hermes_home / "config.yaml").read_text() == "existing: true\n" + + def test_force_skips_confirmation(self, tmp_path, monkeypatch): + """Import with --force skips confirmation and overwrites.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text("existing: true\n") + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + zip_path = tmp_path / "backup.zip" + self._make_backup_zip(zip_path, { + "config.yaml": "model: restored\n", + }) + + args = Namespace(zipfile=str(zip_path), force=True) + + from hermes_cli.backup import run_import + run_import(args) + + assert (hermes_home / "config.yaml").read_text() == "model: restored\n" + + def test_missing_file_exits(self, tmp_path, monkeypatch): + """Import exits with error for nonexistent file.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + args = Namespace(zipfile=str(tmp_path / "nonexistent.zip"), force=True) + + from hermes_cli.backup import run_import + with pytest.raises(SystemExit): + run_import(args) + + +# --------------------------------------------------------------------------- +# Round-trip test +# --------------------------------------------------------------------------- + +class TestRoundTrip: + def test_backup_then_import(self, tmp_path, monkeypatch): + """Full round-trip: backup -> import to a new location -> verify.""" + # Source + src_home = tmp_path / "source" / ".hermes" + src_home.mkdir(parents=True) + _make_hermes_tree(src_home) + + monkeypatch.setenv("HERMES_HOME", str(src_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path / "source") + + # Backup + out_zip = tmp_path / "roundtrip.zip" + from hermes_cli.backup import run_backup, run_import + + run_backup(Namespace(output=str(out_zip))) + assert out_zip.exists() + + # Import into a different location + dst_home = tmp_path / "dest" / ".hermes" + dst_home.mkdir(parents=True) + monkeypatch.setenv("HERMES_HOME", str(dst_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path / "dest") + + run_import(Namespace(zipfile=str(out_zip), force=True)) + + # Verify key files + assert (dst_home / "config.yaml").read_text() == "model:\n provider: openrouter\n" + assert (dst_home / ".env").read_text() == "OPENROUTER_API_KEY=sk-test-123\n" + assert (dst_home / "skills" / "my-skill" / "SKILL.md").exists() + assert (dst_home / "profiles" / "coder" / "config.yaml").exists() + assert (dst_home / "sessions" / "abc123.json").exists() + assert (dst_home / "logs" / "agent.log").exists() + + # hermes-agent should NOT be present + assert not (dst_home / "hermes-agent").exists() + # __pycache__ should NOT be present + assert not (dst_home / "plugins" / "__pycache__").exists() + # PID files should NOT be present + assert not (dst_home / "gateway.pid").exists() + + +# --------------------------------------------------------------------------- +# Validate / detect-prefix unit tests +# --------------------------------------------------------------------------- + +class TestFormatSize: + def test_bytes(self): + from hermes_cli.backup import _format_size + assert _format_size(512) == "512 B" + + def test_kilobytes(self): + from hermes_cli.backup import _format_size + assert "KB" in _format_size(2048) + + def test_megabytes(self): + from hermes_cli.backup import _format_size + assert "MB" in _format_size(5 * 1024 * 1024) + + def test_gigabytes(self): + from hermes_cli.backup import _format_size + assert "GB" in _format_size(3 * 1024 ** 3) + + def test_terabytes(self): + from hermes_cli.backup import _format_size + assert "TB" in _format_size(2 * 1024 ** 4) + + +class TestValidation: + def test_validate_with_config(self): + """Zip with config.yaml passes validation.""" + import io + from hermes_cli.backup import _validate_backup_zip + + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr("config.yaml", "test") + buf.seek(0) + with zipfile.ZipFile(buf, "r") as zf: + ok, reason = _validate_backup_zip(zf) + assert ok + + def test_validate_with_env(self): + """Zip with .env passes validation.""" + import io + from hermes_cli.backup import _validate_backup_zip + + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr(".env", "KEY=val") + buf.seek(0) + with zipfile.ZipFile(buf, "r") as zf: + ok, reason = _validate_backup_zip(zf) + assert ok + + def test_validate_rejects_random(self): + """Zip without hermes markers fails validation.""" + import io + from hermes_cli.backup import _validate_backup_zip + + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr("random/file.txt", "hello") + buf.seek(0) + with zipfile.ZipFile(buf, "r") as zf: + ok, reason = _validate_backup_zip(zf) + assert not ok + + def test_detect_prefix_hermes(self): + """Detects .hermes/ prefix wrapping all entries.""" + import io + from hermes_cli.backup import _detect_prefix + + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr(".hermes/config.yaml", "test") + zf.writestr(".hermes/skills/a/SKILL.md", "skill") + buf.seek(0) + with zipfile.ZipFile(buf, "r") as zf: + assert _detect_prefix(zf) == ".hermes/" + + def test_detect_prefix_none(self): + """No prefix when entries are at root.""" + import io + from hermes_cli.backup import _detect_prefix + + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr("config.yaml", "test") + zf.writestr("skills/a/SKILL.md", "skill") + buf.seek(0) + with zipfile.ZipFile(buf, "r") as zf: + assert _detect_prefix(zf) == "" + + def test_detect_prefix_only_dirs(self): + """Prefix detection returns empty for zip with only directory entries.""" + import io + from hermes_cli.backup import _detect_prefix + + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + # Only directory entries (trailing slash) + zf.writestr(".hermes/", "") + zf.writestr(".hermes/skills/", "") + buf.seek(0) + with zipfile.ZipFile(buf, "r") as zf: + assert _detect_prefix(zf) == "" + + +# --------------------------------------------------------------------------- +# Edge case tests for uncovered paths +# --------------------------------------------------------------------------- + +class TestBackupEdgeCases: + def test_nonexistent_hermes_home(self, tmp_path, monkeypatch): + """Backup exits when hermes home doesn't exist.""" + fake_home = tmp_path / "nonexistent" / ".hermes" + monkeypatch.setenv("HERMES_HOME", str(fake_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path / "nonexistent") + + args = Namespace(output=str(tmp_path / "out.zip")) + + from hermes_cli.backup import run_backup + with pytest.raises(SystemExit): + run_backup(args) + + def test_output_is_directory(self, tmp_path, monkeypatch): + """When output path is a directory, zip is created inside it.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text("model: test\n") + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + out_dir = tmp_path / "backups" + out_dir.mkdir() + + args = Namespace(output=str(out_dir)) + + from hermes_cli.backup import run_backup + run_backup(args) + + zips = list(out_dir.glob("hermes-backup-*.zip")) + assert len(zips) == 1 + + def test_output_without_zip_suffix(self, tmp_path, monkeypatch): + """Output path without .zip gets suffix appended.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text("model: test\n") + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + out_path = tmp_path / "mybackup.tar" + args = Namespace(output=str(out_path)) + + from hermes_cli.backup import run_backup + run_backup(args) + + # Should have .tar.zip suffix + assert (tmp_path / "mybackup.tar.zip").exists() + + def test_empty_hermes_home(self, tmp_path, monkeypatch): + """Backup handles empty hermes home (no files to back up).""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + # Only excluded dirs, no actual files + (hermes_home / "__pycache__").mkdir() + (hermes_home / "__pycache__" / "foo.pyc").write_bytes(b"\x00") + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + args = Namespace(output=str(tmp_path / "out.zip")) + + from hermes_cli.backup import run_backup + run_backup(args) + + # No zip should be created + assert not (tmp_path / "out.zip").exists() + + def test_permission_error_during_backup(self, tmp_path, monkeypatch): + """Backup handles permission errors gracefully.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text("model: test\n") + + # Create an unreadable file + bad_file = hermes_home / "secret.db" + bad_file.write_text("data") + bad_file.chmod(0o000) + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + out_zip = tmp_path / "out.zip" + args = Namespace(output=str(out_zip)) + + from hermes_cli.backup import run_backup + try: + run_backup(args) + finally: + # Restore permissions for cleanup + bad_file.chmod(0o644) + + # Zip should still be created with the readable files + assert out_zip.exists() + + def test_skips_output_zip_inside_hermes(self, tmp_path, monkeypatch): + """Backup skips its own output zip if it's inside hermes root.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text("model: test\n") + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + # Output inside hermes home + out_zip = hermes_home / "backup.zip" + args = Namespace(output=str(out_zip)) + + from hermes_cli.backup import run_backup + run_backup(args) + + # The zip should exist but not contain itself + assert out_zip.exists() + with zipfile.ZipFile(out_zip, "r") as zf: + assert "backup.zip" not in zf.namelist() + + +class TestImportEdgeCases: + def _make_backup_zip(self, zip_path: Path, files: dict[str, str | bytes]) -> None: + with zipfile.ZipFile(zip_path, "w") as zf: + for name, content in files.items(): + zf.writestr(name, content) + + def test_not_a_zip(self, tmp_path, monkeypatch): + """Import rejects a non-zip file.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + not_zip = tmp_path / "fake.zip" + not_zip.write_text("this is not a zip") + + args = Namespace(zipfile=str(not_zip), force=True) + + from hermes_cli.backup import run_import + with pytest.raises(SystemExit): + run_import(args) + + def test_eof_during_confirmation(self, tmp_path, monkeypatch): + """Import handles EOFError during confirmation prompt.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text("existing\n") + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + zip_path = tmp_path / "backup.zip" + self._make_backup_zip(zip_path, {"config.yaml": "new\n"}) + + args = Namespace(zipfile=str(zip_path), force=False) + + from hermes_cli.backup import run_import + with patch("builtins.input", side_effect=EOFError): + with pytest.raises(SystemExit): + run_import(args) + + def test_keyboard_interrupt_during_confirmation(self, tmp_path, monkeypatch): + """Import handles KeyboardInterrupt during confirmation prompt.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + (hermes_home / ".env").write_text("KEY=val\n") + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + zip_path = tmp_path / "backup.zip" + self._make_backup_zip(zip_path, {"config.yaml": "new\n"}) + + args = Namespace(zipfile=str(zip_path), force=False) + + from hermes_cli.backup import run_import + with patch("builtins.input", side_effect=KeyboardInterrupt): + with pytest.raises(SystemExit): + run_import(args) + + def test_permission_error_during_import(self, tmp_path, monkeypatch): + """Import handles permission errors during extraction.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + # Create a read-only directory so extraction fails + locked_dir = hermes_home / "locked" + locked_dir.mkdir() + locked_dir.chmod(0o555) + + zip_path = tmp_path / "backup.zip" + self._make_backup_zip(zip_path, { + "config.yaml": "model: test\n", + "locked/secret.txt": "data", + }) + + args = Namespace(zipfile=str(zip_path), force=True) + + from hermes_cli.backup import run_import + try: + run_import(args) + finally: + locked_dir.chmod(0o755) + + # config.yaml should still be restored despite the error + assert (hermes_home / "config.yaml").exists() + + def test_progress_with_many_files(self, tmp_path, monkeypatch): + """Import shows progress with 500+ files.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + zip_path = tmp_path / "big.zip" + files = {"config.yaml": "model: test\n"} + for i in range(600): + files[f"sessions/s{i:04d}.json"] = "{}" + + self._make_backup_zip(zip_path, files) + + args = Namespace(zipfile=str(zip_path), force=True) + + from hermes_cli.backup import run_import + run_import(args) + + assert (hermes_home / "config.yaml").exists() + assert (hermes_home / "sessions" / "s0599.json").exists() + + +# --------------------------------------------------------------------------- +# Profile restoration tests +# --------------------------------------------------------------------------- + +class TestProfileRestoration: + def _make_backup_zip(self, zip_path: Path, files: dict[str, str | bytes]) -> None: + with zipfile.ZipFile(zip_path, "w") as zf: + for name, content in files.items(): + zf.writestr(name, content) + + def test_import_creates_profile_wrappers(self, tmp_path, monkeypatch): + """Import auto-creates wrapper scripts for restored profiles.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + # Mock the wrapper dir to be inside tmp_path + wrapper_dir = tmp_path / ".local" / "bin" + wrapper_dir.mkdir(parents=True) + + zip_path = tmp_path / "backup.zip" + self._make_backup_zip(zip_path, { + "config.yaml": "model:\n provider: openrouter\n", + "profiles/coder/config.yaml": "model:\n provider: anthropic\n", + "profiles/coder/.env": "ANTHROPIC_API_KEY=sk-test\n", + "profiles/researcher/config.yaml": "model:\n provider: deepseek\n", + }) + + args = Namespace(zipfile=str(zip_path), force=True) + + from hermes_cli.backup import run_import + run_import(args) + + # Profile directories should exist + assert (hermes_home / "profiles" / "coder" / "config.yaml").exists() + assert (hermes_home / "profiles" / "researcher" / "config.yaml").exists() + + # Wrapper scripts should be created + assert (wrapper_dir / "coder").exists() + assert (wrapper_dir / "researcher").exists() + + # Wrappers should contain the right content + coder_wrapper = (wrapper_dir / "coder").read_text() + assert "hermes -p coder" in coder_wrapper + + def test_import_skips_profile_dirs_without_config(self, tmp_path, monkeypatch): + """Import doesn't create wrappers for profile dirs without config.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + wrapper_dir = tmp_path / ".local" / "bin" + wrapper_dir.mkdir(parents=True) + + zip_path = tmp_path / "backup.zip" + self._make_backup_zip(zip_path, { + "config.yaml": "model: test\n", + "profiles/valid/config.yaml": "model: test\n", + "profiles/empty/readme.txt": "nothing here\n", + }) + + args = Namespace(zipfile=str(zip_path), force=True) + + from hermes_cli.backup import run_import + run_import(args) + + # Only valid profile should get a wrapper + assert (wrapper_dir / "valid").exists() + assert not (wrapper_dir / "empty").exists() + + def test_import_without_profiles_module(self, tmp_path, monkeypatch): + """Import gracefully handles missing profiles module (fresh install).""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.setattr(Path, "home", lambda: tmp_path) + + zip_path = tmp_path / "backup.zip" + self._make_backup_zip(zip_path, { + "config.yaml": "model: test\n", + "profiles/coder/config.yaml": "model: test\n", + }) + + args = Namespace(zipfile=str(zip_path), force=True) + + # Simulate profiles module not being available + import hermes_cli.backup as backup_mod + original_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__ + + def fake_import(name, *a, **kw): + if name == "hermes_cli.profiles": + raise ImportError("no profiles module") + return original_import(name, *a, **kw) + + from hermes_cli.backup import run_import + with patch("builtins.__import__", side_effect=fake_import): + run_import(args) + + # Files should still be restored even if wrappers can't be created + assert (hermes_home / "profiles" / "coder" / "config.yaml").exists() diff --git a/tests/hermes_cli/test_cli_model_picker.py b/tests/hermes_cli/test_cli_model_picker.py new file mode 100644 index 0000000000..1fe9fe51ac --- /dev/null +++ b/tests/hermes_cli/test_cli_model_picker.py @@ -0,0 +1,254 @@ +"""Tests for the interactive CLI /model picker (provider → model drill-down).""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + + +class _FakeBuffer: + def __init__(self, text="draft text"): + self.text = text + self.cursor_position = len(text) + self.reset_calls = [] + + def reset(self, append_to_history=False): + self.reset_calls.append(append_to_history) + self.text = "" + self.cursor_position = 0 + + +def _make_providers(): + return [ + { + "slug": "openrouter", + "name": "OpenRouter", + "is_current": True, + "is_user_defined": False, + "models": ["anthropic/claude-opus-4.6", "openai/gpt-5.4"], + "total_models": 2, + "source": "built-in", + }, + { + "slug": "anthropic", + "name": "Anthropic", + "is_current": False, + "is_user_defined": False, + "models": ["claude-opus-4.6", "claude-sonnet-4.6"], + "total_models": 2, + "source": "built-in", + }, + { + "slug": "custom:my-ollama", + "name": "My Ollama", + "is_current": False, + "is_user_defined": True, + "models": ["llama3", "mistral"], + "total_models": 2, + "source": "user-config", + "api_url": "http://localhost:11434/v1", + }, + ] + + +def _make_picker_cli(picker_return_value): + cli = MagicMock() + cli._run_curses_picker = MagicMock(return_value=picker_return_value) + cli._app = MagicMock() + cli._status_bar_visible = True + return cli + + +def _make_modal_cli(): + from cli import HermesCLI + + cli = HermesCLI.__new__(HermesCLI) + cli.model = "gpt-5.4" + cli.provider = "openrouter" + cli.requested_provider = "openrouter" + cli.base_url = "" + cli.api_key = "" + cli.api_mode = "" + cli._explicit_api_key = "" + cli._explicit_base_url = "" + cli._pending_model_switch_note = None + cli._model_picker_state = None + cli._modal_input_snapshot = None + cli._status_bar_visible = True + cli._invalidate = MagicMock() + cli.agent = None + cli.config = {} + cli.console = MagicMock() + cli._app = SimpleNamespace( + current_buffer=_FakeBuffer(), + invalidate=MagicMock(), + ) + return cli + + +def test_provider_selection_returns_slug_on_choice(): + providers = _make_providers() + cli = _make_picker_cli(1) + from cli import HermesCLI + + result = HermesCLI._interactive_provider_selection(cli, providers, "gpt-5.4", "OpenRouter") + + assert result == "anthropic" + cli._run_curses_picker.assert_called_once() + + +def test_provider_selection_returns_none_on_cancel(): + providers = _make_providers() + cli = _make_picker_cli(None) + from cli import HermesCLI + + result = HermesCLI._interactive_provider_selection(cli, providers, "gpt-5.4", "OpenRouter") + + assert result is None + + +def test_provider_selection_default_is_current(): + providers = _make_providers() + cli = _make_picker_cli(0) + from cli import HermesCLI + + HermesCLI._interactive_provider_selection(cli, providers, "gpt-5.4", "OpenRouter") + + assert cli._run_curses_picker.call_args.kwargs["default_index"] == 0 + + +def test_model_selection_returns_model_on_choice(): + provider_data = _make_providers()[0] + cli = _make_picker_cli(0) + from cli import HermesCLI + + result = HermesCLI._interactive_model_selection(cli, provider_data["models"], provider_data) + + assert result == "anthropic/claude-opus-4.6" + + +def test_model_selection_custom_entry_prompts_for_input(): + provider_data = _make_providers()[0] + cli = _make_picker_cli(2) + from cli import HermesCLI + + cli._prompt_text_input = MagicMock(return_value="my-custom-model") + result = HermesCLI._interactive_model_selection(cli, provider_data["models"], provider_data) + + assert result == "my-custom-model" + cli._prompt_text_input.assert_called_once_with(" Enter model name: ") + + +def test_model_selection_empty_prompts_for_manual_input(): + provider_data = { + "slug": "custom:empty", + "name": "Empty Provider", + "models": [], + "total_models": 0, + } + cli = _make_picker_cli(None) + from cli import HermesCLI + + cli._prompt_text_input = MagicMock(return_value="my-model") + result = HermesCLI._interactive_model_selection(cli, [], provider_data) + + assert result == "my-model" + cli._prompt_text_input.assert_called_once_with(" Enter model name manually (or Enter to cancel): ") + + +def test_prompt_text_input_uses_run_in_terminal_when_app_active(): + from cli import HermesCLI + + cli = _make_modal_cli() + + with ( + patch("prompt_toolkit.application.run_in_terminal", side_effect=lambda fn: fn()) as run_mock, + patch("builtins.input", return_value="manual-value"), + ): + result = HermesCLI._prompt_text_input(cli, "Enter value: ") + + assert result == "manual-value" + run_mock.assert_called_once() + assert cli._status_bar_visible is True + + +def test_should_handle_model_command_inline_uses_command_name_resolution(): + from cli import HermesCLI + + cli = _make_modal_cli() + + with patch("hermes_cli.commands.resolve_command", return_value=SimpleNamespace(name="model")): + assert HermesCLI._should_handle_model_command_inline(cli, "/model") is True + + with patch("hermes_cli.commands.resolve_command", return_value=SimpleNamespace(name="help")): + assert HermesCLI._should_handle_model_command_inline(cli, "/model") is False + + assert HermesCLI._should_handle_model_command_inline(cli, "/model", has_images=True) is False + + +def test_process_command_model_without_args_opens_modal_picker_and_captures_draft(): + from cli import HermesCLI + + cli = _make_modal_cli() + providers = _make_providers() + + with ( + patch("hermes_cli.model_switch.list_authenticated_providers", return_value=providers), + patch("cli._cprint"), + ): + result = cli.process_command("/model") + + assert result is True + assert cli._model_picker_state is not None + assert cli._model_picker_state["stage"] == "provider" + assert cli._model_picker_state["selected"] == 0 + assert cli._modal_input_snapshot == {"text": "draft text", "cursor_position": len("draft text")} + assert cli._app.current_buffer.text == "" + + +def test_model_picker_provider_then_model_selection_applies_switch_result_and_restores_draft(): + from cli import HermesCLI + + cli = _make_modal_cli() + providers = _make_providers() + + with ( + patch("hermes_cli.model_switch.list_authenticated_providers", return_value=providers), + patch("cli._cprint"), + ): + assert cli.process_command("/model") is True + + cli._model_picker_state["selected"] = 1 + with patch("hermes_cli.models.provider_model_ids", return_value=["claude-opus-4.6", "claude-sonnet-4.6"]): + HermesCLI._handle_model_picker_selection(cli) + + assert cli._model_picker_state["stage"] == "model" + assert cli._model_picker_state["provider_data"]["slug"] == "anthropic" + assert cli._model_picker_state["model_list"] == ["claude-opus-4.6", "claude-sonnet-4.6"] + + cli._model_picker_state["selected"] = 0 + switch_result = SimpleNamespace( + success=True, + error_message=None, + new_model="claude-opus-4.6", + target_provider="anthropic", + api_key="", + base_url="", + api_mode="anthropic_messages", + provider_label="Anthropic", + model_info=None, + warning_message=None, + provider_changed=True, + ) + + with ( + patch("hermes_cli.model_switch.switch_model", return_value=switch_result) as switch_mock, + patch("cli._cprint"), + ): + HermesCLI._handle_model_picker_selection(cli) + + assert cli._model_picker_state is None + assert cli.model == "claude-opus-4.6" + assert cli.provider == "anthropic" + assert cli.requested_provider == "anthropic" + assert cli._app.current_buffer.text == "draft text" + switch_mock.assert_called_once() + assert switch_mock.call_args.kwargs["explicit_provider"] == "anthropic" diff --git a/tests/hermes_cli/test_config.py b/tests/hermes_cli/test_config.py index 1c245577e9..d934a80125 100644 --- a/tests/hermes_cli/test_config.py +++ b/tests/hermes_cli/test_config.py @@ -68,6 +68,7 @@ class TestLoadConfigDefaults: assert "max_turns" not in config assert "terminal" in config assert config["terminal"]["backend"] == "local" + assert config["display"]["interim_assistant_messages"] is True def test_legacy_root_level_max_turns_migrates_to_agent_config(self, tmp_path): with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): @@ -421,3 +422,25 @@ class TestAnthropicTokenMigration: }): migrate_config(interactive=False, quiet=True) assert load_env().get("ANTHROPIC_TOKEN") == "current-token" + + +class TestInterimAssistantMessageConfig: + """Test the explicit gateway interim-message config gate.""" + + def test_default_config_enables_interim_assistant_messages(self): + assert DEFAULT_CONFIG["display"]["interim_assistant_messages"] is True + + def test_migrate_to_v15_adds_interim_assistant_message_gate(self, tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump({"_config_version": 14, "display": {"tool_progress": "off"}}), + encoding="utf-8", + ) + + with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): + migrate_config(interactive=False, quiet=True) + raw = yaml.safe_load(config_path.read_text(encoding="utf-8")) + + assert raw["_config_version"] == 16 + assert raw["display"]["tool_progress"] == "off" + assert raw["display"]["interim_assistant_messages"] is True diff --git a/tests/hermes_cli/test_container_aware_cli.py b/tests/hermes_cli/test_container_aware_cli.py new file mode 100644 index 0000000000..9e21c0b8d2 --- /dev/null +++ b/tests/hermes_cli/test_container_aware_cli.py @@ -0,0 +1,342 @@ +"""Tests for container-aware CLI routing (NixOS container mode). + +When container.enable = true in the NixOS module, the activation script +writes a .container-mode metadata file. The host CLI detects this and +execs into the container instead of running locally. +""" +import os +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from hermes_cli.config import ( + _is_inside_container, + get_container_exec_info, +) + + +# ============================================================================= +# _is_inside_container +# ============================================================================= + + +def test_is_inside_container_dockerenv(): + """Detects /.dockerenv marker file.""" + with patch("os.path.exists") as mock_exists: + mock_exists.side_effect = lambda p: p == "/.dockerenv" + assert _is_inside_container() is True + + +def test_is_inside_container_containerenv(): + """Detects Podman's /run/.containerenv marker.""" + with patch("os.path.exists") as mock_exists: + mock_exists.side_effect = lambda p: p == "/run/.containerenv" + assert _is_inside_container() is True + + +def test_is_inside_container_cgroup_docker(): + """Detects 'docker' in /proc/1/cgroup.""" + with patch("os.path.exists", return_value=False), \ + patch("builtins.open", create=True) as mock_open: + mock_open.return_value.__enter__ = lambda s: s + mock_open.return_value.__exit__ = MagicMock(return_value=False) + mock_open.return_value.read = MagicMock( + return_value="12:memory:/docker/abc123\n" + ) + assert _is_inside_container() is True + + +def test_is_inside_container_false_on_host(): + """Returns False when none of the container indicators are present.""" + with patch("os.path.exists", return_value=False), \ + patch("builtins.open", side_effect=OSError("no such file")): + assert _is_inside_container() is False + + +# ============================================================================= +# get_container_exec_info +# ============================================================================= + + +@pytest.fixture +def container_env(tmp_path, monkeypatch): + """Set up a fake HERMES_HOME with .container-mode file.""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.delenv("HERMES_DEV", raising=False) + + container_mode = hermes_home / ".container-mode" + container_mode.write_text( + "# Written by NixOS activation script. Do not edit manually.\n" + "backend=podman\n" + "container_name=hermes-agent\n" + "exec_user=hermes\n" + "hermes_bin=/data/current-package/bin/hermes\n" + ) + return hermes_home + + +def test_get_container_exec_info_returns_metadata(container_env): + """Reads .container-mode and returns all fields including exec_user.""" + with patch("hermes_cli.config._is_inside_container", return_value=False): + info = get_container_exec_info() + + assert info is not None + assert info["backend"] == "podman" + assert info["container_name"] == "hermes-agent" + assert info["exec_user"] == "hermes" + assert info["hermes_bin"] == "/data/current-package/bin/hermes" + + +def test_get_container_exec_info_none_inside_container(container_env): + """Returns None when we're already inside a container.""" + with patch("hermes_cli.config._is_inside_container", return_value=True): + info = get_container_exec_info() + + assert info is None + + +def test_get_container_exec_info_none_without_file(tmp_path, monkeypatch): + """Returns None when .container-mode doesn't exist (native mode).""" + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + monkeypatch.delenv("HERMES_DEV", raising=False) + + with patch("hermes_cli.config._is_inside_container", return_value=False): + info = get_container_exec_info() + + assert info is None + + +def test_get_container_exec_info_skipped_when_hermes_dev(container_env, monkeypatch): + """Returns None when HERMES_DEV=1 is set (dev mode bypass).""" + monkeypatch.setenv("HERMES_DEV", "1") + + with patch("hermes_cli.config._is_inside_container", return_value=False): + info = get_container_exec_info() + + assert info is None + + +def test_get_container_exec_info_not_skipped_when_hermes_dev_zero(container_env, monkeypatch): + """HERMES_DEV=0 does NOT trigger bypass — only '1' does.""" + monkeypatch.setenv("HERMES_DEV", "0") + + with patch("hermes_cli.config._is_inside_container", return_value=False): + info = get_container_exec_info() + + assert info is not None + + +def test_get_container_exec_info_defaults(): + """Falls back to defaults for missing keys.""" + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + hermes_home = Path(tmpdir) / ".hermes" + hermes_home.mkdir() + (hermes_home / ".container-mode").write_text( + "# minimal file with no keys\n" + ) + + with patch("hermes_cli.config._is_inside_container", return_value=False), \ + patch("hermes_cli.config.get_hermes_home", return_value=hermes_home), \ + patch.dict(os.environ, {}, clear=False): + os.environ.pop("HERMES_DEV", None) + info = get_container_exec_info() + + assert info is not None + assert info["backend"] == "docker" + assert info["container_name"] == "hermes-agent" + assert info["exec_user"] == "hermes" + assert info["hermes_bin"] == "/data/current-package/bin/hermes" + + +def test_get_container_exec_info_docker_backend(container_env): + """Correctly reads docker backend with custom exec_user.""" + (container_env / ".container-mode").write_text( + "backend=docker\n" + "container_name=hermes-custom\n" + "exec_user=myuser\n" + "hermes_bin=/opt/hermes/bin/hermes\n" + ) + + with patch("hermes_cli.config._is_inside_container", return_value=False): + info = get_container_exec_info() + + assert info["backend"] == "docker" + assert info["container_name"] == "hermes-custom" + assert info["exec_user"] == "myuser" + assert info["hermes_bin"] == "/opt/hermes/bin/hermes" + + +def test_get_container_exec_info_crashes_on_permission_error(container_env): + """PermissionError propagates instead of being silently swallowed.""" + with patch("hermes_cli.config._is_inside_container", return_value=False), \ + patch("builtins.open", side_effect=PermissionError("permission denied")): + with pytest.raises(PermissionError): + get_container_exec_info() + + +# ============================================================================= +# _exec_in_container +# ============================================================================= + + +@pytest.fixture +def docker_container_info(): + return { + "backend": "docker", + "container_name": "hermes-agent", + "exec_user": "hermes", + "hermes_bin": "/data/current-package/bin/hermes", + } + + +@pytest.fixture +def podman_container_info(): + return { + "backend": "podman", + "container_name": "hermes-agent", + "exec_user": "hermes", + "hermes_bin": "/data/current-package/bin/hermes", + } + + +def test_exec_in_container_calls_execvp(docker_container_info): + """Verifies os.execvp is called with correct args: runtime, tty flags, + user, env vars, container name, binary, and CLI args.""" + from hermes_cli.main import _exec_in_container + + with patch("shutil.which", return_value="/usr/bin/docker"), \ + patch("subprocess.run") as mock_run, \ + patch("sys.stdin") as mock_stdin, \ + patch("os.execvp") as mock_execvp, \ + patch.dict(os.environ, {"TERM": "xterm-256color", "LANG": "en_US.UTF-8"}, + clear=False): + mock_stdin.isatty.return_value = True + mock_run.return_value = MagicMock(returncode=0) + + _exec_in_container(docker_container_info, ["chat", "-m", "opus"]) + + mock_execvp.assert_called_once() + cmd = mock_execvp.call_args[0][1] + assert cmd[0] == "/usr/bin/docker" + assert cmd[1] == "exec" + assert "-it" in cmd + idx_u = cmd.index("-u") + assert cmd[idx_u + 1] == "hermes" + e_indices = [i for i, v in enumerate(cmd) if v == "-e"] + e_values = [cmd[i + 1] for i in e_indices] + assert "TERM=xterm-256color" in e_values + assert "LANG=en_US.UTF-8" in e_values + assert "hermes-agent" in cmd + assert "/data/current-package/bin/hermes" in cmd + assert "chat" in cmd + + +def test_exec_in_container_non_tty_uses_i_only(docker_container_info): + """Non-TTY mode uses -i instead of -it.""" + from hermes_cli.main import _exec_in_container + + with patch("shutil.which", return_value="/usr/bin/docker"), \ + patch("subprocess.run") as mock_run, \ + patch("sys.stdin") as mock_stdin, \ + patch("os.execvp") as mock_execvp: + mock_stdin.isatty.return_value = False + mock_run.return_value = MagicMock(returncode=0) + + _exec_in_container(docker_container_info, ["sessions", "list"]) + + cmd = mock_execvp.call_args[0][1] + assert "-i" in cmd + assert "-it" not in cmd + + +def test_exec_in_container_no_runtime_hard_fails(podman_container_info): + """Hard fails when runtime not found (no fallback).""" + from hermes_cli.main import _exec_in_container + + with patch("shutil.which", return_value=None), \ + patch("subprocess.run") as mock_run, \ + patch("os.execvp") as mock_execvp, \ + pytest.raises(SystemExit) as exc_info: + _exec_in_container(podman_container_info, ["chat"]) + + mock_run.assert_not_called() + mock_execvp.assert_not_called() + assert exc_info.value.code != 0 + + +def test_exec_in_container_sudo_probe_sets_prefix(podman_container_info): + """When first probe fails and sudo probe succeeds, execvp is called + with sudo -n prefix.""" + from hermes_cli.main import _exec_in_container + + def which_side_effect(name): + if name == "podman": + return "/usr/bin/podman" + if name == "sudo": + return "/usr/bin/sudo" + return None + + with patch("shutil.which", side_effect=which_side_effect), \ + patch("subprocess.run") as mock_run, \ + patch("sys.stdin") as mock_stdin, \ + patch("os.execvp") as mock_execvp: + mock_stdin.isatty.return_value = True + mock_run.side_effect = [ + MagicMock(returncode=1), # direct probe fails + MagicMock(returncode=0), # sudo probe succeeds + ] + + _exec_in_container(podman_container_info, ["chat"]) + + mock_execvp.assert_called_once() + cmd = mock_execvp.call_args[0][1] + assert cmd[0] == "/usr/bin/sudo" + assert cmd[1] == "-n" + assert cmd[2] == "/usr/bin/podman" + assert cmd[3] == "exec" + + +def test_exec_in_container_probe_timeout_prints_message(docker_container_info): + """TimeoutExpired from probe produces a human-readable error, not a + raw traceback.""" + from hermes_cli.main import _exec_in_container + + with patch("shutil.which", return_value="/usr/bin/docker"), \ + patch("subprocess.run", side_effect=subprocess.TimeoutExpired( + cmd=["docker", "inspect"], timeout=15)), \ + patch("os.execvp") as mock_execvp, \ + pytest.raises(SystemExit) as exc_info: + _exec_in_container(docker_container_info, ["chat"]) + + mock_execvp.assert_not_called() + assert exc_info.value.code == 1 + + +def test_exec_in_container_container_not_running_no_sudo(docker_container_info): + """When runtime exists but container not found and no sudo available, + prints helpful error about root containers.""" + from hermes_cli.main import _exec_in_container + + def which_side_effect(name): + if name == "docker": + return "/usr/bin/docker" + return None + + with patch("shutil.which", side_effect=which_side_effect), \ + patch("subprocess.run") as mock_run, \ + patch("os.execvp") as mock_execvp, \ + pytest.raises(SystemExit) as exc_info: + mock_run.return_value = MagicMock(returncode=1) + + _exec_in_container(docker_container_info, ["chat"]) + + mock_execvp.assert_not_called() + assert exc_info.value.code == 1 diff --git a/tests/hermes_cli/test_gateway.py b/tests/hermes_cli/test_gateway.py index 955449547c..fd88a26c6a 100644 --- a/tests/hermes_cli/test_gateway.py +++ b/tests/hermes_cli/test_gateway.py @@ -260,7 +260,7 @@ class TestWaitForGatewayExit: def test_kill_gateway_processes_force_uses_helper(self, monkeypatch): calls = [] - monkeypatch.setattr(gateway, "find_gateway_pids", lambda exclude_pids=None: [11, 22]) + monkeypatch.setattr(gateway, "find_gateway_pids", lambda exclude_pids=None, all_profiles=False: [11, 22]) monkeypatch.setattr(gateway, "terminate_pid", lambda pid, force=False: calls.append((pid, force))) killed = gateway.kill_gateway_processes(force=True) diff --git a/tests/hermes_cli/test_gateway_service.py b/tests/hermes_cli/test_gateway_service.py index c5d4cb4f5d..cba3a8192f 100644 --- a/tests/hermes_cli/test_gateway_service.py +++ b/tests/hermes_cli/test_gateway_service.py @@ -1,6 +1,7 @@ """Tests for gateway service management helpers.""" import os +import pwd from pathlib import Path from types import SimpleNamespace @@ -129,7 +130,7 @@ class TestGatewayStopCleanup: monkeypatch.setattr( gateway_cli, "kill_gateway_processes", - lambda force=False: kill_calls.append(force) or 2, + lambda force=False, all_profiles=False: kill_calls.append(force) or 2, ) gateway_cli.gateway_command(SimpleNamespace(gateway_command="stop")) @@ -155,7 +156,7 @@ class TestGatewayStopCleanup: monkeypatch.setattr( gateway_cli, "kill_gateway_processes", - lambda force=False: kill_calls.append(force) or 2, + lambda force=False, all_profiles=False: kill_calls.append(force) or 2, ) gateway_cli.gateway_command(SimpleNamespace(gateway_command="stop", **{"all": True})) @@ -924,6 +925,23 @@ class TestProfileArg: assert "--profile" in plist assert "mybot" in plist + def test_launchd_plist_path_uses_real_user_home_not_profile_home(self, tmp_path, monkeypatch): + profile_dir = tmp_path / ".hermes" / "profiles" / "orcha" + profile_dir.mkdir(parents=True) + machine_home = tmp_path / "machine-home" + machine_home.mkdir() + profile_home = profile_dir / "home" + profile_home.mkdir() + + monkeypatch.setattr(Path, "home", lambda: profile_home) + monkeypatch.setenv("HERMES_HOME", str(profile_dir)) + monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir) + monkeypatch.setattr(pwd, "getpwuid", lambda uid: SimpleNamespace(pw_dir=str(machine_home))) + + plist_path = gateway_cli.get_launchd_plist_path() + + assert plist_path == machine_home / "Library" / "LaunchAgents" / "ai.hermes.gateway-orcha.plist" + class TestRemapPathForUser: """Unit tests for _remap_path_for_user().""" diff --git a/tests/hermes_cli/test_logs.py b/tests/hermes_cli/test_logs.py index d379226db5..0827143fc6 100644 --- a/tests/hermes_cli/test_logs.py +++ b/tests/hermes_cli/test_logs.py @@ -1,288 +1,255 @@ -"""Tests for hermes_cli/logs.py — log viewing and filtering.""" +"""Tests for hermes_cli.logs — log viewing and filtering.""" import os -import textwrap from datetime import datetime, timedelta -from io import StringIO from pathlib import Path -from unittest.mock import patch import pytest from hermes_cli.logs import ( LOG_FILES, _extract_level, + _extract_logger_name, + _line_matches_component, _matches_filters, _parse_line_timestamp, _parse_since, _read_last_n_lines, - list_logs, - tail_log, + _read_tail, ) # --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - -@pytest.fixture -def log_dir(tmp_path, monkeypatch): - """Create a fake HERMES_HOME with a logs/ directory.""" - home = Path(os.environ["HERMES_HOME"]) - logs = home / "logs" - logs.mkdir(parents=True, exist_ok=True) - return logs - - -@pytest.fixture -def sample_agent_log(log_dir): - """Write a realistic agent.log with mixed levels and sessions.""" - lines = textwrap.dedent("""\ - 2026-04-05 10:00:00,000 INFO run_agent: conversation turn: session=sess_aaa model=claude provider=openrouter platform=cli history=0 msg='hello' - 2026-04-05 10:00:01,000 INFO run_agent: tool terminal completed (0.50s, 200 chars) - 2026-04-05 10:00:02,000 INFO run_agent: API call #1: model=claude provider=openrouter in=1000 out=200 total=1200 latency=1.5s - 2026-04-05 10:00:03,000 WARNING run_agent: Tool web_search returned error (2.00s): timeout - 2026-04-05 10:00:04,000 INFO run_agent: conversation turn: session=sess_bbb model=gpt-5 provider=openai platform=telegram history=5 msg='fix bug' - 2026-04-05 10:00:05,000 ERROR run_agent: API call failed after 3 retries. rate limited - 2026-04-05 10:00:06,000 INFO run_agent: tool read_file completed (0.01s, 500 chars) - 2026-04-05 10:00:07,000 DEBUG run_agent: verbose internal detail - 2026-04-05 10:00:08,000 INFO credential_pool: credential pool: marking key-1 exhausted (status=429), rotating - 2026-04-05 10:00:09,000 INFO credential_pool: credential pool: rotated to key-2 - """) - path = log_dir / "agent.log" - path.write_text(lines) - return path - - -@pytest.fixture -def sample_errors_log(log_dir): - """Write a small errors.log.""" - lines = textwrap.dedent("""\ - 2026-04-05 10:00:03,000 WARNING run_agent: Tool web_search returned error (2.00s): timeout - 2026-04-05 10:00:05,000 ERROR run_agent: API call failed after 3 retries. rate limited - """) - path = log_dir / "errors.log" - path.write_text(lines) - return path - - -# --------------------------------------------------------------------------- -# _parse_since +# Timestamp parsing # --------------------------------------------------------------------------- class TestParseSince: def test_hours(self): cutoff = _parse_since("2h") assert cutoff is not None - assert (datetime.now() - cutoff).total_seconds() == pytest.approx(7200, abs=5) + assert abs((datetime.now() - cutoff).total_seconds() - 7200) < 2 def test_minutes(self): cutoff = _parse_since("30m") assert cutoff is not None - assert (datetime.now() - cutoff).total_seconds() == pytest.approx(1800, abs=5) + assert abs((datetime.now() - cutoff).total_seconds() - 1800) < 2 def test_days(self): cutoff = _parse_since("1d") assert cutoff is not None - assert (datetime.now() - cutoff).total_seconds() == pytest.approx(86400, abs=5) + assert abs((datetime.now() - cutoff).total_seconds() - 86400) < 2 def test_seconds(self): - cutoff = _parse_since("60s") + cutoff = _parse_since("120s") assert cutoff is not None - assert (datetime.now() - cutoff).total_seconds() == pytest.approx(60, abs=5) + assert abs((datetime.now() - cutoff).total_seconds() - 120) < 2 def test_invalid_returns_none(self): assert _parse_since("abc") is None assert _parse_since("") is None assert _parse_since("10x") is None - def test_whitespace_handling(self): - cutoff = _parse_since(" 1h ") + def test_whitespace_tolerance(self): + cutoff = _parse_since(" 5m ") assert cutoff is not None -# --------------------------------------------------------------------------- -# _parse_line_timestamp -# --------------------------------------------------------------------------- - class TestParseLineTimestamp: def test_standard_format(self): - ts = _parse_line_timestamp("2026-04-05 10:00:00,123 INFO something") - assert ts is not None - assert ts.year == 2026 - assert ts.hour == 10 + ts = _parse_line_timestamp("2026-04-11 10:23:45 INFO gateway.run: msg") + assert ts == datetime(2026, 4, 11, 10, 23, 45) def test_no_timestamp(self): - assert _parse_line_timestamp("just some text") is None + assert _parse_line_timestamp("no timestamp here") is None - def test_continuation_line(self): - assert _parse_line_timestamp(" at module.function (line 42)") is None - - -# --------------------------------------------------------------------------- -# _extract_level -# --------------------------------------------------------------------------- class TestExtractLevel: def test_info(self): - assert _extract_level("2026-04-05 10:00:00 INFO run_agent: something") == "INFO" + assert _extract_level("2026-01-01 00:00:00 INFO gateway.run: msg") == "INFO" def test_warning(self): - assert _extract_level("2026-04-05 10:00:00 WARNING run_agent: bad") == "WARNING" + assert _extract_level("2026-01-01 00:00:00 WARNING tools.file: msg") == "WARNING" def test_error(self): - assert _extract_level("2026-04-05 10:00:00 ERROR run_agent: crash") == "ERROR" + assert _extract_level("2026-01-01 00:00:00 ERROR run_agent: msg") == "ERROR" def test_debug(self): - assert _extract_level("2026-04-05 10:00:00 DEBUG run_agent: detail") == "DEBUG" + assert _extract_level("2026-01-01 00:00:00 DEBUG agent.aux: msg") == "DEBUG" def test_no_level(self): - assert _extract_level("just a plain line") is None + assert _extract_level("random text") is None # --------------------------------------------------------------------------- -# _matches_filters +# Logger name extraction (new for component filtering) +# --------------------------------------------------------------------------- + +class TestExtractLoggerName: + def test_standard_line(self): + line = "2026-04-11 10:23:45 INFO gateway.run: Starting gateway" + assert _extract_logger_name(line) == "gateway.run" + + def test_nested_logger(self): + line = "2026-04-11 10:23:45 INFO gateway.platforms.telegram: connected" + assert _extract_logger_name(line) == "gateway.platforms.telegram" + + def test_warning_level(self): + line = "2026-04-11 10:23:45 WARNING tools.terminal_tool: timeout" + assert _extract_logger_name(line) == "tools.terminal_tool" + + def test_with_session_tag(self): + line = "2026-04-11 10:23:45 INFO [abc123] tools.file_tools: reading file" + assert _extract_logger_name(line) == "tools.file_tools" + + def test_with_session_tag_and_error(self): + line = "2026-04-11 10:23:45 ERROR [sess_xyz] agent.context_compressor: failed" + assert _extract_logger_name(line) == "agent.context_compressor" + + def test_top_level_module(self): + line = "2026-04-11 10:23:45 INFO run_agent: starting conversation" + assert _extract_logger_name(line) == "run_agent" + + def test_no_match(self): + assert _extract_logger_name("random text") is None + + +class TestLineMatchesComponent: + def test_gateway_component(self): + line = "2026-04-11 10:23:45 INFO gateway.run: msg" + assert _line_matches_component(line, ("gateway",)) + + def test_gateway_nested(self): + line = "2026-04-11 10:23:45 INFO gateway.platforms.telegram: msg" + assert _line_matches_component(line, ("gateway",)) + + def test_tools_component(self): + line = "2026-04-11 10:23:45 INFO tools.terminal_tool: msg" + assert _line_matches_component(line, ("tools",)) + + def test_agent_with_multiple_prefixes(self): + prefixes = ("agent", "run_agent", "model_tools") + assert _line_matches_component( + "2026-04-11 10:23:45 INFO agent.context_compressor: msg", prefixes) + assert _line_matches_component( + "2026-04-11 10:23:45 INFO run_agent: msg", prefixes) + assert _line_matches_component( + "2026-04-11 10:23:45 INFO model_tools: msg", prefixes) + + def test_no_match(self): + line = "2026-04-11 10:23:45 INFO tools.browser: msg" + assert not _line_matches_component(line, ("gateway",)) + + def test_with_session_tag(self): + line = "2026-04-11 10:23:45 INFO [abc] gateway.run: msg" + assert _line_matches_component(line, ("gateway",)) + + def test_unparseable_line(self): + assert not _line_matches_component("random text", ("gateway",)) + + +# --------------------------------------------------------------------------- +# Combined filter # --------------------------------------------------------------------------- class TestMatchesFilters: - def test_no_filters_always_matches(self): - assert _matches_filters("any line") is True + def test_no_filters_passes_everything(self): + assert _matches_filters("any line") - def test_level_filter_passes(self): + def test_level_filter(self): assert _matches_filters( - "2026-04-05 10:00:00 WARNING something", - min_level="WARNING", - ) is True + "2026-01-01 00:00:00 WARNING x: msg", min_level="WARNING") + assert not _matches_filters( + "2026-01-01 00:00:00 INFO x: msg", min_level="WARNING") - def test_level_filter_rejects(self): + def test_session_filter(self): assert _matches_filters( - "2026-04-05 10:00:00 INFO something", - min_level="WARNING", - ) is False + "2026-01-01 00:00:00 INFO [abc123] x: msg", session_filter="abc123") + assert not _matches_filters( + "2026-01-01 00:00:00 INFO [xyz789] x: msg", session_filter="abc123") - def test_session_filter_passes(self): + def test_component_filter(self): assert _matches_filters( - "session=sess_aaa model=claude", - session_filter="sess_aaa", - ) is True - - def test_session_filter_rejects(self): - assert _matches_filters( - "session=sess_aaa model=claude", - session_filter="sess_bbb", - ) is False - - def test_since_filter_passes(self): - # Line from the future should always pass - assert _matches_filters( - "2099-01-01 00:00:00 INFO future", - since=datetime.now(), - ) is True - - def test_since_filter_rejects(self): - assert _matches_filters( - "2020-01-01 00:00:00 INFO past", - since=datetime.now(), - ) is False + "2026-01-01 00:00:00 INFO gateway.run: msg", + component_prefixes=("gateway",)) + assert not _matches_filters( + "2026-01-01 00:00:00 INFO tools.file: msg", + component_prefixes=("gateway",)) def test_combined_filters(self): - line = "2099-01-01 00:00:00 WARNING run_agent: session=abc error" + """All filters must pass for a line to match.""" + line = "2026-04-11 10:00:00 WARNING [sess_1] gateway.run: connection lost" assert _matches_filters( - line, min_level="WARNING", session_filter="abc", - since=datetime.now(), - ) is True - # Fails session filter + line, + min_level="WARNING", + session_filter="sess_1", + component_prefixes=("gateway",), + ) + # Fails component filter + assert not _matches_filters( + line, + min_level="WARNING", + session_filter="sess_1", + component_prefixes=("tools",), + ) + + def test_since_filter(self): + # Line with a very old timestamp should be filtered out + assert not _matches_filters( + "2020-01-01 00:00:00 INFO x: old msg", + since=datetime.now() - timedelta(hours=1)) + # Line with a recent timestamp should pass + recent = datetime.now().strftime("%Y-%m-%d %H:%M:%S") assert _matches_filters( - line, min_level="WARNING", session_filter="xyz", - ) is False + f"{recent} INFO x: recent msg", + since=datetime.now() - timedelta(hours=1)) # --------------------------------------------------------------------------- -# _read_last_n_lines +# File reading # --------------------------------------------------------------------------- -class TestReadLastNLines: - def test_reads_correct_count(self, sample_agent_log): - lines = _read_last_n_lines(sample_agent_log, 3) - assert len(lines) == 3 +class TestReadTail: + def test_read_small_file(self, tmp_path): + log_file = tmp_path / "test.log" + lines = [f"2026-01-01 00:00:0{i} INFO x: line {i}\n" for i in range(10)] + log_file.write_text("".join(lines)) - def test_reads_all_when_fewer(self, sample_agent_log): - lines = _read_last_n_lines(sample_agent_log, 100) - assert len(lines) == 10 # sample has 10 lines + result = _read_last_n_lines(log_file, 5) + assert len(result) == 5 + assert "line 9" in result[-1] - def test_empty_file(self, log_dir): - empty = log_dir / "empty.log" - empty.write_text("") - lines = _read_last_n_lines(empty, 10) - assert lines == [] + def test_read_with_component_filter(self, tmp_path): + log_file = tmp_path / "test.log" + lines = [ + "2026-01-01 00:00:00 INFO gateway.run: gw msg\n", + "2026-01-01 00:00:01 INFO tools.file: tool msg\n", + "2026-01-01 00:00:02 INFO gateway.session: session msg\n", + "2026-01-01 00:00:03 INFO agent.compressor: agent msg\n", + ] + log_file.write_text("".join(lines)) - def test_last_line_content(self, sample_agent_log): - lines = _read_last_n_lines(sample_agent_log, 1) - assert "rotated to key-2" in lines[0] + result = _read_tail( + log_file, 50, + has_filters=True, + component_prefixes=("gateway",), + ) + assert len(result) == 2 + assert "gw msg" in result[0] + assert "session msg" in result[1] + + def test_empty_file(self, tmp_path): + log_file = tmp_path / "empty.log" + log_file.write_text("") + result = _read_last_n_lines(log_file, 10) + assert result == [] # --------------------------------------------------------------------------- -# tail_log +# LOG_FILES registry # --------------------------------------------------------------------------- -class TestTailLog: - def test_basic_tail(self, sample_agent_log, capsys): - tail_log("agent", num_lines=3) - captured = capsys.readouterr() - assert "agent.log" in captured.out - # Should have the header + 3 lines - lines = captured.out.strip().split("\n") - assert len(lines) == 4 # 1 header + 3 content - - def test_level_filter(self, sample_agent_log, capsys): - tail_log("agent", num_lines=50, level="ERROR") - captured = capsys.readouterr() - assert "level>=ERROR" in captured.out - # Only the ERROR line should appear - content_lines = [l for l in captured.out.strip().split("\n") if not l.startswith("---")] - assert len(content_lines) == 1 - assert "API call failed" in content_lines[0] - - def test_session_filter(self, sample_agent_log, capsys): - tail_log("agent", num_lines=50, session="sess_bbb") - captured = capsys.readouterr() - content_lines = [l for l in captured.out.strip().split("\n") if not l.startswith("---")] - assert len(content_lines) == 1 - assert "sess_bbb" in content_lines[0] - - def test_errors_log(self, sample_errors_log, capsys): - tail_log("errors", num_lines=10) - captured = capsys.readouterr() - assert "errors.log" in captured.out - assert "WARNING" in captured.out or "ERROR" in captured.out - - def test_unknown_log_exits(self): - with pytest.raises(SystemExit): - tail_log("nonexistent") - - def test_missing_file_exits(self, log_dir): - with pytest.raises(SystemExit): - tail_log("agent") # agent.log doesn't exist in clean log_dir - - -# --------------------------------------------------------------------------- -# list_logs -# --------------------------------------------------------------------------- - -class TestListLogs: - def test_lists_files(self, sample_agent_log, sample_errors_log, capsys): - list_logs() - captured = capsys.readouterr() - assert "agent.log" in captured.out - assert "errors.log" in captured.out - - def test_empty_dir(self, log_dir, capsys): - list_logs() - captured = capsys.readouterr() - assert "no log files yet" in captured.out - - def test_shows_sizes(self, sample_agent_log, capsys): - list_logs() - captured = capsys.readouterr() - # File is small, should show as bytes or KB - assert "B" in captured.out or "KB" in captured.out +class TestLogFiles: + def test_known_log_files(self): + assert "agent" in LOG_FILES + assert "errors" in LOG_FILES + assert "gateway" in LOG_FILES diff --git a/tests/hermes_cli/test_mcp_config.py b/tests/hermes_cli/test_mcp_config.py index 91a5f988cc..9647a0b95b 100644 --- a/tests/hermes_cli/test_mcp_config.py +++ b/tests/hermes_cli/test_mcp_config.py @@ -46,6 +46,8 @@ def _make_args(**kwargs): "command": None, "args": None, "auth": None, + "preset": None, + "env": None, "mcp_action": None, } defaults.update(kwargs) @@ -269,6 +271,145 @@ class TestMcpAdd: config = load_config() assert config["mcp_servers"]["broken"]["enabled"] is False + def test_add_stdio_server_with_env(self, tmp_path, capsys, monkeypatch): + """Stdio servers can persist explicit environment variables.""" + fake_tools = [FakeTool("search", "Search repos")] + + def mock_probe(name, config, **kw): + assert config["env"] == { + "MY_API_KEY": "secret123", + "DEBUG": "true", + } + return [(t.name, t.description) for t in fake_tools] + + monkeypatch.setattr( + "hermes_cli.mcp_config._probe_single_server", mock_probe + ) + monkeypatch.setattr("builtins.input", lambda _: "") + + from hermes_cli.mcp_config import cmd_mcp_add + + cmd_mcp_add(_make_args( + name="github", + command="npx", + args=["@mcp/github"], + env=["MY_API_KEY=secret123", "DEBUG=true"], + )) + out = capsys.readouterr().out + assert "Saved" in out + + from hermes_cli.config import load_config + + config = load_config() + srv = config["mcp_servers"]["github"] + assert srv["env"] == { + "MY_API_KEY": "secret123", + "DEBUG": "true", + } + + def test_add_stdio_server_rejects_invalid_env_name(self, capsys): + """Invalid environment variable names are rejected up front.""" + from hermes_cli.mcp_config import cmd_mcp_add + + cmd_mcp_add(_make_args( + name="github", + command="npx", + args=["@mcp/github"], + env=["BAD-NAME=value"], + )) + out = capsys.readouterr().out + assert "Invalid --env variable name" in out + + def test_add_http_server_rejects_env_flag(self, capsys): + """The --env flag is only valid for stdio transports.""" + from hermes_cli.mcp_config import cmd_mcp_add + + cmd_mcp_add(_make_args( + name="ink", + url="https://mcp.ml.ink/mcp", + env=["DEBUG=true"], + )) + out = capsys.readouterr().out + assert "only supported for stdio MCP servers" in out + + def test_add_preset_fills_transport(self, tmp_path, capsys, monkeypatch): + """A preset fills in command/args when no explicit transport given.""" + monkeypatch.setattr( + "hermes_cli.mcp_config._MCP_PRESETS", + {"testmcp": {"command": "npx", "args": ["-y", "test-mcp-server"], "display_name": "Test MCP"}}, + ) + fake_tools = [FakeTool("do_thing", "Does a thing")] + + def mock_probe(name, config, **kw): + assert name == "myserver" + assert config["command"] == "npx" + assert config["args"] == ["-y", "test-mcp-server"] + assert "env" not in config + return [(t.name, t.description) for t in fake_tools] + + monkeypatch.setattr( + "hermes_cli.mcp_config._probe_single_server", mock_probe + ) + monkeypatch.setattr("builtins.input", lambda _: "") + + from hermes_cli.mcp_config import cmd_mcp_add + from hermes_cli.config import read_raw_config + + cmd_mcp_add(_make_args(name="myserver", preset="testmcp")) + out = capsys.readouterr().out + assert "Saved" in out + + config = read_raw_config() + srv = config["mcp_servers"]["myserver"] + assert srv["command"] == "npx" + assert srv["args"] == ["-y", "test-mcp-server"] + assert "env" not in srv + + def test_preset_does_not_override_explicit_command(self, tmp_path, capsys, monkeypatch): + """Explicit transports win over presets.""" + monkeypatch.setattr( + "hermes_cli.mcp_config._MCP_PRESETS", + {"testmcp": {"command": "npx", "args": ["-y", "test-mcp-server"], "display_name": "Test MCP"}}, + ) + fake_tools = [FakeTool("search", "Search repos")] + + def mock_probe(name, config, **kw): + assert config["command"] == "uvx" + assert config["args"] == ["custom-server"] + assert "env" not in config + return [(t.name, t.description) for t in fake_tools] + + monkeypatch.setattr( + "hermes_cli.mcp_config._probe_single_server", mock_probe + ) + monkeypatch.setattr("builtins.input", lambda _: "") + + from hermes_cli.mcp_config import cmd_mcp_add + from hermes_cli.config import read_raw_config + + cmd_mcp_add(_make_args( + name="custom", + preset="testmcp", + command="uvx", + args=["custom-server"], + )) + out = capsys.readouterr().out + assert "Saved" in out + + config = read_raw_config() + srv = config["mcp_servers"]["custom"] + assert srv["command"] == "uvx" + assert srv["args"] == ["custom-server"] + assert "env" not in srv + + def test_unknown_preset_rejected(self, capsys): + """An unknown preset name is rejected with a clear error.""" + from hermes_cli.mcp_config import cmd_mcp_add + + cmd_mcp_add(_make_args(name="foo", preset="nonexistent")) + out = capsys.readouterr().out + assert "Unknown MCP preset" in out + # --------------------------------------------------------------------------- # Tests: cmd_mcp_test diff --git a/tests/hermes_cli/test_skills_hub.py b/tests/hermes_cli/test_skills_hub.py index 0ef6c2d69a..bf9fa71a3a 100644 --- a/tests/hermes_cli/test_skills_hub.py +++ b/tests/hermes_cli/test_skills_hub.py @@ -1,8 +1,10 @@ from io import StringIO +from unittest.mock import patch import pytest from rich.console import Console +from cli import ChatConsole from hermes_cli.skills_hub import do_check, do_install, do_list, do_update, handle_skills_slash @@ -179,6 +181,21 @@ def test_do_update_reinstalls_outdated_skills(monkeypatch): assert "Updated 1 skill" in output +def test_handle_skills_slash_search_accepts_chatconsole_without_status_errors(): + results = [type("R", (), { + "name": "kubernetes", + "description": "Cluster orchestration", + "source": "skills.sh", + "trust_level": "community", + "identifier": "skills-sh/example/kubernetes", + })()] + + with patch("tools.skills_hub.unified_search", return_value=results), \ + patch("tools.skills_hub.create_source_router", return_value={}), \ + patch("tools.skills_hub.GitHubAuth"): + handle_skills_slash("/skills search kubernetes", console=ChatConsole()) + + def test_do_install_scans_with_resolved_identifier(monkeypatch, tmp_path, hub_env): import tools.skills_guard as guard import tools.skills_hub as hub diff --git a/tests/hermes_cli/test_update_gateway_restart.py b/tests/hermes_cli/test_update_gateway_restart.py index ceb05f65c9..822b22742d 100644 --- a/tests/hermes_cli/test_update_gateway_restart.py +++ b/tests/hermes_cli/test_update_gateway_restart.py @@ -191,6 +191,19 @@ class TestLaunchdPlistPath: raise AssertionError("PATH key not found in plist") +class TestLaunchdPlistCurrentness: + def test_launchd_plist_is_current_ignores_path_drift(self, tmp_path, monkeypatch): + plist_path = tmp_path / "ai.hermes.gateway.plist" + monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path) + + monkeypatch.setenv("PATH", "/custom/bin:/usr/bin:/bin") + plist_path.write_text(gateway_cli.generate_launchd_plist(), encoding="utf-8") + + monkeypatch.setenv("PATH", "/opt/homebrew/bin:/usr/local/bin:/usr/bin:/bin") + + assert gateway_cli.launchd_plist_is_current() is True + + # --------------------------------------------------------------------------- # cmd_update — macOS launchd detection # --------------------------------------------------------------------------- @@ -536,7 +549,7 @@ class TestServicePidExclusion: gateway_cli, "_get_service_pids", return_value={SERVICE_PID} ), patch.object( gateway_cli, "find_gateway_pids", - side_effect=lambda exclude_pids=None: ( + side_effect=lambda exclude_pids=None, all_profiles=False: ( [SERVICE_PID] if not exclude_pids else [p for p in [SERVICE_PID] if p not in exclude_pids] ), @@ -579,7 +592,7 @@ class TestServicePidExclusion: gateway_cli, "_get_service_pids", return_value={SERVICE_PID} ), patch.object( gateway_cli, "find_gateway_pids", - side_effect=lambda exclude_pids=None: ( + side_effect=lambda exclude_pids=None, all_profiles=False: ( [SERVICE_PID] if not exclude_pids else [p for p in [SERVICE_PID] if p not in exclude_pids] ), @@ -618,7 +631,7 @@ class TestServicePidExclusion: launchctl_loaded=True, ) - def fake_find(exclude_pids=None): + def fake_find(exclude_pids=None, all_profiles=False): _exclude = exclude_pids or set() return [p for p in [SERVICE_PID, MANUAL_PID] if p not in _exclude] @@ -760,3 +773,28 @@ class TestFindGatewayPidsExclude: pids = gateway_cli.find_gateway_pids() assert 100 in pids assert 200 in pids + + def test_filters_to_current_profile(self, monkeypatch, tmp_path): + profile_dir = tmp_path / ".hermes" / "profiles" / "orcha" + profile_dir.mkdir(parents=True) + monkeypatch.setattr(gateway_cli, "is_windows", lambda: False) + monkeypatch.setattr(gateway_cli, "get_hermes_home", lambda: profile_dir) + + def fake_run(cmd, **kwargs): + return subprocess.CompletedProcess( + cmd, 0, + stdout=( + "100 /Users/dgrieco/.hermes/hermes-agent/venv/bin/python -m hermes_cli.main --profile orcha gateway run --replace\n" + "200 /Users/dgrieco/.hermes/hermes-agent/venv/bin/python -m hermes_cli.main --profile other gateway run --replace\n" + ), + stderr="", + ) + + monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run) + monkeypatch.setattr("os.getpid", lambda: 999) + monkeypatch.setattr(gateway_cli, "_get_service_pids", lambda: set()) + monkeypatch.setattr(gateway_cli, "_profile_arg", lambda hermes_home=None: "--profile orcha") + + pids = gateway_cli.find_gateway_pids() + + assert pids == [100] diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index 61137fe90a..d716b59b27 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -2742,74 +2742,12 @@ class TestSystemPromptStability: assert "Hermes Agent" in agent._cached_system_prompt class TestBudgetPressure: - """Budget pressure warning system (issue #414).""" + """Budget exhaustion grace call system.""" - def test_no_warning_below_caution(self, agent): - agent.max_iterations = 60 - assert agent._get_budget_warning(30) is None - - def test_caution_at_70_percent(self, agent): - agent.max_iterations = 60 - msg = agent._get_budget_warning(42) - assert msg is not None - assert "[BUDGET:" in msg - assert "18 iterations left" in msg - - def test_warning_at_90_percent(self, agent): - agent.max_iterations = 60 - msg = agent._get_budget_warning(54) - assert "[BUDGET WARNING:" in msg - assert "Provide your final response NOW" in msg - - def test_last_iteration(self, agent): - agent.max_iterations = 60 - msg = agent._get_budget_warning(59) - assert "1 iteration(s) left" in msg - - def test_disabled(self, agent): - agent.max_iterations = 60 - agent._budget_pressure_enabled = False - assert agent._get_budget_warning(55) is None - - def test_zero_max_iterations(self, agent): - agent.max_iterations = 0 - assert agent._get_budget_warning(0) is None - - def test_injects_into_json_tool_result(self, agent): - """Warning should be injected as _budget_warning field in JSON tool results.""" - import json - agent.max_iterations = 10 - messages = [ - {"role": "tool", "content": json.dumps({"output": "done", "exit_code": 0}), "tool_call_id": "tc1"} - ] - warning = agent._get_budget_warning(9) - assert warning is not None - # Simulate the injection logic - last_content = messages[-1]["content"] - parsed = json.loads(last_content) - parsed["_budget_warning"] = warning - messages[-1]["content"] = json.dumps(parsed, ensure_ascii=False) - result = json.loads(messages[-1]["content"]) - assert "_budget_warning" in result - assert "BUDGET WARNING" in result["_budget_warning"] - assert result["output"] == "done" # original content preserved - - def test_appends_to_non_json_tool_result(self, agent): - """Warning should be appended as text for non-JSON tool results.""" - agent.max_iterations = 10 - messages = [ - {"role": "tool", "content": "plain text result", "tool_call_id": "tc1"} - ] - warning = agent._get_budget_warning(9) - # Simulate injection logic for non-JSON - last_content = messages[-1]["content"] - try: - import json - json.loads(last_content) - except (json.JSONDecodeError, TypeError): - messages[-1]["content"] = last_content + f"\n\n{warning}" - assert "plain text result" in messages[-1]["content"] - assert "BUDGET WARNING" in messages[-1]["content"] + def test_grace_call_flags_initialized(self, agent): + """Agent should have budget grace call flags.""" + assert agent._budget_exhausted_injected is False + assert agent._budget_grace_call is False class TestSafeWriter: diff --git a/tests/run_agent/test_run_agent_codex_responses.py b/tests/run_agent/test_run_agent_codex_responses.py index 6756ed6fde..533a85ac83 100644 --- a/tests/run_agent/test_run_agent_codex_responses.py +++ b/tests/run_agent/test_run_agent_codex_responses.py @@ -744,6 +744,44 @@ def test_normalize_codex_response_marks_commentary_only_message_as_incomplete(mo assert "inspect the repository" in (assistant_message.content or "") +def test_interim_commentary_is_not_marked_already_streamed_without_callbacks(monkeypatch): + agent = _build_agent(monkeypatch) + observed = {} + + agent._fire_stream_delta("short version: yes") + agent.interim_assistant_callback = lambda text, *, already_streamed=False: observed.update( + {"text": text, "already_streamed": already_streamed} + ) + + agent._emit_interim_assistant_message({"role": "assistant", "content": "short version: yes"}) + + assert observed == { + "text": "short version: yes", + "already_streamed": False, + } + + +def test_interim_commentary_is_not_marked_already_streamed_when_stream_callback_fails(monkeypatch): + agent = _build_agent(monkeypatch) + observed = {} + + def failing_callback(_text): + raise RuntimeError("display failed") + + agent.stream_delta_callback = failing_callback + agent._fire_stream_delta("short version: yes") + agent.interim_assistant_callback = lambda text, *, already_streamed=False: observed.update( + {"text": text, "already_streamed": already_streamed} + ) + + agent._emit_interim_assistant_message({"role": "assistant", "content": "short version: yes"}) + + assert observed == { + "text": "short version: yes", + "already_streamed": False, + } + + def test_run_conversation_codex_continues_after_commentary_phase_message(monkeypatch): agent = _build_agent(monkeypatch) responses = [ @@ -1104,3 +1142,58 @@ def test_duplicate_detection_distinguishes_different_codex_reasoning(monkeypatch ] assert "enc_first" in encrypted_contents assert "enc_second" in encrypted_contents + + +def test_chat_messages_to_responses_input_deduplicates_reasoning_ids(monkeypatch): + """Duplicate reasoning item IDs across multi-turn incomplete responses + must be deduplicated so the Responses API doesn't reject with HTTP 400.""" + agent = _build_agent(monkeypatch) + messages = [ + {"role": "user", "content": "think hard"}, + { + "role": "assistant", + "content": "", + "codex_reasoning_items": [ + {"type": "reasoning", "id": "rs_aaa", "encrypted_content": "enc_1"}, + {"type": "reasoning", "id": "rs_bbb", "encrypted_content": "enc_2"}, + ], + }, + { + "role": "assistant", + "content": "partial answer", + "codex_reasoning_items": [ + # rs_aaa is duplicated from the previous turn + {"type": "reasoning", "id": "rs_aaa", "encrypted_content": "enc_1"}, + {"type": "reasoning", "id": "rs_ccc", "encrypted_content": "enc_3"}, + ], + }, + ] + items = agent._chat_messages_to_responses_input(messages) + + reasoning_ids = [it["id"] for it in items if it.get("type") == "reasoning"] + # rs_aaa should appear only once (first occurrence kept) + assert reasoning_ids.count("rs_aaa") == 1 + # rs_bbb and rs_ccc should each appear once + assert reasoning_ids.count("rs_bbb") == 1 + assert reasoning_ids.count("rs_ccc") == 1 + assert len(reasoning_ids) == 3 + + +def test_preflight_codex_input_deduplicates_reasoning_ids(monkeypatch): + """_preflight_codex_input_items should also deduplicate reasoning items by ID.""" + agent = _build_agent(monkeypatch) + raw_input = [ + {"role": "user", "content": [{"type": "input_text", "text": "hello"}]}, + {"type": "reasoning", "id": "rs_xyz", "encrypted_content": "enc_a"}, + {"role": "assistant", "content": "ok"}, + {"type": "reasoning", "id": "rs_xyz", "encrypted_content": "enc_a"}, + {"type": "reasoning", "id": "rs_zzz", "encrypted_content": "enc_b"}, + {"role": "assistant", "content": "done"}, + ] + normalized = agent._preflight_codex_input_items(raw_input) + + reasoning_items = [it for it in normalized if it.get("type") == "reasoning"] + reasoning_ids = [it["id"] for it in reasoning_items] + assert reasoning_ids.count("rs_xyz") == 1 + assert reasoning_ids.count("rs_zzz") == 1 + assert len(reasoning_items) == 2 diff --git a/tests/test_hermes_logging.py b/tests/test_hermes_logging.py index 80a23dc688..46969d58d6 100644 --- a/tests/test_hermes_logging.py +++ b/tests/test_hermes_logging.py @@ -3,6 +3,7 @@ import logging import os import stat +import threading from logging.handlers import RotatingFileHandler from pathlib import Path from unittest.mock import patch @@ -34,6 +35,8 @@ def _reset_logging_state(): h.close() else: pre_existing.append(h) + # Ensure the record factory is installed (it's idempotent). + hermes_logging._install_session_record_factory() yield # Restore — remove any handlers added during the test. for h in list(root.handlers): @@ -41,6 +44,7 @@ def _reset_logging_state(): root.removeHandler(h) h.close() hermes_logging._logging_initialized = False + hermes_logging.clear_session_context() @pytest.fixture @@ -220,6 +224,294 @@ class TestSetupLogging: ] assert agent_handlers[0].level == logging.WARNING + def test_record_factory_installed(self, hermes_home): + """The custom record factory injects session_tag on all records.""" + hermes_logging.setup_logging(hermes_home=hermes_home) + factory = logging.getLogRecordFactory() + assert getattr(factory, "_hermes_session_injector", False), ( + "Record factory should have _hermes_session_injector marker" + ) + # Verify session_tag exists on a fresh record + record = factory("test", logging.INFO, "", 0, "msg", (), None) + assert hasattr(record, "session_tag") + + +class TestGatewayMode: + """setup_logging(mode='gateway') creates a filtered gateway.log.""" + + def test_gateway_log_created(self, hermes_home): + hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway") + root = logging.getLogger() + + gw_handlers = [ + h for h in root.handlers + if isinstance(h, RotatingFileHandler) + and "gateway.log" in getattr(h, "baseFilename", "") + ] + assert len(gw_handlers) == 1 + + def test_gateway_log_not_created_in_cli_mode(self, hermes_home): + hermes_logging.setup_logging(hermes_home=hermes_home, mode="cli") + root = logging.getLogger() + + gw_handlers = [ + h for h in root.handlers + if isinstance(h, RotatingFileHandler) + and "gateway.log" in getattr(h, "baseFilename", "") + ] + assert len(gw_handlers) == 0 + + def test_gateway_log_receives_gateway_records(self, hermes_home): + """gateway.log captures records from gateway.* loggers.""" + hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway") + + gw_logger = logging.getLogger("gateway.platforms.telegram") + gw_logger.info("telegram connected") + + for h in logging.getLogger().handlers: + h.flush() + + gw_log = hermes_home / "logs" / "gateway.log" + assert gw_log.exists() + assert "telegram connected" in gw_log.read_text() + + def test_gateway_log_rejects_non_gateway_records(self, hermes_home): + """gateway.log does NOT capture records from tools.*, agent.*, etc.""" + hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway") + + tool_logger = logging.getLogger("tools.terminal_tool") + tool_logger.info("running command") + + agent_logger = logging.getLogger("agent.context_compressor") + agent_logger.info("compressing context") + + for h in logging.getLogger().handlers: + h.flush() + + gw_log = hermes_home / "logs" / "gateway.log" + if gw_log.exists(): + content = gw_log.read_text() + assert "running command" not in content + assert "compressing context" not in content + + def test_agent_log_still_receives_all(self, hermes_home): + """agent.log (catch-all) still receives gateway AND tool records.""" + hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway") + + logging.getLogger("gateway.run").info("gateway msg") + logging.getLogger("tools.file_tools").info("file msg") + + for h in logging.getLogger().handlers: + h.flush() + + agent_log = hermes_home / "logs" / "agent.log" + content = agent_log.read_text() + assert "gateway msg" in content + assert "file msg" in content + + +class TestSessionContext: + """set_session_context / clear_session_context + _SessionFilter.""" + + def test_session_tag_in_log_output(self, hermes_home): + """When session context is set, log lines include [session_id].""" + hermes_logging.setup_logging(hermes_home=hermes_home) + hermes_logging.set_session_context("abc123") + + test_logger = logging.getLogger("test.session_tag") + test_logger.info("tagged message") + + for h in logging.getLogger().handlers: + h.flush() + + agent_log = hermes_home / "logs" / "agent.log" + content = agent_log.read_text() + assert "[abc123]" in content + assert "tagged message" in content + + def test_no_session_tag_without_context(self, hermes_home): + """Without session context, log lines have no session tag.""" + hermes_logging.setup_logging(hermes_home=hermes_home) + hermes_logging.clear_session_context() + + test_logger = logging.getLogger("test.no_session") + test_logger.info("untagged message") + + for h in logging.getLogger().handlers: + h.flush() + + agent_log = hermes_home / "logs" / "agent.log" + content = agent_log.read_text() + assert "untagged message" in content + # Should not have any [xxx] session tag + import re + for line in content.splitlines(): + if "untagged message" in line: + assert not re.search(r"\[.+?\]", line.split("INFO")[1].split("test.no_session")[0]) + + def test_clear_session_context(self, hermes_home): + """After clearing, session tag disappears.""" + hermes_logging.setup_logging(hermes_home=hermes_home) + hermes_logging.set_session_context("xyz789") + hermes_logging.clear_session_context() + + test_logger = logging.getLogger("test.cleared") + test_logger.info("after clear") + + for h in logging.getLogger().handlers: + h.flush() + + agent_log = hermes_home / "logs" / "agent.log" + content = agent_log.read_text() + assert "[xyz789]" not in content + + def test_session_context_thread_isolated(self, hermes_home): + """Session context is per-thread — one thread's context doesn't leak.""" + hermes_logging.setup_logging(hermes_home=hermes_home) + + results = {} + + def thread_a(): + hermes_logging.set_session_context("thread_a_session") + logging.getLogger("test.thread_a").info("from thread A") + for h in logging.getLogger().handlers: + h.flush() + + def thread_b(): + hermes_logging.set_session_context("thread_b_session") + logging.getLogger("test.thread_b").info("from thread B") + for h in logging.getLogger().handlers: + h.flush() + + ta = threading.Thread(target=thread_a) + tb = threading.Thread(target=thread_b) + ta.start() + ta.join() + tb.start() + tb.join() + + agent_log = hermes_home / "logs" / "agent.log" + content = agent_log.read_text() + + # Each thread's message should have its own session tag + for line in content.splitlines(): + if "from thread A" in line: + assert "[thread_a_session]" in line + assert "[thread_b_session]" not in line + if "from thread B" in line: + assert "[thread_b_session]" in line + assert "[thread_a_session]" not in line + + +class TestRecordFactory: + """Unit tests for the custom LogRecord factory.""" + + def test_record_has_session_tag(self): + """Every record gets a session_tag attribute.""" + factory = logging.getLogRecordFactory() + record = factory("test", logging.INFO, "", 0, "msg", (), None) + assert hasattr(record, "session_tag") + + def test_empty_tag_without_context(self): + hermes_logging.clear_session_context() + factory = logging.getLogRecordFactory() + record = factory("test", logging.INFO, "", 0, "msg", (), None) + assert record.session_tag == "" + + def test_tag_with_context(self): + hermes_logging.set_session_context("sess_42") + factory = logging.getLogRecordFactory() + record = factory("test", logging.INFO, "", 0, "msg", (), None) + assert record.session_tag == " [sess_42]" + + def test_idempotent_install(self): + """Calling _install_session_record_factory() twice doesn't double-wrap.""" + hermes_logging._install_session_record_factory() + factory_a = logging.getLogRecordFactory() + hermes_logging._install_session_record_factory() + factory_b = logging.getLogRecordFactory() + assert factory_a is factory_b + + def test_works_with_any_handler(self): + """A handler using %(session_tag)s works even without _SessionFilter.""" + hermes_logging.set_session_context("any_handler_test") + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter("%(session_tag)s %(message)s")) + + logger = logging.getLogger("_test_any_handler") + logger.addHandler(handler) + logger.setLevel(logging.DEBUG) + try: + # Should not raise KeyError + logger.info("hello") + finally: + logger.removeHandler(handler) + + +class TestComponentFilter: + """Unit tests for _ComponentFilter.""" + + def test_passes_matching_prefix(self): + f = hermes_logging._ComponentFilter(("gateway",)) + record = logging.LogRecord( + "gateway.run", logging.INFO, "", 0, "msg", (), None + ) + assert f.filter(record) is True + + def test_passes_nested_matching_prefix(self): + f = hermes_logging._ComponentFilter(("gateway",)) + record = logging.LogRecord( + "gateway.platforms.telegram", logging.INFO, "", 0, "msg", (), None + ) + assert f.filter(record) is True + + def test_blocks_non_matching(self): + f = hermes_logging._ComponentFilter(("gateway",)) + record = logging.LogRecord( + "tools.terminal_tool", logging.INFO, "", 0, "msg", (), None + ) + assert f.filter(record) is False + + def test_multiple_prefixes(self): + f = hermes_logging._ComponentFilter(("agent", "run_agent", "model_tools")) + assert f.filter(logging.LogRecord( + "agent.compressor", logging.INFO, "", 0, "", (), None + )) + assert f.filter(logging.LogRecord( + "run_agent", logging.INFO, "", 0, "", (), None + )) + assert f.filter(logging.LogRecord( + "model_tools", logging.INFO, "", 0, "", (), None + )) + assert not f.filter(logging.LogRecord( + "tools.browser", logging.INFO, "", 0, "", (), None + )) + + +class TestComponentPrefixes: + """COMPONENT_PREFIXES covers the expected components.""" + + def test_gateway_prefix(self): + assert "gateway" in hermes_logging.COMPONENT_PREFIXES + assert ("gateway",) == hermes_logging.COMPONENT_PREFIXES["gateway"] + + def test_agent_prefix(self): + prefixes = hermes_logging.COMPONENT_PREFIXES["agent"] + assert "agent" in prefixes + assert "run_agent" in prefixes + assert "model_tools" in prefixes + + def test_tools_prefix(self): + assert ("tools",) == hermes_logging.COMPONENT_PREFIXES["tools"] + + def test_cli_prefix(self): + prefixes = hermes_logging.COMPONENT_PREFIXES["cli"] + assert "hermes_cli" in prefixes + assert "cli" in prefixes + + def test_cron_prefix(self): + assert ("cron",) == hermes_logging.COMPONENT_PREFIXES["cron"] + class TestSetupVerboseLogging: """setup_verbose_logging() adds a DEBUG-level console handler.""" @@ -301,6 +593,59 @@ class TestAddRotatingHandler: logger.removeHandler(h) h.close() + def test_log_filter_attached(self, tmp_path): + """Optional log_filter is attached to the handler.""" + log_path = tmp_path / "filtered.log" + logger = logging.getLogger("_test_rotating_filter") + formatter = logging.Formatter("%(message)s") + component_filter = hermes_logging._ComponentFilter(("test",)) + + hermes_logging._add_rotating_handler( + logger, log_path, + level=logging.INFO, max_bytes=1024, backup_count=1, + formatter=formatter, + log_filter=component_filter, + ) + + handlers = [h for h in logger.handlers if isinstance(h, RotatingFileHandler)] + assert len(handlers) == 1 + assert component_filter in handlers[0].filters + # Clean up + for h in list(logger.handlers): + if isinstance(h, RotatingFileHandler): + logger.removeHandler(h) + h.close() + + def test_no_session_filter_on_handler(self, tmp_path): + """Handlers rely on record factory, not per-handler _SessionFilter.""" + log_path = tmp_path / "no_session_filter.log" + logger = logging.getLogger("_test_no_session_filter") + formatter = logging.Formatter("%(session_tag)s%(message)s") + + hermes_logging._add_rotating_handler( + logger, log_path, + level=logging.INFO, max_bytes=1024, backup_count=1, + formatter=formatter, + ) + + handlers = [h for h in logger.handlers if isinstance(h, RotatingFileHandler)] + assert len(handlers) == 1 + # No _SessionFilter on the handler — record factory handles it + assert len(handlers[0].filters) == 0 + + # But session_tag still works (via record factory) + hermes_logging.set_session_context("factory_test") + logger.info("test msg") + handlers[0].flush() + content = log_path.read_text() + assert "[factory_test]" in content + + # Clean up + for h in list(logger.handlers): + if isinstance(h, RotatingFileHandler): + logger.removeHandler(h) + h.close() + def test_managed_mode_initial_open_sets_group_writable(self, tmp_path): log_path = tmp_path / "managed-open.log" logger = logging.getLogger("_test_rotating_managed_open") diff --git a/tests/tools/test_browser_camofox_state.py b/tests/tools/test_browser_camofox_state.py index b1f128ccee..33a939f094 100644 --- a/tests/tools/test_browser_camofox_state.py +++ b/tests/tools/test_browser_camofox_state.py @@ -59,8 +59,9 @@ class TestCamofoxConfigDefaults: browser_cfg = DEFAULT_CONFIG["browser"] assert browser_cfg["camofox"]["managed_persistence"] is False - def test_config_version_unchanged(self): + def test_config_version_matches_current_schema(self): from hermes_cli.config import DEFAULT_CONFIG - # managed_persistence is auto-merged by _deep_merge, no version bump needed - assert DEFAULT_CONFIG["_config_version"] == 13 + # The current schema version is tracked globally; unrelated default + # options may bump it after browser defaults are added. + assert DEFAULT_CONFIG["_config_version"] == 15 diff --git a/tests/tools/test_checkpoint_manager.py b/tests/tools/test_checkpoint_manager.py index ef843465f1..ba9da6da1f 100644 --- a/tests/tools/test_checkpoint_manager.py +++ b/tests/tools/test_checkpoint_manager.py @@ -1,9 +1,6 @@ """Tests for tools/checkpoint_manager.py — CheckpointManager.""" import logging -import os -import json -import shutil import subprocess import pytest from pathlib import Path @@ -42,6 +39,19 @@ def checkpoint_base(tmp_path): return tmp_path / "checkpoints" +@pytest.fixture() +def fake_home(tmp_path, monkeypatch): + """Set a deterministic fake home for expanduser/path-home behavior.""" + home = tmp_path / "home" + home.mkdir() + monkeypatch.setenv("HOME", str(home)) + monkeypatch.setenv("USERPROFILE", str(home)) + monkeypatch.delenv("HOMEDRIVE", raising=False) + monkeypatch.delenv("HOMEPATH", raising=False) + monkeypatch.setattr(Path, "home", classmethod(lambda cls: home)) + return home + + @pytest.fixture() def mgr(work_dir, checkpoint_base, monkeypatch): """CheckpointManager with redirected checkpoint base.""" @@ -78,6 +88,16 @@ class TestShadowRepoPath: p = _shadow_repo_path(str(work_dir)) assert str(p).startswith(str(checkpoint_base)) + def test_tilde_and_expanded_home_share_shadow_repo(self, fake_home, checkpoint_base, monkeypatch): + monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base) + project = fake_home / "project" + project.mkdir() + + tilde_path = f"~/{project.name}" + expanded_path = str(project) + + assert _shadow_repo_path(tilde_path) == _shadow_repo_path(expanded_path) + # ========================================================================= # Shadow repo init @@ -221,6 +241,20 @@ class TestListCheckpoints: assert result[0]["reason"] == "third" assert result[2]["reason"] == "first" + def test_tilde_path_lists_same_checkpoints_as_expanded_path(self, checkpoint_base, fake_home, monkeypatch): + monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base) + mgr = CheckpointManager(enabled=True, max_snapshots=50) + project = fake_home / "project" + project.mkdir() + (project / "main.py").write_text("v1\n") + + tilde_path = f"~/{project.name}" + assert mgr.ensure_checkpoint(tilde_path, "initial") is True + + listed = mgr.list_checkpoints(str(project)) + assert len(listed) == 1 + assert listed[0]["reason"] == "initial" + # ========================================================================= # CheckpointManager — restoring @@ -271,6 +305,28 @@ class TestRestore: assert len(all_cps) >= 2 assert "pre-rollback" in all_cps[0]["reason"] + def test_tilde_path_supports_diff_and_restore_flow(self, checkpoint_base, fake_home, monkeypatch): + monkeypatch.setattr("tools.checkpoint_manager.CHECKPOINT_BASE", checkpoint_base) + mgr = CheckpointManager(enabled=True, max_snapshots=50) + project = fake_home / "project" + project.mkdir() + file_path = project / "main.py" + file_path.write_text("original\n") + + tilde_path = f"~/{project.name}" + assert mgr.ensure_checkpoint(tilde_path, "initial") is True + mgr.new_turn() + + file_path.write_text("changed\n") + checkpoints = mgr.list_checkpoints(str(project)) + diff_result = mgr.diff(tilde_path, checkpoints[0]["hash"]) + assert diff_result["success"] is True + assert "main.py" in diff_result["diff"] + + restore_result = mgr.restore(tilde_path, checkpoints[0]["hash"]) + assert restore_result["success"] is True + assert file_path.read_text() == "original\n" + # ========================================================================= # CheckpointManager — working dir resolution @@ -310,6 +366,19 @@ class TestWorkingDirResolution: result = mgr.get_working_dir_for_path(str(filepath)) assert result == str(filepath.parent) + def test_resolves_tilde_path_to_project_root(self, fake_home): + mgr = CheckpointManager(enabled=True) + project = fake_home / "myproject" + project.mkdir() + (project / "pyproject.toml").write_text("[project]\n") + subdir = project / "src" + subdir.mkdir() + filepath = subdir / "main.py" + filepath.write_text("x\n") + + result = mgr.get_working_dir_for_path(f"~/{project.name}/src/main.py") + assert result == str(project) + # ========================================================================= # Git env isolation @@ -333,6 +402,14 @@ class TestGitEnvIsolation: env = _git_env(shadow, str(tmp_path)) assert "GIT_INDEX_FILE" not in env + def test_expands_tilde_in_work_tree(self, fake_home, tmp_path): + shadow = tmp_path / "shadow" + work = fake_home / "work" + work.mkdir() + + env = _git_env(shadow, f"~/{work.name}") + assert env["GIT_WORK_TREE"] == str(work.resolve()) + # ========================================================================= # format_checkpoint_list @@ -384,6 +461,8 @@ class TestErrorResilience: assert result is False def test_run_git_allows_expected_nonzero_without_error_log(self, tmp_path, caplog): + work = tmp_path / "work" + work.mkdir() completed = subprocess.CompletedProcess( args=["git", "diff", "--cached", "--quiet"], returncode=1, @@ -395,7 +474,7 @@ class TestErrorResilience: ok, stdout, stderr = _run_git( ["diff", "--cached", "--quiet"], tmp_path / "shadow", - str(tmp_path / "work"), + str(work), allowed_returncodes={1}, ) assert ok is False @@ -403,6 +482,38 @@ class TestErrorResilience: assert stderr == "" assert not caplog.records + def test_run_git_invalid_working_dir_reports_path_error(self, tmp_path, caplog): + missing = tmp_path / "missing" + with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"): + ok, stdout, stderr = _run_git( + ["status"], + tmp_path / "shadow", + str(missing), + ) + assert ok is False + assert stdout == "" + assert "working directory not found" in stderr + assert not any("Git executable not found" in r.getMessage() for r in caplog.records) + + def test_run_git_missing_git_reports_git_not_found(self, tmp_path, monkeypatch, caplog): + work = tmp_path / "work" + work.mkdir() + + def raise_missing_git(*args, **kwargs): + raise FileNotFoundError(2, "No such file or directory", "git") + + monkeypatch.setattr("tools.checkpoint_manager.subprocess.run", raise_missing_git) + with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"): + ok, stdout, stderr = _run_git( + ["status"], + tmp_path / "shadow", + str(work), + ) + assert ok is False + assert stdout == "" + assert stderr == "git not found" + assert any("Git executable not found" in r.getMessage() for r in caplog.records) + def test_checkpoint_failure_does_not_raise(self, mgr, work_dir, monkeypatch): """Checkpoint failures should never raise — they're silently logged.""" def broken_run_git(*args, **kwargs): @@ -411,3 +522,68 @@ class TestErrorResilience: # Should not raise result = mgr.ensure_checkpoint(str(work_dir), "test") assert result is False + + +# ========================================================================= +# Security / Input validation +# ========================================================================= + +class TestSecurity: + def test_restore_rejects_argument_injection(self, mgr, work_dir): + mgr.ensure_checkpoint(str(work_dir), "initial") + # Try to pass a git flag as a commit hash + result = mgr.restore(str(work_dir), "--patch") + assert result["success"] is False + assert "Invalid commit hash" in result["error"] + assert "must not start with '-'" in result["error"] + + result = mgr.restore(str(work_dir), "-p") + assert result["success"] is False + assert "Invalid commit hash" in result["error"] + + def test_restore_rejects_invalid_hex_chars(self, mgr, work_dir): + mgr.ensure_checkpoint(str(work_dir), "initial") + # Git hashes should not contain characters like ;, &, | + result = mgr.restore(str(work_dir), "abc; rm -rf /") + assert result["success"] is False + assert "expected 4-64 hex characters" in result["error"] + + result = mgr.diff(str(work_dir), "abc&def") + assert result["success"] is False + assert "expected 4-64 hex characters" in result["error"] + + def test_restore_rejects_path_traversal(self, mgr, work_dir): + mgr.ensure_checkpoint(str(work_dir), "initial") + # Real commit hash but malicious path + checkpoints = mgr.list_checkpoints(str(work_dir)) + target_hash = checkpoints[0]["hash"] + + # Absolute path outside + result = mgr.restore(str(work_dir), target_hash, file_path="/etc/passwd") + assert result["success"] is False + assert "got absolute path" in result["error"] + + # Relative traversal outside path + result = mgr.restore(str(work_dir), target_hash, file_path="../outside_file.txt") + assert result["success"] is False + assert "escapes the working directory" in result["error"] + + def test_restore_accepts_valid_file_path(self, mgr, work_dir): + mgr.ensure_checkpoint(str(work_dir), "initial") + checkpoints = mgr.list_checkpoints(str(work_dir)) + target_hash = checkpoints[0]["hash"] + + # Valid path inside directory + result = mgr.restore(str(work_dir), target_hash, file_path="main.py") + assert result["success"] is True + + # Another valid path with subdirectories + (work_dir / "subdir").mkdir() + (work_dir / "subdir" / "test.txt").write_text("hello") + mgr.new_turn() + mgr.ensure_checkpoint(str(work_dir), "second") + checkpoints = mgr.list_checkpoints(str(work_dir)) + target_hash = checkpoints[0]["hash"] + + result = mgr.restore(str(work_dir), target_hash, file_path="subdir/test.txt") + assert result["success"] is True diff --git a/tests/tools/test_code_execution.py b/tests/tools/test_code_execution.py index e015e5d42b..a269218c2a 100644 --- a/tests/tools/test_code_execution.py +++ b/tests/tools/test_code_execution.py @@ -380,7 +380,7 @@ class TestStubSchemaDrift(unittest.TestCase): # Parameters that are internal (injected by the handler, not user-facing) _INTERNAL_PARAMS = {"task_id", "user_task"} # Parameters intentionally blocked in the sandbox - _BLOCKED_TERMINAL_PARAMS = {"background", "check_interval", "pty", "notify_on_complete"} + _BLOCKED_TERMINAL_PARAMS = {"background", "pty", "notify_on_complete"} def test_stubs_cover_all_schema_params(self): """Every user-facing parameter in the real schema must appear in the diff --git a/tests/tools/test_modal_bulk_upload.py b/tests/tools/test_modal_bulk_upload.py new file mode 100644 index 0000000000..e179e702aa --- /dev/null +++ b/tests/tools/test_modal_bulk_upload.py @@ -0,0 +1,295 @@ +"""Tests for Modal bulk upload via tar/base64 archive.""" + +import asyncio +import base64 +import io +import tarfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from tools.environments import modal as modal_env + + +def _make_mock_modal_env(monkeypatch, tmp_path): + """Create a minimal mock ModalEnvironment for testing upload methods. + + Returns a ModalEnvironment-like object with _sandbox and _worker mocked. + We don't call __init__ because it requires the Modal SDK. + """ + env = object.__new__(modal_env.ModalEnvironment) + env._sandbox = MagicMock() + env._worker = MagicMock() + env._persistent = False + env._task_id = "test" + env._sync_manager = None + return env + + +def _make_mock_stdin(): + """Create a mock stdin that captures written data.""" + stdin = MagicMock() + written_chunks = [] + + def mock_write(data): + written_chunks.append(data) + + stdin.write = mock_write + stdin.write_eof = MagicMock() + stdin.drain = MagicMock() + stdin.drain.aio = AsyncMock() + stdin._written_chunks = written_chunks + return stdin + + +def _wire_async_exec(env, exec_calls=None): + """Wire mock sandbox.exec.aio and a real run_coroutine on the env. + + Optionally captures exec call args into *exec_calls* list. + Returns (exec_calls, run_kwargs, stdin_mock). + """ + if exec_calls is None: + exec_calls = [] + run_kwargs: dict = {} + stdin_mock = _make_mock_stdin() + + async def mock_exec_fn(*args, **kwargs): + exec_calls.append(args) + proc = MagicMock() + proc.wait = MagicMock() + proc.wait.aio = AsyncMock(return_value=0) + proc.stdin = stdin_mock + proc.stderr = MagicMock() + proc.stderr.read = MagicMock() + proc.stderr.read.aio = AsyncMock(return_value="") + return proc + + env._sandbox.exec = MagicMock() + env._sandbox.exec.aio = mock_exec_fn + + def real_run_coroutine(coro, **kwargs): + run_kwargs.update(kwargs) + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + env._worker.run_coroutine = real_run_coroutine + return exec_calls, run_kwargs, stdin_mock + + +class TestModalBulkUpload: + """Test _modal_bulk_upload method.""" + + def test_empty_files_is_noop(self, monkeypatch, tmp_path): + """Empty file list should not call worker.run_coroutine.""" + env = _make_mock_modal_env(monkeypatch, tmp_path) + env._modal_bulk_upload([]) + env._worker.run_coroutine.assert_not_called() + + def test_tar_archive_contains_all_files(self, monkeypatch, tmp_path): + """The tar archive sent via stdin should contain all files.""" + env = _make_mock_modal_env(monkeypatch, tmp_path) + + src_a = tmp_path / "a.json" + src_b = tmp_path / "b.py" + src_a.write_text("cred_content") + src_b.write_text("skill_content") + + files = [ + (str(src_a), "/root/.hermes/credentials/a.json"), + (str(src_b), "/root/.hermes/skills/b.py"), + ] + + exec_calls, _, stdin_mock = _wire_async_exec(env) + env._modal_bulk_upload(files) + + # Verify the command reads from stdin (no echo with embedded payload) + assert len(exec_calls) == 1 + args = exec_calls[0] + assert args[0] == "bash" + assert args[1] == "-c" + cmd = args[2] + assert "mkdir -p" in cmd + assert "base64 -d" in cmd + assert "tar xzf" in cmd + assert "-C /" in cmd + + # Reassemble the base64 payload from stdin chunks and verify tar contents + payload = "".join(stdin_mock._written_chunks) + tar_data = base64.b64decode(payload) + buf = io.BytesIO(tar_data) + with tarfile.open(fileobj=buf, mode="r:gz") as tar: + names = sorted(tar.getnames()) + assert "root/.hermes/credentials/a.json" in names + assert "root/.hermes/skills/b.py" in names + + # Verify content + a_content = tar.extractfile("root/.hermes/credentials/a.json").read() + assert a_content == b"cred_content" + b_content = tar.extractfile("root/.hermes/skills/b.py").read() + assert b_content == b"skill_content" + + # Verify stdin was closed + stdin_mock.write_eof.assert_called_once() + + def test_mkdir_includes_all_parents(self, monkeypatch, tmp_path): + """Remote parent directories should be pre-created in the command.""" + env = _make_mock_modal_env(monkeypatch, tmp_path) + + src = tmp_path / "f.txt" + src.write_text("data") + + files = [ + (str(src), "/root/.hermes/credentials/f.txt"), + (str(src), "/root/.hermes/skills/deep/nested/f.txt"), + ] + + exec_calls, _, _ = _wire_async_exec(env) + env._modal_bulk_upload(files) + + cmd = exec_calls[0][2] + assert "/root/.hermes/credentials" in cmd + assert "/root/.hermes/skills/deep/nested" in cmd + + def test_single_exec_call(self, monkeypatch, tmp_path): + """Bulk upload should use exactly one exec call regardless of file count.""" + env = _make_mock_modal_env(monkeypatch, tmp_path) + + files = [] + for i in range(20): + src = tmp_path / f"file_{i}.txt" + src.write_text(f"content_{i}") + files.append((str(src), f"/root/.hermes/cache/file_{i}.txt")) + + exec_calls, _, _ = _wire_async_exec(env) + env._modal_bulk_upload(files) + + # Should be exactly 1 exec call, not 20 + assert len(exec_calls) == 1 + + def test_bulk_upload_wired_in_filesyncmanager(self, monkeypatch): + """Verify ModalEnvironment passes bulk_upload_fn to FileSyncManager.""" + captured_kwargs = {} + + def capture_fsm(**kwargs): + captured_kwargs.update(kwargs) + return type("M", (), {"sync": lambda self, **k: None})() + + monkeypatch.setattr(modal_env, "FileSyncManager", capture_fsm) + + # Create a minimal env without full __init__ + env = object.__new__(modal_env.ModalEnvironment) + env._sandbox = MagicMock() + env._worker = MagicMock() + env._persistent = False + env._task_id = "test" + + # Manually call the part of __init__ that wires FileSyncManager + from tools.environments.file_sync import iter_sync_files + env._sync_manager = modal_env.FileSyncManager( + get_files_fn=lambda: iter_sync_files("/root/.hermes"), + upload_fn=env._modal_upload, + delete_fn=env._modal_delete, + bulk_upload_fn=env._modal_bulk_upload, + ) + + assert "bulk_upload_fn" in captured_kwargs + assert captured_kwargs["bulk_upload_fn"] is not None + assert callable(captured_kwargs["bulk_upload_fn"]) + + def test_timeout_set_to_120(self, monkeypatch, tmp_path): + """Bulk upload uses a 120s timeout (not the per-file 15s).""" + env = _make_mock_modal_env(monkeypatch, tmp_path) + + src = tmp_path / "f.txt" + src.write_text("data") + files = [(str(src), "/root/.hermes/f.txt")] + + _, run_kwargs, _ = _wire_async_exec(env) + env._modal_bulk_upload(files) + + assert run_kwargs.get("timeout") == 120 + + def test_nonzero_exit_raises(self, monkeypatch, tmp_path): + """Non-zero exit code from remote exec should raise RuntimeError.""" + env = _make_mock_modal_env(monkeypatch, tmp_path) + + src = tmp_path / "f.txt" + src.write_text("data") + files = [(str(src), "/root/.hermes/f.txt")] + + stdin_mock = _make_mock_stdin() + + async def mock_exec_fn(*args, **kwargs): + proc = MagicMock() + proc.wait = MagicMock() + proc.wait.aio = AsyncMock(return_value=1) # non-zero exit + proc.stdin = stdin_mock + proc.stderr = MagicMock() + proc.stderr.read = MagicMock() + proc.stderr.read.aio = AsyncMock(return_value="tar: error") + return proc + + env._sandbox.exec = MagicMock() + env._sandbox.exec.aio = mock_exec_fn + + def real_run_coroutine(coro, **kwargs): + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + env._worker.run_coroutine = real_run_coroutine + + with pytest.raises(RuntimeError, match="Modal bulk upload failed"): + env._modal_bulk_upload(files) + + def test_payload_not_in_command_string(self, monkeypatch, tmp_path): + """The base64 payload must NOT appear in the bash -c argument. + + This is the core ARG_MAX fix: the payload goes through stdin, + not embedded in the command string. + """ + env = _make_mock_modal_env(monkeypatch, tmp_path) + + src = tmp_path / "f.txt" + src.write_text("some data to upload") + files = [(str(src), "/root/.hermes/f.txt")] + + exec_calls, _, stdin_mock = _wire_async_exec(env) + env._modal_bulk_upload(files) + + # The command should NOT contain an echo with the payload + cmd = exec_calls[0][2] + assert "echo" not in cmd + # The payload should go through stdin + assert len(stdin_mock._written_chunks) > 0 + + def test_stdin_chunked_for_large_payloads(self, monkeypatch, tmp_path): + """Payloads larger than _STDIN_CHUNK_SIZE should be split into multiple writes.""" + env = _make_mock_modal_env(monkeypatch, tmp_path) + + # Use random bytes so gzip cannot compress them -- ensures the + # base64 payload exceeds one 1 MB chunk. + import os as _os + src = tmp_path / "large.bin" + src.write_bytes(_os.urandom(1024 * 1024 + 512 * 1024)) + files = [(str(src), "/root/.hermes/large.bin")] + + exec_calls, _, stdin_mock = _wire_async_exec(env) + env._modal_bulk_upload(files) + + # Should have multiple stdin write chunks + assert len(stdin_mock._written_chunks) >= 2 + + # Reassembled payload should still decode to valid tar + payload = "".join(stdin_mock._written_chunks) + tar_data = base64.b64decode(payload) + buf = io.BytesIO(tar_data) + with tarfile.open(fileobj=buf, mode="r:gz") as tar: + names = tar.getnames() + assert "root/.hermes/large.bin" in names diff --git a/tests/tools/test_ssh_bulk_upload.py b/tests/tools/test_ssh_bulk_upload.py new file mode 100644 index 0000000000..97cb39f53c --- /dev/null +++ b/tests/tools/test_ssh_bulk_upload.py @@ -0,0 +1,517 @@ +"""Tests for SSH bulk upload via tar pipe.""" + +import os +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from tools.environments import ssh as ssh_env +from tools.environments.file_sync import quoted_mkdir_command, unique_parent_dirs +from tools.environments.ssh import SSHEnvironment + + +def _mock_proc(*, returncode=0, poll_return=0, communicate_return=(b"", b""), + stderr_read=b""): + """Create a MagicMock mimicking subprocess.Popen for tar/ssh pipes.""" + m = MagicMock() + m.stdout = MagicMock() + m.returncode = returncode + m.poll.return_value = poll_return + m.communicate.return_value = communicate_return + m.stderr = MagicMock() + m.stderr.read.return_value = stderr_read + return m + + +@pytest.fixture +def mock_env(monkeypatch): + """Create an SSHEnvironment with mocked connection/sync.""" + monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/testuser") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) + monkeypatch.setattr( + ssh_env, "FileSyncManager", + lambda **kw: type("M", (), {"sync": lambda self, **k: None})(), + ) + return SSHEnvironment(host="example.com", user="testuser") + + +class TestSSHBulkUpload: + """Unit tests for _ssh_bulk_upload — tar pipe mechanics.""" + + def test_empty_files_is_noop(self, mock_env): + """Empty file list should not spawn any subprocesses.""" + with patch.object(subprocess, "run") as mock_run, \ + patch.object(subprocess, "Popen") as mock_popen: + mock_env._ssh_bulk_upload([]) + mock_run.assert_not_called() + mock_popen.assert_not_called() + + def test_mkdir_batched_into_single_call(self, mock_env, tmp_path): + """All parent directories should be created in one SSH call.""" + # Create test files + f1 = tmp_path / "a.txt" + f1.write_text("aaa") + f2 = tmp_path / "b.txt" + f2.write_text("bbb") + + files = [ + (str(f1), "/home/testuser/.hermes/skills/a.txt"), + (str(f2), "/home/testuser/.hermes/credentials/b.txt"), + ] + + # Mock subprocess.run for mkdir and Popen for tar pipe + mock_run = MagicMock(return_value=subprocess.CompletedProcess([], 0)) + + def make_proc(cmd, **kwargs): + m = MagicMock() + m.stdout = MagicMock() + m.returncode = 0 + m.poll.return_value = 0 + m.communicate.return_value = (b"", b"") + m.stderr = MagicMock() + m.stderr.read.return_value = b"" + return m + + with patch.object(subprocess, "run", mock_run), \ + patch.object(subprocess, "Popen", side_effect=make_proc): + mock_env._ssh_bulk_upload(files) + + # Exactly one subprocess.run call for mkdir + assert mock_run.call_count == 1 + mkdir_cmd = mock_run.call_args[0][0] + # Should contain mkdir -p with both parent dirs + mkdir_str = " ".join(mkdir_cmd) + assert "mkdir -p" in mkdir_str + assert "/home/testuser/.hermes/skills" in mkdir_str + assert "/home/testuser/.hermes/credentials" in mkdir_str + + def test_staging_symlinks_mirror_remote_layout(self, mock_env, tmp_path): + """Symlinks in staging dir should mirror the remote path structure.""" + f1 = tmp_path / "local_a.txt" + f1.write_text("content a") + + files = [ + (str(f1), "/home/testuser/.hermes/skills/my_skill.md"), + ] + + staging_paths = [] + + def capture_tar_cmd(cmd, **kwargs): + if cmd[0] == "tar": + # Capture the staging dir from -C argument + c_idx = cmd.index("-C") + staging_dir = cmd[c_idx + 1] + # Check the symlink exists + expected = os.path.join( + staging_dir, "home/testuser/.hermes/skills/my_skill.md" + ) + staging_paths.append(expected) + assert os.path.islink(expected), f"Expected symlink at {expected}" + assert os.readlink(expected) == os.path.abspath(str(f1)) + + mock = MagicMock() + mock.stdout = MagicMock() + mock.returncode = 0 + mock.poll.return_value = 0 + mock.communicate.return_value = (b"", b"") + mock.stderr = MagicMock() + mock.stderr.read.return_value = b"" + return mock + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=capture_tar_cmd): + mock_env._ssh_bulk_upload(files) + + assert len(staging_paths) == 1, "tar command should have been called" + + def test_tar_pipe_commands(self, mock_env, tmp_path): + """Verify tar and SSH commands are wired correctly.""" + f1 = tmp_path / "x.txt" + f1.write_text("x") + + files = [(str(f1), "/home/testuser/.hermes/cache/x.txt")] + + popen_cmds = [] + + def capture_popen(cmd, **kwargs): + popen_cmds.append(cmd) + mock = MagicMock() + mock.stdout = MagicMock() + mock.returncode = 0 + mock.poll.return_value = 0 + mock.communicate.return_value = (b"", b"") + mock.stderr = MagicMock() + mock.stderr.read.return_value = b"" + return mock + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=capture_popen): + mock_env._ssh_bulk_upload(files) + + assert len(popen_cmds) == 2, "Should spawn tar + ssh processes" + + tar_cmd = popen_cmds[0] + ssh_cmd = popen_cmds[1] + + # tar: create, dereference symlinks, to stdout + assert tar_cmd[0] == "tar" + assert "-chf" in tar_cmd + assert "-" in tar_cmd # stdout + assert "-C" in tar_cmd + + # ssh: extract from stdin at / + ssh_str = " ".join(ssh_cmd) + assert "ssh" in ssh_str + assert "tar xf - -C /" in ssh_str + assert "testuser@example.com" in ssh_str + + def test_mkdir_failure_raises(self, mock_env, tmp_path): + """mkdir failure should raise RuntimeError before tar pipe.""" + f1 = tmp_path / "y.txt" + f1.write_text("y") + files = [(str(f1), "/home/testuser/.hermes/skills/y.txt")] + + failed_run = subprocess.CompletedProcess([], 1, stderr="Permission denied") + with patch.object(subprocess, "run", return_value=failed_run): + with pytest.raises(RuntimeError, match="remote mkdir failed"): + mock_env._ssh_bulk_upload(files) + + def test_tar_create_failure_raises(self, mock_env, tmp_path): + """tar create failure should raise RuntimeError.""" + f1 = tmp_path / "z.txt" + f1.write_text("z") + files = [(str(f1), "/home/testuser/.hermes/skills/z.txt")] + + mock_tar = MagicMock() + mock_tar.stdout = MagicMock() + mock_tar.returncode = 1 + mock_tar.poll.return_value = 1 + mock_tar.communicate.return_value = (b"tar: error", b"") + mock_tar.stderr = MagicMock() + mock_tar.stderr.read.return_value = b"tar: error" + + mock_ssh = MagicMock() + mock_ssh.communicate.return_value = (b"", b"") + mock_ssh.returncode = 0 + + def popen_side_effect(cmd, **kwargs): + if cmd[0] == "tar": + return mock_tar + return mock_ssh + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=popen_side_effect): + with pytest.raises(RuntimeError, match="tar create failed"): + mock_env._ssh_bulk_upload(files) + + def test_ssh_extract_failure_raises(self, mock_env, tmp_path): + """SSH tar extract failure should raise RuntimeError.""" + f1 = tmp_path / "w.txt" + f1.write_text("w") + files = [(str(f1), "/home/testuser/.hermes/skills/w.txt")] + + mock_tar = MagicMock() + mock_tar.stdout = MagicMock() + mock_tar.returncode = 0 + mock_tar.poll.return_value = 0 + mock_tar.communicate.return_value = (b"", b"") + mock_tar.stderr = MagicMock() + mock_tar.stderr.read.return_value = b"" + + mock_ssh = MagicMock() + mock_ssh.communicate.return_value = (b"", b"Permission denied") + mock_ssh.returncode = 1 + + def popen_side_effect(cmd, **kwargs): + if cmd[0] == "tar": + return mock_tar + return mock_ssh + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=popen_side_effect): + with pytest.raises(RuntimeError, match="tar extract over SSH failed"): + mock_env._ssh_bulk_upload(files) + + def test_ssh_command_uses_control_socket(self, mock_env, tmp_path): + """SSH command for tar extract should reuse ControlMaster socket.""" + f1 = tmp_path / "c.txt" + f1.write_text("c") + files = [(str(f1), "/home/testuser/.hermes/cache/c.txt")] + + popen_cmds = [] + + def capture_popen(cmd, **kwargs): + popen_cmds.append(cmd) + mock = MagicMock() + mock.stdout = MagicMock() + mock.returncode = 0 + mock.poll.return_value = 0 + mock.communicate.return_value = (b"", b"") + mock.stderr = MagicMock() + mock.stderr.read.return_value = b"" + return mock + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=capture_popen): + mock_env._ssh_bulk_upload(files) + + # The SSH command (second Popen call) should include ControlPath + ssh_cmd = popen_cmds[1] + assert f"ControlPath={mock_env.control_socket}" in " ".join(ssh_cmd) + + def test_custom_port_and_key_in_ssh_command(self, monkeypatch, tmp_path): + """Bulk upload SSH command should include custom port and key.""" + monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/home/u") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) + monkeypatch.setattr( + ssh_env, "FileSyncManager", + lambda **kw: type("M", (), {"sync": lambda self, **k: None})(), + ) + env = SSHEnvironment(host="h", user="u", port=2222, key_path="/my/key") + + f1 = tmp_path / "d.txt" + f1.write_text("d") + files = [(str(f1), "/home/u/.hermes/skills/d.txt")] + + run_cmds = [] + popen_cmds = [] + + def capture_run(cmd, **kwargs): + run_cmds.append(cmd) + return subprocess.CompletedProcess([], 0) + + def capture_popen(cmd, **kwargs): + popen_cmds.append(cmd) + mock = MagicMock() + mock.stdout = MagicMock() + mock.returncode = 0 + mock.poll.return_value = 0 + mock.communicate.return_value = (b"", b"") + mock.stderr = MagicMock() + mock.stderr.read.return_value = b"" + return mock + + with patch.object(subprocess, "run", side_effect=capture_run), \ + patch.object(subprocess, "Popen", side_effect=capture_popen): + env._ssh_bulk_upload(files) + + # Check mkdir SSH call includes port and key + assert len(run_cmds) == 1 + mkdir_cmd = run_cmds[0] + assert "-p" in mkdir_cmd and "2222" in mkdir_cmd + assert "-i" in mkdir_cmd and "/my/key" in mkdir_cmd + + # Check tar extract SSH call includes port and key + ssh_cmd = popen_cmds[1] + assert "-p" in ssh_cmd and "2222" in ssh_cmd + assert "-i" in ssh_cmd and "/my/key" in ssh_cmd + + def test_parent_dirs_deduplicated(self, mock_env, tmp_path): + """Multiple files in the same dir should produce one mkdir entry.""" + f1 = tmp_path / "a.txt" + f1.write_text("a") + f2 = tmp_path / "b.txt" + f2.write_text("b") + f3 = tmp_path / "c.txt" + f3.write_text("c") + + files = [ + (str(f1), "/home/testuser/.hermes/skills/a.txt"), + (str(f2), "/home/testuser/.hermes/skills/b.txt"), + (str(f3), "/home/testuser/.hermes/credentials/c.txt"), + ] + + run_cmds = [] + + def capture_run(cmd, **kwargs): + run_cmds.append(cmd) + return subprocess.CompletedProcess([], 0) + + def make_mock_proc(cmd, **kwargs): + mock = MagicMock() + mock.stdout = MagicMock() + mock.returncode = 0 + mock.poll.return_value = 0 + mock.communicate.return_value = (b"", b"") + mock.stderr = MagicMock() + mock.stderr.read.return_value = b"" + return mock + + with patch.object(subprocess, "run", side_effect=capture_run), \ + patch.object(subprocess, "Popen", side_effect=make_mock_proc): + mock_env._ssh_bulk_upload(files) + + # Only one mkdir call + assert len(run_cmds) == 1 + mkdir_str = " ".join(run_cmds[0]) + # skills dir should appear exactly once despite two files + assert mkdir_str.count("/home/testuser/.hermes/skills") == 1 + assert "/home/testuser/.hermes/credentials" in mkdir_str + + def test_tar_stdout_closed_for_sigpipe(self, mock_env, tmp_path): + """tar_proc.stdout must be closed so SIGPIPE propagates correctly.""" + f1 = tmp_path / "s.txt" + f1.write_text("s") + files = [(str(f1), "/home/testuser/.hermes/skills/s.txt")] + + mock_tar_stdout = MagicMock() + + def make_proc(cmd, **kwargs): + mock = MagicMock() + if cmd[0] == "tar": + mock.stdout = mock_tar_stdout + else: + mock.stdout = MagicMock() + mock.returncode = 0 + mock.poll.return_value = 0 + mock.communicate.return_value = (b"", b"") + mock.stderr = MagicMock() + mock.stderr.read.return_value = b"" + return mock + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=make_proc): + mock_env._ssh_bulk_upload(files) + + mock_tar_stdout.close.assert_called_once() + + def test_timeout_kills_both_processes(self, mock_env, tmp_path): + """TimeoutExpired during communicate should kill both processes.""" + f1 = tmp_path / "t.txt" + f1.write_text("t") + files = [(str(f1), "/home/testuser/.hermes/skills/t.txt")] + + mock_tar = MagicMock() + mock_tar.stdout = MagicMock() + mock_tar.returncode = None + mock_tar.poll.return_value = None + + mock_ssh = MagicMock() + mock_ssh.communicate.side_effect = subprocess.TimeoutExpired("ssh", 120) + mock_ssh.returncode = None + + def make_proc(cmd, **kwargs): + if cmd[0] == "tar": + return mock_tar + return mock_ssh + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=make_proc): + with pytest.raises(RuntimeError, match="SSH bulk upload timed out"): + mock_env._ssh_bulk_upload(files) + + mock_tar.kill.assert_called_once() + mock_ssh.kill.assert_called_once() + + +class TestSSHBulkUploadWiring: + """Verify bulk_upload_fn is wired into FileSyncManager.""" + + def test_filesyncmanager_receives_bulk_upload_fn(self, monkeypatch): + """SSHEnvironment should pass _ssh_bulk_upload to FileSyncManager.""" + monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "_detect_remote_home", lambda self: "/root") + monkeypatch.setattr(ssh_env.SSHEnvironment, "_ensure_remote_dirs", lambda self: None) + monkeypatch.setattr(ssh_env.SSHEnvironment, "init_session", lambda self: None) + + captured_kwargs = {} + + class FakeSyncManager: + def __init__(self, **kwargs): + captured_kwargs.update(kwargs) + + def sync(self, **kw): + pass + + monkeypatch.setattr(ssh_env, "FileSyncManager", FakeSyncManager) + + env = SSHEnvironment(host="h", user="u") + + assert "bulk_upload_fn" in captured_kwargs + assert captured_kwargs["bulk_upload_fn"] is not None + # Should be the bound method + assert callable(captured_kwargs["bulk_upload_fn"]) + + +class TestSharedHelpers: + """Direct unit tests for file_sync.py helpers.""" + + def test_quoted_mkdir_command_basic(self): + result = quoted_mkdir_command(["/a", "/b/c"]) + assert result == "mkdir -p /a /b/c" + + def test_quoted_mkdir_command_quotes_special_chars(self): + result = quoted_mkdir_command(["/path/with spaces", "/path/'quotes'"]) + assert "mkdir -p" in result + # shlex.quote wraps in single quotes + assert "'/path/with spaces'" in result + + def test_quoted_mkdir_command_empty(self): + result = quoted_mkdir_command([]) + assert result == "mkdir -p " + + def test_unique_parent_dirs_deduplicates(self): + files = [ + ("/local/a.txt", "/remote/dir/a.txt"), + ("/local/b.txt", "/remote/dir/b.txt"), + ("/local/c.txt", "/remote/other/c.txt"), + ] + result = unique_parent_dirs(files) + assert result == ["/remote/dir", "/remote/other"] + + def test_unique_parent_dirs_sorted(self): + files = [ + ("/local/z.txt", "/z/file.txt"), + ("/local/a.txt", "/a/file.txt"), + ] + result = unique_parent_dirs(files) + assert result == ["/a", "/z"] + + def test_unique_parent_dirs_empty(self): + assert unique_parent_dirs([]) == [] + + +class TestSSHBulkUploadEdgeCases: + """Edge cases for _ssh_bulk_upload.""" + + def test_ssh_popen_failure_kills_tar(self, mock_env, tmp_path): + """If SSH Popen raises, tar process must be killed and cleaned up.""" + f1 = tmp_path / "e.txt" + f1.write_text("e") + files = [(str(f1), "/home/testuser/.hermes/skills/e.txt")] + + mock_tar = _mock_proc() + + call_count = 0 + + def failing_ssh_popen(cmd, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return mock_tar # tar Popen succeeds + raise OSError("SSH binary not found") + + with patch.object(subprocess, "run", + return_value=subprocess.CompletedProcess([], 0)), \ + patch.object(subprocess, "Popen", side_effect=failing_ssh_popen): + with pytest.raises(OSError, match="SSH binary not found"): + mock_env._ssh_bulk_upload(files) + + mock_tar.kill.assert_called_once() + mock_tar.wait.assert_called_once() diff --git a/tests/tools/test_todo_tool.py b/tests/tools/test_todo_tool.py index d4fd03bafe..6215078525 100644 --- a/tests/tools/test_todo_tool.py +++ b/tests/tools/test_todo_tool.py @@ -24,6 +24,18 @@ class TestWriteAndRead: items[0]["content"] = "MUTATED" assert store.read()[0]["content"] == "Task" + def test_write_deduplicates_duplicate_ids(self): + store = TodoStore() + result = store.write([ + {"id": "1", "content": "First version", "status": "pending"}, + {"id": "2", "content": "Other task", "status": "pending"}, + {"id": "1", "content": "Latest version", "status": "in_progress"}, + ]) + assert result == [ + {"id": "2", "content": "Other task", "status": "pending"}, + {"id": "1", "content": "Latest version", "status": "in_progress"}, + ] + class TestHasItems: def test_empty_store(self): diff --git a/tests/tools/test_tool_result_storage.py b/tests/tools/test_tool_result_storage.py index f95b5dc08a..0bbb95bbd6 100644 --- a/tests/tools/test_tool_result_storage.py +++ b/tests/tools/test_tool_result_storage.py @@ -124,6 +124,34 @@ class TestWriteToSandbox: cmd = env.execute.call_args[0][0] assert "mkdir -p /data/data/com.termux/files/usr/tmp/hermes-results" in cmd + def test_path_with_spaces_is_quoted(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + remote_path = "/tmp/hermes results/abc file.txt" + _write_to_sandbox("content", remote_path, env) + cmd = env.execute.call_args[0][0] + assert "'/tmp/hermes results'" in cmd + assert "'/tmp/hermes results/abc file.txt'" in cmd + + def test_shell_metacharacters_neutralized(self): + """Paths with shell metacharacters must be quoted to prevent injection.""" + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + malicious_path = "/tmp/hermes-results/$(whoami).txt" + _write_to_sandbox("content", malicious_path, env) + cmd = env.execute.call_args[0][0] + # The $() must not appear unquoted — shlex.quote wraps it + assert "'/tmp/hermes-results/$(whoami).txt'" in cmd + + def test_semicolon_injection_neutralized(self): + env = MagicMock() + env.execute.return_value = {"output": "", "returncode": 0} + malicious_path = "/tmp/x; rm -rf /; echo .txt" + _write_to_sandbox("content", malicious_path, env) + cmd = env.execute.call_args[0][0] + # The semicolons must be inside quotes, not acting as command separators + assert "'/tmp/x; rm -rf /; echo .txt'" in cmd + class TestResolveStorageDir: def test_defaults_to_storage_dir_without_env(self): diff --git a/tools/approval.py b/tools/approval.py index faf888f184..9a3a4ef260 100644 --- a/tools/approval.py +++ b/tools/approval.py @@ -40,11 +40,18 @@ def reset_current_session_key(token: contextvars.Token[str]) -> None: def get_current_session_key(default: str = "default") -> str: - """Return the active session key, preferring context-local state.""" + """Return the active session key, preferring context-local state. + + Resolution order: + 1. approval-specific contextvars (set by gateway before agent.run) + 2. session_context contextvars (set by _set_session_env) + 3. os.environ fallback (CLI, cron, tests) + """ session_key = _approval_session_key.get() if session_key: return session_key - return os.getenv("HERMES_SESSION_KEY", default) + from gateway.session_context import get_session_env + return get_session_env("HERMES_SESSION_KEY", default) # Sensitive write targets that should trigger approval even when referenced # via shell expansions like $HOME or $HERMES_HOME. diff --git a/tools/checkpoint_manager.py b/tools/checkpoint_manager.py index c298aa0bb6..42900a643d 100644 --- a/tools/checkpoint_manager.py +++ b/tools/checkpoint_manager.py @@ -21,6 +21,7 @@ into the user's project directory. import hashlib import logging import os +import re import shutil import subprocess from pathlib import Path @@ -64,23 +65,72 @@ _GIT_TIMEOUT: int = max(10, min(60, int(os.getenv("HERMES_CHECKPOINT_TIMEOUT", " # Max files to snapshot — skip huge directories to avoid slowdowns. _MAX_FILES = 50_000 +# Valid git commit hash pattern: 4–40 hex chars (short or full SHA-1/SHA-256). +_COMMIT_HASH_RE = re.compile(r'^[0-9a-fA-F]{4,64}$') + + +# --------------------------------------------------------------------------- +# Input validation helpers +# --------------------------------------------------------------------------- + +def _validate_commit_hash(commit_hash: str) -> Optional[str]: + """Validate a commit hash to prevent git argument injection. + + Returns an error string if invalid, None if valid. + Values starting with '-' would be interpreted as git flags + (e.g., '--patch', '-p') instead of revision specifiers. + """ + if not commit_hash or not commit_hash.strip(): + return "Empty commit hash" + if commit_hash.startswith("-"): + return f"Invalid commit hash (must not start with '-'): {commit_hash!r}" + if not _COMMIT_HASH_RE.match(commit_hash): + return f"Invalid commit hash (expected 4-64 hex characters): {commit_hash!r}" + return None + + +def _validate_file_path(file_path: str, working_dir: str) -> Optional[str]: + """Validate a file path to prevent path traversal outside the working directory. + + Returns an error string if invalid, None if valid. + """ + if not file_path or not file_path.strip(): + return "Empty file path" + # Reject absolute paths — restore targets must be relative to the workdir + if os.path.isabs(file_path): + return f"File path must be relative, got absolute path: {file_path!r}" + # Resolve and check containment within working_dir + abs_workdir = _normalize_path(working_dir) + resolved = (abs_workdir / file_path).resolve() + try: + resolved.relative_to(abs_workdir) + except ValueError: + return f"File path escapes the working directory via traversal: {file_path!r}" + return None + # --------------------------------------------------------------------------- # Shadow repo helpers # --------------------------------------------------------------------------- +def _normalize_path(path_value: str) -> Path: + """Return a canonical absolute path for checkpoint operations.""" + return Path(path_value).expanduser().resolve() + + def _shadow_repo_path(working_dir: str) -> Path: """Deterministic shadow repo path: sha256(abs_path)[:16].""" - abs_path = str(Path(working_dir).resolve()) + abs_path = str(_normalize_path(working_dir)) dir_hash = hashlib.sha256(abs_path.encode()).hexdigest()[:16] return CHECKPOINT_BASE / dir_hash def _git_env(shadow_repo: Path, working_dir: str) -> dict: """Build env dict that redirects git to the shadow repo.""" + normalized_working_dir = _normalize_path(working_dir) env = os.environ.copy() env["GIT_DIR"] = str(shadow_repo) - env["GIT_WORK_TREE"] = str(Path(working_dir).resolve()) + env["GIT_WORK_TREE"] = str(normalized_working_dir) env.pop("GIT_INDEX_FILE", None) env.pop("GIT_NAMESPACE", None) env.pop("GIT_ALTERNATE_OBJECT_DIRECTORIES", None) @@ -100,7 +150,17 @@ def _run_git( exits while preserving the normal ``ok = (returncode == 0)`` contract. Example: ``git diff --cached --quiet`` returns 1 when changes exist. """ - env = _git_env(shadow_repo, working_dir) + normalized_working_dir = _normalize_path(working_dir) + if not normalized_working_dir.exists(): + msg = f"working directory not found: {normalized_working_dir}" + logger.error("Git command skipped: %s (%s)", " ".join(["git"] + list(args)), msg) + return False, "", msg + if not normalized_working_dir.is_dir(): + msg = f"working directory is not a directory: {normalized_working_dir}" + logger.error("Git command skipped: %s (%s)", " ".join(["git"] + list(args)), msg) + return False, "", msg + + env = _git_env(shadow_repo, str(normalized_working_dir)) cmd = ["git"] + list(args) allowed_returncodes = allowed_returncodes or set() try: @@ -110,7 +170,7 @@ def _run_git( text=True, timeout=timeout, env=env, - cwd=str(Path(working_dir).resolve()), + cwd=str(normalized_working_dir), ) ok = result.returncode == 0 stdout = result.stdout.strip() @@ -125,9 +185,14 @@ def _run_git( msg = f"git timed out after {timeout}s: {' '.join(cmd)}" logger.error(msg, exc_info=True) return False, "", msg - except FileNotFoundError: - logger.error("Git executable not found: %s", " ".join(cmd), exc_info=True) - return False, "", "git not found" + except FileNotFoundError as exc: + missing_target = getattr(exc, "filename", None) + if missing_target == "git": + logger.error("Git executable not found: %s", " ".join(cmd), exc_info=True) + return False, "", "git not found" + msg = f"working directory not found: {normalized_working_dir}" + logger.error("Git command failed before execution: %s (%s)", " ".join(cmd), msg, exc_info=True) + return False, "", msg except Exception as exc: logger.error("Unexpected git error running %s: %s", " ".join(cmd), exc, exc_info=True) return False, "", str(exc) @@ -154,7 +219,7 @@ def _init_shadow_repo(shadow_repo: Path, working_dir: str) -> Optional[str]: ) (shadow_repo / "HERMES_WORKDIR").write_text( - str(Path(working_dir).resolve()) + "\n", encoding="utf-8" + str(_normalize_path(working_dir)) + "\n", encoding="utf-8" ) logger.debug("Initialised checkpoint repo at %s for %s", shadow_repo, working_dir) @@ -229,7 +294,7 @@ class CheckpointManager: if not self._git_available: return False - abs_dir = str(Path(working_dir).resolve()) + abs_dir = str(_normalize_path(working_dir)) # Skip root, home, and other overly broad directories if abs_dir in ("/", str(Path.home())): @@ -254,7 +319,7 @@ class CheckpointManager: Returns a list of dicts with keys: hash, short_hash, timestamp, reason, files_changed, insertions, deletions. Most recent first. """ - abs_dir = str(Path(working_dir).resolve()) + abs_dir = str(_normalize_path(working_dir)) shadow = _shadow_repo_path(abs_dir) if not (shadow / "HEAD").exists(): @@ -311,7 +376,12 @@ class CheckpointManager: Returns dict with success, diff text, and stat summary. """ - abs_dir = str(Path(working_dir).resolve()) + # Validate commit_hash to prevent git argument injection + hash_err = _validate_commit_hash(commit_hash) + if hash_err: + return {"success": False, "error": hash_err} + + abs_dir = str(_normalize_path(working_dir)) shadow = _shadow_repo_path(abs_dir) if not (shadow / "HEAD").exists(): @@ -364,7 +434,19 @@ class CheckpointManager: Returns dict with success/error info. """ - abs_dir = str(Path(working_dir).resolve()) + # Validate commit_hash to prevent git argument injection + hash_err = _validate_commit_hash(commit_hash) + if hash_err: + return {"success": False, "error": hash_err} + + abs_dir = str(_normalize_path(working_dir)) + + # Validate file_path to prevent path traversal outside the working dir + if file_path: + path_err = _validate_file_path(file_path, abs_dir) + if path_err: + return {"success": False, "error": path_err} + shadow = _shadow_repo_path(abs_dir) if not (shadow / "HEAD").exists(): @@ -413,7 +495,7 @@ class CheckpointManager: (directory containing .git, pyproject.toml, package.json, etc.). Falls back to the file's parent directory. """ - path = Path(file_path).resolve() + path = _normalize_path(file_path) if path.is_dir(): candidate = path else: diff --git a/tools/code_execution_tool.py b/tools/code_execution_tool.py index d6c561e2c3..8b5f794555 100644 --- a/tools/code_execution_tool.py +++ b/tools/code_execution_tool.py @@ -301,7 +301,7 @@ def _call(tool_name, args): # --------------------------------------------------------------------------- # Terminal parameters that must not be used from ephemeral sandbox scripts -_TERMINAL_BLOCKED_PARAMS = {"background", "check_interval", "pty", "notify_on_complete", "watch_patterns"} +_TERMINAL_BLOCKED_PARAMS = {"background", "pty", "notify_on_complete", "watch_patterns"} def _rpc_server_loop( diff --git a/tools/cronjob_tools.py b/tools/cronjob_tools.py index e2db933813..80c88e3534 100644 --- a/tools/cronjob_tools.py +++ b/tools/cronjob_tools.py @@ -456,7 +456,7 @@ Important safety rule: cron-run sessions should not recursively schedule more cr }, "deliver": { "type": "string", - "description": "Delivery target: origin, local, telegram, discord, slack, whatsapp, signal, weixin, matrix, mattermost, homeassistant, dingtalk, feishu, wecom, email, sms, bluebubbles, or platform:chat_id or platform:chat_id:thread_id for Telegram topics. Examples: 'origin', 'local', 'telegram', 'telegram:-1001234567890:17585', 'discord:#engineering'" + "description": "Delivery target: origin, local, telegram, discord, slack, whatsapp, signal, weixin, matrix, mattermost, homeassistant, dingtalk, feishu, wecom, wecom_callback, email, sms, bluebubbles, or platform:chat_id or platform:chat_id:thread_id for Telegram topics. Examples: 'origin', 'local', 'telegram', 'telegram:-1001234567890:17585', 'discord:#engineering'" }, "skills": { "type": "array", diff --git a/tools/environments/base.py b/tools/environments/base.py index 1598c22110..19c3bf024e 100644 --- a/tools/environments/base.py +++ b/tools/environments/base.py @@ -23,6 +23,19 @@ from tools.interrupt import is_interrupted logger = logging.getLogger(__name__) +# Thread-local activity callback. The agent sets this before a tool call so +# long-running _wait_for_process loops can report liveness to the gateway. +_activity_callback_local = threading.local() + + +def set_activity_callback(cb: Callable[[str], None] | None) -> None: + """Register a callback that _wait_for_process fires periodically.""" + _activity_callback_local.callback = cb + + +def _get_activity_callback() -> Callable[[str], None] | None: + return getattr(_activity_callback_local, "callback", None) + def get_sandbox_dir() -> Path: """Return the host-side root for all sandbox storage (Docker workspaces, @@ -370,6 +383,10 @@ class BaseEnvironment(ABC): """Poll-based wait with interrupt checking and stdout draining. Shared across all backends — not overridden. + + Fires the ``activity_callback`` (if set on this instance) every 10s + while the process is running so the gateway's inactivity timeout + doesn't kill long-running commands. """ output_chunks: list[str] = [] @@ -388,6 +405,8 @@ class BaseEnvironment(ABC): drain_thread = threading.Thread(target=_drain, daemon=True) drain_thread.start() deadline = time.monotonic() + timeout + _last_activity_touch = time.monotonic() + _ACTIVITY_INTERVAL = 10.0 # seconds between activity touches while proc.poll() is None: if is_interrupted(): @@ -408,6 +427,17 @@ class BaseEnvironment(ABC): else timeout_msg.lstrip(), "returncode": 124, } + # Periodic activity touch so the gateway knows we're alive + _now = time.monotonic() + if _now - _last_activity_touch >= _ACTIVITY_INTERVAL: + _last_activity_touch = _now + _cb = _get_activity_callback() + if _cb: + try: + _elapsed = int(_now - (deadline - timeout)) + _cb(f"terminal command running ({_elapsed}s elapsed)") + except Exception: + pass time.sleep(0.2) drain_thread.join(timeout=5) diff --git a/tools/environments/daytona.py b/tools/environments/daytona.py index 5fe074681d..c2913e585e 100644 --- a/tools/environments/daytona.py +++ b/tools/environments/daytona.py @@ -15,7 +15,13 @@ from tools.environments.base import ( BaseEnvironment, _ThreadedProcessHandle, ) -from tools.environments.file_sync import FileSyncManager, iter_sync_files, quoted_rm_command +from tools.environments.file_sync import ( + FileSyncManager, + iter_sync_files, + quoted_mkdir_command, + quoted_rm_command, + unique_parent_dirs, +) logger = logging.getLogger(__name__) @@ -150,11 +156,9 @@ class DaytonaEnvironment(BaseEnvironment): if not files: return - # Pre-create all unique parent directories in one shell call - parents = sorted({str(Path(remote).parent) for _, remote in files}) + parents = unique_parent_dirs(files) if parents: - mkdir_cmd = "mkdir -p " + " ".join(shlex.quote(p) for p in parents) - self._sandbox.process.exec(mkdir_cmd) + self._sandbox.process.exec(quoted_mkdir_command(parents)) uploads = [ FileUpload(source=host_path, destination=remote_path) diff --git a/tools/environments/file_sync.py b/tools/environments/file_sync.py index 29b45f858f..64a5b56dc4 100644 --- a/tools/environments/file_sync.py +++ b/tools/environments/file_sync.py @@ -10,6 +10,7 @@ import logging import os import shlex import time +from pathlib import Path from typing import Callable from tools.environments.base import _file_mtime_key @@ -60,6 +61,16 @@ def quoted_rm_command(remote_paths: list[str]) -> str: return "rm -f " + " ".join(shlex.quote(p) for p in remote_paths) +def quoted_mkdir_command(dirs: list[str]) -> str: + """Build a shell ``mkdir -p`` command for a batch of directories.""" + return "mkdir -p " + " ".join(shlex.quote(d) for d in dirs) + + +def unique_parent_dirs(files: list[tuple[str, str]]) -> list[str]: + """Extract sorted unique parent directories from (host, remote) pairs.""" + return sorted({str(Path(remote).parent) for _, remote in files}) + + class FileSyncManager: """Tracks local file changes and syncs to a remote environment. diff --git a/tools/environments/modal.py b/tools/environments/modal.py index 365eca9fb1..5c5c721c1e 100644 --- a/tools/environments/modal.py +++ b/tools/environments/modal.py @@ -5,8 +5,11 @@ wrapper, while preserving Hermes' persistent snapshot behavior across sessions. """ import asyncio +import base64 +import io import logging import shlex +import tarfile import threading from pathlib import Path from typing import Any, Optional @@ -18,7 +21,13 @@ from tools.environments.base import ( _load_json_store, _save_json_store, ) -from tools.environments.file_sync import FileSyncManager, iter_sync_files, quoted_rm_command +from tools.environments.file_sync import ( + FileSyncManager, + iter_sync_files, + quoted_mkdir_command, + quoted_rm_command, + unique_parent_dirs, +) logger = logging.getLogger(__name__) @@ -259,26 +268,84 @@ class ModalEnvironment(BaseEnvironment): get_files_fn=lambda: iter_sync_files("/root/.hermes"), upload_fn=self._modal_upload, delete_fn=self._modal_delete, + bulk_upload_fn=self._modal_bulk_upload, ) self._sync_manager.sync(force=True) self.init_session() def _modal_upload(self, host_path: str, remote_path: str) -> None: - """Upload a single file via base64-over-exec.""" - import base64 + """Upload a single file via base64 piped through stdin.""" content = Path(host_path).read_bytes() b64 = base64.b64encode(content).decode("ascii") container_dir = str(Path(remote_path).parent) cmd = ( f"mkdir -p {shlex.quote(container_dir)} && " - f"echo {shlex.quote(b64)} | base64 -d > {shlex.quote(remote_path)}" + f"base64 -d > {shlex.quote(remote_path)}" ) async def _write(): proc = await self._sandbox.exec.aio("bash", "-c", cmd) + offset = 0 + chunk_size = self._STDIN_CHUNK_SIZE + while offset < len(b64): + proc.stdin.write(b64[offset:offset + chunk_size]) + await proc.stdin.drain.aio() + offset += chunk_size + proc.stdin.write_eof() + await proc.stdin.drain.aio() await proc.wait.aio() - self._worker.run_coroutine(_write(), timeout=15) + self._worker.run_coroutine(_write(), timeout=30) + + # Modal SDK stdin buffer limit (legacy server path). The command-router + # path allows 16 MB, but we must stay under the smaller 2 MB cap for + # compatibility. Chunks are written below this threshold and flushed + # individually via drain(). + _STDIN_CHUNK_SIZE = 1 * 1024 * 1024 # 1 MB — safe for both transport paths + + def _modal_bulk_upload(self, files: list[tuple[str, str]]) -> None: + """Upload many files via tar archive piped through stdin. + + Builds a gzipped tar archive in memory and streams it into a + ``base64 -d | tar xzf -`` pipeline via the process's stdin, + avoiding the Modal SDK's 64 KB ``ARG_MAX_BYTES`` exec-arg limit. + """ + if not files: + return + + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:gz") as tar: + for host_path, remote_path in files: + tar.add(host_path, arcname=remote_path.lstrip("/")) + payload = base64.b64encode(buf.getvalue()).decode("ascii") + + parents = unique_parent_dirs(files) + mkdir_part = quoted_mkdir_command(parents) + cmd = f"{mkdir_part} && base64 -d | tar xzf - -C /" + + async def _bulk(): + proc = await self._sandbox.exec.aio("bash", "-c", cmd) + + # Stream payload through stdin in chunks to stay under the + # SDK's per-write buffer limit (2 MB legacy / 16 MB router). + offset = 0 + chunk_size = self._STDIN_CHUNK_SIZE + while offset < len(payload): + proc.stdin.write(payload[offset:offset + chunk_size]) + await proc.stdin.drain.aio() + offset += chunk_size + + proc.stdin.write_eof() + await proc.stdin.drain.aio() + + exit_code = await proc.wait.aio() + if exit_code != 0: + stderr_text = await proc.stderr.read.aio() + raise RuntimeError( + f"Modal bulk upload failed (exit {exit_code}): {stderr_text}" + ) + + self._worker.run_coroutine(_bulk(), timeout=120) def _modal_delete(self, remote_paths: list[str]) -> None: """Batch-delete remote files via exec.""" diff --git a/tools/environments/ssh.py b/tools/environments/ssh.py index 8cb1b0c570..0491764b2f 100644 --- a/tools/environments/ssh.py +++ b/tools/environments/ssh.py @@ -1,6 +1,7 @@ """SSH remote execution environment with ControlMaster connection persistence.""" import logging +import os import shlex import shutil import subprocess @@ -8,7 +9,13 @@ import tempfile from pathlib import Path from tools.environments.base import BaseEnvironment, _popen_bash -from tools.environments.file_sync import FileSyncManager, iter_sync_files, quoted_rm_command +from tools.environments.file_sync import ( + FileSyncManager, + iter_sync_files, + quoted_mkdir_command, + quoted_rm_command, + unique_parent_dirs, +) logger = logging.getLogger(__name__) @@ -50,6 +57,7 @@ class SSHEnvironment(BaseEnvironment): get_files_fn=lambda: iter_sync_files(f"{self._remote_home}/.hermes"), upload_fn=self._scp_upload, delete_fn=self._ssh_delete, + bulk_upload_fn=self._ssh_bulk_upload, ) self._sync_manager.sync(force=True) @@ -107,9 +115,8 @@ class SSHEnvironment(BaseEnvironment): """Create base ~/.hermes directory tree on remote in one SSH call.""" base = f"{self._remote_home}/.hermes" dirs = [base, f"{base}/skills", f"{base}/credentials", f"{base}/cache"] - mkdir_cmd = "mkdir -p " + " ".join(shlex.quote(d) for d in dirs) cmd = self._build_ssh_command() - cmd.append(mkdir_cmd) + cmd.append(quoted_mkdir_command(dirs)) subprocess.run(cmd, capture_output=True, text=True, timeout=10) # _get_sync_files provided via iter_sync_files in FileSyncManager init @@ -131,6 +138,84 @@ class SSHEnvironment(BaseEnvironment): if result.returncode != 0: raise RuntimeError(f"scp failed: {result.stderr.strip()}") + def _ssh_bulk_upload(self, files: list[tuple[str, str]]) -> None: + """Upload many files in a single tar-over-SSH stream. + + Pipes ``tar c`` on the local side through an SSH connection to + ``tar x`` on the remote, transferring all files in one TCP stream + instead of spawning a subprocess per file. Directory creation is + batched into a single ``mkdir -p`` call beforehand. + + Typical improvement: ~580 files goes from O(N) scp round-trips + to a single streaming transfer. + """ + if not files: + return + + parents = unique_parent_dirs(files) + if parents: + cmd = self._build_ssh_command() + cmd.append(quoted_mkdir_command(parents)) + result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) + if result.returncode != 0: + raise RuntimeError(f"remote mkdir failed: {result.stderr.strip()}") + + # Symlink staging avoids fragile GNU tar --transform rules. + with tempfile.TemporaryDirectory(prefix="hermes-ssh-bulk-") as staging: + for host_path, remote_path in files: + staged = os.path.join(staging, remote_path.lstrip("/")) + os.makedirs(os.path.dirname(staged), exist_ok=True) + os.symlink(os.path.abspath(host_path), staged) + + tar_cmd = ["tar", "-chf", "-", "-C", staging, "."] + ssh_cmd = self._build_ssh_command() + ssh_cmd.append("tar xf - -C /") + + tar_proc = subprocess.Popen( + tar_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + try: + ssh_proc = subprocess.Popen( + ssh_cmd, stdin=tar_proc.stdout, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + except Exception: + tar_proc.kill() + tar_proc.wait() + raise + + # Allow tar_proc to receive SIGPIPE if ssh_proc exits early + tar_proc.stdout.close() + + try: + _, ssh_stderr = ssh_proc.communicate(timeout=120) + # Use communicate() instead of wait() to drain stderr and + # avoid deadlock if tar produces more than PIPE_BUF of errors. + tar_stderr_raw = b"" + if tar_proc.poll() is None: + _, tar_stderr_raw = tar_proc.communicate(timeout=10) + else: + tar_stderr_raw = tar_proc.stderr.read() if tar_proc.stderr else b"" + except subprocess.TimeoutExpired: + tar_proc.kill() + ssh_proc.kill() + tar_proc.wait() + ssh_proc.wait() + raise RuntimeError("SSH bulk upload timed out") + + if tar_proc.returncode != 0: + raise RuntimeError( + f"tar create failed (rc={tar_proc.returncode}): " + f"{tar_stderr_raw.decode(errors='replace').strip()}" + ) + if ssh_proc.returncode != 0: + raise RuntimeError( + f"tar extract over SSH failed (rc={ssh_proc.returncode}): " + f"{ssh_stderr.decode(errors='replace').strip()}" + ) + + logger.debug("SSH: bulk-uploaded %d file(s) via tar pipe", len(files)) + def _ssh_delete(self, remote_paths: list[str]) -> None: """Batch-delete remote files in one SSH call.""" cmd = self._build_ssh_command() diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index f0cbff0f4c..3dfa786e1a 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -1137,7 +1137,6 @@ def terminal_tool( task_id: Optional[str] = None, force: bool = False, workdir: Optional[str] = None, - check_interval: Optional[int] = None, pty: bool = False, notify_on_complete: bool = False, watch_patterns: Optional[List[str]] = None, @@ -1152,7 +1151,6 @@ def terminal_tool( task_id: Unique identifier for environment isolation (optional) force: If True, skip dangerous command check (use after user confirms) workdir: Working directory for this command (optional, uses session cwd if not set) - check_interval: Seconds between auto-checks for background processes (gateway only, min 30) pty: If True, use pseudo-terminal for interactive CLI tools (local backend only) notify_on_complete: If True and background=True, auto-notify the agent when the process exits watch_patterns: List of strings to watch for in background output; triggers notification on match @@ -1424,7 +1422,7 @@ def terminal_tool( # turn. CLI mode uses the completion_queue directly. from gateway.session_context import get_session_env as _gse _gw_platform = _gse("HERMES_SESSION_PLATFORM", "") - if _gw_platform and not check_interval: + if _gw_platform: _gw_chat_id = _gse("HERMES_SESSION_CHAT_ID", "") _gw_thread_id = _gse("HERMES_SESSION_THREAD_ID", "") _gw_user_id = _gse("HERMES_SESSION_USER_ID", "") @@ -1452,39 +1450,6 @@ def terminal_tool( proc_session.watch_patterns = list(watch_patterns) result_data["watch_patterns"] = proc_session.watch_patterns - # Register check_interval watcher (gateway picks this up after agent run) - if check_interval and background: - effective_interval = max(30, check_interval) - if check_interval < 30: - result_data["check_interval_note"] = ( - f"Requested {check_interval}s raised to minimum 30s" - ) - from gateway.session_context import get_session_env as _gse2 - watcher_platform = _gse2("HERMES_SESSION_PLATFORM", "") - watcher_chat_id = _gse2("HERMES_SESSION_CHAT_ID", "") - watcher_thread_id = _gse2("HERMES_SESSION_THREAD_ID", "") - watcher_user_id = _gse2("HERMES_SESSION_USER_ID", "") - watcher_user_name = _gse2("HERMES_SESSION_USER_NAME", "") - - # Store on session for checkpoint persistence - proc_session.watcher_platform = watcher_platform - proc_session.watcher_chat_id = watcher_chat_id - proc_session.watcher_user_id = watcher_user_id - proc_session.watcher_user_name = watcher_user_name - proc_session.watcher_thread_id = watcher_thread_id - proc_session.watcher_interval = effective_interval - - process_registry.pending_watchers.append({ - "session_id": proc_session.id, - "check_interval": effective_interval, - "session_key": session_key, - "platform": watcher_platform, - "chat_id": watcher_chat_id, - "user_id": watcher_user_id, - "user_name": watcher_user_name, - "thread_id": watcher_thread_id, - }) - return json.dumps(result_data, ensure_ascii=False) except Exception as e: return json.dumps({ @@ -1767,11 +1732,6 @@ TERMINAL_SCHEMA = { "type": "string", "description": "Working directory for this command (absolute path). Defaults to the session working directory." }, - "check_interval": { - "type": "integer", - "description": "Seconds between automatic status checks for background processes (gateway/messaging only, minimum 30). When set, I'll proactively report progress.", - "minimum": 30 - }, "pty": { "type": "boolean", "description": "Run in pseudo-terminal (PTY) mode for interactive CLI tools like Codex, Claude Code, or Python REPL. Only works with local and SSH backends. Default: false.", @@ -1800,7 +1760,6 @@ def _handle_terminal(args, **kw): timeout=args.get("timeout"), task_id=kw.get("task_id"), workdir=args.get("workdir"), - check_interval=args.get("check_interval"), pty=args.get("pty", False), notify_on_complete=args.get("notify_on_complete", False), watch_patterns=args.get("watch_patterns"), diff --git a/tools/todo_tool.py b/tools/todo_tool.py index 9021fbc2d3..b0d38a2342 100644 --- a/tools/todo_tool.py +++ b/tools/todo_tool.py @@ -46,11 +46,11 @@ class TodoStore: """ if not merge: # Replace mode: new list entirely - self._items = [self._validate(t) for t in todos] + self._items = [self._validate(t) for t in self._dedupe_by_id(todos)] else: # Merge mode: update existing items by id, append new ones existing = {item["id"]: item for item in self._items} - for t in todos: + for t in self._dedupe_by_id(todos): item_id = str(t.get("id", "")).strip() if not item_id: continue # Can't merge without an id @@ -143,6 +143,15 @@ class TodoStore: return {"id": item_id, "content": content, "status": status} + @staticmethod + def _dedupe_by_id(todos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Collapse duplicate ids, keeping the last occurrence in its position.""" + last_index: Dict[str, int] = {} + for i, item in enumerate(todos): + item_id = str(item.get("id", "")).strip() or "?" + last_index[item_id] = i + return [todos[i] for i in sorted(last_index.values())] + def todo_tool( todos: Optional[List[Dict[str, Any]]] = None, diff --git a/tools/tool_result_storage.py b/tools/tool_result_storage.py index a8ec5440bc..4342264482 100644 --- a/tools/tool_result_storage.py +++ b/tools/tool_result_storage.py @@ -24,6 +24,7 @@ Defense against context-window overflow operates at three levels: import logging import os +import shlex import uuid from tools.budget_config import ( @@ -79,7 +80,7 @@ def _write_to_sandbox(content: str, remote_path: str, env) -> bool: marker = _heredoc_marker(content) storage_dir = os.path.dirname(remote_path) cmd = ( - f"mkdir -p {storage_dir} && cat > {remote_path} << '{marker}'\n" + f"mkdir -p {shlex.quote(storage_dir)} && cat > {shlex.quote(remote_path)} << '{marker}'\n" f"{content}\n" f"{marker}" ) diff --git a/toolsets.py b/toolsets.py index 6fbc963e62..57e03d2500 100644 --- a/toolsets.py +++ b/toolsets.py @@ -365,6 +365,12 @@ TOOLSETS = { "includes": [] }, + "hermes-wecom-callback": { + "description": "WeCom callback toolset - enterprise self-built app messaging (full access)", + "tools": _HERMES_CORE_TOOLS, + "includes": [] + }, + "hermes-sms": { "description": "SMS bot toolset - interact with Hermes via SMS (Twilio)", "tools": _HERMES_CORE_TOOLS, @@ -380,7 +386,7 @@ TOOLSETS = { "hermes-gateway": { "description": "Gateway toolset - union of all messaging platform tools", "tools": [], - "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-bluebubbles", "hermes-homeassistant", "hermes-email", "hermes-sms", "hermes-mattermost", "hermes-matrix", "hermes-dingtalk", "hermes-feishu", "hermes-wecom", "hermes-weixin", "hermes-webhook"] + "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-signal", "hermes-bluebubbles", "hermes-homeassistant", "hermes-email", "hermes-sms", "hermes-mattermost", "hermes-matrix", "hermes-dingtalk", "hermes-feishu", "hermes-wecom", "hermes-wecom-callback", "hermes-weixin", "hermes-webhook"] } } diff --git a/website/docs/developer-guide/adding-platform-adapters.md b/website/docs/developer-guide/adding-platform-adapters.md new file mode 100644 index 0000000000..1ddb07f08b --- /dev/null +++ b/website/docs/developer-guide/adding-platform-adapters.md @@ -0,0 +1,256 @@ +--- +sidebar_position: 9 +--- + +# Adding a Platform Adapter + +This guide covers adding a new messaging platform to the Hermes gateway. A platform adapter connects Hermes to an external messaging service (Telegram, Discord, WeCom, etc.) so users can interact with the agent through that service. + +:::tip +Adding a platform adapter touches 20+ files across code, config, and docs. Use this guide as a checklist — the adapter file itself is typically only 40% of the work. +::: + +## Architecture Overview + +``` +User ↔ Messaging Platform ↔ Platform Adapter ↔ Gateway Runner ↔ AIAgent +``` + +Every adapter extends `BasePlatformAdapter` from `gateway/platforms/base.py` and implements: + +- **`connect()`** — Establish connection (WebSocket, long-poll, HTTP server, etc.) +- **`disconnect()`** — Clean shutdown +- **`send()`** — Send a text message to a chat +- **`send_typing()`** — Show typing indicator (optional) +- **`get_chat_info()`** — Return chat metadata + +Inbound messages are received by the adapter and forwarded via `self.handle_message(event)`, which the base class routes to the gateway runner. + +## Step-by-Step Checklist + +### 1. Platform Enum + +Add your platform to the `Platform` enum in `gateway/config.py`: + +```python +class Platform(str, Enum): + # ... existing platforms ... + NEWPLAT = "newplat" +``` + +### 2. Adapter File + +Create `gateway/platforms/newplat.py`: + +```python +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + BasePlatformAdapter, MessageEvent, MessageType, SendResult, +) + +def check_newplat_requirements() -> bool: + """Return True if dependencies are available.""" + return SOME_SDK_AVAILABLE + +class NewPlatAdapter(BasePlatformAdapter): + def __init__(self, config: PlatformConfig): + super().__init__(config, Platform.NEWPLAT) + # Read config from config.extra dict + extra = config.extra or {} + self._api_key = extra.get("api_key") or os.getenv("NEWPLAT_API_KEY", "") + + async def connect(self) -> bool: + # Set up connection, start polling/webhook + self._mark_connected() + return True + + async def disconnect(self) -> None: + self._running = False + self._mark_disconnected() + + async def send(self, chat_id, content, reply_to=None, metadata=None): + # Send message via platform API + return SendResult(success=True, message_id="...") + + async def get_chat_info(self, chat_id): + return {"name": chat_id, "type": "dm"} +``` + +For inbound messages, build a `MessageEvent` and call `self.handle_message(event)`: + +```python +source = self.build_source( + chat_id=chat_id, + chat_name=name, + chat_type="dm", # or "group" + user_id=user_id, + user_name=user_name, +) +event = MessageEvent( + text=content, + message_type=MessageType.TEXT, + source=source, + message_id=msg_id, +) +await self.handle_message(event) +``` + +### 3. Gateway Config (`gateway/config.py`) + +Three touchpoints: + +1. **`get_connected_platforms()`** — Add a check for your platform's required credentials +2. **`load_gateway_config()`** — Add token env map entry: `Platform.NEWPLAT: "NEWPLAT_TOKEN"` +3. **`_apply_env_overrides()`** — Map all `NEWPLAT_*` env vars to config + +### 4. Gateway Runner (`gateway/run.py`) + +Five touchpoints: + +1. **`_create_adapter()`** — Add an `elif platform == Platform.NEWPLAT:` branch +2. **`_is_user_authorized()` allowed_users map** — `Platform.NEWPLAT: "NEWPLAT_ALLOWED_USERS"` +3. **`_is_user_authorized()` allow_all map** — `Platform.NEWPLAT: "NEWPLAT_ALLOW_ALL_USERS"` +4. **Early env check `_any_allowlist` tuple** — Add `"NEWPLAT_ALLOWED_USERS"` +5. **Early env check `_allow_all` tuple** — Add `"NEWPLAT_ALLOW_ALL_USERS"` +6. **`_UPDATE_ALLOWED_PLATFORMS` frozenset** — Add `Platform.NEWPLAT` + +### 5. Cross-Platform Delivery + +1. **`gateway/platforms/webhook.py`** — Add `"newplat"` to the delivery type tuple +2. **`cron/scheduler.py`** — Add to `_KNOWN_DELIVERY_PLATFORMS` frozenset and `_deliver_result()` platform map + +### 6. CLI Integration + +1. **`hermes_cli/config.py`** — Add all `NEWPLAT_*` vars to `_EXTRA_ENV_KEYS` +2. **`hermes_cli/gateway.py`** — Add entry to `_PLATFORMS` list with key, label, emoji, token_var, setup_instructions, and vars +3. **`hermes_cli/platforms.py`** — Add `PlatformInfo` entry with label and default_toolset (used by `skills_config` and `tools_config` TUIs) +4. **`hermes_cli/setup.py`** — Add `_setup_newplat()` function (can delegate to `gateway.py`) and add tuple to the messaging platforms list +5. **`hermes_cli/status.py`** — Add platform detection entry: `"NewPlat": ("NEWPLAT_TOKEN", "NEWPLAT_HOME_CHANNEL")` +6. **`hermes_cli/dump.py`** — Add `"newplat": "NEWPLAT_TOKEN"` to platform detection dict + +### 7. Tools + +1. **`tools/send_message_tool.py`** — Add `"newplat": Platform.NEWPLAT` to platform map +2. **`tools/cronjob_tools.py`** — Add `newplat` to the delivery target description string + +### 8. Toolsets + +1. **`toolsets.py`** — Add `"hermes-newplat"` toolset definition with `_HERMES_CORE_TOOLS` +2. **`toolsets.py`** — Add `"hermes-newplat"` to the `"hermes-gateway"` includes list + +### 9. Optional: Platform Hints + +**`agent/prompt_builder.py`** — If your platform has specific rendering limitations (no markdown, message length limits, etc.), add an entry to the `_PLATFORM_HINTS` dict. This injects platform-specific guidance into the system prompt: + +```python +_PLATFORM_HINTS = { + # ... + "newplat": ( + "You are chatting via NewPlat. It supports markdown formatting " + "but has a 4000-character message limit." + ), +} +``` + +Not all platforms need hints — only add one if the agent's behavior should differ. + +### 10. Tests + +Create `tests/gateway/test_newplat.py` covering: + +- Adapter construction from config +- Message event building +- Send method (mock the external API) +- Platform-specific features (encryption, routing, etc.) + +### 11. Documentation + +| File | What to add | +|------|-------------| +| `website/docs/user-guide/messaging/newplat.md` | Full platform setup page | +| `website/docs/user-guide/messaging/index.md` | Platform comparison table, architecture diagram, toolsets table, security section, next-steps link | +| `website/docs/reference/environment-variables.md` | All NEWPLAT_* env vars | +| `website/docs/reference/toolsets-reference.md` | hermes-newplat toolset | +| `website/docs/integrations/index.md` | Platform link | +| `website/sidebars.ts` | Sidebar entry for the docs page | +| `website/docs/developer-guide/architecture.md` | Adapter count + listing | +| `website/docs/developer-guide/gateway-internals.md` | Adapter file listing | + +## Parity Audit + +Before marking a new platform PR as complete, run a parity audit against an established platform: + +```bash +# Find every .py file mentioning the reference platform +search_files "bluebubbles" output_mode="files_only" file_glob="*.py" + +# Find every .py file mentioning the new platform +search_files "newplat" output_mode="files_only" file_glob="*.py" + +# Any file in the first set but not the second is a potential gap +``` + +Repeat for `.md` and `.ts` files. Investigate each gap — is it a platform enumeration (needs updating) or a platform-specific reference (skip)? + +## Common Patterns + +### Long-Poll Adapters + +If your adapter uses long-polling (like Telegram or Weixin), use a polling loop task: + +```python +async def connect(self): + self._poll_task = asyncio.create_task(self._poll_loop()) + self._mark_connected() + +async def _poll_loop(self): + while self._running: + messages = await self._fetch_updates() + for msg in messages: + await self.handle_message(self._build_event(msg)) +``` + +### Callback/Webhook Adapters + +If the platform pushes messages to your endpoint (like WeCom Callback), run an HTTP server: + +```python +async def connect(self): + self._app = web.Application() + self._app.router.add_post("/callback", self._handle_callback) + # ... start aiohttp server + self._mark_connected() + +async def _handle_callback(self, request): + event = self._build_event(await request.text()) + await self._message_queue.put(event) + return web.Response(text="success") # Acknowledge immediately +``` + +For platforms with tight response deadlines (e.g., WeCom's 5-second limit), always acknowledge immediately and deliver the agent's reply proactively via API later. Agent sessions run 3–30 minutes — inline replies within a callback response window are not feasible. + +### Token Locks + +If the adapter holds a persistent connection with a unique credential, add a scoped lock to prevent two profiles from using the same credential: + +```python +from gateway.status import acquire_scoped_lock, release_scoped_lock + +async def connect(self): + if not acquire_scoped_lock("newplat", self._token): + logger.error("Token already in use by another profile") + return False + # ... connect + +async def disconnect(self): + release_scoped_lock("newplat", self._token) +``` + +## Reference Implementations + +| Adapter | Pattern | Complexity | Good reference for | +|---------|---------|------------|-------------------| +| `bluebubbles.py` | REST + webhook | Medium | Simple REST API integration | +| `weixin.py` | Long-poll + CDN | High | Media handling, encryption | +| `wecom_callback.py` | Callback/webhook | Medium | HTTP server, AES crypto, multi-app | +| `telegram.py` | Long-poll + Bot API | High | Full-featured adapter with groups, threads | diff --git a/website/docs/getting-started/nix-setup.md b/website/docs/getting-started/nix-setup.md index 4db4939868..858315329b 100644 --- a/website/docs/getting-started/nix-setup.md +++ b/website/docs/getting-started/nix-setup.md @@ -122,6 +122,41 @@ services.hermes-agent.environmentFiles = [ "/var/lib/hermes/env" ]; Setting `addToSystemPackages = true` does two things: puts the `hermes` CLI on your system PATH **and** sets `HERMES_HOME` system-wide so the interactive CLI shares state (sessions, skills, cron) with the gateway service. Without it, running `hermes` in your shell creates a separate `~/.hermes/` directory. ::: +:::info Container-aware CLI +When `container.enable = true` and `addToSystemPackages = true`, **every** `hermes` command on the host automatically routes into the managed container. This means your interactive CLI session runs inside the same environment as the gateway service — with access to all container-installed packages and tools. + +- The routing is transparent: `hermes chat`, `hermes sessions list`, `hermes version`, etc. all exec into the container under the hood +- All CLI flags are forwarded as-is +- If the container isn't running, the CLI retries briefly (5s with a spinner for interactive use, 10s silently for scripts) then fails with a clear error — no silent fallback +- For developers working on the hermes codebase, set `HERMES_DEV=1` to bypass container routing and run the local checkout directly + +Set `container.hostUsers` to create a `~/.hermes` symlink to the service state directory, so the host CLI and the container share sessions, config, and memories: + +```nix +services.hermes-agent = { + container.enable = true; + container.hostUsers = [ "your-username" ]; + addToSystemPackages = true; +}; +``` + +Users listed in `hostUsers` are automatically added to the `hermes` group for file permission access. + +**Podman users:** The NixOS service runs the container as root. Docker users get access via the `docker` group socket, but Podman's rootful containers require sudo. Grant passwordless sudo for your container runtime: + +```nix +security.sudo.extraRules = [{ + users = [ "your-username" ]; + commands = [{ + command = "/run/current-system/sw/bin/podman"; + options = [ "NOPASSWD" ]; + }]; +}]; +``` + +The CLI auto-detects when sudo is needed and uses it transparently. Without this, you'll need to run `sudo hermes chat` manually. +::: + ### Verify It Works After `nixos-rebuild switch`, check that the service is running: @@ -246,6 +281,7 @@ Run `nix build .#configKeys && cat result` to see every leaf config key extracte container = { image = "ubuntu:24.04"; backend = "docker"; + hostUsers = [ "your-username" ]; extraVolumes = [ "/home/user/projects:/projects:rw" ]; extraOptions = [ "--gpus" "all" ]; }; @@ -285,6 +321,7 @@ Quick reference for the most common things Nix users want to customize: | Mount host directories into container | `container.extraVolumes` | `[ "/data:/data:rw" ]` | | Pass GPU access to container | `container.extraOptions` | `[ "--gpus" "all" ]` | | Use Podman instead of Docker | `container.backend` | `"podman"` | +| Share state between host CLI and container | `container.hostUsers` | `[ "sidbin" ]` | | Add tools to the service PATH (native only) | `extraPackages` | `[ pkgs.pandoc pkgs.imagemagick ]` | | Use a custom base image | `container.image` | `"ubuntu:24.04"` | | Override the hermes package | `package` | `inputs.hermes-agent.packages.${system}.default.override { ... }` | @@ -518,6 +555,7 @@ When container mode is enabled, hermes runs inside a persistent Ubuntu container Host Container ──── ───────── /nix/store/...-hermes-agent-0.1.0 ──► /nix/store/... (ro) +~/.hermes -> /var/lib/hermes/.hermes (symlink bridge, per hostUsers) /var/lib/hermes/ ──► /data/ (rw) ├── current-package -> /nix/store/... (symlink, updated each rebuild) ├── .gc-root -> /nix/store/... (prevents nix-collect-garbage) @@ -526,6 +564,7 @@ Host Container │ ├── .env (merged from environment + environmentFiles) │ ├── config.yaml (Nix-generated, deep-merged by activation) │ ├── .managed (marker file) + │ ├── .container-mode (routing metadata: backend, exec_user, etc.) │ ├── state.db, sessions/, memories/ (runtime state) │ └── mcp-tokens/ (OAuth tokens for MCP servers) ├── home/ ──► /home/hermes (rw) @@ -698,6 +737,7 @@ nix build .#checks.x86_64-linux.config-roundtrip # merge script preserves use | `container.image` | `str` | `"ubuntu:24.04"` | Base image (pulled at runtime) | | `container.extraVolumes` | `listOf str` | `[]` | Extra volume mounts (`host:container:mode`) | | `container.extraOptions` | `listOf str` | `[]` | Extra args passed to `docker create` | +| `container.hostUsers` | `listOf str` | `[]` | Interactive users who get a `~/.hermes` symlink to the service stateDir and are auto-added to the `hermes` group | --- @@ -818,3 +858,5 @@ nix-store --query --roots $(docker exec hermes-agent readlink /data/current-pack | `hermes version` shows old version | Container not restarted | `systemctl restart hermes-agent` | | Permission denied on `/var/lib/hermes` | State dir is `0750 hermes:hermes` | Use `docker exec` or `sudo -u hermes` | | `nix-collect-garbage` removed hermes | GC root missing | Restart the service (preStart recreates the GC root) | +| `no container with name or ID "hermes-agent"` (Podman) | Podman rootful container not visible to regular user | Add passwordless sudo for podman (see [Container-aware CLI](#container-aware-cli) section) | +| `unable to find user hermes` | Container still starting (entrypoint hasn't created user yet) | Wait a few seconds and retry — the CLI retries automatically | diff --git a/website/docs/getting-started/quickstart.md b/website/docs/getting-started/quickstart.md index bd26f1eebb..9646fbcc9f 100644 --- a/website/docs/getting-started/quickstart.md +++ b/website/docs/getting-started/quickstart.md @@ -64,6 +64,10 @@ hermes setup # Or configure everything at once | **Vercel AI Gateway** | Vercel AI Gateway routing | Set `AI_GATEWAY_API_KEY` | | **Custom Endpoint** | VLLM, SGLang, Ollama, or any OpenAI-compatible API | Set base URL + API key | +:::caution Minimum context: 64K tokens +Hermes Agent requires a model with at least **64,000 tokens** of context. Models with smaller windows cannot maintain enough working memory for multi-step tool-calling workflows and will be rejected at startup. Most hosted models (Claude, GPT, Gemini, Qwen, DeepSeek) meet this easily. If you're running a local model, set its context size to at least 64K (e.g. `--ctx-size 65536` for llama.cpp or `-c 65536` for Ollama). +::: + :::tip You can switch providers at any time with `hermes model` — no code changes, no lock-in. When configuring a custom endpoint, Hermes will prompt for the context window size and auto-detect it when possible. See [Context Length Detection](../integrations/providers.md#context-length-detection) for details. ::: diff --git a/website/docs/integrations/index.md b/website/docs/integrations/index.md index 6dccc44e96..cfc82d41d1 100644 --- a/website/docs/integrations/index.md +++ b/website/docs/integrations/index.md @@ -82,7 +82,7 @@ Speech-to-text supports three providers: local Whisper (free, runs on-device), G Hermes runs as a gateway bot on 15+ messaging platforms, all configured through the same `gateway` subsystem: -- **[Telegram](/docs/user-guide/messaging/telegram)**, **[Discord](/docs/user-guide/messaging/discord)**, **[Slack](/docs/user-guide/messaging/slack)**, **[WhatsApp](/docs/user-guide/messaging/whatsapp)**, **[Signal](/docs/user-guide/messaging/signal)**, **[Matrix](/docs/user-guide/messaging/matrix)**, **[Mattermost](/docs/user-guide/messaging/mattermost)**, **[Email](/docs/user-guide/messaging/email)**, **[SMS](/docs/user-guide/messaging/sms)**, **[DingTalk](/docs/user-guide/messaging/dingtalk)**, **[Feishu/Lark](/docs/user-guide/messaging/feishu)**, **[WeCom](/docs/user-guide/messaging/wecom)**, **[Weixin](/docs/user-guide/messaging/weixin)**, **[BlueBubbles](/docs/user-guide/messaging/bluebubbles)**, **[Home Assistant](/docs/user-guide/messaging/homeassistant)**, **[Webhooks](/docs/user-guide/messaging/webhooks)** +- **[Telegram](/docs/user-guide/messaging/telegram)**, **[Discord](/docs/user-guide/messaging/discord)**, **[Slack](/docs/user-guide/messaging/slack)**, **[WhatsApp](/docs/user-guide/messaging/whatsapp)**, **[Signal](/docs/user-guide/messaging/signal)**, **[Matrix](/docs/user-guide/messaging/matrix)**, **[Mattermost](/docs/user-guide/messaging/mattermost)**, **[Email](/docs/user-guide/messaging/email)**, **[SMS](/docs/user-guide/messaging/sms)**, **[DingTalk](/docs/user-guide/messaging/dingtalk)**, **[Feishu/Lark](/docs/user-guide/messaging/feishu)**, **[WeCom](/docs/user-guide/messaging/wecom)**, **[WeCom Callback](/docs/user-guide/messaging/wecom-callback)**, **[Weixin](/docs/user-guide/messaging/weixin)**, **[BlueBubbles](/docs/user-guide/messaging/bluebubbles)**, **[Home Assistant](/docs/user-guide/messaging/homeassistant)**, **[Webhooks](/docs/user-guide/messaging/webhooks)** See the [Messaging Gateway overview](/docs/user-guide/messaging) for the platform comparison table and setup guide. diff --git a/website/docs/reference/environment-variables.md b/website/docs/reference/environment-variables.md index a548a6ff6d..bec5ff1c37 100644 --- a/website/docs/reference/environment-variables.md +++ b/website/docs/reference/environment-variables.md @@ -232,6 +232,15 @@ For cloud sandbox backends, persistence is filesystem-oriented. `TERMINAL_LIFETI | `WECOM_WEBSOCKET_URL` | Custom WebSocket URL (default: `wss://openws.work.weixin.qq.com`) | | `WECOM_ALLOWED_USERS` | Comma-separated WeCom user IDs allowed to message the bot | | `WECOM_HOME_CHANNEL` | WeCom chat ID for cron delivery and notifications | +| `WECOM_CALLBACK_CORP_ID` | WeCom enterprise Corp ID for callback self-built app | +| `WECOM_CALLBACK_CORP_SECRET` | Corp secret for the self-built app | +| `WECOM_CALLBACK_AGENT_ID` | Agent ID of the self-built app | +| `WECOM_CALLBACK_TOKEN` | Callback verification token | +| `WECOM_CALLBACK_ENCODING_AES_KEY` | AES key for callback encryption | +| `WECOM_CALLBACK_HOST` | Callback server bind address (default: `0.0.0.0`) | +| `WECOM_CALLBACK_PORT` | Callback server port (default: `8645`) | +| `WECOM_CALLBACK_ALLOWED_USERS` | Comma-separated user IDs for allowlist | +| `WECOM_CALLBACK_ALLOW_ALL_USERS` | Set `true` to allow all users without an allowlist | | `WEIXIN_ACCOUNT_ID` | Weixin account ID obtained via QR login through iLink Bot API | | `WEIXIN_TOKEN` | Weixin authentication token obtained via QR login through iLink Bot API | | `WEIXIN_BASE_URL` | Override Weixin iLink Bot API base URL (default: `https://ilinkai.weixin.qq.com`) | diff --git a/website/docs/reference/toolsets-reference.md b/website/docs/reference/toolsets-reference.md index 5516cfdfa5..96856552e0 100644 --- a/website/docs/reference/toolsets-reference.md +++ b/website/docs/reference/toolsets-reference.md @@ -103,6 +103,7 @@ Platform toolsets define the complete tool configuration for a deployment target | `hermes-dingtalk` | Same as `hermes-cli`. | | `hermes-feishu` | Same as `hermes-cli`. | | `hermes-wecom` | Same as `hermes-cli`. | +| `hermes-wecom-callback` | WeCom callback toolset — enterprise self-built app messaging (full access). | | `hermes-weixin` | Same as `hermes-cli`. | | `hermes-bluebubbles` | Same as `hermes-cli`. | | `hermes-homeassistant` | Same as `hermes-cli`. | diff --git a/website/docs/user-guide/configuration.md b/website/docs/user-guide/configuration.md index 7b735bbdee..9f7c9e2dd4 100644 --- a/website/docs/user-guide/configuration.md +++ b/website/docs/user-guide/configuration.md @@ -800,7 +800,7 @@ You can also change the reasoning effort at runtime with the `/reasoning` comman ## Tool-Use Enforcement -Some models (especially GPT-family) occasionally describe intended actions as text instead of making tool calls. Tool-use enforcement injects guidance that steers the model back to actually calling tools. +Some models occasionally describe intended actions as text instead of making tool calls ("I would run the tests..." instead of actually calling the terminal). Tool-use enforcement injects system prompt guidance that steers the model back to actually calling tools. ```yaml agent: @@ -809,12 +809,31 @@ agent: | Value | Behavior | |-------|----------| -| `"auto"` (default) | Enabled for GPT models (`gpt-`, `openai/gpt-`) and disabled for all others. | -| `true` | Always enabled for all models. | -| `false` | Always disabled. | -| `["gpt-", "o1-", "custom-model"]` | Enabled only for models whose name contains one of the listed substrings. | +| `"auto"` (default) | Enabled for models matching: `gpt`, `codex`, `gemini`, `gemma`, `grok`. Disabled for all others (Claude, DeepSeek, Qwen, etc.). | +| `true` | Always enabled, regardless of model. Useful if you notice your current model describing actions instead of performing them. | +| `false` | Always disabled, regardless of model. | +| `["gpt", "codex", "qwen", "llama"]` | Enabled only when the model name contains one of the listed substrings (case-insensitive). | -When enabled, the system prompt includes guidance reminding the model to make actual tool calls rather than describing what it would do. This is transparent to the user and has no effect on models that already use tools reliably. +### What it injects + +When enabled, three layers of guidance may be added to the system prompt: + +1. **General tool-use enforcement** (all matched models) — instructs the model to make tool calls immediately instead of describing intentions, keep working until the task is complete, and never end a turn with a promise of future action. + +2. **OpenAI execution discipline** (GPT and Codex models only) — additional guidance addressing GPT-specific failure modes: abandoning work on partial results, skipping prerequisite lookups, hallucinating instead of using tools, and declaring "done" without verification. + +3. **Google operational guidance** (Gemini and Gemma models only) — conciseness, absolute paths, parallel tool calls, and verify-before-edit patterns. + +These are transparent to the user and only affect the system prompt. Models that already use tools reliably (like Claude) don't need this guidance, which is why `"auto"` excludes them. + +### When to turn it on + +If you're using a model not in the default auto list and notice it frequently describes what it *would* do instead of doing it, set `tool_use_enforcement: true` or add the model substring to the list: + +```yaml +agent: + tool_use_enforcement: ["gpt", "codex", "gemini", "grok", "my-custom-model"] +``` ## TTS Configuration @@ -846,6 +865,7 @@ display: tool_progress: all # off | new | all | verbose tool_progress_command: false # Enable /verbose slash command in messaging gateway tool_progress_overrides: {} # Per-platform overrides (see below) + interim_assistant_messages: true # Gateway: send natural mid-turn assistant updates as separate messages skin: default # Built-in or custom CLI skin (see user-guide/features/skins) personality: "kawaii" # Legacy cosmetic field still surfaced in some summaries compact: false # Compact output mode (less whitespace) @@ -881,6 +901,8 @@ display: Platforms without an override fall back to the global `tool_progress` value. Valid platform keys: `telegram`, `discord`, `slack`, `signal`, `whatsapp`, `matrix`, `mattermost`, `email`, `sms`, `homeassistant`, `dingtalk`, `feishu`, `wecom`, `weixin`, `bluebubbles`. +`interim_assistant_messages` is gateway-only. When enabled, Hermes sends completed mid-turn assistant updates as separate chat messages. This is independent from `tool_progress` and does not require gateway streaming. + ## Privacy ```yaml @@ -971,6 +993,8 @@ streaming: When enabled, the bot sends a message on the first token, then progressively edits it as more tokens arrive. Platforms that don't support message editing (Signal, Email, Home Assistant) are auto-detected on the first attempt — streaming is gracefully disabled for that session with no flood of messages. +For separate natural mid-turn assistant updates without progressive token editing, set `display.interim_assistant_messages: true`. + **Overflow handling:** If the streamed text exceeds the platform's message length limit (~4096 chars), the current message is finalized and a new one starts automatically. :::note diff --git a/website/docs/user-guide/features/code-execution.md b/website/docs/user-guide/features/code-execution.md index 01ee862073..53668da901 100644 --- a/website/docs/user-guide/features/code-execution.md +++ b/website/docs/user-guide/features/code-execution.md @@ -153,7 +153,7 @@ When your script calls a function like `web_search("query")`: 3. The result is sent back over the socket 4. The function returns the parsed result -This means tool calls inside scripts behave identically to normal tool calls — same rate limits, same error handling, same capabilities. The only restriction is that `terminal()` is foreground-only (no `background`, `pty`, or `check_interval` parameters). +This means tool calls inside scripts behave identically to normal tool calls — same rate limits, same error handling, same capabilities. The only restriction is that `terminal()` is foreground-only (no `background` or `pty` parameters). ## Error Handling diff --git a/website/docs/user-guide/messaging/index.md b/website/docs/user-guide/messaging/index.md index 335c6530bc..f4131385e2 100644 --- a/website/docs/user-guide/messaging/index.md +++ b/website/docs/user-guide/messaging/index.md @@ -27,6 +27,7 @@ For the full voice feature set — including CLI microphone mode, spoken replies | DingTalk | — | — | — | — | — | ✅ | ✅ | | Feishu/Lark | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | WeCom | ✅ | ✅ | ✅ | — | — | ✅ | ✅ | +| WeCom Callback | — | — | — | — | — | — | — | | Weixin | ✅ | ✅ | ✅ | — | — | ✅ | ✅ | | BlueBubbles | — | ✅ | ✅ | — | ✅ | ✅ | — | @@ -51,6 +52,7 @@ flowchart TB dt[DingTalk] fs[Feishu/Lark] wc[WeCom] + wcb[WeCom Callback] wx[Weixin] bb[BlueBubbles] api["API Server
(OpenAI-compatible)"] @@ -75,6 +77,7 @@ flowchart TB dt --> store fs --> store wc --> store + wcb --> store wx --> store bb --> store api --> store @@ -178,6 +181,9 @@ EMAIL_ALLOWED_USERS=trusted@example.com,colleague@work.com MATTERMOST_ALLOWED_USERS=3uo8dkh1p7g1mfk49ear5fzs5c MATRIX_ALLOWED_USERS=@alice:matrix.org DINGTALK_ALLOWED_USERS=user-id-1 +FEISHU_ALLOWED_USERS=ou_xxxxxxxx,ou_yyyyyyyy +WECOM_ALLOWED_USERS=user-id-1,user-id-2 +WECOM_CALLBACK_ALLOWED_USERS=user-id-1,user-id-2 # Or allow GATEWAY_ALLOWED_USERS=123456789,987654321 @@ -360,6 +366,7 @@ Each platform has its own toolset: | DingTalk | `hermes-dingtalk` | Full tools including terminal | | Feishu/Lark | `hermes-feishu` | Full tools including terminal | | WeCom | `hermes-wecom` | Full tools including terminal | +| WeCom Callback | `hermes-wecom-callback` | Full tools including terminal | | Weixin | `hermes-weixin` | Full tools including terminal | | BlueBubbles | `hermes-bluebubbles` | Full tools including terminal | | API Server | `hermes` (default) | Full tools including terminal | @@ -380,6 +387,7 @@ Each platform has its own toolset: - [DingTalk Setup](dingtalk.md) - [Feishu/Lark Setup](feishu.md) - [WeCom Setup](wecom.md) +- [WeCom Callback Setup](wecom-callback.md) - [Weixin Setup (WeChat)](weixin.md) - [BlueBubbles Setup (iMessage)](bluebubbles.md) - [Open WebUI + API Server](open-webui.md) diff --git a/website/docs/user-guide/messaging/matrix.md b/website/docs/user-guide/messaging/matrix.md index 2c9bdb2291..ccde0740d6 100644 --- a/website/docs/user-guide/messaging/matrix.md +++ b/website/docs/user-guide/messaging/matrix.md @@ -344,9 +344,79 @@ pip install 'hermes-agent[matrix]' **Fix**: 1. Verify `libolm` is installed on your system (see the E2EE section above). 2. Make sure `MATRIX_ENCRYPTION=true` is set in your `.env`. -3. In your Matrix client (Element), go to the bot's profile → **Sessions** → verify/trust the bot's device. +3. In your Matrix client (Element), go to the bot's profile -> Sessions -> verify/trust the bot's device. 4. If the bot just joined an encrypted room, it can only decrypt messages sent *after* it joined. Older messages are inaccessible. +### Upgrading from a previous version with E2EE + +If you previously used Hermes with `MATRIX_ENCRYPTION=true` and are upgrading to +a version that uses the new SQLite-based crypto store, the bot's encryption +identity has changed. Your Matrix client (Element) may cache the old device keys +and refuse to share encryption sessions with the bot. + +**Symptoms**: The bot connects and shows "E2EE enabled" in the logs, but all +messages show "could not decrypt event" and the bot never responds. + +**What's happening**: The old encryption state (from the previous `matrix-nio` or +serialization-based `mautrix` backend) is incompatible with the new SQLite crypto +store. The bot creates a fresh encryption identity, but your Matrix client still +has the old keys cached and won't share the room's encryption session with a +device whose keys changed. This is a Matrix security feature -- clients treat +changed identity keys for the same device as suspicious. + +**Fix** (one-time migration): + +1. **Generate a new access token** to get a fresh device ID. The simplest way: + + ```bash + curl -X POST https://your-server/_matrix/client/v3/login \ + -H "Content-Type: application/json" \ + -d '{ + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "@hermes:your-server.org"}, + "password": "your-password", + "initial_device_display_name": "Hermes Agent" + }' + ``` + + Copy the new `access_token` and update `MATRIX_ACCESS_TOKEN` in `~/.hermes/.env`. + +2. **Delete old encryption state**: + + ```bash + rm -f ~/.hermes/platforms/matrix/store/crypto.db + rm -f ~/.hermes/platforms/matrix/store/crypto_store.* + ``` + +3. **Force your Matrix client to rotate the encryption session**. In Element, + open the DM room with the bot and type `/discardsession`. This forces Element + to create a new encryption session and share it with the bot's new device. + +4. **Restart the gateway**: + + ```bash + hermes gateway run + ``` + +5. **Send a new message**. The bot should decrypt and respond normally. + +:::note +After migration, messages sent *before* the upgrade cannot be decrypted -- the old +encryption keys are gone. This only affects the transition; new messages work +normally. +::: + +:::tip +**New installations are not affected.** This migration is only needed if you had +a working E2EE setup with a previous version of Hermes and are upgrading. + +**Why a new access token?** Each Matrix access token is bound to a specific device +ID. Reusing the same device ID with new encryption keys causes other Matrix +clients to distrust the device (they see changed identity keys as a potential +security breach). A new access token gets a new device ID with no stale key +history, so other clients trust it immediately. +::: + ### Sync issues / bot falls behind **Cause**: Long-running tool executions can delay the sync loop, or the homeserver is slow. diff --git a/website/docs/user-guide/messaging/wecom-callback.md b/website/docs/user-guide/messaging/wecom-callback.md new file mode 100644 index 0000000000..4662942769 --- /dev/null +++ b/website/docs/user-guide/messaging/wecom-callback.md @@ -0,0 +1,147 @@ +--- +sidebar_position: 15 +--- + +# WeCom Callback (Self-Built App) + +Connect Hermes to WeCom (Enterprise WeChat) as a self-built enterprise application using the callback/webhook model. + +:::info WeCom Bot vs WeCom Callback +Hermes supports two WeCom integration modes: +- **[WeCom Bot](wecom.md)** — bot-style, connects via WebSocket. Simpler setup, works in group chats. +- **WeCom Callback** (this page) — self-built app, receives encrypted XML callbacks. Shows as a first-class app in users' WeCom sidebar. Supports multi-corp routing. +::: + +## How It Works + +1. You register a self-built application in the WeCom Admin Console +2. WeCom pushes encrypted XML to your HTTP callback endpoint +3. Hermes decrypts the message, queues it for the agent +4. Immediately acknowledges (silent — nothing displayed to the user) +5. The agent processes the request (typically 3–30 minutes) +6. The reply is delivered proactively via the WeCom `message/send` API + +## Prerequisites + +- A WeCom enterprise account with admin access +- `aiohttp` and `httpx` Python packages (included in the default install) +- A publicly reachable server for the callback URL (or a tunnel like ngrok) + +## Setup + +### 1. Create a Self-Built App in WeCom + +1. Go to [WeCom Admin Console](https://work.weixin.qq.com/) → **Applications** → **Create App** +2. Note your **Corp ID** (shown at the top of the admin console) +3. In the app settings, create a **Corp Secret** +4. Note the **Agent ID** from the app's overview page +5. Under **Receive Messages**, configure the callback URL: + - URL: `http://YOUR_PUBLIC_IP:8645/wecom/callback` + - Token: Generate a random token (WeCom provides one) + - EncodingAESKey: Generate a key (WeCom provides one) + +### 2. Configure Environment Variables + +Add to your `.env` file: + +```bash +WECOM_CALLBACK_CORP_ID=your-corp-id +WECOM_CALLBACK_CORP_SECRET=your-corp-secret +WECOM_CALLBACK_AGENT_ID=1000002 +WECOM_CALLBACK_TOKEN=your-callback-token +WECOM_CALLBACK_ENCODING_AES_KEY=your-43-char-aes-key + +# Optional +WECOM_CALLBACK_HOST=0.0.0.0 +WECOM_CALLBACK_PORT=8645 +WECOM_CALLBACK_ALLOWED_USERS=user1,user2 +``` + +### 3. Start the Gateway + +```bash +hermes gateway start +``` + +The callback adapter starts an HTTP server on the configured port. WeCom will verify the callback URL via a GET request, then begin sending messages via POST. + +## Configuration Reference + +Set these in `config.yaml` under `platforms.wecom_callback.extra`, or use environment variables: + +| Setting | Default | Description | +|---------|---------|-------------| +| `corp_id` | — | WeCom enterprise Corp ID (required) | +| `corp_secret` | — | Corp secret for the self-built app (required) | +| `agent_id` | — | Agent ID of the self-built app (required) | +| `token` | — | Callback verification token (required) | +| `encoding_aes_key` | — | 43-character AES key for callback encryption (required) | +| `host` | `0.0.0.0` | Bind address for the HTTP callback server | +| `port` | `8645` | Port for the HTTP callback server | +| `path` | `/wecom/callback` | URL path for the callback endpoint | + +## Multi-App Routing + +For enterprises running multiple self-built apps (e.g., across different departments or subsidiaries), configure the `apps` list in `config.yaml`: + +```yaml +platforms: + wecom_callback: + enabled: true + extra: + host: "0.0.0.0" + port: 8645 + apps: + - name: "dept-a" + corp_id: "ww_corp_a" + corp_secret: "secret-a" + agent_id: "1000002" + token: "token-a" + encoding_aes_key: "key-a-43-chars..." + - name: "dept-b" + corp_id: "ww_corp_b" + corp_secret: "secret-b" + agent_id: "1000003" + token: "token-b" + encoding_aes_key: "key-b-43-chars..." +``` + +Users are scoped by `corp_id:user_id` to prevent cross-corp collisions. When a user sends a message, the adapter records which app (corp) they belong to and routes replies through the correct app's access token. + +## Access Control + +Restrict which users can interact with the app: + +```bash +# Allowlist specific users +WECOM_CALLBACK_ALLOWED_USERS=zhangsan,lisi,wangwu + +# Or allow all users +WECOM_CALLBACK_ALLOW_ALL_USERS=true +``` + +## Endpoints + +The adapter exposes: + +| Method | Path | Purpose | +|--------|------|---------| +| GET | `/wecom/callback` | URL verification handshake (WeCom sends this during setup) | +| POST | `/wecom/callback` | Encrypted message callback (WeCom sends user messages here) | +| GET | `/health` | Health check — returns `{"status": "ok"}` | + +## Encryption + +All callback payloads are encrypted with AES-CBC using the EncodingAESKey. The adapter handles: + +- **Inbound**: Decrypt XML payload, verify SHA1 signature +- **Outbound**: Replies sent via proactive API (not encrypted callback response) + +The crypto implementation is compatible with Tencent's official WXBizMsgCrypt SDK. + +## Limitations + +- **No streaming** — replies arrive as complete messages after the agent finishes +- **No typing indicators** — the callback model doesn't support typing status +- **Text only** — currently supports text messages; image/file/voice not yet implemented +- **Response latency** — agent sessions take 3–30 minutes; users see the reply when processing completes diff --git a/website/sidebars.ts b/website/sidebars.ts index 52fd589c7f..973cfe89cb 100644 --- a/website/sidebars.ts +++ b/website/sidebars.ts @@ -108,6 +108,7 @@ const sidebars: SidebarsConfig = { 'user-guide/messaging/dingtalk', 'user-guide/messaging/feishu', 'user-guide/messaging/wecom', + 'user-guide/messaging/wecom-callback', 'user-guide/messaging/weixin', 'user-guide/messaging/bluebubbles', 'user-guide/messaging/open-webui', @@ -175,6 +176,7 @@ const sidebars: SidebarsConfig = { items: [ 'developer-guide/adding-tools', 'developer-guide/adding-providers', + 'developer-guide/adding-platform-adapters', 'developer-guide/memory-provider-plugin', 'developer-guide/context-engine-plugin', 'developer-guide/creating-skills',