Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cmd/bulwarkauth/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ func NewAppConfig() (*AppConfig, error) {
config.RequestsPerSecond = getEnvAsInt("REQUESTS_PER_SECOND", 20)
config.AuthenticationAttempts = getEnvAsInt("AUTHENTICATION_ATTEMPTS", 5)
config.LockoutDurationInSecs = getEnvAsInt("LOCKOUT_DURATION_IN_SEC", 15*60)
config.DefaultTenantID = getEnv("DEFAULT_TENANT_ID", "default")

return config, nil
}
Expand Down
12 changes: 3 additions & 9 deletions cmd/bulwarkauth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func main() {
})
wd, _ := os.Getwd()
logger.Info("working directory: ", "dir", wd)
defaultTenantID := config.DefaultTenantID
defaultTenantID := "default"
err = createDefaultTenantID(context.Background(), tenantService, defaultTenantID)
if err != nil {
panic(err)
Expand Down Expand Up @@ -155,15 +155,9 @@ func createDefaultTenantID(ctx context.Context, tenantService tenants.TenantServ
}

if len(existingTenants) == 0 {
return tenantService.CreateDefault(ctx, defaultTenantID)
return tenantService.CreateDefault(ctx)
}

for _, tenant := range existingTenants {
if tenant.ID == defaultTenantID {
return nil
}
}
return tenantService.CreateDefault(ctx, defaultTenantID)
return nil
}

func corsSetting(service *echo.Echo, config *AppConfig, logger *slog.Logger) {
Expand Down
29 changes: 19 additions & 10 deletions internal/tenants/tenants.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tenants

import (
"context"
"time"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
Expand All @@ -12,22 +13,24 @@ const (
)

type Tenant struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Domain string `json:"domain"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Domain string `json:"domain"`
Created time.Time `json:"created"`
Modified time.Time `json:"modified"`
}

type TenantRepository interface {
ReadAll(ctx context.Context) ([]Tenant, error)
Read(ctx context.Context, tenantID string) (*Tenant, error)
Create(ctx context.Context, tenantID string) error
Create(ctx context.Context) error
}

type TenantService interface {
ListTenants(ctx context.Context) ([]Tenant, error)
GetTenant(ctx context.Context, tenantID string) (*Tenant, error)
CreateDefault(ctx context.Context, tenantID string) error
CreateDefault(ctx context.Context) error
}

type MongoDbTenantRepository struct {
Expand Down Expand Up @@ -74,12 +77,18 @@ func (t *MongoDbTenantRepository) Read(ctx context.Context, tenantID string) (*T
return &tenant, nil
}

func (t *MongoDbTenantRepository) Create(ctx context.Context, tenantID string) error {
func (t *MongoDbTenantRepository) Create(ctx context.Context) error {
collection := t.db.Collection(tenantCollection)

tenant := Tenant{
ID: tenantID,
ID: "default",
Name: "default",
Description: "default",
Created: time.Now(),
Modified: time.Now(),
}
_, err := collection.InsertOne(ctx, tenant)

return err
}

Expand All @@ -101,6 +110,6 @@ func (s *DefaultTenantService) GetTenant(ctx context.Context, tenantID string) (
return s.repo.Read(ctx, tenantID)
}

func (s *DefaultTenantService) CreateDefault(ctx context.Context, tenantID string) error {
return s.repo.Create(ctx, tenantID)
func (s *DefaultTenantService) CreateDefault(ctx context.Context) error {
return s.repo.Create(ctx)
}
60 changes: 8 additions & 52 deletions internal/tenants/tenants_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ func TestMongoDbTenantRepository_Create(t *testing.T) {
db := client.Database("bulwark-test")
tenantRepo := NewMongoDbTenantRepository(db)

tenantID := "tenant1"
err = tenantRepo.Create(context.TODO(), tenantID)
tenantID := "default"
err = tenantRepo.Create(context.TODO())
assert.NoError(t, err)

tenant, err := tenantRepo.Read(context.TODO(), tenantID)
Expand Down Expand Up @@ -66,8 +66,8 @@ func TestMongoDbTenantRepository_Read(t *testing.T) {
db := client.Database("bulwark-test")
tenantRepo := NewMongoDbTenantRepository(db)

tenantID := "tenant1"
err = tenantRepo.Create(context.TODO(), tenantID)
tenantID := "default"
err = tenantRepo.Create(context.TODO())
assert.NoError(t, err)

tenant, err := tenantRepo.Read(context.TODO(), tenantID)
Expand All @@ -76,43 +76,6 @@ func TestMongoDbTenantRepository_Read(t *testing.T) {
assert.Equal(t, tenantID, tenant.ID)
}

func TestMongoDbTenantRepository_ReadAll(t *testing.T) {
mongodb := utils.NewMongoTestUtil()
mongoServer, err := mongodb.CreateServer()
if err != nil {
t.Fatal(err)
}
defer mongoServer.Stop()

clientOptions := options.Client().ApplyURI(mongoServer.URI())
client, err := mongo.Connect(context.TODO(), clientOptions)
if err != nil {
t.Fatal(err)
}
defer func() {
err := client.Disconnect(context.TODO())
if err != nil {
t.Fatal(err)
}
}()

db := client.Database("bulwark-test")
tenantRepo := NewMongoDbTenantRepository(db)

tenant1ID := "tenant1"
tenant2ID := "tenant2"

err = tenantRepo.Create(context.TODO(), tenant1ID)
assert.NoError(t, err)

err = tenantRepo.Create(context.TODO(), tenant2ID)
assert.NoError(t, err)

tenants, err := tenantRepo.ReadAll(context.TODO())
assert.NoError(t, err)
assert.Len(t, tenants, 2)
}

func TestDefaultTenantService_CreateDefault(t *testing.T) {
mongodb := utils.NewMongoTestUtil()
mongoServer, err := mongodb.CreateServer()
Expand All @@ -136,9 +99,8 @@ func TestDefaultTenantService_CreateDefault(t *testing.T) {
db := client.Database("bulwark-test")
tenantRepo := NewMongoDbTenantRepository(db)
tenantService := NewDefaultTenantService(tenantRepo)

tenantID := "tenant1"
err = tenantService.CreateDefault(context.TODO(), tenantID)
tenantID := "default"
err = tenantService.CreateDefault(context.TODO())
assert.NoError(t, err)

tenant, err := tenantService.GetTenant(context.TODO(), tenantID)
Expand Down Expand Up @@ -171,16 +133,10 @@ func TestDefaultTenantService_ListTenants(t *testing.T) {
tenantRepo := NewMongoDbTenantRepository(db)
tenantService := NewDefaultTenantService(tenantRepo)

tenant1ID := "tenant1"
tenant2ID := "tenant2"

err = tenantService.CreateDefault(context.TODO(), tenant1ID)
assert.NoError(t, err)

err = tenantService.CreateDefault(context.TODO(), tenant2ID)
err = tenantService.CreateDefault(context.TODO())
assert.NoError(t, err)

tenants, err := tenantService.ListTenants(context.TODO())
assert.NoError(t, err)
assert.Len(t, tenants, 2)
assert.Len(t, tenants, 1)
}
Loading