diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index 4f8c9a0a4..687fe9cd3 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -656,6 +656,96 @@ class AnthropicAuxiliaryClient: close_fn() +# --------------------------------------------------------------------------- +# Bedrock Converse adapter — wraps boto3 bedrock-runtime into an +# OpenAI-client-compatible interface for auxiliary tasks. +# --------------------------------------------------------------------------- + +class _BedrockCompletionsAdapter: + """Translates ``chat.completions.create(**kwargs)`` into a Bedrock + Converse API call and returns an OpenAI-shaped SimpleNamespace.""" + + def __init__(self, region: str, model: str): + self._region = region + self._model = model + + def create(self, **kwargs) -> Any: + from agent.bedrock_adapter import ( + _get_bedrock_runtime_client, + convert_messages_to_converse, + normalize_converse_response, + ) + + messages = kwargs.get("messages", []) + model = kwargs.get("model", self._model) + + system_prompt, converse_messages = convert_messages_to_converse(messages) + + converse_kwargs: Dict[str, Any] = { + "modelId": model, + "messages": converse_messages, + } + if system_prompt: + converse_kwargs["system"] = system_prompt + + # Forward max_tokens / temperature if provided + inference_config: Dict[str, Any] = {} + if kwargs.get("max_tokens"): + inference_config["maxTokens"] = kwargs["max_tokens"] + if kwargs.get("temperature") is not None: + inference_config["temperature"] = kwargs["temperature"] + if inference_config: + converse_kwargs["inferenceConfig"] = inference_config + + client = _get_bedrock_runtime_client(self._region) + response = client.converse(**converse_kwargs) + return normalize_converse_response(response) + + +class _BedrockChatShim: + def __init__(self, adapter: "_BedrockCompletionsAdapter"): + self.completions = adapter + + +class BedrockAuxiliaryClient: + """OpenAI-client-compatible wrapper that routes to AWS Bedrock Converse API.""" + + def __init__(self, region: str, model: str): + self._region = region + self._model = model + adapter = _BedrockCompletionsAdapter(region, model) + self.chat = _BedrockChatShim(adapter) + # Provide dummy api_key/base_url so callers that inspect them don't crash + self.api_key = "bedrock-iam" + self.base_url = f"https://bedrock-runtime.{region}.amazonaws.com" + + def close(self): + pass + + +class _AsyncBedrockCompletionsAdapter: + def __init__(self, sync_adapter: _BedrockCompletionsAdapter): + self._sync = sync_adapter + + async def create(self, **kwargs) -> Any: + import asyncio + return await asyncio.to_thread(self._sync.create, **kwargs) + + +class _AsyncBedrockChatShim: + def __init__(self, adapter: _AsyncBedrockCompletionsAdapter): + self.completions = adapter + + +class AsyncBedrockAuxiliaryClient: + def __init__(self, sync_wrapper: "BedrockAuxiliaryClient"): + sync_adapter = sync_wrapper.chat.completions + async_adapter = _AsyncBedrockCompletionsAdapter(sync_adapter) + self.chat = _AsyncBedrockChatShim(async_adapter) + self.api_key = sync_wrapper.api_key + self.base_url = sync_wrapper.base_url + + class _AsyncAnthropicCompletionsAdapter: def __init__(self, sync_adapter: _AnthropicCompletionsAdapter): self._sync = sync_adapter @@ -1473,6 +1563,8 @@ def _to_async_client(sync_client, model: str): return AsyncCodexAuxiliaryClient(sync_client), model if isinstance(sync_client, AnthropicAuxiliaryClient): return AsyncAnthropicAuxiliaryClient(sync_client), model + if isinstance(sync_client, BedrockAuxiliaryClient): + return AsyncBedrockAuxiliaryClient(sync_client), model try: from agent.gemini_native_adapter import GeminiNativeClient, AsyncGeminiNativeClient @@ -1621,6 +1713,23 @@ def resolve_provider_client( return (_to_async_client(client, final_model) if async_mode else (client, final_model)) + # ── AWS Bedrock (Converse API via IAM) ──────────────────────────── + if provider == "bedrock": + try: + from agent.bedrock_adapter import has_aws_credentials, resolve_bedrock_region + if not has_aws_credentials(): + logger.warning("resolve_provider_client: bedrock requested " + "but no AWS credentials found") + return None, None + region = resolve_bedrock_region() + final_model = model or _read_main_model() or "us.anthropic.claude-sonnet-4-6" + client = BedrockAuxiliaryClient(region, final_model) + return (_to_async_client(client, final_model) if async_mode + else (client, final_model)) + except Exception as exc: + logger.warning("resolve_provider_client: bedrock setup failed: %s", exc) + return None, None + # ── Nous Portal (OAuth) ────────────────────────────────────────── if provider == "nous": # Detect vision tasks: either explicit model override from