diff --git a/pkg/config/config.go b/pkg/config/config.go index 197b959731..a7965bbac0 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -152,10 +152,12 @@ type ProvidersConfig struct { } type ProviderConfig struct { - APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` - APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` - Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"` - AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` + APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` + APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + API string `json:"api,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_API"` + Headers map[string]string `json:"headers,omitempty"` + Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"` + AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` } type GatewayConfig struct { diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 7179c4cc53..24319b7799 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -10,23 +10,39 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" "net/url" "strings" + "github.com/openai/openai-go/v3/responses" "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" ) type HTTPProvider struct { apiKey string apiBase string + apiMode string + headers map[string]string httpClient *http.Client } -func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { +type httpProviderError struct { + statusCode int + body string + url string +} + +func (e *httpProviderError) Error() string { + return fmt.Sprintf("API error (%d): %s", e.statusCode, e.body) +} + +func NewHTTPProvider(apiKey, apiBase, proxy, apiMode string, headers map[string]string) *HTTPProvider { client := &http.Client{ Timeout: 0, } @@ -43,6 +59,8 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { return &HTTPProvider{ apiKey: apiKey, apiBase: apiBase, + apiMode: apiMode, + headers: headers, httpClient: client, } } @@ -52,74 +70,36 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too return nil, fmt.Errorf("API base not configured") } - // Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5) - if idx := strings.Index(model, "/"); idx != -1 { - prefix := model[:idx] - if prefix == "moonshot" || prefix == "nvidia" { - model = model[idx+1:] - } - } - - requestBody := map[string]interface{}{ - "model": model, - "messages": messages, - } - - if len(tools) > 0 { - requestBody["tools"] = tools - requestBody["tool_choice"] = "auto" + if isGoogleGenerativeAI(p.apiMode) { + return p.chatWithGoogleGenerativeAI(ctx, messages, tools, model, options) } - if maxTokens, ok := options["max_tokens"].(int); ok { - lowerModel := strings.ToLower(model) - if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { - requestBody["max_completion_tokens"] = maxTokens - } else { - requestBody["max_tokens"] = maxTokens + useResponses := shouldPreferResponses(model, p.apiMode) + if useResponses { + resp, err := p.chatWithResponses(ctx, messages, tools, model, options) + if err == nil { + return resp, nil } - } - - if temperature, ok := options["temperature"].(float64); ok { - lowerModel := strings.ToLower(model) - // Kimi k2 models only support temperature=1 - if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { - requestBody["temperature"] = 1.0 - } else { - requestBody["temperature"] = temperature + if shouldFallbackFromResponses(err) { + logger.DebugCF("provider", "Responses endpoint unsupported, falling back to chat/completions", map[string]interface{}{ + "model": model, + }) + return p.chatWithCompletions(ctx, messages, tools, model, options) } + return nil, err } - jsonData, err := json.Marshal(requestBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - if p.apiKey != "" { - req.Header.Set("Authorization", "Bearer "+p.apiKey) - } - - resp, err := p.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) + resp, err := p.chatWithCompletions(ctx, messages, tools, model, options) + if err == nil { + return resp, nil } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API error: %s", string(body)) + if shouldFallbackFromCompletions(err) { + logger.DebugCF("provider", "Chat/completions endpoint unsupported, falling back to responses", map[string]interface{}{ + "model": model, + }) + return p.chatWithResponses(ctx, messages, tools, model, options) } - - return p.parseResponse(body) + return nil, err } func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { @@ -196,6 +176,503 @@ func (p *HTTPProvider) GetDefaultModel() string { return "" } +func (p *HTTPProvider) applyHeaders(req *http.Request) { + if len(p.headers) > 0 { + for k, v := range p.headers { + req.Header.Set(k, v) + } + } + + if req.Header.Get("Content-Type") == "" { + req.Header.Set("Content-Type", "application/json") + } + if p.apiKey != "" && req.Header.Get("Authorization") == "" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) + } +} + +func (p *HTTPProvider) chatWithCompletions(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + model = normalizeModelForHTTP(model) + requestBody := map[string]interface{}{ + "model": model, + "messages": messages, + } + + if len(tools) > 0 { + requestBody["tools"] = tools + requestBody["tool_choice"] = "auto" + } + + if maxTokens, ok := options["max_tokens"].(int); ok { + lowerModel := strings.ToLower(model) + if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { + requestBody["max_completion_tokens"] = maxTokens + } else { + requestBody["max_tokens"] = maxTokens + } + } + + if temperature, ok := options["temperature"].(float64); ok { + lowerModel := strings.ToLower(model) + // Kimi k2 models only support temperature=1 + if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { + requestBody["temperature"] = 1.0 + } else { + requestBody["temperature"] = temperature + } + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + logger.DebugCF("provider", "HTTP request", map[string]interface{}{ + "url": req.URL.String(), + "method": req.Method, + }) + + p.applyHeaders(req) + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + logger.DebugCF("provider", "HTTP response error", map[string]interface{}{ + "status": resp.StatusCode, + "body": utils.Truncate(string(body), 500), + }) + return nil, &httpProviderError{statusCode: resp.StatusCode, body: string(body), url: req.URL.String()} + } + + return p.parseResponse(body) +} + +func (p *HTTPProvider) chatWithResponses(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + model = normalizeModelForHTTP(model) + params := buildCodexParams(messages, tools, model, stripTemperature(options)) + + jsonData, err := json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("failed to marshal responses request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/responses", bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + logger.DebugCF("provider", "HTTP request", map[string]interface{}{ + "url": req.URL.String(), + "method": req.Method, + }) + + p.applyHeaders(req) + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + logger.DebugCF("provider", "HTTP response error", map[string]interface{}{ + "status": resp.StatusCode, + "body": utils.Truncate(string(body), 500), + }) + return nil, &httpProviderError{statusCode: resp.StatusCode, body: string(body), url: req.URL.String()} + } + + var apiResponse responses.Response + if err := json.Unmarshal(body, &apiResponse); err != nil { + return nil, fmt.Errorf("failed to unmarshal responses API response: %w", err) + } + return parseCodexResponse(&apiResponse), nil +} + +func (p *HTTPProvider) chatWithGoogleGenerativeAI(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + model = normalizeModelForGemini(model) + requestBody := buildGeminiRequest(messages, tools, options) + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal gemini request: %w", err) + } + + endpoint := fmt.Sprintf("%s/models/%s:generateContent", strings.TrimRight(p.apiBase, "/"), url.PathEscape(model)) + req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + logger.DebugCF("provider", "HTTP request", map[string]interface{}{ + "url": req.URL.String(), + "method": req.Method, + }) + + p.applyHeaders(req) + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + logger.DebugCF("provider", "HTTP response error", map[string]interface{}{ + "status": resp.StatusCode, + "body": utils.Truncate(string(body), 500), + }) + return nil, &httpProviderError{statusCode: resp.StatusCode, body: string(body), url: req.URL.String()} + } + + return parseGeminiResponse(body) +} + +func normalizeModelForHTTP(model string) string { + // Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5) + if idx := strings.Index(model, "/"); idx != -1 { + prefix := model[:idx] + if prefix == "moonshot" || prefix == "nvidia" { + return model[idx+1:] + } + } + return model +} + +func normalizeModelForGemini(model string) string { + if idx := strings.Index(model, "/"); idx != -1 { + prefix := strings.ToLower(model[:idx]) + if prefix == "gemini" || prefix == "google" { + return model[idx+1:] + } + } + return model +} + +func buildGeminiRequest(messages []Message, tools []ToolDefinition, options map[string]interface{}) map[string]interface{} { + requestBody := map[string]interface{}{} + contents := make([]map[string]interface{}, 0, len(messages)) + systemParts := make([]map[string]interface{}, 0, 1) + callNameByID := map[string]string{} + + for _, msg := range messages { + switch msg.Role { + case "system": + if msg.Content != "" { + systemParts = append(systemParts, map[string]interface{}{"text": msg.Content}) + } + + case "user": + if msg.ToolCallID != "" { + name := callNameByID[msg.ToolCallID] + if name == "" { + name = "tool_result" + } + contents = append(contents, map[string]interface{}{ + "role": "user", + "parts": []map[string]interface{}{ + { + "functionResponse": map[string]interface{}{ + "name": name, + "response": map[string]interface{}{ + "content": msg.Content, + }, + }, + }, + }, + }) + continue + } + if msg.Content != "" { + contents = append(contents, map[string]interface{}{ + "role": "user", + "parts": []map[string]interface{}{ + {"text": msg.Content}, + }, + }) + } + + case "assistant": + parts := make([]map[string]interface{}, 0, 1+len(msg.ToolCalls)) + if msg.Content != "" { + parts = append(parts, map[string]interface{}{"text": msg.Content}) + } + + for _, tc := range msg.ToolCalls { + name := tc.Name + if name == "" && tc.Function != nil { + name = tc.Function.Name + } + if name == "" { + continue + } + + args := tc.Arguments + if args == nil { + args = map[string]interface{}{} + } + if tc.ID != "" { + callNameByID[tc.ID] = name + } + + parts = append(parts, map[string]interface{}{ + "functionCall": map[string]interface{}{ + "name": name, + "args": args, + }, + }) + } + + if len(parts) > 0 { + contents = append(contents, map[string]interface{}{ + "role": "model", + "parts": parts, + }) + } + + case "tool": + name := callNameByID[msg.ToolCallID] + if name == "" { + name = "tool" + } + contents = append(contents, map[string]interface{}{ + "role": "user", + "parts": []map[string]interface{}{ + { + "functionResponse": map[string]interface{}{ + "name": name, + "response": map[string]interface{}{ + "content": msg.Content, + }, + }, + }, + }, + }) + } + } + + if len(contents) > 0 { + requestBody["contents"] = contents + } + + if len(systemParts) > 0 { + requestBody["system_instruction"] = map[string]interface{}{ + "parts": systemParts, + } + } + + if len(tools) > 0 { + declarations := make([]map[string]interface{}, 0, len(tools)) + for _, t := range tools { + declarations = append(declarations, map[string]interface{}{ + "name": t.Function.Name, + "description": t.Function.Description, + "parameters": t.Function.Parameters, + }) + } + requestBody["tools"] = []map[string]interface{}{ + {"functionDeclarations": declarations}, + } + requestBody["toolConfig"] = map[string]interface{}{ + "functionCallingConfig": map[string]interface{}{ + "mode": "AUTO", + }, + } + } + + generationConfig := map[string]interface{}{} + if maxTokens, ok := options["max_tokens"].(int); ok { + generationConfig["maxOutputTokens"] = maxTokens + } + if temperature, ok := options["temperature"].(float64); ok { + generationConfig["temperature"] = temperature + } + if len(generationConfig) > 0 { + requestBody["generationConfig"] = generationConfig + } + + return requestBody +} + +func parseGeminiResponse(body []byte) (*LLMResponse, error) { + var resp struct { + Candidates []struct { + Content struct { + Parts []struct { + Text string `json:"text"` + FunctionCall *struct { + Name string `json:"name"` + Args map[string]interface{} `json:"args"` + ID string `json:"id"` + } `json:"functionCall"` + } `json:"parts"` + } `json:"content"` + FinishReason string `json:"finishReason"` + } `json:"candidates"` + UsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + } `json:"usageMetadata"` + } + + if err := json.Unmarshal(body, &resp); err != nil { + return nil, fmt.Errorf("failed to unmarshal gemini response: %w", err) + } + + if len(resp.Candidates) == 0 { + return &LLMResponse{ + Content: "", + FinishReason: "stop", + }, nil + } + + candidate := resp.Candidates[0] + var content strings.Builder + toolCalls := make([]ToolCall, 0) + toolCallCount := 0 + for _, part := range candidate.Content.Parts { + if part.Text != "" { + content.WriteString(part.Text) + } + if part.FunctionCall != nil && part.FunctionCall.Name != "" { + toolCallCount++ + callID := part.FunctionCall.ID + if callID == "" { + callID = fmt.Sprintf("gemini_call_%d", toolCallCount) + } + args := part.FunctionCall.Args + if args == nil { + args = map[string]interface{}{} + } + toolCalls = append(toolCalls, ToolCall{ + ID: callID, + Name: part.FunctionCall.Name, + Arguments: args, + }) + } + } + + finishReason := mapGeminiFinishReason(candidate.FinishReason) + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + + usage := &UsageInfo{ + PromptTokens: resp.UsageMetadata.PromptTokenCount, + CompletionTokens: resp.UsageMetadata.CandidatesTokenCount, + TotalTokens: resp.UsageMetadata.TotalTokenCount, + } + if usage.PromptTokens == 0 && usage.CompletionTokens == 0 && usage.TotalTokens == 0 { + usage = nil + } + + return &LLMResponse{ + Content: content.String(), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + }, nil +} + +func mapGeminiFinishReason(reason string) string { + switch strings.ToUpper(reason) { + case "MAX_TOKENS": + return "length" + case "STOP", "": + return "stop" + case "SAFETY", "PROHIBITED_CONTENT", "RECITATION": + return "content_filter" + default: + return "stop" + } +} + +func stripTemperature(options map[string]interface{}) map[string]interface{} { + if options == nil { + return nil + } + if _, ok := options["temperature"]; !ok { + return options + } + cleaned := make(map[string]interface{}, len(options)-1) + for k, v := range options { + if k == "temperature" { + continue + } + cleaned[k] = v + } + return cleaned +} + +func shouldPreferResponses(model, apiMode string) bool { + lowerMode := strings.ToLower(apiMode) + switch lowerMode { + case "openai-responses", "responses", "response": + return true + case "openai-completions", "chat-completions", "completions": + return false + } + + lower := strings.ToLower(model) + return strings.Contains(lower, "gpt-5") || strings.Contains(lower, "codex") || strings.Contains(lower, "o1") +} + +func isGoogleGenerativeAI(apiMode string) bool { + switch strings.ToLower(apiMode) { + case "google-generative-ai", "google", "gemini": + return true + default: + return false + } +} + +func shouldFallbackFromResponses(err error) bool { + var httpErr *httpProviderError + if errors.As(err, &httpErr) { + return isEndpointUnsupported(httpErr.statusCode) + } + return false +} + +func shouldFallbackFromCompletions(err error) bool { + var httpErr *httpProviderError + if errors.As(err, &httpErr) { + return isEndpointUnsupported(httpErr.statusCode) + } + return false +} + +func isEndpointUnsupported(statusCode int) bool { + switch statusCode { + case http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusNotImplemented, http.StatusGone: + return true + default: + return false + } +} + func createClaudeAuthProvider() (LLMProvider, error) { cred, err := auth.GetCredential("anthropic") if err != nil { @@ -222,7 +699,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { model := cfg.Agents.Defaults.Model providerName := strings.ToLower(cfg.Agents.Defaults.Provider) - var apiKey, apiBase, proxy string + var apiKey, apiBase, proxy, apiMode string + var headers map[string]string lowerModel := strings.ToLower(model) @@ -233,6 +711,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { if cfg.Providers.Groq.APIKey != "" { apiKey = cfg.Providers.Groq.APIKey apiBase = cfg.Providers.Groq.APIBase + apiMode = cfg.Providers.Groq.API + headers = cfg.Providers.Groq.Headers if apiBase == "" { apiBase = "https://api.groq.com/openai/v1" } @@ -244,6 +724,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { } apiKey = cfg.Providers.OpenAI.APIKey apiBase = cfg.Providers.OpenAI.APIBase + apiMode = cfg.Providers.OpenAI.API + headers = cfg.Providers.OpenAI.Headers if apiBase == "" { apiBase = "https://api.openai.com/v1" } @@ -255,6 +737,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { } apiKey = cfg.Providers.Anthropic.APIKey apiBase = cfg.Providers.Anthropic.APIBase + apiMode = cfg.Providers.Anthropic.API + headers = cfg.Providers.Anthropic.Headers if apiBase == "" { apiBase = "https://api.anthropic.com/v1" } @@ -262,6 +746,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { case "openrouter": if cfg.Providers.OpenRouter.APIKey != "" { apiKey = cfg.Providers.OpenRouter.APIKey + apiMode = cfg.Providers.OpenRouter.API + headers = cfg.Providers.OpenRouter.Headers if cfg.Providers.OpenRouter.APIBase != "" { apiBase = cfg.Providers.OpenRouter.APIBase } else { @@ -272,6 +758,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { if cfg.Providers.Zhipu.APIKey != "" { apiKey = cfg.Providers.Zhipu.APIKey apiBase = cfg.Providers.Zhipu.APIBase + apiMode = cfg.Providers.Zhipu.API + headers = cfg.Providers.Zhipu.Headers if apiBase == "" { apiBase = "https://open.bigmodel.cn/api/paas/v4" } @@ -280,6 +768,11 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { if cfg.Providers.Gemini.APIKey != "" { apiKey = cfg.Providers.Gemini.APIKey apiBase = cfg.Providers.Gemini.APIBase + apiMode = cfg.Providers.Gemini.API + if apiMode == "" { + apiMode = "google-generative-ai" + } + headers = cfg.Providers.Gemini.Headers if apiBase == "" { apiBase = "https://generativelanguage.googleapis.com/v1beta" } @@ -288,6 +781,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { if cfg.Providers.VLLM.APIBase != "" { apiKey = cfg.Providers.VLLM.APIKey apiBase = cfg.Providers.VLLM.APIBase + apiMode = cfg.Providers.VLLM.API + headers = cfg.Providers.VLLM.Headers } case "claude-cli", "claudecode", "claude-code": workspace := cfg.Agents.Defaults.Workspace @@ -305,6 +800,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiKey = cfg.Providers.Moonshot.APIKey apiBase = cfg.Providers.Moonshot.APIBase proxy = cfg.Providers.Moonshot.Proxy + apiMode = cfg.Providers.Moonshot.API + headers = cfg.Providers.Moonshot.Headers if apiBase == "" { apiBase = "https://api.moonshot.cn/v1" } @@ -312,6 +809,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): apiKey = cfg.Providers.OpenRouter.APIKey proxy = cfg.Providers.OpenRouter.Proxy + apiMode = cfg.Providers.OpenRouter.API + headers = cfg.Providers.OpenRouter.Headers if cfg.Providers.OpenRouter.APIBase != "" { apiBase = cfg.Providers.OpenRouter.APIBase } else { @@ -325,6 +824,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiKey = cfg.Providers.Anthropic.APIKey apiBase = cfg.Providers.Anthropic.APIBase proxy = cfg.Providers.Anthropic.Proxy + apiMode = cfg.Providers.Anthropic.API + headers = cfg.Providers.Anthropic.Headers if apiBase == "" { apiBase = "https://api.anthropic.com/v1" } @@ -336,6 +837,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiKey = cfg.Providers.OpenAI.APIKey apiBase = cfg.Providers.OpenAI.APIBase proxy = cfg.Providers.OpenAI.Proxy + apiMode = cfg.Providers.OpenAI.API + headers = cfg.Providers.OpenAI.Headers if apiBase == "" { apiBase = "https://api.openai.com/v1" } @@ -344,6 +847,11 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiKey = cfg.Providers.Gemini.APIKey apiBase = cfg.Providers.Gemini.APIBase proxy = cfg.Providers.Gemini.Proxy + apiMode = cfg.Providers.Gemini.API + if apiMode == "" { + apiMode = "google-generative-ai" + } + headers = cfg.Providers.Gemini.Headers if apiBase == "" { apiBase = "https://generativelanguage.googleapis.com/v1beta" } @@ -352,6 +860,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiKey = cfg.Providers.Zhipu.APIKey apiBase = cfg.Providers.Zhipu.APIBase proxy = cfg.Providers.Zhipu.Proxy + apiMode = cfg.Providers.Zhipu.API + headers = cfg.Providers.Zhipu.Headers if apiBase == "" { apiBase = "https://open.bigmodel.cn/api/paas/v4" } @@ -360,6 +870,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiKey = cfg.Providers.Groq.APIKey apiBase = cfg.Providers.Groq.APIBase proxy = cfg.Providers.Groq.Proxy + apiMode = cfg.Providers.Groq.API + headers = cfg.Providers.Groq.Headers if apiBase == "" { apiBase = "https://api.groq.com/openai/v1" } @@ -368,6 +880,8 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiKey = cfg.Providers.Nvidia.APIKey apiBase = cfg.Providers.Nvidia.APIBase proxy = cfg.Providers.Nvidia.Proxy + apiMode = cfg.Providers.Nvidia.API + headers = cfg.Providers.Nvidia.Headers if apiBase == "" { apiBase = "https://integrate.api.nvidia.com/v1" } @@ -376,11 +890,15 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiKey = cfg.Providers.VLLM.APIKey apiBase = cfg.Providers.VLLM.APIBase proxy = cfg.Providers.VLLM.Proxy + apiMode = cfg.Providers.VLLM.API + headers = cfg.Providers.VLLM.Headers default: if cfg.Providers.OpenRouter.APIKey != "" { apiKey = cfg.Providers.OpenRouter.APIKey proxy = cfg.Providers.OpenRouter.Proxy + apiMode = cfg.Providers.OpenRouter.API + headers = cfg.Providers.OpenRouter.Headers if cfg.Providers.OpenRouter.APIBase != "" { apiBase = cfg.Providers.OpenRouter.APIBase } else { @@ -400,5 +918,5 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { return nil, fmt.Errorf("no API base configured for provider (model: %s)", model) } - return NewHTTPProvider(apiKey, apiBase, proxy), nil + return NewHTTPProvider(apiKey, apiBase, proxy, apiMode, headers), nil }