diff --git a/agent/anthropic_adapter.py b/agent/anthropic_adapter.py index bf2b8a62c..d8d181cc1 100644 --- a/agent/anthropic_adapter.py +++ b/agent/anthropic_adapter.py @@ -1525,3 +1525,42 @@ def normalize_anthropic_response( ), finish_reason, ) + + +def normalize_anthropic_response_v2( + response, + strip_tool_prefix: bool = False, +) -> "NormalizedResponse": + """Normalize Anthropic response to NormalizedResponse. + + Wraps the existing normalize_anthropic_response() and maps its output + to the shared transport types. This allows incremental migration — + one call site at a time — without changing the original function. + """ + from agent.transports.types import NormalizedResponse, build_tool_call + + assistant_msg, finish_reason = normalize_anthropic_response(response, strip_tool_prefix) + + tool_calls = None + if assistant_msg.tool_calls: + tool_calls = [ + build_tool_call( + id=tc.id, + name=tc.function.name, + arguments=tc.function.arguments, + ) + for tc in assistant_msg.tool_calls + ] + + provider_data = {} + if getattr(assistant_msg, "reasoning_details", None): + provider_data["reasoning_details"] = assistant_msg.reasoning_details + + return NormalizedResponse( + content=assistant_msg.content, + tool_calls=tool_calls, + finish_reason=finish_reason, + reasoning=getattr(assistant_msg, "reasoning", None), + usage=None, # Anthropic usage is on the raw response, not the normaliser + provider_data=provider_data or None, + ) diff --git a/agent/transports/__init__.py b/agent/transports/__init__.py new file mode 100644 index 000000000..6ee1c5117 --- /dev/null +++ b/agent/transports/__init__.py @@ -0,0 +1 @@ +"""Transport layer types for provider response normalization.""" diff --git a/agent/transports/types.py b/agent/transports/types.py new file mode 100644 index 000000000..2b048fcaa --- /dev/null +++ b/agent/transports/types.py @@ -0,0 +1,100 @@ +"""Shared types for normalized provider responses. + +These dataclasses define the canonical shape that all provider adapters +normalize responses to. The shared surface is intentionally minimal — +only fields that every downstream consumer reads are top-level. +Protocol-specific state goes in ``provider_data`` dicts (response-level +and per-tool-call) so that protocol-aware code paths can access it +without polluting the shared type. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class ToolCall: + """A normalized tool call from any provider. + + ``id`` is the protocol's canonical identifier — what gets used in + ``tool_call_id`` / ``tool_use_id`` when constructing tool result + messages. May be ``None`` when the provider omits it; the agent + fills it via ``_deterministic_call_id()`` before storing in history. + + ``provider_data`` carries per-tool-call protocol metadata that only + protocol-aware code reads: + + * Codex: ``{"call_id": "call_XXX", "response_item_id": "fc_XXX"}`` + * Gemini: ``{"extra_content": {"google": {"thought_signature": "..."}}}`` + * Others: ``None`` + """ + + id: Optional[str] + name: str + arguments: str # JSON string + provider_data: Optional[Dict[str, Any]] = field(default=None, repr=False) + + +@dataclass +class Usage: + """Token usage from an API response.""" + + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + cached_tokens: int = 0 + + +@dataclass +class NormalizedResponse: + """Normalized API response from any provider. + + Shared fields are truly cross-provider — every caller can rely on + them without branching on api_mode. Protocol-specific state goes in + ``provider_data`` so that only protocol-aware code paths read it. + + Response-level ``provider_data`` examples: + + * Anthropic: ``{"reasoning_details": [...]}`` + * Codex: ``{"codex_reasoning_items": [...]}`` + * Others: ``None`` + """ + + content: Optional[str] + tool_calls: Optional[List[ToolCall]] + finish_reason: str # "stop", "tool_calls", "length", "content_filter" + reasoning: Optional[str] = None + usage: Optional[Usage] = None + provider_data: Optional[Dict[str, Any]] = field(default=None, repr=False) + + +# --------------------------------------------------------------------------- +# Factory helpers +# --------------------------------------------------------------------------- + +def build_tool_call( + id: Optional[str], + name: str, + arguments: Any, + **provider_fields: Any, +) -> ToolCall: + """Build a ``ToolCall``, auto-serialising *arguments* if it's a dict. + + Any extra keyword arguments are collected into ``provider_data``. + """ + args_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments) + pd = dict(provider_fields) if provider_fields else None + return ToolCall(id=id, name=name, arguments=args_str, provider_data=pd) + + +def map_finish_reason(reason: Optional[str], mapping: Dict[str, str]) -> str: + """Translate a provider-specific stop reason to the normalised set. + + Falls back to ``"stop"`` for unknown or ``None`` reasons. + """ + if reason is None: + return "stop" + return mapping.get(reason, "stop") diff --git a/run_agent.py b/run_agent.py index 49240d70f..e69d30ff2 100644 --- a/run_agent.py +++ b/run_agent.py @@ -10778,10 +10778,33 @@ class AIAgent: if self.api_mode == "codex_responses": assistant_message, finish_reason = self._normalize_codex_response(response) elif self.api_mode == "anthropic_messages": - from agent.anthropic_adapter import normalize_anthropic_response - assistant_message, finish_reason = normalize_anthropic_response( + from agent.anthropic_adapter import normalize_anthropic_response_v2 + _nr = normalize_anthropic_response_v2( response, strip_tool_prefix=self._is_anthropic_oauth ) + # Back-compat shim: downstream code expects SimpleNamespace with + # .content, .tool_calls, .reasoning, .reasoning_content, + # .reasoning_details attributes. This shim makes the cost of the + # old interface visible — it vanishes when the full transport + # wiring lands (PR 3+). + assistant_message = SimpleNamespace( + content=_nr.content, + tool_calls=[ + SimpleNamespace( + id=tc.id, + type="function", + function=SimpleNamespace(name=tc.name, arguments=tc.arguments), + ) + for tc in (_nr.tool_calls or []) + ] or None, + reasoning=_nr.reasoning, + reasoning_content=None, + reasoning_details=( + _nr.provider_data.get("reasoning_details") + if _nr.provider_data else None + ), + ) + finish_reason = _nr.finish_reason else: assistant_message = response.choices[0].message diff --git a/tests/agent/test_anthropic_normalize_v2.py b/tests/agent/test_anthropic_normalize_v2.py new file mode 100644 index 000000000..9d5c16139 --- /dev/null +++ b/tests/agent/test_anthropic_normalize_v2.py @@ -0,0 +1,238 @@ +"""Regression tests: normalize_anthropic_response_v2 vs v1. + +Constructs mock Anthropic responses and asserts that the v2 function +(returning NormalizedResponse) produces identical field values to the +original v1 function (returning SimpleNamespace + finish_reason). +""" + +import json +import pytest +from types import SimpleNamespace + +from agent.anthropic_adapter import ( + normalize_anthropic_response, + normalize_anthropic_response_v2, +) +from agent.transports.types import NormalizedResponse, ToolCall + + +# --------------------------------------------------------------------------- +# Helpers to build mock Anthropic SDK responses +# --------------------------------------------------------------------------- + +def _text_block(text: str): + return SimpleNamespace(type="text", text=text) + + +def _thinking_block(thinking: str, signature: str = "sig_abc"): + return SimpleNamespace(type="thinking", thinking=thinking, signature=signature) + + +def _tool_use_block(id: str, name: str, input: dict): + return SimpleNamespace(type="tool_use", id=id, name=name, input=input) + + +def _response(content_blocks, stop_reason="end_turn"): + return SimpleNamespace( + content=content_blocks, + stop_reason=stop_reason, + usage=SimpleNamespace( + input_tokens=10, + output_tokens=5, + ), + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestTextOnly: + """Text-only response — no tools, no thinking.""" + + def setup_method(self): + self.resp = _response([_text_block("Hello world")]) + self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp) + self.v2 = normalize_anthropic_response_v2(self.resp) + + def test_type(self): + assert isinstance(self.v2, NormalizedResponse) + + def test_content_matches(self): + assert self.v2.content == self.v1_msg.content + + def test_finish_reason_matches(self): + assert self.v2.finish_reason == self.v1_finish + + def test_no_tool_calls(self): + assert self.v2.tool_calls is None + assert self.v1_msg.tool_calls is None + + def test_no_reasoning(self): + assert self.v2.reasoning is None + assert self.v1_msg.reasoning is None + + +class TestWithToolCalls: + """Response with tool calls.""" + + def setup_method(self): + self.resp = _response( + [ + _text_block("I'll check that"), + _tool_use_block("toolu_abc", "terminal", {"command": "ls"}), + _tool_use_block("toolu_def", "read_file", {"path": "/tmp"}), + ], + stop_reason="tool_use", + ) + self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp) + self.v2 = normalize_anthropic_response_v2(self.resp) + + def test_finish_reason(self): + assert self.v2.finish_reason == "tool_calls" + assert self.v1_finish == "tool_calls" + + def test_tool_call_count(self): + assert len(self.v2.tool_calls) == 2 + assert len(self.v1_msg.tool_calls) == 2 + + def test_tool_call_ids_match(self): + for i in range(2): + assert self.v2.tool_calls[i].id == self.v1_msg.tool_calls[i].id + + def test_tool_call_names_match(self): + assert self.v2.tool_calls[0].name == "terminal" + assert self.v2.tool_calls[1].name == "read_file" + for i in range(2): + assert self.v2.tool_calls[i].name == self.v1_msg.tool_calls[i].function.name + + def test_tool_call_arguments_match(self): + for i in range(2): + assert self.v2.tool_calls[i].arguments == self.v1_msg.tool_calls[i].function.arguments + + def test_content_preserved(self): + assert self.v2.content == self.v1_msg.content + assert "check that" in self.v2.content + + +class TestWithThinking: + """Response with thinking blocks (Claude 3.5+ extended thinking).""" + + def setup_method(self): + self.resp = _response([ + _thinking_block("Let me think about this carefully..."), + _text_block("The answer is 42."), + ]) + self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp) + self.v2 = normalize_anthropic_response_v2(self.resp) + + def test_reasoning_matches(self): + assert self.v2.reasoning == self.v1_msg.reasoning + assert "think about this" in self.v2.reasoning + + def test_reasoning_details_in_provider_data(self): + v1_details = self.v1_msg.reasoning_details + v2_details = self.v2.provider_data.get("reasoning_details") if self.v2.provider_data else None + assert v1_details is not None + assert v2_details is not None + assert len(v2_details) == len(v1_details) + + def test_content_excludes_thinking(self): + assert self.v2.content == "The answer is 42." + + +class TestMixed: + """Response with thinking + text + tool calls.""" + + def setup_method(self): + self.resp = _response( + [ + _thinking_block("Planning my approach..."), + _text_block("I'll run the command"), + _tool_use_block("toolu_xyz", "terminal", {"command": "pwd"}), + ], + stop_reason="tool_use", + ) + self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp) + self.v2 = normalize_anthropic_response_v2(self.resp) + + def test_all_fields_present(self): + assert self.v2.content is not None + assert self.v2.tool_calls is not None + assert self.v2.reasoning is not None + assert self.v2.finish_reason == "tool_calls" + + def test_content_matches(self): + assert self.v2.content == self.v1_msg.content + + def test_reasoning_matches(self): + assert self.v2.reasoning == self.v1_msg.reasoning + + def test_tool_call_matches(self): + assert self.v2.tool_calls[0].id == self.v1_msg.tool_calls[0].id + assert self.v2.tool_calls[0].name == self.v1_msg.tool_calls[0].function.name + + +class TestStopReasons: + """Verify finish_reason mapping matches between v1 and v2.""" + + @pytest.mark.parametrize("stop_reason,expected", [ + ("end_turn", "stop"), + ("tool_use", "tool_calls"), + ("max_tokens", "length"), + ("stop_sequence", "stop"), + ("refusal", "content_filter"), + ("model_context_window_exceeded", "length"), + ("unknown_future_reason", "stop"), + ]) + def test_stop_reason_mapping(self, stop_reason, expected): + resp = _response([_text_block("x")], stop_reason=stop_reason) + v1_msg, v1_finish = normalize_anthropic_response(resp) + v2 = normalize_anthropic_response_v2(resp) + assert v2.finish_reason == v1_finish == expected + + +class TestStripToolPrefix: + """Verify mcp_ prefix stripping works identically.""" + + def test_prefix_stripped(self): + resp = _response( + [_tool_use_block("toolu_1", "mcp_terminal", {"cmd": "ls"})], + stop_reason="tool_use", + ) + v1_msg, _ = normalize_anthropic_response(resp, strip_tool_prefix=True) + v2 = normalize_anthropic_response_v2(resp, strip_tool_prefix=True) + assert v1_msg.tool_calls[0].function.name == "terminal" + assert v2.tool_calls[0].name == "terminal" + + def test_prefix_kept(self): + resp = _response( + [_tool_use_block("toolu_1", "mcp_terminal", {"cmd": "ls"})], + stop_reason="tool_use", + ) + v1_msg, _ = normalize_anthropic_response(resp, strip_tool_prefix=False) + v2 = normalize_anthropic_response_v2(resp, strip_tool_prefix=False) + assert v1_msg.tool_calls[0].function.name == "mcp_terminal" + assert v2.tool_calls[0].name == "mcp_terminal" + + +class TestEdgeCases: + """Edge cases: empty content, no blocks, etc.""" + + def test_empty_content_blocks(self): + resp = _response([]) + v1_msg, v1_finish = normalize_anthropic_response(resp) + v2 = normalize_anthropic_response_v2(resp) + assert v2.content == v1_msg.content + assert v2.content is None + + def test_no_reasoning_details_means_none_provider_data(self): + resp = _response([_text_block("hi")]) + v2 = normalize_anthropic_response_v2(resp) + assert v2.provider_data is None + + def test_v2_returns_dataclass_not_namespace(self): + resp = _response([_text_block("hi")]) + v2 = normalize_anthropic_response_v2(resp) + assert isinstance(v2, NormalizedResponse) + assert not isinstance(v2, SimpleNamespace) diff --git a/tests/agent/transports/__init__.py b/tests/agent/transports/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/agent/transports/test_types.py b/tests/agent/transports/test_types.py new file mode 100644 index 000000000..0be18c688 --- /dev/null +++ b/tests/agent/transports/test_types.py @@ -0,0 +1,151 @@ +"""Tests for agent/transports/types.py — dataclass construction + helpers.""" + +import json +import pytest + +from agent.transports.types import ( + NormalizedResponse, + ToolCall, + Usage, + build_tool_call, + map_finish_reason, +) + + +# --------------------------------------------------------------------------- +# ToolCall +# --------------------------------------------------------------------------- + +class TestToolCall: + def test_basic_construction(self): + tc = ToolCall(id="call_abc", name="terminal", arguments='{"cmd": "ls"}') + assert tc.id == "call_abc" + assert tc.name == "terminal" + assert tc.arguments == '{"cmd": "ls"}' + assert tc.provider_data is None + + def test_none_id(self): + tc = ToolCall(id=None, name="read_file", arguments="{}") + assert tc.id is None + + def test_provider_data(self): + tc = ToolCall( + id="call_x", + name="t", + arguments="{}", + provider_data={"call_id": "call_x", "response_item_id": "fc_x"}, + ) + assert tc.provider_data["call_id"] == "call_x" + assert tc.provider_data["response_item_id"] == "fc_x" + + +# --------------------------------------------------------------------------- +# Usage +# --------------------------------------------------------------------------- + +class TestUsage: + def test_defaults(self): + u = Usage() + assert u.prompt_tokens == 0 + assert u.completion_tokens == 0 + assert u.total_tokens == 0 + assert u.cached_tokens == 0 + + def test_explicit(self): + u = Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150, cached_tokens=80) + assert u.total_tokens == 150 + + +# --------------------------------------------------------------------------- +# NormalizedResponse +# --------------------------------------------------------------------------- + +class TestNormalizedResponse: + def test_text_only(self): + r = NormalizedResponse(content="hello", tool_calls=None, finish_reason="stop") + assert r.content == "hello" + assert r.tool_calls is None + assert r.finish_reason == "stop" + assert r.reasoning is None + assert r.usage is None + assert r.provider_data is None + + def test_with_tool_calls(self): + tcs = [ToolCall(id="call_1", name="terminal", arguments='{"cmd":"pwd"}')] + r = NormalizedResponse(content=None, tool_calls=tcs, finish_reason="tool_calls") + assert r.finish_reason == "tool_calls" + assert len(r.tool_calls) == 1 + assert r.tool_calls[0].name == "terminal" + + def test_with_reasoning(self): + r = NormalizedResponse( + content="answer", + tool_calls=None, + finish_reason="stop", + reasoning="I thought about it", + ) + assert r.reasoning == "I thought about it" + + def test_with_provider_data(self): + r = NormalizedResponse( + content=None, + tool_calls=None, + finish_reason="stop", + provider_data={"reasoning_details": [{"type": "thinking", "thinking": "hmm"}]}, + ) + assert r.provider_data["reasoning_details"][0]["type"] == "thinking" + + +# --------------------------------------------------------------------------- +# build_tool_call +# --------------------------------------------------------------------------- + +class TestBuildToolCall: + def test_dict_arguments_serialized(self): + tc = build_tool_call(id="call_1", name="terminal", arguments={"cmd": "ls"}) + assert tc.arguments == json.dumps({"cmd": "ls"}) + assert tc.provider_data is None + + def test_string_arguments_passthrough(self): + tc = build_tool_call(id="call_2", name="read_file", arguments='{"path": "/tmp"}') + assert tc.arguments == '{"path": "/tmp"}' + + def test_provider_fields(self): + tc = build_tool_call( + id="call_3", + name="terminal", + arguments="{}", + call_id="call_3", + response_item_id="fc_3", + ) + assert tc.provider_data == {"call_id": "call_3", "response_item_id": "fc_3"} + + def test_none_id(self): + tc = build_tool_call(id=None, name="t", arguments="{}") + assert tc.id is None + + +# --------------------------------------------------------------------------- +# map_finish_reason +# --------------------------------------------------------------------------- + +class TestMapFinishReason: + ANTHROPIC_MAP = { + "end_turn": "stop", + "tool_use": "tool_calls", + "max_tokens": "length", + "stop_sequence": "stop", + "refusal": "content_filter", + } + + def test_known_reason(self): + assert map_finish_reason("end_turn", self.ANTHROPIC_MAP) == "stop" + assert map_finish_reason("tool_use", self.ANTHROPIC_MAP) == "tool_calls" + assert map_finish_reason("max_tokens", self.ANTHROPIC_MAP) == "length" + assert map_finish_reason("refusal", self.ANTHROPIC_MAP) == "content_filter" + + def test_unknown_reason_defaults_to_stop(self): + assert map_finish_reason("something_new", self.ANTHROPIC_MAP) == "stop" + + def test_none_reason(self): + assert map_finish_reason(None, self.ANTHROPIC_MAP) == "stop"