diff --git a/tests/cli/test_branch_command.py b/tests/cli/test_branch_command.py index 9c3ec61d8..2483ecbd7 100644 --- a/tests/cli/test_branch_command.py +++ b/tests/cli/test_branch_command.py @@ -21,12 +21,35 @@ import pytest @pytest.fixture def session_db(tmp_path): """Create a real SessionDB for testing.""" - os.environ["HERMES_HOME"] = str(tmp_path / ".hermes") - os.makedirs(tmp_path / ".hermes", exist_ok=True) + hermes_home = tmp_path / ".hermes" + previous_home = os.environ.get("HERMES_HOME") + os.environ["HERMES_HOME"] = str(hermes_home) + os.makedirs(hermes_home, exist_ok=True) from hermes_state import SessionDB - db = SessionDB(db_path=tmp_path / ".hermes" / "test_sessions.db") - yield db - db.close() + db = SessionDB(db_path=hermes_home / "test_sessions.db") + try: + yield db + finally: + db.close() + if previous_home is None: + os.environ.pop("HERMES_HOME", None) + else: + os.environ["HERMES_HOME"] = previous_home + + +def test_session_db_fixture_restores_hermes_home_after_teardown(tmp_path, monkeypatch): + """The session_db fixture should not leak HERMES_HOME after teardown.""" + original_home = str(tmp_path / "original-home") + monkeypatch.setenv("HERMES_HOME", original_home) + + fixture_gen = session_db.__wrapped__(tmp_path) + db = next(fixture_gen) + assert os.environ["HERMES_HOME"] != original_home + + with pytest.raises(StopIteration): + next(fixture_gen) + + assert os.environ.get("HERMES_HOME") == original_home @pytest.fixture