diff --git a/gateway/run.py b/gateway/run.py index 9434c9e5f..1ba7fc847 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -6637,6 +6637,7 @@ class GatewayRunner: thread_id=str(context.source.thread_id) if context.source.thread_id else "", user_id=str(context.source.user_id) if context.source.user_id else "", user_name=str(context.source.user_name) if context.source.user_name else "", + session_key=context.session_key, ) def _clear_session_env(self, tokens: list) -> None: @@ -7379,8 +7380,8 @@ class GatewayRunner: # `_resolve_turn_agent_config(message, …)`. nonlocal message - # Pass session_key to process registry via env var so background - # processes can be mapped back to this gateway session + # session_key is now set via contextvars in _set_session_env() + # (concurrency-safe). Keep os.environ as fallback for CLI/cron. os.environ["HERMES_SESSION_KEY"] = session_key or "" # Read from env var or use default (same as CLI) diff --git a/gateway/session_context.py b/gateway/session_context.py index 6d676dc1e..b9fdcdfaf 100644 --- a/gateway/session_context.py +++ b/gateway/session_context.py @@ -48,6 +48,7 @@ _SESSION_CHAT_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_CHAT_NAME", def _SESSION_THREAD_ID: ContextVar[str] = ContextVar("HERMES_SESSION_THREAD_ID", default="") _SESSION_USER_ID: ContextVar[str] = ContextVar("HERMES_SESSION_USER_ID", default="") _SESSION_USER_NAME: ContextVar[str] = ContextVar("HERMES_SESSION_USER_NAME", default="") +_SESSION_KEY: ContextVar[str] = ContextVar("HERMES_SESSION_KEY", default="") _VAR_MAP = { "HERMES_SESSION_PLATFORM": _SESSION_PLATFORM, @@ -56,6 +57,7 @@ _VAR_MAP = { "HERMES_SESSION_THREAD_ID": _SESSION_THREAD_ID, "HERMES_SESSION_USER_ID": _SESSION_USER_ID, "HERMES_SESSION_USER_NAME": _SESSION_USER_NAME, + "HERMES_SESSION_KEY": _SESSION_KEY, } @@ -66,6 +68,7 @@ def set_session_vars( thread_id: str = "", user_id: str = "", user_name: str = "", + session_key: str = "", ) -> list: """Set all session context variables and return reset tokens. @@ -82,6 +85,7 @@ def set_session_vars( _SESSION_THREAD_ID.set(thread_id), _SESSION_USER_ID.set(user_id), _SESSION_USER_NAME.set(user_name), + _SESSION_KEY.set(session_key), ] return tokens @@ -97,6 +101,7 @@ def clear_session_vars(tokens: list) -> None: _SESSION_THREAD_ID, _SESSION_USER_ID, _SESSION_USER_NAME, + _SESSION_KEY, ] for var, token in zip(vars_in_order, tokens): var.reset(token) diff --git a/tests/gateway/test_session_env.py b/tests/gateway/test_session_env.py index b75e267f1..9f556f884 100644 --- a/tests/gateway/test_session_env.py +++ b/tests/gateway/test_session_env.py @@ -1,3 +1,4 @@ +import asyncio import os from gateway.config import Platform @@ -130,3 +131,99 @@ def test_set_session_env_handles_missing_optional_fields(): assert get_session_env("HERMES_SESSION_THREAD_ID") == "" runner._clear_session_env(tokens) + + +# --------------------------------------------------------------------------- +# SESSION_KEY contextvars tests +# --------------------------------------------------------------------------- + + +def test_session_key_set_via_contextvars(monkeypatch): + """set_session_vars should set HERMES_SESSION_KEY via contextvars.""" + monkeypatch.delenv("HERMES_SESSION_KEY", raising=False) + + tokens = set_session_vars( + platform="telegram", + chat_id="-1001", + session_key="tg:-1001:17585", + ) + assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585" + + clear_session_vars(tokens) + assert get_session_env("HERMES_SESSION_KEY") == "" + + +def test_session_key_falls_back_to_os_environ(monkeypatch): + """get_session_env for SESSION_KEY should fall back to os.environ.""" + monkeypatch.setenv("HERMES_SESSION_KEY", "env-session-123") + + # No contextvar set — should read from os.environ + assert get_session_env("HERMES_SESSION_KEY") == "env-session-123" + + # Set contextvar — should prefer it + tokens = set_session_vars(session_key="ctx-session-456") + assert get_session_env("HERMES_SESSION_KEY") == "ctx-session-456" + + # Restore — should fall back to os.environ + clear_session_vars(tokens) + assert get_session_env("HERMES_SESSION_KEY") == "env-session-123" + + +def test_set_session_env_includes_session_key(): + """_set_session_env should propagate session_key from SessionContext.""" + runner = object.__new__(GatewayRunner) + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="-1001", + chat_name="Group", + chat_type="group", + thread_id="17585", + ) + context = SessionContext( + source=source, + connected_platforms=[], + home_channels={}, + session_key="tg:-1001:17585", + ) + + tokens = runner._set_session_env(context) + assert get_session_env("HERMES_SESSION_KEY") == "tg:-1001:17585" + runner._clear_session_env(tokens) + assert get_session_env("HERMES_SESSION_KEY") == "" + + +def test_session_key_no_race_condition_with_contextvars(monkeypatch): + """Prove contextvars isolates SESSION_KEY across concurrent async tasks. + + Two tasks set different session keys. With contextvars each task + reads back its own value. With os.environ the second task would + overwrite the first (the old bug). + """ + monkeypatch.delenv("HERMES_SESSION_KEY", raising=False) + + results = {} + + async def handler(key: str, delay: float): + tokens = set_session_vars(session_key=key) + try: + await asyncio.sleep(delay) + read_back = get_session_env("HERMES_SESSION_KEY") + results[key] = read_back + finally: + clear_session_vars(tokens) + + async def run(): + task_a = asyncio.create_task(handler("session-A", 0.15)) + await asyncio.sleep(0.05) + task_b = asyncio.create_task(handler("session-B", 0.05)) + await asyncio.gather(task_a, task_b) + + asyncio.run(run()) + + # Both tasks must read back their own session key + assert results["session-A"] == "session-A", ( + f"Session A got '{results['session-A']}' instead of 'session-A' — race condition!" + ) + assert results["session-B"] == "session-B", ( + f"Session B got '{results['session-B']}' instead of 'session-B' — race condition!" + ) diff --git a/tools/approval.py b/tools/approval.py index faf888f18..9a3a4ef26 100644 --- a/tools/approval.py +++ b/tools/approval.py @@ -40,11 +40,18 @@ def reset_current_session_key(token: contextvars.Token[str]) -> None: def get_current_session_key(default: str = "default") -> str: - """Return the active session key, preferring context-local state.""" + """Return the active session key, preferring context-local state. + + Resolution order: + 1. approval-specific contextvars (set by gateway before agent.run) + 2. session_context contextvars (set by _set_session_env) + 3. os.environ fallback (CLI, cron, tests) + """ session_key = _approval_session_key.get() if session_key: return session_key - return os.getenv("HERMES_SESSION_KEY", default) + from gateway.session_context import get_session_env + return get_session_env("HERMES_SESSION_KEY", default) # Sensitive write targets that should trigger approval even when referenced # via shell expansions like $HOME or $HERMES_HOME.