Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open ai wrapper #1371

Open
wants to merge 29 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
aa48ff4
wrapped openai api chat completion endpoint
yash-sojitra Dec 24, 2024
251ca2a
formatting changes
yash-sojitra Dec 24, 2024
1db7000
increased code coverage
yash-sojitra Jan 7, 2025
d716392
Merge branch 'development' into OpenAI-wrapper
Umang01-hash Jan 8, 2025
bc4ce3c
resolved changes
yash-sojitra Jan 8, 2025
63e2e5d
added usage struct to package level, updated get method, and resolved…
yash-sojitra Jan 9, 2025
2d869c3
formatting changes
yash-sojitra Jan 9, 2025
fcc32fa
Merge branch 'development' into OpenAI-wrapper
yash-sojitra Jan 9, 2025
d78f784
Merge branch 'development' into OpenAI-wrapper
yash-sojitra Jan 10, 2025
71fa839
typo fix
yash-sojitra Jan 10, 2025
4f47c74
resolved changes, added more tests to client, increased code coverage
yash-sojitra Jan 13, 2025
7c41988
Merge branch 'development' into OpenAI-wrapper
Umang01-hash Jan 21, 2025
266c1ac
Merge branch 'development' into OpenAI-wrapper
Umang01-hash Jan 27, 2025
ee06e2f
Merge branch 'OpenAI-wrapper' of https://github.com/yash-sojitra/gofr…
yash-sojitra Jan 27, 2025
630755e
added code for injecting openai package into container. added a new p…
yash-sojitra Jan 30, 2025
7929e5b
added documentation and wrote injection code
yash-sojitra Jan 30, 2025
09d9e18
Merge branch 'development' into OpenAI-wrapper
Umang01-hash Jan 30, 2025
c209562
resolved changes and fixed mockcontainer_Test that was failing in cod…
yash-sojitra Jan 30, 2025
96f2da6
solved linting error in services.go file
yash-sojitra Jan 31, 2025
30cae8f
more linter changes
yash-sojitra Jan 31, 2025
658a91a
resolve linters
Umang01-hash Jan 31, 2025
59c03eb
remove deprecated methods in test
Umang01-hash Jan 31, 2025
c20bc1e
resolved changes
yash-sojitra Jan 31, 2025
dd89f56
Merge branch 'development' into OpenAI-wrapper
Umang01-hash Feb 4, 2025
b9f4447
Merge branch 'development' into OpenAI-wrapper
yash-sojitra Feb 4, 2025
50e0700
Merge branch 'development' into OpenAI-wrapper
Umang01-hash Feb 5, 2025
88e7cb4
resolved changes
yash-sojitra Feb 6, 2025
e26d9ab
resolved changes and conflics
yash-sojitra Feb 6, 2025
8acefc2
Merge branch 'OpenAI-wrapper' of https://github.com/yash-sojitra/gofr…
yash-sojitra Feb 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions pkg/gofr/datasource/openai/chatcompletion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
package openai

import (
"context"
"encoding/json"
"errors"
"fmt"
"time"

"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)

const CompletionsEndpoint = "/v1/chat/completions"

type CreateCompletionsRequest struct {
Messages []Message `json:"messages,omitempty"`
Model string `json:"model,omitempty"`
Store bool `json:"store,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
MetaData interface{} `json:"metadata,omitempty"` // object or null
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
LogitBias map[string]string `json:"logit_bias,omitempty"`
LogProbs int `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` // deprecated
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
N int `json:"n,omitempty"`
Modalities []string `json:"modalities,omitempty"`
Prediction interface{} `json:"prediction,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`

Audio struct {
Voice string `json:"voice,omitempty"`
Format string `json:"format,omitempty"`
} `json:"audio,omitempty"`

ResposneFormat interface{} `json:"response_format,omitempty"`
yash-sojitra marked this conversation as resolved.
Show resolved Hide resolved
Seed int `json:"seed,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Stop interface{} `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`

StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
} `json:"stram_options,omitempty"`

Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`

Tools []struct {
Type string `json:"type,omitempty"`
Function struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Parameters interface{} `json:"parameters,omitempty"`
Strict bool `json:"strict,omitempty"`
} `json:"function,omitempty"`
} `json:"tools,omitempty"`

ToolChoice interface{} `json:"toolChoice,omitempty"`
yash-sojitra marked this conversation as resolved.
Show resolved Hide resolved
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
Suffix string `json:"suffix,omitempty"`
User string `json:"user,omitempty"`
}

