fix: restrict provider URL detection to exact hostname matches

This commit is contained in:
Aslaaen 2026-04-21 02:07:13 +03:00 committed by Teknium
parent fdd0ecaf13
commit 5356797f1b
4 changed files with 67 additions and 6 deletions

View file

@ -6,6 +6,7 @@ import logging
import os import os
import re import re
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from urllib.parse import urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,6 +36,14 @@ def _normalize_custom_provider_name(value: str) -> str:
return value.strip().lower().replace(" ", "-") return value.strip().lower().replace(" ", "-")
def _base_url_hostname(base_url: str) -> str:
raw = (base_url or "").strip()
if not raw:
return ""
parsed = urlparse(raw if "://" in raw else f"//{raw}")
return (parsed.hostname or "").lower().rstrip(".")
def _detect_api_mode_for_url(base_url: str) -> Optional[str]: def _detect_api_mode_for_url(base_url: str) -> Optional[str]:
"""Auto-detect api_mode from the resolved base URL. """Auto-detect api_mode from the resolved base URL.
@ -47,9 +56,10 @@ def _detect_api_mode_for_url(base_url: str) -> Optional[str]:
``chat_completions``. ``chat_completions``.
""" """
normalized = (base_url or "").strip().lower().rstrip("/") normalized = (base_url or "").strip().lower().rstrip("/")
if "api.x.ai" in normalized: hostname = _base_url_hostname(base_url)
if hostname == "api.x.ai":
return "codex_responses" return "codex_responses"
if "api.openai.com" in normalized and "openrouter" not in normalized: if hostname == "api.openai.com":
return "codex_responses" return "codex_responses"
if normalized.endswith("/anthropic"): if normalized.endswith("/anthropic"):
return "anthropic_messages" return "anthropic_messages"

View file

@ -38,6 +38,7 @@ import threading
from types import SimpleNamespace from types import SimpleNamespace
import uuid import uuid
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from urllib.parse import urlparse
from openai import OpenAI from openai import OpenAI
import fire import fire
from datetime import datetime from datetime import datetime
@ -127,6 +128,14 @@ from agent.trajectory import (
from utils import atomic_json_write, env_var_enabled from utils import atomic_json_write, env_var_enabled
def _base_url_hostname(base_url: str) -> str:
raw = (base_url or "").strip()
if not raw:
return ""
parsed = urlparse(raw if "://" in raw else f"//{raw}")
return (parsed.hostname or "").lower().rstrip(".")
class _SafeWriter: class _SafeWriter:
"""Transparent stdio wrapper that catches OSError/ValueError from broken pipes. """Transparent stdio wrapper that catches OSError/ValueError from broken pipes.
@ -703,6 +712,7 @@ class AIAgent:
def base_url(self, value: str) -> None: def base_url(self, value: str) -> None:
self._base_url = value self._base_url = value
self._base_url_lower = value.lower() if value else "" self._base_url_lower = value.lower() if value else ""
self._base_url_hostname = _base_url_hostname(value)
def __init__( def __init__(
self, self,
@ -847,7 +857,7 @@ class AIAgent:
elif (provider_name is None) and "chatgpt.com/backend-api/codex" in self._base_url_lower: elif (provider_name is None) and "chatgpt.com/backend-api/codex" in self._base_url_lower:
self.api_mode = "codex_responses" self.api_mode = "codex_responses"
self.provider = "openai-codex" self.provider = "openai-codex"
elif (provider_name is None) and "api.x.ai" in self._base_url_lower: elif (provider_name is None) and self._base_url_hostname == "api.x.ai":
self.api_mode = "codex_responses" self.api_mode = "codex_responses"
self.provider = "xai" self.provider = "xai"
elif self.provider == "anthropic" or (provider_name is None and "api.anthropic.com" in self._base_url_lower): elif self.provider == "anthropic" or (provider_name is None and "api.anthropic.com" in self._base_url_lower):
@ -2259,8 +2269,13 @@ class AIAgent:
def _is_direct_openai_url(self, base_url: str = None) -> bool: def _is_direct_openai_url(self, base_url: str = None) -> bool:
"""Return True when a base URL targets OpenAI's native API.""" """Return True when a base URL targets OpenAI's native API."""
url = (base_url or self._base_url_lower).lower() if base_url is not None:
return "api.openai.com" in url and "openrouter" not in url hostname = _base_url_hostname(base_url)
else:
hostname = getattr(self, "_base_url_hostname", "") or _base_url_hostname(
getattr(self, "_base_url_lower", "")
)
return hostname == "api.openai.com"
def _resolved_api_call_timeout(self) -> float: def _resolved_api_call_timeout(self) -> float:
"""Resolve the effective per-call request timeout in seconds. """Resolve the effective per-call request timeout in seconds.
@ -6747,7 +6762,7 @@ class AIAgent:
if not is_github_responses: if not is_github_responses:
kwargs["prompt_cache_key"] = self.session_id kwargs["prompt_cache_key"] = self.session_id
is_xai_responses = self.provider == "xai" or "api.x.ai" in (self.base_url or "").lower() is_xai_responses = self.provider == "xai" or self._base_url_hostname == "api.x.ai"
if reasoning_enabled and is_xai_responses: if reasoning_enabled and is_xai_responses:
# xAI reasons automatically — no effort param, just include encrypted content # xAI reasons automatically — no effort param, just include encrypted content

View file

@ -0,0 +1,27 @@
from __future__ import annotations
from run_agent import AIAgent
def _agent_with_base_url(base_url: str) -> AIAgent:
agent = object.__new__(AIAgent)
agent.base_url = base_url
return agent
def test_direct_openai_url_requires_openai_host():
agent = _agent_with_base_url("https://api.openai.com.example/v1")
assert agent._is_direct_openai_url() is False
def test_direct_openai_url_ignores_path_segment_match():
agent = _agent_with_base_url("https://proxy.example.test/api.openai.com/v1")
assert agent._is_direct_openai_url() is False
def test_direct_openai_url_accepts_native_host():
agent = _agent_with_base_url("https://api.openai.com/v1")
assert agent._is_direct_openai_url() is True

View file

@ -28,6 +28,15 @@ class TestCodexResponsesDetection:
# api.openai.com check must exclude openrouter (which routes to openai-hosted models). # api.openai.com check must exclude openrouter (which routes to openai-hosted models).
assert _detect_api_mode_for_url("https://openrouter.ai/api/v1") is None assert _detect_api_mode_for_url("https://openrouter.ai/api/v1") is None
def test_openai_host_suffix_does_not_match(self):
assert _detect_api_mode_for_url("https://api.openai.com.example/v1") is None
def test_openai_path_segment_does_not_match(self):
assert _detect_api_mode_for_url("https://proxy.example.test/api.openai.com/v1") is None
def test_xai_host_suffix_does_not_match(self):
assert _detect_api_mode_for_url("https://api.x.ai.example/v1") is None
class TestAnthropicMessagesDetection: class TestAnthropicMessagesDetection:
"""Third-party gateways that speak the Anthropic protocol under /anthropic.""" """Third-party gateways that speak the Anthropic protocol under /anthropic."""