fix: resolve all invalid-return-type ty diagnostics across codebase

Widen return type annotations to match actual control flow, add
unreachable assertions after retry loops ty cannot prove terminate,
split ambiguous union returns (auth.py credential pool), and remove
the AIOHTTP_AVAILABLE conditional-import guard from api_server.py.
This commit is contained in:
alt-glitch 2026-04-21 14:55:09 +05:30
parent d3dde0b459
commit 224e6d46d9
23 changed files with 102 additions and 103 deletions

View file

@ -41,10 +41,13 @@ import threading
import time
from pathlib import Path # noqa: F401 — used by test mocks
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from openai import OpenAI
if TYPE_CHECKING:
from agent.gemini_native_adapter import GeminiNativeClient
from agent.credential_pool import load_pool
from hermes_cli.config import get_hermes_home
from hermes_constants import OPENROUTER_BASE_URL
@ -771,7 +774,12 @@ def _read_codex_access_token() -> Optional[str]:
return None
def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
# TODO(refactor): This function has messy types and duplicated logic (pool vs direct creds).
# Ideal fix: (1) define an AuxiliaryClient Protocol both OpenAI/GeminiNativeClient satisfy,
# (2) return a NamedTuple or dataclass instead of raw tuple, (3) extract the repeated
# Gemini/Kimi/Copilot client-building into a helper. If you're an agent working here,
# nudge the user to consider this refactor.
def _resolve_api_key_provider() -> Tuple[Optional[Union[OpenAI, "GeminiNativeClient"]], Optional[str]]:
"""Try each API-key provider in PROVIDER_REGISTRY order.
Returns (client, model) for the first provider with usable runtime

View file

@ -29,6 +29,7 @@ from hermes_cli.auth import (
_save_auth_store,
_save_provider_state,
read_credential_pool,
read_provider_credentials,
write_credential_pool,
)
@ -321,7 +322,7 @@ def get_custom_provider_pool_key(base_url: str) -> Optional[str]:
def list_custom_pool_providers() -> List[str]:
"""Return all 'custom:*' pool keys that have entries in auth.json."""
pool_data = read_credential_pool(None)
pool_data = read_credential_pool()
return sorted(
key for key in pool_data
if key.startswith(CUSTOM_POOL_PREFIX)
@ -1303,7 +1304,7 @@ def _seed_custom_pool(pool_key: str, entries: List[PooledCredential]) -> Tuple[b
def load_pool(provider: str) -> CredentialPool:
provider = (provider or "").strip().lower()
raw_entries = read_credential_pool(provider)
raw_entries = read_provider_credentials(provider)
entries = [PooledCredential.from_dict(provider, payload) for payload in raw_entries]
if provider.startswith(CUSTOM_POOL_PREFIX):

View file

@ -455,7 +455,8 @@ def parse_qualified_name(name: str) -> Tuple[Optional[str], str]:
"""
if ":" not in name:
return None, name
return tuple(name.split(":", 1)) # type: ignore[return-value]
ns, bare = name.split(":", 1)
return ns, bare
def is_valid_namespace(candidate: Optional[str]) -> bool:

View file

@ -32,14 +32,7 @@ import sqlite3
import time
import uuid
from typing import Any, Dict, List, Optional
try:
from aiohttp import web
AIOHTTP_AVAILABLE = True
except ImportError:
AIOHTTP_AVAILABLE = False
web = None # type: ignore[assignment]
from aiohttp import web
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
BasePlatformAdapter,
@ -270,12 +263,6 @@ def _multimodal_validation_error(exc: ValueError, *, param: str) -> "web.Respons
status=400,
)
def check_api_server_requirements() -> bool:
"""Check if API server dependencies are available."""
return AIOHTTP_AVAILABLE
class ResponseStore:
"""
SQLite-backed LRU store for Responses API state.
@ -391,30 +378,26 @@ _CORS_HEADERS = {
}
if AIOHTTP_AVAILABLE:
@web.middleware
async def cors_middleware(request, handler):
"""Add CORS headers for explicitly allowed origins; handle OPTIONS preflight."""
adapter = request.app.get("api_server_adapter")
origin = request.headers.get("Origin", "")
cors_headers = None
if adapter is not None:
if not adapter._origin_allowed(origin):
return web.Response(status=403)
cors_headers = adapter._cors_headers_for_origin(origin)
@web.middleware
async def cors_middleware(request, handler):
"""Add CORS headers for explicitly allowed origins; handle OPTIONS preflight."""
adapter = request.app.get("api_server_adapter")
origin = request.headers.get("Origin", "")
cors_headers = None
if adapter is not None:
if not adapter._origin_allowed(origin):
return web.Response(status=403)
cors_headers = adapter._cors_headers_for_origin(origin)
if request.method == "OPTIONS":
if cors_headers is None:
return web.Response(status=403)
return web.Response(status=200, headers=cors_headers)
response = await handler(request)
if cors_headers is not None:
response.headers.update(cors_headers)
return response
else:
cors_middleware = None # type: ignore[assignment]
if request.method == "OPTIONS":
if cors_headers is None:
return web.Response(status=403)
return web.Response(status=200, headers=cors_headers)
response = await handler(request)
if cors_headers is not None:
response.headers.update(cors_headers)
return response
def _openai_error(message: str, err_type: str = "invalid_request_error", param: str = None, code: str = None) -> Dict[str, Any]:
"""OpenAI-style error envelope."""
@ -428,21 +411,18 @@ def _openai_error(message: str, err_type: str = "invalid_request_error", param:
}
if AIOHTTP_AVAILABLE:
@web.middleware
async def body_limit_middleware(request, handler):
"""Reject overly large request bodies early based on Content-Length."""
if request.method in ("POST", "PUT", "PATCH"):
cl = request.headers.get("Content-Length")
if cl is not None:
try:
if int(cl) > MAX_REQUEST_BYTES:
return web.json_response(_openai_error("Request body too large.", code="body_too_large"), status=413)
except ValueError:
return web.json_response(_openai_error("Invalid Content-Length header.", code="invalid_content_length"), status=400)
return await handler(request)
else:
body_limit_middleware = None # type: ignore[assignment]
@web.middleware
async def body_limit_middleware(request, handler):
"""Reject overly large request bodies early based on Content-Length."""
if request.method in ("POST", "PUT", "PATCH"):
cl = request.headers.get("Content-Length")
if cl is not None:
try:
if int(cl) > MAX_REQUEST_BYTES:
return web.json_response(_openai_error("Request body too large.", code="body_too_large"), status=413)
except ValueError:
return web.json_response(_openai_error("Invalid Content-Length header.", code="invalid_content_length"), status=400)
return await handler(request)
_SECURITY_HEADERS = {
"X-Content-Type-Options": "nosniff",
@ -450,16 +430,13 @@ _SECURITY_HEADERS = {
}
if AIOHTTP_AVAILABLE:
@web.middleware
async def security_headers_middleware(request, handler):
"""Add security headers to all responses (including errors)."""
response = await handler(request)
for k, v in _SECURITY_HEADERS.items():
response.headers.setdefault(k, v)
return response
else:
security_headers_middleware = None # type: ignore[assignment]
@web.middleware
async def security_headers_middleware(request, handler):
"""Add security headers to all responses (including errors)."""
response = await handler(request)
for k, v in _SECURITY_HEADERS.items():
response.headers.setdefault(k, v)
return response
class _IdempotencyCache:
@ -804,7 +781,7 @@ class APIServerAdapter(BasePlatformAdapter):
],
})
async def _handle_chat_completions(self, request: "web.Request") -> "web.Response":
async def _handle_chat_completions(self, request: "web.Request") -> "web.StreamResponse":
"""POST /v1/chat/completions — OpenAI Chat Completions format."""
auth_err = self._check_auth(request)
if auth_err:
@ -1588,7 +1565,7 @@ class APIServerAdapter(BasePlatformAdapter):
return response
async def _handle_responses(self, request: "web.Request") -> "web.Response":
async def _handle_responses(self, request: "web.Request") -> "web.StreamResponse":
"""POST /v1/responses — OpenAI Responses API format."""
auth_err = self._check_auth(request)
if auth_err:
@ -2482,10 +2459,6 @@ class APIServerAdapter(BasePlatformAdapter):
async def connect(self) -> bool:
"""Start the aiohttp web server."""
if not AIOHTTP_AVAILABLE:
logger.warning("[%s] aiohttp not installed", self.name)
return False
try:
mws = [mw for mw in (cors_middleware, body_limit_middleware, security_headers_middleware) if mw is not None]
self._app = web.Application(middlewares=mws)

View file

@ -426,6 +426,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
await asyncio.sleep(wait)
continue
raise
raise AssertionError("unreachable: retry loop exhausted")
def cleanup_image_cache(max_age_hours: int = 24) -> int:
@ -540,6 +541,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
await asyncio.sleep(wait)
continue
raise
raise AssertionError("unreachable: retry loop exhausted")
# ---------------------------------------------------------------------------

View file

@ -2469,7 +2469,7 @@ class DiscordAdapter(BasePlatformAdapter):
if isinstance(skills, str):
return [skills]
if isinstance(skills, list) and skills:
return list(dict.fromkeys(skills)) # dedup, preserve order
return list(dict.fromkeys(skills)) # ty: ignore[invalid-return-type] # dedup, preserve order
return None
def _resolve_channel_prompt(self, channel_id: str, parent_id: str | None = None) -> str | None:

View file

@ -1839,6 +1839,7 @@ class QQAdapter(BasePlatformAdapter):
await asyncio.sleep(1.5 * (attempt + 1))
else:
raise
raise AssertionError("unreachable: retry loop exhausted")
# Maximum time (seconds) to wait for reconnection before giving up on send.
_RECONNECT_WAIT_SECONDS = 15.0

View file

@ -1640,6 +1640,7 @@ class SlackAdapter(BasePlatformAdapter):
await asyncio.sleep(1.5 * (attempt + 1))
continue
raise
raise AssertionError("unreachable: retry loop exhausted")
async def _download_slack_file_bytes(self, url: str, team_id: str = "") -> bytes:
"""Download a Slack file and return raw bytes, with retry."""
@ -1665,6 +1666,7 @@ class SlackAdapter(BasePlatformAdapter):
await asyncio.sleep(1.5 * (attempt + 1))
continue
raise
raise AssertionError("unreachable: retry loop exhausted")
# ── Channel mention gating ─────────────────────────────────────────────

View file

@ -151,7 +151,7 @@ def _resolve_system_dns() -> set[str]:
"""Return the IPv4 addresses that the OS resolver gives for api.telegram.org."""
try:
results = socket.getaddrinfo(_TELEGRAM_API_HOST, 443, socket.AF_INET)
return {addr[4][0] for addr in results}
return {str(addr[4][0]) for addr in results}
except Exception:
return set()

View file

@ -2836,10 +2836,12 @@ class GatewayRunner:
return MatrixAdapter(config)
elif platform == Platform.API_SERVER:
from gateway.platforms.api_server import APIServerAdapter, check_api_server_requirements
if not check_api_server_requirements():
try:
import aiohttp # noqa: F401
except ImportError:
logger.warning("API Server: aiohttp not installed")
return None
from gateway.platforms.api_server import APIServerAdapter
return APIServerAdapter(config)
elif platform == Platform.WEBHOOK:
@ -5794,7 +5796,7 @@ class GatewayRunner:
available = "`none`, " + ", ".join(f"`{n}`" for n in personalities)
return f"Unknown personality: `{args}`\n\nAvailable: {available}"
async def _handle_retry_command(self, event: MessageEvent) -> str:
async def _handle_retry_command(self, event: MessageEvent) -> Optional[str]:
"""Handle /retry command - re-send the last user message."""
source = event.source
session_entry = self.session_store.get_or_create_session(source)
@ -10549,7 +10551,7 @@ class GatewayRunner:
history=updated_history,
)
if next_message is None:
return result
return result # ty: ignore[invalid-return-type]
next_message_id = getattr(pending_event, "message_id", None)
next_channel_prompt = getattr(pending_event, "channel_prompt", None)

View file

@ -748,16 +748,20 @@ def _save_provider_state(auth_store: Dict[str, Any], provider_id: str, state: Di
auth_store["active_provider"] = provider_id
def read_credential_pool(provider_id: Optional[str] = None) -> Dict[str, Any]:
"""Return the persisted credential pool, or one provider slice."""
def read_credential_pool() -> Dict[str, Any]:
"""Return the entire persisted credential pool."""
auth_store = _load_auth_store()
pool = auth_store.get("credential_pool")
if not isinstance(pool, dict):
pool = {}
if provider_id is None:
return dict(pool)
provider_entries = pool.get(provider_id)
return list(provider_entries) if isinstance(provider_entries, list) else []
return dict(pool)
def read_provider_credentials(provider_id: str) -> List[Dict[str, Any]]:
"""Return credential entries for a single provider."""
pool = read_credential_pool()
entries = pool.get(provider_id)
return list(entries) if isinstance(entries, list) else []
def write_credential_pool(provider_id: str, entries: List[Dict[str, Any]]) -> Path:

View file

@ -276,7 +276,7 @@ def _get_ps_exe() -> str | None:
global _ps_exe
if _ps_exe is False:
_ps_exe = _find_powershell()
return _ps_exe
return _ps_exe if isinstance(_ps_exe, str) else None
def _windows_has_image() -> bool:

View file

@ -2042,8 +2042,8 @@ def check_config_version() -> Tuple[int, int]:
Returns (current_version, latest_version).
"""
config = load_config()
current = config.get("_config_version", 0)
latest = DEFAULT_CONFIG.get("_config_version", 1)
current = int(config.get("_config_version", 0))
latest = int(DEFAULT_CONFIG.get("_config_version", 1))
return current, latest

