mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(kimi): cover remaining fixed-temperature bypasses
This commit is contained in:
parent
53e4a2f2c6
commit
148459716c
7 changed files with 145 additions and 20 deletions
|
|
@ -43,6 +43,15 @@ from dotenv import load_dotenv
|
|||
load_dotenv()
|
||||
|
||||
|
||||
def _effective_temperature_for_model(model: str) -> Optional[float]:
|
||||
"""Return a fixed temperature for models with strict sampling contracts."""
|
||||
try:
|
||||
from agent.auxiliary_client import _fixed_temperature_for_model
|
||||
except Exception:
|
||||
return None
|
||||
return _fixed_temperature_for_model(model)
|
||||
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
|
@ -442,12 +451,17 @@ Complete the user's task step by step."""
|
|||
|
||||
# Make API call
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
tools=self.tools,
|
||||
timeout=300.0
|
||||
)
|
||||
api_kwargs = {
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
"tools": self.tools,
|
||||
"timeout": 300.0,
|
||||
}
|
||||
fixed_temperature = _effective_temperature_for_model(self.model)
|
||||
if fixed_temperature is not None:
|
||||
api_kwargs["temperature"] = fixed_temperature
|
||||
|
||||
response = self.client.chat.completions.create(**api_kwargs)
|
||||
except Exception as e:
|
||||
self.logger.error(f"API call failed: {e}")
|
||||
break
|
||||
|
|
|
|||
28
tests/test_mini_swe_runner.py
Normal file
28
tests/test_mini_swe_runner.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def test_run_task_forces_kimi_fixed_temperature():
|
||||
with patch("openai.OpenAI") as mock_openai:
|
||||
client = MagicMock()
|
||||
client.chat.completions.create.return_value = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="done", tool_calls=[]))]
|
||||
)
|
||||
mock_openai.return_value = client
|
||||
|
||||
from mini_swe_runner import MiniSWERunner
|
||||
|
||||
runner = MiniSWERunner(
|
||||
model="kimi-for-coding",
|
||||
base_url="https://api.kimi.com/coding/v1",
|
||||
api_key="test-key",
|
||||
env_type="local",
|
||||
max_iterations=1,
|
||||
)
|
||||
runner._create_env = MagicMock()
|
||||
runner._cleanup_env = MagicMock()
|
||||
|
||||
result = runner.run_task("2+2")
|
||||
|
||||
assert result["completed"] is True
|
||||
assert client.chat.completions.create.call_args.kwargs["temperature"] == 0.6
|
||||
|
|
@ -31,6 +31,29 @@ def test_import_loads_env_from_hermes_home(tmp_path, monkeypatch):
|
|||
assert os.getenv("OPENROUTER_API_KEY") == "from-hermes-home"
|
||||
|
||||
|
||||
def test_generate_summary_custom_client_forces_kimi_temperature():
|
||||
config = CompressionConfig(
|
||||
summarization_model="kimi-for-coding",
|
||||
temperature=0.3,
|
||||
summary_target_tokens=100,
|
||||
max_retries=1,
|
||||
)
|
||||
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
|
||||
compressor.config = config
|
||||
compressor.logger = MagicMock()
|
||||
compressor._use_call_llm = False
|
||||
compressor.client = MagicMock()
|
||||
compressor.client.chat.completions.create.return_value = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))]
|
||||
)
|
||||
|
||||
metrics = TrajectoryMetrics()
|
||||
result = compressor._generate_summary("tool output", metrics)
|
||||
|
||||
assert result.startswith("[CONTEXT SUMMARY]:")
|
||||
assert compressor.client.chat.completions.create.call_args.kwargs["temperature"] == 0.6
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompressionConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ each asyncio.run() gets a client bound to the current loop.
|
|||
"""
|
||||
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -113,3 +114,30 @@ class TestSourceLineVerification:
|
|||
"""_get_async_client method should exist."""
|
||||
src = self._read_file()
|
||||
assert "def _get_async_client(self)" in src
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_summary_async_custom_client_forces_kimi_temperature():
|
||||
from trajectory_compressor import CompressionConfig, TrajectoryCompressor, TrajectoryMetrics
|
||||
|
||||
config = CompressionConfig(
|
||||
summarization_model="kimi-for-coding",
|
||||
temperature=0.3,
|
||||
summary_target_tokens=100,
|
||||
max_retries=1,
|
||||
)
|
||||
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
|
||||
compressor.config = config
|
||||
compressor.logger = MagicMock()
|
||||
compressor._use_call_llm = False
|
||||
async_client = MagicMock()
|
||||
async_client.chat.completions.create = MagicMock(return_value=SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))]
|
||||
))
|
||||
compressor._get_async_client = MagicMock(return_value=async_client)
|
||||
|
||||
metrics = TrajectoryMetrics()
|
||||
result = await compressor._generate_summary_async("tool output", metrics)
|
||||
|
||||
assert result.startswith("[CONTEXT SUMMARY]:")
|
||||
assert async_client.chat.completions.create.call_args.kwargs["temperature"] == 0.6
|
||||
|
|
|
|||
|
|
@ -2,11 +2,13 @@
|
|||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch as mock_patch
|
||||
|
||||
import tools.approval as approval_module
|
||||
from tools.approval import (
|
||||
_get_approval_mode,
|
||||
_smart_approve,
|
||||
approve_session,
|
||||
detect_dangerous_command,
|
||||
is_approved,
|
||||
|
|
@ -26,6 +28,21 @@ class TestApprovalModeParsing:
|
|||
assert _get_approval_mode() == "off"
|
||||
|
||||
|
||||
class TestSmartApproval:
|
||||
def test_smart_approval_uses_call_llm(self):
|
||||
response = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="APPROVE"))]
|
||||
)
|
||||
with mock_patch("agent.auxiliary_client.call_llm", return_value=response) as mock_call:
|
||||
result = _smart_approve("python -c \"print('hello')\"", "script execution via -c flag")
|
||||
|
||||
assert result == "approve"
|
||||
mock_call.assert_called_once()
|
||||
assert mock_call.call_args.kwargs["task"] == "approval"
|
||||
assert mock_call.call_args.kwargs["temperature"] == 0
|
||||
assert mock_call.call_args.kwargs["max_tokens"] == 16
|
||||
|
||||
|
||||
class TestDetectDangerousRm:
|
||||
def test_rm_rf_detected(self):
|
||||
is_dangerous, key, desc = detect_dangerous_command("rm -rf /home/user")
|
||||
|
|
@ -820,4 +837,3 @@ class TestChmodExecuteCombo:
|
|||
dangerous, _, _ = detect_dangerous_command(cmd)
|
||||
assert dangerous is False
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -542,12 +542,7 @@ def _smart_approve(command: str, description: str) -> str:
|
|||
(openai/codex#13860).
|
||||
"""
|
||||
try:
|
||||
from agent.auxiliary_client import get_text_auxiliary_client, auxiliary_max_tokens_param
|
||||
|
||||
client, model = get_text_auxiliary_client(task="approval")
|
||||
if not client or not model:
|
||||
logger.debug("Smart approvals: no aux client available, escalating")
|
||||
return "escalate"
|
||||
from agent.auxiliary_client import call_llm
|
||||
|
||||
prompt = f"""You are a security reviewer for an AI coding agent. A terminal command was flagged by pattern matching as potentially dangerous.
|
||||
|
||||
|
|
@ -563,11 +558,11 @@ Rules:
|
|||
|
||||
Respond with exactly one word: APPROVE, DENY, or ESCALATE"""
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
response = call_llm(
|
||||
task="approval",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
**auxiliary_max_tokens_param(16),
|
||||
temperature=0,
|
||||
max_tokens=16,
|
||||
)
|
||||
|
||||
answer = (response.choices[0].message.content or "").strip().upper()
|
||||
|
|
|
|||
|
|
@ -54,6 +54,19 @@ _project_env = Path(__file__).parent / ".env"
|
|||
load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
|
||||
|
||||
|
||||
def _effective_temperature_for_model(model: str, requested_temperature: float) -> float:
|
||||
"""Apply fixed model temperature contracts to direct client calls."""
|
||||
try:
|
||||
from agent.auxiliary_client import _fixed_temperature_for_model
|
||||
except Exception:
|
||||
return requested_temperature
|
||||
|
||||
fixed_temperature = _fixed_temperature_for_model(model)
|
||||
if fixed_temperature is not None:
|
||||
return fixed_temperature
|
||||
return requested_temperature
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressionConfig:
|
||||
"""Configuration for trajectory compression."""
|
||||
|
|
@ -567,6 +580,10 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
|||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
metrics.summarization_api_calls += 1
|
||||
summary_temperature = _effective_temperature_for_model(
|
||||
self.config.summarization_model,
|
||||
self.config.temperature,
|
||||
)
|
||||
|
||||
if getattr(self, '_use_call_llm', False):
|
||||
from agent.auxiliary_client import call_llm
|
||||
|
|
@ -574,14 +591,14 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
|||
provider=self._llm_provider,
|
||||
model=self.config.summarization_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=self.config.temperature,
|
||||
temperature=summary_temperature,
|
||||
max_tokens=self.config.summary_target_tokens * 2,
|
||||
)
|
||||
else:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.config.summarization_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=self.config.temperature,
|
||||
temperature=summary_temperature,
|
||||
max_tokens=self.config.summary_target_tokens * 2,
|
||||
)
|
||||
|
||||
|
|
@ -629,6 +646,10 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
|||
for attempt in range(self.config.max_retries):
|
||||
try:
|
||||
metrics.summarization_api_calls += 1
|
||||
summary_temperature = _effective_temperature_for_model(
|
||||
self.config.summarization_model,
|
||||
self.config.temperature,
|
||||
)
|
||||
|
||||
if getattr(self, '_use_call_llm', False):
|
||||
from agent.auxiliary_client import async_call_llm
|
||||
|
|
@ -636,14 +657,14 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
|||
provider=self._llm_provider,
|
||||
model=self.config.summarization_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=self.config.temperature,
|
||||
temperature=summary_temperature,
|
||||
max_tokens=self.config.summary_target_tokens * 2,
|
||||
)
|
||||
else:
|
||||
response = await self._get_async_client().chat.completions.create(
|
||||
model=self.config.summarization_model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=self.config.temperature,
|
||||
temperature=summary_temperature,
|
||||
max_tokens=self.config.summary_target_tokens * 2,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue