This commit is contained in:
Jai Suphavadeeprasit 2025-10-10 18:04:22 -04:00
parent c5386ed7e6
commit e698b7e0e5
19 changed files with 3924 additions and 132 deletions

View file

@ -24,6 +24,8 @@ import json
import logging
import os
import time
import uuid
import asyncio
from typing import List, Dict, Any, Optional
from openai import OpenAI
import fire
@ -31,6 +33,19 @@ from datetime import datetime
# Import our tool system
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
from mock_web_tools import MOCK_TOOL_FUNCTIONS, MOCK_WEB_TOOLS
# Import WebSocket connection pool (optional dependency)
# Use synchronous API to avoid event loop management in agent layer
try:
from api_endpoint.websocket_connection_pool import connect_sync, send_event_sync, is_connected
WEBSOCKET_LOGGER_AVAILABLE = True
except ImportError:
WEBSOCKET_LOGGER_AVAILABLE = False
connect_sync = None
send_event_sync = None
is_connected = None
print("⚠️ WebSocket logger not available (missing websockets package)")
class AIAgent:
@ -51,7 +66,11 @@ class AIAgent:
enabled_toolsets: List[str] = None,
disabled_toolsets: List[str] = None,
save_trajectories: bool = False,
verbose_logging: bool = False
verbose_logging: bool = False,
enable_websocket_logging: bool = False,
websocket_server: str = "ws://localhost:8000/ws",
mock_web_tools: bool = False,
mock_delay: int = 60
):
"""
Initialize the AI Agent.
@ -66,12 +85,21 @@ class AIAgent:
disabled_toolsets (List[str]): Disable tools from these toolsets (optional)
save_trajectories (bool): Whether to save conversation trajectories to JSONL files (default: False)
verbose_logging (bool): Enable verbose logging for debugging (default: False)
enable_websocket_logging (bool): Enable real-time WebSocket logging (default: False)
websocket_server (str): WebSocket server URL (default: ws://localhost:8000/ws)
mock_web_tools (bool): Use mock web tools for testing (no API calls, configurable delays) (default: False)
mock_delay (int): Delay in seconds for mock web_extract to test timeout (default: 60)
"""
self.model = model
self.max_iterations = max_iterations
self.tool_delay = tool_delay
self.save_trajectories = save_trajectories
self.verbose_logging = verbose_logging
self.enable_websocket_logging = enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE
self.websocket_server = websocket_server
self.mock_web_tools = mock_web_tools
self.mock_delay = mock_delay
# Note: We use global ws_pool instead of per-instance connection
# Store toolset filtering options
self.enabled_toolsets = enabled_toolsets
@ -145,6 +173,11 @@ class AIAgent:
# Show trajectory saving status
if self.save_trajectories:
print("📝 Trajectory saving enabled")
# Show mock tools status
if self.mock_web_tools:
print(f"🧪 MOCK MODE ENABLED - Web tools will use fake data (delay: {self.mock_delay}s)")
print(f" Perfect for testing WebSocket reconnection without API costs!")
def _format_tools_for_system_message(self) -> str:
"""
@ -320,11 +353,38 @@ class AIAgent:
except Exception as e:
print(f"⚠️ Failed to save trajectory: {e}")
def _init_websocket_connection(self, session_id: str):
"""
Initialize WebSocket connection pool if enabled.
Connects to logging server using persistent connection pool.
Connection is shared across all agent runs - no per-run overhead!
Uses synchronous API - no event loop management needed in agent layer.
"""
if self.enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE and connect_sync:
try:
# Connect to server (idempotent - safe if already connected)
# API layer handles all event loop management internally
connect_sync(self.websocket_server)
# Send session_start event for this specific session
send_event_sync("session_start", session_id, {
"session_id": session_id,
"start_time": datetime.now().isoformat()
})
print(f"📡 WebSocket logging enabled (session: {session_id[:8]}...)")
except Exception as e:
print(f"⚠️ Failed to initialize WebSocket connection: {e}")
self.enable_websocket_logging = False
def run_conversation(
self,
user_message: str,
system_message: str = None,
conversation_history: List[Dict[str, Any]] = None
conversation_history: List[Dict[str, Any]] = None,
session_id: str = None
) -> Dict[str, Any]:
"""
Run a complete conversation with tool calling until completion.
@ -333,10 +393,37 @@ class AIAgent:
user_message (str): The user's message/question
system_message (str): Custom system message (optional)
conversation_history (List[Dict]): Previous conversation messages (optional)
session_id (str): Optional session ID (generated if not provided)
Returns:
Dict: Complete conversation result with final response and message history
"""
# ============================================================
# WEBSOCKET LOGGING: Session Initialization
# ============================================================
# Generate unique session ID for this agent execution (or use provided one)
# This ID will be used to link all events together in the log file
if session_id is None:
session_id = str(uuid.uuid4())
# Initialize WebSocket logger if enabled (via --enable_websocket_logging flag)
# Uses synchronous API - no event loop management in agent layer
if self.enable_websocket_logging:
try:
# Connect to logging server and log initial query
# All event loop management is handled inside the API layer
self._init_websocket_connection(session_id)
send_event_sync("query", session_id, {
"query": user_message,
"model": self.model,
"toolsets": self.enabled_toolsets
})
except Exception as e:
print(f"⚠️ WebSocket logging initialization failed: {e}")
import traceback
if self.verbose_logging:
traceback.print_exc()
# Initialize conversation
messages = conversation_history or []
@ -356,6 +443,22 @@ class AIAgent:
api_call_count += 1
print(f"\n🔄 Making API call #{api_call_count}...")
# ============================================================
# WEBSOCKET LOGGING: API Call Start
# ============================================================
# Log that we're about to make an API call to the model
# Captures: which call number, how many messages, whether tools available
if self.enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE and send_event_sync:
try:
send_event_sync("api_call", session_id, {
"call_number": api_call_count,
"message_count": len(messages),
"has_tools": bool(self.tools)
})
except Exception as e:
if self.verbose_logging:
print(f"⚠️ WebSocket logging error: {e}")
# Log request details if verbose
if self.verbose_logging:
logging.debug(f"API Request - Model: {self.model}, Messages: {len(messages)}, Tools: {len(self.tools) if self.tools else 0}")
@ -374,10 +477,31 @@ class AIAgent:
tools=self.tools if self.tools else None,
timeout=60.0 # Add explicit timeout
)
print(f"🔧 Response: {response}")
api_duration = time.time() - api_start_time
print(f"⏱️ API call completed in {api_duration:.2f}s")
# ============================================================
# WEBSOCKET LOGGING: API Response
# ============================================================
# Log the response we got back from the AI model
# Captures: what the model said, whether it wants tools, how long it took
if self.enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE and send_event_sync:
try:
assistant_msg = response.choices[0].message
send_event_sync("response", session_id, {
"call_number": api_call_count,
"content": assistant_msg.content if hasattr(assistant_msg, 'content') else None,
"has_tool_calls": hasattr(assistant_msg, 'tool_calls') and bool(assistant_msg.tool_calls),
"tool_call_count": len(assistant_msg.tool_calls) if hasattr(assistant_msg, 'tool_calls') and assistant_msg.tool_calls else 0,
"duration": api_duration
})
except Exception as e:
if self.verbose_logging:
print(f"⚠️ WebSocket logging error: {e}")
if self.verbose_logging:
logging.debug(f"API Response received - Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}")
@ -399,10 +523,12 @@ class AIAgent:
# Handle assistant response
if assistant_message.content:
print(f"🤖 Assistant: {assistant_message.content[:100]}{'...' if len(assistant_message.content) > 100 else ''}")
print(f"🤖 Assistant: {assistant_message.content}")
# Check for tool calls
if assistant_message.tool_calls:
print(f"🔧 Tool calls: {assistant_message.tool_calls}")
print(f"🔧 Processing {len(assistant_message.tool_calls)} tool call(s)...")
if self.verbose_logging:
@ -438,10 +564,37 @@ class AIAgent:
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())})")
# ============================================================
# WEBSOCKET LOGGING: Tool Call (Before Execution)
# ============================================================
# Log which tool we're about to execute and with what parameters
# This happens BEFORE tool runs, so we know what was requested
if self.enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE and send_event_sync:
try:
send_event_sync("tool_call", session_id, {
"call_number": api_call_count,
"tool_index": i,
"tool_name": function_name,
"parameters": function_args, # E.g., {"query": "Python", "limit": 5}
"tool_call_id": tool_call.id
})
except Exception as e:
if self.verbose_logging:
print(f"⚠️ WebSocket logging error: {e}")
tool_start_time = time.time()
# Execute the tool
function_result = handle_function_call(function_name, function_args)
# Execute the tool (mock or real based on configuration)
if self.mock_web_tools and function_name in MOCK_TOOL_FUNCTIONS:
# Use mock implementation (no API calls, configurable delay)
mock_function = MOCK_TOOL_FUNCTIONS[function_name]
# Inject mock_delay for web_extract if not provided
if function_name == "web_extract" and "delay" not in function_args:
function_args["delay"] = self.mock_delay
function_result = mock_function(**function_args)
else:
# Use real tool implementation
function_result = handle_function_call(function_name, function_args)
tool_duration = time.time() - tool_start_time
result_preview = function_result[:200] if len(function_result) > 200 else function_result
@ -459,6 +612,36 @@ class AIAgent:
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s")
# ============================================================
# WEBSOCKET LOGGING: Tool Result (After Execution)
# ============================================================
# Log the result we got back from the tool
# IMPORTANT: Logs BOTH truncated preview AND full raw result
#
# Why both?
# - result: Truncated to 1000 chars for quick preview in UI
# - raw_result: FULL untruncated output for verification
#
# This is crucial for web tools where you want to see:
# - What the scraper actually returned (raw_result)
# - What the LLM processed it into (compare against raw)
# - Verify the LLM isn't losing important information
if self.enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE and send_event_sync:
try:
send_event_sync("tool_result", session_id, {
"call_number": api_call_count,
"tool_index": i,
"tool_name": function_name,
"result": function_result[:1000] if function_result else None, # Truncated preview
"raw_result": function_result, # Full untruncated result (can be 100KB+)
"error": None,
"duration": tool_duration,
"tool_call_id": tool_call.id
})
except Exception as e:
if self.verbose_logging:
print(f"⚠️ WebSocket logging error: {e}")
# Delay between tool calls
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
time.sleep(self.tool_delay)
@ -483,6 +666,21 @@ class AIAgent:
error_msg = f"Error during API call #{api_call_count}: {str(e)}"
print(f"{error_msg}")
# ============================================================
# WEBSOCKET LOGGING: Error Event
# ============================================================
# Log any errors that occur during API calls or tool execution
# Helps track failures and debug issues
if self.enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE and send_event_sync:
try:
send_event_sync("error", session_id, {
"error_message": error_msg,
"call_number": api_call_count
})
except Exception as ws_error:
if self.verbose_logging:
print(f"⚠️ WebSocket logging error: {ws_error}")
if self.verbose_logging:
logging.exception("Detailed error information:")
@ -509,14 +707,37 @@ class AIAgent:
# Save trajectory if enabled
self._save_trajectory(messages, user_message, completed)
# ============================================================
# WEBSOCKET LOGGING: Session Complete
# ============================================================
# Log final completion event for this session
# Note: WebSocket connection stays open for future runs (persistent pool)
# Uses synchronous API - no event loop management in agent layer
if self.enable_websocket_logging and WEBSOCKET_LOGGER_AVAILABLE and send_event_sync:
try:
# Log completion with summary information
# API layer handles event loop management internally
send_event_sync("complete", session_id, {
"final_response": final_response, # What the agent finally answered
"total_calls": api_call_count, # How many API calls were made
"completed": completed # Did it finish successfully?
})
# Connection persists automatically - agent has no control over lifecycle
except Exception as e:
if self.verbose_logging:
print(f"⚠️ WebSocket logging error: {e}")
import traceback
traceback.print_exc()
return {
"final_response": final_response,
"messages": messages,
"api_calls": api_call_count,
"completed": completed
"completed": completed,
"session_id": session_id if self.enable_websocket_logging else None
}
def chat(self, message: str) -> str:
def chat(self, message: str) -> str: # After we connect the UI we can put whatever we want here
"""
Simple chat interface that returns just the final response.
@ -532,7 +753,7 @@ class AIAgent:
def main(
query: str = None,
model: str = "claude-opus-4-20250514",
model: str = "claude-sonnet-4-5-20250929",
api_key: str = None,
base_url: str = "https://api.anthropic.com/v1/",
max_turns: int = 10,
@ -540,7 +761,11 @@ def main(
disabled_toolsets: str = None,
list_tools: bool = False,
save_trajectories: bool = False,
verbose: bool = False
verbose: bool = False,
enable_websocket_logging: bool = False,
websocket_server: str = "ws://localhost:8000/ws",
mock_web_tools: bool = False,
mock_delay: int = 60
):
"""
Main function for running the agent directly.
@ -558,9 +783,24 @@ def main(
list_tools (bool): Just list available tools and exit
save_trajectories (bool): Save conversation trajectories to JSONL files. Defaults to False.
verbose (bool): Enable verbose logging for debugging. Defaults to False.
enable_websocket_logging (bool): Enable real-time WebSocket logging. Defaults to False.
websocket_server (str): WebSocket server URL. Defaults to ws://localhost:8000/ws.
mock_web_tools (bool): Use mock web tools for testing (no API calls, configurable delays). Defaults to False.
mock_delay (int): Delay in seconds for mock web_extract (default: 60s to test timeout). Defaults to 60.
Toolset Examples:
- "research": Web search, extract, crawl + vision tools
Mock Tools (Testing):
Use --mock_web_tools to test WebSocket reconnection without API calls:
- web_search: Returns fake results after 2s
- web_extract: Returns fake content after 60s (tests timeout)
- web_crawl: Returns fake pages after 30s
WebSocket Logging:
1. Start logging server: python logging_server.py
2. Run agent with --enable_websocket_logging flag
3. View logs in realtime at http://localhost:8000
"""
print("🤖 AI Agent with Tool Calling")
print("=" * 50)
@ -665,6 +905,11 @@ def main(
print(f" - Successful conversations → trajectory_samples.jsonl")
print(f" - Failed conversations → failed_trajectories.jsonl")
if enable_websocket_logging:
print(f"📡 WebSocket logging: ENABLED")
print(f" - Server: {websocket_server}")
print(f" - Make sure logging server is running: python logging_server.py")
# Initialize agent with provided parameters
try:
agent = AIAgent(
@ -675,7 +920,11 @@ def main(
enabled_toolsets=enabled_toolsets_list,
disabled_toolsets=disabled_toolsets_list,
save_trajectories=save_trajectories,
verbose_logging=verbose
verbose_logging=verbose,
enable_websocket_logging=enable_websocket_logging,
websocket_server=websocket_server,
mock_web_tools=mock_web_tools,
mock_delay=mock_delay
)
except RuntimeError as e:
print(f"❌ Failed to initialize agent: {e}")
@ -689,6 +938,9 @@ def main(
)
else:
user_query = query
# There needs to be a multi-turn conversation here
# Hermes Agent needs to be multi-turn to be useful
print(f"\n📝 User Query: {user_query}")
print("\n" + "=" * 50)
@ -713,3 +965,12 @@ def main(
if __name__ == "__main__":
fire.Fire(main)
# Order of operations:
# First track the ways in which information flows through the agent in realtime
# Create a FastAPI endpoint that is first able to listen for the logging through sockets
# Create the UI through there and now you have you have a pretty UI. CHECKPOINT 1
# Now that you have better visualization write out the chat interface and allow it to be controlled through the UI as well as the main program
# Now decide how the information flows through the agent you may need to do some trial and error to get this part right
# Implement multiturn conversation now and then CHECKPOINT 2 is now done with multiturn conversations