Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions config/config.example.json
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@
"moonshot": {
"api_key": "sk-xxx",
"api_base": ""
},
"mistral": {
"api_key": "",
"api_base": ""
}
},
"tools": {
Expand Down
11 changes: 11 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ type ProvidersConfig struct {
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
DeepSeek ProviderConfig `json:"deepseek"`
GitHubCopilot ProviderConfig `json:"github_copilot"`
Mistral ProviderConfig `json:"mistral"`
}

type ProviderConfig struct {
Expand Down Expand Up @@ -304,6 +305,7 @@ func DefaultConfig() *Config {
Nvidia: ProviderConfig{},
Moonshot: ProviderConfig{},
ShengSuanYun: ProviderConfig{},
Mistral: ProviderConfig{},
},
Gateway: GatewayConfig{
Host: "0.0.0.0",
Expand Down Expand Up @@ -405,6 +407,9 @@ func (c *Config) GetAPIKey() string {
if c.Providers.ShengSuanYun.APIKey != "" {
return c.Providers.ShengSuanYun.APIKey
}
if c.Providers.Mistral.APIKey != "" {
return c.Providers.Mistral.APIKey
}
return ""
}

Expand All @@ -423,6 +428,12 @@ func (c *Config) GetAPIBase() string {
if c.Providers.VLLM.APIKey != "" && c.Providers.VLLM.APIBase != "" {
return c.Providers.VLLM.APIBase
}
if c.Providers.Mistral.APIKey != "" {
if c.Providers.Mistral.APIBase != "" {
return c.Providers.Mistral.APIBase
}
return "https://api.mistral.ai/v1"
}
return ""
}

Expand Down
251 changes: 230 additions & 21 deletions pkg/providers/http_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,32 +61,125 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
}
}

requestBody := map[string]interface{}{
"model": model,
"messages": messages,
// Determine the endpoint and request format
// Mistral /v1/conversations uses "inputs" and "completion_args"
useConversations := strings.Contains(p.apiBase, "/conversations")

var requestBody map[string]interface{}
if useConversations {
// Mistral conversations API: filter out system messages and put them in instructions
// Also convert tool messages to user messages
var filteredInputs []Message
var systemContent string

for _, msg := range messages {
if msg.Role == "system" {
// Collect system content
if systemContent != "" {
systemContent += "\n\n"
}
systemContent += msg.Content
} else if msg.Role == "tool" {
// Convert tool results to user messages for Mistral
// Format: "The result of tool_name is: <result>"
toolName := ""
if len(msg.ToolCalls) > 0 {
toolName = msg.ToolCalls[0].Name
}
if toolName != "" {
filteredInputs = append(filteredInputs, Message{
Role: "user",
Content: "The result of " + toolName + " is: " + msg.Content,
})
} else {
filteredInputs = append(filteredInputs, Message{
Role: "user",
Content: msg.Content,
})
}
} else {
// Only keep user and assistant messages
filteredInputs = append(filteredInputs, msg)
}
}

// Convert filteredInputs to []map[string]interface{} for JSON
inputsForJSON := make([]map[string]interface{}, len(filteredInputs))
for i, msg := range filteredInputs {
inputsForJSON[i] = map[string]interface{}{
"role": msg.Role,
"content": msg.Content,
}
}

// Mistral conversations API format
requestBody = map[string]interface{}{
"model": model,
"inputs": inputsForJSON,
}

// Add instructions from system message
if systemContent != "" {
requestBody["instructions"] = systemContent
}

// Add completion_args for Mistral conversations API
completionArgs := map[string]interface{}{}
if maxTokens, ok := options["max_tokens"].(int); ok {
completionArgs["max_tokens"] = maxTokens
}
if temperature, ok := options["temperature"].(float64); ok {
completionArgs["temperature"] = temperature
}
if topP, ok := options["top_p"].(float64); ok {
completionArgs["top_p"] = topP
}
if len(completionArgs) > 0 {
requestBody["completion_args"] = completionArgs
}
} else {
// Standard OpenAI-compatible format
requestBody = map[string]interface{}{
"model": model,
"messages": messages,
}
}

// Add tools for Mistral conversations API
// For Mistral, use built-in web_search instead of custom tools
if useConversations && len(tools) > 0 {
// Convert tools to Mistral format
mistralTools := convertToolsForMistral(tools)
if len(mistralTools) > 0 {
requestBody["tools"] = mistralTools
}
}

if len(tools) > 0 {
// Add tools for non-conversations API only
if len(tools) > 0 && !useConversations {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
}

if maxTokens, ok := options["max_tokens"].(int); ok {
lowerModel := strings.ToLower(model)
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") {
requestBody["max_completion_tokens"] = maxTokens
} else {
requestBody["max_tokens"] = maxTokens
// Only add top-level params for non-conversations API
if !useConversations {
if maxTokens, ok := options["max_tokens"].(int); ok {
lowerModel := strings.ToLower(model)
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") {
requestBody["max_completion_tokens"] = maxTokens
} else {
requestBody["max_tokens"] = maxTokens
}
}
}

if temperature, ok := options["temperature"].(float64); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
requestBody["temperature"] = 1.0
} else {
requestBody["temperature"] = temperature
if temperature, ok := options["temperature"].(float64); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
requestBody["temperature"] = 1.0
} else {
requestBody["temperature"] = temperature
}
}
}

