diff --git a/agent/credential_pool.py b/agent/credential_pool.py index e4aa575ab04..d8ca2b1720e 100644 --- a/agent/credential_pool.py +++ b/agent/credential_pool.py @@ -537,10 +537,11 @@ class CredentialPool: self._entries[idx] = new return - def _persist(self) -> None: + def _persist(self, *, removed_ids: Optional[List[str]] = None) -> None: write_credential_pool( self.provider, [entry.to_dict() for entry in self._entries], + removed_ids=removed_ids, ) def _is_terminal_auth_failure( @@ -1124,13 +1125,17 @@ class CredentialPool: logger.debug( "Failed to clear terminal xAI OAuth state: %s", clear_exc ) + removed_ids = [ + item.id for item in self._entries + if item.source == "loopback_pkce" + ] self._entries = [ item for item in self._entries if item.source != "loopback_pkce" ] if self._current_id == entry.id: self._current_id = None - self._persist() + self._persist(removed_ids=removed_ids) return None # For openai-codex: same race as xAI/nous — another Hermes process # may have consumed the refresh token between our proactive sync @@ -1190,13 +1195,17 @@ class CredentialPool: logger.debug( "Failed to clear terminal Codex OAuth state: %s", clear_exc ) + removed_ids = [ + item.id for item in self._entries + if item.source == "device_code" + ] self._entries = [ item for item in self._entries if item.source != "device_code" ] if self._current_id == entry.id: self._current_id = None - self._persist() + self._persist(removed_ids=removed_ids) return None # For nous: another process may have consumed the refresh token # between our proactive sync and the HTTP call. Re-sync from @@ -1253,13 +1262,17 @@ class CredentialPool: auth_mod.NOUS_DEVICE_CODE_SOURCE, f"manual:{auth_mod.NOUS_DEVICE_CODE_SOURCE}", } + removed_ids = [ + item.id for item in self._entries + if item.source in singleton_sources + ] self._entries = [ item for item in self._entries if item.source not in singleton_sources ] if self._current_id == entry.id: self._current_id = None - self._persist() + self._persist(removed_ids=removed_ids) return None self._mark_exhausted(entry, None) return None @@ -1421,7 +1434,7 @@ class CredentialPool: pruned_ids = set(entries_to_prune) self._entries = [e for e in self._entries if e.id not in pruned_ids] if cleared_any: - self._persist() + self._persist(removed_ids=entries_to_prune) return available def _select_unlocked(self) -> Optional[PooledCredential]: @@ -1595,7 +1608,11 @@ class CredentialPool: replace(entry, priority=new_priority) for new_priority, entry in enumerate(self._entries) ] - self._persist() + write_credential_pool( + self.provider, + [entry.to_dict() for entry in self._entries], + removed_ids=[removed.id], + ) if self._current_id == removed.id: self._current_id = None return removed @@ -2257,6 +2274,11 @@ def _seed_custom_pool(pool_key: str, entries: List[PooledCredential]) -> Tuple[b def load_pool(provider: str) -> CredentialPool: provider = (provider or "").strip().lower() raw_entries = read_credential_pool(provider) + disk_ids = { + entry.get("id") + for entry in raw_entries + if isinstance(entry, dict) and entry.get("id") + } raw_needs_sanitization = any( isinstance(payload, dict) and sanitize_borrowed_credential_payload(payload, provider) != payload @@ -2285,8 +2307,10 @@ def load_pool(provider: str) -> CredentialPool: changed |= _normalize_pool_priorities(provider, entries) if changed: + new_ids = {entry.id for entry in entries} write_credential_pool( provider, [entry.to_dict() for entry in sorted(entries, key=lambda item: item.priority)], + removed_ids=disk_ids - new_ids, ) return CredentialPool(provider, entries) diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index 4a0571a180b..1c6112f274a 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -38,7 +38,7 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from http.server import BaseHTTPRequestHandler, HTTPServer, ThreadingHTTPServer from pathlib import Path -from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple +from typing import Any, Callable, Dict, FrozenSet, Iterable, List, Optional, Tuple from urllib.parse import parse_qs, urlencode, urlparse import httpx @@ -1260,24 +1260,54 @@ def read_credential_pool(provider_id: Optional[str] = None) -> Dict[str, Any]: return list(global_entries) if isinstance(global_entries, list) else [] -def write_credential_pool(provider_id: str, entries: List[Dict[str, Any]]) -> Path: +def write_credential_pool( + provider_id: str, + entries: List[Dict[str, Any]], + *, + removed_ids: Optional[Iterable[str]] = None, +) -> Path: """Persist one provider's credential pool under auth.json. This is the final disk-boundary guard for borrowed/reference-only credentials. Callers may pass raw dictionaries, so sanitize here even when ``PooledCredential.to_dict()`` already did the same work upstream. + + Re-read the on-disk pool under the same lock and merge entries present on + disk but missing from ``entries``. Those were added by another process after + the caller loaded its in-memory snapshot; without this merge a later + rotation/exhaustion rewrite drops the concurrent credential. + + Pass ``removed_ids`` for entries the caller intentionally removed, so the + merge does not resurrect them from the on-disk copy. """ + removed = {rid for rid in (removed_ids or ()) if rid} with _auth_store_lock(): auth_store = _load_auth_store() pool = auth_store.get("credential_pool") if not isinstance(pool, dict): pool = {} auth_store["credential_pool"] = pool - pool[provider_id] = [ + sanitized_entries = [ sanitize_borrowed_credential_payload(entry, provider_id) if isinstance(entry, dict) else entry for entry in entries ] + existing = pool.get(provider_id) + existing_list = existing if isinstance(existing, list) else [] + new_ids = { + entry.get("id") + for entry in sanitized_entries + if isinstance(entry, dict) and entry.get("id") + } + merged: List[Dict[str, Any]] = list(sanitized_entries) + for disk_entry in existing_list: + if not isinstance(disk_entry, dict): + continue + disk_id = disk_entry.get("id") + if not disk_id or disk_id in new_ids or disk_id in removed: + continue + merged.append(sanitize_borrowed_credential_payload(disk_entry, provider_id)) + pool[provider_id] = merged return _save_auth_store(auth_store) diff --git a/tests/agent/test_credential_pool.py b/tests/agent/test_credential_pool.py index 0012e7cebca..ad9cbfcbdba 100644 --- a/tests/agent/test_credential_pool.py +++ b/tests/agent/test_credential_pool.py @@ -3045,3 +3045,104 @@ def test_codex_oauth_nonterminal_refresh_does_not_quarantine(tmp_path, monkeypat tokens = auth_payload["providers"]["openai-codex"].get("tokens", {}) assert tokens.get("access_token") == "old-access-token" assert tokens.get("refresh_token") == "old-refresh-token" + + +def test_persist_preserves_concurrent_disk_only_entry(tmp_path, monkeypatch): + """Regression for #19566: stale rotation writes keep concurrent entries.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) + _write_auth_store( + tmp_path, + { + "version": 1, + "credential_pool": { + "anthropic": [ + { + "id": "cred-A", + "label": "primary", + "auth_type": "api_key", + "priority": 0, + "source": "manual", + "access_token": "sk-A", + }, + { + "id": "cred-B", + "label": "secondary", + "auth_type": "api_key", + "priority": 1, + "source": "manual", + "access_token": "sk-B", + }, + ] + }, + }, + ) + + from agent.credential_pool import load_pool + from hermes_cli.auth import read_credential_pool, write_credential_pool + + pool = load_pool("anthropic") + assert {entry.id for entry in pool.entries()} == {"cred-A", "cred-B"} + + disk_snapshot = read_credential_pool("anthropic") + disk_snapshot.append( + { + "id": "cred-C", + "label": "added-concurrently", + "auth_type": "api_key", + "priority": 2, + "source": "manual", + "access_token": "sk-C", + } + ) + write_credential_pool("anthropic", disk_snapshot) + + pool.mark_exhausted_and_rotate(status_code=429) + + final = json.loads((tmp_path / "hermes" / "auth.json").read_text()) + final_ids = [entry["id"] for entry in final["credential_pool"]["anthropic"]] + assert set(final_ids) == {"cred-A", "cred-B", "cred-C"} + persisted_a = next( + entry + for entry in final["credential_pool"]["anthropic"] + if entry["id"] == "cred-A" + ) + assert persisted_a["last_status"] == "exhausted" + + +def test_remove_index_does_not_resurrect_via_disk_merge(tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) + _write_auth_store( + tmp_path, + { + "version": 1, + "credential_pool": { + "anthropic": [ + { + "id": "cred-A", + "label": "keep", + "auth_type": "api_key", + "priority": 0, + "source": "manual", + "access_token": "sk-A", + }, + { + "id": "cred-B", + "label": "drop", + "auth_type": "api_key", + "priority": 1, + "source": "manual", + "access_token": "sk-B", + }, + ] + }, + }, + ) + + from agent.credential_pool import load_pool + + pool = load_pool("anthropic") + pool.remove_index(2) + + final = json.loads((tmp_path / "hermes" / "auth.json").read_text()) + final_ids = [entry["id"] for entry in final["credential_pool"]["anthropic"]] + assert final_ids == ["cred-A"]