fix(agent): guard against non-dict model_extra in tool call normalization

Providers like NVIDIA NIM can return model_extra as a string rather
than a dict.  The previous (model_extra or {}).get(...) pattern
crashes with AttributeError on truthy non-dict values.  Use isinstance
check instead, applied in both the non-streaming normalize_response
path and the streaming accumulation path.

Supersedes #15157.
This commit is contained in:
Tranquil-Flow 2026-04-25 09:21:16 +10:00
parent 00c3d848d8
commit 20818d68c6
3 changed files with 128 additions and 2 deletions

View file

@ -315,7 +315,7 @@ class ChatCompletionsTransport(ProviderTransport):
tc_provider_data: Dict[str, Any] = {} tc_provider_data: Dict[str, Any] = {}
extra = getattr(tc, "extra_content", None) extra = getattr(tc, "extra_content", None)
if extra is None and hasattr(tc, "model_extra"): if extra is None and hasattr(tc, "model_extra"):
extra = (tc.model_extra or {}).get("extra_content") extra = (tc.model_extra if isinstance(tc.model_extra, dict) else {}).get("extra_content")
if extra is not None: if extra is not None:
if hasattr(extra, "model_dump"): if hasattr(extra, "model_dump"):
try: try:

View file

@ -6154,7 +6154,7 @@ class AIAgent:
entry["function"]["arguments"] += tc_delta.function.arguments entry["function"]["arguments"] += tc_delta.function.arguments
extra = getattr(tc_delta, "extra_content", None) extra = getattr(tc_delta, "extra_content", None)
if extra is None and hasattr(tc_delta, "model_extra"): if extra is None and hasattr(tc_delta, "model_extra"):
extra = (tc_delta.model_extra or {}).get("extra_content") extra = (tc_delta.model_extra if isinstance(tc_delta.model_extra, dict) else {}).get("extra_content")
if extra is not None: if extra is not None:
if hasattr(extra, "model_dump"): if hasattr(extra, "model_dump"):
extra = extra.model_dump() extra = extra.model_dump()

View file

@ -0,0 +1,126 @@
"""Tests for model_extra type guard in tool call normalization.
Providers like NVIDIA NIM may return model_extra as a string instead
of a dict, causing AttributeError on .get() calls. The isinstance
guard prevents this crash.
"""
import unittest
from types import SimpleNamespace
from agent.transports.chat_completions import ChatCompletionsTransport
from agent.transports.types import ToolCall
class TestModelExtraTypeGuard(unittest.TestCase):
"""Ensure the isinstance(dict) guard handles all model_extra types."""
def _extract(self, model_extra):
"""Replicate the guarded extraction pattern used in production."""
return (model_extra if isinstance(model_extra, dict) else {}).get(
"extra_content"
)
def test_string_no_crash(self):
"""String model_extra must not raise AttributeError."""
self.assertIsNone(self._extract("unexpected_string"))
def test_none_no_crash(self):
self.assertIsNone(self._extract(None))
def test_dict_extracts_extra_content(self):
self.assertEqual(
self._extract({"extra_content": {"key": "val"}}),
{"key": "val"},
)
def test_empty_dict(self):
self.assertIsNone(self._extract({}))
def test_integer_no_crash(self):
self.assertIsNone(self._extract(42))
def test_list_no_crash(self):
self.assertIsNone(self._extract(["a", "b"]))
def test_bool_no_crash(self):
"""Boolean True is truthy but not a dict."""
self.assertIsNone(self._extract(True))
class TestNormalizeResponseModelExtraGuard(unittest.TestCase):
"""Integration: normalize_response must not crash on non-dict model_extra."""
def test_string_model_extra_normalize(self):
"""Tool call with string model_extra should normalize without error."""
transport = ChatCompletionsTransport.__new__(ChatCompletionsTransport)
tc = SimpleNamespace(
id="call_1",
type="function",
function=SimpleNamespace(name="test_tool", arguments='{"x": 1}'),
extra_content=None,
model_extra="nvidia_nim_extra_string",
)
choice = SimpleNamespace(
index=0,
message=SimpleNamespace(
role="assistant",
content=None,
tool_calls=[tc],
refusal=None,
),
finish_reason="tool_calls",
)
response = SimpleNamespace(
id="resp_1",
choices=[choice],
usage=SimpleNamespace(
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
),
model="test-model",
)
result = transport.normalize_response(response)
self.assertEqual(len(result.tool_calls), 1)
self.assertEqual(result.tool_calls[0].name, "test_tool")
def test_dict_model_extra_with_extra_content(self):
"""Dict model_extra with extra_content should be preserved."""
transport = ChatCompletionsTransport.__new__(ChatCompletionsTransport)
tc = SimpleNamespace(
id="call_1",
type="function",
function=SimpleNamespace(name="test_tool", arguments='{}'),
extra_content=None,
model_extra={"extra_content": {"thought_signature": "abc123"}},
)
choice = SimpleNamespace(
index=0,
message=SimpleNamespace(
role="assistant",
content=None,
tool_calls=[tc],
refusal=None,
),
finish_reason="tool_calls",
)
response = SimpleNamespace(
id="resp_1",
choices=[choice],
usage=SimpleNamespace(
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
),
model="test-model",
)
result = transport.normalize_response(response)
self.assertEqual(len(result.tool_calls), 1)
self.assertEqual(
result.tool_calls[0].provider_data.get("extra_content"),
{"thought_signature": "abc123"},
)