diff --git a/plugins/observability/langfuse/__init__.py b/plugins/observability/langfuse/__init__.py index b992484b05e..31904d47e30 100644 --- a/plugins/observability/langfuse/__init__.py +++ b/plugins/observability/langfuse/__init__.py @@ -54,6 +54,15 @@ class TraceState: _STATE_LOCK = threading.Lock() _TRACE_STATE: Dict[str, TraceState] = {} +# Hard cap on live trace state. Each turn keys _TRACE_STATE by a unique +# turn_id, and an entry is normally reclaimed by _finish_trace when a turn +# ends cleanly (final response has content and no tool calls). A turn that +# never reaches that state — interrupted, a tool-only final step, or empty +# final content — would otherwise linger forever, so over the cap we evict +# the least-recently-updated entries (ending their root span first). The cap +# is far above any realistic concurrent-live-turn working set; it exists only +# to bound the leak from non-finalizing turns, not to limit concurrency. +_MAX_TRACE_STATE = 256 _LANGFUSE_CLIENT = None _READ_FILE_LINE_RE = re.compile(r"^\s*(\d+)\|(.*)$") _READ_FILE_HEAD_LINES = 25 @@ -219,14 +228,43 @@ def _get_langfuse() -> Optional[Langfuse]: return _LANGFUSE_CLIENT -def _trace_key(task_id: str, session_id: str) -> str: +def _scope_prefix(task_id: str, session_id: str) -> str: + """The task/session/thread prefix shared by every trace-key shape.""" if task_id: - return task_id + return f"task:{task_id}" if session_id: return f"session:{session_id}" return f"thread:{threading.get_ident()}" +def _trace_key( + task_id: str, + session_id: str, + *, + turn_id: str = "", + api_request_id: str = "", +) -> str: + """Build a stable in-process trace scope key for one agent turn. + + Older Hermes paths only expose ``task_id``/``session_id``. Newer paths + pass ``turn_id`` and ``api_request_id`` in LLM/tool hooks; when present, + they must scope trace state so concurrent requests sharing one task/session + never collide. ``turn_id`` is preferred over ``api_request_id`` so the + turn-level ``post_llm_call`` hook (which carries ``turn_id`` but no + ``api_request_id``) resolves to the same key as the request-level hooks. + """ + if turn_id: + return f"{_scope_prefix(task_id, session_id)}:turn:{turn_id}" + if api_request_id: + return f"{_scope_prefix(task_id, session_id)}:api:{api_request_id}" + # Legacy shape: a bare ``task_id`` (NOT the ``task:`` prefix) when present, + # otherwise the session/thread prefix. Kept distinct for backward + # compatibility with keys minted before turn/request scoping existed. + if task_id: + return task_id + return _scope_prefix(task_id, session_id) + + def _is_base64_data_uri(value: str) -> bool: prefix = value[:200].lower() return prefix.startswith("data:") and ";base64," in prefix @@ -563,12 +601,15 @@ def _usage_and_cost(response: Any, *, provider: str, api_mode: str, model: str, def _start_root_trace(task_key: str, *, task_id: str, session_id: str, platform: str, provider: str, model: str, - api_mode: str, messages: Any, client: Langfuse) -> TraceState: + api_mode: str, messages: Any, client: Langfuse, + turn_id: str = "", api_request_id: str = "") -> TraceState: trace_id = client.create_trace_id(seed=f"{session_id or 'sessionless'}::{task_id or task_key}") trace_input = _extract_last_user_message(messages) metadata = { "source": "hermes", "task_id": task_id, + "turn_id": turn_id, + "api_request_id": api_request_id, "platform": platform, "provider": provider, "model": model, @@ -669,6 +710,30 @@ def _merge_trace_output(output: Any, state: TraceState) -> Any: return merged +def _evict_stale_locked() -> None: + """Drop least-recently-updated trace state to make room for a new entry. + + Caller MUST hold ``_STATE_LOCK`` and call this immediately before inserting + one new entry. Bounds the leak from turns that never reach ``_finish_trace`` + (interrupted / tool-only final step / empty final content), whose unique + per-turn key would otherwise linger forever. We evict down to + ``_MAX_TRACE_STATE - 1`` so that the about-to-be-added entry leaves the dict + at ``_MAX_TRACE_STATE`` — a true ceiling. The evicted entry's root span is + ended so it is not left dangling on the Langfuse side. + """ + over = len(_TRACE_STATE) - (_MAX_TRACE_STATE - 1) + if over <= 0: + return + # Oldest-first by last_updated_at; evict just enough to make room. + stale = sorted(_TRACE_STATE.items(), key=lambda kv: kv[1].last_updated_at)[:over] + for key, state in stale: + _TRACE_STATE.pop(key, None) + try: + state.root_span.end() + except Exception as exc: # pragma: no cover - fail-open + _debug(f"evict stale trace failed: {exc}") + + def _finish_trace(task_key: str, *, output: Any = None) -> None: client = _get_langfuse() if client is None: @@ -712,7 +777,8 @@ def _request_key(api_call_count: Any) -> str: def on_pre_llm_call(*, task_id: str = "", session_id: str = "", platform: str = "", model: str = "", provider: str = "", base_url: str = "", api_mode: str = "", api_call_count: int = 0, messages: Any = None, turn_type: str = "user", - conversation_history: Any = None, user_message: Any = None, **_: Any) -> None: + conversation_history: Any = None, user_message: Any = None, + turn_id: str = "", api_request_id: str = "", **_: Any) -> None: # Older Hermes branches used pre_llm_call for request-scoped tracing and # passed the actual API messages. Current Hermes also has a turn-scoped # pre_llm_call used for context injection; tracing that hook creates an @@ -729,7 +795,12 @@ def on_pre_llm_call(*, task_id: str = "", session_id: str = "", platform: str = # pre_llm_call with API messages directly. Current Hermes fires # pre_llm_call for context injection (conversation_history/user_message, # no messages list) — tracing that would create orphan traces. - task_key = _trace_key(task_id, session_id) + task_key = _trace_key( + task_id, + session_id, + turn_id=turn_id, + api_request_id=api_request_id, + ) with _STATE_LOCK: state = _TRACE_STATE.get(task_key) @@ -744,7 +815,10 @@ def on_pre_llm_call(*, task_id: str = "", session_id: str = "", platform: str = api_mode=api_mode, messages=messages, client=client, + turn_id=turn_id, + api_request_id=api_request_id, ) + _evict_stale_locked() _TRACE_STATE[task_key] = state state.last_updated_at = time.time() @@ -769,6 +843,8 @@ def on_pre_llm_request( max_tokens: Any = None, conversation_history: Any = None, user_message: Any = None, + turn_id: str = "", + api_request_id: str = "", **_: Any, ) -> None: client = _get_langfuse() @@ -782,7 +858,12 @@ def on_pre_llm_request( user_message=user_message, ) - task_key = _trace_key(task_id, session_id) + task_key = _trace_key( + task_id, + session_id, + turn_id=turn_id, + api_request_id=api_request_id, + ) req_key = _request_key(api_call_count) with _STATE_LOCK: @@ -798,7 +879,10 @@ def on_pre_llm_request( api_mode=api_mode, messages=input_messages, client=client, + turn_id=turn_id, + api_request_id=api_request_id, ) + _evict_stale_locked() _TRACE_STATE[task_key] = state state.last_updated_at = time.time() previous = state.generations.pop(req_key, None) @@ -827,12 +911,18 @@ def on_post_llm_call(*, task_id: str = "", session_id: str = "", provider: str = api_duration: float = 0.0, finish_reason: str = "", usage: Any = None, assistant_content_chars: int = 0, assistant_tool_call_count: int = 0, assistant_response: Any = None, + turn_id: str = "", api_request_id: str = "", **_: Any) -> None: client = _get_langfuse() if client is None: return - task_key = _trace_key(task_id, session_id) + task_key = _trace_key( + task_id, + session_id, + turn_id=turn_id, + api_request_id=api_request_id, + ) req_key = _request_key(api_call_count) with _STATE_LOCK: @@ -950,12 +1040,18 @@ def on_post_llm_call(*, task_id: str = "", session_id: str = "", provider: str = def on_pre_tool_call(*, tool_name: str = "", args: Any = None, task_id: str = "", - session_id: str = "", tool_call_id: str = "", **_: Any) -> None: + session_id: str = "", tool_call_id: str = "", + turn_id: str = "", api_request_id: str = "", **_: Any) -> None: client = _get_langfuse() if client is None: return - task_key = _trace_key(task_id, session_id) + task_key = _trace_key( + task_id, + session_id, + turn_id=turn_id, + api_request_id=api_request_id, + ) with _STATE_LOCK: state = _TRACE_STATE.get(task_key) @@ -976,8 +1072,14 @@ def on_pre_tool_call(*, tool_name: str = "", args: Any = None, task_id: str = "" def on_post_tool_call(*, tool_name: str = "", args: Any = None, result: Any = None, - task_id: str = "", session_id: str = "", tool_call_id: str = "", **_: Any) -> None: - task_key = _trace_key(task_id, session_id) + task_id: str = "", session_id: str = "", tool_call_id: str = "", + turn_id: str = "", api_request_id: str = "", **_: Any) -> None: + task_key = _trace_key( + task_id, + session_id, + turn_id=turn_id, + api_request_id=api_request_id, + ) observation = None with _STATE_LOCK: diff --git a/tests/plugins/test_langfuse_plugin.py b/tests/plugins/test_langfuse_plugin.py index ca91feae613..dd58149eba2 100644 --- a/tests/plugins/test_langfuse_plugin.py +++ b/tests/plugins/test_langfuse_plugin.py @@ -205,6 +205,216 @@ class TestPayloadSanitization: } +class TestTraceScopeKey: + def _fresh_plugin(self): + mod_name = "plugins.observability.langfuse" + sys.modules.pop(mod_name, None) + return importlib.import_module(mod_name) + + def test_trace_key_scopes_by_turn_id_when_available(self): + plugin = self._fresh_plugin() + + key_a = plugin._trace_key("task-1", "session-1", turn_id="turn-a") + key_b = plugin._trace_key("task-1", "session-1", turn_id="turn-b") + + assert key_a != key_b + assert "turn:turn-a" in key_a + assert "turn:turn-b" in key_b + + def test_trace_key_scopes_by_api_request_id_when_turn_missing(self): + plugin = self._fresh_plugin() + + key_a = plugin._trace_key("task-1", "session-1", api_request_id="req-a") + key_b = plugin._trace_key("task-1", "session-1", api_request_id="req-b") + + assert key_a != key_b + assert "api:req-a" in key_a + assert "api:req-b" in key_b + + def test_trace_key_keeps_legacy_shape_without_turn_or_api_id(self): + plugin = self._fresh_plugin() + assert plugin._trace_key("task-1", "session-1") == "task-1" + + +# --------------------------------------------------------------------------- +# End-to-end collision regression: two turns of ONE gateway session must not +# share trace state. The helper-level tests above prove _trace_key returns +# distinct keys; this drives the real pre/post hooks to prove the keys are +# actually threaded through so the second turn gets its own root trace. +# +# Gateway reality this reproduces: +# * task_id == session_id for every turn (gateway/run.py) +# * turn_id is unique per turn (turn_context.py) +# * api_call_count resets to 1 each turn (conversation_loop.py) +# +# Before the turn/request scoping, _trace_key collapsed to the constant +# session_id. That worked only because _finish_trace pops the key on a clean +# turn end. When turn 1 does NOT finalize (interrupted, tool-only final step, +# or empty final content), its state lingered under session_id and turn 2 +# silently merged into turn 1's trace instead of opening its own. +# --------------------------------------------------------------------------- + + +class TestTurnTraceIsolation: + def _fresh_plugin(self): + sys.modules.pop("plugins.observability.langfuse", None) + return importlib.import_module("plugins.observability.langfuse") + + @staticmethod + def _fake_client(started): + """A minimal Langfuse stand-in that records each root trace opened. + + ``_start_root_trace`` calls ``create_trace_id`` then opens a root via + ``start_as_current_observation(...)`` (a context manager whose + ``__enter__`` returns the root span). We record one entry per root + actually opened so the test can count distinct traces. + """ + + class _Span: + def update(self, **kw): + pass + + def end(self, **kw): + pass + + def set_trace_io(self, **kw): + pass + + def start_observation(self, **kw): + return _Span() + + class _RootCM: + def __enter__(self): + return _Span() + + def __exit__(self, *exc): + return False + + class _Client: + def create_trace_id(self, seed=None): + return f"trace::{seed}" + + def start_as_current_observation(self, **kw): + started.append(kw.get("trace_context", {}).get("trace_id")) + return _RootCM() + + def flush(self): + pass + + return _Client() + + def _run_turn(self, mod, *, session, turn_n, finalize): + """Drive one turn through the request-scoped hooks the gateway fires.""" + task_id = session # gateway sets task_id == session_id + turn_id = f"{session}:{task_id}:turn{turn_n}" + api_call_count = 1 # resets every turn + api_request_id = f"{turn_id}:api:{api_call_count}" + + mod.on_pre_llm_request( + task_id=task_id, + session_id=session, + model="m", + provider="p", + api_mode="chat", + api_call_count=api_call_count, + request_messages=[{"role": "user", "content": "hi"}], + turn_id=turn_id, + api_request_id=api_request_id, + ) + # finalize=False => leave a tool call on the final response so + # _finish_trace is skipped and the turn's state lingers. + mod.on_post_llm_call( + task_id=task_id, + session_id=session, + model="m", + provider="p", + api_mode="chat", + api_call_count=api_call_count, + assistant_content_chars=5 if finalize else 0, + assistant_tool_call_count=0 if finalize else 1, + usage={"input_tokens": 10, "output_tokens": 5}, + turn_id=turn_id, + api_request_id=api_request_id, + ) + + def test_unfinalized_turn_does_not_capture_next_turn(self, monkeypatch): + """A turn that never finalizes must not absorb the following turn.""" + mod = self._fresh_plugin() + started: list = [] + monkeypatch.setattr(mod, "_get_langfuse", lambda: self._fake_client(started)) + monkeypatch.setattr(mod, "_end_observation", lambda *a, **k: None) + mod._TRACE_STATE.clear() + + # Turn 1 ends without finalizing (its final step still has a tool call). + self._run_turn(mod, session="sess-iso", turn_n=1, finalize=False) + # Turn 2 is a normal, fully finalizing turn in the SAME session. + self._run_turn(mod, session="sess-iso", turn_n=2, finalize=True) + + # Each turn opened its OWN root trace. On the pre-fix code the second + # turn reused turn 1's lingering state and only one trace was opened. + assert len(started) == 2 + + # Turn 2 finalized and was popped by _finish_trace; only turn 1's + # (non-finalizing) state lingers. Assert the surviving key is turn 1's + # and that turn 2 never merged into it — `all(...)` over an empty set + # would pass vacuously, so pin the exact surviving key instead. + keys = list(mod._TRACE_STATE.keys()) + assert len(keys) == 1 + assert "turn1" in keys[0] + assert "turn2" not in keys[0] + + def test_pre_and_post_hooks_share_one_key_within_a_turn(self, monkeypatch): + """turn_id is preferred over api_request_id so the turn-scoped + post_llm_call (which carries no api_request_id) still resolves to the + same key as the request-scoped pre/post_api_request hooks. If the + ordering were reversed, finalization would silently break.""" + mod = self._fresh_plugin() + turn_id = "S:T:turnX" + api_request_id = f"{turn_id}:api:1" + + k_pre_api = mod._trace_key("T", "S", turn_id=turn_id, api_request_id=api_request_id) + k_post_api = mod._trace_key("T", "S", turn_id=turn_id, api_request_id=api_request_id) + k_post_turn = mod._trace_key("T", "S", turn_id=turn_id, api_request_id="") + + assert k_pre_api == k_post_api == k_post_turn + + def test_non_finalizing_turns_do_not_grow_state_unboundedly(self, monkeypatch): + """Per-turn keys mean a turn that never finalizes leaves a lingering + entry. Without a cap that grows once per non-finalizing turn forever; + the LRU eviction must bound _TRACE_STATE at _MAX_TRACE_STATE. + """ + mod = self._fresh_plugin() + started: list = [] + monkeypatch.setattr(mod, "_get_langfuse", lambda: self._fake_client(started)) + monkeypatch.setattr(mod, "_end_observation", lambda *a, **k: None) + monkeypatch.setattr(mod, "_MAX_TRACE_STATE", 8) + mod._TRACE_STATE.clear() + + # Far more non-finalizing turns than the cap. + for n in range(50): + self._run_turn(mod, session="sess-leak", turn_n=n, finalize=False) + + assert len(mod._TRACE_STATE) <= 8 + # The survivors are the most-recently-updated turns (LRU eviction). + surviving = sorted(int(k.rsplit("turn", 1)[1]) for k in mod._TRACE_STATE) + assert surviving == list(range(42, 50)) + + def test_trace_key_strings_unchanged_by_refactor(self): + """Pin the exact key strings across all task/session/turn/api + combinations so the _scope_prefix extraction can never silently change + a key (keys are matched across hooks; a drift breaks finalization).""" + mod = self._fresh_plugin() + tk = mod._trace_key + assert tk("t", "s", turn_id="u") == "task:t:turn:u" + assert tk("", "s", turn_id="u") == "session:s:turn:u" + assert tk("t", "s", api_request_id="r") == "task:t:api:r" + assert tk("", "s", api_request_id="r") == "session:s:api:r" + assert tk("t", "s") == "t" # legacy: bare task_id + assert tk("", "s") == "session:s" + # turn_id wins over api_request_id when both are present. + assert tk("t", "s", turn_id="u", api_request_id="r") == "task:t:turn:u" + + # --------------------------------------------------------------------------- # Placeholder-credential guard (#23823). #