mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
## Problem
`get_model_context_length()` in `agent/model_metadata.py` had a resolution
order bug that caused every Bedrock model to fall back to the 128K default
context length instead of reaching the static Bedrock table (200K for
Claude, etc.).
The root cause: `bedrock-runtime.<region>.amazonaws.com` is not listed in
`_URL_TO_PROVIDER`, so `_is_known_provider_base_url()` returned False.
The resolution order then ran the custom-endpoint probe (step 2) *before*
the Bedrock branch (step 4b), which:
1. Treated Bedrock as a custom endpoint (via `_is_custom_endpoint`).
2. Called `fetch_endpoint_model_metadata()` → `GET /models` on the
bedrock-runtime URL (Bedrock doesn't serve this shape).
3. Fell through to `return DEFAULT_FALLBACK_CONTEXT` (128K) at the
"probe-down" branch — never reaching the Bedrock static table.
Result: users on Bedrock saw 128K context for Claude models that
actually support 200K on Bedrock, causing premature auto-compression.
## Fix
Promote the Bedrock branch from step 4b to step 1b, so it runs *before*
the custom-endpoint probe at step 2. The static table in
`bedrock_adapter.py::get_bedrock_context_length()` is the authoritative
source for Bedrock (the ListFoundationModels API doesn't expose context
window sizes), so there's no reason to probe `/models` first.
The original step 4b is replaced with a one-line breadcrumb comment
pointing to the new location, to make the resolution-order docstring
accurate.
## Changes
- `agent/model_metadata.py`
- Add step 1b: Bedrock static-table branch (unchanged predicate, moved).
- Remove dead step 4b block, replace with breadcrumb comment.
- Update resolution-order docstring to include step 1b.
- `tests/agent/test_model_metadata.py`
- New `TestBedrockContextResolution` class (3 tests):
- `test_bedrock_provider_returns_static_table_before_probe`:
confirms `provider="bedrock"` hits the static table and does NOT
call `fetch_endpoint_model_metadata` (regression guard).
- `test_bedrock_url_without_provider_hint`: confirms the
`bedrock-runtime.*.amazonaws.com` host match works without an
explicit `provider=` hint.
- `test_non_bedrock_url_still_probes`: confirms the probe still
fires for genuinely-custom endpoints (no over-reach).
## Testing
pytest tests/agent/test_model_metadata.py -q
# 83 passed in 1.95s (3 new + 80 existing)
## Risk
Very low.
- Predicate is identical to the original step 4b — no behaviour change
for non-Bedrock paths.
- Original step 4b was dead code for the user-facing case (always hit
the 128K fallback first), so removing it cannot regress behaviour.
- Bedrock path now short-circuits before any network I/O — faster too.
- `ImportError` fall-through preserved so users without `boto3`
installed are unaffected.
## Related
- This is a prerequisite for accurate context-window accounting on
Bedrock — the fix for #14710 (stale-connection client eviction)
depends on correct context sizing to know when to compress.
Signed-off-by: Andre Kurait <andrekurait@gmail.com>
981 lines
42 KiB
Python
981 lines
42 KiB
Python
"""Tests for agent/model_metadata.py — token estimation, context lengths,
|
|
probing, caching, and error parsing.
|
|
|
|
Coverage levels:
|
|
Token estimation — concrete value assertions, edge cases
|
|
Context length lookup — resolution order, fuzzy match, cache priority
|
|
API metadata fetch — caching, TTL, canonical slugs, stale fallback
|
|
Probe tiers — descending, boundaries, extreme inputs
|
|
Error parsing — OpenAI, Ollama, Anthropic, edge cases
|
|
Persistent cache — save/load, corruption, update, provider isolation
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import tempfile
|
|
|
|
import pytest
|
|
import yaml
|
|
from pathlib import Path
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from agent.model_metadata import (
|
|
CONTEXT_PROBE_TIERS,
|
|
DEFAULT_CONTEXT_LENGTHS,
|
|
_strip_provider_prefix,
|
|
estimate_tokens_rough,
|
|
estimate_messages_tokens_rough,
|
|
get_model_context_length,
|
|
get_next_probe_tier,
|
|
get_cached_context_length,
|
|
parse_context_limit_from_error,
|
|
save_context_length,
|
|
fetch_model_metadata,
|
|
_MODEL_CACHE_TTL,
|
|
)
|
|
|
|
|
|
# =========================================================================
|
|
# Token estimation
|
|
# =========================================================================
|
|
|
|
class TestEstimateTokensRough:
|
|
def test_empty_string(self):
|
|
assert estimate_tokens_rough("") == 0
|
|
|
|
def test_none_returns_zero(self):
|
|
assert estimate_tokens_rough(None) == 0
|
|
|
|
def test_known_length(self):
|
|
assert estimate_tokens_rough("a" * 400) == 100
|
|
|
|
def test_short_text(self):
|
|
# "hello" = 5 chars → ceil(5/4) = 2
|
|
assert estimate_tokens_rough("hello") == 2
|
|
|
|
def test_proportional(self):
|
|
short = estimate_tokens_rough("hello world")
|
|
long = estimate_tokens_rough("hello world " * 100)
|
|
assert long > short
|
|
|
|
def test_unicode_multibyte(self):
|
|
"""Unicode chars are still 1 Python char each — 4 chars/token holds."""
|
|
text = "你好世界" # 4 CJK characters
|
|
assert estimate_tokens_rough(text) == 1
|
|
|
|
|
|
class TestEstimateMessagesTokensRough:
|
|
def test_empty_list(self):
|
|
assert estimate_messages_tokens_rough([]) == 0
|
|
|
|
def test_single_message_concrete_value(self):
|
|
"""Verify against known str(msg) length (ceiling division)."""
|
|
msg = {"role": "user", "content": "a" * 400}
|
|
result = estimate_messages_tokens_rough([msg])
|
|
n = len(str(msg))
|
|
expected = (n + 3) // 4
|
|
assert result == expected
|
|
|
|
def test_multiple_messages_additive(self):
|
|
msgs = [
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "assistant", "content": "Hi there, how can I help?"},
|
|
]
|
|
result = estimate_messages_tokens_rough(msgs)
|
|
n = sum(len(str(m)) for m in msgs)
|
|
expected = (n + 3) // 4
|
|
assert result == expected
|
|
|
|
def test_tool_call_message(self):
|
|
"""Tool call messages with no 'content' key still contribute tokens."""
|
|
msg = {"role": "assistant", "content": None,
|
|
"tool_calls": [{"id": "1", "function": {"name": "terminal", "arguments": "{}"}}]}
|
|
result = estimate_messages_tokens_rough([msg])
|
|
assert result > 0
|
|
assert result == (len(str(msg)) + 3) // 4
|
|
|
|
def test_message_with_list_content(self):
|
|
"""Vision messages with multimodal content arrays."""
|
|
msg = {"role": "user", "content": [
|
|
{"type": "text", "text": "describe"},
|
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}
|
|
]}
|
|
result = estimate_messages_tokens_rough([msg])
|
|
assert result == (len(str(msg)) + 3) // 4
|
|
|
|
|
|
# =========================================================================
|
|
# Default context lengths
|
|
# =========================================================================
|
|
|
|
class TestDefaultContextLengths:
|
|
def test_claude_models_context_lengths(self):
|
|
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
|
if "claude" not in key:
|
|
continue
|
|
# Claude 4.6+ models (4.6 and 4.7) have 1M context at standard
|
|
# API pricing (no long-context premium). Older Claude 4.x and
|
|
# 3.x models cap at 200k.
|
|
if any(tag in key for tag in ("4.6", "4-6", "4.7", "4-7")):
|
|
assert value == 1000000, f"{key} should be 1000000"
|
|
else:
|
|
assert value == 200000, f"{key} should be 200000"
|
|
|
|
def test_gpt4_models_128k_or_1m(self):
|
|
# gpt-4.1 and gpt-4.1-mini have 1M context; other gpt-4* have 128k
|
|
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
|
if "gpt-4" in key and "gpt-4.1" not in key:
|
|
assert value == 128000, f"{key} should be 128000"
|
|
|
|
def test_gpt41_models_1m(self):
|
|
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
|
if "gpt-4.1" in key:
|
|
assert value == 1047576, f"{key} should be 1047576"
|
|
|
|
def test_gemini_models_1m(self):
|
|
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
|
if "gemini" in key:
|
|
assert value == 1048576, f"{key} should be 1048576"
|
|
|
|
def test_grok_models_context_lengths(self):
|
|
# xAI /v1/models does not return context_length metadata, so
|
|
# DEFAULT_CONTEXT_LENGTHS must cover the Grok family explicitly.
|
|
# Values sourced from models.dev (2026-04).
|
|
expected = {
|
|
"grok-4.20": 2000000,
|
|
"grok-4-1-fast": 2000000,
|
|
"grok-4-fast": 2000000,
|
|
"grok-4": 256000,
|
|
"grok-code-fast": 256000,
|
|
"grok-3": 131072,
|
|
"grok-2": 131072,
|
|
"grok-2-vision": 8192,
|
|
"grok": 131072,
|
|
}
|
|
for key, value in expected.items():
|
|
assert key in DEFAULT_CONTEXT_LENGTHS, f"{key} missing from DEFAULT_CONTEXT_LENGTHS"
|
|
assert DEFAULT_CONTEXT_LENGTHS[key] == value, (
|
|
f"{key} should be {value}, got {DEFAULT_CONTEXT_LENGTHS[key]}"
|
|
)
|
|
|
|
def test_grok_substring_matching(self):
|
|
# Longest-first substring matching must resolve the real xAI model
|
|
# IDs to the correct fallback entries without 128k probe-down.
|
|
from agent.model_metadata import get_model_context_length
|
|
from unittest.mock import patch as mock_patch
|
|
|
|
# Fake the provider/API/cache layers so the lookup falls through
|
|
# to DEFAULT_CONTEXT_LENGTHS.
|
|
with mock_patch("agent.model_metadata.fetch_model_metadata", return_value={}), mock_patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), mock_patch("agent.model_metadata.get_cached_context_length", return_value=None):
|
|
cases = [
|
|
("grok-4.20-0309-reasoning", 2000000),
|
|
("grok-4.20-0309-non-reasoning", 2000000),
|
|
("grok-4.20-multi-agent-0309", 2000000),
|
|
("grok-4-1-fast-reasoning", 2000000),
|
|
("grok-4-1-fast-non-reasoning", 2000000),
|
|
("grok-4-fast-reasoning", 2000000),
|
|
("grok-4-fast-non-reasoning", 2000000),
|
|
("grok-4", 256000),
|
|
("grok-4-0709", 256000),
|
|
("grok-code-fast-1", 256000),
|
|
("grok-3", 131072),
|
|
("grok-3-mini", 131072),
|
|
("grok-3-mini-fast", 131072),
|
|
("grok-2", 131072),
|
|
("grok-2-vision", 8192),
|
|
("grok-2-vision-1212", 8192),
|
|
("grok-beta", 131072),
|
|
]
|
|
for model_id, expected_ctx in cases:
|
|
actual = get_model_context_length(model_id)
|
|
assert actual == expected_ctx, (
|
|
f"{model_id}: expected {expected_ctx}, got {actual}"
|
|
)
|
|
|
|
def test_all_values_positive(self):
|
|
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
|
assert value > 0, f"{key} has non-positive context length"
|
|
|
|
def test_dict_is_not_empty(self):
|
|
assert len(DEFAULT_CONTEXT_LENGTHS) >= 10
|
|
|
|
|
|
# =========================================================================
|
|
# Codex OAuth context-window resolution (provider="openai-codex")
|
|
# =========================================================================
|
|
|
|
class TestCodexOAuthContextLength:
|
|
"""ChatGPT Codex OAuth imposes lower context limits than the direct
|
|
OpenAI API for the same slugs. Verified Apr 2026 via live probe of
|
|
chatgpt.com/backend-api/codex/models: every model returns 272k, while
|
|
models.dev reports 1.05M for gpt-5.5/gpt-5.4 and 400k for the rest.
|
|
"""
|
|
|
|
def setup_method(self):
|
|
import agent.model_metadata as mm
|
|
mm._codex_oauth_context_cache = {}
|
|
mm._codex_oauth_context_cache_time = 0.0
|
|
|
|
def test_fallback_table_used_without_token(self):
|
|
"""With no access token, the hardcoded Codex fallback table wins
|
|
over models.dev (which reports 1.05M for gpt-5.5 but Codex is 272k).
|
|
"""
|
|
from agent.model_metadata import get_model_context_length
|
|
|
|
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
|
patch("agent.model_metadata.save_context_length"):
|
|
for model in (
|
|
"gpt-5.5",
|
|
"gpt-5.4",
|
|
"gpt-5.4-mini",
|
|
"gpt-5.3-codex",
|
|
"gpt-5.2-codex",
|
|
"gpt-5.1-codex-max",
|
|
"gpt-5.1-codex-mini",
|
|
):
|
|
ctx = get_model_context_length(
|
|
model=model,
|
|
base_url="https://chatgpt.com/backend-api/codex",
|
|
api_key="",
|
|
provider="openai-codex",
|
|
)
|
|
assert ctx == 272_000, (
|
|
f"Codex {model}: expected 272000 fallback, got {ctx} "
|
|
"(models.dev leakage?)"
|
|
)
|
|
|
|
def test_live_probe_overrides_fallback(self):
|
|
"""When a token is provided, the live /models probe is preferred
|
|
and its context_window drives the result."""
|
|
from agent.model_metadata import get_model_context_length
|
|
|
|
fake_response = MagicMock()
|
|
fake_response.status_code = 200
|
|
fake_response.json.return_value = {
|
|
"models": [
|
|
{"slug": "gpt-5.5", "context_window": 300_000},
|
|
{"slug": "gpt-5.4", "context_window": 400_000},
|
|
]
|
|
}
|
|
|
|
with patch("agent.model_metadata.requests.get", return_value=fake_response), \
|
|
patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
|
patch("agent.model_metadata.save_context_length"):
|
|
ctx_55 = get_model_context_length(
|
|
model="gpt-5.5",
|
|
base_url="https://chatgpt.com/backend-api/codex",
|
|
api_key="fake-token",
|
|
provider="openai-codex",
|
|
)
|
|
ctx_54 = get_model_context_length(
|
|
model="gpt-5.4",
|
|
base_url="https://chatgpt.com/backend-api/codex",
|
|
api_key="fake-token",
|
|
provider="openai-codex",
|
|
)
|
|
assert ctx_55 == 300_000
|
|
assert ctx_54 == 400_000
|
|
|
|
def test_probe_failure_falls_back_to_hardcoded(self):
|
|
"""If the probe fails (non-200 / network error), we still return
|
|
the hardcoded 272k rather than leaking through to models.dev 1.05M."""
|
|
from agent.model_metadata import get_model_context_length
|
|
|
|
fake_response = MagicMock()
|
|
fake_response.status_code = 401
|
|
fake_response.json.return_value = {}
|
|
|
|
with patch("agent.model_metadata.requests.get", return_value=fake_response), \
|
|
patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
|
patch("agent.model_metadata.save_context_length"):
|
|
ctx = get_model_context_length(
|
|
model="gpt-5.5",
|
|
base_url="https://chatgpt.com/backend-api/codex",
|
|
api_key="expired-token",
|
|
provider="openai-codex",
|
|
)
|
|
assert ctx == 272_000
|
|
|
|
def test_non_codex_providers_unaffected(self):
|
|
"""Resolving gpt-5.5 on non-Codex providers must NOT use the Codex
|
|
272k override — OpenRouter / direct OpenAI API have different limits.
|
|
"""
|
|
from agent.model_metadata import get_model_context_length
|
|
|
|
# OpenRouter — should hit its own catalog path first; when mocked
|
|
# empty, falls through to hardcoded DEFAULT_CONTEXT_LENGTHS (400k).
|
|
with patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
|
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
|
patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
|
patch("agent.models_dev.lookup_models_dev_context", return_value=None):
|
|
ctx = get_model_context_length(
|
|
model="openai/gpt-5.5",
|
|
base_url="https://openrouter.ai/api/v1",
|
|
api_key="",
|
|
provider="openrouter",
|
|
)
|
|
assert ctx == 400_000, (
|
|
f"Non-Codex gpt-5.5 resolved to {ctx}; Codex 272k override "
|
|
"leaked outside openai-codex provider"
|
|
)
|
|
|
|
def test_stale_codex_cache_over_400k_is_invalidated(self, tmp_path, monkeypatch):
|
|
"""Pre-PR #14935 builds cached gpt-5.5 at 1.05M (from models.dev)
|
|
before the Codex-aware branch existed. Upgrading users keep that
|
|
stale entry on disk and the cache-first lookup returns it forever.
|
|
Codex OAuth caps at 272k for every slug, so any cached Codex
|
|
entry >= 400k must be dropped and re-resolved via the live probe.
|
|
"""
|
|
from agent import model_metadata as mm
|
|
|
|
# Isolate the cache file to tmp_path
|
|
cache_file = tmp_path / "context_length_cache.yaml"
|
|
monkeypatch.setattr(mm, "_get_context_cache_path", lambda: cache_file)
|
|
|
|
base_url = "https://chatgpt.com/backend-api/codex/"
|
|
stale_key = f"gpt-5.5@{base_url}"
|
|
other_key = "other-model@https://api.openai.com/v1/"
|
|
import yaml as _yaml
|
|
cache_file.write_text(_yaml.dump({"context_lengths": {
|
|
stale_key: 1_050_000, # stale pre-fix value
|
|
other_key: 128_000, # unrelated, must survive
|
|
}}))
|
|
|
|
fake_response = MagicMock()
|
|
fake_response.status_code = 200
|
|
fake_response.json.return_value = {
|
|
"models": [{"slug": "gpt-5.5", "context_window": 272_000}]
|
|
}
|
|
|
|
with patch("agent.model_metadata.requests.get", return_value=fake_response), \
|
|
patch("agent.model_metadata.save_context_length") as mock_save:
|
|
ctx = mm.get_model_context_length(
|
|
model="gpt-5.5",
|
|
base_url=base_url,
|
|
api_key="fake-token",
|
|
provider="openai-codex",
|
|
)
|
|
|
|
assert ctx == 272_000, f"Stale entry should have been re-resolved to 272k, got {ctx}"
|
|
# Live save was called with the fresh value
|
|
mock_save.assert_called_with("gpt-5.5", base_url, 272_000)
|
|
# The stale entry was removed from disk; unrelated entries survived
|
|
remaining = _yaml.safe_load(cache_file.read_text()).get("context_lengths", {})
|
|
assert stale_key not in remaining, "Stale entry was not invalidated from the cache file"
|
|
assert remaining.get(other_key) == 128_000, "Unrelated cache entries must not be touched"
|
|
|
|
def test_fresh_codex_cache_under_400k_is_respected(self, tmp_path, monkeypatch):
|
|
"""Codex entries at the correct 272k must NOT be invalidated —
|
|
only stale pre-fix values (>= 400k) get dropped."""
|
|
from agent import model_metadata as mm
|
|
|
|
cache_file = tmp_path / "context_length_cache.yaml"
|
|
monkeypatch.setattr(mm, "_get_context_cache_path", lambda: cache_file)
|
|
|
|
base_url = "https://chatgpt.com/backend-api/codex/"
|
|
import yaml as _yaml
|
|
cache_file.write_text(_yaml.dump({"context_lengths": {
|
|
f"gpt-5.5@{base_url}": 272_000,
|
|
}}))
|
|
|
|
# If the invalidation incorrectly fired, this would be called; assert it isn't.
|
|
with patch("agent.model_metadata.requests.get") as mock_get:
|
|
ctx = mm.get_model_context_length(
|
|
model="gpt-5.5",
|
|
base_url=base_url,
|
|
api_key="fake-token",
|
|
provider="openai-codex",
|
|
)
|
|
assert ctx == 272_000
|
|
mock_get.assert_not_called()
|
|
|
|
def test_stale_invalidation_scoped_to_codex_provider(self, tmp_path, monkeypatch):
|
|
"""A cached 1M entry for a non-Codex provider (e.g. Anthropic opus on
|
|
OpenRouter, legitimately 1M) must NOT be invalidated by this guard."""
|
|
from agent import model_metadata as mm
|
|
|
|
cache_file = tmp_path / "context_length_cache.yaml"
|
|
monkeypatch.setattr(mm, "_get_context_cache_path", lambda: cache_file)
|
|
|
|
base_url = "https://openrouter.ai/api/v1"
|
|
import yaml as _yaml
|
|
cache_file.write_text(_yaml.dump({"context_lengths": {
|
|
f"anthropic/claude-opus-4.6@{base_url}": 1_000_000,
|
|
}}))
|
|
|
|
ctx = mm.get_model_context_length(
|
|
model="anthropic/claude-opus-4.6",
|
|
base_url=base_url,
|
|
api_key="fake",
|
|
provider="openrouter",
|
|
)
|
|
assert ctx == 1_000_000, "Non-codex 1M cache entries must be respected"
|
|
|
|
|
|
# =========================================================================
|
|
# get_model_context_length — resolution order
|
|
# =========================================================================
|
|
|
|
class TestGetModelContextLength:
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_known_model_from_api(self, mock_fetch):
|
|
mock_fetch.return_value = {
|
|
"test/model": {"context_length": 32000}
|
|
}
|
|
assert get_model_context_length("test/model") == 32000
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_fallback_to_defaults(self, mock_fetch):
|
|
mock_fetch.return_value = {}
|
|
assert get_model_context_length("anthropic/claude-sonnet-4") == 200000
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_unknown_model_returns_first_probe_tier(self, mock_fetch):
|
|
mock_fetch.return_value = {}
|
|
assert get_model_context_length("unknown/never-heard-of-this") == CONTEXT_PROBE_TIERS[0]
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_partial_match_in_defaults(self, mock_fetch):
|
|
mock_fetch.return_value = {}
|
|
assert get_model_context_length("openai/gpt-4o") == 128000
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_qwen3_coder_plus_context_length(self, mock_fetch):
|
|
"""qwen3-coder-plus has a 1M context window, not the generic 128K Qwen default."""
|
|
mock_fetch.return_value = {}
|
|
assert get_model_context_length("qwen3-coder-plus") == 1000000
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_qwen3_coder_context_length(self, mock_fetch):
|
|
"""qwen3-coder has a 256K context window, not the generic 128K Qwen default."""
|
|
mock_fetch.return_value = {}
|
|
assert get_model_context_length("qwen3-coder") == 262144
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_qwen_generic_context_length(self, mock_fetch):
|
|
"""Generic qwen models still get the 128K default."""
|
|
mock_fetch.return_value = {}
|
|
assert get_model_context_length("qwen3-plus") == 131072
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_api_missing_context_length_key(self, mock_fetch):
|
|
"""Model in API but without context_length → defaults to 128000."""
|
|
mock_fetch.return_value = {"test/model": {"name": "Test"}}
|
|
assert get_model_context_length("test/model") == 128000
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_cache_takes_priority_over_api(self, mock_fetch, tmp_path):
|
|
"""Persistent cache should be checked BEFORE API metadata."""
|
|
mock_fetch.return_value = {"my/model": {"context_length": 999999}}
|
|
cache_file = tmp_path / "cache.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
save_context_length("my/model", "http://local", 32768)
|
|
result = get_model_context_length("my/model", base_url="http://local")
|
|
assert result == 32768 # cache wins over API's 999999
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_no_base_url_skips_cache(self, mock_fetch, tmp_path):
|
|
"""Without base_url, cache lookup is skipped."""
|
|
mock_fetch.return_value = {}
|
|
cache_file = tmp_path / "cache.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
save_context_length("custom/model", "http://local", 32768)
|
|
# No base_url → cache skipped → falls to probe tier
|
|
result = get_model_context_length("custom/model")
|
|
assert result == CONTEXT_PROBE_TIERS[0]
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
|
|
def test_custom_endpoint_metadata_beats_fuzzy_default(self, mock_endpoint_fetch, mock_fetch):
|
|
mock_fetch.return_value = {}
|
|
mock_endpoint_fetch.return_value = {
|
|
"zai-org/GLM-5-TEE": {"context_length": 65536}
|
|
}
|
|
|
|
result = get_model_context_length(
|
|
"zai-org/GLM-5-TEE",
|
|
base_url="https://llm.chutes.ai/v1",
|
|
api_key="test-key",
|
|
)
|
|
|
|
assert result == 65536
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
|
|
def test_custom_endpoint_without_metadata_skips_name_based_default(self, mock_endpoint_fetch, mock_fetch):
|
|
mock_fetch.return_value = {}
|
|
mock_endpoint_fetch.return_value = {}
|
|
|
|
result = get_model_context_length(
|
|
"zai-org/GLM-5-TEE",
|
|
base_url="https://llm.chutes.ai/v1",
|
|
api_key="test-key",
|
|
)
|
|
|
|
assert result == CONTEXT_PROBE_TIERS[0]
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
|
|
def test_custom_endpoint_single_model_fallback(self, mock_endpoint_fetch, mock_fetch):
|
|
"""Single-model servers: use the only model even if name doesn't match."""
|
|
mock_fetch.return_value = {}
|
|
mock_endpoint_fetch.return_value = {
|
|
"Qwen3.5-9B-Q4_K_M.gguf": {"context_length": 131072}
|
|
}
|
|
|
|
result = get_model_context_length(
|
|
"qwen3.5:9b",
|
|
base_url="http://myserver.example.com:8080/v1",
|
|
api_key="test-key",
|
|
)
|
|
|
|
assert result == 131072
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
|
|
def test_custom_endpoint_fuzzy_substring_match(self, mock_endpoint_fetch, mock_fetch):
|
|
"""Fuzzy match: configured model name is substring of endpoint model."""
|
|
mock_fetch.return_value = {}
|
|
mock_endpoint_fetch.return_value = {
|
|
"org/llama-3.3-70b-instruct-fp8": {"context_length": 131072},
|
|
"org/qwen-2.5-72b": {"context_length": 32768},
|
|
}
|
|
|
|
result = get_model_context_length(
|
|
"llama-3.3-70b-instruct",
|
|
base_url="http://myserver.example.com:8080/v1",
|
|
api_key="test-key",
|
|
)
|
|
|
|
assert result == 131072
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_config_context_length_overrides_all(self, mock_fetch):
|
|
"""Explicit config_context_length takes priority over everything."""
|
|
mock_fetch.return_value = {
|
|
"test/model": {"context_length": 200000}
|
|
}
|
|
|
|
result = get_model_context_length(
|
|
"test/model",
|
|
config_context_length=65536,
|
|
)
|
|
|
|
assert result == 65536
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_config_context_length_zero_is_ignored(self, mock_fetch):
|
|
"""config_context_length=0 should be treated as unset."""
|
|
mock_fetch.return_value = {}
|
|
|
|
result = get_model_context_length(
|
|
"anthropic/claude-sonnet-4",
|
|
config_context_length=0,
|
|
)
|
|
|
|
assert result == 200000
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_config_context_length_none_is_ignored(self, mock_fetch):
|
|
"""config_context_length=None should be treated as unset."""
|
|
mock_fetch.return_value = {}
|
|
|
|
result = get_model_context_length(
|
|
"anthropic/claude-sonnet-4",
|
|
config_context_length=None,
|
|
)
|
|
|
|
assert result == 200000
|
|
|
|
|
|
# =========================================================================
|
|
# Bedrock context resolution — must run BEFORE custom-endpoint probe
|
|
# =========================================================================
|
|
|
|
class TestBedrockContextResolution:
|
|
"""Regression tests for Bedrock context-length resolution order.
|
|
|
|
Bug: because ``bedrock-runtime.<region>.amazonaws.com`` is not listed in
|
|
``_URL_TO_PROVIDER``, ``_is_known_provider_base_url`` returned False and
|
|
the custom-endpoint probe at step 2 ran first — fetching ``/models`` from
|
|
Bedrock (which it doesn't serve), returning the 128K default-fallback
|
|
before execution ever reached the Bedrock branch.
|
|
|
|
Fix: promote the Bedrock branch ahead of the custom-endpoint probe.
|
|
"""
|
|
|
|
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
|
|
def test_bedrock_provider_returns_static_table_before_probe(self, mock_fetch):
|
|
"""provider='bedrock' resolves via static table, bypasses /models probe."""
|
|
ctx = get_model_context_length(
|
|
"anthropic.claude-opus-4-v1:0",
|
|
provider="bedrock",
|
|
base_url="https://bedrock-runtime.us-east-1.amazonaws.com",
|
|
)
|
|
# Must return the static Bedrock table value (200K for Claude),
|
|
# NOT DEFAULT_FALLBACK_CONTEXT (128K).
|
|
assert ctx == 200000
|
|
mock_fetch.assert_not_called()
|
|
|
|
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
|
|
def test_bedrock_url_without_provider_hint(self, mock_fetch):
|
|
"""bedrock-runtime host infers Bedrock even when provider is omitted."""
|
|
ctx = get_model_context_length(
|
|
"anthropic.claude-sonnet-4-v1:0",
|
|
base_url="https://bedrock-runtime.us-west-2.amazonaws.com",
|
|
)
|
|
assert ctx == 200000
|
|
mock_fetch.assert_not_called()
|
|
|
|
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
|
|
def test_non_bedrock_url_still_probes(self, mock_fetch):
|
|
"""Non-Bedrock hosts still reach the custom-endpoint probe."""
|
|
mock_fetch.return_value = {"some-model": {"context_length": 50000}}
|
|
ctx = get_model_context_length(
|
|
"some-model",
|
|
base_url="https://api.example.com/v1",
|
|
)
|
|
assert ctx == 50000
|
|
assert mock_fetch.called
|
|
|
|
|
|
# =========================================================================
|
|
# _strip_provider_prefix — Ollama model:tag vs provider:model
|
|
# =========================================================================
|
|
|
|
class TestStripProviderPrefix:
|
|
def test_known_provider_prefix_is_stripped(self):
|
|
assert _strip_provider_prefix("local:my-model") == "my-model"
|
|
assert _strip_provider_prefix("openrouter:anthropic/claude-sonnet-4") == "anthropic/claude-sonnet-4"
|
|
assert _strip_provider_prefix("anthropic:claude-sonnet-4") == "claude-sonnet-4"
|
|
assert _strip_provider_prefix("stepfun:step-3.5-flash") == "step-3.5-flash"
|
|
|
|
def test_ollama_model_tag_preserved(self):
|
|
"""Ollama model:tag format must NOT be stripped."""
|
|
assert _strip_provider_prefix("qwen3.5:27b") == "qwen3.5:27b"
|
|
assert _strip_provider_prefix("llama3.3:70b") == "llama3.3:70b"
|
|
assert _strip_provider_prefix("gemma2:9b") == "gemma2:9b"
|
|
assert _strip_provider_prefix("codellama:13b-instruct-q4_0") == "codellama:13b-instruct-q4_0"
|
|
|
|
def test_http_urls_preserved(self):
|
|
assert _strip_provider_prefix("http://example.com") == "http://example.com"
|
|
assert _strip_provider_prefix("https://example.com") == "https://example.com"
|
|
|
|
def test_no_colon_returns_unchanged(self):
|
|
assert _strip_provider_prefix("gpt-4o") == "gpt-4o"
|
|
assert _strip_provider_prefix("anthropic/claude-sonnet-4") == "anthropic/claude-sonnet-4"
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_ollama_model_tag_not_mangled_in_context_lookup(self, mock_fetch):
|
|
"""Ensure 'qwen3.5:27b' is NOT reduced to '27b' during context length lookup.
|
|
|
|
We mock a custom endpoint that knows 'qwen3.5:27b' — the full name
|
|
must reach the endpoint metadata lookup intact.
|
|
"""
|
|
mock_fetch.return_value = {}
|
|
with patch("agent.model_metadata.fetch_endpoint_model_metadata") as mock_ep, \
|
|
patch("agent.model_metadata._is_custom_endpoint", return_value=True):
|
|
mock_ep.return_value = {"qwen3.5:27b": {"context_length": 32768}}
|
|
result = get_model_context_length(
|
|
"qwen3.5:27b",
|
|
base_url="http://localhost:11434/v1",
|
|
)
|
|
assert result == 32768
|
|
|
|
|
|
# =========================================================================
|
|
# fetch_model_metadata — caching, TTL, slugs, failures
|
|
# =========================================================================
|
|
|
|
class TestFetchModelMetadata:
|
|
def _reset_cache(self):
|
|
import agent.model_metadata as mm
|
|
mm._model_metadata_cache = {}
|
|
mm._model_metadata_cache_time = 0
|
|
|
|
@patch("agent.model_metadata.requests.get")
|
|
def test_caches_result(self, mock_get):
|
|
self._reset_cache()
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {
|
|
"data": [{"id": "test/model", "context_length": 99999, "name": "Test"}]
|
|
}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_get.return_value = mock_response
|
|
|
|
result1 = fetch_model_metadata(force_refresh=True)
|
|
assert "test/model" in result1
|
|
assert mock_get.call_count == 1
|
|
|
|
result2 = fetch_model_metadata()
|
|
assert "test/model" in result2
|
|
assert mock_get.call_count == 1 # cached
|
|
|
|
@patch("agent.model_metadata.requests.get")
|
|
def test_api_failure_returns_empty_on_cold_cache(self, mock_get):
|
|
self._reset_cache()
|
|
mock_get.side_effect = Exception("Network error")
|
|
result = fetch_model_metadata(force_refresh=True)
|
|
assert result == {}
|
|
|
|
@patch("agent.model_metadata.requests.get")
|
|
def test_api_failure_returns_stale_cache(self, mock_get):
|
|
"""On API failure with existing cache, stale data is returned."""
|
|
import agent.model_metadata as mm
|
|
mm._model_metadata_cache = {"old/model": {"context_length": 50000}}
|
|
mm._model_metadata_cache_time = 0 # expired
|
|
|
|
mock_get.side_effect = Exception("Network error")
|
|
result = fetch_model_metadata(force_refresh=True)
|
|
assert "old/model" in result
|
|
assert result["old/model"]["context_length"] == 50000
|
|
|
|
@patch("agent.model_metadata.requests.get")
|
|
def test_canonical_slug_aliasing(self, mock_get):
|
|
"""Models with canonical_slug get indexed under both IDs."""
|
|
self._reset_cache()
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {
|
|
"data": [{
|
|
"id": "anthropic/claude-3.5-sonnet:beta",
|
|
"canonical_slug": "anthropic/claude-3.5-sonnet",
|
|
"context_length": 200000,
|
|
"name": "Claude 3.5 Sonnet"
|
|
}]
|
|
}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_get.return_value = mock_response
|
|
|
|
result = fetch_model_metadata(force_refresh=True)
|
|
# Both the original ID and canonical slug should work
|
|
assert "anthropic/claude-3.5-sonnet:beta" in result
|
|
assert "anthropic/claude-3.5-sonnet" in result
|
|
assert result["anthropic/claude-3.5-sonnet"]["context_length"] == 200000
|
|
|
|
@patch("agent.model_metadata.requests.get")
|
|
def test_provider_prefixed_models_get_bare_aliases(self, mock_get):
|
|
self._reset_cache()
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {
|
|
"data": [{
|
|
"id": "provider/test-model",
|
|
"context_length": 123456,
|
|
"name": "Provider: Test Model",
|
|
}]
|
|
}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_get.return_value = mock_response
|
|
|
|
result = fetch_model_metadata(force_refresh=True)
|
|
|
|
assert result["provider/test-model"]["context_length"] == 123456
|
|
assert result["test-model"]["context_length"] == 123456
|
|
|
|
@patch("agent.model_metadata.requests.get")
|
|
def test_ttl_expiry_triggers_refetch(self, mock_get):
|
|
"""Cache expires after _MODEL_CACHE_TTL seconds."""
|
|
import agent.model_metadata as mm
|
|
self._reset_cache()
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {
|
|
"data": [{"id": "m1", "context_length": 1000, "name": "M1"}]
|
|
}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_get.return_value = mock_response
|
|
|
|
fetch_model_metadata(force_refresh=True)
|
|
assert mock_get.call_count == 1
|
|
|
|
# Simulate TTL expiry
|
|
mm._model_metadata_cache_time = time.time() - _MODEL_CACHE_TTL - 1
|
|
fetch_model_metadata()
|
|
assert mock_get.call_count == 2 # refetched
|
|
|
|
@patch("agent.model_metadata.requests.get")
|
|
def test_malformed_json_no_data_key(self, mock_get):
|
|
"""API returns JSON without 'data' key — empty cache, no crash."""
|
|
self._reset_cache()
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {"error": "something"}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_get.return_value = mock_response
|
|
|
|
result = fetch_model_metadata(force_refresh=True)
|
|
assert result == {}
|
|
|
|
|
|
# =========================================================================
|
|
# Context probe tiers
|
|
# =========================================================================
|
|
|
|
class TestContextProbeTiers:
|
|
def test_tiers_descending(self):
|
|
for i in range(len(CONTEXT_PROBE_TIERS) - 1):
|
|
assert CONTEXT_PROBE_TIERS[i] > CONTEXT_PROBE_TIERS[i + 1]
|
|
|
|
def test_first_tier_is_128k(self):
|
|
assert CONTEXT_PROBE_TIERS[0] == 128_000
|
|
|
|
def test_last_tier_is_8k(self):
|
|
assert CONTEXT_PROBE_TIERS[-1] == 8_000
|
|
|
|
|
|
class TestGetNextProbeTier:
|
|
def test_from_128k(self):
|
|
assert get_next_probe_tier(128_000) == 64_000
|
|
|
|
def test_from_64k(self):
|
|
assert get_next_probe_tier(64_000) == 32_000
|
|
|
|
def test_from_32k(self):
|
|
assert get_next_probe_tier(32_000) == 16_000
|
|
|
|
def test_from_8k_returns_none(self):
|
|
assert get_next_probe_tier(8_000) is None
|
|
|
|
def test_from_below_min_returns_none(self):
|
|
assert get_next_probe_tier(4_000) is None
|
|
|
|
def test_from_arbitrary_value(self):
|
|
assert get_next_probe_tier(100_000) == 64_000
|
|
|
|
def test_above_max_tier(self):
|
|
"""Value above 128K should return 128K."""
|
|
assert get_next_probe_tier(500_000) == 128_000
|
|
|
|
def test_zero_returns_none(self):
|
|
assert get_next_probe_tier(0) is None
|
|
|
|
|
|
# =========================================================================
|
|
# Error message parsing
|
|
# =========================================================================
|
|
|
|
class TestParseContextLimitFromError:
|
|
def test_openai_format(self):
|
|
msg = "This model's maximum context length is 32768 tokens. However, your messages resulted in 45000 tokens."
|
|
assert parse_context_limit_from_error(msg) == 32768
|
|
|
|
def test_context_length_exceeded(self):
|
|
msg = "context_length_exceeded: maximum context length is 131072"
|
|
assert parse_context_limit_from_error(msg) == 131072
|
|
|
|
def test_context_size_exceeded(self):
|
|
msg = "Maximum context size 65536 exceeded"
|
|
assert parse_context_limit_from_error(msg) == 65536
|
|
|
|
def test_no_limit_in_message(self):
|
|
assert parse_context_limit_from_error("Something went wrong with the API") is None
|
|
|
|
def test_unreasonable_small_number_rejected(self):
|
|
assert parse_context_limit_from_error("context length is 42 tokens") is None
|
|
|
|
def test_ollama_format(self):
|
|
msg = "Context size has been exceeded. Maximum context size is 32768"
|
|
assert parse_context_limit_from_error(msg) == 32768
|
|
|
|
def test_anthropic_format(self):
|
|
msg = "prompt is too long: 250000 tokens > 200000 maximum"
|
|
# Should extract 200000 (the limit), not 250000 (the input size)
|
|
assert parse_context_limit_from_error(msg) == 200000
|
|
|
|
def test_lmstudio_format(self):
|
|
msg = "Error: context window of 4096 tokens exceeded"
|
|
assert parse_context_limit_from_error(msg) == 4096
|
|
|
|
def test_minimax_delta_only_message_returns_none(self):
|
|
msg = "invalid params, context window exceeds limit (2013)"
|
|
assert parse_context_limit_from_error(msg) is None
|
|
|
|
def test_completely_unrelated_error(self):
|
|
assert parse_context_limit_from_error("Invalid API key") is None
|
|
|
|
def test_empty_string(self):
|
|
assert parse_context_limit_from_error("") is None
|
|
|
|
def test_number_outside_reasonable_range(self):
|
|
"""Very large number (>10M) should be rejected."""
|
|
msg = "maximum context length is 99999999999"
|
|
assert parse_context_limit_from_error(msg) is None
|
|
|
|
|
|
# =========================================================================
|
|
# Persistent context length cache
|
|
# =========================================================================
|
|
|
|
class TestContextLengthCache:
|
|
def test_save_and_load(self, tmp_path):
|
|
cache_file = tmp_path / "cache.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
save_context_length("test/model", "http://localhost:8080/v1", 32768)
|
|
assert get_cached_context_length("test/model", "http://localhost:8080/v1") == 32768
|
|
|
|
def test_missing_cache_returns_none(self, tmp_path):
|
|
cache_file = tmp_path / "nonexistent.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
assert get_cached_context_length("test/model", "http://x") is None
|
|
|
|
def test_multiple_models_cached(self, tmp_path):
|
|
cache_file = tmp_path / "cache.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
save_context_length("model-a", "http://a", 64000)
|
|
save_context_length("model-b", "http://b", 128000)
|
|
assert get_cached_context_length("model-a", "http://a") == 64000
|
|
assert get_cached_context_length("model-b", "http://b") == 128000
|
|
|
|
def test_same_model_different_providers(self, tmp_path):
|
|
cache_file = tmp_path / "cache.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
save_context_length("llama-3", "http://local:8080", 32768)
|
|
save_context_length("llama-3", "https://openrouter.ai/api/v1", 131072)
|
|
assert get_cached_context_length("llama-3", "http://local:8080") == 32768
|
|
assert get_cached_context_length("llama-3", "https://openrouter.ai/api/v1") == 131072
|
|
|
|
def test_idempotent_save(self, tmp_path):
|
|
cache_file = tmp_path / "cache.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
save_context_length("model", "http://x", 32768)
|
|
save_context_length("model", "http://x", 32768)
|
|
with open(cache_file) as f:
|
|
data = yaml.safe_load(f)
|
|
assert len(data["context_lengths"]) == 1
|
|
|
|
def test_update_existing_value(self, tmp_path):
|
|
"""Saving a different value for the same key overwrites it."""
|
|
cache_file = tmp_path / "cache.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
save_context_length("model", "http://x", 128000)
|
|
save_context_length("model", "http://x", 64000)
|
|
assert get_cached_context_length("model", "http://x") == 64000
|
|
|
|
def test_corrupted_yaml_returns_empty(self, tmp_path):
|
|
"""Corrupted cache file is handled gracefully."""
|
|
cache_file = tmp_path / "cache.yaml"
|
|
cache_file.write_text("{{{{not valid yaml: [[[")
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
assert get_cached_context_length("model", "http://x") is None
|
|
|
|
def test_wrong_structure_returns_none(self, tmp_path):
|
|
"""YAML that loads but has wrong structure."""
|
|
cache_file = tmp_path / "cache.yaml"
|
|
cache_file.write_text("just_a_string\n")
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
assert get_cached_context_length("model", "http://x") is None
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_cached_value_takes_priority(self, mock_fetch, tmp_path):
|
|
mock_fetch.return_value = {}
|
|
cache_file = tmp_path / "cache.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
save_context_length("unknown/model", "http://local", 65536)
|
|
assert get_model_context_length("unknown/model", base_url="http://local") == 65536
|
|
|
|
def test_special_chars_in_model_name(self, tmp_path):
|
|
"""Model names with colons, slashes, etc. don't break the cache."""
|
|
cache_file = tmp_path / "cache.yaml"
|
|
model = "anthropic/claude-3.5-sonnet:beta"
|
|
url = "https://api.example.com/v1"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
save_context_length(model, url, 200000)
|
|
assert get_cached_context_length(model, url) == 200000
|