fix: add service domain blocklist and entity_id validation to HA tools

Block dangerous HA service domains (shell_command, command_line,
python_script, pyscript, hassio, rest_command) that allow arbitrary
code execution or SSRF. Add regex validation for entity_id to prevent
path traversal attacks. 17 new tests covering both security features.
This commit is contained in:
0xbyt4 2026-03-01 11:53:50 +03:00
parent dfd50ceccd
commit 25fb9aafcb
2 changed files with 119 additions and 0 deletions

View file

@ -13,6 +13,7 @@ import asyncio
import json
import logging
import os
import re
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
@ -24,6 +25,21 @@ logger = logging.getLogger(__name__)
_HASS_URL: str = os.getenv("HASS_URL", "http://homeassistant.local:8123").rstrip("/")
_HASS_TOKEN: str = os.getenv("HASS_TOKEN", "")
# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1")
_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$")
# Service domains blocked for security -- these allow arbitrary code/command
# execution on the HA host or enable SSRF attacks on the local network.
# HA provides zero service-level access control; all safety must be in our layer.
_BLOCKED_DOMAINS = frozenset({
"shell_command", # arbitrary shell commands as root in HA container
"command_line", # sensors/switches that execute shell commands
"python_script", # sandboxed but can escalate via hass.services.call()
"pyscript", # scripting integration with broader access
"hassio", # addon control, host shutdown/reboot, stdin to containers
"rest_command", # HTTP requests from HA server (SSRF vector)
})
def _get_headers() -> Dict[str, str]:
"""Return authorization headers for HA REST API."""
@ -198,6 +214,8 @@ def _handle_get_state(args: dict, **kw) -> str:
entity_id = args.get("entity_id", "")
if not entity_id:
return json.dumps({"error": "Missing required parameter: entity_id"})
if not _ENTITY_ID_RE.match(entity_id):
return json.dumps({"error": f"Invalid entity_id format: {entity_id}"})
try:
result = _run_async(_async_get_state(entity_id))
return json.dumps({"result": result})
@ -213,7 +231,16 @@ def _handle_call_service(args: dict, **kw) -> str:
if not domain or not service:
return json.dumps({"error": "Missing required parameters: domain and service"})
if domain in _BLOCKED_DOMAINS:
return json.dumps({
"error": f"Service domain '{domain}' is blocked for security. "
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
})
entity_id = args.get("entity_id")
if entity_id and not _ENTITY_ID_RE.match(entity_id):
return json.dumps({"error": f"Invalid entity_id format: {entity_id}"})
data = args.get("data")
try:
result = _run_async(_async_call_service(domain, service, entity_id, data))