diff --git a/cmd/api/migrations/20240226153000_create_collections.go b/cmd/api/migrations/20240226153000_create_collections.go new file mode 100644 index 0000000..666cde4 --- /dev/null +++ b/cmd/api/migrations/20240226153000_create_collections.go @@ -0,0 +1,21 @@ +//go:build !migrations + +package migrations + +import ( + "gofr.dev/pkg/gofr/migration" +) + +func createCollections() migration.Migrate { + return migration.Migrate{ + UP: func(d migration.Datasource) error { + collections := []string{"tenants", "users", "api_keys"} + for _, coll := range collections { + if err := d.Mongo.CreateCollection(d.Context, coll); err != nil { + return err + } + } + return nil + }, + } +} diff --git a/cmd/api/migrations/20240226153100_create_indexes.go b/cmd/api/migrations/20240226153100_create_indexes.go new file mode 100644 index 0000000..f9e92b6 --- /dev/null +++ b/cmd/api/migrations/20240226153100_create_indexes.go @@ -0,0 +1,35 @@ +//go:build !migrations + +package migrations + +import ( + "gofr.dev/pkg/gofr/migration" +) + +func createIndexes() migration.Migrate { + return migration.Migrate{ + UP: func(d migration.Datasource) error { + // Create unique index on tenant name + if err := d.Mongo.CreateIndex(d.Context, "tenants", map[string]interface{}{"name": 1}, true); err != nil { + return err + } + + // Create unique index on user email + if err := d.Mongo.CreateIndex(d.Context, "users", map[string]interface{}{"email": 1}, true); err != nil { + return err + } + + // Create index on user's tenant_id for faster queries + if err := d.Mongo.CreateIndex(d.Context, "users", map[string]interface{}{"tenant_id": 1}, false); err != nil { + return err + } + + // Create unique index on API key + if err := d.Mongo.CreateIndex(d.Context, "api_keys", map[string]interface{}{"key": 1}, true); err != nil { + return err + } + + return nil + }, + } +} diff --git a/cmd/api/migrations/all.go b/cmd/api/migrations/all.go index b66c9c8..45eadad 100644 --- a/cmd/api/migrations/all.go +++ b/cmd/api/migrations/all.go @@ -1,3 +1,5 @@ +//go:build !migrations + package migrations import ( diff --git a/cmd/event-ingest/main.go b/cmd/event-ingest/main.go index cff3c0a..50757bd 100644 --- a/cmd/event-ingest/main.go +++ b/cmd/event-ingest/main.go @@ -32,7 +32,6 @@ func main() { if err != nil { log.Fatalf("Failed to connect to gRPC server: %v", err) } - defer conn.Close() // Create gRPC event forwarder @@ -46,7 +45,7 @@ func main() { jwtMiddleware.Validate, middleware.AuthenticateAPIKey, middleware.RequireRole("admin", "event_publisher"), - func(cc *customctx.Context) (interface{}, error) { + func(cc customctx.Context) (interface{}, error) { return httpServer.HandleEvent(cc) }, )) @@ -58,24 +57,35 @@ func main() { // combineMiddleware chains multiple middleware functions together func combineMiddleware(middlewares ...interface{}) gofr.Handler { return func(c *gofr.Context) (interface{}, error) { + // Create the initial custom context from the GoFr context cc := customctx.NewCustomContext(c) - var handler func(*customctx.Context) (interface{}, error) + // Define the final handler that will be called after applying all middleware + finalHandler := func(ctx customctx.Context) (interface{}, error) { + return nil, eventingest.NewInternalError("No handler provided") + } - // Apply middlewares in reverse order + // Apply middlewares in reverse order to build the middleware chain for i := len(middlewares) - 1; i >= 0; i-- { switch m := middlewares[i].(type) { - case func(*customctx.Context) (interface{}, error): - handler = m - case func(func(*customctx.Context) (interface{}, error)) func(*customctx.Context) (interface{}, error): - handler = m(handler) - case func(gofr.Handler) gofr.Handler: - return m(func(*gofr.Context) (interface{}, error) { - return handler(cc) - })(c) + case func(customctx.Context) (interface{}, error): + // Set the final handler to the current one if no other handler is set + if i == len(middlewares)-1 { + finalHandler = m + } else { + // Wrap the final handler in the current function + // nextHandler := finalHandler + finalHandler = func(ctx customctx.Context) (interface{}, error) { + return m(ctx) + } + } + case func(func(customctx.Context) (interface{}, error)) func(customctx.Context) (interface{}, error): + // Wrap the final handler in middleware if it's a middleware function + finalHandler = m(finalHandler) } } - return handler(cc) + // Execute the final middleware chain with the custom context + return finalHandler(cc) } } diff --git a/cmd/eventrunner/main_test.go b/cmd/eventrunner/main_test.go deleted file mode 100644 index 66e63e1..0000000 --- a/cmd/eventrunner/main_test.go +++ /dev/null @@ -1,211 +0,0 @@ -package main - -import ( - "context" - "log" - "os" - "strings" - "testing" - "time" - - natspubsub "github.com/carverauto/gofr-nats" - "github.com/nats-io/nats-server/v2/server" - "gofr.dev/pkg/gofr" - "gofr.dev/pkg/gofr/datasource/pubsub" - "gofr.dev/pkg/gofr/logging" - "gofr.dev/pkg/gofr/testutil" -) - -type mockMetrics struct{} - -func (*mockMetrics) IncrementCounter(_ context.Context, _ string, _ ...string) {} - -func runNATSServer() (*server.Server, error) { - opts := &server.Options{ - ConfigFile: "configs/nats-server.conf", - JetStream: true, - Port: -1, - Trace: true, - } - - return server.NewServer(opts) -} - -func TestExampleSubscriber(t *testing.T) { - // Start the embedded NATS server - natsServer, err := runNATSServer() - if err != nil { - t.Fatalf("Failed to start NATS server: %v", err) - } - defer natsServer.Shutdown() - - natsServer.Start() - - if !natsServer.ReadyForConnections(5 * time.Second) { - t.Fatal("NATS server failed to start") - } - - serverURL := natsServer.ClientURL() - - // Set environment variable for NATS server URL - os.Setenv("PUBSUB_BROKER", serverURL) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - logs := testutil.StdoutOutputForFunc(func() { - // Initialize test data - client := initializeTest(t, serverURL) - defer client.Close() - - // Start the main application - go runMain(ctx) - - // Publish test messages - publishTestMessages(t, client) - - // Wait for messages to be processed - time.Sleep(10 * time.Second) - }) - - // Cancel the context to stop the application gracefully - cancel() - - // Verify logs - verifyLogs(t, logs) -} - -func initializeTest(t *testing.T, serverURL string) pubsub.Client { - t.Helper() - - conf := &natspubsub.Config{ - Server: serverURL, - Stream: natspubsub.StreamConfig{ - Stream: "sample-stream", - Subjects: []string{"order-logs", "products"}, - MaxDeliver: 4, - }, - Consumer: "test-consumer", - MaxWait: 5 * time.Second, - MaxPullWait: 5, - BatchSize: 10, - } - - mockMetrics := &mockMetrics{} - logger := logging.NewMockLogger(logging.DEBUG) - - client, err := natspubsub.New(conf, logger, mockMetrics) - if err != nil { - t.Fatalf("Error initializing NATS client: %v", err) - } - - return client -} - -func publishTestMessages(t *testing.T, client pubsub.Client) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - err := client.Publish(ctx, "order-logs", []byte(`{"orderId":"123","status":"pending"}`)) - if err != nil { - t.Errorf("Error publishing to 'order-logs': %v", err) - } - - err = client.Publish(ctx, "products", []byte(`{"productId":"69","price":"19.99"}`)) - if err != nil { - t.Errorf("Error publishing to 'products': %v", err) - } - - log.Println("Test messages published") -} - -func verifyLogs(t *testing.T, logs string) { - testCases := []struct { - desc string - expectedLog string - }{ - { - desc: "NATS connection", - expectedLog: "connected to NATS server", - }, - { - desc: "valid order", - expectedLog: "Received order", - }, - { - desc: "valid product", - expectedLog: "Received product", - }, - } - - for i, tc := range testCases { - if !strings.Contains(logs, tc.expectedLog) { - t.Errorf("TEST[%d] Failed.\n%s\nExpected log: %s\nActual logs: %s", - i, tc.desc, tc.expectedLog, logs) - } - } - - // Check for unexpected errors - if strings.Contains(logs, "subscriber not initialized") { - t.Errorf("Subscriber initialization error detected in logs") - } - - if strings.Contains(logs, "failed to connect to NATS server") { - t.Errorf("NATS connection error detected in logs") - } -} - -func runMain(ctx context.Context) { - app := gofr.New() - - app.Subscribe("products", func(c *gofr.Context) error { - var productInfo struct { - ProductID string `json:"productId"` - Price string `json:"price"` - } - - err := c.Bind(&productInfo) - if err != nil { - log.Printf("Error binding product data: %v", err) - c.Logger.Error(err) - return nil - } - - c.Logger.Info("Received product", productInfo) - return nil - }) - - app.Subscribe("order-logs", func(c *gofr.Context) error { - var orderStatus struct { - OrderID string `json:"orderId"` - Status string `json:"status"` - } - - err := c.Bind(&orderStatus) - if err != nil { - log.Printf("Error binding order data: %v", err) - c.Logger.Error(err) - return nil - } - - c.Logger.Info("Received order", orderStatus) - return nil - }) - - go func() { - <-ctx.Done() - log.Println("Context canceled, stopping application") - - shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - err := app.Shutdown(shutdownCtx) - if err != nil { - log.Printf("Error shutting down application: %v", err) - } - }() - - log.Println("Starting application") - app.Run() - log.Println("Application stopped") -} diff --git a/cmd/eventrunner/migrations/20241003051401_create_events_table.go b/cmd/eventrunner/migrations/20241003051401_create_events_table.go index 56dbc3b..6f5af0a 100644 --- a/cmd/eventrunner/migrations/20241003051401_create_events_table.go +++ b/cmd/eventrunner/migrations/20241003051401_create_events_table.go @@ -1,3 +1,5 @@ +//go:build !migrations + package migrations import ( diff --git a/cmd/eventrunner/migrations/all.go b/cmd/eventrunner/migrations/all.go index c356514..3bf1144 100644 --- a/cmd/eventrunner/migrations/all.go +++ b/cmd/eventrunner/migrations/all.go @@ -1,3 +1,5 @@ +//go:build !migrations + package migrations import "gofr.dev/pkg/gofr/migration" diff --git a/pkg/api/handlers/handlers.go b/pkg/api/handlers/handlers.go index 76c7516..58978ec 100644 --- a/pkg/api/handlers/handlers.go +++ b/pkg/api/handlers/handlers.go @@ -2,25 +2,25 @@ package handlers import ( "github.com/carverauto/eventrunner/pkg/api/models" + "github.com/google/uuid" "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" "gofr.dev/pkg/gofr" ) type TenantHandler struct{} -func (*TenantHandler) Create(c *gofr.Context) (interface{}, error) { +func (*TenantHandler) Create(c *gofr.Context) (models.Tenant, error) { var tenant models.Tenant if err := c.Bind(&tenant); err != nil { - return nil, err + return models.Tenant{}, err } result, err := c.Mongo.InsertOne(c, "tenants", tenant) if err != nil { - return nil, err + return models.Tenant{}, err } - tenant.ID = result.(primitive.ObjectID) + tenant.ID = result.(uuid.UUID) return tenant, nil } @@ -47,13 +47,13 @@ func (*UserHandler) Create(c *gofr.Context) (interface{}, error) { return nil, err } - user.ID = result.(primitive.ObjectID) + user.ID = result.(uuid.UUID) return user, nil } func (*UserHandler) GetAll(c *gofr.Context) (interface{}, error) { - tenantID, err := primitive.ObjectIDFromHex(c.Param("tenant_id")) + tenantID, err := uuid.Parse(c.Param("tenant_id")) if err != nil { return nil, err } diff --git a/pkg/api/middleware/jwt.go b/pkg/api/middleware/jwt.go index 745c6e9..7e041fb 100644 --- a/pkg/api/middleware/jwt.go +++ b/pkg/api/middleware/jwt.go @@ -2,12 +2,13 @@ package middleware import ( "context" + "strings" + "github.com/carverauto/eventrunner/pkg/config" customctx "github.com/carverauto/eventrunner/pkg/context" "github.com/carverauto/eventrunner/pkg/eventingest" "github.com/coreos/go-oidc/v3/oidc" "gofr.dev/pkg/gofr" - "strings" ) // JWTMiddleware is a middleware that validates JWT tokens. @@ -39,6 +40,7 @@ func (m *JWTMiddleware) Validate(next func(customctx.Context) (interface{}, erro // Safely retrieve Authorization header from context authHeaderValue := c.Request.Context().Value("Authorization") + authHeader, ok := authHeaderValue.(string) if !ok || authHeader == "" { return nil, eventingest.NewAuthError("Missing or invalid authorization header") diff --git a/pkg/api/middleware/rbac.go b/pkg/api/middleware/rbac.go index ce91551..6d59d12 100644 --- a/pkg/api/middleware/rbac.go +++ b/pkg/api/middleware/rbac.go @@ -6,16 +6,27 @@ import ( ) // AuthenticateAPIKey checks if the API key is valid and active, otherwise returns an error. -func AuthenticateAPIKey(next func(CustomContext) (interface{}, error)) func(CustomContext) (interface{}, error) { - return func(cc CustomContext) (interface{}, error) { +// AuthenticateAPIKey checks if the API key is valid and active, otherwise returns an error. +func AuthenticateAPIKey(next func(customctx.Context) (interface{}, error)) func(customctx.Context) (interface{}, error) { + return func(cc customctx.Context) (interface{}, error) { apiKey, ok := cc.GetAPIKey() if !ok || apiKey == "" { return nil, eventingest.NewAuthError("Missing API Key") } - tenantID, customerID, err := cc.FindAPIKey(apiKey) - if err != nil { - return nil, eventingest.NewAuthError("Invalid API Key") + /* + tenantID, customerID, err := cc.GetAPIKey(apiKey) + if err != nil { + return nil, eventingest.NewAuthError("Invalid API Key") + } + */ + tenantID, ok := cc.GetUUIDClaim("tenant_id") + if !ok { + return nil, eventingest.NewAuthError("Missing tenant ID") + } + customerID, ok := cc.GetUUIDClaim("customer_id") + if !ok { + return nil, eventingest.NewAuthError("Missing customer ID") } cc.SetClaim("api_key", apiKey) @@ -27,12 +38,8 @@ func AuthenticateAPIKey(next func(CustomContext) (interface{}, error)) func(Cust } // RequireRole checks if the user has the required role to access the resource, otherwise returns an error. -// The user's role is stored in the JWT token. The roles parameter is a list of roles that are allowed -// to access the resource. -func RequireRole(roles ...string) func( - func(customctx.Context) (interface{}, error)) func(customctx.Context) (interface{}, error) { - return func( - next func(customctx.Context) (interface{}, error)) func(customctx.Context) (interface{}, error) { +func RequireRole(roles ...string) func(func(customctx.Context) (interface{}, error)) func(customctx.Context) (interface{}, error) { + return func(next func(customctx.Context) (interface{}, error)) func(customctx.Context) (interface{}, error) { return func(cc customctx.Context) (interface{}, error) { userRole, ok := cc.GetStringClaim("user_role") if !ok { diff --git a/pkg/api/middleware/rbac_test.go b/pkg/api/middleware/rbac_test.go index b15ac30..a8d62f3 100644 --- a/pkg/api/middleware/rbac_test.go +++ b/pkg/api/middleware/rbac_test.go @@ -3,6 +3,7 @@ package middleware import ( "testing" + customctx "github.com/carverauto/eventrunner/pkg/context" "github.com/carverauto/eventrunner/pkg/eventingest" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -13,7 +14,7 @@ func TestAuthenticateAPIKey(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockContext := NewMockCustomContext(ctrl) + mockContext := customctx.NewMockContext(ctrl) tests := []struct { name string @@ -28,8 +29,14 @@ func TestAuthenticateAPIKey(t *testing.T) { tenantID := uuid.New() customerID := uuid.New() + // Set up mock expectations for API key retrieval mockContext.EXPECT().GetAPIKey().Return(apiKey, true) - mockContext.EXPECT().FindAPIKey(apiKey).Return(tenantID, customerID, nil) + + // Set up mock expectations for claims + mockContext.EXPECT().GetUUIDClaim("tenant_id").Return(tenantID, true) + mockContext.EXPECT().GetUUIDClaim("customer_id").Return(customerID, true) + + // Set up mock expectations for setting claims mockContext.EXPECT().SetClaim("api_key", apiKey) mockContext.EXPECT().SetClaim("tenant_id", tenantID) mockContext.EXPECT().SetClaim("customer_id", customerID) @@ -45,15 +52,33 @@ func TestAuthenticateAPIKey(t *testing.T) { expectedResult: nil, expectedError: eventingest.NewAuthError("Missing API Key"), }, + { + name: "Missing Tenant or Customer ID Claims", + setupMocks: func() { + apiKey := "valid_key" + + // Set up mock expectations for API key retrieval + mockContext.EXPECT().GetAPIKey().Return(apiKey, true) + + // Mock missing tenant or customer ID claims + mockContext.EXPECT().GetUUIDClaim("tenant_id").Return(uuid.Nil, false) + }, + expectedResult: nil, + expectedError: eventingest.NewAuthError("Missing tenant ID"), + }, { name: "Invalid API Key", setupMocks: func() { apiKey := "invalid_key" + + // Set up mock expectations for API key retrieval mockContext.EXPECT().GetAPIKey().Return(apiKey, true) - mockContext.EXPECT().FindAPIKey(apiKey).Return(uuid.Nil, uuid.Nil, eventingest.NewAuthError("Invalid API Key")) + + // Mock missing tenant and customer ID claims + mockContext.EXPECT().GetUUIDClaim("tenant_id").Return(uuid.Nil, false) }, expectedResult: nil, - expectedError: eventingest.NewAuthError("Invalid API Key"), + expectedError: eventingest.NewAuthError("Missing tenant ID"), }, } @@ -61,7 +86,7 @@ func TestAuthenticateAPIKey(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tt.setupMocks() - middleware := AuthenticateAPIKey(func(cc CustomContext) (interface{}, error) { + middleware := AuthenticateAPIKey(func(cc customctx.Context) (interface{}, error) { return "success", nil })