diff --git a/gateway/platforms/helpers.py b/gateway/platforms/helpers.py index 18d97fcb7a..27dc7b21f6 100644 --- a/gateway/platforms/helpers.py +++ b/gateway/platforms/helpers.py @@ -211,20 +211,24 @@ class ThreadParticipationTracker: def __init__(self, platform_name: str, max_tracked: int = 500): self._platform = platform_name self._max_tracked = max_tracked - self._threads: set = self._load() + self._threads: dict[str, None] = { + str(thread_id): None for thread_id in self._load() + } def _state_path(self) -> Path: from hermes_constants import get_hermes_home return get_hermes_home() / f"{self._platform}_threads.json" - def _load(self) -> set: + def _load(self) -> list[str]: path = self._state_path() if path.exists(): try: - return set(json.loads(path.read_text(encoding="utf-8"))) + data = json.loads(path.read_text(encoding="utf-8")) + if isinstance(data, list): + return [str(thread_id) for thread_id in data] except Exception: pass - return set() + return [] def _save(self) -> None: path = self._state_path() @@ -232,13 +236,13 @@ class ThreadParticipationTracker: thread_list = list(self._threads) if len(thread_list) > self._max_tracked: thread_list = thread_list[-self._max_tracked:] - self._threads = set(thread_list) + self._threads = {thread_id: None for thread_id in thread_list} path.write_text(json.dumps(thread_list), encoding="utf-8") def mark(self, thread_id: str) -> None: """Mark *thread_id* as participated and persist.""" if thread_id not in self._threads: - self._threads.add(thread_id) + self._threads[thread_id] = None self._save() def __contains__(self, thread_id: str) -> bool: diff --git a/tests/gateway/test_discord_thread_persistence.py b/tests/gateway/test_discord_thread_persistence.py index 083f61ac7c..b6be0a6683 100644 --- a/tests/gateway/test_discord_thread_persistence.py +++ b/tests/gateway/test_discord_thread_persistence.py @@ -67,6 +67,21 @@ class TestDiscordThreadPersistence: saved = json.loads((tmp_path / "discord_threads.json").read_text()) assert len(saved) == 5 + assert saved == ["5", "6", "7", "8", "9"] + + def test_capacity_keeps_newest_thread_when_existing_state_is_full(self, tmp_path): + """A newly joined thread must not be evicted by unordered set iteration.""" + state_file = tmp_path / "discord_threads.json" + state_file.write_text(json.dumps(["0", "1", "2", "3", "4"]), encoding="utf-8") + adapter = self._make_adapter(tmp_path) + adapter._threads._max_tracked = 5 + + with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}): + adapter._threads.mark("newest") + + saved = json.loads(state_file.read_text(encoding="utf-8")) + assert saved == ["1", "2", "3", "4", "newest"] + assert "newest" in adapter._threads def test_corrupted_state_file_falls_back_to_empty(self, tmp_path): state_file = tmp_path / "discord_threads.json"