From 0ed28ab80cec759fd15ca4a3d251d4fd02ad5bdd Mon Sep 17 00:00:00 2001 From: kshitijk4poor <82637225+kshitijk4poor@users.noreply.github.com> Date: Thu, 2 Apr 2026 23:41:38 +0530 Subject: [PATCH] refactor: simplify and harden PR fixes after review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix cron ThreadPoolExecutor blocking on timeout: use shutdown(wait=False, cancel_futures=True) instead of context manager that waits indefinitely - Extract _dequeue_pending_text() to deduplicate media-placeholder logic in interrupt and normal-completion dequeue paths - Remove hasattr guards for _running_agents_ts: add class-level default so partial test construction works without scattered defensive checks - Move `import concurrent.futures` to top of cron/scheduler.py - Progress throttle: sleep remaining interval instead of busy-looping 0.1s (~15 wakeups per 1.5s window → 1 wakeup) - Deduplicate _load_stt_config() in transcription_tools.py: _has_openai_audio_backend() now delegates to _resolve_openai_audio_client_config() --- cron/scheduler.py | 35 +++++++++++--------- gateway/run.py | 63 +++++++++++++++++++----------------- tools/transcription_tools.py | 9 +++--- 3 files changed, 57 insertions(+), 50 deletions(-) diff --git a/cron/scheduler.py b/cron/scheduler.py index 906953c0a..8a54520a1 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -9,6 +9,7 @@ runs at a time if multiple processes overlap. """ import asyncio +import concurrent.futures import json import logging import os @@ -448,22 +449,24 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: # override via env var. Uses a separate thread because # run_conversation is synchronous. _cron_timeout = float(os.getenv("HERMES_CRON_TIMEOUT", 600)) - import concurrent.futures - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as _cron_pool: - _cron_future = _cron_pool.submit(agent.run_conversation, prompt) - try: - result = _cron_future.result(timeout=_cron_timeout) - except concurrent.futures.TimeoutError: - logger.error( - "Job '%s' timed out after %.0fs — interrupting agent", - job_name, _cron_timeout, - ) - if hasattr(agent, "interrupt"): - agent.interrupt("Cron job timed out") - raise TimeoutError( - f"Cron job '{job_name}' timed out after " - f"{int(_cron_timeout // 60)} minutes" - ) + _cron_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + _cron_future = _cron_pool.submit(agent.run_conversation, prompt) + try: + result = _cron_future.result(timeout=_cron_timeout) + except concurrent.futures.TimeoutError: + logger.error( + "Job '%s' timed out after %.0fs — interrupting agent", + job_name, _cron_timeout, + ) + if hasattr(agent, "interrupt"): + agent.interrupt("Cron job timed out") + _cron_pool.shutdown(wait=False, cancel_futures=True) + raise TimeoutError( + f"Cron job '{job_name}' timed out after " + f"{int(_cron_timeout // 60)} minutes" + ) + finally: + _cron_pool.shutdown(wait=False) final_response = result.get("final_response", "") or "" # Use a separate variable for log display; keep final_response clean diff --git a/gateway/run.py b/gateway/run.py index 593d00583..dfcda3fac 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -325,6 +325,21 @@ def _build_media_placeholder(event) -> str: return "\n".join(parts) +def _dequeue_pending_text(adapter, session_key: str) -> str | None: + """Consume and return the text of a pending queued message. + + Preserves media context for captionless photo/document events by + building a placeholder so the message isn't silently dropped. + """ + event = adapter.get_pending_message(session_key) + if not event: + return None + text = event.text + if not text and getattr(event, "media_urls", None): + text = _build_media_placeholder(event) + return text + + def _check_unavailable_skill(command_name: str) -> str | None: """Check if a command matches a known-but-inactive skill. @@ -431,9 +446,13 @@ def _resolve_hermes_bin() -> Optional[list[str]]: class GatewayRunner: + # Class-level defaults so partial construction in tests doesn't + # blow up on attribute access. + _running_agents_ts: Dict[str, float] = {} + """ Main gateway controller. - + Manages the lifecycle of all platform adapters and routes messages to/from the agent. """ @@ -1726,15 +1745,14 @@ class GatewayRunner: # longer than the agent timeout, it's a leaked lock from a hung or # crashed handler. Evict it so the session isn't permanently stuck. _STALE_TTL = float(os.getenv("HERMES_AGENT_TIMEOUT", 600)) + 60 # timeout + 1 min grace - _ts_dict = getattr(self, "_running_agents_ts", {}) - _stale_ts = _ts_dict.get(_quick_key, 0) + _stale_ts = self._running_agents_ts.get(_quick_key, 0) if _quick_key in self._running_agents and _stale_ts and (time.time() - _stale_ts) > _STALE_TTL: logger.warning( "Evicting stale _running_agents entry for %s (age: %.0fs)", _quick_key[:30], time.time() - _stale_ts, ) del self._running_agents[_quick_key] - _ts_dict.pop(_quick_key, None) + self._running_agents_ts.pop(_quick_key, None) if _quick_key in self._running_agents: if event.get_command() == "status": @@ -2061,8 +2079,7 @@ class GatewayRunner: # "already running" guard and spin up a duplicate agent for the # same session — corrupting the transcript. self._running_agents[_quick_key] = _AGENT_PENDING_SENTINEL - if hasattr(self, "_running_agents_ts"): - self._running_agents_ts[_quick_key] = time.time() + self._running_agents_ts[_quick_key] = time.time() try: return await self._handle_message_with_agent(event, source, _quick_key) @@ -2073,8 +2090,7 @@ class GatewayRunner: # not linger or the session would be permanently locked out. if self._running_agents.get(_quick_key) is _AGENT_PENDING_SENTINEL: del self._running_agents[_quick_key] - if hasattr(self, "_running_agents_ts"): - self._running_agents_ts.pop(_quick_key, None) + self._running_agents_ts.pop(_quick_key, None) async def _handle_message_with_agent(self, event, source, _quick_key: str): """Inner handler that runs under the _running_agents sentinel guard.""" @@ -5448,8 +5464,9 @@ class GatewayRunner: # (grammY auto-retry pattern: proactively rate-limit # instead of reacting to 429s.) _now = time.monotonic() - if _now - _last_edit_ts < _PROGRESS_EDIT_INTERVAL: - await asyncio.sleep(0.1) + _remaining = _PROGRESS_EDIT_INTERVAL - (_now - _last_edit_ts) + if _remaining > 0: + await asyncio.sleep(_remaining) continue if can_edit and progress_msg_id is not None: @@ -6064,27 +6081,13 @@ class GatewayRunner: pending = None if result and adapter and session_key: if result.get("interrupted"): - # Interrupted — consume the interrupt message - pending_event = adapter.get_pending_message(session_key) - if pending_event: - pending = pending_event.text - # Preserve media context for photo/document events - # whose text is empty (no caption). Without this, - # captionless photos are silently dropped. - if not pending and getattr(pending_event, "media_urls", None): - pending = _build_media_placeholder(pending_event) - elif result.get("interrupt_message"): + pending = _dequeue_pending_text(adapter, session_key) + if not pending and result.get("interrupt_message"): pending = result.get("interrupt_message") else: - # Normal completion — check for /queue'd messages that were - # stored without triggering an interrupt. - pending_event = adapter.get_pending_message(session_key) - if pending_event: - pending = pending_event.text - if not pending and getattr(pending_event, "media_urls", None): - pending = _build_media_placeholder(pending_event) - if pending: - logger.debug("Processing queued message after agent completion: '%s...'", pending[:40]) + pending = _dequeue_pending_text(adapter, session_key) + if pending: + logger.debug("Processing queued message after agent completion: '%s...'", pending[:40]) if pending: logger.debug("Processing pending message: '%s...'", pending[:40]) @@ -6159,7 +6162,7 @@ class GatewayRunner: tracking_task.cancel() if session_key and session_key in self._running_agents: del self._running_agents[session_key] - if session_key and hasattr(self, "_running_agents_ts"): + if session_key: self._running_agents_ts.pop(session_key, None) # Wait for cancelled tasks diff --git a/tools/transcription_tools.py b/tools/transcription_tools.py index 1a7acee9b..9a79cdfba 100644 --- a/tools/transcription_tools.py +++ b/tools/transcription_tools.py @@ -128,10 +128,11 @@ def is_stt_enabled(stt_config: Optional[dict] = None) -> bool: def _has_openai_audio_backend() -> bool: """Return True when OpenAI audio can use config credentials, env credentials, or the managed gateway.""" - stt_config = _load_stt_config() - openai_cfg = stt_config.get("openai", {}) - cfg_api_key = openai_cfg.get("api_key", "") - return bool(cfg_api_key or resolve_openai_audio_api_key() or resolve_managed_tool_gateway("openai-audio")) + try: + _resolve_openai_audio_client_config() + return True + except ValueError: + return False def _find_binary(binary_name: str) -> Optional[str]: