fix(auth): sync shared Nous refresh tokens

This commit is contained in:
Michael Nguyen 2026-05-05 23:36:09 +07:00 committed by Teknium
parent 38b1c7dce5
commit a84e56d4c6
2 changed files with 362 additions and 161 deletions

View file

@ -2769,6 +2769,7 @@ def _poll_for_token(
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
NOUS_SHARED_STORE_FILENAME = "nous_auth.json" NOUS_SHARED_STORE_FILENAME = "nous_auth.json"
_nous_shared_lock_holder = threading.local()
def _nous_shared_auth_dir() -> Path: def _nous_shared_auth_dir() -> Path:
@ -2808,6 +2809,100 @@ def _nous_shared_store_path() -> Path:
return path return path
@contextmanager
def _nous_shared_store_lock(timeout_seconds: float = AUTH_LOCK_TIMEOUT_SECONDS):
"""Cross-profile lock for the shared Nous OAuth store."""
if getattr(_nous_shared_lock_holder, "depth", 0) > 0:
_nous_shared_lock_holder.depth += 1
try:
yield
finally:
_nous_shared_lock_holder.depth -= 1
return
try:
lock_path = _nous_shared_store_path().with_suffix(".lock")
except RuntimeError:
yield
return
lock_path.parent.mkdir(parents=True, exist_ok=True)
if fcntl is None and msvcrt is None:
_nous_shared_lock_holder.depth = 1
try:
yield
finally:
_nous_shared_lock_holder.depth = 0
return
if msvcrt and (not lock_path.exists() or lock_path.stat().st_size == 0):
lock_path.write_text(" ", encoding="utf-8")
with lock_path.open("r+" if msvcrt else "a+") as lock_file:
deadline = time.time() + max(1.0, timeout_seconds)
while True:
try:
if fcntl:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
else:
lock_file.seek(0)
msvcrt.locking(lock_file.fileno(), msvcrt.LK_NBLCK, 1)
break
except (BlockingIOError, OSError, PermissionError):
if time.time() >= deadline:
raise TimeoutError("Timed out waiting for shared Nous auth lock")
time.sleep(0.05)
_nous_shared_lock_holder.depth = 1
try:
yield
finally:
_nous_shared_lock_holder.depth = 0
if fcntl:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
elif msvcrt:
try:
lock_file.seek(0)
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
except (OSError, IOError):
pass
def _merge_shared_nous_oauth_state(state: Dict[str, Any]) -> bool:
"""Copy fresher shared OAuth tokens into a profile-local Nous state."""
shared = _read_shared_nous_state()
if not shared:
return False
shared_refresh = shared.get("refresh_token")
if not isinstance(shared_refresh, str) or not shared_refresh.strip():
return False
local_refresh = state.get("refresh_token")
shared_access_exp = _parse_iso_timestamp(shared.get("expires_at")) or 0.0
local_access_exp = _parse_iso_timestamp(state.get("expires_at")) or 0.0
refresh_changed = shared_refresh.strip() != str(local_refresh or "").strip()
fresher_access = shared_access_exp > local_access_exp
if not refresh_changed and not fresher_access:
return False
for key in (
"access_token",
"refresh_token",
"token_type",
"scope",
"client_id",
"portal_base_url",
"inference_base_url",
"obtained_at",
"expires_at",
):
value = shared.get(key)
if value not in (None, ""):
state[key] = value
return True
def _write_shared_nous_state(state: Dict[str, Any]) -> None: def _write_shared_nous_state(state: Dict[str, Any]) -> None:
"""Persist a minimal copy of the Nous OAuth state to the shared store. """Persist a minimal copy of the Nous OAuth state to the shared store.
@ -2840,6 +2935,7 @@ def _write_shared_nous_state(state: Dict[str, Any]) -> None:
"updated_at": datetime.now(timezone.utc).isoformat(), "updated_at": datetime.now(timezone.utc).isoformat(),
} }
try: try:
with _nous_shared_store_lock():
path = _nous_shared_store_path() path = _nous_shared_store_path()
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + ".tmp") tmp = path.with_suffix(path.suffix + ".tmp")
@ -2905,6 +3001,8 @@ def _try_import_shared_nous_state(
etc.) caller should then fall through to the normal device-code etc.) caller should then fall through to the normal device-code
flow. flow.
""" """
try:
with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)):
shared = _read_shared_nous_state() shared = _read_shared_nous_state()
if not shared: if not shared:
return None return None
@ -2927,7 +3025,6 @@ def _try_import_shared_nous_state(
"tls": {"insecure": False, "ca_bundle": None}, "tls": {"insecure": False, "ca_bundle": None},
} }
try:
refreshed = refresh_nous_oauth_from_state( refreshed = refresh_nous_oauth_from_state(
state, state,
min_key_ttl_seconds=min_key_ttl_seconds, min_key_ttl_seconds=min_key_ttl_seconds,
@ -2935,6 +3032,7 @@ def _try_import_shared_nous_state(
force_refresh=True, force_refresh=True,
force_mint=True, force_mint=True,
) )
_write_shared_nous_state(refreshed)
except AuthError as exc: except AuthError as exc:
_oauth_trace( _oauth_trace(
"nous_shared_import_failed", "nous_shared_import_failed",
@ -3136,6 +3234,8 @@ def resolve_nous_access_token(
client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID) client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID)
verify = _resolve_verify(insecure=insecure, ca_bundle=ca_bundle, auth_state=state) verify = _resolve_verify(insecure=insecure, ca_bundle=ca_bundle, auth_state=state)
with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)):
merged_shared = _merge_shared_nous_oauth_state(state)
access_token = state.get("access_token") access_token = state.get("access_token")
refresh_token = state.get("refresh_token") refresh_token = state.get("refresh_token")
if not isinstance(access_token, str) or not access_token: if not isinstance(access_token, str) or not access_token:
@ -3146,6 +3246,9 @@ def resolve_nous_access_token(
) )
if not _is_expiring(state.get("expires_at"), refresh_skew_seconds): if not _is_expiring(state.get("expires_at"), refresh_skew_seconds):
if merged_shared:
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
return access_token return access_token
if not isinstance(refresh_token, str) or not refresh_token: if not isinstance(refresh_token, str) or not refresh_token:
@ -3188,6 +3291,7 @@ def resolve_nous_access_token(
} }
_save_provider_state(auth_store, "nous", state) _save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store) _save_auth_store(auth_store)
_write_shared_nous_state(state)
return state["access_token"] return state["access_token"]
@ -3455,6 +3559,13 @@ def resolve_nous_runtime_credentials(
provider="nous", relogin_required=True) provider="nous", relogin_required=True)
# Step 1: refresh access token if expiring # Step 1: refresh access token if expiring
if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS):
with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)):
if _merge_shared_nous_oauth_state(state):
access_token = state.get("access_token")
refresh_token = state.get("refresh_token")
_persist_state("post_shared_merge_access_expiring")
if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS): if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS):
if not isinstance(refresh_token, str) or not refresh_token: if not isinstance(refresh_token, str) or not refresh_token:
raise AuthError("Session expired and no refresh token is available.", raise AuthError("Session expired and no refresh token is available.",
@ -3528,6 +3639,12 @@ def resolve_nous_runtime_credentials(
and isinstance(latest_refresh_token, str) and isinstance(latest_refresh_token, str)
and latest_refresh_token and latest_refresh_token
): ):
with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)):
if _merge_shared_nous_oauth_state(state):
access_token = state.get("access_token")
latest_refresh_token = state.get("refresh_token")
_persist_state("post_shared_merge_mint_retry")
else:
_oauth_trace( _oauth_trace(
"refresh_start", "refresh_start",
sequence_id=sequence_id, sequence_id=sequence_id,

View file

@ -1179,3 +1179,87 @@ def test_shared_store_survives_across_profile_switch(
shared_after = auth_mod._read_shared_nous_state() shared_after = auth_mod._read_shared_nous_state()
assert shared_after is not None assert shared_after is not None
assert shared_after["refresh_token"] == "b-refresh-tok" assert shared_after["refresh_token"] == "b-refresh-tok"
def test_runtime_refresh_uses_newer_shared_token_before_local_stale_token(
tmp_path, monkeypatch, shared_store_env,
):
"""A sibling profile may rotate the single-use Nous refresh token.
When this profile later wakes with an expired local token, runtime
resolution must adopt the shared token before refreshing. Otherwise it
can submit the stale local refresh token and trigger portal reuse
revocation for the whole shared session.
"""
from hermes_cli import auth as auth_mod
profile_b = tmp_path / "profile_b"
_setup_nous_auth(
profile_b,
access_token="local-expired-access",
refresh_token="local-stale-refresh",
)
monkeypatch.setenv("HERMES_HOME", str(profile_b))
shared_state = _full_state_fixture()
shared_state["access_token"] = "shared-fresh-access"
shared_state["refresh_token"] = "shared-fresh-refresh"
shared_state["expires_at"] = "2099-01-01T00:00:00+00:00"
auth_mod._write_shared_nous_state(shared_state)
def _refresh_should_not_happen(**_kwargs):
raise AssertionError("stale profile-local refresh token was used")
minted_with: list[str] = []
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
minted_with.append(access_token)
return _mint_payload(api_key="agent-key-from-shared-token")
monkeypatch.setattr(auth_mod, "_refresh_access_token", _refresh_should_not_happen)
monkeypatch.setattr(auth_mod, "_mint_agent_key", _fake_mint_agent_key)
creds = auth_mod.resolve_nous_runtime_credentials(
min_key_ttl_seconds=300,
force_mint=True,
)
assert creds["api_key"] == "agent-key-from-shared-token"
assert minted_with == ["shared-fresh-access"]
profile_state = auth_mod.get_provider_auth_state("nous")
assert profile_state is not None
assert profile_state["refresh_token"] == "shared-fresh-refresh"
assert profile_state["access_token"] == "shared-fresh-access"
def test_managed_gateway_access_token_uses_newer_shared_token(
tmp_path, monkeypatch, shared_store_env,
):
"""Managed-tool token reads share the same stale-refresh-token hazard."""
from hermes_cli import auth as auth_mod
profile_b = tmp_path / "profile_b"
_setup_nous_auth(
profile_b,
access_token="local-expired-access",
refresh_token="local-stale-refresh",
)
monkeypatch.setenv("HERMES_HOME", str(profile_b))
shared_state = _full_state_fixture()
shared_state["access_token"] = "shared-fresh-access"
shared_state["refresh_token"] = "shared-fresh-refresh"
shared_state["expires_at"] = "2099-01-01T00:00:00+00:00"
auth_mod._write_shared_nous_state(shared_state)
def _refresh_should_not_happen(**_kwargs):
raise AssertionError("stale profile-local refresh token was used")
monkeypatch.setattr(auth_mod, "_refresh_access_token", _refresh_should_not_happen)
assert auth_mod.resolve_nous_access_token() == "shared-fresh-access"
profile_state = auth_mod.get_provider_auth_state("nous")
assert profile_state is not None
assert profile_state["refresh_token"] == "shared-fresh-refresh"