diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 707d40fe8..1a5d43bb2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -8,34 +8,16 @@ jobs: runs-on: ubuntu-latest steps: - - name: Set up Go 1.x - uses: actions/setup-go@v5 - with: - go-version: 1.22 - - name: Check out code - uses: actions/checkout@v2 + uses: actions/checkout@v4 - - name: Cache Funnel binary - uses: actions/cache@v3 + - name: Set up Go 1.x + uses: actions/setup-go@v5 with: - path: ./funnel - key: ${{ runner.os }}-funnel-bin-${{ hashFiles('**/go.sum') }}-${{ github.ref }} - restore-keys: | - ${{ runner.os }}-funnel-bin-${{ github.ref }} - ${{ runner.os }}-funnel-bin- + go-version-file: go.mod - - name: Build Funnel (if cache doesn't exist) - run: | - if [ ! -f ./funnel ]; then - make build - fi - - - name: Cache Funnel binary (after build) - uses: actions/cache@v3 - with: - path: ./funnel - key: ${{ runner.os }}-funnel-bin-${{ hashFiles('**/go.sum') }}-${{ github.ref }} + - name: Build Funnel + run: make build - name: Upload Funnel binary as artifact uses: actions/upload-artifact@v4 diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 4e161220b..07b6c859c 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -12,12 +12,14 @@ jobs: name: lint runs-on: ubuntu-latest steps: - - uses: actions/setup-go@v3 + - name: Check out code + uses: actions/checkout@v4 + + - name: Set up Go 1.x + uses: actions/setup-go@v5 with: - go-version: 1.21 + go-version-file: go.mod - - uses: actions/checkout@v3 - - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: @@ -31,7 +33,7 @@ jobs: --skip-dirs "funnel-work-dir" \ -e '.*bundle.go' -e ".*pb.go" -e ".*pb.gw.go" \ ./... - + - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: @@ -44,12 +46,13 @@ jobs: unitTest: runs-on: ubuntu-latest steps: + - name: Check out code + uses: actions/checkout@v4 + - name: Set up Go 1.x uses: actions/setup-go@v5 with: - go-version: 1.21 - - name: Check out code - uses: actions/checkout@v2 + go-version-file: go.mod - name: Unit Tests run: make test-verbose @@ -61,13 +64,13 @@ jobs: runs-on: ubuntu-latest needs: build steps: + - name: Check out code + uses: actions/checkout@v4 + - name: Set up Go 1.x uses: actions/setup-go@v5 with: - go-version: 1.21 - - - name: Check out code - uses: actions/checkout@v2 + go-version-file: go.mod - name: Download funnel bin uses: actions/download-artifact@v4 @@ -85,17 +88,19 @@ jobs: runs-on: ubuntu-latest needs: build steps: + - name: Check out code + uses: actions/checkout@v4 + - name: Set up Go 1.x uses: actions/setup-go@v5 with: - go-version: 1.21 - - name: Check out code - uses: actions/checkout@v2 + go-version-file: go.mod - name: Download funnel bin uses: actions/download-artifact@v4 with: name: funnel + - name: Badger Test run: | chmod +x funnel @@ -105,12 +110,13 @@ jobs: runs-on: ubuntu-latest needs: build steps: + - name: Check out code + uses: actions/checkout@v4 + - name: Set up Go 1.x uses: actions/setup-go@v5 with: - go-version: 1.21 - - name: Check out code - uses: actions/checkout@v2 + go-version-file: go.mod - name: Download funnel bin uses: actions/download-artifact@v4 @@ -126,12 +132,13 @@ jobs: runs-on: ubuntu-latest needs: build steps: + - name: Check out code + uses: actions/checkout@v4 + - name: Set up Go 1.x uses: actions/setup-go@v5 with: - go-version: 1.21 - - name: Check out code - uses: actions/checkout@v2 + go-version-file: go.mod - name: Download funnel bin uses: actions/download-artifact@v4 @@ -144,4 +151,3 @@ jobs: make start-generic-s3 sleep 10 make test-generic-s3 - diff --git a/Makefile b/Makefile index cbf330a2f..1cfbb7af4 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,7 @@ build: @touch version/version.go @go build -ldflags '$(VERSION_LDFLAGS)' -buildvcs=false . -# Build an unoptimized version of the code for use during debugging +# Build an unoptimized version of the code for use during debugging # https://go.dev/doc/gdb debug: @go install -gcflags=all="-N -l" @@ -119,7 +119,7 @@ test-verbose: start-elasticsearch: @docker rm -f funnel-es-test > /dev/null 2>&1 || echo - @docker run -d --name funnel-es-test -p 9200:9200 -p 9300:9300 -e "discovery.type=single-node" -e "xpack.security.enabled=false" docker.elastic.co/elasticsearch/elasticsearch:5.6.3 > /dev/null + @docker run -d --name funnel-es-test -p 9200:9200 -p 9300:9300 -e "discovery.type=single-node" -e "xpack.security.enabled=false" docker.io/elastic/elasticsearch:8.17.1 > /dev/null test-elasticsearch: @go test ./tests/core/ -funnel-config `pwd`/tests/elastic.config.yml @@ -140,7 +140,7 @@ test-badger: start-dynamodb: @docker rm -f funnel-dynamodb-test > /dev/null 2>&1 || echo - @docker run -d --name funnel-dynamodb-test -p 18000:8000 docker.io/dwmkerr/dynamodb:38 -sharedDb > /dev/null + @docker run -d --name funnel-dynamodb-test -p 18000:8000 docker.io/amazon/dynamodb-local > /dev/null test-dynamodb: @go test ./tests/core/ -funnel-config `pwd`/tests/dynamo.config.yml @@ -228,7 +228,7 @@ snapshot: release-dep docker: docker build -t quay.io/ohsu-comp-bio/funnel:latest ./ -# Create a release on Github using GoReleaser +# Create a release on Github using GoReleaser release: @go get github.com/buchanae/github-release-notes @goreleaser \ diff --git a/cmd/server/run.go b/cmd/server/run.go index 5ba82e98a..5daee75dc 100644 --- a/cmd/server/run.go +++ b/cmd/server/run.go @@ -269,6 +269,7 @@ func NewServer(ctx context.Context, conf config.Config, log *logger.Logger) (*Se BasicAuth: conf.Server.BasicAuth, OidcAuth: conf.Server.OidcAuth, DisableHTTPCache: conf.Server.DisableHTTPCache, + TaskAccess: conf.Server.TaskAccess, Log: log, Tasks: &server.TaskService{ Name: conf.Server.ServiceName, diff --git a/cmd/task/list.go b/cmd/task/list.go index a0fc8f7b0..431562ce7 100644 --- a/cmd/task/list.go +++ b/cmd/task/list.go @@ -18,7 +18,8 @@ func List(server, taskView, pageToken, stateFilter string, tagsFilter []string, return err } - _, err = getTaskView(taskView) + taskViewInt, err := getTaskView(taskView) + taskView = tes.View_name[taskViewInt] if err != nil { return err } diff --git a/compute/hpc_backend.go b/compute/hpc_backend.go index 022e7e395..fc4f6a60c 100644 --- a/compute/hpc_backend.go +++ b/compute/hpc_backend.go @@ -104,11 +104,6 @@ func (b *HPCBackend) Cancel(ctx context.Context, taskID string) error { return err } - // only cancel tasks in a QUEUED state - if task.State != tes.State_QUEUED { - return nil - } - backendID := getBackendTaskID(task, b.Name) if backendID == "" { return fmt.Errorf("no %s_id found in metadata for task %s", b.Name, taskID) diff --git a/config/config.go b/config/config.go index 45f14ae93..4bb2b4fde 100644 --- a/config/config.go +++ b/config/config.go @@ -54,6 +54,7 @@ type Config struct { type BasicCredential struct { User string Password string + Admin bool } type OidcAuth struct { @@ -63,6 +64,7 @@ type OidcAuth struct { RedirectURL string RequireScope string RequireAudience string + Admins []string } // RPCClient describes configuration for gRPC clients @@ -88,6 +90,12 @@ type Server struct { BasicAuth []BasicCredential OidcAuth OidcAuth DisableHTTPCache bool + + // Defines task access and visibility by options: + // "All" (default) – all tasks are visible to everyone + // "Owner" - tasks are visible to the users who created them + // "OwnerOrAdmin" - extends "Owner" by allowing Admin-users see everything + TaskAccess string } // HTTPAddress returns the HTTP address based on HostName and HTTPPort @@ -236,8 +244,13 @@ type MongoDB struct { // Elastic configures access to an Elasticsearch database. type Elastic struct { - IndexPrefix string - URL string + IndexPrefix string + URL string + Username string + Password string + CloudID string + APIKey string + ServiceToken string } // Kafka configure access to a Kafka topic for task event reading/writing. diff --git a/config/datastore/README.md b/config/datastore/README.md new file mode 100644 index 000000000..aefd7e47b --- /dev/null +++ b/config/datastore/README.md @@ -0,0 +1,26 @@ +# Google Datastore Usage + +When Funnel is configured to use the Google Datastore as its database, some +additional configuration steps need to be taken. + +## Datastore Access + +Authentication to the Google Datastore needs to be configured through Google +Cloud CLI as described here: +https://cloud.google.com/datastore/docs/reference/libraries?hl=en#authentication + +## Datastore Indexes + +For retrieving a list of tasks, Funnel needs [composite +indexes](https://cloud.google.com/datastore/docs/concepts/indexes?hl=en) to be +defined in the Datastore using the Google Cloud CLI and the +[index.yaml](./index.yaml) file: + +```shell +gcloud datastore indexes create path/to/index.yaml --database='funnel' +``` + +Note that it will take a bit of time before the indexes are ready for accepting +requests. You can see the status of those indexes through the Google Cloud +console: https://console.cloud.google.com/datastore/databases/ (**Indexes** +view under the target database). diff --git a/config/datastore/index.yaml b/config/datastore/index.yaml new file mode 100644 index 000000000..ff9d33eac --- /dev/null +++ b/config/datastore/index.yaml @@ -0,0 +1,51 @@ +# These index-definitions need to be imported to Datastore (if `database: datastore`) +# Using Google Cloud CLI: gcloud datastore indexes create path/to/index.yaml + +indexes: + +- kind: Task + properties: + - name: Owner + - name: State + - name: TagStrings + - name: CreationTime + direction: desc + +- kind: Task + properties: + - name: Owner + - name: State + - name: CreationTime + direction: desc + +- kind: Task + properties: + - name: Owner + - name: TagStrings + - name: CreationTime + direction: desc + +- kind: Task + properties: + - name: Owner + - name: CreationTime + direction: desc + +- kind: Task + properties: + - name: State + - name: TagStrings + - name: CreationTime + direction: desc + +- kind: Task + properties: + - name: State + - name: CreationTime + direction: desc + +- kind: Task + properties: + - name: TagStrings + - name: CreationTime + direction: desc diff --git a/config/default-config.yaml b/config/default-config.yaml index e0d71b2b9..67c363aa7 100644 --- a/config/default-config.yaml +++ b/config/default-config.yaml @@ -32,6 +32,9 @@ Server: # If used, make sure to properly restrict access to the config file # (e.g. chmod 600 funnel.config.yml) # BasicAuth: + # - User: admin + # Password: oejf023moq + # Admin: true # - User: user1 # Password: abc123 # - User: user2 @@ -44,7 +47,6 @@ Server: # # Example: https://example.org/oidc/.well-knwon/openid-configuration # ServiceConfigURL: # # Client ID and secret are sent with the token introspection request - # # (Basic authentication): # ClientId: # ClientSecret: # # The URL where OIDC should redirect after login (keep the path '/login') @@ -53,11 +55,23 @@ Server: # RequireScope: # # Optional: if specified, this audience value must be in the token: # RequireAudience: + # # List of usernames (JWT sub-claim) to be granted Admin-role: + # Admins: + # - admin.username.one@example.org + # - admin.username.two@example.org # Include a "Cache-Control: no-store" HTTP header in Get/List responses # to prevent caching by intermediary services. DisableHTTPCache: true + # Defines task access and visibility by options: + # "All" (default) - all tasks are visible to everyone + # "Owner" - tasks are visible to the users who created them + # "OwnerOrAdmin" - extends "Owner" by allowing Admin-users see everything + # Owner is the username associated with the task. + # Owners (usernames) were not recorded to tasks before Funnel 0.11.1. + TaskAccess: All + RPCClient: # RPC server address ServerAddress: localhost:9090 @@ -158,6 +172,16 @@ Elastic: IndexPrefix: funnel # URL of the elasticsearch server. URL: http://localhost:9200 + # Optional. Username for HTTP Basic Authentication. + Username: + # Optional. Password for HTTP Basic Authentication. + Password: + # Optional. Endpoint for the Elastic Service (https://elastic.co/cloud). + CloudID: + # Optional. Base64-encoded token for authorization; if set, overrides username/password and service token. + APIKey: + # Optional. Service token for authorization; if set, overrides username/password. + ServiceToken: # Google Cloud Datastore task database. Datastore: @@ -170,7 +194,7 @@ Datastore: MongoDB: # Addrs holds the addresses for the seed servers. Addrs: - - mongodb://localhost + - localhost # Database is the database name used within MongoDB to store funnel data. Database: funnel # Timeout is the amount of time to wait for a server to respond when @@ -316,7 +340,7 @@ AWSBatch: # Kubernetes describes the configuration for the Kubernetes compute backend. Kubernetes: # The executor used to execute tasks. Available executors: docker, kubernetes - Executor: "docker" + Executor: "docker" # Turn off task state reconciler. When enabled, Funnel communicates with Kubernetes # to find tasks that are stuck in a queued state or errored and # updates the task state accordingly. diff --git a/config/default.go b/config/default.go index f873a859c..0dbdc4214 100644 --- a/config/default.go +++ b/config/default.go @@ -30,6 +30,7 @@ func DefaultConfig() Config { RPCPort: "9090", ServiceName: "Funnel", DisableHTTPCache: true, + TaskAccess: "All", } c := Config{ @@ -106,7 +107,7 @@ func DefaultConfig() Config { IndexPrefix: "funnel", }, MongoDB: MongoDB{ - Addrs: []string{"mongodb://localhost"}, + Addrs: []string{"localhost"}, Timeout: Duration(time.Minute * 5), Database: "funnel", }, diff --git a/database/badger/events.go b/database/badger/events.go index 05a988707..98f7735b9 100644 --- a/database/badger/events.go +++ b/database/badger/events.go @@ -8,6 +8,7 @@ import ( badger "github.com/dgraph-io/badger/v2" proto "github.com/golang/protobuf/proto" "github.com/ohsu-comp-bio/funnel/events" + "github.com/ohsu-comp-bio/funnel/server" "github.com/ohsu-comp-bio/funnel/tes" "github.com/ohsu-comp-bio/funnel/util" ) @@ -22,15 +23,8 @@ func (db *Badger) WriteEvent(ctx context.Context, req *events.Event) error { RandomizationFactor: 0.5, MaxTries: 50, ShouldRetry: func(err error) bool { - // Don't retry not found errors. - if err == tes.ErrNotFound { - return false - } - // Don't retry on state transition errors. - if _, ok := err.(*tes.TransitionError); ok { - return false - } - return true + _, isTransitionError := err.(*tes.TransitionError) + return !isTransitionError && err != tes.ErrNotFound && err != tes.ErrNotPermitted }, } @@ -54,11 +48,17 @@ func (db *Badger) writeEvent(ctx context.Context, req *events.Event) error { return fmt.Errorf("marshaling task to bytes: %s", err) } + // Store the username as task owner: + err = txn.Set(ownerKey(task.Id), []byte(server.GetUsername(ctx))) + if err != nil { + return fmt.Errorf("storing owner info: %s", err) + } + return txn.Set(taskKey(task.Id), val) } // The rest of the events below all update a task, so we need to make sure it exists. - task, err := db.getTask(txn, req.Id) + task, err := getTask(txn, req.Id, ctx) if err != nil { return err } diff --git a/database/badger/new.go b/database/badger/new.go index 36f3bbec0..d803d0041 100644 --- a/database/badger/new.go +++ b/database/badger/new.go @@ -37,6 +37,7 @@ func (db *Badger) Close() { } var taskKeyPrefix = []byte("tasks") +var ownerKeyPrefix = []byte("owners") func taskKey(id string) []byte { idb := []byte(id) @@ -45,3 +46,12 @@ func taskKey(id string) []byte { key = append(key, idb...) return key } + +func ownerKey(id string) []byte { + return append(ownerKeyPrefix, []byte(id)...) +} + +func ownerKeyFromTaskKey(taskKey []byte) []byte { + taskId := taskKey[len(taskKeyPrefix):] + return append(ownerKeyPrefix, taskId...) +} diff --git a/database/badger/tes.go b/database/badger/tes.go index 4d4546e49..fd656e330 100644 --- a/database/badger/tes.go +++ b/database/badger/tes.go @@ -1,11 +1,13 @@ package badger import ( + "bytes" "context" "fmt" badger "github.com/dgraph-io/badger/v2" proto "github.com/golang/protobuf/proto" + "github.com/ohsu-comp-bio/funnel/server" "github.com/ohsu-comp-bio/funnel/tes" ) @@ -14,7 +16,7 @@ func (db *Badger) GetTask(ctx context.Context, req *tes.GetTaskRequest) (*tes.Ta var task *tes.Task err := db.db.View(func(txn *badger.Txn) error { - t, err := db.getTask(txn, req.Id) + t, err := getTask(txn, req.Id, ctx) task = t return err }) @@ -60,6 +62,18 @@ func (db *Badger) ListTasks(ctx context.Context, req *tes.ListTasksRequest) (*te taskLoop: for ; it.Valid() && len(tasks) < pageSize; it.Next() { + // Iterator items are reverse-ordered by keys (starting with + // task-keys). So by the time the task-key prefix is passed, only + // owner-keys remains, and they can be skipped. + if !bytes.HasPrefix(it.Item().Key(), taskKeyPrefix) { + break + } + + taskOwner := getTaskOwner(txn, ownerKeyFromTaskKey(it.Item().Key())) + if !isAccessible(ctx, taskOwner) { + continue + } + var val []byte err := it.Item().Value(func(d []byte) error { val = copyBytes(d) @@ -118,11 +132,14 @@ func (db *Badger) ListTasks(ctx context.Context, req *tes.ListTasksRequest) (*te return &out, nil } -func (db *Badger) getTask(txn *badger.Txn, id string) (*tes.Task, error) { +func getTask(txn *badger.Txn, id string, ctx context.Context) (*tes.Task, error) { item, err := txn.Get(taskKey(id)) if err == badger.ErrKeyNotFound { return nil, tes.ErrNotFound } + if !isAccessible(ctx, getTaskOwner(txn, ownerKey(id))) { + return nil, tes.ErrNotPermitted + } if err != nil { return nil, fmt.Errorf("loading item: %s", err) } @@ -144,6 +161,21 @@ func (db *Badger) getTask(txn *badger.Txn, id string) (*tes.Task, error) { return task, nil } +func getTaskOwner(txn *badger.Txn, ownerKey []byte) string { + taskOwner := "" + if item, err := txn.Get(ownerKey); err == nil { + _ = item.Value(func(d []byte) error { + taskOwner = string(d) + return nil + }) + } + return taskOwner +} + +func isAccessible(ctx context.Context, taskOwner string) bool { + return server.GetUser(ctx).IsAccessible(taskOwner) +} + func copyBytes(in []byte) []byte { out := make([]byte, len(in)) copy(out, in) diff --git a/database/boltdb/events.go b/database/boltdb/events.go index e5f649a71..0ebed6753 100644 --- a/database/boltdb/events.go +++ b/database/boltdb/events.go @@ -8,6 +8,7 @@ import ( "github.com/boltdb/bolt" proto "github.com/golang/protobuf/proto" "github.com/ohsu-comp-bio/funnel/events" + "github.com/ohsu-comp-bio/funnel/server" "github.com/ohsu-comp-bio/funnel/tes" ) @@ -38,6 +39,7 @@ func (taskBolt *BoltDB) WriteEvent(ctx context.Context, req *events.Event) error err = taskBolt.db.Update(func(tx *bolt.Tx) error { tx.Bucket(TaskBucket).Put(idBytes, taskString) tx.Bucket(TaskState).Put(idBytes, []byte(tes.State_QUEUED.String())) + tx.Bucket(TaskOwner).Put(idBytes, []byte(server.GetUsername(ctx))) return nil }) if err != nil { @@ -52,7 +54,7 @@ func (taskBolt *BoltDB) WriteEvent(ctx context.Context, req *events.Event) error // Check that the task exists err = taskBolt.db.View(func(tx *bolt.Tx) error { - _, err := getTaskView(tx, req.Id, tes.View_MINIMAL) + _, err := getTaskView(tx, req.Id, tes.View_MINIMAL, nil) return err }) if err != nil { @@ -65,6 +67,9 @@ func (taskBolt *BoltDB) WriteEvent(ctx context.Context, req *events.Event) error switch req.Type { case events.Type_TASK_STATE: err = taskBolt.db.Update(func(tx *bolt.Tx) error { + if err := checkOwner(tx, req.Id, ctx); err != nil { + return err + } return transitionTaskState(tx, req.Id, req.GetState()) }) diff --git a/database/boltdb/new.go b/database/boltdb/new.go index 105520390..d64e893d0 100644 --- a/database/boltdb/new.go +++ b/database/boltdb/new.go @@ -19,6 +19,9 @@ var TaskBucket = []byte("tasks") // task ID -> nil var TasksQueued = []byte("tasks-queued") +// TaskOwner maps: task ID -> owner string +var TaskOwner = []byte("tasks-owner") + // TaskState maps: task ID -> state string var TaskState = []byte("tasks-state") @@ -73,6 +76,9 @@ func (taskBolt *BoltDB) Init() error { if tx.Bucket(TasksQueued) == nil { tx.CreateBucket(TasksQueued) } + if tx.Bucket(TaskOwner) == nil { + tx.CreateBucket(TaskOwner) + } if tx.Bucket(TaskState) == nil { tx.CreateBucket(TaskState) } diff --git a/database/boltdb/scheduler.go b/database/boltdb/scheduler.go index dfdac5d0f..35f7aed2b 100644 --- a/database/boltdb/scheduler.go +++ b/database/boltdb/scheduler.go @@ -35,7 +35,7 @@ func (taskBolt *BoltDB) ReadQueue(n int) []*tes.Task { c := tx.Bucket(TasksQueued).Cursor() for k, _ := c.First(); k != nil && len(tasks) < n; k, _ = c.Next() { id := string(k) - task, _ := getTaskView(tx, id, tes.View_FULL) + task, _ := getTaskView(tx, id, tes.View_FULL, nil) tasks = append(tasks, task) } return nil diff --git a/database/boltdb/tes.go b/database/boltdb/tes.go index e360c56fb..433e307cc 100644 --- a/database/boltdb/tes.go +++ b/database/boltdb/tes.go @@ -7,6 +7,7 @@ import ( "github.com/boltdb/bolt" proto "github.com/golang/protobuf/proto" + "github.com/ohsu-comp-bio/funnel/server" "github.com/ohsu-comp-bio/funnel/tes" "golang.org/x/net/context" ) @@ -22,55 +23,66 @@ func getTaskState(tx *bolt.Tx, id string) tes.State { return tes.State(v) } -func loadMinimalTaskView(tx *bolt.Tx, id string, task *tes.Task) error { +func loadTask(tx *bolt.Tx, id string, task *tes.Task, ctx context.Context) error { b := tx.Bucket(TaskBucket).Get([]byte(id)) if b == nil { return tes.ErrNotFound } + + if err := checkOwner(tx, id, ctx); err != nil { + return err + } + + if task != nil { + proto.Unmarshal(b, task) + task.State = getTaskState(tx, id) + } + + return nil +} + +func loadMinimalTaskView(tx *bolt.Tx, id string, task *tes.Task, ctx context.Context) error { + if err := loadTask(tx, id, nil, ctx); err != nil { + return err + } task.Id = id task.State = getTaskState(tx, id) return nil } -func loadBasicTaskView(tx *bolt.Tx, id string, task *tes.Task) error { - b := tx.Bucket(TaskBucket).Get([]byte(id)) - if b == nil { - return tes.ErrNotFound +func loadBasicTaskView(tx *bolt.Tx, id string, task *tes.Task, ctx context.Context) error { + err := loadTask(tx, id, task, ctx) + if err != nil { + return err } - proto.Unmarshal(b, task) + loadTaskLogs(tx, task) // remove content from inputs - inputs := []*tes.Input{} for _, v := range task.Inputs { v.Content = "" - inputs = append(inputs, v) } - task.Inputs = inputs - return loadMinimalTaskView(tx, id, task) + return nil } -func loadFullTaskView(tx *bolt.Tx, id string, task *tes.Task) error { - b := tx.Bucket(TaskBucket).Get([]byte(id)) - if b == nil { - return tes.ErrNotFound +func loadFullTaskView(tx *bolt.Tx, id string, task *tes.Task, ctx context.Context) error { + err := loadTask(tx, id, task, ctx) + if err != nil { + return err } - proto.Unmarshal(b, task) loadTaskLogs(tx, task) // Load executor stdout/err for _, tl := range task.Logs { for j, el := range tl.Logs { - key := fmt.Sprint(id, j) + key := []byte(fmt.Sprint(id, j)) - b := tx.Bucket(ExecutorStdout).Get([]byte(key)) - if b != nil { + if b := tx.Bucket(ExecutorStdout).Get(key); b != nil { el.Stdout = string(b) } - b = tx.Bucket(ExecutorStderr).Get([]byte(key)) - if b != nil { + if b := tx.Bucket(ExecutorStderr).Get(key); b != nil { el.Stderr = string(b) } } @@ -87,7 +99,7 @@ func loadFullTaskView(tx *bolt.Tx, id string, task *tes.Task) error { task.Logs[0].SystemLogs = syslogs } - return loadMinimalTaskView(tx, id, task) + return nil } func loadTaskLogs(tx *bolt.Tx, task *tes.Task) { @@ -122,24 +134,24 @@ func (taskBolt *BoltDB) GetTask(ctx context.Context, req *tes.GetTaskRequest) (* if !ok { return fmt.Errorf("Unknown view: %s", req.View) } - task, err = getTaskView(tx, req.Id, tes.View(tes.View_value[req.View])) + task, err = getTaskView(tx, req.Id, tes.View(tes.View_value[req.View]), ctx) return err }) return task, err } -func getTaskView(tx *bolt.Tx, id string, view tes.View) (*tes.Task, error) { +func getTaskView(tx *bolt.Tx, id string, view tes.View, ctx context.Context) (*tes.Task, error) { var err error task := &tes.Task{} switch { case view == tes.View_MINIMAL: - err = loadMinimalTaskView(tx, id, task) + err = loadMinimalTaskView(tx, id, task, ctx) case view == tes.View_BASIC: - err = loadBasicTaskView(tx, id, task) + err = loadBasicTaskView(tx, id, task, ctx) case view == tes.View_FULL: - err = loadFullTaskView(tx, id, task) + err = loadFullTaskView(tx, id, task, ctx) default: err = fmt.Errorf("Unknown view: %s", view.String()) } @@ -154,6 +166,7 @@ func (taskBolt *BoltDB) ListTasks(ctx context.Context, req *tes.ListTasksRequest if req.View == tes.Minimal.String() && req.GetTags() != nil { view = tes.View_BASIC.String() } + viewMode := tes.View(tes.View_value[view]) pageSize := tes.GetPageSize(req.GetPageSize()) taskBolt.db.View(func(tx *bolt.Tx) error { @@ -176,7 +189,12 @@ func (taskBolt *BoltDB) ListTasks(ctx context.Context, req *tes.ListTasksRequest taskLoop: for ; k != nil && i < pageSize; k, _ = c.Prev() { - task, _ := getTaskView(tx, string(k), tes.View_FULL) + taskId := string(k) + + task, err := getTaskView(tx, taskId, tes.View_BASIC, ctx) + if err != nil { + continue taskLoop // Skip the task as access to it was not confirmed + } if req.State != tes.Unknown && req.State != task.State { continue taskLoop @@ -193,7 +211,9 @@ func (taskBolt *BoltDB) ListTasks(ctx context.Context, req *tes.ListTasksRequest } } - task, _ = getTaskView(tx, string(k), tes.View(tes.View_value[view])) + if viewMode != tes.View_BASIC { + task, _ = getTaskView(tx, taskId, viewMode, ctx) + } tasks = append(tasks, task) i++ @@ -211,3 +231,20 @@ func (taskBolt *BoltDB) ListTasks(ctx context.Context, req *tes.ListTasksRequest return &out, nil } + +func checkOwner(tx *bolt.Tx, taskId string, ctx context.Context) error { + // Skip access-check for system-related operations where ctx is undefined: + if ctx == nil || server.GetUser(ctx).CanSeeAllTasks() { + return nil + } + + taskOwner := "" + if owner := tx.Bucket(TaskOwner).Get([]byte(taskId)); owner != nil { + taskOwner = string(owner) + } + + if server.GetUser(ctx).IsAccessible(taskOwner) { + return nil + } + return tes.ErrNotPermitted +} diff --git a/database/datastore/events.go b/database/datastore/events.go index 4178e1930..8a95628be 100644 --- a/database/datastore/events.go +++ b/database/datastore/events.go @@ -2,9 +2,11 @@ package datastore import ( "context" + "fmt" "cloud.google.com/go/datastore" "github.com/ohsu-comp-bio/funnel/events" + "github.com/ohsu-comp-bio/funnel/server" "github.com/ohsu-comp-bio/funnel/tes" ) @@ -14,10 +16,16 @@ func (d *Datastore) WriteEvent(ctx context.Context, e *events.Event) error { switch e.Type { case events.Type_TASK_CREATED: - putKeys, putData := marshalTask(e.GetTask()) + putKeys, putData := marshalTask(e.GetTask(), ctx) _, err := d.client.PutMulti(ctx, putKeys, putData) return err + case events.Type_TASK_STATE: + return d.taskUpdateInTransaction(ctx, e, updateState) + + case events.Type_SYSTEM_LOG: + return d.appendTaskSystemLog(ctx, e) + case events.Type_EXECUTOR_STDOUT: _, err := d.client.Put(ctx, stdoutKey(e.Id, e.Attempt, e.Index), marshalEvent(e)) return err @@ -26,64 +34,131 @@ func (d *Datastore) WriteEvent(ctx context.Context, e *events.Event) error { _, err := d.client.Put(ctx, stderrKey(e.Id, e.Attempt, e.Index), marshalEvent(e)) return err - case events.Type_SYSTEM_LOG: - _, err := d.client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { - props := datastore.PropertyList{} - err := tx.Get(sysLogsKey(e.Id, e.Attempt), &props) - if err != nil && err != datastore.ErrNoSuchEntity { - return err - } - - p := &part{} - err = datastore.LoadStruct(p, props) - if err != nil { - return err - } - - _, err = tx.Put(sysLogsKey(e.Id, e.Attempt), &part{ - Type: sysLogsPart, - Attempt: int(e.Attempt), - Index: int(e.Index), - SystemLogs: append(p.SystemLogs, e.SysLogString()), - }) + default: + return d.taskUpdateInTransaction(ctx, e, updateTaskLog) + } +} + +type taskUpdater func(ctx context.Context, task *task, e *events.Event) error + +func (d *Datastore) taskUpdateInTransaction(ctx context.Context, event *events.Event, update taskUpdater) error { + _, err := d.client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + task := &task{} + taskKey := taskKey(event.Id) + + if err := tx.Get(taskKey, task); err == datastore.ErrNoSuchEntity { + return tes.ErrNotFound + } else if err != nil { return err - }) + } else if !server.GetUser(ctx).IsAccessible(task.Owner) { + return tes.ErrNotPermitted + } + + if err := update(ctx, task, event); err != nil { + return err + } + + _, err := tx.Put(taskKey, task) + return err + }) + return err +} + +// This method is focused on updating the State of a Task. +// In Datastore, the whole Task-entity is updated, though just one field changes. +func updateState(ctx context.Context, task *task, e *events.Event) error { + from := tes.State(task.State) + to := e.GetState() + if err := tes.ValidateTransition(from, to); err != nil { return err + } + task.State = int32(to.Number()) + return nil +} + +// This method is focused on adding/updating a TaskLog of a Task. +// In Datastore, the whole Task-entity gets updated. +func updateTaskLog(ctx context.Context, task *task, e *events.Event) error { + targetLog := getTaskLog(task, e) + + switch e.Type { + + case events.Type_TASK_START_TIME: + targetLog.StartTime = e.GetStartTime() + + case events.Type_TASK_END_TIME: + targetLog.EndTime = e.GetEndTime() + + case events.Type_TASK_OUTPUTS: + targetLog.Outputs = e.GetOutputs().Value + + case events.Type_TASK_METADATA: + targetLog.Metadata = mergeKvs(targetLog.Metadata, e.GetMetadata().Value) + + case events.Type_EXECUTOR_START_TIME: + getExecutorLog(targetLog, e).StartTime = e.GetStartTime() + + case events.Type_EXECUTOR_END_TIME: + getExecutorLog(targetLog, e).EndTime = e.GetEndTime() + + case events.Type_EXECUTOR_EXIT_CODE: + getExecutorLog(targetLog, e).ExitCode = e.GetExitCode() default: - _, err := d.client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { - props := datastore.PropertyList{} - err := tx.Get(taskKey(e.Id), &props) - if err == datastore.ErrNoSuchEntity { - return tes.ErrNotFound - } - if err != nil { - return err - } - - task := &tes.Task{} - if err := unmarshalTask(task, props); err != nil { - return err - } - - if e.Type == events.Type_TASK_STATE { - from := task.State - to := e.GetState() - if err := tes.ValidateTransition(from, to); err != nil { - return err - } - } - - tb := events.TaskBuilder{Task: task} - err = tb.WriteEvent(context.Background(), e) - if err != nil { - return err - } - - putKeys, putData := marshalTask(task) - _, err = tx.PutMulti(putKeys, putData) + return fmt.Errorf("[Datastore] function updateTaskLog does not support event: %q", e.Type.String()) + } + + return nil +} + +// This method is focused on adding/updating SystemLogs of a TaskLog. +// In Datastore, SystemLogs are stored under separate keys, so the Task-entity is not updated. +func (d *Datastore) appendTaskSystemLog(ctx context.Context, event *events.Event) error { + _, err := d.client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + p := &part{} + err := tx.Get(sysLogsKey(event.Id, event.Attempt), p) + if err != nil && err != datastore.ErrNoSuchEntity { return err + } + + _, err = tx.Put(sysLogsKey(event.Id, event.Attempt), &part{ + Type: sysLogsPart, + Attempt: int(event.Attempt), + Index: int(event.Index), + SystemLogs: append(p.SystemLogs, event.SysLogString()), }) return err + }) + return err +} + +// Retrieves the tasklog from the provided task as referenced in the event (Attempt). +// If the Attempt referes to a non-existing tasklog, it is added to task.Logs. +func getTaskLog(task *task, e *events.Event) *tasklog { + targetLogIndex := int(e.Attempt) + + // Grow slice length if necessary + for range targetLogIndex + 1 - len(task.TaskLogs) { + item := tasklog{ + TaskLog: &tes.TaskLog{}, + Metadata: []kv{}, + } + task.TaskLogs = append(task.TaskLogs, item) + } + + return &task.TaskLogs[targetLogIndex] +} + +// Retrieves the ExecutorLog from the provided tasklog as referenced in the +// event (Index). If the Index referes to a non-existing executor log, it is +// added to taskLog.Logs. +func getExecutorLog(taskLog *tasklog, e *events.Event) *tes.ExecutorLog { + execLogIndex := int(e.Index) + + // Grow slice length if necessary + for j := len(taskLog.Logs); j <= execLogIndex; j++ { + taskLog.Logs = append(taskLog.Logs, &tes.ExecutorLog{}) } + + return taskLog.Logs[execLogIndex] } diff --git a/database/datastore/props.go b/database/datastore/props.go index 2d132d43c..0a5554821 100644 --- a/database/datastore/props.go +++ b/database/datastore/props.go @@ -1,11 +1,13 @@ package datastore import ( + "context" "fmt" "net/url" "cloud.google.com/go/datastore" "github.com/ohsu-comp-bio/funnel/events" + "github.com/ohsu-comp-bio/funnel/server" "github.com/ohsu-comp-bio/funnel/tes" ) @@ -88,6 +90,7 @@ type part struct { type task struct { Id, CreationTime string `datastore:",omitempty"` // nolint State int32 + Owner string Name, Description string `datastore:",noindex,omitempty"` Executors []executor `datastore:",noindex,omitempty"` Inputs []param `datastore:",noindex,omitempty"` @@ -100,8 +103,8 @@ type task struct { } type tasklog struct { - *tes.TaskLog Metadata []kv `datastore:",noindex,omitempty"` + *tes.TaskLog } type resources struct { @@ -122,10 +125,11 @@ type param struct { Type int32 `datastore:",noindex,omitempty"` } -func marshalTask(t *tes.Task) ([]*datastore.Key, []interface{}) { +func marshalTask(t *tes.Task, ctx context.Context) ([]*datastore.Key, []interface{}) { z := &task{ Id: t.Id, State: int32(t.State), + Owner: server.GetUsername(ctx), CreationTime: t.CreationTime, Name: t.Name, Description: t.Description, @@ -197,20 +201,17 @@ func marshalTask(t *tes.Task) ([]*datastore.Key, []interface{}) { return keys, data } -func unmarshalTask(z *tes.Task, props datastore.PropertyList) error { - c := &task{} - err := datastore.LoadStruct(c, props) - if err != nil { - return err +func (c *task) unmarshal() *tes.Task { + z := &tes.Task{ + Id: c.Id, + CreationTime: c.CreationTime, + State: tes.State(c.State), + Name: c.Name, + Description: c.Description, + Volumes: c.Volumes, + Tags: unmarshalMap(c.Tags), } - z.Id = c.Id - z.CreationTime = c.CreationTime - z.State = tes.State(c.State) - z.Name = c.Name - z.Description = c.Description - z.Volumes = c.Volumes - z.Tags = unmarshalMap(c.Tags) if c.Resources != nil { z.Resources = &tes.Resources{ CpuCores: int32(c.Resources.CpuCores), @@ -254,7 +255,11 @@ func unmarshalTask(z *tes.Task, props datastore.PropertyList) error { tl.Metadata = unmarshalMap(i.Metadata) z.Logs = append(z.Logs, tl) } - return nil + // Ensure at least one empty log for the JSON response: + if len(z.Logs) == 0 { + z.Logs = []*tes.TaskLog{{}} + } + return z } func unmarshalPart(t *tes.Task, props datastore.PropertyList) error { @@ -323,3 +328,24 @@ func stringifyMap(m map[string]string) []string { func encodeKV(k, v string) string { return url.QueryEscape(k) + ":" + url.QueryEscape(v) } + +func mergeKvs(existing []kv, updates map[string]string) []kv { + tmp := map[string]string{} + + // Populate 'tmp' with existing values: + for _, kv := range existing { + tmp[kv.Key] = kv.Value + } + + // Update 'tmp' with new values: + for k, v := range updates { + tmp[k] = v + } + + // Convert 'tmp' to a kv-list: + result := []kv{} + for k, v := range tmp { + result = append(result, kv{Key: k, Value: v}) + } + return result +} diff --git a/database/datastore/tes.go b/database/datastore/tes.go index 8938aafe8..5d3ec24c5 100644 --- a/database/datastore/tes.go +++ b/database/datastore/tes.go @@ -2,6 +2,7 @@ package datastore import ( "cloud.google.com/go/datastore" + "github.com/ohsu-comp-bio/funnel/server" "github.com/ohsu-comp-bio/funnel/tes" "golang.org/x/net/context" "google.golang.org/api/iterator" @@ -10,21 +11,21 @@ import ( // GetTask implements the TES GetTask interface. func (d *Datastore) GetTask(ctx context.Context, req *tes.GetTaskRequest) (*tes.Task, error) { key := taskKey(req.Id) + entity := &task{} - var props datastore.PropertyList - err := d.client.Get(ctx, key, &props) + err := d.client.Get(ctx, key, entity) if err == datastore.ErrNoSuchEntity { return nil, tes.ErrNotFound } if err != nil { return nil, err } - - task := &tes.Task{} - if err := unmarshalTask(task, props); err != nil { - return nil, err + if !server.GetUser(ctx).IsAccessible(entity.Owner) { + return nil, tes.ErrNotPermitted } + task := entity.unmarshal() + switch req.View { case tes.View_MINIMAL.String(): task = task.GetMinimalView() @@ -74,7 +75,7 @@ func (d *Datastore) ListTasks(ctx context.Context, req *tes.ListTasksRequest) (* page := req.PageToken size := tes.GetPageSize(req.GetPageSize()) - q := datastore.NewQuery("Task").KeysOnly().Limit(size).Order("-CreationTime") + q := datastore.NewQuery("Task").Limit(size).Order("-CreationTime") if page != "" { c, err := datastore.DecodeCursor(page) @@ -84,6 +85,10 @@ func (d *Datastore) ListTasks(ctx context.Context, req *tes.ListTasksRequest) (* q = q.Start(c) } + if userInfo := server.GetUser(ctx); !userInfo.CanSeeAllTasks() { + q = q.FilterField("Owner", "=", userInfo.Username) + } + if req.State != tes.Unknown { q = q.FilterField("State", "=", int32(req.State)) } @@ -93,33 +98,20 @@ func (d *Datastore) ListTasks(ctx context.Context, req *tes.ListTasksRequest) (* } var tasks []*tes.Task - var keys []*datastore.Key var parts []*datastore.Key byID := map[string]*tes.Task{} it := d.client.Run(ctx, q) for { - key, err := it.Next(nil) + t := &task{} + _, err := it.Next(t) if err == iterator.Done { break } if err != nil { return nil, err } - keys = append(keys, key) - } - - proplists := make([]datastore.PropertyList, len(keys)) - err := d.client.GetMulti(ctx, keys, proplists) - if err != nil { - return nil, err - } - - for _, props := range proplists { - task := &tes.Task{} - if err := unmarshalTask(task, props); err != nil { - return nil, err - } + task := t.unmarshal() switch req.View { case tes.View_MINIMAL.String(): @@ -142,7 +134,7 @@ func (d *Datastore) ListTasks(ctx context.Context, req *tes.ListTasksRequest) (* resp := &tes.ListTasksResponse{Tasks: tasks} - if len(keys) == size { + if len(tasks) == size { c, err := it.Cursor() if err != nil { return nil, err diff --git a/database/dynamodb/events.go b/database/dynamodb/events.go index c9a8a0bb6..939bb0fbe 100644 --- a/database/dynamodb/events.go +++ b/database/dynamodb/events.go @@ -58,6 +58,9 @@ func (db *DynamoDB) WriteEvent(ctx context.Context, e *events.Event) error { if response.Item == nil { return tes.ErrNotFound } + if !isAccessible(ctx, response) { + return tes.ErrNotPermitted + } err = dynamodbattribute.UnmarshalMap(response.Item, ¤t) if err != nil { @@ -244,9 +247,7 @@ func (db *DynamoDB) WriteEvent(ctx context.Context, e *events.Event) error { item.ExpressionAttributeNames = expr.Names() item.ExpressionAttributeValues = expr.Values() item.UpdateExpression = expr.Update() - if *expr.Condition() != "" { - item.ConditionExpression = expr.Condition() - } + item.ConditionExpression = expr.Condition() _, err = db.client.UpdateItemWithContext(ctx, item) return checkErrNotFound(err) diff --git a/database/dynamodb/tes.go b/database/dynamodb/tes.go index c1ba09659..6114131e9 100644 --- a/database/dynamodb/tes.go +++ b/database/dynamodb/tes.go @@ -2,12 +2,12 @@ package dynamodb import ( "fmt" - "strconv" - "strings" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + "github.com/aws/aws-sdk-go/service/dynamodb/expression" + "github.com/ohsu-comp-bio/funnel/server" "github.com/ohsu-comp-bio/funnel/tes" "golang.org/x/net/context" ) @@ -29,10 +29,12 @@ func (db *DynamoDB) GetTask(ctx context.Context, req *tes.GetTaskRequest) (*tes. if err != nil { return nil, err } - if response.Item == nil { return nil, tes.ErrNotFound } + if !isAccessible(ctx, response) { + return nil, tes.ErrNotPermitted + } err = dynamodbattribute.UnmarshalMap(response.Item, &task) if err != nil { @@ -45,57 +47,56 @@ func (db *DynamoDB) GetTask(ctx context.Context, req *tes.GetTaskRequest) (*tes. // ListTasks returns a list of taskIDs func (db *DynamoDB) ListTasks(ctx context.Context, req *tes.ListTasksRequest) (*tes.ListTasksResponse, error) { - var tasks []*tes.Task - var query *dynamodb.QueryInput - pageSize := int64(tes.GetPageSize(req.GetPageSize())) + filters := []expression.ConditionBuilder{} - query = &dynamodb.QueryInput{ - TableName: aws.String(db.taskTable), - Limit: aws.Int64(pageSize), - ScanIndexForward: aws.Bool(false), - ConsistentRead: aws.Bool(true), - KeyConditionExpression: aws.String(fmt.Sprintf("%s = :v1", db.partitionKey)), - ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{ - ":v1": { - S: aws.String(db.partitionValue), - }, - }, + if userInfo := server.GetUser(ctx); !userInfo.CanSeeAllTasks() { + filters = append(filters, expression.Name("owner").Equal(expression.Value(userInfo.Username))) } - filterParts := []string{} if req.State != tes.Unknown { - query.ExpressionAttributeNames = map[string]*string{ - "#state": aws.String("state"), - } - query.ExpressionAttributeValues[":stateFilter"] = &dynamodb.AttributeValue{ - N: aws.String(strconv.Itoa(int(req.State))), - } - filterParts = append(filterParts, "#state = :stateFilter") + num := int(req.State) + filters = append(filters, expression.Name("state").Equal(expression.Value(num))) } for k, v := range req.GetTags() { - tmpl := "tags.%s = :%sFilter" - filterParts = append(filterParts, fmt.Sprintf(tmpl, k, k)) + var fieldValue expression.ValueBuilder if v == "" { - query.ExpressionAttributeValues[fmt.Sprintf(":%sFilter", k)] = &dynamodb.AttributeValue{ - NULL: aws.Bool(true), - } + fieldValue = expression.Value(expression.Null) } else { - query.ExpressionAttributeValues[fmt.Sprintf(":%sFilter", k)] = &dynamodb.AttributeValue{ - S: aws.String(v), - } + fieldValue = expression.Value(v) } + filters = append(filters, expression.Name("tags."+k).Equal(fieldValue)) } - if len(filterParts) > 0 { - query.FilterExpression = aws.String(strings.Join(filterParts, " AND ")) - } + exp := expression.NewBuilder(). + WithKeyCondition(expression.Key(db.partitionKey).Equal(expression.Value(db.partitionValue))) if req.View == tes.View_MINIMAL.String() { - query.ExpressionAttributeNames = map[string]*string{ - "#state": aws.String("state"), - } - query.ProjectionExpression = aws.String("id, #state") + exp = exp.WithProjection(expression.NamesList(expression.Name("id"), expression.Name("state"))) + } + + if len(filters) == 1 { + exp.WithFilter(filters[0]) + } else if len(filters) > 1 { + exp.WithFilter(expression.And(filters[0], filters[1], filters[2:]...)) + } + + eb, err := exp.Build() + if err != nil { + return nil, err + } + + pageSize := int64(tes.GetPageSize(req.GetPageSize())) + query := &dynamodb.QueryInput{ + TableName: aws.String(db.taskTable), + Limit: aws.Int64(pageSize), + ScanIndexForward: aws.Bool(false), + ConsistentRead: aws.Bool(true), + KeyConditionExpression: eb.KeyCondition(), + ExpressionAttributeNames: eb.Names(), + ExpressionAttributeValues: eb.Values(), + FilterExpression: eb.Filter(), + ProjectionExpression: eb.Projection(), } if req.PageToken != "" { @@ -125,6 +126,7 @@ func (db *DynamoDB) ListTasks(ctx context.Context, req *tes.ListTasksRequest) (* } } + var tasks []*tes.Task err = dynamodbattribute.UnmarshalListOfMaps(response.Items, &tasks) if err != nil { return nil, fmt.Errorf("failed to DynamoDB unmarshal Tasks, %v", err) @@ -134,7 +136,7 @@ func (db *DynamoDB) ListTasks(ctx context.Context, req *tes.ListTasksRequest) (* Tasks: tasks, } - if response.LastEvaluatedKey != nil { + if len(tasks) > 0 && response.LastEvaluatedKey != nil { out.NextPageToken = *response.LastEvaluatedKey["id"].S } diff --git a/database/dynamodb/util.go b/database/dynamodb/util.go index 06d1dd1e6..8b8243bce 100644 --- a/database/dynamodb/util.go +++ b/database/dynamodb/util.go @@ -3,7 +3,6 @@ package dynamodb import ( "context" "fmt" - "log" "strconv" "time" @@ -11,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" + "github.com/ohsu-comp-bio/funnel/server" "github.com/ohsu-comp-bio/funnel/tes" ) @@ -253,6 +253,10 @@ func (db *DynamoDB) createTask(ctx context.Context, task *tes.Task) error { S: aws.String(strconv.FormatInt(time.Now().UnixNano(), 10)), } + av["owner"] = &dynamodb.AttributeValue{ + S: aws.String(server.GetUsername(ctx)), + } + // Add nil fields to make updates easier av["logs"] = &dynamodb.AttributeValue{ L: []*dynamodb.AttributeValue{ @@ -315,86 +319,26 @@ func (db *DynamoDB) createTaskInputContent(ctx context.Context, task *tes.Task) return nil } -func (db *DynamoDB) deleteTask(ctx context.Context, id string) error { - var item *dynamodb.DeleteItemInput - var err error - - item = &dynamodb.DeleteItemInput{ - TableName: aws.String(db.taskTable), - Key: map[string]*dynamodb.AttributeValue{ - db.partitionKey: { - S: aws.String(db.partitionValue), - }, - "id": { - S: aws.String(id), - }, - }, - } - _, err = db.client.DeleteItemWithContext(ctx, item) - if err != nil { - return err - } - - query := &dynamodb.QueryInput{ - TableName: aws.String(db.contentTable), - Limit: aws.Int64(10), - ScanIndexForward: aws.Bool(false), - ConsistentRead: aws.Bool(true), - KeyConditionExpression: aws.String("id = :v1"), - ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{ - ":v1": { - S: aws.String(id), - }, +func (db *DynamoDB) taskKey(id string) map[string]*dynamodb.AttributeValue { + return map[string]*dynamodb.AttributeValue{ + db.partitionKey: { + S: aws.String(db.partitionValue), }, - ExpressionAttributeNames: map[string]*string{ - "#index": aws.String("index"), + "id": { + S: aws.String(id), }, - ProjectionExpression: aws.String("id, #index"), - } - - err = db.client.QueryPagesWithContext( - ctx, - query, - func(page *dynamodb.QueryOutput, lastPage bool) bool { - for _, res := range page.Items { - item = &dynamodb.DeleteItemInput{ - TableName: aws.String(db.contentTable), - Key: map[string]*dynamodb.AttributeValue{ - "id": res["id"], - "index": res["index"], - }, - } - // TODO handle error without panic - _, err := db.client.DeleteItem(item) - if err != nil { - log.Fatalf("failed to delete content item: %v", err) - } - } - return page.LastEvaluatedKey != nil - }) - - if err != nil { - return err } - - return nil } func (db *DynamoDB) getMinimalView(ctx context.Context, id string) (*dynamodb.GetItemOutput, error) { item := &dynamodb.GetItemInput{ TableName: aws.String(db.taskTable), - Key: map[string]*dynamodb.AttributeValue{ - db.partitionKey: { - S: aws.String(db.partitionValue), - }, - "id": { - S: aws.String(id), - }, - }, + Key: db.taskKey(id), ExpressionAttributeNames: map[string]*string{ + "#owner": aws.String("owner"), "#state": aws.String("state"), }, - ProjectionExpression: aws.String("id, #state"), + ProjectionExpression: aws.String("id, #owner, #state"), } return db.client.GetItemWithContext(ctx, item) } @@ -402,32 +346,13 @@ func (db *DynamoDB) getMinimalView(ctx context.Context, id string) (*dynamodb.Ge func (db *DynamoDB) getBasicView(ctx context.Context, id string) (*dynamodb.GetItemOutput, error) { item := &dynamodb.GetItemInput{ TableName: aws.String(db.taskTable), - Key: map[string]*dynamodb.AttributeValue{ - db.partitionKey: { - S: aws.String(db.partitionValue), - }, - "id": { - S: aws.String(id), - }, - }, + Key: db.taskKey(id), } return db.client.GetItemWithContext(ctx, item) } func (db *DynamoDB) getFullView(ctx context.Context, id string) (*dynamodb.GetItemOutput, error) { - item := &dynamodb.GetItemInput{ - TableName: aws.String(db.taskTable), - Key: map[string]*dynamodb.AttributeValue{ - db.partitionKey: { - S: aws.String(db.partitionValue), - }, - "id": { - S: aws.String(id), - }, - }, - } - - resp, err := db.client.GetItemWithContext(ctx, item) + resp, err := db.getBasicView(ctx, id) if err != nil || resp.Item == nil { return resp, err } @@ -454,22 +379,23 @@ func (db *DynamoDB) getFullView(ctx context.Context, id string) (*dynamodb.GetIt return resp, nil } - -func (db *DynamoDB) getContent(ctx context.Context, in map[string]*dynamodb.AttributeValue) error { - query := &dynamodb.QueryInput{ - TableName: aws.String(db.contentTable), - Limit: aws.Int64(10), +func (db *DynamoDB) queryInput(table string, id *dynamodb.AttributeValue, limit int64) *dynamodb.QueryInput { + return &dynamodb.QueryInput{ + TableName: aws.String(table), + Limit: aws.Int64(limit), ScanIndexForward: aws.Bool(false), ConsistentRead: aws.Bool(true), KeyConditionExpression: aws.String("id = :v1"), ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{ - ":v1": in["id"], + ":v1": id, }, } +} - err := db.client.QueryPagesWithContext( +func (db *DynamoDB) getContent(ctx context.Context, in map[string]*dynamodb.AttributeValue) error { + return db.client.QueryPagesWithContext( ctx, - query, + db.queryInput(db.contentTable, in["id"], 10), func(page *dynamodb.QueryOutput, lastPage bool) bool { for _, item := range page.Items { i, _ := strconv.ParseInt(*item["index"].N, 10, 64) @@ -478,27 +404,12 @@ func (db *DynamoDB) getContent(ctx context.Context, in map[string]*dynamodb.Attr return page.LastEvaluatedKey != nil }, ) - if err != nil { - return err - } - return nil } func (db *DynamoDB) getExecutorOutput(ctx context.Context, in map[string]*dynamodb.AttributeValue, val string, table string) error { - query := &dynamodb.QueryInput{ - TableName: aws.String(table), - Limit: aws.Int64(10), - ScanIndexForward: aws.Bool(false), - ConsistentRead: aws.Bool(true), - KeyConditionExpression: aws.String("id = :v1"), - ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{ - ":v1": in["id"], - }, - } - - err := db.client.QueryPagesWithContext( + return db.client.QueryPagesWithContext( ctx, - query, + db.queryInput(table, in["id"], 10), func(page *dynamodb.QueryOutput, lastPage bool) bool { for _, item := range page.Items { i, _ := strconv.ParseInt(*item["index"].N, 10, 64) @@ -512,27 +423,12 @@ func (db *DynamoDB) getExecutorOutput(ctx context.Context, in map[string]*dynamo return page.LastEvaluatedKey != nil }, ) - if err != nil { - return err - } - return nil } func (db *DynamoDB) getSystemLogs(ctx context.Context, in map[string]*dynamodb.AttributeValue) error { - query := &dynamodb.QueryInput{ - TableName: aws.String(db.syslogsTable), - Limit: aws.Int64(50), - ScanIndexForward: aws.Bool(false), - ConsistentRead: aws.Bool(true), - KeyConditionExpression: aws.String("id = :v1"), - ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{ - ":v1": in["id"], - }, - } - - err := db.client.QueryPagesWithContext( + return db.client.QueryPagesWithContext( ctx, - query, + db.queryInput(db.syslogsTable, in["id"], 50), func(page *dynamodb.QueryOutput, lastPage bool) bool { for _, item := range page.Items { i, _ := strconv.ParseInt(*item["attempt"].N, 10, 64) @@ -541,8 +437,12 @@ func (db *DynamoDB) getSystemLogs(ctx context.Context, in map[string]*dynamodb.A return page.LastEvaluatedKey != nil }, ) - if err != nil { - return err +} + +func isAccessible(ctx context.Context, response *dynamodb.GetItemOutput) bool { + taskOwner := "" + if attrValue, ok := response.Item["owner"]; ok { + taskOwner = *attrValue.S } - return nil + return server.GetUser(ctx).IsAccessible(taskOwner) } diff --git a/database/elastic/events.go b/database/elastic/events.go index 92f890d35..684251d35 100644 --- a/database/elastic/events.go +++ b/database/elastic/events.go @@ -1,14 +1,16 @@ package elastic import ( - "bytes" "context" + "encoding/json" - "github.com/golang/protobuf/jsonpb" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/elastic/go-elasticsearch/v8/typedapi/types/enums/result" + "github.com/elastic/go-elasticsearch/v8/typedapi/types/enums/scriptlanguage" "github.com/ohsu-comp-bio/funnel/events" + "github.com/ohsu-comp-bio/funnel/server" "github.com/ohsu-comp-bio/funnel/tes" "github.com/ohsu-comp-bio/funnel/util" - elastic "gopkg.in/olivere/elastic.v5" ) var updateTaskLogs = ` @@ -63,66 +65,72 @@ for (; params.index > ctx._source.logs[params.attempt].logs.length - 1; ) { ctx._source.logs[params.attempt].logs[params.index][params.field] = params.value; ` -func taskLogUpdate(attempt uint32, field string, value interface{}) *elastic.Script { - return elastic.NewScript(updateTaskLogs). - Lang("painless"). - Param("attempt", attempt). - Param("field", field). - Param("value", value) +func asRawMessage(v interface{}) json.RawMessage { + b, _ := json.Marshal(v) + return json.RawMessage(b) } -func execLogUpdate(attempt, index uint32, field string, value interface{}) *elastic.Script { - return elastic.NewScript(updateExecutorLogs). - Lang("painless"). - Param("attempt", attempt). - Param("index", index). - Param("field", field). - Param("value", value) +func taskLogUpdate(attempt uint32, field string, value interface{}) *types.Script { + return &types.Script{ + Lang: &scriptlanguage.Painless, + Source: &updateTaskLogs, + Params: map[string]json.RawMessage{ + "attempt": asRawMessage(attempt), + "field": asRawMessage(field), + "value": asRawMessage(value), + }, + } +} + +func execLogUpdate(attempt, index uint32, field string, value interface{}) *types.Script { + return &types.Script{ + Lang: &scriptlanguage.Painless, + Source: &updateExecutorLogs, + Params: map[string]json.RawMessage{ + "attempt": asRawMessage(attempt), + "index": asRawMessage(index), + "field": asRawMessage(field), + "value": asRawMessage(value), + }, + } } // WriteEvent writes a task update event. func (es *Elastic) WriteEvent(ctx context.Context, ev *events.Event) error { - u := es.client.Update(). - Index(es.taskIndex). - Type("task"). - RetryOnConflict(10). - Id(ev.Id) + u := es.client.Update(es.taskIndex, ev.Id).RetryOnConflict(10) switch ev.Type { case events.Type_TASK_CREATED: task := ev.GetTask() - mar := jsonpb.Marshaler{} - s, err := mar.MarshalToString(task) - if err != nil { - return err - } - _, err = es.client.Index(). - Index(es.taskIndex). - Type("task"). + res, err := es.client.Index(es.taskIndex). Id(task.Id). - BodyString(s). + Document(task). Do(ctx) + + if err == nil { + _, err = es.client.Update(res.Index_, res.Id_). + IfSeqNo(int64ToStr(res.SeqNo_)). + IfPrimaryTerm(int64ToStr(res.PrimaryTerm_)). + Doc(map[string]string{"owner": server.GetUsername(ctx)}). + Do(ctx) + } + return err case events.Type_TASK_STATE: retrier := util.NewRetrier() retrier.ShouldRetry = func(err error) bool { - if elastic.IsConflict(err) || elastic.IsConnErr(err) { - return true - } - return false + _, isTransitionError := err.(*tes.TransitionError) + return !isTransitionError && err != tes.ErrNotFound && err != tes.ErrNotPermitted } return retrier.Retry(ctx, func() error { // get current state & version - res, err := es.getTask(ctx, &tes.GetTaskRequest{Id: ev.Id}) - if err != nil { - return err - } - - task := &tes.Task{} - err = jsonpb.Unmarshal(bytes.NewReader(*res.Source), task) + task, seqNo, primaryTerm, err := es.getTask(ctx, &tes.GetTaskRequest{ + Id: ev.Id, + View: tes.View_MINIMAL.String(), + }) if err != nil { return err } @@ -135,11 +143,9 @@ func (es *Elastic) WriteEvent(ctx context.Context, ev *events.Event) error { } // apply version restriction and set update - _, err = es.client.Update(). - Index(es.taskIndex). - Type("task"). - Id(ev.Id). - Version(*res.Version). + _, err = es.client.Update(es.taskIndex, ev.Id). + IfSeqNo(seqNo). + IfPrimaryTerm(primaryTerm). Doc(map[string]string{"state": to.String()}). Do(ctx) return err @@ -176,8 +182,8 @@ func (es *Elastic) WriteEvent(ctx context.Context, ev *events.Event) error { u = u.Script(taskLogUpdate(ev.Attempt, "system_logs", ev.SysLogString())) } - _, err := u.Do(ctx) - if elastic.IsNotFound(err) { + resp, err := u.Do(ctx) + if resp.Result == result.Noop || resp.Result == result.Notfound { return tes.ErrNotFound } return err diff --git a/database/elastic/new.go b/database/elastic/new.go index 07a9cae22..a9b793419 100644 --- a/database/elastic/new.go +++ b/database/elastic/new.go @@ -2,21 +2,22 @@ package elastic import ( "context" - "time" + "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/ohsu-comp-bio/funnel/compute/scheduler" "github.com/ohsu-comp-bio/funnel/config" - elastic "gopkg.in/olivere/elastic.v5" ) -var minimal = elastic.NewFetchSourceContext(true).Include("id", "state") -var basic = elastic.NewFetchSourceContext(true). - Exclude("logs.logs.stderr", "logs.logs.stdout", "inputs.content", "logs.system_logs") +var ( + minimalInclude = []string{"id", "state", "owner"} + basicExclude = []string{"inputs.content", "logs.logs.stderr", "logs.logs.stdout", "logs.system_logs"} +) // Elastic provides an elasticsearch database server backend. type Elastic struct { scheduler.UnimplementedSchedulerServiceServer - client *elastic.Client + client *elasticsearch.TypedClient conf config.Elastic taskIndex string nodeIndex string @@ -24,15 +25,14 @@ type Elastic struct { // NewElastic returns a new Elastic instance. func NewElastic(conf config.Elastic) (*Elastic, error) { - client, err := elastic.NewClient( - elastic.SetURL(conf.URL), - elastic.SetSniff(false), - elastic.SetRetrier( - elastic.NewBackoffRetrier( - elastic.NewExponentialBackoff(time.Millisecond*50, time.Minute), - ), - ), - ) + client, err := elasticsearch.NewTypedClient(elasticsearch.Config{ + Addresses: []string{conf.URL}, // A list of Elasticsearch nodes to use. + Username: conf.Username, // Username for HTTP Basic Authentication. + Password: conf.Password, // Password for HTTP Basic Authentication. + CloudID: conf.CloudID, // Endpoint for the Elastic Service (https://elastic.co/cloud). + APIKey: conf.APIKey, // Base64-encoded token for authorization; if set, overrides username/password and service token. + ServiceToken: conf.ServiceToken, // Service token for authorization; if set, overrides username/password. + }) if err != nil { return nil, err } @@ -47,56 +47,43 @@ func NewElastic(conf config.Elastic) (*Elastic, error) { // Close closes the database client. func (es *Elastic) Close() { - es.client.Stop() + // no-op } -func (es *Elastic) initIndex(ctx context.Context, name, body string) error { - exists, err := es.client. - IndexExists(name). - Do(ctx) - - if err != nil { - return err - } else if !exists { - if _, err := es.client.CreateIndex(name).Body(body).Do(ctx); err != nil { - return err +func (es *Elastic) initIndex(ctx context.Context, name string, properties *map[string]types.Property) error { + exists, err := es.client.Indices.Exists(name).Do(ctx) + if err == nil && !exists { + var mappings *types.TypeMapping = nil + if properties != nil { + mappings = &types.TypeMapping{ + Properties: *properties, + } } + _, err = es.client.Indices.Create(name).Mappings(mappings).Do(ctx) } - return nil + return err } // Init creates the Elasticsearch indices. func (es *Elastic) Init() error { ctx := context.Background() - taskMappings := `{ - "mappings": { - "task":{ - "properties":{ - "id": { - "type": "keyword" - }, - "state": { - "type": "keyword" - }, - "inputs": { - "type": "nested" - }, - "logs": { - "type": "nested", - "properties": { - "logs": { - "type": "nested" - } - } - } - } - } - } - }` - if err := es.initIndex(ctx, es.taskIndex, taskMappings); err != nil { + + taskProperties := &map[string]types.Property{ + "id": types.KeywordProperty{}, + "state": types.KeywordProperty{}, + "owner": types.KeywordProperty{}, + "inputs": types.NestedProperty{}, + "logs": types.NestedProperty{ + Properties: map[string]types.Property{ + "logs": types.NestedProperty{}, + }, + }, + } + + if err := es.initIndex(ctx, es.taskIndex, taskProperties); err != nil { return err } - if err := es.initIndex(ctx, es.nodeIndex, ""); err != nil { + if err := es.initIndex(ctx, es.nodeIndex, nil); err != nil { return err } return nil diff --git a/database/elastic/scheduler.go b/database/elastic/scheduler.go index 06d22c453..40262a58e 100644 --- a/database/elastic/scheduler.go +++ b/database/elastic/scheduler.go @@ -1,30 +1,46 @@ package elastic import ( - "bytes" "fmt" - "github.com/golang/protobuf/jsonpb" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/elastic/go-elasticsearch/v8/typedapi/types/enums/refresh" + "github.com/elastic/go-elasticsearch/v8/typedapi/types/enums/result" + "github.com/elastic/go-elasticsearch/v8/typedapi/types/enums/sortorder" "github.com/ohsu-comp-bio/funnel/compute/scheduler" "github.com/ohsu-comp-bio/funnel/tes" "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - elastic "gopkg.in/olivere/elastic.v5" +) + +var ( + readQueueQuery *types.Query = &types.Query{ + Bool: &types.BoolQuery{ + Filter: []types.Query{ + { + Term: map[string]types.TermQuery{ + "state": {Value: tes.State_QUEUED.String()}, + }, + }, + }, + }, + } + readQueueSort types.SortOptions = types.SortOptions{ + SortOptions: map[string]types.FieldSort{ + "id": {Order: &sortorder.Asc}, + }, + } ) // ReadQueue returns a slice of queued Tasks. Up to "n" tasks are returned. func (es *Elastic) ReadQueue(n int) []*tes.Task { - ctx := context.Background() - - q := elastic.NewTermQuery("state", tes.State_QUEUED.String()) - res, err := es.client.Search(). - Index(es.taskIndex). - Type("task"). + res, err := es.client.Search().Index(es.taskIndex). + Query(readQueueQuery). + SourceExcludes_(basicExclude...). Size(n). - Sort("id", true). - Query(q). - Do(ctx) + Sort(readQueueSort). + Do(context.Background()) if err != nil { fmt.Println(err) return nil @@ -32,28 +48,19 @@ func (es *Elastic) ReadQueue(n int) []*tes.Task { var tasks []*tes.Task for _, hit := range res.Hits.Hits { - t := &tes.Task{} - err := jsonpb.Unmarshal(bytes.NewReader(*hit.Source), t) - if err != nil { - continue + task := &tes.Task{} + if err := customJson.Unmarshal(hit.Source_, task); err == nil { + tasks = append(tasks, task) } - - t = t.GetBasicView() - tasks = append(tasks, t) } - return tasks } // GetNode gets a node func (es *Elastic) GetNode(ctx context.Context, req *scheduler.GetNodeRequest) (*scheduler.Node, error) { - res, err := es.client.Get(). - Index(es.nodeIndex). - Type("node"). - Id(req.Id). - Do(ctx) + res, err := es.client.Get(es.nodeIndex, req.Id).Do(ctx) - if elastic.IsNotFound(err) { + if !res.Found { return nil, status.Errorf(codes.NotFound, "%v: nodeID: %s", err.Error(), req.Id) } if err != nil { @@ -61,12 +68,12 @@ func (es *Elastic) GetNode(ctx context.Context, req *scheduler.GetNodeRequest) ( } node := &scheduler.Node{} - err = jsonpb.Unmarshal(bytes.NewReader(*res.Source), node) + err = customJson.Unmarshal(res.Source_, node) if err != nil { return nil, err } // Must happen after the unmarshal - node.Version = *res.Version + node.Version = *res.Version_ return node, nil } @@ -75,62 +82,54 @@ func (es *Elastic) GetNode(ctx context.Context, req *scheduler.GetNodeRequest) ( // For optimisic locking, if the node already exists and node.Version // doesn't match the version in the database, an error is returned. func (es *Elastic) PutNode(ctx context.Context, node *scheduler.Node) (*scheduler.PutNodeResponse, error) { - g := es.client.Get(). - Index(es.nodeIndex). - Type("node"). - Preference("_primary"). - Id(node.Id) + g := es.client.Get(es.nodeIndex, node.Id).Preference("_primary") - // If the version is 0, then this should be creating a new node. + // If the version is 0, then this should be creating a new node. if node.GetVersion() != 0 { - g = g.Version(node.GetVersion()) + v := node.GetVersion() + g.Version(int64ToStr(&v)) } res, err := g.Do(ctx) - - if err != nil && !elastic.IsNotFound(err) { + if err != nil { return nil, err } existing := &scheduler.Node{} - if err == nil { - jsonpb.Unmarshal(bytes.NewReader(*res.Source), existing) - } - - err = scheduler.UpdateNode(ctx, es, node, existing) + err = customJson.Unmarshal(res.Source_, existing) if err != nil { return nil, err } - mar := jsonpb.Marshaler{} - s, err := mar.MarshalToString(node) + err = scheduler.UpdateNode(ctx, es, node, existing) if err != nil { return nil, err } - i := es.client.Index(). - Index(es.nodeIndex). - Type("node"). + i := es.client.Index(es.nodeIndex). Id(node.Id). - Refresh("true"). - BodyString(s) + Refresh(refresh.True). + Document(node) if node.GetVersion() != 0 { - i = i.Version(node.GetVersion()) + v := node.GetVersion() + i = i.Version(int64ToStr(&v)) + } + resp, err := i.Do(ctx) + if resp.Result != result.Created && resp.Result != result.Updated { + return nil, fmt.Errorf( + "Node [%s] was not recorded in ElasticSearch; response was: %s", + node.Id, resp.Result) } - _, err = i.Do(ctx) return &scheduler.PutNodeResponse{}, err } // DeleteNode deletes a node by ID. func (es *Elastic) DeleteNode(ctx context.Context, node *scheduler.Node) (*scheduler.DeleteNodeResponse, error) { - _, err := es.client.Delete(). - Index(es.nodeIndex). - Type("node"). - Id(node.Id). - Version(node.Version). - Refresh("true"). + _, err := es.client.Delete(es.nodeIndex, node.Id). + Version(int64ToStr(&node.Version)). + Refresh(refresh.True). Do(ctx) return &scheduler.DeleteNodeResponse{}, err } @@ -139,7 +138,6 @@ func (es *Elastic) DeleteNode(ctx context.Context, node *scheduler.Node) (*sched func (es *Elastic) ListNodes(ctx context.Context, req *scheduler.ListNodesRequest) (*scheduler.ListNodesResponse, error) { res, err := es.client.Search(). Index(es.nodeIndex). - Type("node"). Version(true). Size(1000). Do(ctx) @@ -151,11 +149,11 @@ func (es *Elastic) ListNodes(ctx context.Context, req *scheduler.ListNodesReques resp := &scheduler.ListNodesResponse{} for _, hit := range res.Hits.Hits { node := &scheduler.Node{} - err = jsonpb.Unmarshal(bytes.NewReader(*hit.Source), node) + err = customJson.Unmarshal(hit.Source_, node) if err != nil { return nil, err } - node.Version = *hit.Version + node.Version = *hit.Version_ resp.Nodes = append(resp.Nodes, node) } diff --git a/database/elastic/tes.go b/database/elastic/tes.go index 00ca0ed16..dbac132c8 100644 --- a/database/elastic/tes.go +++ b/database/elastic/tes.go @@ -1,98 +1,148 @@ package elastic import ( - "bytes" + "encoding/json" "fmt" + "strconv" - "github.com/golang/protobuf/jsonpb" + "google.golang.org/protobuf/encoding/protojson" + + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + "github.com/elastic/go-elasticsearch/v8/typedapi/types/enums/sortorder" + "github.com/ohsu-comp-bio/funnel/server" "github.com/ohsu-comp-bio/funnel/tes" "golang.org/x/net/context" - elastic "gopkg.in/olivere/elastic.v5" ) -func (es *Elastic) getTask(ctx context.Context, req *tes.GetTaskRequest) (*elastic.GetResult, error) { - g := es.client.Get(). - Index(es.taskIndex). - Type("task"). - Id(req.Id) +// Custom unmarshaller where unknown JSON properties do not cause an error. +var customJson = protojson.UnmarshalOptions{ + DiscardUnknown: true, +} + +func int64ToStr(i *int64) string { + return strconv.FormatInt(*i, 10) +} + +type TaskOwner struct { + Owner string `json:"owner"` +} + +func (es *Elastic) getTask(ctx context.Context, req *tes.GetTaskRequest) (*tes.Task, string, string, error) { + g := es.client.Get(es.taskIndex, req.Id) switch req.View { - case tes.View_BASIC.String(): - g = g.FetchSource(true).FetchSourceContext(basic) case tes.View_MINIMAL.String(): - g = g.FetchSource(true).FetchSourceContext(minimal) + g = g.SourceIncludes_(minimalInclude...) + case tes.View_BASIC.String(): + g = g.SourceExcludes_(basicExclude...) } res, err := g.Do(ctx) - if elastic.IsNotFound(err) { - return nil, tes.ErrNotFound + + if err != nil { + return nil, "", "", err + } + + if !res.Found { + return nil, "", "", tes.ErrNotFound + } + + if userInfo := server.GetUser(ctx); !userInfo.CanSeeAllTasks() { + partial := TaskOwner{} + _ = json.Unmarshal(res.Source_, &partial) + if !userInfo.IsAccessible(partial.Owner) { + return nil, "", "", tes.ErrNotPermitted + } } - return res, err + + seqNo := int64ToStr(res.SeqNo_) + primaryTerm := int64ToStr(res.PrimaryTerm_) + + task := tes.Task{} + err = customJson.Unmarshal(res.Source_, &task) + return &task, seqNo, primaryTerm, err } // GetTask gets a task by ID. func (es *Elastic) GetTask(ctx context.Context, req *tes.GetTaskRequest) (*tes.Task, error) { - res, err := es.getTask(ctx, req) - if err != nil { - return nil, err - } - task := &tes.Task{} - err = jsonpb.Unmarshal(bytes.NewReader(*res.Source), task) + task, _, _, err := es.getTask(ctx, req) return task, err } // ListTasks lists tasks, duh. func (es *Elastic) ListTasks(ctx context.Context, req *tes.ListTasksRequest) (*tes.ListTasksResponse, error) { - pageSize := tes.GetPageSize(req.GetPageSize()) - q := es.client.Search(). - Index(es.taskIndex). - Type("task") + filters := map[string]string{} - if req.PageToken != "" { - q = q.SearchAfter(req.PageToken) + if userInfo := server.GetUser(ctx); !userInfo.CanSeeAllTasks() { + filters["owner"] = userInfo.Username } - filterParts := []elastic.Query{} if req.State != tes.Unknown { - filterParts = append(filterParts, elastic.NewTermQuery("state", req.State.String())) + filters["state"] = req.State.String() } for k, v := range req.GetTags() { - filterParts = append(filterParts, elastic.NewMatchQuery(fmt.Sprintf("tags.%s.keyword", k), v)) + field := fmt.Sprintf("tags.%s.keyword", k) + filters[field] = v } - if len(filterParts) > 0 { - q = q.Query(elastic.NewBoolQuery().Filter(filterParts...)) + sort := types.SortOptions{ + SortOptions: map[string]types.FieldSort{ + "id": {Order: &sortorder.Desc}, + }, } - q = q.Sort("id", false).Size(pageSize) + query := types.Query{ + Bool: &types.BoolQuery{ + Filter: []types.Query{}, + }, + } + + for key, value := range filters { + query.Bool.Filter = append(query.Bool.Filter, types.Query{ + Term: map[string]types.TermQuery{ + key: {Value: value}, + }, + }) + } + + search := es.client.Search(). + Index(es.taskIndex). + Query(&query). + Size(pageSize). + Sort(sort). + ErrorTrace(true) + + if req.PageToken != "" { + search.SearchAfter(req.PageToken) + } switch req.View { case tes.View_BASIC.String(): - q = q.FetchSource(true).FetchSourceContext(basic) + search.SourceExcludes_(basicExclude...) case tes.View_MINIMAL.String(): - q = q.FetchSource(true).FetchSourceContext(minimal) + search.SourceIncludes_(minimalInclude...) } - res, err := q.Do(ctx) + res, err := search.Do(ctx) if err != nil { return nil, err } resp := &tes.ListTasksResponse{} for i, hit := range res.Hits.Hits { - t := &tes.Task{} - err := jsonpb.Unmarshal(bytes.NewReader(*hit.Source), t) + task := &tes.Task{} + err := customJson.Unmarshal(hit.Source_, task) if err != nil { return nil, err } if i == pageSize-1 { - resp.NextPageToken = t.Id + resp.NextPageToken = task.Id } - resp.Tasks = append(resp.Tasks, t) + resp.Tasks = append(resp.Tasks, task) } return resp, nil diff --git a/database/mongodb/counts.go b/database/mongodb/counts.go index ed95c11ec..dd749ba81 100644 --- a/database/mongodb/counts.go +++ b/database/mongodb/counts.go @@ -4,8 +4,8 @@ import ( "context" "github.com/ohsu-comp-bio/funnel/tes" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" ) type stateCount struct { @@ -26,13 +26,19 @@ func (db *MongoDB) TaskStateCounts(ctx context.Context) (map[string]int32, error }, }} - cursor, err := db.tasks(db.client).Aggregate(context.TODO(), mongo.Pipeline{stateStage, groupStage}) + mctx, cancel := db.wrap(ctx) + defer cancel() + + cursor, err := db.tasks().Aggregate(mctx, mongo.Pipeline{stateStage, groupStage}) if err != nil { return nil, err } + mctx, cancel = db.wrap(ctx) + defer cancel() + recs := []stateCount{} - err = cursor.All(context.TODO(), &recs) + err = cursor.All(mctx, &recs) if err != nil { return nil, err } diff --git a/database/mongodb/events.go b/database/mongodb/events.go index f06932a42..a356e8237 100644 --- a/database/mongodb/events.go +++ b/database/mongodb/events.go @@ -6,15 +6,17 @@ import ( "time" "github.com/ohsu-comp-bio/funnel/events" + "github.com/ohsu-comp-bio/funnel/server" "github.com/ohsu-comp-bio/funnel/tes" "github.com/ohsu-comp-bio/funnel/util" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" ) // WriteEvent creates an event for the server to handle. func (db *MongoDB) WriteEvent(ctx context.Context, req *events.Event) error { - tasks := db.tasks(db.client) + tasks := db.tasks() update := bson.M{} selector := bson.M{"id": req.Id} @@ -29,36 +31,36 @@ func (db *MongoDB) WriteEvent(ctx context.Context, req *events.Event) error { SystemLogs: []string{}, }, } - _, err := tasks.InsertOne(context.TODO(), &task) - return err + return db.insertTask(ctx, task) case events.Type_TASK_STATE: retrier := util.NewRetrier() retrier.ShouldRetry = func(err error) bool { - return err == tes.ErrConcurrentStateChange + _, isTransitionError := err.(*tes.TransitionError) + return !isTransitionError && err != tes.ErrNotFound && err != tes.ErrNotPermitted } return retrier.Retry(ctx, func() error { // get current state & version - current := make(map[string]interface{}) - opts := options.FindOne().SetProjection(bson.M{"state": 1, "version": 1}) - err := tasks.FindOne(context.TODO(), bson.M{"id": req.Id}, opts).Decode(¤t) + state, version, err := db.findTaskStateAndVersion(ctx, req.Id) if err != nil { - return tes.ErrNotFound + return err } // validate state transition - from := tes.State(current["state"].(int32)) to := req.GetState() - if err = tes.ValidateTransition(from, to); err != nil { + if err = tes.ValidateTransition(state, to); err != nil { return err } // apply version restriction and set update - selector["version"] = current["version"] + selector["version"] = version update = bson.M{"$set": bson.M{"state": to, "version": time.Now().UnixNano()}} - result, err := tasks.UpdateOne(context.TODO(), selector, update) + mctx, cancel := db.wrap(ctx) + defer cancel() + + result, err := tasks.UpdateOne(mctx, selector, update) if result.MatchedCount == 0 { return tes.ErrConcurrentStateChange } @@ -137,7 +139,52 @@ func (db *MongoDB) WriteEvent(ctx context.Context, req *events.Event) error { } } - opts := options.Update().SetUpsert(true) - _, err := tasks.UpdateOne(context.TODO(), selector, update, opts) + mctx, cancel := db.wrap(ctx) + defer cancel() + + opts := options.UpdateOne().SetUpsert(true) + _, err := tasks.UpdateOne(mctx, selector, update, opts) return err } + +func (db *MongoDB) insertTask(ctx context.Context, task *tes.Task) error { + mctx, cancel := db.wrap(ctx) + defer cancel() + + tasks := db.tasks() + result, err := tasks.InsertOne(mctx, &task) + + if err == nil { + mctx, cancel := db.wrap(ctx) + defer cancel() + + updateOwner := bson.M{"$set": bson.M{"owner": server.GetUsername(ctx)}} + _, err = tasks.UpdateOne(mctx, bson.M{"_id": result.InsertedID}, updateOwner) + } + + return err +} + +func (db *MongoDB) findTaskStateAndVersion(ctx context.Context, taskId string) (tes.State, interface{}, error) { + mctx, cancel := db.wrap(ctx) + defer cancel() + + props := make(map[string]interface{}) + opts := options.FindOne().SetProjection(bson.M{"state": 1, "version": 1, "owner": 1}) + err := db.tasks().FindOne(mctx, bson.M{"id": taskId}, opts).Decode(&props) + + if err == mongo.ErrNoDocuments { + return tes.State_UNKNOWN, nil, tes.ErrNotFound + } else if err != nil { + return tes.State_UNKNOWN, nil, err + } + + taskOwner := props["owner"].(string) + if !server.GetUser(ctx).IsAccessible(taskOwner) { + return tes.State_UNKNOWN, nil, tes.ErrNotPermitted + } + + state := tes.State(props["state"].(int32)) + version := props["version"] + return state, version, nil +} diff --git a/database/mongodb/new.go b/database/mongodb/new.go index 37dd044e3..52db1e93a 100644 --- a/database/mongodb/new.go +++ b/database/mongodb/new.go @@ -3,100 +3,132 @@ package mongodb import ( "context" "fmt" + "time" "github.com/ohsu-comp-bio/funnel/compute/scheduler" "github.com/ohsu-comp-bio/funnel/config" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" ) // MongoDB provides an MongoDB database server backend. type MongoDB struct { scheduler.UnimplementedSchedulerServiceServer - client *mongo.Client - conf config.MongoDB - active bool + client *mongo.Client + database *mongo.Database + conf config.MongoDB + active bool } func NewMongoDB(conf config.MongoDB) (*MongoDB, error) { - client, err := mongo.Connect( - context.TODO(), - options.Client().ApplyURI(conf.Addrs[0])) + opts := options.Client(). + SetHosts(conf.Addrs). + SetAppName("funnel") + + if len(conf.Username) > 0 && len(conf.Password) > 0 { + opts = opts.SetAuth(options.Credential{ + Username: conf.Username, + Password: conf.Password, + }) + } + client, err := mongo.Connect(opts) if err != nil { return nil, err } db := &MongoDB{ - client: client, - conf: conf, - active: true, + client: client, + database: client.Database(conf.Database), + conf: conf, + active: true, } return db, nil } -func (db *MongoDB) tasks(client *mongo.Client) *mongo.Collection { - return client.Database(db.conf.Database).Collection("tasks") +func (db *MongoDB) context() (context.Context, context.CancelFunc) { + return db.wrap(context.Background()) } -func (db *MongoDB) nodes(client *mongo.Client) *mongo.Collection { - return client.Database(db.conf.Database).Collection("nodes") +func (db *MongoDB) wrap(ctx context.Context) (context.Context, context.CancelFunc) { + return context.WithTimeout(ctx, time.Duration(db.conf.Timeout)) } -// Init creates tables in MongoDB. -func (db *MongoDB) Init() error { - tasks := db.tasks(db.client) - nodes := db.nodes(db.client) +func (db *MongoDB) collection(name string) *mongo.Collection { + return db.database.Collection(name) +} + +func (db *MongoDB) nodes() *mongo.Collection { + return db.collection("nodes") +} + +func (db *MongoDB) tasks() *mongo.Collection { + return db.collection("tasks") +} - names, err := db.client.Database(db.conf.Database).ListCollectionNames(context.TODO(), bson.D{}, nil) +func (db *MongoDB) createCollection(name string, indexKeys *bson.D) error { + ctx, cancel := db.context() + defer cancel() + + err := db.database.CreateCollection(ctx, name, nil) if err != nil { - return err + return fmt.Errorf( + "error creating collection [%s] in database [%s]: %v", + name, db.conf.Database, err) + } + + indexModel := mongo.IndexModel{ + Keys: indexKeys, + Options: options.Index().SetUnique(true).SetSparse(true), } - var tasksFound bool - var nodesFound bool + + ctx, cancel = db.context() + defer cancel() + + _, err = db.collection(name).Indexes().CreateOne(ctx, indexModel) + return err +} + +func (db *MongoDB) findCollections(names ...string) (map[string]bool, error) { + ctx, cancel := db.context() + defer cancel() + + filter := bson.M{"name": bson.M{"$in": names}} + names, err := db.database.ListCollectionNames(ctx, filter, nil) + if err != nil { + return nil, err + } + + result := make(map[string]bool) for _, name := range names { - switch name { - case "tasks": - tasksFound = true - case "nodes": - nodesFound = true - } + result[name] = true } + return result, nil +} - if !tasksFound { - err = db.client.Database(db.conf.Database).CreateCollection(context.Background(), "tasks", nil) - if err != nil { - return fmt.Errorf("error creating tasks collection in database %s: %v", db.conf.Database, err) - } +// Init creates tables in MongoDB. +func (db *MongoDB) Init() error { + found, err := db.findCollections("tasks", "nodes") + if err != nil { + return err + } - indexModel := mongo.IndexModel{ - Keys: bson.D{ - {Key: "-id", Value: -1}, - {Key: "-creationtime", Value: -1}, - }, - Options: options.Index().SetUnique(true).SetSparse(true), + if !found["tasks"] { + indexKeys := &bson.D{ + {Key: "-id", Value: -1}, + {Key: "-creationtime", Value: -1}, } - _, err = tasks.Indexes().CreateOne(context.TODO(), indexModel) - if err != nil { + if err := db.createCollection("tasks", indexKeys); err != nil { return err } } - if !nodesFound { - err = db.client.Database(db.conf.Database).CreateCollection(context.Background(), "nodes", nil) - if err != nil { - return fmt.Errorf("error creating nodes collection in database %s: %v", db.conf.Database, err) - } - - indexModel := mongo.IndexModel{ - Keys: bson.D{ - {Key: "-id", Value: -1}, - }, - Options: options.Index().SetUnique(true).SetSparse(true), + if !found["nodes"] { + indexKeys := &bson.D{ + {Key: "-id", Value: -1}, } - _, err = nodes.Indexes().CreateOne(context.TODO(), indexModel) - if err != nil { + if err := db.createCollection("nodes", indexKeys); err != nil { return err } } @@ -107,7 +139,9 @@ func (db *MongoDB) Init() error { // Close closes the database session. func (db *MongoDB) Close() { if db.active { - db.client.Disconnect(context.TODO()) + ctx, cancel := db.context() + defer cancel() + db.client.Disconnect(ctx) } db.active = false } diff --git a/database/mongodb/scheduler.go b/database/mongodb/scheduler.go index 9b69a1223..11ef93ff2 100644 --- a/database/mongodb/scheduler.go +++ b/database/mongodb/scheduler.go @@ -5,9 +5,9 @@ import ( "github.com/ohsu-comp-bio/funnel/compute/scheduler" "github.com/ohsu-comp-bio/funnel/tes" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -15,16 +15,22 @@ import ( // ReadQueue returns a slice of queued Tasks. Up to "n" tasks are returned. func (db *MongoDB) ReadQueue(n int) []*tes.Task { + ctx, cancel := db.context() + defer cancel() + fmt.Println("Reading queue!") - var tasks []*tes.Task opts := options.Find().SetSort(bson.M{"creationtime": 1}).SetLimit(int64(n)) - cursor, err := db.tasks(db.client).Find(context.TODO(), bson.M{"state": tes.State_QUEUED}, opts) + cursor, err := db.tasks().Find(ctx, bson.M{"state": tes.State_QUEUED}, opts) if err != nil { fmt.Println(err) return nil } - err = cursor.All(context.TODO(), &tasks) + ctx, cancel = db.context() + defer cancel() + + var tasks []*tes.Task + err = cursor.All(ctx, &tasks) if err != nil { fmt.Println(err) return nil @@ -37,7 +43,7 @@ func (db *MongoDB) ReadQueue(n int) []*tes.Task { // and status updates, such as completed tasks. The server responds with updated // information for the node, such as canceled tasks. func (db *MongoDB) PutNode(ctx context.Context, node *scheduler.Node) (*scheduler.PutNodeResponse, error) { - nodes := db.nodes(db.client) + nodes := db.nodes() q := bson.M{"id": node.Id} @@ -45,30 +51,42 @@ func (db *MongoDB) PutNode(ctx context.Context, node *scheduler.Node) (*schedule q["version"] = node.GetVersion() } + mctx, cancel := db.wrap(ctx) + defer cancel() + var existing scheduler.Node - err := nodes.FindOne(context.TODO(), bson.M{"id": node.Id}).Decode(&existing) + err := nodes.FindOne(mctx, bson.M{"id": node.Id}).Decode(&existing) if err != nil { return nil, err } + mctx, cancel = db.wrap(ctx) + defer cancel() + db.GetTask(ctx, &tes.GetTaskRequest{Id: "foo"}) - err = scheduler.UpdateNode(ctx, db, node, &existing) + err = scheduler.UpdateNode(mctx, db, node, &existing) if err != nil { return nil, err } node.Version = node.GetVersion() + 1 - opts := options.Update().SetUpsert(true) - _, err = nodes.UpdateOne(context.TODO(), q, node, opts) + mctx, cancel = db.wrap(ctx) + defer cancel() + + opts := options.UpdateOne().SetUpsert(true) + _, err = nodes.UpdateOne(mctx, q, node, opts) return &scheduler.PutNodeResponse{}, err } // GetNode gets a node func (db *MongoDB) GetNode(ctx context.Context, req *scheduler.GetNodeRequest) (*scheduler.Node, error) { + mctx, cancel := db.wrap(ctx) + defer cancel() + var node scheduler.Node - err := db.nodes(db.client).FindOne(context.TODO(), bson.M{"id": req.Id}).Decode(&node) + err := db.nodes().FindOne(mctx, bson.M{"id": req.Id}).Decode(&node) if err == mongo.ErrNoDocuments { return nil, status.Errorf(codes.NotFound, "%v: nodeID: %s", err, req.Id) } @@ -78,8 +96,11 @@ func (db *MongoDB) GetNode(ctx context.Context, req *scheduler.GetNodeRequest) ( // DeleteNode deletes a node func (db *MongoDB) DeleteNode(ctx context.Context, req *scheduler.Node) (*scheduler.DeleteNodeResponse, error) { + mctx, cancel := db.wrap(ctx) + defer cancel() + fmt.Println("DeleteNode", req.Id) - _, err := db.nodes(db.client).DeleteOne(context.TODO(), bson.M{"id": req.Id}) + _, err := db.nodes().DeleteOne(mctx, bson.M{"id": req.Id}) fmt.Println("DeleteNode", req.Id, err) if err == mongo.ErrNoDocuments { return nil, status.Errorf(codes.NotFound, "%v: nodeID: %s", err, req.Id) @@ -89,13 +110,16 @@ func (db *MongoDB) DeleteNode(ctx context.Context, req *scheduler.Node) (*schedu // ListNodes is an API endpoint that returns a list of nodes. func (db *MongoDB) ListNodes(ctx context.Context, req *scheduler.ListNodesRequest) (*scheduler.ListNodesResponse, error) { + mctx, cancel := db.wrap(ctx) + defer cancel() + var nodes []*scheduler.Node - cursor, err := db.nodes(db.client).Find(context.TODO(), nil) + cursor, err := db.nodes().Find(mctx, nil) if err != nil { return nil, err } - err = cursor.All(context.TODO(), &nodes) + err = cursor.All(mctx, &nodes) if err != nil { return nil, err } diff --git a/database/mongodb/tes.go b/database/mongodb/tes.go index 5a0ba322d..2d6f68138 100644 --- a/database/mongodb/tes.go +++ b/database/mongodb/tes.go @@ -3,9 +3,10 @@ package mongodb import ( "fmt" + "github.com/ohsu-comp-bio/funnel/server" "github.com/ohsu-comp-bio/funnel/tes" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo/options" "golang.org/x/net/context" ) @@ -15,11 +16,14 @@ var basicView = bson.M{ "logs.logs.stderr": 0, "inputs.content": 0, } -var minimalView = bson.M{"id": 1, "state": 1} +var minimalView = bson.M{"id": 1, "owner": 1, "state": 1} + +type TaskOwner struct { + Owner string `bson:"owner"` +} // GetTask gets a task, which describes a running task func (db *MongoDB) GetTask(ctx context.Context, req *tes.GetTaskRequest) (*tes.Task, error) { - var task tes.Task var opts = options.FindOne() switch req.View { @@ -29,11 +33,24 @@ func (db *MongoDB) GetTask(ctx context.Context, req *tes.GetTaskRequest) (*tes.T opts = opts.SetProjection(minimalView) } - err := db.tasks(db.client).FindOne(context.TODO(), bson.M{"id": req.Id}, opts).Decode(&task) - if err != nil { + mctx, cancel := db.wrap(ctx) + defer cancel() + + result := db.tasks().FindOne(mctx, bson.M{"id": req.Id}, opts) + if result.Err() != nil { return nil, tes.ErrNotFound } + var owner TaskOwner + _ = result.Decode(&owner) + if !server.GetUser(ctx).IsAccessible(owner.Owner) { + return nil, tes.ErrNotPermitted + } + + var task tes.Task + if err := result.Decode(&task); err != nil { + return nil, tes.ErrNotFound + } return &task, nil } @@ -55,6 +72,10 @@ func (db *MongoDB) ListTasks(ctx context.Context, req *tes.ListTasksRequest) (*t query["name"] = bson.M{"$regex": fmt.Sprintf("^%s", req.NamePrefix)} } + if userInfo := server.GetUser(ctx); !userInfo.CanSeeAllTasks() { + query["owner"] = bson.M{"$eq": userInfo.Username} + } + for k, v := range req.GetTags() { if v == "" { query[fmt.Sprintf("tags.%s", k)] = bson.M{"$exists": true} @@ -72,13 +93,19 @@ func (db *MongoDB) ListTasks(ctx context.Context, req *tes.ListTasksRequest) (*t opts = opts.SetProjection(minimalView) } - cursor, err := db.tasks(db.client).Find(context.TODO(), query, opts) + mctx, cancel := db.wrap(ctx) + defer cancel() + + cursor, err := db.tasks().Find(mctx, query, opts) if err != nil { return nil, err } + mctx, cancel = db.wrap(ctx) + defer cancel() + var tasks []*tes.Task - err = cursor.All(context.TODO(), &tasks) + err = cursor.All(mctx, &tasks) if err != nil { return nil, err } @@ -86,8 +113,7 @@ func (db *MongoDB) ListTasks(ctx context.Context, req *tes.ListTasksRequest) (*t out := tes.ListTasksResponse{ Tasks: tasks, } - // TODO figure out when not to return a next page token - if len(tasks) > 0 { + if len(tasks) == pageSize { out.NextPageToken = tasks[len(tasks)-1].Id } diff --git a/deployments/kubernetes/helm/values.yaml b/deployments/kubernetes/helm/values.yaml index 1860b2c60..e6b005845 100644 --- a/deployments/kubernetes/helm/values.yaml +++ b/deployments/kubernetes/helm/values.yaml @@ -203,7 +203,7 @@ Datastore: MongoDB: # Addrs holds the addresses for the seed servers. Addrs: - - mongodb://localhost + - localhost # Database is the database name used within MongoDB to store funnel data. Database: funnel # Timeout is the amount of time to wait for a server to respond when diff --git a/docs/funnel-config-examples/default-config.yaml b/docs/funnel-config-examples/default-config.yaml index 21b7b409a..675dc45db 100644 --- a/docs/funnel-config-examples/default-config.yaml +++ b/docs/funnel-config-examples/default-config.yaml @@ -32,15 +32,46 @@ Server: # If used, make sure to properly restrict access to the config file # (e.g. chmod 600 funnel.config.yml) # BasicAuth: + # - User: admin + # Password: oejf023moq + # Admin: true # - User: user1 # Password: abc123 # - User: user2 # Password: foobar + # Require Bearer JWT authentication for the server APIs. + # Server won't launch when configuration URL cannot be loaded. + # OidcAuth: + # # URL of the OIDC service configuration (activates OIDC configuration): + # # Example: https://example.org/oidc/.well-knwon/openid-configuration + # ServiceConfigURL: + # # Client ID and secret are sent with the token introspection request + # ClientId: + # ClientSecret: + # # The URL where OIDC should redirect after login (keep the path '/login') + # RedirectURL: "http://localhost:8000/login" + # # Optional: if specified, this scope value must be in the token: + # RequireScope: + # # Optional: if specified, this audience value must be in the token: + # RequireAudience: + # # List of usernames (JWT sub-claim) to be granted Admin-role: + # Admins: + # - admin.username.one@example.org + # - admin.username.two@example.org + # Include a "Cache-Control: no-store" HTTP header in Get/List responses # to prevent caching by intermediary services. DisableHTTPCache: true + # Defines task access and visibility by options: + # "All" (default) - all tasks are visible to everyone + # "Owner" - tasks are visible to the users who created them + # "OwnerOrAdmin" - extends "Owner" by allowing Admin-users see everything + # Owner is the username associated with the task. + # Owners (usernames) were not recorded to tasks before Funnel 0.11.1. + TaskAccess: All + RPCClient: # RPC server address ServerAddress: localhost:9090 @@ -141,6 +172,16 @@ Elastic: IndexPrefix: funnel # URL of the elasticsearch server. URL: http://localhost:9200 + # Optional. Username for HTTP Basic Authentication. + Username: + # Optional. Password for HTTP Basic Authentication. + Password: + # Optional. Endpoint for the Elastic Service (https://elastic.co/cloud). + CloudID: + # Optional. Base64-encoded token for authorization; if set, overrides username/password and service token. + APIKey: + # Optional. Service token for authorization; if set, overrides username/password. + ServiceToken: # Google Cloud Datastore task database. Datastore: @@ -153,7 +194,7 @@ Datastore: MongoDB: # Addrs holds the addresses for the seed servers. Addrs: - - mongodb://localhost + - localhost # Database is the database name used within MongoDB to store funnel data. Database: funnel # Timeout is the amount of time to wait for a server to respond when @@ -299,7 +340,7 @@ AWSBatch: # Kubernetes describes the configuration for the Kubernetes compute backend. Kubernetes: # The executor used to execute tasks. Available executors: docker, kubernetes - Executor: "kubernetes" + Executor: "kubernetes" # Turn off task state reconciler. When enabled, Funnel communicates with Kubernetes # to find tasks that are stuck in a queued state or errored and # updates the task state accordingly. diff --git a/go.mod b/go.mod index 7ac145690..399d7a702 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,6 @@ module github.com/ohsu-comp-bio/funnel go 1.23.0 -toolchain go1.23.2 - require ( cloud.google.com/go/datastore v1.20.0 cloud.google.com/go/pubsub v1.47.0 @@ -51,7 +49,6 @@ require ( github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.6 github.com/stretchr/testify v1.10.0 - go.mongodb.org/mongo-driver v1.17.2 golang.org/x/crypto v0.32.0 golang.org/x/net v0.34.0 golang.org/x/oauth2 v0.26.0 @@ -66,6 +63,7 @@ require ( ) require ( + github.com/elastic/go-elasticsearch/v8 v8.17.0 github.com/gogo/protobuf v1.3.2 github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.0 github.com/hashicorp/go-multierror v1.1.1 @@ -95,6 +93,7 @@ require ( github.com/eapache/go-resiliency v1.7.0 // indirect github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3 // indirect github.com/eapache/queue v1.1.0 // indirect + github.com/elastic/elastic-transport-go/v8 v8.6.0 // indirect github.com/emicklei/go-restful/v3 v3.12.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect @@ -139,7 +138,6 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect - github.com/montanaflynn/stats v0.7.1 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/nsf/termbox-go v1.1.1 // indirect github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80 // indirect diff --git a/go.sum b/go.sum index bc4829bd2..a550de185 100644 --- a/go.sum +++ b/go.sum @@ -37,7 +37,6 @@ github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2 h1:7Ip0wMmLHLRJdrloDxZfhMm0xrLXZS8+COSu2bXmEQs= github.com/armon/circbuf v0.0.0-20190214190532-5111143e8da2/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= -github.com/aws/aws-sdk-go v1.29.11/go.mod h1:1KvfttTE3SPKMpo8g2c6jL3ZKfXtFvKscTgahTma5Xg= github.com/aws/aws-sdk-go v1.55.6 h1:cSg4pvZ3m8dgYcgqB97MrcdjUmZ1BeMYKUxMMB89IPk= github.com/aws/aws-sdk-go v1.55.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= @@ -98,6 +97,10 @@ github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3 h1:Oy0F4A github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3/go.mod h1:YvSRo5mw33fLEx1+DlK6L2VV43tJt5Eyel9n9XBcR+0= github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= +github.com/elastic/elastic-transport-go/v8 v8.6.0 h1:Y2S/FBjx1LlCv5m6pWAF2kDJAHoSjSRSJCApolgfthA= +github.com/elastic/elastic-transport-go/v8 v8.6.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk= +github.com/elastic/go-elasticsearch/v8 v8.17.0 h1:e9cWksE/Fr7urDRmGPGp47Nsp4/mvNOrU8As1l2HQQ0= +github.com/elastic/go-elasticsearch/v8 v8.17.0/go.mod h1:lGMlgKIbYoRvay3xWBeKahAiJOgmFDsjZC39nmO3H64= github.com/elazarl/go-bindata-assetfs v1.0.1 h1:m0kkaHRKEu7tUIUFVwhGGGYClXvyl4RE03qmvRTNfbw= github.com/elazarl/go-bindata-assetfs v1.0.1/go.mod h1:v+YaWX3bdea5J/mo8dSETolEo7R71Vk1u8bnjau5yw4= github.com/emicklei/go-restful/v3 v3.12.1 h1:PJMDIM/ak7btuL8Ex0iYET9hxM3CI2sjZtzpL63nKAU= @@ -145,7 +148,6 @@ github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= -github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.6.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= @@ -159,7 +161,6 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69 github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f h1:16RtHeWGkJMc80Etb8RPCcKevXGldr57+LOyZt8zOlg= github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f/go.mod h1:ijRvpgDJDI262hYq/IQVYgf8hd8IHUs93Ol0kvMBAx4= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -248,7 +249,6 @@ github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZ github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= github.com/jlaffaye/ftp v0.2.0 h1:lXNvW7cBu7R/68bknOX3MrRIIqZ61zELs1P2RAiA3lg= github.com/jlaffaye/ftp v0.2.0/go.mod h1:is2Ds5qkhceAPy2xD6RLI6hmp/qysSoymZ+Z2uTnspI= -github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= @@ -291,7 +291,6 @@ github.com/logrusorgru/aurora v2.0.3+incompatible h1:tOpm7WcpBTn4fjmVfgpQq0EfczG github.com/logrusorgru/aurora v2.0.3+incompatible/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4= github.com/magiconair/properties v1.7.4-0.20170902060319-8d7837e64d3c/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/mailru/easyjson v0.7.1/go.mod h1:KAzv3t3aY1NaHWoQz1+4F1ccyAH66Jk7yos7ldAVICs= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/maruel/panicparse v1.6.2 h1:tZuGQTlbOY5jCprrWMJTikREqKPn+UAKdR4CHSpj834= @@ -324,8 +323,6 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= -github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= -github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= @@ -340,7 +337,6 @@ github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80 h1:nZspmSkneBbtxU9Top github.com/oasdiff/yaml v0.0.0-20241210131133-6b86fb107d80/go.mod h1:7tFDb+Y51LcDpn26GccuUgQXUk6t0CXZsivKjyimYX8= github.com/oasdiff/yaml3 v0.0.0-20241210130736-a94c01f36349 h1:t05Ww3DxZutOqbMN+7OIuqDwXbhl32HiZGpLy26BAPc= github.com/oasdiff/yaml3 v0.0.0-20241210130736-a94c01f36349/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o= -github.com/olivere/elastic/v7 v7.0.12/go.mod h1:14rWX28Pnh3qCKYRVnSGXWLf9MbLonYS/4FDCY3LAPo= github.com/onsi/ginkgo/v2 v2.21.0 h1:7rg/4f3rB88pb5obDgNZrNHrQ4e6WpjonchcpuBRnZM= github.com/onsi/ginkgo/v2 v2.21.0/go.mod h1:7Du3c42kxCUegi0IImZ1wUQzMBVecgIHjR1C+NkhLQo= github.com/onsi/gomega v1.35.1 h1:Cwbd75ZBPxFSuZ6T+rN/WCb/gOc6YgFBXLlZLhC7Ds4= @@ -394,9 +390,6 @@ github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMT github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/smartystreets/assertions v1.0.1/go.mod h1:kHHU4qYBaI3q23Pp3VPrmWhuIUrLW/7eUrw0BU5VaoM= -github.com/smartystreets/go-aws-auth v0.0.0-20180515143844-0c1422d1fdb9/go.mod h1:SnhjPscd9TpLiy1LpzGSKh3bXCfxxXuqd9xmQJy3slM= -github.com/smartystreets/gunit v1.1.3/go.mod h1:EH5qMBab2UclzXUcpR8b93eHsIlp9u+pDQIRp5DZNzQ= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= @@ -516,7 +509,6 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= @@ -533,7 +525,6 @@ golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht golang.org/x/sync v0.0.0-20170517211232-f52d1811a629/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -545,7 +536,6 @@ golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -605,7 +595,6 @@ google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7 google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/genproto v0.0.0-20170918111702-1e559d0a00ee/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200423170343-7949de9c1215/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= @@ -617,7 +606,6 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20250127172529-29210b9bc287 h1: google.golang.org/genproto/googleapis/rpc v0.0.0-20250127172529-29210b9bc287/go.mod h1:8BS3B93F/U1juMFq9+EDk+qOT5CO1R9IzXxG3PTqiRk= google.golang.org/grpc v1.2.1-0.20170921194603-d4b75ebd4f9f/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= @@ -645,8 +633,6 @@ gopkg.in/evanphx/json-patch.v4 v4.12.0 h1:n6jtcsulIzXPJaxegRbvFNNrZDjbij7ny3gmSP gopkg.in/evanphx/json-patch.v4 v4.12.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= -gopkg.in/olivere/elastic.v5 v5.0.86 h1:xFy6qRCGAmo5Wjx96srho9BitLhZl2fcnpuidPwduXM= -gopkg.in/olivere/elastic.v5 v5.0.86/go.mod h1:M3WNlsF+WhYn7api4D87NIflwTV/c0iVs8cqfWhK+68= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/server/auth.go b/server/auth.go index edb1ebc19..f730432dc 100644 --- a/server/auth.go +++ b/server/auth.go @@ -2,7 +2,9 @@ package server import ( "encoding/base64" + "fmt" "net/http" + "os" "strings" "github.com/ohsu-comp-bio/funnel/config" @@ -14,15 +16,26 @@ import ( ) type Authentication struct { - basic map[string]string - oidc *OidcConfig + admins map[string]bool + basic map[string]string + oidc *OidcConfig } +const ( + AccessAll = "All" + AccessOwner = "Owner" + AccessOwnerOrAdmin = "OwnerOrAdmin" +) + // Extracted info about the current user, which is exposed through Context. type UserInfo struct { // Public users are non-authenticated, in case Funnel configuration does // not require OIDC nor Basic authentication. IsPublic bool + // Administrators are defined by the configuration file: + // 1) Basic-authentication: a user with the `Admin: true` property. + // 2) OIDC-authentication: one of the usernames under the `Admins` property. + IsAdmin bool // Username of an authenticated user (subject field from JWT). Username string // In case of OIDC authentication, the provided Bearer token, which can be @@ -39,22 +52,54 @@ var ( errTokenRequired = status.Errorf(codes.Unauthenticated, "Basic/Bearer authorization token missing") errInvalidBasicToken = status.Errorf(codes.Unauthenticated, "Basic-authentication failed") errInvalidBearerToken = status.Errorf(codes.Unauthenticated, "Bearer authorization token not accepted") - publicUserInfo = UserInfo{IsPublic: true, Username: ""} + publicUserInfo = UserInfo{IsPublic: true, IsAdmin: false, Username: ""} + systemUserInfo = UserInfo{IsPublic: false, IsAdmin: true, Username: "SYSTEM"} UserInfoKey = userInfoContextKey("user-info") + accessMode = AccessAll ) -func NewAuthentication(creds []config.BasicCredential, oidc config.OidcAuth) *Authentication { +func GetUser(ctx context.Context) *UserInfo { + if userInfo, ok := ctx.Value(UserInfoKey).(*UserInfo); ok { + return userInfo + } + return &systemUserInfo +} + +func GetUsername(ctx context.Context) string { + return GetUser(ctx).Username +} + +func NewAuthentication( + creds []config.BasicCredential, + oidc config.OidcAuth, + taskAccess string, +) *Authentication { basicCreds := make(map[string]string) + adminUsers := make(map[string]bool) for _, cred := range creds { credBytes := []byte(cred.User + ":" + cred.Password) fullValue := "Basic " + base64.StdEncoding.EncodeToString(credBytes) basicCreds[fullValue] = cred.User + if cred.Admin { + adminUsers[cred.User] = true + } + } + + if taskAccess == AccessAll || taskAccess == AccessOwner || taskAccess == AccessOwnerOrAdmin { + accessMode = taskAccess + } else if taskAccess == "" { + accessMode = AccessAll + } else { + fmt.Printf("[ERROR] Bad configuration value for Server.TaskAccess (%s). "+ + "Expected 'All', 'Owner', or 'OwnerOrAdmin'.\n", accessMode) + os.Exit(1) } return &Authentication{ - basic: basicCreds, - oidc: initOidcConfig(oidc), + admins: adminUsers, + basic: basicCreds, + oidc: initOidcConfig(oidc), } } @@ -91,16 +136,15 @@ func (a *Authentication) Interceptor( authorized = username != "" if authorized { - ctx = context.WithValue(ctx, UserInfoKey, &UserInfo{Username: username}) + isAdmin := a.admins[username] + ctx = context.WithValue(ctx, UserInfoKey, + &UserInfo{Username: username, IsAdmin: isAdmin}) } } else if a.oidc != nil && strings.HasPrefix(authorization, "Bearer ") { authErr = errInvalidBearerToken - jwtString := strings.TrimPrefix(authorization, "Bearer ") - subject := a.oidc.ParseJwtSubject(jwtString) - authorized = subject != "" - - if authorized { - ctx = context.WithValue(ctx, UserInfoKey, &UserInfo{Username: subject, Token: jwtString}) + if userInfo := a.oidc.Authorize(authorization); userInfo != nil { + ctx = context.WithValue(ctx, UserInfoKey, userInfo) + authorized = true } } @@ -155,3 +199,33 @@ func (a *Authentication) handleBasicAuth(w http.ResponseWriter, req *http.Reques http.Error(w, msg, http.StatusUnauthorized) } } + +// Reports whether the current user can access data with the specified owner. +// Evaluation depends on configuration (Server.TaskAccess), current username, +// and the username recorded in the task. For public users and unknown task +// owners, the username is an empty string. +func (u *UserInfo) IsAccessible(dataOwner string) bool { + if *u == systemUserInfo || accessMode == AccessAll { + return true + } + + isOwner := u != nil && u.Username == dataOwner + if accessMode == AccessOwner { + return isOwner + } + + if accessMode == AccessOwnerOrAdmin { + return isOwner || u != nil && u.IsAdmin + } + + return false +} + +// Reports whether the current user can access all tasks considering the +// configuration (Server.TaskAccess) and whether the user has Admin status. +// If the result is false, data access must be verified (see: IsAccessible). +func (u *UserInfo) CanSeeAllTasks() bool { + return *u == systemUserInfo || + accessMode == AccessAll || + accessMode == AccessOwnerOrAdmin && u != nil && u.IsAdmin +} diff --git a/server/auth_oidc.go b/server/auth_oidc.go index 7da907762..c2d9cf2eb 100644 --- a/server/auth_oidc.go +++ b/server/auth_oidc.go @@ -37,6 +37,7 @@ type IntrospectionResponse struct { // OIDC configuration structure used for validating input from request. type OidcConfig struct { local config.OidcAuth + admins map[string]bool remote OidcRemoteConfig oauth2 oauth2.Config jwks jwk.Cache @@ -87,6 +88,11 @@ func (c *OidcConfig) initConfig() { c.oauth2.Endpoint.AuthStyle = oauth2.AuthStyleInParams c.oauth2.Endpoint.AuthURL = c.remote.AuthorizationEndpoint c.oauth2.Endpoint.TokenURL = c.remote.TokenEndpoint + + c.admins = map[string]bool{} + for _, username := range c.local.Admins { + c.admins[username] = true + } } func (c *OidcConfig) initJwks() { @@ -203,6 +209,17 @@ func (c *OidcConfig) computeState(req *http.Request) string { return hex.EncodeToString(b)[:10] } +func (c *OidcConfig) Authorize(authorization string) *UserInfo { + jwtString := strings.TrimPrefix(authorization, "Bearer ") + subject := c.ParseJwtSubject(jwtString) + isAdmin := c.admins[subject] + + if subject == "" { + return nil + } + return &UserInfo{Username: subject, Token: jwtString, IsAdmin: isAdmin} +} + func (c *OidcConfig) ParseJwtSubject(jwtString string) string { keySet, err := c.jwks.Get(context.Background(), c.remote.JwksURI) if err != nil { diff --git a/server/marshal.go b/server/marshal.go index 0da374224..2de3938d2 100644 --- a/server/marshal.go +++ b/server/marshal.go @@ -77,6 +77,7 @@ func (mclean *CustomMarshal) MarshalList(list *tes.ListTasksResponse) ([]byte, e for _, task := range list.Tasks { minTask := mclean.TranslateTask(task, view).(*tes.TaskMin) minList.Tasks = append(minList.Tasks, minTask) + minList.NextPageToken = list.NextPageToken } return mclean.m.Marshal(minList) } @@ -86,6 +87,7 @@ func (mclean *CustomMarshal) MarshalList(list *tes.ListTasksResponse) ([]byte, e for _, task := range list.Tasks { basicTask := mclean.TranslateTask(task, view).(*tes.TaskBasic) basicList.Tasks = append(basicList.Tasks, basicTask) + basicList.NextPageToken = list.NextPageToken } return mclean.m.Marshal(basicList) } diff --git a/server/server.go b/server/server.go index 60897c973..a69dac903 100644 --- a/server/server.go +++ b/server/server.go @@ -32,6 +32,7 @@ type Server struct { HTTPPort string BasicAuth []config.BasicCredential OidcAuth config.OidcAuth + TaskAccess string Tasks tes.TaskServiceServer Events events.EventServiceServer Nodes scheduler.SchedulerServiceServer @@ -124,7 +125,7 @@ func (s *Server) Serve(pctx context.Context) error { return err } - auth := NewAuthentication(s.BasicAuth, s.OidcAuth) + auth := NewAuthentication(s.BasicAuth, s.OidcAuth, s.TaskAccess) grpcServer := grpc.NewServer( grpc.UnaryInterceptor( diff --git a/server/tes.go b/server/tes.go index 20b705d27..351c1d826 100644 --- a/server/tes.go +++ b/server/tes.go @@ -79,18 +79,25 @@ func (ts *TaskService) ListTasks(ctx context.Context, req *tes.ListTasksRequest) // CancelTask cancels a task func (ts *TaskService) CancelTask(ctx context.Context, req *tes.CancelTaskRequest) (*tes.CancelTaskResponse, error) { + result := &tes.CancelTaskResponse{} + + // updated database and other event streams (includes access-checking) + err := ts.Event.WriteEvent(ctx, events.NewState(req.Id, tes.Canceled)) + if err == tes.ErrNotFound { + return result, status.Errorf(codes.NotFound, "%v: taskID: %s", err.Error(), req.Id) + } else if err == tes.ErrNotPermitted { + return result, status.Errorf(codes.PermissionDenied, "%v: taskID: %s", err.Error(), req.Id) + } else if err != nil { + return result, err + } + // dispatch to compute backend - err := ts.Compute.WriteEvent(ctx, events.NewState(req.Id, tes.Canceled)) + err = ts.Compute.WriteEvent(ctx, events.NewState(req.Id, tes.Canceled)) if err != nil { ts.Log.Error("compute backend failed to cancel task", "taskID", req.Id, "error", err) } - // updated database and other event streams - err = ts.Event.WriteEvent(ctx, events.NewState(req.Id, tes.Canceled)) - if err == tes.ErrNotFound { - err = status.Errorf(codes.NotFound, "%v: taskID: %s", err.Error(), req.Id) - } - return &tes.CancelTaskResponse{}, err + return result, err } // GetServiceInfo returns service metadata. diff --git a/tes/utils.go b/tes/utils.go index 936118562..ab2e29a77 100644 --- a/tes/utils.go +++ b/tes/utils.go @@ -54,6 +54,10 @@ func Base64Decode(raw string) (*Task, error) { var ErrNotFound = errors.New("task not found") var ErrConcurrentStateChange = errors.New("Concurrent stage change") +// ErrNotPermitted is returned when the owner of a task does not match the +// current non-admin user. +var ErrNotPermitted = errors.New("permission denied") + // Shorthand for task views const ( Minimal = View_MINIMAL @@ -167,7 +171,7 @@ func GetPageSize(reqSize int32) int { // default page size var pageSize = 256 - if reqSize != 0 { + if reqSize > 0 { pageSize = int(reqSize) // max page size diff --git a/tes/utils_test.go b/tes/utils_test.go index b903286e6..d4d13dc7a 100644 --- a/tes/utils_test.go +++ b/tes/utils_test.go @@ -16,21 +16,11 @@ func TestBase64Encode(t *testing.T) { }, } - // TODO: Investigate strange whitespace behavior in Github Actions - // expected := "ewogICJleGVjdXRvcnMiOiBbCiAgICB7CiAgICAgICJjb21tYW5kIjogWwogICAgICAgICJlY2hvIiwKICAgICAgICAiaGVsbG8gd29ybGQiCiAgICAgIF0sCiAgICAgICJpbWFnZSI6ICJhbHBpbmUiCiAgICB9CiAgXSwKICAiaWQiOiAidGFzazEiCn0=" - expected := "ewogICJleGVjdXRvcnMiOiAgWwogICAgewogICAgICAiY29tbWFuZCI6ICBbCiAgICAgICAgImVjaG8iLAogICAgICAgICJoZWxsbyB3b3JsZCIKICAgICAgXSwKICAgICAgImltYWdlIjogICJhbHBpbmUiCiAgICB9CiAgXSwKICAiaWQiOiAgInRhc2sxIgp9" - encoded, err := Base64Encode(task) if err != nil { t.Fatal(err) } - if encoded != expected { - t.Logf("expected: %+v", expected) - t.Logf("actual: %+v", encoded) - t.Fatal("unexpected value returned from Base64Encode") - } - decoded, err := Base64Decode(encoded) if err != nil { t.Fatal(err) diff --git a/tests/auth/auth_test.go b/tests/auth/auth_test.go index 43aaa4fd1..a7d38876e 100644 --- a/tests/auth/auth_test.go +++ b/tests/auth/auth_test.go @@ -33,27 +33,27 @@ func TestBasicAuthFail(t *testing.T) { Id: "1", View: tes.View_MINIMAL.String(), }) - if err == nil || !strings.Contains(err.Error(), "STATUS CODE - 403") { - t.Fatal("expected error") + if err == nil || !strings.Contains(err.Error(), "STATUS CODE - 401") { + t.Fatal("expected error", err) } _, err = fun.HTTP.ListTasks(ctx, &tes.ListTasksRequest{ View: tes.View_MINIMAL.String(), }) - if err == nil || !strings.Contains(err.Error(), "STATUS CODE - 403") { - t.Fatal("expected error") + if err == nil || !strings.Contains(err.Error(), "STATUS CODE - 401") { + t.Fatal("expected error", err) } _, err = fun.HTTP.CreateTask(ctx, extask) - if err == nil || !strings.Contains(err.Error(), "STATUS CODE - 403") { - t.Fatal("expected error") + if err == nil || !strings.Contains(err.Error(), "STATUS CODE - 401") { + t.Fatal("expected error", err) } _, err = fun.HTTP.CancelTask(ctx, &tes.CancelTaskRequest{ Id: "1", }) - if err == nil || !strings.Contains(err.Error(), "STATUS CODE - 403") { - t.Fatal("expected error") + if err == nil || !strings.Contains(err.Error(), "STATUS CODE - 401") { + t.Fatal("expected error", err) } // RPC client @@ -61,27 +61,27 @@ func TestBasicAuthFail(t *testing.T) { Id: "1", View: tes.View_MINIMAL.String(), }) - if err == nil || !strings.Contains(err.Error(), "PermissionDenied") { - t.Fatal("expected error") + if err == nil || !strings.Contains(err.Error(), "Unauthenticated") { + t.Fatal("expected error", err) } _, err = fun.RPC.ListTasks(ctx, &tes.ListTasksRequest{ View: tes.View_MINIMAL.String(), }) - if err == nil || !strings.Contains(err.Error(), "PermissionDenied") { - t.Fatal("expected error") + if err == nil || !strings.Contains(err.Error(), "Unauthenticated") { + t.Fatal("expected error", err) } _, err = fun.RPC.CreateTask(ctx, tests.HelloWorld()) - if err == nil || !strings.Contains(err.Error(), "PermissionDenied") { - t.Fatal("expected error") + if err == nil || !strings.Contains(err.Error(), "Unauthenticated") { + t.Fatal("expected error", err) } _, err = fun.RPC.CancelTask(ctx, &tes.CancelTaskRequest{ Id: "1", }) - if err == nil || !strings.Contains(err.Error(), "PermissionDenied") { - t.Fatal("expected error") + if err == nil || !strings.Contains(err.Error(), "Unauthenticated") { + t.Fatal("expected error", err) } } diff --git a/tests/core/basic_test.go b/tests/core/basic_test.go index bbcca85f0..af2d7f83c 100644 --- a/tests/core/basic_test.go +++ b/tests/core/basic_test.go @@ -300,14 +300,16 @@ func TestListTaskView(t *testing.T) { `) fun.Wait(id) + time.Sleep(500 * time.Millisecond) + tasks = fun.ListView(tes.View_MINIMAL) task = tasks[0] - if task.Id == "" { + if task.Id != id { t.Fatal("expected task ID in minimal view") } - if task.State == tes.State_UNKNOWN { - t.Fatal("expected state in minimal view") + if task.State != tes.State_COMPLETE { + t.Fatal("expected the COMPLETE state in minimal view, got ", task.State) } if task.Name != "" { t.Fatal("unexpected task name included in minimal view") @@ -328,11 +330,11 @@ func TestListTaskView(t *testing.T) { tasks = fun.ListView(tes.View_BASIC) task = tasks[0] - if task.Id == "" { + if task.Id != id { t.Fatal("expected task ID in basic view") } - if task.State == tes.State_UNKNOWN { - t.Fatal("expected state in basic view") + if task.State != tes.State_COMPLETE { + t.Fatal("expected the COMPLETE state in basic view, got ", task.State) } if task.Name == "" { t.Fatal("expected task name to be included basic view") @@ -365,11 +367,11 @@ func TestListTaskView(t *testing.T) { tasks = fun.ListView(tes.View_FULL) task = tasks[0] - if task.Id == "" { + if task.Id != id { t.Fatal("expected task ID in full view") } - if task.State == tes.State_UNKNOWN { - t.Fatal("expected state in full view") + if task.State != tes.State_COMPLETE { + t.Fatal("expected the COMPLETE state in full view, got ", task.State) } if task.Name == "" { t.Fatal("expected task name to be included full view") @@ -525,9 +527,19 @@ func TestSingleCharLog(t *testing.T) { --sh 'echo a; sleep 100' `) fun.WaitForRunning(id) - time.Sleep(time.Second * 5) - task := fun.Get(id) - if task.Logs[0].Logs[0].Stdout != "a\n" { + + // The EXECUTOR_STDOUT event may take some time, so let's wait max 10 seconds + stdout := "" + for range time.NewTicker(10 * time.Second).C { + task := fun.Get(id) + stdout = task.Logs[0].Logs[0].Stdout + if stdout != "" { + t.Log("Non-empty Stdout detected.") + break + } + } + + if stdout != "a\n" { t.Fatal("Missing logs") } fun.Cancel(id) @@ -793,10 +805,17 @@ func TestSmallPaginationAndSortOrder(t *testing.T) { t.Fatal("expected empty database") } - for i := 0; i < 150; i++ { - f.Run(`--sh 'echo 1'`) + taskIds := make([]string, 150) + for i := range 150 { + taskIds[149-i] = f.Run(`--sh 'echo 1'`) + } + + request := &tes.GetTaskRequest{View: tes.View_BASIC.String()} + for _, taskId := range taskIds { + request.Id = taskId + task, err := f.RPC.GetTask(ctx, request) + t.Log("GetTask", request.Id, task.State.String(), err) } - time.Sleep(time.Second * 5) r4, err := f.RPC.ListTasks(ctx, &tes.ListTasksRequest{ PageSize: 50, @@ -896,6 +915,8 @@ func TestListTaskFilterState(t *testing.T) { } f.Wait(id3) + time.Sleep(500 * time.Millisecond) + r, err := f.HTTP.ListTasks(ctx, &tes.ListTasksRequest{ View: tes.Full.String(), }) @@ -933,10 +954,10 @@ func TestListTaskFilterState(t *testing.T) { t.Fatal(err) } if len(r.Tasks) != 1 { - t.Error("expected 1 tasks", r.Tasks) + t.Fatal("expected 1 tasks", r.Tasks) } if r.Tasks[0].Id != id3 { - t.Error("unexpected canceled task IDs", r.Tasks) + t.Fatal("unexpected canceled task IDs", r.Tasks) } } @@ -1040,6 +1061,8 @@ func TestListTaskMultipleFilters(t *testing.T) { } f.Wait(id3) + time.Sleep(500 * time.Millisecond) + r, err := f.HTTP.ListTasks(ctx, &tes.ListTasksRequest{ View: tes.View_FULL.String(), }) @@ -1081,10 +1104,10 @@ func TestListTaskMultipleFilters(t *testing.T) { t.Fatal(err) } if len(r.Tasks) != 1 { - t.Error("expected 1 tasks", r.Tasks) + t.Fatal("expected 1 tasks", r.Tasks) } if r.Tasks[0].Id != id2 { - t.Error("unexpected task IDs", r.Tasks) + t.Fatal("unexpected task IDs", r.Tasks) } r, _ = f.HTTP.ListTasks(ctx, &tes.ListTasksRequest{ @@ -1097,7 +1120,7 @@ func TestListTaskMultipleFilters(t *testing.T) { t.Fatal(err) } if len(r.Tasks) != 0 { - t.Error("expected 0 tasks", r.Tasks) + t.Fatal("expected 0 tasks", r.Tasks) } r, _ = f.HTTP.ListTasks(ctx, &tes.ListTasksRequest{ @@ -1110,10 +1133,10 @@ func TestListTaskMultipleFilters(t *testing.T) { t.Fatal(err) } if len(r.Tasks) != 1 { - t.Error("expected 1 tasks", r.Tasks) + t.Fatal("expected 1 tasks", r.Tasks) } if r.Tasks[0].Id != id3 { - t.Error("unexpected task IDs", r.Tasks) + t.Fatal("unexpected task IDs", r.Tasks) } } @@ -1216,7 +1239,9 @@ func TestMetadataEvent(t *testing.T) { } if len(task.Logs[0].Metadata) != 3 { - t.Error("unexpected number of items in task metadata", task) + t.Errorf("expected 3 items in task metadata, got %d: %s", + len(task.Logs[0].Metadata), + task.Logs[0].Metadata) } for k, v := range task.Logs[0].Metadata { diff --git a/tests/core/task_access_test.go b/tests/core/task_access_test.go new file mode 100644 index 000000000..2df290062 --- /dev/null +++ b/tests/core/task_access_test.go @@ -0,0 +1,218 @@ +package core + +import ( + "context" + "strings" + "testing" + + "github.com/ohsu-comp-bio/funnel/config" + "github.com/ohsu-comp-bio/funnel/server" + "github.com/ohsu-comp-bio/funnel/tes" + "github.com/ohsu-comp-bio/funnel/tests" + "google.golang.org/grpc/status" +) + +// These tests verify that access to tasks can be controlled through the +// configuration parameter: Server.TaskAccess (All, Owner, OwnerOrAdmin). + +func TestTaskAccessAll(t *testing.T) { + f := initServer(t, server.AccessAll) + + // STEP 1: Create a task for each user + f.SwitchUser("User1") + task1Id := f.Run(`--sh 'echo 1' --tag scope=TestTaskAccessAll`) + f.SwitchUser("User2") + task2Id := f.Run(`--sh 'echo 1' --tag scope=TestTaskAccessAll`) + + // STEP 2: Both users should see the tasks (get) + checkTaskGet(t, f, "User1", task1Id, true) + checkTaskGet(t, f, "User1", task2Id, true) + + checkTaskGet(t, f, "User2", task1Id, true) + checkTaskGet(t, f, "User2", task2Id, true) + + // STEP 3: Both users should see the tasks (list) + listTasksFilter := &tes.ListTasksRequest{ + TagKey: []string{"scope"}, + TagValue: []string{"TestTaskAccessAll"}, + } + + checkTaskList(t, f, "User1", listTasksFilter, 2) + checkTaskList(t, f, "User2", listTasksFilter, 2) + + // STEP 4: No user should get a permission denied error when cancelling the task + checkTaskCancel(t, f, "User1", task1Id, true) + checkTaskCancel(t, f, "User1", task2Id, true) + + checkTaskCancel(t, f, "User2", task1Id, true) + checkTaskCancel(t, f, "User2", task2Id, true) +} + +func TestTaskAccessOwner(t *testing.T) { + f := initServer(t, server.AccessOwner) + + // STEP 1: Create a task for each user + f.SwitchUser("User1") + task1Id := f.Run(`--sh 'echo 1' --tag scope=TestTaskAccessOwner`) + f.SwitchUser("User2") + task2Id := f.Run(`--sh 'echo 1' --tag scope=TestTaskAccessOwner`) + + // STEP 2: Both users should see just their own tasks (get) + checkTaskGet(t, f, "User1", task1Id, true) + checkTaskGet(t, f, "User1", task2Id, false) + + checkTaskGet(t, f, "User2", task1Id, false) + checkTaskGet(t, f, "User2", task2Id, true) + + // Even Admin-user cannot see the tasks: + checkTaskGet(t, f, "Admin", task1Id, false) + checkTaskGet(t, f, "Admin", task2Id, false) + + // STEP 3: Both users should see just their own tasks (list) + listTasksFilter := &tes.ListTasksRequest{ + TagKey: []string{"scope"}, + TagValue: []string{"TestTaskAccessOwner"}, + } + + checkTaskList(t, f, "User1", listTasksFilter, 1) + checkTaskList(t, f, "User2", listTasksFilter, 1) + checkTaskList(t, f, "Admin", listTasksFilter, 0) + + // STEP 4: Users get a permission denied error when cancelling a task of another user + checkTaskCancel(t, f, "User1", task1Id, true) + checkTaskCancel(t, f, "User1", task2Id, false) + + checkTaskCancel(t, f, "User2", task1Id, false) + checkTaskCancel(t, f, "User2", task2Id, true) + + // Even Admin-user cannot cancel the tasks: + checkTaskCancel(t, f, "Admin", task1Id, false) + checkTaskCancel(t, f, "Admin", task2Id, false) +} + +func TestTaskAccessOwnerOrAdmin(t *testing.T) { + f := initServer(t, server.AccessOwnerOrAdmin) + + // STEP 1: Create a task for each user + f.SwitchUser("User1") + task1Id := f.Run(`--sh 'echo 1' --tag scope=TestTaskAccessOwnerOrAdmin`) + f.SwitchUser("User2") + task2Id := f.Run(`--sh 'echo 1' --tag scope=TestTaskAccessOwnerOrAdmin`) + + // STEP 2: Both users should see just their own tasks (get) + checkTaskGet(t, f, "User1", task1Id, true) + checkTaskGet(t, f, "User1", task2Id, false) + + checkTaskGet(t, f, "User2", task1Id, false) + checkTaskGet(t, f, "User2", task2Id, true) + + // Admin-user can see ALL tasks: + checkTaskGet(t, f, "Admin", task1Id, true) + checkTaskGet(t, f, "Admin", task2Id, true) + + // STEP 3: Both users should see just their tasks (list) + listTasksFilter := &tes.ListTasksRequest{ + TagKey: []string{"scope"}, + TagValue: []string{"TestTaskAccessOwnerOrAdmin"}, + } + + checkTaskList(t, f, "User1", listTasksFilter, 1) + checkTaskList(t, f, "User2", listTasksFilter, 1) + checkTaskList(t, f, "Admin", listTasksFilter, 2) + + // STEP 4: Users get a permission denied error when cancelling a task of another user + checkTaskCancel(t, f, "User1", task1Id, true) + checkTaskCancel(t, f, "User1", task2Id, false) + + checkTaskCancel(t, f, "User2", task1Id, false) + checkTaskCancel(t, f, "User2", task2Id, true) + + // Admin-user can cancel ALL tasks: + checkTaskCancel(t, f, "Admin", task1Id, true) + checkTaskCancel(t, f, "Admin", task2Id, true) +} + +func initServer(t *testing.T, taskAccess string) *tests.Funnel { + tests.SetLogOutput(log, t) + + c := tests.DefaultConfig() + c.Compute = "noop" + c.Server.TaskAccess = taskAccess + + c.Server.BasicAuth = []config.BasicCredential{ + {User: "User1", Password: "user1-password"}, + {User: "User2", Password: "user2-password"}, + {User: "Admin", Password: "admin-password", Admin: true}, + } + + f := tests.NewFunnel(c) + f.StartServer() + return f +} + +func checkTaskGet(t *testing.T, f *tests.Funnel, username string, taskId string, expectSuccess bool) { + f.SwitchUser(username) + + // We are checking each task-view to be sure that view mode does not affect access. + for _, view := range tes.View_name { + t.Log("GetTask", taskId, "with", view, "view as", username) + + request := &tes.GetTaskRequest{Id: taskId, View: view} + response, err := f.RPC.GetTask(context.Background(), request) + + if expectSuccess { + if err != nil { + t.Fatal("expected GetTask to succeed but got error:", err) + } else if response.Id != taskId { + t.Fatal("GetTask to returned a different task ID:", response.Id) + } + } else if err == nil { + t.Fatal("expected GetTask to fail (permission denied) but there was no error", response) + } else { + checkPermissionDenied(t, err) + } + } +} + +func checkTaskList(t *testing.T, f *tests.Funnel, username string, request *tes.ListTasksRequest, expectedCount int) { + f.SwitchUser(username) + t.Log("ListTasks as", username) + + response, err := f.RPC.ListTasks(context.Background(), request) + + if err != nil { + t.Fatal("ListTasks returned an error:", err) + } + + if len(response.Tasks) != expectedCount { + t.Fatal("expected", expectedCount, "tasks, got", len(response.Tasks)) + } +} + +func checkTaskCancel(t *testing.T, f *tests.Funnel, username string, taskId string, expectSuccess bool) { + f.SwitchUser(username) + t.Log("CancelTask", taskId, "as", username) + + _, err := f.RPC.CancelTask(context.Background(), &tes.CancelTaskRequest{Id: taskId}) + + if expectSuccess { + if err != nil { + t.Fatal("expected CancelTask to fail with the state-change error but got:", err) + } + } else { + checkPermissionDenied(t, err) + } +} + +func checkPermissionDenied(t *testing.T, err error) { + s := status.Convert(err) + if s == nil { + t.Fatal("expected grpc status error but received:", err) + } + + expectedPrefix := tes.ErrNotPermitted.Error() + + if !strings.HasPrefix(s.Message(), expectedPrefix) { + t.Fatal("expected error-prefix [", expectedPrefix, "] but got:", s.Message()) + } +} diff --git a/tests/funnel_utils.go b/tests/funnel_utils.go index 07eef8d3d..07590550c 100644 --- a/tests/funnel_utils.go +++ b/tests/funnel_utils.go @@ -147,6 +147,22 @@ func (f *Funnel) PollForServerStart() error { } } +// Changes the user who will be making the RPC calls. +// Panics if that user is not defined in the configuration. +func (f *Funnel) SwitchUser(username string) { + for _, cred := range f.Conf.Server.BasicAuth { + if cred.User == username { + if f.Conf.RPCClient.User != username { + f.Conf.RPCClient.User = cred.User + f.Conf.RPCClient.Password = cred.Password + f.addRPCClient() + } + return + } + } + panic("Cannot switch to an undefined user: " + username) +} + // WaitForDockerDestroy waits for a "destroy" event // from docker for the given container ID // diff --git a/tests/mongo.config.yml b/tests/mongo.config.yml index e4e424fa2..5c56c0cbb 100644 --- a/tests/mongo.config.yml +++ b/tests/mongo.config.yml @@ -4,5 +4,5 @@ EventWriters: - log MongoDB: - Addrs: - - mongodb://localhost:27000 + Addrs: + - localhost:27000 diff --git a/website/content/docs/compute/aws-batch.md b/website/content/docs/compute/aws-batch.md index 20c580b8c..bebc256da 100644 --- a/website/content/docs/compute/aws-batch.md +++ b/website/content/docs/compute/aws-batch.md @@ -9,25 +9,25 @@ menu: # AWS Batch This guide covers deploying a Funnel server that leverages [DynamoDB][0] for storage -and [AWS Batch][1] for task execution. +and [AWS Batch][1] for task execution. ## Setup -Get started by creating a compute environment, job queue and job definition using either -the Funnel CLI or the AWS Batch web console. To manage the permissions of instanced -AWS Batch jobs create a new IAM role. For the Funnel configuration outlined +Get started by creating a compute environment, job queue and job definition using either +the Funnel CLI or the AWS Batch web console. To manage the permissions of instanced +AWS Batch jobs create a new IAM role. For the Funnel configuration outlined in this document, this role will need to provide read and write access to both S3 and DynamoDB. -_Note_: We recommend creating the Job Definition with Funnel by running: `funnel aws batch create-job-definition`. -Funnel expects the JobDefinition to start a Funnel worker process with a specific configuration. -Only advanced users should consider making any substantial changes to this Job Definition. +_Note_: We recommend creating the Job Definition with Funnel by running: `funnel aws batch create-job-definition`. +Funnel expects the JobDefinition to start a Funnel worker process with a specific configuration. +Only advanced users should consider making any substantial changes to this Job Definition. -AWS Batch tasks, by default, launch the ECS Optimized AMI which includes -an 8GB volume for the operating system and a 22GB volume for Docker image and metadata -storage. The default Docker configuration allocates up to 10GB of this storage to +AWS Batch tasks, by default, launch the ECS Optimized AMI which includes +an 8GB volume for the operating system and a 22GB volume for Docker image and metadata +storage. The default Docker configuration allocates up to 10GB of this storage to each container instance. [Read more about the default AMI][8]. Due to these limitations, we -recommend [creating a custom AMI][7]. Because AWS Batch has the same requirements for your -AMI as Amazon ECS, use the default Amazon ECS-optimized Amazon Linux AMI as a base and change it +recommend [creating a custom AMI][7]. Because AWS Batch has the same requirements for your +AMI as Amazon ECS, use the default Amazon ECS-optimized Amazon Linux AMI as a base and change it to better suit your tasks. ### Steps @@ -37,7 +37,7 @@ to better suit your tasks. * [Create an EC2ContainerTaskRole with policies for managing access to S3 and DynamoDB][5] * [Create a Job Definition][6] -For more information check out AWS Batch's [getting started guide][2]. +For more information check out AWS Batch's [getting started guide][2]. ### Quickstart @@ -51,8 +51,8 @@ This command will create a compute environment, job queue, IAM role and job defi ## Configuring the Funnel Server Below is an example configuration. Note that the `Key` -and `Secret` fields are left blank in the configuration of the components. This is because -Funnel will, by default, try to will try to automatically load credentials from the environment. +and `Secret` fields are left blank in the configuration of the components. This is because +Funnel will, by default, try to automatically load credentials from the environment. Alternatively, you may explicitly set the credentials in the config. ```YAML @@ -69,11 +69,11 @@ Dynamodb: Batch: JobDefinition: "funnel-job-def" - JobQueue: "funnel-job-queue" + JobQueue: "funnel-job-queue" Region: "us-west-2" Key: "" Secret: "" - + AmazonS3: Key: "" Secret: "" diff --git a/website/content/docs/databases/datastore.md b/website/content/docs/databases/datastore.md index 33694f604..ea31d8c6d 100644 --- a/website/content/docs/databases/datastore.md +++ b/website/content/docs/databases/datastore.md @@ -15,7 +15,10 @@ special requirements on the context of requests and requires a separate library. Two entity types are used, "Task" and "TaskPart" (for larger pieces of task content, such as stdout/err logs). -Funnel will, by default, try to will try to automatically load credentials from the environment. Alternatively, you may explicitly set the credentials in the config. +Funnel will, by default, try to automatically load credentials from the +environment. Alternatively, you may explicitly set the credentials in the config. +You can read more about providing the credentials +[here](https://cloud.google.com/docs/authentication/application-default-credentials). Config: ```yaml @@ -28,3 +31,64 @@ Datastore: # from the environment. CredentialsFile: "" ``` + +Please also import some [composite +indexes](https://cloud.google.com/datastore/docs/concepts/indexes?hl=en) +to support the task-list queries. +This is typically done through command-line by referencing an **index.yaml** +file (do not change the filename) with the following content: + +```shell +gcloud datastore indexes create path/to/index.yaml --database='funnel' +``` + +```yaml +indexes: + +- kind: Task + properties: + - name: Owner + - name: State + - name: TagStrings + - name: CreationTime + direction: desc + +- kind: Task + properties: + - name: Owner + - name: State + - name: CreationTime + direction: desc + +- kind: Task + properties: + - name: Owner + - name: TagStrings + - name: CreationTime + direction: desc + +- kind: Task + properties: + - name: Owner + - name: CreationTime + direction: desc + +- kind: Task + properties: + - name: State + - name: TagStrings + - name: CreationTime + direction: desc + +- kind: Task + properties: + - name: State + - name: CreationTime + direction: desc + +- kind: Task + properties: + - name: TagStrings + - name: CreationTime + direction: desc +``` \ No newline at end of file diff --git a/website/content/docs/databases/dynamodb.md b/website/content/docs/databases/dynamodb.md index 9fc583a42..3e536c217 100644 --- a/website/content/docs/databases/dynamodb.md +++ b/website/content/docs/databases/dynamodb.md @@ -8,7 +8,7 @@ menu: # DynamoDB Funnel supports storing task data in DynamoDB. Storing scheduler data is not supported currently, so using the node scheduler with DynamoDB won't work. Using AWS Batch for compute scheduling may be a better option. -Funnel will, by default, try to will try to automatically load credentials from the environment. Alternatively, you may explicitly set the credentials in the config. +Funnel will, by default, try to automatically load credentials from the environment. Alternatively, you may explicitly set the credentials in the config. Available Config: ```yaml diff --git a/website/content/docs/databases/elasticsearch.md b/website/content/docs/databases/elasticsearch.md index 40509ed53..e397348a5 100644 --- a/website/content/docs/databases/elasticsearch.md +++ b/website/content/docs/databases/elasticsearch.md @@ -7,7 +7,7 @@ menu: # Elasticsearch -Funnel supports storing tasks and scheduler data in Elasticsearch. +Funnel supports storing tasks and scheduler data in Elasticsearch (v8). Config: ```yaml @@ -17,4 +17,14 @@ Elastic: # Prefix to use for indexes IndexPrefix: "funnel" URL: http://localhost:9200 + # Optional. Username for HTTP Basic Authentication. + Username: + # Optional. Password for HTTP Basic Authentication. + Password: + # Optional. Endpoint for the Elastic Service (https://elastic.co/cloud). + CloudID: + # Optional. Base64-encoded token for authorization; if set, overrides username/password and service token. + APIKey: + # Optional. Service token for authorization; if set, overrides username/password. + ServiceToken: ``` diff --git a/website/content/docs/security/basic.md b/website/content/docs/security/basic.md index efb143a04..0b19e0714 100644 --- a/website/content/docs/security/basic.md +++ b/website/content/docs/security/basic.md @@ -7,19 +7,36 @@ menu: --- # Basic Auth -By default, a Funnel server allows open access to its API endpoints, but it -can be configured to require basic password authentication. To enable this, +By default, a Funnel server allows open access to its API endpoints, but it +can be configured to require basic password authentication. To enable this, include users and passwords in your config file: ```yaml Server: BasicAuth: + - User: admin + Password: someReallyComplexSecret + Admin: true - User: funnel Password: abc123 + + TaskAccess: OwnerOrAdmin ``` +The `TaskAccess` property configures the visibility and access-mode for tasks: + +* `All` (default) - all tasks are visible to everyone +* `Owner` - tasks are visible to the users who created them +* `OwnerOrAdmin` - extends `Owner` by allowing Admin-users (`Admin: true`) + access everything + +As new tasks are created, the username behind the request is recorded as the +owner of the task. Depending on the `TaskAccess` property, if owner-based +acces-mode is enabled, the owner of the task is compared to username of current +request to decide if the user may see and interact with the task. + If you are using BoltDB or Badger, the Funnel worker communicates to the server via gRPC -so you will also need to configure the RPC client. +so you will also need to configure the RPC client. ```yaml RPCClient: @@ -27,7 +44,7 @@ RPCClient: Password: abc123 ``` -Make sure to properly protect the configuration file so that it's not readable +Make sure to properly protect the configuration file so that it's not readable by everyone: ```bash diff --git a/website/content/docs/security/oauth2.md b/website/content/docs/security/oauth2.md index f1a7927a1..4b4232da7 100644 --- a/website/content/docs/security/oauth2.md +++ b/website/content/docs/security/oauth2.md @@ -21,8 +21,6 @@ token is still active (i.e., no token invalidation before expiring). Optionally, Funnel can also validate the scope and audience claims to contain specific values. -It is not possible to configure administrative users for the OAuth2 authentication mode. - To enable JWT authentication, specify `OidcAuth` section in your config file: ```yaml @@ -44,8 +42,27 @@ Server: # The URL where OIDC should redirect after login (keep the path '/login') RedirectURL: "http://localhost:8000/login" + + # List of OIDC subjects promoted to Admin status. + Admins: + - user.one@example.org + - user.two@example.org + + TaskAccess: OwnerOrAdmin ``` +The `TaskAccess` property configures the visibility and access-mode for tasks: + +* `All` (default) - all tasks are visible to everyone +* `Owner` - tasks are visible to the users who created them +* `OwnerOrAdmin` - extends `Owner` by allowing Admin-users (defined under + `Admins`) access everything + +As new tasks are created, the username behind the request is recorded as the +owner of the task. Depending on the `TaskAccess` property, if owner-based +acces-mode is enabled, the owner of the task is compared to username of current +request to decide if the user may see and interact with the task. + Make sure to properly protect the configuration file so that it's not readable by everyone: diff --git a/website/static/funnel-config-examples/default-config.yaml b/website/static/funnel-config-examples/default-config.yaml index 21b7b409a..675dc45db 100644 --- a/website/static/funnel-config-examples/default-config.yaml +++ b/website/static/funnel-config-examples/default-config.yaml @@ -32,15 +32,46 @@ Server: # If used, make sure to properly restrict access to the config file # (e.g. chmod 600 funnel.config.yml) # BasicAuth: + # - User: admin + # Password: oejf023moq + # Admin: true # - User: user1 # Password: abc123 # - User: user2 # Password: foobar + # Require Bearer JWT authentication for the server APIs. + # Server won't launch when configuration URL cannot be loaded. + # OidcAuth: + # # URL of the OIDC service configuration (activates OIDC configuration): + # # Example: https://example.org/oidc/.well-knwon/openid-configuration + # ServiceConfigURL: + # # Client ID and secret are sent with the token introspection request + # ClientId: + # ClientSecret: + # # The URL where OIDC should redirect after login (keep the path '/login') + # RedirectURL: "http://localhost:8000/login" + # # Optional: if specified, this scope value must be in the token: + # RequireScope: + # # Optional: if specified, this audience value must be in the token: + # RequireAudience: + # # List of usernames (JWT sub-claim) to be granted Admin-role: + # Admins: + # - admin.username.one@example.org + # - admin.username.two@example.org + # Include a "Cache-Control: no-store" HTTP header in Get/List responses # to prevent caching by intermediary services. DisableHTTPCache: true + # Defines task access and visibility by options: + # "All" (default) - all tasks are visible to everyone + # "Owner" - tasks are visible to the users who created them + # "OwnerOrAdmin" - extends "Owner" by allowing Admin-users see everything + # Owner is the username associated with the task. + # Owners (usernames) were not recorded to tasks before Funnel 0.11.1. + TaskAccess: All + RPCClient: # RPC server address ServerAddress: localhost:9090 @@ -141,6 +172,16 @@ Elastic: IndexPrefix: funnel # URL of the elasticsearch server. URL: http://localhost:9200 + # Optional. Username for HTTP Basic Authentication. + Username: + # Optional. Password for HTTP Basic Authentication. + Password: + # Optional. Endpoint for the Elastic Service (https://elastic.co/cloud). + CloudID: + # Optional. Base64-encoded token for authorization; if set, overrides username/password and service token. + APIKey: + # Optional. Service token for authorization; if set, overrides username/password. + ServiceToken: # Google Cloud Datastore task database. Datastore: @@ -153,7 +194,7 @@ Datastore: MongoDB: # Addrs holds the addresses for the seed servers. Addrs: - - mongodb://localhost + - localhost # Database is the database name used within MongoDB to store funnel data. Database: funnel # Timeout is the amount of time to wait for a server to respond when @@ -299,7 +340,7 @@ AWSBatch: # Kubernetes describes the configuration for the Kubernetes compute backend. Kubernetes: # The executor used to execute tasks. Available executors: docker, kubernetes - Executor: "kubernetes" + Executor: "kubernetes" # Turn off task state reconciler. When enabled, Funnel communicates with Kubernetes # to find tasks that are stuck in a queued state or errored and # updates the task state accordingly.