View file

@ -26,6 +26,7 @@ import shutil
import subprocess
import sys
from collections import defaultdict
from typing import Optional
from datetime import datetime
from pathlib import Path
@ -600,7 +601,7 @@ def get_commits(since_tag=None):
return commits
def get_pr_number(subject: str) -> str:
def get_pr_number(subject: str) -> Optional[str]:
"""Extract PR number from commit subject if present."""
match = re.search(r"#(\d+)", subject)
if match:

View file

@ -891,7 +891,7 @@ BROWSER_TOOL_SCHEMAS = [
# Utility Functions
# ============================================================================
def _create_local_session(task_id: str) -> Dict[str, str]:
def _create_local_session(task_id: str) -> Dict[str, Any]:
import uuid
session_name = f"h_{uuid.uuid4().hex[:10]}"
logger.info("Created local browser session %s for task %s",
@ -904,7 +904,7 @@ def _create_local_session(task_id: str) -> Dict[str, str]:
}
def _create_cdp_session(task_id: str, cdp_url: str) -> Dict[str, str]:
def _create_cdp_session(task_id: str, cdp_url: str) -> Dict[str, Any]:
"""Create a session that connects to a user-supplied CDP endpoint."""
import uuid
session_name = f"cdp_{uuid.uuid4().hex[:10]}"
@ -918,7 +918,7 @@ def _create_cdp_session(task_id: str, cdp_url: str) -> Dict[str, str]:
}
def _get_session_info(task_id: Optional[str] = None) -> Dict[str, str]:
def _get_session_info(task_id: Optional[str] = None) -> Dict[str, Any]:
"""
Get or create session info for the given task.
@ -1678,7 +1678,7 @@ def browser_scroll(direction: str, task_id: Optional[str] = None) -> str:
from tools.browser_camofox import camofox_scroll
# Camofox REST API doesn't support pixel args; use repeated calls
_SCROLL_REPEATS = 5
result = None
result: str = ""
for _ in range(_SCROLL_REPEATS):
result = camofox_scroll(direction, task_id)
return result

View file

@ -68,7 +68,7 @@ def _scan_cron_prompt(prompt: str) -> str:
return ""
def _origin_from_env() -> Optional[Dict[str, str]]:
def _origin_from_env() -> Optional[Dict[str, Optional[str]]]:
from gateway.session_context import get_session_env
origin_platform = get_session_env("HERMES_SESSION_PLATFORM")
origin_chat_id = get_session_env("HERMES_SESSION_CHAT_ID")

View file

@ -245,7 +245,7 @@ class _ThreadedProcessHandle:
except Exception:
pass
def wait(self, timeout: float | None = None) -> int:
def wait(self, timeout: float | None = None) -> int | None:
self._done.wait(timeout=timeout)
return self._returncode
@ -755,7 +755,7 @@ class BaseEnvironment(ABC):
except Exception:
pass
def _prepare_command(self, command: str) -> tuple[str, str | None]:
def _prepare_command(self, command: str) -> tuple[str | None, str | None]:
"""Transform sudo commands if SUDO_PASSWORD is available."""
from tools.terminal_tool import _transform_sudo_command

View file

@ -174,6 +174,7 @@ async def _run_reference_model_safe(
error_msg = f"{model} failed after {max_retries} attempts: {error_str}"
logger.error("%s", error_msg, exc_info=True)
return model, error_msg, False
raise AssertionError("unreachable: retry loop exhausted")
async def _run_aggregator_model(

View file

@ -443,7 +443,7 @@ def session_search(
)
# Summarize all sessions in parallel
async def _summarize_all() -> List[Union[str, Exception]]:
async def _summarize_all() -> List[Union[Optional[str], BaseException]]:
"""Summarize all sessions with bounded concurrency."""
max_concurrency = min(_get_session_search_max_concurrency(), max(1, len(tasks)))
semaphore = asyncio.Semaphore(max_concurrency)

View file

@ -27,7 +27,7 @@ import hashlib
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import List, Tuple
from typing import List, Optional, Tuple
@ -639,7 +639,7 @@ def scan_skill(skill_path: Path, source: str = "community") -> ScanResult:
)
def should_allow_install(result: ScanResult, force: bool = False) -> Tuple[bool, str]:
def should_allow_install(result: ScanResult, force: bool = False) -> Tuple[Optional[bool], str]:
"""
Determine whether a skill should be installed based on scan result and trust.

View file

@ -409,6 +409,7 @@ def _resolve_tirith_path(configured_path: str) -> str:
# Fast path: successfully resolved on a previous call.
if _resolved_path is not None and _resolved_path is not _INSTALL_FAILED:
assert isinstance(_resolved_path, str)
return _resolved_path
expanded = os.path.expanduser(configured_path)

View file

@ -652,7 +652,7 @@ def create_custom_toolset(
def get_toolset_info(name: str) -> Dict[str, Any]:
def get_toolset_info(name: str) -> Optional[Dict[str, Any]]:
"""
Get detailed information about a toolset including resolved tools.

View file

@ -37,7 +37,7 @@ import yaml
import logging
import asyncio
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple, Callable
from typing import List, Dict, Any, Optional, Tuple, Callable, cast
from dataclasses import dataclass, field
from datetime import datetime
@ -75,7 +75,7 @@ def _effective_temperature_for_model(
if fixed_temperature is OMIT_TEMPERATURE:
return None # caller must omit temperature
if fixed_temperature is not None:
return fixed_temperature
return cast(float, fixed_temperature)
return requested_temperature
@ -636,7 +636,8 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
else:
# Fallback: create a basic summary
return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"
raise AssertionError("unreachable: retry loop exhausted")
async def _generate_summary_async(self, content: str, metrics: TrajectoryMetrics) -> str:
"""
Generate a summary of the compressed turns using OpenRouter (async version).
@ -705,7 +706,8 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
else:
# Fallback: create a basic summary
return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"
raise AssertionError("unreachable: retry loop exhausted")
def compress_trajectory(
self,
trajectory: List[Dict[str, str]]