diff --git a/cli.py b/cli.py index 61cb8d966c..19f1ad6ee1 100755 --- a/cli.py +++ b/cli.py @@ -31,19 +31,9 @@ os.environ["HERMES_QUIET"] = "1" # Our own modules import yaml -# prompt_toolkit for fixed input area TUI +# prompt_toolkit for input only (readline replacement) +from prompt_toolkit import PromptSession from prompt_toolkit.history import FileHistory -from prompt_toolkit.styles import Style as PTStyle -from prompt_toolkit.patch_stdout import patch_stdout -from prompt_toolkit.application import Application -from prompt_toolkit.layout import Layout, HSplit, Window, FormattedTextControl, ConditionalContainer -from prompt_toolkit.layout.processors import Processor, Transformation, PasswordProcessor, ConditionalProcessor -from prompt_toolkit.filters import Condition -from prompt_toolkit.layout.dimension import Dimension -from prompt_toolkit.layout.menus import CompletionsMenu -from prompt_toolkit.widgets import TextArea -from prompt_toolkit.key_binding import KeyBindings -from prompt_toolkit import print_formatted_text as _pt_print from prompt_toolkit.formatted_text import ANSI as _PT_ANSI import threading import queue @@ -668,36 +658,8 @@ _DIM = "\033[2m" _RST = "\033[0m" def _cprint(text: str): - """Print ANSI-colored text through prompt_toolkit's native renderer. - - Raw ANSI escapes written via print() are swallowed by patch_stdout's - StdoutProxy. Routing through print_formatted_text(ANSI(...)) lets - prompt_toolkit parse the escapes and render real colors. - """ - _pt_print(_PT_ANSI(text)) - - -class ChatConsole: - """Rich Console adapter for prompt_toolkit's patch_stdout context. - - Captures Rich's rendered ANSI output and routes it through _cprint - so colors and markup render correctly inside the interactive chat loop. - Drop-in replacement for Rich Console — just pass this to any function - that expects a console.print() interface. - """ - - def __init__(self): - from io import StringIO - self._buffer = StringIO() - self._inner = Console(file=self._buffer, force_terminal=True, highlight=False) - - def print(self, *args, **kwargs): - self._buffer.seek(0) - self._buffer.truncate() - self._inner.print(*args, **kwargs) - output = self._buffer.getvalue() - for line in output.rstrip("\n").split("\n"): - _cprint(line) + """Print ANSI-colored text to stdout.""" + print(text) # ASCII Art - HERMES-AGENT logo (full width, single line - requires ~95 char terminal) HERMES_AGENT_LOGO = """[bold #FFD700]██╗ ██╗███████╗██████╗ ███╗ ███╗███████╗███████╗ █████╗ ██████╗ ███████╗███╗ ██╗████████╗[/] @@ -1157,7 +1119,7 @@ class HermesCLI: # Agent will be initialized on first use self.agent: Optional[AIAgent] = None - self._app = None # prompt_toolkit Application (set in run()) + self._app = None # retained for backward compat (no longer used by TUI) # Conversation state self.conversation_history: List[Dict[str, Any]] = [] @@ -1185,15 +1147,11 @@ class HermesCLI: # History file for persistent input recall across sessions self._history_file = Path.home() / ".hermes_history" - self._last_invalidate: float = 0.0 # throttle UI repaints + self._last_invalidate: float = 0.0 def _invalidate(self, min_interval: float = 0.25) -> None: - """Throttled UI repaint — prevents terminal blinking on slow/SSH connections.""" - import time as _time - now = _time.monotonic() - if hasattr(self, "_app") and self._app and (now - self._last_invalidate) >= min_interval: - self._last_invalidate = now - self._app.invalidate() + """No-op — retained for callback compatibility.""" + pass def _normalize_model_for_provider(self, resolved_provider: str) -> bool: """Strip provider prefixes and swap the default model for Codex. @@ -1385,6 +1343,7 @@ class HermesCLI: platform="cli", session_db=self._session_db, clarify_callback=self._clarify_callback, + stream_delta_callback=self._stream_delta, honcho_session_key=self.session_id, fallback_model=self._fallback_model, ) @@ -2313,7 +2272,7 @@ class HermesCLI: def _handle_skills_command(self, cmd: str): """Handle /skills slash command — delegates to hermes_cli.skills_hub.""" from hermes_cli.skills_hub import handle_skills_slash - handle_skills_slash(cmd, ChatConsole()) + handle_skills_slash(cmd, self.console) def _show_gateway_status(self): """Show status of the gateway and connected messaging platforms.""" @@ -2403,47 +2362,10 @@ class HermesCLI: self.agent.flush_memories(self.conversation_history) except Exception: pass - # Clear terminal screen. Inside the TUI, Rich's console.clear() - # goes through patch_stdout's StdoutProxy which swallows the - # screen-clear escape sequences. Use prompt_toolkit's output - # object directly to actually clear the terminal. - if self._app: - out = self._app.output - out.erase_screen() - out.cursor_goto(0, 0) - out.flush() - else: - self.console.clear() - # Reset conversation + self.console.clear() self.conversation_history = [] - # Show fresh banner. Inside the TUI we must route Rich output - # through ChatConsole (which uses prompt_toolkit's native ANSI - # renderer) instead of self.console (which writes raw to stdout - # and gets mangled by patch_stdout). - if self._app: - cc = ChatConsole() - term_w = shutil.get_terminal_size().columns - if self.compact or term_w < 80: - cc.print(_build_compact_banner()) - else: - tools = get_tool_definitions(enabled_toolsets=self.enabled_toolsets, quiet_mode=True) - cwd = os.getenv("TERMINAL_CWD", os.getcwd()) - ctx_len = None - if hasattr(self, 'agent') and self.agent and hasattr(self.agent, 'context_compressor'): - ctx_len = self.agent.context_compressor.context_length - build_welcome_banner( - console=cc, - model=self.model, - cwd=cwd, - tools=tools, - enabled_toolsets=self.enabled_toolsets, - session_id=self.session_id, - context_length=ctx_len, - ) - _cprint(" ✨ (◕‿◕)✨ Fresh start! Screen cleared and conversation reset.\n") - else: - self.show_banner() - print(" ✨ (◕‿◕)✨ Fresh start! Screen cleared and conversation reset.\n") + self.show_banner() + print(" ✨ (◕‿◕)✨ Fresh start! Screen cleared and conversation reset.\n") elif cmd_lower == "/history": self.show_history() elif cmd_lower.startswith("/title"): @@ -2904,6 +2826,11 @@ class HermesCLI: except Exception as e: print(f" ❌ MCP reload failed: {e}") + def _stream_delta(self, text: str) -> None: + """Write streaming token directly to stdout.""" + sys.stdout.write(text) + sys.stdout.flush() + def _clarify_callback(self, question, choices): """ Platform callback for the clarify tool. Called from the agent thread. @@ -3041,152 +2968,53 @@ class HermesCLI: self._approval_deadline = 0 self._invalidate() def chat(self, message, images: list = None) -> Optional[str]: - """ - Send a message to the agent and get a response. - - Handles streaming output, interrupt detection (user typing while agent - is working), and re-queueing of interrupted messages. - - Uses a dedicated _interrupt_queue (separate from _pending_input) to avoid - race conditions between the process_loop and interrupt monitoring. Messages - typed while the agent is running go to _interrupt_queue; messages typed while - idle go to _pending_input. - - Args: - message: The user's message (str or multimodal content list) - images: Optional list of Path objects for attached images - - Returns: - The agent's response, or None on error - """ - # Refresh provider credentials if needed (handles key rotation transparently) + """Send a message and get a response. Runs synchronously — streaming + tokens go directly to stdout via stream_delta_callback.""" if not self._ensure_runtime_credentials(): return None - - # Initialize agent if needed if not self._init_agent(): return None - - # Pre-process images through the vision tool (Gemini Flash) so the - # main model receives text descriptions instead of raw base64 image - # content — works with any model, not just vision-capable ones. + if images: message = self._preprocess_images_with_vision( - message if isinstance(message, str) else "", images + message if isinstance(message, str) else "", images) + + self.conversation_history.append({"role": "user", "content": message}) + + w = shutil.get_terminal_size().columns + print(f"{_GOLD}{'─' * w}{_RST}", flush=True) + + try: + result = self.agent.run_conversation( + user_message=message, + conversation_history=self.conversation_history[:-1], + task_id=self.session_id, ) - # Add user message to history - self.conversation_history.append({"role": "user", "content": message}) - - w = shutil.get_terminal_size().columns - _cprint(f"{_GOLD}{'─' * w}{_RST}") - print(flush=True) - - try: - # Run the conversation with interrupt monitoring - result = None - - def run_agent(): - nonlocal result - result = self.agent.run_conversation( - user_message=message, - conversation_history=self.conversation_history[:-1], # Exclude the message we just added - task_id=self.session_id, - ) - - # Start agent in background thread - agent_thread = threading.Thread(target=run_agent) - agent_thread.start() - - # Monitor the dedicated interrupt queue while the agent runs. - # _interrupt_queue is separate from _pending_input, so process_loop - # and chat() never compete for the same queue. - # When a clarify question is active, user input is handled entirely - # by the Enter key binding (routed to the clarify response queue), - # so we skip interrupt processing to avoid stealing that input. - interrupt_msg = None - while agent_thread.is_alive(): - if hasattr(self, '_interrupt_queue'): - try: - interrupt_msg = self._interrupt_queue.get(timeout=0.1) - if interrupt_msg: - # If clarify is active, the Enter handler routes - # input directly; this queue shouldn't have anything. - # But if it does (race condition), don't interrupt. - if self._clarify_state or self._clarify_freetext: - continue - print(f"\n⚡ New message detected, interrupting...") - self.agent.interrupt(interrupt_msg) - break - except queue.Empty: - pass # Queue empty or timeout, continue waiting - else: - # Fallback for non-interactive mode (e.g., single-query) - agent_thread.join(0.1) - - agent_thread.join() # Ensure agent thread completes - - # Drain any remaining agent output still in the StdoutProxy - # buffer so tool/status lines render ABOVE our response box. - # The flush pushes data into the renderer queue; the short - # sleep lets the renderer actually paint it before we draw. - import time as _time - sys.stdout.flush() - _time.sleep(0.15) - - # Update history with full conversation self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history - - # Get the final response response = result.get("final_response", "") if result else "" - - # Handle failed results (e.g., non-retryable errors like invalid model) - if result and result.get("failed") and not response: - error_detail = result.get("error", "Unknown error") - response = f"Error: {error_detail}" - - # Handle interrupt - check if we were interrupted - pending_message = None - if result and result.get("interrupted"): - pending_message = result.get("interrupt_message") or interrupt_msg - # Add indicator that we were interrupted - if response and pending_message: - response = response + "\n\n---\n_[Interrupted - processing new message]_" - - if response: - w = shutil.get_terminal_size().columns - label = " ⚕ Hermes " - fill = w - 2 - len(label) # 2 for ╭ and ╮ - top = f"{_GOLD}╭─{label}{'─' * max(fill - 1, 0)}╮{_RST}" - bot = f"{_GOLD}╰{'─' * (w - 2)}╯{_RST}" - # Render box + response as a single _cprint call so - # nothing can interleave between the box borders. - _cprint(f"\n{top}\n{response}\n\n{bot}") - - # Play terminal bell when agent finishes (if enabled). - # Works over SSH — the bell propagates to the user's terminal. + if result and result.get("failed") and not response: + response = f"Error: {result.get('error', 'Unknown error')}" + + # If streaming was active, tokens were already printed to stdout. + # If not (codex path, non-streaming fallback), print the response. + if response and not self.agent.stream_delta_callback: + print(f"\n{response}") + print(flush=True) + w = shutil.get_terminal_size().columns + print(f"{_GOLD}{'─' * w}{_RST}") + if self.bell_on_complete: sys.stdout.write("\a") sys.stdout.flush() - - # Combine all interrupt messages (user may have typed multiple while waiting) - # and re-queue as one prompt for process_loop - if pending_message and hasattr(self, '_pending_input'): - all_parts = [pending_message] - while not self._interrupt_queue.empty(): - try: - extra = self._interrupt_queue.get_nowait() - if extra: - all_parts.append(extra) - except queue.Empty: - break - combined = "\n".join(all_parts) - print(f"\n📨 Queued: '{combined[:50]}{'...' if len(combined) > 50 else ''}'") - self._pending_input.put(combined) - + return response - + except KeyboardInterrupt: + if self.agent: + self.agent.interrupt() + print("\n⚡ Interrupted") + return None except Exception as e: print(f"Error: {e}") return None @@ -3218,11 +3046,11 @@ class HermesCLI: print("Goodbye! ⚕") def run(self): - """Run the interactive CLI loop with persistent input at bottom.""" + """Run the interactive CLI loop. Uses PromptSession for input, plain + stdout for output. Streaming tokens go directly to stdout — no TUI + framework, no proxy, no layout. Copy/paste and scrolling work natively.""" self.show_banner() - # If resuming a session, load history and display it immediately - # so the user has context before typing their first message. if self._resumed: if self._preload_resumed_session(): self._display_resumed_history() @@ -3230,780 +3058,72 @@ class HermesCLI: self.console.print("[#FFF8DC]Welcome to Hermes Agent! Type your message or /help for commands.[/]") self.console.print() - # State for async operation self._agent_running = False - self._pending_input = queue.Queue() # For normal input (commands + new queries) - self._interrupt_queue = queue.Queue() # For messages typed while agent is running self._should_exit = False - self._last_ctrl_c_time = 0 # Track double Ctrl+C for force exit - - # Clarify tool state: interactive question/answer with the user. - # When the agent calls the clarify tool, _clarify_state is set and - # the prompt_toolkit UI switches to a selection mode. - self._clarify_state = None # dict with question, choices, selected, response_queue - self._clarify_freetext = False # True when user chose "Other" and is typing - self._clarify_deadline = 0 # monotonic timestamp when the clarify times out - - # Sudo password prompt state (similar mechanism to clarify) - self._sudo_state = None # dict with response_queue when active + self._last_ctrl_c_time = 0 + self._clarify_state = None + self._clarify_freetext = False + self._clarify_deadline = 0 + self._sudo_state = None self._sudo_deadline = 0 - - # Dangerous command approval state (similar mechanism to clarify) - self._approval_state = None # dict with command, description, choices, selected, response_queue + self._approval_state = None self._approval_deadline = 0 - - # Clipboard image attachments (paste images into the CLI) self._attached_images: list[Path] = [] self._image_counter = 0 - # Register callbacks so terminal_tool prompts route through our UI set_sudo_password_callback(self._sudo_password_callback) set_approval_callback(self._approval_callback) - - # Key bindings for the input area - kb = KeyBindings() - - @kb.add('enter') - def handle_enter(event): - """Handle Enter key - submit input. - - Routes to the correct queue based on active UI state: - - Sudo password prompt: password goes to sudo response queue - - Approval selection: selected choice goes to approval response queue - - Clarify freetext mode: answer goes to the clarify response queue - - Clarify choice mode: selected choice goes to the clarify response queue - - Agent running: goes to _interrupt_queue (chat() monitors this) - - Agent idle: goes to _pending_input (process_loop monitors this) - Commands (starting with /) always go to _pending_input so they're - handled as commands, not sent as interrupt text to the agent. - """ - # --- Sudo password prompt: submit the typed password --- - if self._sudo_state: - text = event.app.current_buffer.text - self._sudo_state["response_queue"].put(text) - self._sudo_state = None - event.app.current_buffer.reset() - event.app.invalidate() - return - # --- Approval selection: confirm the highlighted choice --- - if self._approval_state: - state = self._approval_state - selected = state["selected"] - choices = state["choices"] - if 0 <= selected < len(choices): - state["response_queue"].put(choices[selected]) - self._approval_state = None - event.app.invalidate() - return - - # --- Clarify freetext mode: user typed their own answer --- - if self._clarify_freetext and self._clarify_state: - text = event.app.current_buffer.text.strip() - if text: - self._clarify_state["response_queue"].put(text) - self._clarify_state = None - self._clarify_freetext = False - event.app.current_buffer.reset() - event.app.invalidate() - return - - # --- Clarify choice mode: confirm the highlighted selection --- - if self._clarify_state and not self._clarify_freetext: - state = self._clarify_state - selected = state["selected"] - choices = state.get("choices") or [] - if selected < len(choices): - state["response_queue"].put(choices[selected]) - self._clarify_state = None - event.app.invalidate() - else: - # "Other" selected → switch to freetext - self._clarify_freetext = True - event.app.invalidate() - return - - # --- Normal input routing --- - text = event.app.current_buffer.text.strip() - has_images = bool(self._attached_images) - if text or has_images: - # Snapshot and clear attached images - images = list(self._attached_images) - self._attached_images.clear() - event.app.invalidate() - # Bundle text + images as a tuple when images are present - payload = (text, images) if images else text - if self._agent_running and not (text and text.startswith("/")): - self._interrupt_queue.put(payload) - else: - self._pending_input.put(payload) - event.app.current_buffer.reset(append_to_history=True) - - @kb.add('escape', 'enter') - def handle_alt_enter(event): - """Alt+Enter inserts a newline for multi-line input.""" - event.current_buffer.insert_text('\n') - - @kb.add('c-j') - def handle_ctrl_enter(event): - """Ctrl+Enter (c-j) inserts a newline. Most terminals send c-j for Ctrl+Enter.""" - event.current_buffer.insert_text('\n') - - # --- Clarify tool: arrow-key navigation for multiple-choice questions --- - - @kb.add('up', filter=Condition(lambda: bool(self._clarify_state) and not self._clarify_freetext)) - def clarify_up(event): - """Move selection up in clarify choices.""" - if self._clarify_state: - self._clarify_state["selected"] = max(0, self._clarify_state["selected"] - 1) - event.app.invalidate() - - @kb.add('down', filter=Condition(lambda: bool(self._clarify_state) and not self._clarify_freetext)) - def clarify_down(event): - """Move selection down in clarify choices.""" - if self._clarify_state: - choices = self._clarify_state.get("choices") or [] - max_idx = len(choices) # last index is the "Other" option - self._clarify_state["selected"] = min(max_idx, self._clarify_state["selected"] + 1) - event.app.invalidate() - - # --- Dangerous command approval: arrow-key navigation --- - - @kb.add('up', filter=Condition(lambda: bool(self._approval_state))) - def approval_up(event): - if self._approval_state: - self._approval_state["selected"] = max(0, self._approval_state["selected"] - 1) - event.app.invalidate() - - @kb.add('down', filter=Condition(lambda: bool(self._approval_state))) - def approval_down(event): - if self._approval_state: - max_idx = len(self._approval_state["choices"]) - 1 - self._approval_state["selected"] = min(max_idx, self._approval_state["selected"] + 1) - event.app.invalidate() - - # --- History navigation: up/down browse history in normal input mode --- - # The TextArea is multiline, so by default up/down only move the cursor. - # Buffer.auto_up/auto_down handle both: cursor movement when multi-line, - # history browsing when on the first/last line (or single-line input). - _normal_input = Condition( - lambda: not self._clarify_state and not self._approval_state and not self._sudo_state - ) - - @kb.add('up', filter=_normal_input) - def history_up(event): - """Up arrow: browse history when on first line, else move cursor up.""" - event.app.current_buffer.auto_up(count=event.arg) - - @kb.add('down', filter=_normal_input) - def history_down(event): - """Down arrow: browse history when on last line, else move cursor down.""" - event.app.current_buffer.auto_down(count=event.arg) - - @kb.add('c-c') - def handle_ctrl_c(event): - """Handle Ctrl+C - cancel interactive prompts, interrupt agent, or exit. - - Priority: - 1. Cancel active sudo/approval/clarify prompt - 2. Interrupt the running agent (first press) - 3. Force exit (second press within 2s, or when idle) - """ - import time as _time - now = _time.time() - - # Cancel sudo prompt - if self._sudo_state: - self._sudo_state["response_queue"].put("") - self._sudo_state = None - event.app.current_buffer.reset() - event.app.invalidate() - return - - # Cancel approval prompt (deny) - if self._approval_state: - self._approval_state["response_queue"].put("deny") - self._approval_state = None - event.app.invalidate() - return - - # Cancel clarify prompt - if self._clarify_state: - self._clarify_state["response_queue"].put( - "The user cancelled. Use your best judgement to proceed." - ) - self._clarify_state = None - self._clarify_freetext = False - event.app.current_buffer.reset() - event.app.invalidate() - return - - if self._agent_running and self.agent: - if now - self._last_ctrl_c_time < 2.0: - print("\n⚡ Force exiting...") - self._should_exit = True - event.app.exit() - return - - self._last_ctrl_c_time = now - print("\n⚡ Interrupting agent... (press Ctrl+C again to force exit)") - self.agent.interrupt() - else: - # If there's text or images, clear them (like bash). - # If everything is already empty, exit. - if event.app.current_buffer.text or self._attached_images: - event.app.current_buffer.reset() - self._attached_images.clear() - event.app.invalidate() - else: - self._should_exit = True - event.app.exit() - - @kb.add('c-d') - def handle_ctrl_d(event): - """Handle Ctrl+D - exit.""" - self._should_exit = True - event.app.exit() - - from prompt_toolkit.keys import Keys - - @kb.add(Keys.BracketedPaste, eager=True) - def handle_paste(event): - """Handle terminal paste — detect clipboard images. - - When the terminal supports bracketed paste, Ctrl+V / Cmd+V - triggers this with the pasted text. We also check the - clipboard for an image on every paste event. - """ - pasted_text = event.data or "" - if self._try_attach_clipboard_image(): - event.app.invalidate() - if pasted_text: - event.current_buffer.insert_text(pasted_text) - - @kb.add('c-v') - def handle_ctrl_v(event): - """Fallback image paste for terminals without bracketed paste. - - On Linux terminals (GNOME Terminal, Konsole, etc.), Ctrl+V - sends raw byte 0x16 instead of triggering a paste. This - binding catches that and checks the clipboard for images. - On terminals that DO intercept Ctrl+V for paste (macOS - Terminal, iTerm2, VSCode, Windows Terminal), the bracketed - paste handler fires instead and this binding never triggers. - """ - if self._try_attach_clipboard_image(): - event.app.invalidate() - - @kb.add('escape', 'v') - def handle_alt_v(event): - """Alt+V — paste image from clipboard. - - Alt key combos pass through all terminal emulators (sent as - ESC + key), unlike Ctrl+V which terminals intercept for text - paste. This is the reliable way to attach clipboard images - on WSL2, VSCode, and any terminal over SSH where Ctrl+V - can't reach the application for image-only clipboard. - """ - if self._try_attach_clipboard_image(): - event.app.invalidate() - else: - # No image found — show a hint - pass # silent when no image (avoid noise on accidental press) - - # Dynamic prompt: shows Hermes symbol when agent is working, - # or answer prompt when clarify freetext mode is active. - cli_ref = self - - def get_prompt(): - if cli_ref._sudo_state: - return [('class:sudo-prompt', '🔐 ❯ ')] - if cli_ref._approval_state: - return [('class:prompt-working', '⚠ ❯ ')] - if cli_ref._clarify_freetext: - return [('class:clarify-selected', '✎ ❯ ')] - if cli_ref._clarify_state: - return [('class:prompt-working', '? ❯ ')] - if cli_ref._agent_running: - return [('class:prompt-working', '⚕ ❯ ')] - return [('class:prompt', '❯ ')] - - # Create the input area with multiline (shift+enter), autocomplete, and paste handling - input_area = TextArea( - height=Dimension(min=1, max=8, preferred=1), - prompt=get_prompt, - style='class:input-area', - multiline=True, - wrap_lines=True, + from hermes_cli.commands import SlashCommandCompleter + session = PromptSession( history=FileHistory(str(self._history_file)), completer=SlashCommandCompleter(skill_commands_provider=lambda: _skill_commands), - complete_while_typing=True, ) - # Dynamic height: accounts for both explicit newlines AND visual - # wrapping of long lines so the input area always fits its content. - # The prompt characters ("❯ " etc.) consume ~4 columns. - def _input_height(): - try: - doc = input_area.buffer.document - available_width = shutil.get_terminal_size().columns - 4 # subtract prompt width - if available_width < 10: - available_width = 40 - visual_lines = 0 - for line in doc.lines: - # Each logical line takes at least 1 visual row; long lines wrap - if len(line) == 0: - visual_lines += 1 - else: - visual_lines += max(1, -(-len(line) // available_width)) # ceil division - return min(max(visual_lines, 1), 8) - except Exception: - return 1 - - input_area.window.height = _input_height - - # Paste collapsing: detect large pastes and save to temp file - _paste_counter = [0] - _prev_text_len = [0] - - def _on_text_changed(buf): - """Detect large pastes and collapse them to a file reference.""" - text = buf.text - line_count = text.count('\n') - chars_added = len(text) - _prev_text_len[0] - _prev_text_len[0] = len(text) - # Heuristic: a real paste adds many characters at once (not just a - # single newline from Alt+Enter) AND the result has 5+ lines. - if line_count >= 5 and chars_added > 1 and not text.startswith('/'): - _paste_counter[0] += 1 - # Save to temp file - paste_dir = Path(os.path.expanduser("~/.hermes/pastes")) - paste_dir.mkdir(parents=True, exist_ok=True) - paste_file = paste_dir / f"paste_{_paste_counter[0]}_{datetime.now().strftime('%H%M%S')}.txt" - paste_file.write_text(text, encoding="utf-8") - # Replace buffer with compact reference - buf.text = f"[Pasted text #{_paste_counter[0]}: {line_count + 1} lines → {paste_file}]" - buf.cursor_position = len(buf.text) - - input_area.buffer.on_text_changed += _on_text_changed - - # --- Input processors for password masking and inline placeholder --- - - # Mask input with '*' when the sudo password prompt is active - input_area.control.input_processors.append( - ConditionalProcessor( - PasswordProcessor(), - filter=Condition(lambda: bool(cli_ref._sudo_state)), - ) - ) - - class _PlaceholderProcessor(Processor): - """Render grayed-out placeholder text inside the input when empty.""" - def __init__(self, get_text): - self._get_text = get_text - - def apply_transformation(self, ti): - if not ti.document.text and ti.lineno == 0: - text = self._get_text() - if text: - # Append after existing fragments (preserves the ❯ prompt) - return Transformation(fragments=ti.fragments + [('class:placeholder', text)]) - return Transformation(fragments=ti.fragments) - - def _get_placeholder(): - if cli_ref._sudo_state: - return "type password (hidden), Enter to skip" - if cli_ref._approval_state: - return "" - if cli_ref._clarify_state: - return "" - if cli_ref._agent_running: - return "type a message + Enter to interrupt, Ctrl+C to cancel" - return "" - - input_area.control.input_processors.append(_PlaceholderProcessor(_get_placeholder)) - - # Hint line above input: shown only for interactive prompts that need - # extra instructions (sudo countdown, approval navigation, clarify). - # The agent-running interrupt hint is now an inline placeholder above. - def get_hint_text(): - import time as _time - - if cli_ref._sudo_state: - remaining = max(0, int(cli_ref._sudo_deadline - _time.monotonic())) - return [ - ('class:hint', ' password hidden · Enter to skip'), - ('class:clarify-countdown', f' ({remaining}s)'), - ] - - if cli_ref._approval_state: - remaining = max(0, int(cli_ref._approval_deadline - _time.monotonic())) - return [ - ('class:hint', ' ↑/↓ to select, Enter to confirm'), - ('class:clarify-countdown', f' ({remaining}s)'), - ] - - if cli_ref._clarify_state: - remaining = max(0, int(cli_ref._clarify_deadline - _time.monotonic())) - countdown = f' ({remaining}s)' if cli_ref._clarify_deadline else '' - if cli_ref._clarify_freetext: - return [ - ('class:hint', ' type your answer and press Enter'), - ('class:clarify-countdown', countdown), - ] - return [ - ('class:hint', ' ↑/↓ to select, Enter to confirm'), - ('class:clarify-countdown', countdown), - ] - - return [] - - def get_hint_height(): - if cli_ref._sudo_state or cli_ref._approval_state or cli_ref._clarify_state: - return 1 - # Keep a 1-line spacer while agent runs so output doesn't push - # right up against the top rule of the input area - return 1 if cli_ref._agent_running else 0 - - spacer = Window( - content=FormattedTextControl(get_hint_text), - height=get_hint_height, - ) - - # --- Clarify tool: dynamic display widget for questions + choices --- - - def _get_clarify_display(): - """Build styled text for the clarify question/choices panel.""" - state = cli_ref._clarify_state - if not state: - return [] - - question = state["question"] - choices = state.get("choices") or [] - selected = state.get("selected", 0) - - lines = [] - # Box top border - lines.append(('class:clarify-border', '╭─ ')) - lines.append(('class:clarify-title', 'Hermes needs your input')) - lines.append(('class:clarify-border', ' ─────────────────────────────╮\n')) - lines.append(('class:clarify-border', '│\n')) - - # Question text - lines.append(('class:clarify-border', '│ ')) - lines.append(('class:clarify-question', question)) - lines.append(('', '\n')) - lines.append(('class:clarify-border', '│\n')) - - if choices: - # Multiple-choice mode: show selectable options - for i, choice in enumerate(choices): - lines.append(('class:clarify-border', '│ ')) - if i == selected and not cli_ref._clarify_freetext: - lines.append(('class:clarify-selected', f'❯ {choice}')) - else: - lines.append(('class:clarify-choice', f' {choice}')) - lines.append(('', '\n')) - - # "Other" option (5th line, only shown when choices exist) - other_idx = len(choices) - lines.append(('class:clarify-border', '│ ')) - if selected == other_idx and not cli_ref._clarify_freetext: - lines.append(('class:clarify-selected', '❯ Other (type your answer)')) - elif cli_ref._clarify_freetext: - lines.append(('class:clarify-active-other', '❯ Other (type below)')) - else: - lines.append(('class:clarify-choice', ' Other (type your answer)')) - lines.append(('', '\n')) - - lines.append(('class:clarify-border', '│\n')) - lines.append(('class:clarify-border', '╰──────────────────────────────────────────────────╯\n')) - return lines - - clarify_widget = ConditionalContainer( - Window( - FormattedTextControl(_get_clarify_display), - wrap_lines=True, - ), - filter=Condition(lambda: cli_ref._clarify_state is not None), - ) - - # --- Sudo password: display widget --- - - def _get_sudo_display(): - state = cli_ref._sudo_state - if not state: - return [] - lines = [] - lines.append(('class:sudo-border', '╭─ ')) - lines.append(('class:sudo-title', '🔐 Sudo Password Required')) - lines.append(('class:sudo-border', ' ──────────────────────────╮\n')) - lines.append(('class:sudo-border', '│\n')) - lines.append(('class:sudo-border', '│ ')) - lines.append(('class:sudo-text', 'Enter password below (hidden), or press Enter to skip')) - lines.append(('', '\n')) - lines.append(('class:sudo-border', '│\n')) - lines.append(('class:sudo-border', '╰──────────────────────────────────────────────────╯\n')) - return lines - - sudo_widget = ConditionalContainer( - Window( - FormattedTextControl(_get_sudo_display), - wrap_lines=True, - ), - filter=Condition(lambda: cli_ref._sudo_state is not None), - ) - - # --- Dangerous command approval: display widget --- - - def _get_approval_display(): - state = cli_ref._approval_state - if not state: - return [] - command = state["command"] - description = state["description"] - choices = state["choices"] - selected = state.get("selected", 0) - - cmd_display = command[:70] + '...' if len(command) > 70 else command - choice_labels = { - "once": "Allow once", - "session": "Allow for this session", - "always": "Add to permanent allowlist", - "deny": "Deny", - } - - lines = [] - lines.append(('class:approval-border', '╭─ ')) - lines.append(('class:approval-title', '⚠️ Dangerous Command')) - lines.append(('class:approval-border', ' ───────────────────────────────╮\n')) - lines.append(('class:approval-border', '│\n')) - lines.append(('class:approval-border', '│ ')) - lines.append(('class:approval-desc', description)) - lines.append(('', '\n')) - lines.append(('class:approval-border', '│ ')) - lines.append(('class:approval-cmd', cmd_display)) - lines.append(('', '\n')) - lines.append(('class:approval-border', '│\n')) - for i, choice in enumerate(choices): - lines.append(('class:approval-border', '│ ')) - label = choice_labels.get(choice, choice) - if i == selected: - lines.append(('class:approval-selected', f'❯ {label}')) - else: - lines.append(('class:approval-choice', f' {label}')) - lines.append(('', '\n')) - lines.append(('class:approval-border', '│\n')) - lines.append(('class:approval-border', '╰──────────────────────────────────────────────────────╯\n')) - return lines - - approval_widget = ConditionalContainer( - Window( - FormattedTextControl(_get_approval_display), - wrap_lines=True, - ), - filter=Condition(lambda: cli_ref._approval_state is not None), - ) - - # Horizontal rules above and below the input (bronze, 1 line each). - # The bottom rule moves down as the TextArea grows with newlines. - # Using char='─' instead of hardcoded repetition so the rule - # always spans the full terminal width on any screen size. - input_rule_top = Window( - char='─', - height=1, - style='class:input-rule', - ) - input_rule_bot = Window( - char='─', - height=1, - style='class:input-rule', - ) - - # Image attachment indicator — shows badges like [📎 Image #1] above input - cli_ref = self - - def _get_image_bar(): - if not cli_ref._attached_images: - return [] - base = cli_ref._image_counter - len(cli_ref._attached_images) + 1 - badges = " ".join( - f"[📎 Image #{base + i}]" - for i in range(len(cli_ref._attached_images)) - ) - return [("class:image-badge", f" {badges} ")] - - image_bar = Window( - content=FormattedTextControl(_get_image_bar), - height=Condition(lambda: bool(cli_ref._attached_images)), - ) - - # Layout: interactive prompt widgets + ruled input at bottom. - # The sudo, approval, and clarify widgets appear above the input when - # the corresponding interactive prompt is active. - layout = Layout( - HSplit([ - Window(height=0), - sudo_widget, - approval_widget, - clarify_widget, - spacer, - input_rule_top, - image_bar, - input_area, - input_rule_bot, - CompletionsMenu(max_height=12, scroll_offset=1), - ]) - ) - - # Style for the application - style = PTStyle.from_dict({ - 'input-area': '#FFF8DC', - 'placeholder': '#555555 italic', - 'prompt': '#FFF8DC', - 'prompt-working': '#888888 italic', - 'hint': '#555555 italic', - # Bronze horizontal rules around the input area - 'input-rule': '#CD7F32', - # Clipboard image attachment badges - 'image-badge': '#87CEEB bold', - 'completion-menu': 'bg:#1a1a2e #FFF8DC', - 'completion-menu.completion': 'bg:#1a1a2e #FFF8DC', - 'completion-menu.completion.current': 'bg:#333355 #FFD700', - 'completion-menu.meta.completion': 'bg:#1a1a2e #888888', - 'completion-menu.meta.completion.current': 'bg:#333355 #FFBF00', - # Clarify question panel - 'clarify-border': '#CD7F32', - 'clarify-title': '#FFD700 bold', - 'clarify-question': '#FFF8DC bold', - 'clarify-choice': '#AAAAAA', - 'clarify-selected': '#FFD700 bold', - 'clarify-active-other': '#FFD700 italic', - 'clarify-countdown': '#CD7F32', - # Sudo password panel - 'sudo-prompt': '#FF6B6B bold', - 'sudo-border': '#CD7F32', - 'sudo-title': '#FF6B6B bold', - 'sudo-text': '#FFF8DC', - # Dangerous command approval panel - 'approval-border': '#CD7F32', - 'approval-title': '#FF8C00 bold', - 'approval-desc': '#FFF8DC bold', - 'approval-cmd': '#AAAAAA italic', - 'approval-choice': '#AAAAAA', - 'approval-selected': '#FFD700 bold', - }) - - # Create the application - app = Application( - layout=layout, - key_bindings=kb, - style=style, - full_screen=False, - mouse_support=False, - ) - self._app = app # Store reference for clarify_callback - - # Background thread to process inputs and run agent - def process_loop(): + atexit.register(_run_cleanup) + try: while not self._should_exit: try: - # Check for pending input with timeout - try: - user_input = self._pending_input.get(timeout=0.1) - except queue.Empty: - continue - - if not user_input: - continue - - # Unpack image payload: (text, [Path, ...]) or plain str - submit_images = [] - if isinstance(user_input, tuple): - user_input, submit_images = user_input - - # Check for commands - if isinstance(user_input, str) and user_input.startswith("/"): - print(f"\n⚙️ {user_input}") - if not self.process_command(user_input): - self._should_exit = True - # Schedule app exit - if app.is_running: - app.exit() - continue - - # Expand paste references back to full content - import re as _re - paste_match = _re.match(r'\[Pasted text #\d+: \d+ lines → (.+)\]', user_input) if isinstance(user_input, str) else None - if paste_match: - paste_path = Path(paste_match.group(1)) - if paste_path.exists(): - full_text = paste_path.read_text(encoding="utf-8") - line_count = full_text.count('\n') + 1 - print() - _cprint(f"{_GOLD}●{_RST} {_BOLD}[Pasted text: {line_count} lines]{_RST}") - user_input = full_text - else: - print() - _cprint(f"{_GOLD}●{_RST} {_BOLD}{user_input}{_RST}") + text = session.prompt(_PT_ANSI(f"{_GOLD}❯{_RST} ")) + except KeyboardInterrupt: + if self._agent_running and self.agent: + self.agent.interrupt() + print("\n⚡ Interrupted") else: - if '\n' in user_input: - first_line = user_input.split('\n')[0] - line_count = user_input.count('\n') + 1 - print() - _cprint(f"{_GOLD}●{_RST} {_BOLD}{first_line}{_RST} {_DIM}(+{line_count - 1} lines){_RST}") - else: - print() - _cprint(f"{_GOLD}●{_RST} {_BOLD}{user_input}{_RST}") - - # Show image attachment count - if submit_images: - n = len(submit_images) - _cprint(f" {_DIM}📎 {n} image{'s' if n > 1 else ''} attached{_RST}") + print("\nGoodbye! ⚕") + break + continue + except EOFError: + break - # Regular chat - run agent - self._agent_running = True - app.invalidate() # Refresh status line - - try: - self.chat(user_input, images=submit_images or None) - finally: - self._agent_running = False - app.invalidate() # Refresh status line - - except Exception as e: - print(f"Error: {e}") - - # Start processing thread - process_thread = threading.Thread(target=process_loop, daemon=True) - process_thread.start() - - # Register atexit cleanup so resources are freed even on unexpected exit - atexit.register(_run_cleanup) - - # Run the application with patch_stdout for proper output handling - try: - with patch_stdout(): - app.run() - except (EOFError, KeyboardInterrupt): - pass + text = text.strip() + if not text: + continue + + if text.startswith("/"): + if not self.process_command(text): + break + continue + + print(f"\n{_GOLD}●{_RST} {_BOLD}{text[:80]}{'...' if len(text) > 80 else ''}{_RST}") + self._agent_running = True + try: + self.chat(text) + finally: + self._agent_running = False finally: - self._should_exit = True - # Flush memories before exit (only for substantial conversations) if self.agent and self.conversation_history: try: self.agent.flush_memories(self.conversation_history) except Exception: pass - # Unregister terminal_tool callbacks to avoid dangling references set_sudo_password_callback(None) set_approval_callback(None) - # Close session in SQLite if hasattr(self, '_session_db') and self._session_db and self.agent: try: self._session_db.end_session(self.agent.session_id, "cli_close") - except Exception as e: - logger.debug("Could not close session in DB: %s", e) + except Exception: + pass _run_cleanup() self._print_exit_summary() diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 4dd9cd25d9..0e21f7f2b9 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -319,7 +319,7 @@ class SendResult: raw_response: Any = None -# Type for message handlers +# Handler may return str (sent by base) or dict(content=..., already_sent=True). MessageHandler = Callable[[MessageEvent], Awaitable[Optional[str]]] @@ -691,11 +691,20 @@ class BasePlatformAdapter(ABC): try: # Call the handler (this can take a while with tool calls) - response = await self._message_handler(event) + handler_result = await self._message_handler(event) + + # Normalise: handler may return str or dict(content, already_sent) + already_sent = False + if isinstance(handler_result, dict): + response = handler_result.get("content") or "" + already_sent = handler_result.get("already_sent", False) + else: + response = handler_result # Send response if any if not response: - logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id) + if not already_sent: + logger.warning("[%s] Handler returned empty/None response for %s", self.name, event.source.chat_id) if response: # Extract MEDIA: tags (from TTS tool) before other processing media_files, response = self.extract_media(response) @@ -706,7 +715,7 @@ class BasePlatformAdapter(ABC): logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response)) # Send the text portion first (if any remains after extractions) - if text_content: + if text_content and not already_sent: logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id) result = await self.send( chat_id=event.source.chat_id, diff --git a/gateway/run.py b/gateway/run.py index 2584521d12..731ffa73ec 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1291,7 +1291,9 @@ class GatewayRunner: # Update session self.session_store.update_session(session_entry.session_key) - + + if agent_result.get("already_sent"): + return {"content": response, "already_sent": True} return response except Exception as e: @@ -2450,6 +2452,83 @@ class GatewayRunner: # Queue for progress messages (thread-safe) progress_queue = queue.Queue() if tool_progress_enabled else None last_tool = [None] # Mutable container for tracking in closure + + # Streaming token queue — same pattern as progress_queue but for + # assistant text deltas. An async drain task sends/edits a single + # platform message with the accumulated text. + stream_queue = queue.Queue() + stream_sent = [False] # set True once any delta was delivered + + def _stream_delta(text: str): + stream_queue.put(text) + + async def send_stream_messages(): + """Drain stream_queue, deliver via send/edit_message.""" + _adapter = self.adapters.get(source.platform) + if not _adapter: + return + + accumulated = [] + msg_id = None + can_edit = True + last_edit_ts = 0.0 + EDIT_INTERVAL = 0.6 # seconds between edits (rate-limit safe) + + while True: + try: + delta = stream_queue.get_nowait() + accumulated.append(delta) + stream_sent[0] = True + + now = asyncio.get_event_loop().time() + if now - last_edit_ts < EDIT_INTERVAL: + # Coalesce — will flush on next poll cycle + await asyncio.sleep(0.05) + continue + + full_text = "".join(accumulated) + if msg_id is None: + res = await _adapter.send( + chat_id=source.chat_id, content=full_text) + if res.success and res.message_id: + msg_id = res.message_id + elif can_edit: + res = await _adapter.edit_message( + chat_id=source.chat_id, + message_id=msg_id, + content=full_text, + ) + if not res.success: + can_edit = False + last_edit_ts = now + + except queue.Empty: + await asyncio.sleep(0.15) + except asyncio.CancelledError: + # Final flush + while not stream_queue.empty(): + try: + accumulated.append(stream_queue.get_nowait()) + except Exception: + break + if accumulated: + full_text = "".join(accumulated) + if msg_id is None: + await _adapter.send( + chat_id=source.chat_id, content=full_text) + elif can_edit: + try: + await _adapter.edit_message( + chat_id=source.chat_id, + message_id=msg_id, + content=full_text, + ) + except Exception: + pass + return + except Exception as e: + logger.error("Stream message error: %s", e) + await asyncio.sleep(0.5) def progress_callback(tool_name: str, preview: str = None, args: dict = None): """Callback invoked by agent when a tool is called.""" @@ -2693,6 +2772,7 @@ class GatewayRunner: session_id=session_id, tool_progress_callback=progress_callback if tool_progress_enabled else None, step_callback=_step_callback_sync if _hooks_ref.loaded_hooks else None, + stream_delta_callback=_stream_delta, platform=platform_key, honcho_session_key=session_key, session_db=self._session_db, @@ -2815,12 +2895,16 @@ class GatewayRunner: "api_calls": result_holder[0].get("api_calls", 0) if result_holder[0] else 0, "tools": tools_holder[0] or [], "history_offset": len(agent_history), + "already_sent": stream_sent[0], } # Start progress message sender if enabled progress_task = None if tool_progress_enabled: progress_task = asyncio.create_task(send_progress_messages()) + + # Start stream message sender + stream_task = asyncio.create_task(send_stream_messages()) # Track this agent as running for this session (for interrupt support) # We do this in a callback after the agent is created @@ -2896,9 +2980,10 @@ class GatewayRunner: session_key=session_key ) finally: - # Stop progress sender and interrupt monitor + # Stop progress sender, stream sender, and interrupt monitor if progress_task: progress_task.cancel() + stream_task.cancel() interrupt_monitor.cancel() # Clean up tracking @@ -2907,7 +2992,7 @@ class GatewayRunner: del self._running_agents[session_key] # Wait for cancelled tasks - for task in [progress_task, interrupt_monitor, tracking_task]: + for task in [progress_task, stream_task, interrupt_monitor, tracking_task]: if task: try: await task diff --git a/pyproject.toml b/pyproject.toml index 5f86cabd2f..de4da9f27c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "fire", "httpx", "rich", + "textual", "tenacity", "pyyaml", "requests", diff --git a/requirements.txt b/requirements.txt index 030c846564..4ed3efd213 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ python-dotenv fire httpx rich +textual tenacity prompt_toolkit pyyaml diff --git a/run_agent.py b/run_agent.py index c1f2623c83..a56a9f02cb 100644 --- a/run_agent.py +++ b/run_agent.py @@ -174,6 +174,7 @@ class AIAgent: tool_progress_callback: callable = None, clarify_callback: callable = None, step_callback: callable = None, + stream_delta_callback: callable = None, max_tokens: int = None, reasoning_config: Dict[str, Any] = None, prefill_messages: List[Dict[str, Any]] = None, @@ -258,6 +259,7 @@ class AIAgent: self.tool_progress_callback = tool_progress_callback self.clarify_callback = clarify_callback self.step_callback = step_callback + self.stream_delta_callback = stream_delta_callback self._last_reported_tool = None # Track for "new tool" mode # Interrupt mechanism for breaking out of tool loops @@ -2158,6 +2160,137 @@ class AIAgent: raise result["error"] return result["response"] + def _interruptible_streaming_api_call(self, api_kwargs: dict, on_first_delta=None): + """Streaming variant of _interruptible_api_call for chat_completions. + + Fires self.stream_delta_callback(text) as content tokens arrive and + accumulates the full response into a SimpleNamespace matching the shape + downstream code expects. Falls back to the non-streaming path when the + provider rejects the stream request. + """ + from types import SimpleNamespace + + result = {"response": None, "error": None} + first_delta_fired = [False] + + def _stream(): + try: + stream_kwargs = {**api_kwargs, "stream": True, + "stream_options": {"include_usage": True}} + stream = self.client.chat.completions.create(**stream_kwargs) + + content_parts = [] + tool_calls_acc = {} + finish_reason = "stop" + usage = None + reasoning_content = None + model = None + + for chunk in stream: + if not chunk.choices: + if hasattr(chunk, "usage") and chunk.usage: + usage = chunk.usage + continue + + choice = chunk.choices[0] + if choice.finish_reason: + finish_reason = choice.finish_reason + if model is None and hasattr(chunk, "model"): + model = chunk.model + + delta = choice.delta + if delta is None: + continue + + if delta.content: + content_parts.append(delta.content) + if not first_delta_fired[0]: + first_delta_fired[0] = True + if on_first_delta: + on_first_delta() + if self.stream_delta_callback: + try: + self.stream_delta_callback(delta.content) + except Exception: + pass + + if delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + if idx not in tool_calls_acc: + tool_calls_acc[idx] = { + "id": tc_delta.id or "", + "type": tc_delta.type or "function", + "function": { + "name": getattr(tc_delta.function, "name", None) or "", + "arguments": getattr(tc_delta.function, "arguments", None) or "", + }, + } + else: + entry = tool_calls_acc[idx] + if tc_delta.id: + entry["id"] = tc_delta.id + fn = tc_delta.function + if fn: + if fn.name: + entry["function"]["name"] = fn.name + if fn.arguments: + entry["function"]["arguments"] += fn.arguments + + rc = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None) + if rc: + reasoning_content = (reasoning_content or "") + rc + + tool_calls_list = None + if tool_calls_acc: + tool_calls_list = [ + SimpleNamespace( + id=tc["id"], call_id=tc["id"], type=tc["type"], + function=SimpleNamespace(name=tc["function"]["name"], + arguments=tc["function"]["arguments"]), + ) + for idx, tc in sorted(tool_calls_acc.items()) + ] + + message = SimpleNamespace( + content="".join(content_parts) or None, + tool_calls=tool_calls_list, + reasoning=reasoning_content, + reasoning_content=reasoning_content, + reasoning_details=None, + ) + result["response"] = SimpleNamespace( + choices=[SimpleNamespace(message=message, finish_reason=finish_reason)], + usage=usage, + model=model, + ) + except Exception as e: + result["error"] = e + + t = threading.Thread(target=_stream, daemon=True) + t.start() + while t.is_alive(): + t.join(timeout=0.3) + if self._interrupt_requested: + try: + self.client.close() + except Exception: + pass + try: + self.client = OpenAI(**self._client_kwargs) + except Exception: + pass + raise InterruptedError("Agent interrupted during streaming API call") + + if result["error"] is not None: + err = result["error"] + err_str = str(err).lower() + if any(kw in err_str for kw in ("stream", "not support", "unsupported")): + logger.debug("Streaming failed (%s), falling back to non-streaming.", err) + return self._interruptible_api_call(api_kwargs) + raise err + return result["response"] + # ── Provider fallback ────────────────────────────────────────────────── # API-key providers: provider → (base_url, [env_var_names]) @@ -3353,12 +3486,24 @@ class AIAgent: if os.getenv("HERMES_DUMP_REQUESTS", "").strip().lower() in {"1", "true", "yes", "on"}: self._dump_api_request_debug(api_kwargs, reason="preflight") - response = self._interruptible_api_call(api_kwargs) + if self.stream_delta_callback and self.api_mode != "codex_responses": + def _stop_spinner(): + nonlocal thinking_spinner + if thinking_spinner: + thinking_spinner.stop("") + thinking_spinner = None + response = self._interruptible_streaming_api_call( + api_kwargs, on_first_delta=_stop_spinner) + # Newline after streamed content so tool lines don't overwrite it + if response and hasattr(response, 'choices') and response.choices: + msg = response.choices[0].message + if msg and msg.content and msg.tool_calls: + print(flush=True) + else: + response = self._interruptible_api_call(api_kwargs) api_duration = time.time() - api_start_time - # Stop thinking spinner silently -- the response box or tool - # execution messages that follow are more informative. if thinking_spinner: thinking_spinner.stop("") thinking_spinner = None @@ -4055,8 +4200,8 @@ class AIAgent: turn_content = assistant_message.content or "" if turn_content and self._has_content_after_think_block(turn_content): self._last_content_with_tools = turn_content - # Show intermediate commentary so the user can follow along - if self.quiet_mode: + # Show intermediate commentary — skip when streaming (already in buffer) + if self.quiet_mode and not self.stream_delta_callback: clean = self._strip_think_blocks(turn_content).strip() if clean: print(f" ┊ 💬 {clean}") diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 0000000000..c368f22323 --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,256 @@ +"""Tests for streaming token output — accumulator shape, callback order, fallback.""" + +import queue +import threading +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, call + +import pytest + +from run_agent import AIAgent + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +def _make_tool_defs(*names): + return [ + {"type": "function", "function": {"name": n, "description": f"{n}", "parameters": {"type": "object", "properties": {}}}} + for n in names + ] + + +@pytest.fixture() +def agent(): + with ( + patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + cb = MagicMock() + a = AIAgent( + api_key="test-key-1234567890", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + stream_delta_callback=cb, + ) + a.client = MagicMock() + a._stream_cb = cb + return a + + +# --------------------------------------------------------------------------- +# Helpers — fake streaming chunks +# --------------------------------------------------------------------------- + +def _chunk(content=None, tool_call_delta=None, finish_reason=None, usage=None, model=None): + delta = SimpleNamespace(content=content, tool_calls=tool_call_delta) + choice = SimpleNamespace(delta=delta, finish_reason=finish_reason) + c = SimpleNamespace(choices=[choice]) + if usage is not None: + c.usage = SimpleNamespace(**usage) + if model: + c.model = model + return c + + +def _usage_chunk(**kw): + c = SimpleNamespace(choices=[], usage=SimpleNamespace(**kw)) + return c + + +def _tc_delta(index, id=None, name=None, arguments=None, type=None): + fn = SimpleNamespace(name=name, arguments=arguments) + return SimpleNamespace(index=index, id=id, type=type, function=fn) + + +# --------------------------------------------------------------------------- +# Tests: accumulator shape +# --------------------------------------------------------------------------- + + +class TestStreamingAccumulator: + def test_text_only_response(self, agent): + """Streaming text-only response produces correct synthetic shape.""" + chunks = [ + _chunk(content="Hello", model="test/m"), + _chunk(content=" world"), + _chunk(finish_reason="stop"), + _usage_chunk(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ] + agent.client.chat.completions.create.return_value = iter(chunks) + + resp = agent._interruptible_streaming_api_call({"model": "test"}) + + assert resp.choices[0].message.content == "Hello world" + assert resp.choices[0].message.tool_calls is None + assert resp.choices[0].finish_reason == "stop" + assert resp.usage.prompt_tokens == 10 + assert resp.model == "test/m" + + def test_tool_call_response(self, agent): + """Streaming tool-call response accumulates function name + arguments.""" + chunks = [ + _chunk(tool_call_delta=[_tc_delta(0, id="call_1", name="web_search", arguments='{"q', type="function")]), + _chunk(tool_call_delta=[_tc_delta(0, arguments='uery": "hi"}')]), + _chunk(finish_reason="tool_calls"), + ] + agent.client.chat.completions.create.return_value = iter(chunks) + + resp = agent._interruptible_streaming_api_call({"model": "test"}) + + tc = resp.choices[0].message.tool_calls + assert tc is not None + assert len(tc) == 1 + assert tc[0].id == "call_1" + assert tc[0].function.name == "web_search" + assert tc[0].function.arguments == '{"query": "hi"}' + assert resp.choices[0].finish_reason == "tool_calls" + + def test_mixed_content_and_tool_calls(self, agent): + """Content + tool calls in same stream are both accumulated.""" + chunks = [ + _chunk(content="Let me check."), + _chunk(tool_call_delta=[_tc_delta(0, id="c1", name="web_search", arguments="{}", type="function")]), + _chunk(finish_reason="tool_calls"), + ] + agent.client.chat.completions.create.return_value = iter(chunks) + + resp = agent._interruptible_streaming_api_call({"model": "test"}) + + assert resp.choices[0].message.content == "Let me check." + assert len(resp.choices[0].message.tool_calls) == 1 + + +class TestStreamingCallbacks: + def test_deltas_fire_in_order(self, agent): + """stream_delta_callback receives content deltas in order.""" + received = [] + agent.stream_delta_callback = lambda t: received.append(t) + chunks = [_chunk(content="a"), _chunk(content="b"), _chunk(content="c"), _chunk(finish_reason="stop")] + agent.client.chat.completions.create.return_value = iter(chunks) + + agent._interruptible_streaming_api_call({"model": "test"}) + + assert received == ["a", "b", "c"] + + def test_on_first_delta_fires_once(self, agent): + first = MagicMock() + chunks = [_chunk(content="x"), _chunk(content="y"), _chunk(finish_reason="stop")] + agent.client.chat.completions.create.return_value = iter(chunks) + + agent._interruptible_streaming_api_call({"model": "test"}, on_first_delta=first) + + first.assert_called_once() + + def test_tool_only_does_not_fire_callback(self, agent): + """Tool-call-only stream does not invoke stream_delta_callback.""" + received = [] + agent.stream_delta_callback = lambda t: received.append(t) + chunks = [ + _chunk(tool_call_delta=[_tc_delta(0, id="c1", name="t", arguments="{}", type="function")]), + _chunk(finish_reason="tool_calls"), + ] + agent.client.chat.completions.create.return_value = iter(chunks) + + agent._interruptible_streaming_api_call({"model": "test"}) + + assert received == [] + + +class TestStreamingFallback: + def test_stream_error_falls_back(self, agent): + """When streaming fails with 'not support', falls back to non-streaming.""" + agent.client.chat.completions.create.side_effect = [ + Exception("streaming not supported by this provider"), + SimpleNamespace( + choices=[SimpleNamespace( + message=SimpleNamespace(content="ok", tool_calls=None, reasoning=None, reasoning_content=None, reasoning_details=None), + finish_reason="stop", + )], + usage=None, + model="test/m", + ), + ] + + resp = agent._interruptible_streaming_api_call({"model": "test"}) + + assert resp.choices[0].message.content == "ok" + assert agent.client.chat.completions.create.call_count == 2 + + def test_non_stream_error_raises(self, agent): + """Non-stream-related errors propagate normally.""" + agent.client.chat.completions.create.side_effect = ValueError("bad request") + + with pytest.raises(ValueError, match="bad request"): + agent._interruptible_streaming_api_call({"model": "test"}) + + +# --------------------------------------------------------------------------- +# Tests: base.py already_sent contract +# --------------------------------------------------------------------------- + +class TestAlreadySentContract: + def _make_adapter(self, send_side_effect=None): + from gateway.platforms.base import BasePlatformAdapter, SendResult + from gateway.config import Platform, PlatformConfig + + class FakeAdapter(BasePlatformAdapter): + async def connect(self): return True + async def disconnect(self): pass + async def get_chat_info(self, chat_id): return {"name": "test"} + async def send(self, chat_id, content, reply_to=None, metadata=None): + if send_side_effect is not None: + send_side_effect(content) + return SendResult(success=True, message_id="1") + + cfg = PlatformConfig(enabled=True) + adapter = FakeAdapter(cfg, Platform.TELEGRAM) + adapter._running = True + return adapter + + @pytest.mark.asyncio + async def test_already_sent_skips_send(self): + """Handler returning already_sent=True prevents base from calling send().""" + from gateway.platforms.base import MessageEvent + from gateway.config import Platform + from gateway.session import SessionSource + + sent = [] + adapter = self._make_adapter(send_side_effect=lambda c: sent.append(c)) + + async def handler(event): + return {"content": "hello", "already_sent": True} + adapter.set_message_handler(handler) + + event = MessageEvent( + text="hi", + source=SessionSource(platform=Platform.TELEGRAM, chat_id="1", user_id="u1"), + ) + await adapter._process_message_background(event, "s1") + + assert sent == [], "send() should not be called when already_sent=True" + + @pytest.mark.asyncio + async def test_string_response_sends_normally(self): + """Handler returning a plain string triggers send() as before.""" + from gateway.platforms.base import MessageEvent + from gateway.config import Platform + from gateway.session import SessionSource + + sent = [] + adapter = self._make_adapter(send_side_effect=lambda c: sent.append(c)) + + async def handler(event): + return "hello" + adapter.set_message_handler(handler) + + event = MessageEvent( + text="hi", + source=SessionSource(platform=Platform.TELEGRAM, chat_id="1", user_id="u1"), + ) + await adapter._process_message_background(event, "s1") + + assert "hello" in sent