diff --git a/agent/agent_init.py b/agent/agent_init.py index c39712d4d02..be9a09dd2f5 100644 --- a/agent/agent_init.py +++ b/agent/agent_init.py @@ -71,6 +71,71 @@ def _ra(): return run_agent +def _normalized_custom_base_url(value: Any) -> str: + if not isinstance(value, str): + return "" + return value.strip().rstrip("/") + + +def _custom_provider_model_matches(agent_model: str, entry: Dict[str, Any]) -> bool: + provider_model = str(entry.get("model", "") or "").strip().lower() + if not provider_model: + return True + return provider_model == str(agent_model or "").strip().lower() + + +def _custom_provider_extra_body_for_agent( + *, + provider: str, + model: str, + base_url: str, + custom_providers: List[Dict[str, Any]], +) -> Optional[Dict[str, Any]]: + if (provider or "").strip().lower() != "custom": + return None + + target_url = _normalized_custom_base_url(base_url) + if not target_url: + return None + + fallback: Optional[Dict[str, Any]] = None + for entry in custom_providers or []: + if not isinstance(entry, dict): + continue + if _normalized_custom_base_url(entry.get("base_url")) != target_url: + continue + extra_body = entry.get("extra_body") + if not isinstance(extra_body, dict) or not extra_body: + continue + provider_model = str(entry.get("model", "") or "").strip() + if provider_model: + if _custom_provider_model_matches(model, entry): + return dict(extra_body) + elif fallback is None: + fallback = dict(extra_body) + + return fallback + + +def _merge_custom_provider_extra_body(agent, custom_providers: List[Dict[str, Any]]) -> None: + extra_body = _custom_provider_extra_body_for_agent( + provider=agent.provider, + model=agent.model, + base_url=agent.base_url, + custom_providers=custom_providers, + ) + if not extra_body: + return + + overrides = dict(getattr(agent, "request_overrides", {}) or {}) + merged_extra_body = dict(extra_body) + existing_extra_body = overrides.get("extra_body") + if isinstance(existing_extra_body, dict): + merged_extra_body.update(existing_extra_body) + overrides["extra_body"] = merged_extra_body + agent.request_overrides = overrides + + def init_agent( agent, base_url: str = None, @@ -1213,6 +1278,7 @@ def init_agent( # Store for reuse by _check_compression_model_feasibility (auxiliary # compression model context-length detection needs the same list). agent._custom_providers = _custom_providers + _merge_custom_provider_extra_body(agent, _custom_providers) # Check custom_providers per-model context_length if _config_context_length is None and _custom_providers: diff --git a/hermes_cli/config.py b/hermes_cli/config.py index de8ca79cd88..8d4484fad0e 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -3017,7 +3017,7 @@ def _normalize_custom_provider_entry( "api_mode", "transport", "model", "default_model", "models", "context_length", "rate_limit_delay", "request_timeout_seconds", "stale_timeout_seconds", - "discover_models", + "discover_models", "extra_body", } for camel, snake in _CAMEL_ALIASES.items(): if camel in entry and snake not in entry: @@ -3112,6 +3112,10 @@ def _normalize_custom_provider_entry( if isinstance(discover_models, bool): normalized["discover_models"] = discover_models + extra_body = entry.get("extra_body") + if isinstance(extra_body, dict): + normalized["extra_body"] = dict(extra_body) + return normalized @@ -3272,7 +3276,7 @@ _KNOWN_ROOT_KEYS = { # Valid fields inside a custom_providers list entry _VALID_CUSTOM_PROVIDER_FIELDS = { "name", "base_url", "api_key", "api_mode", "model", "models", - "context_length", "rate_limit_delay", + "context_length", "rate_limit_delay", "extra_body", # key_env is read at runtime by runtime_provider.py and auxiliary_client.py # — include it here so the set accurately describes the supported schema. "key_env", diff --git a/hermes_cli/runtime_provider.py b/hermes_cli/runtime_provider.py index 73aa5c45571..c40316e02cc 100644 --- a/hermes_cli/runtime_provider.py +++ b/hermes_cli/runtime_provider.py @@ -528,6 +528,9 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An "api_key": resolved_api_key, "model": entry.get("default_model", ""), } + extra_body = entry.get("extra_body") + if isinstance(extra_body, dict): + result["extra_body"] = dict(extra_body) # The v11→v12 migration writes the API mode under the new # ``transport`` field, but hand-edited configs may still # use the legacy ``api_mode`` spelling. Accept both — @@ -553,6 +556,9 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An "api_key": resolved_api_key, "model": entry.get("default_model", ""), } + extra_body = entry.get("extra_body") + if isinstance(extra_body, dict): + result["extra_body"] = dict(extra_body) api_mode = _parse_api_mode(entry.get("api_mode") or entry.get("transport")) if api_mode: result["api_mode"] = api_mode @@ -596,6 +602,9 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An result["key_env"] = key_env if provider_key: result["provider_key"] = provider_key + extra_body = entry.get("extra_body") + if isinstance(extra_body, dict): + result["extra_body"] = dict(extra_body) api_mode = _parse_api_mode(entry.get("api_mode")) if api_mode: result["api_mode"] = api_mode @@ -607,6 +616,13 @@ def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, An return None +def _custom_provider_request_overrides(custom_provider: Dict[str, Any]) -> Dict[str, Any]: + extra_body = custom_provider.get("extra_body") + if not isinstance(extra_body, dict) or not extra_body: + return {} + return {"extra_body": dict(extra_body)} + + def _resolve_named_custom_runtime( *, requested_provider: str, @@ -683,6 +699,12 @@ def _resolve_named_custom_runtime( model_name = custom_provider.get("model") if model_name: pool_result["model"] = model_name + request_overrides = _custom_provider_request_overrides(custom_provider) + if request_overrides: + pool_result["request_overrides"] = { + **dict(pool_result.get("request_overrides") or {}), + **request_overrides, + } return pool_result _cp_is_openai_url = base_url_host_matches(base_url, "openai.com") or base_url_host_matches(base_url, "openai.azure.com") @@ -714,6 +736,9 @@ def _resolve_named_custom_runtime( # provider name differs from the actual model string the API expects. if custom_provider.get("model"): result["model"] = custom_provider["model"] + request_overrides = _custom_provider_request_overrides(custom_provider) + if request_overrides: + result["request_overrides"] = request_overrides return result diff --git a/tests/agent/test_custom_provider_extra_body.py b/tests/agent/test_custom_provider_extra_body.py new file mode 100644 index 00000000000..23556ae62de --- /dev/null +++ b/tests/agent/test_custom_provider_extra_body.py @@ -0,0 +1,93 @@ +from types import SimpleNamespace + +from agent.agent_init import _merge_custom_provider_extra_body + + +def test_custom_provider_extra_body_merges_into_request_overrides(): + agent = SimpleNamespace( + provider="custom", + model="google/gemma-4-31b-it", + base_url="https://example.test/v1", + request_overrides={"service_tier": "priority"}, + ) + + _merge_custom_provider_extra_body( + agent, + [ + { + "name": "gemma", + "base_url": "https://example.test/v1/", + "model": "google/gemma-4-31b-it", + "extra_body": { + "enable_thinking": True, + "reasoning_effort": "high", + }, + } + ], + ) + + assert agent.request_overrides == { + "service_tier": "priority", + "extra_body": { + "enable_thinking": True, + "reasoning_effort": "high", + }, + } + + +def test_custom_provider_extra_body_preserves_caller_override(): + agent = SimpleNamespace( + provider="custom", + model="google/gemma-4-31b-it", + base_url="https://example.test/v1", + request_overrides={ + "extra_body": { + "reasoning_effort": "low", + "caller_only": True, + } + }, + ) + + _merge_custom_provider_extra_body( + agent, + [ + { + "name": "gemma", + "base_url": "https://example.test/v1", + "model": "google/gemma-4-31b-it", + "extra_body": { + "enable_thinking": True, + "reasoning_effort": "high", + }, + } + ], + ) + + assert agent.request_overrides["extra_body"] == { + "enable_thinking": True, + "reasoning_effort": "low", + "caller_only": True, + } + + +def test_custom_provider_extra_body_ignores_other_custom_models(): + agent = SimpleNamespace( + provider="custom", + model="other-model", + base_url="https://example.test/v1", + request_overrides={}, + ) + + _merge_custom_provider_extra_body( + agent, + [ + { + "name": "gemma", + "base_url": "https://example.test/v1", + "model": "google/gemma-4-31b-it", + "extra_body": {"enable_thinking": True}, + } + ], + ) + + assert agent.request_overrides == {} diff --git a/tests/hermes_cli/test_runtime_provider_resolution.py b/tests/hermes_cli/test_runtime_provider_resolution.py index 3adffabb461..394216c9171 100644 --- a/tests/hermes_cli/test_runtime_provider_resolution.py +++ b/tests/hermes_cli/test_runtime_provider_resolution.py @@ -1631,6 +1631,33 @@ def test_named_custom_runtime_propagates_model_direct_path(monkeypatch): assert resolved["provider"] == "custom" +def test_named_custom_runtime_propagates_extra_body_direct_path(monkeypatch): + """Custom provider extra_body should become runtime request_overrides.""" + monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-gemma") + monkeypatch.setattr( + rp, "_get_named_custom_provider", + lambda p: { + "name": "my-gemma", + "base_url": "http://localhost:8000/v1", + "api_key": "test-key", + "model": "google/gemma-4-31b-it", + "extra_body": { + "enable_thinking": True, + "reasoning_effort": "high", + }, + }, + ) + monkeypatch.setattr(rp, "_try_resolve_from_custom_pool", lambda *a, **k: None) + + resolved = rp.resolve_runtime_provider(requested="my-gemma") + assert resolved["request_overrides"] == { + "extra_body": { + "enable_thinking": True, + "reasoning_effort": "high", + } + } + + def test_named_custom_runtime_propagates_model_pool_path(monkeypatch): """Model should propagate even when credential pool handles credentials.""" monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server") @@ -1662,6 +1689,36 @@ def test_named_custom_runtime_propagates_model_pool_path(monkeypatch): assert resolved["api_key"] == "pool-key", "pool credentials should be used" +def test_named_custom_runtime_propagates_extra_body_pool_path(monkeypatch): + """Custom provider extra_body should survive credential-pool resolution.""" + monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-gemma") + monkeypatch.setattr( + rp, "_get_named_custom_provider", + lambda p: { + "name": "my-gemma", + "base_url": "http://localhost:8000/v1", + "api_key": "test-key", + "model": "google/gemma-4-31b-it", + "extra_body": {"enable_thinking": True}, + }, + ) + monkeypatch.setattr( + rp, "_try_resolve_from_custom_pool", + lambda *a, **k: { + "provider": "custom", + "api_mode": "chat_completions", + "base_url": "http://localhost:8000/v1", + "api_key": "pool-key", + "source": "pool:custom:my-gemma", + }, + ) + + resolved = rp.resolve_runtime_provider(requested="my-gemma") + assert resolved["request_overrides"] == { + "extra_body": {"enable_thinking": True} + } + + def test_named_custom_runtime_no_model_when_absent(monkeypatch): """When custom_providers entry has no model field, runtime should not either.""" monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-server") @@ -2150,6 +2207,24 @@ class TestProviderEntryApiKeyEnvAlias: key_env so the set stays in sync with what the runtime actually reads.""" from hermes_cli.config import _VALID_CUSTOM_PROVIDER_FIELDS assert "key_env" in _VALID_CUSTOM_PROVIDER_FIELDS + + def test_extra_body_is_supported_schema(self): + from hermes_cli.config import ( + _VALID_CUSTOM_PROVIDER_FIELDS, + _normalize_custom_provider_entry, + ) + entry = { + "name": "vendor", + "base_url": "https://api.vendor.example.com/v1", + "extra_body": { + "chat_template_kwargs": {"enable_thinking": True}, + "include_reasoning": True, + }, + } + normalized = _normalize_custom_provider_entry(dict(entry), provider_key="vendor") + assert normalized is not None + assert "extra_body" in _VALID_CUSTOM_PROVIDER_FIELDS + assert normalized["extra_body"] == entry["extra_body"] # ============================================================================= # Tencent TokenHub — API-key provider runtime resolution # ============================================================================= diff --git a/tests/providers/test_transport_parity.py b/tests/providers/test_transport_parity.py index 8c1fb6eb4f1..5d1856cd84b 100644 --- a/tests/providers/test_transport_parity.py +++ b/tests/providers/test_transport_parity.py @@ -236,7 +236,7 @@ class TestQwenParity: class TestCustomOllamaParity: - """Custom/Ollama: num_ctx, think=false — now tested via profile.""" + """Custom/Ollama: num_ctx, thinking controls — now tested via profile.""" def test_ollama_num_ctx(self, transport): kw = transport.build_kwargs( diff --git a/website/docs/integrations/providers.md b/website/docs/integrations/providers.md index 6969bcc7e60..13515a87692 100644 --- a/website/docs/integrations/providers.md +++ b/website/docs/integrations/providers.md @@ -1228,6 +1228,26 @@ custom_providers: api_mode: anthropic_messages # for Anthropic-compatible proxies ``` +Some OpenAI-compatible endpoints need provider-specific request body fields. Add an `extra_body` map to the matching custom provider and Hermes will merge it into each chat-completions request for that endpoint: + +```yaml +custom_providers: + - name: gemma-local + base_url: http://localhost:8080/v1 + model: google/gemma-4-31b-it + extra_body: + enable_thinking: true + reasoning_effort: high +``` + +Use the shape your server documents. For example, vLLM Gemma deployments and some NVIDIA NIM endpoints expect `enable_thinking` under `chat_template_kwargs` instead of as a top-level `extra_body` field: + +```yaml +extra_body: + chat_template_kwargs: + enable_thinking: true +``` + The `hermes model` → Custom Endpoint wizard now prompts for `api_mode` explicitly and persists your answer to `config.yaml`. URL-based auto-detection (e.g. `/anthropic` paths → `anthropic_messages`) still happens as a fallback when the field is left blank. Switch between them mid-session with the triple syntax: