mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-18 09:51:59 +00:00
Merge pull request #48292 from NousResearch/fix/langfuse-trace-scope-salvage
fix(langfuse): scope trace state by turn/request ids (salvage #47945)
This commit is contained in:
commit
9b2f7d2cb1
2 changed files with 323 additions and 11 deletions
|
|
@ -54,6 +54,15 @@ class TraceState:
|
|||
|
||||
_STATE_LOCK = threading.Lock()
|
||||
_TRACE_STATE: Dict[str, TraceState] = {}
|
||||
# Hard cap on live trace state. Each turn keys _TRACE_STATE by a unique
|
||||
# turn_id, and an entry is normally reclaimed by _finish_trace when a turn
|
||||
# ends cleanly (final response has content and no tool calls). A turn that
|
||||
# never reaches that state — interrupted, a tool-only final step, or empty
|
||||
# final content — would otherwise linger forever, so over the cap we evict
|
||||
# the least-recently-updated entries (ending their root span first). The cap
|
||||
# is far above any realistic concurrent-live-turn working set; it exists only
|
||||
# to bound the leak from non-finalizing turns, not to limit concurrency.
|
||||
_MAX_TRACE_STATE = 256
|
||||
_LANGFUSE_CLIENT = None
|
||||
_READ_FILE_LINE_RE = re.compile(r"^\s*(\d+)\|(.*)$")
|
||||
_READ_FILE_HEAD_LINES = 25
|
||||
|
|
@ -219,14 +228,43 @@ def _get_langfuse() -> Optional[Langfuse]:
|
|||
return _LANGFUSE_CLIENT
|
||||
|
||||
|
||||
def _trace_key(task_id: str, session_id: str) -> str:
|
||||
def _scope_prefix(task_id: str, session_id: str) -> str:
|
||||
"""The task/session/thread prefix shared by every trace-key shape."""
|
||||
if task_id:
|
||||
return task_id
|
||||
return f"task:{task_id}"
|
||||
if session_id:
|
||||
return f"session:{session_id}"
|
||||
return f"thread:{threading.get_ident()}"
|
||||
|
||||
|
||||
def _trace_key(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
*,
|
||||
turn_id: str = "",
|
||||
api_request_id: str = "",
|
||||
) -> str:
|
||||
"""Build a stable in-process trace scope key for one agent turn.
|
||||
|
||||
Older Hermes paths only expose ``task_id``/``session_id``. Newer paths
|
||||
pass ``turn_id`` and ``api_request_id`` in LLM/tool hooks; when present,
|
||||
they must scope trace state so concurrent requests sharing one task/session
|
||||
never collide. ``turn_id`` is preferred over ``api_request_id`` so the
|
||||
turn-level ``post_llm_call`` hook (which carries ``turn_id`` but no
|
||||
``api_request_id``) resolves to the same key as the request-level hooks.
|
||||
"""
|
||||
if turn_id:
|
||||
return f"{_scope_prefix(task_id, session_id)}:turn:{turn_id}"
|
||||
if api_request_id:
|
||||
return f"{_scope_prefix(task_id, session_id)}:api:{api_request_id}"
|
||||
# Legacy shape: a bare ``task_id`` (NOT the ``task:`` prefix) when present,
|
||||
# otherwise the session/thread prefix. Kept distinct for backward
|
||||
# compatibility with keys minted before turn/request scoping existed.
|
||||
if task_id:
|
||||
return task_id
|
||||
return _scope_prefix(task_id, session_id)
|
||||
|
||||
|
||||
def _is_base64_data_uri(value: str) -> bool:
|
||||
prefix = value[:200].lower()
|
||||
return prefix.startswith("data:") and ";base64," in prefix
|
||||
|
|
@ -563,12 +601,15 @@ def _usage_and_cost(response: Any, *, provider: str, api_mode: str, model: str,
|
|||
|
||||
|
||||
def _start_root_trace(task_key: str, *, task_id: str, session_id: str, platform: str, provider: str, model: str,
|
||||
api_mode: str, messages: Any, client: Langfuse) -> TraceState:
|
||||
api_mode: str, messages: Any, client: Langfuse,
|
||||
turn_id: str = "", api_request_id: str = "") -> TraceState:
|
||||
trace_id = client.create_trace_id(seed=f"{session_id or 'sessionless'}::{task_id or task_key}")
|
||||
trace_input = _extract_last_user_message(messages)
|
||||
metadata = {
|
||||
"source": "hermes",
|
||||
"task_id": task_id,
|
||||
"turn_id": turn_id,
|
||||
"api_request_id": api_request_id,
|
||||
"platform": platform,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
|
|
@ -669,6 +710,30 @@ def _merge_trace_output(output: Any, state: TraceState) -> Any:
|
|||
return merged
|
||||
|
||||
|
||||
def _evict_stale_locked() -> None:
|
||||
"""Drop least-recently-updated trace state to make room for a new entry.
|
||||
|
||||
Caller MUST hold ``_STATE_LOCK`` and call this immediately before inserting
|
||||
one new entry. Bounds the leak from turns that never reach ``_finish_trace``
|
||||
(interrupted / tool-only final step / empty final content), whose unique
|
||||
per-turn key would otherwise linger forever. We evict down to
|
||||
``_MAX_TRACE_STATE - 1`` so that the about-to-be-added entry leaves the dict
|
||||
at ``_MAX_TRACE_STATE`` — a true ceiling. The evicted entry's root span is
|
||||
ended so it is not left dangling on the Langfuse side.
|
||||
"""
|
||||
over = len(_TRACE_STATE) - (_MAX_TRACE_STATE - 1)
|
||||
if over <= 0:
|
||||
return
|
||||
# Oldest-first by last_updated_at; evict just enough to make room.
|
||||
stale = sorted(_TRACE_STATE.items(), key=lambda kv: kv[1].last_updated_at)[:over]
|
||||
for key, state in stale:
|
||||
_TRACE_STATE.pop(key, None)
|
||||
try:
|
||||
state.root_span.end()
|
||||
except Exception as exc: # pragma: no cover - fail-open
|
||||
_debug(f"evict stale trace failed: {exc}")
|
||||
|
||||
|
||||
def _finish_trace(task_key: str, *, output: Any = None) -> None:
|
||||
client = _get_langfuse()
|
||||
if client is None:
|
||||
|
|
@ -712,7 +777,8 @@ def _request_key(api_call_count: Any) -> str:
|
|||
def on_pre_llm_call(*, task_id: str = "", session_id: str = "", platform: str = "", model: str = "",
|
||||
provider: str = "", base_url: str = "", api_mode: str = "",
|
||||
api_call_count: int = 0, messages: Any = None, turn_type: str = "user",
|
||||
conversation_history: Any = None, user_message: Any = None, **_: Any) -> None:
|
||||
conversation_history: Any = None, user_message: Any = None,
|
||||
turn_id: str = "", api_request_id: str = "", **_: Any) -> None:
|
||||
# Older Hermes branches used pre_llm_call for request-scoped tracing and
|
||||
# passed the actual API messages. Current Hermes also has a turn-scoped
|
||||
# pre_llm_call used for context injection; tracing that hook creates an
|
||||
|
|
@ -729,7 +795,12 @@ def on_pre_llm_call(*, task_id: str = "", session_id: str = "", platform: str =
|
|||
# pre_llm_call with API messages directly. Current Hermes fires
|
||||
# pre_llm_call for context injection (conversation_history/user_message,
|
||||
# no messages list) — tracing that would create orphan traces.
|
||||
task_key = _trace_key(task_id, session_id)
|
||||
task_key = _trace_key(
|
||||
task_id,
|
||||
session_id,
|
||||
turn_id=turn_id,
|
||||
api_request_id=api_request_id,
|
||||
)
|
||||
|
||||
with _STATE_LOCK:
|
||||
state = _TRACE_STATE.get(task_key)
|
||||
|
|
@ -744,7 +815,10 @@ def on_pre_llm_call(*, task_id: str = "", session_id: str = "", platform: str =
|
|||
api_mode=api_mode,
|
||||
messages=messages,
|
||||
client=client,
|
||||
turn_id=turn_id,
|
||||
api_request_id=api_request_id,
|
||||
)
|
||||
_evict_stale_locked()
|
||||
_TRACE_STATE[task_key] = state
|
||||
state.last_updated_at = time.time()
|
||||
|
||||
|
|
@ -769,6 +843,8 @@ def on_pre_llm_request(
|
|||
max_tokens: Any = None,
|
||||
conversation_history: Any = None,
|
||||
user_message: Any = None,
|
||||
turn_id: str = "",
|
||||
api_request_id: str = "",
|
||||
**_: Any,
|
||||
) -> None:
|
||||
client = _get_langfuse()
|
||||
|
|
@ -782,7 +858,12 @@ def on_pre_llm_request(
|
|||
user_message=user_message,
|
||||
)
|
||||
|
||||
task_key = _trace_key(task_id, session_id)
|
||||
task_key = _trace_key(
|
||||
task_id,
|
||||
session_id,
|
||||
turn_id=turn_id,
|
||||
api_request_id=api_request_id,
|
||||
)
|
||||
req_key = _request_key(api_call_count)
|
||||
|
||||
with _STATE_LOCK:
|
||||
|
|
@ -798,7 +879,10 @@ def on_pre_llm_request(
|
|||
api_mode=api_mode,
|
||||
messages=input_messages,
|
||||
client=client,
|
||||
turn_id=turn_id,
|
||||
api_request_id=api_request_id,
|
||||
)
|
||||
_evict_stale_locked()
|
||||
_TRACE_STATE[task_key] = state
|
||||
state.last_updated_at = time.time()
|
||||
previous = state.generations.pop(req_key, None)
|
||||
|
|
@ -827,12 +911,18 @@ def on_post_llm_call(*, task_id: str = "", session_id: str = "", provider: str =
|
|||
api_duration: float = 0.0, finish_reason: str = "",
|
||||
usage: Any = None, assistant_content_chars: int = 0,
|
||||
assistant_tool_call_count: int = 0, assistant_response: Any = None,
|
||||
turn_id: str = "", api_request_id: str = "",
|
||||
**_: Any) -> None:
|
||||
client = _get_langfuse()
|
||||
if client is None:
|
||||
return
|
||||
|
||||
task_key = _trace_key(task_id, session_id)
|
||||
task_key = _trace_key(
|
||||
task_id,
|
||||
session_id,
|
||||
turn_id=turn_id,
|
||||
api_request_id=api_request_id,
|
||||
)
|
||||
req_key = _request_key(api_call_count)
|
||||
|
||||
with _STATE_LOCK:
|
||||
|
|
@ -950,12 +1040,18 @@ def on_post_llm_call(*, task_id: str = "", session_id: str = "", provider: str =
|
|||
|
||||
|
||||
def on_pre_tool_call(*, tool_name: str = "", args: Any = None, task_id: str = "",
|
||||
session_id: str = "", tool_call_id: str = "", **_: Any) -> None:
|
||||
session_id: str = "", tool_call_id: str = "",
|
||||
turn_id: str = "", api_request_id: str = "", **_: Any) -> None:
|
||||
client = _get_langfuse()
|
||||
if client is None:
|
||||
return
|
||||
|
||||
task_key = _trace_key(task_id, session_id)
|
||||
task_key = _trace_key(
|
||||
task_id,
|
||||
session_id,
|
||||
turn_id=turn_id,
|
||||
api_request_id=api_request_id,
|
||||
)
|
||||
|
||||
with _STATE_LOCK:
|
||||
state = _TRACE_STATE.get(task_key)
|
||||
|
|
@ -976,8 +1072,14 @@ def on_pre_tool_call(*, tool_name: str = "", args: Any = None, task_id: str = ""
|
|||
|
||||
|
||||
def on_post_tool_call(*, tool_name: str = "", args: Any = None, result: Any = None,
|
||||
task_id: str = "", session_id: str = "", tool_call_id: str = "", **_: Any) -> None:
|
||||
task_key = _trace_key(task_id, session_id)
|
||||
task_id: str = "", session_id: str = "", tool_call_id: str = "",
|
||||
turn_id: str = "", api_request_id: str = "", **_: Any) -> None:
|
||||
task_key = _trace_key(
|
||||
task_id,
|
||||
session_id,
|
||||
turn_id=turn_id,
|
||||
api_request_id=api_request_id,
|
||||
)
|
||||
observation = None
|
||||
|
||||
with _STATE_LOCK:
|
||||
|
|
|
|||
|
|
@ -205,6 +205,216 @@ class TestPayloadSanitization:
|
|||
}
|
||||
|
||||
|
||||
class TestTraceScopeKey:
|
||||
def _fresh_plugin(self):
|
||||
mod_name = "plugins.observability.langfuse"
|
||||
sys.modules.pop(mod_name, None)
|
||||
return importlib.import_module(mod_name)
|
||||
|
||||
def test_trace_key_scopes_by_turn_id_when_available(self):
|
||||
plugin = self._fresh_plugin()
|
||||
|
||||
key_a = plugin._trace_key("task-1", "session-1", turn_id="turn-a")
|
||||
key_b = plugin._trace_key("task-1", "session-1", turn_id="turn-b")
|
||||
|
||||
assert key_a != key_b
|
||||
assert "turn:turn-a" in key_a
|
||||
assert "turn:turn-b" in key_b
|
||||
|
||||
def test_trace_key_scopes_by_api_request_id_when_turn_missing(self):
|
||||
plugin = self._fresh_plugin()
|
||||
|
||||
key_a = plugin._trace_key("task-1", "session-1", api_request_id="req-a")
|
||||
key_b = plugin._trace_key("task-1", "session-1", api_request_id="req-b")
|
||||
|
||||
assert key_a != key_b
|
||||
assert "api:req-a" in key_a
|
||||
assert "api:req-b" in key_b
|
||||
|
||||
def test_trace_key_keeps_legacy_shape_without_turn_or_api_id(self):
|
||||
plugin = self._fresh_plugin()
|
||||
assert plugin._trace_key("task-1", "session-1") == "task-1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-end collision regression: two turns of ONE gateway session must not
|
||||
# share trace state. The helper-level tests above prove _trace_key returns
|
||||
# distinct keys; this drives the real pre/post hooks to prove the keys are
|
||||
# actually threaded through so the second turn gets its own root trace.
|
||||
#
|
||||
# Gateway reality this reproduces:
|
||||
# * task_id == session_id for every turn (gateway/run.py)
|
||||
# * turn_id is unique per turn (turn_context.py)
|
||||
# * api_call_count resets to 1 each turn (conversation_loop.py)
|
||||
#
|
||||
# Before the turn/request scoping, _trace_key collapsed to the constant
|
||||
# session_id. That worked only because _finish_trace pops the key on a clean
|
||||
# turn end. When turn 1 does NOT finalize (interrupted, tool-only final step,
|
||||
# or empty final content), its state lingered under session_id and turn 2
|
||||
# silently merged into turn 1's trace instead of opening its own.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTurnTraceIsolation:
|
||||
def _fresh_plugin(self):
|
||||
sys.modules.pop("plugins.observability.langfuse", None)
|
||||
return importlib.import_module("plugins.observability.langfuse")
|
||||
|
||||
@staticmethod
|
||||
def _fake_client(started):
|
||||
"""A minimal Langfuse stand-in that records each root trace opened.
|
||||
|
||||
``_start_root_trace`` calls ``create_trace_id`` then opens a root via
|
||||
``start_as_current_observation(...)`` (a context manager whose
|
||||
``__enter__`` returns the root span). We record one entry per root
|
||||
actually opened so the test can count distinct traces.
|
||||
"""
|
||||
|
||||
class _Span:
|
||||
def update(self, **kw):
|
||||
pass
|
||||
|
||||
def end(self, **kw):
|
||||
pass
|
||||
|
||||
def set_trace_io(self, **kw):
|
||||
pass
|
||||
|
||||
def start_observation(self, **kw):
|
||||
return _Span()
|
||||
|
||||
class _RootCM:
|
||||
def __enter__(self):
|
||||
return _Span()
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
class _Client:
|
||||
def create_trace_id(self, seed=None):
|
||||
return f"trace::{seed}"
|
||||
|
||||
def start_as_current_observation(self, **kw):
|
||||
started.append(kw.get("trace_context", {}).get("trace_id"))
|
||||
return _RootCM()
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
return _Client()
|
||||
|
||||
def _run_turn(self, mod, *, session, turn_n, finalize):
|
||||
"""Drive one turn through the request-scoped hooks the gateway fires."""
|
||||
task_id = session # gateway sets task_id == session_id
|
||||
turn_id = f"{session}:{task_id}:turn{turn_n}"
|
||||
api_call_count = 1 # resets every turn
|
||||
api_request_id = f"{turn_id}:api:{api_call_count}"
|
||||
|
||||
mod.on_pre_llm_request(
|
||||
task_id=task_id,
|
||||
session_id=session,
|
||||
model="m",
|
||||
provider="p",
|
||||
api_mode="chat",
|
||||
api_call_count=api_call_count,
|
||||
request_messages=[{"role": "user", "content": "hi"}],
|
||||
turn_id=turn_id,
|
||||
api_request_id=api_request_id,
|
||||
)
|
||||
# finalize=False => leave a tool call on the final response so
|
||||
# _finish_trace is skipped and the turn's state lingers.
|
||||
mod.on_post_llm_call(
|
||||
task_id=task_id,
|
||||
session_id=session,
|
||||
model="m",
|
||||
provider="p",
|
||||
api_mode="chat",
|
||||
api_call_count=api_call_count,
|
||||
assistant_content_chars=5 if finalize else 0,
|
||||
assistant_tool_call_count=0 if finalize else 1,
|
||||
usage={"input_tokens": 10, "output_tokens": 5},
|
||||
turn_id=turn_id,
|
||||
api_request_id=api_request_id,
|
||||
)
|
||||
|
||||
def test_unfinalized_turn_does_not_capture_next_turn(self, monkeypatch):
|
||||
"""A turn that never finalizes must not absorb the following turn."""
|
||||
mod = self._fresh_plugin()
|
||||
started: list = []
|
||||
monkeypatch.setattr(mod, "_get_langfuse", lambda: self._fake_client(started))
|
||||
monkeypatch.setattr(mod, "_end_observation", lambda *a, **k: None)
|
||||
mod._TRACE_STATE.clear()
|
||||
|
||||
# Turn 1 ends without finalizing (its final step still has a tool call).
|
||||
self._run_turn(mod, session="sess-iso", turn_n=1, finalize=False)
|
||||
# Turn 2 is a normal, fully finalizing turn in the SAME session.
|
||||
self._run_turn(mod, session="sess-iso", turn_n=2, finalize=True)
|
||||
|
||||
# Each turn opened its OWN root trace. On the pre-fix code the second
|
||||
# turn reused turn 1's lingering state and only one trace was opened.
|
||||
assert len(started) == 2
|
||||
|
||||
# Turn 2 finalized and was popped by _finish_trace; only turn 1's
|
||||
# (non-finalizing) state lingers. Assert the surviving key is turn 1's
|
||||
# and that turn 2 never merged into it — `all(...)` over an empty set
|
||||
# would pass vacuously, so pin the exact surviving key instead.
|
||||
keys = list(mod._TRACE_STATE.keys())
|
||||
assert len(keys) == 1
|
||||
assert "turn1" in keys[0]
|
||||
assert "turn2" not in keys[0]
|
||||
|
||||
def test_pre_and_post_hooks_share_one_key_within_a_turn(self, monkeypatch):
|
||||
"""turn_id is preferred over api_request_id so the turn-scoped
|
||||
post_llm_call (which carries no api_request_id) still resolves to the
|
||||
same key as the request-scoped pre/post_api_request hooks. If the
|
||||
ordering were reversed, finalization would silently break."""
|
||||
mod = self._fresh_plugin()
|
||||
turn_id = "S:T:turnX"
|
||||
api_request_id = f"{turn_id}:api:1"
|
||||
|
||||
k_pre_api = mod._trace_key("T", "S", turn_id=turn_id, api_request_id=api_request_id)
|
||||
k_post_api = mod._trace_key("T", "S", turn_id=turn_id, api_request_id=api_request_id)
|
||||
k_post_turn = mod._trace_key("T", "S", turn_id=turn_id, api_request_id="")
|
||||
|
||||
assert k_pre_api == k_post_api == k_post_turn
|
||||
|
||||
def test_non_finalizing_turns_do_not_grow_state_unboundedly(self, monkeypatch):
|
||||
"""Per-turn keys mean a turn that never finalizes leaves a lingering
|
||||
entry. Without a cap that grows once per non-finalizing turn forever;
|
||||
the LRU eviction must bound _TRACE_STATE at _MAX_TRACE_STATE.
|
||||
"""
|
||||
mod = self._fresh_plugin()
|
||||
started: list = []
|
||||
monkeypatch.setattr(mod, "_get_langfuse", lambda: self._fake_client(started))
|
||||
monkeypatch.setattr(mod, "_end_observation", lambda *a, **k: None)
|
||||
monkeypatch.setattr(mod, "_MAX_TRACE_STATE", 8)
|
||||
mod._TRACE_STATE.clear()
|
||||
|
||||
# Far more non-finalizing turns than the cap.
|
||||
for n in range(50):
|
||||
self._run_turn(mod, session="sess-leak", turn_n=n, finalize=False)
|
||||
|
||||
assert len(mod._TRACE_STATE) <= 8
|
||||
# The survivors are the most-recently-updated turns (LRU eviction).
|
||||
surviving = sorted(int(k.rsplit("turn", 1)[1]) for k in mod._TRACE_STATE)
|
||||
assert surviving == list(range(42, 50))
|
||||
|
||||
def test_trace_key_strings_unchanged_by_refactor(self):
|
||||
"""Pin the exact key strings across all task/session/turn/api
|
||||
combinations so the _scope_prefix extraction can never silently change
|
||||
a key (keys are matched across hooks; a drift breaks finalization)."""
|
||||
mod = self._fresh_plugin()
|
||||
tk = mod._trace_key
|
||||
assert tk("t", "s", turn_id="u") == "task:t:turn:u"
|
||||
assert tk("", "s", turn_id="u") == "session:s:turn:u"
|
||||
assert tk("t", "s", api_request_id="r") == "task:t:api:r"
|
||||
assert tk("", "s", api_request_id="r") == "session:s:api:r"
|
||||
assert tk("t", "s") == "t" # legacy: bare task_id
|
||||
assert tk("", "s") == "session:s"
|
||||
# turn_id wins over api_request_id when both are present.
|
||||
assert tk("t", "s", turn_id="u", api_request_id="r") == "task:t:turn:u"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Placeholder-credential guard (#23823).
|
||||
#
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue