mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-30 11:52:04 +00:00
fix(auth): preserve concurrently-added credentials on pool rewrite
This commit is contained in:
parent
163cb24d45
commit
8b4c29f0f0
3 changed files with 164 additions and 9 deletions
|
|
@ -537,10 +537,11 @@ class CredentialPool:
|
|||
self._entries[idx] = new
|
||||
return
|
||||
|
||||
def _persist(self) -> None:
|
||||
def _persist(self, *, removed_ids: Optional[List[str]] = None) -> None:
|
||||
write_credential_pool(
|
||||
self.provider,
|
||||
[entry.to_dict() for entry in self._entries],
|
||||
removed_ids=removed_ids,
|
||||
)
|
||||
|
||||
def _is_terminal_auth_failure(
|
||||
|
|
@ -1124,13 +1125,17 @@ class CredentialPool:
|
|||
logger.debug(
|
||||
"Failed to clear terminal xAI OAuth state: %s", clear_exc
|
||||
)
|
||||
removed_ids = [
|
||||
item.id for item in self._entries
|
||||
if item.source == "loopback_pkce"
|
||||
]
|
||||
self._entries = [
|
||||
item for item in self._entries
|
||||
if item.source != "loopback_pkce"
|
||||
]
|
||||
if self._current_id == entry.id:
|
||||
self._current_id = None
|
||||
self._persist()
|
||||
self._persist(removed_ids=removed_ids)
|
||||
return None
|
||||
# For openai-codex: same race as xAI/nous — another Hermes process
|
||||
# may have consumed the refresh token between our proactive sync
|
||||
|
|
@ -1190,13 +1195,17 @@ class CredentialPool:
|
|||
logger.debug(
|
||||
"Failed to clear terminal Codex OAuth state: %s", clear_exc
|
||||
)
|
||||
removed_ids = [
|
||||
item.id for item in self._entries
|
||||
if item.source == "device_code"
|
||||
]
|
||||
self._entries = [
|
||||
item for item in self._entries
|
||||
if item.source != "device_code"
|
||||
]
|
||||
if self._current_id == entry.id:
|
||||
self._current_id = None
|
||||
self._persist()
|
||||
self._persist(removed_ids=removed_ids)
|
||||
return None
|
||||
# For nous: another process may have consumed the refresh token
|
||||
# between our proactive sync and the HTTP call. Re-sync from
|
||||
|
|
@ -1253,13 +1262,17 @@ class CredentialPool:
|
|||
auth_mod.NOUS_DEVICE_CODE_SOURCE,
|
||||
f"manual:{auth_mod.NOUS_DEVICE_CODE_SOURCE}",
|
||||
}
|
||||
removed_ids = [
|
||||
item.id for item in self._entries
|
||||
if item.source in singleton_sources
|
||||
]
|
||||
self._entries = [
|
||||
item for item in self._entries
|
||||
if item.source not in singleton_sources
|
||||
]
|
||||
if self._current_id == entry.id:
|
||||
self._current_id = None
|
||||
self._persist()
|
||||
self._persist(removed_ids=removed_ids)
|
||||
return None
|
||||
self._mark_exhausted(entry, None)
|
||||
return None
|
||||
|
|
@ -1421,7 +1434,7 @@ class CredentialPool:
|
|||
pruned_ids = set(entries_to_prune)
|
||||
self._entries = [e for e in self._entries if e.id not in pruned_ids]
|
||||
if cleared_any:
|
||||
self._persist()
|
||||
self._persist(removed_ids=entries_to_prune)
|
||||
return available
|
||||
|
||||
def _select_unlocked(self) -> Optional[PooledCredential]:
|
||||
|
|
@ -1595,7 +1608,11 @@ class CredentialPool:
|
|||
replace(entry, priority=new_priority)
|
||||
for new_priority, entry in enumerate(self._entries)
|
||||
]
|
||||
self._persist()
|
||||
write_credential_pool(
|
||||
self.provider,
|
||||
[entry.to_dict() for entry in self._entries],
|
||||
removed_ids=[removed.id],
|
||||
)
|
||||
if self._current_id == removed.id:
|
||||
self._current_id = None
|
||||
return removed
|
||||
|
|
@ -2257,6 +2274,11 @@ def _seed_custom_pool(pool_key: str, entries: List[PooledCredential]) -> Tuple[b
|
|||
def load_pool(provider: str) -> CredentialPool:
|
||||
provider = (provider or "").strip().lower()
|
||||
raw_entries = read_credential_pool(provider)
|
||||
disk_ids = {
|
||||
entry.get("id")
|
||||
for entry in raw_entries
|
||||
if isinstance(entry, dict) and entry.get("id")
|
||||
}
|
||||
raw_needs_sanitization = any(
|
||||
isinstance(payload, dict)
|
||||
and sanitize_borrowed_credential_payload(payload, provider) != payload
|
||||
|
|
@ -2285,8 +2307,10 @@ def load_pool(provider: str) -> CredentialPool:
|
|||
changed |= _normalize_pool_priorities(provider, entries)
|
||||
|
||||
if changed:
|
||||
new_ids = {entry.id for entry in entries}
|
||||
write_credential_pool(
|
||||
provider,
|
||||
[entry.to_dict() for entry in sorted(entries, key=lambda item: item.priority)],
|
||||
removed_ids=disk_ids - new_ids,
|
||||
)
|
||||
return CredentialPool(provider, entries)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ from dataclasses import dataclass, field
|
|||
from datetime import datetime, timezone
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer, ThreadingHTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, FrozenSet, Iterable, List, Optional, Tuple
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
import httpx
|
||||
|
|
@ -1260,24 +1260,54 @@ def read_credential_pool(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
|||
return list(global_entries) if isinstance(global_entries, list) else []
|
||||
|
||||
|
||||
def write_credential_pool(provider_id: str, entries: List[Dict[str, Any]]) -> Path:
|
||||
def write_credential_pool(
|
||||
provider_id: str,
|
||||
entries: List[Dict[str, Any]],
|
||||
*,
|
||||
removed_ids: Optional[Iterable[str]] = None,
|
||||
) -> Path:
|
||||
"""Persist one provider's credential pool under auth.json.
|
||||
|
||||
This is the final disk-boundary guard for borrowed/reference-only
|
||||
credentials. Callers may pass raw dictionaries, so sanitize here even when
|
||||
``PooledCredential.to_dict()`` already did the same work upstream.
|
||||
|
||||
Re-read the on-disk pool under the same lock and merge entries present on
|
||||
disk but missing from ``entries``. Those were added by another process after
|
||||
the caller loaded its in-memory snapshot; without this merge a later
|
||||
rotation/exhaustion rewrite drops the concurrent credential.
|
||||
|
||||
Pass ``removed_ids`` for entries the caller intentionally removed, so the
|
||||
merge does not resurrect them from the on-disk copy.
|
||||
"""
|
||||
removed = {rid for rid in (removed_ids or ()) if rid}
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
pool = auth_store.get("credential_pool")
|
||||
if not isinstance(pool, dict):
|
||||
pool = {}
|
||||
auth_store["credential_pool"] = pool
|
||||
pool[provider_id] = [
|
||||
sanitized_entries = [
|
||||
sanitize_borrowed_credential_payload(entry, provider_id)
|
||||
if isinstance(entry, dict) else entry
|
||||
for entry in entries
|
||||
]
|
||||
existing = pool.get(provider_id)
|
||||
existing_list = existing if isinstance(existing, list) else []
|
||||
new_ids = {
|
||||
entry.get("id")
|
||||
for entry in sanitized_entries
|
||||
if isinstance(entry, dict) and entry.get("id")
|
||||
}
|
||||
merged: List[Dict[str, Any]] = list(sanitized_entries)
|
||||
for disk_entry in existing_list:
|
||||
if not isinstance(disk_entry, dict):
|
||||
continue
|
||||
disk_id = disk_entry.get("id")
|
||||
if not disk_id or disk_id in new_ids or disk_id in removed:
|
||||
continue
|
||||
merged.append(sanitize_borrowed_credential_payload(disk_entry, provider_id))
|
||||
pool[provider_id] = merged
|
||||
return _save_auth_store(auth_store)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3045,3 +3045,104 @@ def test_codex_oauth_nonterminal_refresh_does_not_quarantine(tmp_path, monkeypat
|
|||
tokens = auth_payload["providers"]["openai-codex"].get("tokens", {})
|
||||
assert tokens.get("access_token") == "old-access-token"
|
||||
assert tokens.get("refresh_token") == "old-refresh-token"
|
||||
|
||||
|
||||
def test_persist_preserves_concurrent_disk_only_entry(tmp_path, monkeypatch):
|
||||
"""Regression for #19566: stale rotation writes keep concurrent entries."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"anthropic": [
|
||||
{
|
||||
"id": "cred-A",
|
||||
"label": "primary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-A",
|
||||
},
|
||||
{
|
||||
"id": "cred-B",
|
||||
"label": "secondary",
|
||||
"auth_type": "api_key",
|
||||
"priority": 1,
|
||||
"source": "manual",
|
||||
"access_token": "sk-B",
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
from hermes_cli.auth import read_credential_pool, write_credential_pool
|
||||
|
||||
pool = load_pool("anthropic")
|
||||
assert {entry.id for entry in pool.entries()} == {"cred-A", "cred-B"}
|
||||
|
||||
disk_snapshot = read_credential_pool("anthropic")
|
||||
disk_snapshot.append(
|
||||
{
|
||||
"id": "cred-C",
|
||||
"label": "added-concurrently",
|
||||
"auth_type": "api_key",
|
||||
"priority": 2,
|
||||
"source": "manual",
|
||||
"access_token": "sk-C",
|
||||
}
|
||||
)
|
||||
write_credential_pool("anthropic", disk_snapshot)
|
||||
|
||||
pool.mark_exhausted_and_rotate(status_code=429)
|
||||
|
||||
final = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
final_ids = [entry["id"] for entry in final["credential_pool"]["anthropic"]]
|
||||
assert set(final_ids) == {"cred-A", "cred-B", "cred-C"}
|
||||
persisted_a = next(
|
||||
entry
|
||||
for entry in final["credential_pool"]["anthropic"]
|
||||
if entry["id"] == "cred-A"
|
||||
)
|
||||
assert persisted_a["last_status"] == "exhausted"
|
||||
|
||||
|
||||
def test_remove_index_does_not_resurrect_via_disk_merge(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
|
||||
_write_auth_store(
|
||||
tmp_path,
|
||||
{
|
||||
"version": 1,
|
||||
"credential_pool": {
|
||||
"anthropic": [
|
||||
{
|
||||
"id": "cred-A",
|
||||
"label": "keep",
|
||||
"auth_type": "api_key",
|
||||
"priority": 0,
|
||||
"source": "manual",
|
||||
"access_token": "sk-A",
|
||||
},
|
||||
{
|
||||
"id": "cred-B",
|
||||
"label": "drop",
|
||||
"auth_type": "api_key",
|
||||
"priority": 1,
|
||||
"source": "manual",
|
||||
"access_token": "sk-B",
|
||||
},
|
||||
]
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
|
||||
pool = load_pool("anthropic")
|
||||
pool.remove_index(2)
|
||||
|
||||
final = json.loads((tmp_path / "hermes" / "auth.json").read_text())
|
||||
final_ids = [entry["id"] for entry in final["credential_pool"]["anthropic"]]
|
||||
assert final_ids == ["cred-A"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue