diff --git a/api/authentication/apikey_handlers.go b/api/authentication/apikey_handlers.go new file mode 100644 index 0000000..a0a6886 --- /dev/null +++ b/api/authentication/apikey_handlers.go @@ -0,0 +1,153 @@ +package authentication + +import ( + "errors" + "net/http" + "time" + + "github.com/labstack/echo/v4" + "github.com/latebit-io/bulwarkauth/api/problem" + "github.com/latebit-io/bulwarkauth/internal/authentication" +) + +type CreateApiKeyRequest struct { + TenantID string `json:"tenantId"` + AccessToken string `json:"accessToken"` + Name string `json:"name"` + Expires *time.Time `json:"expires"` +} + +func (r CreateApiKeyRequest) Validate() error { + if r.TenantID == "" { + return errors.New("tenantId required") + } + if r.AccessToken == "" { + return errors.New("accessToken required") + } + if r.Name == "" { + return errors.New("name required") + } + if r.Expires != nil && r.Expires.Before(time.Now().UTC()) { + return errors.New("expires must be in the future") + } + return nil +} + +type ListApiKeyRequest struct { + TenantID string `json:"tenantId"` + AccessToken string `json:"accessToken"` +} + +func (r ListApiKeyRequest) Validate() error { + if r.TenantID == "" { + return errors.New("tenantId required") + } + if r.AccessToken == "" { + return errors.New("accessToken required") + } + return nil +} + +type DeleteApiKeyRequest struct { + TenantID string `json:"tenantId"` + AccessToken string `json:"accessToken"` +} + +func (r DeleteApiKeyRequest) Validate() error { + if r.TenantID == "" { + return errors.New("tenantId required") + } + if r.AccessToken == "" { + return errors.New("accessToken required") + } + return nil +} + +type ApiKeyHandlers struct { + apiKeyService authentication.ApiKeyService +} + +func NewApiKeyHandlers(service authentication.ApiKeyService) *ApiKeyHandlers { + return &ApiKeyHandlers{apiKeyService: service} +} + +func (h *ApiKeyHandlers) Create(c echo.Context) error { + req := new(CreateApiKeyRequest) + err := c.Bind(req) + if err != nil { + httpError := problem.NewBadRequest(err) + return echo.NewHTTPError(httpError.Status, httpError) + } + + if err := req.Validate(); err != nil { + httpError := problem.NewBadRequest(err) + return echo.NewHTTPError(httpError.Status, httpError) + } + + created, err := h.apiKeyService.Create(c.Request().Context(), req.TenantID, req.AccessToken, req.Name, req.Expires) + if err != nil { + httpError := problem.NewBadRequest(err) + return echo.NewHTTPError(httpError.Status, httpError) + } + + return c.JSON(http.StatusCreated, created) +} + +func (h *ApiKeyHandlers) List(c echo.Context) error { + req := new(ListApiKeyRequest) + err := c.Bind(req) + if err != nil { + httpError := problem.NewBadRequest(err) + return echo.NewHTTPError(httpError.Status, httpError) + } + + if err := req.Validate(); err != nil { + httpError := problem.NewBadRequest(err) + return echo.NewHTTPError(httpError.Status, httpError) + } + + keys, err := h.apiKeyService.List(c.Request().Context(), req.TenantID, req.AccessToken) + if err != nil { + httpError := problem.NewBadRequest(err) + return echo.NewHTTPError(httpError.Status, httpError) + } + + return c.JSON(http.StatusOK, keys) +} + +func (h *ApiKeyHandlers) Delete(c echo.Context) error { + req := new(DeleteApiKeyRequest) + err := c.Bind(req) + if err != nil { + httpError := problem.NewBadRequest(err) + return echo.NewHTTPError(httpError.Status, httpError) + } + + if err := req.Validate(); err != nil { + httpError := problem.NewBadRequest(err) + return echo.NewHTTPError(httpError.Status, httpError) + } + + prefix := c.Param("prefix") + if prefix == "" { + httpError := problem.NewBadRequest(errors.New("prefix required")) + return echo.NewHTTPError(httpError.Status, httpError) + } + + err = h.apiKeyService.Delete(c.Request().Context(), req.TenantID, req.AccessToken, prefix) + if err != nil { + var notFound authentication.ApiKeyNotFoundError + if errors.As(err, ¬Found) { + return echo.NewHTTPError(http.StatusNotFound, problem.Details{ + Type: "https://latebit.io/bulwark/errors/", + Title: "API Key Not Found", + Status: http.StatusNotFound, + Detail: err.Error(), + }) + } + httpError := problem.NewBadRequest(err) + return echo.NewHTTPError(httpError.Status, httpError) + } + + return c.NoContent(http.StatusNoContent) +} diff --git a/api/authentication/apikey_routes.go b/api/authentication/apikey_routes.go new file mode 100644 index 0000000..5a4db99 --- /dev/null +++ b/api/authentication/apikey_routes.go @@ -0,0 +1,9 @@ +package authentication + +import "github.com/labstack/echo/v4" + +func ApiKeyRoutes(e *echo.Echo, handler *ApiKeyHandlers, middleware ...echo.MiddlewareFunc) { + e.POST("/api/apikeys", handler.Create, middleware...) + e.POST("/api/apikeys/list", handler.List, middleware...) + e.DELETE("/api/apikeys/:prefix", handler.Delete, middleware...) +} diff --git a/cmd/bulwarkauth/main.go b/cmd/bulwarkauth/main.go index 3fb375b..59227f8 100644 --- a/cmd/bulwarkauth/main.go +++ b/cmd/bulwarkauth/main.go @@ -130,6 +130,10 @@ func main() { socialService.AddValidator(google) socialHandlers := authenticationapi.NewSocialHandlers(socialService) authenticationapi.SocialRoutes(service, socialHandlers) + apiKeyRepo := authentication.NewMongoDbApiRepository(mongodb) + apiKeyService := authentication.NewDefaultApiKeyService(apiKeyRepo, encrypt, tokenizer, accountsRepo) + apiKeyHandlers := authenticationapi.NewApiKeyHandlers(apiKeyService) + authenticationapi.ApiKeyRoutes(service, apiKeyHandlers, ratelimiter) corsSetting(service, config, logger) apiKeySetting(service, config, logger) diff --git a/internal/accounts/accounts.go b/internal/accounts/accounts.go index 1ed23b0..4f978b6 100644 --- a/internal/accounts/accounts.go +++ b/internal/accounts/accounts.go @@ -44,6 +44,7 @@ type Verification struct { } type Account struct { + ID string `bson:"id" json:"-"` TenantID string `bson:"tenantId" json:"tenantId"` Email string `bson:"email"` IsVerified bool `bson:"isVerified"` diff --git a/internal/authentication/apikey.go b/internal/authentication/apikey.go new file mode 100644 index 0000000..52e9923 --- /dev/null +++ b/internal/authentication/apikey.go @@ -0,0 +1,281 @@ +package authentication + +import ( + "context" + "errors" + "fmt" + "log" + "strings" + "time" + + "github.com/google/uuid" + "github.com/latebit-io/bulwarkauth/internal/accounts" + "github.com/latebit-io/bulwarkauth/internal/encryption" + "github.com/latebit-io/bulwarkauth/internal/tokens" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +const ( + apiKeyPrefix = "bwa" + randomLength = 8 + apiKeyCollection = "apiKeys" +) + +type ApiKey struct { + ID string `json:"id" bson:"id"` + TenantID string `json:"tenantId" bson:"tenantId"` + AccountID string `json:"-" bson:"accountId"` + Name string `json:"name" bson:"name"` + KeyHash string `json:"-" bson:"key"` + KeyPrefix string `json:"keyPrefix" bson:"keyPrefix"` + IsEnabled bool `json:"isEnabled" bson:"isEnabled"` + Expires *time.Time `json:"expires" bson:"expires"` + Created time.Time `json:"created" bson:"created"` + Modified time.Time `json:"modified" bson:"modified"` +} + +type ApiRepository interface { + Create(ctx context.Context, apiKey *ApiKey) error + Read(ctx context.Context, tenantID, accountID, keyPrefix string) (*ApiKey, error) + List(ctx context.Context, tenantID, accountID string) ([]ApiKey, error) + Delete(ctx context.Context, tenantID, accountID, keyPrefix string) error +} + +type ApiKeyService interface { + Authenticate(ctx context.Context, tenantID, email, apiKey string) (*Authenticated, error) + Create(ctx context.Context, tenantID, accessToken, name string, expires *time.Time) (string, error) + List(ctx context.Context, tenantID, accessToken string) ([]ApiKey, error) + Delete(ctx context.Context, tenantID, accessToken, keyPrefix string) error +} + +// MongoDbApiRepository MongoDB implementation of ApiRepository. +type MongoDbApiRepository struct { + db *mongo.Database +} + +func NewMongoDbApiRepository(db *mongo.Database) *MongoDbApiRepository { + collection := db.Collection(apiKeyCollection) + _, err := collection.Indexes().CreateOne(context.Background(), mongo.IndexModel{ + Keys: bson.D{{Key: "tenantId", Value: 1}, {Key: "accountId", Value: 1}, {Key: "keyPrefix", Value: 1}}, + Options: options.Index().SetUnique(true), + }) + if err != nil { + log.Fatal(err) + } + return &MongoDbApiRepository{db: db} +} + +func (m *MongoDbApiRepository) Create(ctx context.Context, apiKey *ApiKey) error { + collection := m.db.Collection(apiKeyCollection) + _, err := collection.InsertOne(ctx, apiKey) + return err +} + +func (m *MongoDbApiRepository) Read(ctx context.Context, tenantID, accountID, keyPrefix string) (*ApiKey, error) { + collection := m.db.Collection(apiKeyCollection) + filter := bson.M{"tenantId": tenantID, "accountId": accountID, "keyPrefix": keyPrefix} + result := collection.FindOne(ctx, filter) + + if errors.Is(result.Err(), mongo.ErrNoDocuments) { + return nil, ApiKeyNotFoundError{Value: keyPrefix} + } + + var apiKey ApiKey + err := result.Decode(&apiKey) + if err != nil { + return nil, err + } + + return &apiKey, nil +} + +func (m *MongoDbApiRepository) List(ctx context.Context, tenantID, accountID string) ([]ApiKey, error) { + collection := m.db.Collection(apiKeyCollection) + filter := bson.M{"tenantId": tenantID, "accountId": accountID} + cursor, err := collection.Find(ctx, filter) + if err != nil { + return nil, err + } + defer cursor.Close(ctx) + + var apiKeys []ApiKey + for cursor.Next(ctx) { + var apiKey ApiKey + if err := cursor.Decode(&apiKey); err != nil { + return nil, err + } + apiKeys = append(apiKeys, apiKey) + } + return apiKeys, nil +} + +func (m *MongoDbApiRepository) Delete(ctx context.Context, tenantID, accountID, keyPrefix string) error { + collection := m.db.Collection(apiKeyCollection) + result, err := collection.DeleteOne(ctx, bson.M{"tenantId": tenantID, "accountId": accountID, "keyPrefix": keyPrefix}) + if err != nil { + return err + } + if result.DeletedCount == 0 { + return ApiKeyNotFoundError{Value: keyPrefix} + } + return nil +} + +// DefaultApiKeyService implementation of ApiKeyService. +type DefaultApiKeyService struct { + apiRepo ApiRepository + encryption Encryption + tokenizer tokens.Tokenizer + accountRepo accounts.AccountRepository +} + +func NewDefaultApiKeyService(repo ApiRepository, encryption Encryption, tokenizer tokens.Tokenizer, accountRepo accounts.AccountRepository) *DefaultApiKeyService { + return &DefaultApiKeyService{ + apiRepo: repo, + encryption: encryption, + tokenizer: tokenizer, + accountRepo: accountRepo, + } +} + +func (s *DefaultApiKeyService) Authenticate(ctx context.Context, tenantID, email, apiKey string) (*Authenticated, error) { + splitKey := strings.Split(apiKey, "_") + if len(splitKey) != 3 { + return nil, ApiKeyInvalidError{Value: email} + } + + account, err := s.accountRepo.Read(ctx, tenantID, email) + if err != nil { + return nil, err + } + + keyPrefix := fmt.Sprintf("%s_%s", splitKey[0], splitKey[1]) + apiKeyResult, err := s.apiRepo.Read(ctx, tenantID, account.ID, keyPrefix) + if err != nil { + return nil, err + } + + if !apiKeyResult.IsEnabled { + return nil, ApiKeyDisabledError{Value: email} + } + + if apiKeyResult.Expires != nil && apiKeyResult.Expires.Before(time.Now()) { + return nil, ApiKeyExpiredError{Value: email} + } + + verified, err := s.encryption.Verify(apiKeyResult.KeyHash, splitKey[2]) + if err != nil { + return nil, err + } + if !verified { + return nil, ApiKeyInvalidError{Value: email} + } + + accessToken, err := s.tokenizer.CreateAccessToken(ctx, tenantID, email, "apikey", account.Roles) + if err != nil { + return nil, err + } + refreshToken, err := s.tokenizer.CreateRefreshToken(ctx, tenantID, email, "apikey") + if err != nil { + return nil, err + } + + return &Authenticated{ + AccessToken: accessToken, + RefreshToken: refreshToken, + }, nil +} + +func (s *DefaultApiKeyService) Create(ctx context.Context, tenantID, accessToken, name string, expires *time.Time) (string, error) { + claims, err := s.tokenizer.ValidateAccessToken(ctx, accessToken) + if err != nil { + return "", err + } + + if claims.TenantID != tenantID { + return "", errors.New("token invalid") + } + + accountID, err := s.resolveAccountID(ctx, tenantID, claims.Subject) + if err != nil { + return "", err + } + + return s.generate(ctx, tenantID, accountID, name, expires) +} + +func (s *DefaultApiKeyService) List(ctx context.Context, tenantID, accessToken string) ([]ApiKey, error) { + claims, err := s.tokenizer.ValidateAccessToken(ctx, accessToken) + if err != nil { + return nil, err + } + + if claims.TenantID != tenantID { + return nil, errors.New("token invalid") + } + + accountID, err := s.resolveAccountID(ctx, tenantID, claims.Subject) + if err != nil { + return nil, err + } + + return s.apiRepo.List(ctx, tenantID, accountID) +} + +func (s *DefaultApiKeyService) Delete(ctx context.Context, tenantID, accessToken, keyPrefix string) error { + claims, err := s.tokenizer.ValidateAccessToken(ctx, accessToken) + if err != nil { + return err + } + + if claims.TenantID != tenantID { + return errors.New("token invalid") + } + + accountID, err := s.resolveAccountID(ctx, tenantID, claims.Subject) + if err != nil { + return err + } + + return s.apiRepo.Delete(ctx, tenantID, accountID, keyPrefix) +} + +// resolveAccountID looks up the account by email and returns the internal account ID. +func (s *DefaultApiKeyService) resolveAccountID(ctx context.Context, tenantID, email string) (string, error) { + account, err := s.accountRepo.Read(ctx, tenantID, email) + if err != nil { + return "", err + } + return account.ID, nil +} + +func (s *DefaultApiKeyService) generate(ctx context.Context, tenantID, accountID, name string, expire *time.Time) (string, error) { + id := uuid.New().String() + randomStr := encryption.GenerateRandomString(randomLength) + keyPrefix := fmt.Sprintf("%s_%s", apiKeyPrefix, randomStr) + key := uuid.New().String() + hashKey, err := s.encryption.Encrypt(key) + if err != nil { + return "", err + } + apiKey := &ApiKey{ + ID: id, + TenantID: tenantID, + AccountID: accountID, + Name: name, + KeyHash: hashKey, + KeyPrefix: keyPrefix, + IsEnabled: true, + Expires: expire, + Created: time.Now(), + Modified: time.Now(), + } + + if err := s.apiRepo.Create(ctx, apiKey); err != nil { + return "", err + } + + return fmt.Sprintf("%s_%s", apiKey.KeyPrefix, key), nil +} diff --git a/internal/authentication/apikey_errors.go b/internal/authentication/apikey_errors.go new file mode 100644 index 0000000..46ed275 --- /dev/null +++ b/internal/authentication/apikey_errors.go @@ -0,0 +1,35 @@ +package authentication + +import "fmt" + +type ApiKeyNotFoundError struct { + Value string +} + +func (e ApiKeyNotFoundError) Error() string { + return fmt.Sprintf("api key not found: %s", e.Value) +} + +type ApiKeyInvalidError struct { + Value string +} + +func (e ApiKeyInvalidError) Error() string { + return fmt.Sprintf("api key invalid: %s", e.Value) +} + +type ApiKeyDisabledError struct { + Value string +} + +func (e ApiKeyDisabledError) Error() string { + return fmt.Sprintf("api key is disabled: %s", e.Value) +} + +type ApiKeyExpiredError struct { + Value string +} + +func (e ApiKeyExpiredError) Error() string { + return fmt.Sprintf("api key is expired: %s", e.Value) +} diff --git a/internal/authentication/apikey_service_test.go b/internal/authentication/apikey_service_test.go new file mode 100644 index 0000000..aaf1a73 --- /dev/null +++ b/internal/authentication/apikey_service_test.go @@ -0,0 +1,452 @@ +package authentication + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/latebit-io/bulwarkauth/internal/accounts" + "github.com/latebit-io/bulwarkauth/internal/encryption" + "github.com/latebit-io/bulwarkauth/internal/tokens" + "github.com/latebit-io/bulwarkauth/internal/utils" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +type apiKeyTestEnv struct { + db *mongo.Database + service *DefaultApiKeyService + tokenizer tokens.Tokenizer + accountRepo accounts.AccountRepository + cleanup func() +} + +func setupApiKeyServiceTest(t *testing.T) *apiKeyTestEnv { + t.Helper() + ctx := context.Background() + mongodb := utils.NewMongoTestUtil() + mongoServer, err := mongodb.CreateServer() + if err != nil { + t.Fatal(err) + } + + clientOptions := options.Client().ApplyURI(mongoServer.URI()) + client, err := mongo.Connect(ctx, clientOptions) + if err != nil { + mongoServer.Stop() + t.Fatal(err) + } + + db := client.Database("testdb") + encrypt := encryption.NewDefaultEncryption(12) + + signingRepo := tokens.NewDefaultSigningKeyRepository(db) + signingService := tokens.NewDefaultSigningKeyService(signingRepo) + err = signingService.Initialize(ctx) + if err != nil { + client.Disconnect(ctx) + mongoServer.Stop() + t.Fatal(err) + } + + tokenizer := tokens.NewDefaultTokenizer("test", "test", "test", 3600, 3600, signingService) + accountRepo := accounts.NewMongodbAccountRepository(db, encrypt) + apiKeyRepo := NewMongoDbApiRepository(db) + service := NewDefaultApiKeyService(apiKeyRepo, encrypt, tokenizer, accountRepo) + + return &apiKeyTestEnv{ + db: db, + service: service, + tokenizer: tokenizer, + accountRepo: accountRepo, + cleanup: func() { + client.Disconnect(ctx) + mongoServer.Stop() + }, + } +} + +func createTestAccount(t *testing.T, env *apiKeyTestEnv, tenantID, email, password string) { + t.Helper() + ctx := context.Background() + err := env.accountRepo.Create(ctx, tenantID, email, password) + if err != nil { + t.Fatalf("Failed to create test account: %v", err) + } + err = env.accountRepo.Verify(ctx, tenantID, email) + if err != nil { + t.Fatalf("Failed to verify test account: %v", err) + } +} + +func createTestAccessToken(t *testing.T, env *apiKeyTestEnv, tenantID, email string) string { + t.Helper() + token, err := env.tokenizer.CreateAccessToken(context.Background(), tenantID, email, "test-client", []string{}) + if err != nil { + t.Fatalf("Failed to create access token: %v", err) + } + return token +} + +// extractKeyPrefix returns the "bwa_" portion from a full API key "bwa__". +func extractKeyPrefix(apiKey string) string { + parts := strings.Split(apiKey, "_") + if len(parts) != 3 { + return "" + } + return fmt.Sprintf("%s_%s", parts[0], parts[1]) +} + +func TestApiKeyService_Create(t *testing.T) { + env := setupApiKeyServiceTest(t) + defer env.cleanup() + + tenantID := "tenant1" + email := "test@example.com" + createTestAccount(t, env, tenantID, email, "Password123!") + + accessToken := createTestAccessToken(t, env, tenantID, email) + + rawKey, err := env.service.Create(context.Background(), tenantID, accessToken, "my-api-key", nil) + if err != nil { + t.Fatalf("Failed to create api key: %v", err) + } + + parts := strings.Split(rawKey, "_") + if len(parts) != 3 { + t.Fatalf("Expected key format bwa_prefix_secret, got '%s'", rawKey) + } + + if parts[0] != apiKeyPrefix { + t.Errorf("Expected key to start with '%s', got '%s'", apiKeyPrefix, parts[0]) + } + + if len(parts[1]) != randomLength { + t.Errorf("Expected prefix random length %d, got %d", randomLength, len(parts[1])) + } + + if !strings.HasPrefix(rawKey, apiKeyPrefix+"_") { + t.Errorf("Expected key to start with '%s_', got '%s'", apiKeyPrefix, rawKey) + } +} + +func TestApiKeyService_Create_WithExpiry(t *testing.T) { + env := setupApiKeyServiceTest(t) + defer env.cleanup() + + tenantID := "tenant1" + email := "test@example.com" + createTestAccount(t, env, tenantID, email, "Password123!") + + accessToken := createTestAccessToken(t, env, tenantID, email) + expires := time.Now().Add(24 * time.Hour) + + rawKey, err := env.service.Create(context.Background(), tenantID, accessToken, "expiring-key", &expires) + if err != nil { + t.Fatalf("Failed to create api key: %v", err) + } + + // Verify the key was created by listing and checking the expiry + keys, err := env.service.List(context.Background(), tenantID, accessToken) + if err != nil { + t.Fatalf("Failed to list keys: %v", err) + } + + keyPrefix := extractKeyPrefix(rawKey) + var found bool + for _, k := range keys { + if k.KeyPrefix == keyPrefix { + found = true + if k.Expires == nil { + t.Fatal("Expected expires to be set") + } + } + } + if !found { + t.Fatalf("Created key with prefix '%s' not found in list", keyPrefix) + } +} + +func TestApiKeyService_Create_InvalidToken(t *testing.T) { + env := setupApiKeyServiceTest(t) + defer env.cleanup() + + _, err := env.service.Create(context.Background(), "tenant1", "invalid-token", "my-key", nil) + if err == nil { + t.Fatal("Expected error with invalid token") + } +} + +func TestApiKeyService_Create_TenantMismatch(t *testing.T) { + env := setupApiKeyServiceTest(t) + defer env.cleanup() + + tenantID := "tenant1" + email := "test@example.com" + createTestAccount(t, env, tenantID, email, "Password123!") + + accessToken := createTestAccessToken(t, env, tenantID, email) + + _, err := env.service.Create(context.Background(), "tenant2", accessToken, "my-key", nil) + if err == nil { + t.Fatal("Expected error with tenant mismatch") + } +} + +func TestApiKeyService_List(t *testing.T) { + env := setupApiKeyServiceTest(t) + defer env.cleanup() + + tenantID := "tenant1" + email := "test@example.com" + createTestAccount(t, env, tenantID, email, "Password123!") + + accessToken := createTestAccessToken(t, env, tenantID, email) + ctx := context.Background() + + _, err := env.service.Create(ctx, tenantID, accessToken, "key-1", nil) + if err != nil { + t.Fatalf("Failed to create first key: %v", err) + } + + _, err = env.service.Create(ctx, tenantID, accessToken, "key-2", nil) + if err != nil { + t.Fatalf("Failed to create second key: %v", err) + } + + keys, err := env.service.List(ctx, tenantID, accessToken) + if err != nil { + t.Fatalf("Failed to list keys: %v", err) + } + + if len(keys) != 2 { + t.Fatalf("Expected 2 keys, got %d", len(keys)) + } +} + +func TestApiKeyService_List_IsolatedByAccount(t *testing.T) { + env := setupApiKeyServiceTest(t) + defer env.cleanup() + + tenantID := "tenant1" + ctx := context.Background() + + createTestAccount(t, env, tenantID, "user1@example.com", "Password123!") + createTestAccount(t, env, tenantID, "user2@example.com", "Password123!") + + token1 := createTestAccessToken(t, env, tenantID, "user1@example.com") + token2 := createTestAccessToken(t, env, tenantID, "user2@example.com") + + _, err := env.service.Create(ctx, tenantID, token1, "user1-key", nil) + if err != nil { + t.Fatalf("Failed to create key for user1: %v", err) + } + + _, err = env.service.Create(ctx, tenantID, token2, "user2-key", nil) + if err != nil { + t.Fatalf("Failed to create key for user2: %v", err) + } + + keys1, err := env.service.List(ctx, tenantID, token1) + if err != nil { + t.Fatalf("Failed to list keys for user1: %v", err) + } + + if len(keys1) != 1 { + t.Fatalf("Expected 1 key for user1, got %d", len(keys1)) + } + + if keys1[0].Name != "user1-key" { + t.Errorf("Expected key name 'user1-key', got '%s'", keys1[0].Name) + } +} + +func TestApiKeyService_Delete(t *testing.T) { + env := setupApiKeyServiceTest(t) + defer env.cleanup() + + tenantID := "tenant1" + email := "test@example.com" + createTestAccount(t, env, tenantID, email, "Password123!") + + accessToken := createTestAccessToken(t, env, tenantID, email) + ctx := context.Background() + + rawKey, err := env.service.Create(ctx, tenantID, accessToken, "to-delete", nil) + if err != nil { + t.Fatalf("Failed to create key: %v", err) + } + + keyPrefix := extractKeyPrefix(rawKey) + err = env.service.Delete(ctx, tenantID, accessToken, keyPrefix) + if err != nil { + t.Fatalf("Failed to delete key: %v", err) + } + + keys, err := env.service.List(ctx, tenantID, accessToken) + if err != nil { + t.Fatalf("Failed to list keys: %v", err) + } + + if len(keys) != 0 { + t.Errorf("Expected 0 keys after delete, got %d", len(keys)) + } +} + +func TestApiKeyService_Delete_NotFound(t *testing.T) { + env := setupApiKeyServiceTest(t) + defer env.cleanup() + + tenantID := "tenant1" + email := "test@example.com" + createTestAccount(t, env, tenantID, email, "Password123!") + + accessToken := createTestAccessToken(t, env, tenantID, email) + + err := env.service.Delete(context.Background(), tenantID, accessToken, "bwa_nonexist") + if err == nil { + t.Fatal("Expected error deleting non-existent key") + } + + _, ok := err.(ApiKeyNotFoundError) + if !ok { + t.Fatalf("Expected ApiKeyNotFoundError, got %T: %v", err, err) + } +} + +func TestApiKeyService_Authenticate(t *testing.T) { + env := setupApiKeyServiceTest(t) + defer env.cleanup() + + tenantID := "tenant1" + email := "test@example.com" + createTestAccount(t, env, tenantID, email, "Password123!") + + accessToken := createTestAccessToken(t, env, tenantID, email) + ctx := context.Background() + + rawKey, err := env.service.Create(ctx, tenantID, accessToken, "auth-key", nil) + if err != nil { + t.Fatalf("Failed to create key: %v", err) + } + + authenticated, err := env.service.Authenticate(ctx, tenantID, email, rawKey) + if err != nil { + t.Fatalf("Failed to authenticate with api key: %v", err) + } + + if authenticated.AccessToken == "" { + t.Error("Expected access token to be set") + } + + if authenticated.RefreshToken == "" { + t.Error("Expected refresh token to be set") + } +} + +func TestApiKeyService_Authenticate_InvalidKey(t *testing.T) { + env := setupApiKeyServiceTest(t) + defer env.cleanup() + + tenantID := "tenant1" + email := "test@example.com" + createTestAccount(t, env, tenantID, email, "Password123!") + + accessToken := createTestAccessToken(t, env, tenantID, email) + ctx := context.Background() + + rawKey, err := env.service.Create(ctx, tenantID, accessToken, "auth-key", nil) + if err != nil { + t.Fatalf("Failed to create key: %v", err) + } + + keyPrefix := extractKeyPrefix(rawKey) + wrongKey := keyPrefix + "_wrongsecret" + _, err = env.service.Authenticate(ctx, tenantID, email, wrongKey) + if err == nil { + t.Fatal("Expected error with invalid key") + } +} + +func TestApiKeyService_Authenticate_BadFormat(t *testing.T) { + env := setupApiKeyServiceTest(t) + defer env.cleanup() + + _, err := env.service.Authenticate(context.Background(), "tenant1", "test@example.com", "not-a-valid-key") + if err == nil { + t.Fatal("Expected error with bad key format") + } + + _, ok := err.(ApiKeyInvalidError) + if !ok { + t.Fatalf("Expected ApiKeyInvalidError, got %T: %v", err, err) + } +} + +func TestApiKeyService_Authenticate_DisabledKey(t *testing.T) { + env := setupApiKeyServiceTest(t) + defer env.cleanup() + + tenantID := "tenant1" + email := "test@example.com" + createTestAccount(t, env, tenantID, email, "Password123!") + + accessToken := createTestAccessToken(t, env, tenantID, email) + ctx := context.Background() + + rawKey, err := env.service.Create(ctx, tenantID, accessToken, "disabled-key", nil) + if err != nil { + t.Fatalf("Failed to create key: %v", err) + } + + // Disable the key directly in the database + keyPrefix := extractKeyPrefix(rawKey) + collection := env.db.Collection(apiKeyCollection) + _, err = collection.UpdateOne(ctx, + map[string]string{"keyPrefix": keyPrefix}, + map[string]interface{}{"$set": map[string]bool{"isEnabled": false}}, + ) + if err != nil { + t.Fatalf("Failed to disable key: %v", err) + } + + _, err = env.service.Authenticate(ctx, tenantID, email, rawKey) + if err == nil { + t.Fatal("Expected error with disabled key") + } + + _, ok := err.(ApiKeyDisabledError) + if !ok { + t.Fatalf("Expected ApiKeyDisabledError, got %T: %v", err, err) + } +} + +func TestApiKeyService_Authenticate_ExpiredKey(t *testing.T) { + env := setupApiKeyServiceTest(t) + defer env.cleanup() + + tenantID := "tenant1" + email := "test@example.com" + createTestAccount(t, env, tenantID, email, "Password123!") + + accessToken := createTestAccessToken(t, env, tenantID, email) + ctx := context.Background() + + expired := time.Now().Add(-1 * time.Hour) + rawKey, err := env.service.Create(ctx, tenantID, accessToken, "expired-key", &expired) + if err != nil { + t.Fatalf("Failed to create key: %v", err) + } + + _, err = env.service.Authenticate(ctx, tenantID, email, rawKey) + if err == nil { + t.Fatal("Expected error with expired key") + } + + _, ok := err.(ApiKeyExpiredError) + if !ok { + t.Fatalf("Expected ApiKeyExpiredError, got %T: %v", err, err) + } +} diff --git a/internal/authentication/apikey_test.go b/internal/authentication/apikey_test.go new file mode 100644 index 0000000..9d7e2b6 --- /dev/null +++ b/internal/authentication/apikey_test.go @@ -0,0 +1,259 @@ +package authentication + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/latebit-io/bulwarkauth/internal/utils" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +func setupApiKeyTestDB(t *testing.T) (*mongo.Database, func()) { + t.Helper() + ctx := context.Background() + mongodb := utils.NewMongoTestUtil() + mongoServer, err := mongodb.CreateServer() + if err != nil { + t.Fatal(err) + } + + clientOptions := options.Client().ApplyURI(mongoServer.URI()) + client, err := mongo.Connect(ctx, clientOptions) + if err != nil { + mongoServer.Stop() + t.Fatal(err) + } + + db := client.Database("testdb") + return db, func() { + client.Disconnect(ctx) + mongoServer.Stop() + } +} + +func newTestApiKey(tenantID, accountID, name, keyPrefix, keyHash string) *ApiKey { + now := time.Now() + return &ApiKey{ + ID: uuid.New().String(), + TenantID: tenantID, + AccountID: accountID, + Name: name, + KeyHash: keyHash, + KeyPrefix: keyPrefix, + IsEnabled: true, + Created: now, + Modified: now, + } +} + +func TestMongoDbApiRepository_Create(t *testing.T) { + db, cleanup := setupApiKeyTestDB(t) + defer cleanup() + + repo := NewMongoDbApiRepository(db) + ctx := context.Background() + + apiKey := newTestApiKey("tenant1", "account1", "my-key", "bwa_abcd1234", "hashedSecret") + + err := repo.Create(ctx, apiKey) + if err != nil { + t.Fatalf("Failed to create api key: %v", err) + } + + result, err := repo.Read(ctx, "tenant1", "account1", "bwa_abcd1234") + if err != nil { + t.Fatalf("Failed to read api key: %v", err) + } + + if result.Name != "my-key" { + t.Errorf("Expected name 'my-key', got '%s'", result.Name) + } + if result.AccountID != "account1" { + t.Errorf("Expected accountID 'account1', got '%s'", result.AccountID) + } + if result.KeyPrefix != "bwa_abcd1234" { + t.Errorf("Expected keyPrefix 'bwa_abcd1234', got '%s'", result.KeyPrefix) + } + if !result.IsEnabled { + t.Error("Expected api key to be enabled") + } +} + +func TestMongoDbApiRepository_Create_DuplicateKeyPrefix(t *testing.T) { + db, cleanup := setupApiKeyTestDB(t) + defer cleanup() + + repo := NewMongoDbApiRepository(db) + ctx := context.Background() + + apiKey1 := newTestApiKey("tenant1", "account1", "key-1", "bwa_same1234", "hash1") + apiKey2 := newTestApiKey("tenant1", "account1", "key-2", "bwa_same1234", "hash2") + + err := repo.Create(ctx, apiKey1) + if err != nil { + t.Fatalf("Failed to create first api key: %v", err) + } + + err = repo.Create(ctx, apiKey2) + if err == nil { + t.Fatal("Expected error creating duplicate key prefix for same account") + } +} + +func TestMongoDbApiRepository_Create_SamePrefixDifferentAccount(t *testing.T) { + db, cleanup := setupApiKeyTestDB(t) + defer cleanup() + + repo := NewMongoDbApiRepository(db) + ctx := context.Background() + + apiKey1 := newTestApiKey("tenant1", "account1", "key-1", "bwa_same1234", "hash1") + apiKey2 := newTestApiKey("tenant1", "account2", "key-1", "bwa_same1234", "hash2") + + err := repo.Create(ctx, apiKey1) + if err != nil { + t.Fatalf("Failed to create first api key: %v", err) + } + + err = repo.Create(ctx, apiKey2) + if err != nil { + t.Fatalf("Same prefix for different accounts should be allowed: %v", err) + } +} + +func TestMongoDbApiRepository_Read_NotFound(t *testing.T) { + db, cleanup := setupApiKeyTestDB(t) + defer cleanup() + + repo := NewMongoDbApiRepository(db) + ctx := context.Background() + + _, err := repo.Read(ctx, "tenant1", "account1", "bwa_nonexist") + if err == nil { + t.Fatal("Expected error for non-existent key") + } + + _, ok := err.(ApiKeyNotFoundError) + if !ok { + t.Fatalf("Expected ApiKeyNotFoundError, got %T: %v", err, err) + } +} + +func TestMongoDbApiRepository_List(t *testing.T) { + db, cleanup := setupApiKeyTestDB(t) + defer cleanup() + + repo := NewMongoDbApiRepository(db) + ctx := context.Background() + + apiKey1 := newTestApiKey("tenant1", "account1", "key-1", "bwa_prefix01", "hash1") + apiKey2 := newTestApiKey("tenant1", "account1", "key-2", "bwa_prefix02", "hash2") + apiKey3 := newTestApiKey("tenant1", "account2", "key-3", "bwa_prefix03", "hash3") + + for _, key := range []*ApiKey{apiKey1, apiKey2, apiKey3} { + if err := repo.Create(ctx, key); err != nil { + t.Fatalf("Failed to create api key: %v", err) + } + } + + keys, err := repo.List(ctx, "tenant1", "account1") + if err != nil { + t.Fatalf("Failed to list api keys: %v", err) + } + + if len(keys) != 2 { + t.Fatalf("Expected 2 keys for account1, got %d", len(keys)) + } +} + +func TestMongoDbApiRepository_List_Empty(t *testing.T) { + db, cleanup := setupApiKeyTestDB(t) + defer cleanup() + + repo := NewMongoDbApiRepository(db) + ctx := context.Background() + + keys, err := repo.List(ctx, "tenant1", "nonexistent") + if err != nil { + t.Fatalf("Failed to list api keys: %v", err) + } + + if keys != nil { + t.Errorf("Expected nil for empty list, got %v", keys) + } +} + +func TestMongoDbApiRepository_Delete(t *testing.T) { + db, cleanup := setupApiKeyTestDB(t) + defer cleanup() + + repo := NewMongoDbApiRepository(db) + ctx := context.Background() + + apiKey := newTestApiKey("tenant1", "account1", "my-key", "bwa_delete01", "hash1") + err := repo.Create(ctx, apiKey) + if err != nil { + t.Fatalf("Failed to create api key: %v", err) + } + + err = repo.Delete(ctx, "tenant1", "account1", "bwa_delete01") + if err != nil { + t.Fatalf("Failed to delete api key: %v", err) + } + + _, err = repo.Read(ctx, "tenant1", "account1", "bwa_delete01") + if err == nil { + t.Fatal("Expected error reading deleted key") + } +} + +func TestMongoDbApiRepository_Delete_NotFound(t *testing.T) { + db, cleanup := setupApiKeyTestDB(t) + defer cleanup() + + repo := NewMongoDbApiRepository(db) + ctx := context.Background() + + err := repo.Delete(ctx, "tenant1", "account1", "bwa_nonexist") + if err == nil { + t.Fatal("Expected error deleting non-existent key") + } + + _, ok := err.(ApiKeyNotFoundError) + if !ok { + t.Fatalf("Expected ApiKeyNotFoundError, got %T: %v", err, err) + } +} + +func TestMongoDbApiRepository_List_IsolatedByTenant(t *testing.T) { + db, cleanup := setupApiKeyTestDB(t) + defer cleanup() + + repo := NewMongoDbApiRepository(db) + ctx := context.Background() + + apiKey1 := newTestApiKey("tenant1", "account1", "key-1", "bwa_prefix01", "hash1") + apiKey2 := newTestApiKey("tenant2", "account1", "key-2", "bwa_prefix01", "hash2") + + for _, key := range []*ApiKey{apiKey1, apiKey2} { + if err := repo.Create(ctx, key); err != nil { + t.Fatalf("Failed to create api key: %v", err) + } + } + + keys, err := repo.List(ctx, "tenant1", "account1") + if err != nil { + t.Fatalf("Failed to list api keys: %v", err) + } + + if len(keys) != 1 { + t.Fatalf("Expected 1 key for tenant1/account1, got %d", len(keys)) + } + + if keys[0].TenantID != "tenant1" { + t.Errorf("Expected tenantId 'tenant1', got '%s'", keys[0].TenantID) + } +} diff --git a/internal/encryption/encryption.go b/internal/encryption/encryption.go index 5ab54d6..933a7dc 100644 --- a/internal/encryption/encryption.go +++ b/internal/encryption/encryption.go @@ -1,7 +1,9 @@ package encryption import ( + "crypto/rand" "errors" + "math/big" "golang.org/x/crypto/bcrypt" ) @@ -46,3 +48,19 @@ func (d DefaultEncryption) Verify(password, verifyPassword string) (bool, error) } return true, nil } + +const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + +// GenerateRandomString generates a cryptographically random string of the specified length. +func GenerateRandomString(length int) string { + b := make([]byte, length) + max := big.NewInt(int64(len(charset))) + for i := range b { + n, err := rand.Int(rand.Reader, max) + if err != nil { + panic(err) + } + b[i] = charset[n.Int64()] + } + return string(b) +}