feat: query local server for actual context window size

Instead of defaulting to 2M for unknown local models, query the server
API for the real context length. Supports Ollama (/api/show), vLLM
(max_model_len), and LM Studio (/v1/models). Results are cached to
avoid repeated queries.
This commit is contained in:
Peppi Littera 2026-03-18 21:38:41 +01:00
parent 4c0c7f4c6e
commit d223f7388d
2 changed files with 485 additions and 6 deletions

View file

@ -146,6 +146,9 @@ _MAX_COMPLETION_KEYS = (
"max_tokens",
)
# Local server hostnames / address patterns
_LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0")
def _normalize_base_url(base_url: str) -> str:
return (base_url or "").strip().rstrip("/")
@ -178,6 +181,92 @@ def _is_known_provider_base_url(base_url: str) -> bool:
return any(known_host in host for known_host in known_hosts)
def is_local_endpoint(base_url: str) -> bool:
"""Return True if base_url points to a local machine (localhost / RFC-1918 / WSL)."""
normalized = _normalize_base_url(base_url)
if not normalized:
return False
url = normalized if "://" in normalized else f"http://{normalized}"
try:
parsed = urlparse(url)
host = parsed.hostname or ""
except Exception:
return False
if host in _LOCAL_HOSTS:
return True
# RFC-1918 private ranges and link-local
import ipaddress
try:
addr = ipaddress.ip_address(host)
return addr.is_private or addr.is_loopback or addr.is_link_local
except ValueError:
pass
# Bare IP that looks like a private range (e.g. 172.26.x.x for WSL)
parts = host.split(".")
if len(parts) == 4:
try:
first, second = int(parts[0]), int(parts[1])
if first == 10:
return True
if first == 172 and 16 <= second <= 31:
return True
if first == 192 and second == 168:
return True
except ValueError:
pass
return False
def detect_local_server_type(base_url: str) -> Optional[str]:
"""Detect which local server is running at base_url by probing known endpoints.
Returns one of: "ollama", "lmstudio", "vllm", "llamacpp", or None.
"""
import httpx
normalized = _normalize_base_url(base_url)
server_url = normalized
if server_url.endswith("/v1"):
server_url = server_url[:-3]
try:
with httpx.Client(timeout=2.0) as client:
# Ollama exposes /api/tags
try:
r = client.get(f"{server_url}/api/tags")
if r.status_code == 200:
return "ollama"
except Exception:
pass
# LM Studio exposes /api/v0/models
try:
r = client.get(f"{server_url}/api/v0/models")
if r.status_code == 200:
return "lmstudio"
except Exception:
pass
# llama.cpp exposes /props
try:
r = client.get(f"{server_url}/props")
if r.status_code == 200 and "default_generation_settings" in r.text:
return "llamacpp"
except Exception:
pass
# vLLM: /version
try:
r = client.get(f"{server_url}/version")
if r.status_code == 200:
data = r.json()
if "version" in data:
return "vllm"
except Exception:
pass
except Exception:
pass
return None
def _iter_nested_dicts(value: Any):
if isinstance(value, dict):
yield value
@ -383,7 +472,7 @@ def _get_context_cache_path() -> Path:
def _load_context_cache() -> Dict[str, int]:
"""Load the model+provider context_length cache from disk."""
"""Load the model+provider -> context_length cache from disk."""
path = _get_context_cache_path()
if not path.exists():
return {}
@ -412,7 +501,7 @@ def save_context_length(model: str, base_url: str, length: int) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
logger.info("Cached context length %s %s tokens", key, f"{length:,}")
logger.info("Cached context length %s -> %s tokens", key, f"{length:,}")
except Exception as e:
logger.debug("Failed to save context length cache: %s", e)
@ -460,6 +549,69 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
return None
def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
"""Query a local server for the model's context length."""
import httpx
# Strip /v1 suffix to get the server root
server_url = base_url.rstrip("/")
if server_url.endswith("/v1"):
server_url = server_url[:-3]
try:
server_type = detect_local_server_type(base_url)
except Exception:
server_type = None
try:
with httpx.Client(timeout=3.0) as client:
# Ollama: /api/show returns model details with context info
if server_type == "ollama":
resp = client.post(f"{server_url}/api/show", json={"name": model})
if resp.status_code == 200:
data = resp.json()
# Check model_info for context length
model_info = data.get("model_info", {})
for key, value in model_info.items():
if "context_length" in key and isinstance(value, (int, float)):
return int(value)
# Check parameters string for num_ctx
params = data.get("parameters", "")
if "num_ctx" in params:
for line in params.split("\n"):
if "num_ctx" in line:
parts = line.strip().split()
if len(parts) >= 2:
try:
return int(parts[-1])
except ValueError:
pass
# LM Studio / vLLM / llama.cpp: try /v1/models/{model}
resp = client.get(f"{server_url}/v1/models/{model}")
if resp.status_code == 200:
data = resp.json()
# vLLM returns max_model_len
ctx = data.get("max_model_len") or data.get("context_length") or data.get("max_tokens")
if ctx and isinstance(ctx, (int, float)):
return int(ctx)
# Try /v1/models and find the model in the list
resp = client.get(f"{server_url}/v1/models")
if resp.status_code == 200:
data = resp.json()
models_list = data.get("data", [])
for m in models_list:
if m.get("id") == model:
ctx = m.get("max_model_len") or m.get("context_length") or m.get("max_tokens")
if ctx and isinstance(ctx, (int, float)):
return int(ctx)
except Exception:
pass
return None
def get_model_context_length(
model: str,
base_url: str = "",
@ -472,9 +624,10 @@ def get_model_context_length(
0. Explicit config override (model.context_length in config.yaml)
1. Persistent cache (previously discovered via probing)
2. Active endpoint metadata (/models for explicit custom endpoints)
3. OpenRouter API metadata
4. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only)
5. First probe tier (2M) will be narrowed on first context error
3. Local server query (for local endpoints when model not in /models list)
4. OpenRouter API metadata
5. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only)
6. First probe tier (2M) will be narrowed on first context error
"""
# 0. Explicit config override — user knows best
if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0:
@ -507,6 +660,12 @@ def get_model_context_length(
if not _is_known_provider_base_url(base_url):
# Explicit third-party endpoints should not borrow fuzzy global
# defaults from unrelated providers with similarly named models.
# But first try querying the local server directly.
if is_local_endpoint(base_url):
local_ctx = _query_local_context_length(model, base_url)
if local_ctx and local_ctx > 0:
save_context_length(model, base_url, local_ctx)
return local_ctx
logger.info(
"Could not detect context length for model %r at %s"
"defaulting to %s tokens (probe-down). Set model.context_length "
@ -527,7 +686,14 @@ def get_model_context_length(
if default_model in model or model in default_model:
return length
# 5. Unknown model — start at highest probe tier
# 5. Query local server for unknown models before defaulting to 2M
if base_url and is_local_endpoint(base_url):
local_ctx = _query_local_context_length(model, base_url)
if local_ctx and local_ctx > 0:
save_context_length(model, base_url, local_ctx)
return local_ctx
# 6. Unknown model — start at highest probe tier
return CONTEXT_PROBE_TIERS[0]

View file

@ -0,0 +1,313 @@
"""Tests for _query_local_context_length and the local server fallback in
get_model_context_length.
All tests use synthetic inputs no filesystem or live server required.
"""
import sys
import os
import json
from unittest.mock import MagicMock, patch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import pytest
# ---------------------------------------------------------------------------
# _query_local_context_length — unit tests with mocked httpx
# ---------------------------------------------------------------------------
class TestQueryLocalContextLengthOllama:
"""_query_local_context_length with server_type == 'ollama'."""
def _make_resp(self, status_code, body):
resp = MagicMock()
resp.status_code = status_code
resp.json.return_value = body
return resp
def test_ollama_model_info_context_length(self):
"""Reads context length from model_info dict in /api/show response."""
from agent.model_metadata import _query_local_context_length
show_resp = self._make_resp(200, {
"model_info": {"llama.context_length": 131072}
})
models_resp = self._make_resp(404, {})
client_mock = MagicMock()
client_mock.__enter__ = lambda s: client_mock
client_mock.__exit__ = MagicMock(return_value=False)
client_mock.post.return_value = show_resp
client_mock.get.return_value = models_resp
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
patch("httpx.Client", return_value=client_mock):
result = _query_local_context_length("omnicoder-9b", "http://localhost:11434/v1")
assert result == 131072
def test_ollama_parameters_num_ctx(self):
"""Falls back to num_ctx in parameters string when model_info lacks context_length."""
from agent.model_metadata import _query_local_context_length
show_resp = self._make_resp(200, {
"model_info": {},
"parameters": "num_ctx 32768\ntemperature 0.7\n"
})
models_resp = self._make_resp(404, {})
client_mock = MagicMock()
client_mock.__enter__ = lambda s: client_mock
client_mock.__exit__ = MagicMock(return_value=False)
client_mock.post.return_value = show_resp
client_mock.get.return_value = models_resp
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
patch("httpx.Client", return_value=client_mock):
result = _query_local_context_length("some-model", "http://localhost:11434/v1")
assert result == 32768
def test_ollama_show_404_falls_through(self):
"""When /api/show returns 404, falls through to /v1/models/{model}."""
from agent.model_metadata import _query_local_context_length
show_resp = self._make_resp(404, {})
model_detail_resp = self._make_resp(200, {"max_model_len": 65536})
client_mock = MagicMock()
client_mock.__enter__ = lambda s: client_mock
client_mock.__exit__ = MagicMock(return_value=False)
client_mock.post.return_value = show_resp
client_mock.get.return_value = model_detail_resp
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
patch("httpx.Client", return_value=client_mock):
result = _query_local_context_length("some-model", "http://localhost:11434/v1")
assert result == 65536
class TestQueryLocalContextLengthVllm:
"""_query_local_context_length with vLLM-style /v1/models/{model} response."""
def _make_resp(self, status_code, body):
resp = MagicMock()
resp.status_code = status_code
resp.json.return_value = body
return resp
def test_vllm_max_model_len(self):
"""Reads max_model_len from /v1/models/{model} response."""
from agent.model_metadata import _query_local_context_length
detail_resp = self._make_resp(200, {"id": "omnicoder-9b", "max_model_len": 100000})
list_resp = self._make_resp(404, {})
client_mock = MagicMock()
client_mock.__enter__ = lambda s: client_mock
client_mock.__exit__ = MagicMock(return_value=False)
client_mock.post.return_value = self._make_resp(404, {})
client_mock.get.return_value = detail_resp
with patch("agent.model_metadata.detect_local_server_type", return_value="vllm"), \
patch("httpx.Client", return_value=client_mock):
result = _query_local_context_length("omnicoder-9b", "http://localhost:8000/v1")
assert result == 100000
def test_vllm_context_length_key(self):
"""Reads context_length from /v1/models/{model} response."""
from agent.model_metadata import _query_local_context_length
detail_resp = self._make_resp(200, {"id": "some-model", "context_length": 32768})
client_mock = MagicMock()
client_mock.__enter__ = lambda s: client_mock
client_mock.__exit__ = MagicMock(return_value=False)
client_mock.post.return_value = self._make_resp(404, {})
client_mock.get.return_value = detail_resp
with patch("agent.model_metadata.detect_local_server_type", return_value="vllm"), \
patch("httpx.Client", return_value=client_mock):
result = _query_local_context_length("some-model", "http://localhost:8000/v1")
assert result == 32768
class TestQueryLocalContextLengthModelsList:
"""_query_local_context_length: falls back to /v1/models list."""
def _make_resp(self, status_code, body):
resp = MagicMock()
resp.status_code = status_code
resp.json.return_value = body
return resp
def test_models_list_max_model_len(self):
"""Finds context length for model in /v1/models list."""
from agent.model_metadata import _query_local_context_length
detail_resp = self._make_resp(404, {})
list_resp = self._make_resp(200, {
"data": [
{"id": "other-model", "max_model_len": 4096},
{"id": "omnicoder-9b", "max_model_len": 131072},
]
})
call_count = [0]
def side_effect(url, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
return detail_resp # /v1/models/omnicoder-9b
return list_resp # /v1/models
client_mock = MagicMock()
client_mock.__enter__ = lambda s: client_mock
client_mock.__exit__ = MagicMock(return_value=False)
client_mock.post.return_value = self._make_resp(404, {})
client_mock.get.side_effect = side_effect
with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
patch("httpx.Client", return_value=client_mock):
result = _query_local_context_length("omnicoder-9b", "http://localhost:1234")
assert result == 131072
def test_models_list_model_not_found_returns_none(self):
"""Returns None when model is not in the /v1/models list."""
from agent.model_metadata import _query_local_context_length
detail_resp = self._make_resp(404, {})
list_resp = self._make_resp(200, {
"data": [{"id": "other-model", "max_model_len": 4096}]
})
call_count = [0]
def side_effect(url, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
return detail_resp
return list_resp
client_mock = MagicMock()
client_mock.__enter__ = lambda s: client_mock
client_mock.__exit__ = MagicMock(return_value=False)
client_mock.post.return_value = self._make_resp(404, {})
client_mock.get.side_effect = side_effect
with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
patch("httpx.Client", return_value=client_mock):
result = _query_local_context_length("omnicoder-9b", "http://localhost:1234")
assert result is None
class TestQueryLocalContextLengthNetworkError:
"""_query_local_context_length handles network failures gracefully."""
def test_connection_error_returns_none(self):
"""Returns None when the server is unreachable."""
from agent.model_metadata import _query_local_context_length
client_mock = MagicMock()
client_mock.__enter__ = lambda s: client_mock
client_mock.__exit__ = MagicMock(return_value=False)
client_mock.post.side_effect = Exception("Connection refused")
client_mock.get.side_effect = Exception("Connection refused")
with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
patch("httpx.Client", return_value=client_mock):
result = _query_local_context_length("omnicoder-9b", "http://localhost:11434/v1")
assert result is None
# ---------------------------------------------------------------------------
# get_model_context_length — integration-style tests with mocked helpers
# ---------------------------------------------------------------------------
class TestGetModelContextLengthLocalFallback:
"""get_model_context_length uses local server query before falling back to 2M."""
def test_local_endpoint_unknown_model_queries_server(self):
"""Unknown model on local endpoint gets ctx from server, not 2M default."""
from agent.model_metadata import get_model_context_length
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
patch("agent.model_metadata.is_local_endpoint", return_value=True), \
patch("agent.model_metadata._query_local_context_length", return_value=131072), \
patch("agent.model_metadata.save_context_length") as mock_save:
result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
assert result == 131072
def test_local_endpoint_unknown_model_result_is_cached(self):
"""Context length returned from local server is persisted to cache."""
from agent.model_metadata import get_model_context_length
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
patch("agent.model_metadata.is_local_endpoint", return_value=True), \
patch("agent.model_metadata._query_local_context_length", return_value=131072), \
patch("agent.model_metadata.save_context_length") as mock_save:
get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
mock_save.assert_called_once_with("omnicoder-9b", "http://localhost:11434/v1", 131072)
def test_local_endpoint_server_returns_none_falls_back_to_2m(self):
"""When local server returns None, still falls back to 2M probe tier."""
from agent.model_metadata import get_model_context_length, CONTEXT_PROBE_TIERS
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
patch("agent.model_metadata.is_local_endpoint", return_value=True), \
patch("agent.model_metadata._query_local_context_length", return_value=None):
result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
assert result == CONTEXT_PROBE_TIERS[0]
def test_non_local_endpoint_does_not_query_local_server(self):
"""For non-local endpoints, _query_local_context_length is not called."""
from agent.model_metadata import get_model_context_length, CONTEXT_PROBE_TIERS
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
patch("agent.model_metadata.is_local_endpoint", return_value=False), \
patch("agent.model_metadata._query_local_context_length") as mock_query:
result = get_model_context_length(
"unknown-model", "https://some-cloud-api.example.com/v1"
)
mock_query.assert_not_called()
def test_cached_result_skips_local_query(self):
"""Cached context length is returned without querying the local server."""
from agent.model_metadata import get_model_context_length
with patch("agent.model_metadata.get_cached_context_length", return_value=65536), \
patch("agent.model_metadata._query_local_context_length") as mock_query:
result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
assert result == 65536
mock_query.assert_not_called()
def test_no_base_url_does_not_query_local_server(self):
"""When base_url is empty, local server is not queried."""
from agent.model_metadata import get_model_context_length
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
patch("agent.model_metadata._query_local_context_length") as mock_query:
result = get_model_context_length("unknown-xyz-model", "")
mock_query.assert_not_called()