This commit is contained in:
HARRY 2026-04-24 17:29:24 -05:00 committed by GitHub
commit 485f90b16e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 25 additions and 6 deletions

View file

@ -211,20 +211,24 @@ class ThreadParticipationTracker:
def __init__(self, platform_name: str, max_tracked: int = 500):
self._platform = platform_name
self._max_tracked = max_tracked
self._threads: set = self._load()
self._threads: dict[str, None] = {
str(thread_id): None for thread_id in self._load()
}
def _state_path(self) -> Path:
from hermes_constants import get_hermes_home
return get_hermes_home() / f"{self._platform}_threads.json"
def _load(self) -> set:
def _load(self) -> list[str]:
path = self._state_path()
if path.exists():
try:
return set(json.loads(path.read_text(encoding="utf-8")))
data = json.loads(path.read_text(encoding="utf-8"))
if isinstance(data, list):
return [str(thread_id) for thread_id in data]
except Exception:
pass
return set()
return []
def _save(self) -> None:
path = self._state_path()
@ -232,13 +236,13 @@ class ThreadParticipationTracker:
thread_list = list(self._threads)
if len(thread_list) > self._max_tracked:
thread_list = thread_list[-self._max_tracked:]
self._threads = set(thread_list)
self._threads = {thread_id: None for thread_id in thread_list}
path.write_text(json.dumps(thread_list), encoding="utf-8")
def mark(self, thread_id: str) -> None:
"""Mark *thread_id* as participated and persist."""
if thread_id not in self._threads:
self._threads.add(thread_id)
self._threads[thread_id] = None
self._save()
def __contains__(self, thread_id: str) -> bool:

View file

@ -67,6 +67,21 @@ class TestDiscordThreadPersistence:
saved = json.loads((tmp_path / "discord_threads.json").read_text())
assert len(saved) == 5
assert saved == ["5", "6", "7", "8", "9"]
def test_capacity_keeps_newest_thread_when_existing_state_is_full(self, tmp_path):
"""A newly joined thread must not be evicted by unordered set iteration."""
state_file = tmp_path / "discord_threads.json"
state_file.write_text(json.dumps(["0", "1", "2", "3", "4"]), encoding="utf-8")
adapter = self._make_adapter(tmp_path)
adapter._threads._max_tracked = 5
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
adapter._threads.mark("newest")
saved = json.loads(state_file.read_text(encoding="utf-8"))
assert saved == ["1", "2", "3", "4", "newest"]
assert "newest" in adapter._threads
def test_corrupted_state_file_falls_back_to_empty(self, tmp_path):
state_file = tmp_path / "discord_threads.json"