fix leakage

This commit is contained in:
hjc-puro 2025-11-03 17:42:23 -05:00
parent 0ca3e0aaa9
commit a4db3fdee5
4 changed files with 67 additions and 58 deletions

View file

@ -156,6 +156,7 @@ def _process_single_prompt(
print(f" Prompt {prompt_index}: Using toolsets {selected_toolsets}")
# Initialize agent with sampled toolsets
# Use prompt_index as task_id to ensure each task gets its own isolated VM
agent = AIAgent(
base_url=config.get("base_url"),
api_key=config.get("api_key"),
@ -164,7 +165,8 @@ def _process_single_prompt(
enabled_toolsets=selected_toolsets,
save_trajectories=False, # We handle saving ourselves
verbose_logging=config.get("verbose", False),
ephemeral_system_prompt=config.get("ephemeral_system_prompt")
ephemeral_system_prompt=config.get("ephemeral_system_prompt"),
task_id=f"task_{prompt_index}"
)
# Run the agent

View file

@ -480,14 +480,15 @@ def handle_web_function_call(function_name: str, function_args: Dict[str, Any])
else:
return json.dumps({"error": f"Unknown web function: {function_name}"})
def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
"""
Handle function calls for terminal tools.
Args:
function_name (str): Name of the terminal function to call
function_args (Dict): Arguments for the function
task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional)
Returns:
str: Function result as JSON string
"""
@ -498,8 +499,8 @@ def handle_terminal_function_call(function_name: str, function_args: Dict[str, A
idle_threshold = function_args.get("idle_threshold", 5.0)
timeout = function_args.get("timeout")
return terminal_tool(command, input_keys, None, background, idle_threshold, timeout)
return terminal_tool(command, input_keys, None, background, idle_threshold, timeout, task_id)
else:
return json.dumps({"error": f"Unknown terminal function: {function_name}"})
@ -614,21 +615,22 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
return json.dumps({"error": f"Unknown image generation function: {function_name}"})
def handle_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
"""
Main function call dispatcher that routes calls to appropriate toolsets.
This function determines which toolset a function belongs to and dispatches
the call to the appropriate handler. This makes it easy to add new toolsets
without changing the main calling interface.
Args:
function_name (str): Name of the function to call
function_args (Dict): Arguments for the function
task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional)
Returns:
str: Function result as JSON string
Raises:
None: Returns error as JSON string instead of raising exceptions
"""
@ -636,28 +638,28 @@ def handle_function_call(function_name: str, function_args: Dict[str, Any]) -> s
# Route web tools
if function_name in ["web_search", "web_extract", "web_crawl"]:
return handle_web_function_call(function_name, function_args)
# Route terminal tools
elif function_name in ["terminal"]:
return handle_terminal_function_call(function_name, function_args)
return handle_terminal_function_call(function_name, function_args, task_id)
# Route vision tools
elif function_name in ["vision_analyze"]:
return handle_vision_function_call(function_name, function_args)
# Route MoA tools
elif function_name in ["mixture_of_agents"]:
return handle_moa_function_call(function_name, function_args)
# Route image generation tools
elif function_name in ["image_generate"]:
return handle_image_function_call(function_name, function_args)
else:
error_msg = f"Unknown function: {function_name}"
print(f"{error_msg}")
return json.dumps({"error": error_msg})
except Exception as e:
error_msg = f"Error executing {function_name}: {str(e)}"
print(f"{error_msg}")

View file

@ -54,9 +54,9 @@ class AIAgent:
"""
def __init__(
self,
base_url: str = None,
api_key: str = None,
self,
base_url: str = None,
api_key: str = None,
model: str = "gpt-4",
max_iterations: int = 10,
tool_delay: float = 1.0,
@ -64,11 +64,12 @@ class AIAgent:
disabled_toolsets: List[str] = None,
save_trajectories: bool = False,
verbose_logging: bool = False,
ephemeral_system_prompt: str = None
ephemeral_system_prompt: str = None,
task_id: str = None
):
"""
Initialize the AI Agent.
Args:
base_url (str): Base URL for the model API (optional)
api_key (str): API key for authentication (optional, uses env var if not provided)
@ -80,6 +81,7 @@ class AIAgent:
save_trajectories (bool): Whether to save conversation trajectories to JSONL files (default: False)
verbose_logging (bool): Enable verbose logging for debugging (default: False)
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
task_id (str): Unique identifier for this task to isolate VMs between concurrent tasks (optional)
"""
self.model = model
self.max_iterations = max_iterations
@ -87,7 +89,11 @@ class AIAgent:
self.save_trajectories = save_trajectories
self.verbose_logging = verbose_logging
self.ephemeral_system_prompt = ephemeral_system_prompt
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
import uuid
self.task_id = task_id or str(uuid.uuid4())
# Store toolset filtering options
self.enabled_toolsets = enabled_toolsets
self.disabled_toolsets = disabled_toolsets
@ -469,12 +475,12 @@ class AIAgent:
function_args = {}
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())})")
tool_start_time = time.time()
# Execute the tool
function_result = handle_function_call(function_name, function_args)
# Execute the tool with task_id to isolate VMs between concurrent tasks
function_result = handle_function_call(function_name, function_args, self.task_id)
tool_duration = time.time() - tool_start_time
result_preview = function_result[:200] if len(function_result) > 200 else function_result

View file

@ -75,8 +75,9 @@ When commands enter interactive mode (vim, nano, less, git prompts, package mana
# Global state for VM lifecycle management
# These persist across tool calls to enable session continuity
_active_instance = None
_active_context = None
# Changed to dictionaries keyed by task_id to prevent leakage between concurrent tasks
_active_instances: Dict[str, Any] = {}
_active_contexts: Dict[str, Any] = {}
_instance_lock = threading.Lock()
def terminal_tool(
@ -85,23 +86,25 @@ def terminal_tool(
session_id: Optional[str] = None,
background: bool = False,
idle_threshold: float = 5.0,
timeout: Optional[int] = None
timeout: Optional[int] = None,
task_id: Optional[str] = None
) -> str:
"""
Execute a command on a Morph VM with optional interactive session support.
This tool uses Hecate's VM lifecycle management to automatically create
and manage VMs. VMs are reused within the configured lifetime window
and automatically cleaned up after inactivity.
Args:
command: The command to execute (optional if continuing existing session)
input_keys: Keystrokes to send to interactive session (e.g., "hello\\n")
session_id: ID of existing session to continue (optional)
background: Whether to run the command in the background (default: False)
background: Whether to run the command in the background (default: False)
idle_threshold: Seconds to wait for output before considering session idle (default: 5.0)
timeout: Command timeout in seconds (optional)
task_id: Unique identifier for this task to isolate VMs between concurrent tasks (optional)
Returns:
str: JSON string containing command output, session info, exit code, and any errors
@ -120,7 +123,7 @@ def terminal_tool(
# Run a background task
>>> result = terminal_tool(command="sleep 60", background=True)
"""
global _active_instance, _active_context
global _active_instances, _active_contexts
try:
# Import required modules lazily so this module can be imported
@ -135,10 +138,8 @@ def terminal_tool(
return json.dumps({
"output": "",
"screen": "",
"session_id": None,
"exit_code": -1,
"error": f"Terminal tool is disabled due to import error: {import_error}",
"status": "disabled"
"error": f"Terminal tool is disabled due to import error: {import_error}"
})
# Get configuration from environment
@ -151,25 +152,27 @@ def terminal_tool(
return json.dumps({
"output": "",
"screen": "",
"session_id": None,
"exit_code": -1,
"error": "MORPH_API_KEY environment variable not set",
"status": "disabled"
"error": "MORPH_API_KEY environment variable not set"
})
# Get or create VM instance and execution context
# Use task_id to isolate VMs between concurrent tasks
# If no task_id provided, use "default" for backward compatibility
effective_task_id = task_id or "default"
# Get or create VM instance and execution context per task
# This is critical for interactive session support - the context must persist!
with _instance_lock:
if _active_instance is None:
if effective_task_id not in _active_instances:
morph_client = MorphCloudClient(api_key=morph_api_key)
_active_instance = morph_client.instances.start(snapshot_id=snapshot_id)
_active_instances[effective_task_id] = morph_client.instances.start(snapshot_id=snapshot_id)
# Get or create persistent execution context
if _active_context is None:
_active_context = ExecutionContext()
# Get or create persistent execution context per task
if effective_task_id not in _active_contexts:
_active_contexts[effective_task_id] = ExecutionContext()
instance = _active_instance
ctx = _active_context
instance = _active_instances[effective_task_id]
ctx = _active_contexts[effective_task_id]
# Build tool input based on provided parameters
tool_input = {}
@ -208,15 +211,13 @@ def terminal_tool(
ctx=ctx
)
# Format the result with all possible fields
# Format the result with only essential fields for the LLM
# Map hecate's "stdout" to "output" for compatibility
formatted_result = {
"output": result.get("stdout", result.get("output", "")),
"screen": result.get("screen", ""),
"session_id": result.get("session_id"),
"exit_code": result.get("returncode", result.get("exit_code", -1)),
"error": result.get("error"),
"status": "active" if result.get("session_id") else "ended"
"error": result.get("error")
}
return json.dumps(formatted_result)
@ -225,10 +226,8 @@ def terminal_tool(
return json.dumps({
"output": "",
"screen": "",
"session_id": None,
"exit_code": -1,
"error": f"Failed to execute terminal command: {str(e)}",
"status": "error"
"error": f"Failed to execute terminal command: {str(e)}"
})
def check_hecate_requirements() -> bool: