diff --git a/docs/advanced-guide/using-openai-api/page.md b/docs/advanced-guide/using-openai-api/page.md new file mode 100644 index 000000000..6d52e6a19 --- /dev/null +++ b/docs/advanced-guide/using-openai-api/page.md @@ -0,0 +1,64 @@ +# Using OpenAI Api + +GoFr provides an injectable module that integrates OpenAI's API into the GoFr applications. Since it doesn’t come bundled with the framework, this wrapper can be injected seamlessly to extend Gofr's capabilities, enabling developers to utilize OpenAI's powerful AI models effortlessly while maintaining flexibility and scalability. + +GoFr supports any OpenAI API wrapper that implements the following interface. Any other wrapper that implements the interface can be added using `app.AddOpenAI()` method, and user's can use openai across application with `gofr.Context`. + +```go +type OpenAI interface { + // implementation of chat endpoint of openai api + CreateCompletions(ctx context.Context, r any) (any, error) +} +``` + +### Example +```go +package main + +import ( + "gofr.dev/pkg/gofr" + "gofr.dev/pkg/gofr/service/openai" +) + +func main() { + app := gofr.New() + + config := openai.Config{ + APIKey: app.Config.Get("OPENAI_API_KEY"), + Model: "gpt-3.5-turbo", + + // optional config parameters + // BaseURL: "https://api.custom.com", + // Timeout: 10 * time.Second, + // MaxIdleConns: 10, + } + + openAIClient, err := openai.NewClient(&config) + if err != nil { + return + } + + app.AddOpenAI(openAIClient) + + app.POST("/chat", Chat) + + app.Run() +} + +func Chat(ctx *gofr.Context) (any, error) { + + var req *openai.CreateCompletionsRequest + + if err := ctx.Bind(&req); err != nil { + return nil, err + } + + resp, err := ctx.Openai.CreateCompletions(ctx, req) + if err != nil { + return nil, err + } + + return resp, nil +} +``` + diff --git a/docs/navigation.js b/docs/navigation.js index 68634040f..7fb5eec29 100644 --- a/docs/navigation.js +++ b/docs/navigation.js @@ -146,6 +146,11 @@ export const navigation = [ title: 'Serving-Static Files', href: '/docs/advanced-guide/serving-static-files', desc: "Know how GoFr automatically serves static content from a static folder in the application directory." + }, + { + title: 'Using OpenAI Api', + href: '/docs/advanced-guide/using-openai-api', + desc: "Know how to integrate OpenAI api into your applications easily with GoFr's very own OpenAI api wrapper." } ], }, diff --git a/pkg/gofr/container/container.go b/pkg/gofr/container/container.go index b46a90875..4467eb346 100644 --- a/pkg/gofr/container/container.go +++ b/pkg/gofr/container/container.go @@ -63,6 +63,8 @@ type Container struct { SurrealDB SurrealDB ArangoDB ArangoDB + OpenAI OpenAI + KVStore KVStore File file.FileSystem diff --git a/pkg/gofr/container/mock_container.go b/pkg/gofr/container/mock_container.go index c2b04378c..da20e65dd 100644 --- a/pkg/gofr/container/mock_container.go +++ b/pkg/gofr/container/mock_container.go @@ -87,6 +87,9 @@ func NewMockContainer(t *testing.T, options ...options) (*Container, *Mocks) { opentsdbMock := NewMockOpenTSDBProvider(ctrl) container.OpenTSDB = opentsdbMock + openAIMock := NewMockOpenAIProvider(ctrl) + container.OpenAI = openAIMock + arangoMock := NewMockArangoDBProvider(ctrl) container.ArangoDB = arangoMock diff --git a/pkg/gofr/container/mock_services.go b/pkg/gofr/container/mock_services.go new file mode 100644 index 000000000..d0bc007b1 --- /dev/null +++ b/pkg/gofr/container/mock_services.go @@ -0,0 +1,143 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: services.go +// +// Generated by this command: +// +// mockgen -source=services.go -destination=mock_services.go -package=container +// + +// Package container is a generated GoMock package. +package container + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockOpenAI is a mock of OpenAI interface. +type MockOpenAI struct { + ctrl *gomock.Controller + recorder *MockOpenAIMockRecorder + isgomock struct{} +} + +// MockOpenAIMockRecorder is the mock recorder for MockOpenAI. +type MockOpenAIMockRecorder struct { + mock *MockOpenAI +} + +// NewMockOpenAI creates a new mock instance. +func NewMockOpenAI(ctrl *gomock.Controller) *MockOpenAI { + mock := &MockOpenAI{ctrl: ctrl} + mock.recorder = &MockOpenAIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOpenAI) EXPECT() *MockOpenAIMockRecorder { + return m.recorder +} + +// CreateCompletions mocks base method. +func (m *MockOpenAI) CreateCompletions(ctx context.Context, r any) (any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateCompletions", ctx, r) + ret0, _ := ret[0].(any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateCompletions indicates an expected call of CreateCompletions. +func (mr *MockOpenAIMockRecorder) CreateCompletions(ctx, r any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCompletions", reflect.TypeOf((*MockOpenAI)(nil).CreateCompletions), ctx, r) +} + +// MockOpenAIProvider is a mock of OpenAIProvider interface. +type MockOpenAIProvider struct { + ctrl *gomock.Controller + recorder *MockOpenAIProviderMockRecorder + isgomock struct{} +} + +// MockOpenAIProviderMockRecorder is the mock recorder for MockOpenAIProvider. +type MockOpenAIProviderMockRecorder struct { + mock *MockOpenAIProvider +} + +// NewMockOpenAIProvider creates a new mock instance. +func NewMockOpenAIProvider(ctrl *gomock.Controller) *MockOpenAIProvider { + mock := &MockOpenAIProvider{ctrl: ctrl} + mock.recorder = &MockOpenAIProviderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOpenAIProvider) EXPECT() *MockOpenAIProviderMockRecorder { + return m.recorder +} + +// CreateCompletions mocks base method. +func (m *MockOpenAIProvider) CreateCompletions(ctx context.Context, r any) (any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateCompletions", ctx, r) + ret0, _ := ret[0].(any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateCompletions indicates an expected call of CreateCompletions. +func (mr *MockOpenAIProviderMockRecorder) CreateCompletions(ctx, r any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCompletions", reflect.TypeOf((*MockOpenAIProvider)(nil).CreateCompletions), ctx, r) +} + +// InitMetrics mocks base method. +func (m *MockOpenAIProvider) InitMetrics() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "InitMetrics") +} + +// InitMetrics indicates an expected call of InitMetrics. +func (mr *MockOpenAIProviderMockRecorder) InitMetrics() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitMetrics", reflect.TypeOf((*MockOpenAIProvider)(nil).InitMetrics)) +} + +// UseLogger mocks base method. +func (m *MockOpenAIProvider) UseLogger(logger any) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UseLogger", logger) +} + +// UseLogger indicates an expected call of UseLogger. +func (mr *MockOpenAIProviderMockRecorder) UseLogger(logger any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UseLogger", reflect.TypeOf((*MockOpenAIProvider)(nil).UseLogger), logger) +} + +// UseMetrics mocks base method. +func (m *MockOpenAIProvider) UseMetrics(metrics any) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UseMetrics", metrics) +} + +// UseMetrics indicates an expected call of UseMetrics. +func (mr *MockOpenAIProviderMockRecorder) UseMetrics(metrics any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UseMetrics", reflect.TypeOf((*MockOpenAIProvider)(nil).UseMetrics), metrics) +} + +// UseTracer mocks base method. +func (m *MockOpenAIProvider) UseTracer(tracer any) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UseTracer", tracer) +} + +// UseTracer indicates an expected call of UseTracer. +func (mr *MockOpenAIProviderMockRecorder) UseTracer(tracer any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UseTracer", reflect.TypeOf((*MockOpenAIProvider)(nil).UseTracer), tracer) +} diff --git a/pkg/gofr/container/mockcontainer_test.go b/pkg/gofr/container/mockcontainer_test.go index 38264cf1f..e483a6da3 100644 --- a/pkg/gofr/container/mockcontainer_test.go +++ b/pkg/gofr/container/mockcontainer_test.go @@ -15,26 +15,13 @@ import ( ) func Test_HttpServiceMock(t *testing.T) { - test := struct { - desc string - path string - statusCode int - expectedRes string - }{ - - desc: "simple service handler", - path: "/fact", - expectedRes: `{"data":{"fact":"Cats have 3 eyelids.","length":20}}` + "\n", - statusCode: 200, - } - httpservices := []string{"cat-facts", "cat-facts1", "cat-facts2"} _, mock := NewMockContainer(t, WithMockHTTPService(httpservices...)) res := httptest.NewRecorder() res.Body = bytes.NewBufferString(`{"fact":"Cats have 3 eyelids.","length":20}` + "\n") - res.Code = test.statusCode + res.Code = 200 result := res.Result() // Setting mock expectations diff --git a/pkg/gofr/container/services.go b/pkg/gofr/container/services.go new file mode 100644 index 000000000..61f8ecd83 --- /dev/null +++ b/pkg/gofr/container/services.go @@ -0,0 +1,27 @@ +package container + +import ( + "context" +) + +// OpenAI is the interface that wraps the basic endpoint of OpenAI API. +type OpenAI interface { + // implementation of chat endpoint of OpenAI API + CreateCompletions(ctx context.Context, r any) (any, error) +} + +type OpenAIProvider interface { + OpenAI + + // UseLogger set the logger for OpenAI client + UseLogger(logger any) + + // UseMetrics set the logger for OpenAI client + UseMetrics(metrics any) + + // UseTracer set the logger for OpenAI client + UseTracer(tracer any) + + // InitMetrics is used to initializes metrics for the client + InitMetrics() +} diff --git a/pkg/gofr/service/openai/chatcompletion.go b/pkg/gofr/service/openai/chatcompletion.go new file mode 100644 index 000000000..15fadd0f7 --- /dev/null +++ b/pkg/gofr/service/openai/chatcompletion.go @@ -0,0 +1,199 @@ +package openai + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +const CompletionsEndpoint = "/v1/chat/completions" + +type CreateCompletionsRequest struct { + Messages []Message `json:"messages,omitempty"` + Model string `json:"model,omitempty"` + Store bool `json:"store,omitempty"` + ReasoningEffort string `json:"reasoning_effort,omitempty"` + MetaData any `json:"metadata,omitempty"` // object or null + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]string `json:"logit_bias,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + TopLogProbs int `json:"top_logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` // deprecated + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + N int `json:"n,omitempty"` + Modalities []string `json:"modalities,omitempty"` + Prediction any `json:"prediction,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + + Audio struct { + Voice string `json:"voice,omitempty"` + Format string `json:"format,omitempty"` + } `json:"audio,omitempty"` + + ResponseFormat any `json:"response_format,omitempty"` + Seed int `json:"seed,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` + Stop any `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + + StreamOptions struct { + IncludeUsage bool `json:"include_usage,omitempty"` + } `json:"stram_options,omitempty"` + + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + + Tools []struct { + Type string `json:"type,omitempty"` + Function struct { + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Parameters any `json:"parameters,omitempty"` + Strict bool `json:"strict,omitempty"` + } `json:"function,omitempty"` + } `json:"tools,omitempty"` + + ToolChoice any `json:"tool_choice,omitempty"` + ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` + Suffix string `json:"suffix,omitempty"` + User string `json:"user,omitempty"` +} + +type Message struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Name string `json:"name,omitempty"` +} + +type CreateCompletionsResponse struct { + ID string `json:"id,omitempty"` + Object string `json:"object,omitempty"` + Created int `json:"created,omitempty"` + Model string `json:"model,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` + + Choices []struct { + Index int `json:"index,omitempty"` + + Message struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` + ToolCalls any `json:"tool_calls,omitempty"` + } `json:"message"` + + Logprobs any `json:"logprobs,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + } `json:"choices,omitempty"` + + Usage Usage `json:"usage,omitempty"` + + Error *Error `json:"error,omitempty"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens,omitempty"` + CompletionTokens int `json:"completion_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` + CompletionTokensDetails any `json:"completion_tokens_details,omitempty"` + PromptTokensDetails any `json:"prompt_tokens_details,omitempty"` +} + +type Error struct { + Message string `json:"message,omitempty"` + Type string `json:"type,omitempty"` + Param any `json:"param,omitempty"` + Code any `json:"code,omitempty"` +} + +var ( + errMissingBoth = errors.New("both messages and model fields not provided") + errMissingMessages = errors.New("messages fields not provided") + errMissingModel = errors.New("model fields not provided") + errRequestType = errors.New("invalid request type") +) + +func (e *Error) Error() string { + return fmt.Sprintf("%s: %s", e.Code, e.Message) +} + +func (c *Client) CreateCompletionsRaw(ctx context.Context, r *CreateCompletionsRequest) ([]byte, error) { + return c.post(ctx, CompletionsEndpoint, r) +} + +func (c *Client) CreateCompletions(ctx context.Context, r any) (any, error) { + req, ok := r.(*CreateCompletionsRequest) + if !ok { + c.logger.Errorf("%v", errRequestType) + return nil, errRequestType + } + + tracerCtx, span := c.AddTrace(ctx, "CreateCompletions") + startTime := time.Now() + + if req.Messages == nil && req.Model == "" { + c.logger.Errorf("%v", errMissingBoth) + return nil, errMissingBoth + } + + if req.Messages == nil { + c.logger.Errorf("%v", errMissingMessages) + return nil, errMissingMessages + } + + if req.Model == "" { + c.logger.Errorf("%v", errMissingModel) + return nil, errMissingModel + } + + raw, err := c.CreateCompletionsRaw(tracerCtx, req) + if err != nil { + return nil, err + } + + var response CreateCompletionsResponse + + err = json.Unmarshal(raw, &response) + if err != nil { + return nil, err + } + + ql := &APILog{ + ID: response.ID, + Query: "CreateCompletions", + Object: response.Object, + Created: response.Created, + Model: response.Model, + ServiceTier: response.ServiceTier, + SystemFingerprint: response.SystemFingerprint, + Usage: response.Usage, + Error: response.Error, + } + + c.SendChatCompletionOperationStats(ctx, ql, startTime, "ChatCompletion", span) + + return response, err +} + +func (c *Client) SendChatCompletionOperationStats(ctx context.Context, ql *APILog, startTime time.Time, method string, span trace.Span) { + duration := time.Since(startTime).Microseconds() + + ql.Duration = duration + + c.logger.Debug(ql) + + c.metrics.RecordHistogram(ctx, "openai_api_request_duration", float64(duration)) + c.metrics.IncrementCounter(ctx, "openai_api_total_request_count") + c.metrics.DeltaUpDownCounter(ctx, "openai_api_token_usage", float64(ql.Usage.TotalTokens)) + + if span != nil { + defer span.End() + span.SetAttributes(attribute.Int64(fmt.Sprintf("openai.%v.duration", method), duration)) + } +} diff --git a/pkg/gofr/service/openai/chatcompletion_test.go b/pkg/gofr/service/openai/chatcompletion_test.go new file mode 100644 index 000000000..7fcbd61eb --- /dev/null +++ b/pkg/gofr/service/openai/chatcompletion_test.go @@ -0,0 +1,141 @@ +package openai + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +type test struct { + name string + request *CreateCompletionsRequest + response *CreateCompletionsResponse + expectedError error + setupMocks func(*MockLogger, *MockMetrics) +} + +func Test_ChatCompletions(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLogger(ctrl) + mockMetrics := NewMockMetrics(ctrl) + + tests := []test{ + { + name: "successful completion request", + request: &CreateCompletionsRequest{ + Messages: []Message{{Role: "user", Content: "Hello"}}, + Model: "gpt-3.5-turbo", + }, + response: &CreateCompletionsResponse{ + ID: "test-id", + Object: "chat.completion", + Created: 1234567890, + Usage: Usage{ + PromptTokens: 10, + CompletionTokens: 20, + TotalTokens: 30, + }, + }, + expectedError: nil, + setupMocks: func(logger *MockLogger, metrics *MockMetrics) { + metrics.EXPECT().RecordHistogram(gomock.Any(), "openai_api_request_duration", gomock.Any()) + metrics.EXPECT().IncrementCounter(gomock.Any(), "openai_api_total_request_count") + metrics.EXPECT().DeltaUpDownCounter(gomock.Any(), "openai_api_token_usage", 30.0) + logger.EXPECT().Debug(gomock.Any()) + }, + }, + { + name: "missing both messages and model", + request: &CreateCompletionsRequest{}, + expectedError: errMissingBoth, + setupMocks: func(logger *MockLogger, _ *MockMetrics) { + logger.EXPECT().Errorf("%v", errMissingBoth) + }, + }, + { + name: "missing messages", + request: &CreateCompletionsRequest{ + Model: "gpt-3.5-turbo", + }, + expectedError: errMissingMessages, + setupMocks: func(logger *MockLogger, _ *MockMetrics) { + logger.EXPECT().Errorf("%v", errMissingMessages) + }, + }, + { + name: "missing model", + request: &CreateCompletionsRequest{ + Messages: []Message{{Role: "user", Content: "Hello"}}, + }, + expectedError: errMissingModel, + setupMocks: func(logger *MockLogger, _ *MockMetrics) { + logger.EXPECT().Errorf("%v", errMissingModel) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var serverURL string + + var server *httptest.Server + + if tt.response != nil { + server = setupTestServer(t, CompletionsEndpoint, tt.response) + defer server.Close() + serverURL = server.URL + } + + client := &Client{ + config: &Config{ + APIKey: "test-api-key", + BaseURL: serverURL, + }, + httpClient: http.DefaultClient, + logger: mockLogger, + metrics: mockMetrics, + } + + tt.setupMocks(mockLogger, mockMetrics) + response, err := client.CreateCompletions(context.Background(), tt.request) + + if tt.expectedError != nil { + require.ErrorIs(t, err, tt.expectedError) + assert.Nil(t, response) + } else { + require.NoError(t, err) + assert.NotNil(t, response) + } + }) + } +} + +func setupTestServer(t *testing.T, path string, response any) *httptest.Server { + t.Helper() + + server := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, path, r.URL.Path) + assert.Equal(t, "Bearer test-api-key", r.Header.Get("Authorization")) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(response) + + if err != nil { + t.Error(err) + return + } + })) + + return server +} diff --git a/pkg/gofr/service/openai/go.mod b/pkg/gofr/service/openai/go.mod new file mode 100644 index 000000000..13650e2ac --- /dev/null +++ b/pkg/gofr/service/openai/go.mod @@ -0,0 +1,23 @@ +module gofr.dev/pkg/gofr/service/openai + +go 1.23.4 + +require ( + go.opentelemetry.io/otel v1.34.0 + go.uber.org/mock v0.5.0 +) + +require ( + github.com/stretchr/testify v1.10.0 + go.opentelemetry.io/otel/trace v1.34.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel/metric v1.34.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/pkg/gofr/service/openai/go.sum b/pkg/gofr/service/openai/go.sum new file mode 100644 index 000000000..d3b9b2bd3 --- /dev/null +++ b/pkg/gofr/service/openai/go.sum @@ -0,0 +1,31 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.33.0 h1:/FerN9bax5LoK51X/sI0SVYrjSE0/yUL7DpxW4K3FWw= +go.opentelemetry.io/otel v1.33.0/go.mod h1:SUUkR6csvUQl+yjReHu5uM3EtVV7MBm5FHKRlNx4I8I= +go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY= +go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI= +go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ= +go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE= +go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s= +go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck= +go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k= +go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/gofr/service/openai/logger.go b/pkg/gofr/service/openai/logger.go new file mode 100644 index 000000000..c0be3f743 --- /dev/null +++ b/pkg/gofr/service/openai/logger.go @@ -0,0 +1,57 @@ +package openai + +import ( + "fmt" + "io" + "regexp" + "strings" +) + +type Logger interface { + Debug(args ...any) + Debugf(pattern string, args ...any) + Logf(pattern string, args ...any) + Errorf(pattern string, args ...any) +} + +type APILog struct { + ID string `json:"id,omitempty"` + Query string `json:"query,omitempty"` + Object string `json:"object,omitempty"` + Created int `json:"created,omitempty"` + Model string `json:"model,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` + SystemFingerprint string `json:"system_fingerprint,omitempty"` + Duration int64 `json:"duration,omitempty"` + + Usage struct { + PromptTokens int `json:"prompt_tokens,omitempty"` + CompletionTokens int `json:"completion_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` + CompletionTokensDetails any `json:"completion_tokens_details,omitempty"` + PromptTokensDetails any `json:"prompt_tokens_details,omitempty"` + } `json:"usage,omitempty"` + + Error *Error `json:"error,omitempty"` +} + +func (al *APILog) PrettyPrint(writer io.Writer) { + fmt.Fprintf(writer, + "\u001B[38;5;8m%-32s \u001B[38;5;206m%-6s\u001B[0m %8d\u001B[38;5;8mµs\u001B[0m %s\n", + + clean(al.Query), + "OPENAI", + al.Duration, + clean(strings.Join([]string{al.Model, fmt.Sprint(al.Created), fmt.Sprint(al.Usage)}, " ")), + ) +} + +func clean(query string) string { + // Replace multiple consecutive whitespace characters with a single space + query = regexp.MustCompile(`\s+`).ReplaceAllString(query, " ") + + // Trim leading and trailing whitespace from the string + query = strings.TrimSpace(query) + + return query +} diff --git a/pkg/gofr/service/openai/metrics.go b/pkg/gofr/service/openai/metrics.go new file mode 100644 index 000000000..1aae9a131 --- /dev/null +++ b/pkg/gofr/service/openai/metrics.go @@ -0,0 +1,13 @@ +package openai + +import "context" + +type Metrics interface { + NewCounter(name, desc string) + NewUpDownCounter(name, desc string) + NewHistogram(name, desc string, buckets ...float64) + + IncrementCounter(ctx context.Context, name string, labels ...string) + DeltaUpDownCounter(ctx context.Context, name string, value float64, labels ...string) + RecordHistogram(ctx context.Context, name string, value float64, labels ...string) +} diff --git a/pkg/gofr/service/openai/mock_logger.go b/pkg/gofr/service/openai/mock_logger.go new file mode 100644 index 000000000..e1dcbc1e9 --- /dev/null +++ b/pkg/gofr/service/openai/mock_logger.go @@ -0,0 +1,107 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: logger.go +// +// Generated by this command: +// +// mockgen -source=logger.go -destination=mock_logger.go -package=openai +// + +// Package openai is a generated GoMock package. +package openai + +import ( + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockLogger is a mock of Logger interface. +type MockLogger struct { + ctrl *gomock.Controller + recorder *MockLoggerMockRecorder + isgomock struct{} +} + +// MockLoggerMockRecorder is the mock recorder for MockLogger. +type MockLoggerMockRecorder struct { + mock *MockLogger +} + +// NewMockLogger creates a new mock instance. +func NewMockLogger(ctrl *gomock.Controller) *MockLogger { + mock := &MockLogger{ctrl: ctrl} + mock.recorder = &MockLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { + return m.recorder +} + +// Debug mocks base method. +func (m *MockLogger) Debug(args ...any) { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range args { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Debug", varargs...) +} + +// Debug indicates an expected call of Debug. +func (mr *MockLoggerMockRecorder) Debug(args ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), args...) +} + +// Debugf mocks base method. +func (m *MockLogger) Debugf(pattern string, args ...any) { + m.ctrl.T.Helper() + varargs := []any{pattern} + for _, a := range args { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Debugf", varargs...) +} + +// Debugf indicates an expected call of Debugf. +func (mr *MockLoggerMockRecorder) Debugf(pattern any, args ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{pattern}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockLogger)(nil).Debugf), varargs...) +} + +// Errorf mocks base method. +func (m *MockLogger) Errorf(pattern string, args ...any) { + m.ctrl.T.Helper() + varargs := []any{pattern} + for _, a := range args { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Errorf", varargs...) +} + +// Errorf indicates an expected call of Errorf. +func (mr *MockLoggerMockRecorder) Errorf(pattern any, args ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{pattern}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Errorf", reflect.TypeOf((*MockLogger)(nil).Errorf), varargs...) +} + +// Logf mocks base method. +func (m *MockLogger) Logf(pattern string, args ...any) { + m.ctrl.T.Helper() + varargs := []any{pattern} + for _, a := range args { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Logf", varargs...) +} + +// Logf indicates an expected call of Logf. +func (mr *MockLoggerMockRecorder) Logf(pattern any, args ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{pattern}, args...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logf", reflect.TypeOf((*MockLogger)(nil).Logf), varargs...) +} diff --git a/pkg/gofr/service/openai/mock_metrics.go b/pkg/gofr/service/openai/mock_metrics.go new file mode 100644 index 000000000..9b16d92af --- /dev/null +++ b/pkg/gofr/service/openai/mock_metrics.go @@ -0,0 +1,133 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: metrics.go +// +// Generated by this command: +// +// mockgen -source=metrics.go -destination=mock_metrics.go -package=openai +// + +// Package openai is a generated GoMock package. +package openai + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockMetrics is a mock of Metrics interface. +type MockMetrics struct { + ctrl *gomock.Controller + recorder *MockMetricsMockRecorder + isgomock struct{} +} + +// MockMetricsMockRecorder is the mock recorder for MockMetrics. +type MockMetricsMockRecorder struct { + mock *MockMetrics +} + +// NewMockMetrics creates a new mock instance. +func NewMockMetrics(ctrl *gomock.Controller) *MockMetrics { + mock := &MockMetrics{ctrl: ctrl} + mock.recorder = &MockMetricsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMetrics) EXPECT() *MockMetricsMockRecorder { + return m.recorder +} + +// DeltaUpDownCounter mocks base method. +func (m *MockMetrics) DeltaUpDownCounter(ctx context.Context, name string, value float64, labels ...string) { + m.ctrl.T.Helper() + varargs := []any{ctx, name, value} + for _, a := range labels { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "DeltaUpDownCounter", varargs...) +} + +// DeltaUpDownCounter indicates an expected call of DeltaUpDownCounter. +func (mr *MockMetricsMockRecorder) DeltaUpDownCounter(ctx, name, value any, labels ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, name, value}, labels...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeltaUpDownCounter", reflect.TypeOf((*MockMetrics)(nil).DeltaUpDownCounter), varargs...) +} + +// IncrementCounter mocks base method. +func (m *MockMetrics) IncrementCounter(ctx context.Context, name string, labels ...string) { + m.ctrl.T.Helper() + varargs := []any{ctx, name} + for _, a := range labels { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "IncrementCounter", varargs...) +} + +// IncrementCounter indicates an expected call of IncrementCounter. +func (mr *MockMetricsMockRecorder) IncrementCounter(ctx, name any, labels ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, name}, labels...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementCounter", reflect.TypeOf((*MockMetrics)(nil).IncrementCounter), varargs...) +} + +// NewCounter mocks base method. +func (m *MockMetrics) NewCounter(name, desc string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NewCounter", name, desc) +} + +// NewCounter indicates an expected call of NewCounter. +func (mr *MockMetricsMockRecorder) NewCounter(name, desc any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewCounter", reflect.TypeOf((*MockMetrics)(nil).NewCounter), name, desc) +} + +// NewHistogram mocks base method. +func (m *MockMetrics) NewHistogram(name, desc string, buckets ...float64) { + m.ctrl.T.Helper() + varargs := []any{name, desc} + for _, a := range buckets { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "NewHistogram", varargs...) +} + +// NewHistogram indicates an expected call of NewHistogram. +func (mr *MockMetricsMockRecorder) NewHistogram(name, desc any, buckets ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{name, desc}, buckets...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewHistogram", reflect.TypeOf((*MockMetrics)(nil).NewHistogram), varargs...) +} + +// NewUpDownCounter mocks base method. +func (m *MockMetrics) NewUpDownCounter(name, desc string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "NewUpDownCounter", name, desc) +} + +// NewUpDownCounter indicates an expected call of NewUpDownCounter. +func (mr *MockMetricsMockRecorder) NewUpDownCounter(name, desc any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewUpDownCounter", reflect.TypeOf((*MockMetrics)(nil).NewUpDownCounter), name, desc) +} + +// RecordHistogram mocks base method. +func (m *MockMetrics) RecordHistogram(ctx context.Context, name string, value float64, labels ...string) { + m.ctrl.T.Helper() + varargs := []any{ctx, name, value} + for _, a := range labels { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "RecordHistogram", varargs...) +} + +// RecordHistogram indicates an expected call of RecordHistogram. +func (mr *MockMetricsMockRecorder) RecordHistogram(ctx, name, value any, labels ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, name, value}, labels...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordHistogram", reflect.TypeOf((*MockMetrics)(nil).RecordHistogram), varargs...) +} diff --git a/pkg/gofr/service/openai/openai.go b/pkg/gofr/service/openai/openai.go new file mode 100644 index 000000000..7a404f5a9 --- /dev/null +++ b/pkg/gofr/service/openai/openai.go @@ -0,0 +1,190 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "time" + + "go.opentelemetry.io/otel/trace" +) + +type Config struct { + APIKey string + Model string + BaseURL string + Timeout time.Duration + MaxIdleConns int +} + +type Client struct { + config *Config + logger Logger + metrics Metrics + tracer trace.Tracer + httpClient *http.Client +} + +var ( + errorMissingAPIKey = errors.New("API key not provided") +) + +type ClientOption func(*Client) + +func WithClientHTTP(httpClient *http.Client) func(*Client) { + return func(c *Client) { + c.httpClient = httpClient + } +} + +func WithClientTimeout(d time.Duration) func(*Client) { + return func(c *Client) { + c.httpClient.Timeout = d + } +} + +func NewClient(config *Config, opts ...ClientOption) (*Client, error) { + if config.APIKey == "" { + return nil, errorMissingAPIKey + } + + if config.BaseURL == "" { + config.BaseURL = "https://api.openai.com" + } + + if config.Model == "" { + config.Model = "gpt-4o" + } + + // Use the provided HTTP client or create a new one with defaults + c := &Client{ + config: config, + httpClient: &http.Client{ + Timeout: config.Timeout, + Transport: &http.Transport{ + MaxIdleConns: config.MaxIdleConns, + IdleConnTimeout: 120 * time.Second, + }, + }, + } + + for _, opt := range opts { + opt(c) + } + + return c, nil +} + +func (c *Client) UseLogger(logger any) { + if l, ok := logger.(Logger); ok { + c.logger = l + } +} + +func (c *Client) UseMetrics(metrics any) { + if m, ok := metrics.(Metrics); ok { + c.metrics = m + } +} + +func (c *Client) UseTracer(tracer any) { + if tracer, ok := tracer.(trace.Tracer); ok { + c.tracer = tracer + } +} + +func (c *Client) InitMetrics() { + openaiHistogramBuckets := []float64{.05, .075, .1, .125, .15, .2, .3, .5, .75, 1, 2, 3, 4, 5, 7.5, 10} + + c.metrics.NewHistogram( + "openai_api_request_duration", + "duration of OpenAPI requests in seconds", + openaiHistogramBuckets..., + ) + + c.metrics.NewCounter( + "openai_api_total_request_count", + "counts total number of requests made.", + ) + + c.metrics.NewUpDownCounter( + "openai_api_token_usage", + "counts number of tokens used.", + ) +} + +func (c *Client) AddTrace(ctx context.Context, method string) (context.Context, trace.Span) { + if c.tracer != nil { + contextWithTrace, span := c.tracer.Start(ctx, fmt.Sprintf("openai-%v", method)) + + return contextWithTrace, span + } + + return ctx, nil +} + +func (c *Client) post(ctx context.Context, url string, input any) (response []byte, err error) { + response = make([]byte, 0) + + reqJSON, err := json.Marshal(input) + if err != nil { + c.logger.Errorf("%v", err) + return response, err + } + + resp, err := c.call(ctx, http.MethodPost, url, bytes.NewReader(reqJSON)) + if err != nil { + c.logger.Errorf("%v", err) + return response, err + } + defer resp.Body.Close() + + response, err = io.ReadAll(resp.Body) + if err != nil { + c.logger.Errorf("%v", err) + } + + return response, err +} + +// Get makes a get request. +func (c *Client) get(ctx context.Context, url string) (response []byte, err error) { + resp, err := c.call(ctx, http.MethodGet, url, nil) + if err != nil { + c.logger.Errorf("%v", err) + return response, err + } + defer resp.Body.Close() + + response, err = io.ReadAll(resp.Body) + if err != nil { + c.logger.Errorf("%v", err) + } + + return response, err +} + +// Call makes a request. +func (c *Client) call(ctx context.Context, method, endpoint string, body io.Reader) (response *http.Response, err error) { + url := c.config.BaseURL + endpoint + + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + c.logger.Errorf("%v", err) + return response, err + } + + req.Header.Add("Authorization", "Bearer "+c.config.APIKey) + req.Header.Add("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + c.logger.Errorf("%v", err) + } + + return resp, err +} diff --git a/pkg/gofr/service/openai/openai_test.go b/pkg/gofr/service/openai/openai_test.go new file mode 100644 index 000000000..ebc1df1be --- /dev/null +++ b/pkg/gofr/service/openai/openai_test.go @@ -0,0 +1,460 @@ +package openai + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "go.uber.org/mock/gomock" +) + +var ( + errMockRead = errors.New("read error") + errNetworkError = errors.New("network error") +) + +func Test_NewClient(t *testing.T) { + tests := []struct { + name string + config *Config + opts []ClientOption + baseURL string + expected string + timeout time.Duration + expectedError error + }{ + { + name: "with default base URL", + config: &Config{APIKey: "test-key", Model: "gpt-4"}, + opts: []ClientOption{WithClientHTTP(&http.Client{})}, + expected: "https://api.openai.com", + expectedError: nil, + }, + { + name: "with custom base URL", + config: &Config{APIKey: "test-key", Model: "gpt-4", BaseURL: "https://custom.openai.com"}, + opts: []ClientOption{WithClientHTTP(&http.Client{})}, + expected: "https://custom.openai.com", + expectedError: nil, + }, + { + name: "missing api key", + config: &Config{Model: "gpt-4"}, + opts: []ClientOption{WithClientHTTP(&http.Client{})}, + expectedError: errorMissingAPIKey, + }, + { + name: "with custom timeout", + config: &Config{APIKey: "test-key", Model: "gpt-4"}, + opts: []ClientOption{WithClientTimeout(5 * time.Second)}, + expected: "https://api.openai.com", + timeout: 5 * time.Second, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := NewClient(tt.config, tt.opts...) + if tt.expectedError != nil { + assert.Equal(t, tt.expectedError, err) + assert.Nil(t, client) + } else { + assert.Equal(t, tt.expected, client.config.BaseURL) + + if tt.timeout > 0 { + assert.Equal(t, tt.timeout, client.httpClient.Timeout) + } + + require.NoError(t, err) + } + }) + } +} + +func Test_UseLogger(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLogger(ctrl) + + config := &Config{ + APIKey: "key", + } + + client, _ := NewClient(config) + client.UseLogger(mockLogger) + + assert.NotNil(t, client.logger) +} + +func Test_UseMetrics(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockMetrics := NewMockMetrics(ctrl) + + config := &Config{ + APIKey: "key", + } + + client, _ := NewClient(config) + client.UseMetrics(mockMetrics) + + assert.NotNil(t, client.metrics) +} + +func Test_UseTracer(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tracer := otel.GetTracerProvider().Tracer("gofr-openAI") + + config := &Config{ + APIKey: "key", + } + + client, _ := NewClient(config) + client.UseTracer(tracer) + + assert.NotNil(t, client.tracer) +} + +func Test_InitMetrics(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockMetrics := NewMockMetrics(ctrl) + + config := &Config{ + APIKey: "key", + } + + client, _ := NewClient(config) + client.UseMetrics(mockMetrics) + + openaiHistogramBuckets := []float64{.05, .075, .1, .125, .15, .2, .3, .5, .75, 1, 2, 3, 4, 5, 7.5, 10} + + mockMetrics.EXPECT().NewHistogram( + "openai_api_request_duration", + "duration of OpenAPI requests in seconds", + openaiHistogramBuckets, + ) + + mockMetrics.EXPECT().NewCounter( + "openai_api_total_request_count", + "counts total number of requests made.", + ) + + mockMetrics.EXPECT().NewUpDownCounter( + "openai_api_token_usage", + "counts number of tokens used.", + ) + + client.InitMetrics() +} + +func Test_AddTrace(t *testing.T) { + config := &Config{ + APIKey: "test-key", + } + + client, _ := NewClient(config) + tracer := otel.GetTracerProvider().Tracer("gofr-openAI") + client.UseTracer(tracer) + + ctx := context.Background() + resultCtx, span := client.AddTrace(ctx, "test-method") + + assert.NotNil(t, span) + assert.NotEqual(t, ctx, resultCtx) +} + +type mockTransport struct { + response *http.Response + err error +} + +func (m *mockTransport) RoundTrip(*http.Request) (*http.Response, error) { + return m.response, m.err +} + +func Test_Call(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLogger(ctrl) + mockMetrics := NewMockMetrics(ctrl) + + config := &Config{ + APIKey: "test-key", + BaseURL: "https://api.openai.com", + } + + tests := []struct { + name string + method string + endpoint string + body io.Reader + setupMocks func(*http.Client) + wantErr bool + }{ + { + name: "successful request", + method: http.MethodPost, + endpoint: "/v1/chat/completions", + body: strings.NewReader(`{"test":"data"}`), + setupMocks: func(client *http.Client) { + client.Transport = &mockTransport{ + response: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"response":"success"}`)), + }, + } + }, + wantErr: false, + }, + { + name: "failed request", + method: http.MethodPost, + endpoint: "/v1/chat/completions", + body: strings.NewReader(`{"test":"data"}`), + setupMocks: func(client *http.Client) { + client.Transport = &mockTransport{ + err: errNetworkError, + } + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + httpClient := &http.Client{} + if tt.setupMocks != nil { + tt.setupMocks(httpClient) + } + + client, _ := NewClient(config, WithClientHTTP(httpClient)) + client.UseLogger(mockLogger) + client.UseMetrics(mockMetrics) + + mockLogger.EXPECT().Errorf(gomock.Any(), gomock.Any()).AnyTimes() + + resp, err := client.call(context.Background(), tt.method, tt.endpoint, tt.body) + if resp != nil { + defer resp.Body.Close() + } + + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, http.StatusOK, resp.StatusCode) + } + }) + } +} +func Test_Get(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLogger(ctrl) + mockMetrics := NewMockMetrics(ctrl) + + config := &Config{ + APIKey: "test-key", + BaseURL: "https://api.openai.com", + } + + tests := []struct { + name string + url string + setupMocks func(*http.Client) + want []byte + wantErr bool + }{ + { + name: "successful GET request", + url: "/v1/models", + setupMocks: func(client *http.Client) { + client.Transport = &mockTransport{ + response: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"data":"test"}`)), + }, + } + }, + want: []byte(`{"data":"test"}`), + wantErr: false, + }, + { + name: "network error", + url: "/v1/models", + setupMocks: func(client *http.Client) { + client.Transport = &mockTransport{ + err: errNetworkError, + } + }, + want: []byte{}, + wantErr: true, + }, + { + name: "error reading response body", + url: "/v1/models", + setupMocks: func(client *http.Client) { + client.Transport = &mockTransport{ + response: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(&errorReader{}), + }, + } + }, + want: []byte{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + httpClient := &http.Client{} + if tt.setupMocks != nil { + tt.setupMocks(httpClient) + } + + client, _ := NewClient(config, WithClientHTTP(httpClient)) + client.UseLogger(mockLogger) + client.UseMetrics(mockMetrics) + + mockLogger.EXPECT().Errorf(gomock.Any(), gomock.Any()).AnyTimes() + + got, err := client.get(context.Background(), tt.url) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +// ErrMockRead is a static error for mock read operations + +// errorReader is a mock reader that always returns an error. +type errorReader struct{} + +func (*errorReader) Read(_ []byte) (n int, err error) { + return 0, errMockRead +} + +func Test_Post(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockLogger := NewMockLogger(ctrl) + mockMetrics := NewMockMetrics(ctrl) + + config := &Config{ + APIKey: "test-key", + BaseURL: "https://api.openai.com", + } + + tests := []struct { + name string + url string + input any + setupMocks func(*http.Client) + want []byte + wantErr bool + }{ + { + name: "successful POST request", + url: "/v1/completions", + input: map[string]string{ + "prompt": "test", + }, + setupMocks: func(client *http.Client) { + client.Transport = &mockTransport{ + response: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"result":"success"}`)), + }, + } + }, + want: []byte(`{"result":"success"}`), + wantErr: false, + }, + { + name: "invalid input JSON", + url: "/v1/completions", + input: make(chan int), // Unmarshalable type + setupMocks: func(_ *http.Client) {}, + want: []byte{}, + wantErr: true, + }, + { + name: "network error", + url: "/v1/completions", + input: map[string]string{ + "prompt": "test", + }, + setupMocks: func(client *http.Client) { + client.Transport = &mockTransport{ + err: errNetworkError, + } + }, + want: []byte{}, + wantErr: true, + }, + { + name: "error reading response body", + url: "/v1/completions", + input: map[string]string{ + "prompt": "test", + }, + setupMocks: func(client *http.Client) { + client.Transport = &mockTransport{ + response: &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(&errorReader{}), + }, + } + }, + want: []byte{}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + httpClient := &http.Client{} + if tt.setupMocks != nil { + tt.setupMocks(httpClient) + } + + client, _ := NewClient(config, WithClientHTTP(httpClient)) + client.UseLogger(mockLogger) + client.UseMetrics(mockMetrics) + + mockLogger.EXPECT().Errorf(gomock.Any(), gomock.Any()).AnyTimes() + + got, err := client.post(context.Background(), tt.url, tt.input) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} diff --git a/pkg/gofr/services.go b/pkg/gofr/services.go new file mode 100644 index 000000000..061233d06 --- /dev/null +++ b/pkg/gofr/services.go @@ -0,0 +1,20 @@ +package gofr + +import ( + "go.opentelemetry.io/otel" + + "gofr.dev/pkg/gofr/container" +) + +// AddOpenAI sets the OpenAI wrapper in the app's container. +func (a *App) AddOpenAI(openAI container.OpenAIProvider) { + openAI.UseLogger(a.Logger()) + openAI.UseMetrics(a.Metrics()) + + tracer := otel.GetTracerProvider().Tracer("gofr-openai") + openAI.UseTracer(tracer) + + openAI.InitMetrics() + + a.container.OpenAI = openAI +} diff --git a/pkg/gofr/services_test.go b/pkg/gofr/services_test.go new file mode 100644 index 000000000..b2f98a9b1 --- /dev/null +++ b/pkg/gofr/services_test.go @@ -0,0 +1,35 @@ +package gofr + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "gofr.dev/pkg/gofr/config" + "gofr.dev/pkg/gofr/container" +) + +func TestApp_AddOpenAI(t *testing.T) { + t.Run("Adding OpenAI", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + c := container.NewContainer(config.NewMockConfig(nil)) + + app := &App{ + container: c, + } + + mock := container.NewMockOpenAIProvider(ctrl) + + mock.EXPECT().UseLogger(app.Logger()) + mock.EXPECT().UseMetrics(app.Metrics()) + mock.EXPECT().UseTracer(gomock.Any()) + mock.EXPECT().InitMetrics() + + app.AddOpenAI(mock) + + assert.Equal(t, mock, app.container.OpenAI) + }) +}