fix(copilot): send vision header for Copilot vision requests

Thread a vision-request flag through auxiliary provider resolution so Copilot clients can include Copilot-Vision-Request only for vision tasks. This preserves normal text requests while ensuring Copilot vision payloads reach the vision-capable route.

Add regression coverage for Copilot vision routing and keep cached text and vision clients separate so a text client without the header is not reused for vision.

Co-authored-by: dhabibi <9087935+dhabibi@users.noreply.github.com>
This commit is contained in:
hermes-agent-dhabibi 2026-04-26 21:29:55 +00:00 committed by Teknium
parent 512c610058
commit 8402ba150e
3 changed files with 156 additions and 40 deletions

View file

@ -1617,8 +1617,14 @@ def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Option
# below — never look up auth env vars ad-hoc.
def _to_async_client(sync_client, model: str):
"""Convert a sync client to its async counterpart, preserving Codex routing."""
def _to_async_client(sync_client, model: str, is_vision: bool = False):
"""Convert a sync client to its async counterpart, preserving Codex routing.
When ``is_vision=True`` and the underlying base URL is Copilot, the
resulting async client carries the ``Copilot-Vision-Request: true``
header so the request is routed to Copilot's vision-capable
infrastructure (otherwise vision payloads silently time out).
"""
from openai import AsyncOpenAI
if isinstance(sync_client, CodexAuxiliaryClient):
@ -1647,9 +1653,11 @@ def _to_async_client(sync_client, model: str):
if base_url_host_matches(sync_base_url, "openrouter.ai"):
async_kwargs["default_headers"] = dict(_OR_HEADERS)
elif base_url_host_matches(sync_base_url, "api.githubcopilot.com"):
from hermes_cli.models import copilot_default_headers
from hermes_cli.copilot_auth import copilot_request_headers
async_kwargs["default_headers"] = copilot_default_headers()
async_kwargs["default_headers"] = copilot_request_headers(
is_agent_turn=True, is_vision=is_vision
)
elif base_url_host_matches(sync_base_url, "api.kimi.com"):
async_kwargs["default_headers"] = {"User-Agent": "claude-code/0.1.0"}
return AsyncOpenAI(**async_kwargs), model
@ -1676,6 +1684,7 @@ def resolve_provider_client(
explicit_api_key: str = None,
api_mode: str = None,
main_runtime: Optional[Dict[str, Any]] = None,
is_vision: bool = False,
) -> Tuple[Optional[Any], Optional[str]]:
"""Central router: given a provider name and optional model, return a
configured client with the correct auth, base URL, and API format.
@ -1759,7 +1768,7 @@ def resolve_provider_client(
"auxiliary provider (using %r instead)", model, resolved)
model = None
final_model = model or resolved
return (_to_async_client(client, final_model) if async_mode
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
else (client, final_model))
# ── OpenRouter ───────────────────────────────────────────────────
@ -1772,7 +1781,7 @@ def resolve_provider_client(
)
return None, None
final_model = _normalize_resolved_model(model or default, provider)
return (_to_async_client(client, final_model) if async_mode
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
else (client, final_model))
# ── Nous Portal (OAuth) ──────────────────────────────────────────
@ -1789,7 +1798,7 @@ def resolve_provider_client(
"but Nous Portal not configured (run: hermes auth)")
return None, None
final_model = _normalize_resolved_model(model or default, provider)
return (_to_async_client(client, final_model) if async_mode
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
else (client, final_model))
# ── OpenAI Codex (OAuth → Responses API) ─────────────────────────
@ -1816,7 +1825,7 @@ def resolve_provider_client(
"but no Codex OAuth token found (run: hermes model)")
return None, None
final_model = _normalize_resolved_model(model or default, provider)
return (_to_async_client(client, final_model) if async_mode
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
else (client, final_model))
# ── Custom endpoint (OPENAI_BASE_URL + OPENAI_API_KEY) ───────────
@ -1845,11 +1854,13 @@ def resolve_provider_client(
if base_url_host_matches(custom_base, "api.kimi.com"):
extra["default_headers"] = {"User-Agent": "claude-code/0.1.0"}
elif base_url_host_matches(custom_base, "api.githubcopilot.com"):
from hermes_cli.models import copilot_default_headers
extra["default_headers"] = copilot_default_headers()
from hermes_cli.copilot_auth import copilot_request_headers
extra["default_headers"] = copilot_request_headers(
is_agent_turn=True, is_vision=is_vision
)
client = OpenAI(api_key=custom_key, base_url=_clean_base, **extra)
client = _wrap_if_needed(client, final_model, custom_base)
return (_to_async_client(client, final_model) if async_mode
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
else (client, final_model))
# Try custom first, then codex, then API-key providers
for try_fn in (_try_custom_endpoint, _try_codex,
@ -1859,7 +1870,7 @@ def resolve_provider_client(
final_model = _normalize_resolved_model(model or default, provider)
_cbase = str(getattr(client, "base_url", "") or "")
client = _wrap_if_needed(client, final_model, _cbase)
return (_to_async_client(client, final_model) if async_mode
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
else (client, final_model))
logger.warning("resolve_provider_client: custom/main requested "
"but no endpoint credentials found")
@ -1904,7 +1915,7 @@ def resolve_provider_client(
provider,
)
client = OpenAI(api_key=custom_key, base_url=_clean_base2, **_extra2)
return (_to_async_client(client, final_model) if async_mode
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
else (client, final_model))
sync_anthropic = AnthropicAuxiliaryClient(
real_client, final_model, custom_key, custom_base, is_oauth=False,
@ -1923,7 +1934,7 @@ def resolve_provider_client(
client = CodexAuxiliaryClient(client, final_model)
else:
client = _wrap_if_needed(client, final_model, custom_base)
return (_to_async_client(client, final_model) if async_mode
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
else (client, final_model))
logger.warning(
"resolve_provider_client: named custom provider %r has no base_url",
@ -1955,7 +1966,7 @@ def resolve_provider_client(
logger.warning("resolve_provider_client: anthropic requested but no Anthropic credentials found")
return None, None
final_model = _normalize_resolved_model(model or default_model, provider)
return (_to_async_client(client, final_model) if async_mode else (client, final_model))
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode else (client, final_model))
creds = resolve_api_key_provider_credentials(provider)
api_key = str(creds.get("api_key", "")).strip()
@ -1981,7 +1992,7 @@ def resolve_provider_client(
if is_native_gemini_base_url(base_url):
client = GeminiNativeClient(api_key=api_key, base_url=base_url)
logger.debug("resolve_provider_client: %s (%s)", provider, final_model)
return (_to_async_client(client, final_model) if async_mode
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
else (client, final_model))
# Provider-specific headers
@ -1989,9 +2000,11 @@ def resolve_provider_client(
if base_url_host_matches(base_url, "api.kimi.com"):
headers["User-Agent"] = "claude-code/0.1.0"
elif base_url_host_matches(base_url, "api.githubcopilot.com"):
from hermes_cli.models import copilot_default_headers
from hermes_cli.copilot_auth import copilot_request_headers
headers.update(copilot_default_headers())
headers.update(copilot_request_headers(
is_agent_turn=True, is_vision=is_vision
))
client = OpenAI(api_key=api_key, base_url=base_url,
**({"default_headers": headers} if headers else {}))
@ -2017,7 +2030,7 @@ def resolve_provider_client(
client = _wrap_if_needed(client, final_model, base_url)
logger.debug("resolve_provider_client: %s (%s)", provider, final_model)
return (_to_async_client(client, final_model) if async_mode
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
else (client, final_model))
if pconfig.auth_type == "external_process":
@ -2049,7 +2062,7 @@ def resolve_provider_client(
args=args,
)
logger.debug("resolve_provider_client: %s (%s)", provider, final_model)
return (_to_async_client(client, final_model) if async_mode
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
else (client, final_model))
logger.warning("resolve_provider_client: external-process provider %s not "
"directly supported", provider)
@ -2085,7 +2098,7 @@ def resolve_provider_client(
base_url=f"https://bedrock-runtime.{region}.amazonaws.com",
)
logger.debug("resolve_provider_client: bedrock (%s, %s)", final_model, region)
return (_to_async_client(client, final_model) if async_mode
return (_to_async_client(client, final_model, is_vision=is_vision) if async_mode
else (client, final_model))
elif pconfig.auth_type in ("oauth_device_code", "oauth_external"):
@ -2160,8 +2173,13 @@ def _normalize_vision_provider(provider: Optional[str]) -> str:
return _normalize_aux_provider(provider)
def _resolve_strict_vision_backend(provider: str) -> Tuple[Optional[Any], Optional[str]]:
def _resolve_strict_vision_backend(
provider: str,
model: Optional[str] = None,
) -> Tuple[Optional[Any], Optional[str]]:
provider = _normalize_vision_provider(provider)
if provider == "copilot":
return resolve_provider_client("copilot", model, is_vision=True)
if provider == "openrouter":
return _try_openrouter()
if provider == "nous":
@ -2229,7 +2247,7 @@ def resolve_vision_provider_client(
return resolved_provider, None, None
final_model = resolved_model or default_model
if async_mode:
async_client, async_model = _to_async_client(sync_client, final_model)
async_client, async_model = _to_async_client(sync_client, final_model, is_vision=True)
return resolved_provider, async_client, async_model
return resolved_provider, sync_client, final_model
@ -2261,8 +2279,11 @@ def resolve_vision_provider_client(
main_provider = _read_main_provider()
main_model = _read_main_model()
if main_provider and main_provider not in ("auto", ""):
vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model)
if main_provider == "nous":
sync_client, default_model = _resolve_strict_vision_backend(main_provider)
sync_client, default_model = _resolve_strict_vision_backend(
main_provider, vision_model
)
if sync_client is not None:
logger.info(
"Vision auto-detect: using main provider %s (%s)",
@ -2270,10 +2291,10 @@ def resolve_vision_provider_client(
)
return _finalize(main_provider, sync_client, default_model)
else:
vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model)
rpc_client, rpc_model = resolve_provider_client(
main_provider, vision_model,
api_mode=resolved_api_mode)
api_mode=resolved_api_mode,
is_vision=True)
if rpc_client is not None:
logger.info(
"Vision auto-detect: using main provider %s (%s)",
@ -2295,11 +2316,14 @@ def resolve_vision_provider_client(
return None, None, None
if requested in _VISION_AUTO_PROVIDER_ORDER:
sync_client, default_model = _resolve_strict_vision_backend(requested)
sync_client, default_model = _resolve_strict_vision_backend(
requested, resolved_model
)
return _finalize(requested, sync_client, default_model)
client, final_model = _get_cached_client(requested, resolved_model, async_mode,
api_mode=resolved_api_mode)
api_mode=resolved_api_mode,
is_vision=True)
if client is None:
return requested, None, None
return requested, client, final_model
@ -2363,10 +2387,11 @@ def _client_cache_key(
api_key: Optional[str] = None,
api_mode: Optional[str] = None,
main_runtime: Optional[Dict[str, Any]] = None,
is_vision: bool = False,
) -> tuple:
runtime = _normalize_main_runtime(main_runtime)
runtime_key = tuple(runtime.get(field, "") for field in _MAIN_RUNTIME_FIELDS) if provider == "auto" else ()
return (provider, async_mode, base_url or "", api_key or "", api_mode or "", runtime_key)
return (provider, async_mode, base_url or "", api_key or "", api_mode or "", runtime_key, is_vision)
def _store_cached_client(cache_key: tuple, client: Any, default_model: Optional[str], *, bound_loop: Any = None) -> None:
@ -2392,6 +2417,7 @@ def _refresh_nous_auxiliary_client(
api_key: Optional[str] = None,
api_mode: Optional[str] = None,
main_runtime: Optional[Dict[str, Any]] = None,
is_vision: bool = False,
) -> Tuple[Optional[Any], Optional[str]]:
"""Refresh Nous runtime creds, rebuild the client, and replace the cache entry."""
runtime = _resolve_nous_runtime_api(force_refresh=True)
@ -2409,7 +2435,7 @@ def _refresh_nous_auxiliary_client(
current_loop = _aio.get_event_loop()
except RuntimeError:
pass
client, final_model = _to_async_client(sync_client, final_model or "")
client, final_model = _to_async_client(sync_client, final_model or "", is_vision=is_vision)
else:
client = sync_client
@ -2420,6 +2446,7 @@ def _refresh_nous_auxiliary_client(
api_key=api_key,
api_mode=api_mode,
main_runtime=main_runtime,
is_vision=is_vision,
)
_store_cached_client(cache_key, client, final_model, bound_loop=current_loop)
return client, final_model
@ -2549,6 +2576,7 @@ def _get_cached_client(
api_key: str = None,
api_mode: str = None,
main_runtime: Optional[Dict[str, Any]] = None,
is_vision: bool = False,
) -> Tuple[Optional[Any], Optional[str]]:
"""Get or create a cached client for the given provider.
@ -2585,6 +2613,7 @@ def _get_cached_client(
api_key=api_key,
api_mode=api_mode,
main_runtime=main_runtime,
is_vision=is_vision,
)
with _client_cache_lock:
if cache_key in _client_cache:
@ -2616,6 +2645,7 @@ def _get_cached_client(
explicit_api_key=api_key,
api_mode=api_mode,
main_runtime=runtime,
is_vision=is_vision,
)
if client is not None:
# For async clients, remember which loop they were created on so we
@ -3079,6 +3109,7 @@ def call_llm(
api_key=resolved_api_key,
api_mode=resolved_api_mode,
main_runtime=main_runtime,
is_vision=(task == "vision"),
)
if refreshed_client is not None:
logger.info("Auxiliary %s: refreshed Nous runtime credentials after 401, retrying",
@ -3369,6 +3400,7 @@ async def async_call_llm(
base_url=resolved_base_url,
api_key=resolved_api_key,
api_mode=resolved_api_mode,
is_vision=(task == "vision"),
)
if refreshed_client is not None:
logger.info("Auxiliary %s (async): refreshed Nous runtime credentials after 401, retrying",
@ -3437,7 +3469,9 @@ async def async_call_llm(
extra_body=effective_extra_body,
base_url=str(getattr(fb_client, "base_url", "") or ""))
# Convert sync fallback client to async
async_fb, async_fb_model = _to_async_client(fb_client, fb_model or "")
async_fb, async_fb_model = _to_async_client(
fb_client, fb_model or "", is_vision=(task == "vision")
)
if async_fb_model and async_fb_model != fb_kwargs.get("model"):
fb_kwargs["model"] = async_fb_model
return _validate_llm_response(

View file

@ -199,6 +199,7 @@ class TestResolveVisionMainFirst:
mock_resolve.assert_called_once()
assert mock_resolve.call_args.args[0] == "openrouter"
assert mock_resolve.call_args.args[1] == "anthropic/claude-sonnet-4.6"
assert mock_resolve.call_args.kwargs.get("is_vision") is True
def test_nous_main_vision_uses_paid_nous_vision_backend(self):
"""Paid Nous main → aux vision uses the dedicated Nous vision backend."""
@ -266,6 +267,87 @@ class TestResolveVisionMainFirst:
assert provider == "xiaomi"
# Should use mimo-v2.5 (vision override), not mimo-v2-pro (text main)
assert mock_resolve.call_args.args[1] == "mimo-v2.5"
assert mock_resolve.call_args.kwargs.get("is_vision") is True
def test_copilot_vision_sets_vision_header(self, monkeypatch):
"""Copilot vision requests include the header required for vision routing."""
monkeypatch.setenv("COPILOT_GITHUB_TOKEN", "ghu_test-token")
captured = {}
def fake_headers(*, is_agent_turn=False, is_vision=False):
captured["is_agent_turn"] = is_agent_turn
captured["is_vision"] = is_vision
return {"Copilot-Vision-Request": "true"} if is_vision else {}
with patch(
"agent.auxiliary_client._read_main_provider", return_value="copilot",
), patch(
"agent.auxiliary_client._read_main_model", return_value="configured-copilot-model",
), patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", None, None, None, None),
), patch(
"agent.auxiliary_client.OpenAI",
) as mock_openai, patch(
"hermes_cli.auth.resolve_api_key_provider_credentials",
return_value={
"provider": "copilot",
"api_key": "copilot-api-token",
"base_url": "https://api.githubcopilot.com",
},
), patch(
"hermes_cli.copilot_auth.copilot_request_headers",
side_effect=fake_headers,
):
mock_client = MagicMock()
mock_openai.return_value = mock_client
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, model = resolve_vision_provider_client()
assert provider == "copilot"
assert client is mock_client
assert model == "configured-copilot-model"
assert captured == {"is_agent_turn": True, "is_vision": True}
assert mock_openai.call_args.kwargs["default_headers"]["Copilot-Vision-Request"] == "true"
def test_text_copilot_does_not_set_vision_header(self, monkeypatch):
"""Text Copilot requests keep the vision-only header off."""
monkeypatch.setenv("COPILOT_GITHUB_TOKEN", "ghu_test-token")
captured = {}
def fake_headers(*, is_agent_turn=False, is_vision=False):
captured["is_agent_turn"] = is_agent_turn
captured["is_vision"] = is_vision
return {"Copilot-Vision-Request": "true"} if is_vision else {}
with patch(
"agent.auxiliary_client.OpenAI",
) as mock_openai, patch(
"hermes_cli.auth.resolve_api_key_provider_credentials",
return_value={
"provider": "copilot",
"api_key": "copilot-api-token",
"base_url": "https://api.githubcopilot.com",
},
), patch(
"hermes_cli.copilot_auth.copilot_request_headers",
side_effect=fake_headers,
):
mock_client = MagicMock()
mock_openai.return_value = mock_client
from agent.auxiliary_client import resolve_provider_client
client, model = resolve_provider_client("copilot", "gpt-5-mini")
assert client is mock_client
assert model == "gpt-5-mini"
assert captured == {"is_agent_turn": True, "is_vision": False}
assert "default_headers" not in mock_openai.call_args.kwargs
def test_main_unavailable_vision_falls_through_to_aggregators(self):
"""Main provider fails → fall back to OpenRouter/Nous strict backends."""
@ -312,7 +394,7 @@ class TestResolveVisionMainFirst:
# Explicit "nous" override → uses strict backend, NOT main model path
assert provider == "nous"
mock_strict.assert_called_once_with("nous")
mock_strict.assert_called_once_with("nous", None)
# ── Constant cleanup ────────────────────────────────────────────────────────

View file

@ -103,7 +103,7 @@ class TestCleanupStaleAsyncClients:
mock_client._client = MagicMock()
mock_client._client.is_closed = False
key = ("test_stale", True, "", "", "", ())
key = ("test_stale", True, "", "", "", (), False)
with _client_cache_lock:
_client_cache[key] = (mock_client, "test-model", loop)
@ -127,7 +127,7 @@ class TestCleanupStaleAsyncClients:
loop = asyncio.new_event_loop() # NOT closed
mock_client = MagicMock()
key = ("test_live", True, "", "", "", ())
key = ("test_live", True, "", "", "", (), False)
with _client_cache_lock:
_client_cache[key] = (mock_client, "test-model", loop)
@ -149,7 +149,7 @@ class TestCleanupStaleAsyncClients:
)
mock_client = MagicMock()
key = ("test_sync", False, "", "", "", ())
key = ("test_sync", False, "", "", "", (), False)
with _client_cache_lock:
_client_cache[key] = (mock_client, "test-model", None)
@ -182,7 +182,7 @@ class TestClientCacheBoundedGrowth:
_get_cached_client,
)
key = ("test_replace", True, "", "", "", ())
key = ("test_replace", True, "", "", "", (), False)
# Simulate a stale entry from a closed loop
old_loop = asyncio.new_event_loop()
@ -217,7 +217,7 @@ class TestClientCacheBoundedGrowth:
_client_cache_lock,
)
key = ("test_no_grow", True, "", "", "", ())
key = ("test_no_grow", True, "", "", "", (), False)
loops = []
try:
@ -269,7 +269,7 @@ class TestClientCacheBoundedGrowth:
mock_client = MagicMock()
mock_client._client = MagicMock()
mock_client._client.is_closed = False
key = (f"evict_test_{i}", False, "", "", "", ())
key = (f"evict_test_{i}", False, "", "", "", (), False)
with _client_cache_lock:
# Inline the eviction logic (same as _get_cached_client)
while len(_client_cache) >= _CLIENT_CACHE_MAX_SIZE:
@ -281,9 +281,9 @@ class TestClientCacheBoundedGrowth:
assert len(_client_cache) <= _CLIENT_CACHE_MAX_SIZE, \
f"Cache size {len(_client_cache)} exceeds max {_CLIENT_CACHE_MAX_SIZE}"
# The earliest entries should have been evicted
assert ("evict_test_0", False, "", "", "", ()) not in _client_cache
assert ("evict_test_0", False, "", "", "", (), False) not in _client_cache
# The latest entries should be present
assert (f"evict_test_{_CLIENT_CACHE_MAX_SIZE + 4}", False, "", "", "", ()) in _client_cache
assert (f"evict_test_{_CLIENT_CACHE_MAX_SIZE + 4}", False, "", "", "", (), False) in _client_cache
finally:
with _client_cache_lock:
_client_cache.clear()