diff --git a/gateway/run.py b/gateway/run.py index ef3fd3be5ed..40721eada6b 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1213,6 +1213,13 @@ def _resolve_runtime_agent_kwargs() -> dict: mt = model_cfg.get("max_tokens") if isinstance(mt, int): max_tokens = mt + # Fall back to a per-provider output cap (custom_providers max_output_tokens) + # only when the documented global model.max_tokens isn't set, so the global + # key always wins. + if max_tokens is None: + _runtime_mot = runtime.get("max_output_tokens") + if isinstance(_runtime_mot, int) and _runtime_mot > 0: + max_tokens = _runtime_mot return { "api_key": runtime.get("api_key"), diff --git a/hermes_cli/runtime_provider.py b/hermes_cli/runtime_provider.py index 1edb8e99e47..cca80e988ce 100644 --- a/hermes_cli/runtime_provider.py +++ b/hermes_cli/runtime_provider.py @@ -474,6 +474,21 @@ def _try_resolve_from_custom_pool( return None +def _lift_max_output_tokens(entry: Dict[str, Any], result: Dict[str, Any]) -> None: + """Propagate a per-provider output cap onto the resolved runtime dict. + + Accepts ``max_output_tokens`` or ``max_tokens`` on a ``custom_providers`` + entry so a provider block can pin its own output limit. Gateway and CLI + map this onto ``AIAgent.max_tokens`` only when the top-level + ``model.max_tokens`` isn't set, so the documented global key still wins. + """ + for _k in ("max_output_tokens", "max_tokens"): + _v = entry.get(_k) + if isinstance(_v, int) and _v > 0: + result["max_output_tokens"] = _v + return + + def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, Any]]: requested_norm = _normalize_custom_provider_name(requested_provider or "") if not requested_norm or requested_norm == "custom": @@ -541,6 +556,7 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An api_mode = _parse_api_mode(entry.get("api_mode") or entry.get("transport")) if api_mode: result["api_mode"] = api_mode + _lift_max_output_tokens(entry, result) return result # Also check the 'name' field if present display_name = entry.get("name", "") @@ -562,6 +578,7 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An api_mode = _parse_api_mode(entry.get("api_mode") or entry.get("transport")) if api_mode: result["api_mode"] = api_mode + _lift_max_output_tokens(entry, result) return result # Fall back to custom_providers: list (legacy format) @@ -611,6 +628,7 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An model_name = str(entry.get("model", "") or "").strip() if model_name: result["model"] = model_name + _lift_max_output_tokens(entry, result) return result return None @@ -699,6 +717,8 @@ def _resolve_named_custom_runtime( model_name = custom_provider.get("model") if model_name: pool_result["model"] = model_name + if isinstance(custom_provider.get("max_output_tokens"), int): + pool_result["max_output_tokens"] = custom_provider["max_output_tokens"] request_overrides = _custom_provider_request_overrides(custom_provider) if request_overrides: pool_result["request_overrides"] = { @@ -736,6 +756,8 @@ def _resolve_named_custom_runtime( # provider name differs from the actual model string the API expects. if custom_provider.get("model"): result["model"] = custom_provider["model"] + if isinstance(custom_provider.get("max_output_tokens"), int): + result["max_output_tokens"] = custom_provider["max_output_tokens"] request_overrides = _custom_provider_request_overrides(custom_provider) if request_overrides: result["request_overrides"] = request_overrides