diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 75ba3d1153..5e1be74ad4 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -2039,6 +2039,66 @@ class DiscordAdapter(BasePlatformAdapter): except Exception as e: return SendResult(success=False, error=str(e)) + async def send_model_picker( + self, + chat_id: str, + providers: list, + current_model: str, + current_provider: str, + session_key: str, + on_model_selected, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send an interactive select-menu model picker. + + Two-step drill-down: provider dropdown → model dropdown. + Uses Discord embeds + Select menus via ``ModelPickerView``. + """ + if not self._client or not DISCORD_AVAILABLE: + return SendResult(success=False, error="Not connected") + + try: + # Resolve target channel (use thread_id if present) + target_id = chat_id + if metadata and metadata.get("thread_id"): + target_id = metadata["thread_id"] + + channel = self._client.get_channel(int(target_id)) + if not channel: + channel = await self._client.fetch_channel(int(target_id)) + + try: + from hermes_cli.providers import get_label + provider_label = get_label(current_provider) + except Exception: + provider_label = current_provider + + embed = discord.Embed( + title="⚙ Model Configuration", + description=( + f"Current model: `{current_model or 'unknown'}`\n" + f"Provider: {provider_label}\n\n" + f"Select a provider:" + ), + color=discord.Color.blue(), + ) + + view = ModelPickerView( + providers=providers, + current_model=current_model, + current_provider=current_provider, + session_key=session_key, + on_model_selected=on_model_selected, + allowed_user_ids=self._allowed_user_ids, + ) + + msg = await channel.send(embed=embed, view=view) + return SendResult(success=True, message_id=str(msg.id)) + + except Exception as e: + logger.warning("[%s] send_model_picker failed: %s", self.name, e) + return SendResult(success=False, error=str(e)) + def _get_parent_channel_id(self, channel: Any) -> Optional[str]: """Return the parent channel ID for a Discord thread-like channel, if present.""" parent = getattr(channel, "parent", None) @@ -2530,3 +2590,219 @@ if DISCORD_AVAILABLE: self.resolved = True for child in self.children: child.disabled = True + + class ModelPickerView(discord.ui.View): + """Interactive select-menu view for model switching. + + Two-step drill-down: provider dropdown → model dropdown. + Edits the original message in-place as the user navigates. + Times out after 2 minutes. + """ + + def __init__( + self, + providers: list, + current_model: str, + current_provider: str, + session_key: str, + on_model_selected, + allowed_user_ids: set, + ): + super().__init__(timeout=120) + self.providers = providers + self.current_model = current_model + self.current_provider = current_provider + self.session_key = session_key + self.on_model_selected = on_model_selected + self.allowed_user_ids = allowed_user_ids + self.resolved = False + self._selected_provider: str = "" + + self._build_provider_select() + + def _check_auth(self, interaction: discord.Interaction) -> bool: + if not self.allowed_user_ids: + return True + return str(interaction.user.id) in self.allowed_user_ids + + def _build_provider_select(self): + """Build the provider dropdown menu.""" + self.clear_items() + options = [] + for p in self.providers: + count = p.get("total_models", len(p.get("models", []))) + label = f"{p['name']} ({count} models)" + desc = "current" if p.get("is_current") else None + options.append( + discord.SelectOption( + label=label[:100], + value=p["slug"], + default=bool(p.get("is_current")), + description=desc, + ) + ) + if not options: + return + + select = discord.ui.Select( + placeholder="Choose a provider...", + options=options[:25], + custom_id="model_provider_select", + ) + select.callback = self._on_provider_selected + self.add_item(select) + + cancel_btn = discord.ui.Button( + label="Cancel", style=discord.ButtonStyle.red, custom_id="model_cancel" + ) + cancel_btn.callback = self._on_cancel + self.add_item(cancel_btn) + + def _build_model_select(self, provider_slug: str): + """Build the model dropdown for a specific provider.""" + self.clear_items() + provider = next( + (p for p in self.providers if p["slug"] == provider_slug), None + ) + if not provider: + return + + models = provider.get("models", []) + options = [] + for model_id in models[:25]: + short = model_id.split("/")[-1] if "/" in model_id else model_id + options.append( + discord.SelectOption( + label=short[:100], + value=model_id[:100], + ) + ) + if not options: + return + + select = discord.ui.Select( + placeholder=f"Choose a model from {provider.get('name', provider_slug)}...", + options=options, + custom_id="model_model_select", + ) + select.callback = self._on_model_selected + self.add_item(select) + + back_btn = discord.ui.Button( + label="◀ Back", style=discord.ButtonStyle.grey, custom_id="model_back" + ) + back_btn.callback = self._on_back + self.add_item(back_btn) + + cancel_btn = discord.ui.Button( + label="Cancel", style=discord.ButtonStyle.red, custom_id="model_cancel2" + ) + cancel_btn.callback = self._on_cancel + self.add_item(cancel_btn) + + async def _on_provider_selected(self, interaction: discord.Interaction): + if not self._check_auth(interaction): + await interaction.response.send_message( + "You're not authorized~", ephemeral=True + ) + return + + provider_slug = interaction.data["values"][0] + self._selected_provider = provider_slug + provider = next( + (p for p in self.providers if p["slug"] == provider_slug), None + ) + pname = provider.get("name", provider_slug) if provider else provider_slug + + self._build_model_select(provider_slug) + + total = provider.get("total_models", 0) if provider else 0 + shown = min(len(provider.get("models", [])), 25) if provider else 0 + extra = f"\n*{total - shown} more available — type `/model ` directly*" if total > shown else "" + + await interaction.response.edit_message( + embed=discord.Embed( + title="⚙ Model Configuration", + description=f"Provider: **{pname}**\nSelect a model:{extra}", + color=discord.Color.blue(), + ), + view=self, + ) + + async def _on_model_selected(self, interaction: discord.Interaction): + if self.resolved: + await interaction.response.send_message( + "Already resolved~", ephemeral=True + ) + return + if not self._check_auth(interaction): + await interaction.response.send_message( + "You're not authorized~", ephemeral=True + ) + return + + self.resolved = True + model_id = interaction.data["values"][0] + + try: + result_text = await self.on_model_selected( + str(interaction.channel_id), + model_id, + self._selected_provider, + ) + except Exception as exc: + result_text = f"Error switching model: {exc}" + + self.clear_items() + await interaction.response.edit_message( + embed=discord.Embed( + title="⚙ Model Switched", + description=result_text, + color=discord.Color.green(), + ), + view=self, + ) + + async def _on_back(self, interaction: discord.Interaction): + if not self._check_auth(interaction): + await interaction.response.send_message( + "You're not authorized~", ephemeral=True + ) + return + + self._build_provider_select() + + try: + from hermes_cli.providers import get_label + provider_label = get_label(self.current_provider) + except Exception: + provider_label = self.current_provider + + await interaction.response.edit_message( + embed=discord.Embed( + title="⚙ Model Configuration", + description=( + f"Current model: `{self.current_model or 'unknown'}`\n" + f"Provider: {provider_label}\n\n" + f"Select a provider:" + ), + color=discord.Color.blue(), + ), + view=self, + ) + + async def _on_cancel(self, interaction: discord.Interaction): + self.resolved = True + self.clear_items() + await interaction.response.edit_message( + embed=discord.Embed( + title="⚙ Model Configuration", + description="Model selection cancelled.", + color=discord.Color.greyple(), + ), + view=self, + ) + + async def on_timeout(self): + self.resolved = True + self.clear_items() diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 7575c10f37..0362b9f96a 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -151,6 +151,8 @@ class TelegramAdapter(BasePlatformAdapter): self._dm_topics: Dict[str, int] = {} # DM Topics config from extra.dm_topics self._dm_topics_config: List[Dict[str, Any]] = self.config.extra.get("dm_topics", []) + # Interactive model picker state per chat + self._model_picker_state: Dict[str, dict] = {} def _fallback_ips(self) -> list[str]: """Return validated fallback IPs from config (populated by _apply_env_overrides).""" @@ -1008,14 +1010,252 @@ class TelegramAdapter(BasePlatformAdapter): logger.warning("[%s] send_update_prompt failed: %s", self.name, e) return SendResult(success=False, error=str(e)) + async def send_model_picker( + self, + chat_id: str, + providers: list, + current_model: str, + current_provider: str, + session_key: str, + on_model_selected, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send an interactive inline-keyboard model picker. + + Two-step drill-down: provider selection → model selection. + Edits the same message in-place as the user navigates. + """ + if not self._bot: + return SendResult(success=False, error="Not connected") + + try: + from hermes_cli.providers import get_label + except ImportError: + def get_label(slug): + return slug + + try: + # Build provider buttons — 2 per row + buttons: list = [] + for p in providers: + count = p.get("total_models", len(p.get("models", []))) + label = f"{p['name']} ({count})" + if p.get("is_current"): + label = f"✓ {label}" + # Compact callback data: mp: (max 64 bytes) + buttons.append( + InlineKeyboardButton(label, callback_data=f"mp:{p['slug']}") + ) + + rows = [buttons[i : i + 2] for i in range(0, len(buttons), 2)] + rows.append([InlineKeyboardButton("✗ Cancel", callback_data="mx")]) + keyboard = InlineKeyboardMarkup(rows) + + provider_label = get_label(current_provider) + text = ( + f"⚙ *Model Configuration*\n\n" + f"Current model: `{current_model or 'unknown'}`\n" + f"Provider: {provider_label}\n\n" + f"Select a provider:" + ) + + thread_id = metadata.get("thread_id") if metadata else None + msg = await self._bot.send_message( + chat_id=int(chat_id), + text=text, + parse_mode=ParseMode.MARKDOWN, + reply_markup=keyboard, + message_thread_id=int(thread_id) if thread_id else None, + ) + + # Store picker state keyed by chat_id + self._model_picker_state[str(chat_id)] = { + "msg_id": msg.message_id, + "providers": providers, + "session_key": session_key, + "on_model_selected": on_model_selected, + "current_model": current_model, + "current_provider": current_provider, + } + + return SendResult(success=True, message_id=str(msg.message_id)) + except Exception as e: + logger.warning("[%s] send_model_picker failed: %s", self.name, e) + return SendResult(success=False, error=str(e)) + + async def _handle_model_picker_callback( + self, query, data: str, chat_id: str + ) -> None: + """Handle model picker inline keyboard callbacks (mp:/mm:/mb:/mx:).""" + state = self._model_picker_state.get(chat_id) + if not state: + await query.answer(text="Picker expired — use /model again.") + return + + try: + from hermes_cli.providers import get_label + except ImportError: + def get_label(slug): + return slug + + if data.startswith("mp:"): + # --- Provider selected: show model buttons --- + provider_slug = data[3:] + provider = next( + (p for p in state["providers"] if p["slug"] == provider_slug), + None, + ) + if not provider: + await query.answer(text="Provider not found.") + return + + models = provider.get("models", []) + state["selected_provider"] = provider_slug + state["selected_provider_name"] = provider.get("name", provider_slug) + state["model_list"] = models + + buttons: list = [] + for i, model_id in enumerate(models): + # Short display label: strip vendor prefix + short = model_id.split("/")[-1] if "/" in model_id else model_id + # Truncate long model names for button label (max ~40 chars) + if len(short) > 38: + short = short[:35] + "..." + buttons.append( + InlineKeyboardButton(short, callback_data=f"mm:{i}") + ) + + rows = [buttons[i : i + 2] for i in range(0, len(buttons), 2)] + rows.append([ + InlineKeyboardButton("◀ Back", callback_data="mb"), + InlineKeyboardButton("✗ Cancel", callback_data="mx"), + ]) + keyboard = InlineKeyboardMarkup(rows) + + pname = provider.get("name", provider_slug) + total = provider.get("total_models", len(models)) + shown = len(models) + extra = f"\n_{total - shown} more available — type `/model ` directly_" if total > shown else "" + + await query.edit_message_text( + text=( + f"⚙ *Model Configuration*\n\n" + f"Provider: *{pname}*\n" + f"Select a model:{extra}" + ), + parse_mode=ParseMode.MARKDOWN, + reply_markup=keyboard, + ) + await query.answer() + + elif data.startswith("mm:"): + # --- Model selected: perform the switch --- + try: + idx = int(data[3:]) + except ValueError: + await query.answer(text="Invalid selection.") + return + + model_list = state.get("model_list", []) + if idx < 0 or idx >= len(model_list): + await query.answer(text="Invalid model index.") + return + + model_id = model_list[idx] + provider_slug = state.get("selected_provider", "") + callback = state.get("on_model_selected") + + if not callback: + await query.answer(text="Picker expired.") + return + + try: + result_text = await callback(chat_id, model_id, provider_slug) + except Exception as exc: + logger.error("Model picker switch failed: %s", exc) + result_text = f"Error switching model: {exc}" + + # Edit message to show confirmation, remove buttons + try: + await query.edit_message_text( + text=result_text, + parse_mode=ParseMode.MARKDOWN, + reply_markup=None, + ) + except Exception: + # Markdown parse failure — retry as plain text + try: + await query.edit_message_text( + text=result_text, + parse_mode=None, + reply_markup=None, + ) + except Exception: + pass + await query.answer(text="Model switched!") + + # Clean up state + self._model_picker_state.pop(chat_id, None) + + elif data == "mb": + # --- Back to provider list --- + buttons = [] + for p in state["providers"]: + count = p.get("total_models", len(p.get("models", []))) + label = f"{p['name']} ({count})" + if p.get("is_current"): + label = f"✓ {label}" + buttons.append( + InlineKeyboardButton(label, callback_data=f"mp:{p['slug']}") + ) + + rows = [buttons[i : i + 2] for i in range(0, len(buttons), 2)] + rows.append([InlineKeyboardButton("✗ Cancel", callback_data="mx")]) + keyboard = InlineKeyboardMarkup(rows) + + try: + provider_label = get_label(state["current_provider"]) + except Exception: + provider_label = state["current_provider"] + + await query.edit_message_text( + text=( + f"⚙ *Model Configuration*\n\n" + f"Current model: `{state['current_model'] or 'unknown'}`\n" + f"Provider: {provider_label}\n\n" + f"Select a provider:" + ), + parse_mode=ParseMode.MARKDOWN, + reply_markup=keyboard, + ) + await query.answer() + + elif data == "mx": + # --- Cancel --- + self._model_picker_state.pop(chat_id, None) + await query.edit_message_text( + text="Model selection cancelled.", + reply_markup=None, + ) + await query.answer() + async def _handle_callback_query( self, update: "Update", context: "ContextTypes.DEFAULT_TYPE" ) -> None: - """Handle inline keyboard button clicks (update prompts).""" + """Handle inline keyboard button clicks.""" query = update.callback_query if not query or not query.data: return data = query.data + + # --- Model picker callbacks --- + if data.startswith(("mp:", "mm:", "mb", "mx")): + chat_id = str(query.message.chat_id) if query.message else None + if chat_id: + await self._handle_model_picker_callback(query, data, chat_id) + return + + # --- Update prompt callbacks --- if not data.startswith("update_prompt:"): return answer = data.split(":", 1)[1] # "y" or "n" diff --git a/gateway/run.py b/gateway/run.py index 9d5ac5aa2c..08be2b9db2 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -3464,11 +3464,11 @@ class GatewayRunner: lines.append(f"_(Requested page {requested_page} was out of range, showing page {page}.)_") return "\n".join(lines) - async def _handle_model_command(self, event: MessageEvent) -> str: + async def _handle_model_command(self, event: MessageEvent) -> Optional[str]: """Handle /model command — switch model for this session. Supports: - /model — show current model info + /model — interactive picker (Telegram/Discord) or text list /model — switch for this session only /model --global — switch and persist to config.yaml /model --provider — switch provider + model @@ -3516,8 +3516,118 @@ class GatewayRunner: current_base_url = override.get("base_url", current_base_url) current_api_key = override.get("api_key", current_api_key) - # No args: show authenticated providers with models + # No args: show interactive picker (Telegram/Discord) or text list if not model_input and not explicit_provider: + # Try interactive picker if the platform supports it + adapter = self.adapters.get(source.platform) + has_picker = ( + adapter is not None + and getattr(type(adapter), "send_model_picker", None) is not None + ) + + if has_picker: + try: + providers = list_authenticated_providers( + current_provider=current_provider, + user_providers=user_provs, + max_models=8, + ) + except Exception: + providers = [] + + if providers: + # Build a callback closure for when the user picks a model. + # Captures self + locals needed for the switch logic. + _self = self + _session_key = session_key + _cur_model = current_model + _cur_provider = current_provider + _cur_base_url = current_base_url + _cur_api_key = current_api_key + + async def _on_model_selected( + _chat_id: str, model_id: str, provider_slug: str + ) -> str: + """Perform the model switch and return confirmation text.""" + result = _switch_model( + raw_input=model_id, + current_provider=_cur_provider, + current_model=_cur_model, + current_base_url=_cur_base_url, + current_api_key=_cur_api_key, + is_global=False, + explicit_provider=provider_slug, + ) + if not result.success: + return f"Error: {result.error_message}" + + # Update cached agent in-place + cached_entry = None + _cache_lock = getattr(_self, "_agent_cache_lock", None) + _cache = getattr(_self, "_agent_cache", None) + if _cache_lock and _cache is not None: + with _cache_lock: + cached_entry = _cache.get(_session_key) + if cached_entry and cached_entry[0] is not None: + try: + cached_entry[0].switch_model( + new_model=result.new_model, + new_provider=result.target_provider, + api_key=result.api_key, + base_url=result.base_url, + api_mode=result.api_mode, + ) + except Exception as exc: + logger.warning("Picker model switch failed for cached agent: %s", exc) + + # Store model note + session override + if not hasattr(_self, "_pending_model_notes"): + _self._pending_model_notes = {} + _self._pending_model_notes[_session_key] = ( + f"[Note: model was just switched from {_cur_model} to {result.new_model} " + f"via {result.provider_label or result.target_provider}. " + f"Adjust your self-identification accordingly.]" + ) + if not hasattr(_self, "_session_model_overrides"): + _self._session_model_overrides = {} + _self._session_model_overrides[_session_key] = { + "model": result.new_model, + "provider": result.target_provider, + "api_key": result.api_key, + "base_url": result.base_url, + "api_mode": result.api_mode, + } + + # Build confirmation text + plabel = result.provider_label or result.target_provider + lines = [f"Model switched to `{result.new_model}`"] + lines.append(f"Provider: {plabel}") + mi = result.model_info + if mi: + if mi.context_window: + lines.append(f"Context: {mi.context_window:,} tokens") + if mi.max_output: + lines.append(f"Max output: {mi.max_output:,} tokens") + if mi.has_cost_data(): + lines.append(f"Cost: {mi.format_cost()}") + lines.append(f"Capabilities: {mi.format_capabilities()}") + lines.append("_(session only — use `/model --global` to persist)_") + return "\n".join(lines) + + metadata = {"thread_id": source.thread_id} if source.thread_id else None + result = await adapter.send_model_picker( + chat_id=source.chat_id, + providers=providers, + current_model=current_model, + current_provider=current_provider, + session_key=session_key, + on_model_selected=_on_model_selected, + metadata=metadata, + ) + if result.success: + return None # Picker sent — adapter handles the response + + # Fallback: text list (for platforms without picker or if picker failed) provider_label = get_label(current_provider) lines = [f"Current: `{current_model or 'unknown'}` on {provider_label}", ""]