mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-29 06:31:32 +00:00
refactor(run_agent): extract 10 more helpers to agent/agent_runtime_helpers.py
Final extraction pass — the methods left over after run_conversation and __init__ moved out. Together these 10 cover ~813 LOC of medium- sized helpers: * switch_model (194 LOC) — model switching mid-session * _invoke_tool (87) — central tool dispatch with overrides * _repair_tool_call (72) — argument JSON repair entrypoint * _sanitize_api_messages (71) — role-filter for API send * _looks_like_codex_intermediate_ack (72) — codex transcript heuristic * _copy_reasoning_content_for_api (70) — reasoning preservation * _cleanup_dead_connections (70) — periodic dead-socket sweep * _extract_api_error_context (65) — error-dump context builder * _apply_pending_steer_to_tool_results (63) — /steer injection * _force_close_tcp_sockets (59) — aggressive socket cleanup AIAgent keeps thin forwarder methods for all 10 (staticmethods preserved where present). Names tests patch on run_agent (handle_function_call, AIAgent class attrs, logger) routed through _ra() so the patch surface is preserved. tests/run_agent/ + tests/agent/: 4313 passed (same pre-existing test_auxiliary_client failure as on main). run_agent.py: 4634 -> 3821 lines (-813). Final total: 16083 -> 3821 (-12262, 76% reduction).
This commit is contained in:
parent
9f408989c4
commit
94c3e0ab8e
2 changed files with 890 additions and 804 deletions
823
run_agent.py
823
run_agent.py
|
|
@ -492,198 +492,9 @@ class AIAgent:
|
|||
logger.debug("LM Studio preload skipped: %s", err)
|
||||
|
||||
def switch_model(self, new_model, new_provider, api_key='', base_url='', api_mode=''):
|
||||
"""Switch the model/provider in-place for a live agent.
|
||||
|
||||
Called by the /model command handlers (CLI and gateway) after
|
||||
``model_switch.switch_model()`` has resolved credentials and
|
||||
validated the model. This method performs the actual runtime
|
||||
swap: rebuilding clients, updating caching flags, and refreshing
|
||||
the context compressor.
|
||||
|
||||
The implementation mirrors ``_try_activate_fallback()`` for the
|
||||
client-swap logic but also updates ``_primary_runtime`` so the
|
||||
change persists across turns (unlike fallback which is
|
||||
turn-scoped).
|
||||
"""
|
||||
from hermes_cli.providers import determine_api_mode
|
||||
|
||||
# ── Determine api_mode if not provided ──
|
||||
if not api_mode:
|
||||
api_mode = determine_api_mode(new_provider, base_url)
|
||||
|
||||
# Defense-in-depth: ensure OpenCode base_url doesn't carry a trailing
|
||||
# /v1 into the anthropic_messages client, which would cause the SDK to
|
||||
# hit /v1/v1/messages. `model_switch.switch_model()` already strips
|
||||
# this, but we guard here so any direct callers (future code paths,
|
||||
# tests) can't reintroduce the double-/v1 404 bug.
|
||||
if (
|
||||
api_mode == "anthropic_messages"
|
||||
and new_provider in {"opencode-zen", "opencode-go"}
|
||||
and isinstance(base_url, str)
|
||||
and base_url
|
||||
):
|
||||
base_url = re.sub(r"/v1/?$", "", base_url)
|
||||
|
||||
old_model = self.model
|
||||
old_provider = self.provider
|
||||
|
||||
# Clear the per-config context_length override so the new model's
|
||||
# actual context window is resolved via get_model_context_length()
|
||||
# instead of inheriting the stale value from the previous model.
|
||||
self._config_context_length = None
|
||||
|
||||
# ── Swap core runtime fields ──
|
||||
self.model = new_model
|
||||
self.provider = new_provider
|
||||
# Use new base_url when provided; only fall back to current when the
|
||||
# new provider genuinely has no endpoint (e.g. native SDK providers).
|
||||
# Without this guard the old provider's URL (e.g. Ollama's localhost
|
||||
# address) would persist silently after switching to a cloud provider
|
||||
# that returns an empty base_url string.
|
||||
if base_url:
|
||||
self.base_url = base_url
|
||||
self.api_mode = api_mode
|
||||
# Invalidate transport cache — new api_mode may need a different transport
|
||||
if hasattr(self, "_transport_cache"):
|
||||
self._transport_cache.clear()
|
||||
if api_key:
|
||||
self.api_key = api_key
|
||||
|
||||
# ── Build new client ──
|
||||
if api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import (
|
||||
build_anthropic_client,
|
||||
resolve_anthropic_token,
|
||||
_is_oauth_token,
|
||||
)
|
||||
# Only fall back to ANTHROPIC_TOKEN when the provider is actually Anthropic.
|
||||
# Other anthropic_messages providers (MiniMax, Alibaba, etc.) must use their own
|
||||
# API key — falling back would send Anthropic credentials to third-party endpoints.
|
||||
_is_native_anthropic = new_provider == "anthropic"
|
||||
effective_key = (api_key or self.api_key or resolve_anthropic_token() or "") if _is_native_anthropic else (api_key or self.api_key or "")
|
||||
self.api_key = effective_key
|
||||
self._anthropic_api_key = effective_key
|
||||
self._anthropic_base_url = base_url or getattr(self, "_anthropic_base_url", None)
|
||||
self._anthropic_client = build_anthropic_client(
|
||||
effective_key, self._anthropic_base_url,
|
||||
timeout=get_provider_request_timeout(self.provider, self.model),
|
||||
)
|
||||
self._is_anthropic_oauth = _is_oauth_token(effective_key) if _is_native_anthropic else False
|
||||
self.client = None
|
||||
self._client_kwargs = {}
|
||||
else:
|
||||
effective_key = api_key or self.api_key
|
||||
effective_base = base_url or self.base_url
|
||||
self._client_kwargs = {
|
||||
"api_key": effective_key,
|
||||
"base_url": effective_base,
|
||||
}
|
||||
_sm_timeout = get_provider_request_timeout(self.provider, self.model)
|
||||
if _sm_timeout is not None:
|
||||
self._client_kwargs["timeout"] = _sm_timeout
|
||||
self.client = self._create_openai_client(
|
||||
dict(self._client_kwargs),
|
||||
reason="switch_model",
|
||||
shared=True,
|
||||
)
|
||||
|
||||
# ── Re-evaluate prompt caching ──
|
||||
self._use_prompt_caching, self._use_native_cache_layout = (
|
||||
self._anthropic_prompt_cache_policy(
|
||||
provider=new_provider,
|
||||
base_url=self.base_url,
|
||||
api_mode=api_mode,
|
||||
model=new_model,
|
||||
)
|
||||
)
|
||||
|
||||
# ── LM Studio: preload before probing context length ──
|
||||
self._ensure_lmstudio_runtime_loaded()
|
||||
|
||||
# ── Update context compressor ──
|
||||
if hasattr(self, "context_compressor") and self.context_compressor:
|
||||
from agent.model_metadata import get_model_context_length
|
||||
# Re-read custom_providers from live config so per-model
|
||||
# context_length overrides are honored when switching to a
|
||||
# custom provider mid-session (closes #15779).
|
||||
_sm_custom_providers = None
|
||||
try:
|
||||
from hermes_cli.config import load_config, get_compatible_custom_providers
|
||||
_sm_cfg = load_config()
|
||||
_sm_custom_providers = get_compatible_custom_providers(_sm_cfg)
|
||||
except Exception:
|
||||
_sm_custom_providers = None
|
||||
new_context_length = get_model_context_length(
|
||||
self.model,
|
||||
base_url=self.base_url,
|
||||
api_key=self.api_key,
|
||||
provider=self.provider,
|
||||
config_context_length=getattr(self, "_config_context_length", None),
|
||||
custom_providers=_sm_custom_providers,
|
||||
)
|
||||
self.context_compressor.update_model(
|
||||
model=self.model,
|
||||
context_length=new_context_length,
|
||||
base_url=self.base_url,
|
||||
api_key=getattr(self, "api_key", ""),
|
||||
provider=self.provider,
|
||||
api_mode=self.api_mode,
|
||||
)
|
||||
|
||||
# ── Invalidate cached system prompt so it rebuilds next turn ──
|
||||
self._cached_system_prompt = None
|
||||
|
||||
# ── Update _primary_runtime so the change persists across turns ──
|
||||
_cc = self.context_compressor if hasattr(self, "context_compressor") and self.context_compressor else None
|
||||
self._primary_runtime = {
|
||||
"model": self.model,
|
||||
"provider": self.provider,
|
||||
"base_url": self.base_url,
|
||||
"api_mode": self.api_mode,
|
||||
"api_key": getattr(self, "api_key", ""),
|
||||
"client_kwargs": dict(self._client_kwargs),
|
||||
"use_prompt_caching": self._use_prompt_caching,
|
||||
"use_native_cache_layout": self._use_native_cache_layout,
|
||||
"compressor_model": getattr(_cc, "model", self.model) if _cc else self.model,
|
||||
"compressor_base_url": getattr(_cc, "base_url", self.base_url) if _cc else self.base_url,
|
||||
"compressor_api_key": getattr(_cc, "api_key", "") if _cc else "",
|
||||
"compressor_provider": getattr(_cc, "provider", self.provider) if _cc else self.provider,
|
||||
"compressor_context_length": _cc.context_length if _cc else 0,
|
||||
"compressor_threshold_tokens": _cc.threshold_tokens if _cc else 0,
|
||||
}
|
||||
if api_mode == "anthropic_messages":
|
||||
self._primary_runtime.update({
|
||||
"anthropic_api_key": self._anthropic_api_key,
|
||||
"anthropic_base_url": self._anthropic_base_url,
|
||||
"is_anthropic_oauth": self._is_anthropic_oauth,
|
||||
})
|
||||
|
||||
# ── Reset fallback state ──
|
||||
self._fallback_activated = False
|
||||
self._fallback_index = 0
|
||||
|
||||
# When the user deliberately swaps primary providers (e.g. openrouter
|
||||
# → anthropic), drop any fallback entries that target the OLD primary
|
||||
# or the NEW one. The chain was seeded from config at agent init for
|
||||
# the original provider — without pruning, a failed turn on the new
|
||||
# primary silently re-activates the provider the user just rejected,
|
||||
# which is exactly what was reported during TUI v2 blitz testing
|
||||
# ("switched to anthropic, tui keeps trying openrouter").
|
||||
old_norm = (old_provider or "").strip().lower()
|
||||
new_norm = (new_provider or "").strip().lower()
|
||||
fallback_chain = list(getattr(self, "_fallback_chain", []) or [])
|
||||
if old_norm and new_norm and old_norm != new_norm:
|
||||
fallback_chain = [
|
||||
entry for entry in fallback_chain
|
||||
if (entry.get("provider") or "").strip().lower() not in {old_norm, new_norm}
|
||||
]
|
||||
self._fallback_chain = fallback_chain
|
||||
self._fallback_model = fallback_chain[0] if fallback_chain else None
|
||||
|
||||
logging.info(
|
||||
"Model switched in-place: %s (%s) -> %s (%s)",
|
||||
old_model, old_provider, new_model, new_provider,
|
||||
)
|
||||
"""Forwarder — see ``agent.agent_runtime_helpers.switch_model``."""
|
||||
from agent.agent_runtime_helpers import switch_model
|
||||
return switch_model(self, new_model, new_provider, api_key, base_url, api_mode)
|
||||
|
||||
def _safe_print(self, *args, **kwargs):
|
||||
"""Print that silently handles broken pipes / closed stdout.
|
||||
|
|
@ -1134,71 +945,9 @@ class AIAgent:
|
|||
assistant_content: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> bool:
|
||||
"""Detect a planning/ack message that should continue instead of ending the turn."""
|
||||
if any(isinstance(msg, dict) and msg.get("role") == "tool" for msg in messages):
|
||||
return False
|
||||
|
||||
assistant_text = self._strip_think_blocks(assistant_content or "").strip().lower()
|
||||
if not assistant_text:
|
||||
return False
|
||||
if len(assistant_text) > 1200:
|
||||
return False
|
||||
|
||||
has_future_ack = bool(
|
||||
re.search(r"\b(i['’]ll|i will|let me|i can do that|i can help with that)\b", assistant_text)
|
||||
)
|
||||
if not has_future_ack:
|
||||
return False
|
||||
|
||||
action_markers = (
|
||||
"look into",
|
||||
"look at",
|
||||
"inspect",
|
||||
"scan",
|
||||
"check",
|
||||
"analyz",
|
||||
"review",
|
||||
"explore",
|
||||
"read",
|
||||
"open",
|
||||
"run",
|
||||
"test",
|
||||
"fix",
|
||||
"debug",
|
||||
"search",
|
||||
"find",
|
||||
"walkthrough",
|
||||
"report back",
|
||||
"summarize",
|
||||
)
|
||||
workspace_markers = (
|
||||
"directory",
|
||||
"current directory",
|
||||
"current dir",
|
||||
"cwd",
|
||||
"repo",
|
||||
"repository",
|
||||
"codebase",
|
||||
"project",
|
||||
"folder",
|
||||
"filesystem",
|
||||
"file tree",
|
||||
"files",
|
||||
"path",
|
||||
)
|
||||
|
||||
user_text = (user_message or "").strip().lower()
|
||||
user_targets_workspace = (
|
||||
any(marker in user_text for marker in workspace_markers)
|
||||
or "~/" in user_text
|
||||
or "/" in user_text
|
||||
)
|
||||
assistant_mentions_action = any(marker in assistant_text for marker in action_markers)
|
||||
assistant_targets_workspace = any(
|
||||
marker in assistant_text for marker in workspace_markers
|
||||
)
|
||||
return (user_targets_workspace or assistant_targets_workspace) and assistant_mentions_action
|
||||
|
||||
"""Forwarder — see ``agent.agent_runtime_helpers.looks_like_codex_intermediate_ack``."""
|
||||
from agent.agent_runtime_helpers import looks_like_codex_intermediate_ack
|
||||
return looks_like_codex_intermediate_ack(self, user_message, assistant_content, messages)
|
||||
|
||||
def _extract_reasoning(self, assistant_message) -> Optional[str]:
|
||||
"""Forwarder — see ``agent.agent_runtime_helpers.extract_reasoning``."""
|
||||
|
|
@ -1547,68 +1296,9 @@ class AIAgent:
|
|||
|
||||
@staticmethod
|
||||
def _extract_api_error_context(error: Exception) -> Dict[str, Any]:
|
||||
"""Extract structured rate-limit details from provider errors."""
|
||||
context: Dict[str, Any] = {}
|
||||
|
||||
body = getattr(error, "body", None)
|
||||
payload = None
|
||||
if isinstance(body, dict):
|
||||
payload = body.get("error") if isinstance(body.get("error"), dict) else body
|
||||
if isinstance(payload, dict):
|
||||
reason = payload.get("code") or payload.get("error")
|
||||
if isinstance(reason, str) and reason.strip():
|
||||
context["reason"] = reason.strip()
|
||||
message = payload.get("message") or payload.get("error_description")
|
||||
if isinstance(message, str) and message.strip():
|
||||
context["message"] = message.strip()
|
||||
for key in ("resets_at", "reset_at"):
|
||||
value = payload.get(key)
|
||||
if value not in {None, ""}:
|
||||
context["reset_at"] = value
|
||||
break
|
||||
retry_after = payload.get("retry_after")
|
||||
if retry_after not in {None, ""} and "reset_at" not in context:
|
||||
try:
|
||||
context["reset_at"] = time.time() + float(retry_after)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
response = getattr(error, "response", None)
|
||||
headers = getattr(response, "headers", None)
|
||||
if headers:
|
||||
retry_after = headers.get("retry-after") or headers.get("Retry-After")
|
||||
if retry_after and "reset_at" not in context:
|
||||
try:
|
||||
context["reset_at"] = time.time() + float(retry_after)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
ratelimit_reset = headers.get("x-ratelimit-reset")
|
||||
if ratelimit_reset and "reset_at" not in context:
|
||||
context["reset_at"] = ratelimit_reset
|
||||
|
||||
if "message" not in context:
|
||||
raw_message = str(error).strip()
|
||||
if raw_message:
|
||||
context["message"] = raw_message[:500]
|
||||
|
||||
if "reset_at" not in context:
|
||||
message = context.get("message") or ""
|
||||
if isinstance(message, str):
|
||||
delay_match = re.search(r"quotaResetDelay[:\s\"]+(\\d+(?:\\.\\d+)?)(ms|s)", message, re.IGNORECASE)
|
||||
if delay_match:
|
||||
value = float(delay_match.group(1))
|
||||
seconds = value / 1000.0 if delay_match.group(2).lower() == "ms" else value
|
||||
context["reset_at"] = time.time() + seconds
|
||||
else:
|
||||
sec_match = re.search(
|
||||
r"retry\s+(?:after\s+)?(\d+(?:\.\d+)?)\s*(?:sec|secs|seconds|s\b)",
|
||||
message,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
if sec_match:
|
||||
context["reset_at"] = time.time() + float(sec_match.group(1))
|
||||
|
||||
return context
|
||||
"""Forwarder — see ``agent.agent_runtime_helpers.extract_api_error_context``."""
|
||||
from agent.agent_runtime_helpers import extract_api_error_context
|
||||
return extract_api_error_context(error)
|
||||
|
||||
def _usage_summary_for_api_request_hook(self, response: Any) -> Optional[Dict[str, Any]]:
|
||||
"""Token buckets for ``post_api_request`` plugins (no raw ``response`` object)."""
|
||||
|
|
@ -1965,67 +1655,9 @@ class AIAgent:
|
|||
return "\n".join(lines)
|
||||
|
||||
def _apply_pending_steer_to_tool_results(self, messages: list, num_tool_msgs: int) -> None:
|
||||
"""Append any pending /steer text to the last tool result in this turn.
|
||||
|
||||
Called at the end of a tool-call batch, before the next API call.
|
||||
The steer is appended to the last ``role:"tool"`` message's content
|
||||
with a clear marker so the model understands it came from the user
|
||||
and NOT from the tool itself. Role alternation is preserved —
|
||||
nothing new is inserted, we only modify existing content.
|
||||
|
||||
Args:
|
||||
messages: The running messages list.
|
||||
num_tool_msgs: Number of tool results appended in this batch;
|
||||
used to locate the tail slice safely.
|
||||
"""
|
||||
if num_tool_msgs <= 0 or not messages:
|
||||
return
|
||||
steer_text = self._drain_pending_steer()
|
||||
if not steer_text:
|
||||
return
|
||||
# Find the last tool-role message in the recent tail. Skipping
|
||||
# non-tool messages defends against future code appending
|
||||
# something else at the boundary.
|
||||
target_idx = None
|
||||
for j in range(len(messages) - 1, max(len(messages) - num_tool_msgs - 1, -1), -1):
|
||||
msg = messages[j]
|
||||
if isinstance(msg, dict) and msg.get("role") == "tool":
|
||||
target_idx = j
|
||||
break
|
||||
if target_idx is None:
|
||||
# No tool result in this batch (e.g. all skipped by interrupt);
|
||||
# put the steer back so the caller's fallback path can deliver
|
||||
# it as a normal next-turn user message.
|
||||
_lock = getattr(self, "_pending_steer_lock", None)
|
||||
if _lock is not None:
|
||||
with _lock:
|
||||
if self._pending_steer:
|
||||
self._pending_steer = self._pending_steer + "\n" + steer_text
|
||||
else:
|
||||
self._pending_steer = steer_text
|
||||
else:
|
||||
existing = getattr(self, "_pending_steer", None)
|
||||
self._pending_steer = (existing + "\n" + steer_text) if existing else steer_text
|
||||
return
|
||||
marker = f"\n\nUser guidance: {steer_text}"
|
||||
existing_content = messages[target_idx].get("content", "")
|
||||
if not isinstance(existing_content, str):
|
||||
# Anthropic multimodal content blocks — preserve them and append
|
||||
# a text block at the end.
|
||||
try:
|
||||
blocks = list(existing_content) if existing_content else []
|
||||
blocks.append({"type": "text", "text": marker.lstrip()})
|
||||
messages[target_idx]["content"] = blocks
|
||||
except Exception:
|
||||
# Fall back to string replacement if content shape is unexpected.
|
||||
messages[target_idx]["content"] = f"{existing_content}{marker}"
|
||||
else:
|
||||
messages[target_idx]["content"] = existing_content + marker
|
||||
logger.info(
|
||||
"Delivered /steer to agent after tool batch (%d chars): %s",
|
||||
len(steer_text),
|
||||
steer_text[:120] + ("..." if len(steer_text) > 120 else ""),
|
||||
)
|
||||
"""Forwarder — see ``agent.agent_runtime_helpers.apply_pending_steer_to_tool_results``."""
|
||||
from agent.agent_runtime_helpers import apply_pending_steer_to_tool_results
|
||||
return apply_pending_steer_to_tool_results(self, messages, num_tool_msgs)
|
||||
|
||||
def _touch_activity(self, desc: str) -> None:
|
||||
"""Update the last-activity timestamp and description (thread-safe)."""
|
||||
|
|
@ -2383,74 +2015,9 @@ class AIAgent:
|
|||
|
||||
@staticmethod
|
||||
def _sanitize_api_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Fix orphaned tool_call / tool_result pairs before every LLM call.
|
||||
|
||||
Runs unconditionally — not gated on whether the context compressor
|
||||
is present — so orphans from session loading or manual message
|
||||
manipulation are always caught.
|
||||
"""
|
||||
# --- Role allowlist: drop messages with roles the API won't accept ---
|
||||
filtered = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
if role not in AIAgent._VALID_API_ROLES:
|
||||
logger.debug(
|
||||
"Pre-call sanitizer: dropping message with invalid role %r",
|
||||
role,
|
||||
)
|
||||
continue
|
||||
filtered.append(msg)
|
||||
messages = filtered
|
||||
|
||||
surviving_call_ids: set = set()
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
cid = AIAgent._get_tool_call_id_static(tc)
|
||||
if cid:
|
||||
surviving_call_ids.add(cid)
|
||||
|
||||
result_call_ids: set = set()
|
||||
for msg in messages:
|
||||
if msg.get("role") == "tool":
|
||||
cid = msg.get("tool_call_id")
|
||||
if cid:
|
||||
result_call_ids.add(cid)
|
||||
|
||||
# 1. Drop tool results with no matching assistant call
|
||||
orphaned_results = result_call_ids - surviving_call_ids
|
||||
if orphaned_results:
|
||||
messages = [
|
||||
m for m in messages
|
||||
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
|
||||
]
|
||||
logger.debug(
|
||||
"Pre-call sanitizer: removed %d orphaned tool result(s)",
|
||||
len(orphaned_results),
|
||||
)
|
||||
|
||||
# 2. Inject stub results for calls whose result was dropped
|
||||
missing_results = surviving_call_ids - result_call_ids
|
||||
if missing_results:
|
||||
patched: List[Dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
patched.append(msg)
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
cid = AIAgent._get_tool_call_id_static(tc)
|
||||
if cid in missing_results:
|
||||
patched.append({
|
||||
"role": "tool",
|
||||
"name": AIAgent._get_tool_call_name_static(tc),
|
||||
"content": "[Result unavailable — see context summary above]",
|
||||
"tool_call_id": cid,
|
||||
})
|
||||
messages = patched
|
||||
logger.debug(
|
||||
"Pre-call sanitizer: added %d stub tool result(s)",
|
||||
len(missing_results),
|
||||
)
|
||||
return messages
|
||||
"""Forwarder — see ``agent.agent_runtime_helpers.sanitize_api_messages``."""
|
||||
from agent.agent_runtime_helpers import sanitize_api_messages
|
||||
return sanitize_api_messages(messages)
|
||||
|
||||
@staticmethod
|
||||
def _is_thinking_only_assistant(msg: Dict[str, Any]) -> bool:
|
||||
|
|
@ -2564,76 +2131,9 @@ class AIAgent:
|
|||
return unique if len(unique) < len(tool_calls) else tool_calls
|
||||
|
||||
def _repair_tool_call(self, tool_name: str) -> str | None:
|
||||
"""Attempt to repair a mismatched tool name before aborting.
|
||||
|
||||
Models sometimes emit variants of a tool name that differ only
|
||||
in casing, separators, or class-like suffixes. Normalize
|
||||
aggressively before falling back to fuzzy match:
|
||||
|
||||
1. Lowercase direct match.
|
||||
2. Lowercase + hyphens/spaces -> underscores.
|
||||
3. CamelCase -> snake_case (TodoTool -> todo_tool).
|
||||
4. Strip trailing ``_tool`` / ``-tool`` / ``tool`` suffix that
|
||||
Claude-style models sometimes tack on (TodoTool_tool ->
|
||||
TodoTool -> Todo -> todo). Applied twice so double-tacked
|
||||
suffixes like ``TodoTool_tool`` reduce all the way.
|
||||
5. Fuzzy match (difflib, cutoff=0.7).
|
||||
|
||||
See #14784 for the original reports (TodoTool_tool, Patch_tool,
|
||||
BrowserClick_tool were all returning "Unknown tool" before).
|
||||
|
||||
Returns the repaired name if found in valid_tool_names, else None.
|
||||
"""
|
||||
import re
|
||||
from difflib import get_close_matches
|
||||
|
||||
if not tool_name:
|
||||
return None
|
||||
|
||||
def _norm(s: str) -> str:
|
||||
return s.lower().replace("-", "_").replace(" ", "_")
|
||||
|
||||
def _camel_snake(s: str) -> str:
|
||||
return re.sub(r"(?<!^)(?=[A-Z])", "_", s).lower()
|
||||
|
||||
def _strip_tool_suffix(s: str) -> str | None:
|
||||
lc = s.lower()
|
||||
for suffix in ("_tool", "-tool", "tool"):
|
||||
if lc.endswith(suffix):
|
||||
return s[: -len(suffix)].rstrip("_-")
|
||||
return None
|
||||
|
||||
# Cheap fast-paths first — these cover the common case.
|
||||
lowered = tool_name.lower()
|
||||
if lowered in self.valid_tool_names:
|
||||
return lowered
|
||||
normalized = _norm(tool_name)
|
||||
if normalized in self.valid_tool_names:
|
||||
return normalized
|
||||
|
||||
# Build the full candidate set for class-like emissions.
|
||||
cands: set[str] = {tool_name, lowered, normalized, _camel_snake(tool_name)}
|
||||
# Strip trailing tool-suffix up to twice — TodoTool_tool needs it.
|
||||
for _ in range(2):
|
||||
extra: set[str] = set()
|
||||
for c in cands:
|
||||
stripped = _strip_tool_suffix(c)
|
||||
if stripped:
|
||||
extra.add(stripped)
|
||||
extra.add(_norm(stripped))
|
||||
extra.add(_camel_snake(stripped))
|
||||
cands |= extra
|
||||
|
||||
for c in cands:
|
||||
if c and c in self.valid_tool_names:
|
||||
return c
|
||||
|
||||
# Fuzzy match as last resort.
|
||||
matches = get_close_matches(lowered, self.valid_tool_names, n=1, cutoff=0.7)
|
||||
if matches:
|
||||
return matches[0]
|
||||
|
||||
return None
|
||||
"""Forwarder — see ``agent.agent_runtime_helpers.repair_tool_call``."""
|
||||
from agent.agent_runtime_helpers import repair_tool_call
|
||||
return repair_tool_call(self, tool_name)
|
||||
|
||||
def _invalidate_system_prompt(self):
|
||||
"""Forwarder — see ``agent.system_prompt.invalidate_system_prompt``."""
|
||||
|
|
@ -2745,62 +2245,9 @@ class AIAgent:
|
|||
|
||||
@staticmethod
|
||||
def _force_close_tcp_sockets(client: Any) -> int:
|
||||
"""Force-close underlying TCP sockets to prevent CLOSE-WAIT accumulation.
|
||||
|
||||
When a provider drops a connection mid-stream, httpx's ``client.close()``
|
||||
performs a graceful shutdown which leaves sockets in CLOSE-WAIT until the
|
||||
OS times them out (often minutes). This method walks the httpx transport
|
||||
pool and issues ``socket.shutdown(SHUT_RDWR)`` + ``socket.close()`` to
|
||||
force an immediate TCP RST, freeing the file descriptors.
|
||||
|
||||
Returns the number of sockets force-closed.
|
||||
"""
|
||||
import socket as _socket
|
||||
|
||||
closed = 0
|
||||
try:
|
||||
http_client = getattr(client, "_client", None)
|
||||
if http_client is None:
|
||||
return 0
|
||||
transport = getattr(http_client, "_transport", None)
|
||||
if transport is None:
|
||||
return 0
|
||||
pool = getattr(transport, "_pool", None)
|
||||
if pool is None:
|
||||
return 0
|
||||
# httpx uses httpcore connection pools; connections live in
|
||||
# _connections (list) or _pool (list) depending on version.
|
||||
connections = (
|
||||
getattr(pool, "_connections", None)
|
||||
or getattr(pool, "_pool", None)
|
||||
or []
|
||||
)
|
||||
for conn in list(connections):
|
||||
stream = (
|
||||
getattr(conn, "_network_stream", None)
|
||||
or getattr(conn, "_stream", None)
|
||||
)
|
||||
if stream is None:
|
||||
continue
|
||||
sock = getattr(stream, "_sock", None)
|
||||
if sock is None:
|
||||
sock = getattr(stream, "stream", None)
|
||||
if sock is not None:
|
||||
sock = getattr(sock, "_sock", None)
|
||||
if sock is None:
|
||||
continue
|
||||
try:
|
||||
sock.shutdown(_socket.SHUT_RDWR)
|
||||
except OSError:
|
||||
pass
|
||||
try:
|
||||
sock.close()
|
||||
except OSError:
|
||||
pass
|
||||
closed += 1
|
||||
except Exception as exc:
|
||||
logger.debug("Force-close TCP sockets sweep error: %s", exc)
|
||||
return closed
|
||||
"""Forwarder — see ``agent.agent_runtime_helpers.force_close_tcp_sockets``."""
|
||||
from agent.agent_runtime_helpers import force_close_tcp_sockets
|
||||
return force_close_tcp_sockets(client)
|
||||
|
||||
def _close_openai_client(self, client: Any, *, reason: str, shared: bool) -> None:
|
||||
if client is None:
|
||||
|
|
@ -2860,74 +2307,9 @@ class AIAgent:
|
|||
return self.client
|
||||
|
||||
def _cleanup_dead_connections(self) -> bool:
|
||||
"""Detect and clean up dead TCP connections on the primary client.
|
||||
|
||||
Inspects the httpx connection pool for sockets in unhealthy states
|
||||
(CLOSE-WAIT, errors). If any are found, force-closes all sockets
|
||||
and rebuilds the primary client from scratch.
|
||||
|
||||
Returns True if dead connections were found and cleaned up.
|
||||
"""
|
||||
client = getattr(self, "client", None)
|
||||
if client is None:
|
||||
return False
|
||||
try:
|
||||
http_client = getattr(client, "_client", None)
|
||||
if http_client is None:
|
||||
return False
|
||||
transport = getattr(http_client, "_transport", None)
|
||||
if transport is None:
|
||||
return False
|
||||
pool = getattr(transport, "_pool", None)
|
||||
if pool is None:
|
||||
return False
|
||||
connections = (
|
||||
getattr(pool, "_connections", None)
|
||||
or getattr(pool, "_pool", None)
|
||||
or []
|
||||
)
|
||||
dead_count = 0
|
||||
for conn in list(connections):
|
||||
# Check for connections that are idle but have closed sockets
|
||||
stream = (
|
||||
getattr(conn, "_network_stream", None)
|
||||
or getattr(conn, "_stream", None)
|
||||
)
|
||||
if stream is None:
|
||||
continue
|
||||
sock = getattr(stream, "_sock", None)
|
||||
if sock is None:
|
||||
sock = getattr(stream, "stream", None)
|
||||
if sock is not None:
|
||||
sock = getattr(sock, "_sock", None)
|
||||
if sock is None:
|
||||
continue
|
||||
# Probe socket health with a non-blocking recv peek
|
||||
import socket as _socket
|
||||
try:
|
||||
sock.setblocking(False)
|
||||
data = sock.recv(1, _socket.MSG_PEEK | _socket.MSG_DONTWAIT)
|
||||
if data == b"":
|
||||
dead_count += 1
|
||||
except BlockingIOError:
|
||||
pass # No data available — socket is healthy
|
||||
except OSError:
|
||||
dead_count += 1
|
||||
finally:
|
||||
try:
|
||||
sock.setblocking(True)
|
||||
except OSError:
|
||||
pass
|
||||
if dead_count > 0:
|
||||
logger.warning(
|
||||
"Found %d dead connection(s) in client pool — rebuilding client",
|
||||
dead_count,
|
||||
)
|
||||
self._replace_primary_openai_client(reason="dead_connection_cleanup")
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.debug("Dead connection check error: %s", exc)
|
||||
return False
|
||||
"""Forwarder — see ``agent.agent_runtime_helpers.cleanup_dead_connections``."""
|
||||
from agent.agent_runtime_helpers import cleanup_dead_connections
|
||||
return cleanup_dead_connections(self)
|
||||
|
||||
@staticmethod
|
||||
def _api_kwargs_have_image_parts(api_kwargs: dict) -> bool:
|
||||
|
|
@ -4039,74 +3421,9 @@ class AIAgent:
|
|||
)
|
||||
|
||||
def _copy_reasoning_content_for_api(self, source_msg: dict, api_msg: dict) -> None:
|
||||
"""Copy provider-facing reasoning fields onto an API replay message."""
|
||||
if source_msg.get("role") != "assistant":
|
||||
return
|
||||
|
||||
# 1. Explicit reasoning_content already set — preserve it verbatim
|
||||
# (includes DeepSeek/Kimi's own space-placeholder written at creation
|
||||
# time, and any valid reasoning content from the same provider).
|
||||
#
|
||||
# Exception: sessions persisted BEFORE #17341 have empty-string
|
||||
# placeholders pinned at creation time. DeepSeek V4 Pro rejects
|
||||
# those with HTTP 400. When the active provider enforces the
|
||||
# thinking-mode echo, upgrade "" → " " on replay so stale history
|
||||
# doesn't 400 the user on the next turn.
|
||||
existing = source_msg.get("reasoning_content")
|
||||
if isinstance(existing, str):
|
||||
if existing == "" and self._needs_thinking_reasoning_pad():
|
||||
api_msg["reasoning_content"] = " "
|
||||
else:
|
||||
api_msg["reasoning_content"] = existing
|
||||
return
|
||||
|
||||
needs_thinking_pad = self._needs_thinking_reasoning_pad()
|
||||
|
||||
# 2. Cross-provider poisoned history (#15748): on DeepSeek/Kimi,
|
||||
# if the source turn has tool_calls AND a 'reasoning' field but no
|
||||
# 'reasoning_content' key, the 'reasoning' text was written by a
|
||||
# prior provider (e.g. MiniMax) — DeepSeek's own _build_assistant_message
|
||||
# pins reasoning_content at creation time for tool-call turns, so the
|
||||
# shape (reasoning set, reasoning_content absent, tool_calls present)
|
||||
# is unreachable from same-provider DeepSeek history after this fix.
|
||||
# Inject a single space to satisfy the API without leaking another
|
||||
# provider's chain of thought to DeepSeek/Kimi. Space (not "")
|
||||
# because DeepSeek V4 Pro rejects empty-string reasoning_content
|
||||
# in thinking mode (refs #17341).
|
||||
normalized_reasoning = source_msg.get("reasoning")
|
||||
if (
|
||||
needs_thinking_pad
|
||||
and source_msg.get("tool_calls")
|
||||
and isinstance(normalized_reasoning, str)
|
||||
and normalized_reasoning
|
||||
):
|
||||
api_msg["reasoning_content"] = " "
|
||||
return
|
||||
|
||||
# 3. Healthy session: promote 'reasoning' field to 'reasoning_content'
|
||||
# for providers that use the internal 'reasoning' key.
|
||||
# This must happen before the unconditional empty-string fallback so
|
||||
# genuine reasoning content is not overwritten (#15812 regression in
|
||||
# PR #15478).
|
||||
if isinstance(normalized_reasoning, str) and normalized_reasoning:
|
||||
api_msg["reasoning_content"] = normalized_reasoning
|
||||
return
|
||||
|
||||
# 4. DeepSeek / Kimi thinking mode: all assistant messages need
|
||||
# reasoning_content. Inject a single space to satisfy the provider's
|
||||
# requirement when no explicit reasoning content is present. Covers
|
||||
# both tool-call turns (already-poisoned history with no reasoning
|
||||
# at all) and plain text turns. Space (not "") because DeepSeek V4
|
||||
# Pro tightened validation and rejects empty string with HTTP 400
|
||||
# ("The reasoning content in the thinking mode must be passed back
|
||||
# to the API"). Refs #17341.
|
||||
if needs_thinking_pad:
|
||||
api_msg["reasoning_content"] = " "
|
||||
return
|
||||
|
||||
# 5. reasoning_content was present but not a string (e.g. None after
|
||||
# context compaction). Don't pass null to the API.
|
||||
api_msg.pop("reasoning_content", None)
|
||||
"""Forwarder — see ``agent.agent_runtime_helpers.copy_reasoning_content_for_api``."""
|
||||
from agent.agent_runtime_helpers import copy_reasoning_content_for_api
|
||||
return copy_reasoning_content_for_api(self, source_msg, api_msg)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_tool_calls_for_strict_api(api_msg: dict) -> dict:
|
||||
|
|
@ -4251,89 +3568,9 @@ class AIAgent:
|
|||
def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str,
|
||||
tool_call_id: Optional[str] = None, messages: list = None,
|
||||
pre_tool_block_checked: bool = False) -> str:
|
||||
"""Invoke a single tool and return the result string. No display logic.
|
||||
|
||||
Handles both agent-level tools (todo, memory, etc.) and registry-dispatched
|
||||
tools. Used by the concurrent execution path; the sequential path retains
|
||||
its own inline invocation for backward-compatible display handling.
|
||||
"""
|
||||
# Check plugin hooks for a block directive before executing anything.
|
||||
block_message: Optional[str] = None
|
||||
if not pre_tool_block_checked:
|
||||
try:
|
||||
from hermes_cli.plugins import get_pre_tool_call_block_message
|
||||
block_message = get_pre_tool_call_block_message(
|
||||
function_name, function_args, task_id=effective_task_id or "",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if block_message is not None:
|
||||
return json.dumps({"error": block_message}, ensure_ascii=False)
|
||||
|
||||
if function_name == "todo":
|
||||
from tools.todo_tool import todo_tool as _todo_tool
|
||||
return _todo_tool(
|
||||
todos=function_args.get("todos"),
|
||||
merge=function_args.get("merge", False),
|
||||
store=self._todo_store,
|
||||
)
|
||||
elif function_name == "session_search":
|
||||
session_db = self._get_session_db_for_recall()
|
||||
if not session_db:
|
||||
from hermes_state import format_session_db_unavailable
|
||||
return json.dumps({"success": False, "error": format_session_db_unavailable()})
|
||||
from tools.session_search_tool import session_search as _session_search
|
||||
return _session_search(
|
||||
query=function_args.get("query", ""),
|
||||
role_filter=function_args.get("role_filter"),
|
||||
limit=function_args.get("limit", 3),
|
||||
db=session_db,
|
||||
current_session_id=self.session_id,
|
||||
)
|
||||
elif function_name == "memory":
|
||||
target = function_args.get("target", "memory")
|
||||
from tools.memory_tool import memory_tool as _memory_tool
|
||||
result = _memory_tool(
|
||||
action=function_args.get("action"),
|
||||
target=target,
|
||||
content=function_args.get("content"),
|
||||
old_text=function_args.get("old_text"),
|
||||
store=self._memory_store,
|
||||
)
|
||||
# Bridge: notify external memory provider of built-in memory writes
|
||||
if self._memory_manager and function_args.get("action") in {"add", "replace"}:
|
||||
try:
|
||||
self._memory_manager.on_memory_write(
|
||||
function_args.get("action", ""),
|
||||
target,
|
||||
function_args.get("content", ""),
|
||||
metadata=self._build_memory_write_metadata(
|
||||
task_id=effective_task_id,
|
||||
tool_call_id=tool_call_id,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
elif self._memory_manager and self._memory_manager.has_tool(function_name):
|
||||
return self._memory_manager.handle_tool_call(function_name, function_args)
|
||||
elif function_name == "clarify":
|
||||
from tools.clarify_tool import clarify_tool as _clarify_tool
|
||||
return _clarify_tool(
|
||||
question=function_args.get("question", ""),
|
||||
choices=function_args.get("choices"),
|
||||
callback=self.clarify_callback,
|
||||
)
|
||||
elif function_name == "delegate_task":
|
||||
return self._dispatch_delegate_task(function_args)
|
||||
else:
|
||||
return handle_function_call(
|
||||
function_name, function_args, effective_task_id,
|
||||
tool_call_id=tool_call_id,
|
||||
session_id=self.session_id or "",
|
||||
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
|
||||
skip_pre_tool_call_hook=True,
|
||||
)
|
||||
"""Forwarder — see ``agent.agent_runtime_helpers.invoke_tool``."""
|
||||
from agent.agent_runtime_helpers import invoke_tool
|
||||
return invoke_tool(self, function_name, function_args, effective_task_id, tool_call_id, messages, pre_tool_block_checked)
|
||||
|
||||
@staticmethod
|
||||
def _wrap_verbose(label: str, text: str, indent: str = " ") -> str:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue