mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-25 05:52:34 +00:00
chore: remove Atropos RL environments and tinker-atropos integration (#26106)
* chore: remove Atropos RL environments, tools, tests, skill, and tinker-atropos submodule Delete: - environments/ (43 files — base env, agent loop, tool call parsers, benchmarks) - rl_cli.py (standalone RL training CLI) - tools/rl_training_tool.py (all 10 rl_* tools) - tests: test_rl_training_tool, test_tool_call_parsers, test_managed_server_tool_support, test_agent_loop, test_agent_loop_vllm, test_agent_loop_tool_calling, test_terminalbench2_env_security - optional-skills/mlops/hermes-atropos-environments/ - tinker-atropos git submodule + .gitmodules * chore: remove RL/Atropos references from Python source - toolsets.py: remove rl toolset block + update comment - model_tools.py: remove rl_tools group + update async bridging comment - hermes_cli/tools_config.py: remove RL display entry, _DEFAULT_OFF_TOOLSETS, setup block, and rl_training post-setup handler - tools/budget_config.py: remove RL environment reference in docstring - tests/test_model_tools.py: remove rl_tools from expected groups - tests/run_agent/test_streaming_tool_call_repair.py: fix stale cross-reference * chore: remove rl/yc-bench extras and tinker-atropos refs from pyproject.toml - Remove rl extra (atroposlib, tinker, fastapi, uvicorn, wandb) - Remove yc-bench extra - Remove rl_cli from py-modules - Remove [tool.ty.src] exclude for tinker-atropos - Remove [tool.ruff] exclude for tinker-atropos - Regenerate uv.lock * chore: remove tinker-atropos from install/setup scripts - setup-hermes.sh: remove entire tinker-atropos submodule install block - scripts/install.sh: remove both tinker-atropos blocks (Termux + standard) - scripts/install.ps1: remove tinker-atropos block - nix/hermes-agent.nix: remove tinker-atropos pip install line * chore: remove RL references from cli-config.yaml.example * docs: remove Atropos/RL references from README, CONTRIBUTING, AGENTS.md * docs: remove RL/Atropos references from website - Delete: environments.md, rl-training.md, mlops-hermes-atropos-environments.md - sidebars.ts: remove rl-training and environments sidebar entries - optional-skills-catalog.md: remove hermes-atropos-environments row - tools-reference.md: remove entire rl toolset section - toolsets-reference.md: remove rl row + update example - integrations/index.md: remove RL Training bullet - architecture.md: remove environments/ from tree + RL section - contributing.md: remove tinker-atropos setup - updating.md: remove tinker-atropos install + stale submodule update * chore: remove remaining RL/Atropos stragglers - hermes_cli/config.py: remove TINKER_API_KEY + WANDB_API_KEY env var defs - hermes_cli/doctor.py: remove Submodules check section (tinker-atropos) - hermes_cli/setup.py: remove RL Training status check - hermes_cli/status.py: remove Tinker + WandB from API key status display - agent/display.py: remove both rl_* tool preview/activity blocks - website/docs: remove RL references from providers.md + env-variables.md - tests: remove TINKER_API_KEY from conftest, set_config_value, setup_script * chore: remove RL training section from .env.example
This commit is contained in:
parent
d364132114
commit
5af672c753
97 changed files with 18 additions and 15690 deletions
|
|
@ -101,7 +101,6 @@ _CREDENTIAL_NAMES = frozenset({
|
|||
"RETAINDB_API_KEY",
|
||||
"HINDSIGHT_API_KEY",
|
||||
"HINDSIGHT_LLM_API_KEY",
|
||||
"TINKER_API_KEY",
|
||||
"DAYTONA_API_KEY",
|
||||
"TWILIO_AUTH_TOKEN",
|
||||
"TELEGRAM_BOT_TOKEN",
|
||||
|
|
|
|||
|
|
@ -1,164 +0,0 @@
|
|||
"""Security tests for Terminal-Bench 2 archive extraction."""
|
||||
|
||||
import base64
|
||||
import importlib
|
||||
import io
|
||||
import sys
|
||||
import tarfile
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = types.ModuleType(name)
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
def _load_terminalbench_module(monkeypatch):
|
||||
class _EvalHandlingEnum:
|
||||
STOP_TRAIN = "stop_train"
|
||||
|
||||
class _APIServerConfig:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
class _AgentResult:
|
||||
pass
|
||||
|
||||
class _HermesAgentLoop:
|
||||
pass
|
||||
|
||||
class _HermesAgentBaseEnv:
|
||||
pass
|
||||
|
||||
class _HermesAgentEnvConfig:
|
||||
pass
|
||||
|
||||
class _ToolContext:
|
||||
pass
|
||||
|
||||
stub_modules = {
|
||||
"atroposlib": _stub_module("atroposlib"),
|
||||
"atroposlib.envs": _stub_module("atroposlib.envs"),
|
||||
"atroposlib.envs.base": _stub_module(
|
||||
"atroposlib.envs.base",
|
||||
EvalHandlingEnum=_EvalHandlingEnum,
|
||||
),
|
||||
"atroposlib.envs.server_handling": _stub_module("atroposlib.envs.server_handling"),
|
||||
"atroposlib.envs.server_handling.server_manager": _stub_module(
|
||||
"atroposlib.envs.server_handling.server_manager",
|
||||
APIServerConfig=_APIServerConfig,
|
||||
),
|
||||
"environments.agent_loop": _stub_module(
|
||||
"environments.agent_loop",
|
||||
AgentResult=_AgentResult,
|
||||
HermesAgentLoop=_HermesAgentLoop,
|
||||
),
|
||||
"environments.hermes_base_env": _stub_module(
|
||||
"environments.hermes_base_env",
|
||||
HermesAgentBaseEnv=_HermesAgentBaseEnv,
|
||||
HermesAgentEnvConfig=_HermesAgentEnvConfig,
|
||||
),
|
||||
"environments.tool_context": _stub_module(
|
||||
"environments.tool_context",
|
||||
ToolContext=_ToolContext,
|
||||
),
|
||||
"tools.terminal_tool": _stub_module(
|
||||
"tools.terminal_tool",
|
||||
register_task_env_overrides=lambda *args, **kwargs: None,
|
||||
clear_task_env_overrides=lambda *args, **kwargs: None,
|
||||
cleanup_vm=lambda *args, **kwargs: None,
|
||||
),
|
||||
}
|
||||
|
||||
stub_modules["atroposlib"].envs = stub_modules["atroposlib.envs"]
|
||||
stub_modules["atroposlib.envs"].base = stub_modules["atroposlib.envs.base"]
|
||||
stub_modules["atroposlib.envs"].server_handling = stub_modules["atroposlib.envs.server_handling"]
|
||||
stub_modules["atroposlib.envs.server_handling"].server_manager = stub_modules[
|
||||
"atroposlib.envs.server_handling.server_manager"
|
||||
]
|
||||
|
||||
for name, module in stub_modules.items():
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
|
||||
module_name = "environments.benchmarks.terminalbench_2.terminalbench2_env"
|
||||
sys.modules.pop(module_name, None)
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
def _build_tar_b64(entries):
|
||||
buf = io.BytesIO()
|
||||
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
||||
for entry in entries:
|
||||
kind = entry["kind"]
|
||||
info = tarfile.TarInfo(entry["name"])
|
||||
|
||||
if kind == "dir":
|
||||
info.type = tarfile.DIRTYPE
|
||||
tar.addfile(info)
|
||||
continue
|
||||
|
||||
if kind == "file":
|
||||
data = entry["data"].encode("utf-8")
|
||||
info.size = len(data)
|
||||
tar.addfile(info, io.BytesIO(data))
|
||||
continue
|
||||
|
||||
if kind == "symlink":
|
||||
info.type = tarfile.SYMTYPE
|
||||
info.linkname = entry["target"]
|
||||
tar.addfile(info)
|
||||
continue
|
||||
|
||||
raise ValueError(f"Unknown tar entry kind: {kind}")
|
||||
|
||||
return base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
|
||||
def test_extract_base64_tar_allows_safe_files(tmp_path, monkeypatch):
|
||||
module = _load_terminalbench_module(monkeypatch)
|
||||
archive = _build_tar_b64(
|
||||
[
|
||||
{"kind": "dir", "name": "nested"},
|
||||
{"kind": "file", "name": "nested/hello.txt", "data": "hello"},
|
||||
]
|
||||
)
|
||||
|
||||
target = tmp_path / "extract"
|
||||
module._extract_base64_tar(archive, target)
|
||||
|
||||
assert (target / "nested" / "hello.txt").read_text(encoding="utf-8") == "hello"
|
||||
|
||||
|
||||
def test_extract_base64_tar_rejects_path_traversal(tmp_path, monkeypatch):
|
||||
module = _load_terminalbench_module(monkeypatch)
|
||||
archive = _build_tar_b64(
|
||||
[
|
||||
{"kind": "file", "name": "../escape.txt", "data": "owned"},
|
||||
]
|
||||
)
|
||||
|
||||
target = tmp_path / "extract"
|
||||
with pytest.raises(ValueError, match="Unsafe archive member path"):
|
||||
module._extract_base64_tar(archive, target)
|
||||
|
||||
assert not (tmp_path / "escape.txt").exists()
|
||||
|
||||
|
||||
def test_extract_base64_tar_rejects_symlinks(tmp_path, monkeypatch):
|
||||
module = _load_terminalbench_module(monkeypatch)
|
||||
archive = _build_tar_b64(
|
||||
[
|
||||
{"kind": "symlink", "name": "link", "target": "../../escape.txt"},
|
||||
]
|
||||
)
|
||||
|
||||
target = tmp_path / "extract"
|
||||
with pytest.raises(ValueError, match="Unsupported archive member type"):
|
||||
module._extract_base64_tar(archive, target)
|
||||
|
||||
assert not (target / "link").exists()
|
||||
|
|
@ -39,8 +39,6 @@ class TestExplicitAllowlist:
|
|||
"OPENROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"ANTHROPIC_API_KEY",
|
||||
"WANDB_API_KEY",
|
||||
"TINKER_API_KEY",
|
||||
"HONCHO_API_KEY",
|
||||
"FIRECRAWL_API_KEY",
|
||||
"BROWSERBASE_API_KEY",
|
||||
|
|
|
|||
|
|
@ -18,4 +18,3 @@ def test_setup_hermes_script_has_termux_path():
|
|||
assert ".[termux]" in content
|
||||
assert "constraints-termux.txt" in content
|
||||
assert "$PREFIX/bin" in content
|
||||
assert "Skipping tinker-atropos on Termux" in content
|
||||
|
|
|
|||
|
|
@ -1,505 +0,0 @@
|
|||
"""
|
||||
Tests for environments/agent_loop.py — HermesAgentLoop.
|
||||
|
||||
Tests the multi-turn agent engine using mocked servers, without needing
|
||||
real API keys or running servers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure repo root is importable
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
|
||||
|
||||
try:
|
||||
from environments.agent_loop import (
|
||||
AgentResult,
|
||||
HermesAgentLoop,
|
||||
ToolError,
|
||||
_extract_reasoning_from_message,
|
||||
resize_tool_pool,
|
||||
)
|
||||
except ImportError:
|
||||
pytest.skip("atroposlib not installed", allow_module_level=True)
|
||||
|
||||
|
||||
# ─── Mock server infrastructure ─────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockFunction:
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockToolCall:
|
||||
id: str
|
||||
function: MockFunction
|
||||
type: str = "function"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockMessage:
|
||||
content: Optional[str]
|
||||
role: str = "assistant"
|
||||
tool_calls: Optional[List[MockToolCall]] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
reasoning: Optional[str] = None
|
||||
reasoning_details: Optional[list] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockChoice:
|
||||
message: MockMessage
|
||||
finish_reason: str = "stop"
|
||||
index: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockChatCompletion:
|
||||
choices: List[MockChoice]
|
||||
id: str = "chatcmpl-mock"
|
||||
model: str = "mock-model"
|
||||
|
||||
|
||||
class MockServer:
|
||||
"""
|
||||
Mock server that returns pre-configured responses in sequence.
|
||||
Mimics the chat_completion() interface.
|
||||
"""
|
||||
|
||||
def __init__(self, responses: List[MockChatCompletion]):
|
||||
self.responses = responses
|
||||
self.call_count = 0
|
||||
self.call_history: List[Dict[str, Any]] = []
|
||||
|
||||
async def chat_completion(self, **kwargs) -> MockChatCompletion:
|
||||
self.call_history.append(kwargs)
|
||||
if self.call_count >= len(self.responses):
|
||||
# Return a simple text response if we run out
|
||||
return MockChatCompletion(
|
||||
choices=[MockChoice(message=MockMessage(content="Done."))]
|
||||
)
|
||||
resp = self.responses[self.call_count]
|
||||
self.call_count += 1
|
||||
return resp
|
||||
|
||||
|
||||
def make_text_response(content: str) -> MockChatCompletion:
|
||||
"""Create a simple text-only response (no tool calls)."""
|
||||
return MockChatCompletion(
|
||||
choices=[MockChoice(message=MockMessage(content=content))]
|
||||
)
|
||||
|
||||
|
||||
def make_tool_response(
|
||||
tool_name: str,
|
||||
arguments: dict,
|
||||
content: str = "",
|
||||
tool_call_id: str = "call_001",
|
||||
) -> MockChatCompletion:
|
||||
"""Create a response with a single tool call."""
|
||||
return MockChatCompletion(
|
||||
choices=[
|
||||
MockChoice(
|
||||
message=MockMessage(
|
||||
content=content,
|
||||
tool_calls=[
|
||||
MockToolCall(
|
||||
id=tool_call_id,
|
||||
function=MockFunction(
|
||||
name=tool_name,
|
||||
arguments=json.dumps(arguments),
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ─── Tests ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAgentResult:
|
||||
def test_defaults(self):
|
||||
result = AgentResult(messages=[])
|
||||
assert result.messages == []
|
||||
assert result.managed_state is None
|
||||
assert result.turns_used == 0
|
||||
assert result.finished_naturally is False
|
||||
assert result.reasoning_per_turn == []
|
||||
assert result.tool_errors == []
|
||||
|
||||
|
||||
class TestExtractReasoning:
|
||||
def test_reasoning_content_field(self):
|
||||
msg = MockMessage(content="hello", reasoning_content="I think...")
|
||||
assert _extract_reasoning_from_message(msg) == "I think..."
|
||||
|
||||
def test_reasoning_field(self):
|
||||
msg = MockMessage(content="hello", reasoning="Let me consider...")
|
||||
assert _extract_reasoning_from_message(msg) == "Let me consider..."
|
||||
|
||||
def test_reasoning_details(self):
|
||||
detail = MagicMock()
|
||||
detail.text = "Detail reasoning"
|
||||
msg = MockMessage(content="hello", reasoning_details=[detail])
|
||||
assert _extract_reasoning_from_message(msg) == "Detail reasoning"
|
||||
|
||||
def test_reasoning_details_dict_format(self):
|
||||
msg = MockMessage(
|
||||
content="hello",
|
||||
reasoning_details=[{"text": "Dict reasoning"}],
|
||||
)
|
||||
assert _extract_reasoning_from_message(msg) == "Dict reasoning"
|
||||
|
||||
def test_no_reasoning(self):
|
||||
msg = MockMessage(content="hello")
|
||||
assert _extract_reasoning_from_message(msg) is None
|
||||
|
||||
def test_reasoning_content_takes_priority(self):
|
||||
msg = MockMessage(
|
||||
content="hello",
|
||||
reasoning_content="First",
|
||||
reasoning="Second",
|
||||
)
|
||||
assert _extract_reasoning_from_message(msg) == "First"
|
||||
|
||||
|
||||
class TestHermesAgentLoop:
|
||||
"""Test the agent loop with mock servers."""
|
||||
|
||||
@pytest.fixture
|
||||
def basic_tools(self):
|
||||
"""Minimal tool schema for testing."""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "terminal",
|
||||
"description": "Run a command",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Command to run",
|
||||
}
|
||||
},
|
||||
"required": ["command"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Read a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string"},
|
||||
},
|
||||
"required": ["path"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def valid_names(self):
|
||||
return {"terminal", "read_file", "todo"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_text_response(self, basic_tools, valid_names):
|
||||
"""Model responds with text only, no tool calls."""
|
||||
server = MockServer([make_text_response("Hello! How can I help?")])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.finished_naturally is True
|
||||
assert result.turns_used == 1
|
||||
assert len(result.messages) >= 2 # user + assistant
|
||||
assert result.messages[-1]["role"] == "assistant"
|
||||
assert result.messages[-1]["content"] == "Hello! How can I help?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_then_text(self, basic_tools, valid_names):
|
||||
"""Model calls a tool, then responds with text."""
|
||||
server = MockServer([
|
||||
make_tool_response("todo", {"todos": [{"id": "1", "content": "test", "status": "pending"}]}),
|
||||
make_text_response("I created a todo for you."),
|
||||
])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Create a todo"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.finished_naturally is True
|
||||
assert result.turns_used == 2
|
||||
# Should have: user, assistant (tool_call), tool (result), assistant (text)
|
||||
roles = [m["role"] for m in result.messages]
|
||||
assert roles == ["user", "assistant", "tool", "assistant"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_turns_reached(self, basic_tools, valid_names):
|
||||
"""Model keeps calling tools until max_turns is hit."""
|
||||
# Create responses that always call a tool
|
||||
responses = [
|
||||
make_tool_response("todo", {"todos": [{"id": str(i), "content": f"task {i}", "status": "pending"}]}, tool_call_id=f"call_{i}")
|
||||
for i in range(10)
|
||||
]
|
||||
server = MockServer(responses)
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=3,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Keep going"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.finished_naturally is False
|
||||
assert result.turns_used == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_name(self, basic_tools, valid_names):
|
||||
"""Model calls a tool not in valid_tool_names."""
|
||||
server = MockServer([
|
||||
make_tool_response("nonexistent_tool", {"arg": "val"}),
|
||||
make_text_response("OK, that didn't work."),
|
||||
])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Call something weird"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Should record a tool error
|
||||
assert len(result.tool_errors) >= 1
|
||||
assert result.tool_errors[0].tool_name == "nonexistent_tool"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response(self, basic_tools, valid_names):
|
||||
"""Server returns empty response."""
|
||||
server = MockServer([MockChatCompletion(choices=[])])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.finished_naturally is False
|
||||
assert result.turns_used == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_handling(self, basic_tools, valid_names):
|
||||
"""Server raises an exception."""
|
||||
|
||||
class FailingServer:
|
||||
async def chat_completion(self, **kwargs):
|
||||
raise ConnectionError("Server unreachable")
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=FailingServer(),
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.finished_naturally is False
|
||||
assert result.turns_used == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tools_passed_to_server(self, basic_tools, valid_names):
|
||||
"""Verify tools are passed in the chat_completion kwargs."""
|
||||
server = MockServer([make_text_response("OK")])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
await agent.run(messages)
|
||||
|
||||
assert len(server.call_history) == 1
|
||||
assert "tools" in server.call_history[0]
|
||||
assert server.call_history[0]["tools"] == basic_tools
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_body_forwarded(self, basic_tools, valid_names):
|
||||
"""extra_body should be forwarded to server."""
|
||||
extra = {"provider": {"ignore": ["DeepInfra"]}}
|
||||
server = MockServer([make_text_response("OK")])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
extra_body=extra,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
await agent.run(messages)
|
||||
|
||||
assert server.call_history[0].get("extra_body") == extra
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_managed_state_returned(self, basic_tools, valid_names):
|
||||
"""If server has get_state(), result should include managed_state."""
|
||||
server = MockServer([make_text_response("OK")])
|
||||
server.get_state = lambda: {"nodes": [{"test": True}]}
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.managed_state is not None
|
||||
assert "nodes" in result.managed_state
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_managed_state_without_get_state(self, basic_tools, valid_names):
|
||||
"""Regular server without get_state() should return None managed_state."""
|
||||
server = MockServer([make_text_response("OK")])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.managed_state is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_tool_blocked(self, basic_tools):
|
||||
"""Memory tool should return error in RL environments."""
|
||||
valid = {"terminal", "read_file", "todo", "memory"}
|
||||
server = MockServer([
|
||||
make_tool_response("memory", {"action": "add", "target": "user", "content": "test"}),
|
||||
make_text_response("Done"),
|
||||
])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Remember this"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Find the tool response
|
||||
tool_msgs = [m for m in result.messages if m["role"] == "tool"]
|
||||
assert len(tool_msgs) >= 1
|
||||
tool_result = json.loads(tool_msgs[0]["content"])
|
||||
assert "error" in tool_result
|
||||
assert "not available" in tool_result["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_search_blocked(self, basic_tools):
|
||||
"""session_search should return error in RL environments."""
|
||||
valid = {"terminal", "read_file", "todo", "session_search"}
|
||||
server = MockServer([
|
||||
make_tool_response("session_search", {"query": "test"}),
|
||||
make_text_response("Done"),
|
||||
])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Search sessions"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
tool_msgs = [m for m in result.messages if m["role"] == "tool"]
|
||||
assert len(tool_msgs) >= 1
|
||||
tool_result = json.loads(tool_msgs[0]["content"])
|
||||
assert "error" in tool_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_content_preserved(self, basic_tools, valid_names):
|
||||
"""Reasoning content should be extracted and preserved."""
|
||||
resp = MockChatCompletion(
|
||||
choices=[
|
||||
MockChoice(
|
||||
message=MockMessage(
|
||||
content="The answer is 42.",
|
||||
reasoning_content="Let me think about this step by step...",
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
server = MockServer([resp])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "What is the meaning of life?"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert len(result.reasoning_per_turn) == 1
|
||||
assert result.reasoning_per_turn[0] == "Let me think about this step by step..."
|
||||
|
||||
|
||||
class TestResizeToolPool:
|
||||
def test_resize_works(self):
|
||||
"""resize_tool_pool should not raise."""
|
||||
resize_tool_pool(16) # Small pool for testing
|
||||
resize_tool_pool(128) # Restore default
|
||||
|
||||
def test_resize_shuts_down_previous_executor(self, monkeypatch):
|
||||
"""Replacing the global tool executor should shut down the old pool."""
|
||||
import environments.agent_loop as agent_loop_module
|
||||
|
||||
old_executor = MagicMock()
|
||||
new_executor = MagicMock()
|
||||
|
||||
monkeypatch.setattr(agent_loop_module, "_tool_executor", old_executor)
|
||||
monkeypatch.setattr(
|
||||
agent_loop_module.concurrent.futures,
|
||||
"ThreadPoolExecutor",
|
||||
MagicMock(return_value=new_executor),
|
||||
)
|
||||
|
||||
resize_tool_pool(16)
|
||||
|
||||
old_executor.shutdown.assert_called_once_with(wait=False)
|
||||
assert agent_loop_module._tool_executor is new_executor
|
||||
|
|
@ -1,552 +0,0 @@
|
|||
"""Integration tests for HermesAgentLoop tool calling.
|
||||
|
||||
Tests the full agent loop with real LLM calls via OpenRouter.
|
||||
Uses stepfun/step-3.5-flash:free by default (zero cost), falls back
|
||||
to anthropic/claude-sonnet-4 if the free model is unavailable.
|
||||
|
||||
These tests verify:
|
||||
1. Single tool call: model calls a tool, gets result, responds
|
||||
2. Multi-tool call: model calls multiple tools in one turn
|
||||
3. Multi-turn: model calls tools across multiple turns
|
||||
4. Unknown tool rejection: model calling a non-existent tool gets an error
|
||||
5. Max turns: loop stops when max_turns is reached
|
||||
6. No tools: model responds without calling any tools
|
||||
7. Tool error handling: tool execution errors are captured
|
||||
|
||||
Run:
|
||||
pytest tests/test_agent_loop_tool_calling.py -v
|
||||
pytest tests/test_agent_loop_tool_calling.py -v -k "single" # run one test
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
# pytestmark removed — tests skip gracefully via OPENROUTER_API_KEY check on line 59
|
||||
|
||||
# Ensure repo root is importable
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
try:
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("atroposlib not installed", allow_module_level=True)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Test infrastructure
|
||||
# =========================================================================
|
||||
|
||||
# Models to try, in order of preference (free first)
|
||||
_MODELS = [
|
||||
"stepfun/step-3.5-flash:free",
|
||||
"google/gemini-2.0-flash-001",
|
||||
"anthropic/claude-sonnet-4",
|
||||
]
|
||||
|
||||
def _get_api_key():
|
||||
key = os.getenv("OPENROUTER_API_KEY", "")
|
||||
if not key:
|
||||
pytest.skip("OPENROUTER_API_KEY not set")
|
||||
return key
|
||||
|
||||
|
||||
def _make_server(model: str = None):
|
||||
"""Create an OpenAI server for testing."""
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
config = APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name=model or _MODELS[0],
|
||||
server_type="openai",
|
||||
api_key=_get_api_key(),
|
||||
health_check=False,
|
||||
)
|
||||
return OpenAIServer(config)
|
||||
|
||||
|
||||
async def _try_models(test_fn):
|
||||
"""Try running a test with each model until one works."""
|
||||
last_error = None
|
||||
for model in _MODELS:
|
||||
try:
|
||||
server = _make_server(model)
|
||||
return await test_fn(server, model)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if "rate" in str(e).lower() or "limit" in str(e).lower():
|
||||
continue # Rate limited, try next model
|
||||
raise # Real error
|
||||
pytest.skip(f"All models failed. Last error: {last_error}")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Fake tools for testing
|
||||
# =========================================================================
|
||||
|
||||
# Simple calculator tool
|
||||
CALC_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"description": "Calculate a math expression. Returns the numeric result.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Math expression to evaluate, e.g. '2 + 3'"
|
||||
}
|
||||
},
|
||||
"required": ["expression"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Weather lookup tool
|
||||
WEATHER_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a city. Returns temperature and conditions.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name, e.g. 'Tokyo'"
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Lookup tool (always succeeds)
|
||||
LOOKUP_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup",
|
||||
"description": "Look up a fact. Returns a short answer string.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "What to look up"
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Error tool (always fails)
|
||||
ERROR_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "failing_tool",
|
||||
"description": "A tool that always fails with an error.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {"type": "string"}
|
||||
},
|
||||
"required": ["input"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _fake_tool_handler(tool_name: str, args: Dict[str, Any], **kwargs) -> str:
|
||||
"""Handle fake tool calls for testing."""
|
||||
if tool_name == "calculate":
|
||||
expr = args.get("expression", "0")
|
||||
try:
|
||||
# Safe eval for simple math
|
||||
result = eval(expr, {"__builtins__": {}}, {})
|
||||
return json.dumps({"result": result})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
elif tool_name == "get_weather":
|
||||
city = args.get("city", "Unknown")
|
||||
# Return canned weather
|
||||
return json.dumps({
|
||||
"city": city,
|
||||
"temperature": 22,
|
||||
"conditions": "sunny",
|
||||
"humidity": 45,
|
||||
})
|
||||
|
||||
elif tool_name == "lookup":
|
||||
query = args.get("query", "")
|
||||
return json.dumps({"answer": f"The answer to '{query}' is 42."})
|
||||
|
||||
elif tool_name == "failing_tool":
|
||||
raise RuntimeError("This tool always fails!")
|
||||
|
||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_tool_call():
|
||||
"""Model should call a single tool, get the result, and respond."""
|
||||
|
||||
async def _run(server, model):
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=[WEATHER_TOOL],
|
||||
valid_tool_names={"get_weather"},
|
||||
max_turns=5,
|
||||
temperature=0.0,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "What's the weather in Tokyo? Use the get_weather tool."},
|
||||
]
|
||||
|
||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert isinstance(result, AgentResult)
|
||||
assert result.turns_used >= 2, f"Expected at least 2 turns (tool call + response), got {result.turns_used}"
|
||||
|
||||
# Verify a tool call happened
|
||||
tool_calls_found = False
|
||||
for msg in result.messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
if tc["function"]["name"] == "get_weather":
|
||||
tool_calls_found = True
|
||||
args = json.loads(tc["function"]["arguments"])
|
||||
assert "city" in args
|
||||
assert tool_calls_found, "Model should have called get_weather"
|
||||
|
||||
# Verify tool result is in conversation
|
||||
tool_results = [m for m in result.messages if m.get("role") == "tool"]
|
||||
assert len(tool_results) >= 1, "Should have at least one tool result"
|
||||
|
||||
# Verify the final response references the weather
|
||||
final_msg = result.messages[-1]
|
||||
assert final_msg["role"] == "assistant"
|
||||
assert final_msg["content"], "Final response should have content"
|
||||
|
||||
return result
|
||||
|
||||
await _try_models(_run)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_tool_single_turn():
|
||||
"""Model should call multiple tools in a single turn."""
|
||||
|
||||
async def _run(server, model):
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=[WEATHER_TOOL, CALC_TOOL],
|
||||
valid_tool_names={"get_weather", "calculate"},
|
||||
max_turns=5,
|
||||
temperature=0.0,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": (
|
||||
"I need two things at once: "
|
||||
"1) What's the weather in Paris? Use get_weather. "
|
||||
"2) What is 15 * 7? Use calculate. "
|
||||
"Call BOTH tools in a single response."
|
||||
)},
|
||||
]
|
||||
|
||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Count distinct tools called
|
||||
tools_called = set()
|
||||
for msg in result.messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tools_called.add(tc["function"]["name"])
|
||||
|
||||
# At minimum, both tools should have been called (maybe in different turns)
|
||||
assert "get_weather" in tools_called, f"get_weather not called. Called: {tools_called}"
|
||||
assert "calculate" in tools_called, f"calculate not called. Called: {tools_called}"
|
||||
|
||||
return result
|
||||
|
||||
await _try_models(_run)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_turn_conversation():
|
||||
"""Agent should handle multiple turns of tool calls."""
|
||||
|
||||
async def _run(server, model):
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=[LOOKUP_TOOL, CALC_TOOL],
|
||||
valid_tool_names={"lookup", "calculate"},
|
||||
max_turns=10,
|
||||
temperature=0.0,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": (
|
||||
"First, use the lookup tool to look up 'meaning of life'. "
|
||||
"Then use calculate to compute 6 * 7. "
|
||||
"Do these in separate tool calls, one at a time."
|
||||
)},
|
||||
]
|
||||
|
||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Should have used both tools
|
||||
tools_called = set()
|
||||
for msg in result.messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tools_called.add(tc["function"]["name"])
|
||||
|
||||
assert "lookup" in tools_called, f"lookup not called. Called: {tools_called}"
|
||||
assert "calculate" in tools_called, f"calculate not called. Called: {tools_called}"
|
||||
|
||||
# Should finish naturally
|
||||
assert result.finished_naturally, "Should finish naturally after answering"
|
||||
|
||||
return result
|
||||
|
||||
await _try_models(_run)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_rejected():
|
||||
"""If the model calls a tool not in valid_tool_names, it gets an error."""
|
||||
|
||||
async def _run(server, model):
|
||||
# Only allow "calculate" but give schema for both
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=[CALC_TOOL, WEATHER_TOOL],
|
||||
valid_tool_names={"calculate"}, # weather NOT allowed
|
||||
max_turns=5,
|
||||
temperature=0.0,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "What's the weather in London? Use get_weather."},
|
||||
]
|
||||
|
||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Check if get_weather was called and rejected
|
||||
if result.tool_errors:
|
||||
weather_errors = [e for e in result.tool_errors if e.tool_name == "get_weather"]
|
||||
assert len(weather_errors) > 0, "get_weather should have been rejected"
|
||||
assert "Unknown tool" in weather_errors[0].error
|
||||
|
||||
return result
|
||||
|
||||
await _try_models(_run)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_turns_limit():
|
||||
"""Agent should stop after max_turns even if model keeps calling tools."""
|
||||
|
||||
async def _run(server, model):
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=[LOOKUP_TOOL],
|
||||
valid_tool_names={"lookup"},
|
||||
max_turns=2, # Very low limit
|
||||
temperature=0.0,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": (
|
||||
"Keep looking up facts. Look up 'fact 1', then 'fact 2', "
|
||||
"then 'fact 3', then 'fact 4'. Do them one at a time."
|
||||
)},
|
||||
]
|
||||
|
||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.turns_used <= 2, f"Should stop at max_turns=2, used {result.turns_used}"
|
||||
assert not result.finished_naturally, "Should NOT finish naturally (hit max_turns)"
|
||||
|
||||
return result
|
||||
|
||||
await _try_models(_run)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tools_direct_response():
|
||||
"""When no tools are useful, model should respond directly."""
|
||||
|
||||
async def _run(server, model):
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=[WEATHER_TOOL],
|
||||
valid_tool_names={"get_weather"},
|
||||
max_turns=5,
|
||||
temperature=0.0,
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "What is 2 + 2? Just answer directly, no tools needed."},
|
||||
]
|
||||
|
||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.finished_naturally, "Should finish naturally with a direct response"
|
||||
assert result.turns_used == 1, f"Should take exactly 1 turn for a direct answer, took {result.turns_used}"
|
||||
|
||||
final = result.messages[-1]
|
||||
assert final["role"] == "assistant"
|
||||
assert final["content"], "Should have text content"
|
||||
assert "4" in final["content"], "Should contain the answer '4'"
|
||||
|
||||
return result
|
||||
|
||||
await _try_models(_run)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_handling():
|
||||
"""Tool execution errors should be captured and reported to the model."""
|
||||
|
||||
async def _run(server, model):
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=[ERROR_TOOL],
|
||||
valid_tool_names={"failing_tool"},
|
||||
max_turns=5,
|
||||
temperature=0.0,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Please call the failing_tool with input 'test'."},
|
||||
]
|
||||
|
||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||
result = await agent.run(messages)
|
||||
|
||||
# The tool error should be recorded
|
||||
assert len(result.tool_errors) >= 1, "Should have at least one tool error"
|
||||
assert "RuntimeError" in result.tool_errors[0].error or "always fails" in result.tool_errors[0].error
|
||||
|
||||
# The error should be in the conversation as a tool result
|
||||
tool_results = [m for m in result.messages if m.get("role") == "tool"]
|
||||
assert len(tool_results) >= 1
|
||||
error_result = json.loads(tool_results[0]["content"])
|
||||
assert "error" in error_result
|
||||
|
||||
return result
|
||||
|
||||
await _try_models(_run)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_result_structure():
|
||||
"""Verify the AgentResult has all expected fields populated."""
|
||||
|
||||
async def _run(server, model):
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=[CALC_TOOL],
|
||||
valid_tool_names={"calculate"},
|
||||
max_turns=5,
|
||||
temperature=0.0,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "What is 3 + 4? Use the calculate tool."},
|
||||
]
|
||||
|
||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Structural checks
|
||||
assert isinstance(result, AgentResult)
|
||||
assert isinstance(result.messages, list)
|
||||
assert len(result.messages) >= 3, "Should have user + assistant(tool) + tool_result + assistant(final)"
|
||||
assert isinstance(result.turns_used, int)
|
||||
assert result.turns_used > 0
|
||||
assert isinstance(result.finished_naturally, bool)
|
||||
assert isinstance(result.tool_errors, list)
|
||||
assert isinstance(result.reasoning_per_turn, list)
|
||||
|
||||
# Messages should follow OpenAI format
|
||||
for msg in result.messages:
|
||||
assert "role" in msg, f"Message missing 'role': {msg}"
|
||||
assert msg["role"] in ("system", "user", "assistant", "tool"), f"Invalid role: {msg['role']}"
|
||||
|
||||
return result
|
||||
|
||||
await _try_models(_run)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_history_preserved():
|
||||
"""The full conversation history should be in result.messages."""
|
||||
|
||||
async def _run(server, model):
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=[WEATHER_TOOL],
|
||||
valid_tool_names={"get_weather"},
|
||||
max_turns=5,
|
||||
temperature=0.0,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful weather assistant."},
|
||||
{"role": "user", "content": "What's the weather in Berlin? Use get_weather."},
|
||||
]
|
||||
|
||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||
result = await agent.run(messages)
|
||||
|
||||
# System message should be preserved
|
||||
assert result.messages[0]["role"] == "system"
|
||||
assert "weather assistant" in result.messages[0]["content"]
|
||||
|
||||
# User message should be preserved
|
||||
assert result.messages[1]["role"] == "user"
|
||||
assert "Berlin" in result.messages[1]["content"]
|
||||
|
||||
# Should have assistant + tool + assistant sequence
|
||||
roles = [m["role"] for m in result.messages]
|
||||
assert "tool" in roles, "Should have tool results in conversation"
|
||||
|
||||
return result
|
||||
|
||||
await _try_models(_run)
|
||||
|
|
@ -1,359 +0,0 @@
|
|||
"""Integration tests for HermesAgentLoop with a local vLLM server.
|
||||
|
||||
Tests the full Phase 2 flow: ManagedServer + tool calling with a real
|
||||
vLLM backend, producing actual token IDs and logprobs for RL training.
|
||||
|
||||
Requires a running vLLM server. Start one from the atropos directory:
|
||||
|
||||
python -m example_trainer.vllm_api_server \
|
||||
--model Qwen/Qwen3-4B-Thinking-2507 \
|
||||
--port 9001 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--max-model-len=32000
|
||||
|
||||
Tests are automatically skipped if the server is not reachable.
|
||||
|
||||
Run:
|
||||
pytest tests/test_agent_loop_vllm.py -v
|
||||
pytest tests/test_agent_loop_vllm.py -v -k "single"
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
# Ensure repo root is importable
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
try:
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
except ImportError:
|
||||
pytest.skip("atroposlib not installed", allow_module_level=True)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Configuration
|
||||
# =========================================================================
|
||||
|
||||
VLLM_HOST = "localhost"
|
||||
VLLM_PORT = 9001
|
||||
VLLM_BASE_URL = f"http://{VLLM_HOST}:{VLLM_PORT}"
|
||||
VLLM_MODEL = "Qwen/Qwen3-4B-Thinking-2507"
|
||||
|
||||
|
||||
def _vllm_is_running() -> bool:
|
||||
"""Check if the vLLM server is reachable."""
|
||||
try:
|
||||
r = requests.get(f"{VLLM_BASE_URL}/health", timeout=3)
|
||||
return r.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Skip all tests in this module if vLLM is not running
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not _vllm_is_running(),
|
||||
reason=(
|
||||
f"vLLM server not reachable at {VLLM_BASE_URL}. "
|
||||
"Start it with: python -m example_trainer.vllm_api_server "
|
||||
f"--model {VLLM_MODEL} --port {VLLM_PORT} "
|
||||
"--gpu-memory-utilization 0.8 --max-model-len=32000"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Server setup
|
||||
# =========================================================================
|
||||
|
||||
def _make_server_manager():
|
||||
"""Create a ServerManager pointing to the local vLLM server."""
|
||||
from atroposlib.envs.server_handling.server_manager import (
|
||||
ServerManager,
|
||||
APIServerConfig,
|
||||
)
|
||||
|
||||
config = APIServerConfig(
|
||||
base_url=VLLM_BASE_URL,
|
||||
model_name=VLLM_MODEL,
|
||||
server_type="vllm",
|
||||
health_check=False,
|
||||
)
|
||||
sm = ServerManager([config], tool_parser="hermes")
|
||||
sm.servers[0].server_healthy = True
|
||||
return sm
|
||||
|
||||
|
||||
def _get_tokenizer():
|
||||
"""Load the tokenizer for the model."""
|
||||
from transformers import AutoTokenizer
|
||||
return AutoTokenizer.from_pretrained(VLLM_MODEL)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Fake tools
|
||||
# =========================================================================
|
||||
|
||||
WEATHER_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a city. Returns temperature and conditions.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name, e.g. 'Tokyo'",
|
||||
}
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
CALC_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"description": "Calculate a math expression. Returns the numeric result.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "Math expression, e.g. '2 + 3'",
|
||||
}
|
||||
},
|
||||
"required": ["expression"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _fake_tool_handler(tool_name: str, args: Dict[str, Any], **kwargs) -> str:
|
||||
"""Handle fake tool calls for testing."""
|
||||
if tool_name == "get_weather":
|
||||
city = args.get("city", "Unknown")
|
||||
return json.dumps({
|
||||
"city": city,
|
||||
"temperature": 22,
|
||||
"conditions": "sunny",
|
||||
"humidity": 45,
|
||||
})
|
||||
elif tool_name == "calculate":
|
||||
expr = args.get("expression", "0")
|
||||
try:
|
||||
result = eval(expr, {"__builtins__": {}}, {})
|
||||
return json.dumps({"result": result})
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)})
|
||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_single_tool_call():
|
||||
"""vLLM model calls a tool, gets result, responds — full Phase 2 flow."""
|
||||
sm = _make_server_manager()
|
||||
tokenizer = _get_tokenizer()
|
||||
|
||||
async with sm.managed_server(tokenizer=tokenizer) as managed:
|
||||
agent = HermesAgentLoop(
|
||||
server=managed,
|
||||
tool_schemas=[WEATHER_TOOL],
|
||||
valid_tool_names={"get_weather"},
|
||||
max_turns=5,
|
||||
temperature=0.6,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "What's the weather in Tokyo? Use the get_weather tool."},
|
||||
]
|
||||
|
||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert isinstance(result, AgentResult)
|
||||
assert result.turns_used >= 2, f"Expected at least 2 turns, got {result.turns_used}"
|
||||
|
||||
# Verify tool call happened
|
||||
tool_calls_found = False
|
||||
for msg in result.messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
if tc["function"]["name"] == "get_weather":
|
||||
tool_calls_found = True
|
||||
args = json.loads(tc["function"]["arguments"])
|
||||
assert "city" in args
|
||||
assert tool_calls_found, "Model should have called get_weather"
|
||||
|
||||
# Verify tool results in conversation
|
||||
tool_results = [m for m in result.messages if m.get("role") == "tool"]
|
||||
assert len(tool_results) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_multi_tool_calls():
|
||||
"""vLLM model calls multiple tools across turns."""
|
||||
sm = _make_server_manager()
|
||||
tokenizer = _get_tokenizer()
|
||||
|
||||
async with sm.managed_server(tokenizer=tokenizer) as managed:
|
||||
agent = HermesAgentLoop(
|
||||
server=managed,
|
||||
tool_schemas=[WEATHER_TOOL, CALC_TOOL],
|
||||
valid_tool_names={"get_weather", "calculate"},
|
||||
max_turns=10,
|
||||
temperature=0.6,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": (
|
||||
"I need two things: "
|
||||
"1) What's the weather in Paris? Use get_weather. "
|
||||
"2) What is 15 * 7? Use calculate."
|
||||
)},
|
||||
]
|
||||
|
||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Both tools should be called
|
||||
tools_called = set()
|
||||
for msg in result.messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tools_called.add(tc["function"]["name"])
|
||||
|
||||
assert "get_weather" in tools_called, f"get_weather not called. Called: {tools_called}"
|
||||
assert "calculate" in tools_called, f"calculate not called. Called: {tools_called}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_managed_server_produces_nodes():
|
||||
"""ManagedServer should produce SequenceNodes with tokens and logprobs."""
|
||||
sm = _make_server_manager()
|
||||
tokenizer = _get_tokenizer()
|
||||
|
||||
async with sm.managed_server(tokenizer=tokenizer) as managed:
|
||||
agent = HermesAgentLoop(
|
||||
server=managed,
|
||||
tool_schemas=[WEATHER_TOOL],
|
||||
valid_tool_names={"get_weather"},
|
||||
max_turns=5,
|
||||
temperature=0.6,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "What's the weather in Berlin? Use get_weather."},
|
||||
]
|
||||
|
||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Get the managed state — should have SequenceNodes
|
||||
state = managed.get_state()
|
||||
|
||||
assert state is not None, "ManagedServer should return state"
|
||||
nodes = state.get("nodes", [])
|
||||
assert len(nodes) >= 1, f"Should have at least 1 node, got {len(nodes)}"
|
||||
|
||||
node = nodes[0]
|
||||
assert hasattr(node, "tokens"), "Node should have tokens"
|
||||
assert hasattr(node, "logprobs"), "Node should have logprobs"
|
||||
assert len(node.tokens) > 0, "Tokens should not be empty"
|
||||
assert len(node.logprobs) > 0, "Logprobs should not be empty"
|
||||
assert len(node.tokens) == len(node.logprobs), (
|
||||
f"Tokens ({len(node.tokens)}) and logprobs ({len(node.logprobs)}) should have same length"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_no_tools_direct_response():
|
||||
"""vLLM model should respond directly when no tools are needed."""
|
||||
sm = _make_server_manager()
|
||||
tokenizer = _get_tokenizer()
|
||||
|
||||
async with sm.managed_server(tokenizer=tokenizer) as managed:
|
||||
agent = HermesAgentLoop(
|
||||
server=managed,
|
||||
tool_schemas=[WEATHER_TOOL],
|
||||
valid_tool_names={"get_weather"},
|
||||
max_turns=5,
|
||||
temperature=0.6,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "What is 2 + 2? Answer directly, no tools."},
|
||||
]
|
||||
|
||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.finished_naturally, "Should finish naturally"
|
||||
assert result.turns_used == 1, f"Should take 1 turn, took {result.turns_used}"
|
||||
|
||||
final = result.messages[-1]
|
||||
assert final["role"] == "assistant"
|
||||
assert final["content"], "Should have content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_thinking_content_extracted():
|
||||
"""Qwen3-Thinking model should produce reasoning content."""
|
||||
sm = _make_server_manager()
|
||||
tokenizer = _get_tokenizer()
|
||||
|
||||
async with sm.managed_server(
|
||||
tokenizer=tokenizer,
|
||||
preserve_think_blocks=True,
|
||||
) as managed:
|
||||
agent = HermesAgentLoop(
|
||||
server=managed,
|
||||
tool_schemas=[CALC_TOOL],
|
||||
valid_tool_names={"calculate"},
|
||||
max_turns=5,
|
||||
temperature=0.6,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "What is 123 * 456? Use the calculate tool."},
|
||||
]
|
||||
|
||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Qwen3-Thinking should generate <think> blocks
|
||||
# Check if any content contains thinking markers
|
||||
has_thinking = False
|
||||
for msg in result.messages:
|
||||
content = msg.get("content", "") or ""
|
||||
if "<think>" in content or "</think>" in content:
|
||||
has_thinking = True
|
||||
break
|
||||
|
||||
# Also check reasoning_per_turn
|
||||
has_reasoning = any(r for r in result.reasoning_per_turn if r)
|
||||
|
||||
# At least one of these should be true for a thinking model
|
||||
assert has_thinking or has_reasoning, (
|
||||
"Qwen3-Thinking should produce <think> blocks or reasoning content"
|
||||
)
|
||||
|
|
@ -23,7 +23,7 @@ class TestStreamingAssemblyRepair:
|
|||
|
||||
These tests verify the REPAIR FUNCTION itself works correctly for the
|
||||
cases that arise during streaming assembly. Integration tests that
|
||||
exercise the full streaming path are in test_agent_loop_tool_calling.py.
|
||||
exercise the full streaming path are in run_agent.py's streaming tests.
|
||||
"""
|
||||
|
||||
# -- Truncation cases (most common streaming failure) --
|
||||
|
|
|
|||
|
|
@ -278,7 +278,7 @@ class TestLegacyToolsetMap:
|
|||
expected = [
|
||||
"web_tools", "terminal_tools", "vision_tools", "moa_tools",
|
||||
"image_tools", "skills_tools", "browser_tools", "cronjob_tools",
|
||||
"rl_tools", "file_tools", "tts_tools",
|
||||
"file_tools", "tts_tools",
|
||||
]
|
||||
for name in expected:
|
||||
assert name in _LEGACY_TOOLSET_MAP, f"Missing legacy toolset: {name}"
|
||||
|
|
|
|||
|
|
@ -1,178 +0,0 @@
|
|||
"""
|
||||
Tests for ManagedServer / tool-parser integration.
|
||||
|
||||
Validates that:
|
||||
1. The installed atroposlib API still matches Hermes's expectations
|
||||
2. Hermes's parser registry remains compatible with ManagedServer parsing
|
||||
3. HermesAgentBaseEnv wires the selected parser into ServerManager correctly
|
||||
|
||||
These tests verify the contract between hermes-agent's environments/ code
|
||||
and atroposlib's ManagedServer. They detect API incompatibilities early.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
try:
|
||||
import atroposlib # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("atroposlib not installed", allow_module_level=True)
|
||||
|
||||
|
||||
class TestManagedServerAPI:
|
||||
"""Test that ManagedServer's API matches what hermes-agent expects."""
|
||||
|
||||
def test_managed_server_init_signature(self):
|
||||
"""ManagedServer should accept tool_call_parser parameter."""
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
sig = inspect.signature(ManagedServer.__init__)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
# Core params that must exist
|
||||
assert "self" in params
|
||||
assert "server" in params
|
||||
assert "tokenizer" in params
|
||||
assert "track_tree" in params
|
||||
|
||||
# tool_call_parser — required for tool_call_support branch
|
||||
# If this fails, atroposlib hasn't been updated to tool_call_support
|
||||
has_tool_parser = "tool_call_parser" in params
|
||||
if not has_tool_parser:
|
||||
pytest.skip(
|
||||
"ManagedServer does not have tool_call_parser param — "
|
||||
"baseline atroposlib (pre tool_call_support branch)"
|
||||
)
|
||||
|
||||
def test_server_manager_managed_server_signature(self):
|
||||
"""ServerManager.managed_server() should accept tool_call_parser."""
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
sig = inspect.signature(ServerManager.managed_server)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
assert "self" in params
|
||||
assert "tokenizer" in params
|
||||
|
||||
has_tool_parser = "tool_call_parser" in params
|
||||
if not has_tool_parser:
|
||||
pytest.skip(
|
||||
"ServerManager.managed_server() does not have tool_call_parser param — "
|
||||
"baseline atroposlib (pre tool_call_support branch)"
|
||||
)
|
||||
|
||||
def test_managed_server_chat_template_kwargs(self):
|
||||
"""ManagedServer should have CHAT_TEMPLATE_KWARGS for forwarding tools/thinking."""
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
if not hasattr(ManagedServer, "CHAT_TEMPLATE_KWARGS"):
|
||||
pytest.skip(
|
||||
"ManagedServer does not have CHAT_TEMPLATE_KWARGS — "
|
||||
"baseline atroposlib (pre tool_call_support branch)"
|
||||
)
|
||||
|
||||
kwargs = ManagedServer.CHAT_TEMPLATE_KWARGS
|
||||
assert "tools" in kwargs, "tools must be in CHAT_TEMPLATE_KWARGS"
|
||||
|
||||
def test_no_get_logprobs_method(self):
|
||||
"""get_logprobs should be removed in tool_call_support branch."""
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
# In baseline, get_logprobs exists. In tool_call_support, it's removed.
|
||||
# We just note the state — not a hard fail either way.
|
||||
has_get_logprobs = hasattr(ManagedServer, "get_logprobs")
|
||||
if has_get_logprobs:
|
||||
pytest.skip(
|
||||
"ManagedServer still has get_logprobs — baseline atroposlib"
|
||||
)
|
||||
|
||||
|
||||
class TestParserCompatibility:
|
||||
"""Test that hermes-agent's parsers match ManagedServer's expectations."""
|
||||
|
||||
def test_parser_parse_returns_correct_format(self):
|
||||
"""
|
||||
ManagedServer expects parser.parse(text) -> (content, tool_calls)
|
||||
where tool_calls is a list of objects with .id, .function.name, .function.arguments
|
||||
"""
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "ls"}}</tool_call>'
|
||||
content, tool_calls = parser.parse(text)
|
||||
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
|
||||
tc = tool_calls[0]
|
||||
# ManagedServer accesses these attrs directly
|
||||
assert hasattr(tc, "id")
|
||||
assert hasattr(tc, "function")
|
||||
assert hasattr(tc.function, "name")
|
||||
assert hasattr(tc.function, "arguments")
|
||||
|
||||
def test_parser_no_tools_returns_none(self):
|
||||
"""ManagedServer checks `if parsed_tool_calls:` — None should be falsy."""
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
content, tool_calls = parser.parse("Just text, no tools")
|
||||
assert tool_calls is None
|
||||
|
||||
def test_parser_content_is_string_or_none(self):
|
||||
"""ManagedServer uses `parsed_content or ""` — must be str or None."""
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
|
||||
# With tool calls
|
||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "ls"}}</tool_call>'
|
||||
content, _ = parser.parse(text)
|
||||
assert content is None or isinstance(content, str)
|
||||
|
||||
# Without tool calls
|
||||
content2, _ = parser.parse("Just text")
|
||||
assert isinstance(content2, str)
|
||||
|
||||
|
||||
class TestBaseEnvCompatibility:
|
||||
"""Test that hermes_base_env.py's tool-parser wiring matches the current API."""
|
||||
|
||||
def test_hermes_base_env_sets_server_manager_tool_parser(self):
|
||||
"""Hermes wires parser selection through ServerManager.tool_parser."""
|
||||
import ast
|
||||
|
||||
base_env_path = Path(__file__).parent.parent.parent / "environments" / "hermes_base_env.py"
|
||||
source = base_env_path.read_text()
|
||||
tree = ast.parse(source)
|
||||
|
||||
found_assignment = False
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Attribute) and target.attr == "tool_parser":
|
||||
parent = target.value
|
||||
if (
|
||||
isinstance(parent, ast.Attribute)
|
||||
and parent.attr == "server"
|
||||
and isinstance(parent.value, ast.Name)
|
||||
and parent.value.id == "self"
|
||||
):
|
||||
found_assignment = True
|
||||
|
||||
assert found_assignment, (
|
||||
"hermes_base_env.py should set self.server.tool_parser from config.tool_call_parser"
|
||||
)
|
||||
|
||||
def test_hermes_base_env_uses_config_tool_call_parser(self):
|
||||
"""Verify hermes_base_env uses the config field rather than a local parser instance."""
|
||||
base_env_path = Path(__file__).parent.parent.parent / "environments" / "hermes_base_env.py"
|
||||
source = base_env_path.read_text()
|
||||
|
||||
assert 'tool_call_parser: str = Field(' in source
|
||||
assert 'self.server.tool_parser = config.tool_call_parser' in source
|
||||
|
|
@ -1,142 +0,0 @@
|
|||
"""Tests for rl_training_tool.py — file handle lifecycle and cleanup.
|
||||
|
||||
Verifies that _stop_training_run properly closes log file handles,
|
||||
terminates processes, and handles edge cases on failure paths.
|
||||
Inspired by PR #715 (0xbyt4).
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.rl_training_tool import RunState, _stop_training_run
|
||||
|
||||
|
||||
def _make_run_state(**overrides) -> RunState:
|
||||
"""Create a minimal RunState for testing."""
|
||||
defaults = {
|
||||
"run_id": "test-run-001",
|
||||
"environment": "test_env",
|
||||
"config": {},
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return RunState(**defaults)
|
||||
|
||||
|
||||
class TestStopTrainingRunFileHandles:
|
||||
"""Verify that _stop_training_run closes log file handles stored as attributes."""
|
||||
|
||||
def test_closes_all_log_file_handles(self):
|
||||
state = _make_run_state()
|
||||
files = {}
|
||||
for attr in ("api_log_file", "trainer_log_file", "env_log_file"):
|
||||
fh = MagicMock()
|
||||
setattr(state, attr, fh)
|
||||
files[attr] = fh
|
||||
|
||||
_stop_training_run(state)
|
||||
|
||||
for attr, fh in files.items():
|
||||
fh.close.assert_called_once()
|
||||
assert getattr(state, attr) is None
|
||||
|
||||
def test_clears_file_attrs_to_none(self):
|
||||
state = _make_run_state()
|
||||
state.api_log_file = MagicMock()
|
||||
|
||||
_stop_training_run(state)
|
||||
|
||||
assert state.api_log_file is None
|
||||
|
||||
def test_close_exception_does_not_propagate(self):
|
||||
"""If a file handle .close() raises, it must not crash."""
|
||||
state = _make_run_state()
|
||||
bad_fh = MagicMock()
|
||||
bad_fh.close.side_effect = OSError("already closed")
|
||||
good_fh = MagicMock()
|
||||
state.api_log_file = bad_fh
|
||||
state.trainer_log_file = good_fh
|
||||
|
||||
_stop_training_run(state) # should not raise
|
||||
|
||||
bad_fh.close.assert_called_once()
|
||||
good_fh.close.assert_called_once()
|
||||
|
||||
def test_handles_missing_file_attrs(self):
|
||||
"""RunState without log file attrs should not crash."""
|
||||
state = _make_run_state()
|
||||
# No log file attrs set at all — getattr(..., None) should handle it
|
||||
_stop_training_run(state) # should not raise
|
||||
|
||||
|
||||
class TestStopTrainingRunProcesses:
|
||||
"""Verify that _stop_training_run terminates processes correctly."""
|
||||
|
||||
def test_terminates_running_processes(self):
|
||||
state = _make_run_state()
|
||||
for attr in ("api_process", "trainer_process", "env_process"):
|
||||
proc = MagicMock()
|
||||
proc.poll.return_value = None # still running
|
||||
setattr(state, attr, proc)
|
||||
|
||||
_stop_training_run(state)
|
||||
|
||||
for attr in ("api_process", "trainer_process", "env_process"):
|
||||
getattr(state, attr).terminate.assert_called_once()
|
||||
|
||||
def test_does_not_terminate_exited_processes(self):
|
||||
state = _make_run_state()
|
||||
proc = MagicMock()
|
||||
proc.poll.return_value = 0 # already exited
|
||||
state.api_process = proc
|
||||
|
||||
_stop_training_run(state)
|
||||
|
||||
proc.terminate.assert_not_called()
|
||||
|
||||
def test_handles_none_processes(self):
|
||||
state = _make_run_state()
|
||||
# All process attrs are None by default
|
||||
_stop_training_run(state) # should not raise
|
||||
|
||||
def test_handles_mixed_running_and_exited_processes(self):
|
||||
state = _make_run_state()
|
||||
# api still running
|
||||
api = MagicMock()
|
||||
api.poll.return_value = None
|
||||
state.api_process = api
|
||||
# trainer already exited
|
||||
trainer = MagicMock()
|
||||
trainer.poll.return_value = 0
|
||||
state.trainer_process = trainer
|
||||
# env is None
|
||||
state.env_process = None
|
||||
|
||||
_stop_training_run(state)
|
||||
|
||||
api.terminate.assert_called_once()
|
||||
trainer.terminate.assert_not_called()
|
||||
|
||||
|
||||
class TestStopTrainingRunStatus:
|
||||
"""Verify status transitions in _stop_training_run."""
|
||||
|
||||
def test_sets_status_to_stopped_when_running(self):
|
||||
state = _make_run_state(status="running")
|
||||
_stop_training_run(state)
|
||||
assert state.status == "stopped"
|
||||
|
||||
def test_does_not_change_status_when_failed(self):
|
||||
state = _make_run_state(status="failed")
|
||||
_stop_training_run(state)
|
||||
assert state.status == "failed"
|
||||
|
||||
def test_does_not_change_status_when_pending(self):
|
||||
state = _make_run_state(status="pending")
|
||||
_stop_training_run(state)
|
||||
assert state.status == "pending"
|
||||
|
||||
def test_no_crash_with_no_processes_and_no_files(self):
|
||||
state = _make_run_state()
|
||||
_stop_training_run(state) # should not raise
|
||||
assert state.status == "pending"
|
||||
|
|
@ -1,274 +0,0 @@
|
|||
"""
|
||||
Tests for environments/tool_call_parsers/ — client-side tool call parsers.
|
||||
|
||||
These parsers extract structured tool_calls from raw model output text.
|
||||
Used in Phase 2 (VLLM/generate) where the server returns raw tokens.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure repo root is importable
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
try:
|
||||
from environments.tool_call_parsers import (
|
||||
ParseResult,
|
||||
ToolCallParser,
|
||||
get_parser,
|
||||
list_parsers,
|
||||
)
|
||||
except ImportError:
|
||||
pytest.skip("atroposlib not installed", allow_module_level=True)
|
||||
|
||||
|
||||
# ─── Registry tests ─────────────────────────────────────────────────────
|
||||
|
||||
class TestParserRegistry:
|
||||
def test_list_parsers_returns_nonempty(self):
|
||||
parsers = list_parsers()
|
||||
assert len(parsers) > 0
|
||||
|
||||
def test_hermes_parser_registered(self):
|
||||
parsers = list_parsers()
|
||||
assert "hermes" in parsers
|
||||
|
||||
def test_get_parser_returns_instance(self):
|
||||
parser = get_parser("hermes")
|
||||
assert isinstance(parser, ToolCallParser)
|
||||
|
||||
def test_get_parser_unknown_raises(self):
|
||||
with pytest.raises(KeyError):
|
||||
get_parser("nonexistent_parser_xyz")
|
||||
|
||||
def test_all_registered_parsers_instantiate(self):
|
||||
"""Every registered parser should be importable and instantiable."""
|
||||
for name in list_parsers():
|
||||
parser = get_parser(name)
|
||||
assert isinstance(parser, ToolCallParser)
|
||||
assert hasattr(parser, "parse")
|
||||
|
||||
|
||||
# ─── Hermes parser tests ────────────────────────────────────────────────
|
||||
|
||||
class TestHermesParser:
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
return get_parser("hermes")
|
||||
|
||||
def test_no_tool_call(self, parser):
|
||||
text = "Hello, I can help you with that."
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert content == text
|
||||
assert tool_calls is None
|
||||
|
||||
def test_single_tool_call(self, parser):
|
||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "ls -la"}}</tool_call>'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "terminal"
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
assert args["command"] == "ls -la"
|
||||
|
||||
def test_tool_call_with_surrounding_text(self, parser):
|
||||
text = 'Let me check that for you.\n<tool_call>{"name": "terminal", "arguments": {"command": "pwd"}}</tool_call>'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "terminal"
|
||||
# Content should have the surrounding text
|
||||
if content is not None:
|
||||
assert "check that" in content or content.strip() != ""
|
||||
|
||||
def test_multiple_tool_calls(self, parser):
|
||||
text = (
|
||||
'<tool_call>{"name": "terminal", "arguments": {"command": "ls"}}</tool_call>\n'
|
||||
'<tool_call>{"name": "read_file", "arguments": {"path": "test.py"}}</tool_call>'
|
||||
)
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 2
|
||||
names = {tc.function.name for tc in tool_calls}
|
||||
assert "terminal" in names
|
||||
assert "read_file" in names
|
||||
|
||||
def test_tool_call_ids_are_unique(self, parser):
|
||||
text = (
|
||||
'<tool_call>{"name": "terminal", "arguments": {"command": "ls"}}</tool_call>\n'
|
||||
'<tool_call>{"name": "terminal", "arguments": {"command": "pwd"}}</tool_call>'
|
||||
)
|
||||
_, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
ids = [tc.id for tc in tool_calls]
|
||||
assert len(ids) == len(set(ids)), "Tool call IDs must be unique"
|
||||
|
||||
def test_empty_string(self, parser):
|
||||
content, tool_calls = parser.parse("")
|
||||
assert tool_calls is None
|
||||
|
||||
def test_malformed_json_in_tool_call(self, parser):
|
||||
text = '<tool_call>not valid json</tool_call>'
|
||||
content, tool_calls = parser.parse(text)
|
||||
# Should either return None tool_calls or handle gracefully
|
||||
# (implementation may vary — some parsers return error tool calls)
|
||||
|
||||
def test_truncated_tool_call(self, parser):
|
||||
"""Test handling of unclosed tool_call tag (model truncated mid-generation)."""
|
||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "ls -la"}'
|
||||
content, tool_calls = parser.parse(text)
|
||||
# Parser should handle truncated output gracefully
|
||||
# Either parse it successfully or return None
|
||||
|
||||
|
||||
# ─── Parse result contract tests (applies to ALL parsers) ───────────────
|
||||
|
||||
class TestParseResultContract:
|
||||
"""Ensure all parsers conform to the ParseResult contract."""
|
||||
|
||||
@pytest.fixture(params=["hermes"]) # Add more as needed
|
||||
def parser(self, request):
|
||||
return get_parser(request.param)
|
||||
|
||||
def test_returns_tuple_of_two(self, parser):
|
||||
result = parser.parse("hello world")
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_no_tools_returns_none_tool_calls(self, parser):
|
||||
content, tool_calls = parser.parse("Just plain text, no tools.")
|
||||
assert tool_calls is None
|
||||
assert content is not None
|
||||
|
||||
def test_tool_calls_are_proper_objects(self, parser):
|
||||
"""When tool calls are found, they should be ChatCompletionMessageToolCall objects."""
|
||||
# Use hermes format since that's universal
|
||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "echo hi"}}</tool_call>'
|
||||
content, tool_calls = parser.parse(text)
|
||||
if tool_calls is not None:
|
||||
for tc in tool_calls:
|
||||
assert hasattr(tc, "id")
|
||||
assert hasattr(tc, "function")
|
||||
assert hasattr(tc.function, "name")
|
||||
assert hasattr(tc.function, "arguments")
|
||||
assert tc.id is not None
|
||||
assert isinstance(tc.function.name, str)
|
||||
assert isinstance(tc.function.arguments, str)
|
||||
|
||||
|
||||
# ─── DeepSeek V3 parser tests ───────────────────────────────────────────
|
||||
|
||||
class TestDeepSeekV3Parser:
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
return get_parser("deepseek_v3")
|
||||
|
||||
def test_no_tool_call(self, parser):
|
||||
text = "Hello, how can I help you?"
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert content == text
|
||||
assert tool_calls is None
|
||||
|
||||
def test_single_tool_call(self, parser):
|
||||
text = (
|
||||
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather\n'
|
||||
'```json\n{"city": "London"}\n```<|tool▁call▁end|><|tool▁calls▁end|>'
|
||||
)
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "get_weather"
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
assert args["city"] == "London"
|
||||
|
||||
def test_multiple_tool_calls(self, parser):
|
||||
text = (
|
||||
'<|tool▁calls▁begin|>'
|
||||
'<|tool▁call▁begin|>function<|tool▁sep|>get_weather\n'
|
||||
'```json\n{"city": "London"}\n```<|tool▁call▁end|>'
|
||||
'<|tool▁call▁begin|>function<|tool▁sep|>get_time\n'
|
||||
'```json\n{"timezone": "UTC"}\n```<|tool▁call▁end|>'
|
||||
'<|tool▁calls▁end|>'
|
||||
)
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 2, f"Expected 2 tool calls, got {len(tool_calls)}"
|
||||
names = [tc.function.name for tc in tool_calls]
|
||||
assert "get_weather" in names
|
||||
assert "get_time" in names
|
||||
|
||||
def test_tool_call_with_preceding_text(self, parser):
|
||||
text = (
|
||||
'Let me check that for you.\n'
|
||||
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>terminal\n'
|
||||
'```json\n{"command": "ls"}\n```<|tool▁call▁end|><|tool▁calls▁end|>'
|
||||
)
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
|
||||
|
||||
# ─── Mistral parser tests ───────────────────────────────────────────────
|
||||
|
||||
class TestMistralParser:
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
return get_parser("mistral")
|
||||
|
||||
def test_no_tool_call(self, parser):
|
||||
text = "Hello, how can I help you?"
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert content == text
|
||||
assert tool_calls is None
|
||||
|
||||
def test_pre_v11_single_tool_call(self, parser):
|
||||
text = '[TOOL_CALLS] [{"name": "func", "arguments": {"key": "val"}}]'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "func"
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
assert args["key"] == "val"
|
||||
|
||||
def test_pre_v11_nested_json(self, parser):
|
||||
text = '[TOOL_CALLS] [{"name": "func", "arguments": {"nested": {"deep": true}}}]'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "func"
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
assert args["nested"]["deep"] is True
|
||||
|
||||
def test_v11_single_tool_call(self, parser):
|
||||
text = '[TOOL_CALLS]get_weather{"city": "London"}'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "get_weather"
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
assert args["city"] == "London"
|
||||
|
||||
def test_v11_multiple_tool_calls(self, parser):
|
||||
text = '[TOOL_CALLS]func1{"a": 1}[TOOL_CALLS]func2{"b": 2}'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 2
|
||||
names = [tc.function.name for tc in tool_calls]
|
||||
assert "func1" in names
|
||||
assert "func2" in names
|
||||
|
||||
def test_preceding_text_preserved(self, parser):
|
||||
text = 'Hello[TOOL_CALLS]func{"a": 1}'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert content == "Hello"
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "func"
|
||||
|
||||
def test_malformed_json_fallback(self, parser):
|
||||
text = "[TOOL_CALLS] not valid json"
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue