diff --git a/agent/transports/__init__.py b/agent/transports/__init__.py index 575211332..615fadce1 100644 --- a/agent/transports/__init__.py +++ b/agent/transports/__init__.py @@ -6,7 +6,17 @@ Usage: result = transport.normalize_response(raw_response) """ -from agent.transports.types import NormalizedResponse, ToolCall, Usage, build_tool_call, map_finish_reason # noqa: F401 +from agent.transports.types import NormalizedResponse, ToolCall, Usage, build_tool_call, map_finish_reason + +__all__ = [ + "NormalizedResponse", + "ToolCall", + "Usage", + "build_tool_call", + "get_transport", + "map_finish_reason", + "register_transport", +] _REGISTRY: dict = {} @@ -34,18 +44,18 @@ def get_transport(api_mode: str): def _discover_transports() -> None: """Import all transport modules to trigger auto-registration.""" try: - import agent.transports.anthropic # noqa: F401 + import agent.transports.anthropic except ImportError: pass try: - import agent.transports.codex # noqa: F401 + import agent.transports.codex except ImportError: pass try: - import agent.transports.chat_completions # noqa: F401 + import agent.transports.chat_completions except ImportError: pass try: - import agent.transports.bedrock # noqa: F401 + import agent.transports.bedrock except ImportError: pass diff --git a/agent/transports/anthropic.py b/agent/transports/anthropic.py index 66c485b52..2dee499d8 100644 --- a/agent/transports/anthropic.py +++ b/agent/transports/anthropic.py @@ -4,7 +4,7 @@ Delegates to the existing adapter functions in agent/anthropic_adapter.py. This transport owns format conversion and normalization — NOT client lifecycle. """ -from typing import Any, Dict, List, Optional +from typing import Any, ClassVar from agent.transports.base import ProviderTransport from agent.transports.types import NormalizedResponse @@ -21,7 +21,7 @@ class AnthropicTransport(ProviderTransport): def api_mode(self) -> str: return "anthropic_messages" - def convert_messages(self, messages: List[Dict[str, Any]], **kwargs) -> Any: + def convert_messages(self, messages: list[dict[str, Any]], **kwargs) -> Any: """Convert OpenAI messages to Anthropic (system, messages) tuple. kwargs: @@ -32,7 +32,7 @@ class AnthropicTransport(ProviderTransport): base_url = kwargs.get("base_url") return convert_messages_to_anthropic(messages, base_url=base_url) - def convert_tools(self, tools: List[Dict[str, Any]]) -> Any: + def convert_tools(self, tools: list[dict[str, Any]]) -> Any: """Convert OpenAI tool schemas to Anthropic input_schema format.""" from agent.anthropic_adapter import convert_tools_to_anthropic @@ -41,10 +41,10 @@ class AnthropicTransport(ProviderTransport): def build_kwargs( self, model: str, - messages: List[Dict[str, Any]], - tools: Optional[List[Dict[str, Any]]] = None, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, **params, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Build Anthropic messages.create() kwargs. Calls convert_messages and convert_tools internally. @@ -82,6 +82,7 @@ class AnthropicTransport(ProviderTransport): to OpenAI finish_reason, and collects reasoning_details in provider_data. """ import json + from agent.anthropic_adapter import _to_plain_data from agent.transports.types import ToolCall @@ -104,7 +105,7 @@ class AnthropicTransport(ProviderTransport): elif block.type == "tool_use": name = block.name if strip_tool_prefix and name.startswith(_MCP_PREFIX): - name = name[len(_MCP_PREFIX):] + name = name[len(_MCP_PREFIX) :] tool_calls.append( ToolCall( id=block.id, @@ -145,7 +146,7 @@ class AnthropicTransport(ProviderTransport): return getattr(response, "stop_reason", None) == "end_turn" return True - def extract_cache_stats(self, response: Any) -> Optional[Dict[str, int]]: + def extract_cache_stats(self, response: Any) -> dict[str, int] | None: """Extract Anthropic cache_read and cache_creation token counts.""" usage = getattr(response, "usage", None) if usage is None: @@ -156,8 +157,7 @@ class AnthropicTransport(ProviderTransport): return {"cached_tokens": cached, "creation_tokens": written} return None - # Promote the adapter's canonical mapping to module level so it's shared - _STOP_REASON_MAP = { + _STOP_REASON_MAP: ClassVar[dict[str, str]] = { "end_turn": "stop", "tool_use": "tool_calls", "max_tokens": "length", diff --git a/agent/transports/base.py b/agent/transports/base.py index b516967b6..2d8e297b2 100644 --- a/agent/transports/base.py +++ b/agent/transports/base.py @@ -8,7 +8,7 @@ prompt caching, interrupt handling, or retry logic. Those stay on AIAgent. """ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any from agent.transports.types import NormalizedResponse @@ -23,7 +23,7 @@ class ProviderTransport(ABC): ... @abstractmethod - def convert_messages(self, messages: List[Dict[str, Any]], **kwargs) -> Any: + def convert_messages(self, messages: list[dict[str, Any]], **kwargs) -> Any: """Convert OpenAI-format messages to provider-native format. Returns provider-specific structure (e.g. (system, messages) for Anthropic, @@ -32,7 +32,7 @@ class ProviderTransport(ABC): ... @abstractmethod - def convert_tools(self, tools: List[Dict[str, Any]]) -> Any: + def convert_tools(self, tools: list[dict[str, Any]]) -> Any: """Convert OpenAI-format tool definitions to provider-native format. Returns provider-specific tool list (e.g. Anthropic input_schema format). @@ -43,10 +43,10 @@ class ProviderTransport(ABC): def build_kwargs( self, model: str, - messages: List[Dict[str, Any]], - tools: Optional[List[Dict[str, Any]]] = None, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, **params, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Build the complete API call kwargs dict. This is the primary entry point — it typically calls convert_messages() @@ -72,7 +72,7 @@ class ProviderTransport(ABC): """ return True - def extract_cache_stats(self, response: Any) -> Optional[Dict[str, int]]: + def extract_cache_stats(self, response: Any) -> dict[str, int] | None: """Optional: extract provider-specific cache hit/creation stats. Returns dict with 'cached_tokens' and 'creation_tokens', or None. diff --git a/agent/transports/bedrock.py b/agent/transports/bedrock.py index af549e7ea..3a8e71c39 100644 --- a/agent/transports/bedrock.py +++ b/agent/transports/bedrock.py @@ -6,7 +6,7 @@ owns format conversion and normalization, while client construction and boto3 calls stay on AIAgent. """ -from typing import Any, Dict, List, Optional +from typing import Any from agent.transports.base import ProviderTransport from agent.transports.types import NormalizedResponse, ToolCall, Usage @@ -19,23 +19,25 @@ class BedrockTransport(ProviderTransport): def api_mode(self) -> str: return "bedrock_converse" - def convert_messages(self, messages: List[Dict[str, Any]], **kwargs) -> Any: + def convert_messages(self, messages: list[dict[str, Any]], **kwargs) -> Any: """Convert OpenAI messages to Bedrock Converse format.""" from agent.bedrock_adapter import convert_messages_to_converse + return convert_messages_to_converse(messages) - def convert_tools(self, tools: List[Dict[str, Any]]) -> Any: + def convert_tools(self, tools: list[dict[str, Any]]) -> Any: """Convert OpenAI tool schemas to Bedrock Converse toolConfig.""" from agent.bedrock_adapter import convert_tools_to_converse + return convert_tools_to_converse(tools) def build_kwargs( self, model: str, - messages: List[Dict[str, Any]], - tools: Optional[List[Dict[str, Any]]] = None, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, **params, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Build Bedrock converse() kwargs. Calls convert_messages and convert_tools internally. @@ -105,7 +107,9 @@ class BedrockTransport(ProviderTransport): total_tokens=getattr(u, "total_tokens", 0) or 0, ) - reasoning = getattr(msg, "reasoning", None) or getattr(msg, "reasoning_content", None) + reasoning = getattr(msg, "reasoning", None) or getattr( + msg, "reasoning_content", None + ) return NormalizedResponse( content=msg.content, diff --git a/agent/transports/chat_completions.py b/agent/transports/chat_completions.py index 815897513..5b11c028c 100644 --- a/agent/transports/chat_completions.py +++ b/agent/transports/chat_completions.py @@ -10,7 +10,7 @@ reasoning configuration, temperature handling, and extra_body assembly. """ import copy -from typing import Any, Dict, List, Optional +from typing import Any from agent.prompt_builder import DEVELOPER_ROLE_MODELS from agent.transports.base import ProviderTransport @@ -27,7 +27,9 @@ class ChatCompletionsTransport(ProviderTransport): def api_mode(self) -> str: return "chat_completions" - def convert_messages(self, messages: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]: + def convert_messages( + self, messages: list[dict[str, Any]], **kwargs + ) -> list[dict[str, Any]]: """Messages are already in OpenAI format — sanitize Codex leaks only. Strips Codex Responses API fields (``codex_reasoning_items`` on the @@ -44,7 +46,9 @@ class ChatCompletionsTransport(ProviderTransport): tool_calls = msg.get("tool_calls") if isinstance(tool_calls, list): for tc in tool_calls: - if isinstance(tc, dict) and ("call_id" in tc or "response_item_id" in tc): + if isinstance(tc, dict) and ( + "call_id" in tc or "response_item_id" in tc + ): needs_sanitize = True break if needs_sanitize: @@ -66,17 +70,17 @@ class ChatCompletionsTransport(ProviderTransport): tc.pop("response_item_id", None) return sanitized - def convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def convert_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: """Tools are already in OpenAI format — identity.""" return tools def build_kwargs( self, model: str, - messages: List[Dict[str, Any]], - tools: Optional[List[Dict[str, Any]]] = None, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, **params, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Build chat.completions.create() kwargs. This is the most complex transport method — it handles ~16 providers @@ -86,7 +90,8 @@ class ChatCompletionsTransport(ProviderTransport): timeout: float — API call timeout max_tokens: int | None — user-configured max tokens ephemeral_max_output_tokens: int | None — one-shot override (error recovery) - max_tokens_param_fn: callable — returns {max_tokens: N} or {max_completion_tokens: N} + max_tokens_param_fn: callable — {max_tokens: N} or + {max_completion_tokens: N} reasoning_config: dict | None request_overrides: dict | None session_id: str | None @@ -105,7 +110,7 @@ class ChatCompletionsTransport(ProviderTransport): provider_preferences: dict | None # Qwen-specific qwen_prepare_fn: callable | None — runs AFTER codex sanitization - qwen_prepare_inplace_fn: callable | None — in-place variant for deepcopied lists + qwen_prepare_inplace_fn: callable | None — in-place deepcopy variant # Temperature fixed_temperature: Any — from _fixed_temperature_for_model() omit_temperature: bool @@ -155,7 +160,7 @@ class ChatCompletionsTransport(ProviderTransport): sanitized = list(sanitized) sanitized[0] = {**sanitized[0], "role": "developer"} - api_kwargs: Dict[str, Any] = { + api_kwargs: dict[str, Any] = { "model": model, "messages": sanitized, } @@ -220,7 +225,7 @@ class ChatCompletionsTransport(ProviderTransport): api_kwargs["reasoning_effort"] = _kimi_effort # extra_body assembly - extra_body: Dict[str, Any] = {} + extra_body: dict[str, Any] = {} is_openrouter = params.get("is_openrouter", False) is_nous = params.get("is_nous", False) @@ -233,9 +238,8 @@ class ChatCompletionsTransport(ProviderTransport): # Kimi extra_body.thinking if is_kimi: _kimi_thinking_enabled = True - if reasoning_config and isinstance(reasoning_config, dict): - if reasoning_config.get("enabled") is False: - _kimi_thinking_enabled = False + if reasoning_config and isinstance(reasoning_config, dict) and reasoning_config.get("enabled") is False: + _kimi_thinking_enabled = False extra_body["thinking"] = { "type": "enabled" if _kimi_thinking_enabled else "disabled", } @@ -267,8 +271,7 @@ class ChatCompletionsTransport(ProviderTransport): extra_body["options"] = options # Ollama/custom think=false - if params.get("is_custom_provider", False): - if reasoning_config and isinstance(reasoning_config, dict): + if params.get("is_custom_provider", False) and reasoning_config and isinstance(reasoning_config, dict): _effort = (reasoning_config.get("effort") or "").strip().lower() _enabled = reasoning_config.get("enabled", True) if _effort == "none" or _enabled is False: @@ -314,7 +317,7 @@ class ChatCompletionsTransport(ProviderTransport): sanitized = list(sanitized) sanitized[0] = {**sanitized[0], "role": "developer"} - api_kwargs: Dict[str, Any] = { + api_kwargs: dict[str, Any] = { "model": model, "messages": sanitized, } @@ -356,15 +359,17 @@ class ChatCompletionsTransport(ProviderTransport): # Provider-specific api_kwargs extras (reasoning_effort, metadata, etc.) reasoning_config = params.get("reasoning_config") - extra_body_from_profile, top_level_from_profile = profile.build_api_kwargs_extras( - reasoning_config=reasoning_config, - supports_reasoning=params.get("supports_reasoning", False), - qwen_session_metadata=params.get("qwen_session_metadata"), + extra_body_from_profile, top_level_from_profile = ( + profile.build_api_kwargs_extras( + reasoning_config=reasoning_config, + supports_reasoning=params.get("supports_reasoning", False), + qwen_session_metadata=params.get("qwen_session_metadata"), + ) ) api_kwargs.update(top_level_from_profile) # extra_body assembly - extra_body: Dict[str, Any] = {} + extra_body: dict[str, Any] = {} # Profile's extra_body (tags, provider prefs, vl_high_resolution, etc.) profile_body = profile.build_extra_body( @@ -418,7 +423,7 @@ class ChatCompletionsTransport(ProviderTransport): # Gemini 3 thinking models attach extra_content with # thought_signature — without replay on the next turn the API # rejects the request with 400. - tc_provider_data: Dict[str, Any] = {} + tc_provider_data: dict[str, Any] = {} extra = getattr(tc, "extra_content", None) if extra is None and hasattr(tc, "model_extra"): extra = (tc.model_extra or {}).get("extra_content") @@ -429,12 +434,14 @@ class ChatCompletionsTransport(ProviderTransport): except Exception: pass tc_provider_data["extra_content"] = extra - tool_calls.append(ToolCall( - id=tc.id, - name=tc.function.name, - arguments=tc.function.arguments, - provider_data=tc_provider_data or None, - )) + tool_calls.append( + ToolCall( + id=tc.id, + name=tc.function.name, + arguments=tc.function.arguments, + provider_data=tc_provider_data or None, + ) + ) usage = None if hasattr(response, "usage") and response.usage: @@ -452,7 +459,7 @@ class ChatCompletionsTransport(ProviderTransport): reasoning = getattr(msg, "reasoning", None) reasoning_content = getattr(msg, "reasoning_content", None) - provider_data: Dict[str, Any] = {} + provider_data: dict[str, Any] = {} if reasoning_content: provider_data["reasoning_content"] = reasoning_content rd = getattr(msg, "reasoning_details", None) @@ -474,11 +481,9 @@ class ChatCompletionsTransport(ProviderTransport): return False if not hasattr(response, "choices") or response.choices is None: return False - if not response.choices: - return False - return True + return bool(response.choices) - def extract_cache_stats(self, response: Any) -> Optional[Dict[str, int]]: + def extract_cache_stats(self, response: Any) -> dict[str, int] | None: """Extract OpenRouter/OpenAI cache stats from prompt_tokens_details.""" usage = getattr(response, "usage", None) if usage is None: diff --git a/agent/transports/codex.py b/agent/transports/codex.py index ec4835219..863848fc0 100644 --- a/agent/transports/codex.py +++ b/agent/transports/codex.py @@ -5,10 +5,10 @@ This transport owns format conversion and normalization — NOT client lifecycle streaming, or the _run_codex_stream() call path. """ -from typing import Any, Dict, List, Optional +from typing import Any from agent.transports.base import ProviderTransport -from agent.transports.types import NormalizedResponse, ToolCall, Usage +from agent.transports.types import NormalizedResponse, ToolCall class ResponsesApiTransport(ProviderTransport): @@ -21,23 +21,25 @@ class ResponsesApiTransport(ProviderTransport): def api_mode(self) -> str: return "codex_responses" - def convert_messages(self, messages: List[Dict[str, Any]], **kwargs) -> Any: + def convert_messages(self, messages: list[dict[str, Any]], **kwargs) -> Any: """Convert OpenAI chat messages to Responses API input items.""" from agent.codex_responses_adapter import _chat_messages_to_responses_input + return _chat_messages_to_responses_input(messages) - def convert_tools(self, tools: List[Dict[str, Any]]) -> Any: + def convert_tools(self, tools: list[dict[str, Any]]) -> Any: """Convert OpenAI tool schemas to Responses API function definitions.""" from agent.codex_responses_adapter import _responses_tools + return _responses_tools(tools) def build_kwargs( self, model: str, - messages: List[Dict[str, Any]], - tools: Optional[List[Dict[str, Any]]] = None, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, **params, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Build Responses API kwargs. Calls convert_messages and convert_tools internally. @@ -60,13 +62,11 @@ class ResponsesApiTransport(ProviderTransport): _chat_messages_to_responses_input, _responses_tools, ) - from run_agent import DEFAULT_AGENT_IDENTITY instructions = params.get("instructions", "") payload_messages = messages - if not instructions: - if messages and messages[0].get("role") == "system": + if not instructions and messages and messages[0].get("role") == "system": instructions = str(messages[0].get("content") or "").strip() payload_messages = messages[1:] if not instructions: @@ -133,8 +133,6 @@ class ResponsesApiTransport(ProviderTransport): """Normalize Codex Responses API response to NormalizedResponse.""" from agent.codex_responses_adapter import ( _normalize_codex_response, - _extract_responses_message_text, - _extract_responses_reasoning_text, ) # _normalize_codex_response returns (SimpleNamespace, finish_reason_str) @@ -149,12 +147,20 @@ class ResponsesApiTransport(ProviderTransport): provider_data["call_id"] = tc.call_id if hasattr(tc, "response_item_id") and tc.response_item_id: provider_data["response_item_id"] = tc.response_item_id - tool_calls.append(ToolCall( - id=tc.id if hasattr(tc, "id") else (tc.function.name if hasattr(tc, "function") else None), - name=tc.function.name if hasattr(tc, "function") else getattr(tc, "name", ""), - arguments=tc.function.arguments if hasattr(tc, "function") else getattr(tc, "arguments", "{}"), - provider_data=provider_data or None, - )) + tool_calls.append( + ToolCall( + id=tc.id + if hasattr(tc, "id") + else (tc.function.name if hasattr(tc, "function") else None), + name=tc.function.name + if hasattr(tc, "function") + else getattr(tc, "name", ""), + arguments=tc.function.arguments + if hasattr(tc, "function") + else getattr(tc, "arguments", "{}"), + provider_data=provider_data or None, + ) + ) # Extract reasoning items for provider_data provider_data = {} @@ -182,9 +188,7 @@ class ResponsesApiTransport(ProviderTransport): if response is None: return False output = getattr(response, "output", None) - if not isinstance(output, list) or not output: - return False - return True + return isinstance(output, list) and bool(output) def preflight_kwargs(self, api_kwargs: Any, *, allow_stream: bool = False) -> dict: """Validate and sanitize Codex API kwargs before the call. @@ -192,6 +196,7 @@ class ResponsesApiTransport(ProviderTransport): Normalizes input items, strips unsupported fields, validates structure. """ from agent.codex_responses_adapter import _preflight_codex_api_kwargs + return _preflight_codex_api_kwargs(api_kwargs, allow_stream=allow_stream) def map_finish_reason(self, raw_reason: str) -> str: diff --git a/agent/transports/types.py b/agent/transports/types.py index 5199a5db1..8345be278 100644 --- a/agent/transports/types.py +++ b/agent/transports/types.py @@ -12,7 +12,7 @@ from __future__ import annotations import json from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any @dataclass @@ -32,10 +32,10 @@ class ToolCall: * Others: ``None`` """ - id: Optional[str] + id: str | None name: str arguments: str # JSON string - provider_data: Optional[Dict[str, Any]] = field(default=None, repr=False) + provider_data: dict[str, Any] | None = field(default=None, repr=False) # ── Backward compatibility ────────────────────────────────── # The agent loop reads tc.function.name / tc.function.arguments @@ -47,17 +47,17 @@ class ToolCall: return "function" @property - def function(self) -> "ToolCall": + def function(self) -> ToolCall: """Return self so tc.function.name / tc.function.arguments work.""" return self @property - def call_id(self) -> Optional[str]: - """Codex call_id from provider_data, accessed via getattr by _build_assistant_message.""" + def call_id(self) -> str | None: + """Codex call_id from provider_data.""" return (self.provider_data or {}).get("call_id") @property - def response_item_id(self) -> Optional[str]: + def response_item_id(self) -> str | None: """Codex response_item_id from provider_data.""" return (self.provider_data or {}).get("response_item_id") @@ -87,18 +87,18 @@ class NormalizedResponse: * Others: ``None`` """ - content: Optional[str] - tool_calls: Optional[List[ToolCall]] + content: str | None + tool_calls: list[ToolCall] | None finish_reason: str # "stop", "tool_calls", "length", "content_filter" - reasoning: Optional[str] = None - usage: Optional[Usage] = None - provider_data: Optional[Dict[str, Any]] = field(default=None, repr=False) + reasoning: str | None = None + usage: Usage | None = None + provider_data: dict[str, Any] | None = field(default=None, repr=False) # ── Backward compatibility ────────────────────────────────── # The shim _nr_to_assistant_message() mapped these from provider_data. # These properties let NormalizedResponse pass through directly. @property - def reasoning_content(self) -> Optional[str]: + def reasoning_content(self) -> str | None: pd = self.provider_data or {} return pd.get("reasoning_content") @@ -117,8 +117,9 @@ class NormalizedResponse: # Factory helpers # --------------------------------------------------------------------------- + def build_tool_call( - id: Optional[str], + id: str | None, name: str, arguments: Any, **provider_fields: Any, @@ -132,7 +133,7 @@ def build_tool_call( return ToolCall(id=id, name=name, arguments=args_str, provider_data=pd) -def map_finish_reason(reason: Optional[str], mapping: Dict[str, str]) -> str: +def map_finish_reason(reason: str | None, mapping: dict[str, str]) -> str: """Translate a provider-specific stop reason to the normalised set. Falls back to ``"stop"`` for unknown or ``None`` reasons. diff --git a/providers/__init__.py b/providers/__init__.py index 4f24f7e20..809279daf 100644 --- a/providers/__init__.py +++ b/providers/__init__.py @@ -9,14 +9,17 @@ Usage: profile = get_provider_profile("kimi") # checks name + aliases """ -from __future__ import annotations +from providers.base import OMIT_TEMPERATURE, ProviderProfile -from typing import Dict, Optional +__all__ = [ + "OMIT_TEMPERATURE", + "ProviderProfile", + "get_provider_profile", + "register_provider", +] -from providers.base import ProviderProfile, OMIT_TEMPERATURE # noqa: F401 - -_REGISTRY: Dict[str, ProviderProfile] = {} -_ALIASES: Dict[str, str] = {} +_REGISTRY: dict[str, ProviderProfile] = {} +_ALIASES: dict[str, str] = {} _discovered = False @@ -27,7 +30,7 @@ def register_provider(profile: ProviderProfile) -> None: _ALIASES[alias] = profile.name -def get_provider_profile(name: str) -> Optional[ProviderProfile]: +def get_provider_profile(name: str) -> ProviderProfile | None: """Look up a provider profile by name or alias. Returns None if the provider has no profile (falls back to generic). @@ -47,6 +50,7 @@ def _discover_providers() -> None: import importlib import pkgutil + import providers as _pkg for _importer, modname, _ispkg in pkgutil.iter_modules(_pkg.__path__): @@ -56,6 +60,7 @@ def _discover_providers() -> None: importlib.import_module(f"providers.{modname}") except ImportError as e: import logging + logging.getLogger(__name__).warning( "Failed to import provider module %s: %s", modname, e ) diff --git a/providers/base.py b/providers/base.py index 8e03e62f3..61ef69d94 100644 --- a/providers/base.py +++ b/providers/base.py @@ -9,11 +9,8 @@ They do NOT own client construction, credential rotation, or streaming. Those stay on AIAgent. """ -from __future__ import annotations - from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple - +from typing import Any # Sentinel for "omit temperature entirely" (Kimi: server manages it) OMIT_TEMPERATURE = object() @@ -31,19 +28,21 @@ class ProviderProfile: # ── Auth ───────────────────────────────────────────────── env_vars: tuple = () base_url: str = "" - auth_type: str = "api_key" # api_key | oauth_device_code | oauth_external | copilot | aws + auth_type: str = ( + "api_key" # api_key | oauth_device_code | oauth_external | copilot | aws + ) # ── Client-level quirks (set once at client construction) ─ - default_headers: Dict[str, str] = field(default_factory=dict) + default_headers: dict[str, str] = field(default_factory=dict) # ── Request-level quirks ───────────────────────────────── # Temperature: None = use caller's default, OMIT_TEMPERATURE = don't send fixed_temperature: Any = None - default_max_tokens: Optional[int] = None + default_max_tokens: int | None = None # ── Hooks (override in subclass for complex providers) ─── - def prepare_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Provider-specific message preprocessing. Called AFTER codex field sanitization, BEFORE developer role swap. @@ -51,15 +50,18 @@ class ProviderProfile: """ return messages - def build_extra_body(self, *, session_id: str = None, **context) -> Dict[str, Any]: + def build_extra_body( + self, *, session_id: str | None = None, **context + ) -> dict[str, Any]: """Provider-specific extra_body fields. Merged into the API kwargs extra_body. Default: empty dict. """ return {} - def build_api_kwargs_extras(self, *, reasoning_config: dict = None, - **context) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def build_api_kwargs_extras( + self, *, reasoning_config: dict | None = None, **context + ) -> tuple[dict[str, Any], dict[str, Any]]: """Provider-specific kwargs that go to BOTH extra_body and top-level api_kwargs. Returns (extra_body_additions, top_level_kwargs). diff --git a/providers/deepseek.py b/providers/deepseek.py index ba7ed0d3d..a93ee1b18 100644 --- a/providers/deepseek.py +++ b/providers/deepseek.py @@ -1,7 +1,7 @@ """DeepSeek provider profile.""" -from providers.base import ProviderProfile from providers import register_provider +from providers.base import ProviderProfile deepseek = ProviderProfile( name="deepseek", diff --git a/providers/kimi.py b/providers/kimi.py index e6538a166..ac5098ccd 100644 --- a/providers/kimi.py +++ b/providers/kimi.py @@ -7,17 +7,18 @@ Kimi has dual endpoints: This module covers the chat_completions path (/v1 endpoint). """ -from typing import Any, Dict, Tuple +from typing import Any -from providers.base import ProviderProfile, OMIT_TEMPERATURE from providers import register_provider +from providers.base import OMIT_TEMPERATURE, ProviderProfile class KimiProfile(ProviderProfile): """Kimi/Moonshot — temperature omitted, thinking + reasoning_effort.""" - def build_api_kwargs_extras(self, *, reasoning_config: dict = None, - **context) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def build_api_kwargs_extras( + self, *, reasoning_config: dict | None = None, **context + ) -> tuple[dict[str, Any], dict[str, Any]]: """Kimi uses extra_body.thinking + top-level reasoning_effort.""" extra_body = {} top_level = {} diff --git a/providers/nous.py b/providers/nous.py index 42113672e..191c24c8f 100644 --- a/providers/nous.py +++ b/providers/nous.py @@ -1,20 +1,26 @@ """Nous Portal provider profile.""" -from typing import Any, Dict, Tuple +from typing import Any -from providers.base import ProviderProfile from providers import register_provider +from providers.base import ProviderProfile class NousProfile(ProviderProfile): """Nous Portal — product tags, reasoning with Nous-specific omission.""" - def build_extra_body(self, *, session_id: str = None, **context) -> Dict[str, Any]: + def build_extra_body( + self, *, session_id: str | None = None, **context + ) -> dict[str, Any]: return {"tags": ["product=hermes-agent"]} - def build_api_kwargs_extras(self, *, reasoning_config: dict = None, - supports_reasoning: bool = False, - **context) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def build_api_kwargs_extras( + self, + *, + reasoning_config: dict | None = None, + supports_reasoning: bool = False, + **context, + ) -> tuple[dict[str, Any], dict[str, Any]]: """Nous: passes full reasoning_config, but OMITS when disabled.""" extra_body = {} if supports_reasoning: diff --git a/providers/nvidia.py b/providers/nvidia.py index f707aa72a..b91116535 100644 --- a/providers/nvidia.py +++ b/providers/nvidia.py @@ -1,7 +1,7 @@ """NVIDIA NIM provider profile.""" -from providers.base import ProviderProfile from providers import register_provider +from providers.base import ProviderProfile nvidia = ProviderProfile( name="nvidia", diff --git a/providers/openrouter.py b/providers/openrouter.py index 7d74e8b9c..942721454 100644 --- a/providers/openrouter.py +++ b/providers/openrouter.py @@ -1,24 +1,30 @@ """OpenRouter provider profile.""" -from typing import Any, Dict, Tuple +from typing import Any -from providers.base import ProviderProfile from providers import register_provider +from providers.base import ProviderProfile class OpenRouterProfile(ProviderProfile): """OpenRouter — provider preferences, full reasoning config passthrough.""" - def build_extra_body(self, *, session_id: str = None, **context) -> Dict[str, Any]: + def build_extra_body( + self, *, session_id: str | None = None, **context + ) -> dict[str, Any]: body = {} prefs = context.get("provider_preferences") if prefs: body["provider"] = prefs return body - def build_api_kwargs_extras(self, *, reasoning_config: dict = None, - supports_reasoning: bool = False, - **context) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def build_api_kwargs_extras( + self, + *, + reasoning_config: dict | None = None, + supports_reasoning: bool = False, + **context, + ) -> tuple[dict[str, Any], dict[str, Any]]: """OpenRouter passes the FULL reasoning_config dict as extra_body.reasoning.""" extra_body = {} if supports_reasoning: diff --git a/providers/qwen.py b/providers/qwen.py index f72ea3569..ce0b4ebde 100644 --- a/providers/qwen.py +++ b/providers/qwen.py @@ -1,17 +1,17 @@ """Qwen Portal provider profile.""" import copy -from typing import Any, Dict, List, Tuple +from typing import Any -from providers.base import ProviderProfile from providers import register_provider +from providers.base import ProviderProfile class QwenProfile(ProviderProfile): """Qwen Portal — message normalization, vl_high_resolution, metadata top-level.""" - def prepare_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Normalize content to list-of-dicts format, inject cache_control on system msg. + def prepare_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Normalize content to list-of-dicts, inject cache_control on system msg. Matches the behavior of run_agent.py:_qwen_prepare_chat_messages(). """ @@ -39,18 +39,28 @@ class QwenProfile(ProviderProfile): for msg in prepared: if isinstance(msg, dict) and msg.get("role") == "system": content = msg.get("content") - if isinstance(content, list) and content and isinstance(content[-1], dict): + if ( + isinstance(content, list) + and content + and isinstance(content[-1], dict) + ): content[-1]["cache_control"] = {"type": "ephemeral"} break return prepared - def build_extra_body(self, *, session_id: str = None, **context) -> Dict[str, Any]: + def build_extra_body( + self, *, session_id: str | None = None, **context + ) -> dict[str, Any]: return {"vl_high_resolution_images": True} - def build_api_kwargs_extras(self, *, reasoning_config: dict = None, - qwen_session_metadata: dict = None, - **context) -> Tuple[Dict[str, Any], Dict[str, Any]]: + def build_api_kwargs_extras( + self, + *, + reasoning_config: dict | None = None, + qwen_session_metadata: dict | None = None, + **context, + ) -> tuple[dict[str, Any], dict[str, Any]]: """Qwen metadata goes to top-level api_kwargs, not extra_body.""" top_level = {} if qwen_session_metadata: diff --git a/pyproject.toml b/pyproject.toml index 7c9521731..104a9ee5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -142,8 +142,50 @@ python-version = "3.13" unknown-argument = "warn" redundant-cast = "ignore" +[tool.ruff] +include = ["agent/transports/**/*.py", "providers/**/*.py"] +target-version = "py311" +line-length = 120 + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "F", # Pyflakes + "W", # pycodestyle warnings + "I", # isort + "UP", # pyupgrade + "B", # flake8-bugbear + "SIM", # flake8-simplify + "RUF", # ruff-specific + "PIE", # flake8-pie + "C4", # flake8-comprehensions + "PLC", # pylint convention + "PLE", # pylint error + "RET", # flake8-return + "RSE", # flake8-raise + "FLY", # flynt + "PERF", # perflint + "FURB", # refurb + "LOG", # flake8-logging +] +ignore = [ + "E501", # line length — handled by formatter + "SIM108", # ternary instead of if/else + "RET504", # unnecessary assignment before return + "UP007", # X | None in older annotations + "PLC0415", # deferred/lazy imports are intentional + "SIM105", # try/except/pass for optional deps is idiomatic here +] + +[tool.ruff.lint.isort] +known-first-party = ["agent", "providers", "tools", "hermes_cli", "gateway", "tui_gateway"] +combine-as-imports = true + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] + [tool.ty.src] -exclude = ["**"] +include = ["agent/transports/**/*.py", "providers/**/*.py"] [[tool.ty.overrides]] include = ["**"]