fix(bedrock): preserve reasoningContent across converse normalization

This commit is contained in:
Molvikar 2026-05-05 22:30:50 +03:00 committed by Teknium
parent f0dd5b9c10
commit 8d363f8d54
3 changed files with 33 additions and 2 deletions

View file

@ -631,11 +631,18 @@ def normalize_converse_response(response: Dict) -> SimpleNamespace:
stop_reason = response.get("stopReason", "end_turn") stop_reason = response.get("stopReason", "end_turn")
text_parts = [] text_parts = []
reasoning_parts = []
tool_calls = [] tool_calls = []
for block in content_blocks: for block in content_blocks:
if "text" in block: if "text" in block:
text_parts.append(block["text"]) text_parts.append(block["text"])
elif "reasoningContent" in block:
reasoning = block["reasoningContent"]
if isinstance(reasoning, dict):
thinking_text = reasoning.get("text", "")
if thinking_text:
reasoning_parts.append(str(thinking_text))
elif "toolUse" in block: elif "toolUse" in block:
tu = block["toolUse"] tu = block["toolUse"]
tool_calls.append(SimpleNamespace( tool_calls.append(SimpleNamespace(
@ -652,6 +659,7 @@ def normalize_converse_response(response: Dict) -> SimpleNamespace:
role="assistant", role="assistant",
content="\n".join(text_parts) if text_parts else None, content="\n".join(text_parts) if text_parts else None,
tool_calls=tool_calls if tool_calls else None, tool_calls=tool_calls if tool_calls else None,
reasoning_content="\n\n".join(reasoning_parts) if reasoning_parts else None,
) )
# Build usage stats # Build usage stats
@ -732,6 +740,7 @@ def stream_converse_with_callbacks(
``normalize_converse_response()``. ``normalize_converse_response()``.
""" """
text_parts: List[str] = [] text_parts: List[str] = []
reasoning_parts: List[str] = []
tool_calls: List[SimpleNamespace] = [] tool_calls: List[SimpleNamespace] = []
current_tool: Optional[Dict] = None current_tool: Optional[Dict] = None
current_text_buffer: List[str] = [] current_text_buffer: List[str] = []
@ -777,8 +786,10 @@ def stream_converse_with_callbacks(
reasoning = delta["reasoningContent"] reasoning = delta["reasoningContent"]
if isinstance(reasoning, dict): if isinstance(reasoning, dict):
thinking_text = reasoning.get("text", "") thinking_text = reasoning.get("text", "")
if thinking_text and on_reasoning_delta: if thinking_text:
on_reasoning_delta(thinking_text) reasoning_parts.append(str(thinking_text))
if on_reasoning_delta:
on_reasoning_delta(thinking_text)
elif "contentBlockStop" in event: elif "contentBlockStop" in event:
if current_tool is not None: if current_tool is not None:
@ -817,6 +828,7 @@ def stream_converse_with_callbacks(
role="assistant", role="assistant",
content="\n".join(text_parts) if text_parts else None, content="\n".join(text_parts) if text_parts else None,
tool_calls=tool_calls if tool_calls else None, tool_calls=tool_calls if tool_calls else None,
reasoning_content="\n\n".join(reasoning_parts) if reasoning_parts else None,
) )
usage = SimpleNamespace( usage = SimpleNamespace(

View file

@ -994,6 +994,7 @@ class TestStreamConverseWithCallbacks:
events, on_reasoning_delta=lambda t: reasoning.append(t), events, on_reasoning_delta=lambda t: reasoning.append(t),
) )
assert reasoning == ["Let me think..."] assert reasoning == ["Let me think..."]
assert result.choices[0].message.reasoning_content == "Let me think..."
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View file

@ -142,6 +142,24 @@ class TestBedrockNormalize:
assert len(nr.tool_calls) == 1 assert len(nr.tool_calls) == 1
assert nr.tool_calls[0].name == "terminal" assert nr.tool_calls[0].name == "terminal"
def test_raw_reasoning_content_response(self, transport):
raw = {
"output": {
"message": {
"role": "assistant",
"content": [
{"reasoningContent": {"text": "Let me think..."}},
{"text": "Answer."},
],
}
},
"stopReason": "end_turn",
"usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15},
}
nr = transport.normalize_response(raw)
assert nr.reasoning == "Let me think..."
assert nr.content == "Answer."
def test_already_normalized_response(self, transport): def test_already_normalized_response(self, transport):
"""Test normalize_response handles already-normalized SimpleNamespace (from dispatch site).""" """Test normalize_response handles already-normalized SimpleNamespace (from dispatch site)."""
pre_normalized = SimpleNamespace( pre_normalized = SimpleNamespace(