fix(auth): preserve concurrently-added credentials on pool rewrite

This commit is contained in:
konsisumer 2026-06-05 07:41:00 +02:00 committed by Teknium
parent 163cb24d45
commit 8b4c29f0f0
3 changed files with 164 additions and 9 deletions

View file

@ -537,10 +537,11 @@ class CredentialPool:
self._entries[idx] = new
return
def _persist(self) -> None:
def _persist(self, *, removed_ids: Optional[List[str]] = None) -> None:
write_credential_pool(
self.provider,
[entry.to_dict() for entry in self._entries],
removed_ids=removed_ids,
)
def _is_terminal_auth_failure(
@ -1124,13 +1125,17 @@ class CredentialPool:
logger.debug(
"Failed to clear terminal xAI OAuth state: %s", clear_exc
)
removed_ids = [
item.id for item in self._entries
if item.source == "loopback_pkce"
]
self._entries = [
item for item in self._entries
if item.source != "loopback_pkce"
]
if self._current_id == entry.id:
self._current_id = None
self._persist()
self._persist(removed_ids=removed_ids)
return None
# For openai-codex: same race as xAI/nous — another Hermes process
# may have consumed the refresh token between our proactive sync
@ -1190,13 +1195,17 @@ class CredentialPool:
logger.debug(
"Failed to clear terminal Codex OAuth state: %s", clear_exc
)
removed_ids = [
item.id for item in self._entries
if item.source == "device_code"
]
self._entries = [
item for item in self._entries
if item.source != "device_code"
]
if self._current_id == entry.id:
self._current_id = None
self._persist()
self._persist(removed_ids=removed_ids)
return None
# For nous: another process may have consumed the refresh token
# between our proactive sync and the HTTP call. Re-sync from
@ -1253,13 +1262,17 @@ class CredentialPool:
auth_mod.NOUS_DEVICE_CODE_SOURCE,
f"manual:{auth_mod.NOUS_DEVICE_CODE_SOURCE}",
}
removed_ids = [
item.id for item in self._entries
if item.source in singleton_sources
]
self._entries = [
item for item in self._entries
if item.source not in singleton_sources
]
if self._current_id == entry.id:
self._current_id = None
self._persist()
self._persist(removed_ids=removed_ids)
return None
self._mark_exhausted(entry, None)
return None
@ -1421,7 +1434,7 @@ class CredentialPool:
pruned_ids = set(entries_to_prune)
self._entries = [e for e in self._entries if e.id not in pruned_ids]
if cleared_any:
self._persist()
self._persist(removed_ids=entries_to_prune)
return available
def _select_unlocked(self) -> Optional[PooledCredential]:
@ -1595,7 +1608,11 @@ class CredentialPool:
replace(entry, priority=new_priority)
for new_priority, entry in enumerate(self._entries)
]
self._persist()
write_credential_pool(
self.provider,
[entry.to_dict() for entry in self._entries],
removed_ids=[removed.id],
)
if self._current_id == removed.id:
self._current_id = None
return removed
@ -2257,6 +2274,11 @@ def _seed_custom_pool(pool_key: str, entries: List[PooledCredential]) -> Tuple[b
def load_pool(provider: str) -> CredentialPool:
provider = (provider or "").strip().lower()
raw_entries = read_credential_pool(provider)
disk_ids = {
entry.get("id")
for entry in raw_entries
if isinstance(entry, dict) and entry.get("id")
}
raw_needs_sanitization = any(
isinstance(payload, dict)
and sanitize_borrowed_credential_payload(payload, provider) != payload
@ -2285,8 +2307,10 @@ def load_pool(provider: str) -> CredentialPool:
changed |= _normalize_pool_priorities(provider, entries)
if changed:
new_ids = {entry.id for entry in entries}
write_credential_pool(
provider,
[entry.to_dict() for entry in sorted(entries, key=lambda item: item.priority)],
removed_ids=disk_ids - new_ids,
)
return CredentialPool(provider, entries)

View file

@ -38,7 +38,7 @@ from dataclasses import dataclass, field
from datetime import datetime, timezone
from http.server import BaseHTTPRequestHandler, HTTPServer, ThreadingHTTPServer
from pathlib import Path
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple
from typing import Any, Callable, Dict, FrozenSet, Iterable, List, Optional, Tuple
from urllib.parse import parse_qs, urlencode, urlparse
import httpx
@ -1260,24 +1260,54 @@ def read_credential_pool(provider_id: Optional[str] = None) -> Dict[str, Any]:
return list(global_entries) if isinstance(global_entries, list) else []
def write_credential_pool(provider_id: str, entries: List[Dict[str, Any]]) -> Path:
def write_credential_pool(
provider_id: str,
entries: List[Dict[str, Any]],
*,
removed_ids: Optional[Iterable[str]] = None,
) -> Path:
"""Persist one provider's credential pool under auth.json.
This is the final disk-boundary guard for borrowed/reference-only
credentials. Callers may pass raw dictionaries, so sanitize here even when
``PooledCredential.to_dict()`` already did the same work upstream.
Re-read the on-disk pool under the same lock and merge entries present on
disk but missing from ``entries``. Those were added by another process after
the caller loaded its in-memory snapshot; without this merge a later
rotation/exhaustion rewrite drops the concurrent credential.
Pass ``removed_ids`` for entries the caller intentionally removed, so the
merge does not resurrect them from the on-disk copy.
"""
removed = {rid for rid in (removed_ids or ()) if rid}
with _auth_store_lock():
auth_store = _load_auth_store()
pool = auth_store.get("credential_pool")
if not isinstance(pool, dict):
pool = {}
auth_store["credential_pool"] = pool
pool[provider_id] = [
sanitized_entries = [
sanitize_borrowed_credential_payload(entry, provider_id)
if isinstance(entry, dict) else entry
for entry in entries
]
existing = pool.get(provider_id)
existing_list = existing if isinstance(existing, list) else []
new_ids = {
entry.get("id")
for entry in sanitized_entries
if isinstance(entry, dict) and entry.get("id")
}
merged: List[Dict[str, Any]] = list(sanitized_entries)
for disk_entry in existing_list:
if not isinstance(disk_entry, dict):
continue
disk_id = disk_entry.get("id")
if not disk_id or disk_id in new_ids or disk_id in removed:
continue
merged.append(sanitize_borrowed_credential_payload(disk_entry, provider_id))
pool[provider_id] = merged
return _save_auth_store(auth_store)

View file

@ -3045,3 +3045,104 @@ def test_codex_oauth_nonterminal_refresh_does_not_quarantine(tmp_path, monkeypat
tokens = auth_payload["providers"]["openai-codex"].get("tokens", {})
assert tokens.get("access_token") == "old-access-token"
assert tokens.get("refresh_token") == "old-refresh-token"
def test_persist_preserves_concurrent_disk_only_entry(tmp_path, monkeypatch):
"""Regression for #19566: stale rotation writes keep concurrent entries."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
_write_auth_store(
tmp_path,
{
"version": 1,
"credential_pool": {
"anthropic": [
{
"id": "cred-A",
"label": "primary",
"auth_type": "api_key",
"priority": 0,
"source": "manual",
"access_token": "sk-A",
},
{
"id": "cred-B",
"label": "secondary",
"auth_type": "api_key",
"priority": 1,
"source": "manual",
"access_token": "sk-B",
},
]
},
},
)
from agent.credential_pool import load_pool
from hermes_cli.auth import read_credential_pool, write_credential_pool
pool = load_pool("anthropic")
assert {entry.id for entry in pool.entries()} == {"cred-A", "cred-B"}
disk_snapshot = read_credential_pool("anthropic")
disk_snapshot.append(
{
"id": "cred-C",
"label": "added-concurrently",
"auth_type": "api_key",
"priority": 2,
"source": "manual",
"access_token": "sk-C",
}
)
write_credential_pool("anthropic", disk_snapshot)
pool.mark_exhausted_and_rotate(status_code=429)
final = json.loads((tmp_path / "hermes" / "auth.json").read_text())
final_ids = [entry["id"] for entry in final["credential_pool"]["anthropic"]]
assert set(final_ids) == {"cred-A", "cred-B", "cred-C"}
persisted_a = next(
entry
for entry in final["credential_pool"]["anthropic"]
if entry["id"] == "cred-A"
)
assert persisted_a["last_status"] == "exhausted"
def test_remove_index_does_not_resurrect_via_disk_merge(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
_write_auth_store(
tmp_path,
{
"version": 1,
"credential_pool": {
"anthropic": [
{
"id": "cred-A",
"label": "keep",
"auth_type": "api_key",
"priority": 0,
"source": "manual",
"access_token": "sk-A",
},
{
"id": "cred-B",
"label": "drop",
"auth_type": "api_key",
"priority": 1,
"source": "manual",
"access_token": "sk-B",
},
]
},
},
)
from agent.credential_pool import load_pool
pool = load_pool("anthropic")
pool.remove_index(2)
final = json.loads((tmp_path / "hermes" / "auth.json").read_text())
final_ids = [entry["id"] for entry in final["credential_pool"]["anthropic"]]
assert final_ids == ["cred-A"]