diff --git a/cmd/hatchet-migrate/migrate/migrations/20260410190713_v1_0_97.go b/cmd/hatchet-migrate/migrate/migrations/20260410190713_v1_0_97.go new file mode 100644 index 0000000000..0239444c45 --- /dev/null +++ b/cmd/hatchet-migrate/migrate/migrations/20260410190713_v1_0_97.go @@ -0,0 +1,122 @@ +package migrations + +import ( + "context" + "database/sql" + "fmt" + + "github.com/pressly/goose/v3" +) + +func init() { + goose.AddMigrationNoTxContext(up20260410190713, down20260410190713) +} + +func v1RunsOlapTenantStatusInsAtIdxName(table string) string { + return fmt.Sprintf("ix_%s_tenant_status_ins_at", table) +} + +func up20260410190713(ctx context.Context, db *sql.DB) error { + // drop the old outdated index first + // note: can't do this concurrently or in parts (i.e. dropping children first) + // see: https://stackoverflow.com/a/76167838 + stmt := "DROP INDEX IF EXISTS ix_v1_runs_olap_tenant_id" + + if _, err := db.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("drop old index on %s: %w", v1RunsOlapTable, err) + } + + grandchildPartitions, err := listLeafPartitions(ctx, db, v1RunsOlapTable, 2) + if err != nil { + return err + } + + for _, partition := range grandchildPartitions { + stmt := fmt.Sprintf( + `CREATE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (tenant_id, readable_status, inserted_at DESC)`, + quoteIdent(v1RunsOlapTenantStatusInsAtIdxName(partition)), + quoteIdent(partition), + ) + if _, err := db.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("create index concurrently on %s: %w", partition, err) + } + } + + childPartitions, err := listLeafPartitions(ctx, db, v1RunsOlapTable, 1) + if err != nil { + return err + } + + for _, partition := range childPartitions { + stmt := fmt.Sprintf( + `CREATE INDEX IF NOT EXISTS %s ON %s (tenant_id, readable_status, inserted_at DESC)`, + quoteIdent(v1RunsOlapTenantStatusInsAtIdxName(partition)), + quoteIdent(partition), + ) + if _, err := db.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("create index on partition %s: %w", partition, err) + } + } + + stmt = fmt.Sprintf( + `CREATE INDEX IF NOT EXISTS %s ON %s (tenant_id, readable_status, inserted_at DESC)`, + quoteIdent(v1RunsOlapTenantStatusInsAtIdxName(v1RunsOlapTable)), + quoteIdent(v1RunsOlapTable), + ) + if _, err := db.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("create index on %s: %w", v1RunsOlapTable, err) + } + + return nil +} + +func down20260410190713(ctx context.Context, db *sql.DB) error { + // drop the new index first so we can rebuild the old one bottom-up + stmt := "DROP INDEX IF EXISTS ix_v1_runs_olap_tenant_status_ins_at" + if _, err := db.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("drop new index on %s: %w", v1RunsOlapTable, err) + } + + grandchildPartitions, err := listLeafPartitions(ctx, db, v1RunsOlapTable, 2) + if err != nil { + return err + } + + for _, partition := range grandchildPartitions { + stmt := fmt.Sprintf( + `CREATE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (tenant_id, inserted_at, id, readable_status, kind)`, + quoteIdent(idxNameForPartition(partition)), + quoteIdent(partition), + ) + if _, err := db.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("create index concurrently on %s: %w", partition, err) + } + } + + childPartitions, err := listLeafPartitions(ctx, db, v1RunsOlapTable, 1) + if err != nil { + return err + } + + for _, partition := range childPartitions { + stmt := fmt.Sprintf( + `CREATE INDEX IF NOT EXISTS %s ON %s (tenant_id, inserted_at, id, readable_status, kind)`, + quoteIdent(idxNameForPartition(partition)), + quoteIdent(partition), + ) + if _, err := db.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("create index on partition %s: %w", partition, err) + } + } + + stmt = fmt.Sprintf( + `CREATE INDEX IF NOT EXISTS %s ON %s (tenant_id, inserted_at, id, readable_status, kind)`, + quoteIdent(idxNameForPartition(v1RunsOlapTable)), + quoteIdent(v1RunsOlapTable), + ) + if _, err := db.ExecContext(ctx, stmt); err != nil { + return fmt.Errorf("create index on %s: %w", v1RunsOlapTable, err) + } + + return nil +} diff --git a/cmd/hatchet-migrate/migrate/migrations/20260410202520_v1_0_98.sql b/cmd/hatchet-migrate/migrate/migrations/20260410202520_v1_0_98.sql new file mode 100644 index 0000000000..bf690fc643 --- /dev/null +++ b/cmd/hatchet-migrate/migrate/migrations/20260410202520_v1_0_98.sql @@ -0,0 +1,41 @@ +-- +goose Up +-- +goose StatementBegin +DROP FUNCTION create_v1_olap_partition_with_date_and_status(text, date); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +CREATE OR REPLACE FUNCTION create_v1_olap_partition_with_date_and_status( + targetTableName text, + targetDate date +) RETURNS integer + LANGUAGE plpgsql AS +$$ +DECLARE + targetDateStr varchar; + targetDatePlusOneDayStr varchar; + newTableName varchar; +BEGIN + SELECT to_char(targetDate, 'YYYYMMDD') INTO targetDateStr; + SELECT to_char(targetDate + INTERVAL '1 day', 'YYYYMMDD') INTO targetDatePlusOneDayStr; + SELECT format('%s_%s', targetTableName, targetDateStr) INTO newTableName; + IF NOT EXISTS (SELECT 1 FROM pg_tables WHERE tablename = newTableName) THEN + EXECUTE format('CREATE TABLE %s (LIKE %s INCLUDING INDEXES) PARTITION BY LIST (readable_status)', newTableName, targetTableName); + END IF; + + PERFORM create_v1_partition_with_status(newTableName, 'QUEUED'); + PERFORM create_v1_partition_with_status(newTableName, 'RUNNING'); + PERFORM create_v1_partition_with_status(newTableName, 'COMPLETED'); + PERFORM create_v1_partition_with_status(newTableName, 'CANCELLED'); + PERFORM create_v1_partition_with_status(newTableName, 'FAILED'); + PERFORM create_v1_partition_with_status(newTableName, 'EVICTED'); + + -- If it's not already attached, attach the partition + IF NOT EXISTS (SELECT 1 FROM pg_inherits WHERE inhrelid = newTableName::regclass) THEN + EXECUTE format('ALTER TABLE %s ATTACH PARTITION %s FOR VALUES FROM (''%s'') TO (''%s'')', targetTableName, newTableName, targetDateStr, targetDatePlusOneDayStr); + END IF; + + RETURN 1; +END; +$$; +-- +goose StatementEnd diff --git a/internal/operation/pool.go b/internal/operation/pool.go index b8a303454c..455560b10b 100644 --- a/internal/operation/pool.go +++ b/internal/operation/pool.go @@ -109,11 +109,11 @@ func (p *TenantOperationPool) Cleanup() { }) } -func (p *TenantOperationPool) setTenants(tenants []*sqlcv1.Tenant) { +func (p *TenantOperationPool) setTenants(tenants []uuid.UUID) { tenantMap := make(map[uuid.UUID]bool) for _, t := range tenants { - tenantMap[t.ID] = true + tenantMap[t] = true } // init ops for new tenants diff --git a/internal/queueutils/pool.go b/internal/queueutils/pool.go index a6bb9c330f..c6aa4ac510 100644 --- a/internal/queueutils/pool.go +++ b/internal/queueutils/pool.go @@ -3,8 +3,8 @@ package queueutils import ( "time" + "github.com/google/uuid" "github.com/hatchet-dev/hatchet/internal/syncx" - "github.com/hatchet-dev/hatchet/pkg/repository/sqlcv1" "github.com/rs/zerolog" ) @@ -33,11 +33,11 @@ func (p *OperationPool) WithJitter(maxJitter time.Duration) *OperationPool { return p } -func (p *OperationPool) SetTenants(tenants []*sqlcv1.Tenant) { +func (p *OperationPool) SetTenants(tenants []uuid.UUID) { tenantMap := make(map[string]bool) for _, t := range tenants { - tenantMap[t.ID.String()] = true + tenantMap[t.String()] = true } // delete tenants that are not in the list diff --git a/internal/services/controllers/olap/controller.go b/internal/services/controllers/olap/controller.go index 2ebc5eee1b..20d06bf6ad 100644 --- a/internal/services/controllers/olap/controller.go +++ b/internal/services/controllers/olap/controller.go @@ -304,38 +304,38 @@ func (o *OLAPControllerImpl) Start() (func() error, error) { } // Default poll interval - pollIntervalSec := 2 + // pollIntervalSec := 2 // Override with config value if available - if o.olapConfig != nil && o.olapConfig.PollInterval > 0 { - pollIntervalSec = o.olapConfig.PollInterval - } + // if o.olapConfig != nil && o.olapConfig.PollInterval > 0 { + // pollIntervalSec = o.olapConfig.PollInterval + // } - _, err = o.s.NewJob( - gocron.DurationJob(time.Second*time.Duration(pollIntervalSec)), - gocron.NewTask( - o.runTaskStatusUpdates(ctx), - ), - gocron.WithSingletonMode(gocron.LimitModeReschedule), - ) + // _, err = o.s.NewJob( + // gocron.DurationJob(time.Second*time.Duration(pollIntervalSec)), + // gocron.NewTask( + // o.runTaskStatusUpdates(ctx), + // ), + // gocron.WithSingletonMode(gocron.LimitModeReschedule), + // ) - if err != nil { - cancel() - return nil, fmt.Errorf("could not schedule task status updates: %w", err) - } + // if err != nil { + // cancel() + // return nil, fmt.Errorf("could not schedule task status updates: %w", err) + // } - _, err = o.s.NewJob( - gocron.DurationJob(time.Second*time.Duration(pollIntervalSec)), - gocron.NewTask( - o.runDAGStatusUpdates(ctx), - ), - gocron.WithSingletonMode(gocron.LimitModeReschedule), - ) + // _, err = o.s.NewJob( + // gocron.DurationJob(time.Second*time.Duration(pollIntervalSec)), + // gocron.NewTask( + // o.runDAGStatusUpdates(ctx), + // ), + // gocron.WithSingletonMode(gocron.LimitModeReschedule), + // ) - if err != nil { - cancel() - return nil, fmt.Errorf("could not schedule dag status updates: %w", err) - } + // if err != nil { + // cancel() + // return nil, fmt.Errorf("could not schedule dag status updates: %w", err) + // } _, err = o.s.NewJob( gocron.DurationJob(time.Second*60), @@ -766,6 +766,7 @@ func (tc *OLAPControllerImpl) handleCreateMonitoringEvent(ctx context.Context, t durableInvocationCounts := make([]int32, 0) workerIds := make([]uuid.UUID, 0) workflowIds := make([]uuid.UUID, 0) + workflowRunIDs := make([]uuid.UUID, 0) eventTypes := make([]sqlcv1.V1EventTypeOlap, 0) readableStatuses := make([]sqlcv1.V1ReadableStatusOlap, 0) eventPayloads := make([]string, 0) @@ -790,6 +791,7 @@ func (tc *OLAPControllerImpl) handleCreateMonitoringEvent(ctx context.Context, t taskIds = append(taskIds, msg.TaskId) taskInsertedAts = append(taskInsertedAts, taskMeta.InsertedAt) workflowIds = append(workflowIds, taskMeta.WorkflowID) + workflowRunIDs = append(workflowRunIDs, taskMeta.WorkflowRunID) retryCounts = append(retryCounts, msg.RetryCount) durableInvocationCounts = append(durableInvocationCounts, msg.DurableInvocationCount) eventTypes = append(eventTypes, msg.EventType) @@ -913,12 +915,51 @@ func (tc *OLAPControllerImpl) handleCreateMonitoringEvent(ctx context.Context, t opts = append(opts, event) } - err = tc.repo.OLAP().CreateTaskEvents(ctx, tenantId, opts) + notFoundEvents, err := tc.repo.OLAP().CreateTaskEvents(ctx, tenantId, opts) if err != nil { return err } + if len(notFoundEvents) > 0 { + notFoundTaskIDs := make(map[int64]struct{}, len(notFoundEvents)) + for _, e := range notFoundEvents { + notFoundTaskIDs[e.TaskID] = struct{}{} + } + + requeueCount := 0 + + for _, msg := range msgs { + if _, ok := notFoundTaskIDs[msg.TaskId]; !ok { + continue + } + + if msg.RequeueCount >= 10 { + tc.l.Error().Ctx(ctx).Msgf("giving up on requeuing monitoring event for task %d after %d attempts", msg.TaskId, msg.RequeueCount) + continue + } + + requeued := *msg + requeued.RequeueCount++ + + requeueMsg, requeueErr := tasktypes.MonitoringEventMessageFromInternal(tenantId, requeued) + if requeueErr != nil { + tc.l.Error().Ctx(ctx).Err(requeueErr).Msgf("could not create requeue message for task %d", msg.TaskId) + continue + } + + if requeueErr = tc.mq.SendMessage(ctx, msgqueue.OLAP_QUEUE, requeueMsg); requeueErr != nil { + tc.l.Error().Ctx(ctx).Err(requeueErr).Msgf("could not requeue monitoring event for task %d", msg.TaskId) + } else { + requeueCount++ + } + } + + if requeueCount > 0 { + tc.l.Warn().Ctx(ctx).Msgf("requeued %d monitoring events for tasks not yet in OLAP table", requeueCount) + } + } + tc.synthesizeEngineSpans(ctx, tenantId, spanEvents) if !tc.repo.OLAP().PayloadStore().ExternalStoreEnabled() { diff --git a/internal/services/controllers/olap/process_alerts.go b/internal/services/controllers/olap/process_alerts.go index 07446f13af..ac58626cc9 100644 --- a/internal/services/controllers/olap/process_alerts.go +++ b/internal/services/controllers/olap/process_alerts.go @@ -16,7 +16,6 @@ func (o *OLAPControllerImpl) runTenantProcessAlerts(ctx context.Context) func() return func() { o.l.Debug().Ctx(ctx).Msgf("partition: processing tenant alerts") - // list all tenants tenants, err := o.p.ListTenantsForController(ctx, sqlcv1.TenantMajorEngineVersionV1) if err != nil { @@ -26,10 +25,8 @@ func (o *OLAPControllerImpl) runTenantProcessAlerts(ctx context.Context) func() o.processTenantAlertOperations.SetTenants(tenants) - for i := range tenants { - tenantId := tenants[i].ID.String() - - o.processTenantAlertOperations.RunOrContinue(tenantId) + for _, tenantId := range tenants { + o.processTenantAlertOperations.RunOrContinue(tenantId.String()) } } } diff --git a/internal/services/controllers/olap/process_dag_status_updates.go b/internal/services/controllers/olap/process_dag_status_updates.go index 525da55c9a..ab2f08ba27 100644 --- a/internal/services/controllers/olap/process_dag_status_updates.go +++ b/internal/services/controllers/olap/process_dag_status_updates.go @@ -22,21 +22,13 @@ func (o *OLAPControllerImpl) runDAGStatusUpdates(ctx context.Context) func() { for shouldContinue { o.l.Debug().Ctx(ctx).Msgf("partition: running status updates for dags") - // list all tenants - tenants, err := o.p.ListTenantsForController(ctx, sqlcv1.TenantMajorEngineVersionV1) + tenantIds, err := o.p.ListTenantsForController(ctx, sqlcv1.TenantMajorEngineVersionV1) if err != nil { o.l.Error().Ctx(ctx).Err(err).Msg("could not list tenants") return } - tenantIds := make([]uuid.UUID, 0, len(tenants)) - - for _, tenant := range tenants { - tenantId := tenant.ID - tenantIds = append(tenantIds, tenantId) - } - var rows []v1.UpdateDAGStatusRow shouldContinue, rows, err = o.repo.OLAP().UpdateDAGStatuses(ctx, tenantIds) diff --git a/internal/services/controllers/olap/process_task_status_updates.go b/internal/services/controllers/olap/process_task_status_updates.go index 39eb102713..ce3b97e30a 100644 --- a/internal/services/controllers/olap/process_task_status_updates.go +++ b/internal/services/controllers/olap/process_task_status_updates.go @@ -22,21 +22,13 @@ func (o *OLAPControllerImpl) runTaskStatusUpdates(ctx context.Context) func() { for shouldContinue { o.l.Debug().Ctx(ctx).Msgf("partition: running status updates for tasks") - // list all tenants - tenants, err := o.p.ListTenantsForController(ctx, sqlcv1.TenantMajorEngineVersionV1) + tenantIds, err := o.p.ListTenantsForController(ctx, sqlcv1.TenantMajorEngineVersionV1) if err != nil { o.l.Error().Ctx(ctx).Err(err).Msg("could not list tenants") return } - tenantIds := make([]uuid.UUID, 0, len(tenants)) - - for _, tenant := range tenants { - tenantId := tenant.ID - tenantIds = append(tenantIds, tenantId) - } - var rows []v1.UpdateTaskStatusRow shouldContinue, rows, err = o.repo.OLAP().UpdateTaskStatuses(ctx, tenantIds) diff --git a/internal/services/controllers/retention/shared.go b/internal/services/controllers/retention/shared.go index 618e0eb428..1ba678b3bf 100644 --- a/internal/services/controllers/retention/shared.go +++ b/internal/services/controllers/retention/shared.go @@ -7,6 +7,7 @@ import ( "golang.org/x/sync/errgroup" + "github.com/google/uuid" "github.com/hatchet-dev/hatchet/pkg/repository/sqlcv1" ) @@ -20,9 +21,8 @@ func GetDataRetentionExpiredTime(duration string) (time.Time, error) { return time.Now().UTC().Add(-d), nil } -func (rc *RetentionControllerImpl) ForTenants(ctx context.Context, f func(ctx context.Context, tenant sqlcv1.Tenant) error) error { +func (rc *RetentionControllerImpl) ForTenants(ctx context.Context, f func(ctx context.Context, tenant uuid.UUID) error) error { - // list all tenants tenants, err := rc.p.ListTenantsForController(ctx, sqlcv1.TenantMajorEngineVersionV0) if err != nil { @@ -34,7 +34,7 @@ func (rc *RetentionControllerImpl) ForTenants(ctx context.Context, f func(ctx co for i := range tenants { index := i g.Go(func() error { - return f(ctx, *tenants[index]) + return f(ctx, tenants[index]) }) } diff --git a/internal/services/partition/partition.go b/internal/services/partition/partition.go index 07c463439f..fa3e6ddc10 100644 --- a/internal/services/partition/partition.go +++ b/internal/services/partition/partition.go @@ -7,6 +7,7 @@ import ( "time" "github.com/go-co-op/gocron/v2" + "github.com/google/uuid" "github.com/rs/zerolog" "github.com/hatchet-dev/hatchet/pkg/cleanup" @@ -176,7 +177,7 @@ func (p *Partition) GetInternalTenantForController(ctx context.Context) (*sqlcv1 return p.repo.GetInternalTenantForController(ctx, p.GetControllerPartitionId()) } -func (p *Partition) ListTenantsForController(ctx context.Context, majorVersion sqlcv1.TenantMajorEngineVersion) ([]*sqlcv1.Tenant, error) { +func (p *Partition) ListTenantsForController(ctx context.Context, majorVersion sqlcv1.TenantMajorEngineVersion) ([]uuid.UUID, error) { return p.repo.ListTenantsByControllerPartition(ctx, p.GetControllerPartitionId(), majorVersion) } diff --git a/internal/services/shared/tasktypes/v1/olap.go b/internal/services/shared/tasktypes/v1/olap.go index 979d25bb37..d696ee808d 100644 --- a/internal/services/shared/tasktypes/v1/olap.go +++ b/internal/services/shared/tasktypes/v1/olap.go @@ -101,6 +101,8 @@ type CreateMonitoringEventPayload struct { EventTimestamp time.Time `json:"event_timestamp" validate:"required"` EventPayload string `json:"event_payload" validate:"required"` EventMessage string `json:"event_message,omitempty"` + + RequeueCount int32 `json:"requeue_count,omitempty"` } func MonitoringEventMessageFromActionEvent(tenantId uuid.UUID, taskId int64, retryCount int32, durableInvocationCount int32, request *contracts.StepActionEvent) (*msgqueue.Message, error) { diff --git a/pkg/repository/olap.go b/pkg/repository/olap.go index 0c60db36ba..a8206e1ed4 100644 --- a/pkg/repository/olap.go +++ b/pkg/repository/olap.go @@ -237,7 +237,7 @@ type OLAPRepository interface { ListWorkflowRunDisplayNames(ctx context.Context, tenantId uuid.UUID, externalIds []uuid.UUID) ([]*sqlcv1.ListWorkflowRunDisplayNamesRow, error) ReadTaskRunMetrics(ctx context.Context, tenantId uuid.UUID, opts ReadTaskRunMetricsOpts) ([]TaskRunMetric, error) CreateTasks(ctx context.Context, tenantId uuid.UUID, tasks []*V1TaskWithPayload) error - CreateTaskEvents(ctx context.Context, tenantId uuid.UUID, events []sqlcv1.CreateTaskEventsOLAPParams) error + CreateTaskEvents(ctx context.Context, tenantId uuid.UUID, events []sqlcv1.CreateTaskEventsOLAPParams) (notFound []sqlcv1.CreateTaskEventsOLAPParams, err error) CreateDAGs(ctx context.Context, tenantId uuid.UUID, dags []*DAGWithData) error GetTaskPointMetrics(ctx context.Context, tenantId uuid.UUID, startTimestamp *time.Time, endTimestamp *time.Time, bucketInterval time.Duration) ([]*sqlcv1.GetTaskPointMetricsRow, error) UpdateTaskStatuses(ctx context.Context, tenantIds []uuid.UUID) (bool, []UpdateTaskStatusRow, error) @@ -1550,10 +1550,101 @@ func getCacheKey(event sqlcv1.CreateTaskEventsOLAPParams) string { return fmt.Sprintf("%d-%s-%d-%d", event.TaskID, event.EventType, event.RetryCount, event.DurableInvocationCount) } -func (r *OLAPRepositoryImpl) writeTaskEventBatch(ctx context.Context, tenantId uuid.UUID, events []sqlcv1.CreateTaskEventsOLAPParams) error { +func (r *OLAPRepositoryImpl) prepareStatusUpdateBatch(ctx context.Context, tenantId uuid.UUID, events []sqlcv1.CreateTaskEventsOLAPParams) sqlcv1.UpdateTaskStatusesFromMQParams { + type statusRetryCountWorkerIdTuple struct { + Status sqlcv1.V1ReadableStatusOlap + RetryCount int32 + WorkerId *uuid.UUID + } + + taskIdInsertedAtToMeta := make(map[IdInsertedAt]statusRetryCountWorkerIdTuple) + + for _, event := range events { + statusAndRetryCount, seen := taskIdInsertedAtToMeta[IdInsertedAt{ + ID: event.TaskID, + InsertedAt: event.TaskInsertedAt, + }] + + if !seen || event.RetryCount > statusAndRetryCount.RetryCount || compareStatuses(event.ReadableStatus, statusAndRetryCount.Status) { + taskIdInsertedAtToMeta[IdInsertedAt{ + ID: event.TaskID, + InsertedAt: event.TaskInsertedAt, + }] = statusRetryCountWorkerIdTuple{ + Status: event.ReadableStatus, + RetryCount: event.RetryCount, + WorkerId: event.WorkerID, + } + } + } + + tenantIds := make([]uuid.UUID, 0) + taskIds := make([]int64, 0) + taskInsertedAts := make([]pgtype.Timestamptz, 0) + statuses := make([]sqlcv1.V1ReadableStatusOlap, 0) + workerIds := make([]uuid.UUID, 0) + retryCounts := make([]int32, 0) + + for idInsertedAt, meta := range taskIdInsertedAtToMeta { + tenantIds = append(tenantIds, tenantId) + taskIds = append(taskIds, idInsertedAt.ID) + taskInsertedAts = append(taskInsertedAts, idInsertedAt.InsertedAt) + statuses = append(statuses, meta.Status) + retryCounts = append(retryCounts, meta.RetryCount) + + if meta.WorkerId != nil { + workerIds = append(workerIds, *meta.WorkerId) + } else { + workerIds = append(workerIds, uuid.Nil) + } + } + + return sqlcv1.UpdateTaskStatusesFromMQParams{ + Tenantids: tenantIds, + Taskids: taskIds, + Taskinsertedats: taskInsertedAts, + Statuses: statuses, + Workerids: workerIds, + Retrycounts: retryCounts, + } +} + +func (r *OLAPRepositoryImpl) prepareDAGStatusUpdateBatch(taskRows []*sqlcv1.UpdateTaskStatusesFromMQRow) sqlcv1.UpdateDAGStatusesFromMQParams { + type dagKey struct { + DagID int64 + DagInsertedAt pgtype.Timestamptz + } + + seen := make(map[dagKey]struct{}) + tenantIds := make([]uuid.UUID, 0) + dagIds := make([]int64, 0) + dagInsertedAts := make([]pgtype.Timestamptz, 0) + + for _, row := range taskRows { + if !row.DagID.Valid { + continue + } + + key := dagKey{DagID: row.DagID.Int64, DagInsertedAt: row.DagInsertedAt} + + if _, ok := seen[key]; !ok { + seen[key] = struct{}{} + tenantIds = append(tenantIds, row.TenantID) + dagIds = append(dagIds, row.DagID.Int64) + dagInsertedAts = append(dagInsertedAts, row.DagInsertedAt) + } + } + + return sqlcv1.UpdateDAGStatusesFromMQParams{ + Tenantids: tenantIds, + Dagids: dagIds, + Daginsertedats: dagInsertedAts, + } +} + +func (r *OLAPRepositoryImpl) writeTaskEventBatch(ctx context.Context, tenantId uuid.UUID, events []sqlcv1.CreateTaskEventsOLAPParams) ([]sqlcv1.CreateTaskEventsOLAPParams, error) { // skip any events which have a corresponding event already eventsToWrite := make([]sqlcv1.CreateTaskEventsOLAPParams, 0) - tmpEventsToWrite := make([]sqlcv1.CreateTaskEventsOLAPTmpParams, 0) + eventsForStatusUpdate := make([]sqlcv1.CreateTaskEventsOLAPParams, 0, len(events)) payloadsToWrite := make([]StoreOLAPPayloadOpts, 0) for _, event := range events { @@ -1566,18 +1657,10 @@ func (r *OLAPRepositoryImpl) writeTaskEventBatch(ctx context.Context, tenantId u } eventsToWrite = append(eventsToWrite, event) - - tmpEventsToWrite = append(tmpEventsToWrite, sqlcv1.CreateTaskEventsOLAPTmpParams{ - TenantID: event.TenantID, - TaskID: event.TaskID, - TaskInsertedAt: event.TaskInsertedAt, - EventType: event.EventType, - RetryCount: event.RetryCount, - ReadableStatus: event.ReadableStatus, - WorkerID: event.WorkerID, - }) } + eventsForStatusUpdate = append(eventsForStatusUpdate, event) + if event.ExternalID != nil { // randomly jitter the inserted at time by +/- 300ms to make collisions virtually impossible dummyInsertedAt := time.Now().Add(time.Duration(rand.Intn(2*300+1)-300) * time.Millisecond) @@ -1590,43 +1673,75 @@ func (r *OLAPRepositoryImpl) writeTaskEventBatch(ctx context.Context, tenantId u } } - if len(eventsToWrite) == 0 { - return nil + if len(eventsForStatusUpdate) == 0 { + return nil, nil } tx, commit, rollback, err := sqlchelpers.PrepareTx(ctx, r.pool, r.l) if err != nil { - return err + return nil, err } defer rollback() - _, err = r.queries.CreateTaskEventsOLAP(ctx, tx, eventsToWrite) + if len(eventsToWrite) > 0 { + _, err = r.queries.CreateTaskEventsOLAP(ctx, tx, eventsToWrite) - if err != nil { - return err + if err != nil { + return nil, err + } } - _, err = r.queries.CreateTaskEventsOLAPTmp(ctx, tx, tmpEventsToWrite) + statusUpdates := r.prepareStatusUpdateBatch(ctx, tenantId, eventsForStatusUpdate) + + taskRows, err := r.queries.UpdateTaskStatusesFromMQ(ctx, tx, statusUpdates) if err != nil { - return err + return nil, err } - _, err = r.PutPayloads(ctx, tx, tenantId, payloadsToWrite...) + foundTaskIDs := make(map[int64]struct{}, len(taskRows)) + updatedTaskCount := 0 + for _, row := range taskRows { + foundTaskIDs[row.TaskID] = struct{}{} - if err != nil { - return err + if row.WasUpdated { + updatedTaskCount++ + } + } + var notFoundEvents []sqlcv1.CreateTaskEventsOLAPParams + for _, event := range eventsForStatusUpdate { + if _, ok := foundTaskIDs[event.TaskID]; !ok { + notFoundEvents = append(notFoundEvents, event) + } + } + + dagStatusUpdates := r.prepareDAGStatusUpdateBatch(taskRows) + + if len(dagStatusUpdates.Dagids) > 0 { + _, err = r.queries.UpdateDAGStatusesFromMQ(ctx, tx, dagStatusUpdates) + + if err != nil { + return nil, err + } + } + + if len(payloadsToWrite) > 0 { + _, err = r.PutPayloads(ctx, tx, tenantId, payloadsToWrite...) + + if err != nil { + return nil, err + } } if err := commit(ctx); err != nil { - return err + return nil, err } r.saveEventsToCache(eventsToWrite) - return nil + return notFoundEvents, nil } func (r *OLAPRepositoryImpl) UpdateTaskStatuses(ctx context.Context, tenantIds []uuid.UUID) (bool, []UpdateTaskStatusRow, error) { @@ -1642,7 +1757,7 @@ func (r *OLAPRepositoryImpl) UpdateTaskStatuses(ctx context.Context, tenantIds [ // if any of the partitions are saturated, we return true isSaturated := false - for i := 0; i < NUM_PARTITIONS; i++ { + for i := range NUM_PARTITIONS { partitionNumber := i innerCtx, innerSpan := telemetry.NewSpan(ctx, "olap_repository.update_task_statuses.partition") @@ -1738,6 +1853,26 @@ func (r *OLAPRepositoryImpl) UpdateTaskStatuses(ctx context.Context, tenantIds [ return isSaturated, rows, nil } +func compareStatuses(status1, status2 sqlcv1.V1ReadableStatusOlap) bool { + ordering := map[sqlcv1.V1ReadableStatusOlap]int{ + sqlcv1.V1ReadableStatusOlapQUEUED: 1, + sqlcv1.V1ReadableStatusOlapRUNNING: 2, + sqlcv1.V1ReadableStatusOlapEVICTED: 3, + sqlcv1.V1ReadableStatusOlapCANCELLED: 4, + sqlcv1.V1ReadableStatusOlapFAILED: 5, + sqlcv1.V1ReadableStatusOlapCOMPLETED: 6, + } + + left, leftOk := ordering[status1] + right, rightOk := ordering[status2] + + if !leftOk || !rightOk { + return false + } + + return left > right +} + func (r *OLAPRepositoryImpl) UpdateDAGStatuses(ctx context.Context, tenantIds []uuid.UUID) (bool, []UpdateDAGStatusRow, error) { ctx, span := telemetry.NewSpan(ctx, "olap_repository.update_dag_statuses") defer span.End() @@ -1972,7 +2107,7 @@ func (r *OLAPRepositoryImpl) writeDAGBatch(ctx context.Context, tenantId uuid.UU return nil } -func (r *OLAPRepositoryImpl) CreateTaskEvents(ctx context.Context, tenantId uuid.UUID, events []sqlcv1.CreateTaskEventsOLAPParams) error { +func (r *OLAPRepositoryImpl) CreateTaskEvents(ctx context.Context, tenantId uuid.UUID, events []sqlcv1.CreateTaskEventsOLAPParams) ([]sqlcv1.CreateTaskEventsOLAPParams, error) { return r.writeTaskEventBatch(ctx, tenantId, events) } diff --git a/pkg/repository/sqlcv1/olap.sql b/pkg/repository/sqlcv1/olap.sql index 2d7a74e3b2..4b42243557 100644 --- a/pkg/repository/sqlcv1/olap.sql +++ b/pkg/repository/sqlcv1/olap.sql @@ -2,9 +2,9 @@ SELECT create_v1_hash_partitions('v1_task_events_olap_tmp'::text, @partitions::int), create_v1_hash_partitions('v1_task_status_updates_tmp'::text, @partitions::int), - create_v1_olap_partition_with_date_and_status('v1_tasks_olap'::text, @date::date), - create_v1_olap_partition_with_date_and_status('v1_runs_olap'::text, @date::date), - create_v1_olap_partition_with_date_and_status('v1_dags_olap'::text, @date::date), + create_v1_range_partition('v1_tasks_olap'::text, @date::date), + create_v1_range_partition('v1_runs_olap'::text, @date::date), + create_v1_range_partition('v1_dags_olap'::text, @date::date), create_v1_range_partition('v1_payloads_olap'::text, @date::date) ; @@ -1004,6 +1004,7 @@ WITH tenants AS ( FROM locked_events ) + SELECT -- Little wonky, but we return the count of events that were processed in each row. Potential edge case -- where there are no tasks updated with a non-zero count, but this should be very rare and we'll get @@ -1013,6 +1014,164 @@ SELECT FROM all_result_tasks t; +-- name: UpdateTaskStatusesFromMQ :many +WITH inputs AS ( + SELECT + UNNEST(@tenantIds::UUID[]) AS tenant_id, + UNNEST(@taskIds::bigint[]) AS task_id, + UNNEST(@taskInsertedAts::timestamptz[]) AS task_inserted_at, + UNNEST(@statuses::v1_readable_status_olap[]) AS readable_status, + UNNEST(@workerIds::UUID[]) AS worker_id, + UNNEST(@retryCounts::int[]) AS retry_count +), locked_tasks AS ( + SELECT * + FROM v1_tasks_olap + WHERE (inserted_at, id, tenant_id) IN ( + SELECT task_inserted_at, task_id, tenant_id + FROM inputs + ) + FOR UPDATE +), updated_tasks AS ( + UPDATE v1_tasks_olap t + SET + readable_status = i.readable_status, + latest_retry_count = i.retry_count, + latest_worker_id = CASE + WHEN i.worker_id != '00000000-0000-0000-0000-000000000000'::uuid THEN i.worker_id + ELSE t.latest_worker_id + END + FROM locked_tasks lt + JOIN inputs i ON (i.tenant_id, i.task_id, i.task_inserted_at) = (lt.tenant_id, lt.id, lt.inserted_at) + WHERE + (t.inserted_at, t.id, t.tenant_id) = (lt.inserted_at, lt.id, lt.tenant_id) + AND ( + -- If the retry count is greater than the latest retry count, update the status + ( + i.retry_count > lt.latest_retry_count + AND i.readable_status != lt.readable_status + ) OR + -- If the retry count is equal, only update if the new status has higher priority + ( + i.retry_count = lt.latest_retry_count + AND v1_status_to_priority(i.readable_status) > v1_status_to_priority(lt.readable_status) + ) OR + -- EVICTED is non-terminal and reversible (durable restore moves it back to RUNNING) + ( + i.retry_count = lt.latest_retry_count + AND lt.readable_status = 'EVICTED' + AND i.readable_status != 'EVICTED' + ) + ) + RETURNING + t.tenant_id, t.id, t.inserted_at, t.readable_status, t.external_id, t.latest_worker_id, t.workflow_id, t.dag_id, t.dag_inserted_at, (t.dag_id IS NOT NULL)::boolean AS is_dag_task +) + +SELECT + t.tenant_id, + t.id AS task_id, + t.inserted_at AS task_inserted_at, + t.readable_status, + t.external_id, + t.latest_worker_id, + t.workflow_id, + t.dag_id, + t.dag_inserted_at, + (t.dag_id IS NOT NULL)::BOOLEAN AS is_dag_task, + (SELECT EXISTS ( + SELECT 1 + FROM updated_tasks ut + WHERE (ut.tenant_id, ut.id, ut.inserted_at) = (t.tenant_id, t.id, t.inserted_at) + )) AS was_updated +FROM v1_tasks_olap t +WHERE (t.inserted_at, t.id, t.tenant_id) IN ( + SELECT task_inserted_at, task_id, tenant_id + FROM inputs +) +; + +-- name: UpdateDAGStatusesFromMQ :many +WITH inputs AS ( + SELECT + UNNEST(@tenantIds::UUID[]) AS tenant_id, + UNNEST(@dagIds::bigint[]) AS dag_id, + UNNEST(@dagInsertedAts::timestamptz[]) AS dag_inserted_at +), locked_dags AS ( + SELECT * + FROM v1_dags_olap d + WHERE (d.inserted_at, d.id, d.tenant_id) IN ( + SELECT dag_inserted_at, dag_id, tenant_id + FROM inputs + ) + FOR UPDATE +), relevant_tasks AS ( + SELECT + d.tenant_id, + d.id AS dag_id, + d.inserted_at AS dag_inserted_at, + t.readable_status + FROM + locked_dags d + JOIN + v1_dag_to_task_olap dt ON + (d.id, d.inserted_at) = (dt.dag_id, dt.dag_inserted_at) + JOIN + v1_tasks_olap t ON + (dt.task_id, dt.task_inserted_at) = (t.id, t.inserted_at) +), dag_task_counts AS ( + SELECT + d.id, + d.inserted_at, + d.readable_status, + d.tenant_id, + d.total_tasks, + COUNT(t.dag_id) AS task_count, + COUNT(t.dag_id) FILTER (WHERE t.readable_status = 'COMPLETED') AS completed_count, + COUNT(t.dag_id) FILTER (WHERE t.readable_status = 'FAILED') AS failed_count, + COUNT(t.dag_id) FILTER (WHERE t.readable_status = 'CANCELLED') AS cancelled_count, + COUNT(t.dag_id) FILTER (WHERE t.readable_status = 'QUEUED') AS queued_count, + COUNT(t.dag_id) FILTER (WHERE t.readable_status = 'RUNNING') AS running_count, + COUNT(t.dag_id) FILTER (WHERE t.readable_status = 'EVICTED') AS evicted_count + FROM + locked_dags d + LEFT JOIN + relevant_tasks t ON (d.tenant_id, d.id, d.inserted_at) = (t.tenant_id, t.dag_id, t.dag_inserted_at) + GROUP BY + d.id, d.inserted_at, d.readable_status, d.tenant_id, d.total_tasks +), dag_new_statuses AS ( + SELECT + dtc.id, + dtc.inserted_at, + dtc.tenant_id, + CASE + -- If we only have queued events, we should keep the status as is + WHEN dtc.queued_count = dtc.task_count THEN dtc.readable_status + -- If the task count is not equal to the total tasks, we should set the status to running + WHEN dtc.task_count != dtc.total_tasks THEN 'RUNNING' + -- If we have any running or queued tasks, we should set the status to running + WHEN dtc.running_count > 0 OR dtc.queued_count > 0 THEN 'RUNNING' + -- If all tasks are evicted, mark DAG as evicted + WHEN dtc.evicted_count = dtc.task_count AND dtc.task_count = dtc.total_tasks THEN 'EVICTED' + WHEN dtc.failed_count > 0 THEN 'FAILED' + WHEN dtc.cancelled_count > 0 THEN 'CANCELLED' + WHEN dtc.completed_count = dtc.task_count THEN 'COMPLETED' + ELSE 'RUNNING' + END::v1_readable_status_olap AS new_readable_status + FROM + dag_task_counts dtc +) + +UPDATE v1_dags_olap d +SET + readable_status = dns.new_readable_status +FROM + dag_new_statuses dns +WHERE + (d.inserted_at, d.id, d.tenant_id) = (dns.inserted_at, dns.id, dns.tenant_id) + AND dns.new_readable_status != d.readable_status +RETURNING + d.tenant_id, d.id, d.inserted_at, d.readable_status, d.external_id, d.workflow_id +; + -- name: FindMinInsertedAtForDAGStatusUpdates :one WITH tenants AS ( SELECT UNNEST( diff --git a/pkg/repository/sqlcv1/olap.sql.go b/pkg/repository/sqlcv1/olap.sql.go index a0feecd147..1ebb75982c 100644 --- a/pkg/repository/sqlcv1/olap.sql.go +++ b/pkg/repository/sqlcv1/olap.sql.go @@ -429,9 +429,9 @@ const createOLAPPartitions = `-- name: CreateOLAPPartitions :exec SELECT create_v1_hash_partitions('v1_task_events_olap_tmp'::text, $1::int), create_v1_hash_partitions('v1_task_status_updates_tmp'::text, $1::int), - create_v1_olap_partition_with_date_and_status('v1_tasks_olap'::text, $2::date), - create_v1_olap_partition_with_date_and_status('v1_runs_olap'::text, $2::date), - create_v1_olap_partition_with_date_and_status('v1_dags_olap'::text, $2::date), + create_v1_range_partition('v1_tasks_olap'::text, $2::date), + create_v1_range_partition('v1_runs_olap'::text, $2::date), + create_v1_range_partition('v1_dags_olap'::text, $2::date), create_v1_range_partition('v1_payloads_olap'::text, $2::date) ` @@ -3650,6 +3650,131 @@ func (q *Queries) UpdateDAGStatuses(ctx context.Context, db DBTX, arg UpdateDAGS return items, nil } +const updateDAGStatusesFromMQ = `-- name: UpdateDAGStatusesFromMQ :many +WITH inputs AS ( + SELECT + UNNEST($1::UUID[]) AS tenant_id, + UNNEST($2::bigint[]) AS dag_id, + UNNEST($3::timestamptz[]) AS dag_inserted_at +), locked_dags AS ( + SELECT id, inserted_at, tenant_id, external_id, display_name, workflow_id, workflow_version_id, readable_status, input, additional_metadata, parent_task_external_id, total_tasks + FROM v1_dags_olap d + WHERE (d.inserted_at, d.id, d.tenant_id) IN ( + SELECT dag_inserted_at, dag_id, tenant_id + FROM inputs + ) + FOR UPDATE +), relevant_tasks AS ( + SELECT + d.tenant_id, + d.id AS dag_id, + d.inserted_at AS dag_inserted_at, + t.readable_status + FROM + locked_dags d + JOIN + v1_dag_to_task_olap dt ON + (d.id, d.inserted_at) = (dt.dag_id, dt.dag_inserted_at) + JOIN + v1_tasks_olap t ON + (dt.task_id, dt.task_inserted_at) = (t.id, t.inserted_at) +), dag_task_counts AS ( + SELECT + d.id, + d.inserted_at, + d.readable_status, + d.tenant_id, + d.total_tasks, + COUNT(t.dag_id) AS task_count, + COUNT(t.dag_id) FILTER (WHERE t.readable_status = 'COMPLETED') AS completed_count, + COUNT(t.dag_id) FILTER (WHERE t.readable_status = 'FAILED') AS failed_count, + COUNT(t.dag_id) FILTER (WHERE t.readable_status = 'CANCELLED') AS cancelled_count, + COUNT(t.dag_id) FILTER (WHERE t.readable_status = 'QUEUED') AS queued_count, + COUNT(t.dag_id) FILTER (WHERE t.readable_status = 'RUNNING') AS running_count, + COUNT(t.dag_id) FILTER (WHERE t.readable_status = 'EVICTED') AS evicted_count + FROM + locked_dags d + LEFT JOIN + relevant_tasks t ON (d.tenant_id, d.id, d.inserted_at) = (t.tenant_id, t.dag_id, t.dag_inserted_at) + GROUP BY + d.id, d.inserted_at, d.readable_status, d.tenant_id, d.total_tasks +), dag_new_statuses AS ( + SELECT + dtc.id, + dtc.inserted_at, + dtc.tenant_id, + CASE + -- If we only have queued events, we should keep the status as is + WHEN dtc.queued_count = dtc.task_count THEN dtc.readable_status + -- If the task count is not equal to the total tasks, we should set the status to running + WHEN dtc.task_count != dtc.total_tasks THEN 'RUNNING' + -- If we have any running or queued tasks, we should set the status to running + WHEN dtc.running_count > 0 OR dtc.queued_count > 0 THEN 'RUNNING' + -- If all tasks are evicted, mark DAG as evicted + WHEN dtc.evicted_count = dtc.task_count AND dtc.task_count = dtc.total_tasks THEN 'EVICTED' + WHEN dtc.failed_count > 0 THEN 'FAILED' + WHEN dtc.cancelled_count > 0 THEN 'CANCELLED' + WHEN dtc.completed_count = dtc.task_count THEN 'COMPLETED' + ELSE 'RUNNING' + END::v1_readable_status_olap AS new_readable_status + FROM + dag_task_counts dtc +) + +UPDATE v1_dags_olap d +SET + readable_status = dns.new_readable_status +FROM + dag_new_statuses dns +WHERE + (d.inserted_at, d.id, d.tenant_id) = (dns.inserted_at, dns.id, dns.tenant_id) + AND dns.new_readable_status != d.readable_status +RETURNING + d.tenant_id, d.id, d.inserted_at, d.readable_status, d.external_id, d.workflow_id +` + +type UpdateDAGStatusesFromMQParams struct { + Tenantids []uuid.UUID `json:"tenantids"` + Dagids []int64 `json:"dagids"` + Daginsertedats []pgtype.Timestamptz `json:"daginsertedats"` +} + +type UpdateDAGStatusesFromMQRow struct { + TenantID uuid.UUID `json:"tenant_id"` + ID int64 `json:"id"` + InsertedAt pgtype.Timestamptz `json:"inserted_at"` + ReadableStatus V1ReadableStatusOlap `json:"readable_status"` + ExternalID uuid.UUID `json:"external_id"` + WorkflowID uuid.UUID `json:"workflow_id"` +} + +func (q *Queries) UpdateDAGStatusesFromMQ(ctx context.Context, db DBTX, arg UpdateDAGStatusesFromMQParams) ([]*UpdateDAGStatusesFromMQRow, error) { + rows, err := db.Query(ctx, updateDAGStatusesFromMQ, arg.Tenantids, arg.Dagids, arg.Daginsertedats) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*UpdateDAGStatusesFromMQRow + for rows.Next() { + var i UpdateDAGStatusesFromMQRow + if err := rows.Scan( + &i.TenantID, + &i.ID, + &i.InsertedAt, + &i.ReadableStatus, + &i.ExternalID, + &i.WorkflowID, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const updateTaskStatuses = `-- name: UpdateTaskStatuses :many WITH tenants AS ( SELECT UNNEST( @@ -3882,6 +4007,7 @@ WITH tenants AS ( FROM locked_events ) + SELECT -- Little wonky, but we return the count of events that were processed in each row. Potential edge case -- where there are no tasks updated with a non-zero count, but this should be very rare and we'll get @@ -3945,3 +4071,140 @@ func (q *Queries) UpdateTaskStatuses(ctx context.Context, db DBTX, arg UpdateTas } return items, nil } + +const updateTaskStatusesFromMQ = `-- name: UpdateTaskStatusesFromMQ :many +WITH inputs AS ( + SELECT + UNNEST($1::UUID[]) AS tenant_id, + UNNEST($2::bigint[]) AS task_id, + UNNEST($3::timestamptz[]) AS task_inserted_at, + UNNEST($4::v1_readable_status_olap[]) AS readable_status, + UNNEST($5::UUID[]) AS worker_id, + UNNEST($6::int[]) AS retry_count +), locked_tasks AS ( + SELECT tenant_id, id, inserted_at, external_id, queue, action_id, step_id, workflow_id, workflow_version_id, workflow_run_id, schedule_timeout, step_timeout, priority, sticky, desired_worker_id, display_name, input, additional_metadata, readable_status, latest_retry_count, latest_worker_id, dag_id, dag_inserted_at, parent_task_external_id + FROM v1_tasks_olap + WHERE (inserted_at, id, tenant_id) IN ( + SELECT task_inserted_at, task_id, tenant_id + FROM inputs + ) + FOR UPDATE +), updated_tasks AS ( + UPDATE v1_tasks_olap t + SET + readable_status = i.readable_status, + latest_retry_count = i.retry_count, + latest_worker_id = CASE + WHEN i.worker_id != '00000000-0000-0000-0000-000000000000'::uuid THEN i.worker_id + ELSE t.latest_worker_id + END + FROM locked_tasks lt + JOIN inputs i ON (i.tenant_id, i.task_id, i.task_inserted_at) = (lt.tenant_id, lt.id, lt.inserted_at) + WHERE + (t.inserted_at, t.id, t.tenant_id) = (lt.inserted_at, lt.id, lt.tenant_id) + AND ( + -- If the retry count is greater than the latest retry count, update the status + ( + i.retry_count > lt.latest_retry_count + AND i.readable_status != lt.readable_status + ) OR + -- If the retry count is equal, only update if the new status has higher priority + ( + i.retry_count = lt.latest_retry_count + AND v1_status_to_priority(i.readable_status) > v1_status_to_priority(lt.readable_status) + ) OR + -- EVICTED is non-terminal and reversible (durable restore moves it back to RUNNING) + ( + i.retry_count = lt.latest_retry_count + AND lt.readable_status = 'EVICTED' + AND i.readable_status != 'EVICTED' + ) + ) + RETURNING + t.tenant_id, t.id, t.inserted_at, t.readable_status, t.external_id, t.latest_worker_id, t.workflow_id, t.dag_id, t.dag_inserted_at, (t.dag_id IS NOT NULL)::boolean AS is_dag_task +) + +SELECT + t.tenant_id, + t.id AS task_id, + t.inserted_at AS task_inserted_at, + t.readable_status, + t.external_id, + t.latest_worker_id, + t.workflow_id, + t.dag_id, + t.dag_inserted_at, + (t.dag_id IS NOT NULL)::BOOLEAN AS is_dag_task, + (SELECT EXISTS ( + SELECT 1 + FROM updated_tasks ut + WHERE (ut.tenant_id, ut.id, ut.inserted_at) = (t.tenant_id, t.id, t.inserted_at) + )) AS was_updated +FROM v1_tasks_olap t +WHERE (t.inserted_at, t.id, t.tenant_id) IN ( + SELECT task_inserted_at, task_id, tenant_id + FROM inputs +) +` + +type UpdateTaskStatusesFromMQParams struct { + Tenantids []uuid.UUID `json:"tenantids"` + Taskids []int64 `json:"taskids"` + Taskinsertedats []pgtype.Timestamptz `json:"taskinsertedats"` + Statuses []V1ReadableStatusOlap `json:"statuses"` + Workerids []uuid.UUID `json:"workerids"` + Retrycounts []int32 `json:"retrycounts"` +} + +type UpdateTaskStatusesFromMQRow struct { + TenantID uuid.UUID `json:"tenant_id"` + TaskID int64 `json:"task_id"` + TaskInsertedAt pgtype.Timestamptz `json:"task_inserted_at"` + ReadableStatus V1ReadableStatusOlap `json:"readable_status"` + ExternalID uuid.UUID `json:"external_id"` + LatestWorkerID *uuid.UUID `json:"latest_worker_id"` + WorkflowID uuid.UUID `json:"workflow_id"` + DagID pgtype.Int8 `json:"dag_id"` + DagInsertedAt pgtype.Timestamptz `json:"dag_inserted_at"` + IsDagTask bool `json:"is_dag_task"` + WasUpdated bool `json:"was_updated"` +} + +func (q *Queries) UpdateTaskStatusesFromMQ(ctx context.Context, db DBTX, arg UpdateTaskStatusesFromMQParams) ([]*UpdateTaskStatusesFromMQRow, error) { + rows, err := db.Query(ctx, updateTaskStatusesFromMQ, + arg.Tenantids, + arg.Taskids, + arg.Taskinsertedats, + arg.Statuses, + arg.Workerids, + arg.Retrycounts, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []*UpdateTaskStatusesFromMQRow + for rows.Next() { + var i UpdateTaskStatusesFromMQRow + if err := rows.Scan( + &i.TenantID, + &i.TaskID, + &i.TaskInsertedAt, + &i.ReadableStatus, + &i.ExternalID, + &i.LatestWorkerID, + &i.WorkflowID, + &i.DagID, + &i.DagInsertedAt, + &i.IsDagTask, + &i.WasUpdated, + ); err != nil { + return nil, err + } + items = append(items, &i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/pkg/repository/sqlcv1/tenants.sql b/pkg/repository/sqlcv1/tenants.sql index 685e387456..261f24da8a 100644 --- a/pkg/repository/sqlcv1/tenants.sql +++ b/pkg/repository/sqlcv1/tenants.sql @@ -81,10 +81,8 @@ WHERE AND "slug" = 'internal'; -- name: ListTenantsByControllerPartitionId :many -SELECT - * -FROM - "Tenant" as tenants +SELECT "id" +FROM "Tenant" WHERE "controllerPartitionId" = sqlc.arg('controllerPartitionId')::text AND "version" = @majorVersion::"TenantMajorEngineVersion" diff --git a/pkg/repository/sqlcv1/tenants.sql.go b/pkg/repository/sqlcv1/tenants.sql.go index a44235f00d..e767f9285a 100644 --- a/pkg/repository/sqlcv1/tenants.sql.go +++ b/pkg/repository/sqlcv1/tenants.sql.go @@ -994,10 +994,8 @@ func (q *Queries) ListTenants(ctx context.Context, db DBTX) ([]*Tenant, error) { } const listTenantsByControllerPartitionId = `-- name: ListTenantsByControllerPartitionId :many -SELECT - id, "createdAt", "updatedAt", "deletedAt", version, "uiVersion", name, slug, "analyticsOptOut", "alertMemberEmails", "controllerPartitionId", "workerPartitionId", "dataRetentionPeriod", "schedulerPartitionId", "canUpgradeV1", "onboardingData", environment -FROM - "Tenant" as tenants +SELECT "id" +FROM "Tenant" WHERE "controllerPartitionId" = $1::text AND "version" = $2::"TenantMajorEngineVersion" @@ -1009,37 +1007,19 @@ type ListTenantsByControllerPartitionIdParams struct { Majorversion TenantMajorEngineVersion `json:"majorversion"` } -func (q *Queries) ListTenantsByControllerPartitionId(ctx context.Context, db DBTX, arg ListTenantsByControllerPartitionIdParams) ([]*Tenant, error) { +func (q *Queries) ListTenantsByControllerPartitionId(ctx context.Context, db DBTX, arg ListTenantsByControllerPartitionIdParams) ([]uuid.UUID, error) { rows, err := db.Query(ctx, listTenantsByControllerPartitionId, arg.ControllerPartitionId, arg.Majorversion) if err != nil { return nil, err } defer rows.Close() - var items []*Tenant + var items []uuid.UUID for rows.Next() { - var i Tenant - if err := rows.Scan( - &i.ID, - &i.CreatedAt, - &i.UpdatedAt, - &i.DeletedAt, - &i.Version, - &i.UiVersion, - &i.Name, - &i.Slug, - &i.AnalyticsOptOut, - &i.AlertMemberEmails, - &i.ControllerPartitionId, - &i.WorkerPartitionId, - &i.DataRetentionPeriod, - &i.SchedulerPartitionId, - &i.CanUpgradeV1, - &i.OnboardingData, - &i.Environment, - ); err != nil { + var id uuid.UUID + if err := rows.Scan(&id); err != nil { return nil, err } - items = append(items, &i) + items = append(items, id) } if err := rows.Err(); err != nil { return nil, err diff --git a/pkg/repository/tenant.go b/pkg/repository/tenant.go index 801c6395b1..a57495d695 100644 --- a/pkg/repository/tenant.go +++ b/pkg/repository/tenant.go @@ -137,7 +137,7 @@ type TenantRepository interface { GetInternalTenantForController(ctx context.Context, controllerPartitionId string) (*sqlcv1.Tenant, error) // ListTenantsByPartition lists all tenants in the given partition - ListTenantsByControllerPartition(ctx context.Context, controllerPartitionId string, majorVersion sqlcv1.TenantMajorEngineVersion) ([]*sqlcv1.Tenant, error) + ListTenantsByControllerPartition(ctx context.Context, controllerPartitionId string, majorVersion sqlcv1.TenantMajorEngineVersion) ([]uuid.UUID, error) ListTenantsByWorkerPartition(ctx context.Context, workerPartitionId string, majorVersion sqlcv1.TenantMajorEngineVersion) ([]*sqlcv1.Tenant, error) @@ -743,7 +743,7 @@ func (r *tenantRepository) GetInternalTenantForController(ctx context.Context, c return tenant, nil } -func (r *tenantRepository) ListTenantsByControllerPartition(ctx context.Context, controllerPartitionId string, majorVersion sqlcv1.TenantMajorEngineVersion) ([]*sqlcv1.Tenant, error) { +func (r *tenantRepository) ListTenantsByControllerPartition(ctx context.Context, controllerPartitionId string, majorVersion sqlcv1.TenantMajorEngineVersion) ([]uuid.UUID, error) { if controllerPartitionId == "" { return nil, fmt.Errorf("partitionId is required") } diff --git a/sdks/python/hatchet_sdk/context/context.py b/sdks/python/hatchet_sdk/context/context.py index a01c89ba17..35439ff288 100644 --- a/sdks/python/hatchet_sdk/context/context.py +++ b/sdks/python/hatchet_sdk/context/context.py @@ -6,7 +6,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar, cast, overload from warnings import warn from pydantic import BaseModel, TypeAdapter @@ -98,6 +98,52 @@ def _compute_memo_key(task_run_external_id: str, *args: Any, **kwargs: Any) -> b return h.digest() +TSagaOperationResult = TypeVar("TSagaOperationResult") + + +class SagaOperation(Generic[TSagaOperationResult]): + def __init__( + self, + operation_fn: Callable[[], TSagaOperationResult], + compensation_fn: Callable[[], None], + ): + self._compensation_fn = compensation_fn + self._operation_fn = operation_fn + + def apply(self) -> TSagaOperationResult: + return self._operation_fn() + + def rollback(self) -> None: + return self._compensation_fn() + + +class Saga: + def __init__(self) -> None: + self._stack: list[SagaOperation[Any]] = [] + + def add( + self, + operation_fn: Callable[[], TSagaOperationResult], + compensation_fn: Callable[[], None], + ) -> TSagaOperationResult: + operation = SagaOperation(operation_fn, compensation_fn) + self._stack.append(operation) + + try: + return operation.apply() + except Exception as e: + self._rollback() + raise e + + def _rollback(self) -> None: + while self._stack: + operation = self._stack.pop() + try: + operation.rollback() + except Exception: + logger.exception("Error during compensation") + + class Context: def __init__( self, @@ -721,6 +767,9 @@ def get_task_run_error( return TaskRunError.deserialize(error) + def begin_compensation_chain(self) -> Saga: + return Saga() + @dataclass class DurableSpawnResult: diff --git a/sql/schema/v1-olap.sql b/sql/schema/v1-olap.sql index 7968c83db2..c2b2f5034d 100644 --- a/sql/schema/v1-olap.sql +++ b/sql/schema/v1-olap.sql @@ -89,40 +89,6 @@ BEGIN END; $$; -CREATE OR REPLACE FUNCTION create_v1_olap_partition_with_date_and_status( - targetTableName text, - targetDate date -) RETURNS integer - LANGUAGE plpgsql AS -$$ -DECLARE - targetDateStr varchar; - targetDatePlusOneDayStr varchar; - newTableName varchar; -BEGIN - SELECT to_char(targetDate, 'YYYYMMDD') INTO targetDateStr; - SELECT to_char(targetDate + INTERVAL '1 day', 'YYYYMMDD') INTO targetDatePlusOneDayStr; - SELECT format('%s_%s', targetTableName, targetDateStr) INTO newTableName; - IF NOT EXISTS (SELECT 1 FROM pg_tables WHERE tablename = newTableName) THEN - EXECUTE format('CREATE TABLE %s (LIKE %s INCLUDING INDEXES) PARTITION BY LIST (readable_status)', newTableName, targetTableName); - END IF; - - PERFORM create_v1_partition_with_status(newTableName, 'QUEUED'); - PERFORM create_v1_partition_with_status(newTableName, 'RUNNING'); - PERFORM create_v1_partition_with_status(newTableName, 'COMPLETED'); - PERFORM create_v1_partition_with_status(newTableName, 'CANCELLED'); - PERFORM create_v1_partition_with_status(newTableName, 'FAILED'); - PERFORM create_v1_partition_with_status(newTableName, 'EVICTED'); - - -- If it's not already attached, attach the partition - IF NOT EXISTS (SELECT 1 FROM pg_inherits WHERE inhrelid = newTableName::regclass) THEN - EXECUTE format('ALTER TABLE %s ATTACH PARTITION %s FOR VALUES FROM (''%s'') TO (''%s'')', targetTableName, newTableName, targetDateStr, targetDatePlusOneDayStr); - END IF; - - RETURN 1; -END; -$$; - CREATE OR REPLACE FUNCTION create_v1_hash_partitions( targetTableName text, num_partitions INT @@ -197,8 +163,6 @@ CREATE INDEX v1_tasks_olap_workflow_id_idx ON v1_tasks_olap (tenant_id, workflow CREATE INDEX v1_tasks_olap_worker_id_idx ON v1_tasks_olap (tenant_id, latest_worker_id) WHERE latest_worker_id IS NOT NULL; -SELECT create_v1_olap_partition_with_date_and_status('v1_tasks_olap', CURRENT_DATE); - -- DAG DEFINITIONS -- CREATE TABLE v1_dags_olap ( id BIGINT NOT NULL, @@ -218,8 +182,6 @@ CREATE TABLE v1_dags_olap ( CREATE INDEX v1_dags_olap_workflow_id_idx ON v1_dags_olap (tenant_id, workflow_id); -SELECT create_v1_olap_partition_with_date_and_status('v1_dags_olap', CURRENT_DATE); - -- RUN DEFINITIONS -- CREATE TYPE v1_run_kind AS ENUM ('TASK', 'DAG'); @@ -240,10 +202,8 @@ CREATE TABLE v1_runs_olap ( PRIMARY KEY (inserted_at, id, readable_status, kind) ) PARTITION BY RANGE(inserted_at); -SELECT create_v1_olap_partition_with_date_and_status('v1_runs_olap', CURRENT_DATE); - CREATE INDEX ix_v1_runs_olap_parent_task_external_id ON v1_runs_olap (parent_task_external_id) WHERE parent_task_external_id IS NOT NULL; -CREATE INDEX ix_v1_runs_olap_tenant_id ON v1_runs_olap (tenant_id, inserted_at, id, readable_status, kind); +CREATE INDEX ix_v1_runs_olap_tenant_status_ins_at ON v1_runs_olap (tenant_id, readable_status, inserted_at DESC); -- LOOKUP TABLES -- CREATE TABLE v1_lookup_table_olap (