fix(auth) fix a few cases where refresh tokens were not rotated.

This commit is contained in:
Robin Fernandes 2026-05-17 22:29:40 +10:00 committed by Teknium
parent 20bffa5b37
commit 569bc94b59
6 changed files with 166 additions and 109 deletions

View file

@ -103,7 +103,7 @@ def test_nous_adapter_authenticated_with_refresh_token_only(tmp_path, monkeypatc
assert NousPortalAdapter().is_authenticated()
def test_nous_adapter_get_credential_refreshes_and_persists(tmp_path, monkeypatch):
def test_nous_adapter_get_credential_uses_runtime_resolver(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
_write_auth_store(tmp_path, {
"access_token": "access-tok",
@ -114,32 +114,24 @@ def test_nous_adapter_get_credential_refreshes_and_persists(tmp_path, monkeypatc
})
refreshed_state = {
"access_token": "access-tok",
"refresh_token": "refresh-tok",
"client_id": "hermes-cli",
"portal_base_url": "https://portal.nousresearch.com",
"inference_base_url": "https://inference-api.nousresearch.com/v1",
"agent_key": "minted-bearer",
"agent_key_expires_at": "2099-01-01T00:00:00Z",
"api_key": "minted-bearer",
"base_url": "https://inference-api.nousresearch.com/v1",
"expires_at": "2099-01-01T00:00:00Z",
}
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
"hermes_cli.proxy.adapters.nous_portal.resolve_nous_runtime_credentials",
return_value=refreshed_state,
) as mock_refresh:
) as mock_resolve:
adapter = NousPortalAdapter()
cred = adapter.get_credential()
mock_refresh.assert_called_once()
mock_resolve.assert_called_once()
assert cred.bearer == "minted-bearer"
assert cred.base_url == "https://inference-api.nousresearch.com/v1"
assert cred.expires_at == "2099-01-01T00:00:00Z"
assert cred.token_type == "Bearer"
# Verify state was persisted back
stored = json.loads((tmp_path / "auth.json").read_text())
assert stored["providers"]["nous"]["agent_key"] == "minted-bearer"
def test_nous_adapter_retry_credential_forces_legacy_mint(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
@ -153,19 +145,15 @@ def test_nous_adapter_retry_credential_forces_legacy_mint(tmp_path, monkeypatch)
})
refreshed_state = {
"access_token": "jwt-access",
"refresh_token": "refresh-tok",
"client_id": "hermes-cli",
"portal_base_url": "https://portal.nousresearch.com",
"inference_base_url": "https://inference-api.nousresearch.com/v1",
"agent_key": "legacy-bearer",
"agent_key_expires_at": "2099-01-01T00:00:00Z",
"api_key": "legacy-bearer",
"base_url": "https://inference-api.nousresearch.com/v1",
"expires_at": "2099-01-01T00:00:00Z",
}
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
"hermes_cli.proxy.adapters.nous_portal.resolve_nous_runtime_credentials",
return_value=refreshed_state,
) as mock_refresh:
) as mock_resolve:
adapter = NousPortalAdapter()
cred = adapter.get_retry_credential(
failed_credential=UpstreamCredential(
@ -177,7 +165,7 @@ def test_nous_adapter_retry_credential_forces_legacy_mint(tmp_path, monkeypatch)
assert cred is not None
assert cred.bearer == "legacy-bearer"
assert mock_refresh.call_args.kwargs["inference_auth_mode"] == "legacy"
assert mock_resolve.call_args.kwargs["inference_auth_mode"] == "legacy"
def test_nous_adapter_retry_credential_skips_opaque_bearer(tmp_path, monkeypatch):
@ -189,8 +177,8 @@ def test_nous_adapter_retry_credential_skips_opaque_bearer(tmp_path, monkeypatch
})
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
) as mock_refresh:
"hermes_cli.proxy.adapters.nous_portal.resolve_nous_runtime_credentials",
) as mock_resolve:
adapter = NousPortalAdapter()
cred = adapter.get_retry_credential(
failed_credential=UpstreamCredential(
@ -201,7 +189,7 @@ def test_nous_adapter_retry_credential_skips_opaque_bearer(tmp_path, monkeypatch
)
assert cred is None
mock_refresh.assert_not_called()
mock_resolve.assert_not_called()
def test_nous_adapter_get_credential_raises_when_not_logged_in(tmp_path, monkeypatch):
@ -219,7 +207,7 @@ def test_nous_adapter_get_credential_raises_on_refresh_failure(tmp_path, monkeyp
})
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
"hermes_cli.proxy.adapters.nous_portal.resolve_nous_runtime_credentials",
side_effect=RuntimeError("Refresh session has been revoked"),
):
adapter = NousPortalAdapter()
@ -240,7 +228,7 @@ def test_nous_adapter_quarantines_terminal_refresh_failure(tmp_path, monkeypatch
assert load_pool("nous").select() is not None
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
"hermes_cli.proxy.adapters.nous_portal.resolve_nous_runtime_credentials",
side_effect=AuthError(
"Refresh session has been revoked",
provider="nous",
@ -270,7 +258,7 @@ def test_nous_adapter_get_credential_raises_when_no_agent_key_returned(tmp_path,
})
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
"hermes_cli.proxy.adapters.nous_portal.resolve_nous_runtime_credentials",
return_value={"access_token": "a", "refresh_token": "r"},
):
adapter = NousPortalAdapter()
@ -291,7 +279,7 @@ def test_nous_adapter_concurrent_refresh_serialized(tmp_path, monkeypatch):
counter = [0]
counter_lock = threading.Lock()
def serializing_refresh(state, **kwargs):
def serializing_refresh(**kwargs):
# If another thread is already inside refresh, the lock is broken.
if in_flight.is_set():
overlap_detected.set()
@ -305,10 +293,9 @@ def test_nous_adapter_concurrent_refresh_serialized(tmp_path, monkeypatch):
counter[0] += 1
idx = counter[0]
return {
**state,
"agent_key": f"key-{idx}",
"agent_key_expires_at": "2099-01-01T00:00:00Z",
"inference_base_url": "https://inference-api.nousresearch.com/v1",
"api_key": f"key-{idx}",
"expires_at": "2099-01-01T00:00:00Z",
"base_url": "https://inference-api.nousresearch.com/v1",
}
finally:
in_flight.clear()
@ -324,7 +311,7 @@ def test_nous_adapter_concurrent_refresh_serialized(tmp_path, monkeypatch):
errors.append(exc)
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
"hermes_cli.proxy.adapters.nous_portal.resolve_nous_runtime_credentials",
side_effect=serializing_refresh,
):
threads = [threading.Thread(target=worker) for _ in range(3)]