fix: forward auth when probing local model metadata

Pass the user's configured api_key through local-server detection and
context-length probes (detect_local_server_type, _query_local_context_length,
query_ollama_num_ctx) and use LM Studio's native /api/v1/models endpoint in
fetch_endpoint_model_metadata when a loaded instance is present — so the
probed context length is the actual runtime value the user loaded the model
at, not just the model's theoretical max.

Helps local-LLM users whose auto-detected context length was wrong, causing
compression failures and context-overrun crashes.
This commit is contained in:
Tanner Fokkens 2026-04-20 20:49:44 -07:00 committed by Teknium
parent 3821921ef7
commit cde7283821
3 changed files with 141 additions and 11 deletions

View file

@ -210,6 +210,13 @@ def _normalize_base_url(base_url: str) -> str:
return (base_url or "").strip().rstrip("/")
def _auth_headers(api_key: str = "") -> Dict[str, str]:
token = str(api_key or "").strip()
if not token:
return {}
return {"Authorization": f"Bearer {token}"}
def _is_openrouter_base_url(base_url: str) -> bool:
return "openrouter.ai" in _normalize_base_url(base_url).lower()
@ -309,7 +316,7 @@ def is_local_endpoint(base_url: str) -> bool:
return False
def detect_local_server_type(base_url: str) -> Optional[str]:
def detect_local_server_type(base_url: str, api_key: str = "") -> Optional[str]:
"""Detect which local server is running at base_url by probing known endpoints.
Returns one of: "ollama", "lm-studio", "vllm", "llamacpp", or None.
@ -321,8 +328,10 @@ def detect_local_server_type(base_url: str) -> Optional[str]:
if server_url.endswith("/v1"):
server_url = server_url[:-3]
headers = _auth_headers(api_key)
try:
with httpx.Client(timeout=2.0) as client:
with httpx.Client(timeout=2.0, headers=headers) as client:
# LM Studio exposes /api/v1/models — check first (most specific)
try:
r = client.get(f"{server_url}/api/v1/models")
@ -509,6 +518,59 @@ def fetch_endpoint_model_metadata(
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
last_error: Optional[Exception] = None
if is_local_endpoint(normalized):
try:
if detect_local_server_type(normalized, api_key=api_key) == "lm-studio":
server_url = normalized[:-3].rstrip("/") if normalized.endswith("/v1") else normalized
response = requests.get(
server_url.rstrip("/") + "/api/v1/models",
headers=headers,
timeout=10,
)
response.raise_for_status()
payload = response.json()
cache: Dict[str, Dict[str, Any]] = {}
for model in payload.get("models", []):
if not isinstance(model, dict):
continue
model_id = model.get("key") or model.get("id")
if not model_id:
continue
entry: Dict[str, Any] = {"name": model.get("name", model_id)}
context_length = None
for inst in model.get("loaded_instances", []) or []:
if not isinstance(inst, dict):
continue
cfg = inst.get("config", {})
ctx = cfg.get("context_length") if isinstance(cfg, dict) else None
if isinstance(ctx, int) and ctx > 0:
context_length = ctx
break
if context_length is None:
context_length = _extract_context_length(model)
if context_length is not None:
entry["context_length"] = context_length
max_completion_tokens = _extract_max_completion_tokens(model)
if max_completion_tokens is not None:
entry["max_completion_tokens"] = max_completion_tokens
pricing = _extract_pricing(model)
if pricing:
entry["pricing"] = pricing
_add_model_aliases(cache, model_id, entry)
alt_id = model.get("id")
if isinstance(alt_id, str) and alt_id and alt_id != model_id:
_add_model_aliases(cache, alt_id, entry)
_endpoint_model_metadata_cache[normalized] = cache
_endpoint_model_metadata_cache_time[normalized] = time.time()
return cache
except Exception as exc:
last_error = exc
for candidate in candidates:
url = candidate.rstrip("/") + "/models"
try:
@ -715,7 +777,7 @@ def _model_id_matches(candidate_id: str, lookup_model: str) -> bool:
return False
def query_ollama_num_ctx(model: str, base_url: str) -> Optional[int]:
def query_ollama_num_ctx(model: str, base_url: str, api_key: str = "") -> Optional[int]:
"""Query an Ollama server for the model's context length.
Returns the model's maximum context from GGUF metadata via ``/api/show``,
@ -733,14 +795,16 @@ def query_ollama_num_ctx(model: str, base_url: str) -> Optional[int]:
server_url = server_url[:-3]
try:
server_type = detect_local_server_type(base_url)
server_type = detect_local_server_type(base_url, api_key=api_key)
except Exception:
return None
if server_type != "ollama":
return None
headers = _auth_headers(api_key)
try:
with httpx.Client(timeout=3.0) as client:
with httpx.Client(timeout=3.0, headers=headers) as client:
resp = client.post(f"{server_url}/api/show", json={"name": bare_model})
if resp.status_code != 200:
return None
@ -768,7 +832,7 @@ def query_ollama_num_ctx(model: str, base_url: str) -> Optional[int]:
return None
def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
def _query_local_context_length(model: str, base_url: str, api_key: str = "") -> Optional[int]:
"""Query a local server for the model's context length."""
import httpx
@ -781,13 +845,15 @@ def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
if server_url.endswith("/v1"):
server_url = server_url[:-3]
headers = _auth_headers(api_key)
try:
server_type = detect_local_server_type(base_url)
server_type = detect_local_server_type(base_url, api_key=api_key)
except Exception:
server_type = None
try:
with httpx.Client(timeout=3.0) as client:
with httpx.Client(timeout=3.0, headers=headers) as client:
# Ollama: /api/show returns model details with context info
if server_type == "ollama":
resp = client.post(f"{server_url}/api/show", json={"name": model})
@ -998,7 +1064,7 @@ def get_model_context_length(
if not _is_known_provider_base_url(base_url):
# 3. Try querying local server directly
if is_local_endpoint(base_url):
local_ctx = _query_local_context_length(model, base_url)
local_ctx = _query_local_context_length(model, base_url, api_key=api_key)
if local_ctx and local_ctx > 0:
save_context_length(model, base_url, local_ctx)
return local_ctx
@ -1068,7 +1134,7 @@ def get_model_context_length(
# 9. Query local server as last resort
if base_url and is_local_endpoint(base_url):
local_ctx = _query_local_context_length(model, base_url)
local_ctx = _query_local_context_length(model, base_url, api_key=api_key)
if local_ctx and local_ctx > 0:
save_context_length(model, base_url, local_ctx)
return local_ctx