hermes-agent/agent/turn_context.py
teknium1 54870847cb refactor(agent): extract run_conversation prologue into agent/turn_context.py
Phase 1 of the god-file decomposition plan. run_conversation's ~470-line
once-per-turn setup block (stdio guarding, retry-counter resets, user-message
sanitization, todo/nudge hydration, system-prompt restore-or-build,
crash-resilience persistence, preflight compression, the pre_llm_call hook, and
external-memory prefetch) is moved verbatim into build_turn_context(), which
returns a TurnContext dataclass the loop unpacks.

Behavior-neutral move-and-name refactor: the builder mutates `agent` exactly as
the inline code did; only the locals the loop reads back are returned.

- run_conversation: 4602 -> 4217 LOC (-385)
- agent/conversation_loop.py: 4965 -> ~4580 LOC
- new agent/turn_context.py: focused, dependency-injected, unit-tested in isolation

Tests: tests/run_agent/ 1570 passed / 0 failed under per-file process isolation.
Relocation follow-ups: 413_compression mocks now patch both module references;
nudge/on_turn_start source-inspection guards point at the extracted module.
2026-06-07 22:17:35 -07:00

388 lines
16 KiB
Python

"""Per-turn setup for ``run_conversation`` (the turn prologue).
``run_conversation`` opened with ~470 lines of straight-line setup before the
tool-calling loop ever started: stdio guarding, runtime-main wiring, retry-counter
resets, user-message sanitization, todo/nudge-counter hydration, system-prompt
restore-or-build, crash-resilience persistence, preflight context compression, the
``pre_llm_call`` plugin hook, and external-memory prefetch.
All of that is *prologue* — it runs once per turn, has no back-references into the
loop, and produces a fixed set of values the loop then consumes. ``TurnContext``
captures those produced values; ``build_turn_context`` performs the setup work and
returns one. ``run_conversation`` is left to unpack the context and run the loop,
shrinking the orchestrator by the full prologue.
The builder still mutates ``agent`` heavily (counters, thread id, cached prompt,
session DB) exactly as the inline code did — those side effects are the point. The
``TurnContext`` it returns carries only the *locals* the loop reads back.
Behavior is identical to the original inline prologue; this is a pure
move-and-name refactor with no semantic change.
"""
from __future__ import annotations
import logging
import threading
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from agent.iteration_budget import IterationBudget
from agent.model_metadata import estimate_request_tokens_rough
logger = logging.getLogger(__name__)
@dataclass
class TurnContext:
"""Values produced by the turn prologue and consumed by the turn loop."""
# Sanitized inbound message (surrogates stripped).
user_message: str
# Clean message preserved for transcripts / memory queries (no nudge injection).
original_user_message: Any
# Working message list for this turn (loop appends to it).
messages: List[Dict[str, Any]]
# May be reset to None by preflight compression (new session created).
conversation_history: Optional[List[Dict[str, Any]]]
# Cached system prompt active for this turn (may be rebuilt by compression).
active_system_prompt: Optional[str]
# Task / turn identifiers.
effective_task_id: str
turn_id: str
# Index of the current user turn within ``messages``.
current_turn_user_idx: int
# Whether the post-turn memory review should fire.
should_review_memory: bool = False
# Context contributed by ``pre_llm_call`` plugins (appended to user message).
plugin_user_context: str = ""
# External-memory prefetch result, reused across loop iterations.
ext_prefetch_cache: str = ""
def build_turn_context(
agent,
user_message: str,
system_message: Optional[str],
conversation_history: Optional[List[Dict[str, Any]]],
task_id: Optional[str],
stream_callback,
persist_user_message: Optional[str],
*,
restore_or_build_system_prompt,
install_safe_stdio,
sanitize_surrogates,
summarize_user_message_for_log,
set_session_context,
set_current_write_origin,
ra,
) -> TurnContext:
"""Run the once-per-turn setup and return the loop's input context.
The callables/helpers the original prologue referenced from the
``conversation_loop`` module are passed in explicitly to keep this module
free of an import cycle with ``agent.conversation_loop``.
"""
# Guard stdio against OSError from broken pipes (systemd/headless/daemon).
install_safe_stdio()
agent._ensure_db_session()
# Tell auxiliary_client what the live main provider/model are for this turn.
try:
from agent.auxiliary_client import set_runtime_main
set_runtime_main(
getattr(agent, "provider", "") or "",
getattr(agent, "model", "") or "",
base_url=getattr(agent, "base_url", "") or "",
api_key=getattr(agent, "api_key", "") or "",
api_mode=getattr(agent, "api_mode", "") or "",
)
except Exception:
pass
# Tag log records on this thread with the session ID for ``hermes logs``.
set_session_context(agent.session_id)
# Bind the skill write-origin ContextVar for this thread.
set_current_write_origin(getattr(agent, "_memory_write_origin", "assistant_tool"))
# Restore the primary runtime if the previous turn activated fallback.
agent._restore_primary_runtime()
# Sanitize surrogate characters from user input.
if isinstance(user_message, str):
user_message = sanitize_surrogates(user_message)
if isinstance(persist_user_message, str):
persist_user_message = sanitize_surrogates(persist_user_message)
# Store stream callback for _interruptible_api_call to pick up.
agent._stream_callback = stream_callback
agent._persist_user_message_idx = None
agent._persist_user_message_override = persist_user_message
# Generate unique task_id if not provided to isolate VMs between tasks.
effective_task_id = task_id or str(uuid.uuid4())
agent._current_task_id = effective_task_id
turn_id = f"{agent.session_id or 'session'}:{effective_task_id}:{uuid.uuid4().hex[:8]}"
agent._current_turn_id = turn_id
agent._current_api_request_id = ""
# Reset retry counters and iteration budget at the start of each turn.
agent._invalid_tool_retries = 0
agent._invalid_json_retries = 0
agent._empty_content_retries = 0
agent._incomplete_scratchpad_retries = 0
agent._codex_incomplete_retries = 0
agent._thinking_prefill_retries = 0
agent._post_tool_empty_retried = False
agent._last_content_with_tools = None
agent._last_content_tools_all_housekeeping = False
agent._mute_post_response = False
agent._unicode_sanitization_passes = 0
agent._tool_guardrails.reset_for_turn()
agent._tool_guardrail_halt_decision = None
agent._vision_supported = True
# Pre-turn connection health check: clean up dead TCP connections.
if agent.api_mode != "anthropic_messages":
try:
if agent._cleanup_dead_connections():
agent._emit_status(
"🔌 Detected stale connections from a previous provider "
"issue — cleaned up automatically. Proceeding with fresh "
"connection."
)
except Exception:
pass
# Replay compression warning through status_callback for gateway platforms.
if agent._compression_warning:
agent._replay_compression_warning()
agent._compression_warning = None # send once
# NOTE: _turns_since_memory and _iters_since_skill are NOT reset here.
agent.iteration_budget = IterationBudget(agent.max_iterations)
# Log conversation turn start for debugging/observability.
_preview_text = summarize_user_message_for_log(user_message)
_msg_preview = (_preview_text[:80] + "...") if len(_preview_text) > 80 else _preview_text
_msg_preview = _msg_preview.replace("\n", " ")
logger.info(
"conversation turn: session=%s model=%s provider=%s platform=%s history=%d msg=%r",
agent.session_id or "none", agent.model, agent.provider or "unknown",
agent.platform or "unknown", len(conversation_history or []),
_msg_preview,
)
# Initialize conversation (copy to avoid mutating the caller's list).
messages = list(conversation_history) if conversation_history else []
# Hydrate todo store from conversation history.
if conversation_history and not agent._todo_store.has_items():
agent._hydrate_todo_store(conversation_history)
# Hydrate per-session nudge counters from persisted history (issue #22357).
if conversation_history and agent._user_turn_count == 0:
prior_user_turns = sum(
1 for m in conversation_history if m.get("role") == "user"
)
if prior_user_turns > 0:
agent._user_turn_count = prior_user_turns
if agent._memory_nudge_interval > 0 and agent._turns_since_memory == 0:
agent._turns_since_memory = prior_user_turns % agent._memory_nudge_interval
# Track user turns for memory flush and periodic nudge logic.
agent._user_turn_count += 1
# Reset the streaming context scrubber at the top of each turn.
scrubber = getattr(agent, "_stream_context_scrubber", None)
if scrubber is not None:
scrubber.reset()
# Reset the think scrubber for the same reason.
think_scrubber = getattr(agent, "_stream_think_scrubber", None)
if think_scrubber is not None:
think_scrubber.reset()
# Preserve the original user message (no nudge injection).
original_user_message = persist_user_message if persist_user_message is not None else user_message
# Track memory nudge trigger (turn-based, checked here).
should_review_memory = False
if (agent._memory_nudge_interval > 0
and "memory" in agent.valid_tool_names
and agent._memory_store):
agent._turns_since_memory += 1
if agent._turns_since_memory >= agent._memory_nudge_interval:
should_review_memory = True
agent._turns_since_memory = 0
# Add user message.
user_msg = {"role": "user", "content": user_message}
messages.append(user_msg)
current_turn_user_idx = len(messages) - 1
agent._persist_user_message_idx = current_turn_user_idx
if not agent.quiet_mode:
_print_preview = summarize_user_message_for_log(user_message)
agent._safe_print(
f"💬 Starting conversation: '{_print_preview[:60]}"
f"{'...' if len(_print_preview) > 60 else ''}'"
)
# ── System prompt (cached per session for prefix caching) ──
if agent._cached_system_prompt is None:
restore_or_build_system_prompt(agent, system_message, conversation_history)
active_system_prompt = agent._cached_system_prompt
# Crash-resilience: persist the inbound user turn as soon as the session row exists.
try:
agent._persist_session(messages, conversation_history)
except Exception:
logger.warning(
"Early turn-start session persistence failed for session=%s",
agent.session_id or "none",
exc_info=True,
)
# ── Preflight context compression ──
if (
agent.compression_enabled
and len(messages) > agent.context_compressor.protect_first_n
+ agent.context_compressor.protect_last_n + 1
):
_preflight_tokens = estimate_request_tokens_rough(
messages,
system_prompt=active_system_prompt or "",
tools=agent.tools or None,
)
_compressor = agent.context_compressor
_defer_preflight = getattr(
_compressor,
"should_defer_preflight_to_real_usage",
lambda _tokens: False,
)
_preflight_deferred = _defer_preflight(_preflight_tokens)
if not _preflight_deferred:
_last = _compressor.last_prompt_tokens
# Do NOT overwrite the -1 sentinel (#36718).
if _last >= 0 and _preflight_tokens > _last:
_compressor.last_prompt_tokens = _preflight_tokens
if _preflight_deferred:
logger.info(
"Skipping preflight compression: rough estimate ~%s >= %s, "
"but last real provider prompt was %s after compression",
f"{_preflight_tokens:,}",
f"{_compressor.threshold_tokens:,}",
f"{_compressor.last_real_prompt_tokens:,}",
)
elif _compressor.should_compress(_preflight_tokens):
logger.info(
"Preflight compression: ~%s tokens >= %s threshold (model %s, ctx %s)",
f"{_preflight_tokens:,}",
f"{_compressor.threshold_tokens:,}",
agent.model,
f"{_compressor.context_length:,}",
)
agent._emit_status(
f"📦 Preflight compression: ~{_preflight_tokens:,} tokens "
f">= {_compressor.threshold_tokens:,} threshold. "
"This may take a moment."
)
for _pass in range(3):
_orig_len = len(messages)
messages, active_system_prompt = agent._compress_context(
messages, system_message, approx_tokens=_preflight_tokens,
task_id=effective_task_id,
)
if len(messages) >= _orig_len:
break # Cannot compress further
conversation_history = None
agent._empty_content_retries = 0
agent._thinking_prefill_retries = 0
agent._last_content_with_tools = None
agent._last_content_tools_all_housekeeping = False
agent._mute_post_response = False
_preflight_tokens = estimate_request_tokens_rough(
messages,
system_prompt=active_system_prompt or "",
tools=agent.tools or None,
)
if not _compressor.should_compress(_preflight_tokens):
break
# Plugin hook: pre_llm_call (context injected into user message, not system prompt).
plugin_user_context = ""
try:
from hermes_cli.plugins import invoke_hook as _invoke_hook
_pre_results = _invoke_hook(
"pre_llm_call",
session_id=agent.session_id,
task_id=effective_task_id,
turn_id=turn_id,
user_message=original_user_message,
conversation_history=list(messages),
is_first_turn=(not bool(conversation_history)),
model=agent.model,
platform=getattr(agent, "platform", None) or "",
sender_id=getattr(agent, "_user_id", None) or "",
)
_ctx_parts: list[str] = []
for r in _pre_results:
if isinstance(r, dict) and r.get("context"):
_ctx_parts.append(str(r["context"]))
elif isinstance(r, str) and r.strip():
_ctx_parts.append(r)
if _ctx_parts:
plugin_user_context = "\n\n".join(_ctx_parts)
except Exception as exc:
logger.warning("pre_llm_call hook failed: %s", exc)
# Per-turn file-mutation verifier state.
agent._turn_failed_file_mutations = {}
# Record the execution thread so interrupt()/clear_interrupt() can scope
# the tool-level interrupt signal to THIS agent's thread only.
agent._execution_thread_id = threading.current_thread().ident
# Clear stale per-thread interrupt state, preserving a pending interrupt.
ra()._set_interrupt(False, agent._execution_thread_id)
if agent._interrupt_requested:
ra()._set_interrupt(True, agent._execution_thread_id)
agent._interrupt_thread_signal_pending = False
else:
agent._interrupt_message = None
agent._interrupt_thread_signal_pending = False
# Notify memory providers of the new turn (BEFORE prefetch_all).
if agent._memory_manager:
try:
_turn_msg = original_user_message if isinstance(original_user_message, str) else ""
agent._memory_manager.on_turn_start(agent._user_turn_count, _turn_msg)
except Exception:
pass
# External memory provider: prefetch once before the tool loop.
ext_prefetch_cache = ""
if agent._memory_manager:
try:
_query = original_user_message if isinstance(original_user_message, str) else ""
ext_prefetch_cache = agent._memory_manager.prefetch_all(_query) or ""
except Exception:
pass
return TurnContext(
user_message=user_message,
original_user_message=original_user_message,
messages=messages,
conversation_history=conversation_history,
active_system_prompt=active_system_prompt,
effective_task_id=effective_task_id,
turn_id=turn_id,
current_turn_user_idx=current_turn_user_idx,
should_review_memory=should_review_memory,
plugin_user_context=plugin_user_context,
ext_prefetch_cache=ext_prefetch_cache,
)