fix(auth): sync shared Nous refresh tokens

This commit is contained in:
Michael Nguyen 2026-05-05 23:36:09 +07:00 committed by Teknium
parent 38b1c7dce5
commit a84e56d4c6
2 changed files with 362 additions and 161 deletions

View file

@ -2769,6 +2769,7 @@ def _poll_for_token(
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
NOUS_SHARED_STORE_FILENAME = "nous_auth.json" NOUS_SHARED_STORE_FILENAME = "nous_auth.json"
_nous_shared_lock_holder = threading.local()
def _nous_shared_auth_dir() -> Path: def _nous_shared_auth_dir() -> Path:
@ -2808,6 +2809,100 @@ def _nous_shared_store_path() -> Path:
return path return path
@contextmanager
def _nous_shared_store_lock(timeout_seconds: float = AUTH_LOCK_TIMEOUT_SECONDS):
"""Cross-profile lock for the shared Nous OAuth store."""
if getattr(_nous_shared_lock_holder, "depth", 0) > 0:
_nous_shared_lock_holder.depth += 1
try:
yield
finally:
_nous_shared_lock_holder.depth -= 1
return
try:
lock_path = _nous_shared_store_path().with_suffix(".lock")
except RuntimeError:
yield
return
lock_path.parent.mkdir(parents=True, exist_ok=True)
if fcntl is None and msvcrt is None:
_nous_shared_lock_holder.depth = 1
try:
yield
finally:
_nous_shared_lock_holder.depth = 0
return
if msvcrt and (not lock_path.exists() or lock_path.stat().st_size == 0):
lock_path.write_text(" ", encoding="utf-8")
with lock_path.open("r+" if msvcrt else "a+") as lock_file:
deadline = time.time() + max(1.0, timeout_seconds)
while True:
try:
if fcntl:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
else:
lock_file.seek(0)
msvcrt.locking(lock_file.fileno(), msvcrt.LK_NBLCK, 1)
break
except (BlockingIOError, OSError, PermissionError):
if time.time() >= deadline:
raise TimeoutError("Timed out waiting for shared Nous auth lock")
time.sleep(0.05)
_nous_shared_lock_holder.depth = 1
try:
yield
finally:
_nous_shared_lock_holder.depth = 0
if fcntl:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
elif msvcrt:
try:
lock_file.seek(0)
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
except (OSError, IOError):
pass
def _merge_shared_nous_oauth_state(state: Dict[str, Any]) -> bool:
"""Copy fresher shared OAuth tokens into a profile-local Nous state."""
shared = _read_shared_nous_state()
if not shared:
return False
shared_refresh = shared.get("refresh_token")
if not isinstance(shared_refresh, str) or not shared_refresh.strip():
return False
local_refresh = state.get("refresh_token")
shared_access_exp = _parse_iso_timestamp(shared.get("expires_at")) or 0.0
local_access_exp = _parse_iso_timestamp(state.get("expires_at")) or 0.0
refresh_changed = shared_refresh.strip() != str(local_refresh or "").strip()
fresher_access = shared_access_exp > local_access_exp
if not refresh_changed and not fresher_access:
return False
for key in (
"access_token",
"refresh_token",
"token_type",
"scope",
"client_id",
"portal_base_url",
"inference_base_url",
"obtained_at",
"expires_at",
):
value = shared.get(key)
if value not in (None, ""):
state[key] = value
return True
def _write_shared_nous_state(state: Dict[str, Any]) -> None: def _write_shared_nous_state(state: Dict[str, Any]) -> None:
"""Persist a minimal copy of the Nous OAuth state to the shared store. """Persist a minimal copy of the Nous OAuth state to the shared store.
@ -2840,15 +2935,16 @@ def _write_shared_nous_state(state: Dict[str, Any]) -> None:
"updated_at": datetime.now(timezone.utc).isoformat(), "updated_at": datetime.now(timezone.utc).isoformat(),
} }
try: try:
path = _nous_shared_store_path() with _nous_shared_store_lock():
path.parent.mkdir(parents=True, exist_ok=True) path = _nous_shared_store_path()
tmp = path.with_suffix(path.suffix + ".tmp") path.parent.mkdir(parents=True, exist_ok=True)
tmp.write_text(json.dumps(shared, indent=2, sort_keys=True)) tmp = path.with_suffix(path.suffix + ".tmp")
try: tmp.write_text(json.dumps(shared, indent=2, sort_keys=True))
os.chmod(tmp, 0o600) try:
except OSError: os.chmod(tmp, 0o600)
pass except OSError:
os.replace(tmp, path) pass
os.replace(tmp, path)
_oauth_trace( _oauth_trace(
"nous_shared_store_written", "nous_shared_store_written",
path=str(path), path=str(path),
@ -2905,36 +3001,38 @@ def _try_import_shared_nous_state(
etc.) caller should then fall through to the normal device-code etc.) caller should then fall through to the normal device-code
flow. flow.
""" """
shared = _read_shared_nous_state()
if not shared:
return None
# Build a full state dict so refresh_nous_oauth_from_state has every
# field it needs. force_refresh=True gets us a fresh access_token
# for this profile; force_mint=True gets us a fresh agent_key.
state: Dict[str, Any] = {
"access_token": shared.get("access_token"),
"refresh_token": shared.get("refresh_token"),
"client_id": shared.get("client_id") or DEFAULT_NOUS_CLIENT_ID,
"portal_base_url": shared.get("portal_base_url") or DEFAULT_NOUS_PORTAL_URL,
"inference_base_url": shared.get("inference_base_url") or DEFAULT_NOUS_INFERENCE_URL,
"token_type": shared.get("token_type") or "Bearer",
"scope": shared.get("scope") or DEFAULT_NOUS_SCOPE,
"obtained_at": shared.get("obtained_at"),
"expires_at": shared.get("expires_at"),
"agent_key": None,
"agent_key_expires_at": None,
"tls": {"insecure": False, "ca_bundle": None},
}
try: try:
refreshed = refresh_nous_oauth_from_state( with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)):
state, shared = _read_shared_nous_state()
min_key_ttl_seconds=min_key_ttl_seconds, if not shared:
timeout_seconds=timeout_seconds, return None
force_refresh=True,
force_mint=True, # Build a full state dict so refresh_nous_oauth_from_state has every
) # field it needs. force_refresh=True gets us a fresh access_token
# for this profile; force_mint=True gets us a fresh agent_key.
state: Dict[str, Any] = {
"access_token": shared.get("access_token"),
"refresh_token": shared.get("refresh_token"),
"client_id": shared.get("client_id") or DEFAULT_NOUS_CLIENT_ID,
"portal_base_url": shared.get("portal_base_url") or DEFAULT_NOUS_PORTAL_URL,
"inference_base_url": shared.get("inference_base_url") or DEFAULT_NOUS_INFERENCE_URL,
"token_type": shared.get("token_type") or "Bearer",
"scope": shared.get("scope") or DEFAULT_NOUS_SCOPE,
"obtained_at": shared.get("obtained_at"),
"expires_at": shared.get("expires_at"),
"agent_key": None,
"agent_key_expires_at": None,
"tls": {"insecure": False, "ca_bundle": None},
}
refreshed = refresh_nous_oauth_from_state(
state,
min_key_ttl_seconds=min_key_ttl_seconds,
timeout_seconds=timeout_seconds,
force_refresh=True,
force_mint=True,
)
_write_shared_nous_state(refreshed)
except AuthError as exc: except AuthError as exc:
_oauth_trace( _oauth_trace(
"nous_shared_import_failed", "nous_shared_import_failed",
@ -3136,59 +3234,65 @@ def resolve_nous_access_token(
client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID) client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID)
verify = _resolve_verify(insecure=insecure, ca_bundle=ca_bundle, auth_state=state) verify = _resolve_verify(insecure=insecure, ca_bundle=ca_bundle, auth_state=state)
access_token = state.get("access_token") with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)):
refresh_token = state.get("refresh_token") merged_shared = _merge_shared_nous_oauth_state(state)
if not isinstance(access_token, str) or not access_token: access_token = state.get("access_token")
raise AuthError( refresh_token = state.get("refresh_token")
"No access token found for Nous Portal login.", if not isinstance(access_token, str) or not access_token:
provider="nous", raise AuthError(
relogin_required=True, "No access token found for Nous Portal login.",
) provider="nous",
relogin_required=True,
)
if not _is_expiring(state.get("expires_at"), refresh_skew_seconds): if not _is_expiring(state.get("expires_at"), refresh_skew_seconds):
return access_token if merged_shared:
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
return access_token
if not isinstance(refresh_token, str) or not refresh_token: if not isinstance(refresh_token, str) or not refresh_token:
raise AuthError( raise AuthError(
"Session expired and no refresh token is available.", "Session expired and no refresh token is available.",
provider="nous", provider="nous",
relogin_required=True, relogin_required=True,
) )
timeout = httpx.Timeout(timeout_seconds if timeout_seconds else 15.0) timeout = httpx.Timeout(timeout_seconds if timeout_seconds else 15.0)
with httpx.Client( with httpx.Client(
timeout=timeout, timeout=timeout,
headers={"Accept": "application/json"}, headers={"Accept": "application/json"},
verify=verify, verify=verify,
) as client: ) as client:
refreshed = _refresh_access_token( refreshed = _refresh_access_token(
client=client, client=client,
portal_base_url=portal_base_url, portal_base_url=portal_base_url,
client_id=client_id, client_id=client_id,
refresh_token=refresh_token, refresh_token=refresh_token,
) )
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
state["access_token"] = refreshed["access_token"] state["access_token"] = refreshed["access_token"]
state["refresh_token"] = refreshed.get("refresh_token") or refresh_token state["refresh_token"] = refreshed.get("refresh_token") or refresh_token
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
state["scope"] = refreshed.get("scope") or state.get("scope") state["scope"] = refreshed.get("scope") or state.get("scope")
state["obtained_at"] = now.isoformat() state["obtained_at"] = now.isoformat()
state["expires_in"] = access_ttl state["expires_in"] = access_ttl
state["expires_at"] = datetime.fromtimestamp( state["expires_at"] = datetime.fromtimestamp(
now.timestamp() + access_ttl, now.timestamp() + access_ttl,
tz=timezone.utc, tz=timezone.utc,
).isoformat() ).isoformat()
state["portal_base_url"] = portal_base_url state["portal_base_url"] = portal_base_url
state["client_id"] = client_id state["client_id"] = client_id
state["tls"] = { state["tls"] = {
"insecure": verify is False, "insecure": verify is False,
"ca_bundle": verify if isinstance(verify, str) else None, "ca_bundle": verify if isinstance(verify, str) else None,
} }
_save_provider_state(auth_store, "nous", state) _save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store) _save_auth_store(auth_store)
return state["access_token"] _write_shared_nous_state(state)
return state["access_token"]
def refresh_nous_oauth_pure( def refresh_nous_oauth_pure(
@ -3456,46 +3560,53 @@ def resolve_nous_runtime_credentials(
# Step 1: refresh access token if expiring # Step 1: refresh access token if expiring
if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS): if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS):
if not isinstance(refresh_token, str) or not refresh_token: with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)):
raise AuthError("Session expired and no refresh token is available.", if _merge_shared_nous_oauth_state(state):
provider="nous", relogin_required=True) access_token = state.get("access_token")
refresh_token = state.get("refresh_token")
_persist_state("post_shared_merge_access_expiring")
_oauth_trace( if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS):
"refresh_start", if not isinstance(refresh_token, str) or not refresh_token:
sequence_id=sequence_id, raise AuthError("Session expired and no refresh token is available.",
reason="access_expiring", provider="nous", relogin_required=True)
refresh_token_fp=_token_fingerprint(refresh_token),
) _oauth_trace(
refreshed = _refresh_access_token( "refresh_start",
client=client, portal_base_url=portal_base_url, sequence_id=sequence_id,
client_id=client_id, refresh_token=refresh_token, reason="access_expiring",
) refresh_token_fp=_token_fingerprint(refresh_token),
now = datetime.now(timezone.utc) )
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) refreshed = _refresh_access_token(
previous_refresh_token = refresh_token client=client, portal_base_url=portal_base_url,
state["access_token"] = refreshed["access_token"] client_id=client_id, refresh_token=refresh_token,
state["refresh_token"] = refreshed.get("refresh_token") or refresh_token )
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" now = datetime.now(timezone.utc)
state["scope"] = refreshed.get("scope") or state.get("scope") access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
refreshed_url = _optional_base_url(refreshed.get("inference_base_url")) previous_refresh_token = refresh_token
if refreshed_url: state["access_token"] = refreshed["access_token"]
inference_base_url = refreshed_url state["refresh_token"] = refreshed.get("refresh_token") or refresh_token
state["obtained_at"] = now.isoformat() state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
state["expires_in"] = access_ttl state["scope"] = refreshed.get("scope") or state.get("scope")
state["expires_at"] = datetime.fromtimestamp( refreshed_url = _optional_base_url(refreshed.get("inference_base_url"))
now.timestamp() + access_ttl, tz=timezone.utc if refreshed_url:
).isoformat() inference_base_url = refreshed_url
access_token = state["access_token"] state["obtained_at"] = now.isoformat()
refresh_token = state["refresh_token"] state["expires_in"] = access_ttl
_oauth_trace( state["expires_at"] = datetime.fromtimestamp(
"refresh_success", now.timestamp() + access_ttl, tz=timezone.utc
sequence_id=sequence_id, ).isoformat()
reason="access_expiring", access_token = state["access_token"]
previous_refresh_token_fp=_token_fingerprint(previous_refresh_token), refresh_token = state["refresh_token"]
new_refresh_token_fp=_token_fingerprint(refresh_token), _oauth_trace(
) "refresh_success",
# Persist immediately so downstream mint failures cannot drop rotated refresh tokens. sequence_id=sequence_id,
_persist_state("post_refresh_access_expiring") reason="access_expiring",
previous_refresh_token_fp=_token_fingerprint(previous_refresh_token),
new_refresh_token_fp=_token_fingerprint(refresh_token),
)
# Persist immediately so downstream mint failures cannot drop rotated refresh tokens.
_persist_state("post_refresh_access_expiring")
# Step 2: mint agent key if missing/expiring # Step 2: mint agent key if missing/expiring
used_cached_key = False used_cached_key = False
@ -3528,41 +3639,47 @@ def resolve_nous_runtime_credentials(
and isinstance(latest_refresh_token, str) and isinstance(latest_refresh_token, str)
and latest_refresh_token and latest_refresh_token
): ):
_oauth_trace( with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)):
"refresh_start", if _merge_shared_nous_oauth_state(state):
sequence_id=sequence_id, access_token = state.get("access_token")
reason="mint_retry_after_invalid_token", latest_refresh_token = state.get("refresh_token")
refresh_token_fp=_token_fingerprint(latest_refresh_token), _persist_state("post_shared_merge_mint_retry")
) else:
refreshed = _refresh_access_token( _oauth_trace(
client=client, portal_base_url=portal_base_url, "refresh_start",
client_id=client_id, refresh_token=latest_refresh_token, sequence_id=sequence_id,
) reason="mint_retry_after_invalid_token",
now = datetime.now(timezone.utc) refresh_token_fp=_token_fingerprint(latest_refresh_token),
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) )
state["access_token"] = refreshed["access_token"] refreshed = _refresh_access_token(
state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token client=client, portal_base_url=portal_base_url,
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" client_id=client_id, refresh_token=latest_refresh_token,
state["scope"] = refreshed.get("scope") or state.get("scope") )
refreshed_url = _optional_base_url(refreshed.get("inference_base_url")) now = datetime.now(timezone.utc)
if refreshed_url: access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
inference_base_url = refreshed_url state["access_token"] = refreshed["access_token"]
state["obtained_at"] = now.isoformat() state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token
state["expires_in"] = access_ttl state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
state["expires_at"] = datetime.fromtimestamp( state["scope"] = refreshed.get("scope") or state.get("scope")
now.timestamp() + access_ttl, tz=timezone.utc refreshed_url = _optional_base_url(refreshed.get("inference_base_url"))
).isoformat() if refreshed_url:
access_token = state["access_token"] inference_base_url = refreshed_url
refresh_token = state["refresh_token"] state["obtained_at"] = now.isoformat()
_oauth_trace( state["expires_in"] = access_ttl
"refresh_success", state["expires_at"] = datetime.fromtimestamp(
sequence_id=sequence_id, now.timestamp() + access_ttl, tz=timezone.utc
reason="mint_retry_after_invalid_token", ).isoformat()
previous_refresh_token_fp=_token_fingerprint(latest_refresh_token), access_token = state["access_token"]
new_refresh_token_fp=_token_fingerprint(refresh_token), refresh_token = state["refresh_token"]
) _oauth_trace(
# Persist retry refresh immediately for crash safety and cross-process visibility. "refresh_success",
_persist_state("post_refresh_mint_retry") sequence_id=sequence_id,
reason="mint_retry_after_invalid_token",
previous_refresh_token_fp=_token_fingerprint(latest_refresh_token),
new_refresh_token_fp=_token_fingerprint(refresh_token),
)
# Persist retry refresh immediately for crash safety and cross-process visibility.
_persist_state("post_refresh_mint_retry")
mint_payload = _mint_agent_key( mint_payload = _mint_agent_key(
client=client, portal_base_url=portal_base_url, client=client, portal_base_url=portal_base_url,

View file

@ -1179,3 +1179,87 @@ def test_shared_store_survives_across_profile_switch(
shared_after = auth_mod._read_shared_nous_state() shared_after = auth_mod._read_shared_nous_state()
assert shared_after is not None assert shared_after is not None
assert shared_after["refresh_token"] == "b-refresh-tok" assert shared_after["refresh_token"] == "b-refresh-tok"
def test_runtime_refresh_uses_newer_shared_token_before_local_stale_token(
tmp_path, monkeypatch, shared_store_env,
):
"""A sibling profile may rotate the single-use Nous refresh token.
When this profile later wakes with an expired local token, runtime
resolution must adopt the shared token before refreshing. Otherwise it
can submit the stale local refresh token and trigger portal reuse
revocation for the whole shared session.
"""
from hermes_cli import auth as auth_mod
profile_b = tmp_path / "profile_b"
_setup_nous_auth(
profile_b,
access_token="local-expired-access",
refresh_token="local-stale-refresh",
)
monkeypatch.setenv("HERMES_HOME", str(profile_b))
shared_state = _full_state_fixture()
shared_state["access_token"] = "shared-fresh-access"
shared_state["refresh_token"] = "shared-fresh-refresh"
shared_state["expires_at"] = "2099-01-01T00:00:00+00:00"
auth_mod._write_shared_nous_state(shared_state)
def _refresh_should_not_happen(**_kwargs):
raise AssertionError("stale profile-local refresh token was used")
minted_with: list[str] = []
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
minted_with.append(access_token)
return _mint_payload(api_key="agent-key-from-shared-token")
monkeypatch.setattr(auth_mod, "_refresh_access_token", _refresh_should_not_happen)
monkeypatch.setattr(auth_mod, "_mint_agent_key", _fake_mint_agent_key)
creds = auth_mod.resolve_nous_runtime_credentials(
min_key_ttl_seconds=300,
force_mint=True,
)
assert creds["api_key"] == "agent-key-from-shared-token"
assert minted_with == ["shared-fresh-access"]
profile_state = auth_mod.get_provider_auth_state("nous")
assert profile_state is not None
assert profile_state["refresh_token"] == "shared-fresh-refresh"
assert profile_state["access_token"] == "shared-fresh-access"
def test_managed_gateway_access_token_uses_newer_shared_token(
tmp_path, monkeypatch, shared_store_env,
):
"""Managed-tool token reads share the same stale-refresh-token hazard."""
from hermes_cli import auth as auth_mod
profile_b = tmp_path / "profile_b"
_setup_nous_auth(
profile_b,
access_token="local-expired-access",
refresh_token="local-stale-refresh",
)
monkeypatch.setenv("HERMES_HOME", str(profile_b))
shared_state = _full_state_fixture()
shared_state["access_token"] = "shared-fresh-access"
shared_state["refresh_token"] = "shared-fresh-refresh"
shared_state["expires_at"] = "2099-01-01T00:00:00+00:00"
auth_mod._write_shared_nous_state(shared_state)
def _refresh_should_not_happen(**_kwargs):
raise AssertionError("stale profile-local refresh token was used")
monkeypatch.setattr(auth_mod, "_refresh_access_token", _refresh_should_not_happen)
assert auth_mod.resolve_nous_access_token() == "shared-fresh-access"
profile_state = auth_mod.get_provider_auth_state("nous")
assert profile_state is not None
assert profile_state["refresh_token"] == "shared-fresh-refresh"