mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Implement interrupt handling for message processing in GatewayRunner and BasePlatformAdapter
- Introduced a monitoring mechanism in GatewayRunner to detect incoming messages while an agent is active, allowing for graceful interruption and processing of new messages. - Enhanced BasePlatformAdapter to manage active sessions and pending messages, ensuring that new messages can interrupt ongoing tasks effectively. - Improved the handling of pending messages by checking for interrupts and processing them in the correct order, enhancing user experience during message interactions. - Updated the cleanup process for active tasks to ensure proper resource management after interruptions.
This commit is contained in:
parent
9bfe185a2e
commit
51a6b7d2b5
2 changed files with 118 additions and 30 deletions
|
|
@ -108,6 +108,11 @@ class BasePlatformAdapter(ABC):
|
||||||
self.platform = platform
|
self.platform = platform
|
||||||
self._message_handler: Optional[MessageHandler] = None
|
self._message_handler: Optional[MessageHandler] = None
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
|
# Track active message handlers per session for interrupt support
|
||||||
|
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
||||||
|
self._active_sessions: Dict[str, asyncio.Event] = {}
|
||||||
|
self._pending_messages: Dict[str, MessageEvent] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
|
|
@ -190,12 +195,33 @@ class BasePlatformAdapter(ABC):
|
||||||
"""
|
"""
|
||||||
Process an incoming message.
|
Process an incoming message.
|
||||||
|
|
||||||
Calls the registered message handler and sends the response.
|
This method returns quickly by spawning background tasks.
|
||||||
Keeps typing indicator active throughout processing.
|
This allows new messages to be processed even while an agent is running,
|
||||||
|
enabling interruption support.
|
||||||
"""
|
"""
|
||||||
if not self._message_handler:
|
if not self._message_handler:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
session_key = event.source.chat_id
|
||||||
|
|
||||||
|
# Check if there's already an active handler for this session
|
||||||
|
if session_key in self._active_sessions:
|
||||||
|
# Store this as a pending message - it will interrupt the running agent
|
||||||
|
print(f"[{self.name}] ⚡ New message while session {session_key} is active - triggering interrupt")
|
||||||
|
self._pending_messages[session_key] = event
|
||||||
|
# Signal the interrupt (the processing task checks this)
|
||||||
|
self._active_sessions[session_key].set()
|
||||||
|
return # Don't process now - will be handled after current task finishes
|
||||||
|
|
||||||
|
# Spawn background task to process this message
|
||||||
|
asyncio.create_task(self._process_message_background(event, session_key))
|
||||||
|
|
||||||
|
async def _process_message_background(self, event: MessageEvent, session_key: str) -> None:
|
||||||
|
"""Background task that actually processes the message."""
|
||||||
|
# Create interrupt event for this session
|
||||||
|
interrupt_event = asyncio.Event()
|
||||||
|
self._active_sessions[session_key] = interrupt_event
|
||||||
|
|
||||||
# Start continuous typing indicator (refreshes every 2 seconds)
|
# Start continuous typing indicator (refreshes every 2 seconds)
|
||||||
typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id))
|
typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id))
|
||||||
|
|
||||||
|
|
@ -222,6 +248,23 @@ class BasePlatformAdapter(ABC):
|
||||||
)
|
)
|
||||||
if not fallback_result.success:
|
if not fallback_result.success:
|
||||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||||
|
|
||||||
|
# Check if there's a pending message that was queued during our processing
|
||||||
|
if session_key in self._pending_messages:
|
||||||
|
pending_event = self._pending_messages.pop(session_key)
|
||||||
|
print(f"[{self.name}] 📨 Processing queued message from interrupt")
|
||||||
|
# Clean up current session before processing pending
|
||||||
|
if session_key in self._active_sessions:
|
||||||
|
del self._active_sessions[session_key]
|
||||||
|
typing_task.cancel()
|
||||||
|
try:
|
||||||
|
await typing_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
# Process pending message in new background task
|
||||||
|
await self._process_message_background(pending_event, session_key)
|
||||||
|
return # Already cleaned up
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{self.name}] Error handling message: {e}")
|
print(f"[{self.name}] Error handling message: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
@ -233,6 +276,17 @@ class BasePlatformAdapter(ABC):
|
||||||
await typing_task
|
await typing_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
# Clean up session tracking
|
||||||
|
if session_key in self._active_sessions:
|
||||||
|
del self._active_sessions[session_key]
|
||||||
|
|
||||||
|
def has_pending_interrupt(self, session_key: str) -> bool:
|
||||||
|
"""Check if there's a pending interrupt for a session."""
|
||||||
|
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()
|
||||||
|
|
||||||
|
def get_pending_message(self, session_key: str) -> Optional[MessageEvent]:
|
||||||
|
"""Get and clear any pending message for a session."""
|
||||||
|
return self._pending_messages.get(session_key)
|
||||||
|
|
||||||
def build_source(
|
def build_source(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -531,6 +531,27 @@ class GatewayRunner:
|
||||||
|
|
||||||
tracking_task = asyncio.create_task(track_agent())
|
tracking_task = asyncio.create_task(track_agent())
|
||||||
|
|
||||||
|
# Monitor for interrupts from the adapter (new messages arriving)
|
||||||
|
async def monitor_for_interrupt():
|
||||||
|
adapter = self.adapters.get(source.platform)
|
||||||
|
if not adapter:
|
||||||
|
return
|
||||||
|
|
||||||
|
chat_id = source.chat_id
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(0.2) # Check every 200ms
|
||||||
|
# Check if adapter has a pending interrupt for this session
|
||||||
|
if hasattr(adapter, 'has_pending_interrupt') and adapter.has_pending_interrupt(chat_id):
|
||||||
|
agent = agent_holder[0]
|
||||||
|
if agent:
|
||||||
|
pending_event = adapter.get_pending_message(chat_id)
|
||||||
|
pending_text = pending_event.text if pending_event else None
|
||||||
|
print(f"[gateway] ⚡ Interrupt detected from adapter, signaling agent...")
|
||||||
|
agent.interrupt(pending_text)
|
||||||
|
break
|
||||||
|
|
||||||
|
interrupt_monitor = asyncio.create_task(monitor_for_interrupt())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Run in thread pool to not block
|
# Run in thread pool to not block
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
@ -538,42 +559,55 @@ class GatewayRunner:
|
||||||
|
|
||||||
# Check if we were interrupted and have a pending message
|
# Check if we were interrupted and have a pending message
|
||||||
result = result_holder[0]
|
result = result_holder[0]
|
||||||
if result and result.get("interrupted") and session_key:
|
adapter = self.adapters.get(source.platform)
|
||||||
pending = self._pending_messages.pop(session_key, None)
|
|
||||||
if pending:
|
# Get pending message from adapter if interrupted
|
||||||
print(f"[gateway] 📨 Processing interrupted message: '{pending[:40]}...'")
|
pending = None
|
||||||
# Add an indicator to the response
|
if result and result.get("interrupted") and adapter:
|
||||||
if response:
|
pending_event = adapter.get_pending_message(source.chat_id)
|
||||||
response = response + "\n\n---\n_[Interrupted - processing your new message]_"
|
if pending_event:
|
||||||
|
pending = pending_event.text
|
||||||
# Send the interrupted response first
|
elif result.get("interrupt_message"):
|
||||||
adapter = self.adapters.get(source.platform)
|
pending = result.get("interrupt_message")
|
||||||
if adapter and response:
|
|
||||||
await adapter.send(chat_id=source.chat_id, content=response)
|
if pending:
|
||||||
|
print(f"[gateway] 📨 Processing interrupted message: '{pending[:40]}...'")
|
||||||
# Now process the pending message with updated history
|
# Add an indicator to the response
|
||||||
updated_history = result.get("messages", history)
|
if response:
|
||||||
return await self._run_agent(
|
response = response + "\n\n---\n_[Interrupted - processing your new message]_"
|
||||||
message=pending,
|
|
||||||
context_prompt=context_prompt,
|
# Send the interrupted response first
|
||||||
history=updated_history,
|
if adapter and response:
|
||||||
source=source,
|
await adapter.send(chat_id=source.chat_id, content=response)
|
||||||
session_id=session_id,
|
|
||||||
session_key=session_key
|
# Now process the pending message with updated history
|
||||||
)
|
updated_history = result.get("messages", history)
|
||||||
|
return await self._run_agent(
|
||||||
|
message=pending,
|
||||||
|
context_prompt=context_prompt,
|
||||||
|
history=updated_history,
|
||||||
|
source=source,
|
||||||
|
session_id=session_id,
|
||||||
|
session_key=session_key
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
# Stop progress sender
|
# Stop progress sender and interrupt monitor
|
||||||
if progress_task:
|
if progress_task:
|
||||||
progress_task.cancel()
|
progress_task.cancel()
|
||||||
|
interrupt_monitor.cancel()
|
||||||
|
|
||||||
# Clean up tracking
|
# Clean up tracking
|
||||||
tracking_task.cancel()
|
tracking_task.cancel()
|
||||||
if session_key and session_key in self._running_agents:
|
if session_key and session_key in self._running_agents:
|
||||||
del self._running_agents[session_key]
|
del self._running_agents[session_key]
|
||||||
try:
|
|
||||||
await progress_task
|
# Wait for cancelled tasks
|
||||||
except asyncio.CancelledError:
|
for task in [progress_task, interrupt_monitor, tracking_task]:
|
||||||
pass
|
if task:
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue