adding some more debugging, hitting endpoint errors or some other slowdown

This commit is contained in:
Shannon Sands 2026-02-05 08:59:14 +10:00
parent ea7aa0b0d4
commit 5b82190460
3 changed files with 68 additions and 5 deletions

View file

@ -333,6 +333,61 @@ class AtroposAgent:
text = str(dumped)
print(text[:200_000], flush=True)
async def _chat_completion_with_debug(
self, *, managed: Any, step_num: int, chat_kwargs: Dict[str, Any]
) -> Any:
"""
Call `managed.chat_completion()` with optional timeout + richer failure logging.
Debug env vars:
- `ATROPOS_AGENT_CHAT_TIMEOUT_S`: if set, wraps the await in `asyncio.wait_for`.
- `ATROPOS_DEBUG_AGENT_WAIT_EVERY_S`: if set, prints a heartbeat while waiting.
"""
timeout_s_raw = os.getenv("ATROPOS_AGENT_CHAT_TIMEOUT_S")
timeout_s = float(timeout_s_raw) if timeout_s_raw else None
wait_every_raw = os.getenv("ATROPOS_DEBUG_AGENT_WAIT_EVERY_S")
wait_every_s = float(wait_every_raw) if wait_every_raw else None
async def _await_call() -> Any:
if not wait_every_s or wait_every_s <= 0:
return await managed.chat_completion(**chat_kwargs)
task = asyncio.create_task(managed.chat_completion(**chat_kwargs))
t0 = time.perf_counter()
while True:
try:
return await asyncio.wait_for(task, timeout=wait_every_s)
except TimeoutError:
waited = time.perf_counter() - t0
print(
f"[AtroposAgent] step={step_num} still waiting for chat_completion... ({waited:.1f}s)",
flush=True,
)
try:
if timeout_s and timeout_s > 0:
return await asyncio.wait_for(_await_call(), timeout=timeout_s)
return await _await_call()
except Exception as e:
detail: Dict[str, Any] = {
"step": step_num,
"exc_type": type(e).__name__,
"exc_str": str(e),
}
if isinstance(e, httpx.HTTPStatusError):
try:
detail["status_code"] = e.response.status_code
detail["response_text"] = e.response.text[:20_000]
except Exception:
pass
elif isinstance(e, httpx.RequestError):
detail["request"] = repr(getattr(e, "request", None))
print("\n=== ATROPOS_DEBUG_AGENT_CHAT_FAILURE ===", flush=True)
print(detail, flush=True)
raise
async def run(
self,
task: str,
@ -386,7 +441,9 @@ class AtroposAgent:
flush=True,
)
self._debug_dump_request(step_num=step_num + 1, chat_kwargs=chat_kwargs)
response = await managed.chat_completion(**chat_kwargs)
response = await self._chat_completion_with_debug(
managed=managed, step_num=step_num + 1, chat_kwargs=chat_kwargs
)
self._debug_dump_response(step_num=step_num + 1, response=response)
print(
f"[AtroposAgent] step={step_num+1} chat_completion done in {time.perf_counter() - t_req:.2f}s",
@ -541,7 +598,7 @@ class AtroposAgent:
chat_kwargs["temperature"] = self.config.temperature
self._debug_dump_request(step_num=1, chat_kwargs=chat_kwargs)
response = await managed.chat_completion(**chat_kwargs)
response = await self._chat_completion_with_debug(managed=managed, step_num=1, chat_kwargs=chat_kwargs)
self._debug_dump_response(step_num=1, response=response)
current_node = None

View file

@ -6,7 +6,7 @@ providing helpers for running agent trajectories with queued/batched tool calls.
"""
from __future__ import annotations
import os
import asyncio
import time
import uuid
@ -320,7 +320,13 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
print(f"[AgentEnv] collect_trajectory(): tid={trajectory_id} start", flush=True)
task = self.build_task(item)
agent_config = self.build_agent_config(item)
print(f"Starting trajectory {trajectory_id} with task: {task}")
if os.getenv("ATROPOS_DEBUG_PRINT_TASK") == "1":
print(f"Starting trajectory {trajectory_id} with task: {task}", flush=True)
else:
# Avoid printing the full task prompt by default (can be huge/noisy).
one_line = " ".join(str(task).splitlines()).strip()
preview = one_line[:240] + ("" if len(one_line) > 240 else "")
print(f"Starting trajectory {trajectory_id} (task preview): {preview}", flush=True)
async def _exec(call):
return await self._tool_executor.execute(trajectory_id, call)

View file

@ -110,7 +110,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
api_key=api_key,
num_max_requests_at_once=1,
num_requests_for_eval=1,
timeout=300,
timeout=int(os.getenv("ATROPOS_SERVER_TIMEOUT_S") or "300"),
),
]