fix(langfuse): bound _TRACE_STATE growth from non-finalizing turns

Scoping the trace key by turn_id (the prior commit) fixed cross-turn
collisions but introduced a slow leak: _finish_trace only pops a key when a
turn ends cleanly (final response has content and no tool calls), so any
turn that is interrupted, ends on a tool call, or has empty final content
now leaves its uniquely-keyed entry in _TRACE_STATE forever. Previously the
constant per-session key was overwritten by the next turn, capping growth at
~1 entry per session.

Add an LRU cap (_MAX_TRACE_STATE) enforced by _evict_stale_locked, called
under _STATE_LOCK immediately before each insert. It evicts the
least-recently-updated entries (using the previously-dead last_updated_at
field) and ends their root span so nothing dangles. Regression test drives
50 non-finalizing turns against a cap of 8 and asserts the dict stays bounded
with the most-recent turns surviving.
This commit is contained in:
kshitijk4poor 2026-06-18 12:59:41 +05:30
parent e1d10ec1ed
commit f4fbaa6cda
2 changed files with 56 additions and 0 deletions

View file

@ -54,6 +54,15 @@ class TraceState:
_STATE_LOCK = threading.Lock()
_TRACE_STATE: Dict[str, TraceState] = {}
# Hard cap on live trace state. Each turn keys _TRACE_STATE by a unique
# turn_id, and an entry is normally reclaimed by _finish_trace when a turn
# ends cleanly (final response has content and no tool calls). A turn that
# never reaches that state — interrupted, a tool-only final step, or empty
# final content — would otherwise linger forever, so over the cap we evict
# the least-recently-updated entries (ending their root span first). The cap
# is far above any realistic concurrent-live-turn working set; it exists only
# to bound the leak from non-finalizing turns, not to limit concurrency.
_MAX_TRACE_STATE = 256
_LANGFUSE_CLIENT = None
_READ_FILE_LINE_RE = re.compile(r"^\s*(\d+)\|(.*)$")
_READ_FILE_HEAD_LINES = 25
@ -701,6 +710,30 @@ def _merge_trace_output(output: Any, state: TraceState) -> Any:
return merged
def _evict_stale_locked() -> None:
"""Drop least-recently-updated trace state to make room for a new entry.
Caller MUST hold ``_STATE_LOCK`` and call this immediately before inserting
one new entry. Bounds the leak from turns that never reach ``_finish_trace``
(interrupted / tool-only final step / empty final content), whose unique
per-turn key would otherwise linger forever. We evict down to
``_MAX_TRACE_STATE - 1`` so that the about-to-be-added entry leaves the dict
at ``_MAX_TRACE_STATE`` a true ceiling. The evicted entry's root span is
ended so it is not left dangling on the Langfuse side.
"""
over = len(_TRACE_STATE) - (_MAX_TRACE_STATE - 1)
if over <= 0:
return
# Oldest-first by last_updated_at; evict just enough to make room.
stale = sorted(_TRACE_STATE.items(), key=lambda kv: kv[1].last_updated_at)[:over]
for key, state in stale:
_TRACE_STATE.pop(key, None)
try:
state.root_span.end()
except Exception as exc: # pragma: no cover - fail-open
_debug(f"evict stale trace failed: {exc}")
def _finish_trace(task_key: str, *, output: Any = None) -> None:
client = _get_langfuse()
if client is None:
@ -785,6 +818,7 @@ def on_pre_llm_call(*, task_id: str = "", session_id: str = "", platform: str =
turn_id=turn_id,
api_request_id=api_request_id,
)
_evict_stale_locked()
_TRACE_STATE[task_key] = state
state.last_updated_at = time.time()
@ -848,6 +882,7 @@ def on_pre_llm_request(
turn_id=turn_id,
api_request_id=api_request_id,
)
_evict_stale_locked()
_TRACE_STATE[task_key] = state
state.last_updated_at = time.time()
previous = state.generations.pop(req_key, None)

View file

@ -373,6 +373,27 @@ class TestTurnTraceIsolation:
assert k_pre_api == k_post_api == k_post_turn
def test_non_finalizing_turns_do_not_grow_state_unboundedly(self, monkeypatch):
"""Per-turn keys mean a turn that never finalizes leaves a lingering
entry. Without a cap that grows once per non-finalizing turn forever;
the LRU eviction must bound _TRACE_STATE at _MAX_TRACE_STATE.
"""
mod = self._fresh_plugin()
started: list = []
monkeypatch.setattr(mod, "_get_langfuse", lambda: self._fake_client(started))
monkeypatch.setattr(mod, "_end_observation", lambda *a, **k: None)
monkeypatch.setattr(mod, "_MAX_TRACE_STATE", 8)
mod._TRACE_STATE.clear()
# Far more non-finalizing turns than the cap.
for n in range(50):
self._run_turn(mod, session="sess-leak", turn_n=n, finalize=False)
assert len(mod._TRACE_STATE) <= 8
# The survivors are the most-recently-updated turns (LRU eviction).
surviving = sorted(int(k.rsplit("turn", 1)[1]) for k in mod._TRACE_STATE)
assert surviving == list(range(42, 50))
def test_trace_key_strings_unchanged_by_refactor(self):
"""Pin the exact key strings across all task/session/turn/api
combinations so the _scope_prefix extraction can never silently change