diff --git a/agent/context_compressor.py b/agent/context_compressor.py index 8ff43da507..11b5c5b80c 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -45,16 +45,18 @@ class ContextCompressor: quiet_mode: bool = False, summary_model_override: str = None, base_url: str = "", + api_key: str = "", ): self.model = model self.base_url = base_url + self.api_key = api_key self.threshold_percent = threshold_percent self.protect_first_n = protect_first_n self.protect_last_n = protect_last_n self.summary_target_tokens = summary_target_tokens self.quiet_mode = quiet_mode - self.context_length = get_model_context_length(model, base_url=base_url) + self.context_length = get_model_context_length(model, base_url=base_url, api_key=api_key) self.threshold_tokens = int(self.context_length * threshold_percent) self.compression_count = 0 self._context_probed = False # True after a step-down from context error diff --git a/agent/model_metadata.py b/agent/model_metadata.py index fb0d38466e..8283e8d32f 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -10,6 +10,7 @@ import re import time from pathlib import Path from typing import Any, Dict, List, Optional +from urllib.parse import urlparse import requests import yaml @@ -21,6 +22,9 @@ logger = logging.getLogger(__name__) _model_metadata_cache: Dict[str, Dict[str, Any]] = {} _model_metadata_cache_time: float = 0 _MODEL_CACHE_TTL = 3600 +_endpoint_model_metadata_cache: Dict[str, Dict[str, Dict[str, Any]]] = {} +_endpoint_model_metadata_cache_time: Dict[str, float] = {} +_ENDPOINT_MODEL_CACHE_TTL = 300 # Descending tiers for context length probing when the model is unknown. # We start high and step down on context-length errors until one works. @@ -123,6 +127,128 @@ DEFAULT_CONTEXT_LENGTHS = { "qwen-vl-max": 32768, } +_CONTEXT_LENGTH_KEYS = ( + "context_length", + "context_window", + "max_context_length", + "max_position_embeddings", + "max_model_len", + "max_input_tokens", + "max_sequence_length", + "max_seq_len", +) + +_MAX_COMPLETION_KEYS = ( + "max_completion_tokens", + "max_output_tokens", + "max_tokens", +) + + +def _normalize_base_url(base_url: str) -> str: + return (base_url or "").strip().rstrip("/") + + +def _is_openrouter_base_url(base_url: str) -> bool: + return "openrouter.ai" in _normalize_base_url(base_url).lower() + + +def _is_custom_endpoint(base_url: str) -> bool: + normalized = _normalize_base_url(base_url) + return bool(normalized) and not _is_openrouter_base_url(normalized) + + +def _is_known_provider_base_url(base_url: str) -> bool: + normalized = _normalize_base_url(base_url) + if not normalized: + return False + parsed = urlparse(normalized if "://" in normalized else f"https://{normalized}") + host = parsed.netloc.lower() or parsed.path.lower() + known_hosts = ( + "api.openai.com", + "chatgpt.com", + "api.anthropic.com", + "api.z.ai", + "api.moonshot.ai", + "api.kimi.com", + "api.minimax", + ) + return any(known_host in host for known_host in known_hosts) + + +def _iter_nested_dicts(value: Any): + if isinstance(value, dict): + yield value + for nested in value.values(): + yield from _iter_nested_dicts(nested) + elif isinstance(value, list): + for item in value: + yield from _iter_nested_dicts(item) + + +def _coerce_reasonable_int(value: Any, minimum: int = 1024, maximum: int = 10_000_000) -> Optional[int]: + try: + if isinstance(value, bool): + return None + if isinstance(value, str): + value = value.strip().replace(",", "") + result = int(value) + except (TypeError, ValueError): + return None + if minimum <= result <= maximum: + return result + return None + + +def _extract_first_int(payload: Dict[str, Any], keys: tuple[str, ...]) -> Optional[int]: + keyset = {key.lower() for key in keys} + for mapping in _iter_nested_dicts(payload): + for key, value in mapping.items(): + if str(key).lower() not in keyset: + continue + coerced = _coerce_reasonable_int(value) + if coerced is not None: + return coerced + return None + + +def _extract_context_length(payload: Dict[str, Any]) -> Optional[int]: + return _extract_first_int(payload, _CONTEXT_LENGTH_KEYS) + + +def _extract_max_completion_tokens(payload: Dict[str, Any]) -> Optional[int]: + return _extract_first_int(payload, _MAX_COMPLETION_KEYS) + + +def _extract_pricing(payload: Dict[str, Any]) -> Dict[str, Any]: + alias_map = { + "prompt": ("prompt", "input", "input_cost_per_token", "prompt_token_cost"), + "completion": ("completion", "output", "output_cost_per_token", "completion_token_cost"), + "request": ("request", "request_cost"), + "cache_read": ("cache_read", "cached_prompt", "input_cache_read", "cache_read_cost_per_token"), + "cache_write": ("cache_write", "cache_creation", "input_cache_write", "cache_write_cost_per_token"), + } + for mapping in _iter_nested_dicts(payload): + normalized = {str(key).lower(): value for key, value in mapping.items()} + if not any(any(alias in normalized for alias in aliases) for aliases in alias_map.values()): + continue + pricing: Dict[str, Any] = {} + for target, aliases in alias_map.items(): + for alias in aliases: + if alias in normalized and normalized[alias] not in (None, ""): + pricing[target] = normalized[alias] + break + if pricing: + return pricing + return {} + + +def _add_model_aliases(cache: Dict[str, Dict[str, Any]], model_id: str, entry: Dict[str, Any]) -> None: + cache[model_id] = entry + if "/" in model_id: + bare_model = model_id.split("/", 1)[1] + cache.setdefault(bare_model, entry) + def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]: """Fetch model metadata from OpenRouter (cached for 1 hour).""" @@ -139,15 +265,16 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any cache = {} for model in data.get("data", []): model_id = model.get("id", "") - cache[model_id] = { + entry = { "context_length": model.get("context_length", 128000), "max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096), "name": model.get("name", model_id), "pricing": model.get("pricing", {}), } + _add_model_aliases(cache, model_id, entry) canonical = model.get("canonical_slug", "") if canonical and canonical != model_id: - cache[canonical] = cache[model_id] + _add_model_aliases(cache, canonical, entry) _model_metadata_cache = cache _model_metadata_cache_time = time.time() @@ -159,6 +286,75 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any return _model_metadata_cache or {} +def fetch_endpoint_model_metadata( + base_url: str, + api_key: str = "", + force_refresh: bool = False, +) -> Dict[str, Dict[str, Any]]: + """Fetch model metadata from an OpenAI-compatible ``/models`` endpoint. + + This is used for explicit custom endpoints where hardcoded global model-name + defaults are unreliable. Results are cached in memory per base URL. + """ + normalized = _normalize_base_url(base_url) + if not normalized or _is_openrouter_base_url(normalized): + return {} + + if not force_refresh: + cached = _endpoint_model_metadata_cache.get(normalized) + cached_at = _endpoint_model_metadata_cache_time.get(normalized, 0) + if cached is not None and (time.time() - cached_at) < _ENDPOINT_MODEL_CACHE_TTL: + return cached + + candidates = [normalized] + if normalized.endswith("/v1"): + alternate = normalized[:-3].rstrip("/") + else: + alternate = normalized + "/v1" + if alternate and alternate not in candidates: + candidates.append(alternate) + + headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} + last_error: Optional[Exception] = None + + for candidate in candidates: + url = candidate.rstrip("/") + "/models" + try: + response = requests.get(url, headers=headers, timeout=10) + response.raise_for_status() + payload = response.json() + cache: Dict[str, Dict[str, Any]] = {} + for model in payload.get("data", []): + if not isinstance(model, dict): + continue + model_id = model.get("id") + if not model_id: + continue + entry: Dict[str, Any] = {"name": model.get("name", model_id)} + context_length = _extract_context_length(model) + if context_length is not None: + entry["context_length"] = context_length + max_completion_tokens = _extract_max_completion_tokens(model) + if max_completion_tokens is not None: + entry["max_completion_tokens"] = max_completion_tokens + pricing = _extract_pricing(model) + if pricing: + entry["pricing"] = pricing + _add_model_aliases(cache, model_id, entry) + + _endpoint_model_metadata_cache[normalized] = cache + _endpoint_model_metadata_cache_time[normalized] = time.time() + return cache + except Exception as exc: + last_error = exc + + if last_error: + logger.debug("Failed to fetch model metadata from %s/models: %s", normalized, last_error) + _endpoint_model_metadata_cache[normalized] = {} + _endpoint_model_metadata_cache_time[normalized] = time.time() + return {} + + def _get_context_cache_path() -> Path: """Return path to the persistent context length cache file.""" hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) @@ -243,14 +439,15 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]: return None -def get_model_context_length(model: str, base_url: str = "") -> int: +def get_model_context_length(model: str, base_url: str = "", api_key: str = "") -> int: """Get the context length for a model. Resolution order: 1. Persistent cache (previously discovered via probing) - 2. OpenRouter API metadata - 3. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match) - 4. First probe tier (2M) — will be narrowed on first context error + 2. Active endpoint metadata (/models for explicit custom endpoints) + 3. OpenRouter API metadata + 4. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only) + 5. First probe tier (2M) — will be narrowed on first context error """ # 1. Check persistent cache (model+provider) if base_url: @@ -258,19 +455,31 @@ def get_model_context_length(model: str, base_url: str = "") -> int: if cached is not None: return cached - # 2. OpenRouter API metadata + # 2. Active endpoint metadata for explicit custom routes + if _is_custom_endpoint(base_url): + endpoint_metadata = fetch_endpoint_model_metadata(base_url, api_key=api_key) + if model in endpoint_metadata: + context_length = endpoint_metadata[model].get("context_length") + if isinstance(context_length, int): + return context_length + if not _is_known_provider_base_url(base_url): + # Explicit third-party endpoints should not borrow fuzzy global + # defaults from unrelated providers with similarly named models. + return CONTEXT_PROBE_TIERS[0] + + # 3. OpenRouter API metadata metadata = fetch_model_metadata() if model in metadata: return metadata[model].get("context_length", 128000) - # 3. Hardcoded defaults (fuzzy match — longest key first for specificity) + # 4. Hardcoded defaults (fuzzy match — longest key first for specificity) for default_model, length in sorted( DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True ): if default_model in model or model in default_model: return length - # 4. Unknown model — start at highest probe tier + # 5. Unknown model — start at highest probe tier return CONTEXT_PROBE_TIERS[0] diff --git a/agent/usage_pricing.py b/agent/usage_pricing.py index 29e7df254f..81c50026ec 100644 --- a/agent/usage_pricing.py +++ b/agent/usage_pricing.py @@ -5,7 +5,7 @@ from datetime import datetime, timezone from decimal import Decimal from typing import Any, Dict, Literal, Optional -from agent.model_metadata import fetch_model_metadata +from agent.model_metadata import fetch_endpoint_model_metadata, fetch_model_metadata DEFAULT_PRICING = {"input": 0.0, "output": 0.0} @@ -335,8 +335,21 @@ def _lookup_official_docs_pricing(route: BillingRoute) -> Optional[PricingEntry] def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]: - metadata = fetch_model_metadata() - model_id = route.model + return _pricing_entry_from_metadata( + fetch_model_metadata(), + route.model, + source_url="https://openrouter.ai/docs/api/api-reference/models/get-models", + pricing_version="openrouter-models-api", + ) + + +def _pricing_entry_from_metadata( + metadata: Dict[str, Dict[str, Any]], + model_id: str, + *, + source_url: str, + pricing_version: str, +) -> Optional[PricingEntry]: if model_id not in metadata: return None pricing = metadata[model_id].get("pricing") or {} @@ -355,6 +368,7 @@ def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]: ) if prompt is None and completion is None and request is None: return None + def _per_token_to_per_million(value: Optional[Decimal]) -> Optional[Decimal]: if value is None: return None @@ -367,8 +381,8 @@ def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]: cache_write_cost_per_million=_per_token_to_per_million(cache_write), request_cost=request, source="provider_models_api", - source_url="https://openrouter.ai/docs/api/api-reference/models/get-models", - pricing_version="openrouter-models-api", + source_url=source_url, + pricing_version=pricing_version, fetched_at=_UTC_NOW(), ) @@ -377,6 +391,7 @@ def get_pricing_entry( model_name: str, provider: Optional[str] = None, base_url: Optional[str] = None, + api_key: Optional[str] = None, ) -> Optional[PricingEntry]: route = resolve_billing_route(model_name, provider=provider, base_url=base_url) if route.billing_mode == "subscription_included": @@ -390,6 +405,15 @@ def get_pricing_entry( ) if route.provider == "openrouter": return _openrouter_pricing_entry(route) + if route.base_url: + entry = _pricing_entry_from_metadata( + fetch_endpoint_model_metadata(route.base_url, api_key=api_key or ""), + route.model, + source_url=f"{route.base_url.rstrip('/')}/models", + pricing_version="openai-compatible-models-api", + ) + if entry: + return entry return _lookup_official_docs_pricing(route) @@ -460,6 +484,7 @@ def estimate_usage_cost( *, provider: Optional[str] = None, base_url: Optional[str] = None, + api_key: Optional[str] = None, ) -> CostResult: route = resolve_billing_route(model_name, provider=provider, base_url=base_url) if route.billing_mode == "subscription_included": @@ -471,7 +496,7 @@ def estimate_usage_cost( pricing_version="included-route", ) - entry = get_pricing_entry(model_name, provider=provider, base_url=base_url) + entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key) if not entry: return CostResult(amount_usd=None, status="unknown", source="none", label="n/a") @@ -536,6 +561,7 @@ def has_known_pricing( model_name: str, provider: Optional[str] = None, base_url: Optional[str] = None, + api_key: Optional[str] = None, ) -> bool: """Check whether we have pricing data for this model+route. @@ -545,7 +571,7 @@ def has_known_pricing( route = resolve_billing_route(model_name, provider=provider, base_url=base_url) if route.billing_mode == "subscription_included": return True - entry = get_pricing_entry(model_name, provider=provider, base_url=base_url) + entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key) return entry is not None @@ -553,13 +579,14 @@ def get_pricing( model_name: str, provider: Optional[str] = None, base_url: Optional[str] = None, + api_key: Optional[str] = None, ) -> Dict[str, float]: """Backward-compatible thin wrapper for legacy callers. Returns only non-cache input/output fields when a pricing entry exists. Unknown routes return zeroes. """ - entry = get_pricing_entry(model_name, provider=provider, base_url=base_url) + entry = get_pricing_entry(model_name, provider=provider, base_url=base_url, api_key=api_key) if not entry: return {"input": 0.0, "output": 0.0} return { @@ -575,6 +602,7 @@ def estimate_cost_usd( *, provider: Optional[str] = None, base_url: Optional[str] = None, + api_key: Optional[str] = None, ) -> float: """Backward-compatible helper for legacy callers. @@ -586,6 +614,7 @@ def estimate_cost_usd( CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens), provider=provider, base_url=base_url, + api_key=api_key, ) return float(result.amount_usd or _ZERO) diff --git a/model_tools.py b/model_tools.py index 87d5210918..3d252f4498 100644 --- a/model_tools.py +++ b/model_tools.py @@ -276,6 +276,7 @@ def get_tool_definitions( # The registry still holds their schemas; dispatch just returns a stub error # so if something slips through, the LLM sees a sensible message. _AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"} +_READ_SEARCH_TOOLS = {"read_file", "search_files"} def handle_function_call( @@ -305,7 +306,6 @@ def handle_function_call( """ # Notify the read-loop tracker when a non-read/search tool runs, # so the *consecutive* counter resets (reads after other work are fine). - _READ_SEARCH_TOOLS = {"read_file", "search_files"} if function_name not in _READ_SEARCH_TOOLS: try: from tools.file_tools import notify_other_tool_call diff --git a/run_agent.py b/run_agent.py index 0ce3919fcf..f0e8f25dbc 100644 --- a/run_agent.py +++ b/run_agent.py @@ -263,11 +263,20 @@ def _inject_honcho_turn_context(content, turn_context: str): class AIAgent: """ AI Agent with tool calling capabilities. - + This class manages the conversation flow, tool execution, and response handling for AI models that support function calling. """ - + + @property + def base_url(self) -> str: + return self._base_url + + @base_url.setter + def base_url(self, value: str) -> None: + self._base_url = value + self._base_url_lower = value.lower() if value else "" + def __init__( self, base_url: str = None, @@ -383,10 +392,10 @@ class AIAgent: self.api_mode = api_mode elif self.provider == "openai-codex": self.api_mode = "codex_responses" - elif (provider_name is None) and "chatgpt.com/backend-api/codex" in self.base_url.lower(): + elif (provider_name is None) and "chatgpt.com/backend-api/codex" in self._base_url_lower: self.api_mode = "codex_responses" self.provider = "openai-codex" - elif self.provider == "anthropic" or (provider_name is None and "api.anthropic.com" in self.base_url.lower()): + elif self.provider == "anthropic" or (provider_name is None and "api.anthropic.com" in self._base_url_lower): self.api_mode = "anthropic_messages" self.provider = "anthropic" else: @@ -395,7 +404,7 @@ class AIAgent: # Pre-warm OpenRouter model metadata cache in a background thread. # fetch_model_metadata() is cached for 1 hour; this avoids a blocking # HTTP request on the first API response when pricing is estimated. - if self.provider == "openrouter" or "openrouter" in self.base_url.lower(): + if self.provider == "openrouter" or "openrouter" in self._base_url_lower: threading.Thread( target=lambda: fetch_model_metadata(), daemon=True, @@ -439,7 +448,7 @@ class AIAgent: # Anthropic prompt caching: auto-enabled for Claude models via OpenRouter. # Reduces input costs by ~75% on multi-turn conversations by caching the # conversation prefix. Uses system_and_3 strategy (4 breakpoints). - is_openrouter = "openrouter" in self.base_url.lower() + is_openrouter = "openrouter" in self._base_url_lower is_claude = "claude" in self.model.lower() is_native_anthropic = self.api_mode == "anthropic_messages" self._use_prompt_caching = (is_openrouter and is_claude) or is_native_anthropic @@ -555,6 +564,7 @@ class AIAgent: if self.api_mode == "anthropic_messages": from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token effective_key = api_key or resolve_anthropic_token() or "" + self.api_key = effective_key self._anthropic_api_key = effective_key self._anthropic_base_url = base_url from agent.anthropic_adapter import _is_oauth_token as _is_oat @@ -609,6 +619,7 @@ class AIAgent: } self._client_kwargs = client_kwargs # stored for rebuilding after interrupt + self.api_key = client_kwargs.get("api_key", "") try: self.client = self._create_openai_client(client_kwargs, reason="agent_init", shared=True) if not self.quiet_mode: @@ -732,6 +743,13 @@ class AIAgent: from tools.todo_tool import TodoStore self._todo_store = TodoStore() + # Load config once for memory, skills, and compression sections + try: + from hermes_cli.config import load_config as _load_agent_config + _agent_cfg = _load_agent_config() + except Exception: + _agent_cfg = {} + # Persistent memory (MEMORY.md + USER.md) -- loaded from disk self._memory_store = None self._memory_enabled = False @@ -742,8 +760,7 @@ class AIAgent: self._iters_since_skill = 0 if not skip_memory: try: - from hermes_cli.config import load_config as _load_mem_config - mem_config = _load_mem_config().get("memory", {}) + mem_config = _agent_cfg.get("memory", {}) self._memory_enabled = mem_config.get("memory_enabled", False) self._user_profile_enabled = mem_config.get("user_profile_enabled", False) self._memory_nudge_interval = int(mem_config.get("nudge_interval", 10)) @@ -831,21 +848,16 @@ class AIAgent: # Skills config: nudge interval for skill creation reminders self._skill_nudge_interval = 10 try: - from hermes_cli.config import load_config as _load_skills_config - skills_config = _load_skills_config().get("skills", {}) + skills_config = _agent_cfg.get("skills", {}) self._skill_nudge_interval = int(skills_config.get("creation_nudge_interval", 15)) except Exception: pass - + # Initialize context compressor for automatic context management # Compresses conversation when approaching model's context limit # Configuration via config.yaml (compression section) - try: - from hermes_cli.config import load_config as _load_compression_config - _compression_cfg = _load_compression_config().get("compression", {}) - if not isinstance(_compression_cfg, dict): - _compression_cfg = {} - except ImportError: + _compression_cfg = _agent_cfg.get("compression", {}) + if not isinstance(_compression_cfg, dict): _compression_cfg = {} compression_threshold = float(_compression_cfg.get("threshold", 0.50)) compression_enabled = str(_compression_cfg.get("enabled", True)).lower() in ("true", "1", "yes") @@ -860,6 +872,7 @@ class AIAgent: summary_model_override=compression_summary_model, quiet_mode=self.quiet_mode, base_url=self.base_url, + api_key=getattr(self, "api_key", ""), ) self.compression_enabled = compression_enabled self._user_turn_count = 0 @@ -915,8 +928,8 @@ class AIAgent: OpenAI models use 'max_tokens'. """ _is_direct_openai = ( - "api.openai.com" in self.base_url.lower() - and "openrouter" not in self.base_url.lower() + "api.openai.com" in self._base_url_lower + and "openrouter" not in self._base_url_lower ) if _is_direct_openai: return {"max_completion_tokens": value} @@ -3643,7 +3656,7 @@ class AIAgent: extra_body = {} - _is_openrouter = "openrouter" in self.base_url.lower() + _is_openrouter = "openrouter" in self._base_url_lower # Provider preferences (only, ignore, order, sort) are OpenRouter- # specific. Only send to OpenRouter-compatible endpoints. @@ -3651,7 +3664,7 @@ class AIAgent: # for _is_nous when their backend is updated. if provider_preferences and _is_openrouter: extra_body["provider"] = provider_preferences - _is_nous = "nousresearch" in self.base_url.lower() + _is_nous = "nousresearch" in self._base_url_lower if self._supports_reasoning_extra_body(): if self.reasoning_config is not None: @@ -3684,14 +3697,13 @@ class AIAgent: Some providers/routes reject `reasoning` with 400s, so gate it to known reasoning-capable model families and direct Nous Portal. """ - base_url = (self.base_url or "").lower() - if "nousresearch" in base_url: + if "nousresearch" in self._base_url_lower: return True - if "ai-gateway.vercel.sh" in base_url: + if "ai-gateway.vercel.sh" in self._base_url_lower: return True - if "openrouter" not in base_url: + if "openrouter" not in self._base_url_lower: return False - if "api.mistral.ai" in base_url: + if "api.mistral.ai" in self._base_url_lower: return False model = (self.model or "").lower() @@ -3877,7 +3889,7 @@ class AIAgent: try: # Build API messages for the flush call - _is_strict_api = "api.mistral.ai" in self.base_url.lower() + _is_strict_api = "api.mistral.ai" in self._base_url_lower api_messages = [] for msg in messages: api_msg = msg.copy() @@ -4653,7 +4665,7 @@ class AIAgent: try: # Build API messages, stripping internal-only fields # (finish_reason, reasoning) that strict APIs like Mistral reject with 422 - _is_strict_api = "api.mistral.ai" in self.base_url.lower() + _is_strict_api = "api.mistral.ai" in self._base_url_lower api_messages = [] for msg in messages: api_msg = msg.copy() @@ -4674,7 +4686,7 @@ class AIAgent: api_messages.insert(sys_offset + idx, pfm.copy()) summary_extra_body = {} - _is_nous = "nousresearch" in self.base_url.lower() + _is_nous = "nousresearch" in self._base_url_lower if self._supports_reasoning_extra_body(): if self.reasoning_config is not None: summary_extra_body["reasoning"] = self.reasoning_config @@ -5092,7 +5104,7 @@ class AIAgent: # strict providers like Mistral that reject unknown fields with 422. # Uses new dicts so the internal messages list retains the fields # for Codex Responses compatibility. - if "api.mistral.ai" in self.base_url.lower(): + if "api.mistral.ai" in self._base_url_lower: self._sanitize_tool_calls_for_strict_api(api_msg) # Keep 'reasoning_details' - OpenRouter uses this for multi-turn reasoning context # The signature field helps maintain reasoning continuity @@ -5464,6 +5476,7 @@ class AIAgent: canonical_usage, provider=self.provider, base_url=self.base_url, + api_key=getattr(self, "api_key", ""), ) if cost_result.amount_usd is not None: self.session_estimated_cost_usd += float(cost_result.amount_usd) diff --git a/tests/agent/test_model_metadata.py b/tests/agent/test_model_metadata.py index 75570e343a..aa35be9b93 100644 --- a/tests/agent/test_model_metadata.py +++ b/tests/agent/test_model_metadata.py @@ -188,6 +188,36 @@ class TestGetModelContextLength: result = get_model_context_length("custom/model") assert result == CONTEXT_PROBE_TIERS[0] + @patch("agent.model_metadata.fetch_model_metadata") + @patch("agent.model_metadata.fetch_endpoint_model_metadata") + def test_custom_endpoint_metadata_beats_fuzzy_default(self, mock_endpoint_fetch, mock_fetch): + mock_fetch.return_value = {} + mock_endpoint_fetch.return_value = { + "zai-org/GLM-5-TEE": {"context_length": 65536} + } + + result = get_model_context_length( + "zai-org/GLM-5-TEE", + base_url="https://llm.chutes.ai/v1", + api_key="test-key", + ) + + assert result == 65536 + + @patch("agent.model_metadata.fetch_model_metadata") + @patch("agent.model_metadata.fetch_endpoint_model_metadata") + def test_custom_endpoint_without_metadata_skips_name_based_default(self, mock_endpoint_fetch, mock_fetch): + mock_fetch.return_value = {} + mock_endpoint_fetch.return_value = {} + + result = get_model_context_length( + "zai-org/GLM-5-TEE", + base_url="https://llm.chutes.ai/v1", + api_key="test-key", + ) + + assert result == CONTEXT_PROBE_TIERS[0] + # ========================================================================= # fetch_model_metadata — caching, TTL, slugs, failures @@ -258,6 +288,25 @@ class TestFetchModelMetadata: assert "anthropic/claude-3.5-sonnet" in result assert result["anthropic/claude-3.5-sonnet"]["context_length"] == 200000 + @patch("agent.model_metadata.requests.get") + def test_provider_prefixed_models_get_bare_aliases(self, mock_get): + self._reset_cache() + mock_response = MagicMock() + mock_response.json.return_value = { + "data": [{ + "id": "provider/test-model", + "context_length": 123456, + "name": "Provider: Test Model", + }] + } + mock_response.raise_for_status = MagicMock() + mock_get.return_value = mock_response + + result = fetch_model_metadata(force_refresh=True) + + assert result["provider/test-model"]["context_length"] == 123456 + assert result["test-model"]["context_length"] == 123456 + @patch("agent.model_metadata.requests.get") def test_ttl_expiry_triggers_refetch(self, mock_get): """Cache expires after _MODEL_CACHE_TTL seconds.""" diff --git a/tests/agent/test_usage_pricing.py b/tests/agent/test_usage_pricing.py index 6d972dfa7b..a65668bb44 100644 --- a/tests/agent/test_usage_pricing.py +++ b/tests/agent/test_usage_pricing.py @@ -99,3 +99,27 @@ def test_estimate_usage_cost_refuses_cache_pricing_without_official_cache_rate(m ) assert result.status == "unknown" + + +def test_custom_endpoint_models_api_pricing_is_supported(monkeypatch): + monkeypatch.setattr( + "agent.usage_pricing.fetch_endpoint_model_metadata", + lambda base_url, api_key=None: { + "zai-org/GLM-5-TEE": { + "pricing": { + "prompt": "0.0000005", + "completion": "0.000002", + } + } + }, + ) + + entry = get_pricing_entry( + "zai-org/GLM-5-TEE", + provider="custom", + base_url="https://llm.chutes.ai/v1", + api_key="test-key", + ) + + assert float(entry.input_cost_per_million) == 0.5 + assert float(entry.output_cost_per_million) == 2.0