diff --git a/tests/hermes_cli/test_models.py b/tests/hermes_cli/test_models.py index 5b9840c286..d40a471444 100644 --- a/tests/hermes_cli/test_models.py +++ b/tests/hermes_cli/test_models.py @@ -6,6 +6,7 @@ from hermes_cli.models import ( OPENROUTER_MODELS, fetch_openrouter_models, menu_labels, model_ids, detect_provider_for_model, filter_nous_free_models, _NOUS_ALLOWED_FREE_MODELS, is_nous_free_tier, partition_nous_models_by_tier, + check_nous_free_tier, _FREE_TIER_CACHE_TTL, ) import hermes_cli.models as _models_mod @@ -351,3 +352,48 @@ class TestPartitionNousModelsByTier: assert unav == models +class TestCheckNousFreeTierCache: + """Tests for the TTL cache on check_nous_free_tier().""" + + def setup_method(self): + _models_mod._free_tier_cache = None + + def teardown_method(self): + _models_mod._free_tier_cache = None + + @patch("hermes_cli.models.fetch_nous_account_tier") + @patch("hermes_cli.models.is_nous_free_tier", return_value=True) + def test_result_is_cached(self, mock_is_free, mock_fetch): + """Second call within TTL returns cached result without API call.""" + mock_fetch.return_value = {"subscription": {"monthly_charge": 0}} + with patch("hermes_cli.auth.get_provider_auth_state", return_value={"access_token": "tok"}), \ + patch("hermes_cli.auth.resolve_nous_runtime_credentials"): + result1 = check_nous_free_tier() + result2 = check_nous_free_tier() + + assert result1 is True + assert result2 is True + assert mock_fetch.call_count == 1 + + @patch("hermes_cli.models.fetch_nous_account_tier") + @patch("hermes_cli.models.is_nous_free_tier", return_value=False) + def test_cache_expires_after_ttl(self, mock_is_free, mock_fetch): + """After TTL expires, the API is called again.""" + mock_fetch.return_value = {"subscription": {"monthly_charge": 20}} + with patch("hermes_cli.auth.get_provider_auth_state", return_value={"access_token": "tok"}), \ + patch("hermes_cli.auth.resolve_nous_runtime_credentials"): + result1 = check_nous_free_tier() + assert mock_fetch.call_count == 1 + + cached_result, cached_at = _models_mod._free_tier_cache + _models_mod._free_tier_cache = (cached_result, cached_at - _FREE_TIER_CACHE_TTL - 1) + + result2 = check_nous_free_tier() + assert mock_fetch.call_count == 2 + + assert result1 is False + assert result2 is False + + def test_cache_ttl_is_short(self): + """TTL should be short enough to catch upgrades quickly (<=5 min).""" + assert _FREE_TIER_CACHE_TTL <= 300 diff --git a/tests/tools/test_approval.py b/tests/tools/test_approval.py index a684b247ba..99edb3b182 100644 --- a/tests/tools/test_approval.py +++ b/tests/tools/test_approval.py @@ -110,6 +110,52 @@ class TestSafeCommand: assert desc is None +def _clear_session(key): + """Replace for removed clear_session() — directly clear internal state.""" + approval_module._session_approved.pop(key, None) + approval_module._pending.pop(key, None) + + +class TestApproveAndCheckSession: + def test_session_approval(self): + key = "test_session_approve" + _clear_session(key) + + assert is_approved(key, "rm") is False + approve_session(key, "rm") + assert is_approved(key, "rm") is True + + +class TestSessionKeyContext: + def test_context_session_key_overrides_process_env(self): + token = approval_module.set_current_session_key("alice") + try: + with mock_patch.dict("os.environ", {"HERMES_SESSION_KEY": "bob"}, clear=False): + assert approval_module.get_current_session_key() == "alice" + finally: + approval_module.reset_current_session_key(token) + + def test_gateway_runner_binds_session_key_to_context_before_agent_run(self): + run_py = Path(__file__).resolve().parents[2] / "gateway" / "run.py" + module = ast.parse(run_py.read_text(encoding="utf-8")) + + run_sync = None + for node in ast.walk(module): + if isinstance(node, ast.FunctionDef) and node.name == "run_sync": + run_sync = node + break + + assert run_sync is not None, "gateway.run.run_sync not found" + + called_names = set() + for node in ast.walk(run_sync): + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): + called_names.add(node.func.id) + + assert "set_current_session_key" in called_names + assert "reset_current_session_key" in called_names + + class TestRmFalsePositiveFix: """Regression tests: filenames starting with 'r' must NOT trigger recursive delete.""" @@ -383,6 +429,19 @@ class TestPatternKeyUniqueness: "approving one silently approves the other" ) + def test_approving_find_exec_does_not_approve_find_delete(self): + """Session approval for find -exec rm must not carry over to find -delete.""" + _, key_exec, _ = detect_dangerous_command("find . -exec rm {} \\;") + _, key_delete, _ = detect_dangerous_command("find . -name '*.tmp' -delete") + session = "test_find_collision" + _clear_session(session) + approve_session(session, key_exec) + assert is_approved(session, key_exec) is True + assert is_approved(session, key_delete) is False, ( + "approving find -exec rm should not auto-approve find -delete" + ) + _clear_session(session) + def test_legacy_find_key_still_approves_find_exec(self): """Old allowlist entry 'find' should keep approving the matching command.""" _, key_exec, _ = detect_dangerous_command("find . -exec rm {} \\;")