From a84e56d4c662770798584a79d34260fb86c6600d Mon Sep 17 00:00:00 2001 From: Michael Nguyen Date: Tue, 5 May 2026 23:36:09 +0700 Subject: [PATCH] fix(auth): sync shared Nous refresh tokens --- hermes_cli/auth.py | 439 +++++++++++++------- tests/hermes_cli/test_auth_nous_provider.py | 84 ++++ 2 files changed, 362 insertions(+), 161 deletions(-) diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index 5ff5638b91..889f8ce1ee 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -2769,6 +2769,7 @@ def _poll_for_token( # ----------------------------------------------------------------------------- NOUS_SHARED_STORE_FILENAME = "nous_auth.json" +_nous_shared_lock_holder = threading.local() def _nous_shared_auth_dir() -> Path: @@ -2808,6 +2809,100 @@ def _nous_shared_store_path() -> Path: return path +@contextmanager +def _nous_shared_store_lock(timeout_seconds: float = AUTH_LOCK_TIMEOUT_SECONDS): + """Cross-profile lock for the shared Nous OAuth store.""" + if getattr(_nous_shared_lock_holder, "depth", 0) > 0: + _nous_shared_lock_holder.depth += 1 + try: + yield + finally: + _nous_shared_lock_holder.depth -= 1 + return + + try: + lock_path = _nous_shared_store_path().with_suffix(".lock") + except RuntimeError: + yield + return + lock_path.parent.mkdir(parents=True, exist_ok=True) + + if fcntl is None and msvcrt is None: + _nous_shared_lock_holder.depth = 1 + try: + yield + finally: + _nous_shared_lock_holder.depth = 0 + return + + if msvcrt and (not lock_path.exists() or lock_path.stat().st_size == 0): + lock_path.write_text(" ", encoding="utf-8") + + with lock_path.open("r+" if msvcrt else "a+") as lock_file: + deadline = time.time() + max(1.0, timeout_seconds) + while True: + try: + if fcntl: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + else: + lock_file.seek(0) + msvcrt.locking(lock_file.fileno(), msvcrt.LK_NBLCK, 1) + break + except (BlockingIOError, OSError, PermissionError): + if time.time() >= deadline: + raise TimeoutError("Timed out waiting for shared Nous auth lock") + time.sleep(0.05) + + _nous_shared_lock_holder.depth = 1 + try: + yield + finally: + _nous_shared_lock_holder.depth = 0 + if fcntl: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + elif msvcrt: + try: + lock_file.seek(0) + msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1) + except (OSError, IOError): + pass + + +def _merge_shared_nous_oauth_state(state: Dict[str, Any]) -> bool: + """Copy fresher shared OAuth tokens into a profile-local Nous state.""" + shared = _read_shared_nous_state() + if not shared: + return False + + shared_refresh = shared.get("refresh_token") + if not isinstance(shared_refresh, str) or not shared_refresh.strip(): + return False + + local_refresh = state.get("refresh_token") + shared_access_exp = _parse_iso_timestamp(shared.get("expires_at")) or 0.0 + local_access_exp = _parse_iso_timestamp(state.get("expires_at")) or 0.0 + refresh_changed = shared_refresh.strip() != str(local_refresh or "").strip() + fresher_access = shared_access_exp > local_access_exp + if not refresh_changed and not fresher_access: + return False + + for key in ( + "access_token", + "refresh_token", + "token_type", + "scope", + "client_id", + "portal_base_url", + "inference_base_url", + "obtained_at", + "expires_at", + ): + value = shared.get(key) + if value not in (None, ""): + state[key] = value + return True + + def _write_shared_nous_state(state: Dict[str, Any]) -> None: """Persist a minimal copy of the Nous OAuth state to the shared store. @@ -2840,15 +2935,16 @@ def _write_shared_nous_state(state: Dict[str, Any]) -> None: "updated_at": datetime.now(timezone.utc).isoformat(), } try: - path = _nous_shared_store_path() - path.parent.mkdir(parents=True, exist_ok=True) - tmp = path.with_suffix(path.suffix + ".tmp") - tmp.write_text(json.dumps(shared, indent=2, sort_keys=True)) - try: - os.chmod(tmp, 0o600) - except OSError: - pass - os.replace(tmp, path) + with _nous_shared_store_lock(): + path = _nous_shared_store_path() + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + tmp.write_text(json.dumps(shared, indent=2, sort_keys=True)) + try: + os.chmod(tmp, 0o600) + except OSError: + pass + os.replace(tmp, path) _oauth_trace( "nous_shared_store_written", path=str(path), @@ -2905,36 +3001,38 @@ def _try_import_shared_nous_state( etc.) — caller should then fall through to the normal device-code flow. """ - shared = _read_shared_nous_state() - if not shared: - return None - - # Build a full state dict so refresh_nous_oauth_from_state has every - # field it needs. force_refresh=True gets us a fresh access_token - # for this profile; force_mint=True gets us a fresh agent_key. - state: Dict[str, Any] = { - "access_token": shared.get("access_token"), - "refresh_token": shared.get("refresh_token"), - "client_id": shared.get("client_id") or DEFAULT_NOUS_CLIENT_ID, - "portal_base_url": shared.get("portal_base_url") or DEFAULT_NOUS_PORTAL_URL, - "inference_base_url": shared.get("inference_base_url") or DEFAULT_NOUS_INFERENCE_URL, - "token_type": shared.get("token_type") or "Bearer", - "scope": shared.get("scope") or DEFAULT_NOUS_SCOPE, - "obtained_at": shared.get("obtained_at"), - "expires_at": shared.get("expires_at"), - "agent_key": None, - "agent_key_expires_at": None, - "tls": {"insecure": False, "ca_bundle": None}, - } - try: - refreshed = refresh_nous_oauth_from_state( - state, - min_key_ttl_seconds=min_key_ttl_seconds, - timeout_seconds=timeout_seconds, - force_refresh=True, - force_mint=True, - ) + with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)): + shared = _read_shared_nous_state() + if not shared: + return None + + # Build a full state dict so refresh_nous_oauth_from_state has every + # field it needs. force_refresh=True gets us a fresh access_token + # for this profile; force_mint=True gets us a fresh agent_key. + state: Dict[str, Any] = { + "access_token": shared.get("access_token"), + "refresh_token": shared.get("refresh_token"), + "client_id": shared.get("client_id") or DEFAULT_NOUS_CLIENT_ID, + "portal_base_url": shared.get("portal_base_url") or DEFAULT_NOUS_PORTAL_URL, + "inference_base_url": shared.get("inference_base_url") or DEFAULT_NOUS_INFERENCE_URL, + "token_type": shared.get("token_type") or "Bearer", + "scope": shared.get("scope") or DEFAULT_NOUS_SCOPE, + "obtained_at": shared.get("obtained_at"), + "expires_at": shared.get("expires_at"), + "agent_key": None, + "agent_key_expires_at": None, + "tls": {"insecure": False, "ca_bundle": None}, + } + + refreshed = refresh_nous_oauth_from_state( + state, + min_key_ttl_seconds=min_key_ttl_seconds, + timeout_seconds=timeout_seconds, + force_refresh=True, + force_mint=True, + ) + _write_shared_nous_state(refreshed) except AuthError as exc: _oauth_trace( "nous_shared_import_failed", @@ -3136,59 +3234,65 @@ def resolve_nous_access_token( client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID) verify = _resolve_verify(insecure=insecure, ca_bundle=ca_bundle, auth_state=state) - access_token = state.get("access_token") - refresh_token = state.get("refresh_token") - if not isinstance(access_token, str) or not access_token: - raise AuthError( - "No access token found for Nous Portal login.", - provider="nous", - relogin_required=True, - ) + with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)): + merged_shared = _merge_shared_nous_oauth_state(state) + access_token = state.get("access_token") + refresh_token = state.get("refresh_token") + if not isinstance(access_token, str) or not access_token: + raise AuthError( + "No access token found for Nous Portal login.", + provider="nous", + relogin_required=True, + ) - if not _is_expiring(state.get("expires_at"), refresh_skew_seconds): - return access_token + if not _is_expiring(state.get("expires_at"), refresh_skew_seconds): + if merged_shared: + _save_provider_state(auth_store, "nous", state) + _save_auth_store(auth_store) + return access_token - if not isinstance(refresh_token, str) or not refresh_token: - raise AuthError( - "Session expired and no refresh token is available.", - provider="nous", - relogin_required=True, - ) + if not isinstance(refresh_token, str) or not refresh_token: + raise AuthError( + "Session expired and no refresh token is available.", + provider="nous", + relogin_required=True, + ) - timeout = httpx.Timeout(timeout_seconds if timeout_seconds else 15.0) - with httpx.Client( - timeout=timeout, - headers={"Accept": "application/json"}, - verify=verify, - ) as client: - refreshed = _refresh_access_token( - client=client, - portal_base_url=portal_base_url, - client_id=client_id, - refresh_token=refresh_token, - ) + timeout = httpx.Timeout(timeout_seconds if timeout_seconds else 15.0) + with httpx.Client( + timeout=timeout, + headers={"Accept": "application/json"}, + verify=verify, + ) as client: + refreshed = _refresh_access_token( + client=client, + portal_base_url=portal_base_url, + client_id=client_id, + refresh_token=refresh_token, + ) - now = datetime.now(timezone.utc) - access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) - state["access_token"] = refreshed["access_token"] - state["refresh_token"] = refreshed.get("refresh_token") or refresh_token - state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" - state["scope"] = refreshed.get("scope") or state.get("scope") - state["obtained_at"] = now.isoformat() - state["expires_in"] = access_ttl - state["expires_at"] = datetime.fromtimestamp( - now.timestamp() + access_ttl, - tz=timezone.utc, - ).isoformat() - state["portal_base_url"] = portal_base_url - state["client_id"] = client_id - state["tls"] = { - "insecure": verify is False, - "ca_bundle": verify if isinstance(verify, str) else None, - } - _save_provider_state(auth_store, "nous", state) - _save_auth_store(auth_store) - return state["access_token"] + now = datetime.now(timezone.utc) + access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) + state["access_token"] = refreshed["access_token"] + state["refresh_token"] = refreshed.get("refresh_token") or refresh_token + state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" + state["scope"] = refreshed.get("scope") or state.get("scope") + state["obtained_at"] = now.isoformat() + state["expires_in"] = access_ttl + state["expires_at"] = datetime.fromtimestamp( + now.timestamp() + access_ttl, + tz=timezone.utc, + ).isoformat() + state["portal_base_url"] = portal_base_url + state["client_id"] = client_id + state["tls"] = { + "insecure": verify is False, + "ca_bundle": verify if isinstance(verify, str) else None, + } + _save_provider_state(auth_store, "nous", state) + _save_auth_store(auth_store) + _write_shared_nous_state(state) + return state["access_token"] def refresh_nous_oauth_pure( @@ -3456,46 +3560,53 @@ def resolve_nous_runtime_credentials( # Step 1: refresh access token if expiring if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS): - if not isinstance(refresh_token, str) or not refresh_token: - raise AuthError("Session expired and no refresh token is available.", - provider="nous", relogin_required=True) + with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)): + if _merge_shared_nous_oauth_state(state): + access_token = state.get("access_token") + refresh_token = state.get("refresh_token") + _persist_state("post_shared_merge_access_expiring") - _oauth_trace( - "refresh_start", - sequence_id=sequence_id, - reason="access_expiring", - refresh_token_fp=_token_fingerprint(refresh_token), - ) - refreshed = _refresh_access_token( - client=client, portal_base_url=portal_base_url, - client_id=client_id, refresh_token=refresh_token, - ) - now = datetime.now(timezone.utc) - access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) - previous_refresh_token = refresh_token - state["access_token"] = refreshed["access_token"] - state["refresh_token"] = refreshed.get("refresh_token") or refresh_token - state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" - state["scope"] = refreshed.get("scope") or state.get("scope") - refreshed_url = _optional_base_url(refreshed.get("inference_base_url")) - if refreshed_url: - inference_base_url = refreshed_url - state["obtained_at"] = now.isoformat() - state["expires_in"] = access_ttl - state["expires_at"] = datetime.fromtimestamp( - now.timestamp() + access_ttl, tz=timezone.utc - ).isoformat() - access_token = state["access_token"] - refresh_token = state["refresh_token"] - _oauth_trace( - "refresh_success", - sequence_id=sequence_id, - reason="access_expiring", - previous_refresh_token_fp=_token_fingerprint(previous_refresh_token), - new_refresh_token_fp=_token_fingerprint(refresh_token), - ) - # Persist immediately so downstream mint failures cannot drop rotated refresh tokens. - _persist_state("post_refresh_access_expiring") + if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS): + if not isinstance(refresh_token, str) or not refresh_token: + raise AuthError("Session expired and no refresh token is available.", + provider="nous", relogin_required=True) + + _oauth_trace( + "refresh_start", + sequence_id=sequence_id, + reason="access_expiring", + refresh_token_fp=_token_fingerprint(refresh_token), + ) + refreshed = _refresh_access_token( + client=client, portal_base_url=portal_base_url, + client_id=client_id, refresh_token=refresh_token, + ) + now = datetime.now(timezone.utc) + access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) + previous_refresh_token = refresh_token + state["access_token"] = refreshed["access_token"] + state["refresh_token"] = refreshed.get("refresh_token") or refresh_token + state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" + state["scope"] = refreshed.get("scope") or state.get("scope") + refreshed_url = _optional_base_url(refreshed.get("inference_base_url")) + if refreshed_url: + inference_base_url = refreshed_url + state["obtained_at"] = now.isoformat() + state["expires_in"] = access_ttl + state["expires_at"] = datetime.fromtimestamp( + now.timestamp() + access_ttl, tz=timezone.utc + ).isoformat() + access_token = state["access_token"] + refresh_token = state["refresh_token"] + _oauth_trace( + "refresh_success", + sequence_id=sequence_id, + reason="access_expiring", + previous_refresh_token_fp=_token_fingerprint(previous_refresh_token), + new_refresh_token_fp=_token_fingerprint(refresh_token), + ) + # Persist immediately so downstream mint failures cannot drop rotated refresh tokens. + _persist_state("post_refresh_access_expiring") # Step 2: mint agent key if missing/expiring used_cached_key = False @@ -3528,41 +3639,47 @@ def resolve_nous_runtime_credentials( and isinstance(latest_refresh_token, str) and latest_refresh_token ): - _oauth_trace( - "refresh_start", - sequence_id=sequence_id, - reason="mint_retry_after_invalid_token", - refresh_token_fp=_token_fingerprint(latest_refresh_token), - ) - refreshed = _refresh_access_token( - client=client, portal_base_url=portal_base_url, - client_id=client_id, refresh_token=latest_refresh_token, - ) - now = datetime.now(timezone.utc) - access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) - state["access_token"] = refreshed["access_token"] - state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token - state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" - state["scope"] = refreshed.get("scope") or state.get("scope") - refreshed_url = _optional_base_url(refreshed.get("inference_base_url")) - if refreshed_url: - inference_base_url = refreshed_url - state["obtained_at"] = now.isoformat() - state["expires_in"] = access_ttl - state["expires_at"] = datetime.fromtimestamp( - now.timestamp() + access_ttl, tz=timezone.utc - ).isoformat() - access_token = state["access_token"] - refresh_token = state["refresh_token"] - _oauth_trace( - "refresh_success", - sequence_id=sequence_id, - reason="mint_retry_after_invalid_token", - previous_refresh_token_fp=_token_fingerprint(latest_refresh_token), - new_refresh_token_fp=_token_fingerprint(refresh_token), - ) - # Persist retry refresh immediately for crash safety and cross-process visibility. - _persist_state("post_refresh_mint_retry") + with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)): + if _merge_shared_nous_oauth_state(state): + access_token = state.get("access_token") + latest_refresh_token = state.get("refresh_token") + _persist_state("post_shared_merge_mint_retry") + else: + _oauth_trace( + "refresh_start", + sequence_id=sequence_id, + reason="mint_retry_after_invalid_token", + refresh_token_fp=_token_fingerprint(latest_refresh_token), + ) + refreshed = _refresh_access_token( + client=client, portal_base_url=portal_base_url, + client_id=client_id, refresh_token=latest_refresh_token, + ) + now = datetime.now(timezone.utc) + access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) + state["access_token"] = refreshed["access_token"] + state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token + state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" + state["scope"] = refreshed.get("scope") or state.get("scope") + refreshed_url = _optional_base_url(refreshed.get("inference_base_url")) + if refreshed_url: + inference_base_url = refreshed_url + state["obtained_at"] = now.isoformat() + state["expires_in"] = access_ttl + state["expires_at"] = datetime.fromtimestamp( + now.timestamp() + access_ttl, tz=timezone.utc + ).isoformat() + access_token = state["access_token"] + refresh_token = state["refresh_token"] + _oauth_trace( + "refresh_success", + sequence_id=sequence_id, + reason="mint_retry_after_invalid_token", + previous_refresh_token_fp=_token_fingerprint(latest_refresh_token), + new_refresh_token_fp=_token_fingerprint(refresh_token), + ) + # Persist retry refresh immediately for crash safety and cross-process visibility. + _persist_state("post_refresh_mint_retry") mint_payload = _mint_agent_key( client=client, portal_base_url=portal_base_url, diff --git a/tests/hermes_cli/test_auth_nous_provider.py b/tests/hermes_cli/test_auth_nous_provider.py index d0e24aeaab..136265c7e4 100644 --- a/tests/hermes_cli/test_auth_nous_provider.py +++ b/tests/hermes_cli/test_auth_nous_provider.py @@ -1179,3 +1179,87 @@ def test_shared_store_survives_across_profile_switch( shared_after = auth_mod._read_shared_nous_state() assert shared_after is not None assert shared_after["refresh_token"] == "b-refresh-tok" + + +def test_runtime_refresh_uses_newer_shared_token_before_local_stale_token( + tmp_path, monkeypatch, shared_store_env, +): + """A sibling profile may rotate the single-use Nous refresh token. + + When this profile later wakes with an expired local token, runtime + resolution must adopt the shared token before refreshing. Otherwise it + can submit the stale local refresh token and trigger portal reuse + revocation for the whole shared session. + """ + from hermes_cli import auth as auth_mod + + profile_b = tmp_path / "profile_b" + _setup_nous_auth( + profile_b, + access_token="local-expired-access", + refresh_token="local-stale-refresh", + ) + monkeypatch.setenv("HERMES_HOME", str(profile_b)) + + shared_state = _full_state_fixture() + shared_state["access_token"] = "shared-fresh-access" + shared_state["refresh_token"] = "shared-fresh-refresh" + shared_state["expires_at"] = "2099-01-01T00:00:00+00:00" + auth_mod._write_shared_nous_state(shared_state) + + def _refresh_should_not_happen(**_kwargs): + raise AssertionError("stale profile-local refresh token was used") + + minted_with: list[str] = [] + + def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds): + minted_with.append(access_token) + return _mint_payload(api_key="agent-key-from-shared-token") + + monkeypatch.setattr(auth_mod, "_refresh_access_token", _refresh_should_not_happen) + monkeypatch.setattr(auth_mod, "_mint_agent_key", _fake_mint_agent_key) + + creds = auth_mod.resolve_nous_runtime_credentials( + min_key_ttl_seconds=300, + force_mint=True, + ) + + assert creds["api_key"] == "agent-key-from-shared-token" + assert minted_with == ["shared-fresh-access"] + + profile_state = auth_mod.get_provider_auth_state("nous") + assert profile_state is not None + assert profile_state["refresh_token"] == "shared-fresh-refresh" + assert profile_state["access_token"] == "shared-fresh-access" + + +def test_managed_gateway_access_token_uses_newer_shared_token( + tmp_path, monkeypatch, shared_store_env, +): + """Managed-tool token reads share the same stale-refresh-token hazard.""" + from hermes_cli import auth as auth_mod + + profile_b = tmp_path / "profile_b" + _setup_nous_auth( + profile_b, + access_token="local-expired-access", + refresh_token="local-stale-refresh", + ) + monkeypatch.setenv("HERMES_HOME", str(profile_b)) + + shared_state = _full_state_fixture() + shared_state["access_token"] = "shared-fresh-access" + shared_state["refresh_token"] = "shared-fresh-refresh" + shared_state["expires_at"] = "2099-01-01T00:00:00+00:00" + auth_mod._write_shared_nous_state(shared_state) + + def _refresh_should_not_happen(**_kwargs): + raise AssertionError("stale profile-local refresh token was used") + + monkeypatch.setattr(auth_mod, "_refresh_access_token", _refresh_should_not_happen) + + assert auth_mod.resolve_nous_access_token() == "shared-fresh-access" + + profile_state = auth_mod.get_provider_auth_state("nous") + assert profile_state is not None + assert profile_state["refresh_token"] == "shared-fresh-refresh"