increasing per-chat timeout (re api issues ergh), and tweaked logging

This commit is contained in:
Shannon Sands 2026-02-05 14:54:34 +10:00
parent 487487406d
commit beac2ee06a
2 changed files with 12 additions and 5 deletions

View file

@ -438,12 +438,12 @@ class AtroposAgent:
- `ATROPOS_AGENT_CHAT_TIMEOUT_S`: if set, wraps the await in `asyncio.wait_for`. - `ATROPOS_AGENT_CHAT_TIMEOUT_S`: if set, wraps the await in `asyncio.wait_for`.
- `ATROPOS_DEBUG_AGENT_WAIT_EVERY_S`: if set, prints a heartbeat while waiting. - `ATROPOS_DEBUG_AGENT_WAIT_EVERY_S`: if set, prints a heartbeat while waiting.
""" """
# Hard guardrail: never allow a single chat completion to block for more than 2 minutes. # Hard guardrail: never allow a single chat completion to block for too long.
# This is essential for RL data-gen stability; long hangs should be treated as failures (score=0). # This is essential for RL data-gen stability; long hangs should be treated as failures (score=0).
timeout_s_raw = os.getenv("ATROPOS_AGENT_CHAT_TIMEOUT_S") timeout_s_raw = os.getenv("ATROPOS_AGENT_CHAT_TIMEOUT_S")
timeout_s_default = 120.0 timeout_s_default = 240.0
timeout_s = float(timeout_s_raw) if timeout_s_raw else timeout_s_default timeout_s = float(timeout_s_raw) if timeout_s_raw else timeout_s_default
timeout_s = min(timeout_s, 120.0) timeout_s = min(timeout_s, 240.0)
wait_every_raw = os.getenv("ATROPOS_DEBUG_AGENT_WAIT_EVERY_S") wait_every_raw = os.getenv("ATROPOS_DEBUG_AGENT_WAIT_EVERY_S")
wait_every_s = float(wait_every_raw) if wait_every_raw else None wait_every_s = float(wait_every_raw) if wait_every_raw else None

View file

@ -37,8 +37,11 @@ class SweSmithOracleEnvConfig(AgentEnvConfig):
python_only: bool = Field(default=True, description="Filter to Python-evaluable rows") python_only: bool = Field(default=True, description="Filter to Python-evaluable rows")
score_include_fail_to_pass: bool = Field( score_include_fail_to_pass: bool = Field(
default=False, default=True,
description="If true, score tests on PASS_TO_PASS FAIL_TO_PASS; else PASS_TO_PASS only.", description=(
"If true (default), score tests on PASS_TO_PASS FAIL_TO_PASS. "
"Disable to only run PASS_TO_PASS (faster but weaker signal)."
),
) )
prompt_mode: str = Field( prompt_mode: str = Field(
@ -347,6 +350,10 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
# Training correctness: do not reward trajectories that never actually used tools. # Training correctness: do not reward trajectories that never actually used tools.
if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0: if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0:
print(
f"[SweSmithOracleEnv] tid={trajectory_id} verify (dataset_tests): no tool calls; score=0.0",
flush=True,
)
return 0.0, { return 0.0, {
"verification_mode": "dataset_tests", "verification_mode": "dataset_tests",
"error": "No tool calls were made by the agent", "error": "No tool calls were made by the agent",