mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
fix(gateway): honor per-provider max_output_tokens in max_tokens chain
Widens ViewWay's #20741 fix to the sibling config surface: a custom_providers entry can pin its own output cap via max_output_tokens (or max_tokens). _get_named_custom_provider now lifts it onto the resolved runtime at all three return sites, and the gateway uses it as a fallback only when the documented global model.max_tokens isn't set, so the global key always wins. Precedence: HERMES_MAX_TOKENS > model.max_tokens > provider max_output_tokens > None. Closes the same #20741 truncation for users who configure the cap per-provider rather than globally. Picks up the intent of #19782 (alexcam1901), reimplemented to feed ViewWay's max_tokens pipeline.
This commit is contained in:
parent
1c909e75e1
commit
14275d7baa
2 changed files with 29 additions and 0 deletions
|
|
@ -1213,6 +1213,13 @@ def _resolve_runtime_agent_kwargs() -> dict:
|
|||
mt = model_cfg.get("max_tokens")
|
||||
if isinstance(mt, int):
|
||||
max_tokens = mt
|
||||
# Fall back to a per-provider output cap (custom_providers max_output_tokens)
|
||||
# only when the documented global model.max_tokens isn't set, so the global
|
||||
# key always wins.
|
||||
if max_tokens is None:
|
||||
_runtime_mot = runtime.get("max_output_tokens")
|
||||
if isinstance(_runtime_mot, int) and _runtime_mot > 0:
|
||||
max_tokens = _runtime_mot
|
||||
|
||||
return {
|
||||
"api_key": runtime.get("api_key"),
|
||||
|
|
|
|||
|
|
@ -474,6 +474,21 @@ def _try_resolve_from_custom_pool(
|
|||
return None
|
||||
|
||||
|
||||
def _lift_max_output_tokens(entry: Dict[str, Any], result: Dict[str, Any]) -> None:
|
||||
"""Propagate a per-provider output cap onto the resolved runtime dict.
|
||||
|
||||
Accepts ``max_output_tokens`` or ``max_tokens`` on a ``custom_providers``
|
||||
entry so a provider block can pin its own output limit. Gateway and CLI
|
||||
map this onto ``AIAgent.max_tokens`` only when the top-level
|
||||
``model.max_tokens`` isn't set, so the documented global key still wins.
|
||||
"""
|
||||
for _k in ("max_output_tokens", "max_tokens"):
|
||||
_v = entry.get(_k)
|
||||
if isinstance(_v, int) and _v > 0:
|
||||
result["max_output_tokens"] = _v
|
||||
return
|
||||
|
||||
|
||||
def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, Any]]:
|
||||
requested_norm = _normalize_custom_provider_name(requested_provider or "")
|
||||
if not requested_norm or requested_norm == "custom":
|
||||
|
|
@ -541,6 +556,7 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
|
|||
api_mode = _parse_api_mode(entry.get("api_mode") or entry.get("transport"))
|
||||
if api_mode:
|
||||
result["api_mode"] = api_mode
|
||||
_lift_max_output_tokens(entry, result)
|
||||
return result
|
||||
# Also check the 'name' field if present
|
||||
display_name = entry.get("name", "")
|
||||
|
|
@ -562,6 +578,7 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
|
|||
api_mode = _parse_api_mode(entry.get("api_mode") or entry.get("transport"))
|
||||
if api_mode:
|
||||
result["api_mode"] = api_mode
|
||||
_lift_max_output_tokens(entry, result)
|
||||
return result
|
||||
|
||||
# Fall back to custom_providers: list (legacy format)
|
||||
|
|
@ -611,6 +628,7 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An
|
|||
model_name = str(entry.get("model", "") or "").strip()
|
||||
if model_name:
|
||||
result["model"] = model_name
|
||||
_lift_max_output_tokens(entry, result)
|
||||
return result
|
||||
|
||||
return None
|
||||
|
|
@ -699,6 +717,8 @@ def _resolve_named_custom_runtime(
|
|||
model_name = custom_provider.get("model")
|
||||
if model_name:
|
||||
pool_result["model"] = model_name
|
||||
if isinstance(custom_provider.get("max_output_tokens"), int):
|
||||
pool_result["max_output_tokens"] = custom_provider["max_output_tokens"]
|
||||
request_overrides = _custom_provider_request_overrides(custom_provider)
|
||||
if request_overrides:
|
||||
pool_result["request_overrides"] = {
|
||||
|
|
@ -736,6 +756,8 @@ def _resolve_named_custom_runtime(
|
|||
# provider name differs from the actual model string the API expects.
|
||||
if custom_provider.get("model"):
|
||||
result["model"] = custom_provider["model"]
|
||||
if isinstance(custom_provider.get("max_output_tokens"), int):
|
||||
result["max_output_tokens"] = custom_provider["max_output_tokens"]
|
||||
request_overrides = _custom_provider_request_overrides(custom_provider)
|
||||
if request_overrides:
|
||||
result["request_overrides"] = request_overrides
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue