diff --git a/hermes_cli/models.py b/hermes_cli/models.py index 4a6d4c478b..85413267da 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -404,13 +404,38 @@ def partition_nous_models_by_tier( return (selectable, unavailable) +# --------------------------------------------------------------------------- +# TTL cache for free-tier detection — avoids repeated API calls within a +# session while still picking up upgrades quickly. +# --------------------------------------------------------------------------- +_FREE_TIER_CACHE_TTL: int = 180 # seconds (3 minutes) +_free_tier_cache: tuple[bool, float] | None = None # (result, timestamp) + + +def clear_nous_free_tier_cache() -> None: + """Invalidate the cached free-tier result (e.g. after login/logout).""" + global _free_tier_cache + _free_tier_cache = None + + def check_nous_free_tier() -> bool: """Check if the current Nous Portal user is on a free (unpaid) tier. - Resolves the OAuth access token from the auth store, calls the - portal account endpoint, and returns True if the account has no - paid subscription. Returns False (assume paid) on any error. + Results are cached for ``_FREE_TIER_CACHE_TTL`` seconds to avoid + hitting the Portal API on every call. The cache is short-lived so + that an account upgrade is reflected within a few minutes. + + Returns False (assume paid) on any error — never blocks paying users. """ + global _free_tier_cache + import time + + now = time.monotonic() + if _free_tier_cache is not None: + cached_result, cached_at = _free_tier_cache + if now - cached_at < _FREE_TIER_CACHE_TTL: + return cached_result + try: from hermes_cli.auth import get_provider_auth_state, resolve_nous_runtime_credentials @@ -419,15 +444,20 @@ def check_nous_free_tier() -> bool: state = get_provider_auth_state("nous") if not state: + _free_tier_cache = (False, now) return False access_token = state.get("access_token", "") portal_url = state.get("portal_base_url", "") if not access_token: + _free_tier_cache = (False, now) return False account_info = fetch_nous_account_tier(access_token, portal_url) - return is_nous_free_tier(account_info) + result = is_nous_free_tier(account_info) + _free_tier_cache = (result, now) + return result except Exception: + _free_tier_cache = (False, now) return False # default to paid on error — don't block users diff --git a/tests/hermes_cli/test_models.py b/tests/hermes_cli/test_models.py index 3d1564ae23..776256f0f0 100644 --- a/tests/hermes_cli/test_models.py +++ b/tests/hermes_cli/test_models.py @@ -1,10 +1,15 @@ """Tests for the hermes_cli models module.""" +from unittest.mock import patch, MagicMock + from hermes_cli.models import ( OPENROUTER_MODELS, menu_labels, model_ids, detect_provider_for_model, filter_nous_free_models, _NOUS_ALLOWED_FREE_MODELS, is_nous_free_tier, partition_nous_models_by_tier, + check_nous_free_tier, clear_nous_free_tier_cache, + _FREE_TIER_CACHE_TTL, ) +import hermes_cli.models as _models_mod class TestModelIds: @@ -291,3 +296,63 @@ class TestPartitionNousModelsByTier: sel, unav = partition_nous_models_by_tier(models, pricing, free_tier=True) assert sel == [] assert unav == models + + +class TestCheckNousFreeTierCache: + """Tests for the TTL cache on check_nous_free_tier().""" + + def setup_method(self): + """Reset cache before each test.""" + clear_nous_free_tier_cache() + + def teardown_method(self): + """Reset cache after each test.""" + clear_nous_free_tier_cache() + + @patch("hermes_cli.models.fetch_nous_account_tier") + @patch("hermes_cli.models.is_nous_free_tier", return_value=True) + def test_result_is_cached(self, mock_is_free, mock_fetch): + """Second call within TTL returns cached result without API call.""" + mock_fetch.return_value = {"subscription": {"monthly_charge": 0}} + with patch("hermes_cli.auth.get_provider_auth_state", return_value={"access_token": "tok"}), \ + patch("hermes_cli.auth.resolve_nous_runtime_credentials"): + result1 = check_nous_free_tier() + result2 = check_nous_free_tier() + + assert result1 is True + assert result2 is True + # fetch_nous_account_tier should only be called once (cached on second call) + assert mock_fetch.call_count == 1 + + @patch("hermes_cli.models.fetch_nous_account_tier") + @patch("hermes_cli.models.is_nous_free_tier", return_value=False) + def test_cache_expires_after_ttl(self, mock_is_free, mock_fetch): + """After TTL expires, the API is called again.""" + mock_fetch.return_value = {"subscription": {"monthly_charge": 20}} + with patch("hermes_cli.auth.get_provider_auth_state", return_value={"access_token": "tok"}), \ + patch("hermes_cli.auth.resolve_nous_runtime_credentials"): + result1 = check_nous_free_tier() + assert mock_fetch.call_count == 1 + + # Simulate TTL expiry by backdating the cache timestamp + cached_result, cached_at = _models_mod._free_tier_cache + _models_mod._free_tier_cache = (cached_result, cached_at - _FREE_TIER_CACHE_TTL - 1) + + result2 = check_nous_free_tier() + assert mock_fetch.call_count == 2 + + assert result1 is False + assert result2 is False + + def test_clear_cache_forces_refresh(self): + """clear_nous_free_tier_cache() invalidates the cached result.""" + # Manually seed the cache + import time + _models_mod._free_tier_cache = (True, time.monotonic()) + + clear_nous_free_tier_cache() + assert _models_mod._free_tier_cache is None + + def test_cache_ttl_is_short(self): + """TTL should be short enough to catch upgrades quickly (<=5 min).""" + assert _FREE_TIER_CACHE_TTL <= 300