diff --git a/chat_stream.go b/chat_stream.go index 80d16cc63..7b0bc40c2 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -65,10 +65,21 @@ type ChatCompletionStreamResponse struct { Usage *Usage `json:"usage,omitempty"` } +// ChatStreamReader is an interface for reading chat completion streams. +type ChatStreamReader interface { + Recv() (ChatCompletionStreamResponse, error) + Close() error +} + // ChatCompletionStream // Note: Perhaps it is more elegant to abstract Stream using generics. type ChatCompletionStream struct { - *streamReader[ChatCompletionStreamResponse] + reader ChatStreamReader +} + +// NewChatCompletionStream allows injecting a custom ChatStreamReader (for testing). +func NewChatCompletionStream(reader ChatStreamReader) *ChatCompletionStream { + return &ChatCompletionStream{reader: reader} } // CreateChatCompletionStream — API call to create a chat completion w/ streaming @@ -106,7 +117,37 @@ func (c *Client) CreateChatCompletionStream( return } stream = &ChatCompletionStream{ - streamReader: resp, + reader: resp, } return } + +func (s *ChatCompletionStream) Recv() (ChatCompletionStreamResponse, error) { + return s.reader.Recv() +} + +func (s *ChatCompletionStream) Close() error { + return s.reader.Close() +} + +func (s *ChatCompletionStream) Header() http.Header { + if h, ok := s.reader.(interface{ Header() http.Header }); ok { + return h.Header() + } + return http.Header{} +} + +func (s *ChatCompletionStream) GetRateLimitHeaders() map[string]interface{} { + if h, ok := s.reader.(interface{ GetRateLimitHeaders() RateLimitHeaders }); ok { + headers := h.GetRateLimitHeaders() + return map[string]interface{}{ + "x-ratelimit-limit-requests": headers.LimitRequests, + "x-ratelimit-limit-tokens": headers.LimitTokens, + "x-ratelimit-remaining-requests": headers.RemainingRequests, + "x-ratelimit-remaining-tokens": headers.RemainingTokens, + "x-ratelimit-reset-requests": headers.ResetRequests.String(), + "x-ratelimit-reset-tokens": headers.ResetTokens.String(), + } + } + return map[string]interface{}{} +} diff --git a/chat_stream_test.go b/chat_stream_test.go index eabb0f3a2..65d92a702 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -767,6 +767,34 @@ func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { } } +type mockStream struct { + calls int +} + +// Implement ChatStreamReader. +func (m *mockStream) Recv() (openai.ChatCompletionStreamResponse, error) { + m.calls++ + if m.calls == 1 { + return openai.ChatCompletionStreamResponse{ID: "mock1"}, nil + } + return openai.ChatCompletionStreamResponse{}, io.EOF +} +func (m *mockStream) Close() error { return nil } + +func TestChatCompletionStream_MockInjection(t *testing.T) { + mock := &mockStream{} + stream := openai.NewChatCompletionStream(mock) + + resp, err := stream.Recv() + if err != nil || resp.ID != "mock1" { + t.Errorf("expected mock1, got %v, err %v", resp.ID, err) + } + _, err = stream.Recv() + if !errors.Is(err, io.EOF) { + t.Errorf("expected EOF, got %v", err) + } +} + // Helper funcs. func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { diff --git a/mock_streaming_demo_test.go b/mock_streaming_demo_test.go new file mode 100644 index 000000000..d235766f2 --- /dev/null +++ b/mock_streaming_demo_test.go @@ -0,0 +1,199 @@ +package openai_test + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/sashabaranov/go-openai" +) + +// This file demonstrates how to create mock clients for go-openai streaming +// functionality. This pattern is useful when testing code that depends on +// go-openai streaming but you want to control the responses for testing. + +// MockOpenAIStreamClient demonstrates how to create a full mock client for go-openai. +type MockOpenAIStreamClient struct { + // Configure canned responses + ChatCompletionResponse openai.ChatCompletionResponse + ChatCompletionStreamErr error + + // Allow function overrides for more complex scenarios + CreateChatCompletionStreamFn func( + ctx context.Context, req openai.ChatCompletionRequest) (*openai.ChatCompletionStream, error) +} + +func (m *MockOpenAIStreamClient) CreateChatCompletionStream( + ctx context.Context, + req openai.ChatCompletionRequest, +) (*openai.ChatCompletionStream, error) { + if m.CreateChatCompletionStreamFn != nil { + return m.CreateChatCompletionStreamFn(ctx, req) + } + return nil, m.ChatCompletionStreamErr +} + +// mockStreamReader creates specific responses for testing. +type mockStreamReader struct { + responses []openai.ChatCompletionStreamResponse + index int +} + +func (m *mockStreamReader) Recv() (openai.ChatCompletionStreamResponse, error) { + if m.index >= len(m.responses) { + return openai.ChatCompletionStreamResponse{}, io.EOF + } + resp := m.responses[m.index] + m.index++ + return resp, nil +} + +func (m *mockStreamReader) Close() error { + return nil +} + +func TestMockOpenAIStreamClient_Demo(t *testing.T) { + // Create expected responses that our mock stream will return + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "test-1", + Object: "chat.completion.chunk", + Model: "gpt-3.5-turbo", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Role: "assistant", + Content: "Hello", + }, + }, + }, + }, + { + ID: "test-2", + Object: "chat.completion.chunk", + Model: "gpt-3.5-turbo", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " World", + }, + }, + }, + }, + { + ID: "test-3", + Object: "chat.completion.chunk", + Model: "gpt-3.5-turbo", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + // Create mock client with custom stream function + mockClient := &MockOpenAIStreamClient{ + CreateChatCompletionStreamFn: func( + _ context.Context, _ openai.ChatCompletionRequest, + ) (*openai.ChatCompletionStream, error) { + // Create a mock stream reader with our expected responses + mockStreamReader := &mockStreamReader{ + responses: expectedResponses, + index: 0, + } + // Return a new ChatCompletionStream with our mock reader + return openai.NewChatCompletionStream(mockStreamReader), nil + }, + } + + // Test the mock client + stream, err := mockClient.CreateChatCompletionStream( + context.Background(), + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + ) + if err != nil { + t.Fatalf("CreateChatCompletionStream returned error: %v", err) + } + defer stream.Close() + + // Verify we get back exactly the responses we configured + fullResponse := "" + for i, expectedResponse := range expectedResponses { + receivedResponse, streamErr := stream.Recv() + if streamErr != nil { + t.Fatalf("stream.Recv() failed at index %d: %v", i, streamErr) + } + + // Additional specific checks + if receivedResponse.ID != expectedResponse.ID { + t.Errorf("Response %d ID mismatch. Expected: %s, Got: %s", + i, expectedResponse.ID, receivedResponse.ID) + } + if len(receivedResponse.Choices) > 0 && len(expectedResponse.Choices) > 0 { + expectedContent := expectedResponse.Choices[0].Delta.Content + receivedContent := receivedResponse.Choices[0].Delta.Content + if receivedContent != expectedContent { + t.Errorf("Response %d content mismatch. Expected: %s, Got: %s", + i, expectedContent, receivedContent) + } + fullResponse += receivedContent + } + } + + // Verify EOF at the end + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("Expected EOF at end of stream, got: %v", streamErr) + } + + // Verify the full assembled response + expectedFullResponse := "Hello World" + if fullResponse != expectedFullResponse { + t.Errorf("Full response mismatch. Expected: %s, Got: %s", expectedFullResponse, fullResponse) + } + + t.Log("✅ Successfully demonstrated mock OpenAI client with streaming responses!") + t.Logf(" Full response assembled: %q", fullResponse) +} + +// TestMockOpenAIStreamClient_ErrorHandling demonstrates error handling. +func TestMockOpenAIStreamClient_ErrorHandling(t *testing.T) { + expectedError := errors.New("mock stream error") + + mockClient := &MockOpenAIStreamClient{ + ChatCompletionStreamErr: expectedError, + } + + _, err := mockClient.CreateChatCompletionStream( + context.Background(), + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + ) + + if !errors.Is(err, expectedError) { + t.Errorf("Expected error %v, got %v", expectedError, err) + } + + t.Log("✅ Successfully demonstrated mock OpenAI client error handling!") +} diff --git a/stream_reader.go b/stream_reader.go index 6faefe0a7..4dbcfc4b6 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -16,6 +16,8 @@ var ( errorPrefix = regexp.MustCompile(`^data:\s*{"error":`) ) +var _ ChatStreamReader = (*streamReader[ChatCompletionStreamResponse])(nil) + type streamable interface { ChatCompletionStreamResponse | CompletionResponse }