Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions hermes_cli/runtime_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from hermes_constants import OPENROUTER_BASE_URL


def _normalize_custom_provider_name(value: str) -> str:
return value.strip().lower().replace(" ", "-")


def _get_model_config() -> Dict[str, Any]:
config = load_config()
model_cfg = config.get("model")
Expand Down Expand Up @@ -45,6 +49,69 @@ def resolve_requested_provider(requested: Optional[str] = None) -> str:
return "auto"


def _get_named_custom_provider(requested_provider: str) -> Optional[Dict[str, Any]]:
requested_norm = _normalize_custom_provider_name(requested_provider or "")
if not requested_norm or requested_norm == "custom":
return None

config = load_config()
custom_providers = config.get("custom_providers")
if not isinstance(custom_providers, list):
return None

for entry in custom_providers:
if not isinstance(entry, dict):
continue
name = entry.get("name")
base_url = entry.get("base_url")
if not isinstance(name, str) or not isinstance(base_url, str):
continue
name_norm = _normalize_custom_provider_name(name)
menu_key = f"custom:{name_norm}"
if requested_norm not in {name_norm, menu_key}:
continue
return {
"name": name.strip(),
"base_url": base_url.strip(),
"api_key": str(entry.get("api_key", "") or "").strip(),
}

return None


def _resolve_named_custom_runtime(
*,
requested_provider: str,
explicit_api_key: Optional[str] = None,
explicit_base_url: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
custom_provider = _get_named_custom_provider(requested_provider)
if not custom_provider:
return None

base_url = (
(explicit_base_url or "").strip()
or custom_provider.get("base_url", "")
).rstrip("/")
if not base_url:
return None

api_key = (
(explicit_api_key or "").strip()
or custom_provider.get("api_key", "")
or os.getenv("OPENAI_API_KEY", "").strip()
or os.getenv("OPENROUTER_API_KEY", "").strip()
)

return {
"provider": "openrouter",
"api_mode": "chat_completions",
"base_url": base_url,
"api_key": api_key,
"source": f"custom_provider:{custom_provider.get('name', requested_provider)}",
}


def _resolve_openrouter_runtime(
*,
requested_provider: str,
Expand Down Expand Up @@ -120,6 +187,15 @@ def resolve_runtime_provider(
"""Resolve runtime provider credentials for agent execution."""
requested_provider = resolve_requested_provider(requested)

custom_runtime = _resolve_named_custom_runtime(
requested_provider=requested_provider,
explicit_api_key=explicit_api_key,
explicit_base_url=explicit_base_url,
)
if custom_runtime:
custom_runtime["requested_provider"] = requested_provider
return custom_runtime

provider = resolve_provider(
requested_provider,
explicit_api_key=explicit_api_key,
Expand Down
68 changes: 68 additions & 0 deletions tests/test_runtime_provider_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,74 @@ def test_custom_endpoint_auto_provider_prefers_openai_key(monkeypatch):
assert resolved["api_key"] == "sk-vllm-key"


def test_named_custom_provider_uses_saved_credentials(monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
monkeypatch.setattr(
rp,
"load_config",
lambda: {
"custom_providers": [
{
"name": "Local",
"base_url": "http://1.2.3.4:1234/v1",
"api_key": "local-provider-key",
}
]
},
)
monkeypatch.setattr(
rp,
"resolve_provider",
lambda *a, **k: (_ for _ in ()).throw(
AssertionError(
"resolve_provider should not be called for named custom providers"
)
),
)

resolved = rp.resolve_runtime_provider(requested="local")

assert resolved["provider"] == "openrouter"
assert resolved["api_mode"] == "chat_completions"
assert resolved["base_url"] == "http://1.2.3.4:1234/v1"
assert resolved["api_key"] == "local-provider-key"
assert resolved["requested_provider"] == "local"
assert resolved["source"] == "custom_provider:Local"


def test_named_custom_provider_falls_back_to_openai_api_key(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "env-openai-key")
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
monkeypatch.setattr(
rp,
"load_config",
lambda: {
"custom_providers": [
{
"name": "Local LLM",
"base_url": "http://localhost:1234/v1",
}
]
},
)
monkeypatch.setattr(
rp,
"resolve_provider",
lambda *a, **k: (_ for _ in ()).throw(
AssertionError(
"resolve_provider should not be called for named custom providers"
)
),
)

resolved = rp.resolve_runtime_provider(requested="custom:local-llm")

assert resolved["base_url"] == "http://localhost:1234/v1"
assert resolved["api_key"] == "env-openai-key"
assert resolved["requested_provider"] == "custom:local-llm"


def test_resolve_runtime_provider_nous_api(monkeypatch):
"""Nous Portal API key provider resolves via the api_key path."""
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "nous-api")
Expand Down
Loading