fix: propagate kimi base-url temperature overrides

Follow up salvaged PR #12668 by threading base_url through the
remaining direct-call sites so kimi-k2.5 uses temperature=1.0 on
api.moonshot.ai and keeps 0.6 on api.kimi.com/coding. Add focused
regression tests for run_agent, trajectory_compressor, and
mini_swe_runner.
This commit is contained in:
kshitijk4poor 2026-04-20 01:35:42 +05:30 committed by Teknium
parent 6f79b8f01d
commit 50d6799389
7 changed files with 119 additions and 8 deletions

View file

@ -43,13 +43,16 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
def _effective_temperature_for_model(model: str) -> Optional[float]: def _effective_temperature_for_model(
model: str,
base_url: Optional[str] = None,
) -> Optional[float]:
"""Return a fixed temperature for models with strict sampling contracts.""" """Return a fixed temperature for models with strict sampling contracts."""
try: try:
from agent.auxiliary_client import _fixed_temperature_for_model from agent.auxiliary_client import _fixed_temperature_for_model
except Exception: except Exception:
return None return None
return _fixed_temperature_for_model(model) return _fixed_temperature_for_model(model, base_url)
@ -457,7 +460,10 @@ Complete the user's task step by step."""
"tools": self.tools, "tools": self.tools,
"timeout": 300.0, "timeout": 300.0,
} }
fixed_temperature = _effective_temperature_for_model(self.model) fixed_temperature = _effective_temperature_for_model(
self.model,
str(getattr(self.client, "base_url", "") or ""),
)
if fixed_temperature is not None: if fixed_temperature is not None:
api_kwargs["temperature"] = fixed_temperature api_kwargs["temperature"] = fixed_temperature

View file

@ -7173,7 +7173,7 @@ class AIAgent:
except Exception: except Exception:
_fixed_temperature_for_model = None _fixed_temperature_for_model = None
if _fixed_temperature_for_model is not None: if _fixed_temperature_for_model is not None:
fixed_temperature = _fixed_temperature_for_model(self.model) fixed_temperature = _fixed_temperature_for_model(self.model, self.base_url)
if fixed_temperature is not None: if fixed_temperature is not None:
api_kwargs["temperature"] = fixed_temperature api_kwargs["temperature"] = fixed_temperature
if self._is_qwen_portal(): if self._is_qwen_portal():
@ -7619,7 +7619,7 @@ class AIAgent:
_aux_available = True _aux_available = True
# Use the fixed-temperature override (e.g. kimi-for-coding → 0.6) if # Use the fixed-temperature override (e.g. kimi-for-coding → 0.6) if
# the model has a strict contract; otherwise the historical 0.3 default. # the model has a strict contract; otherwise the historical 0.3 default.
_flush_temperature = _fixed_temperature_for_model(self.model) _flush_temperature = _fixed_temperature_for_model(self.model, self.base_url)
if _flush_temperature is None: if _flush_temperature is None:
_flush_temperature = 0.3 _flush_temperature = 0.3
try: try:
@ -8675,7 +8675,7 @@ class AIAgent:
except Exception: except Exception:
_fixed_temperature_for_model = None _fixed_temperature_for_model = None
_summary_temperature = ( _summary_temperature = (
_fixed_temperature_for_model(self.model) _fixed_temperature_for_model(self.model, self.base_url)
if _fixed_temperature_for_model is not None if _fixed_temperature_for_model is not None
else None else None
) )

View file

@ -918,6 +918,26 @@ class TestBuildApiKwargs:
assert kwargs["messages"] is messages assert kwargs["messages"] is messages
assert kwargs["timeout"] == 1800.0 assert kwargs["timeout"] == 1800.0
def test_public_moonshot_kimi_k2_5_forces_temperature_1(self, agent):
agent.base_url = "https://api.moonshot.ai/v1"
agent._base_url_lower = agent.base_url.lower()
agent.model = "kimi-k2.5"
messages = [{"role": "user", "content": "hi"}]
kwargs = agent._build_api_kwargs(messages)
assert kwargs["temperature"] == 1.0
def test_kimi_coding_endpoint_keeps_kimi_k2_5_at_0_6(self, agent):
agent.base_url = "https://api.kimi.com/coding/v1"
agent._base_url_lower = agent.base_url.lower()
agent.model = "kimi-k2.5"
messages = [{"role": "user", "content": "hi"}]
kwargs = agent._build_api_kwargs(messages)
assert kwargs["temperature"] == 0.6
def test_provider_preferences_injected(self, agent): def test_provider_preferences_injected(self, agent):
agent.base_url = "https://openrouter.ai/api/v1" agent.base_url = "https://openrouter.ai/api/v1"
agent.providers_allowed = ["Anthropic"] agent.providers_allowed = ["Anthropic"]

View file

