diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 0a448990d3..6e14d9d997 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -146,6 +146,9 @@ _MAX_COMPLETION_KEYS = ( "max_tokens", ) +# Local server hostnames / address patterns +_LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0") + def _normalize_base_url(base_url: str) -> str: return (base_url or "").strip().rstrip("/") @@ -178,6 +181,99 @@ def _is_known_provider_base_url(base_url: str) -> bool: return any(known_host in host for known_host in known_hosts) +def is_local_endpoint(base_url: str) -> bool: + """Return True if base_url points to a local machine (localhost / RFC-1918 / WSL).""" + normalized = _normalize_base_url(base_url) + if not normalized: + return False + url = normalized if "://" in normalized else f"http://{normalized}" + try: + parsed = urlparse(url) + host = parsed.hostname or "" + except Exception: + return False + if host in _LOCAL_HOSTS: + return True + # RFC-1918 private ranges and link-local + import ipaddress + try: + addr = ipaddress.ip_address(host) + return addr.is_private or addr.is_loopback or addr.is_link_local + except ValueError: + pass + # Bare IP that looks like a private range (e.g. 172.26.x.x for WSL) + parts = host.split(".") + if len(parts) == 4: + try: + first, second = int(parts[0]), int(parts[1]) + if first == 10: + return True + if first == 172 and 16 <= second <= 31: + return True + if first == 192 and second == 168: + return True + except ValueError: + pass + return False + + +def detect_local_server_type(base_url: str) -> Optional[str]: + """Detect which local server is running at base_url by probing known endpoints. + + Returns one of: "ollama", "lm-studio", "vllm", "llamacpp", or None. + """ + import httpx + + normalized = _normalize_base_url(base_url) + server_url = normalized + if server_url.endswith("/v1"): + server_url = server_url[:-3] + + try: + with httpx.Client(timeout=2.0) as client: + # LM Studio exposes /api/v1/models — check first (most specific) + try: + r = client.get(f"{server_url}/api/v1/models") + if r.status_code == 200: + return "lm-studio" + except Exception: + pass + # Ollama exposes /api/tags and responds with {"models": [...]} + # LM Studio returns {"error": "Unexpected endpoint"} with status 200 + # on this path, so we must verify the response contains "models". + try: + r = client.get(f"{server_url}/api/tags") + if r.status_code == 200: + try: + data = r.json() + if "models" in data: + return "ollama" + except Exception: + pass + except Exception: + pass + # llama.cpp exposes /props + try: + r = client.get(f"{server_url}/props") + if r.status_code == 200 and "default_generation_settings" in r.text: + return "llamacpp" + except Exception: + pass + # vLLM: /version + try: + r = client.get(f"{server_url}/version") + if r.status_code == 200: + data = r.json() + if "version" in data: + return "vllm" + except Exception: + pass + except Exception: + pass + + return None + + def _iter_nested_dicts(value: Any): if isinstance(value, dict): yield value @@ -383,7 +479,7 @@ def _get_context_cache_path() -> Path: def _load_context_cache() -> Dict[str, int]: - """Load the model+provider → context_length cache from disk.""" + """Load the model+provider -> context_length cache from disk.""" path = _get_context_cache_path() if not path.exists(): return {} @@ -412,7 +508,7 @@ def save_context_length(model: str, base_url: str, length: int) -> None: path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w") as f: yaml.dump({"context_lengths": cache}, f, default_flow_style=False) - logger.info("Cached context length %s → %s tokens", key, f"{length:,}") + logger.info("Cached context length %s -> %s tokens", key, f"{length:,}") except Exception as e: logger.debug("Failed to save context length cache: %s", e) @@ -460,6 +556,116 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]: return None +def _model_id_matches(candidate_id: str, lookup_model: str) -> bool: + """Return True if *candidate_id* (from server) matches *lookup_model* (configured). + + Supports two forms: + - Exact match: "nvidia-nemotron-super-49b-v1" == "nvidia-nemotron-super-49b-v1" + - Slug match: "nvidia/nvidia-nemotron-super-49b-v1" matches "nvidia-nemotron-super-49b-v1" + (the part after the last "/" equals lookup_model) + + This covers LM Studio's native API which stores models as "publisher/slug" + while users typically configure only the slug after the "local:" prefix. + """ + if candidate_id == lookup_model: + return True + # Slug match: basename of candidate equals the lookup name + if "/" in candidate_id and candidate_id.rsplit("/", 1)[1] == lookup_model: + return True + return False + + +def _query_local_context_length(model: str, base_url: str) -> Optional[int]: + """Query a local server for the model's context length.""" + import httpx + + # Strip provider prefix (e.g., "local:model-name" → "model-name"). + # LM Studio and Ollama don't use provider prefixes in their model IDs. + if ":" in model and not model.startswith("http"): + model = model.split(":", 1)[1] + + # Strip /v1 suffix to get the server root + server_url = base_url.rstrip("/") + if server_url.endswith("/v1"): + server_url = server_url[:-3] + + try: + server_type = detect_local_server_type(base_url) + except Exception: + server_type = None + + try: + with httpx.Client(timeout=3.0) as client: + # Ollama: /api/show returns model details with context info + if server_type == "ollama": + resp = client.post(f"{server_url}/api/show", json={"name": model}) + if resp.status_code == 200: + data = resp.json() + # Check model_info for context length + model_info = data.get("model_info", {}) + for key, value in model_info.items(): + if "context_length" in key and isinstance(value, (int, float)): + return int(value) + # Check parameters string for num_ctx + params = data.get("parameters", "") + if "num_ctx" in params: + for line in params.split("\n"): + if "num_ctx" in line: + parts = line.strip().split() + if len(parts) >= 2: + try: + return int(parts[-1]) + except ValueError: + pass + + # LM Studio native API: /api/v1/models returns max_context_length. + # This is more reliable than the OpenAI-compat /v1/models which + # doesn't include context window information for LM Studio servers. + # Use _model_id_matches for fuzzy matching: LM Studio stores models as + # "publisher/slug" but users configure only "slug" after "local:" prefix. + if server_type == "lm-studio": + resp = client.get(f"{server_url}/api/v1/models") + if resp.status_code == 200: + data = resp.json() + for m in data.get("models", []): + if _model_id_matches(m.get("key", ""), model) or _model_id_matches(m.get("id", ""), model): + # Prefer loaded instance context (actual runtime value) + for inst in m.get("loaded_instances", []): + cfg = inst.get("config", {}) + ctx = cfg.get("context_length") + if ctx and isinstance(ctx, (int, float)): + return int(ctx) + # Fall back to max_context_length (theoretical model max) + ctx = m.get("max_context_length") or m.get("context_length") + if ctx and isinstance(ctx, (int, float)): + return int(ctx) + + # LM Studio / vLLM / llama.cpp: try /v1/models/{model} + resp = client.get(f"{server_url}/v1/models/{model}") + if resp.status_code == 200: + data = resp.json() + # vLLM returns max_model_len + ctx = data.get("max_model_len") or data.get("context_length") or data.get("max_tokens") + if ctx and isinstance(ctx, (int, float)): + return int(ctx) + + # Try /v1/models and find the model in the list. + # Use _model_id_matches to handle "publisher/slug" vs bare "slug". + resp = client.get(f"{server_url}/v1/models") + if resp.status_code == 200: + data = resp.json() + models_list = data.get("data", []) + for m in models_list: + if _model_id_matches(m.get("id", ""), model): + ctx = m.get("max_model_len") or m.get("context_length") or m.get("max_tokens") + if ctx and isinstance(ctx, (int, float)): + return int(ctx) + except Exception: + pass + + return None + + def get_model_context_length( model: str, base_url: str = "", @@ -472,14 +678,21 @@ def get_model_context_length( 0. Explicit config override (model.context_length in config.yaml) 1. Persistent cache (previously discovered via probing) 2. Active endpoint metadata (/models for explicit custom endpoints) - 3. OpenRouter API metadata - 4. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only) - 5. First probe tier (2M) — will be narrowed on first context error + 3. Local server query (for local endpoints when model not in /models list) + 4. OpenRouter API metadata + 5. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only) + 6. First probe tier (2M) — will be narrowed on first context error """ # 0. Explicit config override — user knows best if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0: return config_context_length + # Normalise provider-prefixed model names (e.g. "local:model-name" → + # "model-name") so cache lookups and server queries use the bare ID that + # local servers actually know about. + if ":" in model and not model.startswith("http"): + model = model.split(":", 1)[1] + # 1. Check persistent cache (model+provider) if base_url: cached = get_cached_context_length(model, base_url) @@ -507,6 +720,12 @@ def get_model_context_length( if not _is_known_provider_base_url(base_url): # Explicit third-party endpoints should not borrow fuzzy global # defaults from unrelated providers with similarly named models. + # But first try querying the local server directly. + if is_local_endpoint(base_url): + local_ctx = _query_local_context_length(model, base_url) + if local_ctx and local_ctx > 0: + save_context_length(model, base_url, local_ctx) + return local_ctx logger.info( "Could not detect context length for model %r at %s — " "defaulting to %s tokens (probe-down). Set model.context_length " @@ -527,7 +746,14 @@ def get_model_context_length( if default_model in model or model in default_model: return length - # 5. Unknown model — start at highest probe tier + # 5. Query local server for unknown models before defaulting to 2M + if base_url and is_local_endpoint(base_url): + local_ctx = _query_local_context_length(model, base_url) + if local_ctx and local_ctx > 0: + save_context_length(model, base_url, local_ctx) + return local_ctx + + # 6. Unknown model — start at highest probe tier return CONTEXT_PROBE_TIERS[0] diff --git a/run_agent.py b/run_agent.py index b035ee0eaa..2655ac288c 100644 --- a/run_agent.py +++ b/run_agent.py @@ -6569,7 +6569,21 @@ class AIAgent: self._response_was_previewed = True break - # No fallback -- append the empty message as-is + # No fallback -- if reasoning_text exists, the model put its + # entire response inside tags; use that as the content. + if reasoning_text: + self._vprint(f"{self.log_prefix}Using reasoning as response content (model wrapped entire response in think tags).", force=True) + final_response = reasoning_text + empty_msg = { + "role": "assistant", + "content": final_response, + "reasoning": reasoning_text, + "finish_reason": finish_reason, + } + messages.append(empty_msg) + break + + # Truly empty -- no reasoning and no content empty_msg = { "role": "assistant", "content": final_response, @@ -6577,10 +6591,10 @@ class AIAgent: "finish_reason": finish_reason, } messages.append(empty_msg) - + self._cleanup_task_resources(effective_task_id) self._persist_session(messages, conversation_history) - + return { "final_response": final_response or None, "messages": messages, diff --git a/tests/test_model_metadata_local_ctx.py b/tests/test_model_metadata_local_ctx.py new file mode 100644 index 0000000000..e5ad0dc58c --- /dev/null +++ b/tests/test_model_metadata_local_ctx.py @@ -0,0 +1,493 @@ +"""Tests for _query_local_context_length and the local server fallback in +get_model_context_length. + +All tests use synthetic inputs — no filesystem or live server required. +""" + +import sys +import os +import json +from unittest.mock import MagicMock, patch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import pytest + + +# --------------------------------------------------------------------------- +# _query_local_context_length — unit tests with mocked httpx +# --------------------------------------------------------------------------- + +class TestQueryLocalContextLengthOllama: + """_query_local_context_length with server_type == 'ollama'.""" + + def _make_resp(self, status_code, body): + resp = MagicMock() + resp.status_code = status_code + resp.json.return_value = body + return resp + + def test_ollama_model_info_context_length(self): + """Reads context length from model_info dict in /api/show response.""" + from agent.model_metadata import _query_local_context_length + + show_resp = self._make_resp(200, { + "model_info": {"llama.context_length": 131072} + }) + models_resp = self._make_resp(404, {}) + + client_mock = MagicMock() + client_mock.__enter__ = lambda s: client_mock + client_mock.__exit__ = MagicMock(return_value=False) + client_mock.post.return_value = show_resp + client_mock.get.return_value = models_resp + + with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length("omnicoder-9b", "http://localhost:11434/v1") + + assert result == 131072 + + def test_ollama_parameters_num_ctx(self): + """Falls back to num_ctx in parameters string when model_info lacks context_length.""" + from agent.model_metadata import _query_local_context_length + + show_resp = self._make_resp(200, { + "model_info": {}, + "parameters": "num_ctx 32768\ntemperature 0.7\n" + }) + models_resp = self._make_resp(404, {}) + + client_mock = MagicMock() + client_mock.__enter__ = lambda s: client_mock + client_mock.__exit__ = MagicMock(return_value=False) + client_mock.post.return_value = show_resp + client_mock.get.return_value = models_resp + + with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length("some-model", "http://localhost:11434/v1") + + assert result == 32768 + + def test_ollama_show_404_falls_through(self): + """When /api/show returns 404, falls through to /v1/models/{model}.""" + from agent.model_metadata import _query_local_context_length + + show_resp = self._make_resp(404, {}) + model_detail_resp = self._make_resp(200, {"max_model_len": 65536}) + + client_mock = MagicMock() + client_mock.__enter__ = lambda s: client_mock + client_mock.__exit__ = MagicMock(return_value=False) + client_mock.post.return_value = show_resp + client_mock.get.return_value = model_detail_resp + + with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length("some-model", "http://localhost:11434/v1") + + assert result == 65536 + + +class TestQueryLocalContextLengthVllm: + """_query_local_context_length with vLLM-style /v1/models/{model} response.""" + + def _make_resp(self, status_code, body): + resp = MagicMock() + resp.status_code = status_code + resp.json.return_value = body + return resp + + def test_vllm_max_model_len(self): + """Reads max_model_len from /v1/models/{model} response.""" + from agent.model_metadata import _query_local_context_length + + detail_resp = self._make_resp(200, {"id": "omnicoder-9b", "max_model_len": 100000}) + list_resp = self._make_resp(404, {}) + + client_mock = MagicMock() + client_mock.__enter__ = lambda s: client_mock + client_mock.__exit__ = MagicMock(return_value=False) + client_mock.post.return_value = self._make_resp(404, {}) + client_mock.get.return_value = detail_resp + + with patch("agent.model_metadata.detect_local_server_type", return_value="vllm"), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length("omnicoder-9b", "http://localhost:8000/v1") + + assert result == 100000 + + def test_vllm_context_length_key(self): + """Reads context_length from /v1/models/{model} response.""" + from agent.model_metadata import _query_local_context_length + + detail_resp = self._make_resp(200, {"id": "some-model", "context_length": 32768}) + + client_mock = MagicMock() + client_mock.__enter__ = lambda s: client_mock + client_mock.__exit__ = MagicMock(return_value=False) + client_mock.post.return_value = self._make_resp(404, {}) + client_mock.get.return_value = detail_resp + + with patch("agent.model_metadata.detect_local_server_type", return_value="vllm"), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length("some-model", "http://localhost:8000/v1") + + assert result == 32768 + + +class TestQueryLocalContextLengthModelsList: + """_query_local_context_length: falls back to /v1/models list.""" + + def _make_resp(self, status_code, body): + resp = MagicMock() + resp.status_code = status_code + resp.json.return_value = body + return resp + + def test_models_list_max_model_len(self): + """Finds context length for model in /v1/models list.""" + from agent.model_metadata import _query_local_context_length + + detail_resp = self._make_resp(404, {}) + list_resp = self._make_resp(200, { + "data": [ + {"id": "other-model", "max_model_len": 4096}, + {"id": "omnicoder-9b", "max_model_len": 131072}, + ] + }) + + call_count = [0] + def side_effect(url, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return detail_resp # /v1/models/omnicoder-9b + return list_resp # /v1/models + + client_mock = MagicMock() + client_mock.__enter__ = lambda s: client_mock + client_mock.__exit__ = MagicMock(return_value=False) + client_mock.post.return_value = self._make_resp(404, {}) + client_mock.get.side_effect = side_effect + + with patch("agent.model_metadata.detect_local_server_type", return_value=None), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length("omnicoder-9b", "http://localhost:1234") + + assert result == 131072 + + def test_models_list_model_not_found_returns_none(self): + """Returns None when model is not in the /v1/models list.""" + from agent.model_metadata import _query_local_context_length + + detail_resp = self._make_resp(404, {}) + list_resp = self._make_resp(200, { + "data": [{"id": "other-model", "max_model_len": 4096}] + }) + + call_count = [0] + def side_effect(url, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return detail_resp + return list_resp + + client_mock = MagicMock() + client_mock.__enter__ = lambda s: client_mock + client_mock.__exit__ = MagicMock(return_value=False) + client_mock.post.return_value = self._make_resp(404, {}) + client_mock.get.side_effect = side_effect + + with patch("agent.model_metadata.detect_local_server_type", return_value=None), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length("omnicoder-9b", "http://localhost:1234") + + assert result is None + + +class TestQueryLocalContextLengthLmStudio: + """_query_local_context_length with LM Studio native /api/v1/models response.""" + + def _make_resp(self, status_code, body): + resp = MagicMock() + resp.status_code = status_code + resp.json.return_value = body + return resp + + def _make_client(self, native_resp, detail_resp, list_resp): + """Build a mock httpx.Client with sequenced GET responses.""" + client_mock = MagicMock() + client_mock.__enter__ = lambda s: client_mock + client_mock.__exit__ = MagicMock(return_value=False) + client_mock.post.return_value = self._make_resp(404, {}) + + responses = [native_resp, detail_resp, list_resp] + call_idx = [0] + + def get_side_effect(url, **kwargs): + idx = call_idx[0] + call_idx[0] += 1 + if idx < len(responses): + return responses[idx] + return self._make_resp(404, {}) + + client_mock.get.side_effect = get_side_effect + return client_mock + + def test_lmstudio_exact_key_match(self): + """Reads max_context_length when key matches exactly.""" + from agent.model_metadata import _query_local_context_length + + native_resp = self._make_resp(200, { + "models": [ + {"key": "nvidia/nvidia-nemotron-super-49b-v1", "id": "nvidia/nvidia-nemotron-super-49b-v1", + "max_context_length": 131072}, + ] + }) + client_mock = self._make_client( + native_resp, + self._make_resp(404, {}), + self._make_resp(404, {}), + ) + + with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length( + "nvidia/nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1" + ) + + assert result == 131072 + + def test_lmstudio_slug_only_matches_key_with_publisher_prefix(self): + """Fuzzy match: bare model slug matches key that includes publisher prefix. + + When the user configures the model as "local:nvidia-nemotron-super-49b-v1" + (slug only, no publisher), but LM Studio's native API stores it as + "nvidia/nvidia-nemotron-super-49b-v1", the lookup must still succeed. + """ + from agent.model_metadata import _query_local_context_length + + native_resp = self._make_resp(200, { + "models": [ + {"key": "nvidia/nvidia-nemotron-super-49b-v1", + "id": "nvidia/nvidia-nemotron-super-49b-v1", + "max_context_length": 131072}, + ] + }) + client_mock = self._make_client( + native_resp, + self._make_resp(404, {}), + self._make_resp(404, {}), + ) + + with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ + patch("httpx.Client", return_value=client_mock): + # Model passed in is just the slug after stripping "local:" prefix + result = _query_local_context_length( + "nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1" + ) + + assert result == 131072 + + def test_lmstudio_v1_models_list_slug_fuzzy_match(self): + """Fuzzy match also works for /v1/models list when exact match fails. + + LM Studio's OpenAI-compat /v1/models returns id like + "nvidia/nvidia-nemotron-super-49b-v1" — must match bare slug. + """ + from agent.model_metadata import _query_local_context_length + + # native /api/v1/models: no match + native_resp = self._make_resp(404, {}) + # /v1/models/{model}: no match + detail_resp = self._make_resp(404, {}) + # /v1/models list: model found with publisher prefix, includes context_length + list_resp = self._make_resp(200, { + "data": [ + {"id": "nvidia/nvidia-nemotron-super-49b-v1", "context_length": 131072}, + ] + }) + client_mock = self._make_client(native_resp, detail_resp, list_resp) + + with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length( + "nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1" + ) + + assert result == 131072 + + def test_lmstudio_loaded_instances_context_length(self): + """Reads active context_length from loaded_instances when max_context_length absent.""" + from agent.model_metadata import _query_local_context_length + + native_resp = self._make_resp(200, { + "models": [ + { + "key": "nvidia/nvidia-nemotron-super-49b-v1", + "id": "nvidia/nvidia-nemotron-super-49b-v1", + "loaded_instances": [ + {"config": {"context_length": 65536}}, + ], + }, + ] + }) + client_mock = self._make_client( + native_resp, + self._make_resp(404, {}), + self._make_resp(404, {}), + ) + + with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length( + "nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1" + ) + + assert result == 65536 + + def test_lmstudio_loaded_instance_beats_max_context_length(self): + """loaded_instances context_length takes priority over max_context_length. + + LM Studio may show max_context_length=1_048_576 (theoretical model max) + while the actual loaded context is 122_651 (runtime setting). The loaded + value is the real constraint and must be preferred. + """ + from agent.model_metadata import _query_local_context_length + + native_resp = self._make_resp(200, { + "models": [ + { + "key": "nvidia/nvidia-nemotron-3-nano-4b", + "id": "nvidia/nvidia-nemotron-3-nano-4b", + "max_context_length": 1_048_576, + "loaded_instances": [ + {"config": {"context_length": 122_651}}, + ], + }, + ] + }) + client_mock = self._make_client( + native_resp, + self._make_resp(404, {}), + self._make_resp(404, {}), + ) + + with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length( + "nvidia-nemotron-3-nano-4b", "http://192.168.1.22:1234/v1" + ) + + assert result == 122_651, ( + f"Expected loaded instance context (122651) but got {result}. " + "max_context_length (1048576) must not win over loaded_instances." + ) + + +class TestQueryLocalContextLengthNetworkError: + """_query_local_context_length handles network failures gracefully.""" + + def test_connection_error_returns_none(self): + """Returns None when the server is unreachable.""" + from agent.model_metadata import _query_local_context_length + + client_mock = MagicMock() + client_mock.__enter__ = lambda s: client_mock + client_mock.__exit__ = MagicMock(return_value=False) + client_mock.post.side_effect = Exception("Connection refused") + client_mock.get.side_effect = Exception("Connection refused") + + with patch("agent.model_metadata.detect_local_server_type", return_value=None), \ + patch("httpx.Client", return_value=client_mock): + result = _query_local_context_length("omnicoder-9b", "http://localhost:11434/v1") + + assert result is None + + +# --------------------------------------------------------------------------- +# get_model_context_length — integration-style tests with mocked helpers +# --------------------------------------------------------------------------- + +class TestGetModelContextLengthLocalFallback: + """get_model_context_length uses local server query before falling back to 2M.""" + + def test_local_endpoint_unknown_model_queries_server(self): + """Unknown model on local endpoint gets ctx from server, not 2M default.""" + from agent.model_metadata import get_model_context_length + + with patch("agent.model_metadata.get_cached_context_length", return_value=None), \ + patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \ + patch("agent.model_metadata.fetch_model_metadata", return_value={}), \ + patch("agent.model_metadata.is_local_endpoint", return_value=True), \ + patch("agent.model_metadata._query_local_context_length", return_value=131072), \ + patch("agent.model_metadata.save_context_length") as mock_save: + result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1") + + assert result == 131072 + + def test_local_endpoint_unknown_model_result_is_cached(self): + """Context length returned from local server is persisted to cache.""" + from agent.model_metadata import get_model_context_length + + with patch("agent.model_metadata.get_cached_context_length", return_value=None), \ + patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \ + patch("agent.model_metadata.fetch_model_metadata", return_value={}), \ + patch("agent.model_metadata.is_local_endpoint", return_value=True), \ + patch("agent.model_metadata._query_local_context_length", return_value=131072), \ + patch("agent.model_metadata.save_context_length") as mock_save: + get_model_context_length("omnicoder-9b", "http://localhost:11434/v1") + + mock_save.assert_called_once_with("omnicoder-9b", "http://localhost:11434/v1", 131072) + + def test_local_endpoint_server_returns_none_falls_back_to_2m(self): + """When local server returns None, still falls back to 2M probe tier.""" + from agent.model_metadata import get_model_context_length, CONTEXT_PROBE_TIERS + + with patch("agent.model_metadata.get_cached_context_length", return_value=None), \ + patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \ + patch("agent.model_metadata.fetch_model_metadata", return_value={}), \ + patch("agent.model_metadata.is_local_endpoint", return_value=True), \ + patch("agent.model_metadata._query_local_context_length", return_value=None): + result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1") + + assert result == CONTEXT_PROBE_TIERS[0] + + def test_non_local_endpoint_does_not_query_local_server(self): + """For non-local endpoints, _query_local_context_length is not called.""" + from agent.model_metadata import get_model_context_length, CONTEXT_PROBE_TIERS + + with patch("agent.model_metadata.get_cached_context_length", return_value=None), \ + patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \ + patch("agent.model_metadata.fetch_model_metadata", return_value={}), \ + patch("agent.model_metadata.is_local_endpoint", return_value=False), \ + patch("agent.model_metadata._query_local_context_length") as mock_query: + result = get_model_context_length( + "unknown-model", "https://some-cloud-api.example.com/v1" + ) + + mock_query.assert_not_called() + + def test_cached_result_skips_local_query(self): + """Cached context length is returned without querying the local server.""" + from agent.model_metadata import get_model_context_length + + with patch("agent.model_metadata.get_cached_context_length", return_value=65536), \ + patch("agent.model_metadata._query_local_context_length") as mock_query: + result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1") + + assert result == 65536 + mock_query.assert_not_called() + + def test_no_base_url_does_not_query_local_server(self): + """When base_url is empty, local server is not queried.""" + from agent.model_metadata import get_model_context_length + + with patch("agent.model_metadata.get_cached_context_length", return_value=None), \ + patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \ + patch("agent.model_metadata.fetch_model_metadata", return_value={}), \ + patch("agent.model_metadata._query_local_context_length") as mock_query: + result = get_model_context_length("unknown-xyz-model", "") + + mock_query.assert_not_called()