diff --git a/config/config.example.json b/config/config.example.json index aa75c8338..dd282cc45 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -107,6 +107,10 @@ "moonshot": { "api_key": "sk-xxx", "api_base": "" + }, + "mistral": { + "api_key": "", + "api_base": "" } }, "tools": { diff --git a/pkg/config/config.go b/pkg/config/config.go index d76ec8095..4ade9602a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -179,6 +179,7 @@ type ProvidersConfig struct { ShengSuanYun ProviderConfig `json:"shengsuanyun"` DeepSeek ProviderConfig `json:"deepseek"` GitHubCopilot ProviderConfig `json:"github_copilot"` + Mistral ProviderConfig `json:"mistral"` } type ProviderConfig struct { @@ -304,6 +305,7 @@ func DefaultConfig() *Config { Nvidia: ProviderConfig{}, Moonshot: ProviderConfig{}, ShengSuanYun: ProviderConfig{}, + Mistral: ProviderConfig{}, }, Gateway: GatewayConfig{ Host: "0.0.0.0", @@ -405,6 +407,9 @@ func (c *Config) GetAPIKey() string { if c.Providers.ShengSuanYun.APIKey != "" { return c.Providers.ShengSuanYun.APIKey } + if c.Providers.Mistral.APIKey != "" { + return c.Providers.Mistral.APIKey + } return "" } @@ -423,6 +428,12 @@ func (c *Config) GetAPIBase() string { if c.Providers.VLLM.APIKey != "" && c.Providers.VLLM.APIBase != "" { return c.Providers.VLLM.APIBase } + if c.Providers.Mistral.APIKey != "" { + if c.Providers.Mistral.APIBase != "" { + return c.Providers.Mistral.APIBase + } + return "https://api.mistral.ai/v1" + } return "" } diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 17eb6214c..88e1baa56 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -61,32 +61,125 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } } - requestBody := map[string]interface{}{ - "model": model, - "messages": messages, + // Determine the endpoint and request format + // Mistral /v1/conversations uses "inputs" and "completion_args" + useConversations := strings.Contains(p.apiBase, "/conversations") + + var requestBody map[string]interface{} + if useConversations { + // Mistral conversations API: filter out system messages and put them in instructions + // Also convert tool messages to user messages + var filteredInputs []Message + var systemContent string + + for _, msg := range messages { + if msg.Role == "system" { + // Collect system content + if systemContent != "" { + systemContent += "\n\n" + } + systemContent += msg.Content + } else if msg.Role == "tool" { + // Convert tool results to user messages for Mistral + // Format: "The result of tool_name is: " + toolName := "" + if len(msg.ToolCalls) > 0 { + toolName = msg.ToolCalls[0].Name + } + if toolName != "" { + filteredInputs = append(filteredInputs, Message{ + Role: "user", + Content: "The result of " + toolName + " is: " + msg.Content, + }) + } else { + filteredInputs = append(filteredInputs, Message{ + Role: "user", + Content: msg.Content, + }) + } + } else { + // Only keep user and assistant messages + filteredInputs = append(filteredInputs, msg) + } + } + + // Convert filteredInputs to []map[string]interface{} for JSON + inputsForJSON := make([]map[string]interface{}, len(filteredInputs)) + for i, msg := range filteredInputs { + inputsForJSON[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + // Mistral conversations API format + requestBody = map[string]interface{}{ + "model": model, + "inputs": inputsForJSON, + } + + // Add instructions from system message + if systemContent != "" { + requestBody["instructions"] = systemContent + } + + // Add completion_args for Mistral conversations API + completionArgs := map[string]interface{}{} + if maxTokens, ok := options["max_tokens"].(int); ok { + completionArgs["max_tokens"] = maxTokens + } + if temperature, ok := options["temperature"].(float64); ok { + completionArgs["temperature"] = temperature + } + if topP, ok := options["top_p"].(float64); ok { + completionArgs["top_p"] = topP + } + if len(completionArgs) > 0 { + requestBody["completion_args"] = completionArgs + } + } else { + // Standard OpenAI-compatible format + requestBody = map[string]interface{}{ + "model": model, + "messages": messages, + } + } + + // Add tools for Mistral conversations API + // For Mistral, use built-in web_search instead of custom tools + if useConversations && len(tools) > 0 { + // Convert tools to Mistral format + mistralTools := convertToolsForMistral(tools) + if len(mistralTools) > 0 { + requestBody["tools"] = mistralTools + } } - if len(tools) > 0 { + // Add tools for non-conversations API only + if len(tools) > 0 && !useConversations { requestBody["tools"] = tools requestBody["tool_choice"] = "auto" } - if maxTokens, ok := options["max_tokens"].(int); ok { - lowerModel := strings.ToLower(model) - if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { - requestBody["max_completion_tokens"] = maxTokens - } else { - requestBody["max_tokens"] = maxTokens + // Only add top-level params for non-conversations API + if !useConversations { + if maxTokens, ok := options["max_tokens"].(int); ok { + lowerModel := strings.ToLower(model) + if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { + requestBody["max_completion_tokens"] = maxTokens + } else { + requestBody["max_tokens"] = maxTokens + } } - } - if temperature, ok := options["temperature"].(float64); ok { - lowerModel := strings.ToLower(model) - // Kimi k2 models only support temperature=1 - if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { - requestBody["temperature"] = 1.0 - } else { - requestBody["temperature"] = temperature + if temperature, ok := options["temperature"].(float64); ok { + lowerModel := strings.ToLower(model) + // Kimi k2 models only support temperature=1 + if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { + requestBody["temperature"] = 1.0 + } else { + requestBody["temperature"] = temperature + } } } @@ -95,7 +188,13 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too return nil, fmt.Errorf("failed to marshal request: %w", err) } - req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData)) + // Determine the endpoint - use /chat/completions unless already using /conversations + endpoint := "/chat/completions" + if strings.Contains(p.apiBase, "/conversations") { + endpoint = "" + } + + req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+endpoint, bytes.NewReader(jsonData)) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } @@ -117,13 +216,75 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) + return nil, fmt.Errorf("API request failed:\n URL: %s\n Status: %d\n Body: %s", p.apiBase+endpoint, resp.StatusCode, string(body)) } return p.parseResponse(body) } func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) { + // First, try to parse as conversations API response format + // The /v1/conversations API returns {"outputs": [{"type": "message.output", "content": [...]}]} + // When tools are used, there can be multiple outputs (tool.execution + message.output) + var convResponse struct { + Outputs []struct { + Type string `json:"type"` + Content interface{} `json:"content"` // Can be string, array of {type, text}, or nil + Name string `json:"name"` // For tool.execution + Arguments string `json:"arguments"` // For tool.execution + } `json:"outputs"` + Usage *UsageInfo `json:"usage"` + } + + if err := json.Unmarshal(body, &convResponse); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Check if it's a conversations API response + if len(convResponse.Outputs) > 0 { + // Parse as conversations API format + toolCalls := []ToolCall{} + content := "" + + for _, output := range convResponse.Outputs { + if output.Type == "tool.execution" { + // This is a tool call + args := make(map[string]interface{}) + if output.Arguments != "" { + json.Unmarshal([]byte(output.Arguments), &args) + } + toolCalls = append(toolCalls, ToolCall{ + ID: output.Name, // Use tool name as ID + Name: output.Name, + Arguments: args, + }) + } else if output.Type == "message.output" || output.Type == "message" { + // This is the actual message content + switch c := output.Content.(type) { + case string: + content = c + case []interface{}: + // Array of {type, text} objects + for _, item := range c { + if itemMap, ok := item.(map[string]interface{}); ok { + if text, ok := itemMap["text"].(string); ok { + content += text + } + } + } + } + } + } + + return &LLMResponse{ + Content: content, + ToolCalls: toolCalls, + FinishReason: "stop", + Usage: convResponse.Usage, + }, nil + } + + // Fallback: try standard OpenAI format var apiResponse struct { Choices []struct { Message struct { @@ -322,6 +483,16 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiBase = "localhost:4321" } return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model) + case "mistral": + if cfg.Providers.Mistral.APIKey != "" { + apiKey = cfg.Providers.Mistral.APIKey + apiBase = cfg.Providers.Mistral.APIBase + proxy = cfg.Providers.Mistral.Proxy + // Mistral /v1/conversations endpoint for better rate limits + if apiBase == "" { + apiBase = "https://api.mistral.ai/v1/conversations" + } + } } @@ -338,7 +509,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiBase = "https://api.moonshot.cn/v1" } - case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): + case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/") || strings.HasPrefix(model, "mistral/"): apiKey = cfg.Providers.OpenRouter.APIKey proxy = cfg.Providers.OpenRouter.Proxy if cfg.Providers.OpenRouter.APIBase != "" { @@ -401,6 +572,14 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiBase = "https://integrate.api.nvidia.com/v1" } + case (strings.Contains(lowerModel, "mistral") || strings.HasPrefix(model, "mistral/")) && cfg.Providers.Mistral.APIKey != "": + apiKey = cfg.Providers.Mistral.APIKey + apiBase = cfg.Providers.Mistral.APIBase + proxy = cfg.Providers.Mistral.Proxy + if apiBase == "" { + apiBase = "https://api.mistral.ai/v1" + } + case cfg.Providers.VLLM.APIBase != "": apiKey = cfg.Providers.VLLM.APIKey apiBase = cfg.Providers.VLLM.APIBase @@ -431,3 +610,33 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { return NewHTTPProvider(apiKey, apiBase, proxy), nil } + +// convertToolsForMistral converts PicoClaw tools to Mistral format +// For Mistral conversations API, we use built-in web_search instead of custom tools +func convertToolsForMistral(tools []ToolDefinition) []map[string]interface{} { + mistralTools := []map[string]interface{}{} + + for _, tool := range tools { + if tool.Type == "function" { + // Check if this is a web_search tool - use Mistral's built-in instead + if tool.Function.Name == "web_search" || tool.Function.Name == "search" { + // Use Mistral's built-in web_search + mistralTools = append(mistralTools, map[string]interface{}{ + "type": "web_search", + }) + } else { + // Keep custom function in Mistral format + mistralTools = append(mistralTools, map[string]interface{}{ + "type": "function", + "function": map[string]interface{}{ + "name": tool.Function.Name, + "description": tool.Function.Description, + "parameters": tool.Function.Parameters, + }, + }) + } + } + } + + return mistralTools +}