fix(gateway): honor per-provider max_output_tokens in max_tokens chain

Widens ViewWay's #20741 fix to the sibling config surface: a
custom_providers entry can pin its own output cap via max_output_tokens
(or max_tokens). _get_named_custom_provider now lifts it onto the
resolved runtime at all three return sites, and the gateway uses it as a
fallback only when the documented global model.max_tokens isn't set, so
the global key always wins.

Precedence: HERMES_MAX_TOKENS > model.max_tokens > provider
max_output_tokens > None. Closes the same #20741 truncation for users who
configure the cap per-provider rather than globally.

Picks up the intent of #19782 (alexcam1901), reimplemented to feed
ViewWay's max_tokens pipeline.
This commit is contained in:
teknium1 2026-06-05 07:00:14 -07:00 committed by Teknium
parent 1c909e75e1
commit 14275d7baa
2 changed files with 29 additions and 0 deletions

View file

@ -1213,6 +1213,13 @@ def _resolve_runtime_agent_kwargs() -> dict:
mt = model_cfg.get("max_tokens")
if isinstance(mt, int):
max_tokens = mt
# Fall back to a per-provider output cap (custom_providers max_output_tokens)
# only when the documented global model.max_tokens isn't set, so the global
# key always wins.
if max_tokens is None:
_runtime_mot = runtime.get("max_output_tokens")
if isinstance(_runtime_mot, int) and _runtime_mot > 0:
max_tokens = _runtime_mot
return {
"api_key": runtime.get("api_key"),

View file

@ -474,6 +474,21 @@ def _try_resolve_from_custom_pool(
return None
def _lift_max_output_tokens(entry: Dict[str, Any], result: Dict[str, Any]) -> None:
"""Propagate a per-provider output cap onto the resolved runtime dict.
Accepts ``max_output_tokens`` or ``max_tokens`` on a ``custom_providers``
entry so a provider block can pin its own output limit. Gateway and CLI
map this onto ``AIAgent.max_tokens`` only when the top-level
``model.max_tokens`` isn't set, so the documented global key still wins.
"""
for _k in ("max_output_tokens", "max_tokens"):
_v = entry.get(_k)
if isinstance(_v, int) and _v > 0:
result["max_output_tokens"] = _v
return
def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, Any]]:
requested_norm = _normalize_custom_provider_name(requested_provider or "")
if not requested_norm or requested_norm == "custom":
@ -541,6 +556,7 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
api_mode = _parse_api_mode(entry.get("api_mode") or entry.get("transport"))
if api_mode:
result["api_mode"] = api_mode
_lift_max_output_tokens(entry, result)
return result
# Also check the 'name' field if present
display_name = entry.get("name", "")
@ -562,6 +578,7 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
api_mode = _parse_api_mode(entry.get("api_mode") or entry.get("transport"))
if api_mode:
result["api_mode"] = api_mode
_lift_max_output_tokens(entry, result)
return result
# Fall back to custom_providers: list (legacy format)
@ -611,6 +628,7 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
model_name = str(entry.get("model", "") or "").strip()
if model_name:
result["model"] = model_name
_lift_max_output_tokens(entry, result)
return result
return None
@ -699,6 +717,8 @@ def _resolve_named_custom_runtime(
model_name = custom_provider.get("model")
if model_name:
pool_result["model"] = model_name
if isinstance(custom_provider.get("max_output_tokens"), int):
pool_result["max_output_tokens"] = custom_provider["max_output_tokens"]
request_overrides = _custom_provider_request_overrides(custom_provider)
if request_overrides:
pool_result["request_overrides"] = {
@ -736,6 +756,8 @@ def _resolve_named_custom_runtime(
# provider name differs from the actual model string the API expects.
if custom_provider.get("model"):
result["model"] = custom_provider["model"]
if isinstance(custom_provider.get("max_output_tokens"), int):
result["max_output_tokens"] = custom_provider["max_output_tokens"]
request_overrides = _custom_provider_request_overrides(custom_provider)
if request_overrides:
result["request_overrides"] = request_overrides