diff --git a/datasets.go b/datasets.go index 0b39edc..fb408c4 100644 --- a/datasets.go +++ b/datasets.go @@ -2,10 +2,8 @@ package gptscript import ( "context" - "encoding/base64" "encoding/json" "fmt" - "os" ) type DatasetElementMeta struct { @@ -15,7 +13,8 @@ type DatasetElementMeta struct { type DatasetElement struct { DatasetElementMeta `json:",inline"` - Contents []byte `json:"contents"` + Contents string `json:"contents"` + BinaryContents []byte `json:"binaryContents"` } type DatasetMeta struct { @@ -24,34 +23,17 @@ type DatasetMeta struct { Description string `json:"description"` } -type Dataset struct { - DatasetMeta `json:",inline"` - BaseDir string `json:"baseDir,omitempty"` - Elements map[string]DatasetElementMeta `json:"elements"` -} - type datasetRequest struct { - Input string `json:"input"` - WorkspaceID string `json:"workspaceID"` - DatasetToolRepo string `json:"datasetToolRepo"` - Env []string `json:"env"` -} - -type createDatasetArgs struct { - Name string `json:"datasetName"` - Description string `json:"datasetDescription"` -} - -type addDatasetElementArgs struct { - DatasetID string `json:"datasetID"` - ElementName string `json:"elementName"` - ElementDescription string `json:"elementDescription"` - ElementContent string `json:"elementContent"` + Input string `json:"input"` + DatasetTool string `json:"datasetTool"` + Env []string `json:"env"` } type addDatasetElementsArgs struct { - DatasetID string `json:"datasetID"` - Elements []DatasetElement `json:"elements"` + DatasetID string `json:"datasetID"` + Name string `json:"name"` + Description string `json:"description"` + Elements []DatasetElement `json:"elements"` } type listDatasetElementArgs struct { @@ -60,19 +42,14 @@ type listDatasetElementArgs struct { type getDatasetElementArgs struct { DatasetID string `json:"datasetID"` - Element string `json:"element"` + Element string `json:"name"` } -func (g *GPTScript) ListDatasets(ctx context.Context, workspaceID string) ([]DatasetMeta, error) { - if workspaceID == "" { - workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") - } - +func (g *GPTScript) ListDatasets(ctx context.Context) ([]DatasetMeta, error) { out, err := g.runBasicCommand(ctx, "datasets", datasetRequest{ - Input: "{}", - WorkspaceID: workspaceID, - DatasetToolRepo: g.globalOpts.DatasetToolRepo, - Env: g.globalOpts.Env, + Input: "{}", + DatasetTool: g.globalOpts.DatasetTool, + Env: g.globalOpts.Env, }) if err != nil { return nil, err @@ -85,98 +62,42 @@ func (g *GPTScript) ListDatasets(ctx context.Context, workspaceID string) ([]Dat return datasets, nil } -func (g *GPTScript) CreateDataset(ctx context.Context, workspaceID, name, description string) (Dataset, error) { - if workspaceID == "" { - workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") - } - - args := createDatasetArgs{ - Name: name, - Description: description, - } - argsJSON, err := json.Marshal(args) - if err != nil { - return Dataset{}, fmt.Errorf("failed to marshal dataset args: %w", err) - } - - out, err := g.runBasicCommand(ctx, "datasets/create", datasetRequest{ - Input: string(argsJSON), - WorkspaceID: workspaceID, - DatasetToolRepo: g.globalOpts.DatasetToolRepo, - Env: g.globalOpts.Env, - }) - if err != nil { - return Dataset{}, err - } - - var dataset Dataset - if err = json.Unmarshal([]byte(out), &dataset); err != nil { - return Dataset{}, err - } - return dataset, nil +type DatasetOptions struct { + Name, Description string } -func (g *GPTScript) AddDatasetElement(ctx context.Context, workspaceID, datasetID, elementName, elementDescription string, elementContent []byte) (DatasetElementMeta, error) { - if workspaceID == "" { - workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") - } - - args := addDatasetElementArgs{ - DatasetID: datasetID, - ElementName: elementName, - ElementDescription: elementDescription, - ElementContent: base64.StdEncoding.EncodeToString(elementContent), - } - argsJSON, err := json.Marshal(args) - if err != nil { - return DatasetElementMeta{}, fmt.Errorf("failed to marshal element args: %w", err) - } - - out, err := g.runBasicCommand(ctx, "datasets/add-element", datasetRequest{ - Input: string(argsJSON), - WorkspaceID: workspaceID, - DatasetToolRepo: g.globalOpts.DatasetToolRepo, - Env: g.globalOpts.Env, - }) - if err != nil { - return DatasetElementMeta{}, err - } - - var element DatasetElementMeta - if err = json.Unmarshal([]byte(out), &element); err != nil { - return DatasetElementMeta{}, err - } - return element, nil +func (g *GPTScript) CreateDatasetWithElements(ctx context.Context, elements []DatasetElement, options ...DatasetOptions) (string, error) { + return g.AddDatasetElements(ctx, "", elements, options...) } -func (g *GPTScript) AddDatasetElements(ctx context.Context, workspaceID, datasetID string, elements []DatasetElement) error { - if workspaceID == "" { - workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") - } - +func (g *GPTScript) AddDatasetElements(ctx context.Context, datasetID string, elements []DatasetElement, options ...DatasetOptions) (string, error) { args := addDatasetElementsArgs{ DatasetID: datasetID, Elements: elements, } + + for _, opt := range options { + if opt.Name != "" { + args.Name = opt.Name + } + if opt.Description != "" { + args.Description = opt.Description + } + } + argsJSON, err := json.Marshal(args) if err != nil { - return fmt.Errorf("failed to marshal element args: %w", err) + return "", fmt.Errorf("failed to marshal element args: %w", err) } - _, err = g.runBasicCommand(ctx, "datasets/add-elements", datasetRequest{ - Input: string(argsJSON), - WorkspaceID: workspaceID, - DatasetToolRepo: g.globalOpts.DatasetToolRepo, - Env: g.globalOpts.Env, + return g.runBasicCommand(ctx, "datasets/add-elements", datasetRequest{ + Input: string(argsJSON), + DatasetTool: g.globalOpts.DatasetTool, + Env: g.globalOpts.Env, }) - return err } -func (g *GPTScript) ListDatasetElements(ctx context.Context, workspaceID, datasetID string) ([]DatasetElementMeta, error) { - if workspaceID == "" { - workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") - } - +func (g *GPTScript) ListDatasetElements(ctx context.Context, datasetID string) ([]DatasetElementMeta, error) { args := listDatasetElementArgs{ DatasetID: datasetID, } @@ -186,10 +107,9 @@ func (g *GPTScript) ListDatasetElements(ctx context.Context, workspaceID, datase } out, err := g.runBasicCommand(ctx, "datasets/list-elements", datasetRequest{ - Input: string(argsJSON), - WorkspaceID: workspaceID, - DatasetToolRepo: g.globalOpts.DatasetToolRepo, - Env: g.globalOpts.Env, + Input: string(argsJSON), + DatasetTool: g.globalOpts.DatasetTool, + Env: g.globalOpts.Env, }) if err != nil { return nil, err @@ -202,11 +122,7 @@ func (g *GPTScript) ListDatasetElements(ctx context.Context, workspaceID, datase return elements, nil } -func (g *GPTScript) GetDatasetElement(ctx context.Context, workspaceID, datasetID, elementName string) (DatasetElement, error) { - if workspaceID == "" { - workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID") - } - +func (g *GPTScript) GetDatasetElement(ctx context.Context, datasetID, elementName string) (DatasetElement, error) { args := getDatasetElementArgs{ DatasetID: datasetID, Element: elementName, @@ -217,10 +133,9 @@ func (g *GPTScript) GetDatasetElement(ctx context.Context, workspaceID, datasetI } out, err := g.runBasicCommand(ctx, "datasets/get-element", datasetRequest{ - Input: string(argsJSON), - WorkspaceID: workspaceID, - DatasetToolRepo: g.globalOpts.DatasetToolRepo, - Env: g.globalOpts.Env, + Input: string(argsJSON), + DatasetTool: g.globalOpts.DatasetTool, + Env: g.globalOpts.Env, }) if err != nil { return DatasetElement{}, err diff --git a/datasets_test.go b/datasets_test.go index 6fc7ed9..c1f1a92 100644 --- a/datasets_test.go +++ b/datasets_test.go @@ -2,6 +2,7 @@ package gptscript import ( "context" + "os" "testing" "github.com/stretchr/testify/require" @@ -11,66 +12,87 @@ func TestDatasets(t *testing.T) { workspaceID, err := g.CreateWorkspace(context.Background(), "directory") require.NoError(t, err) + client, err := NewGPTScript(GlobalOptions{ + OpenAIAPIKey: os.Getenv("OPENAI_API_KEY"), + Env: append(os.Environ(), "GPTSCRIPT_WORKSPACE_ID="+workspaceID), + }) + require.NoError(t, err) + defer func() { _ = g.DeleteWorkspace(context.Background(), workspaceID) }() - // Create a dataset - dataset, err := g.CreateDataset(context.Background(), workspaceID, "test-dataset", "This is a test dataset") - require.NoError(t, err) - require.Equal(t, "test-dataset", dataset.Name) - require.Equal(t, "This is a test dataset", dataset.Description) - require.Equal(t, 0, len(dataset.Elements)) - - // Add an element - elementMeta, err := g.AddDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element", "This is a test element", []byte("This is the content")) + datasetID, err := client.CreateDatasetWithElements(context.Background(), []DatasetElement{ + { + DatasetElementMeta: DatasetElementMeta{ + Name: "test-element-1", + Description: "This is a test element 1", + }, + Contents: "This is the content 1", + }, + }, DatasetOptions{ + Name: "test-dataset", + Description: "this is a test dataset", + }) require.NoError(t, err) - require.Equal(t, "test-element", elementMeta.Name) - require.Equal(t, "This is a test element", elementMeta.Description) - // Add two more - err = g.AddDatasetElements(context.Background(), workspaceID, dataset.ID, []DatasetElement{ + // Add three more elements + _, err = client.AddDatasetElements(context.Background(), datasetID, []DatasetElement{ { DatasetElementMeta: DatasetElementMeta{ Name: "test-element-2", Description: "This is a test element 2", }, - Contents: []byte("This is the content 2"), + Contents: "This is the content 2", }, { DatasetElementMeta: DatasetElementMeta{ Name: "test-element-3", Description: "This is a test element 3", }, - Contents: []byte("This is the content 3"), + Contents: "This is the content 3", + }, + { + DatasetElementMeta: DatasetElementMeta{ + Name: "binary-element", + Description: "this element has binary contents", + }, + BinaryContents: []byte("binary contents"), }, }) require.NoError(t, err) // Get the first element - element, err := g.GetDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element") + element, err := client.GetDatasetElement(context.Background(), datasetID, "test-element-1") require.NoError(t, err) - require.Equal(t, "test-element", element.Name) - require.Equal(t, "This is a test element", element.Description) - require.Equal(t, []byte("This is the content"), element.Contents) + require.Equal(t, "test-element-1", element.Name) + require.Equal(t, "This is a test element 1", element.Description) + require.Equal(t, "This is the content 1", element.Contents) // Get the third element - element, err = g.GetDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element-3") + element, err = client.GetDatasetElement(context.Background(), datasetID, "test-element-3") require.NoError(t, err) require.Equal(t, "test-element-3", element.Name) require.Equal(t, "This is a test element 3", element.Description) - require.Equal(t, []byte("This is the content 3"), element.Contents) + require.Equal(t, "This is the content 3", element.Contents) + + // Get the binary element + element, err = client.GetDatasetElement(context.Background(), datasetID, "binary-element") + require.NoError(t, err) + require.Equal(t, "binary-element", element.Name) + require.Equal(t, "this element has binary contents", element.Description) + require.Equal(t, []byte("binary contents"), element.BinaryContents) // List elements in the dataset - elements, err := g.ListDatasetElements(context.Background(), workspaceID, dataset.ID) + elements, err := client.ListDatasetElements(context.Background(), datasetID) require.NoError(t, err) - require.Equal(t, 3, len(elements)) + require.Equal(t, 4, len(elements)) // List datasets - datasets, err := g.ListDatasets(context.Background(), workspaceID) + datasets, err := client.ListDatasets(context.Background()) require.NoError(t, err) require.Equal(t, 1, len(datasets)) + require.Equal(t, datasetID, datasets[0].ID) require.Equal(t, "test-dataset", datasets[0].Name) - require.Equal(t, "This is a test dataset", datasets[0].Description) - require.Equal(t, dataset.ID, datasets[0].ID) + require.Equal(t, "this is a test dataset", datasets[0].Description) } diff --git a/opts.go b/opts.go index e08d217..07507e2 100644 --- a/opts.go +++ b/opts.go @@ -11,7 +11,7 @@ type GlobalOptions struct { DefaultModelProvider string `json:"DefaultModelProvider"` CacheDir string `json:"CacheDir"` Env []string `json:"env"` - DatasetToolRepo string `json:"DatasetToolRepo"` + DatasetTool string `json:"DatasetTool"` WorkspaceTool string `json:"WorkspaceTool"` } @@ -46,7 +46,7 @@ func completeGlobalOptions(opts ...GlobalOptions) GlobalOptions { result.OpenAIBaseURL = firstSet(opt.OpenAIBaseURL, result.OpenAIBaseURL) result.DefaultModel = firstSet(opt.DefaultModel, result.DefaultModel) result.DefaultModelProvider = firstSet(opt.DefaultModelProvider, result.DefaultModelProvider) - result.DatasetToolRepo = firstSet(opt.DatasetToolRepo, result.DatasetToolRepo) + result.DatasetTool = firstSet(opt.DatasetTool, result.DatasetTool) result.WorkspaceTool = firstSet(opt.WorkspaceTool, result.WorkspaceTool) result.Env = append(result.Env, opt.Env...) }