Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions docs/proposals/004-vendor-specific-fields/proposal.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,14 @@ type ChatCompletionRequest struct {

// Vendor-specific fields are added as inline fields
*GCPVertexAIVendorFields `json:",inline,omitempty"`
*AnthropicVendorFields `json:",inline,omitempty"`
}

// GCPVertexAIVendorFields contains GCP Vertex AI (Gemini) vendor-specific fields.
type GCPVertexAIVendorFields struct {
// GenerationConfig holds Gemini generation configuration options.
GenerationConfig *GCPVertexAIGenerationConfig `json:"generationConfig,omitempty"`
}

// GCPVertexAIGenerationConfig represents Gemini generation configuration options.
type GCPVertexAIGenerationConfig struct {
ThinkingConfig *genai.GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"`
}

// AnthropicVendorFields contains GCP Anthropic-specific fields.
type AnthropicVendorFields struct {
Thinking *anthropic.ThinkingConfigParamUnion `json:"thinking,omitzero"`
// SafetySettings: Safety settings in the request to block unsafe content in the response.
//
// https://cloud.google.com/vertex-ai/docs/reference/rest/v1/SafetySetting
SafetySettings []*genai.SafetySetting `json:"safetySettings,omitzero"`
}
```

Expand Down
94 changes: 68 additions & 26 deletions internal/apischema/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"strings"
"time"

"github.com/anthropics/anthropic-sdk-go"
"github.com/openai/openai-go/v2"
"github.com/tidwall/gjson"
"google.golang.org/genai"
Expand Down Expand Up @@ -806,6 +805,71 @@ type WebSearchLocation struct {
Country string `json:"country,omitempty"`
}

// ThinkingConfig contains thinking config for reasoning models
type ThinkingUnion struct {
OfEnabled *ThinkingEnabled `json:",omitzero,inline"`
OfDisabled *ThinkingDisabled `json:",omitzero,inline"`
}

type ThinkingEnabled struct {
// Determines how many tokens the model can use for its internal reasoning process.
// Larger budgets can enable more thorough analysis for complex problems, improving
// response quality.
BudgetTokens int64 `json:"budget_tokens"`
// This field can be elided, and will marshal its zero value as "enabled".
Type string `json:"type"`

// Optional. Indicates the thinking budget in tokens.
IncludeThoughts bool `json:"includeThoughts,omitempty"`
}

type ThinkingDisabled struct {
Type string `json:"type,"`
}

// MarshalJSON implements the json.Marshaler interface for ThinkingUnion.
func (t *ThinkingUnion) MarshalJSON() ([]byte, error) {
if t.OfEnabled != nil {
return json.Marshal(t.OfEnabled)
}
if t.OfDisabled != nil {
return json.Marshal(t.OfDisabled)
}
// If both are nil, return an empty object or an error, depending on your desired behavior.
return []byte(`{}`), nil
}

// UnmarshalJSON implements the json.Unmarshaler interface for ThinkingUnion.
func (t *ThinkingUnion) UnmarshalJSON(data []byte) error {
// Use a temporary struct to determine the type
typeResult := gjson.GetBytes(data, "type")
if !typeResult.Exists() {
return errors.New("thinking config does not have a type")
}

// Based on the 'type' field, unmarshal into the correct struct.
typeVal := typeResult.String()

switch typeVal {
case "enabled":
var enabled ThinkingEnabled
if err := json.Unmarshal(data, &enabled); err != nil {
return err
}
t.OfEnabled = &enabled
case "disabled":
var disabled ThinkingDisabled
if err := json.Unmarshal(data, &disabled); err != nil {
return err
}
t.OfDisabled = &disabled
default:
return fmt.Errorf("invalid thinking union type: %s", typeVal)
}

return nil
}

type ChatCompletionRequest struct {
// Messages: A list of messages comprising the conversation so far.
// Depending on the model you use, different message types (modalities) are supported,
Expand Down Expand Up @@ -969,9 +1033,6 @@ type ChatCompletionRequest struct {
// GCPVertexAIVendorFields configures the GCP VertexAI specific fields during schema translation.
*GCPVertexAIVendorFields `json:",inline,omitempty"`

// AnthropicVendorFields configures the Anthropic specific fields during schema translation.
*AnthropicVendorFields `json:",inline,omitempty"`

// GuidedChoice: The output will be exactly one of the choices.
GuidedChoice []string `json:"guided_choice,omitzero"`

Expand All @@ -980,6 +1041,9 @@ type ChatCompletionRequest struct {

// GuidedJSON: The output will follow the JSON schema.
GuidedJSON json.RawMessage `json:"guided_json,omitzero"`

// Thinking: The thinking config for reasoning models
Thinking *ThinkingUnion `json:"thinking,omitzero"`
}

type StreamOptions struct {
Expand Down Expand Up @@ -1550,34 +1614,12 @@ func (t JSONUNIXTime) Equal(other JSONUNIXTime) bool {

// GCPVertexAIVendorFields contains GCP Vertex AI (Gemini) vendor-specific fields.
type GCPVertexAIVendorFields struct {
// GenerationConfig holds Gemini generation configuration options.
// Currently only a subset of the options are supported.
//
// https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig
GenerationConfig *GCPVertexAIGenerationConfig `json:"generationConfig,omitzero"`

// SafetySettings: Safety settings in the request to block unsafe content in the response.
//
// https://cloud.google.com/vertex-ai/docs/reference/rest/v1/SafetySetting
SafetySettings []*genai.SafetySetting `json:"safetySettings,omitzero"`
}

// GCPVertexAIGenerationConfig represents Gemini generation configuration options.
type GCPVertexAIGenerationConfig struct {
// ThinkingConfig holds Gemini thinking configuration options.
//
// https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig#ThinkingConfig
ThinkingConfig *genai.ThinkingConfig `json:"thinkingConfig,omitzero"`
}

// AnthropicVendorFields contains Anthropic vendor-specific fields.
type AnthropicVendorFields struct {
// Thinking holds Anthropic thinking configuration options.
//
// https://docs.anthropic.com/en/api/messages#body-thinking
Thinking *anthropic.ThinkingConfigParamUnion `json:"thinking,omitzero"`
}

// ReasoningContentUnion content regarding the reasoning that is carried out by the model.
// Reasoning refers to a Chain of Thought (CoT) that the model generates to enhance the accuracy of its final response.
type ReasoningContentUnion struct {
Expand Down
76 changes: 0 additions & 76 deletions internal/apischema/openai/vendor_fields_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/openai/openai-go/v2/packages/param"
"github.com/stretchr/testify/require"
"google.golang.org/genai"
"k8s.io/utils/ptr"
)

func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) {
Expand All @@ -36,12 +35,6 @@ func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) {
"content": "Hello, world!"
}
],
"generationConfig": {
"thinkingConfig": {
"includeThoughts": true,
"thinkingBudget": 1000
}
},
"safetySettings": [{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH"
Expand All @@ -58,12 +51,6 @@ func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) {
},
},
GCPVertexAIVendorFields: &GCPVertexAIVendorFields{
GenerationConfig: &GCPVertexAIGenerationConfig{
ThinkingConfig: &genai.ThinkingConfig{
IncludeThoughts: true,
ThinkingBudget: ptr.To(int32(1000)),
},
},
SafetySettings: []*genai.SafetySetting{
{
Category: genai.HarmCategoryHarassment,
Expand All @@ -73,55 +60,6 @@ func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) {
},
},
},
{
name: "Request with multiple vendor fields",
jsonData: []byte(`{
"model": "claude-3",
"messages": [
{
"role": "user",
"content": "Multiple vendors test"
}
],
"generationConfig": {
"thinkingConfig": {
"includeThoughts": true,
"thinkingBudget": 1000
}
},
"thinking": {
"type": "enabled",
"budget_tokens": 1000
}
}`),
expected: &ChatCompletionRequest{
Model: "claude-3",
Messages: []ChatCompletionMessageParamUnion{
{
OfUser: &ChatCompletionUserMessageParam{
Role: ChatMessageRoleUser,
Content: StringOrUserRoleContentUnion{Value: "Multiple vendors test"},
},
},
},
AnthropicVendorFields: &AnthropicVendorFields{
Thinking: &anthropic.ThinkingConfigParamUnion{
OfEnabled: &anthropic.ThinkingConfigEnabledParam{
BudgetTokens: 1000,
Type: "enabled",
},
},
},
GCPVertexAIVendorFields: &GCPVertexAIVendorFields{
GenerationConfig: &GCPVertexAIGenerationConfig{
ThinkingConfig: &genai.ThinkingConfig{
IncludeThoughts: true,
ThinkingBudget: ptr.To(int32(1000)),
},
},
},
},
},
{
name: "Request without vendor fields",
jsonData: []byte(`{
Expand Down Expand Up @@ -207,20 +145,6 @@ func TestChatCompletionRequest_VendorFieldsExtraction(t *testing.T) {
}`),
expectedErrMsg: "invalid character",
},
{
name: "Invalid vendor field type",
jsonData: []byte(`{
"model": "gemini-1.5-pro",
"messages": [
{
"role": "user",
"content": "Test invalid vendor field type"
}
],
"generationConfig": "invalid_string_type"
}`),
expectedErrMsg: "cannot unmarshal string into Go struct field",
},
}

for _, tt := range tests {
Expand Down
29 changes: 26 additions & 3 deletions internal/extproc/translator/openai_awsbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,29 @@ type openAIToAWSBedrockTranslatorV1ChatCompletion struct {
activeToolStream bool
}

func getAwsBedrockThinkingMap(tu *openai.ThinkingUnion) map[string]any {
if tu == nil {
return nil
}

resultMap := make(map[string]any)

if tu.OfEnabled != nil {
reasoningConfigMap := map[string]any{
"type": "enabled",
"budget_tokens": tu.OfEnabled.BudgetTokens,
}
resultMap["thinking"] = reasoningConfigMap
} else if tu.OfDisabled != nil {
reasoningConfigMap := map[string]any{
"type": "disabled",
}
resultMap["thinking"] = reasoningConfigMap
}

return resultMap
}

// RequestBody implements [OpenAIChatCompletionTranslator.RequestBody].
func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) RequestBody(_ []byte, openAIReq *openai.ChatCompletionRequest, _ bool) (
headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, err error,
Expand Down Expand Up @@ -90,12 +113,12 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) RequestBody(_ []byte, ope
bedrockReq.InferenceConfig.StopSequences = openAIReq.Stop.OfStringArray
}

// Handle Anthropic vendor fields if present. Currently only supports thinking fields.
if openAIReq.AnthropicVendorFields != nil && openAIReq.Thinking != nil {
// Handle thinking config
if openAIReq.Thinking != nil {
if bedrockReq.AdditionalModelRequestFields == nil {
bedrockReq.AdditionalModelRequestFields = make(map[string]interface{})
}
bedrockReq.AdditionalModelRequestFields["thinking"] = openAIReq.Thinking
bedrockReq.AdditionalModelRequestFields = getAwsBedrockThinkingMap(openAIReq.Thinking)
}

// Convert Chat Completion messages.
Expand Down
28 changes: 11 additions & 17 deletions internal/extproc/translator/openai_awsbedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"strings"
"testing"

"github.com/anthropics/anthropic-sdk-go"
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -895,11 +894,10 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T)
},
},
},
AnthropicVendorFields: &openai.AnthropicVendorFields{
Thinking: &anthropic.ThinkingConfigParamUnion{
OfEnabled: &anthropic.ThinkingConfigEnabledParam{
BudgetTokens: int64(1024),
},
Thinking: &openai.ThinkingUnion{
OfEnabled: &openai.ThinkingEnabled{
BudgetTokens: int64(1024),
Type: "enabled",
},
},
},
Expand Down Expand Up @@ -1113,12 +1111,10 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T)
},
},
},
AnthropicVendorFields: &openai.AnthropicVendorFields{
Thinking: &anthropic.ThinkingConfigParamUnion{
OfEnabled: &anthropic.ThinkingConfigEnabledParam{
Type: "enabled",
BudgetTokens: 1024,
},
Thinking: &openai.ThinkingUnion{
OfEnabled: &openai.ThinkingEnabled{
Type: "enabled",
BudgetTokens: 1024,
},
},
},
Expand Down Expand Up @@ -1147,11 +1143,9 @@ func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T)
},
},
},
AnthropicVendorFields: &openai.AnthropicVendorFields{
Thinking: &anthropic.ThinkingConfigParamUnion{
OfDisabled: &anthropic.ThinkingConfigDisabledParam{
Type: "disabled",
},
Thinking: &openai.ThinkingUnion{
OfDisabled: &openai.ThinkingDisabled{
Type: "disabled",
},
},
},
Expand Down
Loading
Loading