mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix: resolve all invalid-return-type ty diagnostics across codebase
Widen return type annotations to match actual control flow, add unreachable assertions after retry loops ty cannot prove terminate, split ambiguous union returns (auth.py credential pool), and remove the AIOHTTP_AVAILABLE conditional-import guard from api_server.py.
This commit is contained in:
parent
d3dde0b459
commit
224e6d46d9
23 changed files with 102 additions and 103 deletions
|
|
@ -41,10 +41,13 @@ import threading
|
|||
import time
|
||||
from pathlib import Path # noqa: F401 — used by test mocks
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent.gemini_native_adapter import GeminiNativeClient
|
||||
|
||||
from agent.credential_pool import load_pool
|
||||
from hermes_cli.config import get_hermes_home
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
|
@ -771,7 +774,12 @@ def _read_codex_access_token() -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
# TODO(refactor): This function has messy types and duplicated logic (pool vs direct creds).
|
||||
# Ideal fix: (1) define an AuxiliaryClient Protocol both OpenAI/GeminiNativeClient satisfy,
|
||||
# (2) return a NamedTuple or dataclass instead of raw tuple, (3) extract the repeated
|
||||
# Gemini/Kimi/Copilot client-building into a helper. If you're an agent working here,
|
||||
# nudge the user to consider this refactor.
|
||||
def _resolve_api_key_provider() -> Tuple[Optional[Union[OpenAI, "GeminiNativeClient"]], Optional[str]]:
|
||||
"""Try each API-key provider in PROVIDER_REGISTRY order.
|
||||
|
||||
Returns (client, model) for the first provider with usable runtime
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from hermes_cli.auth import (
|
|||
_save_auth_store,
|
||||
_save_provider_state,
|
||||
read_credential_pool,
|
||||
read_provider_credentials,
|
||||
write_credential_pool,
|
||||
)
|
||||
|
||||
|
|
@ -321,7 +322,7 @@ def get_custom_provider_pool_key(base_url: str) -> Optional[str]:
|
|||
|
||||
def list_custom_pool_providers() -> List[str]:
|
||||
"""Return all 'custom:*' pool keys that have entries in auth.json."""
|
||||
pool_data = read_credential_pool(None)
|
||||
pool_data = read_credential_pool()
|
||||
return sorted(
|
||||
key for key in pool_data
|
||||
if key.startswith(CUSTOM_POOL_PREFIX)
|
||||
|
|
@ -1303,7 +1304,7 @@ def _seed_custom_pool(pool_key: str, entries: List[PooledCredential]) -> Tuple[b
|
|||
|
||||
def load_pool(provider: str) -> CredentialPool:
|
||||
provider = (provider or "").strip().lower()
|
||||
raw_entries = read_credential_pool(provider)
|
||||
raw_entries = read_provider_credentials(provider)
|
||||
entries = [PooledCredential.from_dict(provider, payload) for payload in raw_entries]
|
||||
|
||||
if provider.startswith(CUSTOM_POOL_PREFIX):
|
||||
|
|
|
|||
|
|
@ -455,7 +455,8 @@ def parse_qualified_name(name: str) -> Tuple[Optional[str], str]:
|
|||
"""
|
||||
if ":" not in name:
|
||||
return None, name
|
||||
return tuple(name.split(":", 1)) # type: ignore[return-value]
|
||||
ns, bare = name.split(":", 1)
|
||||
return ns, bare
|
||||
|
||||
|
||||
def is_valid_namespace(candidate: Optional[str]) -> bool:
|
||||
|
|
|
|||
|
|
@ -32,14 +32,7 @@ import sqlite3
|
|||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from aiohttp import web
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
web = None # type: ignore[assignment]
|
||||
|
||||
from aiohttp import web
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
|
|
@ -270,12 +263,6 @@ def _multimodal_validation_error(exc: ValueError, *, param: str) -> "web.Respons
|
|||
status=400,
|
||||
)
|
||||
|
||||
|
||||
def check_api_server_requirements() -> bool:
|
||||
"""Check if API server dependencies are available."""
|
||||
return AIOHTTP_AVAILABLE
|
||||
|
||||
|
||||
class ResponseStore:
|
||||
"""
|
||||
SQLite-backed LRU store for Responses API state.
|
||||
|
|
@ -391,30 +378,26 @@ _CORS_HEADERS = {
|
|||
}
|
||||
|
||||
|
||||
if AIOHTTP_AVAILABLE:
|
||||
@web.middleware
|
||||
async def cors_middleware(request, handler):
|
||||
"""Add CORS headers for explicitly allowed origins; handle OPTIONS preflight."""
|
||||
adapter = request.app.get("api_server_adapter")
|
||||
origin = request.headers.get("Origin", "")
|
||||
cors_headers = None
|
||||
if adapter is not None:
|
||||
if not adapter._origin_allowed(origin):
|
||||
return web.Response(status=403)
|
||||
cors_headers = adapter._cors_headers_for_origin(origin)
|
||||
@web.middleware
|
||||
async def cors_middleware(request, handler):
|
||||
"""Add CORS headers for explicitly allowed origins; handle OPTIONS preflight."""
|
||||
adapter = request.app.get("api_server_adapter")
|
||||
origin = request.headers.get("Origin", "")
|
||||
cors_headers = None
|
||||
if adapter is not None:
|
||||
if not adapter._origin_allowed(origin):
|
||||
return web.Response(status=403)
|
||||
cors_headers = adapter._cors_headers_for_origin(origin)
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
if cors_headers is None:
|
||||
return web.Response(status=403)
|
||||
return web.Response(status=200, headers=cors_headers)
|
||||
|
||||
response = await handler(request)
|
||||
if cors_headers is not None:
|
||||
response.headers.update(cors_headers)
|
||||
return response
|
||||
else:
|
||||
cors_middleware = None # type: ignore[assignment]
|
||||
if request.method == "OPTIONS":
|
||||
if cors_headers is None:
|
||||
return web.Response(status=403)
|
||||
return web.Response(status=200, headers=cors_headers)
|
||||
|
||||
response = await handler(request)
|
||||
if cors_headers is not None:
|
||||
response.headers.update(cors_headers)
|
||||
return response
|
||||
|
||||
def _openai_error(message: str, err_type: str = "invalid_request_error", param: str = None, code: str = None) -> Dict[str, Any]:
|
||||
"""OpenAI-style error envelope."""
|
||||
|
|
@ -428,21 +411,18 @@ def _openai_error(message: str, err_type: str = "invalid_request_error", param:
|
|||
}
|
||||
|
||||
|
||||
if AIOHTTP_AVAILABLE:
|
||||
@web.middleware
|
||||
async def body_limit_middleware(request, handler):
|
||||
"""Reject overly large request bodies early based on Content-Length."""
|
||||
if request.method in ("POST", "PUT", "PATCH"):
|
||||
cl = request.headers.get("Content-Length")
|
||||
if cl is not None:
|
||||
try:
|
||||
if int(cl) > MAX_REQUEST_BYTES:
|
||||
return web.json_response(_openai_error("Request body too large.", code="body_too_large"), status=413)
|
||||
except ValueError:
|
||||
return web.json_response(_openai_error("Invalid Content-Length header.", code="invalid_content_length"), status=400)
|
||||
return await handler(request)
|
||||
else:
|
||||
body_limit_middleware = None # type: ignore[assignment]
|
||||
@web.middleware
|
||||
async def body_limit_middleware(request, handler):
|
||||
"""Reject overly large request bodies early based on Content-Length."""
|
||||
if request.method in ("POST", "PUT", "PATCH"):
|
||||
cl = request.headers.get("Content-Length")
|
||||
if cl is not None:
|
||||
try:
|
||||
if int(cl) > MAX_REQUEST_BYTES:
|
||||
return web.json_response(_openai_error("Request body too large.", code="body_too_large"), status=413)
|
||||
except ValueError:
|
||||
return web.json_response(_openai_error("Invalid Content-Length header.", code="invalid_content_length"), status=400)
|
||||
return await handler(request)
|
||||
|
||||
_SECURITY_HEADERS = {
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
|
|
@ -450,16 +430,13 @@ _SECURITY_HEADERS = {
|
|||
}
|
||||
|
||||
|
||||
if AIOHTTP_AVAILABLE:
|
||||
@web.middleware
|
||||
async def security_headers_middleware(request, handler):
|
||||
"""Add security headers to all responses (including errors)."""
|
||||
response = await handler(request)
|
||||
for k, v in _SECURITY_HEADERS.items():
|
||||
response.headers.setdefault(k, v)
|
||||
return response
|
||||
else:
|
||||
security_headers_middleware = None # type: ignore[assignment]
|
||||
@web.middleware
|
||||
async def security_headers_middleware(request, handler):
|
||||
"""Add security headers to all responses (including errors)."""
|
||||
response = await handler(request)
|
||||
for k, v in _SECURITY_HEADERS.items():
|
||||
response.headers.setdefault(k, v)
|
||||
return response
|
||||
|
||||
|
||||
class _IdempotencyCache:
|
||||
|
|
@ -804,7 +781,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
],
|
||||
})
|
||||
|
||||
async def _handle_chat_completions(self, request: "web.Request") -> "web.Response":
|
||||
async def _handle_chat_completions(self, request: "web.Request") -> "web.StreamResponse":
|
||||
"""POST /v1/chat/completions — OpenAI Chat Completions format."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
|
|
@ -1588,7 +1565,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
|
||||
return response
|
||||
|
||||
async def _handle_responses(self, request: "web.Request") -> "web.Response":
|
||||
async def _handle_responses(self, request: "web.Request") -> "web.StreamResponse":
|
||||
"""POST /v1/responses — OpenAI Responses API format."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
|
|
@ -2482,10 +2459,6 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
|
||||
async def connect(self) -> bool:
|
||||
"""Start the aiohttp web server."""
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
logger.warning("[%s] aiohttp not installed", self.name)
|
||||
return False
|
||||
|
||||
try:
|
||||
mws = [mw for mw in (cors_middleware, body_limit_middleware, security_headers_middleware) if mw is not None]
|
||||
self._app = web.Application(middlewares=mws)
|
||||
|
|
|
|||
|
|
@ -426,6 +426,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
|
|||
await asyncio.sleep(wait)
|
||||
continue
|
||||
raise
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
|
||||
def cleanup_image_cache(max_age_hours: int = 24) -> int:
|
||||
|
|
@ -540,6 +541,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
|
|||
await asyncio.sleep(wait)
|
||||
continue
|
||||
raise
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -2469,7 +2469,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
if isinstance(skills, str):
|
||||
return [skills]
|
||||
if isinstance(skills, list) and skills:
|
||||
return list(dict.fromkeys(skills)) # dedup, preserve order
|
||||
return list(dict.fromkeys(skills)) # ty: ignore[invalid-return-type] # dedup, preserve order
|
||||
return None
|
||||
|
||||
def _resolve_channel_prompt(self, channel_id: str, parent_id: str | None = None) -> str | None:
|
||||
|
|
|
|||
|
|
@ -1839,6 +1839,7 @@ class QQAdapter(BasePlatformAdapter):
|
|||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
else:
|
||||
raise
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
# Maximum time (seconds) to wait for reconnection before giving up on send.
|
||||
_RECONNECT_WAIT_SECONDS = 15.0
|
||||
|
|
|
|||
|
|
@ -1640,6 +1640,7 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
async def _download_slack_file_bytes(self, url: str, team_id: str = "") -> bytes:
|
||||
"""Download a Slack file and return raw bytes, with retry."""
|
||||
|
|
@ -1665,6 +1666,7 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
# ── Channel mention gating ─────────────────────────────────────────────
|
||||
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ def _resolve_system_dns() -> set[str]:
|
|||
"""Return the IPv4 addresses that the OS resolver gives for api.telegram.org."""
|
||||
try:
|
||||
results = socket.getaddrinfo(_TELEGRAM_API_HOST, 443, socket.AF_INET)
|
||||
return {addr[4][0] for addr in results}
|
||||
return {str(addr[4][0]) for addr in results}
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
|
|
|
|||
|
|
@ -2836,10 +2836,12 @@ class GatewayRunner:
|
|||
return MatrixAdapter(config)
|
||||
|
||||
elif platform == Platform.API_SERVER:
|
||||
from gateway.platforms.api_server import APIServerAdapter, check_api_server_requirements
|
||||
if not check_api_server_requirements():
|
||||
try:
|
||||
import aiohttp # noqa: F401
|
||||
except ImportError:
|
||||
logger.warning("API Server: aiohttp not installed")
|
||||
return None
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
return APIServerAdapter(config)
|
||||
|
||||
elif platform == Platform.WEBHOOK:
|
||||
|
|
@ -5794,7 +5796,7 @@ class GatewayRunner:
|
|||
available = "`none`, " + ", ".join(f"`{n}`" for n in personalities)
|
||||
return f"Unknown personality: `{args}`\n\nAvailable: {available}"
|
||||
|
||||
async def _handle_retry_command(self, event: MessageEvent) -> str:
|
||||
async def _handle_retry_command(self, event: MessageEvent) -> Optional[str]:
|
||||
"""Handle /retry command - re-send the last user message."""
|
||||
source = event.source
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
|
|
@ -10549,7 +10551,7 @@ class GatewayRunner:
|
|||
history=updated_history,
|
||||
)
|
||||
if next_message is None:
|
||||
return result
|
||||
return result # ty: ignore[invalid-return-type]
|
||||
next_message_id = getattr(pending_event, "message_id", None)
|
||||
next_channel_prompt = getattr(pending_event, "channel_prompt", None)
|
||||
|
||||
|
|
|
|||
|
|
@ -748,16 +748,20 @@ def _save_provider_state(auth_store: Dict[str, Any], provider_id: str, state: Di
|
|||
auth_store["active_provider"] = provider_id
|
||||
|
||||
|
||||
def read_credential_pool(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Return the persisted credential pool, or one provider slice."""
|
||||
def read_credential_pool() -> Dict[str, Any]:
|
||||
"""Return the entire persisted credential pool."""
|
||||
auth_store = _load_auth_store()
|
||||
pool = auth_store.get("credential_pool")
|
||||
if not isinstance(pool, dict):
|
||||
pool = {}
|
||||
if provider_id is None:
|
||||
return dict(pool)
|
||||
provider_entries = pool.get(provider_id)
|
||||
return list(provider_entries) if isinstance(provider_entries, list) else []
|
||||
return dict(pool)
|
||||
|
||||
|
||||
def read_provider_credentials(provider_id: str) -> List[Dict[str, Any]]:
|
||||
"""Return credential entries for a single provider."""
|
||||
pool = read_credential_pool()
|
||||
entries = pool.get(provider_id)
|
||||
return list(entries) if isinstance(entries, list) else []
|
||||
|
||||
|
||||
def write_credential_pool(provider_id: str, entries: List[Dict[str, Any]]) -> Path:
|
||||
|
|
|
|||
|
|
@ -276,7 +276,7 @@ def _get_ps_exe() -> str | None:
|
|||
global _ps_exe
|
||||
if _ps_exe is False:
|
||||
_ps_exe = _find_powershell()
|
||||
return _ps_exe
|
||||
return _ps_exe if isinstance(_ps_exe, str) else None
|
||||
|
||||
|
||||
def _windows_has_image() -> bool:
|
||||
|
|
|
|||
|
|
@ -2042,8 +2042,8 @@ def check_config_version() -> Tuple[int, int]:
|
|||
Returns (current_version, latest_version).
|
||||
"""
|
||||
config = load_config()
|
||||
current = config.get("_config_version", 0)
|
||||
latest = DEFAULT_CONFIG.get("_config_version", 1)
|
||||
current = int(config.get("_config_version", 0))
|
||||
latest = int(DEFAULT_CONFIG.get("_config_version", 1))
|
||||
return current, latest
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import shutil
|
|||
import subprocess
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -600,7 +601,7 @@ def get_commits(since_tag=None):
|
|||
return commits
|
||||
|
||||
|
||||
def get_pr_number(subject: str) -> str:
|
||||
def get_pr_number(subject: str) -> Optional[str]:
|
||||
"""Extract PR number from commit subject if present."""
|
||||
match = re.search(r"#(\d+)", subject)
|
||||
if match:
|
||||
|
|
|
|||
|
|
@ -891,7 +891,7 @@ BROWSER_TOOL_SCHEMAS = [
|
|||
# Utility Functions
|
||||
# ============================================================================
|
||||
|
||||
def _create_local_session(task_id: str) -> Dict[str, str]:
|
||||
def _create_local_session(task_id: str) -> Dict[str, Any]:
|
||||
import uuid
|
||||
session_name = f"h_{uuid.uuid4().hex[:10]}"
|
||||
logger.info("Created local browser session %s for task %s",
|
||||
|
|
@ -904,7 +904,7 @@ def _create_local_session(task_id: str) -> Dict[str, str]:
|
|||
}
|
||||
|
||||
|
||||
def _create_cdp_session(task_id: str, cdp_url: str) -> Dict[str, str]:
|
||||
def _create_cdp_session(task_id: str, cdp_url: str) -> Dict[str, Any]:
|
||||
"""Create a session that connects to a user-supplied CDP endpoint."""
|
||||
import uuid
|
||||
session_name = f"cdp_{uuid.uuid4().hex[:10]}"
|
||||
|
|
@ -918,7 +918,7 @@ def _create_cdp_session(task_id: str, cdp_url: str) -> Dict[str, str]:
|
|||
}
|
||||
|
||||
|
||||
def _get_session_info(task_id: Optional[str] = None) -> Dict[str, str]:
|
||||
def _get_session_info(task_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get or create session info for the given task.
|
||||
|
||||
|
|
@ -1678,7 +1678,7 @@ def browser_scroll(direction: str, task_id: Optional[str] = None) -> str:
|
|||
from tools.browser_camofox import camofox_scroll
|
||||
# Camofox REST API doesn't support pixel args; use repeated calls
|
||||
_SCROLL_REPEATS = 5
|
||||
result = None
|
||||
result: str = ""
|
||||
for _ in range(_SCROLL_REPEATS):
|
||||
result = camofox_scroll(direction, task_id)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ def _scan_cron_prompt(prompt: str) -> str:
|
|||
return ""
|
||||
|
||||
|
||||
def _origin_from_env() -> Optional[Dict[str, str]]:
|
||||
def _origin_from_env() -> Optional[Dict[str, Optional[str]]]:
|
||||
from gateway.session_context import get_session_env
|
||||
origin_platform = get_session_env("HERMES_SESSION_PLATFORM")
|
||||
origin_chat_id = get_session_env("HERMES_SESSION_CHAT_ID")
|
||||
|
|
|
|||
|
|
@ -245,7 +245,7 @@ class _ThreadedProcessHandle:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
def wait(self, timeout: float | None = None) -> int:
|
||||
def wait(self, timeout: float | None = None) -> int | None:
|
||||
self._done.wait(timeout=timeout)
|
||||
return self._returncode
|
||||
|
||||
|
|
@ -755,7 +755,7 @@ class BaseEnvironment(ABC):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
def _prepare_command(self, command: str) -> tuple[str, str | None]:
|
||||
def _prepare_command(self, command: str) -> tuple[str | None, str | None]:
|
||||
"""Transform sudo commands if SUDO_PASSWORD is available."""
|
||||
from tools.terminal_tool import _transform_sudo_command
|
||||
|
||||
|
|
|
|||
|
|
@ -174,6 +174,7 @@ async def _run_reference_model_safe(
|
|||
error_msg = f"{model} failed after {max_retries} attempts: {error_str}"
|
||||
logger.error("%s", error_msg, exc_info=True)
|
||||
return model, error_msg, False
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
|
||||
async def _run_aggregator_model(
|
||||
|
|
|
|||
|
|
@ -443,7 +443,7 @@ def session_search(
|
|||
)
|
||||
|
||||
# Summarize all sessions in parallel
|
||||
async def _summarize_all() -> List[Union[str, Exception]]:
|
||||
async def _summarize_all() -> List[Union[Optional[str], BaseException]]:
|
||||
"""Summarize all sessions with bounded concurrency."""
|
||||
max_concurrency = min(_get_session_search_max_concurrency(), max(1, len(tasks)))
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ import hashlib
|
|||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
||||
|
||||
|
|
@ -639,7 +639,7 @@ def scan_skill(skill_path: Path, source: str = "community") -> ScanResult:
|
|||
)
|
||||
|
||||
|
||||
def should_allow_install(result: ScanResult, force: bool = False) -> Tuple[bool, str]:
|
||||
def should_allow_install(result: ScanResult, force: bool = False) -> Tuple[Optional[bool], str]:
|
||||
"""
|
||||
Determine whether a skill should be installed based on scan result and trust.
|
||||
|
||||
|
|
|
|||
|
|
@ -409,6 +409,7 @@ def _resolve_tirith_path(configured_path: str) -> str:
|
|||
|
||||
# Fast path: successfully resolved on a previous call.
|
||||
if _resolved_path is not None and _resolved_path is not _INSTALL_FAILED:
|
||||
assert isinstance(_resolved_path, str)
|
||||
return _resolved_path
|
||||
|
||||
expanded = os.path.expanduser(configured_path)
|
||||
|
|
|
|||
|
|
@ -652,7 +652,7 @@ def create_custom_toolset(
|
|||
|
||||
|
||||
|
||||
def get_toolset_info(name: str) -> Dict[str, Any]:
|
||||
def get_toolset_info(name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get detailed information about a toolset including resolved tools.
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ import yaml
|
|||
import logging
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple, Callable
|
||||
from typing import List, Dict, Any, Optional, Tuple, Callable, cast
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -75,7 +75,7 @@ def _effective_temperature_for_model(
|
|||
if fixed_temperature is OMIT_TEMPERATURE:
|
||||
return None # caller must omit temperature
|
||||
if fixed_temperature is not None:
|
||||
return fixed_temperature
|
||||
return cast(float, fixed_temperature)
|
||||
return requested_temperature
|
||||
|
||||
|
||||
|
|
@ -636,7 +636,8 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
|||
else:
|
||||
# Fallback: create a basic summary
|
||||
return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"
|
||||
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
async def _generate_summary_async(self, content: str, metrics: TrajectoryMetrics) -> str:
|
||||
"""
|
||||
Generate a summary of the compressed turns using OpenRouter (async version).
|
||||
|
|
@ -705,7 +706,8 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
|||
else:
|
||||
# Fallback: create a basic summary
|
||||
return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"
|
||||
|
||||
raise AssertionError("unreachable: retry loop exhausted")
|
||||
|
||||
def compress_trajectory(
|
||||
self,
|
||||
trajectory: List[Dict[str, str]]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue