Skip to content

Commit 603a3e3

Browse files
add xai support (#135)
1 parent e14de7a commit 603a3e3

File tree

6 files changed

+90
-14
lines changed

6 files changed

+90
-14
lines changed

internal/config/config.go

+14-2
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,13 @@ func setProviderDefaults() {
242242
if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" {
243243
viper.SetDefault("providers.openrouter.apiKey", apiKey)
244244
}
245+
if apiKey := os.Getenv("XAI_API_KEY"); apiKey != "" {
246+
viper.SetDefault("providers.xai.apiKey", apiKey)
247+
}
248+
if apiKey := os.Getenv("AZURE_OPENAI_ENDPOINT"); apiKey != "" {
249+
// api-key may be empty when using Entra ID credentials – that's okay
250+
viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY"))
251+
}
245252

246253
// Use this order to set the default models
247254
// 1. Anthropic
@@ -292,6 +299,13 @@ func setProviderDefaults() {
292299
return
293300
}
294301

302+
if viper.Get("providers.xai.apiKey") != "" {
303+
viper.SetDefault("agents.coder.model", models.XAIGrok3Beta)
304+
viper.SetDefault("agents.task.model", models.XAIGrok3Beta)
305+
viper.SetDefault("agents.title.model", models.XAiGrok3MiniFastBeta)
306+
return
307+
}
308+
295309
// AWS Bedrock configuration
296310
if hasAWSCredentials() {
297311
viper.SetDefault("agents.coder.model", models.BedrockClaude37Sonnet)
@@ -301,8 +315,6 @@ func setProviderDefaults() {
301315
}
302316

303317
if os.Getenv("AZURE_OPENAI_ENDPOINT") != "" {
304-
// api-key may be empty when using Entra ID credentials – that's okay
305-
viper.SetDefault("providers.azure.apiKey", os.Getenv("AZURE_OPENAI_API_KEY"))
306318
viper.SetDefault("agents.coder.model", models.AzureGPT41)
307319
viper.SetDefault("agents.task.model", models.AzureGPT41Mini)
308320
viper.SetDefault("agents.title.model", models.AzureGPT41Mini)

internal/llm/models/models.go

+1
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,5 @@ func init() {
8989
maps.Copy(SupportedModels, GroqModels)
9090
maps.Copy(SupportedModels, AzureModels)
9191
maps.Copy(SupportedModels, OpenRouterModels)
92+
maps.Copy(SupportedModels, XAIModels)
9293
}

internal/llm/models/xai.go

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package models
2+
3+
const (
4+
ProviderXAI ModelProvider = "xai"
5+
6+
XAIGrok3Beta ModelID = "grok-3-beta"
7+
XAIGrok3MiniBeta ModelID = "grok-3-mini-beta"
8+
XAIGrok3FastBeta ModelID = "grok-3-fast-beta"
9+
XAiGrok3MiniFastBeta ModelID = "grok-3-mini-fast-beta"
10+
)
11+
12+
var XAIModels = map[ModelID]Model{
13+
XAIGrok3Beta: {
14+
ID: XAIGrok3Beta,
15+
Name: "Grok3 Beta",
16+
Provider: ProviderXAI,
17+
APIModel: "grok-3-beta",
18+
CostPer1MIn: 3.0,
19+
CostPer1MInCached: 0,
20+
CostPer1MOut: 15,
21+
CostPer1MOutCached: 0,
22+
ContextWindow: 131_072,
23+
DefaultMaxTokens: 20_000,
24+
},
25+
XAIGrok3MiniBeta: {
26+
ID: XAIGrok3MiniBeta,
27+
Name: "Grok3 Mini Beta",
28+
Provider: ProviderXAI,
29+
APIModel: "grok-3-mini-beta",
30+
CostPer1MIn: 0.3,
31+
CostPer1MInCached: 0,
32+
CostPer1MOut: 0.5,
33+
CostPer1MOutCached: 0,
34+
ContextWindow: 131_072,
35+
DefaultMaxTokens: 20_000,
36+
},
37+
XAIGrok3FastBeta: {
38+
ID: XAIGrok3FastBeta,
39+
Name: "Grok3 Fast Beta",
40+
Provider: ProviderXAI,
41+
APIModel: "grok-3-fast-beta",
42+
CostPer1MIn: 5,
43+
CostPer1MInCached: 0,
44+
CostPer1MOut: 25,
45+
CostPer1MOutCached: 0,
46+
ContextWindow: 131_072,
47+
DefaultMaxTokens: 20_000,
48+
},
49+
XAiGrok3MiniFastBeta: {
50+
ID: XAiGrok3MiniFastBeta,
51+
Name: "Grok3 Mini Fast Beta",
52+
Provider: ProviderXAI,
53+
APIModel: "grok-3-mini-fast-beta",
54+
CostPer1MIn: 0.6,
55+
CostPer1MInCached: 0,
56+
CostPer1MOut: 4.0,
57+
CostPer1MOutCached: 0,
58+
ContextWindow: 131_072,
59+
DefaultMaxTokens: 20_000,
60+
},
61+
}

internal/llm/provider/openai.go

+3-10
Original file line numberDiff line numberDiff line change
@@ -258,15 +258,6 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
258258
chunk := openaiStream.Current()
259259
acc.AddChunk(chunk)
260260

261-
if tool, ok := acc.JustFinishedToolCall(); ok {
262-
toolCalls = append(toolCalls, message.ToolCall{
263-
ID: tool.Id,
264-
Name: tool.Name,
265-
Input: tool.Arguments,
266-
Type: "function",
267-
})
268-
}
269-
270261
for _, choice := range chunk.Choices {
271262
if choice.Delta.Content != "" {
272263
eventChan <- ProviderEvent{
@@ -282,7 +273,9 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t
282273
if err == nil || errors.Is(err, io.EOF) {
283274
// Stream completed successfully
284275
finishReason := o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
285-
276+
if len(acc.ChatCompletion.Choices[0].Message.ToolCalls) > 0 {
277+
toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
278+
}
286279
if len(toolCalls) > 0 {
287280
finishReason = message.FinishReasonToolUse
288281
}

internal/llm/provider/provider.go

+9
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,15 @@ func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption
132132
options: clientOptions,
133133
client: newOpenAIClient(clientOptions),
134134
}, nil
135+
case models.ProviderXAI:
136+
clientOptions.openaiOptions = append(clientOptions.openaiOptions,
137+
WithOpenAIBaseURL("https://api.x.ai/v1"),
138+
)
139+
return &baseProvider[OpenAIClient]{
140+
options: clientOptions,
141+
client: newOpenAIClient(clientOptions),
142+
}, nil
143+
135144
case models.ProviderMock:
136145
// TODO: implement mock client for test
137146
panic("not implemented")

internal/tui/tui.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ var keys = keyMap{
5656
),
5757

5858
Models: key.NewBinding(
59-
key.WithKeys("ctrl+m"),
60-
key.WithHelp("ctrl+m", "model selection"),
59+
key.WithKeys("ctrl+o"),
60+
key.WithHelp("ctrl+o", "model selection"),
6161
),
6262

6363
SwitchTheme: key.NewBinding(

0 commit comments

Comments
 (0)