diff --git a/flytepropeller/events/admin_eventsink.go b/flytepropeller/events/admin_eventsink.go index cb4b88a69a..3da6cca421 100644 --- a/flytepropeller/events/admin_eventsink.go +++ b/flytepropeller/events/admin_eventsink.go @@ -57,7 +57,11 @@ func (s *adminEventSink) Sink(ctx context.Context, message proto.Message) error if s.filter.Contains(ctx, id) { logger.Debugf(ctx, "event '%s' has already been sent", string(id)) - return nil + return &errors.EventError{ + Code: errors.AlreadyExists, + Cause: fmt.Errorf("event has already been sent"), + Message: "Event Already Exists", + } } // Validate submission with rate limiter and send admin event diff --git a/flytepropeller/events/admin_eventsink_test.go b/flytepropeller/events/admin_eventsink_test.go index c13b7ad47f..510371d056 100644 --- a/flytepropeller/events/admin_eventsink_test.go +++ b/flytepropeller/events/admin_eventsink_test.go @@ -184,13 +184,16 @@ func TestAdminFilterContains(t *testing.T) { filter.OnContainsMatch(mock.Anything, mock.Anything).Return(true) wfErr := adminEventSink.Sink(ctx, wfEvent) - assert.NoError(t, wfErr) + assert.Error(t, wfErr) + assert.True(t, errors.IsAlreadyExists(wfErr)) nodeErr := adminEventSink.Sink(ctx, nodeEvent) - assert.NoError(t, nodeErr) + assert.Error(t, nodeErr) + assert.True(t, errors.IsAlreadyExists(nodeErr)) taskErr := adminEventSink.Sink(ctx, taskEvent) - assert.NoError(t, taskErr) + assert.Error(t, taskErr) + assert.True(t, errors.IsAlreadyExists(taskErr)) } func TestIDFromMessage(t *testing.T) { diff --git a/flytepropeller/events/errors/errors.go b/flytepropeller/events/errors/errors.go index 879b8b07d7..2d3e02e0df 100644 --- a/flytepropeller/events/errors/errors.go +++ b/flytepropeller/events/errors/errors.go @@ -33,7 +33,11 @@ type EventError struct { } func (r EventError) Error() string { - return fmt.Sprintf("%s: %s, caused by [%s]", r.Code, r.Message, r.Cause.Error()) + var cause string + if r.Cause != nil { + cause = r.Cause.Error() + } + return fmt.Sprintf("%s: %s, caused by [%s]", r.Code, r.Message, cause) } func (r *EventError) Is(target error) bool { diff --git a/flytepropeller/pkg/controller/config/config.go b/flytepropeller/pkg/controller/config/config.go index 419386eddd..f058212322 100644 --- a/flytepropeller/pkg/controller/config/config.go +++ b/flytepropeller/pkg/controller/config/config.go @@ -259,6 +259,7 @@ const ( type EventConfig struct { RawOutputPolicy RawOutputPolicy `json:"raw-output-policy" pflag:",How output data should be passed along in execution events."` FallbackToOutputReference bool `json:"fallback-to-output-reference" pflag:",Whether output data should be sent by reference when it is too large to be sent inline in execution events."` + ErrorOnAlreadyExists bool `json:"error-on-already-exists" pflag:",Whether to return an error when an event already exists."` } // ParallelismBehavior defines how ArrayNode should handle subNode parallelism by default diff --git a/flytepropeller/pkg/controller/config/config_flags.go b/flytepropeller/pkg/controller/config/config_flags.go index ea0b428c2f..d2dc0971ff 100755 --- a/flytepropeller/pkg/controller/config/config_flags.go +++ b/flytepropeller/pkg/controller/config/config_flags.go @@ -100,6 +100,7 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "max-streak-length"), defaultConfig.MaxStreakLength, "Maximum number of consecutive rounds that one propeller worker can use for one workflow - >1 => turbo-mode is enabled.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "event-config.raw-output-policy"), defaultConfig.EventConfig.RawOutputPolicy, "How output data should be passed along in execution events.") cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "event-config.fallback-to-output-reference"), defaultConfig.EventConfig.FallbackToOutputReference, "Whether output data should be sent by reference when it is too large to be sent inline in execution events.") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "event-config.error-on-already-exists"), defaultConfig.EventConfig.ErrorOnAlreadyExists, "Whether to return an error when an event already exists.") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "include-shard-key-label"), defaultConfig.IncludeShardKeyLabel, "Include the specified shard key label in the k8s FlyteWorkflow CRD label selector") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "exclude-shard-key-label"), defaultConfig.ExcludeShardKeyLabel, "Exclude the specified shard key label from the k8s FlyteWorkflow CRD label selector") cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "include-project-label"), defaultConfig.IncludeProjectLabel, "Include the specified project label in the k8s FlyteWorkflow CRD label selector") diff --git a/flytepropeller/pkg/controller/config/config_flags_test.go b/flytepropeller/pkg/controller/config/config_flags_test.go index bce7238f60..66a14381af 100755 --- a/flytepropeller/pkg/controller/config/config_flags_test.go +++ b/flytepropeller/pkg/controller/config/config_flags_test.go @@ -799,6 +799,20 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_event-config.error-on-already-exists", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("event-config.error-on-already-exists", testValue) + if vBool, err := cmdFlags.GetBool("event-config.error-on-already-exists"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.EventConfig.ErrorOnAlreadyExists) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_include-shard-key-label", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/flytepropeller/pkg/controller/nodes/array/handler.go b/flytepropeller/pkg/controller/nodes/array/handler.go index 315041cb51..a101ed5a30 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler.go +++ b/flytepropeller/pkg/controller/nodes/array/handler.go @@ -11,6 +11,7 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/array/errorcollector" "github.com/flyteorg/flyte/flytepropeller/events" + eventsErr "github.com/flyteorg/flyte/flytepropeller/events/errors" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/validators" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" @@ -21,6 +22,7 @@ import ( "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/interfaces" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/task/k8s" "github.com/flyteorg/flyte/flytestdlib/bitarray" + stdConfig "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/promutils" "github.com/flyteorg/flyte/flytestdlib/storage" @@ -112,6 +114,10 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut // update state for subNodes if err := eventRecorder.finalize(ctx, nCtx, idlcore.TaskExecution_ABORTED, 0, a.eventConfig); err != nil { + // a task event with abort phase is already emitted when handling ArrayNodePhaseFailing + if eventsErr.IsAlreadyExists(err) { + return nil + } logger.Errorf(ctx, "ArrayNode event recording failed: [%s]", err.Error()) return err } @@ -579,12 +585,35 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu // increment taskPhaseVersion if we detect any changes in subNode state. if incrementTaskPhaseVersion { - arrayNodeState.TaskPhaseVersion = arrayNodeState.TaskPhaseVersion + 1 + arrayNodeState.TaskPhaseVersion++ } - if err := eventRecorder.finalize(ctx, nCtx, taskPhase, arrayNodeState.TaskPhaseVersion, a.eventConfig); err != nil { - logger.Errorf(ctx, "ArrayNode event recording failed: [%s]", err.Error()) - return handler.UnknownTransition, err + const maxRetries = 3 + retries := 0 + for retries <= maxRetries { + err := eventRecorder.finalize(ctx, nCtx, taskPhase, arrayNodeState.TaskPhaseVersion, a.eventConfig) + + if err == nil { + break + } + + // Handle potential race condition if FlyteWorkflow CRD fails to get synced + if eventsErr.IsAlreadyExists(err) { + if !incrementTaskPhaseVersion { + break + } + logger.Warnf(ctx, "Event version already exists, bumping version and retrying (%d/%d): [%s]", retries+1, maxRetries, err.Error()) + arrayNodeState.TaskPhaseVersion++ + } else { + logger.Errorf(ctx, "ArrayNode event recording failed: [%s]", err.Error()) + return handler.UnknownTransition, err + } + + retries++ + if retries > maxRetries { + logger.Errorf(ctx, "ArrayNode event recording failed after %d retries: [%s]", maxRetries, err.Error()) + return handler.UnknownTransition, err + } } // if the ArrayNode phase has changed we need to reset the taskPhaseVersion to 0 @@ -632,9 +661,21 @@ func New(nodeExecutor interfaces.Node, eventConfig *config.EventConfig, scope pr return nil, err } + eventConfigCopy, err := stdConfig.DeepCopyConfig(eventConfig) + if err != nil { + return nil, err + } + + deepCopiedEventConfig, ok := eventConfigCopy.(*config.EventConfig) + if !ok { + return nil, fmt.Errorf("deep copy error: expected *config.EventConfig, but got %T", eventConfigCopy) + } + + deepCopiedEventConfig.ErrorOnAlreadyExists = true + arrayScope := scope.NewSubScope("array") return &arrayNodeHandler{ - eventConfig: eventConfig, + eventConfig: deepCopiedEventConfig, gatherOutputsRequestChannel: make(chan *gatherOutputsRequest), metrics: newMetrics(arrayScope), nodeExecutionRequestChannel: make(chan *nodeExecutionRequest), diff --git a/flytepropeller/pkg/controller/nodes/array/handler_test.go b/flytepropeller/pkg/controller/nodes/array/handler_test.go index ee1fc5b80b..d27b412c1f 100644 --- a/flytepropeller/pkg/controller/nodes/array/handler_test.go +++ b/flytepropeller/pkg/controller/nodes/array/handler_test.go @@ -12,7 +12,8 @@ import ( idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/event" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" - pluginmocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" + pluginiomocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" + eventsErr "github.com/flyteorg/flyte/flytepropeller/events/errors" eventmocks "github.com/flyteorg/flyte/flytepropeller/events/mocks" "github.com/flyteorg/flyte/flytepropeller/pkg/apis/flyteworkflow/v1alpha1" "github.com/flyteorg/flyte/flytepropeller/pkg/controller/config" @@ -50,7 +51,7 @@ func createArrayNodeHandler(ctx context.Context, t *testing.T, nodeHandler inter // mock components adminClient := launchplan.NewFailFastLaunchPlanExecutor() enqueueWorkflowFunc := func(workflowID v1alpha1.WorkflowID) {} - eventConfig := &config.EventConfig{} + eventConfig := &config.EventConfig{ErrorOnAlreadyExists: true} mockEventSink := eventmocks.NewMockEventSink() mockHandlerFactory := &mocks.HandlerFactory{} mockHandlerFactory.OnGetHandlerMatch(mock.Anything).Return(nodeHandler, nil) @@ -135,7 +136,7 @@ func createNodeExecutionContext(dataStore *storage.DataStore, eventRecorder inte nCtx.OnEventsRecorder().Return(eventRecorder) // InputReader - inputFilePaths := &pluginmocks.InputFilePaths{} + inputFilePaths := &pluginiomocks.InputFilePaths{} inputFilePaths.OnGetInputPath().Return(storage.DataReference("s3://bucket/input")) nCtx.OnInputReader().Return( newStaticInputReader( @@ -459,6 +460,24 @@ func uint32Ptr(v uint32) *uint32 { return &v } +type fakeEventRecorder struct { + taskErr error + phaseVersionFailures uint32 + recordTaskEventCallCount int +} + +func (f *fakeEventRecorder) RecordNodeEvent(ctx context.Context, event *event.NodeExecutionEvent, eventConfig *config.EventConfig) error { + return nil +} + +func (f *fakeEventRecorder) RecordTaskEvent(ctx context.Context, event *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { + f.recordTaskEventCallCount++ + if f.phaseVersionFailures == 0 || event.PhaseVersion < f.phaseVersionFailures { + return f.taskErr + } + return nil +} + func TestHandleArrayNodePhaseExecuting(t *testing.T) { ctx := context.Background() @@ -492,11 +511,18 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { subNodeTaskPhases []core.Phase subNodeTransitions []handler.Transition expectedArrayNodePhase v1alpha1.ArrayNodePhase + expectedArrayNodeSubPhases []v1alpha1.NodePhase expectedTransitionPhase handler.EPhase expectedExternalResourcePhases []idlcore.TaskExecution_Phase currentWfParallelism uint32 maxWfParallelism uint32 incrementParallelismCount uint32 + useFakeEventRecorder bool + eventRecorderFailures uint32 + eventRecorderError error + expectedTaskPhaseVersion uint32 + expectHandleError bool + expectedEventingCalls int }{ { name: "StartAllSubNodes", @@ -514,6 +540,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), }, expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTaskPhaseVersion: 1, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_RUNNING}, incrementParallelismCount: 1, @@ -533,6 +560,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), }, expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTaskPhaseVersion: 1, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING}, incrementParallelismCount: 1, @@ -553,6 +581,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), }, expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTaskPhaseVersion: 1, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_RUNNING}, currentWfParallelism: 0, @@ -573,6 +602,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), }, expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTaskPhaseVersion: 1, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING}, currentWfParallelism: workflowMaxParallelism - 1, @@ -591,6 +621,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { }, subNodeTransitions: []handler.Transition{}, expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTaskPhaseVersion: 0, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{}, currentWfParallelism: workflowMaxParallelism, @@ -612,6 +643,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), }, expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTaskPhaseVersion: 1, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_RUNNING}, incrementParallelismCount: 1, @@ -632,6 +664,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), }, expectedArrayNodePhase: v1alpha1.ArrayNodePhaseSucceeding, + expectedTaskPhaseVersion: 0, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_SUCCEEDED, idlcore.TaskExecution_SUCCEEDED}, }, @@ -652,6 +685,7 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoFailure(0, "", "", &handler.ExecutionInfo{})), }, expectedArrayNodePhase: v1alpha1.ArrayNodePhaseSucceeding, + expectedTaskPhaseVersion: 0, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_SUCCEEDED, idlcore.TaskExecution_FAILED}, }, @@ -671,9 +705,78 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(&handler.ExecutionInfo{})), }, expectedArrayNodePhase: v1alpha1.ArrayNodePhaseFailing, + expectedTaskPhaseVersion: 0, expectedTransitionPhase: handler.EPhaseRunning, expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_FAILED, idlcore.TaskExecution_SUCCEEDED}, }, + { + name: "EventingAlreadyExists_EventuallySucceeds", + parallelism: uint32Ptr(0), + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseQueued, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseRunning, + core.PhaseRunning, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + }, + expectedArrayNodePhase: v1alpha1.ArrayNodePhaseExecuting, + expectedTaskPhaseVersion: 2, + expectedTransitionPhase: handler.EPhaseRunning, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_RUNNING}, + useFakeEventRecorder: true, + eventRecorderFailures: 2, + eventRecorderError: &eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}, + incrementParallelismCount: 1, + expectedEventingCalls: 2, + }, + { + name: "EventingAlreadyExists_EventuallyFails", + parallelism: uint32Ptr(0), + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseQueued, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseRunning, + core.PhaseRunning, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + }, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_RUNNING}, + useFakeEventRecorder: true, + eventRecorderFailures: 5, + eventRecorderError: &eventsErr.EventError{Code: eventsErr.AlreadyExists, Cause: fmt.Errorf("err")}, + expectHandleError: true, + expectedEventingCalls: 4, + }, + { + name: "EventingFails", + parallelism: uint32Ptr(0), + subNodePhases: []v1alpha1.NodePhase{ + v1alpha1.NodePhaseQueued, + v1alpha1.NodePhaseQueued, + }, + subNodeTaskPhases: []core.Phase{ + core.PhaseRunning, + core.PhaseRunning, + }, + subNodeTransitions: []handler.Transition{ + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoRunning(&handler.ExecutionInfo{})), + }, + expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_RUNNING, idlcore.TaskExecution_RUNNING}, + useFakeEventRecorder: true, + eventRecorderError: fmt.Errorf("err"), + expectHandleError: true, + expectedEventingCalls: 1, + }, } for _, test := range tests { @@ -684,6 +787,15 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { }, scope) assert.NoError(t, err) + var eventRecorder interfaces.EventRecorder + if test.useFakeEventRecorder { + eventRecorder = &fakeEventRecorder{ + phaseVersionFailures: test.eventRecorderFailures, + taskErr: test.eventRecorderError, + } + } else { + eventRecorder = newBufferedEventRecorder() + } // initialize ArrayNodeState arrayNodeState := &handler.ArrayNodeState{ Phase: v1alpha1.ArrayNodePhaseExecuting, @@ -705,18 +817,12 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { for i, nodePhase := range test.subNodePhases { arrayNodeState.SubNodePhases.SetItem(i, bitarray.Item(nodePhase)) } - for i, taskPhase := range test.subNodeTaskPhases { - arrayNodeState.SubNodeTaskPhases.SetItem(i, bitarray.Item(taskPhase)) - } - - // create NodeExecutionContext - eventRecorder := newBufferedEventRecorder() nodeSpec := arrayNodeSpec nodeSpec.ArrayNode.Parallelism = test.parallelism nodeSpec.ArrayNode.MinSuccessRatio = test.minSuccessRatio - nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState, test.currentWfParallelism, workflowMaxParallelism) + nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &nodeSpec, arrayNodeState, test.currentWfParallelism, workflowMaxParallelism) // initialize ArrayNodeHandler nodeHandler := &mocks.NodeHandler{} @@ -745,22 +851,41 @@ func TestHandleArrayNodePhaseExecuting(t *testing.T) { // evaluate node transition, err := arrayNodeHandler.Handle(ctx, nCtx) - assert.NoError(t, err) + + fakeEventRecorder, ok := eventRecorder.(*fakeEventRecorder) + if ok { + assert.Equal(t, test.expectedEventingCalls, fakeEventRecorder.recordTaskEventCallCount) + } + + if !test.expectHandleError { + assert.NoError(t, err) + } else { + assert.Error(t, err) + return + } // validate results assert.Equal(t, test.expectedArrayNodePhase, arrayNodeState.Phase) assert.Equal(t, test.expectedTransitionPhase, transition.Info().GetPhase()) + assert.Equal(t, test.expectedTaskPhaseVersion, arrayNodeState.TaskPhaseVersion) - if len(test.expectedExternalResourcePhases) > 0 { - assert.Equal(t, 1, len(eventRecorder.taskExecutionEvents)) + for i, expectedPhase := range test.expectedArrayNodeSubPhases { + assert.Equal(t, expectedPhase, v1alpha1.NodePhase(arrayNodeState.SubNodePhases.GetItem(i))) + } - externalResources := eventRecorder.taskExecutionEvents[0].Metadata.GetExternalResources() - assert.Equal(t, len(test.expectedExternalResourcePhases), len(externalResources)) - for i, expectedPhase := range test.expectedExternalResourcePhases { - assert.Equal(t, expectedPhase, externalResources[i].Phase) + bufferedEventRecorder, ok := eventRecorder.(*bufferedEventRecorder) + if ok { + if len(test.expectedExternalResourcePhases) > 0 { + assert.Equal(t, 1, len(bufferedEventRecorder.taskExecutionEvents)) + + externalResources := bufferedEventRecorder.taskExecutionEvents[0].Metadata.GetExternalResources() + assert.Equal(t, len(test.expectedExternalResourcePhases), len(externalResources)) + for i, expectedPhase := range test.expectedExternalResourcePhases { + assert.Equal(t, expectedPhase, externalResources[i].Phase) + } + } else { + assert.Equal(t, 0, len(bufferedEventRecorder.taskExecutionEvents)) } - } else { - assert.Equal(t, 0, len(eventRecorder.taskExecutionEvents)) } nCtx.ExecutionContext().(*execmocks.ExecutionContext).AssertNumberOfCalls(t, "IncrementParallelism", int(test.incrementParallelismCount)) diff --git a/flytepropeller/pkg/controller/nodes/node_exec_context.go b/flytepropeller/pkg/controller/nodes/node_exec_context.go index a579b241f3..7de31100c6 100644 --- a/flytepropeller/pkg/controller/nodes/node_exec_context.go +++ b/flytepropeller/pkg/controller/nodes/node_exec_context.go @@ -36,6 +36,9 @@ type eventRecorder struct { func (e eventRecorder) RecordTaskEvent(ctx context.Context, ev *event.TaskExecutionEvent, eventConfig *config.EventConfig) error { if err := e.taskEventRecorder.RecordTaskEvent(ctx, ev, eventConfig); err != nil { if eventsErr.IsAlreadyExists(err) { + if eventConfig.ErrorOnAlreadyExists { + return err + } logger.Warningf(ctx, "Failed to record taskEvent, error [%s]. Trying to record state: %s. Ignoring this error!", err.Error(), ev.Phase) return nil } else if eventsErr.IsEventAlreadyInTerminalStateError(err) {