diff --git a/README.md b/README.md index 1831af3e9c..d6a3d56963 100644 --- a/README.md +++ b/README.md @@ -679,6 +679,16 @@ The subagent has access to tools (message, web_search, etc.) and can communicate | `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | | `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | +### Provider Architecture + +PicoClaw routes providers by protocol family: + +- OpenAI-compatible protocol: OpenRouter, OpenAI-compatible gateways, Groq, Zhipu, and vLLM-style endpoints. +- Anthropic protocol: Claude-native API behavior. +- Codex/OAuth path: OpenAI OAuth/token authentication route. + +This keeps the runtime lightweight while making new OpenAI-compatible backends mostly a config operation (`api_base` + `api_key`). +
Zhipu diff --git a/pkg/migrate/migrate_test.go b/pkg/migrate/migrate_test.go index be2360aac5..e930d45f44 100644 --- a/pkg/migrate/migrate_test.go +++ b/pkg/migrate/migrate_test.go @@ -299,6 +299,24 @@ func TestConvertConfig(t *testing.T) { }) } +func TestSupportedProvidersCompatibility(t *testing.T) { + expected := []string{ + "anthropic", + "openai", + "openrouter", + "groq", + "zhipu", + "vllm", + "gemini", + } + + for _, provider := range expected { + if !supportedProviders[provider] { + t.Fatalf("supportedProviders missing expected key %q", provider) + } + } +} + func TestMergeConfig(t *testing.T) { t.Run("fills empty fields", func(t *testing.T) { existing := config.DefaultConfig() diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go new file mode 100644 index 0000000000..8f46aa70cf --- /dev/null +++ b/pkg/providers/anthropic/provider.go @@ -0,0 +1,248 @@ +package anthropicprovider + +import ( + "context" + "encoding/json" + "fmt" + "log" + "strings" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +type ToolCall = protocoltypes.ToolCall +type FunctionCall = protocoltypes.FunctionCall +type LLMResponse = protocoltypes.LLMResponse +type UsageInfo = protocoltypes.UsageInfo +type Message = protocoltypes.Message +type ToolDefinition = protocoltypes.ToolDefinition +type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition + +const defaultBaseURL = "https://api.anthropic.com" + +type Provider struct { + client *anthropic.Client + tokenSource func() (string, error) + baseURL string +} + +func NewProvider(token string) *Provider { + return NewProviderWithBaseURL(token, "") +} + +func NewProviderWithBaseURL(token, apiBase string) *Provider { + baseURL := normalizeBaseURL(apiBase) + client := anthropic.NewClient( + option.WithAuthToken(token), + option.WithBaseURL(baseURL), + ) + return &Provider{ + client: &client, + baseURL: baseURL, + } +} + +func NewProviderWithClient(client *anthropic.Client) *Provider { + return &Provider{ + client: client, + baseURL: defaultBaseURL, + } +} + +func NewProviderWithTokenSource(token string, tokenSource func() (string, error)) *Provider { + return NewProviderWithTokenSourceAndBaseURL(token, tokenSource, "") +} + +func NewProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *Provider { + p := NewProviderWithBaseURL(token, apiBase) + p.tokenSource = tokenSource + return p +} + +func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAuthToken(tok)) + } + + params, err := buildParams(messages, tools, model, options) + if err != nil { + return nil, err + } + + resp, err := p.client.Messages.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("claude API call: %w", err) + } + + return parseResponse(resp), nil +} + +func (p *Provider) GetDefaultModel() string { + return "claude-sonnet-4-5-20250929" +} + +func (p *Provider) BaseURL() string { + return p.baseURL +} + +func buildParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { + var system []anthropic.TextBlockParam + var anthropicMessages []anthropic.MessageParam + + for _, msg := range messages { + switch msg.Role { + case "system": + system = append(system, anthropic.TextBlockParam{Text: msg.Content}) + case "user": + if msg.ToolCallID != "" { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + var blocks []anthropic.ContentBlockParamUnion + if msg.Content != "" { + blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) + } + for _, tc := range msg.ToolCalls { + blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) + } + anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "tool": + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } + } + + maxTokens := int64(4096) + if mt, ok := options["max_tokens"].(int); ok { + maxTokens = int64(mt) + } + + params := anthropic.MessageNewParams{ + Model: anthropic.Model(model), + Messages: anthropicMessages, + MaxTokens: maxTokens, + } + + if len(system) > 0 { + params.System = system + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = anthropic.Float(temp) + } + + if len(tools) > 0 { + params.Tools = translateTools(tools) + } + + return params, nil +} + +func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam { + result := make([]anthropic.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + tool := anthropic.ToolParam{ + Name: t.Function.Name, + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: t.Function.Parameters["properties"], + }, + } + if desc := t.Function.Description; desc != "" { + tool.Description = anthropic.String(desc) + } + if req, ok := t.Function.Parameters["required"].([]interface{}); ok { + required := make([]string, 0, len(req)) + for _, r := range req { + if s, ok := r.(string); ok { + required = append(required, s) + } + } + tool.InputSchema.Required = required + } + result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) + } + return result +} + +func parseResponse(resp *anthropic.Message) *LLMResponse { + var content string + var toolCalls []ToolCall + + for _, block := range resp.Content { + switch block.Type { + case "text": + tb := block.AsText() + content += tb.Text + case "tool_use": + tu := block.AsToolUse() + var args map[string]interface{} + if err := json.Unmarshal(tu.Input, &args); err != nil { + log.Printf("anthropic: failed to decode tool call input for %q: %v", tu.Name, err) + args = map[string]interface{}{"raw": string(tu.Input)} + } + toolCalls = append(toolCalls, ToolCall{ + ID: tu.ID, + Name: tu.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + switch resp.StopReason { + case anthropic.StopReasonToolUse: + finishReason = "tool_calls" + case anthropic.StopReasonMaxTokens: + finishReason = "length" + case anthropic.StopReasonEndTurn: + finishReason = "stop" + } + + return &LLMResponse{ + Content: content, + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), + }, + } +} + +func normalizeBaseURL(apiBase string) string { + base := strings.TrimSpace(apiBase) + if base == "" { + return defaultBaseURL + } + + base = strings.TrimRight(base, "/") + if strings.HasSuffix(base, "/v1") { + base = strings.TrimSuffix(base, "/v1") + } + if base == "" { + return defaultBaseURL + } + + return base +} diff --git a/pkg/providers/anthropic/provider_test.go b/pkg/providers/anthropic/provider_test.go new file mode 100644 index 0000000000..6a1dabafbe --- /dev/null +++ b/pkg/providers/anthropic/provider_test.go @@ -0,0 +1,265 @@ +package anthropicprovider + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" +) + +func TestBuildParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{ + "max_tokens": 1024, + }) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if string(params.Model) != "claude-sonnet-4-5-20250929" { + t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929") + } + if params.MaxTokens != 1024 { + t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildParams_SystemMessage(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.System) != 1 { + t.Fatalf("len(System) = %d, want 1", len(params.System)) + } + if params.System[0].Text != "You are helpful" { + t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildParams_ToolCallMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + ID: "call_1", + Name: "get_weather", + Arguments: map[string]interface{}{"city": "SF"}, + }, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params, err := buildParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.Messages) != 3 { + t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) + } +} + +func TestBuildParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather for a city", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + "required": []interface{}{"city"}, + }, + }, + }, + } + params, err := buildParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildParams() error: %v", err) + } + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } +} + +func TestParseResponse_TextOnly(t *testing.T) { + resp := &anthropic.Message{ + Content: []anthropic.ContentBlockUnion{}, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 20, + }, + } + result := parseResponse(resp) + if result.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) + } + if result.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } +} + +func TestParseResponse_StopReasons(t *testing.T) { + tests := []struct { + stopReason anthropic.StopReason + want string + }{ + {anthropic.StopReasonEndTurn, "stop"}, + {anthropic.StopReasonMaxTokens, "length"}, + {anthropic.StopReasonToolUse, "tool_calls"}, + } + for _, tt := range tests { + resp := &anthropic.Message{ + StopReason: tt.stopReason, + } + result := parseResponse(resp) + if result.FinishReason != tt.want { + t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) + } + } +} + +func TestProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]interface{}{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]interface{}{ + {"type": "text", "text": "Hello! How can I help you?"}, + }, + "usage": map[string]interface{}{ + "input_tokens": 15, + "output_tokens": 8, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token")) + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hello! How can I help you?" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.PromptTokens != 15 { + t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens) + } +} + +func TestProvider_GetDefaultModel(t *testing.T) { + p := NewProvider("test-token") + if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929") + } +} + +func TestProvider_NewProviderWithBaseURL_NormalizesV1Suffix(t *testing.T) { + p := NewProviderWithBaseURL("token", "https://api.anthropic.com/v1/") + if got := p.BaseURL(); got != "https://api.anthropic.com" { + t.Fatalf("BaseURL() = %q, want %q", got, "https://api.anthropic.com") + } +} + +func TestProvider_ChatUsesTokenSource(t *testing.T) { + var requests int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + atomic.AddInt32(&requests, 1) + + if got := r.Header.Get("Authorization"); got != "Bearer refreshed-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]interface{}{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]interface{}{ + {"type": "text", "text": "ok"}, + }, + "usage": map[string]interface{}{ + "input_tokens": 1, + "output_tokens": 1, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProviderWithTokenSourceAndBaseURL("stale-token", func() (string, error) { + return "refreshed-token", nil + }, server.URL) + + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if got := atomic.LoadInt32(&requests); got != 1 { + t.Fatalf("requests = %d, want 1", got) + } +} + +func createAnthropicTestClient(baseURL, token string) *anthropic.Client { + c := anthropic.NewClient( + anthropicoption.WithAuthToken(token), + anthropicoption.WithBaseURL(baseURL), + ) + return &c +} diff --git a/pkg/providers/claude_provider.go b/pkg/providers/claude_provider.go index ae6aca96d0..c72f5b0efa 100644 --- a/pkg/providers/claude_provider.go +++ b/pkg/providers/claude_provider.go @@ -2,200 +2,57 @@ package providers import ( "context" - "encoding/json" "fmt" - - "github.com/anthropics/anthropic-sdk-go" - "github.com/anthropics/anthropic-sdk-go/option" - "github.com/sipeed/picoclaw/pkg/auth" + anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic" ) type ClaudeProvider struct { - client *anthropic.Client - tokenSource func() (string, error) + delegate *anthropicprovider.Provider } func NewClaudeProvider(token string) *ClaudeProvider { - client := anthropic.NewClient( - option.WithAuthToken(token), - option.WithBaseURL("https://api.anthropic.com"), - ) - return &ClaudeProvider{client: &client} -} - -func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider { - p := NewClaudeProvider(token) - p.tokenSource = tokenSource - return p -} - -func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - var opts []option.RequestOption - if p.tokenSource != nil { - tok, err := p.tokenSource() - if err != nil { - return nil, fmt.Errorf("refreshing token: %w", err) - } - opts = append(opts, option.WithAuthToken(tok)) + return &ClaudeProvider{ + delegate: anthropicprovider.NewProvider(token), } - - params, err := buildClaudeParams(messages, tools, model, options) - if err != nil { - return nil, err - } - - resp, err := p.client.Messages.New(ctx, params, opts...) - if err != nil { - return nil, fmt.Errorf("claude API call: %w", err) - } - - return parseClaudeResponse(resp), nil } -func (p *ClaudeProvider) GetDefaultModel() string { - return "claude-sonnet-4-5-20250929" -} - -func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { - var system []anthropic.TextBlockParam - var anthropicMessages []anthropic.MessageParam - - for _, msg := range messages { - switch msg.Role { - case "system": - system = append(system, anthropic.TextBlockParam{Text: msg.Content}) - case "user": - if msg.ToolCallID != "" { - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), - ) - } else { - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), - ) - } - case "assistant": - if len(msg.ToolCalls) > 0 { - var blocks []anthropic.ContentBlockParamUnion - if msg.Content != "" { - blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) - } - for _, tc := range msg.ToolCalls { - blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) - } - anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) - } else { - anthropicMessages = append(anthropicMessages, - anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), - ) - } - case "tool": - anthropicMessages = append(anthropicMessages, - anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), - ) - } - } - - maxTokens := int64(4096) - if mt, ok := options["max_tokens"].(int); ok { - maxTokens = int64(mt) - } - - params := anthropic.MessageNewParams{ - Model: anthropic.Model(model), - Messages: anthropicMessages, - MaxTokens: maxTokens, - } - - if len(system) > 0 { - params.System = system - } - - if temp, ok := options["temperature"].(float64); ok { - params.Temperature = anthropic.Float(temp) +func NewClaudeProviderWithBaseURL(token, apiBase string) *ClaudeProvider { + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithBaseURL(token, apiBase), } +} - if len(tools) > 0 { - params.Tools = translateToolsForClaude(tools) +func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider { + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithTokenSource(token, tokenSource), } - - return params, nil } -func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam { - result := make([]anthropic.ToolUnionParam, 0, len(tools)) - for _, t := range tools { - tool := anthropic.ToolParam{ - Name: t.Function.Name, - InputSchema: anthropic.ToolInputSchemaParam{ - Properties: t.Function.Parameters["properties"], - }, - } - if desc := t.Function.Description; desc != "" { - tool.Description = anthropic.String(desc) - } - if req, ok := t.Function.Parameters["required"].([]interface{}); ok { - required := make([]string, 0, len(req)) - for _, r := range req { - if s, ok := r.(string); ok { - required = append(required, s) - } - } - tool.InputSchema.Required = required - } - result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) +func NewClaudeProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *ClaudeProvider { + return &ClaudeProvider{ + delegate: anthropicprovider.NewProviderWithTokenSourceAndBaseURL(token, tokenSource, apiBase), } - return result } -func parseClaudeResponse(resp *anthropic.Message) *LLMResponse { - var content string - var toolCalls []ToolCall - - for _, block := range resp.Content { - switch block.Type { - case "text": - tb := block.AsText() - content += tb.Text - case "tool_use": - tu := block.AsToolUse() - var args map[string]interface{} - if err := json.Unmarshal(tu.Input, &args); err != nil { - args = map[string]interface{}{"raw": string(tu.Input)} - } - toolCalls = append(toolCalls, ToolCall{ - ID: tu.ID, - Name: tu.Name, - Arguments: args, - }) - } - } +func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *ClaudeProvider { + return &ClaudeProvider{delegate: delegate} +} - finishReason := "stop" - switch resp.StopReason { - case anthropic.StopReasonToolUse: - finishReason = "tool_calls" - case anthropic.StopReasonMaxTokens: - finishReason = "length" - case anthropic.StopReasonEndTurn: - finishReason = "stop" +func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + resp, err := p.delegate.Chat(ctx, messages, tools, model, options) + if err != nil { + return nil, err } + return resp, nil +} - return &LLMResponse{ - Content: content, - ToolCalls: toolCalls, - FinishReason: finishReason, - Usage: &UsageInfo{ - PromptTokens: int(resp.Usage.InputTokens), - CompletionTokens: int(resp.Usage.OutputTokens), - TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), - }, - } +func (p *ClaudeProvider) GetDefaultModel() string { + return p.delegate.GetDefaultModel() } func createClaudeTokenSource() func() (string, error) { return func() (string, error) { - cred, err := auth.GetCredential("anthropic") + cred, err := getCredential("anthropic") if err != nil { return "", fmt.Errorf("loading auth credentials: %w", err) } diff --git a/pkg/providers/claude_provider_test.go b/pkg/providers/claude_provider_test.go index bbad2d2692..13bbde1fc1 100644 --- a/pkg/providers/claude_provider_test.go +++ b/pkg/providers/claude_provider_test.go @@ -8,140 +8,9 @@ import ( "github.com/anthropics/anthropic-sdk-go" anthropicoption "github.com/anthropics/anthropic-sdk-go/option" + anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic" ) -func TestBuildClaudeParams_BasicMessage(t *testing.T) { - messages := []Message{ - {Role: "user", Content: "Hello"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{ - "max_tokens": 1024, - }) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if string(params.Model) != "claude-sonnet-4-5-20250929" { - t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929") - } - if params.MaxTokens != 1024 { - t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) - } - if len(params.Messages) != 1 { - t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) - } -} - -func TestBuildClaudeParams_SystemMessage(t *testing.T) { - messages := []Message{ - {Role: "system", Content: "You are helpful"}, - {Role: "user", Content: "Hi"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.System) != 1 { - t.Fatalf("len(System) = %d, want 1", len(params.System)) - } - if params.System[0].Text != "You are helpful" { - t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") - } - if len(params.Messages) != 1 { - t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) - } -} - -func TestBuildClaudeParams_ToolCallMessage(t *testing.T) { - messages := []Message{ - {Role: "user", Content: "What's the weather?"}, - { - Role: "assistant", - Content: "", - ToolCalls: []ToolCall{ - { - ID: "call_1", - Name: "get_weather", - Arguments: map[string]interface{}{"city": "SF"}, - }, - }, - }, - {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, - } - params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.Messages) != 3 { - t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) - } -} - -func TestBuildClaudeParams_WithTools(t *testing.T) { - tools := []ToolDefinition{ - { - Type: "function", - Function: ToolFunctionDefinition{ - Name: "get_weather", - Description: "Get weather for a city", - Parameters: map[string]interface{}{ - "type": "object", - "properties": map[string]interface{}{ - "city": map[string]interface{}{"type": "string"}, - }, - "required": []interface{}{"city"}, - }, - }, - }, - } - params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{}) - if err != nil { - t.Fatalf("buildClaudeParams() error: %v", err) - } - if len(params.Tools) != 1 { - t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) - } -} - -func TestParseClaudeResponse_TextOnly(t *testing.T) { - resp := &anthropic.Message{ - Content: []anthropic.ContentBlockUnion{}, - Usage: anthropic.Usage{ - InputTokens: 10, - OutputTokens: 20, - }, - } - result := parseClaudeResponse(resp) - if result.Usage.PromptTokens != 10 { - t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) - } - if result.Usage.CompletionTokens != 20 { - t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) - } - if result.FinishReason != "stop" { - t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") - } -} - -func TestParseClaudeResponse_StopReasons(t *testing.T) { - tests := []struct { - stopReason anthropic.StopReason - want string - }{ - {anthropic.StopReasonEndTurn, "stop"}, - {anthropic.StopReasonMaxTokens, "length"}, - {anthropic.StopReasonToolUse, "tool_calls"}, - } - for _, tt := range tests { - resp := &anthropic.Message{ - StopReason: tt.stopReason, - } - result := parseClaudeResponse(resp) - if result.FinishReason != tt.want { - t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) - } - } -} - func TestClaudeProvider_ChatRoundTrip(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/messages" { @@ -175,8 +44,8 @@ func TestClaudeProvider_ChatRoundTrip(t *testing.T) { })) defer server.Close() - provider := NewClaudeProvider("test-token") - provider.client = createAnthropicTestClient(server.URL, "test-token") + delegate := anthropicprovider.NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token")) + provider := newClaudeProviderWithDelegate(delegate) messages := []Message{{Role: "user", Content: "Hello"}} resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024}) diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go new file mode 100644 index 0000000000..e39cfe32b8 --- /dev/null +++ b/pkg/providers/factory.go @@ -0,0 +1,360 @@ +package providers + +import ( + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/config" +) + +const defaultAnthropicAPIBase = "https://api.anthropic.com/v1" + +var getCredential = auth.GetCredential + +type providerType int + +const ( + providerTypeHTTPCompat providerType = iota + providerTypeClaudeAuth + providerTypeCodexAuth + providerTypeCodexCLIToken + providerTypeClaudeCLI + providerTypeCodexCLI + providerTypeGitHubCopilot +) + +type providerSelection struct { + providerType providerType + apiKey string + apiBase string + proxy string + model string + workspace string + connectMode string + enableWebSearch bool +} + +func createClaudeAuthProvider(apiBase string) (LLMProvider, error) { + if apiBase == "" { + apiBase = defaultAnthropicAPIBase + } + cred, err := getCredential("anthropic") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return NewClaudeProviderWithTokenSourceAndBaseURL(cred.AccessToken, createClaudeTokenSource(), apiBase), nil +} + +func createCodexAuthProvider(enableWebSearch bool) (LLMProvider, error) { + cred, err := getCredential("openai") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + p := NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()) + p.enableWebSearch = enableWebSearch + return p, nil +} + +func resolveProviderSelection(cfg *config.Config) (providerSelection, error) { + model := cfg.Agents.Defaults.Model + providerName := strings.ToLower(cfg.Agents.Defaults.Provider) + lowerModel := strings.ToLower(model) + + sel := providerSelection{ + providerType: providerTypeHTTPCompat, + model: model, + } + + // First, prefer explicit provider configuration. + if providerName != "" { + switch providerName { + case "groq": + if cfg.Providers.Groq.APIKey != "" { + sel.apiKey = cfg.Providers.Groq.APIKey + sel.apiBase = cfg.Providers.Groq.APIBase + sel.proxy = cfg.Providers.Groq.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.groq.com/openai/v1" + } + } + case "openai", "gpt": + if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { + sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch + if cfg.Providers.OpenAI.AuthMethod == "codex-cli" { + sel.providerType = providerTypeCodexCLIToken + return sel, nil + } + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + sel.providerType = providerTypeCodexAuth + return sel, nil + } + sel.apiKey = cfg.Providers.OpenAI.APIKey + sel.apiBase = cfg.Providers.OpenAI.APIBase + sel.proxy = cfg.Providers.OpenAI.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.openai.com/v1" + } + } + case "anthropic", "claude": + if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + sel.apiBase = cfg.Providers.Anthropic.APIBase + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } + sel.providerType = providerTypeClaudeAuth + return sel, nil + } + sel.apiKey = cfg.Providers.Anthropic.APIKey + sel.apiBase = cfg.Providers.Anthropic.APIBase + sel.proxy = cfg.Providers.Anthropic.Proxy + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } + } + case "openrouter": + if cfg.Providers.OpenRouter.APIKey != "" { + sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + } + case "zhipu", "glm": + if cfg.Providers.Zhipu.APIKey != "" { + sel.apiKey = cfg.Providers.Zhipu.APIKey + sel.apiBase = cfg.Providers.Zhipu.APIBase + sel.proxy = cfg.Providers.Zhipu.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + } + case "gemini", "google": + if cfg.Providers.Gemini.APIKey != "" { + sel.apiKey = cfg.Providers.Gemini.APIKey + sel.apiBase = cfg.Providers.Gemini.APIBase + sel.proxy = cfg.Providers.Gemini.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + } + case "vllm": + if cfg.Providers.VLLM.APIBase != "" { + sel.apiKey = cfg.Providers.VLLM.APIKey + sel.apiBase = cfg.Providers.VLLM.APIBase + sel.proxy = cfg.Providers.VLLM.Proxy + } + case "shengsuanyun": + if cfg.Providers.ShengSuanYun.APIKey != "" { + sel.apiKey = cfg.Providers.ShengSuanYun.APIKey + sel.apiBase = cfg.Providers.ShengSuanYun.APIBase + sel.proxy = cfg.Providers.ShengSuanYun.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://router.shengsuanyun.com/api/v1" + } + } + case "nvidia": + if cfg.Providers.Nvidia.APIKey != "" { + sel.apiKey = cfg.Providers.Nvidia.APIKey + sel.apiBase = cfg.Providers.Nvidia.APIBase + sel.proxy = cfg.Providers.Nvidia.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://integrate.api.nvidia.com/v1" + } + } + case "claude-cli", "claude-code", "claudecode": + workspace := cfg.WorkspacePath() + if workspace == "" { + workspace = "." + } + sel.providerType = providerTypeClaudeCLI + sel.workspace = workspace + return sel, nil + case "codex-cli", "codex-code": + workspace := cfg.WorkspacePath() + if workspace == "" { + workspace = "." + } + sel.providerType = providerTypeCodexCLI + sel.workspace = workspace + return sel, nil + case "deepseek": + if cfg.Providers.DeepSeek.APIKey != "" { + sel.apiKey = cfg.Providers.DeepSeek.APIKey + sel.apiBase = cfg.Providers.DeepSeek.APIBase + sel.proxy = cfg.Providers.DeepSeek.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.deepseek.com/v1" + } + if model != "deepseek-chat" && model != "deepseek-reasoner" { + sel.model = "deepseek-chat" + } + } + case "github_copilot", "copilot": + sel.providerType = providerTypeGitHubCopilot + if cfg.Providers.GitHubCopilot.APIBase != "" { + sel.apiBase = cfg.Providers.GitHubCopilot.APIBase + } else { + sel.apiBase = "localhost:4321" + } + sel.connectMode = cfg.Providers.GitHubCopilot.ConnectMode + return sel, nil + } + } + + // Fallback: infer provider from model and configured keys. + if sel.apiKey == "" && sel.apiBase == "" { + switch { + case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": + sel.apiKey = cfg.Providers.Moonshot.APIKey + sel.apiBase = cfg.Providers.Moonshot.APIBase + sel.proxy = cfg.Providers.Moonshot.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.moonshot.cn/v1" + } + case strings.HasPrefix(model, "openrouter/") || + strings.HasPrefix(model, "anthropic/") || + strings.HasPrefix(model, "openai/") || + strings.HasPrefix(model, "meta-llama/") || + strings.HasPrefix(model, "deepseek/") || + strings.HasPrefix(model, "google/"): + sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && + (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + sel.apiBase = cfg.Providers.Anthropic.APIBase + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } + sel.providerType = providerTypeClaudeAuth + return sel, nil + } + sel.apiKey = cfg.Providers.Anthropic.APIKey + sel.apiBase = cfg.Providers.Anthropic.APIBase + sel.proxy = cfg.Providers.Anthropic.Proxy + if sel.apiBase == "" { + sel.apiBase = defaultAnthropicAPIBase + } + case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && + (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): + sel.enableWebSearch = cfg.Providers.OpenAI.WebSearch + if cfg.Providers.OpenAI.AuthMethod == "codex-cli" { + sel.providerType = providerTypeCodexCLIToken + return sel, nil + } + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + sel.providerType = providerTypeCodexAuth + return sel, nil + } + sel.apiKey = cfg.Providers.OpenAI.APIKey + sel.apiBase = cfg.Providers.OpenAI.APIBase + sel.proxy = cfg.Providers.OpenAI.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.openai.com/v1" + } + case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": + sel.apiKey = cfg.Providers.Gemini.APIKey + sel.apiBase = cfg.Providers.Gemini.APIBase + sel.proxy = cfg.Providers.Gemini.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": + sel.apiKey = cfg.Providers.Zhipu.APIKey + sel.apiBase = cfg.Providers.Zhipu.APIBase + sel.proxy = cfg.Providers.Zhipu.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": + sel.apiKey = cfg.Providers.Groq.APIKey + sel.apiBase = cfg.Providers.Groq.APIBase + sel.proxy = cfg.Providers.Groq.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://api.groq.com/openai/v1" + } + case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": + sel.apiKey = cfg.Providers.Nvidia.APIKey + sel.apiBase = cfg.Providers.Nvidia.APIBase + sel.proxy = cfg.Providers.Nvidia.Proxy + if sel.apiBase == "" { + sel.apiBase = "https://integrate.api.nvidia.com/v1" + } + case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "": + sel.apiKey = cfg.Providers.Ollama.APIKey + sel.apiBase = cfg.Providers.Ollama.APIBase + sel.proxy = cfg.Providers.Ollama.Proxy + if sel.apiBase == "" { + sel.apiBase = "http://localhost:11434/v1" + } + case cfg.Providers.VLLM.APIBase != "": + sel.apiKey = cfg.Providers.VLLM.APIKey + sel.apiBase = cfg.Providers.VLLM.APIBase + sel.proxy = cfg.Providers.VLLM.Proxy + default: + if cfg.Providers.OpenRouter.APIKey != "" { + sel.apiKey = cfg.Providers.OpenRouter.APIKey + sel.proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + sel.apiBase = cfg.Providers.OpenRouter.APIBase + } else { + sel.apiBase = "https://openrouter.ai/api/v1" + } + } else { + return providerSelection{}, fmt.Errorf("no API key configured for model: %s", model) + } + } + } + + if sel.providerType == providerTypeHTTPCompat { + if sel.apiKey == "" && !strings.HasPrefix(model, "bedrock/") { + return providerSelection{}, fmt.Errorf("no API key configured for provider (model: %s)", model) + } + if sel.apiBase == "" { + return providerSelection{}, fmt.Errorf("no API base configured for provider (model: %s)", model) + } + } + + return sel, nil +} + +func CreateProvider(cfg *config.Config) (LLMProvider, error) { + sel, err := resolveProviderSelection(cfg) + if err != nil { + return nil, err + } + + switch sel.providerType { + case providerTypeClaudeAuth: + return createClaudeAuthProvider(sel.apiBase) + case providerTypeCodexAuth: + return createCodexAuthProvider(sel.enableWebSearch) + case providerTypeCodexCLIToken: + c := NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()) + c.enableWebSearch = sel.enableWebSearch + return c, nil + case providerTypeClaudeCLI: + return NewClaudeCliProvider(sel.workspace), nil + case providerTypeCodexCLI: + return NewCodexCliProvider(sel.workspace), nil + case providerTypeGitHubCopilot: + return NewGitHubCopilotProvider(sel.apiBase, sel.connectMode, sel.model) + default: + return NewHTTPProvider(sel.apiKey, sel.apiBase, sel.proxy), nil + } +} diff --git a/pkg/providers/factory_test.go b/pkg/providers/factory_test.go new file mode 100644 index 0000000000..e31737eb97 --- /dev/null +++ b/pkg/providers/factory_test.go @@ -0,0 +1,299 @@ +package providers + +import ( + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestResolveProviderSelection(t *testing.T) { + tests := []struct { + name string + setup func(*config.Config) + wantType providerType + wantAPIBase string + wantProxy string + wantErrSubstr string + }{ + { + name: "explicit claude-cli provider routes to cli provider type", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "claude-cli" + cfg.Agents.Defaults.Workspace = "/tmp/ws" + }, + wantType: providerTypeClaudeCLI, + }, + { + name: "explicit copilot provider routes to github copilot type", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "copilot" + }, + wantType: providerTypeGitHubCopilot, + wantAPIBase: "localhost:4321", + }, + { + name: "explicit deepseek provider uses deepseek defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "deepseek" + cfg.Agents.Defaults.Model = "deepseek/deepseek-chat" + cfg.Providers.DeepSeek.APIKey = "deepseek-key" + cfg.Providers.DeepSeek.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.deepseek.com/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "explicit shengsuanyun provider uses defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "shengsuanyun" + cfg.Providers.ShengSuanYun.APIKey = "ssy-key" + cfg.Providers.ShengSuanYun.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://router.shengsuanyun.com/api/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "explicit nvidia provider uses defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "nvidia" + cfg.Providers.Nvidia.APIKey = "nvapi-test" + cfg.Providers.Nvidia.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://integrate.api.nvidia.com/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "openrouter model uses openrouter defaults", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "openrouter/auto" + cfg.Providers.OpenRouter.APIKey = "sk-or-test" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://openrouter.ai/api/v1", + }, + { + name: "anthropic oauth routes to claude auth provider", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "claude-sonnet-4-5-20250929" + cfg.Providers.Anthropic.AuthMethod = "oauth" + }, + wantType: providerTypeClaudeAuth, + }, + { + name: "openai oauth routes to codex auth provider", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "gpt-4o" + cfg.Providers.OpenAI.AuthMethod = "oauth" + }, + wantType: providerTypeCodexAuth, + }, + { + name: "openai codex-cli auth routes to codex cli token provider", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "gpt-4o" + cfg.Providers.OpenAI.AuthMethod = "codex-cli" + }, + wantType: providerTypeCodexCLIToken, + }, + { + name: "explicit codex-code provider routes to codex cli provider type", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Provider = "codex-code" + cfg.Agents.Defaults.Workspace = "/tmp/ws" + }, + wantType: providerTypeCodexCLI, + }, + { + name: "zhipu model uses zhipu base default", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "glm-4.7" + cfg.Providers.Zhipu.APIKey = "zhipu-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://open.bigmodel.cn/api/paas/v4", + }, + { + name: "groq model uses groq base default", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "groq/llama-3.3-70b" + cfg.Providers.Groq.APIKey = "gsk-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.groq.com/openai/v1", + }, + { + name: "ollama model uses ollama base default", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "ollama/qwen2.5:14b" + cfg.Providers.Ollama.APIKey = "ollama-key" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "http://localhost:11434/v1", + }, + { + name: "moonshot model keeps proxy and default base", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "moonshot/kimi-k2.5" + cfg.Providers.Moonshot.APIKey = "moonshot-key" + cfg.Providers.Moonshot.Proxy = "http://127.0.0.1:7890" + }, + wantType: providerTypeHTTPCompat, + wantAPIBase: "https://api.moonshot.cn/v1", + wantProxy: "http://127.0.0.1:7890", + }, + { + name: "missing keys returns model config error", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "custom-model" + }, + wantErrSubstr: "no API key configured for model", + }, + { + name: "openrouter prefix without key returns provider key error", + setup: func(cfg *config.Config) { + cfg.Agents.Defaults.Model = "openrouter/auto" + }, + wantErrSubstr: "no API key configured for provider", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := config.DefaultConfig() + tt.setup(cfg) + + got, err := resolveProviderSelection(cfg) + if tt.wantErrSubstr != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr) + } + if !strings.Contains(err.Error(), tt.wantErrSubstr) { + t.Fatalf("error = %q, want substring %q", err.Error(), tt.wantErrSubstr) + } + return + } + + if err != nil { + t.Fatalf("resolveProviderSelection() error = %v", err) + } + if got.providerType != tt.wantType { + t.Fatalf("providerType = %v, want %v", got.providerType, tt.wantType) + } + if tt.wantAPIBase != "" && got.apiBase != tt.wantAPIBase { + t.Fatalf("apiBase = %q, want %q", got.apiBase, tt.wantAPIBase) + } + if tt.wantProxy != "" && got.proxy != tt.wantProxy { + t.Fatalf("proxy = %q, want %q", got.proxy, tt.wantProxy) + } + }) + } +} + +func TestCreateProviderReturnsHTTPProviderForOpenRouter(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Model = "openrouter/auto" + cfg.Providers.OpenRouter.APIKey = "sk-or-test" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*HTTPProvider); !ok { + t.Fatalf("provider type = %T, want *HTTPProvider", provider) + } +} + +func TestCreateProviderReturnsCodexCliProviderForCodexCode(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "codex-code" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*CodexCliProvider); !ok { + t.Fatalf("provider type = %T, want *CodexCliProvider", provider) + } +} + +func TestCreateProviderReturnsCodexProviderForCodexCliAuthMethod(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "openai" + cfg.Providers.OpenAI.AuthMethod = "codex-cli" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*CodexProvider); !ok { + t.Fatalf("provider type = %T, want *CodexProvider", provider) + } +} + +func TestCreateProviderReturnsClaudeProviderForAnthropicOAuth(t *testing.T) { + originalGetCredential := getCredential + t.Cleanup(func() { getCredential = originalGetCredential }) + + getCredential = func(provider string) (*auth.AuthCredential, error) { + if provider != "anthropic" { + t.Fatalf("provider = %q, want anthropic", provider) + } + return &auth.AuthCredential{ + AccessToken: "anthropic-token", + }, nil + } + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "anthropic" + cfg.Providers.Anthropic.AuthMethod = "oauth" + cfg.Providers.Anthropic.APIBase = "https://proxy.example.com/v1" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + claudeProvider, ok := provider.(*ClaudeProvider) + if !ok { + t.Fatalf("provider type = %T, want *ClaudeProvider", provider) + } + if got := claudeProvider.delegate.BaseURL(); got != "https://proxy.example.com" { + t.Fatalf("anthropic baseURL = %q, want %q", got, "https://proxy.example.com") + } +} + +func TestCreateProviderReturnsCodexProviderForOpenAIOAuth(t *testing.T) { + originalGetCredential := getCredential + t.Cleanup(func() { getCredential = originalGetCredential }) + + getCredential = func(provider string) (*auth.AuthCredential, error) { + if provider != "openai" { + t.Fatalf("provider = %q, want openai", provider) + } + return &auth.AuthCredential{ + AccessToken: "openai-token", + AccountID: "acct_123", + }, nil + } + + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "openai" + cfg.Providers.OpenAI.AuthMethod = "oauth" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider() error = %v", err) + } + + if _, ok := provider.(*CodexProvider); !ok { + t.Fatalf("provider type = %T, want *CodexProvider", provider) + } +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 946aa29d22..e39a19e90e 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -7,448 +7,24 @@ package providers import ( - "bytes" "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "time" - - "github.com/sipeed/picoclaw/pkg/auth" - "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers/openai_compat" ) type HTTPProvider struct { - apiKey string - apiBase string - httpClient *http.Client + delegate *openai_compat.Provider } func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { - client := &http.Client{ - Timeout: 120 * time.Second, - } - - if proxy != "" { - proxyURL, err := url.Parse(proxy) - if err == nil { - client.Transport = &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - } - } - } - return &HTTPProvider{ - apiKey: apiKey, - apiBase: strings.TrimRight(apiBase, "/"), - httpClient: client, + delegate: openai_compat.NewProvider(apiKey, apiBase, proxy), } } func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - if p.apiBase == "" { - return nil, fmt.Errorf("API base not configured") - } - - // Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5, groq/openai/gpt-oss-120b -> openai/gpt-oss-120b, ollama/qwen2.5:14b -> qwen2.5:14b) - if idx := strings.Index(model, "/"); idx != -1 { - prefix := model[:idx] - if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" { - model = model[idx+1:] - } - } - - requestBody := map[string]interface{}{ - "model": model, - "messages": messages, - } - - if len(tools) > 0 { - requestBody["tools"] = tools - requestBody["tool_choice"] = "auto" - } - - if maxTokens, ok := options["max_tokens"].(int); ok { - lowerModel := strings.ToLower(model) - if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { - requestBody["max_completion_tokens"] = maxTokens - } else { - requestBody["max_tokens"] = maxTokens - } - } - - if temperature, ok := options["temperature"].(float64); ok { - lowerModel := strings.ToLower(model) - // Kimi k2 models only support temperature=1 - if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { - requestBody["temperature"] = 1.0 - } else { - requestBody["temperature"] = temperature - } - } - - jsonData, err := json.Marshal(requestBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - if p.apiKey != "" { - req.Header.Set("Authorization", "Bearer "+p.apiKey) - } - - resp, err := p.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) - } - - return p.parseResponse(body) -} - -func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { - var apiResponse struct { - Choices []struct { - Message struct { - Content string `json:"content"` - ToolCalls []struct { - ID string `json:"id"` - Type string `json:"type"` - Function *struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - } `json:"function"` - } `json:"tool_calls"` - } `json:"message"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` - Usage *UsageInfo `json:"usage"` - } - - if err := json.Unmarshal(body, &apiResponse); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - if len(apiResponse.Choices) == 0 { - return &LLMResponse{ - Content: "", - FinishReason: "stop", - }, nil - } - - choice := apiResponse.Choices[0] - - toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) - for _, tc := range choice.Message.ToolCalls { - arguments := make(map[string]interface{}) - name := "" - - // Handle OpenAI format with nested function object - if tc.Type == "function" && tc.Function != nil { - name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - arguments["raw"] = tc.Function.Arguments - } - } - } else if tc.Function != nil { - // Legacy format without type field - name = tc.Function.Name - if tc.Function.Arguments != "" { - if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { - arguments["raw"] = tc.Function.Arguments - } - } - } - - toolCalls = append(toolCalls, ToolCall{ - ID: tc.ID, - Name: name, - Arguments: arguments, - }) - } - - return &LLMResponse{ - Content: choice.Message.Content, - ToolCalls: toolCalls, - FinishReason: choice.FinishReason, - Usage: apiResponse.Usage, - }, nil + return p.delegate.Chat(ctx, messages, tools, model, options) } func (p *HTTPProvider) GetDefaultModel() string { return "" } - -func createClaudeAuthProvider() (LLMProvider, error) { - cred, err := auth.GetCredential("anthropic") - if err != nil { - return nil, fmt.Errorf("loading auth credentials: %w", err) - } - if cred == nil { - return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") - } - return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil -} - -func createCodexAuthProvider(enableWebSearch bool) (LLMProvider, error) { - cred, err := auth.GetCredential("openai") - if err != nil { - return nil, fmt.Errorf("loading auth credentials: %w", err) - } - if cred == nil { - return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") - } - p := NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()) - p.enableWebSearch = enableWebSearch - return p, nil -} - -func CreateProvider(cfg *config.Config) (LLMProvider, error) { - model := cfg.Agents.Defaults.Model - providerName := strings.ToLower(cfg.Agents.Defaults.Provider) - - var apiKey, apiBase, proxy string - - lowerModel := strings.ToLower(model) - - // First, try to use explicitly configured provider - if providerName != "" { - switch providerName { - case "groq": - if cfg.Providers.Groq.APIKey != "" { - apiKey = cfg.Providers.Groq.APIKey - apiBase = cfg.Providers.Groq.APIBase - if apiBase == "" { - apiBase = "https://api.groq.com/openai/v1" - } - } - case "openai", "gpt": - if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { - if cfg.Providers.OpenAI.AuthMethod == "codex-cli" { - c := NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()) - c.enableWebSearch = cfg.Providers.OpenAI.WebSearch - return c, nil - } - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - return createCodexAuthProvider(cfg.Providers.OpenAI.WebSearch) - } - apiKey = cfg.Providers.OpenAI.APIKey - apiBase = cfg.Providers.OpenAI.APIBase - if apiBase == "" { - apiBase = "https://api.openai.com/v1" - } - } - case "anthropic", "claude": - if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - return createClaudeAuthProvider() - } - apiKey = cfg.Providers.Anthropic.APIKey - apiBase = cfg.Providers.Anthropic.APIBase - if apiBase == "" { - apiBase = "https://api.anthropic.com/v1" - } - } - case "openrouter": - if cfg.Providers.OpenRouter.APIKey != "" { - apiKey = cfg.Providers.OpenRouter.APIKey - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - } - case "zhipu", "glm": - if cfg.Providers.Zhipu.APIKey != "" { - apiKey = cfg.Providers.Zhipu.APIKey - apiBase = cfg.Providers.Zhipu.APIBase - if apiBase == "" { - apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - } - case "gemini", "google": - if cfg.Providers.Gemini.APIKey != "" { - apiKey = cfg.Providers.Gemini.APIKey - apiBase = cfg.Providers.Gemini.APIBase - if apiBase == "" { - apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - } - case "vllm": - if cfg.Providers.VLLM.APIBase != "" { - apiKey = cfg.Providers.VLLM.APIKey - apiBase = cfg.Providers.VLLM.APIBase - } - case "shengsuanyun": - if cfg.Providers.ShengSuanYun.APIKey != "" { - apiKey = cfg.Providers.ShengSuanYun.APIKey - apiBase = cfg.Providers.ShengSuanYun.APIBase - if apiBase == "" { - apiBase = "https://router.shengsuanyun.com/api/v1" - } - } - case "claude-cli", "claudecode", "claude-code": - workspace := cfg.WorkspacePath() - if workspace == "" { - workspace = "." - } - return NewClaudeCliProvider(workspace), nil - case "codex-cli", "codex-code": - workspace := cfg.WorkspacePath() - if workspace == "" { - workspace = "." - } - return NewCodexCliProvider(workspace), nil - case "deepseek": - if cfg.Providers.DeepSeek.APIKey != "" { - apiKey = cfg.Providers.DeepSeek.APIKey - apiBase = cfg.Providers.DeepSeek.APIBase - if apiBase == "" { - apiBase = "https://api.deepseek.com/v1" - } - if model != "deepseek-chat" && model != "deepseek-reasoner" { - model = "deepseek-chat" - } - } - case "github_copilot", "copilot": - if cfg.Providers.GitHubCopilot.APIBase != "" { - apiBase = cfg.Providers.GitHubCopilot.APIBase - } else { - apiBase = "localhost:4321" - } - return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model) - - } - - } - - // Fallback: detect provider from model name - if apiKey == "" && apiBase == "" { - switch { - case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": - apiKey = cfg.Providers.Moonshot.APIKey - apiBase = cfg.Providers.Moonshot.APIBase - proxy = cfg.Providers.Moonshot.Proxy - if apiBase == "" { - apiBase = "https://api.moonshot.cn/v1" - } - - case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): - apiKey = cfg.Providers.OpenRouter.APIKey - proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - - case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - return createClaudeAuthProvider() - } - apiKey = cfg.Providers.Anthropic.APIKey - apiBase = cfg.Providers.Anthropic.APIBase - proxy = cfg.Providers.Anthropic.Proxy - if apiBase == "" { - apiBase = "https://api.anthropic.com/v1" - } - - case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - return createCodexAuthProvider(cfg.Providers.OpenAI.WebSearch) - } - apiKey = cfg.Providers.OpenAI.APIKey - apiBase = cfg.Providers.OpenAI.APIBase - proxy = cfg.Providers.OpenAI.Proxy - if apiBase == "" { - apiBase = "https://api.openai.com/v1" - } - - case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": - apiKey = cfg.Providers.Gemini.APIKey - apiBase = cfg.Providers.Gemini.APIBase - proxy = cfg.Providers.Gemini.Proxy - if apiBase == "" { - apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - - case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": - apiKey = cfg.Providers.Zhipu.APIKey - apiBase = cfg.Providers.Zhipu.APIBase - proxy = cfg.Providers.Zhipu.Proxy - if apiBase == "" { - apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - - case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": - apiKey = cfg.Providers.Groq.APIKey - apiBase = cfg.Providers.Groq.APIBase - proxy = cfg.Providers.Groq.Proxy - if apiBase == "" { - apiBase = "https://api.groq.com/openai/v1" - } - - case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": - apiKey = cfg.Providers.Nvidia.APIKey - apiBase = cfg.Providers.Nvidia.APIBase - proxy = cfg.Providers.Nvidia.Proxy - if apiBase == "" { - apiBase = "https://integrate.api.nvidia.com/v1" - } - case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "": - fmt.Println("Ollama provider selected based on model name prefix") - apiKey = cfg.Providers.Ollama.APIKey - apiBase = cfg.Providers.Ollama.APIBase - proxy = cfg.Providers.Ollama.Proxy - if apiBase == "" { - apiBase = "http://localhost:11434/v1" - } - fmt.Println("Ollama apiBase:", apiBase) - case cfg.Providers.VLLM.APIBase != "": - apiKey = cfg.Providers.VLLM.APIKey - apiBase = cfg.Providers.VLLM.APIBase - proxy = cfg.Providers.VLLM.Proxy - - default: - if cfg.Providers.OpenRouter.APIKey != "" { - apiKey = cfg.Providers.OpenRouter.APIKey - proxy = cfg.Providers.OpenRouter.Proxy - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" - } - } else { - return nil, fmt.Errorf("no API key configured for model: %s", model) - } - } - } - - if apiKey == "" && !strings.HasPrefix(model, "bedrock/") { - return nil, fmt.Errorf("no API key configured for provider (model: %s)", model) - } - - if apiBase == "" { - return nil, fmt.Errorf("no API base configured for provider (model: %s)", model) - } - - return NewHTTPProvider(apiKey, apiBase, proxy), nil -} diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go new file mode 100644 index 0000000000..9b404dd773 --- /dev/null +++ b/pkg/providers/openai_compat/provider.go @@ -0,0 +1,232 @@ +package openai_compat + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) + +type ToolCall = protocoltypes.ToolCall +type FunctionCall = protocoltypes.FunctionCall +type LLMResponse = protocoltypes.LLMResponse +type UsageInfo = protocoltypes.UsageInfo +type Message = protocoltypes.Message +type ToolDefinition = protocoltypes.ToolDefinition +type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition + +type Provider struct { + apiKey string + apiBase string + httpClient *http.Client +} + +func NewProvider(apiKey, apiBase, proxy string) *Provider { + client := &http.Client{ + Timeout: 120 * time.Second, + } + + if proxy != "" { + parsed, err := url.Parse(proxy) + if err == nil { + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(parsed), + } + } else { + log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err) + } + } + + return &Provider{ + apiKey: apiKey, + apiBase: strings.TrimRight(apiBase, "/"), + httpClient: client, + } +} + +func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + if p.apiBase == "" { + return nil, fmt.Errorf("API base not configured") + } + + model = normalizeModel(model, p.apiBase) + + requestBody := map[string]interface{}{ + "model": model, + "messages": messages, + } + + if len(tools) > 0 { + requestBody["tools"] = tools + requestBody["tool_choice"] = "auto" + } + + if maxTokens, ok := asInt(options["max_tokens"]); ok { + lowerModel := strings.ToLower(model) + if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { + requestBody["max_completion_tokens"] = maxTokens + } else { + requestBody["max_tokens"] = maxTokens + } + } + + if temperature, ok := asFloat(options["temperature"]); ok { + lowerModel := strings.ToLower(model) + // Kimi k2 models only support temperature=1. + if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { + requestBody["temperature"] = 1.0 + } else { + requestBody["temperature"] = temperature + } + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if p.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+p.apiKey) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) + } + + return parseResponse(body) +} + +func parseResponse(body []byte) (*LLMResponse, error) { + var apiResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function *struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage *UsageInfo `json:"usage"` + } + + if err := json.Unmarshal(body, &apiResponse); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + if len(apiResponse.Choices) == 0 { + return &LLMResponse{ + Content: "", + FinishReason: "stop", + }, nil + } + + choice := apiResponse.Choices[0] + toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls)) + for _, tc := range choice.Message.ToolCalls { + arguments := make(map[string]interface{}) + name := "" + + if tc.Function != nil { + name = tc.Function.Name + if tc.Function.Arguments != "" { + if err := json.Unmarshal([]byte(tc.Function.Arguments), &arguments); err != nil { + log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err) + arguments["raw"] = tc.Function.Arguments + } + } + } + + toolCalls = append(toolCalls, ToolCall{ + ID: tc.ID, + Name: name, + Arguments: arguments, + }) + } + + return &LLMResponse{ + Content: choice.Message.Content, + ToolCalls: toolCalls, + FinishReason: choice.FinishReason, + Usage: apiResponse.Usage, + }, nil +} + +func normalizeModel(model, apiBase string) string { + idx := strings.Index(model, "/") + if idx == -1 { + return model + } + + if strings.Contains(strings.ToLower(apiBase), "openrouter.ai") { + return model + } + + prefix := strings.ToLower(model[:idx]) + switch prefix { + case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu": + return model[idx+1:] + default: + return model + } +} + +func asInt(v interface{}) (int, bool) { + switch val := v.(type) { + case int: + return val, true + case int64: + return int(val), true + case float64: + return int(val), true + case float32: + return int(val), true + default: + return 0, false + } +} + +func asFloat(v interface{}) (float64, bool) { + switch val := v.(type) { + case float64: + return val, true + case float32: + return float64(val), true + case int: + return float64(val), true + case int64: + return float64(val), true + default: + return 0, false + } +} diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go new file mode 100644 index 0000000000..94779b39cb --- /dev/null +++ b/pkg/providers/openai_compat/provider_test.go @@ -0,0 +1,277 @@ +package openai_compat + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) { + var requestBody map[string]interface{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.7", map[string]interface{}{"max_tokens": 1234}) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if _, ok := requestBody["max_completion_tokens"]; !ok { + t.Fatalf("expected max_completion_tokens in request body") + } + if _, ok := requestBody["max_tokens"]; ok { + t.Fatalf("did not expect max_tokens key for glm model") + } +} + +func TestProviderChat_ParsesToolCalls(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{ + "content": "", + "tool_calls": []map[string]interface{}{ + { + "id": "call_1", + "type": "function", + "function": map[string]interface{}{ + "name": "get_weather", + "arguments": "{\"city\":\"SF\"}", + }, + }, + }, + }, + "finish_reason": "tool_calls", + }, + }, + "usage": map[string]interface{}{ + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if len(out.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls)) + } + if out.ToolCalls[0].Name != "get_weather" { + t.Fatalf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather") + } + if out.ToolCalls[0].Arguments["city"] != "SF" { + t.Fatalf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"]) + } +} + +func TestProviderChat_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad request", http.StatusBadRequest) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-4o", nil) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) { + var requestBody map[string]interface{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "moonshot/kimi-k2.5", + map[string]interface{}{"temperature": 0.3}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["model"] != "kimi-k2.5" { + t.Fatalf("model = %v, want kimi-k2.5", requestBody["model"]) + } + if requestBody["temperature"] != 1.0 { + t.Fatalf("temperature = %v, want 1.0", requestBody["temperature"]) + } +} + +func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) { + tests := []struct { + name string + input string + wantModel string + }{ + { + name: "strips groq prefix and keeps nested model", + input: "groq/openai/gpt-oss-120b", + wantModel: "openai/gpt-oss-120b", + }, + { + name: "strips ollama prefix", + input: "ollama/qwen2.5:14b", + wantModel: "qwen2.5:14b", + }, + { + name: "strips deepseek prefix", + input: "deepseek/deepseek-chat", + wantModel: "deepseek-chat", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var requestBody map[string]interface{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, tt.input, nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["model"] != tt.wantModel { + t.Fatalf("model = %v, want %s", requestBody["model"], tt.wantModel) + } + }) + } +} + +func TestProvider_ProxyConfigured(t *testing.T) { + proxyURL := "http://127.0.0.1:8080" + p := NewProvider("key", "https://example.com", proxyURL) + + transport, ok := p.httpClient.Transport.(*http.Transport) + if !ok || transport == nil { + t.Fatalf("expected http transport with proxy, got %T", p.httpClient.Transport) + } + + req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}} + gotProxy, err := transport.Proxy(req) + if err != nil { + t.Fatalf("proxy function returned error: %v", err) + } + if gotProxy == nil || gotProxy.String() != proxyURL { + t.Fatalf("proxy = %v, want %s", gotProxy, proxyURL) + } +} + +func TestProviderChat_AcceptsNumericOptionTypes(t *testing.T) { + var requestBody map[string]interface{} + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + resp := map[string]interface{}{ + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{"content": "ok"}, + "finish_reason": "stop", + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + p := NewProvider("key", server.URL, "") + _, err := p.Chat( + t.Context(), + []Message{{Role: "user", Content: "hi"}}, + nil, + "gpt-4o", + map[string]interface{}{"max_tokens": float64(512), "temperature": 1}, + ) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + if requestBody["max_tokens"] != float64(512) { + t.Fatalf("max_tokens = %v, want 512", requestBody["max_tokens"]) + } + if requestBody["temperature"] != float64(1) { + t.Fatalf("temperature = %v, want 1", requestBody["temperature"]) + } +} + +func TestNormalizeModel_UsesAPIBase(t *testing.T) { + if got := normalizeModel("deepseek/deepseek-chat", "https://api.deepseek.com/v1"); got != "deepseek-chat" { + t.Fatalf("normalizeModel(deepseek) = %q, want %q", got, "deepseek-chat") + } + if got := normalizeModel("openrouter/auto", "https://openrouter.ai/api/v1"); got != "openrouter/auto" { + t.Fatalf("normalizeModel(openrouter) = %q, want %q", got, "openrouter/auto") + } +} diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go new file mode 100644 index 0000000000..6b33ae7342 --- /dev/null +++ b/pkg/providers/protocoltypes/types.go @@ -0,0 +1,45 @@ +package protocoltypes + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type,omitempty"` + Function *FunctionCall `json:"function,omitempty"` + Name string `json:"name,omitempty"` + Arguments map[string]interface{} `json:"arguments,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type LLMResponse struct { + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + FinishReason string `json:"finish_reason"` + Usage *UsageInfo `json:"usage,omitempty"` +} + +type UsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type ToolDefinition struct { + Type string `json:"type"` + Function ToolFunctionDefinition `json:"function"` +} + +type ToolFunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` +} diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 88b62e9758..221a842faa 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -1,52 +1,20 @@ package providers -import "context" +import ( + "context" -type ToolCall struct { - ID string `json:"id"` - Type string `json:"type,omitempty"` - Function *FunctionCall `json:"function,omitempty"` - Name string `json:"name,omitempty"` - Arguments map[string]interface{} `json:"arguments,omitempty"` -} - -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -type LLMResponse struct { - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - FinishReason string `json:"finish_reason"` - Usage *UsageInfo `json:"usage,omitempty"` -} - -type UsageInfo struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" +) -type Message struct { - Role string `json:"role"` - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` -} +type ToolCall = protocoltypes.ToolCall +type FunctionCall = protocoltypes.FunctionCall +type LLMResponse = protocoltypes.LLMResponse +type UsageInfo = protocoltypes.UsageInfo +type Message = protocoltypes.Message +type ToolDefinition = protocoltypes.ToolDefinition +type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition type LLMProvider interface { Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) GetDefaultModel() string } - -type ToolDefinition struct { - Type string `json:"type"` - Function ToolFunctionDefinition `json:"function"` -} - -type ToolFunctionDefinition struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters map[string]interface{} `json:"parameters"` -}