diff --git a/cmd/api/main.go b/cmd/api/main.go index 9a9e29e..23e4825 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -13,6 +13,7 @@ import ( "github.com/gin-gonic/gin" "github.com/voidrunnerhq/voidrunner/internal/api/routes" "github.com/voidrunnerhq/voidrunner/internal/config" + "github.com/voidrunnerhq/voidrunner/internal/database" "github.com/voidrunnerhq/voidrunner/pkg/logger" ) @@ -25,12 +26,46 @@ func main() { log := logger.New(cfg.Logger.Level, cfg.Logger.Format) + // Initialize database connection + dbConn, err := database.NewConnection(&cfg.Database, log.Logger) + if err != nil { + log.Error("failed to initialize database connection", "error", err) + os.Exit(1) + } + defer dbConn.Close() + + // Run database migrations + migrateConfig := &database.MigrateConfig{ + DatabaseConfig: &cfg.Database, + MigrationsPath: "file://migrations", + Logger: log.Logger, + } + + if err := database.MigrateUp(migrateConfig); err != nil { + log.Error("failed to run database migrations", "error", err) + os.Exit(1) + } + + // Initialize repositories + repos := database.NewRepositories(dbConn) + + // Perform database health check + healthCtx, healthCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer healthCancel() + + if err := dbConn.HealthCheck(healthCtx); err != nil { + log.Error("database health check failed", "error", err) + os.Exit(1) + } + + log.Info("database initialized successfully") + if cfg.IsProduction() { gin.SetMode(gin.ReleaseMode) } router := gin.New() - routes.Setup(router, cfg, log) + routes.Setup(router, cfg, log, repos) srv := &http.Server{ Addr: fmt.Sprintf("%s:%s", cfg.Server.Host, cfg.Server.Port), diff --git a/go.mod b/go.mod index f2819f2..fbeed57 100644 --- a/go.mod +++ b/go.mod @@ -21,9 +21,17 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.26.0 // indirect github.com/goccy/go-json v0.10.5 // indirect + github.com/golang-migrate/migrate/v4 v4.18.3 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.7.5 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect @@ -31,9 +39,11 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.3.0 // indirect + go.uber.org/atomic v1.7.0 // indirect golang.org/x/arch v0.18.0 // indirect golang.org/x/crypto v0.39.0 // indirect golang.org/x/net v0.41.0 // indirect + golang.org/x/sync v0.15.0 // indirect golang.org/x/sys v0.33.0 // indirect golang.org/x/text v0.26.0 // indirect google.golang.org/protobuf v1.36.6 // indirect diff --git a/go.sum b/go.sum index b427cbc..0468b9f 100644 --- a/go.sum +++ b/go.sum @@ -25,9 +25,24 @@ github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/golang-migrate/migrate/v4 v4.18.3 h1:EYGkoOsvgHHfm5U/naS1RP/6PL/Xv3S4B/swMiAmDLs= +github.com/golang-migrate/migrate/v4 v4.18.3/go.mod h1:99BKpIi6ruaaXRM1A77eqZ+FWPQ3cfRa+ZVy5bmWMaY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs= +github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.6.0-pre.2 h1:SCkYm/XGeCcXItAv0Xofqsa4JPdDDkyNcG1Ush5cBLQ= @@ -40,6 +55,8 @@ github.com/klauspost/cpuid/v2 v2.2.10/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQe github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -65,12 +82,16 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA= github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= golang.org/x/arch v0.18.0 h1:WN9poc33zL4AzGxqf8VtpKUnGvMi8O9lhNyBMF/85qc= golang.org/x/arch v0.18.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= diff --git a/internal/api/routes/routes.go b/internal/api/routes/routes.go index 6896a9b..f588f65 100644 --- a/internal/api/routes/routes.go +++ b/internal/api/routes/routes.go @@ -5,12 +5,13 @@ import ( "github.com/voidrunnerhq/voidrunner/internal/api/handlers" "github.com/voidrunnerhq/voidrunner/internal/api/middleware" "github.com/voidrunnerhq/voidrunner/internal/config" + "github.com/voidrunnerhq/voidrunner/internal/database" "github.com/voidrunnerhq/voidrunner/pkg/logger" ) -func Setup(router *gin.Engine, cfg *config.Config, log *logger.Logger) { +func Setup(router *gin.Engine, cfg *config.Config, log *logger.Logger, repos *database.Repositories) { setupMiddleware(router, cfg, log) - setupRoutes(router) + setupRoutes(router, repos) } func setupMiddleware(router *gin.Engine, cfg *config.Config, log *logger.Logger) { @@ -22,7 +23,7 @@ func setupMiddleware(router *gin.Engine, cfg *config.Config, log *logger.Logger) router.Use(middleware.ErrorHandler()) } -func setupRoutes(router *gin.Engine) { +func setupRoutes(router *gin.Engine, repos *database.Repositories) { healthHandler := handlers.NewHealthHandler() router.GET("/health", healthHandler.Health) @@ -35,5 +36,17 @@ func setupRoutes(router *gin.Engine) { "message": "pong", }) }) + + // Future API routes will use repos here + // userHandler := handlers.NewUserHandler(repos.Users) + // taskHandler := handlers.NewTaskHandler(repos.Tasks) + // executionHandler := handlers.NewTaskExecutionHandler(repos.TaskExecutions) + + // v1.POST("/users", userHandler.Create) + // v1.GET("/users/:id", userHandler.GetByID) + // v1.POST("/tasks", taskHandler.Create) + // v1.GET("/tasks/:id", taskHandler.GetByID) + // v1.POST("/tasks/:id/executions", executionHandler.Create) + // v1.GET("/executions/:id", executionHandler.GetByID) } } \ No newline at end of file diff --git a/internal/database/connection.go b/internal/database/connection.go new file mode 100644 index 0000000..f92d8a4 --- /dev/null +++ b/internal/database/connection.go @@ -0,0 +1,150 @@ +package database + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/voidrunnerhq/voidrunner/internal/config" +) + +// Connection represents a database connection pool +type Connection struct { + Pool *pgxpool.Pool + logger *slog.Logger +} + +// NewConnection creates a new database connection pool +func NewConnection(cfg *config.DatabaseConfig, logger *slog.Logger) (*Connection, error) { + if cfg == nil { + return nil, fmt.Errorf("database configuration is required") + } + + if logger == nil { + logger = slog.Default() + } + + connStr := fmt.Sprintf( + "postgres://%s:%s@%s:%s/%s?sslmode=%s", + cfg.User, + cfg.Password, + cfg.Host, + cfg.Port, + cfg.Database, + cfg.SSLMode, + ) + + poolConfig, err := pgxpool.ParseConfig(connStr) + if err != nil { + return nil, fmt.Errorf("failed to parse database connection string: %w", err) + } + + // Configure connection pool settings for optimal performance + poolConfig.MaxConns = 25 // Maximum number of connections + poolConfig.MinConns = 5 // Minimum number of connections + poolConfig.MaxConnLifetime = time.Hour * 1 // Maximum connection lifetime + poolConfig.MaxConnIdleTime = time.Minute * 30 // Maximum connection idle time + poolConfig.HealthCheckPeriod = time.Minute * 5 // Health check frequency + + // Connection timeout settings + poolConfig.ConnConfig.ConnectTimeout = time.Second * 10 + poolConfig.ConnConfig.RuntimeParams["statement_timeout"] = "30s" + poolConfig.ConnConfig.RuntimeParams["idle_in_transaction_session_timeout"] = "60s" + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + pool, err := pgxpool.NewWithConfig(ctx, poolConfig) + if err != nil { + return nil, fmt.Errorf("failed to create database pool: %w", err) + } + + // Test the connection + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("failed to ping database: %w", err) + } + + logger.Info("database connection pool created successfully", + "host", cfg.Host, + "port", cfg.Port, + "database", cfg.Database, + "max_conns", poolConfig.MaxConns, + "min_conns", poolConfig.MinConns, + ) + + return &Connection{ + Pool: pool, + logger: logger, + }, nil +} + +// Close closes the database connection pool +func (c *Connection) Close() { + if c.Pool != nil { + c.logger.Info("closing database connection pool") + c.Pool.Close() + } +} + +// Ping checks if the database connection is alive +func (c *Connection) Ping(ctx context.Context) error { + return c.Pool.Ping(ctx) +} + +// Stats returns connection pool statistics +func (c *Connection) Stats() *pgxpool.Stat { + return c.Pool.Stat() +} + +// LogStats logs connection pool statistics +func (c *Connection) LogStats() { + stats := c.Stats() + c.logger.Info("database connection pool stats", + "total_conns", stats.TotalConns(), + "idle_conns", stats.IdleConns(), + "acquired_conns", stats.AcquiredConns(), + "constructing_conns", stats.ConstructingConns(), + "acquire_count", stats.AcquireCount(), + "acquire_duration", stats.AcquireDuration(), + "acquired_conns_duration", stats.AcquiredConns(), + "canceled_acquire_count", stats.CanceledAcquireCount(), + "empty_acquire_count", stats.EmptyAcquireCount(), + "max_conns", stats.MaxConns(), + "new_conns_count", stats.NewConnsCount(), + ) +} + +// HealthCheck performs a comprehensive health check of the database connection +func (c *Connection) HealthCheck(ctx context.Context) error { + // Check if pool is available + if c.Pool == nil { + return fmt.Errorf("database pool is not initialized") + } + + // Ping the database + if err := c.Pool.Ping(ctx); err != nil { + return fmt.Errorf("database ping failed: %w", err) + } + + // Check pool statistics + stats := c.Stats() + if stats.TotalConns() == 0 { + return fmt.Errorf("no database connections available") + } + + // Execute a simple query to ensure the database is responsive + var result int + err := c.Pool.QueryRow(ctx, "SELECT 1").Scan(&result) + if err != nil { + return fmt.Errorf("database query test failed: %w", err) + } + + if result != 1 { + return fmt.Errorf("unexpected database query result: %d", result) + } + + return nil +} \ No newline at end of file diff --git a/internal/database/integration_test.go b/internal/database/integration_test.go new file mode 100644 index 0000000..c25cb26 --- /dev/null +++ b/internal/database/integration_test.go @@ -0,0 +1,441 @@ +package database + +import ( + "context" + "encoding/json" + "os" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/voidrunnerhq/voidrunner/internal/config" + "github.com/voidrunnerhq/voidrunner/internal/models" +) + +// TestMain sets up integration test environment +func TestMain(m *testing.M) { + // Skip integration tests if no database connection is available + if os.Getenv("INTEGRATION_TESTS") != "true" { + os.Exit(0) + } + + // Run integration tests + code := m.Run() + os.Exit(code) +} + +// setupTestDatabase creates a test database connection and runs migrations +func setupTestDatabase(t *testing.T) (*Connection, *Repositories) { + t.Helper() + + // Skip if not running integration tests + if os.Getenv("INTEGRATION_TESTS") != "true" { + t.Skip("Skipping integration test - set INTEGRATION_TESTS=true to run") + } + + // Load test database configuration + cfg := &config.DatabaseConfig{ + Host: getEnvOrDefault("TEST_DB_HOST", "localhost"), + Port: getEnvOrDefault("TEST_DB_PORT", "5432"), + User: getEnvOrDefault("TEST_DB_USER", "postgres"), + Password: getEnvOrDefault("TEST_DB_PASSWORD", ""), + Database: getEnvOrDefault("TEST_DB_NAME", "voidrunner_test"), + SSLMode: getEnvOrDefault("TEST_DB_SSL_MODE", "disable"), + } + + // Create database connection + conn, err := NewConnection(cfg, nil) + require.NoError(t, err, "Failed to create database connection") + + // Run migrations + migrateConfig := &MigrateConfig{ + DatabaseConfig: cfg, + MigrationsPath: "file://../../migrations", + Logger: nil, + } + + err = MigrateUp(migrateConfig) + require.NoError(t, err, "Failed to run database migrations") + + // Create repositories + repos := NewRepositories(conn) + + // Clean up function + t.Cleanup(func() { + cleanupTestData(t, repos) + conn.Close() + }) + + return conn, repos +} + +// cleanupTestData removes all test data from the database +func cleanupTestData(t *testing.T, repos *Repositories) { + t.Helper() + + // Clean up in reverse order due to foreign key constraints + // Note: In a real test environment, you might want to use transactions + // or a separate test database that gets reset between tests +} + +func TestUserRepository_Integration(t *testing.T) { + _, repos := setupTestDatabase(t) + ctx := context.Background() + + t.Run("user CRUD operations", func(t *testing.T) { + // Create a test user + user := &models.User{ + Email: "integration.test@example.com", + PasswordHash: "hashed_password_123", + } + + // Test Create + err := repos.Users.Create(ctx, user) + require.NoError(t, err) + assert.NotEqual(t, uuid.Nil, user.ID) + assert.False(t, user.CreatedAt.IsZero()) + assert.False(t, user.UpdatedAt.IsZero()) + + // Test GetByID + retrievedUser, err := repos.Users.GetByID(ctx, user.ID) + require.NoError(t, err) + assert.Equal(t, user.Email, retrievedUser.Email) + assert.Equal(t, user.PasswordHash, retrievedUser.PasswordHash) + + // Test GetByEmail + userByEmail, err := repos.Users.GetByEmail(ctx, user.Email) + require.NoError(t, err) + assert.Equal(t, user.ID, userByEmail.ID) + + // Test Update + user.Email = "updated.integration.test@example.com" + err = repos.Users.Update(ctx, user) + require.NoError(t, err) + + updatedUser, err := repos.Users.GetByID(ctx, user.ID) + require.NoError(t, err) + assert.Equal(t, "updated.integration.test@example.com", updatedUser.Email) + assert.True(t, updatedUser.UpdatedAt.After(updatedUser.CreatedAt)) + + // Test Count + count, err := repos.Users.Count(ctx) + require.NoError(t, err) + assert.Greater(t, count, int64(0)) + + // Test List + users, err := repos.Users.List(ctx, 10, 0) + require.NoError(t, err) + assert.NotEmpty(t, users) + + // Test Delete + err = repos.Users.Delete(ctx, user.ID) + require.NoError(t, err) + + // Verify deletion + _, err = repos.Users.GetByID(ctx, user.ID) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) +} + +func TestTaskRepository_Integration(t *testing.T) { + _, repos := setupTestDatabase(t) + ctx := context.Background() + + t.Run("task CRUD operations", func(t *testing.T) { + // First create a user + user := &models.User{ + Email: "task.test@example.com", + PasswordHash: "hashed_password_123", + } + err := repos.Users.Create(ctx, user) + require.NoError(t, err) + + // Clean up user at the end + defer repos.Users.Delete(ctx, user.ID) + + // Create test metadata + metadata, _ := json.Marshal(map[string]interface{}{ + "environment": "test", + "priority": "high", + "tags": []string{"integration", "test"}, + }) + + // Create a test task + task := &models.Task{ + UserID: user.ID, + Name: "Integration Test Task", + Description: stringPtr("Test task for integration testing"), + ScriptContent: "print('Hello from integration test')", + ScriptType: models.ScriptTypePython, + Status: models.TaskStatusPending, + Priority: 1, + TimeoutSeconds: 30, + Metadata: metadata, + } + + // Test Create + err = repos.Tasks.Create(ctx, task) + require.NoError(t, err) + assert.NotEqual(t, uuid.Nil, task.ID) + assert.False(t, task.CreatedAt.IsZero()) + assert.False(t, task.UpdatedAt.IsZero()) + + // Test GetByID + retrievedTask, err := repos.Tasks.GetByID(ctx, task.ID) + require.NoError(t, err) + assert.Equal(t, task.Name, retrievedTask.Name) + assert.Equal(t, task.ScriptContent, retrievedTask.ScriptContent) + assert.Equal(t, task.ScriptType, retrievedTask.ScriptType) + assert.JSONEq(t, string(task.Metadata), string(retrievedTask.Metadata)) + + // Test GetByUserID + userTasks, err := repos.Tasks.GetByUserID(ctx, user.ID, 10, 0) + require.NoError(t, err) + assert.Len(t, userTasks, 1) + assert.Equal(t, task.ID, userTasks[0].ID) + + // Test GetByStatus + pendingTasks, err := repos.Tasks.GetByStatus(ctx, models.TaskStatusPending, 10, 0) + require.NoError(t, err) + assert.NotEmpty(t, pendingTasks) + + // Test UpdateStatus + err = repos.Tasks.UpdateStatus(ctx, task.ID, models.TaskStatusRunning) + require.NoError(t, err) + + updatedTask, err := repos.Tasks.GetByID(ctx, task.ID) + require.NoError(t, err) + assert.Equal(t, models.TaskStatusRunning, updatedTask.Status) + + // Test SearchByMetadata + metadataQuery := `{"environment": "test"}` + searchResults, err := repos.Tasks.SearchByMetadata(ctx, metadataQuery, 10, 0) + require.NoError(t, err) + assert.NotEmpty(t, searchResults) + + // Test Count operations + totalCount, err := repos.Tasks.Count(ctx) + require.NoError(t, err) + assert.Greater(t, totalCount, int64(0)) + + userTaskCount, err := repos.Tasks.CountByUserID(ctx, user.ID) + require.NoError(t, err) + assert.Equal(t, int64(1), userTaskCount) + + runningTaskCount, err := repos.Tasks.CountByStatus(ctx, models.TaskStatusRunning) + require.NoError(t, err) + assert.Greater(t, runningTaskCount, int64(0)) + + // Test Delete + err = repos.Tasks.Delete(ctx, task.ID) + require.NoError(t, err) + + // Verify deletion + _, err = repos.Tasks.GetByID(ctx, task.ID) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) +} + +func TestTaskExecutionRepository_Integration(t *testing.T) { + _, repos := setupTestDatabase(t) + ctx := context.Background() + + t.Run("task execution CRUD operations", func(t *testing.T) { + // First create a user and task + user := &models.User{ + Email: "execution.test@example.com", + PasswordHash: "hashed_password_123", + } + err := repos.Users.Create(ctx, user) + require.NoError(t, err) + + task := &models.Task{ + UserID: user.ID, + Name: "Execution Test Task", + ScriptContent: "print('test')", + ScriptType: models.ScriptTypePython, + Status: models.TaskStatusPending, + Priority: 1, + TimeoutSeconds: 30, + Metadata: json.RawMessage(`{}`), + } + err = repos.Tasks.Create(ctx, task) + require.NoError(t, err) + + // Clean up at the end + defer func() { + repos.Tasks.Delete(ctx, task.ID) + repos.Users.Delete(ctx, user.ID) + }() + + // Create a test execution + execution := &models.TaskExecution{ + TaskID: task.ID, + Status: models.ExecutionStatusPending, + } + + // Test Create + err = repos.TaskExecutions.Create(ctx, execution) + require.NoError(t, err) + assert.NotEqual(t, uuid.Nil, execution.ID) + assert.False(t, execution.CreatedAt.IsZero()) + + // Test GetByID + retrievedExecution, err := repos.TaskExecutions.GetByID(ctx, execution.ID) + require.NoError(t, err) + assert.Equal(t, execution.TaskID, retrievedExecution.TaskID) + assert.Equal(t, execution.Status, retrievedExecution.Status) + + // Test GetByTaskID + taskExecutions, err := repos.TaskExecutions.GetByTaskID(ctx, task.ID, 10, 0) + require.NoError(t, err) + assert.Len(t, taskExecutions, 1) + assert.Equal(t, execution.ID, taskExecutions[0].ID) + + // Test GetLatestByTaskID + latestExecution, err := repos.TaskExecutions.GetLatestByTaskID(ctx, task.ID) + require.NoError(t, err) + assert.Equal(t, execution.ID, latestExecution.ID) + + // Test Update with execution results + startTime := time.Now() + endTime := startTime.Add(2 * time.Second) + execution.Status = models.ExecutionStatusCompleted + execution.ReturnCode = intPtr(0) + execution.Stdout = stringPtr("Test output") + execution.Stderr = stringPtr("") + execution.ExecutionTimeMs = intPtr(2000) + execution.MemoryUsageBytes = int64Ptr(1024 * 1024) // 1MB + execution.StartedAt = &startTime + execution.CompletedAt = &endTime + + err = repos.TaskExecutions.Update(ctx, execution) + require.NoError(t, err) + + updatedExecution, err := repos.TaskExecutions.GetByID(ctx, execution.ID) + require.NoError(t, err) + assert.Equal(t, models.ExecutionStatusCompleted, updatedExecution.Status) + assert.Equal(t, 0, *updatedExecution.ReturnCode) + assert.Equal(t, "Test output", *updatedExecution.Stdout) + + // Test GetByStatus + completedExecutions, err := repos.TaskExecutions.GetByStatus(ctx, models.ExecutionStatusCompleted, 10, 0) + require.NoError(t, err) + assert.NotEmpty(t, completedExecutions) + + // Test Count operations + totalCount, err := repos.TaskExecutions.Count(ctx) + require.NoError(t, err) + assert.Greater(t, totalCount, int64(0)) + + taskExecutionCount, err := repos.TaskExecutions.CountByTaskID(ctx, task.ID) + require.NoError(t, err) + assert.Equal(t, int64(1), taskExecutionCount) + + completedCount, err := repos.TaskExecutions.CountByStatus(ctx, models.ExecutionStatusCompleted) + require.NoError(t, err) + assert.Greater(t, completedCount, int64(0)) + + // Test Delete + err = repos.TaskExecutions.Delete(ctx, execution.ID) + require.NoError(t, err) + + // Verify deletion + _, err = repos.TaskExecutions.GetByID(ctx, execution.ID) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) +} + +// BenchmarkDatabaseOperations provides performance benchmarks +func BenchmarkDatabaseOperations(b *testing.B) { + if os.Getenv("INTEGRATION_TESTS") != "true" { + b.Skip("Skipping benchmark - set INTEGRATION_TESTS=true to run") + } + + _, repos := setupBenchmarkDatabase(b) + ctx := context.Background() + + // Create test user for benchmarks + user := &models.User{ + Email: "benchmark@example.com", + PasswordHash: "hashed_password", + } + repos.Users.Create(ctx, user) + defer repos.Users.Delete(ctx, user.ID) + + b.Run("UserRepository_Create", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + user := &models.User{ + Email: "bench.user." + string(rune(i)) + "@example.com", + PasswordHash: "hashed_password", + } + repos.Users.Create(ctx, user) + repos.Users.Delete(ctx, user.ID) // Clean up + } + }) + + b.Run("UserRepository_GetByID", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + repos.Users.GetByID(ctx, user.ID) + } + }) + + b.Run("TaskRepository_Create", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + task := &models.Task{ + UserID: user.ID, + Name: "Benchmark Task", + ScriptContent: "print('benchmark')", + ScriptType: models.ScriptTypePython, + Status: models.TaskStatusPending, + Priority: 1, + TimeoutSeconds: 30, + Metadata: json.RawMessage(`{}`), + } + repos.Tasks.Create(ctx, task) + repos.Tasks.Delete(ctx, task.ID) // Clean up + } + }) +} + +// Helper functions +func setupBenchmarkDatabase(b *testing.B) (*Connection, *Repositories) { + b.Helper() + + cfg := &config.DatabaseConfig{ + Host: getEnvOrDefault("TEST_DB_HOST", "localhost"), + Port: getEnvOrDefault("TEST_DB_PORT", "5432"), + User: getEnvOrDefault("TEST_DB_USER", "postgres"), + Password: getEnvOrDefault("TEST_DB_PASSWORD", ""), + Database: getEnvOrDefault("TEST_DB_NAME", "voidrunner_test"), + SSLMode: getEnvOrDefault("TEST_DB_SSL_MODE", "disable"), + } + + conn, err := NewConnection(cfg, nil) + if err != nil { + b.Fatalf("Failed to create database connection: %v", err) + } + + repos := NewRepositories(conn) + + b.Cleanup(func() { + conn.Close() + }) + + return conn, repos +} + +func getEnvOrDefault(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} \ No newline at end of file diff --git a/internal/database/interfaces.go b/internal/database/interfaces.go new file mode 100644 index 0000000..0e5c35d --- /dev/null +++ b/internal/database/interfaces.go @@ -0,0 +1,67 @@ +package database + +import ( + "context" + + "github.com/google/uuid" + "github.com/voidrunnerhq/voidrunner/internal/models" +) + +// UserRepository defines the interface for user data operations +type UserRepository interface { + Create(ctx context.Context, user *models.User) error + GetByID(ctx context.Context, id uuid.UUID) (*models.User, error) + GetByEmail(ctx context.Context, email string) (*models.User, error) + Update(ctx context.Context, user *models.User) error + Delete(ctx context.Context, id uuid.UUID) error + List(ctx context.Context, limit, offset int) ([]*models.User, error) + Count(ctx context.Context) (int64, error) +} + +// TaskRepository defines the interface for task data operations +type TaskRepository interface { + Create(ctx context.Context, task *models.Task) error + GetByID(ctx context.Context, id uuid.UUID) (*models.Task, error) + GetByUserID(ctx context.Context, userID uuid.UUID, limit, offset int) ([]*models.Task, error) + GetByStatus(ctx context.Context, status models.TaskStatus, limit, offset int) ([]*models.Task, error) + Update(ctx context.Context, task *models.Task) error + UpdateStatus(ctx context.Context, id uuid.UUID, status models.TaskStatus) error + Delete(ctx context.Context, id uuid.UUID) error + List(ctx context.Context, limit, offset int) ([]*models.Task, error) + Count(ctx context.Context) (int64, error) + CountByUserID(ctx context.Context, userID uuid.UUID) (int64, error) + CountByStatus(ctx context.Context, status models.TaskStatus) (int64, error) + SearchByMetadata(ctx context.Context, query string, limit, offset int) ([]*models.Task, error) +} + +// TaskExecutionRepository defines the interface for task execution data operations +type TaskExecutionRepository interface { + Create(ctx context.Context, execution *models.TaskExecution) error + GetByID(ctx context.Context, id uuid.UUID) (*models.TaskExecution, error) + GetByTaskID(ctx context.Context, taskID uuid.UUID, limit, offset int) ([]*models.TaskExecution, error) + GetLatestByTaskID(ctx context.Context, taskID uuid.UUID) (*models.TaskExecution, error) + GetByStatus(ctx context.Context, status models.ExecutionStatus, limit, offset int) ([]*models.TaskExecution, error) + Update(ctx context.Context, execution *models.TaskExecution) error + UpdateStatus(ctx context.Context, id uuid.UUID, status models.ExecutionStatus) error + Delete(ctx context.Context, id uuid.UUID) error + List(ctx context.Context, limit, offset int) ([]*models.TaskExecution, error) + Count(ctx context.Context) (int64, error) + CountByTaskID(ctx context.Context, taskID uuid.UUID) (int64, error) + CountByStatus(ctx context.Context, status models.ExecutionStatus) (int64, error) +} + +// Repositories aggregates all repository interfaces +type Repositories struct { + Users UserRepository + Tasks TaskRepository + TaskExecutions TaskExecutionRepository +} + +// NewRepositories creates a new repositories instance +func NewRepositories(conn *Connection) *Repositories { + return &Repositories{ + Users: NewUserRepository(conn), + Tasks: NewTaskRepository(conn), + TaskExecutions: NewTaskExecutionRepository(conn), + } +} \ No newline at end of file diff --git a/internal/database/migrate.go b/internal/database/migrate.go new file mode 100644 index 0000000..8857d98 --- /dev/null +++ b/internal/database/migrate.go @@ -0,0 +1,210 @@ +package database + +import ( + "database/sql" + "errors" + "fmt" + "log/slog" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/postgres" + _ "github.com/golang-migrate/migrate/v4/source/file" + _ "github.com/jackc/pgx/v5/stdlib" // pgx driver for database/sql + "github.com/voidrunnerhq/voidrunner/internal/config" +) + +// MigrateConfig holds migration configuration +type MigrateConfig struct { + DatabaseConfig *config.DatabaseConfig + MigrationsPath string + Logger *slog.Logger +} + +// Migrator handles database migrations +type Migrator struct { + migrate *migrate.Migrate + logger *slog.Logger +} + +// NewMigrator creates a new database migrator +func NewMigrator(cfg *MigrateConfig) (*Migrator, error) { + if cfg == nil { + return nil, fmt.Errorf("migration configuration is required") + } + + if cfg.DatabaseConfig == nil { + return nil, fmt.Errorf("database configuration is required") + } + + if cfg.Logger == nil { + cfg.Logger = slog.Default() + } + + if cfg.MigrationsPath == "" { + cfg.MigrationsPath = "file://migrations" + } + + // Create database connection string for sql.DB + connStr := fmt.Sprintf( + "postgres://%s:%s@%s:%s/%s?sslmode=%s", + cfg.DatabaseConfig.User, + cfg.DatabaseConfig.Password, + cfg.DatabaseConfig.Host, + cfg.DatabaseConfig.Port, + cfg.DatabaseConfig.Database, + cfg.DatabaseConfig.SSLMode, + ) + + // Open database connection using database/sql + db, err := sql.Open("pgx", connStr) + if err != nil { + return nil, fmt.Errorf("failed to open database connection: %w", err) + } + + // Test the connection + if err := db.Ping(); err != nil { + db.Close() + return nil, fmt.Errorf("failed to ping database: %w", err) + } + + driver, err := postgres.WithInstance(db, &postgres.Config{}) + if err != nil { + db.Close() + return nil, fmt.Errorf("failed to create postgres driver: %w", err) + } + + m, err := migrate.NewWithDatabaseInstance( + cfg.MigrationsPath, + "postgres", + driver, + ) + if err != nil { + db.Close() + return nil, fmt.Errorf("failed to create migrator: %w", err) + } + + return &Migrator{ + migrate: m, + logger: cfg.Logger, + }, nil +} + +// Up applies all pending migrations +func (m *Migrator) Up() error { + m.logger.Info("applying database migrations") + + err := m.migrate.Up() + if err != nil { + if errors.Is(err, migrate.ErrNoChange) { + m.logger.Info("no migrations to apply") + return nil + } + return fmt.Errorf("failed to apply migrations: %w", err) + } + + m.logger.Info("database migrations applied successfully") + return nil +} + +// Down rolls back one migration +func (m *Migrator) Down() error { + m.logger.Info("rolling back database migration") + + err := m.migrate.Steps(-1) + if err != nil { + if errors.Is(err, migrate.ErrNoChange) { + m.logger.Info("no migrations to roll back") + return nil + } + return fmt.Errorf("failed to roll back migration: %w", err) + } + + m.logger.Info("database migration rolled back successfully") + return nil +} + +// Reset rolls back all migrations +func (m *Migrator) Reset() error { + m.logger.Info("resetting database (rolling back all migrations)") + + err := m.migrate.Drop() + if err != nil { + return fmt.Errorf("failed to reset database: %w", err) + } + + m.logger.Info("database reset successfully") + return nil +} + +// Version returns the current migration version +func (m *Migrator) Version() (uint, bool, error) { + version, dirty, err := m.migrate.Version() + if err != nil { + if errors.Is(err, migrate.ErrNilVersion) { + return 0, false, nil + } + return 0, false, fmt.Errorf("failed to get migration version: %w", err) + } + + return version, dirty, nil +} + +// ForceVersion forces the migration version (use with caution) +func (m *Migrator) ForceVersion(version int) error { + m.logger.Warn("forcing migration version", "version", version) + + err := m.migrate.Force(version) + if err != nil { + return fmt.Errorf("failed to force migration version: %w", err) + } + + m.logger.Info("migration version forced successfully", "version", version) + return nil +} + +// Close closes the migrator +func (m *Migrator) Close() error { + if m.migrate != nil { + sourceErr, dbErr := m.migrate.Close() + if sourceErr != nil { + return fmt.Errorf("failed to close migration source: %w", sourceErr) + } + if dbErr != nil { + return fmt.Errorf("failed to close migration database: %w", dbErr) + } + } + return nil +} + +// MigrateUp is a convenience function to apply migrations +func MigrateUp(cfg *MigrateConfig) error { + migrator, err := NewMigrator(cfg) + if err != nil { + return err + } + defer migrator.Close() + + return migrator.Up() +} + +// MigrateDown is a convenience function to roll back migrations +func MigrateDown(cfg *MigrateConfig) error { + migrator, err := NewMigrator(cfg) + if err != nil { + return err + } + defer migrator.Close() + + return migrator.Down() +} + +// MigrateReset is a convenience function to reset the database +func MigrateReset(cfg *MigrateConfig) error { + migrator, err := NewMigrator(cfg) + if err != nil { + return err + } + defer migrator.Close() + + return migrator.Reset() +} \ No newline at end of file diff --git a/internal/database/task_execution_repository.go b/internal/database/task_execution_repository.go new file mode 100644 index 0000000..401b5ab --- /dev/null +++ b/internal/database/task_execution_repository.go @@ -0,0 +1,368 @@ +package database + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/voidrunnerhq/voidrunner/internal/models" +) + +// taskExecutionRepository implements TaskExecutionRepository interface +type taskExecutionRepository struct { + conn *Connection +} + +// NewTaskExecutionRepository creates a new task execution repository +func NewTaskExecutionRepository(conn *Connection) TaskExecutionRepository { + return &taskExecutionRepository{ + conn: conn, + } +} + +// Create creates a new task execution +func (r *taskExecutionRepository) Create(ctx context.Context, execution *models.TaskExecution) error { + if execution == nil { + return fmt.Errorf("task execution cannot be nil") + } + + if execution.ID == uuid.Nil { + execution.ID = models.NewID() + } + + query := ` + INSERT INTO task_executions (id, task_id, status, return_code, stdout, stderr, execution_time_ms, memory_usage_bytes, started_at, completed_at, created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW()) + RETURNING created_at + ` + + err := r.conn.Pool.QueryRow(ctx, query, + execution.ID, + execution.TaskID, + execution.Status, + execution.ReturnCode, + execution.Stdout, + execution.Stderr, + execution.ExecutionTimeMs, + execution.MemoryUsageBytes, + execution.StartedAt, + execution.CompletedAt, + ).Scan(&execution.CreatedAt) + + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + switch pgErr.Code { + case "23505": // unique_violation + return fmt.Errorf("task execution with ID %s already exists", execution.ID) + case "23503": // foreign_key_violation + if strings.Contains(pgErr.Detail, "task_id") { + return fmt.Errorf("task with ID %s does not exist", execution.TaskID) + } + case "23514": // check_violation + return fmt.Errorf("task execution validation failed: %s", pgErr.Detail) + } + } + return fmt.Errorf("failed to create task execution: %w", err) + } + + return nil +} + +// GetByID retrieves a task execution by ID +func (r *taskExecutionRepository) GetByID(ctx context.Context, id uuid.UUID) (*models.TaskExecution, error) { + query := ` + SELECT id, task_id, status, return_code, stdout, stderr, execution_time_ms, memory_usage_bytes, started_at, completed_at, created_at + FROM task_executions + WHERE id = $1 + ` + + var execution models.TaskExecution + err := r.conn.Pool.QueryRow(ctx, query, id).Scan( + &execution.ID, + &execution.TaskID, + &execution.Status, + &execution.ReturnCode, + &execution.Stdout, + &execution.Stderr, + &execution.ExecutionTimeMs, + &execution.MemoryUsageBytes, + &execution.StartedAt, + &execution.CompletedAt, + &execution.CreatedAt, + ) + + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, fmt.Errorf("task execution with ID %s not found", id) + } + return nil, fmt.Errorf("failed to get task execution by ID: %w", err) + } + + return &execution, nil +} + +// GetByTaskID retrieves task executions by task ID with pagination +func (r *taskExecutionRepository) GetByTaskID(ctx context.Context, taskID uuid.UUID, limit, offset int) ([]*models.TaskExecution, error) { + if limit <= 0 { + limit = 10 + } + if offset < 0 { + offset = 0 + } + + query := ` + SELECT id, task_id, status, return_code, stdout, stderr, execution_time_ms, memory_usage_bytes, started_at, completed_at, created_at + FROM task_executions + WHERE task_id = $1 + ORDER BY created_at DESC + LIMIT $2 OFFSET $3 + ` + + rows, err := r.conn.Pool.Query(ctx, query, taskID, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to get task executions by task ID: %w", err) + } + defer rows.Close() + + return r.scanTaskExecutions(rows) +} + +// GetLatestByTaskID retrieves the latest task execution for a task +func (r *taskExecutionRepository) GetLatestByTaskID(ctx context.Context, taskID uuid.UUID) (*models.TaskExecution, error) { + query := ` + SELECT id, task_id, status, return_code, stdout, stderr, execution_time_ms, memory_usage_bytes, started_at, completed_at, created_at + FROM task_executions + WHERE task_id = $1 + ORDER BY created_at DESC + LIMIT 1 + ` + + var execution models.TaskExecution + err := r.conn.Pool.QueryRow(ctx, query, taskID).Scan( + &execution.ID, + &execution.TaskID, + &execution.Status, + &execution.ReturnCode, + &execution.Stdout, + &execution.Stderr, + &execution.ExecutionTimeMs, + &execution.MemoryUsageBytes, + &execution.StartedAt, + &execution.CompletedAt, + &execution.CreatedAt, + ) + + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, fmt.Errorf("no task executions found for task ID %s", taskID) + } + return nil, fmt.Errorf("failed to get latest task execution by task ID: %w", err) + } + + return &execution, nil +} + +// GetByStatus retrieves task executions by status with pagination +func (r *taskExecutionRepository) GetByStatus(ctx context.Context, status models.ExecutionStatus, limit, offset int) ([]*models.TaskExecution, error) { + if limit <= 0 { + limit = 10 + } + if offset < 0 { + offset = 0 + } + + query := ` + SELECT id, task_id, status, return_code, stdout, stderr, execution_time_ms, memory_usage_bytes, started_at, completed_at, created_at + FROM task_executions + WHERE status = $1 + ORDER BY created_at DESC + LIMIT $2 OFFSET $3 + ` + + rows, err := r.conn.Pool.Query(ctx, query, status, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to get task executions by status: %w", err) + } + defer rows.Close() + + return r.scanTaskExecutions(rows) +} + +// Update updates a task execution +func (r *taskExecutionRepository) Update(ctx context.Context, execution *models.TaskExecution) error { + if execution == nil { + return fmt.Errorf("task execution cannot be nil") + } + + query := ` + UPDATE task_executions + SET status = $2, return_code = $3, stdout = $4, stderr = $5, execution_time_ms = $6, memory_usage_bytes = $7, started_at = $8, completed_at = $9 + WHERE id = $1 + ` + + result, err := r.conn.Pool.Exec(ctx, query, + execution.ID, + execution.Status, + execution.ReturnCode, + execution.Stdout, + execution.Stderr, + execution.ExecutionTimeMs, + execution.MemoryUsageBytes, + execution.StartedAt, + execution.CompletedAt, + ) + + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "23514" { + return fmt.Errorf("task execution validation failed: %s", pgErr.Detail) + } + return fmt.Errorf("failed to update task execution: %w", err) + } + + if result.RowsAffected() == 0 { + return fmt.Errorf("task execution with ID %s not found", execution.ID) + } + + return nil +} + +// UpdateStatus updates only the status of a task execution +func (r *taskExecutionRepository) UpdateStatus(ctx context.Context, id uuid.UUID, status models.ExecutionStatus) error { + query := ` + UPDATE task_executions + SET status = $2 + WHERE id = $1 + ` + + result, err := r.conn.Pool.Exec(ctx, query, id, status) + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "23514" { + return fmt.Errorf("invalid task execution status: %s", status) + } + return fmt.Errorf("failed to update task execution status: %w", err) + } + + if result.RowsAffected() == 0 { + return fmt.Errorf("task execution with ID %s not found", id) + } + + return nil +} + +// Delete deletes a task execution +func (r *taskExecutionRepository) Delete(ctx context.Context, id uuid.UUID) error { + query := `DELETE FROM task_executions WHERE id = $1` + + result, err := r.conn.Pool.Exec(ctx, query, id) + if err != nil { + return fmt.Errorf("failed to delete task execution: %w", err) + } + + if result.RowsAffected() == 0 { + return fmt.Errorf("task execution with ID %s not found", id) + } + + return nil +} + +// List retrieves task executions with pagination +func (r *taskExecutionRepository) List(ctx context.Context, limit, offset int) ([]*models.TaskExecution, error) { + if limit <= 0 { + limit = 10 + } + if offset < 0 { + offset = 0 + } + + query := ` + SELECT id, task_id, status, return_code, stdout, stderr, execution_time_ms, memory_usage_bytes, started_at, completed_at, created_at + FROM task_executions + ORDER BY created_at DESC + LIMIT $1 OFFSET $2 + ` + + rows, err := r.conn.Pool.Query(ctx, query, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to list task executions: %w", err) + } + defer rows.Close() + + return r.scanTaskExecutions(rows) +} + +// Count returns the total number of task executions +func (r *taskExecutionRepository) Count(ctx context.Context) (int64, error) { + query := `SELECT COUNT(*) FROM task_executions` + + var count int64 + err := r.conn.Pool.QueryRow(ctx, query).Scan(&count) + if err != nil { + return 0, fmt.Errorf("failed to count task executions: %w", err) + } + + return count, nil +} + +// CountByTaskID returns the total number of task executions for a task +func (r *taskExecutionRepository) CountByTaskID(ctx context.Context, taskID uuid.UUID) (int64, error) { + query := `SELECT COUNT(*) FROM task_executions WHERE task_id = $1` + + var count int64 + err := r.conn.Pool.QueryRow(ctx, query, taskID).Scan(&count) + if err != nil { + return 0, fmt.Errorf("failed to count task executions by task ID: %w", err) + } + + return count, nil +} + +// CountByStatus returns the total number of task executions with a specific status +func (r *taskExecutionRepository) CountByStatus(ctx context.Context, status models.ExecutionStatus) (int64, error) { + query := `SELECT COUNT(*) FROM task_executions WHERE status = $1` + + var count int64 + err := r.conn.Pool.QueryRow(ctx, query, status).Scan(&count) + if err != nil { + return 0, fmt.Errorf("failed to count task executions by status: %w", err) + } + + return count, nil +} + +// scanTaskExecutions is a helper function to scan task execution rows +func (r *taskExecutionRepository) scanTaskExecutions(rows pgx.Rows) ([]*models.TaskExecution, error) { + var executions []*models.TaskExecution + for rows.Next() { + var execution models.TaskExecution + err := rows.Scan( + &execution.ID, + &execution.TaskID, + &execution.Status, + &execution.ReturnCode, + &execution.Stdout, + &execution.Stderr, + &execution.ExecutionTimeMs, + &execution.MemoryUsageBytes, + &execution.StartedAt, + &execution.CompletedAt, + &execution.CreatedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan task execution row: %w", err) + } + executions = append(executions, &execution) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating task execution rows: %w", err) + } + + return executions, nil +} \ No newline at end of file diff --git a/internal/database/task_execution_repository_test.go b/internal/database/task_execution_repository_test.go new file mode 100644 index 0000000..ad8775c --- /dev/null +++ b/internal/database/task_execution_repository_test.go @@ -0,0 +1,440 @@ +package database + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/voidrunnerhq/voidrunner/internal/models" +) + +func TestTaskExecutionRepository_Create(t *testing.T) { + tests := []struct { + name string + execution *models.TaskExecution + wantError bool + errorMsg string + }{ + { + name: "successful task execution creation", + execution: &models.TaskExecution{ + TaskID: uuid.New(), + Status: models.ExecutionStatusPending, + }, + wantError: false, + }, + { + name: "nil task execution", + execution: nil, + wantError: true, + errorMsg: "task execution cannot be nil", + }, + { + name: "task execution with results", + execution: &models.TaskExecution{ + TaskID: uuid.New(), + Status: models.ExecutionStatusCompleted, + ReturnCode: intPtr(0), + Stdout: stringPtr("Hello, World!"), + Stderr: stringPtr(""), + ExecutionTimeMs: intPtr(1500), + MemoryUsageBytes: int64Ptr(1024 * 1024), // 1MB + StartedAt: timePtr(time.Now().Add(-2 * time.Second)), + CompletedAt: timePtr(time.Now()), + }, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskExecutionRepository_GetByID(t *testing.T) { + tests := []struct { + name string + executionID uuid.UUID + wantError bool + errorMsg string + }{ + { + name: "successful get by ID", + executionID: uuid.New(), + wantError: false, + }, + { + name: "task execution not found", + executionID: uuid.New(), + wantError: true, + errorMsg: "not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskExecutionRepository_GetByTaskID(t *testing.T) { + tests := []struct { + name string + taskID uuid.UUID + limit int + offset int + wantError bool + }{ + { + name: "successful get by task ID", + taskID: uuid.New(), + limit: 10, + offset: 0, + wantError: false, + }, + { + name: "default limit for zero limit", + taskID: uuid.New(), + limit: 0, + offset: 0, + wantError: false, + }, + { + name: "default offset for negative offset", + taskID: uuid.New(), + limit: 10, + offset: -1, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskExecutionRepository_GetLatestByTaskID(t *testing.T) { + tests := []struct { + name string + taskID uuid.UUID + wantError bool + errorMsg string + }{ + { + name: "successful get latest by task ID", + taskID: uuid.New(), + wantError: false, + }, + { + name: "no executions found", + taskID: uuid.New(), + wantError: true, + errorMsg: "no task executions found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskExecutionRepository_GetByStatus(t *testing.T) { + tests := []struct { + name string + status models.ExecutionStatus + limit int + offset int + wantError bool + }{ + { + name: "successful get by status", + status: models.ExecutionStatusRunning, + limit: 10, + offset: 0, + wantError: false, + }, + { + name: "get completed executions", + status: models.ExecutionStatusCompleted, + limit: 5, + offset: 0, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskExecutionRepository_Update(t *testing.T) { + tests := []struct { + name string + execution *models.TaskExecution + wantError bool + errorMsg string + }{ + { + name: "successful update", + execution: &models.TaskExecution{ + ID: uuid.New(), + TaskID: uuid.New(), + Status: models.ExecutionStatusCompleted, + ReturnCode: intPtr(0), + Stdout: stringPtr("Output"), + Stderr: stringPtr(""), + ExecutionTimeMs: intPtr(1000), + MemoryUsageBytes: int64Ptr(512 * 1024), + StartedAt: timePtr(time.Now().Add(-1 * time.Second)), + CompletedAt: timePtr(time.Now()), + }, + wantError: false, + }, + { + name: "nil task execution", + execution: nil, + wantError: true, + errorMsg: "task execution cannot be nil", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskExecutionRepository_UpdateStatus(t *testing.T) { + tests := []struct { + name string + executionID uuid.UUID + status models.ExecutionStatus + wantError bool + errorMsg string + }{ + { + name: "successful status update", + executionID: uuid.New(), + status: models.ExecutionStatusRunning, + wantError: false, + }, + { + name: "task execution not found", + executionID: uuid.New(), + status: models.ExecutionStatusCompleted, + wantError: true, + errorMsg: "not found", + }, + { + name: "invalid status", + executionID: uuid.New(), + status: "invalid_status", + wantError: true, + errorMsg: "invalid task execution status", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskExecutionRepository_Delete(t *testing.T) { + tests := []struct { + name string + executionID uuid.UUID + wantError bool + errorMsg string + }{ + { + name: "successful delete", + executionID: uuid.New(), + wantError: false, + }, + { + name: "task execution not found", + executionID: uuid.New(), + wantError: true, + errorMsg: "not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskExecutionRepository_Count(t *testing.T) { + t.Run("successful count", func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) +} + +func TestTaskExecutionRepository_CountByTaskID(t *testing.T) { + t.Run("successful count by task ID", func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) +} + +func TestTaskExecutionRepository_CountByStatus(t *testing.T) { + t.Run("successful count by status", func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) +} + +// Mock tests for business logic validation +func TestTaskExecutionRepository_CreateValidation(t *testing.T) { + repo := &taskExecutionRepository{conn: nil} // Mock repository + + t.Run("nil task execution validation", func(t *testing.T) { + err := repo.Create(context.Background(), nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "task execution cannot be nil") + }) +} + +func TestTaskExecutionRepository_UpdateValidation(t *testing.T) { + repo := &taskExecutionRepository{conn: nil} // Mock repository + + t.Run("nil task execution validation", func(t *testing.T) { + err := repo.Update(context.Background(), nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "task execution cannot be nil") + }) +} + +func TestTaskExecutionRepository_ScanTaskExecutions(t *testing.T) { + t.Run("scan task executions with nil rows", func(t *testing.T) { + // This would test the scanTaskExecutions method with mock rows + // In a real implementation, you would use testify/mock or similar + t.Skip("Requires mock implementation") + }) +} + +// Helper functions for testing +func createTestTaskExecution(t *testing.T, taskID uuid.UUID, status models.ExecutionStatus) *models.TaskExecution { + t.Helper() + return &models.TaskExecution{ + ID: uuid.New(), + TaskID: taskID, + Status: status, + CreatedAt: time.Now(), + } +} + +func createCompletedTaskExecution(t *testing.T, taskID uuid.UUID) *models.TaskExecution { + t.Helper() + startTime := time.Now().Add(-2 * time.Second) + endTime := time.Now() + + return &models.TaskExecution{ + ID: uuid.New(), + TaskID: taskID, + Status: models.ExecutionStatusCompleted, + ReturnCode: intPtr(0), + Stdout: stringPtr("Task completed successfully"), + Stderr: stringPtr(""), + ExecutionTimeMs: intPtr(int(endTime.Sub(startTime).Milliseconds())), + MemoryUsageBytes: int64Ptr(1024 * 1024), // 1MB + StartedAt: &startTime, + CompletedAt: &endTime, + CreatedAt: startTime, + } +} + +func assertTaskExecutionEqual(t *testing.T, expected, actual *models.TaskExecution) { + t.Helper() + assert.Equal(t, expected.ID, actual.ID) + assert.Equal(t, expected.TaskID, actual.TaskID) + assert.Equal(t, expected.Status, actual.Status) + + if expected.ReturnCode != nil && actual.ReturnCode != nil { + assert.Equal(t, *expected.ReturnCode, *actual.ReturnCode) + } else { + assert.Equal(t, expected.ReturnCode, actual.ReturnCode) + } + + if expected.Stdout != nil && actual.Stdout != nil { + assert.Equal(t, *expected.Stdout, *actual.Stdout) + } else { + assert.Equal(t, expected.Stdout, actual.Stdout) + } + + if expected.Stderr != nil && actual.Stderr != nil { + assert.Equal(t, *expected.Stderr, *actual.Stderr) + } else { + assert.Equal(t, expected.Stderr, actual.Stderr) + } + + if expected.ExecutionTimeMs != nil && actual.ExecutionTimeMs != nil { + assert.Equal(t, *expected.ExecutionTimeMs, *actual.ExecutionTimeMs) + } else { + assert.Equal(t, expected.ExecutionTimeMs, actual.ExecutionTimeMs) + } + + if expected.MemoryUsageBytes != nil && actual.MemoryUsageBytes != nil { + assert.Equal(t, *expected.MemoryUsageBytes, *actual.MemoryUsageBytes) + } else { + assert.Equal(t, expected.MemoryUsageBytes, actual.MemoryUsageBytes) + } + + if expected.StartedAt != nil && actual.StartedAt != nil { + assert.WithinDuration(t, *expected.StartedAt, *actual.StartedAt, time.Second) + } else { + assert.Equal(t, expected.StartedAt, actual.StartedAt) + } + + if expected.CompletedAt != nil && actual.CompletedAt != nil { + assert.WithinDuration(t, *expected.CompletedAt, *actual.CompletedAt, time.Second) + } else { + assert.Equal(t, expected.CompletedAt, actual.CompletedAt) + } + + assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second) +} + +// Helper functions for pointer values +func intPtr(i int) *int { + return &i +} + +func stringPtr(s string) *string { + return &s +} + +func int64Ptr(i int64) *int64 { + return &i +} + +func timePtr(t time.Time) *time.Time { + return &t +} + +// Benchmark tests +func BenchmarkTaskExecutionRepository_Create(b *testing.B) { + b.Skip("Integration benchmark - requires database connection") +} + +func BenchmarkTaskExecutionRepository_GetByID(b *testing.B) { + b.Skip("Integration benchmark - requires database connection") +} + +func BenchmarkTaskExecutionRepository_GetByTaskID(b *testing.B) { + b.Skip("Integration benchmark - requires database connection") +} \ No newline at end of file diff --git a/internal/database/task_repository.go b/internal/database/task_repository.go new file mode 100644 index 0000000..c4cda2b --- /dev/null +++ b/internal/database/task_repository.go @@ -0,0 +1,365 @@ +package database + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/voidrunnerhq/voidrunner/internal/models" +) + +// taskRepository implements TaskRepository interface +type taskRepository struct { + conn *Connection +} + +// NewTaskRepository creates a new task repository +func NewTaskRepository(conn *Connection) TaskRepository { + return &taskRepository{ + conn: conn, + } +} + +// Create creates a new task +func (r *taskRepository) Create(ctx context.Context, task *models.Task) error { + if task == nil { + return fmt.Errorf("task cannot be nil") + } + + if task.ID == uuid.Nil { + task.ID = models.NewID() + } + + query := ` + INSERT INTO tasks (id, user_id, name, description, script_content, script_type, status, priority, timeout_seconds, metadata, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW(), NOW()) + RETURNING created_at, updated_at + ` + + err := r.conn.Pool.QueryRow(ctx, query, + task.ID, + task.UserID, + task.Name, + task.Description, + task.ScriptContent, + task.ScriptType, + task.Status, + task.Priority, + task.TimeoutSeconds, + task.Metadata, + ).Scan(&task.CreatedAt, &task.UpdatedAt) + + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + switch pgErr.Code { + case "23505": // unique_violation + return fmt.Errorf("task with ID %s already exists", task.ID) + case "23503": // foreign_key_violation + if strings.Contains(pgErr.Detail, "user_id") { + return fmt.Errorf("user with ID %s does not exist", task.UserID) + } + case "23514": // check_violation + return fmt.Errorf("task validation failed: %s", pgErr.Detail) + } + } + return fmt.Errorf("failed to create task: %w", err) + } + + return nil +} + +// GetByID retrieves a task by ID +func (r *taskRepository) GetByID(ctx context.Context, id uuid.UUID) (*models.Task, error) { + query := ` + SELECT id, user_id, name, description, script_content, script_type, status, priority, timeout_seconds, metadata, created_at, updated_at + FROM tasks + WHERE id = $1 + ` + + var task models.Task + err := r.conn.Pool.QueryRow(ctx, query, id).Scan( + &task.ID, + &task.UserID, + &task.Name, + &task.Description, + &task.ScriptContent, + &task.ScriptType, + &task.Status, + &task.Priority, + &task.TimeoutSeconds, + &task.Metadata, + &task.CreatedAt, + &task.UpdatedAt, + ) + + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, fmt.Errorf("task with ID %s not found", id) + } + return nil, fmt.Errorf("failed to get task by ID: %w", err) + } + + return &task, nil +} + +// GetByUserID retrieves tasks by user ID with pagination +func (r *taskRepository) GetByUserID(ctx context.Context, userID uuid.UUID, limit, offset int) ([]*models.Task, error) { + if limit <= 0 { + limit = 10 + } + if offset < 0 { + offset = 0 + } + + query := ` + SELECT id, user_id, name, description, script_content, script_type, status, priority, timeout_seconds, metadata, created_at, updated_at + FROM tasks + WHERE user_id = $1 + ORDER BY priority DESC, created_at DESC + LIMIT $2 OFFSET $3 + ` + + rows, err := r.conn.Pool.Query(ctx, query, userID, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to get tasks by user ID: %w", err) + } + defer rows.Close() + + return r.scanTasks(rows) +} + +// GetByStatus retrieves tasks by status with pagination +func (r *taskRepository) GetByStatus(ctx context.Context, status models.TaskStatus, limit, offset int) ([]*models.Task, error) { + if limit <= 0 { + limit = 10 + } + if offset < 0 { + offset = 0 + } + + query := ` + SELECT id, user_id, name, description, script_content, script_type, status, priority, timeout_seconds, metadata, created_at, updated_at + FROM tasks + WHERE status = $1 + ORDER BY priority DESC, created_at DESC + LIMIT $2 OFFSET $3 + ` + + rows, err := r.conn.Pool.Query(ctx, query, status, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to get tasks by status: %w", err) + } + defer rows.Close() + + return r.scanTasks(rows) +} + +// Update updates a task +func (r *taskRepository) Update(ctx context.Context, task *models.Task) error { + if task == nil { + return fmt.Errorf("task cannot be nil") + } + + query := ` + UPDATE tasks + SET name = $2, description = $3, script_content = $4, script_type = $5, status = $6, priority = $7, timeout_seconds = $8, metadata = $9, updated_at = NOW() + WHERE id = $1 + RETURNING updated_at + ` + + err := r.conn.Pool.QueryRow(ctx, query, + task.ID, + task.Name, + task.Description, + task.ScriptContent, + task.ScriptType, + task.Status, + task.Priority, + task.TimeoutSeconds, + task.Metadata, + ).Scan(&task.UpdatedAt) + + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return fmt.Errorf("task with ID %s not found", task.ID) + } + + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + switch pgErr.Code { + case "23514": // check_violation + return fmt.Errorf("task validation failed: %s", pgErr.Detail) + } + } + return fmt.Errorf("failed to update task: %w", err) + } + + return nil +} + +// UpdateStatus updates only the status of a task +func (r *taskRepository) UpdateStatus(ctx context.Context, id uuid.UUID, status models.TaskStatus) error { + query := ` + UPDATE tasks + SET status = $2, updated_at = NOW() + WHERE id = $1 + ` + + result, err := r.conn.Pool.Exec(ctx, query, id, status) + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "23514" { + return fmt.Errorf("invalid task status: %s", status) + } + return fmt.Errorf("failed to update task status: %w", err) + } + + if result.RowsAffected() == 0 { + return fmt.Errorf("task with ID %s not found", id) + } + + return nil +} + +// Delete deletes a task +func (r *taskRepository) Delete(ctx context.Context, id uuid.UUID) error { + query := `DELETE FROM tasks WHERE id = $1` + + result, err := r.conn.Pool.Exec(ctx, query, id) + if err != nil { + return fmt.Errorf("failed to delete task: %w", err) + } + + if result.RowsAffected() == 0 { + return fmt.Errorf("task with ID %s not found", id) + } + + return nil +} + +// List retrieves tasks with pagination +func (r *taskRepository) List(ctx context.Context, limit, offset int) ([]*models.Task, error) { + if limit <= 0 { + limit = 10 + } + if offset < 0 { + offset = 0 + } + + query := ` + SELECT id, user_id, name, description, script_content, script_type, status, priority, timeout_seconds, metadata, created_at, updated_at + FROM tasks + ORDER BY priority DESC, created_at DESC + LIMIT $1 OFFSET $2 + ` + + rows, err := r.conn.Pool.Query(ctx, query, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to list tasks: %w", err) + } + defer rows.Close() + + return r.scanTasks(rows) +} + +// Count returns the total number of tasks +func (r *taskRepository) Count(ctx context.Context) (int64, error) { + query := `SELECT COUNT(*) FROM tasks` + + var count int64 + err := r.conn.Pool.QueryRow(ctx, query).Scan(&count) + if err != nil { + return 0, fmt.Errorf("failed to count tasks: %w", err) + } + + return count, nil +} + +// CountByUserID returns the total number of tasks for a user +func (r *taskRepository) CountByUserID(ctx context.Context, userID uuid.UUID) (int64, error) { + query := `SELECT COUNT(*) FROM tasks WHERE user_id = $1` + + var count int64 + err := r.conn.Pool.QueryRow(ctx, query, userID).Scan(&count) + if err != nil { + return 0, fmt.Errorf("failed to count tasks by user ID: %w", err) + } + + return count, nil +} + +// CountByStatus returns the total number of tasks with a specific status +func (r *taskRepository) CountByStatus(ctx context.Context, status models.TaskStatus) (int64, error) { + query := `SELECT COUNT(*) FROM tasks WHERE status = $1` + + var count int64 + err := r.conn.Pool.QueryRow(ctx, query, status).Scan(&count) + if err != nil { + return 0, fmt.Errorf("failed to count tasks by status: %w", err) + } + + return count, nil +} + +// SearchByMetadata searches tasks by metadata using JSON operators +func (r *taskRepository) SearchByMetadata(ctx context.Context, query string, limit, offset int) ([]*models.Task, error) { + if limit <= 0 { + limit = 10 + } + if offset < 0 { + offset = 0 + } + + sqlQuery := ` + SELECT id, user_id, name, description, script_content, script_type, status, priority, timeout_seconds, metadata, created_at, updated_at + FROM tasks + WHERE metadata @> $1 + ORDER BY priority DESC, created_at DESC + LIMIT $2 OFFSET $3 + ` + + rows, err := r.conn.Pool.Query(ctx, sqlQuery, query, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to search tasks by metadata: %w", err) + } + defer rows.Close() + + return r.scanTasks(rows) +} + +// scanTasks is a helper function to scan task rows +func (r *taskRepository) scanTasks(rows pgx.Rows) ([]*models.Task, error) { + var tasks []*models.Task + for rows.Next() { + var task models.Task + err := rows.Scan( + &task.ID, + &task.UserID, + &task.Name, + &task.Description, + &task.ScriptContent, + &task.ScriptType, + &task.Status, + &task.Priority, + &task.TimeoutSeconds, + &task.Metadata, + &task.CreatedAt, + &task.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan task row: %w", err) + } + tasks = append(tasks, &task) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating task rows: %w", err) + } + + return tasks, nil +} \ No newline at end of file diff --git a/internal/database/task_repository_test.go b/internal/database/task_repository_test.go new file mode 100644 index 0000000..f2ea529 --- /dev/null +++ b/internal/database/task_repository_test.go @@ -0,0 +1,398 @@ +package database + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/voidrunnerhq/voidrunner/internal/models" +) + +func TestTaskRepository_Create(t *testing.T) { + tests := []struct { + name string + task *models.Task + wantError bool + errorMsg string + }{ + { + name: "successful task creation", + task: &models.Task{ + UserID: uuid.New(), + Name: "Test Task", + ScriptContent: "print('hello world')", + ScriptType: models.ScriptTypePython, + Status: models.TaskStatusPending, + Priority: 1, + TimeoutSeconds: 30, + Metadata: json.RawMessage(`{"key":"value"}`), + }, + wantError: false, + }, + { + name: "nil task", + task: nil, + wantError: true, + errorMsg: "task cannot be nil", + }, + { + name: "invalid script type", + task: &models.Task{ + UserID: uuid.New(), + Name: "Test Task", + ScriptContent: "print('hello world')", + ScriptType: "invalid", + Status: models.TaskStatusPending, + Priority: 1, + TimeoutSeconds: 30, + }, + wantError: true, + errorMsg: "validation failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskRepository_GetByID(t *testing.T) { + tests := []struct { + name string + taskID uuid.UUID + wantError bool + errorMsg string + }{ + { + name: "successful get by ID", + taskID: uuid.New(), + wantError: false, + }, + { + name: "task not found", + taskID: uuid.New(), + wantError: true, + errorMsg: "not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskRepository_GetByUserID(t *testing.T) { + tests := []struct { + name string + userID uuid.UUID + limit int + offset int + wantError bool + }{ + { + name: "successful get by user ID", + userID: uuid.New(), + limit: 10, + offset: 0, + wantError: false, + }, + { + name: "default limit for zero limit", + userID: uuid.New(), + limit: 0, + offset: 0, + wantError: false, + }, + { + name: "default offset for negative offset", + userID: uuid.New(), + limit: 10, + offset: -1, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskRepository_GetByStatus(t *testing.T) { + tests := []struct { + name string + status models.TaskStatus + limit int + offset int + wantError bool + }{ + { + name: "successful get by status", + status: models.TaskStatusPending, + limit: 10, + offset: 0, + wantError: false, + }, + { + name: "get completed tasks", + status: models.TaskStatusCompleted, + limit: 5, + offset: 0, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskRepository_Update(t *testing.T) { + tests := []struct { + name string + task *models.Task + wantError bool + errorMsg string + }{ + { + name: "successful update", + task: &models.Task{ + BaseModel: models.BaseModel{ + ID: uuid.New(), + }, + UserID: uuid.New(), + Name: "Updated Task", + ScriptContent: "print('updated')", + ScriptType: models.ScriptTypePython, + Status: models.TaskStatusRunning, + Priority: 2, + TimeoutSeconds: 60, + }, + wantError: false, + }, + { + name: "nil task", + task: nil, + wantError: true, + errorMsg: "task cannot be nil", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskRepository_UpdateStatus(t *testing.T) { + tests := []struct { + name string + taskID uuid.UUID + status models.TaskStatus + wantError bool + errorMsg string + }{ + { + name: "successful status update", + taskID: uuid.New(), + status: models.TaskStatusRunning, + wantError: false, + }, + { + name: "task not found", + taskID: uuid.New(), + status: models.TaskStatusCompleted, + wantError: true, + errorMsg: "not found", + }, + { + name: "invalid status", + taskID: uuid.New(), + status: "invalid_status", + wantError: true, + errorMsg: "invalid task status", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskRepository_SearchByMetadata(t *testing.T) { + tests := []struct { + name string + query string + limit int + offset int + wantError bool + }{ + { + name: "successful metadata search", + query: `{"environment": "production"}`, + limit: 10, + offset: 0, + wantError: false, + }, + { + name: "empty query", + query: `{}`, + limit: 10, + offset: 0, + wantError: false, + }, + { + name: "complex query", + query: `{"tags": ["urgent", "api"]}`, + limit: 5, + offset: 0, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskRepository_Delete(t *testing.T) { + tests := []struct { + name string + taskID uuid.UUID + wantError bool + errorMsg string + }{ + { + name: "successful delete", + taskID: uuid.New(), + wantError: false, + }, + { + name: "task not found", + taskID: uuid.New(), + wantError: true, + errorMsg: "not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestTaskRepository_Count(t *testing.T) { + t.Run("successful count", func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) +} + +func TestTaskRepository_CountByUserID(t *testing.T) { + t.Run("successful count by user ID", func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) +} + +func TestTaskRepository_CountByStatus(t *testing.T) { + t.Run("successful count by status", func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) +} + +// Mock tests for business logic validation +func TestTaskRepository_CreateValidation(t *testing.T) { + repo := &taskRepository{conn: nil} // Mock repository + + t.Run("nil task validation", func(t *testing.T) { + err := repo.Create(context.Background(), nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "task cannot be nil") + }) +} + +func TestTaskRepository_UpdateValidation(t *testing.T) { + repo := &taskRepository{conn: nil} // Mock repository + + t.Run("nil task validation", func(t *testing.T) { + err := repo.Update(context.Background(), nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "task cannot be nil") + }) +} + +func TestTaskRepository_ScanTasks(t *testing.T) { + t.Run("scan tasks with nil rows", func(t *testing.T) { + // This would test the scanTasks method with mock rows + // In a real implementation, you would use testify/mock or similar + t.Skip("Requires mock implementation") + }) +} + +// Helper functions for testing +func createTestTask(t *testing.T, userID uuid.UUID, name string) *models.Task { + t.Helper() + metadata, _ := json.Marshal(map[string]interface{}{ + "environment": "test", + "created_by": "test_suite", + }) + + return &models.Task{ + BaseModel: models.BaseModel{ + ID: uuid.New(), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + UserID: userID, + Name: name, + ScriptContent: "print('test')", + ScriptType: models.ScriptTypePython, + Status: models.TaskStatusPending, + Priority: 1, + TimeoutSeconds: 30, + Metadata: metadata, + } +} + +func assertTaskEqual(t *testing.T, expected, actual *models.Task) { + t.Helper() + assert.Equal(t, expected.ID, actual.ID) + assert.Equal(t, expected.UserID, actual.UserID) + assert.Equal(t, expected.Name, actual.Name) + assert.Equal(t, expected.ScriptContent, actual.ScriptContent) + assert.Equal(t, expected.ScriptType, actual.ScriptType) + assert.Equal(t, expected.Status, actual.Status) + assert.Equal(t, expected.Priority, actual.Priority) + assert.Equal(t, expected.TimeoutSeconds, actual.TimeoutSeconds) + assert.JSONEq(t, string(expected.Metadata), string(actual.Metadata)) + assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second) + assert.WithinDuration(t, expected.UpdatedAt, actual.UpdatedAt, time.Second) +} + +// Benchmark tests +func BenchmarkTaskRepository_Create(b *testing.B) { + b.Skip("Integration benchmark - requires database connection") +} + +func BenchmarkTaskRepository_GetByID(b *testing.B) { + b.Skip("Integration benchmark - requires database connection") +} + +func BenchmarkTaskRepository_SearchByMetadata(b *testing.B) { + b.Skip("Integration benchmark - requires database connection") +} \ No newline at end of file diff --git a/internal/database/user_repository.go b/internal/database/user_repository.go new file mode 100644 index 0000000..b55b57f --- /dev/null +++ b/internal/database/user_repository.go @@ -0,0 +1,229 @@ +package database + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/voidrunnerhq/voidrunner/internal/models" +) + +// userRepository implements UserRepository interface +type userRepository struct { + conn *Connection +} + +// NewUserRepository creates a new user repository +func NewUserRepository(conn *Connection) UserRepository { + return &userRepository{ + conn: conn, + } +} + +// Create creates a new user +func (r *userRepository) Create(ctx context.Context, user *models.User) error { + if user == nil { + return fmt.Errorf("user cannot be nil") + } + + if user.ID == uuid.Nil { + user.ID = models.NewID() + } + + query := ` + INSERT INTO users (id, email, password_hash, created_at, updated_at) + VALUES ($1, $2, $3, NOW(), NOW()) + RETURNING created_at, updated_at + ` + + err := r.conn.Pool.QueryRow(ctx, query, user.ID, user.Email, user.PasswordHash). + Scan(&user.CreatedAt, &user.UpdatedAt) + + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + switch pgErr.Code { + case "23505": // unique_violation + if strings.Contains(pgErr.Detail, "email") { + return fmt.Errorf("user with email %s already exists", user.Email) + } + return fmt.Errorf("user with ID %s already exists", user.ID) + } + } + return fmt.Errorf("failed to create user: %w", err) + } + + return nil +} + +// GetByID retrieves a user by ID +func (r *userRepository) GetByID(ctx context.Context, id uuid.UUID) (*models.User, error) { + query := ` + SELECT id, email, password_hash, created_at, updated_at + FROM users + WHERE id = $1 + ` + + var user models.User + err := r.conn.Pool.QueryRow(ctx, query, id).Scan( + &user.ID, + &user.Email, + &user.PasswordHash, + &user.CreatedAt, + &user.UpdatedAt, + ) + + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, fmt.Errorf("user with ID %s not found", id) + } + return nil, fmt.Errorf("failed to get user by ID: %w", err) + } + + return &user, nil +} + +// GetByEmail retrieves a user by email +func (r *userRepository) GetByEmail(ctx context.Context, email string) (*models.User, error) { + if email == "" { + return nil, fmt.Errorf("email cannot be empty") + } + + query := ` + SELECT id, email, password_hash, created_at, updated_at + FROM users + WHERE email = $1 + ` + + var user models.User + err := r.conn.Pool.QueryRow(ctx, query, email).Scan( + &user.ID, + &user.Email, + &user.PasswordHash, + &user.CreatedAt, + &user.UpdatedAt, + ) + + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, fmt.Errorf("user with email %s not found", email) + } + return nil, fmt.Errorf("failed to get user by email: %w", err) + } + + return &user, nil +} + +// Update updates a user +func (r *userRepository) Update(ctx context.Context, user *models.User) error { + if user == nil { + return fmt.Errorf("user cannot be nil") + } + + query := ` + UPDATE users + SET email = $2, password_hash = $3, updated_at = NOW() + WHERE id = $1 + RETURNING updated_at + ` + + err := r.conn.Pool.QueryRow(ctx, query, user.ID, user.Email, user.PasswordHash). + Scan(&user.UpdatedAt) + + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return fmt.Errorf("user with ID %s not found", user.ID) + } + + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + switch pgErr.Code { + case "23505": // unique_violation + if strings.Contains(pgErr.Detail, "email") { + return fmt.Errorf("user with email %s already exists", user.Email) + } + } + } + return fmt.Errorf("failed to update user: %w", err) + } + + return nil +} + +// Delete deletes a user +func (r *userRepository) Delete(ctx context.Context, id uuid.UUID) error { + query := `DELETE FROM users WHERE id = $1` + + result, err := r.conn.Pool.Exec(ctx, query, id) + if err != nil { + return fmt.Errorf("failed to delete user: %w", err) + } + + if result.RowsAffected() == 0 { + return fmt.Errorf("user with ID %s not found", id) + } + + return nil +} + +// List retrieves users with pagination +func (r *userRepository) List(ctx context.Context, limit, offset int) ([]*models.User, error) { + if limit <= 0 { + limit = 10 + } + if offset < 0 { + offset = 0 + } + + query := ` + SELECT id, email, password_hash, created_at, updated_at + FROM users + ORDER BY created_at DESC + LIMIT $1 OFFSET $2 + ` + + rows, err := r.conn.Pool.Query(ctx, query, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to list users: %w", err) + } + defer rows.Close() + + var users []*models.User + for rows.Next() { + var user models.User + err := rows.Scan( + &user.ID, + &user.Email, + &user.PasswordHash, + &user.CreatedAt, + &user.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan user row: %w", err) + } + users = append(users, &user) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating user rows: %w", err) + } + + return users, nil +} + +// Count returns the total number of users +func (r *userRepository) Count(ctx context.Context) (int64, error) { + query := `SELECT COUNT(*) FROM users` + + var count int64 + err := r.conn.Pool.QueryRow(ctx, query).Scan(&count) + if err != nil { + return 0, fmt.Errorf("failed to count users: %w", err) + } + + return count, nil +} \ No newline at end of file diff --git a/internal/database/user_repository_test.go b/internal/database/user_repository_test.go new file mode 100644 index 0000000..1e70994 --- /dev/null +++ b/internal/database/user_repository_test.go @@ -0,0 +1,291 @@ +package database + +import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/voidrunnerhq/voidrunner/internal/models" +) + +func TestUserRepository_Create(t *testing.T) { + tests := []struct { + name string + user *models.User + wantError bool + errorMsg string + }{ + { + name: "successful user creation", + user: &models.User{ + Email: "test@example.com", + PasswordHash: "hashed_password", + }, + wantError: false, + }, + { + name: "nil user", + user: nil, + wantError: true, + errorMsg: "user cannot be nil", + }, + { + name: "duplicate email", + user: &models.User{ + Email: "duplicate@example.com", + PasswordHash: "hashed_password", + }, + wantError: true, + errorMsg: "already exists", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // This is a unit test template + // In a real implementation, you would use a test database or mock + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestUserRepository_GetByID(t *testing.T) { + tests := []struct { + name string + userID uuid.UUID + wantError bool + errorMsg string + }{ + { + name: "successful get by ID", + userID: uuid.New(), + wantError: false, + }, + { + name: "user not found", + userID: uuid.New(), + wantError: true, + errorMsg: "not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestUserRepository_GetByEmail(t *testing.T) { + tests := []struct { + name string + email string + wantError bool + errorMsg string + }{ + { + name: "successful get by email", + email: "test@example.com", + wantError: false, + }, + { + name: "empty email", + email: "", + wantError: true, + errorMsg: "email cannot be empty", + }, + { + name: "user not found", + email: "nonexistent@example.com", + wantError: true, + errorMsg: "not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestUserRepository_Update(t *testing.T) { + tests := []struct { + name string + user *models.User + wantError bool + errorMsg string + }{ + { + name: "successful update", + user: &models.User{ + BaseModel: models.BaseModel{ + ID: uuid.New(), + }, + Email: "updated@example.com", + PasswordHash: "new_hashed_password", + }, + wantError: false, + }, + { + name: "nil user", + user: nil, + wantError: true, + errorMsg: "user cannot be nil", + }, + { + name: "user not found", + user: &models.User{ + BaseModel: models.BaseModel{ + ID: uuid.New(), + }, + Email: "notfound@example.com", + PasswordHash: "hashed_password", + }, + wantError: true, + errorMsg: "not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestUserRepository_Delete(t *testing.T) { + tests := []struct { + name string + userID uuid.UUID + wantError bool + errorMsg string + }{ + { + name: "successful delete", + userID: uuid.New(), + wantError: false, + }, + { + name: "user not found", + userID: uuid.New(), + wantError: true, + errorMsg: "not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestUserRepository_List(t *testing.T) { + tests := []struct { + name string + limit int + offset int + wantError bool + }{ + { + name: "successful list with valid pagination", + limit: 10, + offset: 0, + wantError: false, + }, + { + name: "default limit for zero limit", + limit: 0, + offset: 0, + wantError: false, + }, + { + name: "default offset for negative offset", + limit: 10, + offset: -1, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) + } +} + +func TestUserRepository_Count(t *testing.T) { + t.Run("successful count", func(t *testing.T) { + t.Skip("Integration test - requires database connection") + }) +} + +// Mock tests for business logic validation +func TestUserRepository_CreateValidation(t *testing.T) { + repo := &userRepository{conn: nil} // Mock repository + + t.Run("nil user validation", func(t *testing.T) { + err := repo.Create(context.Background(), nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "user cannot be nil") + }) +} + +func TestUserRepository_GetByEmailValidation(t *testing.T) { + repo := &userRepository{conn: nil} // Mock repository + + t.Run("empty email validation", func(t *testing.T) { + _, err := repo.GetByEmail(context.Background(), "") + assert.Error(t, err) + assert.Contains(t, err.Error(), "email cannot be empty") + }) +} + +func TestUserRepository_UpdateValidation(t *testing.T) { + repo := &userRepository{conn: nil} // Mock repository + + t.Run("nil user validation", func(t *testing.T) { + err := repo.Update(context.Background(), nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "user cannot be nil") + }) +} + +// Helper functions for testing +func createTestUser(t *testing.T, email string) *models.User { + t.Helper() + return &models.User{ + BaseModel: models.BaseModel{ + ID: uuid.New(), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + Email: email, + PasswordHash: "test_password_hash", + } +} + +func assertUserEqual(t *testing.T, expected, actual *models.User) { + t.Helper() + assert.Equal(t, expected.ID, actual.ID) + assert.Equal(t, expected.Email, actual.Email) + assert.Equal(t, expected.PasswordHash, actual.PasswordHash) + assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second) + assert.WithinDuration(t, expected.UpdatedAt, actual.UpdatedAt, time.Second) +} + +// Benchmark tests +func BenchmarkUserRepository_Create(b *testing.B) { + b.Skip("Integration benchmark - requires database connection") +} + +func BenchmarkUserRepository_GetByID(b *testing.B) { + b.Skip("Integration benchmark - requires database connection") +} + +func BenchmarkUserRepository_GetByEmail(b *testing.B) { + b.Skip("Integration benchmark - requires database connection") +} \ No newline at end of file diff --git a/internal/models/base.go b/internal/models/base.go new file mode 100644 index 0000000..659fee6 --- /dev/null +++ b/internal/models/base.go @@ -0,0 +1,24 @@ +package models + +import ( + "time" + + "github.com/google/uuid" +) + +// BaseModel contains common fields for all models +type BaseModel struct { + ID uuid.UUID `json:"id" db:"id"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +// NewID generates a new UUID +func NewID() uuid.UUID { + return uuid.New() +} + +// ValidateID checks if an ID is valid +func ValidateID(id string) (uuid.UUID, error) { + return uuid.Parse(id) +} \ No newline at end of file diff --git a/internal/models/base_test.go b/internal/models/base_test.go new file mode 100644 index 0000000..2265fe1 --- /dev/null +++ b/internal/models/base_test.go @@ -0,0 +1,67 @@ +package models + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestNewID(t *testing.T) { + id1 := NewID() + id2 := NewID() + + // IDs should be valid UUIDs + assert.NotEqual(t, uuid.Nil, id1) + assert.NotEqual(t, uuid.Nil, id2) + + // IDs should be different + assert.NotEqual(t, id1, id2) +} + +func TestValidateID(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + { + name: "valid UUID string", + input: "550e8400-e29b-41d4-a716-446655440000", + wantErr: false, + }, + { + name: "valid UUID string with different format", + input: NewID().String(), + wantErr: false, + }, + { + name: "invalid UUID string", + input: "invalid-uuid", + wantErr: true, + }, + { + name: "empty string", + input: "", + wantErr: true, + }, + { + name: "UUID string with wrong length", + input: "550e8400-e29b-41d4-a716", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id, err := ValidateID(tt.input) + if tt.wantErr { + assert.Error(t, err) + assert.Equal(t, uuid.Nil, id) + } else { + assert.NoError(t, err) + assert.NotEqual(t, uuid.Nil, id) + } + }) + } +} \ No newline at end of file diff --git a/internal/models/task.go b/internal/models/task.go new file mode 100644 index 0000000..5754181 --- /dev/null +++ b/internal/models/task.go @@ -0,0 +1,181 @@ +package models + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/google/uuid" +) + +// TaskStatus represents the status of a task +type TaskStatus string + +const ( + TaskStatusPending TaskStatus = "pending" + TaskStatusRunning TaskStatus = "running" + TaskStatusCompleted TaskStatus = "completed" + TaskStatusFailed TaskStatus = "failed" + TaskStatusTimeout TaskStatus = "timeout" + TaskStatusCancelled TaskStatus = "cancelled" +) + +// ScriptType represents the type of script +type ScriptType string + +const ( + ScriptTypePython ScriptType = "python" + ScriptTypeJavaScript ScriptType = "javascript" + ScriptTypeBash ScriptType = "bash" + ScriptTypeGo ScriptType = "go" +) + +// Task represents a task in the system +type Task struct { + BaseModel + UserID uuid.UUID `json:"user_id" db:"user_id"` + Name string `json:"name" db:"name"` + Description *string `json:"description,omitempty" db:"description"` + ScriptContent string `json:"script_content" db:"script_content"` + ScriptType ScriptType `json:"script_type" db:"script_type"` + Status TaskStatus `json:"status" db:"status"` + Priority int `json:"priority" db:"priority"` + TimeoutSeconds int `json:"timeout_seconds" db:"timeout_seconds"` + Metadata json.RawMessage `json:"metadata" db:"metadata"` +} + +// CreateTaskRequest represents the request to create a new task +type CreateTaskRequest struct { + Name string `json:"name" validate:"required,min=1,max=255"` + Description *string `json:"description,omitempty" validate:"omitempty,max=1000"` + ScriptContent string `json:"script_content" validate:"required,min=1,max=65535"` + ScriptType ScriptType `json:"script_type" validate:"required,oneof=python javascript bash go"` + Priority *int `json:"priority,omitempty" validate:"omitempty,min=0,max=10"` + TimeoutSeconds *int `json:"timeout_seconds,omitempty" validate:"omitempty,min=1,max=3600"` + Metadata json.RawMessage `json:"metadata,omitempty"` +} + +// UpdateTaskRequest represents the request to update a task +type UpdateTaskRequest struct { + Name *string `json:"name,omitempty" validate:"omitempty,min=1,max=255"` + Description *string `json:"description,omitempty" validate:"omitempty,max=1000"` + ScriptContent *string `json:"script_content,omitempty" validate:"omitempty,min=1,max=65535"` + ScriptType *ScriptType `json:"script_type,omitempty" validate:"omitempty,oneof=python javascript bash go"` + Priority *int `json:"priority,omitempty" validate:"omitempty,min=0,max=10"` + TimeoutSeconds *int `json:"timeout_seconds,omitempty" validate:"omitempty,min=1,max=3600"` + Metadata json.RawMessage `json:"metadata,omitempty"` +} + +// TaskResponse represents the task response +type TaskResponse struct { + ID uuid.UUID `json:"id"` + UserID uuid.UUID `json:"user_id"` + Name string `json:"name"` + Description *string `json:"description,omitempty"` + ScriptContent string `json:"script_content"` + ScriptType ScriptType `json:"script_type"` + Status TaskStatus `json:"status"` + Priority int `json:"priority"` + TimeoutSeconds int `json:"timeout_seconds"` + Metadata json.RawMessage `json:"metadata"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +// ToResponse converts Task to TaskResponse +func (t *Task) ToResponse() TaskResponse { + return TaskResponse{ + ID: t.ID, + UserID: t.UserID, + Name: t.Name, + Description: t.Description, + ScriptContent: t.ScriptContent, + ScriptType: t.ScriptType, + Status: t.Status, + Priority: t.Priority, + TimeoutSeconds: t.TimeoutSeconds, + Metadata: t.Metadata, + CreatedAt: t.CreatedAt.Format("2006-01-02T15:04:05Z07:00"), + UpdatedAt: t.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"), + } +} + +// ValidateTaskName validates the task name +func ValidateTaskName(name string) error { + if name == "" { + return fmt.Errorf("task name is required") + } + + name = strings.TrimSpace(name) + if len(name) == 0 { + return fmt.Errorf("task name cannot be empty") + } + + if len(name) > 255 { + return fmt.Errorf("task name is too long (max 255 characters)") + } + + return nil +} + +// ValidateScriptType validates the script type +func ValidateScriptType(scriptType ScriptType) error { + switch scriptType { + case ScriptTypePython, ScriptTypeJavaScript, ScriptTypeBash, ScriptTypeGo: + return nil + default: + return fmt.Errorf("invalid script type: %s", scriptType) + } +} + +// ValidateScriptContent validates the script content +func ValidateScriptContent(content string) error { + if content == "" { + return fmt.Errorf("script content is required") + } + + content = strings.TrimSpace(content) + if len(content) == 0 { + return fmt.Errorf("script content cannot be empty") + } + + if len(content) > 65535 { + return fmt.Errorf("script content is too long (max 65535 characters)") + } + + // Basic security checks + if strings.Contains(strings.ToLower(content), "rm -rf") { + return fmt.Errorf("potentially dangerous script content detected") + } + + return nil +} + +// ValidateTaskStatus validates the task status +func ValidateTaskStatus(status TaskStatus) error { + switch status { + case TaskStatusPending, TaskStatusRunning, TaskStatusCompleted, TaskStatusFailed, TaskStatusTimeout, TaskStatusCancelled: + return nil + default: + return fmt.Errorf("invalid task status: %s", status) + } +} + +// ValidatePriority validates the priority value +func ValidatePriority(priority int) error { + if priority < 0 || priority > 10 { + return fmt.Errorf("priority must be between 0 and 10") + } + return nil +} + +// ValidateTimeout validates the timeout value +func ValidateTimeout(timeout int) error { + if timeout <= 0 { + return fmt.Errorf("timeout must be greater than 0") + } + if timeout > 3600 { + return fmt.Errorf("timeout cannot exceed 3600 seconds") + } + return nil +} \ No newline at end of file diff --git a/internal/models/task_execution.go b/internal/models/task_execution.go new file mode 100644 index 0000000..fbbee5d --- /dev/null +++ b/internal/models/task_execution.go @@ -0,0 +1,132 @@ +package models + +import ( + "fmt" + "time" + + "github.com/google/uuid" +) + +// ExecutionStatus represents the status of a task execution +type ExecutionStatus string + +const ( + ExecutionStatusPending ExecutionStatus = "pending" + ExecutionStatusRunning ExecutionStatus = "running" + ExecutionStatusCompleted ExecutionStatus = "completed" + ExecutionStatusFailed ExecutionStatus = "failed" + ExecutionStatusTimeout ExecutionStatus = "timeout" + ExecutionStatusCancelled ExecutionStatus = "cancelled" +) + +// TaskExecution represents a task execution in the system +type TaskExecution struct { + ID uuid.UUID `json:"id" db:"id"` + TaskID uuid.UUID `json:"task_id" db:"task_id"` + Status ExecutionStatus `json:"status" db:"status"` + ReturnCode *int `json:"return_code,omitempty" db:"return_code"` + Stdout *string `json:"stdout,omitempty" db:"stdout"` + Stderr *string `json:"stderr,omitempty" db:"stderr"` + ExecutionTimeMs *int `json:"execution_time_ms,omitempty" db:"execution_time_ms"` + MemoryUsageBytes *int64 `json:"memory_usage_bytes,omitempty" db:"memory_usage_bytes"` + StartedAt *time.Time `json:"started_at,omitempty" db:"started_at"` + CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"` + CreatedAt time.Time `json:"created_at" db:"created_at"` +} + +// CreateTaskExecutionRequest represents the request to create a new task execution +type CreateTaskExecutionRequest struct { + TaskID uuid.UUID `json:"task_id" validate:"required"` +} + +// UpdateTaskExecutionRequest represents the request to update a task execution +type UpdateTaskExecutionRequest struct { + Status *ExecutionStatus `json:"status,omitempty"` + ReturnCode *int `json:"return_code,omitempty" validate:"omitempty,min=0,max=255"` + Stdout *string `json:"stdout,omitempty"` + Stderr *string `json:"stderr,omitempty"` + ExecutionTimeMs *int `json:"execution_time_ms,omitempty" validate:"omitempty,min=0"` + MemoryUsageBytes *int64 `json:"memory_usage_bytes,omitempty" validate:"omitempty,min=0"` + StartedAt *time.Time `json:"started_at,omitempty"` + CompletedAt *time.Time `json:"completed_at,omitempty"` +} + +// TaskExecutionResponse represents the task execution response +type TaskExecutionResponse struct { + ID uuid.UUID `json:"id"` + TaskID uuid.UUID `json:"task_id"` + Status ExecutionStatus `json:"status"` + ReturnCode *int `json:"return_code,omitempty"` + Stdout *string `json:"stdout,omitempty"` + Stderr *string `json:"stderr,omitempty"` + ExecutionTimeMs *int `json:"execution_time_ms,omitempty"` + MemoryUsageBytes *int64 `json:"memory_usage_bytes,omitempty"` + StartedAt *string `json:"started_at,omitempty"` + CompletedAt *string `json:"completed_at,omitempty"` + CreatedAt string `json:"created_at"` +} + +// ToResponse converts TaskExecution to TaskExecutionResponse +func (te *TaskExecution) ToResponse() TaskExecutionResponse { + response := TaskExecutionResponse{ + ID: te.ID, + TaskID: te.TaskID, + Status: te.Status, + ReturnCode: te.ReturnCode, + Stdout: te.Stdout, + Stderr: te.Stderr, + ExecutionTimeMs: te.ExecutionTimeMs, + MemoryUsageBytes: te.MemoryUsageBytes, + CreatedAt: te.CreatedAt.Format("2006-01-02T15:04:05Z07:00"), + } + + if te.StartedAt != nil { + startedAtStr := te.StartedAt.Format("2006-01-02T15:04:05Z07:00") + response.StartedAt = &startedAtStr + } + + if te.CompletedAt != nil { + completedAtStr := te.CompletedAt.Format("2006-01-02T15:04:05Z07:00") + response.CompletedAt = &completedAtStr + } + + return response +} + +// ValidateExecutionStatus validates the execution status +func ValidateExecutionStatus(status ExecutionStatus) error { + switch status { + case ExecutionStatusPending, ExecutionStatusRunning, ExecutionStatusCompleted, + ExecutionStatusFailed, ExecutionStatusTimeout, ExecutionStatusCancelled: + return nil + default: + return fmt.Errorf("invalid execution status: %s", status) + } +} + +// IsTerminal returns true if the execution status is terminal (completed, failed, timeout, cancelled) +func (te *TaskExecution) IsTerminal() bool { + return te.Status == ExecutionStatusCompleted || + te.Status == ExecutionStatusFailed || + te.Status == ExecutionStatusTimeout || + te.Status == ExecutionStatusCancelled +} + +// IsRunning returns true if the execution is currently running +func (te *TaskExecution) IsRunning() bool { + return te.Status == ExecutionStatusRunning +} + +// IsPending returns true if the execution is pending +func (te *TaskExecution) IsPending() bool { + return te.Status == ExecutionStatusPending +} + +// GetDuration returns the execution duration in milliseconds +func (te *TaskExecution) GetDuration() *int { + if te.StartedAt != nil && te.CompletedAt != nil { + duration := int(te.CompletedAt.Sub(*te.StartedAt).Milliseconds()) + return &duration + } + return te.ExecutionTimeMs +} \ No newline at end of file diff --git a/internal/models/task_execution_test.go b/internal/models/task_execution_test.go new file mode 100644 index 0000000..a58c0b1 --- /dev/null +++ b/internal/models/task_execution_test.go @@ -0,0 +1,266 @@ +package models + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestValidateExecutionStatus(t *testing.T) { + tests := []struct { + name string + status ExecutionStatus + wantErr bool + }{ + { + name: "valid pending status", + status: ExecutionStatusPending, + wantErr: false, + }, + { + name: "valid running status", + status: ExecutionStatusRunning, + wantErr: false, + }, + { + name: "valid completed status", + status: ExecutionStatusCompleted, + wantErr: false, + }, + { + name: "valid failed status", + status: ExecutionStatusFailed, + wantErr: false, + }, + { + name: "valid timeout status", + status: ExecutionStatusTimeout, + wantErr: false, + }, + { + name: "valid cancelled status", + status: ExecutionStatusCancelled, + wantErr: false, + }, + { + name: "invalid status", + status: "invalid", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateExecutionStatus(tt.status) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestTaskExecution_IsTerminal(t *testing.T) { + tests := []struct { + name string + status ExecutionStatus + expected bool + }{ + { + name: "pending is not terminal", + status: ExecutionStatusPending, + expected: false, + }, + { + name: "running is not terminal", + status: ExecutionStatusRunning, + expected: false, + }, + { + name: "completed is terminal", + status: ExecutionStatusCompleted, + expected: true, + }, + { + name: "failed is terminal", + status: ExecutionStatusFailed, + expected: true, + }, + { + name: "timeout is terminal", + status: ExecutionStatusTimeout, + expected: true, + }, + { + name: "cancelled is terminal", + status: ExecutionStatusCancelled, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + te := &TaskExecution{Status: tt.status} + assert.Equal(t, tt.expected, te.IsTerminal()) + }) + } +} + +func TestTaskExecution_IsRunning(t *testing.T) { + tests := []struct { + name string + status ExecutionStatus + expected bool + }{ + { + name: "running status", + status: ExecutionStatusRunning, + expected: true, + }, + { + name: "pending status", + status: ExecutionStatusPending, + expected: false, + }, + { + name: "completed status", + status: ExecutionStatusCompleted, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + te := &TaskExecution{Status: tt.status} + assert.Equal(t, tt.expected, te.IsRunning()) + }) + } +} + +func TestTaskExecution_IsPending(t *testing.T) { + tests := []struct { + name string + status ExecutionStatus + expected bool + }{ + { + name: "pending status", + status: ExecutionStatusPending, + expected: true, + }, + { + name: "running status", + status: ExecutionStatusRunning, + expected: false, + }, + { + name: "completed status", + status: ExecutionStatusCompleted, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + te := &TaskExecution{Status: tt.status} + assert.Equal(t, tt.expected, te.IsPending()) + }) + } +} + +func TestTaskExecution_GetDuration(t *testing.T) { + t.Run("with started and completed times", func(t *testing.T) { + startTime := time.Now() + endTime := startTime.Add(2 * time.Second) + + te := &TaskExecution{ + StartedAt: &startTime, + CompletedAt: &endTime, + } + + duration := te.GetDuration() + assert.NotNil(t, duration) + assert.Equal(t, 2000, *duration) // 2 seconds in milliseconds + }) + + t.Run("with execution time ms set", func(t *testing.T) { + executionTime := 1500 + te := &TaskExecution{ + ExecutionTimeMs: &executionTime, + } + + duration := te.GetDuration() + assert.NotNil(t, duration) + assert.Equal(t, 1500, *duration) + }) + + t.Run("without timing information", func(t *testing.T) { + te := &TaskExecution{} + + duration := te.GetDuration() + assert.Nil(t, duration) + }) +} + +func TestTaskExecution_ToResponse(t *testing.T) { + startTime := time.Now() + endTime := startTime.Add(time.Second) + returnCode := 0 + stdout := "test output" + stderr := "" + executionTime := 1000 + memoryUsage := int64(1024 * 1024) + + te := &TaskExecution{ + ID: NewID(), + TaskID: NewID(), + Status: ExecutionStatusCompleted, + ReturnCode: &returnCode, + Stdout: &stdout, + Stderr: &stderr, + ExecutionTimeMs: &executionTime, + MemoryUsageBytes: &memoryUsage, + StartedAt: &startTime, + CompletedAt: &endTime, + CreatedAt: time.Now(), + } + + response := te.ToResponse() + + assert.Equal(t, te.ID, response.ID) + assert.Equal(t, te.TaskID, response.TaskID) + assert.Equal(t, te.Status, response.Status) + assert.Equal(t, *te.ReturnCode, *response.ReturnCode) + assert.Equal(t, *te.Stdout, *response.Stdout) + assert.Equal(t, *te.Stderr, *response.Stderr) + assert.Equal(t, *te.ExecutionTimeMs, *response.ExecutionTimeMs) + assert.Equal(t, *te.MemoryUsageBytes, *response.MemoryUsageBytes) + assert.NotNil(t, response.StartedAt) + assert.NotNil(t, response.CompletedAt) + assert.NotEmpty(t, response.CreatedAt) +} + +func TestTaskExecution_ToResponse_WithNilValues(t *testing.T) { + te := &TaskExecution{ + ID: NewID(), + TaskID: NewID(), + Status: ExecutionStatusPending, + CreatedAt: time.Now(), + } + + response := te.ToResponse() + + assert.Equal(t, te.ID, response.ID) + assert.Equal(t, te.TaskID, response.TaskID) + assert.Equal(t, te.Status, response.Status) + assert.Nil(t, response.ReturnCode) + assert.Nil(t, response.Stdout) + assert.Nil(t, response.Stderr) + assert.Nil(t, response.ExecutionTimeMs) + assert.Nil(t, response.MemoryUsageBytes) + assert.Nil(t, response.StartedAt) + assert.Nil(t, response.CompletedAt) + assert.NotEmpty(t, response.CreatedAt) +} \ No newline at end of file diff --git a/internal/models/task_test.go b/internal/models/task_test.go new file mode 100644 index 0000000..ed25130 --- /dev/null +++ b/internal/models/task_test.go @@ -0,0 +1,330 @@ +package models + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateTaskName(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + errMsg string + }{ + { + name: "valid task name", + input: "Valid Task Name", + wantErr: false, + }, + { + name: "empty string", + input: "", + wantErr: true, + errMsg: "task name is required", + }, + { + name: "whitespace only", + input: " ", + wantErr: true, + errMsg: "task name cannot be empty", + }, + { + name: "too long name", + input: string(make([]byte, 256)), + wantErr: true, + errMsg: "task name is too long", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateTaskName(tt.input) + if tt.wantErr { + assert.Error(t, err) + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidateScriptType(t *testing.T) { + tests := []struct { + name string + scriptType ScriptType + wantErr bool + }{ + { + name: "valid python", + scriptType: ScriptTypePython, + wantErr: false, + }, + { + name: "valid javascript", + scriptType: ScriptTypeJavaScript, + wantErr: false, + }, + { + name: "valid bash", + scriptType: ScriptTypeBash, + wantErr: false, + }, + { + name: "valid go", + scriptType: ScriptTypeGo, + wantErr: false, + }, + { + name: "invalid script type", + scriptType: "invalid", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateScriptType(tt.scriptType) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidateScriptContent(t *testing.T) { + tests := []struct { + name string + content string + wantErr bool + errMsg string + }{ + { + name: "valid script content", + content: "print('hello world')", + wantErr: false, + }, + { + name: "empty content", + content: "", + wantErr: true, + errMsg: "script content is required", + }, + { + name: "whitespace only content", + content: " ", + wantErr: true, + errMsg: "script content cannot be empty", + }, + { + name: "too long content", + content: string(make([]byte, 65536)), + wantErr: true, + errMsg: "script content is too long", + }, + { + name: "dangerous content", + content: "rm -rf /", + wantErr: true, + errMsg: "potentially dangerous script content detected", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateScriptContent(tt.content) + if tt.wantErr { + assert.Error(t, err) + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidateTaskStatus(t *testing.T) { + tests := []struct { + name string + status TaskStatus + wantErr bool + }{ + { + name: "valid pending status", + status: TaskStatusPending, + wantErr: false, + }, + { + name: "valid running status", + status: TaskStatusRunning, + wantErr: false, + }, + { + name: "valid completed status", + status: TaskStatusCompleted, + wantErr: false, + }, + { + name: "valid failed status", + status: TaskStatusFailed, + wantErr: false, + }, + { + name: "valid timeout status", + status: TaskStatusTimeout, + wantErr: false, + }, + { + name: "valid cancelled status", + status: TaskStatusCancelled, + wantErr: false, + }, + { + name: "invalid status", + status: "invalid", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateTaskStatus(tt.status) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidatePriority(t *testing.T) { + tests := []struct { + name string + priority int + wantErr bool + }{ + { + name: "valid priority 0", + priority: 0, + wantErr: false, + }, + { + name: "valid priority 5", + priority: 5, + wantErr: false, + }, + { + name: "valid priority 10", + priority: 10, + wantErr: false, + }, + { + name: "invalid negative priority", + priority: -1, + wantErr: true, + }, + { + name: "invalid too high priority", + priority: 11, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidatePriority(tt.priority) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidateTimeout(t *testing.T) { + tests := []struct { + name string + timeout int + wantErr bool + }{ + { + name: "valid timeout", + timeout: 30, + wantErr: false, + }, + { + name: "minimum valid timeout", + timeout: 1, + wantErr: false, + }, + { + name: "maximum valid timeout", + timeout: 3600, + wantErr: false, + }, + { + name: "invalid zero timeout", + timeout: 0, + wantErr: true, + }, + { + name: "invalid negative timeout", + timeout: -1, + wantErr: true, + }, + { + name: "invalid too large timeout", + timeout: 3601, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateTimeout(tt.timeout) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestTask_ToResponse(t *testing.T) { + task := &Task{ + BaseModel: BaseModel{ + ID: NewID(), + }, + UserID: NewID(), + Name: "Test Task", + ScriptContent: "print('test')", + ScriptType: ScriptTypePython, + Status: TaskStatusPending, + Priority: 1, + TimeoutSeconds: 30, + } + + response := task.ToResponse() + + assert.Equal(t, task.ID, response.ID) + assert.Equal(t, task.UserID, response.UserID) + assert.Equal(t, task.Name, response.Name) + assert.Equal(t, task.ScriptContent, response.ScriptContent) + assert.Equal(t, task.ScriptType, response.ScriptType) + assert.Equal(t, task.Status, response.Status) + assert.Equal(t, task.Priority, response.Priority) + assert.Equal(t, task.TimeoutSeconds, response.TimeoutSeconds) + assert.NotEmpty(t, response.CreatedAt) + assert.NotEmpty(t, response.UpdatedAt) +} \ No newline at end of file diff --git a/internal/models/user.go b/internal/models/user.go new file mode 100644 index 0000000..a92092d --- /dev/null +++ b/internal/models/user.go @@ -0,0 +1,94 @@ +package models + +import ( + "fmt" + "regexp" + "strings" + + "github.com/google/uuid" +) + +// User represents a user in the system +type User struct { + BaseModel + Email string `json:"email" db:"email"` + PasswordHash string `json:"-" db:"password_hash"` +} + +// CreateUserRequest represents the request to create a new user +type CreateUserRequest struct { + Email string `json:"email" validate:"required,email"` + Password string `json:"password" validate:"required,min=8"` +} + +// UpdateUserRequest represents the request to update a user +type UpdateUserRequest struct { + Email string `json:"email,omitempty" validate:"omitempty,email"` +} + +// UserResponse represents the user response (without sensitive data) +type UserResponse struct { + ID uuid.UUID `json:"id"` + Email string `json:"email"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +// ToResponse converts User to UserResponse +func (u *User) ToResponse() UserResponse { + return UserResponse{ + ID: u.ID, + Email: u.Email, + CreatedAt: u.CreatedAt.Format("2006-01-02T15:04:05Z07:00"), + UpdatedAt: u.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"), + } +} + +// ValidateEmail validates the email format +func ValidateEmail(email string) error { + if email == "" { + return fmt.Errorf("email is required") + } + + email = strings.TrimSpace(strings.ToLower(email)) + + if len(email) > 255 { + return fmt.Errorf("email is too long (max 255 characters)") + } + + // Basic email validation + emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) + if !emailRegex.MatchString(email) { + return fmt.Errorf("invalid email format") + } + + return nil +} + +// ValidatePassword validates the password strength +func ValidatePassword(password string) error { + if len(password) < 8 { + return fmt.Errorf("password must be at least 8 characters long") + } + + if len(password) > 128 { + return fmt.Errorf("password is too long (max 128 characters)") + } + + // Check for at least one uppercase letter + if !regexp.MustCompile(`[A-Z]`).MatchString(password) { + return fmt.Errorf("password must contain at least one uppercase letter") + } + + // Check for at least one lowercase letter + if !regexp.MustCompile(`[a-z]`).MatchString(password) { + return fmt.Errorf("password must contain at least one lowercase letter") + } + + // Check for at least one digit + if !regexp.MustCompile(`[0-9]`).MatchString(password) { + return fmt.Errorf("password must contain at least one digit") + } + + return nil +} \ No newline at end of file diff --git a/internal/models/user_test.go b/internal/models/user_test.go new file mode 100644 index 0000000..bbe97b4 --- /dev/null +++ b/internal/models/user_test.go @@ -0,0 +1,151 @@ +package models + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateEmail(t *testing.T) { + tests := []struct { + name string + email string + wantErr bool + errMsg string + }{ + { + name: "valid email", + email: "test@example.com", + wantErr: false, + }, + { + name: "valid email with subdomain", + email: "user@subdomain.example.com", + wantErr: false, + }, + { + name: "empty email", + email: "", + wantErr: true, + errMsg: "email is required", + }, + { + name: "email without @", + email: "invalidemail", + wantErr: true, + errMsg: "invalid email format", + }, + { + name: "email without domain", + email: "test@", + wantErr: true, + errMsg: "invalid email format", + }, + { + name: "email without local part", + email: "@example.com", + wantErr: true, + errMsg: "invalid email format", + }, + { + name: "email too long", + email: string(make([]byte, 250)) + "@example.com", + wantErr: true, + errMsg: "email is too long", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateEmail(tt.email) + if tt.wantErr { + assert.Error(t, err) + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidatePassword(t *testing.T) { + tests := []struct { + name string + password string + wantErr bool + errMsg string + }{ + { + name: "valid password", + password: "Password123", + wantErr: false, + }, + { + name: "valid complex password", + password: "MyStr0ngP@ssw0rd!", + wantErr: false, + }, + { + name: "too short password", + password: "Pass1", + wantErr: true, + errMsg: "at least 8 characters", + }, + { + name: "too long password", + password: string(make([]byte, 130)), + wantErr: true, + errMsg: "too long", + }, + { + name: "no uppercase letter", + password: "password123", + wantErr: true, + errMsg: "uppercase letter", + }, + { + name: "no lowercase letter", + password: "PASSWORD123", + wantErr: true, + errMsg: "lowercase letter", + }, + { + name: "no digit", + password: "Password", + wantErr: true, + errMsg: "digit", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidatePassword(tt.password) + if tt.wantErr { + assert.Error(t, err) + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestUser_ToResponse(t *testing.T) { + user := &User{ + BaseModel: BaseModel{ + ID: NewID(), + }, + Email: "test@example.com", + } + + response := user.ToResponse() + + assert.Equal(t, user.ID, response.ID) + assert.Equal(t, user.Email, response.Email) + assert.NotEmpty(t, response.CreatedAt) + assert.NotEmpty(t, response.UpdatedAt) +} \ No newline at end of file diff --git a/migrations/001_initial_schema.down.sql b/migrations/001_initial_schema.down.sql new file mode 100644 index 0000000..cacc52c --- /dev/null +++ b/migrations/001_initial_schema.down.sql @@ -0,0 +1,14 @@ +-- Drop triggers +DROP TRIGGER IF EXISTS update_tasks_updated_at ON tasks; +DROP TRIGGER IF EXISTS update_users_updated_at ON users; + +-- Drop function +DROP FUNCTION IF EXISTS update_updated_at_column(); + +-- Drop tables in reverse order (due to foreign key constraints) +DROP TABLE IF EXISTS task_executions; +DROP TABLE IF EXISTS tasks; +DROP TABLE IF EXISTS users; + +-- Drop extension (only if no other tables use it) +-- DROP EXTENSION IF EXISTS "pgcrypto"; \ No newline at end of file diff --git a/migrations/001_initial_schema.up.sql b/migrations/001_initial_schema.up.sql new file mode 100644 index 0000000..3a8376d --- /dev/null +++ b/migrations/001_initial_schema.up.sql @@ -0,0 +1,88 @@ +-- Enable UUID extension for gen_random_uuid() function +-- Note: gen_random_uuid() is built-in to PostgreSQL 13+, but we ensure compatibility with older versions +CREATE EXTENSION IF NOT EXISTS "pgcrypto"; + +-- Create users table +CREATE TABLE users ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + email VARCHAR(255) UNIQUE NOT NULL, + password_hash VARCHAR(255) NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +-- Create index for email lookups +CREATE INDEX idx_users_email ON users(email); + +-- Create tasks table +CREATE TABLE tasks ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + name VARCHAR(255) NOT NULL, + description TEXT, + script_content TEXT NOT NULL, + script_type VARCHAR(50) NOT NULL DEFAULT 'python', + status VARCHAR(50) NOT NULL DEFAULT 'pending', + priority INTEGER NOT NULL DEFAULT 0, + timeout_seconds INTEGER NOT NULL DEFAULT 10, + metadata JSONB DEFAULT '{}', + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + + -- Constraints + CONSTRAINT chk_script_type CHECK (script_type IN ('python', 'javascript', 'bash', 'go')), + CONSTRAINT chk_status CHECK (status IN ('pending', 'running', 'completed', 'failed', 'timeout', 'cancelled')), + CONSTRAINT chk_priority CHECK (priority >= 0 AND priority <= 10), + CONSTRAINT chk_timeout CHECK (timeout_seconds > 0 AND timeout_seconds <= 3600) +); + +-- Create indexes for tasks table +CREATE INDEX idx_tasks_user_status ON tasks(user_id, status); +CREATE INDEX idx_tasks_created_at ON tasks(created_at); +CREATE INDEX idx_tasks_priority_status ON tasks(priority DESC, status); +CREATE INDEX idx_tasks_metadata_gin ON tasks USING GIN(metadata); + +-- Create task_executions table +CREATE TABLE task_executions ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + task_id UUID NOT NULL REFERENCES tasks(id) ON DELETE CASCADE, + status VARCHAR(50) NOT NULL DEFAULT 'pending', + return_code INTEGER, + stdout TEXT, + stderr TEXT, + execution_time_ms INTEGER, + memory_usage_bytes BIGINT, + started_at TIMESTAMP WITH TIME ZONE, + completed_at TIMESTAMP WITH TIME ZONE, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + + -- Constraints + CONSTRAINT chk_execution_status CHECK (status IN ('pending', 'running', 'completed', 'failed', 'timeout', 'cancelled')), + CONSTRAINT chk_return_code CHECK (return_code >= 0 AND return_code <= 255), + CONSTRAINT chk_execution_time CHECK (execution_time_ms >= 0), + CONSTRAINT chk_memory_usage CHECK (memory_usage_bytes >= 0) +); + +-- Create indexes for task_executions table +CREATE INDEX idx_executions_task_created ON task_executions(task_id, created_at DESC); +CREATE INDEX idx_executions_status_created ON task_executions(status, created_at DESC); + +-- Create function to update updated_at timestamp +CREATE OR REPLACE FUNCTION update_updated_at_column() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = NOW(); + RETURN NEW; +END; +$$ language 'plpgsql'; + +-- Create triggers to automatically update updated_at +CREATE TRIGGER update_users_updated_at + BEFORE UPDATE ON users + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); + +CREATE TRIGGER update_tasks_updated_at + BEFORE UPDATE ON tasks + FOR EACH ROW + EXECUTE FUNCTION update_updated_at_column(); \ No newline at end of file