diff --git a/gateway/run.py b/gateway/run.py index f909a2c73..c50c67462 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -3252,6 +3252,18 @@ class GatewayRunner: logger.debug("Gateway memory flush on reset failed: %s", e) self._evict_cached_agent(session_key) + try: + from tools.env_passthrough import clear_env_passthrough + clear_env_passthrough() + except Exception: + pass + + try: + from tools.credential_files import clear_credential_files + clear_credential_files() + except Exception: + pass + # Reset the session new_entry = self.session_store.reset_session(session_key) diff --git a/tools/credential_files.py b/tools/credential_files.py index 9a30f9bff..49768bff4 100644 --- a/tools/credential_files.py +++ b/tools/credential_files.py @@ -22,14 +22,26 @@ from __future__ import annotations import logging import os +from contextvars import ContextVar from pathlib import Path from typing import Dict, List logger = logging.getLogger(__name__) # Session-scoped list of credential files to mount. -# Key: container_path (deduplicated), Value: host_path -_registered_files: Dict[str, str] = {} +# Backed by ContextVar to prevent cross-session data bleed in the gateway pipeline. +_registered_files_var: ContextVar[Dict[str, str]] = ContextVar("_registered_files") + + +def _get_registered() -> Dict[str, str]: + """Get or create the registered credential files dict for the current context/session.""" + try: + return _registered_files_var.get() + except LookupError: + val: Dict[str, str] = {} + _registered_files_var.set(val) + return val + # Cache for config-based file list (loaded once per process). _config_files: List[Dict[str, str]] | None = None @@ -86,7 +98,7 @@ def register_credential_file( return False container_path = f"{container_base.rstrip('/')}/{relative_path}" - _registered_files[container_path] = str(resolved) + _get_registered()[container_path] = str(resolved) logger.debug("credential_files: registered %s -> %s", resolved, container_path) return True @@ -174,7 +186,7 @@ def get_credential_file_mounts() -> List[Dict[str, str]]: mounts: Dict[str, str] = {} # Skill-registered files - for container_path, host_path in _registered_files.items(): + for container_path, host_path in _get_registered().items(): # Re-check existence (file may have been deleted since registration) if Path(host_path).is_file(): mounts[container_path] = host_path @@ -395,7 +407,7 @@ def iter_cache_files( def clear_credential_files() -> None: """Reset the skill-scoped registry (e.g. on session reset).""" - _registered_files.clear() + _get_registered().clear() def reset_config_cache() -> None: diff --git a/tools/env_passthrough.py b/tools/env_passthrough.py index 29e94e7c3..e8dc51272 100644 --- a/tools/env_passthrough.py +++ b/tools/env_passthrough.py @@ -21,13 +21,25 @@ from __future__ import annotations import logging import os -from pathlib import Path +from contextvars import ContextVar from typing import Iterable logger = logging.getLogger(__name__) # Session-scoped set of env var names that should pass through to sandboxes. -_allowed_env_vars: set[str] = set() +# Backed by ContextVar to prevent cross-session data bleed in the gateway pipeline. +_allowed_env_vars_var: ContextVar[set[str]] = ContextVar("_allowed_env_vars") + + +def _get_allowed() -> set[str]: + """Get or create the allowed env vars set for the current context/session.""" + try: + return _allowed_env_vars_var.get() + except LookupError: + val: set[str] = set() + _allowed_env_vars_var.set(val) + return val + # Cache for the config-based allowlist (loaded once per process). _config_passthrough: frozenset[str] | None = None @@ -41,7 +53,7 @@ def register_env_passthrough(var_names: Iterable[str]) -> None: for name in var_names: name = name.strip() if name: - _allowed_env_vars.add(name) + _get_allowed().add(name) logger.debug("env passthrough: registered %s", name) @@ -78,19 +90,19 @@ def is_env_passthrough(var_name: str) -> bool: Returns ``True`` if the variable was registered by a skill or listed in the user's ``tools.env_passthrough`` config. """ - if var_name in _allowed_env_vars: + if var_name in _get_allowed(): return True return var_name in _load_config_passthrough() def get_all_passthrough() -> frozenset[str]: """Return the union of skill-registered and config-based passthrough vars.""" - return frozenset(_allowed_env_vars) | _load_config_passthrough() + return frozenset(_get_allowed()) | _load_config_passthrough() def clear_env_passthrough() -> None: """Reset the skill-scoped allowlist (e.g. on session reset).""" - _allowed_env_vars.clear() + _get_allowed().clear() def reset_config_cache() -> None: