mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat: use endpoint metadata for custom model context and pricing (#1906)
* perf: cache base_url.lower() via property, consolidate triple load_config(), hoist set constant run_agent.py: - Add base_url property that auto-caches _base_url_lower on every assignment, eliminating 12+ redundant .lower() calls per API cycle across __init__, _build_api_kwargs, _supports_reasoning_extra_body, and the main conversation loop - Consolidate three separate load_config() disk reads in __init__ (memory, skills, compression) into a single call, reusing the result dict for all three config sections model_tools.py: - Hoist _READ_SEARCH_TOOLS set to module level (was rebuilt inside handle_function_call on every tool invocation) * Use endpoint metadata for custom model context and pricing --------- Co-authored-by: kshitij <82637225+kshitijk4poor@users.noreply.github.com>
This commit is contained in:
parent
11f029c311
commit
a2440f72f6
7 changed files with 375 additions and 49 deletions
|
|
@ -45,16 +45,18 @@ class ContextCompressor:
|
|||
quiet_mode: bool = False,
|
||||
summary_model_override: str = None,
|
||||
base_url: str = "",
|
||||
api_key: str = "",
|
||||
):
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.threshold_percent = threshold_percent
|
||||
self.protect_first_n = protect_first_n
|
||||
self.protect_last_n = protect_last_n
|
||||
self.summary_target_tokens = summary_target_tokens
|
||||
self.quiet_mode = quiet_mode
|
||||
|
||||
self.context_length = get_model_context_length(model, base_url=base_url)
|
||||
self.context_length = get_model_context_length(model, base_url=base_url, api_key=api_key)
|
||||
self.threshold_tokens = int(self.context_length * threshold_percent)
|
||||
self.compression_count = 0
|
||||
self._context_probed = False # True after a step-down from context error
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import re
|
|||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
|
@ -21,6 +22,9 @@ logger = logging.getLogger(__name__)
|
|||
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_model_metadata_cache_time: float = 0
|
||||
_MODEL_CACHE_TTL = 3600
|
||||
_endpoint_model_metadata_cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
|
||||
_endpoint_model_metadata_cache_time: Dict[str, float] = {}
|
||||
_ENDPOINT_MODEL_CACHE_TTL = 300
|
||||
|
||||
# Descending tiers for context length probing when the model is unknown.
|
||||
# We start high and step down on context-length errors until one works.
|
||||
|
|
@ -123,6 +127,128 @@ DEFAULT_CONTEXT_LENGTHS = {
|
|||
"qwen-vl-max": 32768,
|
||||
}
|
||||
|
||||
_CONTEXT_LENGTH_KEYS = (
|
||||
"context_length",
|
||||
"context_window",
|
||||
"max_context_length",
|
||||
"max_position_embeddings",
|
||||
"max_model_len",
|
||||
"max_input_tokens",
|
||||
"max_sequence_length",
|
||||
"max_seq_len",
|
||||
)
|
||||
|
||||
_MAX_COMPLETION_KEYS = (
|
||||
"max_completion_tokens",
|
||||
"max_output_tokens",
|
||||
"max_tokens",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_base_url(base_url: str) -> str:
|
||||
return (base_url or "").strip().rstrip("/")
|
||||
|
||||
|
||||
def _is_openrouter_base_url(base_url: str) -> bool:
|
||||
return "openrouter.ai" in _normalize_base_url(base_url).lower()
|
||||
|
||||
|
||||
def _is_custom_endpoint(base_url: str) -> bool:
|
||||
normalized = _normalize_base_url(base_url)
|
||||
return bool(normalized) and not _is_openrouter_base_url(normalized)
|
||||
|
||||
|
||||
def _is_known_provider_base_url(base_url: str) -> bool:
|
||||
normalized = _normalize_base_url(base_url)
|
||||
if not normalized:
|
||||
return False
|
||||
parsed = urlparse(normalized if "://" in normalized else f"https://{normalized}")
|
||||
host = parsed.netloc.lower() or parsed.path.lower()
|
||||
known_hosts = (
|
||||
"api.openai.com",
|
||||
"chatgpt.com",
|
||||
"api.anthropic.com",
|
||||
"api.z.ai",
|
||||
"api.moonshot.ai",
|
||||
"api.kimi.com",
|
||||
"api.minimax",
|
||||
)
|
||||
return any(known_host in host for known_host in known_hosts)
|
||||
|
||||
|
||||
def _iter_nested_dicts(value: Any):
|
||||
if isinstance(value, dict):
|
||||
yield value
|
||||
for nested in value.values():
|
||||
yield from _iter_nested_dicts(nested)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
yield from _iter_nested_dicts(item)
|
||||
|
||||
|
||||
def _coerce_reasonable_int(value: Any, minimum: int = 1024, maximum: int = 10_000_000) -> Optional[int]:
|
||||
try:
|
||||
if isinstance(value, bool):
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
value = value.strip().replace(",", "")
|
||||
result = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if minimum <= result <= maximum:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _extract_first_int(payload: Dict[str, Any], keys: tuple[str, ...]) -> Optional[int]:
|
||||
keyset = {key.lower() for key in keys}
|
||||
for mapping in _iter_nested_dicts(payload):
|
||||
for key, value in mapping.items():
|
||||
if str(key).lower() not in keyset:
|
||||
continue
|
||||
coerced = _coerce_reasonable_int(value)
|
||||
if coerced is not None:
|
||||
return coerced
|
||||
return None
|
||||
|
||||
|
||||
def _extract_context_length(payload: Dict[str, Any]) -> Optional[int]:
|
||||
return _extract_first_int(payload, _CONTEXT_LENGTH_KEYS)
|
||||
|
||||
|
||||
def _extract_max_completion_tokens(payload: Dict[str, Any]) -> Optional[int]:
|
||||
return _extract_first_int(payload, _MAX_COMPLETION_KEYS)
|
||||
|
||||
|
||||
def _extract_pricing(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
alias_map = {
|
||||
"prompt": ("prompt", "input", "input_cost_per_token", "prompt_token_cost"),
|
||||
"completion": ("completion", "output", "output_cost_per_token", "completion_token_cost"),
|
||||
"request": ("request", "request_cost"),
|
||||
"cache_read": ("cache_read", "cached_prompt", "input_cache_read", "cache_read_cost_per_token"),
|
||||
"cache_write": ("cache_write", "cache_creation", "input_cache_write", "cache_write_cost_per_token"),
|
||||
}
|
||||
for mapping in _iter_nested_dicts(payload):
|
||||
normalized = {str(key).lower(): value for key, value in mapping.items()}
|
||||
if not any(any(alias in normalized for alias in aliases) for aliases in alias_map.values()):
|
||||
continue
|
||||
pricing: Dict[str, Any] = {}
|
||||
for target, aliases in alias_map.items():
|
||||
for alias in aliases:
|
||||
if alias in normalized and normalized[alias] not in (None, ""):
|
||||
pricing[target] = normalized[alias]
|
||||
break
|
||||
if pricing:
|
||||
return pricing
|
||||
return {}
|
||||
|
||||
|
||||
def _add_model_aliases(cache: Dict[str, Dict[str, Any]], model_id: str, entry: Dict[str, Any]) -> None:
|
||||
cache[model_id] = entry
|
||||
if "/" in model_id:
|
||||
bare_model = model_id.split("/", 1)[1]
|
||||
cache.setdefault(bare_model, entry)
|
||||
|
||||
|
||||
def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]:
|
||||
"""Fetch model metadata from OpenRouter (cached for 1 hour)."""
|
||||
|
|
@ -139,15 +265,16 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any
|
|||
cache = {}
|
||||
for model in data.get("data", []):
|
||||
model_id = model.get("id", "")
|
||||
cache[model_id] = {
|
||||
entry = {
|
||||
"context_length": model.get("context_length", 128000),
|
||||
"max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096),
|
||||
"name": model.get("name", model_id),
|
||||
"pricing": model.get("pricing", {}),
|
||||
}
|
||||
_add_model_aliases(cache, model_id, entry)
|
||||
canonical = model.get("canonical_slug", "")
|
||||
if canonical and canonical != model_id:
|
||||
cache[canonical] = cache[model_id]
|
||||
_add_model_aliases(cache, canonical, entry)
|
||||
|
||||
_model_metadata_cache = cache
|
||||
_model_metadata_cache_time = time.time()
|
||||
|
|
@ -159,6 +286,75 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any
|
|||
return _model_metadata_cache or {}
|
||||
|
||||
|
||||
def fetch_endpoint_model_metadata(
|
||||
base_url: str,
|
||||
api_key: str = "",
|
||||
force_refresh: bool = False,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""Fetch model metadata from an OpenAI-compatible ``/models`` endpoint.
|
||||
|
||||
This is used for explicit custom endpoints where hardcoded global model-name
|
||||
defaults are unreliable. Results are cached in memory per base URL.
|
||||
"""
|
||||
normalized = _normalize_base_url(base_url)
|
||||
if not normalized or _is_openrouter_base_url(normalized):
|
||||
return {}
|
||||
|
||||
if not force_refresh:
|
||||
cached = _endpoint_model_metadata_cache.get(normalized)
|
||||
cached_at = _endpoint_model_metadata_cache_time.get(normalized, 0)
|
||||
if cached is not None and (time.time() - cached_at) < _ENDPOINT_MODEL_CACHE_TTL:
|
||||
return cached
|
||||
|
||||
candidates = [normalized]
|
||||
if normalized.endswith("/v1"):
|
||||
alternate = normalized[:-3].rstrip("/")
|
||||
else:
|
||||
alternate = normalized + "/v1"
|
||||
if alternate and alternate not in candidates:
|
||||
candidates.append(alternate)
|
||||
|
||||
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for candidate in candidates:
|
||||
url = candidate.rstrip("/") + "/models"
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
cache: Dict[str, Dict[str, Any]] = {}
|
||||
for model in payload.get("data", []):
|
||||
if not isinstance(model, dict):
|
||||
continue
|
||||
model_id = model.get("id")
|
||||
if not model_id:
|
||||
continue
|
||||
entry: Dict[str, Any] = {"name": model.get("name", model_id)}
|
||||
context_length = _extract_context_length(model)
|
||||
if context_length is not None:
|
||||
entry["context_length"] = context_length
|
||||
max_completion_tokens = _extract_max_completion_tokens(model)
|
||||
if max_completion_tokens is not None:
|
||||
entry["max_completion_tokens"] = max_completion_tokens
|
||||
pricing = _extract_pricing(model)
|
||||
if pricing:
|
||||
entry["pricing"] = pricing
|
||||
_add_model_aliases(cache, model_id, entry)
|
||||
|
||||
_endpoint_model_metadata_cache[normalized] = cache
|
||||
_endpoint_model_metadata_cache_time[normalized] = time.time()
|
||||
return cache
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
|
||||
if last_error:
|
||||
logger.debug("Failed to fetch model metadata from %s/models: %s", normalized, last_error)
|
||||
_endpoint_model_metadata_cache[normalized] = {}
|
||||
_endpoint_model_metadata_cache_time[normalized] = time.time()
|
||||
return {}
|
||||
|
||||
|
||||
def _get_context_cache_path() -> Path:
|
||||
"""Return path to the persistent context length cache file."""
|
||||
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
|
@ -243,14 +439,15 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
|
|||
return None
|
||||
|
||||
|
||||
def get_model_context_length(model: str, base_url: str = "") -> int:
|
||||
def get_model_context_length(model: str, base_url: str = "", api_key: str = "") -> int:
|
||||
"""Get the context length for a model.
|
||||
|
||||
Resolution order:
|
||||
1. Persistent cache (previously discovered via probing)
|
||||
2. OpenRouter API metadata
|
||||
3. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match)
|
||||
4. First probe tier (2M) — will be narrowed on first context error
|
||||
2. Active endpoint metadata (/models for explicit custom endpoints)
|
||||
3. OpenRouter API metadata
|
||||
4. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only)
|
||||
5. First probe tier (2M) — will be narrowed on first context error
|
||||
"""
|
||||
# 1. Check persistent cache (model+provider)
|
||||
if base_url:
|
||||
|
|
@ -258,19 +455,31 @@ def get_model_context_length(model: str, base_url: str = "") -> int:
|
|||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# 2. OpenRouter API metadata
|
||||
# 2. Active endpoint metadata for explicit custom routes
|
||||
if _is_custom_endpoint(base_url):
|
||||
endpoint_metadata = fetch_endpoint_model_metadata(base_url, api_key=api_key)
|
||||
if model in endpoint_metadata:
|
||||
context_length = endpoint_metadata[model].get("context_length")
|
||||
if isinstance(context_length, int):
|
||||
return context_length
|
||||
if not _is_known_provider_base_url(base_url):
|
||||
# Explicit third-party endpoints should not borrow fuzzy global
|
||||
# defaults from unrelated providers with similarly named models.
|
||||
return CONTEXT_PROBE_TIERS[0]
|
||||
|
||||
# 3. OpenRouter API metadata
|
||||
metadata = fetch_model_metadata()
|
||||
if model in metadata:
|
||||
return metadata[model].get("context_length", 128000)
|
||||
|
||||
# 3. Hardcoded defaults (fuzzy match — longest key first for specificity)
|
||||
# 4. Hardcoded defaults (fuzzy match — longest key first for specificity)
|
||||
for default_model, length in sorted(
|
||||
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
|
||||
):
|
||||
if default_model in model or model in default_model:
|
||||
return length
|
||||
|
||||
# 4. Unknown model — start at highest probe tier
|
||||
# 5. Unknown model — start at highest probe tier
|
||||
return CONTEXT_PROBE_TIERS[0]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from datetime import datetime, timezone
|
|||
from decimal import Decimal
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from agent.model_metadata import fetch_model_metadata
|
||||
from agent.model_metadata import fetch_endpoint_model_metadata, fetch_model_metadata
|
||||
|
||||
DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
|
||||
|
||||
|
|
@ -335,8 +335,21 @@ def _lookup_official_docs_pricing(route: BillingRoute) -> Optional[PricingEntry]
|
|||
|
||||
|
||||
def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]:
|
||||
metadata = fetch_model_metadata()
|
||||
model_id = route.model
|
||||
return _pricing_entry_from_metadata(
|
||||
fetch_model_metadata(),
|
||||
route.model,
|
||||
source_url="https://openrouter.ai/docs/api/api-reference/models/get-models",
|
||||
pricing_version="openrouter-models-api",
|
||||
)
|
||||
|
||||
|
||||
def _pricing_entry_from_metadata(
|
||||
metadata: Dict[str, Dict[str, Any]],
|
||||
model_id: str,
|
||||
*,
|
||||
source_url: str,
|
||||
pricing_version: str,
|
||||
) -> Optional[PricingEntry]:
|
||||
if model_id not in metadata:
|
||||
return None
|
||||
pricing = metadata[model_id].get("pricing") or {}
|
||||
|
|
@ -355,6 +368,7 @@ def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]:
|
|||
)
|
||||
if prompt is None and completion is None and request is None:
|
||||
return None
|
||||
|
||||
def _per_token_to_per_million(value: Optional[Decimal]) -> Optional[Decimal]:
|
||||
if value is None:
|
||||
return None
|
||||
|
|
@ -367,8 +381,8 @@ def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]:
|
|||
cache_write_cost_per_million=_per_token_to_per_million(cache_write),
|
||||
request_cost=request,
|
||||
source="provider_models_api",
|
||||
source_url="https://openrouter.ai/docs/api/api-reference/models/get-models",
|
||||
pricing_version="openrouter-models-api",
|
||||
source_url=source_url,
|
||||
pricing_version=pricing_version,
|
||||
fetched_at=_UTC_NOW(),
|
||||
)
|
||||
|
||||
|
|
@ -377,6 +391,7 @@ def get_pricing_entry(
|
|||
model_name: str,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Optional[PricingEntry]:
|
||||
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
||||
if route.billing_mode == "subscription_included":
|
||||
|
|
@ -390,6 +405,15 @@ def get_pricing_entry(
|
|||
)
|
||||
if route.provider == "openrouter":
|
||||
return _openrouter_pricing_entry(route)
|
||||
if route.base_url:
|
||||
entry = _pricing_entry_from_metadata(
|
||||
fetch_endpoint_model_metadata(route.base_url, api_key=api_key or ""),
|
||||
route.model,
|
||||
source_url=f"{route.base_url.rstrip('/')}/models",
|
||||
pricing_version="openai-compatible-models-api",
|
||||
)
|
||||
if entry:
|
||||
return entry
|
||||
return _lookup_official_docs_pricing(route)
|
||||
|
||||
|
||||
|
|
@ -460,6 +484,7 @@ def estimate_usage_cost(
|
|||
*,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> CostResult:
|
||||
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
||||
if route.billing_mode == "subscription_included":
|
||||
|
|
@ -471,7 +496,7 @@ def estimate_usage_cost(
|
|||
pricing_version="included-route",
|
||||
)
|
||||
|
||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url)
|
||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
|
||||
if not entry:
|
||||
return CostResult(amount_usd=None, status="unknown", source="none", label="n/a")
|
||||
|
||||
|
|
@ -536,6 +561,7 @@ def has_known_pricing(
|
|||
model_name: str,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Check whether we have pricing data for this model+route.
|
||||
|
||||
|
|
@ -545,7 +571,7 @@ def has_known_pricing(
|
|||
route = resolve_billing_route(model_name, provider=provider, base_url=base_url)
|
||||
if route.billing_mode == "subscription_included":
|
||||
return True
|
||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url)
|
||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
|
||||
return entry is not None
|
||||
|
||||
|
||||
|
|
@ -553,13 +579,14 @@ def get_pricing(
|
|||
model_name: str,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Backward-compatible thin wrapper for legacy callers.
|
||||
|
||||
Returns only non-cache input/output fields when a pricing entry exists.
|
||||
Unknown routes return zeroes.
|
||||
"""
|
||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url)
|
||||
entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key)
|
||||
if not entry:
|
||||
return {"input": 0.0, "output": 0.0}
|
||||
return {
|
||||
|
|
@ -575,6 +602,7 @@ def estimate_cost_usd(
|
|||
*,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> float:
|
||||
"""Backward-compatible helper for legacy callers.
|
||||
|
||||
|
|
@ -586,6 +614,7 @@ def estimate_cost_usd(
|
|||
CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens),
|
||||
provider=provider,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
return float(result.amount_usd or _ZERO)
|
||||
|
||||
|
|
|
|||
|
|
@ -276,6 +276,7 @@ def get_tool_definitions(
|
|||
# The registry still holds their schemas; dispatch just returns a stub error
|
||||
# so if something slips through, the LLM sees a sensible message.
|
||||
_AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"}
|
||||
_READ_SEARCH_TOOLS = {"read_file", "search_files"}
|
||||
|
||||
|
||||
def handle_function_call(
|
||||
|
|
@ -305,7 +306,6 @@ def handle_function_call(
|
|||
"""
|
||||
# Notify the read-loop tracker when a non-read/search tool runs,
|
||||
# so the *consecutive* counter resets (reads after other work are fine).
|
||||
_READ_SEARCH_TOOLS = {"read_file", "search_files"}
|
||||
if function_name not in _READ_SEARCH_TOOLS:
|
||||
try:
|
||||
from tools.file_tools import notify_other_tool_call
|
||||
|
|
|
|||
73
run_agent.py
73
run_agent.py
|
|
@ -263,11 +263,20 @@ def _inject_honcho_turn_context(content, turn_context: str):
|
|||
class AIAgent:
|
||||
"""
|
||||
AI Agent with tool calling capabilities.
|
||||
|
||||
|
||||
This class manages the conversation flow, tool execution, and response handling
|
||||
for AI models that support function calling.
|
||||
"""
|
||||
|
||||
|
||||
@property
|
||||
def base_url(self) -> str:
|
||||
return self._base_url
|
||||
|
||||
@base_url.setter
|
||||
def base_url(self, value: str) -> None:
|
||||
self._base_url = value
|
||||
self._base_url_lower = value.lower() if value else ""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = None,
|
||||
|
|
@ -383,10 +392,10 @@ class AIAgent:
|
|||
self.api_mode = api_mode
|
||||
elif self.provider == "openai-codex":
|
||||
self.api_mode = "codex_responses"
|
||||
elif (provider_name is None) and "chatgpt.com/backend-api/codex" in self.base_url.lower():
|
||||
elif (provider_name is None) and "chatgpt.com/backend-api/codex" in self._base_url_lower:
|
||||
self.api_mode = "codex_responses"
|
||||
self.provider = "openai-codex"
|
||||
elif self.provider == "anthropic" or (provider_name is None and "api.anthropic.com" in self.base_url.lower()):
|
||||
elif self.provider == "anthropic" or (provider_name is None and "api.anthropic.com" in self._base_url_lower):
|
||||
self.api_mode = "anthropic_messages"
|
||||
self.provider = "anthropic"
|
||||
else:
|
||||
|
|
@ -395,7 +404,7 @@ class AIAgent:
|
|||
# Pre-warm OpenRouter model metadata cache in a background thread.
|
||||
# fetch_model_metadata() is cached for 1 hour; this avoids a blocking
|
||||
# HTTP request on the first API response when pricing is estimated.
|
||||
if self.provider == "openrouter" or "openrouter" in self.base_url.lower():
|
||||
if self.provider == "openrouter" or "openrouter" in self._base_url_lower:
|
||||
threading.Thread(
|
||||
target=lambda: fetch_model_metadata(),
|
||||
daemon=True,
|
||||
|
|
@ -439,7 +448,7 @@ class AIAgent:
|
|||
# Anthropic prompt caching: auto-enabled for Claude models via OpenRouter.
|
||||
# Reduces input costs by ~75% on multi-turn conversations by caching the
|
||||
# conversation prefix. Uses system_and_3 strategy (4 breakpoints).
|
||||
is_openrouter = "openrouter" in self.base_url.lower()
|
||||
is_openrouter = "openrouter" in self._base_url_lower
|
||||
is_claude = "claude" in self.model.lower()
|
||||
is_native_anthropic = self.api_mode == "anthropic_messages"
|
||||
self._use_prompt_caching = (is_openrouter and is_claude) or is_native_anthropic
|
||||
|
|
@ -555,6 +564,7 @@ class AIAgent:
|
|||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token
|
||||
effective_key = api_key or resolve_anthropic_token() or ""
|
||||
self.api_key = effective_key
|
||||
self._anthropic_api_key = effective_key
|
||||
self._anthropic_base_url = base_url
|
||||
from agent.anthropic_adapter import _is_oauth_token as _is_oat
|
||||
|
|
@ -609,6 +619,7 @@ class AIAgent:
|
|||
}
|
||||
|
||||
self._client_kwargs = client_kwargs # stored for rebuilding after interrupt
|
||||
self.api_key = client_kwargs.get("api_key", "")
|
||||
try:
|
||||
self.client = self._create_openai_client(client_kwargs, reason="agent_init", shared=True)
|
||||
if not self.quiet_mode:
|
||||
|
|
@ -732,6 +743,13 @@ class AIAgent:
|
|||
from tools.todo_tool import TodoStore
|
||||
self._todo_store = TodoStore()
|
||||
|
||||
# Load config once for memory, skills, and compression sections
|
||||
try:
|
||||
from hermes_cli.config import load_config as _load_agent_config
|
||||
_agent_cfg = _load_agent_config()
|
||||
except Exception:
|
||||
_agent_cfg = {}
|
||||
|
||||
# Persistent memory (MEMORY.md + USER.md) -- loaded from disk
|
||||
self._memory_store = None
|
||||
self._memory_enabled = False
|
||||
|
|
@ -742,8 +760,7 @@ class AIAgent:
|
|||
self._iters_since_skill = 0
|
||||
if not skip_memory:
|
||||
try:
|
||||
from hermes_cli.config import load_config as _load_mem_config
|
||||
mem_config = _load_mem_config().get("memory", {})
|
||||
mem_config = _agent_cfg.get("memory", {})
|
||||
self._memory_enabled = mem_config.get("memory_enabled", False)
|
||||
self._user_profile_enabled = mem_config.get("user_profile_enabled", False)
|
||||
self._memory_nudge_interval = int(mem_config.get("nudge_interval", 10))
|
||||
|
|
@ -831,21 +848,16 @@ class AIAgent:
|
|||
# Skills config: nudge interval for skill creation reminders
|
||||
self._skill_nudge_interval = 10
|
||||
try:
|
||||
from hermes_cli.config import load_config as _load_skills_config
|
||||
skills_config = _load_skills_config().get("skills", {})
|
||||
skills_config = _agent_cfg.get("skills", {})
|
||||
self._skill_nudge_interval = int(skills_config.get("creation_nudge_interval", 15))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Initialize context compressor for automatic context management
|
||||
# Compresses conversation when approaching model's context limit
|
||||
# Configuration via config.yaml (compression section)
|
||||
try:
|
||||
from hermes_cli.config import load_config as _load_compression_config
|
||||
_compression_cfg = _load_compression_config().get("compression", {})
|
||||
if not isinstance(_compression_cfg, dict):
|
||||
_compression_cfg = {}
|
||||
except ImportError:
|
||||
_compression_cfg = _agent_cfg.get("compression", {})
|
||||
if not isinstance(_compression_cfg, dict):
|
||||
_compression_cfg = {}
|
||||
compression_threshold = float(_compression_cfg.get("threshold", 0.50))
|
||||
compression_enabled = str(_compression_cfg.get("enabled", True)).lower() in ("true", "1", "yes")
|
||||
|
|
@ -860,6 +872,7 @@ class AIAgent:
|
|||
summary_model_override=compression_summary_model,
|
||||
quiet_mode=self.quiet_mode,
|
||||
base_url=self.base_url,
|
||||
api_key=getattr(self, "api_key", ""),
|
||||
)
|
||||
self.compression_enabled = compression_enabled
|
||||
self._user_turn_count = 0
|
||||
|
|
@ -915,8 +928,8 @@ class AIAgent:
|
|||
OpenAI models use 'max_tokens'.
|
||||
"""
|
||||
_is_direct_openai = (
|
||||
"api.openai.com" in self.base_url.lower()
|
||||
and "openrouter" not in self.base_url.lower()
|
||||
"api.openai.com" in self._base_url_lower
|
||||
and "openrouter" not in self._base_url_lower
|
||||
)
|
||||
if _is_direct_openai:
|
||||
return {"max_completion_tokens": value}
|
||||
|
|
@ -3643,7 +3656,7 @@ class AIAgent:
|
|||
|
||||
extra_body = {}
|
||||
|
||||
_is_openrouter = "openrouter" in self.base_url.lower()
|
||||
_is_openrouter = "openrouter" in self._base_url_lower
|
||||
|
||||
# Provider preferences (only, ignore, order, sort) are OpenRouter-
|
||||
# specific. Only send to OpenRouter-compatible endpoints.
|
||||
|
|
@ -3651,7 +3664,7 @@ class AIAgent:
|
|||
# for _is_nous when their backend is updated.
|
||||
if provider_preferences and _is_openrouter:
|
||||
extra_body["provider"] = provider_preferences
|
||||
_is_nous = "nousresearch" in self.base_url.lower()
|
||||
_is_nous = "nousresearch" in self._base_url_lower
|
||||
|
||||
if self._supports_reasoning_extra_body():
|
||||
if self.reasoning_config is not None:
|
||||
|
|
@ -3684,14 +3697,13 @@ class AIAgent:
|
|||
Some providers/routes reject `reasoning` with 400s, so gate it to
|
||||
known reasoning-capable model families and direct Nous Portal.
|
||||
"""
|
||||
base_url = (self.base_url or "").lower()
|
||||
if "nousresearch" in base_url:
|
||||
if "nousresearch" in self._base_url_lower:
|
||||
return True
|
||||
if "ai-gateway.vercel.sh" in base_url:
|
||||
if "ai-gateway.vercel.sh" in self._base_url_lower:
|
||||
return True
|
||||
if "openrouter" not in base_url:
|
||||
if "openrouter" not in self._base_url_lower:
|
||||
return False
|
||||
if "api.mistral.ai" in base_url:
|
||||
if "api.mistral.ai" in self._base_url_lower:
|
||||
return False
|
||||
|
||||
model = (self.model or "").lower()
|
||||
|
|
@ -3877,7 +3889,7 @@ class AIAgent:
|
|||
|
||||
try:
|
||||
# Build API messages for the flush call
|
||||
_is_strict_api = "api.mistral.ai" in self.base_url.lower()
|
||||
_is_strict_api = "api.mistral.ai" in self._base_url_lower
|
||||
api_messages = []
|
||||
for msg in messages:
|
||||
api_msg = msg.copy()
|
||||
|
|
@ -4653,7 +4665,7 @@ class AIAgent:
|
|||
try:
|
||||
# Build API messages, stripping internal-only fields
|
||||
# (finish_reason, reasoning) that strict APIs like Mistral reject with 422
|
||||
_is_strict_api = "api.mistral.ai" in self.base_url.lower()
|
||||
_is_strict_api = "api.mistral.ai" in self._base_url_lower
|
||||
api_messages = []
|
||||
for msg in messages:
|
||||
api_msg = msg.copy()
|
||||
|
|
@ -4674,7 +4686,7 @@ class AIAgent:
|
|||
api_messages.insert(sys_offset + idx, pfm.copy())
|
||||
|
||||
summary_extra_body = {}
|
||||
_is_nous = "nousresearch" in self.base_url.lower()
|
||||
_is_nous = "nousresearch" in self._base_url_lower
|
||||
if self._supports_reasoning_extra_body():
|
||||
if self.reasoning_config is not None:
|
||||
summary_extra_body["reasoning"] = self.reasoning_config
|
||||
|
|
@ -5092,7 +5104,7 @@ class AIAgent:
|
|||
# strict providers like Mistral that reject unknown fields with 422.
|
||||
# Uses new dicts so the internal messages list retains the fields
|
||||
# for Codex Responses compatibility.
|
||||
if "api.mistral.ai" in self.base_url.lower():
|
||||
if "api.mistral.ai" in self._base_url_lower:
|
||||
self._sanitize_tool_calls_for_strict_api(api_msg)
|
||||
# Keep 'reasoning_details' - OpenRouter uses this for multi-turn reasoning context
|
||||
# The signature field helps maintain reasoning continuity
|
||||
|
|
@ -5464,6 +5476,7 @@ class AIAgent:
|
|||
canonical_usage,
|
||||
provider=self.provider,
|
||||
base_url=self.base_url,
|
||||
api_key=getattr(self, "api_key", ""),
|
||||
)
|
||||
if cost_result.amount_usd is not None:
|
||||
self.session_estimated_cost_usd += float(cost_result.amount_usd)
|
||||
|
|
|
|||
|
|
@ -188,6 +188,36 @@ class TestGetModelContextLength:
|
|||
result = get_model_context_length("custom/model")
|
||||
assert result == CONTEXT_PROBE_TIERS[0]
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
|
||||
def test_custom_endpoint_metadata_beats_fuzzy_default(self, mock_endpoint_fetch, mock_fetch):
|
||||
mock_fetch.return_value = {}
|
||||
mock_endpoint_fetch.return_value = {
|
||||
"zai-org/GLM-5-TEE": {"context_length": 65536}
|
||||
}
|
||||
|
||||
result = get_model_context_length(
|
||||
"zai-org/GLM-5-TEE",
|
||||
base_url="https://llm.chutes.ai/v1",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
assert result == 65536
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
@patch("agent.model_metadata.fetch_endpoint_model_metadata")
|
||||
def test_custom_endpoint_without_metadata_skips_name_based_default(self, mock_endpoint_fetch, mock_fetch):
|
||||
mock_fetch.return_value = {}
|
||||
mock_endpoint_fetch.return_value = {}
|
||||
|
||||
result = get_model_context_length(
|
||||
"zai-org/GLM-5-TEE",
|
||||
base_url="https://llm.chutes.ai/v1",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
assert result == CONTEXT_PROBE_TIERS[0]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# fetch_model_metadata — caching, TTL, slugs, failures
|
||||
|
|
@ -258,6 +288,25 @@ class TestFetchModelMetadata:
|
|||
assert "anthropic/claude-3.5-sonnet" in result
|
||||
assert result["anthropic/claude-3.5-sonnet"]["context_length"] == 200000
|
||||
|
||||
@patch("agent.model_metadata.requests.get")
|
||||
def test_provider_prefixed_models_get_bare_aliases(self, mock_get):
|
||||
self._reset_cache()
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"data": [{
|
||||
"id": "provider/test-model",
|
||||
"context_length": 123456,
|
||||
"name": "Provider: Test Model",
|
||||
}]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = fetch_model_metadata(force_refresh=True)
|
||||
|
||||
assert result["provider/test-model"]["context_length"] == 123456
|
||||
assert result["test-model"]["context_length"] == 123456
|
||||
|
||||
@patch("agent.model_metadata.requests.get")
|
||||
def test_ttl_expiry_triggers_refetch(self, mock_get):
|
||||
"""Cache expires after _MODEL_CACHE_TTL seconds."""
|
||||
|
|
|
|||
|
|
@ -99,3 +99,27 @@ def test_estimate_usage_cost_refuses_cache_pricing_without_official_cache_rate(m
|
|||
)
|
||||
|
||||
assert result.status == "unknown"
|
||||
|
||||
|
||||
def test_custom_endpoint_models_api_pricing_is_supported(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"agent.usage_pricing.fetch_endpoint_model_metadata",
|
||||
lambda base_url, api_key=None: {
|
||||
"zai-org/GLM-5-TEE": {
|
||||
"pricing": {
|
||||
"prompt": "0.0000005",
|
||||
"completion": "0.000002",
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
entry = get_pricing_entry(
|
||||
"zai-org/GLM-5-TEE",
|
||||
provider="custom",
|
||||
base_url="https://llm.chutes.ai/v1",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
assert float(entry.input_cost_per_million) == 0.5
|
||||
assert float(entry.output_cost_per_million) == 2.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue