diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 2b39be989b..c03c5e89cb 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -210,6 +210,13 @@ def _normalize_base_url(base_url: str) -> str: return (base_url or "").strip().rstrip("/") +def _auth_headers(api_key: str = "") -> Dict[str, str]: + token = str(api_key or "").strip() + if not token: + return {} + return {"Authorization": f"Bearer {token}"} + + def _is_openrouter_base_url(base_url: str) -> bool: return "openrouter.ai" in _normalize_base_url(base_url).lower() @@ -309,7 +316,7 @@ def is_local_endpoint(base_url: str) -> bool: return False -def detect_local_server_type(base_url: str) -> Optional[str]: +def detect_local_server_type(base_url: str, api_key: str = "") -> Optional[str]: """Detect which local server is running at base_url by probing known endpoints. Returns one of: "ollama", "lm-studio", "vllm", "llamacpp", or None. @@ -321,8 +328,10 @@ def detect_local_server_type(base_url: str) -> Optional[str]: if server_url.endswith("/v1"): server_url = server_url[:-3] + headers = _auth_headers(api_key) + try: - with httpx.Client(timeout=2.0) as client: + with httpx.Client(timeout=2.0, headers=headers) as client: # LM Studio exposes /api/v1/models — check first (most specific) try: r = client.get(f"{server_url}/api/v1/models") @@ -509,6 +518,59 @@ def fetch_endpoint_model_metadata( headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} last_error: Optional[Exception] = None + if is_local_endpoint(normalized): + try: + if detect_local_server_type(normalized, api_key=api_key) == "lm-studio": + server_url = normalized[:-3].rstrip("/") if normalized.endswith("/v1") else normalized + response = requests.get( + server_url.rstrip("/") + "/api/v1/models", + headers=headers, + timeout=10, + ) + response.raise_for_status() + payload = response.json() + cache: Dict[str, Dict[str, Any]] = {} + for model in payload.get("models", []): + if not isinstance(model, dict): + continue + model_id = model.get("key") or model.get("id") + if not model_id: + continue + entry: Dict[str, Any] = {"name": model.get("name", model_id)} + + context_length = None + for inst in model.get("loaded_instances", []) or []: + if not isinstance(inst, dict): + continue + cfg = inst.get("config", {}) + ctx = cfg.get("context_length") if isinstance(cfg, dict) else None + if isinstance(ctx, int) and ctx > 0: + context_length = ctx + break + if context_length is None: + context_length = _extract_context_length(model) + if context_length is not None: + entry["context_length"] = context_length + + max_completion_tokens = _extract_max_completion_tokens(model) + if max_completion_tokens is not None: + entry["max_completion_tokens"] = max_completion_tokens + + pricing = _extract_pricing(model) + if pricing: + entry["pricing"] = pricing + + _add_model_aliases(cache, model_id, entry) + alt_id = model.get("id") + if isinstance(alt_id, str) and alt_id and alt_id != model_id: + _add_model_aliases(cache, alt_id, entry) + + _endpoint_model_metadata_cache[normalized] = cache + _endpoint_model_metadata_cache_time[normalized] = time.time() + return cache + except Exception as exc: + last_error = exc + for candidate in candidates: url = candidate.rstrip("/") + "/models" try: @@ -715,7 +777,7 @@ def _model_id_matches(candidate_id: str, lookup_model: str) -> bool: return False -def query_ollama_num_ctx(model: str, base_url: str) -> Optional[int]: +def query_ollama_num_ctx(model: str, base_url: str, api_key: str = "") -> Optional[int]: """Query an Ollama server for the model's context length. Returns the model's maximum context from GGUF metadata via ``/api/show``, @@ -733,14 +795,16 @@ def query_ollama_num_ctx(model: str, base_url: str) -> Optional[int]: server_url = server_url[:-3] try: - server_type = detect_local_server_type(base_url) + server_type = detect_local_server_type(base_url, api_key=api_key) except Exception: return None if server_type != "ollama": return None + headers = _auth_headers(api_key) + try: - with httpx.Client(timeout=3.0) as client: + with httpx.Client(timeout=3.0, headers=headers) as client: resp = client.post(f"{server_url}/api/show", json={"name": bare_model}) if resp.status_code != 200: return None @@ -768,7 +832,7 @@ def query_ollama_num_ctx(model: str, base_url: str) -> Optional[int]: return None -def _query_local_context_length(model: str, base_url: str) -> Optional[int]: +def _query_local_context_length(model: str, base_url: str, api_key: str = "") -> Optional[int]: """Query a local server for the model's context length.""" import httpx @@ -781,13 +845,15 @@ def _query_local_context_length(model: str, base_url: str) -> Optional[int]: if server_url.endswith("/v1"): server_url = server_url[:-3] + headers = _auth_headers(api_key) + try: - server_type = detect_local_server_type(base_url) + server_type = detect_local_server_type(base_url, api_key=api_key) except Exception: server_type = None try: - with httpx.Client(timeout=3.0) as client: + with httpx.Client(timeout=3.0, headers=headers) as client: # Ollama: /api/show returns model details with context info if server_type == "ollama": resp = client.post(f"{server_url}/api/show", json={"name": model}) @@ -998,7 +1064,7 @@ def get_model_context_length( if not _is_known_provider_base_url(base_url): # 3. Try querying local server directly if is_local_endpoint(base_url): - local_ctx = _query_local_context_length(model, base_url) + local_ctx = _query_local_context_length(model, base_url, api_key=api_key) if local_ctx and local_ctx > 0: save_context_length(model, base_url, local_ctx) return local_ctx @@ -1068,7 +1134,7 @@ def get_model_context_length( # 9. Query local server as last resort if base_url and is_local_endpoint(base_url): - local_ctx = _query_local_context_length(model, base_url) + local_ctx = _query_local_context_length(model, base_url, api_key=api_key) if local_ctx and local_ctx > 0: save_context_length(model, base_url, local_ctx) return local_ctx diff --git a/gateway/run.py b/gateway/run.py index eb0dfe237f..36c5655b10 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -3876,9 +3876,11 @@ class GatewayRunner: from agent.model_metadata import get_model_context_length _msg_cwd = os.environ.get("TERMINAL_CWD", os.path.expanduser("~")) + _msg_runtime = _resolve_runtime_agent_kwargs() _msg_ctx_len = get_model_context_length( self._model, - base_url=self._base_url or "", + base_url=self._base_url or _msg_runtime.get("base_url") or "", + api_key=_msg_runtime.get("api_key") or "", ) _ctx_result = await preprocess_context_references_async( message_text, diff --git a/tests/agent/test_model_metadata_local_ctx.py b/tests/agent/test_model_metadata_local_ctx.py index 6852a82cc9..5da1ed7037 100644 --- a/tests/agent/test_model_metadata_local_ctx.py +++ b/tests/agent/test_model_metadata_local_ctx.py @@ -424,6 +424,68 @@ class TestQueryLocalContextLengthLmStudio: ) +class TestDetectLocalServerTypeAuth: + def test_passes_bearer_token_to_probe_requests(self): + from agent.model_metadata import detect_local_server_type + + resp = MagicMock() + resp.status_code = 200 + + client_mock = MagicMock() + client_mock.__enter__ = lambda s: client_mock + client_mock.__exit__ = MagicMock(return_value=False) + client_mock.get.return_value = resp + + with patch("httpx.Client", return_value=client_mock) as mock_client: + result = detect_local_server_type("http://localhost:1234/v1", api_key="lm-token") + + assert result == "lm-studio" + assert mock_client.call_args.kwargs["headers"] == { + "Authorization": "Bearer lm-token" + } + + +class TestFetchEndpointModelMetadataLmStudio: + """fetch_endpoint_model_metadata should use LM Studio's native models endpoint.""" + + def _make_resp(self, body): + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.json.return_value = body + return resp + + def test_uses_native_models_endpoint_only(self): + from agent.model_metadata import fetch_endpoint_model_metadata + + native_resp = self._make_resp( + { + "models": [ + { + "key": "lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf", + "id": "lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf", + "max_context_length": 131072, + } + ] + } + ) + + with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ + patch("agent.model_metadata.requests.get", return_value=native_resp) as mock_get: + result = fetch_endpoint_model_metadata( + "http://localhost:1234/v1", + api_key="lm-token", + force_refresh=True, + ) + + assert mock_get.call_count == 1 + assert mock_get.call_args[0][0] == "http://localhost:1234/api/v1/models" + assert mock_get.call_args.kwargs["headers"] == { + "Authorization": "Bearer lm-token" + } + assert result["lmstudio-community/Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf"]["context_length"] == 131072 + assert result["Qwen3.5-27B-GGUF/Qwen3.5-27B-Q8_0.gguf"]["context_length"] == 131072 + + class TestQueryLocalContextLengthNetworkError: """_query_local_context_length handles network failures gracefully."""