Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions components/ambient-api-server/pkg/middleware/bearer_token_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package middleware
import (
"context"
"crypto/subtle"
"strings"

"github.com/golang-jwt/jwt/v4"
"github.com/openshift-online/rh-trex-ai/pkg/auth"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
Expand All @@ -25,6 +28,9 @@ func bearerTokenGRPCUnaryInterceptor(expectedToken string) grpc.UnaryServerInter
if subtle.ConstantTimeCompare([]byte(token), []byte(expectedToken)) == 1 {
return handler(withCallerType(ctx, CallerTypeService), req)
}
if username := usernameFromJWT(token); username != "" {
return handler(auth.SetUsernameContext(ctx, username), req)
}
}
}
}
Expand All @@ -45,6 +51,10 @@ func bearerTokenGRPCStreamInterceptor(expectedToken string) grpc.StreamServerInt
if subtle.ConstantTimeCompare([]byte(token), []byte(expectedToken)) == 1 {
return handler(srv, &serviceCallerStream{ServerStream: ss, ctx: withCallerType(ss.Context(), CallerTypeService)})
}
if username := usernameFromJWT(token); username != "" {
ctx := auth.SetUsernameContext(ss.Context(), username)
return handler(srv, &serviceCallerStream{ServerStream: ss, ctx: ctx})
}
}
}
}
Expand All @@ -53,6 +63,24 @@ func bearerTokenGRPCStreamInterceptor(expectedToken string) grpc.StreamServerInt
}
}

func usernameFromJWT(tokenString string) string {
p := jwt.NewParser()
token, _, err := p.ParseUnverified(tokenString, jwt.MapClaims{})
if err != nil {
return ""
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return ""
}
for _, key := range []string{"preferred_username", "username", "sub"} {
if v, _ := claims[key].(string); v != "" && !strings.Contains(v, ":") {
return v
}
}
return ""
}

type serviceCallerStream struct {
grpc.ServerStream
ctx context.Context
Expand Down
67 changes: 67 additions & 0 deletions components/ambient-api-server/plugins/credentials/migration.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package credentials

import (
"encoding/json"

"gorm.io/gorm"

"github.com/go-gormigrate/gormigrate/v2"
"github.com/openshift-online/rh-trex-ai/pkg/api"
"github.com/openshift-online/rh-trex-ai/pkg/db"
)

Expand All @@ -30,3 +33,67 @@ func migration() *gormigrate.Migration {
},
}
}

func rolesMigration() *gormigrate.Migration {
type roleRow struct {
ID string
Name string
DisplayName string
Description string
Permissions string
BuiltIn bool
}

seed := []struct {
name string
displayName string
description string
permissions []string
}{
{
name: "credential:token-reader",
displayName: "Credential Token Reader",
description: "Retrieve the raw token value for a credential",
permissions: []string{"credential:token"},
},
{
name: "credential:reader",
displayName: "Credential Reader",
description: "Read credential metadata (name, provider, description)",
permissions: []string{"credential:read", "credential:list"},
},
}

return &gormigrate.Migration{
ID: "202603311216",
Migrate: func(tx *gorm.DB) error {
for _, r := range seed {
permsJSON, err := json.Marshal(r.permissions)
if err != nil {
return err
}
row := roleRow{
ID: api.NewID(),
Name: r.name,
DisplayName: r.displayName,
Description: r.description,
Permissions: string(permsJSON),
BuiltIn: true,
}
if err := tx.Table("roles").
Where("name = ?", r.name).
FirstOrCreate(&row).Error; err != nil {
return err
}
}
return nil
},
Rollback: func(tx *gorm.DB) error {
names := make([]string, len(seed))
for i, r := range seed {
names[i] = r.name
}
return tx.Table("roles").Where("name IN ?", names).Delete(&roleRow{}).Error
},
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,5 @@ func init() {
presenters.RegisterKind(&Credential{}, "Credential")

db.RegisterMigration(migration())
db.RegisterMigration(rolesMigration())
}
12 changes: 12 additions & 0 deletions components/ambient-api-server/plugins/credentials/testmain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ import (
"github.com/golang/glog"

"github.com/ambient-code/platform/components/ambient-api-server/test"

_ "github.com/ambient-code/platform/components/ambient-api-server/plugins/agents"
_ "github.com/ambient-code/platform/components/ambient-api-server/plugins/inbox"
_ "github.com/ambient-code/platform/components/ambient-api-server/plugins/projectSettings"
_ "github.com/ambient-code/platform/components/ambient-api-server/plugins/projects"
_ "github.com/ambient-code/platform/components/ambient-api-server/plugins/rbac"
_ "github.com/ambient-code/platform/components/ambient-api-server/plugins/roleBindings"
_ "github.com/ambient-code/platform/components/ambient-api-server/plugins/roles"
_ "github.com/ambient-code/platform/components/ambient-api-server/plugins/sessions"
_ "github.com/ambient-code/platform/components/ambient-api-server/plugins/users"
_ "github.com/openshift-online/rh-trex-ai/plugins/events"
_ "github.com/openshift-online/rh-trex-ai/plugins/generic"
)

func TestMain(m *testing.M) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,10 @@ func (h *sessionGRPCHandler) WatchSessionMessages(req *pb.WatchSessionMessagesRe

if !middleware.IsServiceCaller(ctx) {
username := auth.GetUsernameFromContext(ctx)
if username != "" && (h.grpcServiceAccount == "" || username != h.grpcServiceAccount) {
if username == "" {
return status.Error(codes.PermissionDenied, "not authorized to watch this session")
}
if h.grpcServiceAccount == "" || username != h.grpcServiceAccount {
session, svcErr := h.service.Get(ctx, req.GetSessionId())
if svcErr != nil {
return grpcutil.ServiceErrorToGRPC(svcErr)
Expand Down
3 changes: 2 additions & 1 deletion components/ambient-api-server/plugins/sessions/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"net"
"net/http"
"strings"
"time"

"github.com/golang/glog"
Expand Down Expand Up @@ -299,7 +300,7 @@ func (h sessionHandler) StreamRunnerEvents(w http.ResponseWriter, r *http.Reques

runnerURL := fmt.Sprintf(
"http://session-%s.%s.svc.cluster.local:8001/events/%s",
*session.KubeCrName, *session.KubeNamespace, *session.KubeCrName,
strings.ToLower(*session.KubeCrName), *session.KubeNamespace, *session.KubeCrName,
)

req, reqErr := http.NewRequestWithContext(ctx, http.MethodGet, runnerURL, nil)
Expand Down
Loading
Loading