diff --git a/.gitignore b/.gitignore index 4c49bd7..bd20d6f 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,32 @@ +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work + +# Env .env +.env.* + +# IDEs +.idea/ +.vscode/ + +# macOS +.DS_Store diff --git a/client.go b/client.go new file mode 100644 index 0000000..80b385b --- /dev/null +++ b/client.go @@ -0,0 +1,143 @@ +package gpt3 + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +const ( + DEFAULT_BASE_URL = "https://api.openai.com/v1" + DEFAULT_USER_AGENT = "gpt3-go" + DEFAULT_TIMEOUT = 30 +) + +var dataPrefix = []byte("data: ") +var streamTerminationPrefix = []byte("[DONE]") + +type Client interface { + Models(ctx context.Context) (*ModelsResponse, error) + Model(ctx context.Context, model string) (*ModelObject, error) + Completion(ctx context.Context, request CompletionRequest) (*CompletionResponse, error) + CompletionStream(ctx context.Context, request CompletionRequest, onData func(*CompletionResponse)) error + Edits(ctx context.Context, request EditsRequest) (*EditsResponse, error) + Embeddings(ctx context.Context, request EmbeddingsRequest) (*EmbeddingsResponse, error) + Files(ctx context.Context) (*FilesResponse, error) + UploadFile(ctx context.Context, request UploadFileRequest) (*FileObject, error) + DeleteFile(ctx context.Context, fileID string) (*DeleteFileResponse, error) + File(ctx context.Context, fileID string) (*FileObject, error) + FileContent(ctx context.Context, fileID string) ([]byte, error) + CreateFineTune(ctx context.Context, request CreateFineTuneRequest) (*FineTuneObject, error) + ListFineTunes(ctx context.Context) (*ListFineTunesResponse, error) + FineTune(ctx context.Context, fineTuneID string) (*FineTuneObject, error) + CancelFineTune(ctx context.Context, fineTuneID string) (*FineTuneObject, error) + FineTuneEvents(ctx context.Context, request FineTuneEventsRequest) (*FineTuneEventsResponse, error) + FineTuneStreamEvents(ctx context.Context, request FineTuneEventsRequest, onData func(*FineTuneEvent)) error + DeleteFineTuneModel(ctx context.Context, modelID string) (*DeleteFineTuneModelResponse, error) + + // Deprecated + CompletionWithEngine(ctx context.Context, engine string, request CompletionRequest) (*CompletionResponse, error) + CompletionStreamWithEngine(ctx context.Context, engine string, request CompletionRequest, onData func(*CompletionResponse)) error +} + +type client struct { + baseURL string + apiKey string + orgID string + userAgent string + httpClient *http.Client + defaultModel string +} + +func NewClient(apiKey string, options ...ClientOption) (Client, error) { + c := &client{ + baseURL: DEFAULT_BASE_URL, + apiKey: apiKey, + orgID: "", + userAgent: DEFAULT_USER_AGENT, + httpClient: &http.Client{Timeout: time.Duration(DEFAULT_TIMEOUT) * time.Second}, + defaultModel: DavinciModel, + } + + for _, option := range options { + if err := option(c); err != nil { + return nil, err + } + } + + return c, nil +} + +func (c *client) newRequest(ctx context.Context, method, path string, payload interface{}) (*http.Request, error) { + bodyReader, err := jsonBodyReader(payload) + if err != nil { + return nil, err + } + url := c.baseURL + path + req, err := http.NewRequestWithContext(ctx, method, url, bodyReader) + if err != nil { + return nil, err + } + req.Header.Set("Content-type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) + req.Header.Set("User-Agent", c.userAgent) + if len(c.orgID) > 0 { + req.Header.Set("OpenAI-Organization", c.orgID) + } + return req, nil +} + +func (c *client) performRequest(req *http.Request) (*http.Response, error) { + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return resp, checkForSuccess(resp) +} + +func checkForSuccess(resp *http.Response) error { + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return nil + } + data, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read from body: %w", err) + } + var result APIErrorResponse + if err := json.Unmarshal(data, &result); err != nil { + // if we can't decode the json error then create an unexpected error + apiError := APIError{ + StatusCode: resp.StatusCode, + Type: "Unexpected", + Message: string(data), + } + return apiError + } + result.Error.StatusCode = resp.StatusCode + return result.Error +} + +func getResponseObject(rsp *http.Response, v interface{}) error { + defer rsp.Body.Close() + if err := json.NewDecoder(rsp.Body).Decode(v); err != nil { + return fmt.Errorf("invalid json response: %w", err) + } + return nil +} + +func jsonBodyReader(body interface{}) (io.Reader, error) { + if body == nil { + // the body is allowed to be nil so we return an empty buffer + return bytes.NewBuffer(nil), nil + } + raw, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed encoding json: %w", err) + } + return bytes.NewBuffer(raw), nil +} diff --git a/client_options.go b/client_options.go index 27e76b3..8facad8 100644 --- a/client_options.go +++ b/client_options.go @@ -8,18 +8,18 @@ import ( // ClientOption are options that can be passed when creating a new client type ClientOption func(*client) error -// WithOrg is a client option that allows you to override the organization ID +// WithOrg is a client option that allows you to set the organization ID func WithOrg(id string) ClientOption { return func(c *client) error { - c.idOrg = id + c.orgID = id return nil } } -// WithDefaultEngine is a client option that allows you to override the default engine of the client -func WithDefaultEngine(engine string) ClientOption { +// WithDefaultModel is a client option that allows you to override the default model of the client +func WithDefaultModel(model string) ClientOption { return func(c *client) error { - c.defaultEngine = engine + c.defaultModel = model return nil } } diff --git a/client_options_test.go b/client_options_test.go new file mode 100644 index 0000000..2dcb5bb --- /dev/null +++ b/client_options_test.go @@ -0,0 +1,36 @@ +package gpt3_test + +import ( + "net/http" + "testing" + + "github.com/PullRequestInc/go-gpt3" + "github.com/stretchr/testify/assert" +) + +func TestClient(t *testing.T) { + testCases := []struct { + name string + options []gpt3.ClientOption + }{ + { + name: "test-key", + options: []gpt3.ClientOption{ + gpt3.WithOrg("test-org"), + gpt3.WithDefaultModel("test-model"), + gpt3.WithUserAgent("test-agent"), + gpt3.WithBaseURL("test-url"), + gpt3.WithHTTPClient(&http.Client{}), + gpt3.WithTimeout(10), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + client, err := gpt3.NewClient(tc.name, tc.options...) + assert.Nil(t, err) + assert.NotNil(t, client) + }) + } +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..440304c --- /dev/null +++ b/client_test.go @@ -0,0 +1,14 @@ +package gpt3_test + +import ( + "testing" + + "github.com/PullRequestInc/go-gpt3" + "github.com/stretchr/testify/assert" +) + +func TestInitNewClient(t *testing.T) { + client, err := gpt3.NewClient("test-key") + assert.Nil(t, err) + assert.NotNil(t, client) +} diff --git a/cmd/errors/main.go b/cmd/errors/main.go index 392e42d..336ff04 100644 --- a/cmd/errors/main.go +++ b/cmd/errors/main.go @@ -11,7 +11,10 @@ import ( ) func main() { - godotenv.Load() + err := godotenv.Load() + if err != nil { + log.Fatalln(err) + } apiKey := os.Getenv("API_KEY") if apiKey == "" { @@ -19,7 +22,10 @@ func main() { } ctx := context.Background() - client := gpt3.NewClient(apiKey) + client, err := gpt3.NewClient(apiKey) + if err != nil { + log.Fatalln(err) + } resp, err := client.Completion(ctx, gpt3.CompletionRequest{ Prompt: []string{ diff --git a/cmd/test/main.go b/cmd/test/main.go index 15b4725..64fcb12 100644 --- a/cmd/test/main.go +++ b/cmd/test/main.go @@ -19,7 +19,10 @@ func main() { } ctx := context.Background() - client := gpt3.NewClient(apiKey) + client, err := gpt3.NewClient(apiKey) + if err != nil { + log.Fatalln(err) + } resp, err := client.Completion(ctx, gpt3.CompletionRequest{ Prompt: []string{ diff --git a/completions.go b/completions.go new file mode 100644 index 0000000..73f2fb0 --- /dev/null +++ b/completions.go @@ -0,0 +1,151 @@ +package gpt3 + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" +) + +// CompletionRequest is a request for the completions API +type CompletionRequest struct { + // The model ID to use for completion + Model string `json:"model"` + // A list of string prompts to use. + // TODO there are other prompt types here for using token integers that we could add support for. + Prompt []string `json:"prompt"` + // The suffix that comes after a completion of inserted text + Suffix string `json:"suffix,omitempty"` + // How many tokens to complete up to. Max of 512 + MaxTokens *int `json:"max_tokens,omitempty"` + // Sampling temperature to use + Temperature *float32 `json:"temperature,omitempty"` + // Alternative to temperature for nucleus sampling + TopP *float32 `json:"top_p,omitempty"` + // How many choice to create for each prompt + N *int `json:"n"` + // Include the probabilities of most likely tokens + LogProbs *int `json:"logprobs"` + // Echo back the prompt in addition to the completion + Echo bool `json:"echo"` + // Up to 4 sequences where the API will stop generating tokens. Response will not contain the stop sequence. + Stop []string `json:"stop,omitempty"` + // PresencePenalty number between 0 and 1 that penalizes tokens that have already appeared in the text so far. + PresencePenalty float32 `json:"presence_penalty"` + // FrequencyPenalty number between 0 and 1 that penalizes tokens on existing frequency in the text so far. + FrequencyPenalty float32 `json:"frequency_penalty"` + // BestOf generates n completions server-side and returns the "best" one. + BestOf *int `json:"best_of,omitempty"` + // LogitBias is a list of token logit biases to apply before sampling. + LogitBias []interface{} `json:"logit_bias,omitempty"` + // User is the user ID to associate with this request + User string `json:"user,omitempty"` + + // Whether to stream back results or not. Don't set this value in the request yourself + // as it will be overridden depending on if you use CompletionStream or Completion methods. + Stream bool `json:"stream,omitempty"` +} + +// CompletionResponse is the full response from a request to the completions API +type CompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []CompletionResponseChoice `json:"choices"` + Usage CompletionResponseUsage `json:"usage"` +} + +// CompletionResponseChoice is one of the choices returned in the response to the Completions API +type CompletionResponseChoice struct { + Text string `json:"text"` + Index int `json:"index"` + LogProbs LogprobResult `json:"logprobs"` + FinishReason string `json:"finish_reason"` +} + +// LogprobResult represents logprob result of Choice +type LogprobResult struct { + Tokens []string `json:"tokens"` + TokenLogprobs []float32 `json:"token_logprobs"` + TopLogprobs []map[string]float32 `json:"top_logprobs"` + TextOffset []int `json:"text_offset"` +} + +// CompletionResponseUsage is the object that returns how many tokens the completion's request used +type CompletionResponseUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// Completion is a single completion. +func (c *client) Completion(ctx context.Context, request CompletionRequest) (*CompletionResponse, error) { + request.Stream = false + // Check if the request provides a model + if request.Model == "" { + request.Model = c.defaultModel + } + req, err := c.newRequest(ctx, "POST", "/completions", request) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := new(CompletionResponse) + if err := getResponseObject(resp, output); err != nil { + return nil, err + } + return output, nil +} + +func (c *client) CompletionStream(ctx context.Context, request CompletionRequest, onData func(*CompletionResponse)) error { + request.Stream = true + // Check if the request provides a model + if request.Model == "" { + request.Model = c.defaultModel + } + req, err := c.newRequest(ctx, "POST", "/completions", request) + if err != nil { + return err + } + resp, err := c.performRequest(req) + if err != nil { + return err + } + + reader := bufio.NewReader(resp.Body) + defer resp.Body.Close() + + for { + line, err := reader.ReadBytes('\n') + if err != nil { + return err + } + + // Trim whitespace + line = bytes.TrimSpace(line) + // Check if the line has data + if !bytes.HasPrefix(line, dataPrefix) { + continue + } + // Trim the data prefix + line = bytes.TrimPrefix(line, dataPrefix) + + // Check if the stream is terminated + if bytes.HasPrefix(line, streamTerminationPrefix) { + break + } + + output := new(CompletionResponse) + if err := json.Unmarshal(line, output); err != nil { + return fmt.Errorf("invalid json stream data: %v", err) + } + onData(output) + } + return nil +} diff --git a/edits.go b/edits.go new file mode 100644 index 0000000..76579e0 --- /dev/null +++ b/edits.go @@ -0,0 +1,57 @@ +package gpt3 + +import "context" + +// EditsRequest is a request for the edits API +type EditsRequest struct { + // ID of the model to use. You can use the List models API to see all of your available models, or see our Model overview for descriptions of them. + Model string `json:"model"` + // The input text to use as a starting point for the edit. + Input string `json:"input"` + // The instruction that tells the model how to edit the prompt. + Instruction string `json:"instruction"` + // Sampling temperature to use + Temperature *float32 `json:"temperature,omitempty"` + // Alternative to temperature for nucleus sampling + TopP *float32 `json:"top_p,omitempty"` + // How many edits to generate for the input and instruction. Defaults to 1 + N *int `json:"n"` +} + +// EditsResponse is the full response from a request to the edits API +type EditsResponse struct { + Object string `json:"object"` + Created int `json:"created"` + Choices []EditsResponseChoice `json:"choices"` + Usage EditsResponseUsage `json:"usage"` +} + +// EditsResponseChoice is one of the choices returned in the response to the Edits API +type EditsResponseChoice struct { + Text string `json:"text"` + Index int `json:"index"` +} + +// EditsResponseUsage is a structure used in the response from a request to the edits API +type EditsResponseUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +func (c *client) Edits(ctx context.Context, request EditsRequest) (*EditsResponse, error) { + req, err := c.newRequest(ctx, "POST", "/edits", request) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := new(EditsResponse) + if err := getResponseObject(resp, output); err != nil { + return nil, err + } + return output, nil +} diff --git a/embeddings.go b/embeddings.go new file mode 100644 index 0000000..7f00c9c --- /dev/null +++ b/embeddings.go @@ -0,0 +1,84 @@ +package gpt3 + +import "context" + +const ( + TextSimilarityAda001 = "text-similarity-ada-001" + TextSimilarityBabbage001 = "text-similarity-babbage-001" + TextSimilarityCurie001 = "text-similarity-curie-001" + TextSimilarityDavinci001 = "text-similarity-davinci-001" + TextSearchAdaDoc001 = "text-search-ada-doc-001" + TextSearchAdaQuery001 = "text-search-ada-query-001" + TextSearchBabbageDoc001 = "text-search-babbage-doc-001" + TextSearchBabbageQuery001 = "text-search-babbage-query-001" + TextSearchCurieDoc001 = "text-search-curie-doc-001" + TextSearchCurieQuery001 = "text-search-curie-query-001" + TextSearchDavinciDoc001 = "text-search-davinci-doc-001" + TextSearchDavinciQuery001 = "text-search-davinci-query-001" + CodeSearchAdaCode001 = "code-search-ada-code-001" + CodeSearchAdaText001 = "code-search-ada-text-001" + CodeSearchBabbageCode001 = "code-search-babbage-code-001" + CodeSearchBabbageText001 = "code-search-babbage-text-001" + TextEmbeddingAda002 = "text-embedding-ada-002" +) + +// EmbeddingsRequest is a request for the Embeddings API +type EmbeddingsRequest struct { + // Input text to get embeddings for, encoded as a string or array of tokens. To get embeddings + // for multiple inputs in a single request, pass an array of strings or array of token arrays. + // Each input must not exceed 2048 tokens in length. + Input []string `json:"input"` + // ID of the model to use + Model string `json:"model"` + // The request user is an optional parameter meant to be used to trace abusive requests + // back to the originating user. OpenAI states: + // "The [user] IDs should be a string that uniquely identifies each user. We recommend hashing + // their username or email address, in order to avoid sending us any identifying information. + // If you offer a preview of your product to non-logged in users, you can send a session ID + // instead." + User string `json:"user,omitempty"` +} + +// EmbeddingsResponse is the response from a create embeddings request. +type EmbeddingsResponse struct { + Object string `json:"object"` + Data []EmbeddingsResult `json:"data"` + Usage EmbeddingsUsage `json:"usage"` +} + +// The inner result of a create embeddings request, containing the embeddings for a single input. +type EmbeddingsResult struct { + // The type of object returned (e.g., "list", "object") + Object string `json:"object"` + // The embedding data for the input + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +// The usage stats for an embeddings response +type EmbeddingsUsage struct { + // The number of tokens used by the prompt + PromptTokens int `json:"prompt_tokens"` + // The total tokens used + TotalTokens int `json:"total_tokens"` +} + +// Embeddings creates text embeddings for a supplied slice of inputs with a provided model. +// +// See: https://beta.openai.com/docs/api-reference/embeddings +func (c *client) Embeddings(ctx context.Context, request EmbeddingsRequest) (*EmbeddingsResponse, error) { + req, err := c.newRequest(ctx, "POST", "/embeddings", request) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := EmbeddingsResponse{} + if err := getResponseObject(resp, &output); err != nil { + return nil, err + } + return &output, nil +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..cb7c129 --- /dev/null +++ b/errors.go @@ -0,0 +1,19 @@ +package gpt3 + +import "fmt" + +// APIError represents an error that occurred on an API +type APIError struct { + StatusCode int `json:"status_code"` + Message string `json:"message"` + Type string `json:"type"` +} + +func (e APIError) Error() string { + return fmt.Sprintf("[%d:%s] %s", e.StatusCode, e.Type, e.Message) +} + +// APIErrorResponse is the full error response that has been returned by an API. +type APIErrorResponse struct { + Error APIError `json:"error"` +} diff --git a/files.go b/files.go new file mode 100644 index 0000000..0c615c5 --- /dev/null +++ b/files.go @@ -0,0 +1,135 @@ +package gpt3 + +import ( + "context" + "fmt" + "io/ioutil" +) + +// UploadFileRequest is a request for the Files API +type UploadFileRequest struct { + // The file name of the JSON Lines file to upload + File string `json:"file"` + // The purpose of the file. Use "fine-tune" for a file that will be used to fine-tune a model. + Purpose string `json:"purpose"` +} + +// FileObject is a single file object +type FileObject struct { + ID string `json:"id"` + Object string `json:"object"` + Bytes int `json:"bytes"` + CreatedAt int `json:"created_at"` + Filename string `json:"filename"` + Purpose string `json:"purpose"` +} + +// FilesResponse is the response from a list files request. +type FilesResponse struct { + Data []FileObject `json:"data"` + Object string `json:"object"` +} + +// DeleteFileResponse is the response from a delete file request. +type DeleteFileResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` +} + +// Files lists the files that belong to the user's organization. +// +// See: https://beta.openai.com/docs/api-reference/files/list +func (c *client) Files(ctx context.Context) (*FilesResponse, error) { + req, err := c.newRequest(ctx, "GET", "/files", nil) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := FilesResponse{} + if err := getResponseObject(resp, &output); err != nil { + return nil, err + } + return &output, nil +} + +// UploadFile uploads a file that contains document(s) to be used across various endpoints. +// +// See: https://beta.openai.com/docs/api-reference/files/upload +func (c *client) UploadFile(ctx context.Context, request UploadFileRequest) (*FileObject, error) { + req, err := c.newRequest(ctx, "POST", "/files", request) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := FileObject{} + if err := getResponseObject(resp, &output); err != nil { + return nil, err + } + return &output, nil +} + +// DeleteFile deletes a file that contains document(s) to be used across various endpoints. +// +// See: https://beta.openai.com/docs/api-reference/files/delete +func (c *client) DeleteFile(ctx context.Context, fileID string) (*DeleteFileResponse, error) { + req, err := c.newRequest(ctx, "DELETE", fmt.Sprintf("/files/%s", fileID), nil) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := DeleteFileResponse{} + if err := getResponseObject(resp, &output); err != nil { + return nil, err + } + return &output, nil +} + +// File retrieves a file that contains document(s) to be used across various endpoints. +// +// See: https://beta.openai.com/docs/api-reference/files/retrieve +func (c *client) File(ctx context.Context, fileID string) (*FileObject, error) { + req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/files/%s", fileID), nil) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := FileObject{} + if err := getResponseObject(resp, &output); err != nil { + return nil, err + } + return &output, nil +} + +// FileContent retrieves the content of a file that contains document(s) to be used across various endpoints. +// +// See: https://beta.openai.com/docs/api-reference/files/retrieve-content +func (c *client) FileContent(ctx context.Context, fileID string) ([]byte, error) { + req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/files/%s/content", fileID), nil) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + return ioutil.ReadAll(resp.Body) +} diff --git a/fine_tunes.go b/fine_tunes.go new file mode 100644 index 0000000..6e56bdc --- /dev/null +++ b/fine_tunes.go @@ -0,0 +1,279 @@ +package gpt3 + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" +) + +// CreateFineTuneRequest is a request for the FineTune API +type CreateFineTuneRequest struct { + // The ID of an uploaded file that contains training data + TrainingFile string `json:"training_file"` + // The ID of an uploaded file that contains validation data + ValidationFile string `json:"validation_file"` + // The ID of the model to fine-tune + Model string `json:"model"` + // The number of epochs to train for + NEpochs int `json:"n_epochs"` + // The batch size to use for training + BatchSize int `json:"batch_size"` + // The learning rate to use for training + LearningRate float32 `json:"learning_rate"` + // The weight to use for loss on the prompt tokens + PromptLossWeight float32 `json:"prompt_loss_weight"` + // If set, we calculate classification-specific metrics using the validation set at the end of each epoch + ComputeClassificationMetrics bool `json:"compute_classification_metrics"` + // The number of classes in a classification task + ClassificationNClasses int `json:"classification_n_classes"` + // The positive class in binary classification + ClassificationPositiveClass string `json:"classification_positive_class"` + // If this is provided, we calculate F-beta scores at the specified beta values + ClassificationBetas []float32 `json:"classification_betas"` + // A string of up to 40 characters that will be added to your fine-tuned model name + Suffix string `json:"suffix"` +} + +// FineTuneRequest is a request for the FineTune API +type FineTuneRequest struct { + // The ID of the fine-tune job + FineTuneID string `json:"fine_tune_id"` +} + +// FineTuneEventsRequest is a request for the FineTune API +type FineTuneEventsRequest struct { + // The ID of the fine-tune job + FineTuneID string `json:"fine_tune_id"` + // Whether to stream events for the fine-tune job + Stream bool `json:"stream"` +} + +// DeleteFineTuneModelRequest is a request for the FineTune API +type DeleteFineTuneModelRequest struct { + // The ID of the fine-tune model to delete + Model string `json:"model"` +} + +// FineTuneEvent is a single fine tune event +type FineTuneEvent struct { + Object string `json:"object"` + CreatedAt int `json:"created_at"` + Level string `json:"level"` + Message string `json:"message"` +} + +// FineTuneHyperparams is the hyperparams for a fine tune request +type FineTuneHyperparams struct { + BatchSize int `json:"batch_size"` + LearningRateMultiplier float64 `json:"learning_rate_multiplier"` + NEpochs int `json:"n_epochs"` + PromptLossWeight float64 `json:"prompt_loss_weight"` +} + +// FineTuneObject is a single fine tune object +// +// See: https://beta.openai.com/docs/api-reference/fine-tunes/retrieve +type FineTuneObject struct { + ID string `json:"id"` + Object string `json:"object"` + Model string `json:"model"` + CreatedAt int `json:"created_at"` + Events []FineTuneEvent `json:"events"` + FineTuneModel string `json:"fine_tune_model"` + Hyperparams FineTuneHyperparams `json:"hyperparams"` + OrganizationID string `json:"organization_id"` + ResultFiles []FileObject `json:"result_files"` + Status string `json:"status"` + ValidationFiles []FileObject `json:"validation_files"` + TrainingFiles []FileObject `json:"training_files"` + UpdatedAt int `json:"updated_at"` +} + +// ListFineTunesResponse is the response from a list fine tunes request. +// +// See: https://beta.openai.com/docs/api-reference/fine-tunes/list +type ListFineTunesResponse struct { + Data []FineTuneObject `json:"data"` + Object string `json:"object"` +} + +// FineTuneEventsResponse +// +// See: https://beta.openai.com/docs/api-reference/fine-tunes/events +type FineTuneEventsResponse struct { + Data []FineTuneEvent `json:"data"` + Object string `json:"object"` +} + +// DeleteFineTuneModelResponse +// +// See: https://beta.openai.com/docs/api-reference/fine-tunes/delete-model +type DeleteFineTuneModelResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` +} + +// CreateFineTune creates a job that fine-tunes a model on a dataset. +// +// See: https://beta.openai.com/docs/api-reference/fine-tunes/create +func (c *client) CreateFineTune(ctx context.Context, request CreateFineTuneRequest) (*FineTuneObject, error) { + req, err := c.newRequest(ctx, "POST", "/fine-tunes", request) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := FineTuneObject{} + if err := getResponseObject(resp, &output); err != nil { + return nil, err + } + return &output, nil +} + +// ListFineTunes lists the fine-tuning jobs that belong to the user's organization. +// +// See: https://beta.openai.com/docs/api-reference/fine-tunes/list +func (c *client) ListFineTunes(ctx context.Context) (*ListFineTunesResponse, error) { + req, err := c.newRequest(ctx, "GET", "/fine-tunes", nil) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := ListFineTunesResponse{} + if err := getResponseObject(resp, &output); err != nil { + return nil, err + } + return &output, nil +} + +// FineTune retrieves a fine-tuning job from the user's organization. +// +// See: https://beta.openai.com/docs/api-reference/fine-tunes/retrieve +func (c *client) FineTune(ctx context.Context, fineTuneID string) (*FineTuneObject, error) { + req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/fine-tunes/%s", fineTuneID), nil) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := FineTuneObject{} + if err := getResponseObject(resp, &output); err != nil { + return nil, err + } + return &output, nil +} + +// CancelFineTune cancels a fine-tuning job from the user's organization. +// +// See: https://beta.openai.com/docs/api-reference/fine-tunes/cancel +func (c *client) CancelFineTune(ctx context.Context, fineTuneID string) (*FineTuneObject, error) { + req, err := c.newRequest(ctx, "POST", fmt.Sprintf("/fine-tunes/%s/cancel", fineTuneID), nil) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := FineTuneObject{} + if err := getResponseObject(resp, &output); err != nil { + return nil, err + } + return &output, nil +} + +// FineTuneEvents lists the events that belong to a fine-tuning job. +// +// See: https://beta.openai.com/docs/api-reference/fine-tunes/events +func (c *client) FineTuneEvents(ctx context.Context, request FineTuneEventsRequest) (*FineTuneEventsResponse, error) { + request.Stream = false + req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/fine-tunes/%s/events", request.FineTuneID), nil) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := FineTuneEventsResponse{} + if err := getResponseObject(resp, &output); err != nil { + return nil, err + } + return &output, nil +} + +// FineTuneStreamEvents streams the events that belong to a fine-tuning job. +// +// See: https://beta.openai.com/docs/api-reference/fine-tunes/events +func (c *client) FineTuneStreamEvents(ctx context.Context, request FineTuneEventsRequest, onData func(*FineTuneEvent)) error { + request.Stream = true + req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/fine-tunes/%s/events", request.FineTuneID), nil) + if err != nil { + return err + } + resp, err := c.performRequest(req) + if err != nil { + return err + } + + reader := bufio.NewReader(resp.Body) + defer resp.Body.Close() + + for { + line, err := reader.ReadBytes('\n') + if err != nil { + return err + } + + line = bytes.TrimSpace(line) + if !bytes.HasPrefix(line, dataPrefix) { + continue + } + line = bytes.TrimPrefix(line, dataPrefix) + + if bytes.HasPrefix(line, streamTerminationPrefix) { + break + } + output := new(FineTuneEvent) + if err := json.Unmarshal(line, output); err != nil { + return fmt.Errorf("invalid json stream data: %v", err) + } + onData(output) + } + return nil +} + +// DeleteFineTuneModel deletes a fine-tuned model from the user's organization. +// +// See: https://beta.openai.com/docs/api-reference/fine-tunes/delete-model +func (c *client) DeleteFineTuneModel(ctx context.Context, modelID string) (*DeleteFineTuneModelResponse, error) { + req, err := c.newRequest(ctx, "DELETE", fmt.Sprintf("/models/%s", modelID), nil) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := DeleteFineTuneModelResponse{} + if err := getResponseObject(resp, &output); err != nil { + return nil, err + } + return &output, nil +} diff --git a/gpt3.go b/gpt3.go deleted file mode 100644 index 207612f..0000000 --- a/gpt3.go +++ /dev/null @@ -1,364 +0,0 @@ -package gpt3 - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "io/ioutil" - "net/http" - "time" -) - -// Engine Types -const ( - TextAda001Engine = "text-ada-001" - TextBabbage001Engine = "text-babbage-001" - TextCurie001Engine = "text-curie-001" - TextDavinci001Engine = "text-davinci-001" - TextDavinci002Engine = "text-davinci-002" - TextDavinci003Engine = "text-davinci-003" - AdaEngine = "ada" - BabbageEngine = "babbage" - CurieEngine = "curie" - DavinciEngine = "davinci" - DefaultEngine = DavinciEngine -) - -type EmbeddingEngine string - -const ( - TextSimilarityAda001 = "text-similarity-ada-001" - TextSimilarityBabbage001 = "text-similarity-babbage-001" - TextSimilarityCurie001 = "text-similarity-curie-001" - TextSimilarityDavinci001 = "text-similarity-davinci-001" - TextSearchAdaDoc001 = "text-search-ada-doc-001" - TextSearchAdaQuery001 = "text-search-ada-query-001" - TextSearchBabbageDoc001 = "text-search-babbage-doc-001" - TextSearchBabbageQuery001 = "text-search-babbage-query-001" - TextSearchCurieDoc001 = "text-search-curie-doc-001" - TextSearchCurieQuery001 = "text-search-curie-query-001" - TextSearchDavinciDoc001 = "text-search-davinci-doc-001" - TextSearchDavinciQuery001 = "text-search-davinci-query-001" - CodeSearchAdaCode001 = "code-search-ada-code-001" - CodeSearchAdaText001 = "code-search-ada-text-001" - CodeSearchBabbageCode001 = "code-search-babbage-code-001" - CodeSearchBabbageText001 = "code-search-babbage-text-001" - TextEmbeddingAda002 = "text-embedding-ada-002" -) - -const ( - defaultBaseURL = "https://api.openai.com/v1" - defaultUserAgent = "go-gpt3" - defaultTimeoutSeconds = 30 -) - -func getEngineURL(engine string) string { - return fmt.Sprintf("%s/engines/%s/completions", defaultBaseURL, engine) -} - -// A Client is an API client to communicate with the OpenAI gpt-3 APIs -type Client interface { - // Engines lists the currently available engines, and provides basic information about each - // option such as the owner and availability. - Engines(ctx context.Context) (*EnginesResponse, error) - - // Engine retrieves an engine instance, providing basic information about the engine such - // as the owner and availability. - Engine(ctx context.Context, engine string) (*EngineObject, error) - - // Completion creates a completion with the default engine. This is the main endpoint of the API - // which auto-completes based on the given prompt. - Completion(ctx context.Context, request CompletionRequest) (*CompletionResponse, error) - - // CompletionStream creates a completion with the default engine and streams the results through - // multiple calls to onData. - CompletionStream(ctx context.Context, request CompletionRequest, onData func(*CompletionResponse)) error - - // CompletionWithEngine is the same as Completion except allows overriding the default engine on the client - CompletionWithEngine(ctx context.Context, engine string, request CompletionRequest) (*CompletionResponse, error) - - // CompletionStreamWithEngine is the same as CompletionStream except allows overriding the default engine on the client - CompletionStreamWithEngine(ctx context.Context, engine string, request CompletionRequest, onData func(*CompletionResponse)) error - - // Given a prompt and an instruction, the model will return an edited version of the prompt. - Edits(ctx context.Context, request EditsRequest) (*EditsResponse, error) - - // Search performs a semantic search over a list of documents with the default engine. - Search(ctx context.Context, request SearchRequest) (*SearchResponse, error) - - // SearchWithEngine performs a semantic search over a list of documents with the specified engine. - SearchWithEngine(ctx context.Context, engine string, request SearchRequest) (*SearchResponse, error) - - // Returns an embedding using the provided request. - Embeddings(ctx context.Context, request EmbeddingsRequest) (*EmbeddingsResponse, error) -} - -type client struct { - baseURL string - apiKey string - userAgent string - httpClient *http.Client - defaultEngine string - idOrg string -} - -// NewClient returns a new OpenAI GPT-3 API client. An apiKey is required to use the client -func NewClient(apiKey string, options ...ClientOption) Client { - httpClient := &http.Client{ - Timeout: time.Duration(defaultTimeoutSeconds * time.Second), - } - - c := &client{ - userAgent: defaultUserAgent, - apiKey: apiKey, - baseURL: defaultBaseURL, - httpClient: httpClient, - defaultEngine: DefaultEngine, - idOrg: "", - } - for _, o := range options { - o(c) - } - return c -} - -func (c *client) Engines(ctx context.Context) (*EnginesResponse, error) { - req, err := c.newRequest(ctx, "GET", "/engines", nil) - if err != nil { - return nil, err - } - resp, err := c.performRequest(req) - if err != nil { - return nil, err - } - - output := new(EnginesResponse) - if err := getResponseObject(resp, output); err != nil { - return nil, err - } - return output, nil -} - -func (c *client) Engine(ctx context.Context, engine string) (*EngineObject, error) { - req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/engines/%s", engine), nil) - if err != nil { - return nil, err - } - resp, err := c.performRequest(req) - if err != nil { - return nil, err - } - - output := new(EngineObject) - if err := getResponseObject(resp, output); err != nil { - return nil, err - } - return output, nil -} - -func (c *client) Completion(ctx context.Context, request CompletionRequest) (*CompletionResponse, error) { - return c.CompletionWithEngine(ctx, c.defaultEngine, request) -} - -func (c *client) CompletionWithEngine(ctx context.Context, engine string, request CompletionRequest) (*CompletionResponse, error) { - request.Stream = false - req, err := c.newRequest(ctx, "POST", fmt.Sprintf("/engines/%s/completions", engine), request) - if err != nil { - return nil, err - } - resp, err := c.performRequest(req) - if err != nil { - return nil, err - } - - output := new(CompletionResponse) - if err := getResponseObject(resp, output); err != nil { - return nil, err - } - return output, nil -} - -func (c *client) CompletionStream(ctx context.Context, request CompletionRequest, onData func(*CompletionResponse)) error { - return c.CompletionStreamWithEngine(ctx, c.defaultEngine, request, onData) -} - -var dataPrefix = []byte("data: ") -var doneSequence = []byte("[DONE]") - -func (c *client) CompletionStreamWithEngine( - ctx context.Context, - engine string, - request CompletionRequest, - onData func(*CompletionResponse), -) error { - request.Stream = true - req, err := c.newRequest(ctx, "POST", fmt.Sprintf("/engines/%s/completions", engine), request) - if err != nil { - return err - } - resp, err := c.performRequest(req) - if err != nil { - return err - } - - reader := bufio.NewReader(resp.Body) - defer resp.Body.Close() - - for { - line, err := reader.ReadBytes('\n') - if err != nil { - return err - } - // make sure there isn't any extra whitespace before or after - line = bytes.TrimSpace(line) - // the completion API only returns data events - if !bytes.HasPrefix(line, dataPrefix) { - continue - } - line = bytes.TrimPrefix(line, dataPrefix) - - // the stream is completed when terminated by [DONE] - if bytes.HasPrefix(line, doneSequence) { - break - } - output := new(CompletionResponse) - if err := json.Unmarshal(line, output); err != nil { - return fmt.Errorf("invalid json stream data: %v", err) - } - onData(output) - } - - return nil -} - -func (c *client) Edits(ctx context.Context, request EditsRequest) (*EditsResponse, error) { - req, err := c.newRequest(ctx, "POST", "/edits", request) - if err != nil { - return nil, err - } - resp, err := c.performRequest(req) - if err != nil { - return nil, err - } - - output := new(EditsResponse) - if err := getResponseObject(resp, output); err != nil { - return nil, err - } - return output, nil -} - -func (c *client) Search(ctx context.Context, request SearchRequest) (*SearchResponse, error) { - return c.SearchWithEngine(ctx, c.defaultEngine, request) -} - -func (c *client) SearchWithEngine(ctx context.Context, engine string, request SearchRequest) (*SearchResponse, error) { - req, err := c.newRequest(ctx, "POST", fmt.Sprintf("/engines/%s/search", engine), request) - if err != nil { - return nil, err - } - resp, err := c.performRequest(req) - if err != nil { - return nil, err - } - output := new(SearchResponse) - if err := getResponseObject(resp, output); err != nil { - return nil, err - } - return output, nil -} - -// Embeddings creates text embeddings for a supplied slice of inputs with a provided model. -// -// See: https://beta.openai.com/docs/api-reference/embeddings -func (c *client) Embeddings(ctx context.Context, request EmbeddingsRequest) (*EmbeddingsResponse, error) { - req, err := c.newRequest(ctx, "POST", "/embeddings", request) - if err != nil { - return nil, err - } - resp, err := c.performRequest(req) - if err != nil { - return nil, err - } - - output := EmbeddingsResponse{} - if err := getResponseObject(resp, &output); err != nil { - return nil, err - } - return &output, nil -} - -func (c *client) performRequest(req *http.Request) (*http.Response, error) { - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, err - } - if err := checkForSuccess(resp); err != nil { - return nil, err - } - return resp, nil -} - -// returns an error if this response includes an error. -func checkForSuccess(resp *http.Response) error { - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - return nil - } - defer resp.Body.Close() - data, err := ioutil.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read from body: %w", err) - } - var result APIErrorResponse - if err := json.Unmarshal(data, &result); err != nil { - // if we can't decode the json error then create an unexpected error - apiError := APIError{ - StatusCode: resp.StatusCode, - Type: "Unexpected", - Message: string(data), - } - return apiError - } - result.Error.StatusCode = resp.StatusCode - return result.Error -} - -func getResponseObject(rsp *http.Response, v interface{}) error { - defer rsp.Body.Close() - if err := json.NewDecoder(rsp.Body).Decode(v); err != nil { - return fmt.Errorf("invalid json response: %w", err) - } - return nil -} - -func jsonBodyReader(body interface{}) (io.Reader, error) { - if body == nil { - return bytes.NewBuffer(nil), nil - } - raw, err := json.Marshal(body) - if err != nil { - return nil, fmt.Errorf("failed encoding json: %w", err) - } - return bytes.NewBuffer(raw), nil -} - -func (c *client) newRequest(ctx context.Context, method, path string, payload interface{}) (*http.Request, error) { - bodyReader, err := jsonBodyReader(payload) - if err != nil { - return nil, err - } - url := c.baseURL + path - req, err := http.NewRequestWithContext(ctx, method, url, bodyReader) - if err != nil { - return nil, err - } - if len(c.idOrg) > 0 { - req.Header.Set("OpenAI-Organization", c.idOrg) - } - req.Header.Set("Content-type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey)) - return req, nil -} diff --git a/gpt3_test.go b/gpt3_test.go index 6b36b65..795f0f3 100644 --- a/gpt3_test.go +++ b/gpt3_test.go @@ -17,11 +17,6 @@ import ( //go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 net/http.RoundTripper -func TestInitNewClient(t *testing.T) { - client := gpt3.NewClient("test-key") - assert.NotNil(t, client) -} - func fakeHttpClient() (*fakes.FakeRoundTripper, *http.Client) { rt := &fakes.FakeRoundTripper{} return rt, &http.Client{ @@ -32,7 +27,8 @@ func fakeHttpClient() (*fakes.FakeRoundTripper, *http.Client) { func TestRequestCreationFails(t *testing.T) { ctx := context.Background() rt, httpClient := fakeHttpClient() - client := gpt3.NewClient("test-key", gpt3.WithHTTPClient(httpClient)) + client, err := gpt3.NewClient("test-key", gpt3.WithHTTPClient(httpClient)) + assert.Nil(t, err) rt.RoundTripReturns(nil, errors.New("request error")) type testCase struct { @@ -43,25 +39,25 @@ func TestRequestCreationFails(t *testing.T) { testCases := []testCase{ { - "Engines", + "Models", func() (interface{}, error) { - return client.Engines(ctx) + return client.Models(ctx) }, - "Get \"https://api.openai.com/v1/engines\": request error", + "Get \"https://api.openai.com/v1/models\": request error", }, { - "Engine", + "Model", func() (interface{}, error) { - return client.Engine(ctx, gpt3.DefaultEngine) + return client.Model(ctx, gpt3.DefaultModel) }, - "Get \"https://api.openai.com/v1/engines/davinci\": request error", + "Get \"https://api.openai.com/v1/models/davinci\": request error", }, { "Completion", func() (interface{}, error) { return client.Completion(ctx, gpt3.CompletionRequest{}) }, - "Post \"https://api.openai.com/v1/engines/davinci/completions\": request error", + "Post \"https://api.openai.com/v1/completions\": request error", }, { "CompletionStream", func() (interface{}, error) { @@ -71,47 +67,94 @@ func TestRequestCreationFails(t *testing.T) { } return rsp, client.CompletionStream(ctx, gpt3.CompletionRequest{}, onData) }, - "Post \"https://api.openai.com/v1/engines/davinci/completions\": request error", + "Post \"https://api.openai.com/v1/completions\": request error", }, { - "CompletionWithEngine", + "Edits", func() (interface{}, error) { - return client.CompletionWithEngine(ctx, gpt3.AdaEngine, gpt3.CompletionRequest{}) + return client.Edits(ctx, gpt3.EditsRequest{}) }, - "Post \"https://api.openai.com/v1/engines/ada/completions\": request error", + "Post \"https://api.openai.com/v1/edits\": request error", }, { - "CompletionStreamWithEngine", + "Embeddings", func() (interface{}, error) { - var rsp *gpt3.CompletionResponse - onData := func(data *gpt3.CompletionResponse) { - rsp = data - } - return rsp, client.CompletionStreamWithEngine(ctx, gpt3.AdaEngine, gpt3.CompletionRequest{}, onData) + return client.Embeddings(ctx, gpt3.EmbeddingsRequest{}) }, - "Post \"https://api.openai.com/v1/engines/ada/completions\": request error", + "Post \"https://api.openai.com/v1/embeddings\": request error", }, { - "Edits", + "Files", func() (interface{}, error) { - return client.Edits(ctx, gpt3.EditsRequest{}) + return client.Files(ctx) }, - "Post \"https://api.openai.com/v1/edits\": request error", + "Get \"https://api.openai.com/v1/files\": request error", }, { - "Search", + "UploadFile", func() (interface{}, error) { - return client.Search(ctx, gpt3.SearchRequest{}) + return client.UploadFile(ctx, gpt3.UploadFileRequest{}) }, - "Post \"https://api.openai.com/v1/engines/davinci/search\": request error", + "Post \"https://api.openai.com/v1/files\": request error", }, { - "SearchWithEngine", + "DeleteFile", func() (interface{}, error) { - return client.SearchWithEngine(ctx, gpt3.AdaEngine, gpt3.SearchRequest{}) + return client.DeleteFile(ctx, "file-id") }, - "Post \"https://api.openai.com/v1/engines/ada/search\": request error", + "Delete \"https://api.openai.com/v1/files/file-id\": request error", }, { - "Embeddings", + "File", func() (interface{}, error) { - return client.Embeddings(ctx, gpt3.EmbeddingsRequest{}) + return client.File(ctx, "file-id") }, - "Post \"https://api.openai.com/v1/embeddings\": request error", + "Get \"https://api.openai.com/v1/files/file-id\": request error", + }, { + "FileContent", + func() (interface{}, error) { + return client.FileContent(ctx, "file-id") + }, + "Get \"https://api.openai.com/v1/files/file-id/content\": request error", + }, { + "ListFineTunes", + func() (interface{}, error) { + return client.ListFineTunes(ctx) + }, + "Get \"https://api.openai.com/v1/fine-tunes\": request error", + }, { + "FineTune", + func() (interface{}, error) { + return client.FineTune(ctx, "fine-tune-id") + }, + "Get \"https://api.openai.com/v1/fine-tunes/fine-tune-id\": request error", + }, { + "CancelFineTune", + func() (interface{}, error) { + return client.CancelFineTune(ctx, "fine-tune-id") + }, + "Post \"https://api.openai.com/v1/fine-tunes/fine-tune-id/cancel\": request error", + }, { + "FineTuneEvents", + func() (interface{}, error) { + return client.FineTuneEvents(ctx, gpt3.FineTuneEventsRequest{ + FineTuneID: "fine-tune-id", + }) + }, + "Get \"https://api.openai.com/v1/fine-tunes/fine-tune-id/events\": request error", + }, + { + "FineTuneStreamEvents", + func() (interface{}, error) { + var rsp *gpt3.FineTuneEvent + onData := func(data *gpt3.FineTuneEvent) { + rsp = data + } + return rsp, client.FineTuneStreamEvents(ctx, gpt3.FineTuneEventsRequest{ + FineTuneID: "fine-tune-id", + }, onData) + }, + "Get \"https://api.openai.com/v1/fine-tunes/fine-tune-id/events\": request error", + }, { + "DeleteFineTuneModel", + func() (interface{}, error) { + return client.DeleteFineTuneModel(ctx, "model-id") + }, + "Delete \"https://api.openai.com/v1/models/model-id\": request error", }, } @@ -133,7 +176,8 @@ func (errReader) Read(p []byte) (n int, err error) { func TestResponses(t *testing.T) { ctx := context.Background() rt, httpClient := fakeHttpClient() - client := gpt3.NewClient("test-key", gpt3.WithHTTPClient(httpClient)) + client, err := gpt3.NewClient("test-key", gpt3.WithHTTPClient(httpClient)) + assert.Nil(t, err) type testCase struct { name string @@ -143,31 +187,31 @@ func TestResponses(t *testing.T) { testCases := []testCase{ { - "Engines", + "Models", func() (interface{}, error) { - return client.Engines(ctx) + return client.Models(ctx) }, - &gpt3.EnginesResponse{ - Data: []gpt3.EngineObject{ - gpt3.EngineObject{ - ID: "123", - Object: "list", - Owner: "owner", - Ready: true, + &gpt3.ModelsResponse{ + Data: []gpt3.ModelObject{ + { + ID: "123", + Object: "list", + OwnedBy: "organization-owner", + Permissions: []string{}, }, }, }, }, { - "Engine", + "Model", func() (interface{}, error) { - return client.Engine(ctx, gpt3.DefaultEngine) + return client.Model(ctx, gpt3.DefaultModel) }, - &gpt3.EngineObject{ - ID: "123", - Object: "list", - Owner: "owner", - Ready: true, + &gpt3.ModelObject{ + ID: "123", + Object: "list", + OwnedBy: "organization-owner", + Permissions: []string{}, }, }, { @@ -181,7 +225,7 @@ func TestResponses(t *testing.T) { Created: 123456789, Model: "davinci-12", Choices: []gpt3.CompletionResponseChoice{ - gpt3.CompletionResponseChoice{ + { Text: "output", FinishReason: "stop", }, @@ -198,77 +242,233 @@ func TestResponses(t *testing.T) { }, nil, // streaming responses are tested separately }, { - "CompletionWithEngine", + "Embeddings", func() (interface{}, error) { - return client.CompletionWithEngine(ctx, gpt3.AdaEngine, gpt3.CompletionRequest{}) + return client.Embeddings(ctx, gpt3.EmbeddingsRequest{}) }, - &gpt3.CompletionResponse{ - ID: "123", - Object: "list", - Created: 123456789, - Model: "davinci-12", - Choices: []gpt3.CompletionResponseChoice{ - gpt3.CompletionResponseChoice{ - Text: "output", - FinishReason: "stop", + &gpt3.EmbeddingsResponse{ + Object: "list", + Data: []gpt3.EmbeddingsResult{{ + Object: "object", + Embedding: []float64{0.1, 0.2, 0.3}, + Index: 0, + }}, + Usage: gpt3.EmbeddingsUsage{ + PromptTokens: 1, + TotalTokens: 2, + }, + }, + }, { + "Files", + func() (interface{}, error) { + return client.Files(ctx) + }, + &gpt3.FilesResponse{ + Object: "list", + Data: []gpt3.FileObject{ + { + ID: "123", + Object: "object", + Bytes: 123, + CreatedAt: 123456789, + Filename: "file.txt", + Purpose: "fine-tune", }, }, }, }, { - "CompletionStreamWithEngine", + "File", func() (interface{}, error) { - var rsp *gpt3.CompletionResponse - onData := func(data *gpt3.CompletionResponse) { - rsp = data - } - return rsp, client.CompletionStreamWithEngine(ctx, gpt3.AdaEngine, gpt3.CompletionRequest{}, onData) + return client.File(ctx, "file-id") + }, + &gpt3.FileObject{ + ID: "123", + Object: "object", + Bytes: 123, + CreatedAt: 123456789, + Filename: "file.txt", + Purpose: "fine-tune", }, - nil, // streaming responses are tested separately }, { - "Search", + "UploadFile", func() (interface{}, error) { - return client.Search(ctx, gpt3.SearchRequest{}) + return client.UploadFile(ctx, gpt3.UploadFileRequest{ + File: "file.jsonl", + Purpose: "fine-tune", + }) + }, - &gpt3.SearchResponse{ - Data: []gpt3.SearchData{ - gpt3.SearchData{ - Document: 1, - Object: "search_result", - Score: 40.312, - }, - }, + &gpt3.FileObject{ + ID: "123", + Object: "object", + Bytes: 123, + CreatedAt: 123456789, + Filename: "file.txt", + Purpose: "fine-tune", }, }, { - "SearchWithEngine", + "DeleteFile", func() (interface{}, error) { - return client.SearchWithEngine(ctx, gpt3.AdaEngine, gpt3.SearchRequest{}) + return client.DeleteFile(ctx, "file-id") }, - &gpt3.SearchResponse{ - Data: []gpt3.SearchData{ - gpt3.SearchData{ - Document: 1, - Object: "search_result", - Score: 40.312, + nil, + }, { + "ListFineTunes", + func() (interface{}, error) { + return client.ListFineTunes(ctx) + }, + &gpt3.ListFineTunesResponse{ + Object: "list", + Data: []gpt3.FineTuneObject{ + { + ID: "123", + Object: "object", + Model: "davinci-12", + CreatedAt: 123456789, + Events: []gpt3.FineTuneEvent{}, + FineTuneModel: "davince:ft:123", + Hyperparams: gpt3.FineTuneHyperparams{ + BatchSize: 1, + LearningRateMultiplier: 1.0, + NEpochs: 1, + PromptLossWeight: 1.0, + }, + OrganizationID: "org-id", + ResultFiles: []gpt3.FileObject{ + { + ID: "123", + Object: "object", + Bytes: 123, + CreatedAt: 123456789, + Filename: "file.txt", + Purpose: "fine-tune", + }, + }, + Status: "complete", + ValidationFiles: []gpt3.FileObject{ + { + ID: "123", + Object: "object", + Bytes: 123, + CreatedAt: 123456789, + Filename: "file.txt", + Purpose: "fine-tune", + }, + }, + TrainingFiles: []gpt3.FileObject{ + { + ID: "123", + Object: "object", + Bytes: 123, + CreatedAt: 123456789, + Filename: "file.txt", + Purpose: "fine-tune", + }, + }, + UpdatedAt: 123456789, }, }, }, }, { - "Embeddings", + "FineTune", func() (interface{}, error) { - return client.Embeddings(ctx, gpt3.EmbeddingsRequest{}) + return client.FineTune(ctx, "fine-tune-id") }, - &gpt3.EmbeddingsResponse{ + &gpt3.FineTuneObject{ + ID: "123", + Object: "object", + Model: "davinci-12", + CreatedAt: 123456789, + Events: []gpt3.FineTuneEvent{}, + FineTuneModel: "davince:ft:123", + Hyperparams: gpt3.FineTuneHyperparams{}, + OrganizationID: "org-id", + ResultFiles: []gpt3.FileObject{}, + Status: "complete", + ValidationFiles: []gpt3.FileObject{}, + TrainingFiles: []gpt3.FileObject{}, + UpdatedAt: 123456789, + }, + }, { + "CreateFineTune", + func() (interface{}, error) { + return client.CreateFineTune(ctx, gpt3.CreateFineTuneRequest{}) + }, + &gpt3.FineTuneObject{ + ID: "123", + Object: "object", + Model: "davinci-12", + CreatedAt: 123456789, + Events: []gpt3.FineTuneEvent{}, + FineTuneModel: "davince:ft:123", + Hyperparams: gpt3.FineTuneHyperparams{}, + OrganizationID: "org-id", + ResultFiles: []gpt3.FileObject{}, + Status: "complete", + ValidationFiles: []gpt3.FileObject{}, + TrainingFiles: []gpt3.FileObject{}, + UpdatedAt: 123456789, + }, + }, { + "CancelFineTune", + func() (interface{}, error) { + return client.CancelFineTune(ctx, "fine-tune-id") + }, + &gpt3.FineTuneObject{ + ID: "123", + Object: "object", + Model: "davinci-12", + CreatedAt: 123456789, + Events: []gpt3.FineTuneEvent{}, + FineTuneModel: "davince:ft:123", + Hyperparams: gpt3.FineTuneHyperparams{}, + OrganizationID: "org-id", + ResultFiles: []gpt3.FileObject{}, + Status: "complete", + ValidationFiles: []gpt3.FileObject{}, + TrainingFiles: []gpt3.FileObject{}, + UpdatedAt: 123456789, + }, + }, { + "FineTuneEvents", + func() (interface{}, error) { + return client.FineTuneEvents(ctx, gpt3.FineTuneEventsRequest{ + FineTuneID: "fine-tune-id", + }) + }, + &gpt3.FineTuneEventsResponse{ Object: "list", - Data: []gpt3.EmbeddingsResult{{ - Object: "object", - Embedding: []float64{0.1, 0.2, 0.3}, - Index: 0, - }}, - Usage: gpt3.EmbeddingsUsage{ - PromptTokens: 1, - TotalTokens: 2, + Data: []gpt3.FineTuneEvent{ + { + Object: "object", + CreatedAt: 123456789, + Level: "info", + Message: "message", + }, }, }, + }, { + "FineTuneStreamEvents", + func() (interface{}, error) { + var events []gpt3.FineTuneEvent + onEvent := func(event *gpt3.FineTuneEvent) { + events = append(events, *event) + } + return nil, client.FineTuneStreamEvents(ctx, gpt3.FineTuneEventsRequest{ + FineTuneID: "fine-tune-id", + }, onEvent) + }, + nil, + }, { + "DeleteFineTuneModel", + func() (interface{}, error) { + return client.DeleteFineTuneModel(ctx, "model-id") + }, + &gpt3.DeleteFineTuneModelResponse{ + ID: "model-id", + Object: "object", + Deleted: true, + }, }, } @@ -357,3 +557,4 @@ func TestResponses(t *testing.T) { } // TODO: add streaming response tests +// TODO: add file content tests diff --git a/legacy_completions.go b/legacy_completions.go new file mode 100644 index 0000000..fa0bf22 --- /dev/null +++ b/legacy_completions.go @@ -0,0 +1,17 @@ +package gpt3 + +import "context" + +// These are the legacy methods that are deprecated and will be removed in a future version. + +func (c *client) CompletionWithEngine(ctx context.Context, engine string, request CompletionRequest) (*CompletionResponse, error) { + // CompletionWithEngine is deprecated. Use Completion instead. + request.Model = engine + return c.Completion(ctx, request) +} + +func (c *client) CompletionStreamWithEngine(ctx context.Context, engine string, request CompletionRequest, onData func(*CompletionResponse)) error { + // CompletionStreamWithEngine is deprecated. Use CompletionStream instead. + request.Model = engine + return c.CompletionStream(ctx, request, onData) +} diff --git a/models.go b/models.go index 884906c..bfaa150 100644 --- a/models.go +++ b/models.go @@ -1,194 +1,73 @@ package gpt3 -import "fmt" - -// APIError represents an error that occured on an API -type APIError struct { - StatusCode int `json:"status_code"` - Message string `json:"message"` - Type string `json:"type"` -} - -func (e APIError) Error() string { - return fmt.Sprintf("[%d:%s] %s", e.StatusCode, e.Type, e.Message) -} - -// APIErrorResponse is the full error respnose that has been returned by an API. -type APIErrorResponse struct { - Error APIError `json:"error"` -} - -// EngineObject contained in an engine reponse -type EngineObject struct { - ID string `json:"id"` - Object string `json:"object"` - Owner string `json:"owner"` - Ready bool `json:"ready"` -} - -// EnginesResponse is returned from the Engines API -type EnginesResponse struct { - Data []EngineObject `json:"data"` - Object string `json:"object"` -} - -// CompletionRequest is a request for the completions API -type CompletionRequest struct { - // A list of string prompts to use. - // TODO there are other prompt types here for using token integers that we could add support for. - Prompt []string `json:"prompt"` - // How many tokens to complete up to. Max of 512 - MaxTokens *int `json:"max_tokens,omitempty"` - // Sampling temperature to use - Temperature *float32 `json:"temperature,omitempty"` - // Alternative to temperature for nucleus sampling - TopP *float32 `json:"top_p,omitempty"` - // How many choice to create for each prompt - N *int `json:"n"` - // Include the probabilities of most likely tokens - LogProbs *int `json:"logprobs"` - // Echo back the prompt in addition to the completion - Echo bool `json:"echo"` - // Up to 4 sequences where the API will stop generating tokens. Response will not contain the stop sequence. - Stop []string `json:"stop,omitempty"` - // PresencePenalty number between 0 and 1 that penalizes tokens that have already appeared in the text so far. - PresencePenalty float32 `json:"presence_penalty"` - // FrequencyPenalty number between 0 and 1 that penalizes tokens on existing frequency in the text so far. - FrequencyPenalty float32 `json:"frequency_penalty"` - - // Whether to stream back results or not. Don't set this value in the request yourself - // as it will be overriden depending on if you use CompletionStream or Completion methods. - Stream bool `json:"stream,omitempty"` -} - -// EditsRequest is a request for the edits API -type EditsRequest struct { - // ID of the model to use. You can use the List models API to see all of your available models, or see our Model overview for descriptions of them. - Model string `json:"model"` - // The input text to use as a starting point for the edit. - Input string `json:"input"` - // The instruction that tells the model how to edit the prompt. - Instruction string `json:"instruction"` - // Sampling temperature to use - Temperature *float32 `json:"temperature,omitempty"` - // Alternative to temperature for nucleus sampling - TopP *float32 `json:"top_p,omitempty"` - // How many edits to generate for the input and instruction. Defaults to 1 - N *int `json:"n"` -} - -// EmbeddingsRequest is a request for the Embeddings API -type EmbeddingsRequest struct { - // Input text to get embeddings for, encoded as a string or array of tokens. To get embeddings - // for multiple inputs in a single request, pass an array of strings or array of token arrays. - // Each input must not exceed 2048 tokens in length. - Input []string `json:"input"` - // ID of the model to use - Model string `json:"model"` - // The request user is an optional parameter meant to be used to trace abusive requests - // back to the originating user. OpenAI states: - // "The [user] IDs should be a string that uniquely identifies each user. We recommend hashing - // their username or email address, in order to avoid sending us any identifying information. - // If you offer a preview of your product to non-logged in users, you can send a session ID - // instead." - User string `json:"user,omitempty"` -} - -// LogprobResult represents logprob result of Choice -type LogprobResult struct { - Tokens []string `json:"tokens"` - TokenLogprobs []float32 `json:"token_logprobs"` - TopLogprobs []map[string]float32 `json:"top_logprobs"` - TextOffset []int `json:"text_offset"` -} - -// CompletionResponseChoice is one of the choices returned in the response to the Completions API -type CompletionResponseChoice struct { - Text string `json:"text"` - Index int `json:"index"` - LogProbs LogprobResult `json:"logprobs"` - FinishReason string `json:"finish_reason"` -} - -// CompletionResponse is the full response from a request to the completions API -type CompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - Model string `json:"model"` - Choices []CompletionResponseChoice `json:"choices"` - Usage CompletionResponseUsage `json:"usage"` -} - -// CompletionResponseUsage is the object that returns how many tokens the completion's request used -type CompletionResponseUsage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -// EditsResponse is the full response from a request to the edits API -type EditsResponse struct { - Object string `json:"object"` - Created int `json:"created"` - Choices []EditsResponseChoice `json:"choices"` - Usage EditsResponseUsage `json:"usage"` -} - -// The inner result of a create embeddings request, containing the embeddings for a single input. -type EmbeddingsResult struct { - // The type of object returned (e.g., "list", "object") - Object string `json:"object"` - // The embedding data for the input - Embedding []float64 `json:"embedding"` - Index int `json:"index"` -} - -// The usage stats for an embeddings response -type EmbeddingsUsage struct { - // The number of tokens used by the prompt - PromptTokens int `json:"prompt_tokens"` - // The total tokens used - TotalTokens int `json:"total_tokens"` -} - -// EmbeddingsResponse is the response from a create embeddings request. -// -// See: https://beta.openai.com/docs/api-reference/embeddings/create -type EmbeddingsResponse struct { - Object string `json:"object"` - Data []EmbeddingsResult `json:"data"` - Usage EmbeddingsUsage `json:"usage"` -} - -// EditsResponseChoice is one of the choices returned in the response to the Edits API -type EditsResponseChoice struct { - Text string `json:"text"` - Index int `json:"index"` -} - -// EditsResponseUsage is a structure used in the response from a request to the edits API -type EditsResponseUsage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -// SearchRequest is a request for the document search API -type SearchRequest struct { - Documents []string `json:"documents"` - Query string `json:"query"` -} - -// SearchData is a single search result from the document search API -type SearchData struct { - Document int `json:"document"` - Object string `json:"object"` - Score float64 `json:"score"` -} - -// SearchResponse is the full response from a request to the document search API -type SearchResponse struct { - Data []SearchData `json:"data"` - Object string `json:"object"` +import ( + "context" + "fmt" +) + +// Model Types +const ( + TextAda001Model = "text-ada-001" + TextBabbage001Model = "text-babbage-001" + TextCurie001Model = "text-curie-001" + TextDavinci001Model = "text-davinci-001" + TextDavinci002Model = "text-davinci-002" + TextDavinci003Model = "text-davinci-003" + AdaModel = "ada" + BabbageModel = "babbage" + CurieModel = "curie" + DavinciModel = "davinci" + DefaultModel = DavinciModel +) + +// ModelObject +type ModelObject struct { + ID string `json:"id"` + Object string `json:"object"` + OwnedBy string `json:"owned_by"` + Permissions []string `json:"permissions"` +} + +// ModelsResponse is returned from the Models API +type ModelsResponse struct { + Data []ModelObject `json:"data"` + Object string `json:"object"` +} + +// Models lists the currently available models, and provides basic information about each +// option such as the owner and permissioning. +func (c *client) Models(ctx context.Context) (*ModelsResponse, error) { + req, err := c.newRequest(ctx, "GET", "/models", nil) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := new(ModelsResponse) + if err := getResponseObject(resp, output); err != nil { + return nil, err + } + return output, nil +} + +// Model retrieves a single model, providing basic information about the model such +// as the owner and permissioning. +func (c *client) Model(ctx context.Context, model string) (*ModelObject, error) { + req, err := c.newRequest(ctx, "GET", fmt.Sprintf("/models/%s", model), nil) + if err != nil { + return nil, err + } + resp, err := c.performRequest(req) + if err != nil { + return nil, err + } + + output := new(ModelObject) + if err := getResponseObject(resp, output); err != nil { + return nil, err + } + return output, nil } diff --git a/models_test.go b/models_test.go new file mode 100644 index 0000000..cddd9a3 --- /dev/null +++ b/models_test.go @@ -0,0 +1 @@ +package gpt3_test diff --git a/tools/tools.go b/tools/tools.go index 1316cf0..e6fa2b8 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -1,3 +1,4 @@ +//go:build tools // +build tools package tools