fix(gateway): ensure deterministic thread eviction in helpers

This commit is contained in:
hharry11 2026-04-21 21:26:52 +03:00 committed by Teknium
parent 935cf2fcca
commit 247c9d468c
2 changed files with 25 additions and 6 deletions

View file

@ -222,33 +222,37 @@ class ThreadParticipationTracker:
def __init__(self, platform_name: str, max_tracked: int = 500): def __init__(self, platform_name: str, max_tracked: int = 500):
self._platform = platform_name self._platform = platform_name
self._max_tracked = max_tracked self._max_tracked = max_tracked
self._threads: set = self._load() self._threads: dict[str, None] = {
str(thread_id): None for thread_id in self._load()
}
def _state_path(self) -> Path: def _state_path(self) -> Path:
from hermes_constants import get_hermes_home from hermes_constants import get_hermes_home
return get_hermes_home() / f"{self._platform}_threads.json" return get_hermes_home() / f"{self._platform}_threads.json"
def _load(self) -> set: def _load(self) -> list[str]:
path = self._state_path() path = self._state_path()
if path.exists(): if path.exists():
try: try:
return set(json.loads(path.read_text(encoding="utf-8"))) data = json.loads(path.read_text(encoding="utf-8"))
if isinstance(data, list):
return [str(thread_id) for thread_id in data]
except Exception: except Exception:
pass pass
return set() return []
def _save(self) -> None: def _save(self) -> None:
path = self._state_path() path = self._state_path()
thread_list = list(self._threads) thread_list = list(self._threads)
if len(thread_list) > self._max_tracked: if len(thread_list) > self._max_tracked:
thread_list = thread_list[-self._max_tracked:] thread_list = thread_list[-self._max_tracked:]
self._threads = set(thread_list) self._threads = {thread_id: None for thread_id in thread_list}
atomic_json_write(path, thread_list, indent=None) atomic_json_write(path, thread_list, indent=None)
def mark(self, thread_id: str) -> None: def mark(self, thread_id: str) -> None:
"""Mark *thread_id* as participated and persist.""" """Mark *thread_id* as participated and persist."""
if thread_id not in self._threads: if thread_id not in self._threads:
self._threads.add(thread_id) self._threads[thread_id] = None
self._save() self._save()
def __contains__(self, thread_id: str) -> bool: def __contains__(self, thread_id: str) -> bool:

View file

@ -67,6 +67,21 @@ class TestDiscordThreadPersistence:
saved = json.loads((tmp_path / "discord_threads.json").read_text()) saved = json.loads((tmp_path / "discord_threads.json").read_text())
assert len(saved) == 5 assert len(saved) == 5
assert saved == ["5", "6", "7", "8", "9"]
def test_capacity_keeps_newest_thread_when_existing_state_is_full(self, tmp_path):
"""A newly joined thread must not be evicted by unordered set iteration."""
state_file = tmp_path / "discord_threads.json"
state_file.write_text(json.dumps(["0", "1", "2", "3", "4"]), encoding="utf-8")
adapter = self._make_adapter(tmp_path)
adapter._threads._max_tracked = 5
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
adapter._threads.mark("newest")
saved = json.loads(state_file.read_text(encoding="utf-8"))
assert saved == ["1", "2", "3", "4", "newest"]
assert "newest" in adapter._threads
def test_corrupted_state_file_falls_back_to_empty(self, tmp_path): def test_corrupted_state_file_falls_back_to_empty(self, tmp_path):
state_file = tmp_path / "discord_threads.json" state_file = tmp_path / "discord_threads.json"