Merge branch 'main' into rewbs/tool-use-charge-to-subscription

This commit is contained in:
Ben Barclay 2026-04-02 11:00:35 +11:00
commit a2e56d044b
175 changed files with 18848 additions and 3772 deletions

View file

@ -230,6 +230,27 @@ class TestStripThinkBlocks:
assert "line1" not in result
assert "visible" in result
def test_orphaned_closing_think_tag(self, agent):
result = agent._strip_think_blocks("some reasoning</think>actual answer")
assert "</think>" not in result
assert "actual answer" in result
def test_orphaned_closing_thinking_tag(self, agent):
result = agent._strip_think_blocks("reasoning</thinking>answer")
assert "</thinking>" not in result
assert "answer" in result
def test_orphaned_opening_think_tag(self, agent):
result = agent._strip_think_blocks("<think>orphaned reasoning without close")
assert "<think>" not in result
def test_mixed_orphaned_and_paired_tags(self, agent):
text = "stray</think><think>paired reasoning</think> visible"
result = agent._strip_think_blocks(text)
assert "</think>" not in result
assert "<think>" not in result
assert "visible" in result
class TestExtractReasoning:
def test_reasoning_field(self, agent):
@ -1223,6 +1244,42 @@ class TestConcurrentToolExecution:
)
assert result == "result"
def test_sequential_tool_callbacks_fire_in_order(self, agent):
tool_call = _mock_tool_call(name="web_search", arguments='{"query":"hello"}', call_id="c1")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call])
messages = []
starts = []
completes = []
agent.tool_start_callback = lambda tool_call_id, function_name, function_args: starts.append((tool_call_id, function_name, function_args))
agent.tool_complete_callback = lambda tool_call_id, function_name, function_args, function_result: completes.append((tool_call_id, function_name, function_args, function_result))
with patch("run_agent.handle_function_call", return_value='{"success": true}'):
agent._execute_tool_calls_sequential(mock_msg, messages, "task-1")
assert starts == [("c1", "web_search", {"query": "hello"})]
assert completes == [("c1", "web_search", {"query": "hello"}, '{"success": true}')]
def test_concurrent_tool_callbacks_fire_for_each_tool(self, agent):
tc1 = _mock_tool_call(name="web_search", arguments='{"query":"one"}', call_id="c1")
tc2 = _mock_tool_call(name="web_search", arguments='{"query":"two"}', call_id="c2")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
messages = []
starts = []
completes = []
agent.tool_start_callback = lambda tool_call_id, function_name, function_args: starts.append((tool_call_id, function_name, function_args))
agent.tool_complete_callback = lambda tool_call_id, function_name, function_args, function_result: completes.append((tool_call_id, function_name, function_args, function_result))
with patch("run_agent.handle_function_call", side_effect=['{"id":1}', '{"id":2}']):
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
assert starts == [
("c1", "web_search", {"query": "one"}),
("c2", "web_search", {"query": "two"}),
]
assert len(completes) == 2
assert {entry[0] for entry in completes} == {"c1", "c2"}
assert {entry[3] for entry in completes} == {'{"id":1}', '{"id":2}'}
def test_invoke_tool_handles_agent_level_tools(self, agent):
"""_invoke_tool should handle todo tool directly."""
with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}') as mock_todo:
@ -1264,6 +1321,38 @@ class TestPathsOverlap:
assert not _paths_overlap(Path("src/a.py"), Path(""))
class TestParallelScopePathNormalization:
def test_extract_parallel_scope_path_normalizes_relative_to_cwd(self, tmp_path, monkeypatch):
from run_agent import _extract_parallel_scope_path
monkeypatch.chdir(tmp_path)
scoped = _extract_parallel_scope_path("write_file", {"path": "./notes.txt"})
assert scoped == tmp_path / "notes.txt"
def test_extract_parallel_scope_path_treats_relative_and_absolute_same_file_as_same_scope(self, tmp_path, monkeypatch):
from run_agent import _extract_parallel_scope_path, _paths_overlap
monkeypatch.chdir(tmp_path)
abs_path = tmp_path / "notes.txt"
rel_scoped = _extract_parallel_scope_path("write_file", {"path": "notes.txt"})
abs_scoped = _extract_parallel_scope_path("write_file", {"path": str(abs_path)})
assert rel_scoped == abs_scoped
assert _paths_overlap(rel_scoped, abs_scoped)
def test_should_parallelize_tool_batch_rejects_same_file_with_mixed_path_spellings(self, tmp_path, monkeypatch):
from run_agent import _should_parallelize_tool_batch
monkeypatch.chdir(tmp_path)
tc1 = _mock_tool_call(name="write_file", arguments='{"path":"notes.txt","content":"one"}', call_id="c1")
tc2 = _mock_tool_call(name="write_file", arguments=f'{{"path":"{tmp_path / "notes.txt"}","content":"two"}}', call_id="c2")
assert not _should_parallelize_tool_batch([tc1, tc2])
class TestHandleMaxIterations:
def test_returns_summary(self, agent):
resp = _mock_response(content="Here is a summary of what I did.")
@ -1776,6 +1865,127 @@ class TestNousCredentialRefresh:
assert isinstance(agent.client, _RebuiltClient)
class TestCredentialPoolRecovery:
def test_recover_with_pool_rotates_on_402(self, agent):
current = SimpleNamespace(label="primary")
next_entry = SimpleNamespace(label="secondary")
class _Pool:
def current(self):
return current
def mark_exhausted_and_rotate(self, *, status_code):
assert status_code == 402
return next_entry
agent._credential_pool = _Pool()
agent._swap_credential = MagicMock()
recovered, retry_same = agent._recover_with_credential_pool(
status_code=402,
has_retried_429=False,
)
assert recovered is True
assert retry_same is False
agent._swap_credential.assert_called_once_with(next_entry)
def test_recover_with_pool_retries_first_429_then_rotates(self, agent):
next_entry = SimpleNamespace(label="secondary")
class _Pool:
def current(self):
return SimpleNamespace(label="primary")
def mark_exhausted_and_rotate(self, *, status_code):
assert status_code == 429
return next_entry
agent._credential_pool = _Pool()
agent._swap_credential = MagicMock()
recovered, retry_same = agent._recover_with_credential_pool(
status_code=429,
has_retried_429=False,
)
assert recovered is False
assert retry_same is True
agent._swap_credential.assert_not_called()
recovered, retry_same = agent._recover_with_credential_pool(
status_code=429,
has_retried_429=True,
)
assert recovered is True
assert retry_same is False
agent._swap_credential.assert_called_once_with(next_entry)
def test_recover_with_pool_refreshes_on_401(self, agent):
"""401 with successful refresh should swap to refreshed credential."""
refreshed_entry = SimpleNamespace(label="refreshed-primary", id="abc")
class _Pool:
def try_refresh_current(self):
return refreshed_entry
agent._credential_pool = _Pool()
agent._swap_credential = MagicMock()
recovered, retry_same = agent._recover_with_credential_pool(
status_code=401,
has_retried_429=False,
)
assert recovered is True
agent._swap_credential.assert_called_once_with(refreshed_entry)
def test_recover_with_pool_rotates_on_401_when_refresh_fails(self, agent):
"""401 with failed refresh should rotate to next credential."""
next_entry = SimpleNamespace(label="secondary", id="def")
class _Pool:
def try_refresh_current(self):
return None # refresh failed
def mark_exhausted_and_rotate(self, *, status_code):
assert status_code == 401
return next_entry
agent._credential_pool = _Pool()
agent._swap_credential = MagicMock()
recovered, retry_same = agent._recover_with_credential_pool(
status_code=401,
has_retried_429=False,
)
assert recovered is True
assert retry_same is False
agent._swap_credential.assert_called_once_with(next_entry)
def test_recover_with_pool_401_refresh_fails_no_more_credentials(self, agent):
"""401 with failed refresh and no other credentials returns not recovered."""
class _Pool:
def try_refresh_current(self):
return None
def mark_exhausted_and_rotate(self, *, status_code):
return None # no more credentials
agent._credential_pool = _Pool()
agent._swap_credential = MagicMock()
recovered, retry_same = agent._recover_with_credential_pool(
status_code=401,
has_retried_429=False,
)
assert recovered is False
agent._swap_credential.assert_not_called()
class TestMaxTokensParam:
"""Verify _max_tokens_param returns the correct key for each provider."""
@ -2604,6 +2814,46 @@ def test_is_openai_client_closed_honors_custom_client_flag():
assert AIAgent._is_openai_client_closed(SimpleNamespace(is_closed=False)) is False
def test_is_openai_client_closed_handles_method_form():
"""Fix for issue #4377: is_closed as method (openai SDK) vs property (httpx).
The openai SDK's is_closed is a method, not a property. Prior to this fix,
getattr(client, "is_closed", False) returned the bound method object, which
is always truthy, causing the function to incorrectly report all clients as
closed and triggering unnecessary client recreation on every API call.
"""
class MethodFormClient:
"""Mimics openai.OpenAI where is_closed() is a method."""
def __init__(self, closed: bool):
self._closed = closed
def is_closed(self) -> bool:
return self._closed
# Method returning False - client is open
open_client = MethodFormClient(closed=False)
assert AIAgent._is_openai_client_closed(open_client) is False
# Method returning True - client is closed
closed_client = MethodFormClient(closed=True)
assert AIAgent._is_openai_client_closed(closed_client) is True
def test_is_openai_client_closed_falls_back_to_http_client():
"""Verify fallback to _client.is_closed when top-level is_closed is None."""
class ClientWithHttpClient:
is_closed = None # No top-level is_closed
def __init__(self, http_closed: bool):
self._client = SimpleNamespace(is_closed=http_closed)
assert AIAgent._is_openai_client_closed(ClientWithHttpClient(http_closed=False)) is False
assert AIAgent._is_openai_client_closed(ClientWithHttpClient(http_closed=True)) is True
class TestAnthropicBaseUrlPassthrough:
"""Bug fix: base_url was filtered with 'anthropic in base_url', blocking proxies."""