mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
add nemo gym support
This commit is contained in:
parent
bd9e0b605f
commit
44371a9bbb
2 changed files with 166 additions and 1 deletions
|
|
@ -193,6 +193,10 @@ class HermesAgentLoop:
|
|||
|
||||
import time as _time
|
||||
|
||||
prompt_token_ids = None
|
||||
generation_token_ids = None
|
||||
generation_log_probs = None
|
||||
|
||||
for turn in range(self.max_turns):
|
||||
turn_start = _time.monotonic()
|
||||
|
||||
|
|
@ -220,6 +224,7 @@ class HermesAgentLoop:
|
|||
api_start = _time.monotonic()
|
||||
try:
|
||||
response = await self.server.chat_completion(**chat_kwargs)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e)
|
||||
|
|
@ -230,6 +235,9 @@ class HermesAgentLoop:
|
|||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
# prompt_token_ids=response.choices[0].message.prompt_token_ids if hasattr(response.choices[0].message, "prompt_token_ids") else None,
|
||||
# generation_token_ids=response.choices[0].message.generation_token_ids if hasattr(response.choices[0].message, "generation_token_ids") else None,
|
||||
# generation_log_probs=response.choices[0].message.generation_log_probs if hasattr(response.choices[0].message, "generation_log_probs") else None,
|
||||
)
|
||||
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
|
|
@ -246,6 +254,12 @@ class HermesAgentLoop:
|
|||
)
|
||||
|
||||
assistant_msg = response.choices[0].message
|
||||
if hasattr(assistant_msg, "prompt_token_ids"):
|
||||
prompt_token_ids = assistant_msg.prompt_token_ids
|
||||
if hasattr(assistant_msg, "generation_token_ids"):
|
||||
generation_token_ids = assistant_msg.generation_token_ids
|
||||
if hasattr(assistant_msg, "generation_log_probs"):
|
||||
generation_log_probs = assistant_msg.generation_log_probs
|
||||
|
||||
# Extract reasoning content from the response (all provider formats)
|
||||
reasoning = _extract_reasoning_from_message(assistant_msg)
|
||||
|
|
@ -308,7 +322,10 @@ class HermesAgentLoop:
|
|||
"content": assistant_msg.content or "",
|
||||
"tool_calls": [_tc_to_dict(tc) for tc in assistant_msg.tool_calls],
|
||||
}
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
msg_dict["prompt_token_ids"] = prompt_token_ids
|
||||
msg_dict["generation_token_ids"] = generation_token_ids
|
||||
msg_dict["generation_log_probs"] = generation_log_probs
|
||||
# Preserve reasoning_content for multi-turn chat template handling
|
||||
# (e.g., Kimi-K2's template renders <think> blocks differently
|
||||
# for history vs. the latest turn based on this field)
|
||||
|
|
@ -471,6 +488,10 @@ class HermesAgentLoop:
|
|||
}
|
||||
if reasoning:
|
||||
msg_dict["reasoning_content"] = reasoning
|
||||
if prompt_token_ids is not None:
|
||||
msg_dict["prompt_token_ids"] = prompt_token_ids
|
||||
msg_dict["generation_token_ids"] = generation_token_ids
|
||||
msg_dict["generation_log_probs"] = generation_log_probs
|
||||
messages.append(msg_dict)
|
||||
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
|
|
|
|||
144
environments/check_gym_compat.py
Normal file
144
environments/check_gym_compat.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick compatibility check: connect to a local OpenAI-compatible endpoint
|
||||
and run a single agent turn via HermesAgentLoop with all standard tools.
|
||||
|
||||
Usage:
|
||||
python environments/check_gym_compat.py # auto-detect model
|
||||
python environments/check_gym_compat.py --model my-model # explicit model
|
||||
python environments/check_gym_compat.py --base-url http://... --model ...
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure repo root is on sys.path when run as a standalone script
|
||||
_repo_root = str(Path(__file__).resolve().parent.parent)
|
||||
if _repo_root not in sys.path:
|
||||
sys.path.insert(0, _repo_root)
|
||||
|
||||
import requests
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from environments.agent_loop import HermesAgentLoop, AgentResult
|
||||
from model_tools import get_tool_definitions
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thin server wrapper — gives HermesAgentLoop the chat_completion() it wants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class OpenAIServer:
|
||||
"""Minimal async server wrapping an OpenAI-compatible endpoint."""
|
||||
|
||||
def __init__(self, base_url: str, model: str, api_key: str = "dummy"):
|
||||
self.model = model
|
||||
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key)
|
||||
|
||||
async def chat_completion(self, **kwargs):
|
||||
kwargs.setdefault("model", self.model)
|
||||
return await self.client.chat.completions.create(**kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def detect_model(base_url: str) -> str:
|
||||
try:
|
||||
resp = requests.get(f"{base_url}/models", timeout=10)
|
||||
resp.raise_for_status()
|
||||
models = resp.json().get("data", [])
|
||||
if not models:
|
||||
print("WARNING: /v1/models returned no models")
|
||||
return "default"
|
||||
model_id = models[0]["id"]
|
||||
print(f"Auto-detected model: {model_id}")
|
||||
return model_id
|
||||
except Exception as e:
|
||||
print(f"Could not auto-detect model ({e}), falling back to 'default'")
|
||||
return "default"
|
||||
|
||||
|
||||
async def run_check(base_url: str, model: str, message: str) -> AgentResult:
|
||||
server = OpenAIServer(base_url=base_url, model=model)
|
||||
|
||||
# Get all default hermes tools
|
||||
tool_schemas = get_tool_definitions(quiet_mode=False)
|
||||
valid_names = {t["function"]["name"] for t in tool_schemas}
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=tool_schemas,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=5,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant with access to tools."},
|
||||
{"role": "user", "content": message},
|
||||
]
|
||||
|
||||
return await agent.run(messages)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Check gym endpoint compatibility")
|
||||
parser.add_argument("--base-url", default="http://127.0.0.1:11746/v1")
|
||||
parser.add_argument("--model", default=None)
|
||||
parser.add_argument("--message", default="Hello! What's the current directory you're in?")
|
||||
args = parser.parse_args()
|
||||
|
||||
model = args.model or detect_model(args.base_url)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Endpoint: {args.base_url}")
|
||||
print(f"Model: {model}")
|
||||
print(f"Message: {args.message}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
try:
|
||||
result = asyncio.run(run_check(args.base_url, model, args.message))
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Turns used: {result.turns_used}")
|
||||
print(f"Finished naturally: {result.finished_naturally}")
|
||||
print(f"Tool errors: {len(result.tool_errors)}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Print the final assistant response
|
||||
for msg in reversed(result.messages):
|
||||
# if msg.get("role") == "assistant" and msg.get("content"):
|
||||
# print("\nRESPONSE:")
|
||||
# print(msg["content"])
|
||||
# break
|
||||
print(msg)
|
||||
|
||||
if result.tool_errors:
|
||||
print("\nTOOL ERRORS:")
|
||||
for err in result.tool_errors:
|
||||
print(f" turn {err.turn}: {err.tool_name} — {err.error}")
|
||||
|
||||
status = "✅ passed" if result.finished_naturally else "⚠️ hit max turns"
|
||||
print(f"\nGym compatibility check {status}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Gym compatibility check failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue