mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gateway): preserve notify context in executor threads
Gateway executor work now inherits the active session contextvars via copy_context() so background process watchers retain the correct platform/chat/user/session metadata for routing completion events back to the originating chat. Cherry-picked from #10647 by @helix4u with: - Use asyncio.get_running_loop() instead of deprecated get_event_loop() - Strip trailing whitespace - Add *args forwarding test - Add exception propagation test
This commit is contained in:
parent
4093982f19
commit
8021a735c2
2 changed files with 80 additions and 7 deletions
|
|
@ -24,6 +24,7 @@ import signal
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from contextvars import copy_context
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, Optional, Any, List
|
from typing import Dict, Optional, Any, List
|
||||||
|
|
@ -5715,8 +5716,7 @@ class GatewayRunner:
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
result = await self._run_in_executor_with_context(run_sync)
|
||||||
result = await loop.run_in_executor(None, run_sync)
|
|
||||||
|
|
||||||
response = result.get("final_response", "") if result else ""
|
response = result.get("final_response", "") if result else ""
|
||||||
if not response and result and result.get("error"):
|
if not response and result and result.get("error"):
|
||||||
|
|
@ -5898,8 +5898,7 @@ class GatewayRunner:
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
result = await self._run_in_executor_with_context(run_sync)
|
||||||
result = await loop.run_in_executor(None, run_sync)
|
|
||||||
|
|
||||||
response = (result.get("final_response") or "") if result else ""
|
response = (result.get("final_response") or "") if result else ""
|
||||||
if not response and result and result.get("error"):
|
if not response and result and result.get("error"):
|
||||||
|
|
@ -7319,6 +7318,12 @@ class GatewayRunner:
|
||||||
from gateway.session_context import clear_session_vars
|
from gateway.session_context import clear_session_vars
|
||||||
clear_session_vars(tokens)
|
clear_session_vars(tokens)
|
||||||
|
|
||||||
|
async def _run_in_executor_with_context(self, func, *args):
|
||||||
|
"""Run blocking work in the thread pool while preserving session contextvars."""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
ctx = copy_context()
|
||||||
|
return await loop.run_in_executor(None, ctx.run, func, *args)
|
||||||
|
|
||||||
async def _enrich_message_with_vision(
|
async def _enrich_message_with_vision(
|
||||||
self,
|
self,
|
||||||
user_text: str,
|
user_text: str,
|
||||||
|
|
@ -9094,9 +9099,8 @@ class GatewayRunner:
|
||||||
_agent_warning_raw = float(os.getenv("HERMES_AGENT_TIMEOUT_WARNING", 900))
|
_agent_warning_raw = float(os.getenv("HERMES_AGENT_TIMEOUT_WARNING", 900))
|
||||||
_agent_warning = _agent_warning_raw if _agent_warning_raw > 0 else None
|
_agent_warning = _agent_warning_raw if _agent_warning_raw > 0 else None
|
||||||
_warning_fired = False
|
_warning_fired = False
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
_executor_task = asyncio.ensure_future(
|
_executor_task = asyncio.ensure_future(
|
||||||
loop.run_in_executor(None, run_sync)
|
self._run_in_executor_with_context(run_sync)
|
||||||
)
|
)
|
||||||
|
|
||||||
_inactivity_timeout = False
|
_inactivity_timeout = False
|
||||||
|
|
|
||||||
|
|
@ -251,3 +251,72 @@ def test_session_key_no_race_condition_with_contextvars(monkeypatch):
|
||||||
assert results["session-B"] == "session-B", (
|
assert results["session-B"] == "session-B", (
|
||||||
f"Session B got '{results['session-B']}' instead of 'session-B' — race condition!"
|
f"Session B got '{results['session-B']}' instead of 'session-B' — race condition!"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_in_executor_with_context_preserves_session_env(monkeypatch):
|
||||||
|
"""Gateway executor work should inherit session contextvars for tool routing."""
|
||||||
|
runner = object.__new__(GatewayRunner)
|
||||||
|
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||||
|
monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False)
|
||||||
|
monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False)
|
||||||
|
monkeypatch.delenv("HERMES_SESSION_USER_ID", raising=False)
|
||||||
|
|
||||||
|
source = SessionSource(
|
||||||
|
platform=Platform.TELEGRAM,
|
||||||
|
chat_id="2144471399",
|
||||||
|
chat_type="dm",
|
||||||
|
user_id="123456",
|
||||||
|
user_name="alice",
|
||||||
|
thread_id=None,
|
||||||
|
)
|
||||||
|
context = SessionContext(
|
||||||
|
source=source,
|
||||||
|
connected_platforms=[],
|
||||||
|
home_channels={},
|
||||||
|
session_key="agent:main:telegram:dm:2144471399",
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = runner._set_session_env(context)
|
||||||
|
try:
|
||||||
|
result = await runner._run_in_executor_with_context(
|
||||||
|
lambda: {
|
||||||
|
"platform": get_session_env("HERMES_SESSION_PLATFORM"),
|
||||||
|
"chat_id": get_session_env("HERMES_SESSION_CHAT_ID"),
|
||||||
|
"user_id": get_session_env("HERMES_SESSION_USER_ID"),
|
||||||
|
"session_key": get_session_env("HERMES_SESSION_KEY"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
runner._clear_session_env(tokens)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"platform": "telegram",
|
||||||
|
"chat_id": "2144471399",
|
||||||
|
"user_id": "123456",
|
||||||
|
"session_key": "agent:main:telegram:dm:2144471399",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_in_executor_with_context_forwards_args():
|
||||||
|
"""_run_in_executor_with_context should forward *args to the callable."""
|
||||||
|
runner = object.__new__(GatewayRunner)
|
||||||
|
|
||||||
|
def add(a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
result = await runner._run_in_executor_with_context(add, 3, 7)
|
||||||
|
assert result == 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_in_executor_with_context_propagates_exceptions():
|
||||||
|
"""Exceptions inside the executor should propagate to the caller."""
|
||||||
|
runner = object.__new__(GatewayRunner)
|
||||||
|
|
||||||
|
def blow_up():
|
||||||
|
raise ValueError("boom")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="boom"):
|
||||||
|
await runner._run_in_executor_with_context(blow_up)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue