refactor(auth): collapse Nous inference fallback controls

This commit is contained in:
Robin Fernandes 2026-05-17 20:34:39 +10:00 committed by Teknium
parent 89a3d038cf
commit 0bac7dd05b
13 changed files with 1071 additions and 240 deletions

View file

@ -81,6 +81,21 @@ class UpstreamAdapter(ABC):
refresh fails. The proxy will return 401 to the client.
"""
def get_retry_credential(
self,
*,
failed_credential: UpstreamCredential,
status_code: int,
) -> Optional[UpstreamCredential]:
"""Return an alternate credential after an upstream auth failure.
The default is no retry. Providers can override this for one-shot
fallback paths, such as switching from a preferred token type to a
legacy bearer after the upstream rejects the first request.
"""
del failed_credential, status_code
return None
def describe(self) -> str:
"""One-line status summary for ``proxy status``."""
try:

View file

@ -19,13 +19,16 @@ from typing import Any, Dict, FrozenSet, Optional
from hermes_cli.auth import (
AuthError,
DEFAULT_NOUS_INFERENCE_URL,
NOUS_INFERENCE_AUTH_AUTO,
NOUS_INFERENCE_AUTH_LEGACY,
_load_auth_store,
_is_terminal_nous_refresh_error,
_quarantine_nous_oauth_state,
_quarantine_nous_pool_entries,
_save_auth_store,
_write_shared_nous_state,
refresh_nous_oauth_from_state,
)
)
from hermes_cli.proxy.adapters.base import UpstreamAdapter, UpstreamCredential
logger = logging.getLogger(__name__)
@ -76,6 +79,21 @@ class NousPortalAdapter(UpstreamAdapter):
)
def get_credential(self) -> UpstreamCredential:
return self._get_credential(auth_mode=NOUS_INFERENCE_AUTH_AUTO)
def get_retry_credential(
self,
*,
failed_credential: UpstreamCredential,
status_code: int,
) -> Optional[UpstreamCredential]:
del failed_credential
if status_code != 401:
return None
logger.info("proxy: Nous upstream rejected bearer; retrying with legacy session key")
return self._get_credential(auth_mode=NOUS_INFERENCE_AUTH_LEGACY)
def _get_credential(self, *, auth_mode: str) -> UpstreamCredential:
with self._lock:
state = self._read_state()
if state is None:
@ -84,7 +102,10 @@ class NousPortalAdapter(UpstreamAdapter):
)
try:
refreshed = refresh_nous_oauth_from_state(state)
refreshed = refresh_nous_oauth_from_state(
state,
auth_mode=auth_mode,
)
except AuthError as exc:
if _is_terminal_nous_refresh_error(exc):
_quarantine_nous_oauth_state(
@ -92,7 +113,11 @@ class NousPortalAdapter(UpstreamAdapter):
exc,
reason="proxy_refresh_failure",
)
self._save_state(state)
self._save_state(
state,
quarantine_error=exc,
quarantine_reason="proxy_refresh_failure",
)
raise RuntimeError(
f"Failed to refresh Nous Portal credentials: {exc}"
) from exc
@ -136,9 +161,21 @@ class NousPortalAdapter(UpstreamAdapter):
return None
return dict(state) # copy so the refresh helper can mutate freely
def _save_state(self, state: Dict[str, Any]) -> None:
def _save_state(
self,
state: Dict[str, Any],
*,
quarantine_error: Optional[AuthError] = None,
quarantine_reason: Optional[str] = None,
) -> None:
try:
store = _load_auth_store()
if quarantine_error is not None and quarantine_reason:
_quarantine_nous_pool_entries(
store,
quarantine_error,
reason=quarantine_reason,
)
providers = store.setdefault("providers", {})
providers["nous"] = state
_save_auth_store(store)

View file

@ -26,7 +26,7 @@ except ImportError:
web = None # type: ignore[assignment]
AIOHTTP_AVAILABLE = False
from hermes_cli.proxy.adapters.base import UpstreamAdapter
from hermes_cli.proxy.adapters.base import UpstreamAdapter, UpstreamCredential
logger = logging.getLogger(__name__)
@ -136,50 +136,93 @@ def create_app(adapter: UpstreamAdapter) -> "web.Application":
logger.warning("proxy: credential resolution failed: %s", exc)
return _json_error(401, str(exc), code="upstream_auth_failed")
upstream_url = f"{cred.base_url.rstrip('/')}{rel_path}"
# Preserve query string verbatim.
if request.query_string:
upstream_url = f"{upstream_url}?{request.query_string}"
# Forward body verbatim. Read into memory once — request bodies for
# chat/completions/embeddings are small (<1MB typically). If we ever
# need to forward large multipart uploads we'll switch to streaming
# the request body too.
body = await request.read()
fwd_headers = _filter_request_headers(request.headers)
fwd_headers["Authorization"] = f"{cred.token_type} {cred.bearer}"
logger.debug(
"proxy: forwarding %s %s -> %s (body=%d bytes)",
request.method, rel_path, upstream_url, len(body),
)
# Use a per-request session so connection state doesn't leak between
# clients. Could be optimized to a shared session later.
timeout = aiohttp.ClientTimeout(total=None, sock_connect=15, sock_read=300)
try:
session = aiohttp.ClientSession(timeout=timeout)
except Exception as exc: # pragma: no cover - aiohttp setup issue
return _json_error(500, f"proxy session init failed: {exc}")
try:
upstream_resp = await session.request(
request.method,
upstream_url,
data=body if body else None,
headers=fwd_headers,
allow_redirects=False,
async def _send_upstream(active_cred: UpstreamCredential):
upstream_url = f"{active_cred.base_url.rstrip('/')}{rel_path}"
# Preserve query string verbatim.
if request.query_string:
upstream_url = f"{upstream_url}?{request.query_string}"
fwd_headers = _filter_request_headers(request.headers)
fwd_headers["Authorization"] = f"{active_cred.token_type} {active_cred.bearer}"
logger.debug(
"proxy: forwarding %s %s -> %s (body=%d bytes)",
request.method, rel_path, upstream_url, len(body),
)
except aiohttp.ClientError as exc:
await session.close()
logger.warning("proxy: upstream connection failed: %s", exc)
return _json_error(502, f"upstream connection failed: {exc}",
code="upstream_unreachable")
except asyncio.TimeoutError:
await session.close()
return _json_error(504, "upstream request timed out",
code="upstream_timeout")
try:
session = aiohttp.ClientSession(timeout=timeout)
except Exception as exc: # pragma: no cover - aiohttp setup issue
raise RuntimeError(f"proxy session init failed: {exc}") from exc
try:
upstream_resp = await session.request(
request.method,
upstream_url,
data=body if body else None,
headers=fwd_headers,
allow_redirects=False,
)
except Exception:
await session.close()
raise
return session, upstream_resp
async def _open_upstream(active_cred: UpstreamCredential):
try:
return await _send_upstream(active_cred)
except RuntimeError as exc:
return _json_error(500, str(exc)), None
except aiohttp.ClientError as exc:
logger.warning("proxy: upstream connection failed: %s", exc)
return (
_json_error(
502,
f"upstream connection failed: {exc}",
code="upstream_unreachable",
),
None,
)
except asyncio.TimeoutError:
return (
_json_error(
504,
"upstream request timed out",
code="upstream_timeout",
),
None,
)
session_or_response, upstream_resp = await _open_upstream(cred)
if upstream_resp is None:
return session_or_response
session = session_or_response
if upstream_resp.status == 401:
try:
retry_cred = adapter.get_retry_credential(
failed_credential=cred,
status_code=upstream_resp.status,
)
except Exception as exc:
logger.warning("proxy: retry credential resolution failed: %s", exc)
retry_cred = None
if retry_cred is not None:
upstream_resp.release()
await session.close()
session_or_response, upstream_resp = await _open_upstream(retry_cred)
if upstream_resp is None:
return session_or_response
session = session_or_response
# Stream response back. Headers first, then chunked body.
resp = web.StreamResponse(