type Message struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
Name string `json:"name,omitempty"`
}

type CreateCompletionsResponse struct {
ID string `json:"id,omitempty"`
Object string `json:"object,omitempty"`
Created int `json:"created,omitempty"`
Model string `json:"model,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`

Choices []struct {
Index int `json:"index,omitempty"`

Message struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
Refusal string `json:"refusal,omitempty"`
ToolCalls interface{} `json:"tool_calls,omitempty"`
} `json:"message"`

Logprobs interface{} `json:"logprobs,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
} `json:"choices,omitempty"`

Usage struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
CompletionTokelDetails interface{} `json:"completion_tokens_details,omitempty"`
PromptTokenDetails interface{} `json:"prompt_tokens_details,omitempty"`
} `json:"usage,omitempty"`

Error *Error `json:"error,omitempty"`
}

type Error struct {
Message string `json:"message,omitempty"`
Type string `json:"type,omitempty"`
Param interface{} `json:"param,omitempty"`
Code interface{} `json:"code,omitempty"`
yash-sojitra marked this conversation as resolved.
Show resolved Hide resolved
}

var (
ErrMissingBoth = errors.New("both messages and model fields not provided")
yash-sojitra marked this conversation as resolved.
Show resolved Hide resolved
ErrMissingMessages = errors.New("messages fields not provided")
ErrMissingModel = errors.New("model fields not provided")
)

func (e *Error) Error() string {
return fmt.Sprintf("%s: %s", e.Code, e.Message)
}

func (c *Client) CreateCompletionsRaw(ctx context.Context, r *CreateCompletionsRequest) ([]byte, error) {
return c.Post(ctx, CompletionsEndpoint, r)
}

func (c *Client) CreateCompletions(ctx context.Context, r *CreateCompletionsRequest) (response *CreateCompletionsResponse, err error) {
tracerCtx, span := c.AddTrace(ctx, "CreateCompletions")
startTime := time.Now()

if r.Messages == nil && r.Model == "" {
c.logger.Errorf("%v", ErrMissingBoth)
return nil, ErrMissingBoth
}

if r.Messages == nil {
c.logger.Errorf("%v", ErrMissingMessages)
return nil, ErrMissingMessages
}

if r.Model == "" {
c.logger.Errorf("%v", ErrMissingModel)
return nil, ErrMissingModel
}

raw, err := c.CreateCompletionsRaw(tracerCtx, r)
if err != nil {
return response, err
}

err = json.Unmarshal(raw, &response)

ql := &OpenAiAPILog{
ID: response.ID,
Object: response.Object,
Created: response.Created,
Model: response.Model,
ServiceTier: response.ServiceTier,
SystemFingerprint: response.SystemFingerprint,
Usage: response.Usage,
Error: response.Error,
}

c.SendChatCompletionOperationStats(ql, startTime, "ChatCompletion", span)

return response, err
}

func (c *Client) SendChatCompletionOperationStats(ql *OpenAiAPILog, startTime time.Time, method string, span trace.Span) {
duration := time.Since(startTime).Microseconds()

ql.Duration = duration

c.logger.Debug(ql)

c.metrics.RecordHistogram(context.Background(), "openai_api_request_duration", float64(duration))
c.metrics.RecordRequestCount(context.Background(), "openai_api_total_request_count")
c.metrics.RecordTokenUsage(context.Background(), "openai_api_token_usage", ql.Usage.PromptTokens, ql.Usage.CompletionTokens)

if span != nil {
defer span.End()
span.SetAttributes(attribute.Int64(fmt.Sprintf("openai.%v.duration", method), duration))
}
}
147 changes: 147 additions & 0 deletions pkg/gofr/datasource/openai/chatcompletion_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package openai

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)

//nolint:funlen // Function length is intentional due to complexity
yash-sojitra marked this conversation as resolved.
Show resolved Hide resolved
func Test_ChatCompletions(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockLogger := NewMockLogger(ctrl)
mockMetrics := NewMockMetrics(ctrl)

tests := []struct {
name string
request *CreateCompletionsRequest
response *CreateCompletionsResponse
expectedError error
setupMocks func(*MockLogger, *MockMetrics)
}{
{
name: "successful completion request",
request: &CreateCompletionsRequest{
Messages: []Message{{Role: "user", Content: "Hello"}},
Model: "gpt-3.5-turbo",
},
response: &CreateCompletionsResponse{
ID: "test-id",
Object: "chat.completion",
Created: 1234567890,
Usage: struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
CompletionTokelDetails interface{} `json:"completion_tokens_details,omitempty"`
PromptTokenDetails interface{} `json:"prompt_tokens_details,omitempty"`
}{
PromptTokens: 10,
CompletionTokens: 20,
TotalTokens: 30,
},
},
expectedError: nil,
setupMocks: func(logger *MockLogger, metrics *MockMetrics) {
metrics.EXPECT().RecordHistogram(gomock.Any(), "openai_api_request_duration", gomock.Any())
metrics.EXPECT().RecordRequestCount(gomock.Any(), "openai_api_total_request_count")
metrics.EXPECT().RecordTokenUsage(gomock.Any(), "openai_api_token_usage", 10, 20)
logger.EXPECT().Debug(gomock.Any())
},
},
{
name: "missing both messages and model",
request: &CreateCompletionsRequest{},
expectedError: ErrMissingBoth,
setupMocks: func(logger *MockLogger, _ *MockMetrics) {
logger.EXPECT().Errorf("%v", ErrMissingBoth)
},
},
{
name: "missing messages",
request: &CreateCompletionsRequest{
Model: "gpt-3.5-turbo",
},
expectedError: ErrMissingMessages,
setupMocks: func(logger *MockLogger, _ *MockMetrics) {
logger.EXPECT().Errorf("%v", ErrMissingMessages)
},
},
{
name: "missing model",
request: &CreateCompletionsRequest{
Messages: []Message{{Role: "user", Content: "Hello"}},
},
expectedError: ErrMissingModel,
setupMocks: func(logger *MockLogger, _ *MockMetrics) {
logger.EXPECT().Errorf("%v", ErrMissingModel)
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var serverURL string

var server *httptest.Server

if tt.response != nil {
server = setupTestServer(t, CompletionsEndpoint, tt.response)
defer server.Close()
serverURL = server.URL
}

client := &Client{
config: &Config{
APIKey: "test-api-key",
BaseURL: serverURL,
},
httpClient: http.DefaultClient,
logger: mockLogger,
metrics: mockMetrics,
}

tt.setupMocks(mockLogger, mockMetrics)

response, err := client.CreateCompletions(context.Background(), tt.request)

if tt.expectedError != nil {
assert.Equal(t, tt.expectedError, err)
yash-sojitra marked this conversation as resolved.
Show resolved Hide resolved
assert.Nil(t, response)
} else {
require.NoError(t, err)
assert.NotNil(t, response)
}
})
}
}

func setupTestServer(t *testing.T, path string, response interface{}) *httptest.Server {
t.Helper()

server := httptest.NewServer(
http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, path, r.URL.Path)
assert.Equal(t, "Bearer test-api-key", r.Header.Get("Authorization"))
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))

w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(response)

if err != nil {
t.Error(err)
return
}
yash-sojitra marked this conversation as resolved.
Show resolved Hide resolved
}))

return server
}
Loading
Loading