@ -26,3 +26,30 @@ def test_run_task_forces_kimi_fixed_temperature():
assert result["completed"] is True assert result["completed"] is True
assert client.chat.completions.create.call_args.kwargs["temperature"] == 0.6 assert client.chat.completions.create.call_args.kwargs["temperature"] == 0.6
def test_run_task_public_moonshot_kimi_k2_5_forces_temperature_1():
with patch("openai.OpenAI") as mock_openai:
client = MagicMock()
client.base_url = "https://api.moonshot.ai/v1"
client.chat.completions.create.return_value = SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="done", tool_calls=[]))]
)
mock_openai.return_value = client
from mini_swe_runner import MiniSWERunner
runner = MiniSWERunner(
model="kimi-k2.5",
base_url="https://api.moonshot.ai/v1",
api_key="test-key",
env_type="local",
max_iterations=1,
)
runner._create_env = MagicMock()
runner._cleanup_env = MagicMock()
result = runner.run_task("2+2")
assert result["completed"] is True
assert client.chat.completions.create.call_args.kwargs["temperature"] == 1.0

View file

@ -54,6 +54,30 @@ def test_generate_summary_custom_client_forces_kimi_temperature():
assert compressor.client.chat.completions.create.call_args.kwargs["temperature"] == 0.6 assert compressor.client.chat.completions.create.call_args.kwargs["temperature"] == 0.6
def test_generate_summary_public_moonshot_kimi_k2_5_forces_temperature_1():
config = CompressionConfig(
summarization_model="kimi-k2.5",
base_url="https://api.moonshot.ai/v1",
temperature=0.3,
summary_target_tokens=100,
max_retries=1,
)
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
compressor.config = config
compressor.logger = MagicMock()
compressor._use_call_llm = False
compressor.client = MagicMock()
compressor.client.chat.completions.create.return_value = SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))]
)
metrics = TrajectoryMetrics()
result = compressor._generate_summary("tool output", metrics)
assert result.startswith("[CONTEXT SUMMARY]:")
assert compressor.client.chat.completions.create.call_args.kwargs["temperature"] == 1.0
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# CompressionConfig # CompressionConfig
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View file

@ -141,3 +141,31 @@ async def test_generate_summary_async_custom_client_forces_kimi_temperature():
assert result.startswith("[CONTEXT SUMMARY]:") assert result.startswith("[CONTEXT SUMMARY]:")
assert async_client.chat.completions.create.call_args.kwargs["temperature"] == 0.6 assert async_client.chat.completions.create.call_args.kwargs["temperature"] == 0.6
@pytest.mark.asyncio
async def test_generate_summary_async_public_moonshot_kimi_k2_5_forces_temperature_1():
from trajectory_compressor import CompressionConfig, TrajectoryCompressor, TrajectoryMetrics
config = CompressionConfig(
summarization_model="kimi-k2.5",
base_url="https://api.moonshot.ai/v1",
temperature=0.3,
summary_target_tokens=100,
max_retries=1,
)
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
compressor.config = config
compressor.logger = MagicMock()
compressor._use_call_llm = False
async_client = MagicMock()
async_client.chat.completions.create = MagicMock(return_value=SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))]
))
compressor._get_async_client = MagicMock(return_value=async_client)
metrics = TrajectoryMetrics()
result = await compressor._generate_summary_async("tool output", metrics)
assert result.startswith("[CONTEXT SUMMARY]:")
assert async_client.chat.completions.create.call_args.kwargs["temperature"] == 1.0

View file

@ -54,14 +54,18 @@ _project_env = Path(__file__).parent / ".env"
load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env) load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
def _effective_temperature_for_model(model: str, requested_temperature: float) -> float: def _effective_temperature_for_model(
model: str,
requested_temperature: float,
base_url: Optional[str] = None,
) -> float:
"""Apply fixed model temperature contracts to direct client calls.""" """Apply fixed model temperature contracts to direct client calls."""
try: try:
from agent.auxiliary_client import _fixed_temperature_for_model from agent.auxiliary_client import _fixed_temperature_for_model
except Exception: except Exception:
return requested_temperature return requested_temperature
fixed_temperature = _fixed_temperature_for_model(model) fixed_temperature = _fixed_temperature_for_model(model, base_url)
if fixed_temperature is not None: if fixed_temperature is not None:
return fixed_temperature return fixed_temperature
return requested_temperature return requested_temperature
@ -583,6 +587,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
summary_temperature = _effective_temperature_for_model( summary_temperature = _effective_temperature_for_model(
self.config.summarization_model, self.config.summarization_model,
self.config.temperature, self.config.temperature,
self.config.base_url,
) )
if getattr(self, '_use_call_llm', False): if getattr(self, '_use_call_llm', False):
@ -649,6 +654,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
summary_temperature = _effective_temperature_for_model( summary_temperature = _effective_temperature_for_model(
self.config.summarization_model, self.config.summarization_model,
self.config.temperature, self.config.temperature,
self.config.base_url,
) )
if getattr(self, '_use_call_llm', False): if getattr(self, '_use_call_llm', False):