fix(kimi): cover remaining fixed-temperature bypasses

This commit is contained in:
helix4u 2026-04-17 20:39:24 -06:00 committed by Teknium
parent 53e4a2f2c6
commit 148459716c
7 changed files with 145 additions and 20 deletions

View file

@ -43,6 +43,15 @@ from dotenv import load_dotenv
load_dotenv()
def _effective_temperature_for_model(model: str) -> Optional[float]:
"""Return a fixed temperature for models with strict sampling contracts."""
try:
from agent.auxiliary_client import _fixed_temperature_for_model
except Exception:
return None
return _fixed_temperature_for_model(model)
# ============================================================================
@ -442,12 +451,17 @@ Complete the user's task step by step."""
# Make API call
try:
response = self.client.chat.completions.create(
model=self.model,
messages=api_messages,
tools=self.tools,
timeout=300.0
)
api_kwargs = {
"model": self.model,
"messages": api_messages,
"tools": self.tools,
"timeout": 300.0,
}
fixed_temperature = _effective_temperature_for_model(self.model)
if fixed_temperature is not None:
api_kwargs["temperature"] = fixed_temperature
response = self.client.chat.completions.create(**api_kwargs)
except Exception as e:
self.logger.error(f"API call failed: {e}")
break

View file

@ -0,0 +1,28 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
def test_run_task_forces_kimi_fixed_temperature():
with patch("openai.OpenAI") as mock_openai:
client = MagicMock()
client.chat.completions.create.return_value = SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="done", tool_calls=[]))]
)
mock_openai.return_value = client
from mini_swe_runner import MiniSWERunner
runner = MiniSWERunner(
model="kimi-for-coding",
base_url="https://api.kimi.com/coding/v1",
api_key="test-key",
env_type="local",
max_iterations=1,
)
runner._create_env = MagicMock()
runner._cleanup_env = MagicMock()
result = runner.run_task("2+2")
assert result["completed"] is True
assert client.chat.completions.create.call_args.kwargs["temperature"] == 0.6

View file

@ -31,6 +31,29 @@ def test_import_loads_env_from_hermes_home(tmp_path, monkeypatch):
assert os.getenv("OPENROUTER_API_KEY") == "from-hermes-home"
def test_generate_summary_custom_client_forces_kimi_temperature():
config = CompressionConfig(
summarization_model="kimi-for-coding",
temperature=0.3,
summary_target_tokens=100,
max_retries=1,
)
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
compressor.config = config
compressor.logger = MagicMock()
compressor._use_call_llm = False
compressor.client = MagicMock()
compressor.client.chat.completions.create.return_value = SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))]
)
metrics = TrajectoryMetrics()
result = compressor._generate_summary("tool output", metrics)
assert result.startswith("[CONTEXT SUMMARY]:")
assert compressor.client.chat.completions.create.call_args.kwargs["temperature"] == 0.6
# ---------------------------------------------------------------------------
# CompressionConfig
# ---------------------------------------------------------------------------

View file

@ -11,6 +11,7 @@ each asyncio.run() gets a client bound to the current loop.
"""
import types
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
@ -113,3 +114,30 @@ class TestSourceLineVerification:
"""_get_async_client method should exist."""
src = self._read_file()
assert "def _get_async_client(self)" in src
@pytest.mark.asyncio
async def test_generate_summary_async_custom_client_forces_kimi_temperature():
from trajectory_compressor import CompressionConfig, TrajectoryCompressor, TrajectoryMetrics
config = CompressionConfig(
summarization_model="kimi-for-coding",
temperature=0.3,
summary_target_tokens=100,
max_retries=1,
)
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
compressor.config = config
compressor.logger = MagicMock()
compressor._use_call_llm = False
async_client = MagicMock()
async_client.chat.completions.create = MagicMock(return_value=SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))]
))
compressor._get_async_client = MagicMock(return_value=async_client)
metrics = TrajectoryMetrics()
result = await compressor._generate_summary_async("tool output", metrics)
assert result.startswith("[CONTEXT SUMMARY]:")
assert async_client.chat.completions.create.call_args.kwargs["temperature"] == 0.6

View file

@ -2,11 +2,13 @@
import ast
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch as mock_patch
import tools.approval as approval_module
from tools.approval import (
_get_approval_mode,
_smart_approve,
approve_session,
detect_dangerous_command,
is_approved,
@ -26,6 +28,21 @@ class TestApprovalModeParsing:
assert _get_approval_mode() == "off"
class TestSmartApproval:
def test_smart_approval_uses_call_llm(self):
response = SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="APPROVE"))]
)
with mock_patch("agent.auxiliary_client.call_llm", return_value=response) as mock_call:
result = _smart_approve("python -c \"print('hello')\"", "script execution via -c flag")
assert result == "approve"
mock_call.assert_called_once()
assert mock_call.call_args.kwargs["task"] == "approval"
assert mock_call.call_args.kwargs["temperature"] == 0
assert mock_call.call_args.kwargs["max_tokens"] == 16
class TestDetectDangerousRm:
def test_rm_rf_detected(self):
is_dangerous, key, desc = detect_dangerous_command("rm -rf /home/user")
@ -820,4 +837,3 @@ class TestChmodExecuteCombo:
dangerous, _, _ = detect_dangerous_command(cmd)
assert dangerous is False

View file

@ -542,12 +542,7 @@ def _smart_approve(command: str, description: str) -> str:
(openai/codex#13860).
"""
try:
from agent.auxiliary_client import get_text_auxiliary_client, auxiliary_max_tokens_param
client, model = get_text_auxiliary_client(task="approval")
if not client or not model:
logger.debug("Smart approvals: no aux client available, escalating")
return "escalate"
from agent.auxiliary_client import call_llm
prompt = f"""You are a security reviewer for an AI coding agent. A terminal command was flagged by pattern matching as potentially dangerous.
@ -563,11 +558,11 @@ Rules:
Respond with exactly one word: APPROVE, DENY, or ESCALATE"""
response = client.chat.completions.create(
model=model,
response = call_llm(
task="approval",
messages=[{"role": "user", "content": prompt}],
**auxiliary_max_tokens_param(16),
temperature=0,
max_tokens=16,
)
answer = (response.choices[0].message.content or "").strip().upper()

View file

@ -54,6 +54,19 @@ _project_env = Path(__file__).parent / ".env"
load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
def _effective_temperature_for_model(model: str, requested_temperature: float) -> float:
"""Apply fixed model temperature contracts to direct client calls."""
try:
from agent.auxiliary_client import _fixed_temperature_for_model
except Exception:
return requested_temperature
fixed_temperature = _fixed_temperature_for_model(model)
if fixed_temperature is not None:
return fixed_temperature
return requested_temperature
@dataclass
class CompressionConfig:
"""Configuration for trajectory compression."""
@ -567,6 +580,10 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
for attempt in range(self.config.max_retries):
try:
metrics.summarization_api_calls += 1
summary_temperature = _effective_temperature_for_model(
self.config.summarization_model,
self.config.temperature,
)
if getattr(self, '_use_call_llm', False):
from agent.auxiliary_client import call_llm
@ -574,14 +591,14 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
provider=self._llm_provider,
model=self.config.summarization_model,
messages=[{"role": "user", "content": prompt}],
temperature=self.config.temperature,
temperature=summary_temperature,
max_tokens=self.config.summary_target_tokens * 2,
)
else:
response = self.client.chat.completions.create(
model=self.config.summarization_model,
messages=[{"role": "user", "content": prompt}],
temperature=self.config.temperature,
temperature=summary_temperature,
max_tokens=self.config.summary_target_tokens * 2,
)
@ -629,6 +646,10 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
for attempt in range(self.config.max_retries):
try:
metrics.summarization_api_calls += 1
summary_temperature = _effective_temperature_for_model(
self.config.summarization_model,
self.config.temperature,
)
if getattr(self, '_use_call_llm', False):
from agent.auxiliary_client import async_call_llm
@ -636,14 +657,14 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
provider=self._llm_provider,
model=self.config.summarization_model,
messages=[{"role": "user", "content": prompt}],
temperature=self.config.temperature,
temperature=summary_temperature,
max_tokens=self.config.summary_target_tokens * 2,
)
else:
response = await self._get_async_client().chat.completions.create(
model=self.config.summarization_model,
messages=[{"role": "user", "content": prompt}],
temperature=self.config.temperature,
temperature=summary_temperature,
max_tokens=self.config.summary_target_tokens * 2,
)