fix(gateway): enforce env variable template expansion on runtime config loaders

This commit is contained in:
QuenVix 2026-05-23 10:44:32 +03:00 committed by Teknium
parent d21ac579e9
commit 2362cc4688
2 changed files with 140 additions and 83 deletions

View file

@ -1308,6 +1308,26 @@ def _load_gateway_config() -> dict:
return {}
def _load_gateway_runtime_config() -> dict:
"""Load gateway config for runtime reads, expanding supported ``${VAR}`` refs.
Runtime helpers should honor the same env-template expansion documented for
``config.yaml`` while still respecting tests that monkeypatch
``gateway.run._hermes_home``. Build on ``_load_gateway_config()`` rather
than calling the canonical loader directly so both behaviors stay aligned.
"""
cfg = _load_gateway_config()
if not isinstance(cfg, dict) or not cfg:
return {}
try:
from hermes_cli.config import _expand_env_vars
expanded = _expand_env_vars(cfg)
return expanded if isinstance(expanded, dict) else {}
except Exception:
return cfg
def _resolve_gateway_model(config: dict | None = None) -> str:
"""Read model from config.yaml — single source of truth.
@ -2642,15 +2662,8 @@ class GatewayRunner:
"""
file_path = os.getenv("HERMES_PREFILL_MESSAGES_FILE", "")
if not file_path:
try:
import yaml as _y
cfg_path = _hermes_home / "config.yaml"
if cfg_path.exists():
with open(cfg_path, encoding="utf-8") as _f:
cfg = _y.safe_load(_f) or {}
file_path = cfg.get("prefill_messages_file", "")
except Exception:
pass
cfg = _load_gateway_runtime_config()
file_path = str(cfg.get("prefill_messages_file", "") or "")
if not file_path:
return []
path = Path(file_path).expanduser()
@ -2680,16 +2693,8 @@ class GatewayRunner:
prompt = os.getenv("HERMES_EPHEMERAL_SYSTEM_PROMPT", "")
if prompt:
return prompt
try:
import yaml as _y
cfg_path = _hermes_home / "config.yaml"
if cfg_path.exists():
with open(cfg_path, encoding="utf-8") as _f:
cfg = _y.safe_load(_f) or {}
return (cfg_get(cfg, "agent", "system_prompt", default="") or "").strip()
except Exception:
pass
return ""
cfg = _load_gateway_runtime_config()
return str(cfg_get(cfg, "agent", "system_prompt", default="") or "").strip()
@staticmethod
def _load_reasoning_config() -> dict | None:
@ -2700,16 +2705,8 @@ class GatewayRunner:
default (medium).
"""
from hermes_constants import parse_reasoning_effort
effort = ""
try:
import yaml as _y
cfg_path = _hermes_home / "config.yaml"
if cfg_path.exists():
with open(cfg_path, encoding="utf-8") as _f:
cfg = _y.safe_load(_f) or {}
effort = str(cfg_get(cfg, "agent", "reasoning_effort", default="") or "").strip()
except Exception:
pass
cfg = _load_gateway_runtime_config()
effort = str(cfg_get(cfg, "agent", "reasoning_effort", default="") or "").strip()
result = parse_reasoning_effort(effort)
if effort and effort.strip() and result is None:
logger.warning("Unknown reasoning_effort '%s', using default (medium)", effort)
@ -2783,16 +2780,8 @@ class GatewayRunner:
"fast"/"priority"/"on" => "priority", while "normal"/"off" disables it.
Returns None when unset or unsupported.
"""
raw = ""
try:
import yaml as _y
cfg_path = _hermes_home / "config.yaml"
if cfg_path.exists():
with open(cfg_path, encoding="utf-8") as _f:
cfg = _y.safe_load(_f) or {}
raw = str(cfg_get(cfg, "agent", "service_tier", default="") or "").strip()
except Exception:
pass
cfg = _load_gateway_runtime_config()
raw = str(cfg_get(cfg, "agent", "service_tier", default="") or "").strip()
value = raw.lower()
if not value or value in {"normal", "default", "standard", "off", "none"}:
@ -2805,34 +2794,19 @@ class GatewayRunner:
@staticmethod
def _load_show_reasoning() -> bool:
"""Load show_reasoning toggle from config.yaml display section."""
try:
import yaml as _y
cfg_path = _hermes_home / "config.yaml"
if cfg_path.exists():
with open(cfg_path, encoding="utf-8") as _f:
cfg = _y.safe_load(_f) or {}
return is_truthy_value(
cfg_get(cfg, "display", "show_reasoning"),
default=False,
)
except Exception:
pass
return False
cfg = _load_gateway_runtime_config()
return is_truthy_value(
cfg_get(cfg, "display", "show_reasoning"),
default=False,
)
@staticmethod
def _load_busy_input_mode() -> str:
"""Load gateway drain-time busy-input behavior from config/env."""
mode = os.getenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "").strip().lower()
if not mode:
try:
import yaml as _y
cfg_path = _hermes_home / "config.yaml"
if cfg_path.exists():
with open(cfg_path, encoding="utf-8") as _f:
cfg = _y.safe_load(_f) or {}
mode = str(cfg_get(cfg, "display", "busy_input_mode", default="") or "").strip().lower()
except Exception:
pass
cfg = _load_gateway_runtime_config()
mode = str(cfg_get(cfg, "display", "busy_input_mode", default="") or "").strip().lower()
if mode == "queue":
return "queue"
if mode == "steer":
@ -2844,15 +2818,8 @@ class GatewayRunner:
"""Load graceful gateway restart/stop drain timeout in seconds."""
raw = os.getenv("HERMES_RESTART_DRAIN_TIMEOUT", "").strip()
if not raw:
try:
import yaml as _y
cfg_path = _hermes_home / "config.yaml"
if cfg_path.exists():
with open(cfg_path, encoding="utf-8") as _f:
cfg = _y.safe_load(_f) or {}
raw = str(cfg_get(cfg, "agent", "restart_drain_timeout", default="") or "").strip()
except Exception:
pass
cfg = _load_gateway_runtime_config()
raw = str(cfg_get(cfg, "agent", "restart_drain_timeout", default="") or "").strip()
value = parse_restart_drain_timeout(raw)
if raw and value == DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT:
try:
@ -2877,19 +2844,12 @@ class GatewayRunner:
"""
mode = os.getenv("HERMES_BACKGROUND_NOTIFICATIONS", "")
if not mode:
try:
import yaml as _y
cfg_path = _hermes_home / "config.yaml"
if cfg_path.exists():
with open(cfg_path, encoding="utf-8") as _f:
cfg = _y.safe_load(_f) or {}
raw = cfg_get(cfg, "display", "background_process_notifications")
if raw is False:
mode = "off"
elif raw not in {None, ""}:
mode = str(raw)
except Exception:
pass
cfg = _load_gateway_runtime_config()
raw = cfg_get(cfg, "display", "background_process_notifications")
if raw is False:
mode = "off"
elif raw not in {None, ""}:
mode = str(raw)
mode = (mode or "all").strip().lower()
valid = {"all", "result", "error", "off"}
if mode not in valid:

View file

@ -0,0 +1,97 @@
"""Regression tests for gateway runtime config env-var expansion."""
from __future__ import annotations
import json
import pytest
import gateway.run as gateway_run
def _write_config(home, body: str) -> None:
(home / "config.yaml").write_text(body, encoding="utf-8")
@pytest.fixture
def gateway_home(monkeypatch, tmp_path):
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_PREFILL_MESSAGES_FILE", raising=False)
monkeypatch.delenv("HERMES_EPHEMERAL_SYSTEM_PROMPT", raising=False)
monkeypatch.delenv("HERMES_GATEWAY_BUSY_INPUT_MODE", raising=False)
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
monkeypatch.delenv("HERMES_BACKGROUND_NOTIFICATIONS", raising=False)
return tmp_path
def test_load_prefill_messages_expands_env_var_path(monkeypatch, gateway_home):
prefill = [{"role": "system", "content": "few-shot"}]
(gateway_home / "prefill.json").write_text(json.dumps(prefill), encoding="utf-8")
_write_config(gateway_home, "prefill_messages_file: ${PREFILL_FILE}\n")
monkeypatch.setenv("PREFILL_FILE", "prefill.json")
assert gateway_run.GatewayRunner._load_prefill_messages() == prefill
@pytest.mark.parametrize(
("config_body", "env_name", "env_value", "loader_name", "expected"),
[
(
"agent:\n system_prompt: ${GW_PROMPT}\n",
"GW_PROMPT",
"expanded prompt",
"_load_ephemeral_system_prompt",
"expanded prompt",
),
(
"agent:\n reasoning_effort: ${REASONING_LEVEL}\n",
"REASONING_LEVEL",
"high",
"_load_reasoning_config",
{"enabled": True, "effort": "high"},
),
(
"agent:\n service_tier: ${SERVICE_TIER}\n",
"SERVICE_TIER",
"priority",
"_load_service_tier",
"priority",
),
(
"display:\n busy_input_mode: ${BUSY_MODE}\n",
"BUSY_MODE",
"steer",
"_load_busy_input_mode",
"steer",
),
(
"agent:\n restart_drain_timeout: ${DRAIN_TIMEOUT}\n",
"DRAIN_TIMEOUT",
"12",
"_load_restart_drain_timeout",
12.0,
),
(
"display:\n background_process_notifications: ${BG_MODE}\n",
"BG_MODE",
"error",
"_load_background_notifications_mode",
"error",
),
],
)
def test_gateway_runtime_loaders_expand_env_var_templates(
monkeypatch,
gateway_home,
config_body,
env_name,
env_value,
loader_name,
expected,
):
_write_config(gateway_home, config_body)
monkeypatch.setenv(env_name, env_value)
loader = getattr(gateway_run.GatewayRunner, loader_name)
assert loader() == expected