fix(auth): refresh xAI OAuth tokens earlier

This commit is contained in:
Veritas-7 2026-06-15 20:25:57 +09:00 committed by Teknium
parent aab2e99bae
commit febdddb41a
3 changed files with 79 additions and 31 deletions

View file

@ -103,7 +103,12 @@ XAI_OAUTH_SCOPE = "openid profile email offline_access grok-cli:access api:acces
XAI_OAUTH_REDIRECT_HOST = "127.0.0.1"
XAI_OAUTH_REDIRECT_PORT = 56121
XAI_OAUTH_REDIRECT_PATH = "/callback"
XAI_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
# xAI/Grok OAuth access tokens are intentionally short-lived (about 6h in
# current SuperGrok flows). A two-minute refresh window is too narrow for
# gateway/cron workloads that may only touch the provider every 30 minutes,
# leaving brief but noisy credential-expiry gaps. Refresh up to one hour
# early so ordinary runtime calls keep the token warm without user reauth.
XAI_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 3600
QWEN_OAUTH_CLIENT_ID = "f0304373b74a44d2b584a3fb70ca9e56"
QWEN_OAUTH_TOKEN_URL = "https://chat.qwen.ai/api/v1/oauth2/token"
QWEN_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120

View file

@ -152,7 +152,7 @@ def test_xai_access_token_is_expiring_returns_true_for_expired_jwt():
def test_xai_access_token_is_expiring_returns_false_for_fresh_jwt():
fresh = _jwt_with_exp(int(time.time()) + 3600)
fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
assert _xai_access_token_is_expiring(fresh, 0) is False
@ -476,7 +476,7 @@ def test_read_xai_oauth_tokens_missing_refresh_token(tmp_path, monkeypatch):
def test_resolve_xai_runtime_credentials_returns_singleton_state(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
fresh = _jwt_with_exp(int(time.time()) + 3600)
fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
_setup_hermes_auth(hermes_home, access_token=fresh)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.delenv("HERMES_XAI_BASE_URL", raising=False)
@ -501,7 +501,7 @@ def test_resolve_xai_runtime_credentials_refreshes_expiring_token(tmp_path, monk
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
new_access = _jwt_with_exp(int(time.time()) + 3600)
new_access = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
called = {"count": 0}
def _fake_refresh(tokens, **kwargs):
@ -520,7 +520,7 @@ def test_resolve_xai_runtime_credentials_refreshes_expiring_token(tmp_path, monk
def test_resolve_xai_runtime_credentials_force_refresh(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
fresh = _jwt_with_exp(int(time.time()) + 3600)
fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
_setup_hermes_auth(
hermes_home,
access_token=fresh,
@ -546,7 +546,7 @@ def test_resolve_xai_runtime_credentials_force_refresh(tmp_path, monkeypatch):
def test_resolve_xai_runtime_credentials_honours_env_base_url(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
fresh = _jwt_with_exp(int(time.time()) + 3600)
fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
_setup_hermes_auth(hermes_home, access_token=fresh)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setenv("HERMES_XAI_BASE_URL", "https://custom.x.ai/v1/")
@ -669,7 +669,7 @@ def test_resolve_xai_runtime_credentials_rejects_off_origin_env_base_url(tmp_pat
# the resolver MUST silently fall back to the default rather than ship
# the OAuth bearer to the attacker.
hermes_home = tmp_path / "hermes"
fresh = _jwt_with_exp(int(time.time()) + 3600)
fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
_setup_hermes_auth(hermes_home, access_token=fresh)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.setenv("XAI_BASE_URL", "https://attacker.example/v1")
@ -807,7 +807,7 @@ def test_resolve_credentials_does_not_quarantine_on_transient_refresh_failure(
def test_get_xai_oauth_auth_status_logged_in_via_singleton(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
fresh = _jwt_with_exp(int(time.time()) + 3600)
fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
_setup_hermes_auth(hermes_home, access_token=fresh)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
@ -910,7 +910,7 @@ def test_format_auth_error_tier_denied_does_not_suggest_relogin():
def test_refresh_xai_oauth_pure_returns_updated_tokens(monkeypatch):
new_access = _jwt_with_exp(int(time.time()) + 3600)
new_access = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
response = _StubHTTPResponse(
200,
{
@ -941,7 +941,7 @@ def test_refresh_xai_oauth_pure_returns_updated_tokens(monkeypatch):
def test_refresh_xai_oauth_pure_keeps_refresh_token_when_response_omits_it(monkeypatch):
"""Some OAuth providers don't rotate refresh tokens — preserve the old one."""
new_access = _jwt_with_exp(int(time.time()) + 3600)
new_access = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
response = _StubHTTPResponse(
200,
{
@ -1080,7 +1080,7 @@ def test_refresh_xai_oauth_pure_accepts_apex_and_subdomain_endpoints(monkeypatch
``*.x.ai`` subdomain (e.g. ``auth.x.ai`` today, future migrations to
``accounts.x.ai`` etc.). Without subdomain support we'd lock the
integration to whatever xAI happens to use today."""
new_access = _jwt_with_exp(int(time.time()) + 3600)
new_access = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
response = _StubHTTPResponse(
200,
{"access_token": new_access, "expires_in": 3600, "token_type": "Bearer"},
@ -1172,7 +1172,7 @@ def test_credential_pool_seeds_xai_oauth_from_singleton(tmp_path, monkeypatch):
from agent.credential_pool import load_pool
hermes_home = tmp_path / "hermes"
fresh = _jwt_with_exp(int(time.time()) + 3600)
fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
_setup_hermes_auth(hermes_home, access_token=fresh, refresh_token="rt-1")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
@ -1214,7 +1214,7 @@ def test_credential_pool_seed_respects_suppression(tmp_path, monkeypatch):
from agent.credential_pool import load_pool
hermes_home = tmp_path / "hermes"
fresh = _jwt_with_exp(int(time.time()) + 3600)
fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
_setup_hermes_auth(hermes_home, access_token=fresh)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
@ -1246,7 +1246,7 @@ def test_auth_remove_xai_oauth_clears_singleton_and_sticks(tmp_path, monkeypatch
from types import SimpleNamespace
hermes_home = tmp_path / "hermes"
fresh = _jwt_with_exp(int(time.time()) + 3600)
fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
_setup_hermes_auth(hermes_home, access_token=fresh, refresh_token="rt-1")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
@ -1294,7 +1294,7 @@ def test_pool_sync_back_writes_to_singleton(tmp_path, monkeypatch):
_setup_hermes_auth(hermes_home, access_token=expired, refresh_token="rt-old")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
new_access = _jwt_with_exp(int(time.time()) + 3600)
new_access = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
def _fake_refresh(access_token, refresh_token, **kwargs):
assert refresh_token == "rt-old"
@ -1334,7 +1334,7 @@ def test_runtime_provider_uses_pool_entry_for_xai_oauth(tmp_path, monkeypatch):
from hermes_cli.runtime_provider import resolve_runtime_provider
hermes_home = tmp_path / "hermes"
fresh = _jwt_with_exp(int(time.time()) + 3600)
fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
_setup_hermes_auth(hermes_home, access_token=fresh)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.delenv("HERMES_XAI_BASE_URL", raising=False)
@ -1360,7 +1360,7 @@ def test_runtime_provider_default_base_url_when_pool_entry_missing_url(tmp_path,
monkeypatch.delenv("HERMES_XAI_BASE_URL", raising=False)
monkeypatch.delenv("XAI_BASE_URL", raising=False)
fresh = _jwt_with_exp(int(time.time()) + 3600)
fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
pool = load_pool("xai-oauth")
pool.add_entry(
PooledCredential(
@ -1404,7 +1404,7 @@ def test_pool_entry_needs_refresh_when_jwt_within_skew(tmp_path, monkeypatch):
(hermes_home / "auth.json").write_text(json.dumps({"version": 1, "providers": {}}))
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
# Token expires in 30s — well inside the 120s skew window.
# Token expires in 30s — well inside the proactive refresh skew window.
near_expiry = _jwt_with_exp(int(time.time()) + 30)
pool = load_pool("xai-oauth")
entry = PooledCredential(
@ -1433,7 +1433,7 @@ def test_pool_entry_no_refresh_for_fresh_jwt(tmp_path, monkeypatch):
(hermes_home / "auth.json").write_text(json.dumps({"version": 1, "providers": {}}))
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
fresh = _jwt_with_exp(int(time.time()) + 3600)
fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
pool = load_pool("xai-oauth")
entry = PooledCredential(
provider="xai-oauth",
@ -1463,7 +1463,7 @@ def test_pool_select_proactively_refreshes_expiring_token(tmp_path, monkeypatch)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
near_expiry = _jwt_with_exp(int(time.time()) + 30)
new_access = _jwt_with_exp(int(time.time()) + 3600)
new_access = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
refresh_calls = {"count": 0}
@ -1520,7 +1520,7 @@ def test_pool_try_refresh_current_handles_xai_oauth(tmp_path, monkeypatch):
# We simulate the scenario where the server rejected the token (401)
# despite client-side expiry math saying it's still valid (e.g. clock
# skew, server-side revocation, token bound to a session that expired).
seemingly_fresh = _jwt_with_exp(int(time.time()) + 3600)
seemingly_fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
new_access = _jwt_with_exp(int(time.time()) + 7200)
def _fake_refresh(access_token, refresh_token, **kwargs):
@ -1577,7 +1577,7 @@ def test_pool_refresh_marks_entry_exhausted_on_failure(tmp_path, monkeypatch):
monkeypatch.setattr("hermes_cli.auth.refresh_xai_oauth_pure", _fake_refresh_fail)
pool = load_pool("xai-oauth")
seemingly_fresh = _jwt_with_exp(int(time.time()) + 3600)
seemingly_fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
pool.add_entry(
PooledCredential(
provider="xai-oauth",
@ -1609,7 +1609,7 @@ def test_pool_seeded_entry_sync_back_after_refresh(tmp_path, monkeypatch):
_setup_hermes_auth(hermes_home, access_token=near_expiry, refresh_token="rt-singleton")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
new_access = _jwt_with_exp(int(time.time()) + 3600)
new_access = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
def _fake_refresh(access_token, refresh_token, **kwargs):
assert refresh_token == "rt-singleton"
@ -1658,7 +1658,7 @@ def test_pool_refresh_adopts_singleton_tokens_when_consumed_elsewhere(tmp_path,
# Now simulate "another process refreshed the tokens" by overwriting
# the singleton on disk WITHOUT touching this process's pool object.
other_process_at = _jwt_with_exp(int(time.time()) + 3600)
other_process_at = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
raw = json.loads((hermes_home / "auth.json").read_text())
raw["providers"]["xai-oauth"]["tokens"] = {
"access_token": other_process_at,
@ -1708,7 +1708,7 @@ def test_pool_refresh_recovers_when_other_process_already_refreshed(tmp_path, mo
pool = load_pool("xai-oauth")
other_process_at = _jwt_with_exp(int(time.time()) + 3600)
other_process_at = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
def _fake_refresh(access_token, refresh_token, **kwargs):
# Simulate the racing process winning at the auth server right
@ -1750,7 +1750,7 @@ def test_pool_exhausted_xai_entry_recovers_after_singleton_refresh(tmp_path, mon
from dataclasses import replace
hermes_home = tmp_path / "hermes"
stale_at = _jwt_with_exp(int(time.time()) + 3600)
stale_at = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
_setup_hermes_auth(hermes_home, access_token=stale_at, refresh_token="rt-stale")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
@ -1804,7 +1804,7 @@ def test_pool_manual_xai_entry_not_synced_from_singleton(tmp_path, monkeypatch):
import uuid
hermes_home = tmp_path / "hermes"
singleton_at = _jwt_with_exp(int(time.time()) + 3600)
singleton_at = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
_setup_hermes_auth(hermes_home, access_token=singleton_at, refresh_token="rt-singleton")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
@ -1842,7 +1842,7 @@ def test_pool_manual_entry_does_not_sync_back_to_singleton(tmp_path, monkeypatch
hermes_home = tmp_path / "hermes"
# Singleton has its own tokens (separate login).
singleton_at = _jwt_with_exp(int(time.time()) + 3600)
singleton_at = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
_setup_hermes_auth(hermes_home, access_token=singleton_at, refresh_token="rt-singleton")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
@ -1911,7 +1911,7 @@ def test_auxiliary_client_routes_xai_oauth_through_responses_api(tmp_path, monke
)
hermes_home = tmp_path / "hermes"
fresh = _jwt_with_exp(int(time.time()) + 3600)
fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
_setup_hermes_auth(hermes_home, access_token=fresh)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
monkeypatch.delenv("HERMES_XAI_BASE_URL", raising=False)
@ -1955,7 +1955,7 @@ def test_auxiliary_client_xai_oauth_requires_explicit_model(tmp_path, monkeypatc
from agent.auxiliary_client import resolve_provider_client
hermes_home = tmp_path / "hermes"
fresh = _jwt_with_exp(int(time.time()) + 3600)
fresh = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
_setup_hermes_auth(hermes_home, access_token=fresh)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
@ -1991,7 +1991,7 @@ def test_pool_sync_back_preserves_active_provider(tmp_path, monkeypatch):
raw["active_provider"] = "openrouter"
(hermes_home / "auth.json").write_text(json.dumps(raw))
new_access = _jwt_with_exp(int(time.time()) + 3600)
new_access = _jwt_with_exp(int(time.time()) + 2 * 60 * 60)
def _fake_refresh(access_token, refresh_token, **kwargs):
return {

View file

@ -0,0 +1,43 @@
from __future__ import annotations
import base64
import json
import time
from hermes_cli import auth
def _jwt_with_exp(exp: int) -> str:
header = (
base64.urlsafe_b64encode(json.dumps({"alg": "none"}).encode())
.decode()
.rstrip("=")
)
payload = (
base64.urlsafe_b64encode(json.dumps({"exp": exp}).encode())
.decode()
.rstrip("=")
)
return f"{header}.{payload}.sig"
def test_xai_oauth_refresh_skew_is_one_hour() -> None:
assert auth.XAI_ACCESS_TOKEN_REFRESH_SKEW_SECONDS == 3600
def test_xai_oauth_token_expiring_uses_one_hour_skew() -> None:
token = _jwt_with_exp(int(time.time()) + 30 * 60)
assert auth._xai_access_token_is_expiring(
token,
auth.XAI_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
)
def test_xai_oauth_token_not_expiring_beyond_one_hour_skew() -> None:
token = _jwt_with_exp(int(time.time()) + 90 * 60)
assert not auth._xai_access_token_is_expiring(
token,
auth.XAI_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
)