fix(langfuse): scope trace state by turn/request ids

This commit is contained in:
infinitycrew39 2026-06-17 23:38:47 +07:00 committed by kshitijk4poor
parent edcde6b26f
commit 0b54a33a34

View file

@ -219,7 +219,34 @@ def _get_langfuse() -> Optional[Langfuse]:
return _LANGFUSE_CLIENT
def _trace_key(task_id: str, session_id: str) -> str:
def _trace_key(
task_id: str,
session_id: str,
*,
turn_id: str = "",
api_request_id: str = "",
) -> str:
"""Build a stable in-process trace scope key for one agent turn.
Older Hermes paths only expose ``task_id``/``session_id``. Newer paths
pass ``turn_id`` and ``api_request_id`` in LLM/tool hooks; when present,
they must scope trace state so concurrent requests sharing one task/session
never collide.
"""
if turn_id:
if task_id:
return f"task:{task_id}:turn:{turn_id}"
if session_id:
return f"session:{session_id}:turn:{turn_id}"
return f"thread:{threading.get_ident()}:turn:{turn_id}"
if api_request_id:
if task_id:
return f"task:{task_id}:api:{api_request_id}"
if session_id:
return f"session:{session_id}:api:{api_request_id}"
return f"thread:{threading.get_ident()}:api:{api_request_id}"
if task_id:
return task_id
if session_id:
@ -563,12 +590,15 @@ def _usage_and_cost(response: Any, *, provider: str, api_mode: str, model: str,
def _start_root_trace(task_key: str, *, task_id: str, session_id: str, platform: str, provider: str, model: str,
api_mode: str, messages: Any, client: Langfuse) -> TraceState:
api_mode: str, messages: Any, client: Langfuse,
turn_id: str = "", api_request_id: str = "") -> TraceState:
trace_id = client.create_trace_id(seed=f"{session_id or 'sessionless'}::{task_id or task_key}")
trace_input = _extract_last_user_message(messages)
metadata = {
"source": "hermes",
"task_id": task_id,
"turn_id": turn_id,
"api_request_id": api_request_id,
"platform": platform,
"provider": provider,
"model": model,
@ -712,7 +742,8 @@ def _request_key(api_call_count: Any) -> str:
def on_pre_llm_call(*, task_id: str = "", session_id: str = "", platform: str = "", model: str = "",
provider: str = "", base_url: str = "", api_mode: str = "",
api_call_count: int = 0, messages: Any = None, turn_type: str = "user",
conversation_history: Any = None, user_message: Any = None, **_: Any) -> None:
conversation_history: Any = None, user_message: Any = None,
turn_id: str = "", api_request_id: str = "", **_: Any) -> None:
# Older Hermes branches used pre_llm_call for request-scoped tracing and
# passed the actual API messages. Current Hermes also has a turn-scoped
# pre_llm_call used for context injection; tracing that hook creates an
@ -729,7 +760,12 @@ def on_pre_llm_call(*, task_id: str = "", session_id: str = "", platform: str =
# pre_llm_call with API messages directly. Current Hermes fires
# pre_llm_call for context injection (conversation_history/user_message,
# no messages list) — tracing that would create orphan traces.
task_key = _trace_key(task_id, session_id)
task_key = _trace_key(
task_id,
session_id,
turn_id=turn_id,
api_request_id=api_request_id,
)
with _STATE_LOCK:
state = _TRACE_STATE.get(task_key)
@ -744,6 +780,8 @@ def on_pre_llm_call(*, task_id: str = "", session_id: str = "", platform: str =
api_mode=api_mode,
messages=messages,
client=client,
turn_id=turn_id,
api_request_id=api_request_id,
)
_TRACE_STATE[task_key] = state
state.last_updated_at = time.time()
@ -769,6 +807,8 @@ def on_pre_llm_request(
max_tokens: Any = None,
conversation_history: Any = None,
user_message: Any = None,
turn_id: str = "",
api_request_id: str = "",
**_: Any,
) -> None:
client = _get_langfuse()
@ -782,7 +822,12 @@ def on_pre_llm_request(
user_message=user_message,
)
task_key = _trace_key(task_id, session_id)
task_key = _trace_key(
task_id,
session_id,
turn_id=turn_id,
api_request_id=api_request_id,
)
req_key = _request_key(api_call_count)
with _STATE_LOCK:
@ -798,6 +843,8 @@ def on_pre_llm_request(
api_mode=api_mode,
messages=input_messages,
client=client,
turn_id=turn_id,
api_request_id=api_request_id,
)
_TRACE_STATE[task_key] = state
state.last_updated_at = time.time()
@ -827,12 +874,18 @@ def on_post_llm_call(*, task_id: str = "", session_id: str = "", provider: str =
api_duration: float = 0.0, finish_reason: str = "",
usage: Any = None, assistant_content_chars: int = 0,
assistant_tool_call_count: int = 0, assistant_response: Any = None,
turn_id: str = "", api_request_id: str = "",
**_: Any) -> None:
client = _get_langfuse()
if client is None:
return
task_key = _trace_key(task_id, session_id)
task_key = _trace_key(
task_id,
session_id,
turn_id=turn_id,
api_request_id=api_request_id,
)
req_key = _request_key(api_call_count)
with _STATE_LOCK:
@ -950,12 +1003,18 @@ def on_post_llm_call(*, task_id: str = "", session_id: str = "", provider: str =
def on_pre_tool_call(*, tool_name: str = "", args: Any = None, task_id: str = "",
session_id: str = "", tool_call_id: str = "", **_: Any) -> None:
session_id: str = "", tool_call_id: str = "",
turn_id: str = "", api_request_id: str = "", **_: Any) -> None:
client = _get_langfuse()
if client is None:
return
task_key = _trace_key(task_id, session_id)
task_key = _trace_key(
task_id,
session_id,
turn_id=turn_id,
api_request_id=api_request_id,
)
with _STATE_LOCK:
state = _TRACE_STATE.get(task_key)
@ -976,8 +1035,14 @@ def on_pre_tool_call(*, tool_name: str = "", args: Any = None, task_id: str = ""
def on_post_tool_call(*, tool_name: str = "", args: Any = None, result: Any = None,
task_id: str = "", session_id: str = "", tool_call_id: str = "", **_: Any) -> None:
task_key = _trace_key(task_id, session_id)
task_id: str = "", session_id: str = "", tool_call_id: str = "",
turn_id: str = "", api_request_id: str = "", **_: Any) -> None:
task_key = _trace_key(
task_id,
session_id,
turn_id=turn_id,
api_request_id=api_request_id,
)
observation = None
with _STATE_LOCK: