diff --git a/mini_swe_runner.py b/mini_swe_runner.py index 28c0ae48c9..739074402d 100644 --- a/mini_swe_runner.py +++ b/mini_swe_runner.py @@ -43,6 +43,15 @@ from dotenv import load_dotenv load_dotenv() +def _effective_temperature_for_model(model: str) -> Optional[float]: + """Return a fixed temperature for models with strict sampling contracts.""" + try: + from agent.auxiliary_client import _fixed_temperature_for_model + except Exception: + return None + return _fixed_temperature_for_model(model) + + # ============================================================================ @@ -442,12 +451,17 @@ Complete the user's task step by step.""" # Make API call try: - response = self.client.chat.completions.create( - model=self.model, - messages=api_messages, - tools=self.tools, - timeout=300.0 - ) + api_kwargs = { + "model": self.model, + "messages": api_messages, + "tools": self.tools, + "timeout": 300.0, + } + fixed_temperature = _effective_temperature_for_model(self.model) + if fixed_temperature is not None: + api_kwargs["temperature"] = fixed_temperature + + response = self.client.chat.completions.create(**api_kwargs) except Exception as e: self.logger.error(f"API call failed: {e}") break diff --git a/tests/test_mini_swe_runner.py b/tests/test_mini_swe_runner.py new file mode 100644 index 0000000000..adecb5582a --- /dev/null +++ b/tests/test_mini_swe_runner.py @@ -0,0 +1,28 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + + +def test_run_task_forces_kimi_fixed_temperature(): + with patch("openai.OpenAI") as mock_openai: + client = MagicMock() + client.chat.completions.create.return_value = SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="done", tool_calls=[]))] + ) + mock_openai.return_value = client + + from mini_swe_runner import MiniSWERunner + + runner = MiniSWERunner( + model="kimi-for-coding", + base_url="https://api.kimi.com/coding/v1", + api_key="test-key", + env_type="local", + max_iterations=1, + ) + runner._create_env = MagicMock() + runner._cleanup_env = MagicMock() + + result = runner.run_task("2+2") + + assert result["completed"] is True + assert client.chat.completions.create.call_args.kwargs["temperature"] == 0.6 diff --git a/tests/test_trajectory_compressor.py b/tests/test_trajectory_compressor.py index dc66ef4c4a..682097173a 100644 --- a/tests/test_trajectory_compressor.py +++ b/tests/test_trajectory_compressor.py @@ -31,6 +31,29 @@ def test_import_loads_env_from_hermes_home(tmp_path, monkeypatch): assert os.getenv("OPENROUTER_API_KEY") == "from-hermes-home" +def test_generate_summary_custom_client_forces_kimi_temperature(): + config = CompressionConfig( + summarization_model="kimi-for-coding", + temperature=0.3, + summary_target_tokens=100, + max_retries=1, + ) + compressor = TrajectoryCompressor.__new__(TrajectoryCompressor) + compressor.config = config + compressor.logger = MagicMock() + compressor._use_call_llm = False + compressor.client = MagicMock() + compressor.client.chat.completions.create.return_value = SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))] + ) + + metrics = TrajectoryMetrics() + result = compressor._generate_summary("tool output", metrics) + + assert result.startswith("[CONTEXT SUMMARY]:") + assert compressor.client.chat.completions.create.call_args.kwargs["temperature"] == 0.6 + + # --------------------------------------------------------------------------- # CompressionConfig # --------------------------------------------------------------------------- diff --git a/tests/test_trajectory_compressor_async.py b/tests/test_trajectory_compressor_async.py index 1c671471d9..7bf5191621 100644 --- a/tests/test_trajectory_compressor_async.py +++ b/tests/test_trajectory_compressor_async.py @@ -11,6 +11,7 @@ each asyncio.run() gets a client bound to the current loop. """ import types +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -113,3 +114,30 @@ class TestSourceLineVerification: """_get_async_client method should exist.""" src = self._read_file() assert "def _get_async_client(self)" in src + + +@pytest.mark.asyncio +async def test_generate_summary_async_custom_client_forces_kimi_temperature(): + from trajectory_compressor import CompressionConfig, TrajectoryCompressor, TrajectoryMetrics + + config = CompressionConfig( + summarization_model="kimi-for-coding", + temperature=0.3, + summary_target_tokens=100, + max_retries=1, + ) + compressor = TrajectoryCompressor.__new__(TrajectoryCompressor) + compressor.config = config + compressor.logger = MagicMock() + compressor._use_call_llm = False + async_client = MagicMock() + async_client.chat.completions.create = MagicMock(return_value=SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))] + )) + compressor._get_async_client = MagicMock(return_value=async_client) + + metrics = TrajectoryMetrics() + result = await compressor._generate_summary_async("tool output", metrics) + + assert result.startswith("[CONTEXT SUMMARY]:") + assert async_client.chat.completions.create.call_args.kwargs["temperature"] == 0.6 diff --git a/tests/tools/test_approval.py b/tests/tools/test_approval.py index 661b86bf3f..2d7bfe6b0a 100644 --- a/tests/tools/test_approval.py +++ b/tests/tools/test_approval.py @@ -2,11 +2,13 @@ import ast from pathlib import Path +from types import SimpleNamespace from unittest.mock import patch as mock_patch import tools.approval as approval_module from tools.approval import ( _get_approval_mode, + _smart_approve, approve_session, detect_dangerous_command, is_approved, @@ -26,6 +28,21 @@ class TestApprovalModeParsing: assert _get_approval_mode() == "off" +class TestSmartApproval: + def test_smart_approval_uses_call_llm(self): + response = SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="APPROVE"))] + ) + with mock_patch("agent.auxiliary_client.call_llm", return_value=response) as mock_call: + result = _smart_approve("python -c \"print('hello')\"", "script execution via -c flag") + + assert result == "approve" + mock_call.assert_called_once() + assert mock_call.call_args.kwargs["task"] == "approval" + assert mock_call.call_args.kwargs["temperature"] == 0 + assert mock_call.call_args.kwargs["max_tokens"] == 16 + + class TestDetectDangerousRm: def test_rm_rf_detected(self): is_dangerous, key, desc = detect_dangerous_command("rm -rf /home/user") @@ -820,4 +837,3 @@ class TestChmodExecuteCombo: dangerous, _, _ = detect_dangerous_command(cmd) assert dangerous is False - diff --git a/tools/approval.py b/tools/approval.py index d9fcf51a88..7d8c5b032e 100644 --- a/tools/approval.py +++ b/tools/approval.py @@ -542,12 +542,7 @@ def _smart_approve(command: str, description: str) -> str: (openai/codex#13860). """ try: - from agent.auxiliary_client import get_text_auxiliary_client, auxiliary_max_tokens_param - - client, model = get_text_auxiliary_client(task="approval") - if not client or not model: - logger.debug("Smart approvals: no aux client available, escalating") - return "escalate" + from agent.auxiliary_client import call_llm prompt = f"""You are a security reviewer for an AI coding agent. A terminal command was flagged by pattern matching as potentially dangerous. @@ -563,11 +558,11 @@ Rules: Respond with exactly one word: APPROVE, DENY, or ESCALATE""" - response = client.chat.completions.create( - model=model, + response = call_llm( + task="approval", messages=[{"role": "user", "content": prompt}], - **auxiliary_max_tokens_param(16), temperature=0, + max_tokens=16, ) answer = (response.choices[0].message.content or "").strip().upper() diff --git a/trajectory_compressor.py b/trajectory_compressor.py index 3c0e3f1b7a..dff15b2278 100644 --- a/trajectory_compressor.py +++ b/trajectory_compressor.py @@ -54,6 +54,19 @@ _project_env = Path(__file__).parent / ".env" load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env) +def _effective_temperature_for_model(model: str, requested_temperature: float) -> float: + """Apply fixed model temperature contracts to direct client calls.""" + try: + from agent.auxiliary_client import _fixed_temperature_for_model + except Exception: + return requested_temperature + + fixed_temperature = _fixed_temperature_for_model(model) + if fixed_temperature is not None: + return fixed_temperature + return requested_temperature + + @dataclass class CompressionConfig: """Configuration for trajectory compression.""" @@ -567,6 +580,10 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" for attempt in range(self.config.max_retries): try: metrics.summarization_api_calls += 1 + summary_temperature = _effective_temperature_for_model( + self.config.summarization_model, + self.config.temperature, + ) if getattr(self, '_use_call_llm', False): from agent.auxiliary_client import call_llm @@ -574,14 +591,14 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" provider=self._llm_provider, model=self.config.summarization_model, messages=[{"role": "user", "content": prompt}], - temperature=self.config.temperature, + temperature=summary_temperature, max_tokens=self.config.summary_target_tokens * 2, ) else: response = self.client.chat.completions.create( model=self.config.summarization_model, messages=[{"role": "user", "content": prompt}], - temperature=self.config.temperature, + temperature=summary_temperature, max_tokens=self.config.summary_target_tokens * 2, ) @@ -629,6 +646,10 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" for attempt in range(self.config.max_retries): try: metrics.summarization_api_calls += 1 + summary_temperature = _effective_temperature_for_model( + self.config.summarization_model, + self.config.temperature, + ) if getattr(self, '_use_call_llm', False): from agent.auxiliary_client import async_call_llm @@ -636,14 +657,14 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" provider=self._llm_provider, model=self.config.summarization_model, messages=[{"role": "user", "content": prompt}], - temperature=self.config.temperature, + temperature=summary_temperature, max_tokens=self.config.summary_target_tokens * 2, ) else: response = await self._get_async_client().chat.completions.create( model=self.config.summarization_model, messages=[{"role": "user", "content": prompt}], - temperature=self.config.temperature, + temperature=summary_temperature, max_tokens=self.config.summary_target_tokens * 2, )