Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions clawbot/core/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ def _match_provider(
"""Match provider config and its registry name. Returns (config, spec_name)."""
model_lower = (model or self.agents.defaults.model).lower()

# Explicit routing by "<provider>/..." prefix — deterministic and works
# for providers with no keywords (e.g. "custom/my-model").
if "/" in model_lower:
prefix = model_lower.split("/", 1)[0]
spec = find_by_name(prefix)
if spec:
p = getattr(self.providers, spec.name, None)
if p is not None:
return p, spec.name

# First pass: match by keyword in model name + has API key
for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None)
Expand Down
8 changes: 7 additions & 1 deletion clawbot/providers/custom_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,14 @@ async def chat(
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
# The config stores models as "custom/<real_id>" so provider routing
# can find us, but the upstream OpenAI-compatible endpoint only knows
# the bare id — strip the routing prefix before sending.
resolved_model = model or self.default_model
if resolved_model.startswith("custom/"):
resolved_model = resolved_model[len("custom/") :]
kwargs: dict[str, Any] = {
"model": model or self.default_model,
"model": resolved_model,
"messages": messages,
"max_tokens": max(1, max_tokens),
"temperature": temperature,
Expand Down
4 changes: 2 additions & 2 deletions clawforce-ui/src/components/OnboardingWizardModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,11 @@ export function OnboardingWizardModal({
model={wizardAgent.model}
savedProviders={wizardAgent.providers as Record<string, Record<string, unknown>> | undefined}
onModelChange={(v) => update({ model: v })}
onProviderKeyChange={(provider, key) => {
onProviderChange={(provider, patch) => {
update({
providers: {
...wizardAgent.providers,
[provider]: { ...(wizardAgent.providers?.[provider] || {}), apiKey: key },
[provider]: { ...(wizardAgent.providers?.[provider] || {}), ...patch },
},
});
}}
Expand Down
1 change: 1 addition & 0 deletions clawforce-ui/src/components/agent-detail/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ export const PROVIDER_DEFS: ProviderDef[] = [
{ field: "chatgpt", label: "ChatGPT Plus", keywords: ["chatgpt"], oauth: true },
{ field: "openai_codex", label: "OpenAI Codex", keywords: ["openai-codex", "codex"], oauth: true },
// API key / token providers
{ field: "custom", label: "Custom (OpenAI-compatible)", keywords: ["custom/"] },
{ field: "anthropic", label: "Anthropic", keywords: ["anthropic", "claude"] },
{ field: "openai", label: "OpenAI", keywords: ["openai", "gpt", "o1", "o3", "o4"] },
{ field: "openrouter", label: "OpenRouter", keywords: ["openrouter"] },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,11 @@ export function GeneralTab({ agentId, agent, update, updateTools }: { agentId: s
model={agent.model}
savedProviders={agent.providers as Record<string, Record<string, unknown>> | undefined}
onModelChange={(v) => update({ model: v })}
onProviderKeyChange={(provider, key) => {
onProviderChange={(provider, patch) => {
update({
providers: {
...agent.providers,
[provider]: { ...(agent.providers?.[provider] || {}), apiKey: key },
[provider]: { ...(agent.providers?.[provider] || {}), ...patch },
},
});
}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ export function ModelProviderSection({
model,
savedProviders,
onModelChange,
onProviderKeyChange,
onProviderChange,
}: {
agentId: string;
model: string;
savedProviders?: Record<string, Record<string, unknown>>;
onModelChange: (model: string) => void;
onProviderKeyChange: (provider: string, apiKey: string) => void;
onProviderChange: (provider: string, patch: { apiKey?: string; apiBase?: string }) => void;
}) {
// Derive initial provider from current model
const detected = detectProvider(model);
Expand Down Expand Up @@ -49,6 +49,20 @@ export function ModelProviderSection({
}
}, [savedKey]);

// Custom provider: OpenAI-compatible base URL (snake/camel tolerant)
const savedBase = selectedProvider && savedProviders?.[selectedProvider]
? ((savedProviders[selectedProvider].apiBase ?? savedProviders[selectedProvider].api_base ?? "") as string)
: "";
const [apiBase, setApiBase] = useState(savedBase);

const prevSavedBaseRef = useRef(savedBase);
useEffect(() => {
if (savedBase !== prevSavedBaseRef.current) {
prevSavedBaseRef.current = savedBase;
setApiBase(savedBase);
}
}, [savedBase]);

const [models, setModels] = useState<FetchedModel[]>([]);
const [loadingModels, setLoadingModels] = useState(false);
const [modelError, setModelError] = useState("");
Expand Down Expand Up @@ -157,17 +171,27 @@ export function ModelProviderSection({
if (providerDef?.oauth) return; // handled by the OAuth effect above
const isStatic = ["bedrock", "azure"].includes(selectedProvider);
const hasSavedKey = !!(savedKey && savedKey.length > 0);
if (isStatic || hasSavedKey) {
const isCustom = selectedProvider === "custom";
const hasSavedBase = !!(savedBase && savedBase.length > 0);
// Custom needs a base URL; API key is optional (self-hosted endpoints).
const canAutoFetch = isStatic || (isCustom ? hasSavedBase : hasSavedKey);
if (canAutoFetch) {
setLoadingModels(true);
setModelError("");
// For saved (redacted) keys, pass agentId so backend uses stored key
// For saved (redacted) keys/bases, pass agentId so backend uses stored values
const keyToSend = savedKey.startsWith("***") ? "" : savedKey;
api.providers.listModels(selectedProvider, keyToSend, keyToSend ? undefined : agentId)
const baseToSend = savedBase.startsWith("***") ? "" : savedBase;
api.providers.listModels(
selectedProvider,
keyToSend,
keyToSend && (!isCustom || baseToSend) ? undefined : agentId,
isCustom ? baseToSend : undefined,
)
.then((r) => setModels(r.models))
.catch(() => {})
.finally(() => setLoadingModels(false));
}
}, [selectedProvider, agentId, savedKey, providerDef?.oauth]);
}, [selectedProvider, agentId, savedKey, savedBase, providerDef?.oauth]);

async function handleOAuthAuthorize() {
setOauthLoading(true);
Expand All @@ -187,13 +211,24 @@ export function ModelProviderSection({

function doFetch() {
if (!selectedProvider) return;
// Use explicit key if available, otherwise fall back to stored key via agentId
const isCustom = selectedProvider === "custom";
const hasExplicitKey = apiKey.length > 0 && !apiKey.startsWith("***");
if (!hasExplicitKey && !savedKey) return;
const hasExplicitBase = apiBase.length > 0 && !apiBase.startsWith("***");
// Custom requires a base URL; other providers require a key.
if (isCustom) {
if (!hasExplicitBase && !savedBase) return;
} else if (!hasExplicitKey && !savedKey) {
return;
}
setLoadingModels(true);
setModelError("");
setModels([]);
api.providers.listModels(selectedProvider, hasExplicitKey ? apiKey : "", agentId)
api.providers.listModels(
selectedProvider,
hasExplicitKey ? apiKey : "",
agentId,
isCustom ? (hasExplicitBase ? apiBase : "") : undefined,
)
.then((r) => setModels(r.models))
.catch((err) => {
const msg = err instanceof Error ? err.message : String(err);
Expand All @@ -213,16 +248,23 @@ export function ModelProviderSection({
? ((savedProviders[field].apiKey ?? savedProviders[field].api_key ?? "") as string)
: "";
setApiKey(typeof saved === "string" && !saved.startsWith("***") ? saved : "");
const savedBaseForField = field && savedProviders?.[field]
? ((savedProviders[field].apiBase ?? savedProviders[field].api_base ?? "") as string)
: "";
setApiBase(typeof savedBaseForField === "string" && !savedBaseForField.startsWith("***") ? savedBaseForField : "");
setModels([]);
setModelError("");
setModelSearch("");
}

function handleModelSelect(modelId: string) {
const fullModel = `${selectedProvider}/${modelId}`;
// Save the provider API key into agent.providers so it's included in "Save Changes"
if (needsKey && apiKey) {
onProviderKeyChange(selectedProvider, apiKey);
// Persist API key / base URL into agent.providers so they're included in "Save Changes"
if (needsKey) {
const patch: { apiKey?: string; apiBase?: string } = {};
if (apiKey && !apiKey.startsWith("***")) patch.apiKey = apiKey;
if (selectedProvider === "custom" && apiBase && !apiBase.startsWith("***")) patch.apiBase = apiBase;
if (Object.keys(patch).length > 0) onProviderChange(selectedProvider, patch);
}
// OAuth providers: no key to propagate — credentials live in the OS credential store
onModelChange(fullModel);
Expand All @@ -235,6 +277,9 @@ export function ModelProviderSection({
: "";

const hasUsableKey = apiKey.length > 0 || !!(savedKey && savedKey.length > 0);
const hasUsableBase = apiBase.length > 0 || !!(savedBase && savedBase.length > 0);
// Custom only needs a base URL to fetch models; API key is optional.
const canFetchModels = selectedProvider === "custom" ? hasUsableBase : hasUsableKey;

const filteredModels = modelSearch
? models.filter((m) =>
Expand Down Expand Up @@ -343,6 +388,24 @@ export function ModelProviderSection({
)}
</div>

{/* Base URL input — only for the Custom (OpenAI-compatible) provider */}
{selectedProvider === "custom" && (
<div>
<label className={css.label}>Base URL</label>
<input
className={`${css.input} w-full`}
type="text"
value={apiBase}
onChange={(e) => setApiBase(e.target.value)}
onKeyDown={(e) => { if (e.key === "Enter") doFetch(); }}
placeholder="https://api.example.com/v1"
/>
<p className="mt-1 text-[10px] text-claude-text-muted">
Must be an OpenAI API compatible endpoint (e.g. vLLM, LM Studio, a private gateway).
</p>
</div>
)}

{/* Row 2: Model + Fetch models */}
{selectedProvider && (
<div>
Expand All @@ -368,7 +431,9 @@ export function ModelProviderSection({
? "Loading models…"
: models.length === 0
? (needsKey
? (savedKey ? "Click Fetch models to load" : "Enter API key and fetch models")
? (selectedProvider === "custom"
? (hasUsableBase ? "Click Fetch models to load" : "Enter base URL and fetch models")
: (savedKey ? "Click Fetch models to load" : "Enter API key and fetch models"))
: providerDef?.oauth
? (oauthAuthorized ? "Select a model…" : "Connect first to browse models")
: "Select a provider first")
Expand Down Expand Up @@ -438,7 +503,7 @@ export function ModelProviderSection({
<button
type="button"
onClick={doFetch}
disabled={!hasUsableKey || loadingModels}
disabled={!canFetchModels || loadingModels}
className={`${css.btn} shrink-0 text-claude-accent ring-1 ring-claude-accent/30 hover:bg-claude-accent/5 disabled:opacity-40 disabled:cursor-not-allowed`}
>
{loadingModels ? (
Expand All @@ -456,11 +521,14 @@ export function ModelProviderSection({
<button
type="button"
onClick={() => {
const custom = prompt("Enter model ID (e.g. claude-sonnet-4-20250514):", currentModelDisplay);
if (custom !== null && custom.trim()) {
const fullModel = selectedProvider ? `${selectedProvider}/${custom.trim()}` : custom.trim();
if (needsKey && apiKey) {
onProviderKeyChange(selectedProvider, apiKey);
const manual = prompt("Enter model ID (e.g. claude-sonnet-4-20250514):", currentModelDisplay);
if (manual !== null && manual.trim()) {
const fullModel = selectedProvider ? `${selectedProvider}/${manual.trim()}` : manual.trim();
if (needsKey) {
const patch: { apiKey?: string; apiBase?: string } = {};
if (apiKey && !apiKey.startsWith("***")) patch.apiKey = apiKey;
if (selectedProvider === "custom" && apiBase && !apiBase.startsWith("***")) patch.apiBase = apiBase;
if (Object.keys(patch).length > 0) onProviderChange(selectedProvider, patch);
}
onModelChange(fullModel);
}
Expand Down
4 changes: 2 additions & 2 deletions clawforce-ui/src/lib/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,10 @@ export const api = {
post<{ ok: boolean }>(`/plans/${planId}/workspace-folder/${path}`),
},
providers: {
listModels: (provider: string, apiKey: string, agentId?: string) =>
listModels: (provider: string, apiKey: string, agentId?: string, apiBase?: string) =>
post<{ provider: string; prefix: string; models: { id: string; name: string }[] }>(
"/providers/models",
{ provider, api_key: apiKey, agent_id: agentId || "" },
{ provider, api_key: apiKey, agent_id: agentId || "", api_base: apiBase || "" },
),
oauthStatus: (provider: string, agentId?: string) =>
request<{ provider: string; authorized: boolean; account_id?: string }>(
Expand Down
52 changes: 52 additions & 0 deletions clawforce/apis/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ class ListModelsRequest(BaseModel):
provider: str
api_key: str = ""
agent_id: str = ""
api_base: str = ""


@router.post("/api/providers/models")
Expand All @@ -225,6 +226,57 @@ async def list_provider_models(
provider is read from the agent's persisted config.
"""
provider = body.provider.lower()

# Custom (OpenAI-compatible) provider: user-supplied base URL, GET {base}/models.
if provider == "custom":
api_base = body.api_base
api_key = body.api_key
if body.agent_id:
stored_cfg = agent_config_store.get_config(body.agent_id) or {}
stored = (stored_cfg.get("providers") or {}).get("custom") or {}
if not api_base or api_base.startswith("***"):
api_base = stored.get("api_base") or stored.get("apiBase") or ""
if not api_key or api_key.startswith("***"):
api_key = stored.get("api_key") or stored.get("apiKey") or ""
if not api_base:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Base URL is required for custom provider",
)
url = api_base.rstrip("/") + "/models"
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
try:
async with httpx.AsyncClient(timeout=15) as client:
resp = await client.get(url, headers=headers)
if resp.status_code == 401:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
if resp.status_code != 200:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Provider returned {resp.status_code}: {resp.text[:200]}",
)
data = resp.json()
models = [
{"id": m["id"], "name": m.get("id", m["id"])}
for m in sorted(data.get("data", []), key=lambda m: m["id"])
]
return {"provider": "custom", "prefix": "custom", "models": models}
except httpx.TimeoutException:
raise HTTPException(
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
detail="Timed out connecting to provider",
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Failed to fetch models: {str(e)[:200]}",
)

ep = PROVIDER_ENDPOINTS.get(provider)
if not ep:
raise HTTPException(
Expand Down
Loading