Expand All @@ -95,7 +188,13 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
return nil, fmt.Errorf("failed to marshal request: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+"/chat/completions", bytes.NewReader(jsonData))
// Determine the endpoint - use /chat/completions unless already using /conversations
endpoint := "/chat/completions"
if strings.Contains(p.apiBase, "/conversations") {
endpoint = ""
}

req, err := http.NewRequestWithContext(ctx, "POST", p.apiBase+endpoint, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
Expand All @@ -117,13 +216,75 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("API request failed:\n URL: %s\n Status: %d\n Body: %s", p.apiBase+endpoint, resp.StatusCode, string(body))
}

return p.parseResponse(body)
}

func (p *HTTPProvider) parseResponse(body []byte) (*LLMResponse, error) {
// First, try to parse as conversations API response format
// The /v1/conversations API returns {"outputs": [{"type": "message.output", "content": [...]}]}
// When tools are used, there can be multiple outputs (tool.execution + message.output)
var convResponse struct {
Outputs []struct {
Type string `json:"type"`
Content interface{} `json:"content"` // Can be string, array of {type, text}, or nil
Name string `json:"name"` // For tool.execution
Arguments string `json:"arguments"` // For tool.execution
} `json:"outputs"`
Usage *UsageInfo `json:"usage"`
}

if err := json.Unmarshal(body, &convResponse); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}

// Check if it's a conversations API response
if len(convResponse.Outputs) > 0 {
// Parse as conversations API format
toolCalls := []ToolCall{}
content := ""

for _, output := range convResponse.Outputs {
if output.Type == "tool.execution" {
// This is a tool call
args := make(map[string]interface{})
if output.Arguments != "" {
json.Unmarshal([]byte(output.Arguments), &args)
}
toolCalls = append(toolCalls, ToolCall{
ID: output.Name, // Use tool name as ID
Name: output.Name,
Arguments: args,
})
} else if output.Type == "message.output" || output.Type == "message" {
// This is the actual message content
switch c := output.Content.(type) {
case string:
content = c
case []interface{}:
// Array of {type, text} objects
for _, item := range c {
if itemMap, ok := item.(map[string]interface{}); ok {
if text, ok := itemMap["text"].(string); ok {
content += text
}
}
}
}
}
}

return &LLMResponse{
Content: content,
ToolCalls: toolCalls,
FinishReason: "stop",
Usage: convResponse.Usage,
}, nil
}

// Fallback: try standard OpenAI format
var apiResponse struct {
Choices []struct {
Message struct {
Expand Down Expand Up @@ -322,6 +483,16 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
apiBase = "localhost:4321"
}
return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
case "mistral":
if cfg.Providers.Mistral.APIKey != "" {
apiKey = cfg.Providers.Mistral.APIKey
apiBase = cfg.Providers.Mistral.APIBase
proxy = cfg.Providers.Mistral.Proxy
// Mistral /v1/conversations endpoint for better rate limits
if apiBase == "" {
apiBase = "https://api.mistral.ai/v1/conversations"
}
}

}

Expand All @@ -338,7 +509,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
apiBase = "https://api.moonshot.cn/v1"
}

case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/") || strings.HasPrefix(model, "mistral/"):
apiKey = cfg.Providers.OpenRouter.APIKey
proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
Expand Down Expand Up @@ -401,6 +572,14 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
apiBase = "https://integrate.api.nvidia.com/v1"
}

case (strings.Contains(lowerModel, "mistral") || strings.HasPrefix(model, "mistral/")) && cfg.Providers.Mistral.APIKey != "":
apiKey = cfg.Providers.Mistral.APIKey
apiBase = cfg.Providers.Mistral.APIBase
proxy = cfg.Providers.Mistral.Proxy
if apiBase == "" {
apiBase = "https://api.mistral.ai/v1"
}

case cfg.Providers.VLLM.APIBase != "":
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
Expand Down Expand Up @@ -431,3 +610,33 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {

return NewHTTPProvider(apiKey, apiBase, proxy), nil
}

// convertToolsForMistral converts PicoClaw tools to Mistral format
// For Mistral conversations API, we use built-in web_search instead of custom tools
func convertToolsForMistral(tools []ToolDefinition) []map[string]interface{} {
mistralTools := []map[string]interface{}{}

for _, tool := range tools {
if tool.Type == "function" {
// Check if this is a web_search tool - use Mistral's built-in instead
if tool.Function.Name == "web_search" || tool.Function.Name == "search" {
// Use Mistral's built-in web_search
mistralTools = append(mistralTools, map[string]interface{}{
"type": "web_search",
})
} else {
// Keep custom function in Mistral format
mistralTools = append(mistralTools, map[string]interface{}{
"type": "function",
"function": map[string]interface{}{
"name": tool.Function.Name,
"description": tool.Function.Description,
"parameters": tool.Function.Parameters,
},
})
}
}
}

return mistralTools
}