fix(auth): sync shared Nous refresh tokens

This commit is contained in:
Michael Nguyen 2026-05-05 23:36:09 +07:00 committed by Teknium
parent 38b1c7dce5
commit a84e56d4c6
2 changed files with 362 additions and 161 deletions

View file

@ -2769,6 +2769,7 @@ def _poll_for_token(
# -----------------------------------------------------------------------------
NOUS_SHARED_STORE_FILENAME = "nous_auth.json"
_nous_shared_lock_holder = threading.local()
def _nous_shared_auth_dir() -> Path:
@ -2808,6 +2809,100 @@ def _nous_shared_store_path() -> Path:
return path
@contextmanager
def _nous_shared_store_lock(timeout_seconds: float = AUTH_LOCK_TIMEOUT_SECONDS):
"""Cross-profile lock for the shared Nous OAuth store."""
if getattr(_nous_shared_lock_holder, "depth", 0) > 0:
_nous_shared_lock_holder.depth += 1
try:
yield
finally:
_nous_shared_lock_holder.depth -= 1
return
try:
lock_path = _nous_shared_store_path().with_suffix(".lock")
except RuntimeError:
yield
return
lock_path.parent.mkdir(parents=True, exist_ok=True)
if fcntl is None and msvcrt is None:
_nous_shared_lock_holder.depth = 1
try:
yield
finally:
_nous_shared_lock_holder.depth = 0
return
if msvcrt and (not lock_path.exists() or lock_path.stat().st_size == 0):
lock_path.write_text(" ", encoding="utf-8")
with lock_path.open("r+" if msvcrt else "a+") as lock_file:
deadline = time.time() + max(1.0, timeout_seconds)
while True:
try:
if fcntl:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
else:
lock_file.seek(0)
msvcrt.locking(lock_file.fileno(), msvcrt.LK_NBLCK, 1)
break
except (BlockingIOError, OSError, PermissionError):
if time.time() >= deadline:
raise TimeoutError("Timed out waiting for shared Nous auth lock")
time.sleep(0.05)
_nous_shared_lock_holder.depth = 1
try:
yield
finally:
_nous_shared_lock_holder.depth = 0
if fcntl:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
elif msvcrt:
try:
lock_file.seek(0)
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
except (OSError, IOError):
pass
def _merge_shared_nous_oauth_state(state: Dict[str, Any]) -> bool:
"""Copy fresher shared OAuth tokens into a profile-local Nous state."""
shared = _read_shared_nous_state()
if not shared:
return False
shared_refresh = shared.get("refresh_token")
if not isinstance(shared_refresh, str) or not shared_refresh.strip():
return False
local_refresh = state.get("refresh_token")
shared_access_exp = _parse_iso_timestamp(shared.get("expires_at")) or 0.0
local_access_exp = _parse_iso_timestamp(state.get("expires_at")) or 0.0
refresh_changed = shared_refresh.strip() != str(local_refresh or "").strip()
fresher_access = shared_access_exp > local_access_exp
if not refresh_changed and not fresher_access:
return False
for key in (
"access_token",
"refresh_token",
"token_type",
"scope",
"client_id",
"portal_base_url",
"inference_base_url",
"obtained_at",
"expires_at",
):
value = shared.get(key)
if value not in (None, ""):
state[key] = value
return True
def _write_shared_nous_state(state: Dict[str, Any]) -> None:
"""Persist a minimal copy of the Nous OAuth state to the shared store.
@ -2840,15 +2935,16 @@ def _write_shared_nous_state(state: Dict[str, Any]) -> None:
"updated_at": datetime.now(timezone.utc).isoformat(),
}
try:
path = _nous_shared_store_path()
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + ".tmp")
tmp.write_text(json.dumps(shared, indent=2, sort_keys=True))
try:
os.chmod(tmp, 0o600)
except OSError:
pass
os.replace(tmp, path)
with _nous_shared_store_lock():
path = _nous_shared_store_path()
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + ".tmp")
tmp.write_text(json.dumps(shared, indent=2, sort_keys=True))
try:
os.chmod(tmp, 0o600)
except OSError:
pass
os.replace(tmp, path)
_oauth_trace(
"nous_shared_store_written",
path=str(path),
@ -2905,36 +3001,38 @@ def _try_import_shared_nous_state(
etc.) caller should then fall through to the normal device-code
flow.
"""
shared = _read_shared_nous_state()
if not shared:
return None
# Build a full state dict so refresh_nous_oauth_from_state has every
# field it needs. force_refresh=True gets us a fresh access_token
# for this profile; force_mint=True gets us a fresh agent_key.
state: Dict[str, Any] = {
"access_token": shared.get("access_token"),
"refresh_token": shared.get("refresh_token"),
"client_id": shared.get("client_id") or DEFAULT_NOUS_CLIENT_ID,
"portal_base_url": shared.get("portal_base_url") or DEFAULT_NOUS_PORTAL_URL,
"inference_base_url": shared.get("inference_base_url") or DEFAULT_NOUS_INFERENCE_URL,
"token_type": shared.get("token_type") or "Bearer",
"scope": shared.get("scope") or DEFAULT_NOUS_SCOPE,
"obtained_at": shared.get("obtained_at"),
"expires_at": shared.get("expires_at"),
"agent_key": None,
"agent_key_expires_at": None,
"tls": {"insecure": False, "ca_bundle": None},
}
try:
refreshed = refresh_nous_oauth_from_state(
state,
min_key_ttl_seconds=min_key_ttl_seconds,
timeout_seconds=timeout_seconds,
force_refresh=True,
force_mint=True,
)
with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)):
shared = _read_shared_nous_state()
if not shared:
return None
# Build a full state dict so refresh_nous_oauth_from_state has every
# field it needs. force_refresh=True gets us a fresh access_token
# for this profile; force_mint=True gets us a fresh agent_key.
state: Dict[str, Any] = {
"access_token": shared.get("access_token"),
"refresh_token": shared.get("refresh_token"),
"client_id": shared.get("client_id") or DEFAULT_NOUS_CLIENT_ID,
"portal_base_url": shared.get("portal_base_url") or DEFAULT_NOUS_PORTAL_URL,
"inference_base_url": shared.get("inference_base_url") or DEFAULT_NOUS_INFERENCE_URL,
"token_type": shared.get("token_type") or "Bearer",
"scope": shared.get("scope") or DEFAULT_NOUS_SCOPE,
"obtained_at": shared.get("obtained_at"),
"expires_at": shared.get("expires_at"),
"agent_key": None,
"agent_key_expires_at": None,
"tls": {"insecure": False, "ca_bundle": None},
}
refreshed = refresh_nous_oauth_from_state(
state,
min_key_ttl_seconds=min_key_ttl_seconds,
timeout_seconds=timeout_seconds,
force_refresh=True,
force_mint=True,
)
_write_shared_nous_state(refreshed)
except AuthError as exc:
_oauth_trace(
"nous_shared_import_failed",
@ -3136,59 +3234,65 @@ def resolve_nous_access_token(
client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID)
verify = _resolve_verify(insecure=insecure, ca_bundle=ca_bundle, auth_state=state)
access_token = state.get("access_token")
refresh_token = state.get("refresh_token")
if not isinstance(access_token, str) or not access_token:
raise AuthError(
"No access token found for Nous Portal login.",
provider="nous",
relogin_required=True,
)
with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)):
merged_shared = _merge_shared_nous_oauth_state(state)
access_token = state.get("access_token")
refresh_token = state.get("refresh_token")
if not isinstance(access_token, str) or not access_token:
raise AuthError(
"No access token found for Nous Portal login.",
provider="nous",
relogin_required=True,
)
if not _is_expiring(state.get("expires_at"), refresh_skew_seconds):
return access_token
if not _is_expiring(state.get("expires_at"), refresh_skew_seconds):
if merged_shared:
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
return access_token
if not isinstance(refresh_token, str) or not refresh_token:
raise AuthError(
"Session expired and no refresh token is available.",
provider="nous",
relogin_required=True,
)
if not isinstance(refresh_token, str) or not refresh_token:
raise AuthError(
"Session expired and no refresh token is available.",
provider="nous",
relogin_required=True,
)
timeout = httpx.Timeout(timeout_seconds if timeout_seconds else 15.0)
with httpx.Client(
timeout=timeout,
headers={"Accept": "application/json"},
verify=verify,
) as client:
refreshed = _refresh_access_token(
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
refresh_token=refresh_token,
)
timeout = httpx.Timeout(timeout_seconds if timeout_seconds else 15.0)
with httpx.Client(
timeout=timeout,
headers={"Accept": "application/json"},
verify=verify,
) as client:
refreshed = _refresh_access_token(
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
refresh_token=refresh_token,
)
now = datetime.now(timezone.utc)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
state["access_token"] = refreshed["access_token"]
state["refresh_token"] = refreshed.get("refresh_token") or refresh_token
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
state["scope"] = refreshed.get("scope") or state.get("scope")
state["obtained_at"] = now.isoformat()
state["expires_in"] = access_ttl
state["expires_at"] = datetime.fromtimestamp(
now.timestamp() + access_ttl,
tz=timezone.utc,
).isoformat()
state["portal_base_url"] = portal_base_url
state["client_id"] = client_id
state["tls"] = {
"insecure": verify is False,
"ca_bundle": verify if isinstance(verify, str) else None,
}
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
return state["access_token"]
now = datetime.now(timezone.utc)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
state["access_token"] = refreshed["access_token"]
state["refresh_token"] = refreshed.get("refresh_token") or refresh_token
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
state["scope"] = refreshed.get("scope") or state.get("scope")
state["obtained_at"] = now.isoformat()
state["expires_in"] = access_ttl
state["expires_at"] = datetime.fromtimestamp(
now.timestamp() + access_ttl,
tz=timezone.utc,
).isoformat()
state["portal_base_url"] = portal_base_url
state["client_id"] = client_id
state["tls"] = {
"insecure": verify is False,
"ca_bundle": verify if isinstance(verify, str) else None,
}
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
_write_shared_nous_state(state)
return state["access_token"]
def refresh_nous_oauth_pure(
@ -3456,46 +3560,53 @@ def resolve_nous_runtime_credentials(
# Step 1: refresh access token if expiring
if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS):
if not isinstance(refresh_token, str) or not refresh_token:
raise AuthError("Session expired and no refresh token is available.",
provider="nous", relogin_required=True)
with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)):
if _merge_shared_nous_oauth_state(state):
access_token = state.get("access_token")
refresh_token = state.get("refresh_token")
_persist_state("post_shared_merge_access_expiring")
_oauth_trace(
"refresh_start",
sequence_id=sequence_id,
reason="access_expiring",
refresh_token_fp=_token_fingerprint(refresh_token),
)
refreshed = _refresh_access_token(
client=client, portal_base_url=portal_base_url,
client_id=client_id, refresh_token=refresh_token,
)
now = datetime.now(timezone.utc)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
previous_refresh_token = refresh_token
state["access_token"] = refreshed["access_token"]
state["refresh_token"] = refreshed.get("refresh_token") or refresh_token
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
state["scope"] = refreshed.get("scope") or state.get("scope")
refreshed_url = _optional_base_url(refreshed.get("inference_base_url"))
if refreshed_url:
inference_base_url = refreshed_url
state["obtained_at"] = now.isoformat()
state["expires_in"] = access_ttl
state["expires_at"] = datetime.fromtimestamp(
now.timestamp() + access_ttl, tz=timezone.utc
).isoformat()
access_token = state["access_token"]
refresh_token = state["refresh_token"]
_oauth_trace(
"refresh_success",
sequence_id=sequence_id,
reason="access_expiring",
previous_refresh_token_fp=_token_fingerprint(previous_refresh_token),
new_refresh_token_fp=_token_fingerprint(refresh_token),
)
# Persist immediately so downstream mint failures cannot drop rotated refresh tokens.
_persist_state("post_refresh_access_expiring")
if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS):
if not isinstance(refresh_token, str) or not refresh_token:
raise AuthError("Session expired and no refresh token is available.",
provider="nous", relogin_required=True)
_oauth_trace(
"refresh_start",
sequence_id=sequence_id,
reason="access_expiring",
refresh_token_fp=_token_fingerprint(refresh_token),
)
refreshed = _refresh_access_token(
client=client, portal_base_url=portal_base_url,
client_id=client_id, refresh_token=refresh_token,
)
now = datetime.now(timezone.utc)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
previous_refresh_token = refresh_token
state["access_token"] = refreshed["access_token"]
state["refresh_token"] = refreshed.get("refresh_token") or refresh_token
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
state["scope"] = refreshed.get("scope") or state.get("scope")
refreshed_url = _optional_base_url(refreshed.get("inference_base_url"))
if refreshed_url:
inference_base_url = refreshed_url
state["obtained_at"] = now.isoformat()
state["expires_in"] = access_ttl
state["expires_at"] = datetime.fromtimestamp(
now.timestamp() + access_ttl, tz=timezone.utc
).isoformat()
access_token = state["access_token"]
refresh_token = state["refresh_token"]
_oauth_trace(
"refresh_success",
sequence_id=sequence_id,
reason="access_expiring",
previous_refresh_token_fp=_token_fingerprint(previous_refresh_token),
new_refresh_token_fp=_token_fingerprint(refresh_token),
)
# Persist immediately so downstream mint failures cannot drop rotated refresh tokens.
_persist_state("post_refresh_access_expiring")
# Step 2: mint agent key if missing/expiring
used_cached_key = False
@ -3528,41 +3639,47 @@ def resolve_nous_runtime_credentials(
and isinstance(latest_refresh_token, str)
and latest_refresh_token
):
_oauth_trace(
"refresh_start",
sequence_id=sequence_id,
reason="mint_retry_after_invalid_token",
refresh_token_fp=_token_fingerprint(latest_refresh_token),
)
refreshed = _refresh_access_token(
client=client, portal_base_url=portal_base_url,
client_id=client_id, refresh_token=latest_refresh_token,
)
now = datetime.now(timezone.utc)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
state["access_token"] = refreshed["access_token"]
state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
state["scope"] = refreshed.get("scope") or state.get("scope")
refreshed_url = _optional_base_url(refreshed.get("inference_base_url"))
if refreshed_url:
inference_base_url = refreshed_url
state["obtained_at"] = now.isoformat()
state["expires_in"] = access_ttl
state["expires_at"] = datetime.fromtimestamp(
now.timestamp() + access_ttl, tz=timezone.utc
).isoformat()
access_token = state["access_token"]
refresh_token = state["refresh_token"]
_oauth_trace(
"refresh_success",
sequence_id=sequence_id,
reason="mint_retry_after_invalid_token",
previous_refresh_token_fp=_token_fingerprint(latest_refresh_token),
new_refresh_token_fp=_token_fingerprint(refresh_token),
)
# Persist retry refresh immediately for crash safety and cross-process visibility.
_persist_state("post_refresh_mint_retry")
with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)):
if _merge_shared_nous_oauth_state(state):
access_token = state.get("access_token")
latest_refresh_token = state.get("refresh_token")
_persist_state("post_shared_merge_mint_retry")
else:
_oauth_trace(
"refresh_start",
sequence_id=sequence_id,
reason="mint_retry_after_invalid_token",
refresh_token_fp=_token_fingerprint(latest_refresh_token),
)
refreshed = _refresh_access_token(
client=client, portal_base_url=portal_base_url,
client_id=client_id, refresh_token=latest_refresh_token,
)
now = datetime.now(timezone.utc)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
state["access_token"] = refreshed["access_token"]
state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
state["scope"] = refreshed.get("scope") or state.get("scope")
refreshed_url = _optional_base_url(refreshed.get("inference_base_url"))
if refreshed_url:
inference_base_url = refreshed_url
state["obtained_at"] = now.isoformat()
state["expires_in"] = access_ttl
state["expires_at"] = datetime.fromtimestamp(
now.timestamp() + access_ttl, tz=timezone.utc
).isoformat()
access_token = state["access_token"]
refresh_token = state["refresh_token"]
_oauth_trace(
"refresh_success",
sequence_id=sequence_id,
reason="mint_retry_after_invalid_token",
previous_refresh_token_fp=_token_fingerprint(latest_refresh_token),
new_refresh_token_fp=_token_fingerprint(refresh_token),
)
# Persist retry refresh immediately for crash safety and cross-process visibility.
_persist_state("post_refresh_mint_retry")
mint_payload = _mint_agent_key(
client=client, portal_base_url=portal_base_url,

View file

@ -1179,3 +1179,87 @@ def test_shared_store_survives_across_profile_switch(
shared_after = auth_mod._read_shared_nous_state()
assert shared_after is not None
assert shared_after["refresh_token"] == "b-refresh-tok"
def test_runtime_refresh_uses_newer_shared_token_before_local_stale_token(
tmp_path, monkeypatch, shared_store_env,
):
"""A sibling profile may rotate the single-use Nous refresh token.
When this profile later wakes with an expired local token, runtime
resolution must adopt the shared token before refreshing. Otherwise it
can submit the stale local refresh token and trigger portal reuse
revocation for the whole shared session.
"""
from hermes_cli import auth as auth_mod
profile_b = tmp_path / "profile_b"
_setup_nous_auth(
profile_b,
access_token="local-expired-access",
refresh_token="local-stale-refresh",
)
monkeypatch.setenv("HERMES_HOME", str(profile_b))
shared_state = _full_state_fixture()
shared_state["access_token"] = "shared-fresh-access"
shared_state["refresh_token"] = "shared-fresh-refresh"
shared_state["expires_at"] = "2099-01-01T00:00:00+00:00"
auth_mod._write_shared_nous_state(shared_state)
def _refresh_should_not_happen(**_kwargs):
raise AssertionError("stale profile-local refresh token was used")
minted_with: list[str] = []
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
minted_with.append(access_token)
return _mint_payload(api_key="agent-key-from-shared-token")
monkeypatch.setattr(auth_mod, "_refresh_access_token", _refresh_should_not_happen)
monkeypatch.setattr(auth_mod, "_mint_agent_key", _fake_mint_agent_key)
creds = auth_mod.resolve_nous_runtime_credentials(
min_key_ttl_seconds=300,
force_mint=True,
)
assert creds["api_key"] == "agent-key-from-shared-token"
assert minted_with == ["shared-fresh-access"]
profile_state = auth_mod.get_provider_auth_state("nous")
assert profile_state is not None
assert profile_state["refresh_token"] == "shared-fresh-refresh"
assert profile_state["access_token"] == "shared-fresh-access"
def test_managed_gateway_access_token_uses_newer_shared_token(
tmp_path, monkeypatch, shared_store_env,
):
"""Managed-tool token reads share the same stale-refresh-token hazard."""
from hermes_cli import auth as auth_mod
profile_b = tmp_path / "profile_b"
_setup_nous_auth(
profile_b,
access_token="local-expired-access",
refresh_token="local-stale-refresh",
)
monkeypatch.setenv("HERMES_HOME", str(profile_b))
shared_state = _full_state_fixture()
shared_state["access_token"] = "shared-fresh-access"
shared_state["refresh_token"] = "shared-fresh-refresh"
shared_state["expires_at"] = "2099-01-01T00:00:00+00:00"
auth_mod._write_shared_nous_state(shared_state)
def _refresh_should_not_happen(**_kwargs):
raise AssertionError("stale profile-local refresh token was used")
monkeypatch.setattr(auth_mod, "_refresh_access_token", _refresh_should_not_happen)
assert auth_mod.resolve_nous_access_token() == "shared-fresh-access"
profile_state = auth_mod.get_provider_auth_state("nous")
assert profile_state is not None
assert profile_state["refresh_token"] == "shared-fresh-refresh"