mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Implement interrupt handling for agent and CLI input and persistent prompt line at bottom of CLI :)
- Enhanced the AIAgent class to support interrupt requests, allowing for graceful interruption of ongoing tasks and processing of new messages. - Updated the HermesCLI to manage user input in a persistent manner, enabling real-time interruption of the agent's conversation. - Introduced a mechanism in the GatewayRunner to handle incoming messages while an agent is running, allowing for immediate response to user commands. - Improved overall user experience by providing feedback during interruptions and ensuring that pending messages are processed correctly.
This commit is contained in:
parent
beeb7896e0
commit
9bfe185a2e
3 changed files with 336 additions and 34 deletions
200
cli.py
200
cli.py
|
|
@ -33,6 +33,15 @@ from prompt_toolkit.history import FileHistory
|
||||||
from prompt_toolkit.styles import Style as PTStyle
|
from prompt_toolkit.styles import Style as PTStyle
|
||||||
from prompt_toolkit.formatted_text import HTML
|
from prompt_toolkit.formatted_text import HTML
|
||||||
from prompt_toolkit.patch_stdout import patch_stdout
|
from prompt_toolkit.patch_stdout import patch_stdout
|
||||||
|
from prompt_toolkit.application import Application, get_app
|
||||||
|
from prompt_toolkit.buffer import Buffer
|
||||||
|
from prompt_toolkit.layout import Layout, HSplit, Window, FormattedTextControl
|
||||||
|
from prompt_toolkit.layout.processors import BeforeInput
|
||||||
|
from prompt_toolkit.widgets import TextArea
|
||||||
|
from prompt_toolkit.key_binding import KeyBindings
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
|
||||||
# Load environment variables first
|
# Load environment variables first
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
@ -1284,17 +1293,52 @@ class HermesCLI:
|
||||||
print("─" * 60, flush=True)
|
print("─" * 60, flush=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Run the conversation
|
# Run the conversation with interrupt monitoring
|
||||||
result = self.agent.run_conversation(
|
result = None
|
||||||
user_message=message,
|
|
||||||
conversation_history=self.conversation_history[:-1], # Exclude the message we just added
|
def run_agent():
|
||||||
)
|
nonlocal result
|
||||||
|
result = self.agent.run_conversation(
|
||||||
|
user_message=message,
|
||||||
|
conversation_history=self.conversation_history[:-1], # Exclude the message we just added
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start agent in background thread
|
||||||
|
agent_thread = threading.Thread(target=run_agent)
|
||||||
|
agent_thread.start()
|
||||||
|
|
||||||
|
# Monitor for new input in the pending queue while agent runs
|
||||||
|
interrupt_msg = None
|
||||||
|
while agent_thread.is_alive():
|
||||||
|
# Check if there's new input in the queue (from the persistent input area)
|
||||||
|
if hasattr(self, '_pending_input'):
|
||||||
|
try:
|
||||||
|
interrupt_msg = self._pending_input.get(timeout=0.1)
|
||||||
|
if interrupt_msg:
|
||||||
|
print(f"\n⚡ New message detected, interrupting...")
|
||||||
|
self.agent.interrupt(interrupt_msg)
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
pass # Queue empty or timeout, continue waiting
|
||||||
|
else:
|
||||||
|
# Fallback if no queue (shouldn't happen)
|
||||||
|
agent_thread.join(0.1)
|
||||||
|
|
||||||
|
agent_thread.join() # Ensure agent thread completes
|
||||||
|
|
||||||
# Update history with full conversation
|
# Update history with full conversation
|
||||||
self.conversation_history = result.get("messages", self.conversation_history)
|
self.conversation_history = result.get("messages", self.conversation_history) if result else self.conversation_history
|
||||||
|
|
||||||
# Get the final response
|
# Get the final response
|
||||||
response = result.get("final_response", "")
|
response = result.get("final_response", "") if result else ""
|
||||||
|
|
||||||
|
# Handle interrupt - check if we were interrupted
|
||||||
|
pending_message = None
|
||||||
|
if result and result.get("interrupted"):
|
||||||
|
pending_message = result.get("interrupt_message") or interrupt_msg
|
||||||
|
# Add indicator that we were interrupted
|
||||||
|
if response and pending_message:
|
||||||
|
response = response + "\n\n---\n_[Interrupted - processing new message]_"
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
# Use simple print for compatibility with prompt_toolkit's patch_stdout
|
# Use simple print for compatibility with prompt_toolkit's patch_stdout
|
||||||
|
|
@ -1307,6 +1351,11 @@ class HermesCLI:
|
||||||
print()
|
print()
|
||||||
print("─" * 60)
|
print("─" * 60)
|
||||||
|
|
||||||
|
# If we have a pending message from interrupt, process it immediately
|
||||||
|
if pending_message:
|
||||||
|
print(f"\n📨 Processing: '{pending_message[:50]}{'...' if len(pending_message) > 50 else ''}'")
|
||||||
|
return self.chat(pending_message) # Recursive call to handle the new message
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -1345,22 +1394,101 @@ class HermesCLI:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""Run the interactive CLI loop with fixed input at bottom."""
|
"""Run the interactive CLI loop with persistent input at bottom."""
|
||||||
self.show_banner()
|
self.show_banner()
|
||||||
|
|
||||||
# These Rich prints work fine BEFORE patch_stdout
|
|
||||||
self.console.print("[#FFF8DC]Welcome to Hermes Agent! Type your message or /help for commands.[/]")
|
self.console.print("[#FFF8DC]Welcome to Hermes Agent! Type your message or /help for commands.[/]")
|
||||||
self.console.print()
|
self.console.print()
|
||||||
|
|
||||||
# Use patch_stdout to ensure all output appears above the input prompt
|
# State for async operation
|
||||||
with patch_stdout():
|
self._agent_running = False
|
||||||
while True:
|
self._pending_input = queue.Queue()
|
||||||
|
self._should_exit = False
|
||||||
|
|
||||||
|
# Create a persistent input area using prompt_toolkit Application
|
||||||
|
input_buffer = Buffer()
|
||||||
|
|
||||||
|
# Key bindings for the input area
|
||||||
|
kb = KeyBindings()
|
||||||
|
|
||||||
|
@kb.add('enter')
|
||||||
|
def handle_enter(event):
|
||||||
|
"""Handle Enter key - submit input."""
|
||||||
|
text = event.app.current_buffer.text.strip()
|
||||||
|
if text:
|
||||||
|
# Store the input
|
||||||
|
self._pending_input.put(text)
|
||||||
|
# Clear the buffer
|
||||||
|
event.app.current_buffer.reset()
|
||||||
|
|
||||||
|
@kb.add('c-c')
|
||||||
|
def handle_ctrl_c(event):
|
||||||
|
"""Handle Ctrl+C - interrupt or exit."""
|
||||||
|
if self._agent_running and self.agent:
|
||||||
|
print("\n⚡ Interrupting agent...")
|
||||||
|
self.agent.interrupt()
|
||||||
|
else:
|
||||||
|
self._should_exit = True
|
||||||
|
event.app.exit()
|
||||||
|
|
||||||
|
@kb.add('c-d')
|
||||||
|
def handle_ctrl_d(event):
|
||||||
|
"""Handle Ctrl+D - exit."""
|
||||||
|
self._should_exit = True
|
||||||
|
event.app.exit()
|
||||||
|
|
||||||
|
# Create the input area widget
|
||||||
|
input_area = TextArea(
|
||||||
|
height=1,
|
||||||
|
prompt='❯ ',
|
||||||
|
style='class:input-area',
|
||||||
|
multiline=False,
|
||||||
|
wrap_lines=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a status line that shows when agent is working
|
||||||
|
def get_status_text():
|
||||||
|
if self._agent_running:
|
||||||
|
return [('class:status', ' 🔄 Agent working... (type to interrupt) ')]
|
||||||
|
return [('class:status', '')]
|
||||||
|
|
||||||
|
status_window = Window(
|
||||||
|
content=FormattedTextControl(get_status_text),
|
||||||
|
height=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Layout with status and input at bottom
|
||||||
|
layout = Layout(
|
||||||
|
HSplit([
|
||||||
|
Window(height=0), # Spacer that expands
|
||||||
|
status_window,
|
||||||
|
input_area,
|
||||||
|
])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Style for the application
|
||||||
|
style = PTStyle.from_dict({
|
||||||
|
'input-area': '#FFF8DC',
|
||||||
|
'status': 'bg:#333333 #FFD700',
|
||||||
|
})
|
||||||
|
|
||||||
|
# Create the application
|
||||||
|
app = Application(
|
||||||
|
layout=layout,
|
||||||
|
key_bindings=kb,
|
||||||
|
style=style,
|
||||||
|
full_screen=False,
|
||||||
|
mouse_support=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Background thread to process inputs and run agent
|
||||||
|
def process_loop():
|
||||||
|
while not self._should_exit:
|
||||||
try:
|
try:
|
||||||
user_input = self.get_input()
|
# Check for pending input with timeout
|
||||||
|
try:
|
||||||
if user_input is None:
|
user_input = self._pending_input.get(timeout=0.1)
|
||||||
print("\nGoodbye! ⚕")
|
except queue.Empty:
|
||||||
break
|
continue
|
||||||
|
|
||||||
if not user_input:
|
if not user_input:
|
||||||
continue
|
continue
|
||||||
|
|
@ -1368,16 +1496,38 @@ class HermesCLI:
|
||||||
# Check for commands
|
# Check for commands
|
||||||
if user_input.startswith("/"):
|
if user_input.startswith("/"):
|
||||||
if not self.process_command(user_input):
|
if not self.process_command(user_input):
|
||||||
print("\nGoodbye! ⚕")
|
self._should_exit = True
|
||||||
break
|
# Schedule app exit
|
||||||
|
if app.is_running:
|
||||||
|
app.exit()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Regular chat message
|
# Regular chat - run agent
|
||||||
self.chat(user_input)
|
self._agent_running = True
|
||||||
|
app.invalidate() # Refresh status line
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
try:
|
||||||
print("\nInterrupted. Type /quit to exit.")
|
self.chat(user_input)
|
||||||
continue
|
finally:
|
||||||
|
self._agent_running = False
|
||||||
|
app.invalidate() # Refresh status line
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
# Start processing thread
|
||||||
|
process_thread = threading.Thread(target=process_loop, daemon=True)
|
||||||
|
process_thread.start()
|
||||||
|
|
||||||
|
# Run the application with patch_stdout for proper output handling
|
||||||
|
try:
|
||||||
|
with patch_stdout():
|
||||||
|
app.run()
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self._should_exit = True
|
||||||
|
print("\nGoodbye! ⚕")
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
|
||||||
104
gateway/run.py
104
gateway/run.py
|
|
@ -72,6 +72,11 @@ class GatewayRunner:
|
||||||
self.delivery_router = DeliveryRouter(self.config)
|
self.delivery_router = DeliveryRouter(self.config)
|
||||||
self._running = False
|
self._running = False
|
||||||
self._shutdown_event = asyncio.Event()
|
self._shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
|
# Track running agents per session for interrupt support
|
||||||
|
# Key: session_key, Value: AIAgent instance
|
||||||
|
self._running_agents: Dict[str, Any] = {}
|
||||||
|
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
|
||||||
|
|
||||||
async def start(self) -> bool:
|
async def start(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
@ -217,10 +222,11 @@ class GatewayRunner:
|
||||||
This is the core message processing pipeline:
|
This is the core message processing pipeline:
|
||||||
1. Check user authorization
|
1. Check user authorization
|
||||||
2. Check for commands (/new, /reset, etc.)
|
2. Check for commands (/new, /reset, etc.)
|
||||||
3. Get or create session
|
3. Check for running agent and interrupt if needed
|
||||||
4. Build context for agent
|
4. Get or create session
|
||||||
5. Run agent conversation
|
5. Build context for agent
|
||||||
6. Return response
|
6. Run agent conversation
|
||||||
|
7. Return response
|
||||||
"""
|
"""
|
||||||
source = event.source
|
source = event.source
|
||||||
|
|
||||||
|
|
@ -229,7 +235,7 @@ class GatewayRunner:
|
||||||
print(f"[gateway] Unauthorized user: {source.user_id} ({source.user_name}) on {source.platform.value}")
|
print(f"[gateway] Unauthorized user: {source.user_id} ({source.user_name}) on {source.platform.value}")
|
||||||
return None # Silently ignore unauthorized users
|
return None # Silently ignore unauthorized users
|
||||||
|
|
||||||
# Check for reset commands
|
# Check for commands
|
||||||
command = event.get_command()
|
command = event.get_command()
|
||||||
if command in ["new", "reset"]:
|
if command in ["new", "reset"]:
|
||||||
return await self._handle_reset_command(event)
|
return await self._handle_reset_command(event)
|
||||||
|
|
@ -237,8 +243,21 @@ class GatewayRunner:
|
||||||
if command == "status":
|
if command == "status":
|
||||||
return await self._handle_status_command(event)
|
return await self._handle_status_command(event)
|
||||||
|
|
||||||
|
if command == "stop":
|
||||||
|
return await self._handle_stop_command(event)
|
||||||
|
|
||||||
# Get or create session
|
# Get or create session
|
||||||
session_entry = self.session_store.get_or_create_session(source)
|
session_entry = self.session_store.get_or_create_session(source)
|
||||||
|
session_key = session_entry.session_key
|
||||||
|
|
||||||
|
# Check if there's already a running agent for this session
|
||||||
|
if session_key in self._running_agents:
|
||||||
|
running_agent = self._running_agents[session_key]
|
||||||
|
print(f"[gateway] ⚡ Interrupting running agent for session {session_key[:20]}...")
|
||||||
|
running_agent.interrupt(event.text)
|
||||||
|
# Store the new message to be processed after current agent finishes
|
||||||
|
self._pending_messages[session_key] = event.text
|
||||||
|
return None # Don't respond yet - let the interrupt handle it
|
||||||
|
|
||||||
# Build session context
|
# Build session context
|
||||||
context = build_session_context(source, self.config, session_entry)
|
context = build_session_context(source, self.config, session_entry)
|
||||||
|
|
@ -259,7 +278,8 @@ class GatewayRunner:
|
||||||
context_prompt=context_prompt,
|
context_prompt=context_prompt,
|
||||||
history=history,
|
history=history,
|
||||||
source=source,
|
source=source,
|
||||||
session_id=session_entry.session_id
|
session_id=session_entry.session_id,
|
||||||
|
session_key=session_key
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append to transcript
|
# Append to transcript
|
||||||
|
|
@ -309,6 +329,10 @@ class GatewayRunner:
|
||||||
|
|
||||||
connected_platforms = [p.value for p in self.adapters.keys()]
|
connected_platforms = [p.value for p in self.adapters.keys()]
|
||||||
|
|
||||||
|
# Check if there's an active agent
|
||||||
|
session_key = session_entry.session_key
|
||||||
|
is_running = session_key in self._running_agents
|
||||||
|
|
||||||
lines = [
|
lines = [
|
||||||
"📊 **Hermes Gateway Status**",
|
"📊 **Hermes Gateway Status**",
|
||||||
"",
|
"",
|
||||||
|
|
@ -316,12 +340,26 @@ class GatewayRunner:
|
||||||
f"**Created:** {session_entry.created_at.strftime('%Y-%m-%d %H:%M')}",
|
f"**Created:** {session_entry.created_at.strftime('%Y-%m-%d %H:%M')}",
|
||||||
f"**Last Activity:** {session_entry.updated_at.strftime('%Y-%m-%d %H:%M')}",
|
f"**Last Activity:** {session_entry.updated_at.strftime('%Y-%m-%d %H:%M')}",
|
||||||
f"**Tokens:** {session_entry.total_tokens:,}",
|
f"**Tokens:** {session_entry.total_tokens:,}",
|
||||||
|
f"**Agent Running:** {'Yes ⚡' if is_running else 'No'}",
|
||||||
"",
|
"",
|
||||||
f"**Connected Platforms:** {', '.join(connected_platforms)}",
|
f"**Connected Platforms:** {', '.join(connected_platforms)}",
|
||||||
]
|
]
|
||||||
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
async def _handle_stop_command(self, event: MessageEvent) -> str:
|
||||||
|
"""Handle /stop command - interrupt a running agent."""
|
||||||
|
source = event.source
|
||||||
|
session_entry = self.session_store.get_or_create_session(source)
|
||||||
|
session_key = session_entry.session_key
|
||||||
|
|
||||||
|
if session_key in self._running_agents:
|
||||||
|
agent = self._running_agents[session_key]
|
||||||
|
agent.interrupt()
|
||||||
|
return "⚡ Stopping the current task... The agent will finish its current step and respond."
|
||||||
|
else:
|
||||||
|
return "No active task to stop."
|
||||||
|
|
||||||
def _set_session_env(self, context: SessionContext) -> None:
|
def _set_session_env(self, context: SessionContext) -> None:
|
||||||
"""Set environment variables for the current session."""
|
"""Set environment variables for the current session."""
|
||||||
os.environ["HERMES_SESSION_PLATFORM"] = context.source.platform.value
|
os.environ["HERMES_SESSION_PLATFORM"] = context.source.platform.value
|
||||||
|
|
@ -341,12 +379,14 @@ class GatewayRunner:
|
||||||
context_prompt: str,
|
context_prompt: str,
|
||||||
history: List[Dict[str, Any]],
|
history: List[Dict[str, Any]],
|
||||||
source: SessionSource,
|
source: SessionSource,
|
||||||
session_id: str
|
session_id: str,
|
||||||
|
session_key: str = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Run the agent with the given message and context.
|
Run the agent with the given message and context.
|
||||||
|
|
||||||
This is run in a thread pool to not block the event loop.
|
This is run in a thread pool to not block the event loop.
|
||||||
|
Supports interruption via new messages.
|
||||||
"""
|
"""
|
||||||
from run_agent import AIAgent
|
from run_agent import AIAgent
|
||||||
import queue
|
import queue
|
||||||
|
|
@ -432,6 +472,10 @@ class GatewayRunner:
|
||||||
print(f"[Gateway] Progress message error: {e}")
|
print(f"[Gateway] Progress message error: {e}")
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# We need to share the agent instance for interrupt support
|
||||||
|
agent_holder = [None] # Mutable container for the agent instance
|
||||||
|
result_holder = [None] # Mutable container for the result
|
||||||
|
|
||||||
def run_sync():
|
def run_sync():
|
||||||
# Read from env var or use default (same as CLI)
|
# Read from env var or use default (same as CLI)
|
||||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "60"))
|
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "60"))
|
||||||
|
|
@ -446,6 +490,9 @@ class GatewayRunner:
|
||||||
tool_progress_callback=progress_callback if tool_progress_enabled else None,
|
tool_progress_callback=progress_callback if tool_progress_enabled else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Store agent reference for interrupt support
|
||||||
|
agent_holder[0] = agent
|
||||||
|
|
||||||
# Convert transcript history to agent format
|
# Convert transcript history to agent format
|
||||||
# Transcript has timestamps; agent expects {"role": ..., "content": ...}
|
# Transcript has timestamps; agent expects {"role": ..., "content": ...}
|
||||||
agent_history = []
|
agent_history = []
|
||||||
|
|
@ -456,6 +503,7 @@ class GatewayRunner:
|
||||||
agent_history.append({"role": role, "content": content})
|
agent_history.append({"role": role, "content": content})
|
||||||
|
|
||||||
result = agent.run_conversation(message, conversation_history=agent_history)
|
result = agent.run_conversation(message, conversation_history=agent_history)
|
||||||
|
result_holder[0] = result
|
||||||
|
|
||||||
# Return final response, or a message if something went wrong
|
# Return final response, or a message if something went wrong
|
||||||
final_response = result.get("final_response")
|
final_response = result.get("final_response")
|
||||||
|
|
@ -472,14 +520,56 @@ class GatewayRunner:
|
||||||
if tool_progress_enabled:
|
if tool_progress_enabled:
|
||||||
progress_task = asyncio.create_task(send_progress_messages())
|
progress_task = asyncio.create_task(send_progress_messages())
|
||||||
|
|
||||||
|
# Track this agent as running for this session (for interrupt support)
|
||||||
|
# We do this in a callback after the agent is created
|
||||||
|
async def track_agent():
|
||||||
|
# Wait for agent to be created
|
||||||
|
while agent_holder[0] is None:
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
if session_key:
|
||||||
|
self._running_agents[session_key] = agent_holder[0]
|
||||||
|
|
||||||
|
tracking_task = asyncio.create_task(track_agent())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Run in thread pool to not block
|
# Run in thread pool to not block
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
response = await loop.run_in_executor(None, run_sync)
|
response = await loop.run_in_executor(None, run_sync)
|
||||||
|
|
||||||
|
# Check if we were interrupted and have a pending message
|
||||||
|
result = result_holder[0]
|
||||||
|
if result and result.get("interrupted") and session_key:
|
||||||
|
pending = self._pending_messages.pop(session_key, None)
|
||||||
|
if pending:
|
||||||
|
print(f"[gateway] 📨 Processing interrupted message: '{pending[:40]}...'")
|
||||||
|
# Add an indicator to the response
|
||||||
|
if response:
|
||||||
|
response = response + "\n\n---\n_[Interrupted - processing your new message]_"
|
||||||
|
|
||||||
|
# Send the interrupted response first
|
||||||
|
adapter = self.adapters.get(source.platform)
|
||||||
|
if adapter and response:
|
||||||
|
await adapter.send(chat_id=source.chat_id, content=response)
|
||||||
|
|
||||||
|
# Now process the pending message with updated history
|
||||||
|
updated_history = result.get("messages", history)
|
||||||
|
return await self._run_agent(
|
||||||
|
message=pending,
|
||||||
|
context_prompt=context_prompt,
|
||||||
|
history=updated_history,
|
||||||
|
source=source,
|
||||||
|
session_id=session_id,
|
||||||
|
session_key=session_key
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
# Stop progress sender
|
# Stop progress sender
|
||||||
if progress_task:
|
if progress_task:
|
||||||
progress_task.cancel()
|
progress_task.cancel()
|
||||||
|
|
||||||
|
# Clean up tracking
|
||||||
|
tracking_task.cancel()
|
||||||
|
if session_key and session_key in self._running_agents:
|
||||||
|
del self._running_agents[session_key]
|
||||||
try:
|
try:
|
||||||
await progress_task
|
await progress_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
|
|
||||||
66
run_agent.py
66
run_agent.py
|
|
@ -639,6 +639,10 @@ class AIAgent:
|
||||||
self.tool_progress_callback = tool_progress_callback
|
self.tool_progress_callback = tool_progress_callback
|
||||||
self._last_reported_tool = None # Track for "new tool" mode
|
self._last_reported_tool = None # Track for "new tool" mode
|
||||||
|
|
||||||
|
# Interrupt mechanism for breaking out of tool loops
|
||||||
|
self._interrupt_requested = False
|
||||||
|
self._interrupt_message = None # Optional message that triggered interrupt
|
||||||
|
|
||||||
# Store OpenRouter provider preferences
|
# Store OpenRouter provider preferences
|
||||||
self.providers_allowed = providers_allowed
|
self.providers_allowed = providers_allowed
|
||||||
self.providers_ignored = providers_ignored
|
self.providers_ignored = providers_ignored
|
||||||
|
|
@ -1302,6 +1306,42 @@ class AIAgent:
|
||||||
if self.verbose_logging:
|
if self.verbose_logging:
|
||||||
logging.warning(f"Failed to save session log: {e}")
|
logging.warning(f"Failed to save session log: {e}")
|
||||||
|
|
||||||
|
def interrupt(self, message: str = None) -> None:
|
||||||
|
"""
|
||||||
|
Request the agent to interrupt its current tool-calling loop.
|
||||||
|
|
||||||
|
Call this from another thread (e.g., input handler, message receiver)
|
||||||
|
to gracefully stop the agent and process a new message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Optional new message that triggered the interrupt.
|
||||||
|
If provided, the agent will include this in its response context.
|
||||||
|
|
||||||
|
Example (CLI):
|
||||||
|
# In a separate input thread:
|
||||||
|
if user_typed_something:
|
||||||
|
agent.interrupt(user_input)
|
||||||
|
|
||||||
|
Example (Messaging):
|
||||||
|
# When new message arrives for active session:
|
||||||
|
if session_has_running_agent:
|
||||||
|
running_agent.interrupt(new_message.text)
|
||||||
|
"""
|
||||||
|
self._interrupt_requested = True
|
||||||
|
self._interrupt_message = message
|
||||||
|
if not self.quiet_mode:
|
||||||
|
print(f"\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else ""))
|
||||||
|
|
||||||
|
def clear_interrupt(self) -> None:
|
||||||
|
"""Clear any pending interrupt request."""
|
||||||
|
self._interrupt_requested = False
|
||||||
|
self._interrupt_message = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_interrupted(self) -> bool:
|
||||||
|
"""Check if an interrupt has been requested."""
|
||||||
|
return self._interrupt_requested
|
||||||
|
|
||||||
def run_conversation(
|
def run_conversation(
|
||||||
self,
|
self,
|
||||||
user_message: str,
|
user_message: str,
|
||||||
|
|
@ -1359,8 +1399,19 @@ class AIAgent:
|
||||||
# Main conversation loop
|
# Main conversation loop
|
||||||
api_call_count = 0
|
api_call_count = 0
|
||||||
final_response = None
|
final_response = None
|
||||||
|
interrupted = False
|
||||||
|
|
||||||
|
# Clear any stale interrupt state at start
|
||||||
|
self.clear_interrupt()
|
||||||
|
|
||||||
while api_call_count < self.max_iterations:
|
while api_call_count < self.max_iterations:
|
||||||
|
# Check for interrupt request (e.g., user sent new message)
|
||||||
|
if self._interrupt_requested:
|
||||||
|
interrupted = True
|
||||||
|
if not self.quiet_mode:
|
||||||
|
print(f"\n⚡ Breaking out of tool loop due to interrupt...")
|
||||||
|
break
|
||||||
|
|
||||||
api_call_count += 1
|
api_call_count += 1
|
||||||
|
|
||||||
# Prepare messages for API call
|
# Prepare messages for API call
|
||||||
|
|
@ -2059,13 +2110,24 @@ class AIAgent:
|
||||||
self._session_messages = messages
|
self._session_messages = messages
|
||||||
self._save_session_log(messages)
|
self._save_session_log(messages)
|
||||||
|
|
||||||
return {
|
# Build result with interrupt info if applicable
|
||||||
|
result = {
|
||||||
"final_response": final_response,
|
"final_response": final_response,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"api_calls": api_call_count,
|
"api_calls": api_call_count,
|
||||||
"completed": completed,
|
"completed": completed,
|
||||||
"partial": False # True only when stopped due to invalid tool calls
|
"partial": False, # True only when stopped due to invalid tool calls
|
||||||
|
"interrupted": interrupted,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Include interrupt message if one triggered the interrupt
|
||||||
|
if interrupted and self._interrupt_message:
|
||||||
|
result["interrupt_message"] = self._interrupt_message
|
||||||
|
|
||||||
|
# Clear interrupt state after handling
|
||||||
|
self.clear_interrupt()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def chat(self, message: str) -> str:
|
def chat(self, message: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue