mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-08 03:01:47 +00:00
fix(copilot): send vision header for Copilot vision requests
Thread a vision-request flag through auxiliary provider resolution so Copilot clients can include Copilot-Vision-Request only for vision tasks. This preserves normal text requests while ensuring Copilot vision payloads reach the vision-capable route. Add regression coverage for Copilot vision routing and keep cached text and vision clients separate so a text client without the header is not reused for vision. Co-authored-by: dhabibi <9087935+dhabibi@users.noreply.github.com>
This commit is contained in:
parent
512c610058
commit
8402ba150e
3 changed files with 156 additions and 40 deletions
|
|
@ -1617,8 +1617,14 @@ def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Option
|
||||||
# below — never look up auth env vars ad-hoc.
|
# below — never look up auth env vars ad-hoc.
|
||||||
|
|
||||||
|
|
||||||
def _to_async_client(sync_client, model: str):
|
def _to_async_client(sync_client, model: str, is_vision: bool = False):
|
||||||
"""Convert a sync client to its async counterpart, preserving Codex routing."""
|
"""Convert a sync client to its async counterpart, preserving Codex routing.
|
||||||
|
|
||||||
|
When ``is_vision=True`` and the underlying base URL is Copilot, the
|
||||||
|
resulting async client carries the ``Copilot-Vision-Request: true``
|
||||||
|
header so the request is routed to Copilot's vision-capable
|
||||||
|
infrastructure (otherwise vision payloads silently time out).
|
||||||
|
"""
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
if isinstance(sync_client, CodexAuxiliaryClient):
|
if isinstance(sync_client, CodexAuxiliaryClient):
|
||||||
|
|
@ -1647,9 +1653,11 @@ def _to_async_client(sync_client, model: str):
|
||||||
if base_url_host_matches(sync_base_url, "openrouter.ai"):
|
if base_url_host_matches(sync_base_url, "openrouter.ai"):
|
||||||
async_kwargs["default_headers"] = dict(_OR_HEADERS)
|
async_kwargs["default_headers"] = dict(_OR_HEADERS)
|
||||||
elif base_url_host_matches(sync_base_url, "api.githubcopilot.com"):
|
elif base_url_host_matches(sync_base_url, "api.githubcopilot.com"):
|
||||||
from hermes_cli.models import copilot_default_headers
|
from hermes_cli.copilot_auth import copilot_request_headers
|
||||||
|
|
||||||
async_kwargs["default_headers"] = copilot_default_headers()
|
async_kwargs["default_headers"] = copilot_request_headers(
|
||||||
|
is_agent_turn=True, is_vision=is_vision
|
||||||
|
)
|
||||||
elif base_url_host_matches(sync_base_url, "api.kimi.com"):
|
elif base_url_host_matches(sync_base_url, "api.kimi.com"):
|
||||||
async_kwargs["default_headers"] = {"User-Agent": "claude-code/0.1.0"}
|
async_kwargs["default_headers"] = {"User-Agent": "claude-code/0.1.0"}
|
||||||
return AsyncOpenAI(**async_kwargs), model
|
return AsyncOpenAI(**async_kwargs), model
|
||||||
|
|
@ -1676,6 +1684,7 @@ def resolve_provider_client(
|
||||||
explicit_api_key: str = None,
|
explicit_api_key: str = None,
|
||||||
api_mode: str = None,
|
api_mode: str = None,
|
||||||
main_runtime: Optional[Dict[str, Any]] = None,
|
main_runtime: Optional[Dict[str, Any]] = None,
|
||||||
|
is_vision: bool = False,
|
||||||
) -> Tuple[Optional[Any], Optional[str]]:
|
) -> Tuple[Optional[Any], Optional[str]]:
|
||||||
"""Central router: given a provider name and optional model, return a
|
"""Central router: given a provider name and optional model, return a
|
||||||
configured client with the correct auth, base URL, and API format.
|
configured client with the correct auth, base URL, and API format.
|
||||||
|
|
@ -1759,7 +1768,7 @@ def resolve_provider_client(
|
||||||
"auxiliary provider (using %r instead)", model, resolved)
|
"auxiliary provider (using %r instead)", model, resolved)
|
||||||
model = None
|
model = None
|
||||||
final_model = model or resolved
|
final_model = model or resolved
|
||||||
return (_to_async_client(client, final_model) if async_mode
|
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
|
||||||
else (client, final_model))
|
else (client, final_model))
|
||||||
|
|
||||||
# ── OpenRouter ───────────────────────────────────────────────────
|
# ── OpenRouter ───────────────────────────────────────────────────
|
||||||
|
|
@ -1772,7 +1781,7 @@ def resolve_provider_client(
|
||||||
)
|
)
|
||||||
return None, None
|
return None, None
|
||||||
final_model = _normalize_resolved_model(model or default, provider)
|
final_model = _normalize_resolved_model(model or default, provider)
|
||||||
return (_to_async_client(client, final_model) if async_mode
|
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
|
||||||
else (client, final_model))
|
else (client, final_model))
|
||||||
|
|
||||||
# ── Nous Portal (OAuth) ──────────────────────────────────────────
|
# ── Nous Portal (OAuth) ──────────────────────────────────────────
|
||||||
|
|
@ -1789,7 +1798,7 @@ def resolve_provider_client(
|
||||||
"but Nous Portal not configured (run: hermes auth)")
|
"but Nous Portal not configured (run: hermes auth)")
|
||||||
return None, None
|
return None, None
|
||||||
final_model = _normalize_resolved_model(model or default, provider)
|
final_model = _normalize_resolved_model(model or default, provider)
|
||||||
return (_to_async_client(client, final_model) if async_mode
|
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
|
||||||
else (client, final_model))
|
else (client, final_model))
|
||||||
|
|
||||||
# ── OpenAI Codex (OAuth → Responses API) ─────────────────────────
|
# ── OpenAI Codex (OAuth → Responses API) ─────────────────────────
|
||||||
|
|
@ -1816,7 +1825,7 @@ def resolve_provider_client(
|
||||||
"but no Codex OAuth token found (run: hermes model)")
|
"but no Codex OAuth token found (run: hermes model)")
|
||||||
return None, None
|
return None, None
|
||||||
final_model = _normalize_resolved_model(model or default, provider)
|
final_model = _normalize_resolved_model(model or default, provider)
|
||||||
return (_to_async_client(client, final_model) if async_mode
|
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
|
||||||
else (client, final_model))
|
else (client, final_model))
|
||||||
|
|
||||||
# ── Custom endpoint (OPENAI_BASE_URL + OPENAI_API_KEY) ───────────
|
# ── Custom endpoint (OPENAI_BASE_URL + OPENAI_API_KEY) ───────────
|
||||||
|
|
@ -1845,11 +1854,13 @@ def resolve_provider_client(
|
||||||
if base_url_host_matches(custom_base, "api.kimi.com"):
|
if base_url_host_matches(custom_base, "api.kimi.com"):
|
||||||
extra["default_headers"] = {"User-Agent": "claude-code/0.1.0"}
|
extra["default_headers"] = {"User-Agent": "claude-code/0.1.0"}
|
||||||
elif base_url_host_matches(custom_base, "api.githubcopilot.com"):
|
elif base_url_host_matches(custom_base, "api.githubcopilot.com"):
|
||||||
from hermes_cli.models import copilot_default_headers
|
from hermes_cli.copilot_auth import copilot_request_headers
|
||||||
extra["default_headers"] = copilot_default_headers()
|
extra["default_headers"] = copilot_request_headers(
|
||||||
|
is_agent_turn=True, is_vision=is_vision
|
||||||
|
)
|
||||||
client = OpenAI(api_key=custom_key, base_url=_clean_base, **extra)
|
client = OpenAI(api_key=custom_key, base_url=_clean_base, **extra)
|
||||||
client = _wrap_if_needed(client, final_model, custom_base)
|
client = _wrap_if_needed(client, final_model, custom_base)
|
||||||
return (_to_async_client(client, final_model) if async_mode
|
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
|
||||||
else (client, final_model))
|
else (client, final_model))
|
||||||
# Try custom first, then codex, then API-key providers
|
# Try custom first, then codex, then API-key providers
|
||||||
for try_fn in (_try_custom_endpoint, _try_codex,
|
for try_fn in (_try_custom_endpoint, _try_codex,
|
||||||
|
|
@ -1859,7 +1870,7 @@ def resolve_provider_client(
|
||||||
final_model = _normalize_resolved_model(model or default, provider)
|
final_model = _normalize_resolved_model(model or default, provider)
|
||||||
_cbase = str(getattr(client, "base_url", "") or "")
|
_cbase = str(getattr(client, "base_url", "") or "")
|
||||||
client = _wrap_if_needed(client, final_model, _cbase)
|
client = _wrap_if_needed(client, final_model, _cbase)
|
||||||
return (_to_async_client(client, final_model) if async_mode
|
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
|
||||||
else (client, final_model))
|
else (client, final_model))
|
||||||
logger.warning("resolve_provider_client: custom/main requested "
|
logger.warning("resolve_provider_client: custom/main requested "
|
||||||
"but no endpoint credentials found")
|
"but no endpoint credentials found")
|
||||||
|
|
@ -1904,7 +1915,7 @@ def resolve_provider_client(
|
||||||
provider,
|
provider,
|
||||||
)
|
)
|
||||||
client = OpenAI(api_key=custom_key, base_url=_clean_base2, **_extra2)
|
client = OpenAI(api_key=custom_key, base_url=_clean_base2, **_extra2)
|
||||||
return (_to_async_client(client, final_model) if async_mode
|
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
|
||||||
else (client, final_model))
|
else (client, final_model))
|
||||||
sync_anthropic = AnthropicAuxiliaryClient(
|
sync_anthropic = AnthropicAuxiliaryClient(
|
||||||
real_client, final_model, custom_key, custom_base, is_oauth=False,
|
real_client, final_model, custom_key, custom_base, is_oauth=False,
|
||||||
|
|
@ -1923,7 +1934,7 @@ def resolve_provider_client(
|
||||||
client = CodexAuxiliaryClient(client, final_model)
|
client = CodexAuxiliaryClient(client, final_model)
|
||||||
else:
|
else:
|
||||||
client = _wrap_if_needed(client, final_model, custom_base)
|
client = _wrap_if_needed(client, final_model, custom_base)
|
||||||
return (_to_async_client(client, final_model) if async_mode
|
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
|
||||||
else (client, final_model))
|
else (client, final_model))
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"resolve_provider_client: named custom provider %r has no base_url",
|
"resolve_provider_client: named custom provider %r has no base_url",
|
||||||
|
|
@ -1955,7 +1966,7 @@ def resolve_provider_client(
|
||||||
logger.warning("resolve_provider_client: anthropic requested but no Anthropic credentials found")
|
logger.warning("resolve_provider_client: anthropic requested but no Anthropic credentials found")
|
||||||
return None, None
|
return None, None
|
||||||
final_model = _normalize_resolved_model(model or default_model, provider)
|
final_model = _normalize_resolved_model(model or default_model, provider)
|
||||||
return (_to_async_client(client, final_model) if async_mode else (client, final_model))
|
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode else (client, final_model))
|
||||||
|
|
||||||
creds = resolve_api_key_provider_credentials(provider)
|
creds = resolve_api_key_provider_credentials(provider)
|
||||||
api_key = str(creds.get("api_key", "")).strip()
|
api_key = str(creds.get("api_key", "")).strip()
|
||||||
|
|
@ -1981,7 +1992,7 @@ def resolve_provider_client(
|
||||||
if is_native_gemini_base_url(base_url):
|
if is_native_gemini_base_url(base_url):
|
||||||
client = GeminiNativeClient(api_key=api_key, base_url=base_url)
|
client = GeminiNativeClient(api_key=api_key, base_url=base_url)
|
||||||
logger.debug("resolve_provider_client: %s (%s)", provider, final_model)
|
logger.debug("resolve_provider_client: %s (%s)", provider, final_model)
|
||||||
return (_to_async_client(client, final_model) if async_mode
|
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
|
||||||
else (client, final_model))
|
else (client, final_model))
|
||||||
|
|
||||||
# Provider-specific headers
|
# Provider-specific headers
|
||||||
|
|
@ -1989,9 +2000,11 @@ def resolve_provider_client(
|
||||||
if base_url_host_matches(base_url, "api.kimi.com"):
|
if base_url_host_matches(base_url, "api.kimi.com"):
|
||||||
headers["User-Agent"] = "claude-code/0.1.0"
|
headers["User-Agent"] = "claude-code/0.1.0"
|
||||||
elif base_url_host_matches(base_url, "api.githubcopilot.com"):
|
elif base_url_host_matches(base_url, "api.githubcopilot.com"):
|
||||||
from hermes_cli.models import copilot_default_headers
|
from hermes_cli.copilot_auth import copilot_request_headers
|
||||||
|
|
||||||
headers.update(copilot_default_headers())
|
headers.update(copilot_request_headers(
|
||||||
|
is_agent_turn=True, is_vision=is_vision
|
||||||
|
))
|
||||||
client = OpenAI(api_key=api_key, base_url=base_url,
|
client = OpenAI(api_key=api_key, base_url=base_url,
|
||||||
**({"default_headers": headers} if headers else {}))
|
**({"default_headers": headers} if headers else {}))
|
||||||
|
|
||||||
|
|
@ -2017,7 +2030,7 @@ def resolve_provider_client(
|
||||||
client = _wrap_if_needed(client, final_model, base_url)
|
client = _wrap_if_needed(client, final_model, base_url)
|
||||||
|
|
||||||
logger.debug("resolve_provider_client: %s (%s)", provider, final_model)
|
logger.debug("resolve_provider_client: %s (%s)", provider, final_model)
|
||||||
return (_to_async_client(client, final_model) if async_mode
|
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
|
||||||
else (client, final_model))
|
else (client, final_model))
|
||||||
|
|
||||||
if pconfig.auth_type == "external_process":
|
if pconfig.auth_type == "external_process":
|
||||||
|
|
@ -2049,7 +2062,7 @@ def resolve_provider_client(
|
||||||
args=args,
|
args=args,
|
||||||
)
|
)
|
||||||
logger.debug("resolve_provider_client: %s (%s)", provider, final_model)
|
logger.debug("resolve_provider_client: %s (%s)", provider, final_model)
|
||||||
return (_to_async_client(client, final_model) if async_mode
|
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
|
||||||
else (client, final_model))
|
else (client, final_model))
|
||||||
logger.warning("resolve_provider_client: external-process provider %s not "
|
logger.warning("resolve_provider_client: external-process provider %s not "
|
||||||
"directly supported", provider)
|
"directly supported", provider)
|
||||||
|
|
@ -2085,7 +2098,7 @@ def resolve_provider_client(
|
||||||
base_url=f"https://bedrock-runtime.{region}.amazonaws.com",
|
base_url=f"https://bedrock-runtime.{region}.amazonaws.com",
|
||||||
)
|
)
|
||||||
logger.debug("resolve_provider_client: bedrock (%s, %s)", final_model, region)
|
logger.debug("resolve_provider_client: bedrock (%s, %s)", final_model, region)
|
||||||
return (_to_async_client(client, final_model) if async_mode
|
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
|
||||||
else (client, final_model))
|
else (client, final_model))
|
||||||
|
|
||||||
elif pconfig.auth_type in ("oauth_device_code", "oauth_external"):
|
elif pconfig.auth_type in ("oauth_device_code", "oauth_external"):
|
||||||
|
|
@ -2160,8 +2173,13 @@ def _normalize_vision_provider(provider: Optional[str]) -> str:
|
||||||
return _normalize_aux_provider(provider)
|
return _normalize_aux_provider(provider)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_strict_vision_backend(provider: str) -> Tuple[Optional[Any], Optional[str]]:
|
def _resolve_strict_vision_backend(
|
||||||
|
provider: str,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
) -> Tuple[Optional[Any], Optional[str]]:
|
||||||
provider = _normalize_vision_provider(provider)
|
provider = _normalize_vision_provider(provider)
|
||||||
|
if provider == "copilot":
|
||||||
|
return resolve_provider_client("copilot", model, is_vision=True)
|
||||||
if provider == "openrouter":
|
if provider == "openrouter":
|
||||||
return _try_openrouter()
|
return _try_openrouter()
|
||||||
if provider == "nous":
|
if provider == "nous":
|
||||||
|
|
@ -2229,7 +2247,7 @@ def resolve_vision_provider_client(
|
||||||
return resolved_provider, None, None
|
return resolved_provider, None, None
|
||||||
final_model = resolved_model or default_model
|
final_model = resolved_model or default_model
|
||||||
if async_mode:
|
if async_mode:
|
||||||
async_client, async_model = _to_async_client(sync_client, final_model)
|
async_client, async_model = _to_async_client(sync_client, final_model, is_vision=True)
|
||||||
return resolved_provider, async_client, async_model
|
return resolved_provider, async_client, async_model
|
||||||
return resolved_provider, sync_client, final_model
|
return resolved_provider, sync_client, final_model
|
||||||
|
|
||||||
|
|
@ -2261,8 +2279,11 @@ def resolve_vision_provider_client(
|
||||||
main_provider = _read_main_provider()
|
main_provider = _read_main_provider()
|
||||||
main_model = _read_main_model()
|
main_model = _read_main_model()
|
||||||
if main_provider and main_provider not in ("auto", ""):
|
if main_provider and main_provider not in ("auto", ""):
|
||||||
|
vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model)
|
||||||
if main_provider == "nous":
|
if main_provider == "nous":
|
||||||
sync_client, default_model = _resolve_strict_vision_backend(main_provider)
|
sync_client, default_model = _resolve_strict_vision_backend(
|
||||||
|
main_provider, vision_model
|
||||||
|
)
|
||||||
if sync_client is not None:
|
if sync_client is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Vision auto-detect: using main provider %s (%s)",
|
"Vision auto-detect: using main provider %s (%s)",
|
||||||
|
|
@ -2270,10 +2291,10 @@ def resolve_vision_provider_client(
|
||||||
)
|
)
|
||||||
return _finalize(main_provider, sync_client, default_model)
|
return _finalize(main_provider, sync_client, default_model)
|
||||||
else:
|
else:
|
||||||
vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model)
|
|
||||||
rpc_client, rpc_model = resolve_provider_client(
|
rpc_client, rpc_model = resolve_provider_client(
|
||||||
main_provider, vision_model,
|
main_provider, vision_model,
|
||||||
api_mode=resolved_api_mode)
|
api_mode=resolved_api_mode,
|
||||||
|
is_vision=True)
|
||||||
if rpc_client is not None:
|
if rpc_client is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Vision auto-detect: using main provider %s (%s)",
|
"Vision auto-detect: using main provider %s (%s)",
|
||||||
|
|
@ -2295,11 +2316,14 @@ def resolve_vision_provider_client(
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
if requested in _VISION_AUTO_PROVIDER_ORDER:
|
if requested in _VISION_AUTO_PROVIDER_ORDER:
|
||||||
sync_client, default_model = _resolve_strict_vision_backend(requested)
|
sync_client, default_model = _resolve_strict_vision_backend(
|
||||||
|
requested, resolved_model
|
||||||
|
)
|
||||||
return _finalize(requested, sync_client, default_model)
|
return _finalize(requested, sync_client, default_model)
|
||||||
|
|
||||||
client, final_model = _get_cached_client(requested, resolved_model, async_mode,
|
client, final_model = _get_cached_client(requested, resolved_model, async_mode,
|
||||||
api_mode=resolved_api_mode)
|
api_mode=resolved_api_mode,
|
||||||
|
is_vision=True)
|
||||||
if client is None:
|
if client is None:
|
||||||
return requested, None, None
|
return requested, None, None
|
||||||
return requested, client, final_model
|
return requested, client, final_model
|
||||||
|
|
@ -2363,10 +2387,11 @@ def _client_cache_key(
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_mode: Optional[str] = None,
|
api_mode: Optional[str] = None,
|
||||||
main_runtime: Optional[Dict[str, Any]] = None,
|
main_runtime: Optional[Dict[str, Any]] = None,
|
||||||
|
is_vision: bool = False,
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
runtime = _normalize_main_runtime(main_runtime)
|
runtime = _normalize_main_runtime(main_runtime)
|
||||||
runtime_key = tuple(runtime.get(field, "") for field in _MAIN_RUNTIME_FIELDS) if provider == "auto" else ()
|
runtime_key = tuple(runtime.get(field, "") for field in _MAIN_RUNTIME_FIELDS) if provider == "auto" else ()
|
||||||
return (provider, async_mode, base_url or "", api_key or "", api_mode or "", runtime_key)
|
return (provider, async_mode, base_url or "", api_key or "", api_mode or "", runtime_key, is_vision)
|
||||||
|
|
||||||
|
|
||||||
def _store_cached_client(cache_key: tuple, client: Any, default_model: Optional[str], *, bound_loop: Any = None) -> None:
|
def _store_cached_client(cache_key: tuple, client: Any, default_model: Optional[str], *, bound_loop: Any = None) -> None:
|
||||||
|
|
@ -2392,6 +2417,7 @@ def _refresh_nous_auxiliary_client(
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
api_mode: Optional[str] = None,
|
api_mode: Optional[str] = None,
|
||||||
main_runtime: Optional[Dict[str, Any]] = None,
|
main_runtime: Optional[Dict[str, Any]] = None,
|
||||||
|
is_vision: bool = False,
|
||||||
) -> Tuple[Optional[Any], Optional[str]]:
|
) -> Tuple[Optional[Any], Optional[str]]:
|
||||||
"""Refresh Nous runtime creds, rebuild the client, and replace the cache entry."""
|
"""Refresh Nous runtime creds, rebuild the client, and replace the cache entry."""
|
||||||
runtime = _resolve_nous_runtime_api(force_refresh=True)
|
runtime = _resolve_nous_runtime_api(force_refresh=True)
|
||||||
|
|
@ -2409,7 +2435,7 @@ def _refresh_nous_auxiliary_client(
|
||||||
current_loop = _aio.get_event_loop()
|
current_loop = _aio.get_event_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
client, final_model = _to_async_client(sync_client, final_model or "")
|
client, final_model = _to_async_client(sync_client, final_model or "", is_vision=is_vision)
|
||||||
else:
|
else:
|
||||||
client = sync_client
|
client = sync_client
|
||||||
|
|
||||||
|
|
@ -2420,6 +2446,7 @@ def _refresh_nous_auxiliary_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_mode=api_mode,
|
api_mode=api_mode,
|
||||||
main_runtime=main_runtime,
|
main_runtime=main_runtime,
|
||||||
|
is_vision=is_vision,
|
||||||
)
|
)
|
||||||
_store_cached_client(cache_key, client, final_model, bound_loop=current_loop)
|
_store_cached_client(cache_key, client, final_model, bound_loop=current_loop)
|
||||||
return client, final_model
|
return client, final_model
|
||||||
|
|
@ -2549,6 +2576,7 @@ def _get_cached_client(
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
api_mode: str = None,
|
api_mode: str = None,
|
||||||
main_runtime: Optional[Dict[str, Any]] = None,
|
main_runtime: Optional[Dict[str, Any]] = None,
|
||||||
|
is_vision: bool = False,
|
||||||
) -> Tuple[Optional[Any], Optional[str]]:
|
) -> Tuple[Optional[Any], Optional[str]]:
|
||||||
"""Get or create a cached client for the given provider.
|
"""Get or create a cached client for the given provider.
|
||||||
|
|
||||||
|
|
@ -2585,6 +2613,7 @@ def _get_cached_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_mode=api_mode,
|
api_mode=api_mode,
|
||||||
main_runtime=main_runtime,
|
main_runtime=main_runtime,
|
||||||
|
is_vision=is_vision,
|
||||||
)
|
)
|
||||||
with _client_cache_lock:
|
with _client_cache_lock:
|
||||||
if cache_key in _client_cache:
|
if cache_key in _client_cache:
|
||||||
|
|
@ -2616,6 +2645,7 @@ def _get_cached_client(
|
||||||
explicit_api_key=api_key,
|
explicit_api_key=api_key,
|
||||||
api_mode=api_mode,
|
api_mode=api_mode,
|
||||||
main_runtime=runtime,
|
main_runtime=runtime,
|
||||||
|
is_vision=is_vision,
|
||||||
)
|
)
|
||||||
if client is not None:
|
if client is not None:
|
||||||
# For async clients, remember which loop they were created on so we
|
# For async clients, remember which loop they were created on so we
|
||||||
|
|
@ -3079,6 +3109,7 @@ def call_llm(
|
||||||
api_key=resolved_api_key,
|
api_key=resolved_api_key,
|
||||||
api_mode=resolved_api_mode,
|
api_mode=resolved_api_mode,
|
||||||
main_runtime=main_runtime,
|
main_runtime=main_runtime,
|
||||||
|
is_vision=(task == "vision"),
|
||||||
)
|
)
|
||||||
if refreshed_client is not None:
|
if refreshed_client is not None:
|
||||||
logger.info("Auxiliary %s: refreshed Nous runtime credentials after 401, retrying",
|
logger.info("Auxiliary %s: refreshed Nous runtime credentials after 401, retrying",
|
||||||
|
|
@ -3369,6 +3400,7 @@ async def async_call_llm(
|
||||||
base_url=resolved_base_url,
|
base_url=resolved_base_url,
|
||||||
api_key=resolved_api_key,
|
api_key=resolved_api_key,
|
||||||
api_mode=resolved_api_mode,
|
api_mode=resolved_api_mode,
|
||||||
|
is_vision=(task == "vision"),
|
||||||
)
|
)
|
||||||
if refreshed_client is not None:
|
if refreshed_client is not None:
|
||||||
logger.info("Auxiliary %s (async): refreshed Nous runtime credentials after 401, retrying",
|
logger.info("Auxiliary %s (async): refreshed Nous runtime credentials after 401, retrying",
|
||||||
|
|
@ -3437,7 +3469,9 @@ async def async_call_llm(
|
||||||
extra_body=effective_extra_body,
|
extra_body=effective_extra_body,
|
||||||
base_url=str(getattr(fb_client, "base_url", "") or ""))
|
base_url=str(getattr(fb_client, "base_url", "") or ""))
|
||||||
# Convert sync fallback client to async
|
# Convert sync fallback client to async
|
||||||
async_fb, async_fb_model = _to_async_client(fb_client, fb_model or "")
|
async_fb, async_fb_model = _to_async_client(
|
||||||
|
fb_client, fb_model or "", is_vision=(task == "vision")
|
||||||
|
)
|
||||||
if async_fb_model and async_fb_model != fb_kwargs.get("model"):
|
if async_fb_model and async_fb_model != fb_kwargs.get("model"):
|
||||||
fb_kwargs["model"] = async_fb_model
|
fb_kwargs["model"] = async_fb_model
|
||||||
return _validate_llm_response(
|
return _validate_llm_response(
|
||||||
|
|
|
||||||
|
|
@ -199,6 +199,7 @@ class TestResolveVisionMainFirst:
|
||||||
mock_resolve.assert_called_once()
|
mock_resolve.assert_called_once()
|
||||||
assert mock_resolve.call_args.args[0] == "openrouter"
|
assert mock_resolve.call_args.args[0] == "openrouter"
|
||||||
assert mock_resolve.call_args.args[1] == "anthropic/claude-sonnet-4.6"
|
assert mock_resolve.call_args.args[1] == "anthropic/claude-sonnet-4.6"
|
||||||
|
assert mock_resolve.call_args.kwargs.get("is_vision") is True
|
||||||
|
|
||||||
def test_nous_main_vision_uses_paid_nous_vision_backend(self):
|
def test_nous_main_vision_uses_paid_nous_vision_backend(self):
|
||||||
"""Paid Nous main → aux vision uses the dedicated Nous vision backend."""
|
"""Paid Nous main → aux vision uses the dedicated Nous vision backend."""
|
||||||
|
|
@ -266,6 +267,87 @@ class TestResolveVisionMainFirst:
|
||||||
assert provider == "xiaomi"
|
assert provider == "xiaomi"
|
||||||
# Should use mimo-v2.5 (vision override), not mimo-v2-pro (text main)
|
# Should use mimo-v2.5 (vision override), not mimo-v2-pro (text main)
|
||||||
assert mock_resolve.call_args.args[1] == "mimo-v2.5"
|
assert mock_resolve.call_args.args[1] == "mimo-v2.5"
|
||||||
|
assert mock_resolve.call_args.kwargs.get("is_vision") is True
|
||||||
|
|
||||||
|
def test_copilot_vision_sets_vision_header(self, monkeypatch):
|
||||||
|
"""Copilot vision requests include the header required for vision routing."""
|
||||||
|
monkeypatch.setenv("COPILOT_GITHUB_TOKEN", "ghu_test-token")
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_headers(*, is_agent_turn=False, is_vision=False):
|
||||||
|
captured["is_agent_turn"] = is_agent_turn
|
||||||
|
captured["is_vision"] = is_vision
|
||||||
|
return {"Copilot-Vision-Request": "true"} if is_vision else {}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"agent.auxiliary_client._read_main_provider", return_value="copilot",
|
||||||
|
), patch(
|
||||||
|
"agent.auxiliary_client._read_main_model", return_value="configured-copilot-model",
|
||||||
|
), patch(
|
||||||
|
"agent.auxiliary_client._resolve_task_provider_model",
|
||||||
|
return_value=("auto", None, None, None, None),
|
||||||
|
), patch(
|
||||||
|
"agent.auxiliary_client.OpenAI",
|
||||||
|
) as mock_openai, patch(
|
||||||
|
"hermes_cli.auth.resolve_api_key_provider_credentials",
|
||||||
|
return_value={
|
||||||
|
"provider": "copilot",
|
||||||
|
"api_key": "copilot-api-token",
|
||||||
|
"base_url": "https://api.githubcopilot.com",
|
||||||
|
},
|
||||||
|
), patch(
|
||||||
|
"hermes_cli.copilot_auth.copilot_request_headers",
|
||||||
|
side_effect=fake_headers,
|
||||||
|
):
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_openai.return_value = mock_client
|
||||||
|
|
||||||
|
from agent.auxiliary_client import resolve_vision_provider_client
|
||||||
|
|
||||||
|
provider, client, model = resolve_vision_provider_client()
|
||||||
|
|
||||||
|
assert provider == "copilot"
|
||||||
|
assert client is mock_client
|
||||||
|
assert model == "configured-copilot-model"
|
||||||
|
assert captured == {"is_agent_turn": True, "is_vision": True}
|
||||||
|
assert mock_openai.call_args.kwargs["default_headers"]["Copilot-Vision-Request"] == "true"
|
||||||
|
|
||||||
|
def test_text_copilot_does_not_set_vision_header(self, monkeypatch):
|
||||||
|
"""Text Copilot requests keep the vision-only header off."""
|
||||||
|
monkeypatch.setenv("COPILOT_GITHUB_TOKEN", "ghu_test-token")
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_headers(*, is_agent_turn=False, is_vision=False):
|
||||||
|
captured["is_agent_turn"] = is_agent_turn
|
||||||
|
captured["is_vision"] = is_vision
|
||||||
|
return {"Copilot-Vision-Request": "true"} if is_vision else {}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"agent.auxiliary_client.OpenAI",
|
||||||
|
) as mock_openai, patch(
|
||||||
|
"hermes_cli.auth.resolve_api_key_provider_credentials",
|
||||||
|
return_value={
|
||||||
|
"provider": "copilot",
|
||||||
|
"api_key": "copilot-api-token",
|
||||||
|
"base_url": "https://api.githubcopilot.com",
|
||||||
|
},
|
||||||
|
), patch(
|
||||||
|
"hermes_cli.copilot_auth.copilot_request_headers",
|
||||||
|
side_effect=fake_headers,
|
||||||
|
):
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_openai.return_value = mock_client
|
||||||
|
|
||||||
|
from agent.auxiliary_client import resolve_provider_client
|
||||||
|
|
||||||
|
client, model = resolve_provider_client("copilot", "gpt-5-mini")
|
||||||
|
|
||||||
|
assert client is mock_client
|
||||||
|
assert model == "gpt-5-mini"
|
||||||
|
assert captured == {"is_agent_turn": True, "is_vision": False}
|
||||||
|
assert "default_headers" not in mock_openai.call_args.kwargs
|
||||||
|
|
||||||
def test_main_unavailable_vision_falls_through_to_aggregators(self):
|
def test_main_unavailable_vision_falls_through_to_aggregators(self):
|
||||||
"""Main provider fails → fall back to OpenRouter/Nous strict backends."""
|
"""Main provider fails → fall back to OpenRouter/Nous strict backends."""
|
||||||
|
|
@ -312,7 +394,7 @@ class TestResolveVisionMainFirst:
|
||||||
|
|
||||||
# Explicit "nous" override → uses strict backend, NOT main model path
|
# Explicit "nous" override → uses strict backend, NOT main model path
|
||||||
assert provider == "nous"
|
assert provider == "nous"
|
||||||
mock_strict.assert_called_once_with("nous")
|
mock_strict.assert_called_once_with("nous", None)
|
||||||
|
|
||||||
|
|
||||||
# ── Constant cleanup ────────────────────────────────────────────────────────
|
# ── Constant cleanup ────────────────────────────────────────────────────────
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,7 @@ class TestCleanupStaleAsyncClients:
|
||||||
mock_client._client = MagicMock()
|
mock_client._client = MagicMock()
|
||||||
mock_client._client.is_closed = False
|
mock_client._client.is_closed = False
|
||||||
|
|
||||||
key = ("test_stale", True, "", "", "", ())
|
key = ("test_stale", True, "", "", "", (), False)
|
||||||
with _client_cache_lock:
|
with _client_cache_lock:
|
||||||
_client_cache[key] = (mock_client, "test-model", loop)
|
_client_cache[key] = (mock_client, "test-model", loop)
|
||||||
|
|
||||||
|
|
@ -127,7 +127,7 @@ class TestCleanupStaleAsyncClients:
|
||||||
loop = asyncio.new_event_loop() # NOT closed
|
loop = asyncio.new_event_loop() # NOT closed
|
||||||
|
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
key = ("test_live", True, "", "", "", ())
|
key = ("test_live", True, "", "", "", (), False)
|
||||||
with _client_cache_lock:
|
with _client_cache_lock:
|
||||||
_client_cache[key] = (mock_client, "test-model", loop)
|
_client_cache[key] = (mock_client, "test-model", loop)
|
||||||
|
|
||||||
|
|
@ -149,7 +149,7 @@ class TestCleanupStaleAsyncClients:
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
key = ("test_sync", False, "", "", "", ())
|
key = ("test_sync", False, "", "", "", (), False)
|
||||||
with _client_cache_lock:
|
with _client_cache_lock:
|
||||||
_client_cache[key] = (mock_client, "test-model", None)
|
_client_cache[key] = (mock_client, "test-model", None)
|
||||||
|
|
||||||
|
|
@ -182,7 +182,7 @@ class TestClientCacheBoundedGrowth:
|
||||||
_get_cached_client,
|
_get_cached_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
key = ("test_replace", True, "", "", "", ())
|
key = ("test_replace", True, "", "", "", (), False)
|
||||||
|
|
||||||
# Simulate a stale entry from a closed loop
|
# Simulate a stale entry from a closed loop
|
||||||
old_loop = asyncio.new_event_loop()
|
old_loop = asyncio.new_event_loop()
|
||||||
|
|
@ -217,7 +217,7 @@ class TestClientCacheBoundedGrowth:
|
||||||
_client_cache_lock,
|
_client_cache_lock,
|
||||||
)
|
)
|
||||||
|
|
||||||
key = ("test_no_grow", True, "", "", "", ())
|
key = ("test_no_grow", True, "", "", "", (), False)
|
||||||
|
|
||||||
loops = []
|
loops = []
|
||||||
try:
|
try:
|
||||||
|
|
@ -269,7 +269,7 @@ class TestClientCacheBoundedGrowth:
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_client._client = MagicMock()
|
mock_client._client = MagicMock()
|
||||||
mock_client._client.is_closed = False
|
mock_client._client.is_closed = False
|
||||||
key = (f"evict_test_{i}", False, "", "", "", ())
|
key = (f"evict_test_{i}", False, "", "", "", (), False)
|
||||||
with _client_cache_lock:
|
with _client_cache_lock:
|
||||||
# Inline the eviction logic (same as _get_cached_client)
|
# Inline the eviction logic (same as _get_cached_client)
|
||||||
while len(_client_cache) >= _CLIENT_CACHE_MAX_SIZE:
|
while len(_client_cache) >= _CLIENT_CACHE_MAX_SIZE:
|
||||||
|
|
@ -281,9 +281,9 @@ class TestClientCacheBoundedGrowth:
|
||||||
assert len(_client_cache) <= _CLIENT_CACHE_MAX_SIZE, \
|
assert len(_client_cache) <= _CLIENT_CACHE_MAX_SIZE, \
|
||||||
f"Cache size {len(_client_cache)} exceeds max {_CLIENT_CACHE_MAX_SIZE}"
|
f"Cache size {len(_client_cache)} exceeds max {_CLIENT_CACHE_MAX_SIZE}"
|
||||||
# The earliest entries should have been evicted
|
# The earliest entries should have been evicted
|
||||||
assert ("evict_test_0", False, "", "", "", ()) not in _client_cache
|
assert ("evict_test_0", False, "", "", "", (), False) not in _client_cache
|
||||||
# The latest entries should be present
|
# The latest entries should be present
|
||||||
assert (f"evict_test_{_CLIENT_CACHE_MAX_SIZE + 4}", False, "", "", "", ()) in _client_cache
|
assert (f"evict_test_{_CLIENT_CACHE_MAX_SIZE + 4}", False, "", "", "", (), False) in _client_cache
|
||||||
finally:
|
finally:
|
||||||
with _client_cache_lock:
|
with _client_cache_lock:
|
||||||
_client_cache.clear()
|
_client_cache.clear()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue