Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions api/datasets/datasets.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/url"
"strconv"

"github.com/braintrustdata/braintrust-sdk-go/internal/https"
Expand Down Expand Up @@ -104,31 +105,31 @@ func (a *API) Fetch(ctx context.Context, datasetID string, cursor string, limit
// Query searches for datasets by name, version, or other criteria.
func (a *API) Query(ctx context.Context, params QueryParams) (*QueryResponse, error) {
// Build query parameters
queryParams := make(map[string]string)
queryParams := url.Values{}

if params.ID != "" {
queryParams["id"] = params.ID
queryParams.Set("id", params.ID)
}
if params.Name != "" {
queryParams["dataset_name"] = params.Name
queryParams.Set("dataset_name", params.Name)
}
if params.Version != "" {
queryParams["version"] = params.Version
queryParams.Set("version", params.Version)
}
if params.ProjectID != "" {
queryParams["project_id"] = params.ProjectID
queryParams.Set("project_id", params.ProjectID)
}
if params.ProjectName != "" {
queryParams["project_name"] = params.ProjectName
queryParams.Set("project_name", params.ProjectName)
}
if params.Limit > 0 {
queryParams["limit"] = strconv.Itoa(params.Limit)
queryParams.Set("limit", strconv.Itoa(params.Limit))
}
if params.StartingAfter != "" {
queryParams["starting_after"] = params.StartingAfter
queryParams.Set("starting_after", params.StartingAfter)
}
if params.EndingBefore != "" {
queryParams["ending_before"] = params.EndingBefore
queryParams.Set("ending_before", params.EndingBefore)
}

resp, err := a.client.GET(ctx, "/v1/dataset", queryParams)
Expand Down
120 changes: 115 additions & 5 deletions api/experiments/experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/url"
"strconv"

"github.com/braintrustdata/braintrust-sdk-go/internal/https"
Expand Down Expand Up @@ -61,19 +62,34 @@ func (a *API) Register(ctx context.Context, name, projectID string, opts Registe

// List returns a list of experiments filtered by the given parameters.
func (a *API) List(ctx context.Context, params ListParams) (*ListResponse, error) {
queryParams := make(map[string]string)
queryParams := url.Values{}

if params.ProjectID != "" {
queryParams["project_id"] = params.ProjectID
queryParams.Set("project_id", params.ProjectID)
}
if params.ProjectName != "" {
queryParams.Set("project_name", params.ProjectName)
}
if params.ExperimentName != "" {
queryParams["experiment_name"] = params.ExperimentName
queryParams.Set("experiment_name", params.ExperimentName)
}
if params.OrgName != "" {
queryParams["org_name"] = params.OrgName
queryParams.Set("org_name", params.OrgName)
}
if params.Limit > 0 {
queryParams["limit"] = strconv.Itoa(params.Limit)
queryParams.Set("limit", strconv.Itoa(params.Limit))
}
if params.StartingAfter != "" {
queryParams.Set("starting_after", params.StartingAfter)
}
if params.EndingBefore != "" {
queryParams.Set("ending_before", params.EndingBefore)
}
if len(params.IDs) > 0 {
// Add multiple values for the ids parameter
for _, id := range params.IDs {
queryParams.Add("ids", id)
}
}

resp, err := a.client.GET(ctx, "/v1/experiment", queryParams)
Expand Down Expand Up @@ -110,6 +126,100 @@ func (a *API) Get(ctx context.Context, experimentID string) (*Experiment, error)
return &result, nil
}

// Update partially updates an experiment by its ID.
// Only the fields provided in params will be updated.
func (a *API) Update(ctx context.Context, experimentID string, params UpdateParams) (*Experiment, error) {
if experimentID == "" {
return nil, fmt.Errorf("experiment ID is required")
}

resp, err := a.client.PATCH(ctx, "/v1/experiment/"+experimentID, params)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()

var result Experiment
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("error decoding response: %w", err)
}

return &result, nil
}

// InsertEvents inserts events into an experiment.
func (a *API) InsertEvents(ctx context.Context, experimentID string, events []ExperimentEvent) (*InsertEventsResponse, error) {
if experimentID == "" {
return nil, fmt.Errorf("experiment ID is required")
}
if len(events) == 0 {
return nil, fmt.Errorf("at least one event is required")
}

reqBody := InsertEventsRequest{Events: events}
resp, err := a.client.POST(ctx, "/v1/experiment/"+experimentID+"/insert", reqBody)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()

var result InsertEventsResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("error decoding response: %w", err)
}

return &result, nil
}

// FetchEvents retrieves events from an experiment with optional pagination.
// This uses the POST variant of the fetch endpoint, which accepts filter parameters in the request body.
func (a *API) FetchEvents(ctx context.Context, experimentID string, params FetchEventsParams) (*FetchEventsResponse, error) {
if experimentID == "" {
return nil, fmt.Errorf("experiment ID is required")
}

resp, err := a.client.POST(ctx, "/v1/experiment/"+experimentID+"/fetch", params)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()

var result FetchEventsResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("error decoding response: %w", err)
}

return &result, nil
}

// Summarize returns summary statistics for an experiment, including score averages and comparisons.
func (a *API) Summarize(ctx context.Context, experimentID string, params SummarizeParams) (*SummarizeResponse, error) {
if experimentID == "" {
return nil, fmt.Errorf("experiment ID is required")
}

queryParams := url.Values{}
if params.SummarizeScores {
queryParams.Set("summarize_scores", "true")
}
if params.ComparisonExperimentID != "" {
queryParams.Set("comparison_experiment_id", params.ComparisonExperimentID)
}

resp, err := a.client.GET(ctx, "/v1/experiment/"+experimentID+"/summarize", queryParams)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()

var result SummarizeResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("error decoding response: %w", err)
}

return &result, nil
}

// Delete deletes an experiment by its ID.
func (a *API) Delete(ctx context.Context, experimentID string) error {
if experimentID == "" {
Expand Down
Loading