hermes-agent/tests/agent/test_model_metadata.py
Andre Kurait b290297d66 fix(bedrock): resolve context length via static table before custom-endpoint probe
## Problem

`get_model_context_length()` in `agent/model_metadata.py` had a resolution
order bug that caused every Bedrock model to fall back to the 128K default
context length instead of reaching the static Bedrock table (200K for
Claude, etc.).

The root cause: `bedrock-runtime.<region>.amazonaws.com` is not listed in
`_URL_TO_PROVIDER`, so `_is_known_provider_base_url()` returned False.
The resolution order then ran the custom-endpoint probe (step 2) *before*
the Bedrock branch (step 4b), which:

  1. Treated Bedrock as a custom endpoint (via `_is_custom_endpoint`).
  2. Called `fetch_endpoint_model_metadata()` → `GET /models` on the
     bedrock-runtime URL (Bedrock doesn't serve this shape).
  3. Fell through to `return DEFAULT_FALLBACK_CONTEXT` (128K) at the
     "probe-down" branch — never reaching the Bedrock static table.

Result: users on Bedrock saw 128K context for Claude models that
actually support 200K on Bedrock, causing premature auto-compression.

## Fix

Promote the Bedrock branch from step 4b to step 1b, so it runs *before*
the custom-endpoint probe at step 2. The static table in
`bedrock_adapter.py::get_bedrock_context_length()` is the authoritative
source for Bedrock (the ListFoundationModels API doesn't expose context
window sizes), so there's no reason to probe `/models` first.

The original step 4b is replaced with a one-line breadcrumb comment
pointing to the new location, to make the resolution-order docstring
accurate.

## Changes

- `agent/model_metadata.py`
  - Add step 1b: Bedrock static-table branch (unchanged predicate, moved).
  - Remove dead step 4b block, replace with breadcrumb comment.
  - Update resolution-order docstring to include step 1b.

- `tests/agent/test_model_metadata.py`
  - New `TestBedrockContextResolution` class (3 tests):
    - `test_bedrock_provider_returns_static_table_before_probe`:
      confirms `provider="bedrock"` hits the static table and does NOT
      call `fetch_endpoint_model_metadata` (regression guard).
    - `test_bedrock_url_without_provider_hint`: confirms the
      `bedrock-runtime.*.amazonaws.com` host match works without an
      explicit `provider=` hint.
    - `test_non_bedrock_url_still_probes`: confirms the probe still
      fires for genuinely-custom endpoints (no over-reach).

## Testing

  pytest tests/agent/test_model_metadata.py -q
  # 83 passed in 1.95s (3 new + 80 existing)

## Risk

Very low.

- Predicate is identical to the original step 4b — no behaviour change
  for non-Bedrock paths.
- Original step 4b was dead code for the user-facing case (always hit
  the 128K fallback first), so removing it cannot regress behaviour.
- Bedrock path now short-circuits before any network I/O — faster too.
- `ImportError` fall-through preserved so users without `boto3`
  installed are unaffected.

## Related

- This is a prerequisite for accurate context-window accounting on
  Bedrock — the fix for #14710 (stale-connection client eviction)
  depends on correct context sizing to know when to compress.

Signed-off-by: Andre Kurait <andrekurait@gmail.com>
2026-04-24 07:26:07 -07:00

981 lines
42 KiB
Python

"""Tests for agent/model_metadata.py — token estimation, context lengths,
probing, caching, and error parsing.
Coverage levels:
Token estimation — concrete value assertions, edge cases
Context length lookup — resolution order, fuzzy match, cache priority
API metadata fetch — caching, TTL, canonical slugs, stale fallback
Probe tiers — descending, boundaries, extreme inputs
Error parsing — OpenAI, Ollama, Anthropic, edge cases
Persistent cache — save/load, corruption, update, provider isolation
"""
import os
import time
import tempfile
import pytest
import yaml
from pathlib import Path
from unittest.mock import patch, MagicMock
from agent.model_metadata import (
CONTEXT_PROBE_TIERS,
DEFAULT_CONTEXT_LENGTHS,
_strip_provider_prefix,
estimate_tokens_rough,
estimate_messages_tokens_rough,
get_model_context_length,
get_next_probe_tier,
get_cached_context_length,
parse_context_limit_from_error,
save_context_length,
fetch_model_metadata,
_MODEL_CACHE_TTL,
)
# =========================================================================
# Token estimation
# =========================================================================
class TestEstimateTokensRough:
def test_empty_string(self):
assert estimate_tokens_rough("") == 0
def test_none_returns_zero(self):
assert estimate_tokens_rough(None) == 0
def test_known_length(self):
assert estimate_tokens_rough("a" * 400) == 100
def test_short_text(self):
# "hello" = 5 chars → ceil(5/4) = 2
assert estimate_tokens_rough("hello") == 2
def test_proportional(self):
short = estimate_tokens_rough("hello world")
long = estimate_tokens_rough("hello world " * 100)
assert long > short
def test_unicode_multibyte(self):
"""Unicode chars are still 1 Python char each — 4 chars/token holds."""
text = "你好世界" # 4 CJK characters
assert estimate_tokens_rough(text) == 1
class TestEstimateMessagesTokensRough:
def test_empty_list(self):
assert estimate_messages_tokens_rough([]) == 0
def test_single_message_concrete_value(self):
"""Verify against known str(msg) length (ceiling division)."""
msg = {"role": "user", "content": "a" * 400}
result = estimate_messages_tokens_rough([msg])
n = len(str(msg))
expected = (n + 3) // 4
assert result == expected
def test_multiple_messages_additive(self):
msgs = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there, how can I help?"},
]
result = estimate_messages_tokens_rough(msgs)
n = sum(len(str(m)) for m in msgs)
expected = (n + 3) // 4
assert result == expected
def test_tool_call_message(self):
"""Tool call messages with no 'content' key still contribute tokens."""
msg = {"role": "assistant", "content": None,
"tool_calls": [{"id": "1", "function": {"name": "terminal", "arguments": "{}"}}]}
result = estimate_messages_tokens_rough([msg])
assert result > 0
assert result == (len(str(msg)) + 3) // 4
def test_message_with_list_content(self):
"""Vision messages with multimodal content arrays."""
msg = {"role": "user", "content": [
{"type": "text", "text": "describe"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}
]}
result = estimate_messages_tokens_rough([msg])
assert result == (len(str(msg)) + 3) // 4
# =========================================================================
# Default context lengths
# =========================================================================
class TestDefaultContextLengths:
def test_claude_models_context_lengths(self):
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
if "claude" not in key:
continue
# Claude 4.6+ models (4.6 and 4.7) have 1M context at standard
# API pricing (no long-context premium). Older Claude 4.x and
# 3.x models cap at 200k.
if any(tag in key for tag in ("4.6", "4-6", "4.7", "4-7")):
assert value == 1000000, f"{key} should be 1000000"
else:
assert value == 200000, f"{key} should be 200000"
def test_gpt4_models_128k_or_1m(self):
# gpt-4.1 and gpt-4.1-mini have 1M context; other gpt-4* have 128k
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
if "gpt-4" in key and "gpt-4.1" not in key:
assert value == 128000, f"{key} should be 128000"
def test_gpt41_models_1m(self):
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
if "gpt-4.1" in key:
assert value == 1047576, f"{key} should be 1047576"
def test_gemini_models_1m(self):
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
if "gemini" in key:
assert value == 1048576, f"{key} should be 1048576"
def test_grok_models_context_lengths(self):
# xAI /v1/models does not return context_length metadata, so
# DEFAULT_CONTEXT_LENGTHS must cover the Grok family explicitly.
# Values sourced from models.dev (2026-04).
expected = {
"grok-4.20": 2000000,
"grok-4-1-fast": 2000000,
"grok-4-fast": 2000000,
"grok-4": 256000,
"grok-code-fast": 256000,
"grok-3": 131072,
"grok-2": 131072,
"grok-2-vision": 8192,
"grok": 131072,
}
for key, value in expected.items():
assert key in DEFAULT_CONTEXT_LENGTHS, f"{key} missing from DEFAULT_CONTEXT_LENGTHS"
assert DEFAULT_CONTEXT_LENGTHS[key] == value, (
f"{key} should be {value}, got {DEFAULT_CONTEXT_LENGTHS[key]}"
)
def test_grok_substring_matching(self):
# Longest-first substring matching must resolve the real xAI model
# IDs to the correct fallback entries without 128k probe-down.
from agent.model_metadata import get_model_context_length
from unittest.mock import patch as mock_patch
# Fake the provider/API/cache layers so the lookup falls through
# to DEFAULT_CONTEXT_LENGTHS.
with mock_patch("agent.model_metadata.fetch_model_metadata", return_value={}), mock_patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), mock_patch("agent.model_metadata.get_cached_context_length", return_value=None):
cases = [
("grok-4.20-0309-reasoning", 2000000),
("grok-4.20-0309-non-reasoning", 2000000),
("grok-4.20-multi-agent-0309", 2000000),
("grok-4-1-fast-reasoning", 2000000),
("grok-4-1-fast-non-reasoning", 2000000),
("grok-4-fast-reasoning", 2000000),
("grok-4-fast-non-reasoning", 2000000),
("grok-4", 256000),
("grok-4-0709", 256000),
("grok-code-fast-1", 256000),
("grok-3", 131072),
("grok-3-mini", 131072),
("grok-3-mini-fast", 131072),
("grok-2", 131072),
("grok-2-vision", 8192),
("grok-2-vision-1212", 8192),
("grok-beta", 131072),
]
for model_id, expected_ctx in cases:
actual = get_model_context_length(model_id)
assert actual == expected_ctx, (
f"{model_id}: expected {expected_ctx}, got {actual}"
)
def test_all_values_positive(self):
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
assert value > 0, f"{key} has non-positive context length"
def test_dict_is_not_empty(self):
assert len(DEFAULT_CONTEXT_LENGTHS) >= 10
# =========================================================================
# Codex OAuth context-window resolution (provider="openai-codex")
# =========================================================================
class TestCodexOAuthContextLength:
"""ChatGPT Codex OAuth imposes lower context limits than the direct
OpenAI API for the same slugs. Verified Apr 2026 via live probe of
chatgpt.com/backend-api/codex/models: every model returns 272k, while
models.dev reports 1.05M for gpt-5.5/gpt-5.4 and 400k for the rest.
"""
def setup_method(self):
import agent.model_metadata as mm
mm._codex_oauth_context_cache = {}
mm._codex_oauth_context_cache_time = 0.0
def test_fallback_table_used_without_token(self):
"""With no access token, the hardcoded Codex fallback table wins
over models.dev (which reports 1.05M for gpt-5.5 but Codex is 272k).
"""
from agent.model_metadata import get_model_context_length
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
patch("agent.model_metadata.save_context_length"):
for model in (
"gpt-5.5",
"gpt-5.4",
"gpt-5.4-mini",
"gpt-5.3-codex",
"gpt-5.2-codex",
"gpt-5.1-codex-max",
"gpt-5.1-codex-mini",
):
ctx = get_model_context_length(
model=model,
base_url="https://chatgpt.com/backend-api/codex",
api_key="",
provider="openai-codex",
)
assert ctx == 272_000, (
f"Codex {model}: expected 272000 fallback, got {ctx} "
"(models.dev leakage?)"
)
def test_live_probe_overrides_fallback(self):
"""When a token is provided, the live /models probe is preferred
and its context_window drives the result."""
from agent.model_metadata import get_model_context_length
fake_response = MagicMock()
fake_response.status_code = 200
fake_response.json.return_value = {
"models": [
{"slug": "gpt-5.5", "context_window": 300_000},
{"slug": "gpt-5.4", "context_window": 400_000},
]
}
with patch("agent.model_metadata.requests.get", return_value=fake_response), \
patch("agent.model_metadata.get_cached_context_length", return_value=None), \
patch("agent.model_metadata.save_context_length"):
ctx_55 = get_model_context_length(
model="gpt-5.5",
base_url="https://chatgpt.com/backend-api/codex",
api_key="fake-token",
provider="openai-codex",
)
ctx_54 = get_model_context_length(
model="gpt-5.4",
base_url="https://chatgpt.com/backend-api/codex",
api_key="fake-token",
provider="openai-codex",
)
assert ctx_55 == 300_000
assert ctx_54 == 400_000
def test_probe_failure_falls_back_to_hardcoded(self):
"""If the probe fails (non-200 / network error), we still return
the hardcoded 272k rather than leaking through to models.dev 1.05M."""
from agent.model_metadata import get_model_context_length
fake_response = MagicMock()
fake_response.status_code = 401
fake_response.json.return_value = {}
with patch("agent.model_metadata.requests.get", return_value=fake_response), \
patch("agent.model_metadata.get_cached_context_length", return_value=None), \
patch("agent.model_metadata.save_context_length"):
ctx = get_model_context_length(
model="gpt-5.5",
base_url="https://chatgpt.com/backend-api/codex",
api_key="expired-token",
provider="openai-codex",
)
assert ctx == 272_000
def test_non_codex_providers_unaffected(self):
"""Resolving gpt-5.5 on non-Codex providers must NOT use the Codex
272k override — OpenRouter / direct OpenAI API have different limits.
"""
from agent.model_metadata import get_model_context_length
# OpenRouter — should hit its own catalog path first; when mocked
# empty, falls through to hardcoded DEFAULT_CONTEXT_LENGTHS (400k).
with patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
patch("agent.model_metadata.get_cached_context_length", return_value=None), \
patch("agent.models_dev.lookup_models_dev_context", return_value=None):
ctx = get_model_context_length(
model="openai/gpt-5.5",
base_url="https://openrouter.ai/api/v1",
api_key="",
provider="openrouter",
)
assert ctx == 400_000, (
f"Non-Codex gpt-5.5 resolved to {ctx}; Codex 272k override "
"leaked outside openai-codex provider"
)
def test_stale_codex_cache_over_400k_is_invalidated(self, tmp_path, monkeypatch):
"""Pre-PR #14935 builds cached gpt-5.5 at 1.05M (from models.dev)
before the Codex-aware branch existed. Upgrading users keep that
stale entry on disk and the cache-first lookup returns it forever.
Codex OAuth caps at 272k for every slug, so any cached Codex
entry >= 400k must be dropped and re-resolved via the live probe.
"""
from agent import model_metadata as mm
# Isolate the cache file to tmp_path
cache_file = tmp_path / "context_length_cache.yaml"
monkeypatch.setattr(mm, "_get_context_cache_path", lambda: cache_file)
base_url = "https://chatgpt.com/backend-api/codex/"
stale_key = f"gpt-5.5@{base_url}"
other_key = "other-model@https://api.openai.com/v1/"
import yaml as _yaml
cache_file.write_text(_yaml.dump({"context_lengths": {
stale_key: 1_050_000, # stale pre-fix value
other_key: 128_000, # unrelated, must survive
}}))
fake_response = MagicMock()
fake_response.status_code = 200
fake_response.json.return_value = {
"models": [{"slug": "gpt-5.5", "context_window": 272_000}]
}
with patch("agent.model_metadata.requests.get", return_value=fake_response), \
patch("agent.model_metadata.save_context_length") as mock_save:
ctx = mm.get_model_context_length(
model="gpt-5.5",
base_url=base_url,
api_key="fake-token",
provider="openai-codex",
)
assert ctx == 272_000, f"Stale entry should have been re-resolved to 272k, got {ctx}"
# Live save was called with the fresh value
mock_save.assert_called_with("gpt-5.5", base_url, 272_000)
# The stale entry was removed from disk; unrelated entries survived
remaining = _yaml.safe_load(cache_file.read_text()).get("context_lengths", {})
assert stale_key not in remaining, "Stale entry was not invalidated from the cache file"
assert remaining.get(other_key) == 128_000, "Unrelated cache entries must not be touched"
def test_fresh_codex_cache_under_400k_is_respected(self, tmp_path, monkeypatch):
"""Codex entries at the correct 272k must NOT be invalidated —
only stale pre-fix values (>= 400k) get dropped."""
from agent import model_metadata as mm
cache_file = tmp_path / "context_length_cache.yaml"
monkeypatch.setattr(mm, "_get_context_cache_path", lambda: cache_file)
base_url = "https://chatgpt.com/backend-api/codex/"
import yaml as _yaml
cache_file.write_text(_yaml.dump({"context_lengths": {
f"gpt-5.5@{base_url}": 272_000,
}}))
# If the invalidation incorrectly fired, this would be called; assert it isn't.
with patch("agent.model_metadata.requests.get") as mock_get:
ctx = mm.get_model_context_length(
model="gpt-5.5",
base_url=base_url,
api_key="fake-token",
provider="openai-codex",
)
assert ctx == 272_000
mock_get.assert_not_called()
def test_stale_invalidation_scoped_to_codex_provider(self, tmp_path, monkeypatch):
"""A cached 1M entry for a non-Codex provider (e.g. Anthropic opus on
OpenRouter, legitimately 1M) must NOT be invalidated by this guard."""
from agent import model_metadata as mm
cache_file = tmp_path / "context_length_cache.yaml"
monkeypatch.setattr(mm, "_get_context_cache_path", lambda: cache_file)
base_url = "https://openrouter.ai/api/v1"
import yaml as _yaml
cache_file.write_text(_yaml.dump({"context_lengths": {
f"anthropic/claude-opus-4.6@{base_url}": 1_000_000,
}}))
ctx = mm.get_model_context_length(
model="anthropic/claude-opus-4.6",
base_url=base_url,
api_key="fake",
provider="openrouter",
)
assert ctx == 1_000_000, "Non-codex 1M cache entries must be respected"
# =========================================================================
# get_model_context_length — resolution order
# =========================================================================
class TestGetModelContextLength:
@patch("agent.model_metadata.fetch_model_metadata")
def test_known_model_from_api(self, mock_fetch):
mock_fetch.return_value = {
"test/model": {"context_length": 32000}
}
assert get_model_context_length("test/model") == 32000
@patch("agent.model_metadata.fetch_model_metadata")
def test_fallback_to_defaults(self, mock_fetch):
mock_fetch.return_value = {}
assert get_model_context_length("anthropic/claude-sonnet-4") == 200000
@patch("agent.model_metadata.fetch_model_metadata")
def test_unknown_model_returns_first_probe_tier(self, mock_fetch):
mock_fetch.return_value = {}
assert get_model_context_length("unknown/never-heard-of-this") == CONTEXT_PROBE_TIERS[0]
@patch("agent.model_metadata.fetch_model_metadata")
def test_partial_match_in_defaults(self, mock_fetch):
mock_fetch.return_value = {}
assert get_model_context_length("openai/gpt-4o") == 128000
@patch("agent.model_metadata.fetch_model_metadata")
def test_qwen3_coder_plus_context_length(self, mock_fetch):
"""qwen3-coder-plus has a 1M context window, not the generic 128K Qwen default."""
mock_fetch.return_value = {}
assert get_model_context_length("qwen3-coder-plus") == 1000000
@patch("agent.model_metadata.fetch_model_metadata")
def test_qwen3_coder_context_length(self, mock_fetch):
"""qwen3-coder has a 256K context window, not the generic 128K Qwen default."""
mock_fetch.return_value = {}
assert get_model_context_length("qwen3-coder") == 262144
@patch("agent.model_metadata.fetch_model_metadata")
def test_qwen_generic_context_length(self, mock_fetch):
"""Generic qwen models still get the 128K default."""
mock_fetch.return_value = {}
assert get_model_context_length("qwen3-plus") == 131072
@patch("agent.model_metadata.fetch_model_metadata")
def test_api_missing_context_length_key(self, mock_fetch):
"""Model in API but without context_length → defaults to 128000."""
mock_fetch.return_value = {"test/model": {"name": "Test"}}
assert get_model_context_length("test/model") == 128000
@patch("agent.model_metadata.fetch_model_metadata")
def test_cache_takes_priority_over_api(self, mock_fetch, tmp_path):
"""Persistent cache should be checked BEFORE API metadata."""
mock_fetch.return_value = {"my/model": {"context_length": 999999}}
cache_file = tmp_path / "cache.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
save_context_length("my/model", "http://local", 32768)
result = get_model_context_length("my/model", base_url="http://local")
assert result == 32768 # cache wins over API's 999999
@patch("agent.model_metadata.fetch_model_metadata")
def test_no_base_url_skips_cache(self, mock_fetch, tmp_path):
"""Without base_url, cache lookup is skipped."""
mock_fetch.return_value = {}
cache_file = tmp_path / "cache.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
save_context_length("custom/model", "http://local", 32768)
# No base_url → cache skipped → falls to probe tier
result = get_model_context_length("custom/model")
assert result == CONTEXT_PROBE_TIERS[0]
@patch("agent.model_metadata.fetch_model_metadata")
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
def test_custom_endpoint_metadata_beats_fuzzy_default(self, mock_endpoint_fetch, mock_fetch):
mock_fetch.return_value = {}
mock_endpoint_fetch.return_value = {
"zai-org/GLM-5-TEE": {"context_length": 65536}
}
result = get_model_context_length(
"zai-org/GLM-5-TEE",
base_url="https://llm.chutes.ai/v1",
api_key="test-key",
)
assert result == 65536
@patch("agent.model_metadata.fetch_model_metadata")
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
def test_custom_endpoint_without_metadata_skips_name_based_default(self, mock_endpoint_fetch, mock_fetch):
mock_fetch.return_value = {}
mock_endpoint_fetch.return_value = {}
result = get_model_context_length(
"zai-org/GLM-5-TEE",
base_url="https://llm.chutes.ai/v1",
api_key="test-key",
)
assert result == CONTEXT_PROBE_TIERS[0]
@patch("agent.model_metadata.fetch_model_metadata")
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
def test_custom_endpoint_single_model_fallback(self, mock_endpoint_fetch, mock_fetch):
"""Single-model servers: use the only model even if name doesn't match."""
mock_fetch.return_value = {}
mock_endpoint_fetch.return_value = {
"Qwen3.5-9B-Q4_K_M.gguf": {"context_length": 131072}
}
result = get_model_context_length(
"qwen3.5:9b",
base_url="http://myserver.example.com:8080/v1",
api_key="test-key",
)
assert result == 131072
@patch("agent.model_metadata.fetch_model_metadata")
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
def test_custom_endpoint_fuzzy_substring_match(self, mock_endpoint_fetch, mock_fetch):
"""Fuzzy match: configured model name is substring of endpoint model."""
mock_fetch.return_value = {}
mock_endpoint_fetch.return_value = {
"org/llama-3.3-70b-instruct-fp8": {"context_length": 131072},
"org/qwen-2.5-72b": {"context_length": 32768},
}
result = get_model_context_length(
"llama-3.3-70b-instruct",
base_url="http://myserver.example.com:8080/v1",
api_key="test-key",
)
assert result == 131072
@patch("agent.model_metadata.fetch_model_metadata")
def test_config_context_length_overrides_all(self, mock_fetch):
"""Explicit config_context_length takes priority over everything."""
mock_fetch.return_value = {
"test/model": {"context_length": 200000}
}
result = get_model_context_length(
"test/model",
config_context_length=65536,
)
assert result == 65536
@patch("agent.model_metadata.fetch_model_metadata")
def test_config_context_length_zero_is_ignored(self, mock_fetch):
"""config_context_length=0 should be treated as unset."""
mock_fetch.return_value = {}
result = get_model_context_length(
"anthropic/claude-sonnet-4",
config_context_length=0,
)
assert result == 200000
@patch("agent.model_metadata.fetch_model_metadata")
def test_config_context_length_none_is_ignored(self, mock_fetch):
"""config_context_length=None should be treated as unset."""
mock_fetch.return_value = {}
result = get_model_context_length(
"anthropic/claude-sonnet-4",
config_context_length=None,
)
assert result == 200000
# =========================================================================
# Bedrock context resolution — must run BEFORE custom-endpoint probe
# =========================================================================
class TestBedrockContextResolution:
"""Regression tests for Bedrock context-length resolution order.
Bug: because ``bedrock-runtime.<region>.amazonaws.com`` is not listed in
``_URL_TO_PROVIDER``, ``_is_known_provider_base_url`` returned False and
the custom-endpoint probe at step 2 ran first — fetching ``/models`` from
Bedrock (which it doesn't serve), returning the 128K default-fallback
before execution ever reached the Bedrock branch.
Fix: promote the Bedrock branch ahead of the custom-endpoint probe.
"""
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
def test_bedrock_provider_returns_static_table_before_probe(self, mock_fetch):
"""provider='bedrock' resolves via static table, bypasses /models probe."""
ctx = get_model_context_length(
"anthropic.claude-opus-4-v1:0",
provider="bedrock",
base_url="https://bedrock-runtime.us-east-1.amazonaws.com",
)
# Must return the static Bedrock table value (200K for Claude),
# NOT DEFAULT_FALLBACK_CONTEXT (128K).
assert ctx == 200000
mock_fetch.assert_not_called()
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
def test_bedrock_url_without_provider_hint(self, mock_fetch):
"""bedrock-runtime host infers Bedrock even when provider is omitted."""
ctx = get_model_context_length(
"anthropic.claude-sonnet-4-v1:0",
base_url="https://bedrock-runtime.us-west-2.amazonaws.com",
)
assert ctx == 200000
mock_fetch.assert_not_called()
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
def test_non_bedrock_url_still_probes(self, mock_fetch):
"""Non-Bedrock hosts still reach the custom-endpoint probe."""
mock_fetch.return_value = {"some-model": {"context_length": 50000}}
ctx = get_model_context_length(
"some-model",
base_url="https://api.example.com/v1",
)
assert ctx == 50000
assert mock_fetch.called
# =========================================================================
# _strip_provider_prefix — Ollama model:tag vs provider:model
# =========================================================================
class TestStripProviderPrefix:
def test_known_provider_prefix_is_stripped(self):
assert _strip_provider_prefix("local:my-model") == "my-model"
assert _strip_provider_prefix("openrouter:anthropic/claude-sonnet-4") == "anthropic/claude-sonnet-4"
assert _strip_provider_prefix("anthropic:claude-sonnet-4") == "claude-sonnet-4"
assert _strip_provider_prefix("stepfun:step-3.5-flash") == "step-3.5-flash"
def test_ollama_model_tag_preserved(self):
"""Ollama model:tag format must NOT be stripped."""
assert _strip_provider_prefix("qwen3.5:27b") == "qwen3.5:27b"
assert _strip_provider_prefix("llama3.3:70b") == "llama3.3:70b"
assert _strip_provider_prefix("gemma2:9b") == "gemma2:9b"
assert _strip_provider_prefix("codellama:13b-instruct-q4_0") == "codellama:13b-instruct-q4_0"
def test_http_urls_preserved(self):
assert _strip_provider_prefix("http://example.com") == "http://example.com"
assert _strip_provider_prefix("https://example.com") == "https://example.com"
def test_no_colon_returns_unchanged(self):
assert _strip_provider_prefix("gpt-4o") == "gpt-4o"
assert _strip_provider_prefix("anthropic/claude-sonnet-4") == "anthropic/claude-sonnet-4"
@patch("agent.model_metadata.fetch_model_metadata")
def test_ollama_model_tag_not_mangled_in_context_lookup(self, mock_fetch):
"""Ensure 'qwen3.5:27b' is NOT reduced to '27b' during context length lookup.
We mock a custom endpoint that knows 'qwen3.5:27b' — the full name
must reach the endpoint metadata lookup intact.
"""
mock_fetch.return_value = {}
with patch("agent.model_metadata.fetch_endpoint_model_metadata") as mock_ep, \
patch("agent.model_metadata._is_custom_endpoint", return_value=True):
mock_ep.return_value = {"qwen3.5:27b": {"context_length": 32768}}
result = get_model_context_length(
"qwen3.5:27b",
base_url="http://localhost:11434/v1",
)
assert result == 32768
# =========================================================================
# fetch_model_metadata — caching, TTL, slugs, failures
# =========================================================================
class TestFetchModelMetadata:
def _reset_cache(self):
import agent.model_metadata as mm
mm._model_metadata_cache = {}
mm._model_metadata_cache_time = 0
@patch("agent.model_metadata.requests.get")
def test_caches_result(self, mock_get):
self._reset_cache()
mock_response = MagicMock()
mock_response.json.return_value = {
"data": [{"id": "test/model", "context_length": 99999, "name": "Test"}]
}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
result1 = fetch_model_metadata(force_refresh=True)
assert "test/model" in result1
assert mock_get.call_count == 1
result2 = fetch_model_metadata()
assert "test/model" in result2
assert mock_get.call_count == 1 # cached
@patch("agent.model_metadata.requests.get")
def test_api_failure_returns_empty_on_cold_cache(self, mock_get):
self._reset_cache()
mock_get.side_effect = Exception("Network error")
result = fetch_model_metadata(force_refresh=True)
assert result == {}
@patch("agent.model_metadata.requests.get")
def test_api_failure_returns_stale_cache(self, mock_get):
"""On API failure with existing cache, stale data is returned."""
import agent.model_metadata as mm
mm._model_metadata_cache = {"old/model": {"context_length": 50000}}
mm._model_metadata_cache_time = 0 # expired
mock_get.side_effect = Exception("Network error")
result = fetch_model_metadata(force_refresh=True)
assert "old/model" in result
assert result["old/model"]["context_length"] == 50000
@patch("agent.model_metadata.requests.get")
def test_canonical_slug_aliasing(self, mock_get):
"""Models with canonical_slug get indexed under both IDs."""
self._reset_cache()
mock_response = MagicMock()
mock_response.json.return_value = {
"data": [{
"id": "anthropic/claude-3.5-sonnet:beta",
"canonical_slug": "anthropic/claude-3.5-sonnet",
"context_length": 200000,
"name": "Claude 3.5 Sonnet"
}]
}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
result = fetch_model_metadata(force_refresh=True)
# Both the original ID and canonical slug should work
assert "anthropic/claude-3.5-sonnet:beta" in result
assert "anthropic/claude-3.5-sonnet" in result
assert result["anthropic/claude-3.5-sonnet"]["context_length"] == 200000
@patch("agent.model_metadata.requests.get")
def test_provider_prefixed_models_get_bare_aliases(self, mock_get):
self._reset_cache()
mock_response = MagicMock()
mock_response.json.return_value = {
"data": [{
"id": "provider/test-model",
"context_length": 123456,
"name": "Provider: Test Model",
}]
}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
result = fetch_model_metadata(force_refresh=True)
assert result["provider/test-model"]["context_length"] == 123456
assert result["test-model"]["context_length"] == 123456
@patch("agent.model_metadata.requests.get")
def test_ttl_expiry_triggers_refetch(self, mock_get):
"""Cache expires after _MODEL_CACHE_TTL seconds."""
import agent.model_metadata as mm
self._reset_cache()
mock_response = MagicMock()
mock_response.json.return_value = {
"data": [{"id": "m1", "context_length": 1000, "name": "M1"}]
}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
fetch_model_metadata(force_refresh=True)
assert mock_get.call_count == 1
# Simulate TTL expiry
mm._model_metadata_cache_time = time.time() - _MODEL_CACHE_TTL - 1
fetch_model_metadata()
assert mock_get.call_count == 2 # refetched
@patch("agent.model_metadata.requests.get")
def test_malformed_json_no_data_key(self, mock_get):
"""API returns JSON without 'data' key — empty cache, no crash."""
self._reset_cache()
mock_response = MagicMock()
mock_response.json.return_value = {"error": "something"}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
result = fetch_model_metadata(force_refresh=True)
assert result == {}
# =========================================================================
# Context probe tiers
# =========================================================================
class TestContextProbeTiers:
def test_tiers_descending(self):
for i in range(len(CONTEXT_PROBE_TIERS) - 1):
assert CONTEXT_PROBE_TIERS[i] > CONTEXT_PROBE_TIERS[i + 1]
def test_first_tier_is_128k(self):
assert CONTEXT_PROBE_TIERS[0] == 128_000
def test_last_tier_is_8k(self):
assert CONTEXT_PROBE_TIERS[-1] == 8_000
class TestGetNextProbeTier:
def test_from_128k(self):
assert get_next_probe_tier(128_000) == 64_000
def test_from_64k(self):
assert get_next_probe_tier(64_000) == 32_000
def test_from_32k(self):
assert get_next_probe_tier(32_000) == 16_000
def test_from_8k_returns_none(self):
assert get_next_probe_tier(8_000) is None
def test_from_below_min_returns_none(self):
assert get_next_probe_tier(4_000) is None
def test_from_arbitrary_value(self):
assert get_next_probe_tier(100_000) == 64_000
def test_above_max_tier(self):
"""Value above 128K should return 128K."""
assert get_next_probe_tier(500_000) == 128_000
def test_zero_returns_none(self):
assert get_next_probe_tier(0) is None
# =========================================================================
# Error message parsing
# =========================================================================
class TestParseContextLimitFromError:
def test_openai_format(self):
msg = "This model's maximum context length is 32768 tokens. However, your messages resulted in 45000 tokens."
assert parse_context_limit_from_error(msg) == 32768
def test_context_length_exceeded(self):
msg = "context_length_exceeded: maximum context length is 131072"
assert parse_context_limit_from_error(msg) == 131072
def test_context_size_exceeded(self):
msg = "Maximum context size 65536 exceeded"
assert parse_context_limit_from_error(msg) == 65536
def test_no_limit_in_message(self):
assert parse_context_limit_from_error("Something went wrong with the API") is None
def test_unreasonable_small_number_rejected(self):
assert parse_context_limit_from_error("context length is 42 tokens") is None
def test_ollama_format(self):
msg = "Context size has been exceeded. Maximum context size is 32768"
assert parse_context_limit_from_error(msg) == 32768
def test_anthropic_format(self):
msg = "prompt is too long: 250000 tokens > 200000 maximum"
# Should extract 200000 (the limit), not 250000 (the input size)
assert parse_context_limit_from_error(msg) == 200000
def test_lmstudio_format(self):
msg = "Error: context window of 4096 tokens exceeded"
assert parse_context_limit_from_error(msg) == 4096
def test_minimax_delta_only_message_returns_none(self):
msg = "invalid params, context window exceeds limit (2013)"
assert parse_context_limit_from_error(msg) is None
def test_completely_unrelated_error(self):
assert parse_context_limit_from_error("Invalid API key") is None
def test_empty_string(self):
assert parse_context_limit_from_error("") is None
def test_number_outside_reasonable_range(self):
"""Very large number (>10M) should be rejected."""
msg = "maximum context length is 99999999999"
assert parse_context_limit_from_error(msg) is None
# =========================================================================
# Persistent context length cache
# =========================================================================
class TestContextLengthCache:
def test_save_and_load(self, tmp_path):
cache_file = tmp_path / "cache.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
save_context_length("test/model", "http://localhost:8080/v1", 32768)
assert get_cached_context_length("test/model", "http://localhost:8080/v1") == 32768
def test_missing_cache_returns_none(self, tmp_path):
cache_file = tmp_path / "nonexistent.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
assert get_cached_context_length("test/model", "http://x") is None
def test_multiple_models_cached(self, tmp_path):
cache_file = tmp_path / "cache.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
save_context_length("model-a", "http://a", 64000)
save_context_length("model-b", "http://b", 128000)
assert get_cached_context_length("model-a", "http://a") == 64000
assert get_cached_context_length("model-b", "http://b") == 128000
def test_same_model_different_providers(self, tmp_path):
cache_file = tmp_path / "cache.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
save_context_length("llama-3", "http://local:8080", 32768)
save_context_length("llama-3", "https://openrouter.ai/api/v1", 131072)
assert get_cached_context_length("llama-3", "http://local:8080") == 32768
assert get_cached_context_length("llama-3", "https://openrouter.ai/api/v1") == 131072
def test_idempotent_save(self, tmp_path):
cache_file = tmp_path / "cache.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
save_context_length("model", "http://x", 32768)
save_context_length("model", "http://x", 32768)
with open(cache_file) as f:
data = yaml.safe_load(f)
assert len(data["context_lengths"]) == 1
def test_update_existing_value(self, tmp_path):
"""Saving a different value for the same key overwrites it."""
cache_file = tmp_path / "cache.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
save_context_length("model", "http://x", 128000)
save_context_length("model", "http://x", 64000)
assert get_cached_context_length("model", "http://x") == 64000
def test_corrupted_yaml_returns_empty(self, tmp_path):
"""Corrupted cache file is handled gracefully."""
cache_file = tmp_path / "cache.yaml"
cache_file.write_text("{{{{not valid yaml: [[[")
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
assert get_cached_context_length("model", "http://x") is None
def test_wrong_structure_returns_none(self, tmp_path):
"""YAML that loads but has wrong structure."""
cache_file = tmp_path / "cache.yaml"
cache_file.write_text("just_a_string\n")
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
assert get_cached_context_length("model", "http://x") is None
@patch("agent.model_metadata.fetch_model_metadata")
def test_cached_value_takes_priority(self, mock_fetch, tmp_path):
mock_fetch.return_value = {}
cache_file = tmp_path / "cache.yaml"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
save_context_length("unknown/model", "http://local", 65536)
assert get_model_context_length("unknown/model", base_url="http://local") == 65536
def test_special_chars_in_model_name(self, tmp_path):
"""Model names with colons, slashes, etc. don't break the cache."""
cache_file = tmp_path / "cache.yaml"
model = "anthropic/claude-3.5-sonnet:beta"
url = "https://api.example.com/v1"
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
save_context_length(model, url, 200000)
assert get_cached_context_length(model, url) == 200000