diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java index 63cbc441f8a82..0156d101ed8db 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java @@ -73,7 +73,6 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.Collections; import java.util.Deque; import java.util.HashMap; @@ -103,7 +102,6 @@ import static org.apache.kafka.test.StreamsTestUtils.TaskBuilder.statefulTask; import static org.hamcrest.CoreMatchers.hasItem; import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.core.IsEqual.equalTo; @@ -113,7 +111,6 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; @@ -2188,72 +2185,65 @@ public void shouldReAddRevivedTasksToStateUpdater() { @Test public void shouldReviveCorruptTasks() { - final ProcessorStateManager stateManager = mock(ProcessorStateManager.class); + final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions) + .withInputPartitions(taskId00Partitions) + .inState(State.RUNNING) + .build(); - final AtomicBoolean enforcedCheckpoint = new AtomicBoolean(false); - final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager) { - @Override - public void postCommit(final boolean enforceCheckpoint) { - if (enforceCheckpoint) { - enforcedCheckpoint.set(true); - } - super.postCommit(enforceCheckpoint); - } - }; + final TasksRegistry tasks = mock(TasksRegistry.class); + when(tasks.task(taskId00)).thenReturn(task00); + when(tasks.allTasksPerId()).thenReturn(singletonMap(taskId00, task00)); + when(tasks.activeTaskIds()).thenReturn(Set.of(taskId00)); - // `handleAssignment` - when(consumer.assignment()) - .thenReturn(assignment) - .thenReturn(taskId00Partitions); - when(activeTaskCreator.createTasks(any(), eq(taskId00Assignment))).thenReturn(singletonList(task00)); + when(task00.prepareCommit(false)).thenReturn(emptyMap()); + doNothing().when(task00).postCommit(anyBoolean()); + when(task00.changelogPartitions()).thenReturn(taskId00ChangelogPartitions); - taskManager.handleAssignment(taskId00Assignment, emptyMap()); - assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), tp -> assertThat(tp, is(empty()))), is(true)); - assertThat(task00.state(), is(Task.State.RUNNING)); + when(consumer.assignment()).thenReturn(taskId00Partitions); - task00.setChangelogOffsets(singletonMap(t1p0, 0L)); - taskManager.handleCorruption(singleton(taskId00)); + final TaskManager taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks); - assertThat(task00.commitPrepared, is(true)); - assertThat(task00.state(), is(Task.State.CREATED)); - assertThat(task00.partitionsForOffsetReset, equalTo(taskId00Partitions)); - assertThat(enforcedCheckpoint.get(), is(true)); - assertThat(taskManager.activeTaskMap(), is(singletonMap(taskId00, task00))); - assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + taskManager.handleCorruption(singleton(taskId00)); - verify(stateManager).markChangelogAsCorrupted(taskId00Partitions); + verify(task00).prepareCommit(false); + verify(task00).postCommit(true); + verify(task00).addPartitionsForOffsetReset(taskId00Partitions); + verify(task00).changelogPartitions(); + verify(task00).closeDirty(); + verify(task00).revive(); + verify(tasks).removeTask(task00); + verify(tasks).addPendingTasksToInit(Set.of(task00)); + verify(consumer, never()).commitSync(emptyMap()); } @Test public void shouldReviveCorruptTasksEvenIfTheyCannotCloseClean() { - final ProcessorStateManager stateManager = mock(ProcessorStateManager.class); + final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions) + .withInputPartitions(taskId00Partitions) + .inState(State.RUNNING) + .build(); - final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager) { - @Override - public void suspend() { - super.suspend(); - throw new RuntimeException("oops"); - } - }; + final TasksRegistry tasks = mock(TasksRegistry.class); + when(tasks.task(taskId00)).thenReturn(task00); + when(tasks.allTasksPerId()).thenReturn(singletonMap(taskId00, task00)); + when(tasks.activeTaskIds()).thenReturn(Set.of(taskId00)); - when(consumer.assignment()) - .thenReturn(assignment) - .thenReturn(taskId00Partitions); - when(activeTaskCreator.createTasks(any(), eq(taskId00Assignment))).thenReturn(singletonList(task00)); + when(task00.prepareCommit(false)).thenReturn(emptyMap()); + when(task00.changelogPartitions()).thenReturn(taskId00ChangelogPartitions); + doThrow(new RuntimeException("oops")).when(task00).suspend(); - taskManager.handleAssignment(taskId00Assignment, emptyMap()); - assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), tp -> assertThat(tp, is(empty()))), is(true)); - assertThat(task00.state(), is(Task.State.RUNNING)); + final TaskManager taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks); - task00.setChangelogOffsets(singletonMap(t1p0, 0L)); taskManager.handleCorruption(singleton(taskId00)); - assertThat(task00.commitPrepared, is(true)); - assertThat(task00.state(), is(Task.State.CREATED)); - assertThat(task00.partitionsForOffsetReset, equalTo(taskId00Partitions)); - assertThat(taskManager.activeTaskMap(), is(singletonMap(taskId00, task00))); - assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); - verify(stateManager).markChangelogAsCorrupted(taskId00Partitions); + verify(task00).prepareCommit(false); + verify(task00).suspend(); + verify(task00, never()).postCommit(anyBoolean()); // postCommit is NOT called + verify(task00).closeDirty(); + verify(task00).revive(); + verify(tasks).removeTask(task00); + verify(tasks).addPendingTasksToInit(Set.of(task00)); + verify(task00).addPartitionsForOffsetReset(emptySet()); } @Test @@ -2326,431 +2316,558 @@ public void shouldNotCommitNonCorruptedRestoringActiveTasksAndNotCommitRunningSt @Test public void shouldCleanAndReviveCorruptedStandbyTasksBeforeCommittingNonCorruptedTasks() { - final ProcessorStateManager stateManager = mock(ProcessorStateManager.class); - - final StateMachineTask corruptedStandby = new StateMachineTask(taskId00, taskId00Partitions, false, stateManager); - final StateMachineTask runningNonCorruptedActive = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager) { - @Override - public Map prepareCommit(final boolean clean) { - throw new TaskMigratedException("You dropped out of the group!", new RuntimeException()); - } - }; - - // handleAssignment - when(activeTaskCreator.createTasks(any(), eq(taskId01Assignment))) - .thenReturn(singleton(runningNonCorruptedActive)); - when(standbyTaskCreator.createTasks(taskId00Assignment)).thenReturn(singleton(corruptedStandby)); + final StandbyTask corruptedStandby = standbyTask(taskId00, taskId00ChangelogPartitions) + .inState(State.RUNNING) + .withInputPartitions(taskId00Partitions).build(); + final StreamTask runningNonCorruptedActive = statefulTask(taskId01, taskId01ChangelogPartitions) + .inState(State.RUNNING) + .withInputPartitions(taskId01Partitions).build(); - when(consumer.assignment()).thenReturn(assignment); + final TasksRegistry tasks = mock(TasksRegistry.class); + when(tasks.task(taskId00)).thenReturn(corruptedStandby); + when(tasks.allTasksPerId()).thenReturn(mkMap( + mkEntry(taskId00, corruptedStandby), + mkEntry(taskId01, runningNonCorruptedActive) + )); + when(tasks.activeTaskIds()).thenReturn(Set.of(taskId01)); - taskManager.handleAssignment(taskId01Assignment, taskId00Assignment); - assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + when(runningNonCorruptedActive.commitNeeded()).thenReturn(true); + when(runningNonCorruptedActive.prepareCommit(true)) + .thenThrow(new TaskMigratedException("You dropped out of the group!", new RuntimeException())); - // make sure this will be committed and throw - assertThat(runningNonCorruptedActive.state(), is(Task.State.RUNNING)); - assertThat(corruptedStandby.state(), is(Task.State.RUNNING)); + when(corruptedStandby.changelogPartitions()).thenReturn(taskId00ChangelogPartitions); + when(corruptedStandby.prepareCommit(false)).thenReturn(emptyMap()); + doNothing().when(corruptedStandby).suspend(); + doNothing().when(corruptedStandby).postCommit(anyBoolean()); - runningNonCorruptedActive.setCommitNeeded(); + final TaskManager taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks); - corruptedStandby.setChangelogOffsets(singletonMap(t1p0, 0L)); assertThrows(TaskMigratedException.class, () -> taskManager.handleCorruption(singleton(taskId00))); + // verifying the entire task lifecycle + final InOrder taskOrder = inOrder(corruptedStandby, runningNonCorruptedActive); + taskOrder.verify(corruptedStandby).prepareCommit(false); + taskOrder.verify(corruptedStandby).suspend(); + taskOrder.verify(corruptedStandby).postCommit(true); + taskOrder.verify(corruptedStandby).closeDirty(); + taskOrder.verify(corruptedStandby).revive(); + taskOrder.verify(runningNonCorruptedActive).prepareCommit(true); - assertThat(corruptedStandby.commitPrepared, is(true)); - assertThat(corruptedStandby.state(), is(Task.State.CREATED)); - verify(stateManager).markChangelogAsCorrupted(taskId00Partitions); + verify(tasks).removeTask(corruptedStandby); + verify(tasks).addPendingTasksToInit(Set.of(corruptedStandby)); } @Test public void shouldNotAttemptToCommitInHandleCorruptedDuringARebalance() { - final ProcessorStateManager stateManager = mock(ProcessorStateManager.class); - when(stateDirectory.listNonEmptyTaskDirectories()).thenReturn(new ArrayList<>()); - - final StateMachineTask corruptedActive = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager); - - // make sure this will attempt to be committed and throw - final StateMachineTask uncorruptedActive = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager); - final Map offsets = singletonMap(t1p1, new OffsetAndMetadata(0L, null)); - uncorruptedActive.setCommitNeeded(); + final StreamTask corruptedActive = statefulTask(taskId00, taskId00ChangelogPartitions) + .withInputPartitions(taskId00Partitions) + .inState(State.RUNNING) + .build(); - // handleAssignment - final Map> firstAssignement = new HashMap<>(); - firstAssignement.putAll(taskId00Assignment); - firstAssignement.putAll(taskId01Assignment); - when(activeTaskCreator.createTasks(any(), eq(firstAssignement))) - .thenReturn(asList(corruptedActive, uncorruptedActive)); + final StreamTask uncorruptedActive = statefulTask(taskId01, taskId01ChangelogPartitions) + .withInputPartitions(taskId01Partitions) + .inState(State.RUNNING) + .build(); - when(consumer.assignment()) - .thenReturn(assignment) - .thenReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions)); + final TasksRegistry tasks = mock(TasksRegistry.class); + when(tasks.task(taskId00)).thenReturn(corruptedActive); + when(tasks.allTasksPerId()).thenReturn(mkMap( + mkEntry(taskId00, corruptedActive), + mkEntry(taskId01, uncorruptedActive) + )); + when(tasks.activeTaskIds()).thenReturn(Set.of(taskId00, taskId01)); - uncorruptedActive.setCommittableOffsetsAndMetadata(offsets); + when(uncorruptedActive.commitNeeded()).thenReturn(true); + when(uncorruptedActive.prepareCommit(true)).thenReturn(emptyMap()); - taskManager.handleAssignment(firstAssignement, emptyMap()); - assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + when(corruptedActive.prepareCommit(false)).thenReturn(emptyMap()); + doNothing().when(corruptedActive).postCommit(anyBoolean()); - assertThat(uncorruptedActive.state(), is(Task.State.RUNNING)); + when(consumer.assignment()).thenReturn(taskId00Partitions); - assertThat(uncorruptedActive.commitPrepared, is(false)); - assertThat(uncorruptedActive.commitNeeded, is(true)); - assertThat(uncorruptedActive.commitCompleted, is(false)); + final TaskManager taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks); taskManager.handleRebalanceStart(singleton(topic1)); assertThat(taskManager.rebalanceInProgress(), is(true)); + taskManager.handleCorruption(singleton(taskId00)); - assertThat(uncorruptedActive.commitPrepared, is(false)); - assertThat(uncorruptedActive.commitNeeded, is(true)); - assertThat(uncorruptedActive.commitCompleted, is(false)); + verify(uncorruptedActive, never()).prepareCommit(anyBoolean()); + verify(uncorruptedActive, never()).postCommit(anyBoolean()); - assertThat(uncorruptedActive.state(), is(State.RUNNING)); + verify(corruptedActive).changelogPartitions(); + verify(corruptedActive).postCommit(true); + verify(corruptedActive).addPartitionsForOffsetReset(taskId00Partitions); + verify(consumer, never()).commitSync(emptyMap()); } + @SuppressWarnings("removal") @Test - public void shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitWithAlos() { - final ProcessorStateManager stateManager = mock(ProcessorStateManager.class); - - final StateMachineTask corruptedActive = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager); - final StateMachineTask uncorruptedActive = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager) { - @Override - public void markChangelogAsCorrupted(final Collection partitions) { - fail("Should not try to mark changelogs as corrupted for uncorrupted task"); - } - }; - final Map offsets = singletonMap(t1p1, new OffsetAndMetadata(0L, null)); - uncorruptedActive.setCommittableOffsetsAndMetadata(offsets); - - // handleAssignment - final Map> firstAssignment = new HashMap<>(); - firstAssignment.putAll(taskId00Assignment); - firstAssignment.putAll(taskId01Assignment); - when(activeTaskCreator.createTasks(any(), eq(firstAssignment))) - .thenReturn(asList(corruptedActive, uncorruptedActive)); + public void shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitDuringHandleCorruptedWithEOS() { + final StreamTask corruptedActive = statefulTask(taskId00, taskId00ChangelogPartitions) + .withInputPartitions(taskId00Partitions) + .inState(State.RUNNING) + .build(); - when(consumer.assignment()) - .thenReturn(assignment) - .thenReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions)); + // this task will time out during commit + final StreamTask uncorruptedActive = statefulTask(taskId01, taskId01ChangelogPartitions) + .withInputPartitions(taskId01Partitions) + .inState(State.RUNNING) + .build(); - doThrow(new TimeoutException()).when(consumer).commitSync(offsets); + final TasksRegistry tasks = mock(TasksRegistry.class); + when(tasks.task(taskId00)).thenReturn(corruptedActive); + when(tasks.allTasksPerId()).thenReturn(mkMap( + mkEntry(taskId00, corruptedActive), + mkEntry(taskId01, uncorruptedActive) + )); + when(tasks.activeTaskIds()).thenReturn(Set.of(taskId00, taskId01)); - taskManager.handleAssignment(firstAssignment, emptyMap()); - assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + final StreamsProducer producer = mock(StreamsProducer.class); + when(activeTaskCreator.streamsProducer()).thenReturn(producer); + final ConsumerGroupMetadata groupMetadata = new ConsumerGroupMetadata("appId"); + when(consumer.groupMetadata()).thenReturn(groupMetadata); + when(consumer.assignment()).thenReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions)); - assertThat(uncorruptedActive.state(), is(Task.State.RUNNING)); - assertThat(corruptedActive.state(), is(Task.State.RUNNING)); + // mock uncorrupted task to indicate that it needs commit and will return offsets + final Map offsets = singletonMap(t1p1, new OffsetAndMetadata(0L, null)); + when(tasks.tasks(singleton(taskId01))).thenReturn(Set.of(uncorruptedActive)); + when(uncorruptedActive.commitNeeded()).thenReturn(true); + when(uncorruptedActive.prepareCommit(true)).thenReturn(offsets); + when(uncorruptedActive.prepareCommit(false)).thenReturn(emptyMap()); + when(uncorruptedActive.changelogPartitions()).thenReturn(taskId01ChangelogPartitions); + doNothing().when(uncorruptedActive).suspend(); + doNothing().when(uncorruptedActive).closeDirty(); + doNothing().when(uncorruptedActive).revive(); + doNothing().when(uncorruptedActive).markChangelogAsCorrupted(taskId01ChangelogPartitions); + + // corrupted task doesn't need commit + when(corruptedActive.commitNeeded()).thenReturn(false); + when(corruptedActive.prepareCommit(false)).thenReturn(emptyMap()); + when(corruptedActive.changelogPartitions()).thenReturn(taskId00ChangelogPartitions); + doNothing().when(corruptedActive).suspend(); + doNothing().when(corruptedActive).postCommit(true); + doNothing().when(corruptedActive).closeDirty(); + doNothing().when(corruptedActive).revive(); - // make sure this will be committed and throw - uncorruptedActive.setCommitNeeded(); - corruptedActive.setChangelogOffsets(singletonMap(t1p0, 0L)); + doThrow(new TimeoutException()).when(producer).commitTransaction(offsets, groupMetadata); - assertThat(uncorruptedActive.commitPrepared, is(false)); - assertThat(uncorruptedActive.commitNeeded, is(true)); - assertThat(uncorruptedActive.commitCompleted, is(false)); - assertThat(corruptedActive.commitPrepared, is(false)); - assertThat(corruptedActive.commitNeeded, is(false)); - assertThat(corruptedActive.commitCompleted, is(false)); + final TaskManager taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, tasks); taskManager.handleCorruption(singleton(taskId00)); - assertThat(uncorruptedActive.commitPrepared, is(true)); - assertThat(uncorruptedActive.commitNeeded, is(false)); - assertThat(uncorruptedActive.commitCompleted, is(false)); //if not corrupted, we should close dirty without committing - assertThat(corruptedActive.commitPrepared, is(true)); - assertThat(corruptedActive.commitNeeded, is(false)); - assertThat(corruptedActive.commitCompleted, is(true)); //if corrupted, should enforce checkpoint with corrupted tasks removed - - assertThat(corruptedActive.state(), is(Task.State.CREATED)); - assertThat(uncorruptedActive.state(), is(Task.State.CREATED)); - verify(stateManager).markChangelogAsCorrupted(taskId00Partitions); + // 1. verify corrupted task was closed dirty and revived + final InOrder corruptedOrder = inOrder(corruptedActive, tasks); + corruptedOrder.verify(corruptedActive).prepareCommit(false); + corruptedOrder.verify(corruptedActive).suspend(); + corruptedOrder.verify(corruptedActive).postCommit(true); + corruptedOrder.verify(corruptedActive).closeDirty(); + corruptedOrder.verify(tasks).removeTask(corruptedActive); + corruptedOrder.verify(corruptedActive).revive(); + corruptedOrder.verify(tasks).addPendingTasksToInit(Set.of(corruptedActive)); + + // 2. verify uncorrupted task attempted commit, failed with timeout, then was closed dirty and revived + final InOrder uncorruptedOrder = inOrder(uncorruptedActive, producer, tasks); + uncorruptedOrder.verify(uncorruptedActive).prepareCommit(true); + uncorruptedOrder.verify(producer).commitTransaction(offsets, groupMetadata); // tries to commit, throws TimeoutException + uncorruptedOrder.verify(uncorruptedActive).suspend(); + uncorruptedOrder.verify(uncorruptedActive).postCommit(true); + uncorruptedOrder.verify(uncorruptedActive).closeDirty(); + uncorruptedOrder.verify(tasks).removeTask(uncorruptedActive); + uncorruptedOrder.verify(uncorruptedActive).revive(); + uncorruptedOrder.verify(tasks).addPendingTasksToInit(Set.of(uncorruptedActive)); + + // verify both tasks had their input partitions reset + verify(corruptedActive).addPartitionsForOffsetReset(taskId00Partitions); + verify(uncorruptedActive).addPartitionsForOffsetReset(taskId01Partitions); } - @SuppressWarnings("removal") @Test - public void shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitDuringHandleCorruptedWithEOS() { - final TaskManager taskManager = setUpTaskManagerWithoutStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, null, false); - final StreamsProducer producer = mock(StreamsProducer.class); - when(activeTaskCreator.streamsProducer()).thenReturn(producer); - final ProcessorStateManager stateManager = mock(ProcessorStateManager.class); - - final AtomicBoolean corruptedTaskChangelogMarkedAsCorrupted = new AtomicBoolean(false); - final StateMachineTask corruptedActiveTask = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager) { - @Override - public void markChangelogAsCorrupted(final Collection partitions) { - super.markChangelogAsCorrupted(partitions); - corruptedTaskChangelogMarkedAsCorrupted.set(true); - } - }; - - final AtomicBoolean uncorruptedTaskChangelogMarkedAsCorrupted = new AtomicBoolean(false); - final StateMachineTask uncorruptedActiveTask = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager) { - @Override - public void markChangelogAsCorrupted(final Collection partitions) { - super.markChangelogAsCorrupted(partitions); - uncorruptedTaskChangelogMarkedAsCorrupted.set(true); - } - }; - final Map offsets = singletonMap(t1p1, new OffsetAndMetadata(0L, null)); - uncorruptedActiveTask.setCommittableOffsetsAndMetadata(offsets); - - // handleAssignment - final Map> firstAssignment = new HashMap<>(); - firstAssignment.putAll(taskId00Assignment); - firstAssignment.putAll(taskId01Assignment); - when(activeTaskCreator.createTasks(any(), eq(firstAssignment))) - .thenReturn(asList(corruptedActiveTask, uncorruptedActiveTask)); - - when(consumer.assignment()) - .thenReturn(assignment) - .thenReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions)); - - final ConsumerGroupMetadata groupMetadata = new ConsumerGroupMetadata("appId"); - when(consumer.groupMetadata()).thenReturn(groupMetadata); - - doThrow(new TimeoutException()).when(producer).commitTransaction(offsets, groupMetadata); + public void shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitWithAlos() { + final StreamTask corruptedActive = statefulTask(taskId00, taskId00ChangelogPartitions) + .withInputPartitions(taskId00Partitions) + .inState(State.RUNNING) + .build(); - taskManager.handleAssignment(firstAssignment, emptyMap()); - assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + // this task will time out during commit + final StreamTask uncorruptedActive = statefulTask(taskId01, taskId01ChangelogPartitions) + .withInputPartitions(taskId01Partitions) + .inState(State.RUNNING) + .build(); - assertThat(uncorruptedActiveTask.state(), is(Task.State.RUNNING)); - assertThat(corruptedActiveTask.state(), is(Task.State.RUNNING)); + final TasksRegistry tasks = mock(TasksRegistry.class); + when(tasks.task(taskId00)).thenReturn(corruptedActive); + when(tasks.allTasksPerId()).thenReturn(mkMap( + mkEntry(taskId00, corruptedActive), + mkEntry(taskId01, uncorruptedActive) + )); + when(tasks.activeTaskIds()).thenReturn(Set.of(taskId00, taskId01)); + when(tasks.activeTasks()).thenReturn(Set.of(corruptedActive, uncorruptedActive)); - // make sure this will be committed and throw - uncorruptedActiveTask.setCommitNeeded(); + // we need to mock uncorrupted task to indicate that it needs commit and will return offsets + final Map offsets = singletonMap(t1p1, new OffsetAndMetadata(0L, null)); + when(uncorruptedActive.commitNeeded()).thenReturn(true); + when(uncorruptedActive.prepareCommit(true)).thenReturn(offsets); + when(uncorruptedActive.changelogPartitions()).thenReturn(taskId01ChangelogPartitions); + doNothing().when(uncorruptedActive).suspend(); + doNothing().when(uncorruptedActive).closeDirty(); + doNothing().when(uncorruptedActive).revive(); + + // corrupted task doesn't need commit + when(corruptedActive.commitNeeded()).thenReturn(false); + when(corruptedActive.prepareCommit(false)).thenReturn(emptyMap()); + when(corruptedActive.changelogPartitions()).thenReturn(taskId00ChangelogPartitions); + doNothing().when(corruptedActive).suspend(); + doNothing().when(corruptedActive).postCommit(anyBoolean()); + doNothing().when(corruptedActive).closeDirty(); + doNothing().when(corruptedActive).revive(); - final Map corruptedActiveTaskChangelogOffsets = singletonMap(t1p0changelog, 0L); - corruptedActiveTask.setChangelogOffsets(corruptedActiveTaskChangelogOffsets); - final Map uncorruptedActiveTaskChangelogOffsets = singletonMap(t1p1changelog, 0L); - uncorruptedActiveTask.setChangelogOffsets(uncorruptedActiveTaskChangelogOffsets); + doThrow(new TimeoutException()).when(consumer).commitSync(offsets); + when(consumer.assignment()).thenReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions)); - assertThat(uncorruptedActiveTask.commitPrepared, is(false)); - assertThat(uncorruptedActiveTask.commitNeeded, is(true)); - assertThat(uncorruptedActiveTask.commitCompleted, is(false)); - assertThat(corruptedActiveTask.commitPrepared, is(false)); - assertThat(corruptedActiveTask.commitNeeded, is(false)); - assertThat(corruptedActiveTask.commitCompleted, is(false)); + final TaskManager taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks); taskManager.handleCorruption(singleton(taskId00)); - assertThat(uncorruptedActiveTask.commitPrepared, is(true)); - assertThat(uncorruptedActiveTask.commitNeeded, is(false)); - assertThat(uncorruptedActiveTask.commitCompleted, is(true)); //if corrupted due to timeout on commit, should enforce checkpoint with corrupted tasks removed - assertThat(corruptedActiveTask.commitPrepared, is(true)); - assertThat(corruptedActiveTask.commitNeeded, is(false)); - assertThat(corruptedActiveTask.commitCompleted, is(true)); //if corrupted, should enforce checkpoint with corrupted tasks removed - - assertThat(corruptedActiveTask.state(), is(Task.State.CREATED)); - assertThat(uncorruptedActiveTask.state(), is(Task.State.CREATED)); - assertThat(corruptedTaskChangelogMarkedAsCorrupted.get(), is(true)); - assertThat(uncorruptedTaskChangelogMarkedAsCorrupted.get(), is(true)); - verify(stateManager).markChangelogAsCorrupted(taskId00ChangelogPartitions); - verify(stateManager).markChangelogAsCorrupted(taskId01ChangelogPartitions); + // 1. verify corrupted task was closed dirty and revived + final InOrder corruptedOrder = inOrder(corruptedActive, tasks); + corruptedOrder.verify(corruptedActive).prepareCommit(false); + corruptedOrder.verify(corruptedActive).suspend(); + corruptedOrder.verify(corruptedActive).postCommit(true); + corruptedOrder.verify(corruptedActive).closeDirty(); + corruptedOrder.verify(tasks).removeTask(corruptedActive); + corruptedOrder.verify(corruptedActive).revive(); + corruptedOrder.verify(tasks).addPendingTasksToInit(Set.of(corruptedActive)); + + // 2. verify uncorrupted task attempted commit, failed with timeout, then was closed dirty and revived + final InOrder uncorruptedOrder = inOrder(uncorruptedActive, consumer, tasks); + uncorruptedOrder.verify(uncorruptedActive).prepareCommit(true); + uncorruptedOrder.verify(consumer).commitSync(offsets); // attempt commit, throws TimeoutException + uncorruptedOrder.verify(uncorruptedActive).prepareCommit(false); + uncorruptedOrder.verify(uncorruptedActive).suspend(); + uncorruptedOrder.verify(uncorruptedActive).closeDirty(); + uncorruptedOrder.verify(tasks).removeTask(uncorruptedActive); + uncorruptedOrder.verify(uncorruptedActive).revive(); + uncorruptedOrder.verify(tasks).addPendingTasksToInit(Set.of(uncorruptedActive)); + + // verify both tasks had their input partitions reset + verify(corruptedActive).addPartitionsForOffsetReset(taskId00Partitions); + verify(uncorruptedActive).addPartitionsForOffsetReset(taskId01Partitions); } @Test public void shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitDuringRevocationWithAlos() { - final StateMachineTask revokedActiveTask = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager); - final Map offsets00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); - revokedActiveTask.setCommittableOffsetsAndMetadata(offsets00); - revokedActiveTask.setCommitNeeded(); + // task being revoked - needs commit + final StreamTask revokedActiveTask = statefulTask(taskId00, taskId00ChangelogPartitions) + .withInputPartitions(taskId00Partitions) + .inState(State.RUNNING) + .build(); - final StateMachineTask unrevokedActiveTaskWithCommitNeeded = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager) { - @Override - public void markChangelogAsCorrupted(final Collection partitions) { - fail("Should not try to mark changelogs as corrupted for uncorrupted task"); - } - }; - final Map offsets01 = singletonMap(t1p1, new OffsetAndMetadata(1L, null)); - unrevokedActiveTaskWithCommitNeeded.setCommittableOffsetsAndMetadata(offsets01); - unrevokedActiveTaskWithCommitNeeded.setCommitNeeded(); + // unrevoked task that needs commit - this will also be affected by timeout + final StreamTask unrevokedActiveTaskWithCommit = statefulTask(taskId01, taskId01ChangelogPartitions) + .withInputPartitions(taskId01Partitions) + .inState(State.RUNNING) + .build(); - final StateMachineTask unrevokedActiveTaskWithoutCommitNeeded = new StateMachineTask(taskId02, taskId02Partitions, true, stateManager); + // unrevoked task without commit needed - this should stay RUNNING + final StreamTask unrevokedActiveTaskWithoutCommit = statefulTask(taskId02, taskId02ChangelogPartitions) + .withInputPartitions(taskId02Partitions) + .inState(State.RUNNING) + .build(); - final Map expectedCommittedOffsets = new HashMap<>(); - expectedCommittedOffsets.putAll(offsets00); - expectedCommittedOffsets.putAll(offsets01); + final TasksRegistry tasks = mock(TasksRegistry.class); + when(tasks.allTasks()).thenReturn(Set.of(revokedActiveTask, unrevokedActiveTaskWithCommit, unrevokedActiveTaskWithoutCommit)); - final Map> assignmentActive = mkMap( - mkEntry(taskId00, taskId00Partitions), - mkEntry(taskId01, taskId01Partitions), - mkEntry(taskId02, taskId02Partitions) - ); + when(consumer.assignment()).thenReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions, taskId02Partitions)); - when(consumer.assignment()) - .thenReturn(assignment) - .thenReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions, taskId02Partitions)); + // revoked task needs commit + final Map revokedTaskOffsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + when(revokedActiveTask.commitNeeded()).thenReturn(true); + when(revokedActiveTask.prepareCommit(true)).thenReturn(revokedTaskOffsets); + when(revokedActiveTask.changelogPartitions()).thenReturn(taskId00ChangelogPartitions); + doNothing().when(revokedActiveTask).suspend(); + doNothing().when(revokedActiveTask).closeDirty(); + doNothing().when(revokedActiveTask).revive(); - when(activeTaskCreator.createTasks(any(), eq(assignmentActive))) - .thenReturn(asList(revokedActiveTask, unrevokedActiveTaskWithCommitNeeded, unrevokedActiveTaskWithoutCommitNeeded)); + // unrevoked task with commit also takes part in commit + final Map unrevokedTaskOffsets = singletonMap(t1p1, new OffsetAndMetadata(1L, null)); + when(unrevokedActiveTaskWithCommit.commitNeeded()).thenReturn(true); + when(unrevokedActiveTaskWithCommit.prepareCommit(true)).thenReturn(unrevokedTaskOffsets); + when(unrevokedActiveTaskWithCommit.changelogPartitions()).thenReturn(taskId01ChangelogPartitions); + doNothing().when(unrevokedActiveTaskWithCommit).suspend(); + doNothing().when(unrevokedActiveTaskWithCommit).closeDirty(); + doNothing().when(unrevokedActiveTaskWithCommit).revive(); + + // unrevoked task without commit needed + when(unrevokedActiveTaskWithoutCommit.commitNeeded()).thenReturn(false); + // mock timeout during commit - all offsets from tasks needing commit + final Map expectedCommittedOffsets = new HashMap<>(); + expectedCommittedOffsets.putAll(revokedTaskOffsets); + expectedCommittedOffsets.putAll(unrevokedTaskOffsets); doThrow(new TimeoutException()).when(consumer).commitSync(expectedCommittedOffsets); - taskManager.handleAssignment(assignmentActive, emptyMap()); - assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); - assertThat(revokedActiveTask.state(), is(Task.State.RUNNING)); - assertThat(unrevokedActiveTaskWithCommitNeeded.state(), is(State.RUNNING)); - assertThat(unrevokedActiveTaskWithoutCommitNeeded.state(), is(Task.State.RUNNING)); + final TaskManager taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks); taskManager.handleRevocation(taskId00Partitions); - assertThat(revokedActiveTask.state(), is(State.SUSPENDED)); - assertThat(unrevokedActiveTaskWithCommitNeeded.state(), is(State.CREATED)); - assertThat(unrevokedActiveTaskWithoutCommitNeeded.state(), is(State.RUNNING)); + // 1. verify that the revoked task was suspended, closed dirty, and revived + final InOrder revokedOrder = inOrder(revokedActiveTask, tasks); + revokedOrder.verify(revokedActiveTask).prepareCommit(true); + revokedOrder.verify(revokedActiveTask).suspend(); + revokedOrder.verify(revokedActiveTask).closeDirty(); + revokedOrder.verify(tasks).removeTask(revokedActiveTask); + revokedOrder.verify(revokedActiveTask).revive(); + revokedOrder.verify(tasks).addPendingTasksToInit(argThat(set -> set.contains(revokedActiveTask))); + + // 2. verify that the unrevoked task with commit also tried to commit and was closed dirty due to timeout + final InOrder unrevokedOrder = inOrder(unrevokedActiveTaskWithCommit, consumer, tasks); + unrevokedOrder.verify(unrevokedActiveTaskWithCommit).prepareCommit(true); + unrevokedOrder.verify(consumer).commitSync(expectedCommittedOffsets); // timeout thrown here + unrevokedOrder.verify(unrevokedActiveTaskWithCommit).suspend(); + unrevokedOrder.verify(unrevokedActiveTaskWithCommit).closeDirty(); + unrevokedOrder.verify(tasks).removeTask(unrevokedActiveTaskWithCommit); + unrevokedOrder.verify(unrevokedActiveTaskWithCommit).revive(); + unrevokedOrder.verify(tasks).addPendingTasksToInit(argThat(set -> set.contains(unrevokedActiveTaskWithCommit))); + + // 3. verify that the unrevoked task without commit needed was not affected + verify(unrevokedActiveTaskWithoutCommit, never()).prepareCommit(anyBoolean()); + verify(unrevokedActiveTaskWithoutCommit, never()).suspend(); + verify(unrevokedActiveTaskWithoutCommit, never()).closeDirty(); + + // input partitions were reset for affected tasks + verify(revokedActiveTask).addPartitionsForOffsetReset(taskId00Partitions); + verify(unrevokedActiveTaskWithCommit).addPartitionsForOffsetReset(taskId01Partitions); + verify(unrevokedActiveTaskWithoutCommit, never()).addPartitionsForOffsetReset(any()); } @SuppressWarnings("removal") @Test public void shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitDuringRevocationWithEOS() { - final TaskManager taskManager = setUpTaskManagerWithoutStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, null, false); - final StreamsProducer producer = mock(StreamsProducer.class); - when(activeTaskCreator.streamsProducer()).thenReturn(producer); - final ProcessorStateManager stateManager = mock(ProcessorStateManager.class); - - final StateMachineTask revokedActiveTask = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager); - final Map revokedActiveTaskOffsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); - revokedActiveTask.setCommittableOffsetsAndMetadata(revokedActiveTaskOffsets); - revokedActiveTask.setCommitNeeded(); - - final AtomicBoolean unrevokedTaskChangelogMarkedAsCorrupted = new AtomicBoolean(false); - final StateMachineTask unrevokedActiveTask = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager) { - @Override - public void markChangelogAsCorrupted(final Collection partitions) { - super.markChangelogAsCorrupted(partitions); - unrevokedTaskChangelogMarkedAsCorrupted.set(true); - } - }; - final Map unrevokedTaskOffsets = singletonMap(t1p1, new OffsetAndMetadata(1L, null)); - unrevokedActiveTask.setCommittableOffsetsAndMetadata(unrevokedTaskOffsets); - unrevokedActiveTask.setCommitNeeded(); - - final StateMachineTask unrevokedActiveTaskWithoutCommitNeeded = new StateMachineTask(taskId02, taskId02Partitions, true, stateManager); - - final Map expectedCommittedOffsets = new HashMap<>(); - expectedCommittedOffsets.putAll(revokedActiveTaskOffsets); - expectedCommittedOffsets.putAll(unrevokedTaskOffsets); + // task being revoked - needs commit + final StreamTask revokedActiveTask = statefulTask(taskId00, taskId00ChangelogPartitions) + .withInputPartitions(taskId00Partitions) + .inState(State.RUNNING) + .build(); - final Map> assignmentActive = mkMap( - mkEntry(taskId00, taskId00Partitions), - mkEntry(taskId01, taskId01Partitions), - mkEntry(taskId02, taskId02Partitions) - ); + // unrevoked task that needs commit - this will also be affected by timeout + final StreamTask unrevokedActiveTaskWithCommit = statefulTask(taskId01, taskId01ChangelogPartitions) + .withInputPartitions(taskId01Partitions) + .inState(State.RUNNING) + .build(); - when(consumer.assignment()) - .thenReturn(assignment) - .thenReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions, taskId02Partitions)); + // unrevoked task without commit needed - this should remain RUNNING + final StreamTask unrevokedActiveTaskWithoutCommit = statefulTask(taskId02, taskId02ChangelogPartitions) + .withInputPartitions(taskId02Partitions) + .inState(State.RUNNING) + .build(); - when(activeTaskCreator.createTasks(any(), eq(assignmentActive))) - .thenReturn(asList(revokedActiveTask, unrevokedActiveTask, unrevokedActiveTaskWithoutCommitNeeded)); + final TasksRegistry tasks = mock(TasksRegistry.class); + when(tasks.allTasks()).thenReturn(Set.of(revokedActiveTask, unrevokedActiveTaskWithCommit, unrevokedActiveTaskWithoutCommit)); + when(tasks.tasks(Set.of(taskId00, taskId01))).thenReturn(Set.of(revokedActiveTask, unrevokedActiveTaskWithCommit)); + final StreamsProducer producer = mock(StreamsProducer.class); + when(activeTaskCreator.streamsProducer()).thenReturn(producer); final ConsumerGroupMetadata groupMetadata = new ConsumerGroupMetadata("appId"); when(consumer.groupMetadata()).thenReturn(groupMetadata); - + when(consumer.assignment()).thenReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions, taskId02Partitions)); + + // revoked task needs commit + final Map revokedTaskOffsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + when(revokedActiveTask.commitNeeded()).thenReturn(true); + when(revokedActiveTask.prepareCommit(true)).thenReturn(revokedTaskOffsets); + when(revokedActiveTask.changelogPartitions()).thenReturn(taskId00ChangelogPartitions); + doNothing().when(revokedActiveTask).suspend(); + doNothing().when(revokedActiveTask).closeDirty(); + doNothing().when(revokedActiveTask).revive(); + doNothing().when(revokedActiveTask).markChangelogAsCorrupted(taskId00ChangelogPartitions); + + // unrevoked task with commit also takes part in EOS-v2 commit + final Map unrevokedTaskOffsets = singletonMap(t1p1, new OffsetAndMetadata(1L, null)); + when(unrevokedActiveTaskWithCommit.commitNeeded()).thenReturn(true); + when(unrevokedActiveTaskWithCommit.prepareCommit(true)).thenReturn(unrevokedTaskOffsets); + when(unrevokedActiveTaskWithCommit.changelogPartitions()).thenReturn(taskId01ChangelogPartitions); + doNothing().when(unrevokedActiveTaskWithCommit).suspend(); + doNothing().when(unrevokedActiveTaskWithCommit).closeDirty(); + doNothing().when(unrevokedActiveTaskWithCommit).revive(); + doNothing().when(unrevokedActiveTaskWithCommit).markChangelogAsCorrupted(taskId01ChangelogPartitions); + + // unrevoked task without commit needed + when(unrevokedActiveTaskWithoutCommit.commitNeeded()).thenReturn(false); + + // mock timeout during commit - all offsets from tasks needing commit + final Map expectedCommittedOffsets = new HashMap<>(); + expectedCommittedOffsets.putAll(revokedTaskOffsets); + expectedCommittedOffsets.putAll(unrevokedTaskOffsets); doThrow(new TimeoutException()).when(producer).commitTransaction(expectedCommittedOffsets, groupMetadata); - taskManager.handleAssignment(assignmentActive, emptyMap()); - assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); - assertThat(revokedActiveTask.state(), is(Task.State.RUNNING)); - assertThat(unrevokedActiveTask.state(), is(Task.State.RUNNING)); - assertThat(unrevokedActiveTaskWithoutCommitNeeded.state(), is(State.RUNNING)); - - final Map revokedActiveTaskChangelogOffsets = singletonMap(t1p0changelog, 0L); - revokedActiveTask.setChangelogOffsets(revokedActiveTaskChangelogOffsets); - final Map unrevokedActiveTaskChangelogOffsets = singletonMap(t1p1changelog, 0L); - unrevokedActiveTask.setChangelogOffsets(unrevokedActiveTaskChangelogOffsets); + final TaskManager taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, tasks); taskManager.handleRevocation(taskId00Partitions); - assertThat(unrevokedTaskChangelogMarkedAsCorrupted.get(), is(true)); - assertThat(revokedActiveTask.state(), is(State.SUSPENDED)); - assertThat(unrevokedActiveTask.state(), is(State.CREATED)); - assertThat(unrevokedActiveTaskWithoutCommitNeeded.state(), is(State.RUNNING)); - verify(stateManager).markChangelogAsCorrupted(taskId00ChangelogPartitions); - verify(stateManager).markChangelogAsCorrupted(taskId01ChangelogPartitions); + // 1. verify that the revoked task was suspended, closed dirty, and revived + final InOrder revokedOrder = inOrder(revokedActiveTask, tasks); + revokedOrder.verify(revokedActiveTask).prepareCommit(true); + revokedOrder.verify(revokedActiveTask).suspend(); + revokedOrder.verify(revokedActiveTask).closeDirty(); + revokedOrder.verify(tasks).removeTask(revokedActiveTask); + revokedOrder.verify(revokedActiveTask).revive(); + revokedOrder.verify(tasks).addPendingTasksToInit(argThat(set -> set.contains(revokedActiveTask))); + + // 2. verify that the unrevoked task with commit also tried to commit and was closed dirty due to timeout + final InOrder unrevokedOrder = inOrder(unrevokedActiveTaskWithCommit, producer, tasks); + unrevokedOrder.verify(unrevokedActiveTaskWithCommit).prepareCommit(true); + unrevokedOrder.verify(producer).commitTransaction(expectedCommittedOffsets, groupMetadata); // timeout thrown here + unrevokedOrder.verify(unrevokedActiveTaskWithCommit).suspend(); + unrevokedOrder.verify(unrevokedActiveTaskWithCommit).closeDirty(); + unrevokedOrder.verify(tasks).removeTask(unrevokedActiveTaskWithCommit); + unrevokedOrder.verify(unrevokedActiveTaskWithCommit).revive(); + unrevokedOrder.verify(tasks).addPendingTasksToInit(argThat(set -> set.contains(unrevokedActiveTaskWithCommit))); + + // 3. verify that the unrevoked task without commit needed was not affected + verify(unrevokedActiveTaskWithoutCommit, never()).prepareCommit(anyBoolean()); + verify(unrevokedActiveTaskWithoutCommit, never()).suspend(); + verify(unrevokedActiveTaskWithoutCommit, never()).closeDirty(); + + // verify input partitions were reset for affected tasks + verify(revokedActiveTask).addPartitionsForOffsetReset(taskId00Partitions); + verify(unrevokedActiveTaskWithCommit).addPartitionsForOffsetReset(taskId01Partitions); + verify(unrevokedActiveTaskWithoutCommit, never()).addPartitionsForOffsetReset(any()); } @Test public void shouldCloseStandbyUnassignedTasksWhenCreatingNewTasks() { - final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, false, stateManager); + final StandbyTask task00 = standbyTask(taskId00, taskId00ChangelogPartitions) + .inState(State.RUNNING) + .withInputPartitions(taskId00Partitions) + .build(); - when(consumer.assignment()).thenReturn(assignment); - when(standbyTaskCreator.createTasks(taskId00Assignment)).thenReturn(singletonList(task00)); + final TasksRegistry tasks = mock(TasksRegistry.class); + when(tasks.drainPendingTasksToInit()).thenReturn(emptySet()); - taskManager.handleAssignment(emptyMap(), taskId00Assignment); - assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); - assertThat(task00.state(), is(Task.State.RUNNING)); + taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks); + + when(stateUpdater.tasks()).thenReturn(Set.of(task00)); + + // mock future for removing task from StateUpdater + final CompletableFuture future = new CompletableFuture<>(); + when(stateUpdater.remove(task00.id())).thenReturn(future); + future.complete(new StateUpdater.RemovedTaskResult(task00)); taskManager.handleAssignment(emptyMap(), emptyMap()); - assertThat(task00.state(), is(Task.State.CLOSED)); - assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap()); - assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + + verify(stateUpdater).remove(task00.id()); + verify(task00).suspend(); + verify(task00).closeClean(); + + verify(activeTaskCreator).createTasks(any(), eq(emptyMap())); + verify(standbyTaskCreator).createTasks(emptyMap()); } @Test public void shouldAddNonResumedSuspendedTasks() { - final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager); - final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, false, stateManager); + final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions) + .withInputPartitions(taskId00Partitions) + .inState(State.RUNNING) + .build(); + final StandbyTask task01 = standbyTask(taskId01, taskId01ChangelogPartitions) + .withInputPartitions(taskId01Partitions) + .inState(State.RUNNING) + .build(); - when(consumer.assignment()).thenReturn(assignment); - when(activeTaskCreator.createTasks(any(), eq(taskId00Assignment))).thenReturn(singletonList(task00)); - when(standbyTaskCreator.createTasks(taskId01Assignment)).thenReturn(singletonList(task01)); + final TasksRegistry tasks = mock(TasksRegistry.class); - taskManager.handleAssignment(taskId00Assignment, taskId01Assignment); - assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); - assertThat(task00.state(), is(Task.State.RUNNING)); - assertThat(task01.state(), is(Task.State.RUNNING)); + when(tasks.allNonFailedTasks()).thenReturn(Set.of(task00)); + + when(tasks.drainPendingTasksToInit()).thenReturn(emptySet()); + when(tasks.hasPendingTasksToInit()).thenReturn(false); + + taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks); + + when(stateUpdater.tasks()).thenReturn(Set.of(task01)); + when(stateUpdater.restoresActiveTasks()).thenReturn(false); + when(stateUpdater.hasExceptionsAndFailedTasks()).thenReturn(false); taskManager.handleAssignment(taskId00Assignment, taskId01Assignment); - assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); - assertThat(task00.state(), is(Task.State.RUNNING)); - assertThat(task01.state(), is(Task.State.RUNNING)); - // expect these calls twice (because we're going to tryToCompleteRestoration twice) + // checkStateUpdater should return true (all tasks ready, no pending work) + assertTrue(taskManager.checkStateUpdater(time.milliseconds(), noOpResetter)); + + verify(stateUpdater, never()).add(any(Task.class)); verify(activeTaskCreator).createTasks(any(), eq(emptyMap())); - verify(consumer, times(2)).assignment(); - verify(consumer, times(2)).resume(assignment); + verify(standbyTaskCreator).createTasks(emptyMap()); + + // verify idempotence + taskManager.handleAssignment(taskId00Assignment, taskId01Assignment); + assertTrue(taskManager.checkStateUpdater(time.milliseconds(), noOpResetter)); + verify(stateUpdater, never()).add(any(Task.class)); } @Test public void shouldUpdateInputPartitionsAfterRebalance() { - final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager); + final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions) + .withInputPartitions(taskId00Partitions) + .inState(State.RUNNING) + .build(); - when(consumer.assignment()).thenReturn(assignment); - when(activeTaskCreator.createTasks(any(), eq(taskId00Assignment))).thenReturn(singletonList(task00)); + final TasksRegistry tasks = mock(TasksRegistry.class); + final Set newPartitionsSet = Set.of(t1p1); - taskManager.handleAssignment(taskId00Assignment, emptyMap()); - assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); - assertThat(task00.state(), is(Task.State.RUNNING)); + when(tasks.allNonFailedTasks()).thenReturn(Set.of(task00)); + when(tasks.drainPendingTasksToInit()).thenReturn(emptySet()); + when(tasks.hasPendingTasksToInit()).thenReturn(false); + when(tasks.updateActiveTaskInputPartitions(task00, newPartitionsSet)).thenReturn(true); + + taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks); + + when(stateUpdater.tasks()).thenReturn(emptySet()); + when(stateUpdater.restoresActiveTasks()).thenReturn(false); + when(stateUpdater.hasExceptionsAndFailedTasks()).thenReturn(false); - final Set newPartitionsSet = Set.of(t1p1); final Map> taskIdSetMap = singletonMap(taskId00, newPartitionsSet); taskManager.handleAssignment(taskIdSetMap, emptyMap()); - assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + verify(task00).updateInputPartitions(eq(newPartitionsSet), any()); + assertTrue(taskManager.checkStateUpdater(time.milliseconds(), noOpResetter)); assertThat(task00.state(), is(Task.State.RUNNING)); - assertEquals(newPartitionsSet, task00.inputPartitions()); - // expect these calls twice (because we're going to tryToCompleteRestoration twice) - verify(consumer, times(2)).resume(assignment); - verify(consumer, times(2)).assignment(); verify(activeTaskCreator).createTasks(any(), eq(emptyMap())); + verify(standbyTaskCreator).createTasks(emptyMap()); } @Test public void shouldAddNewActiveTasks() { + // task in created state + final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions) + .inState(State.CREATED) + .withInputPartitions(taskId00Partitions) + .build(); + final Map> assignment = taskId00Assignment; - final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager); + final TasksRegistry tasks = mock(TasksRegistry.class); + final TaskManager taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks); + // first, we need to handle assignment -- creates tasks and adds to pending initialization when(activeTaskCreator.createTasks(any(), eq(assignment))).thenReturn(singletonList(task00)); taskManager.handleAssignment(assignment, emptyMap()); - assertThat(task00.state(), is(Task.State.CREATED)); + verify(tasks).addPendingTasksToInit(singletonList(task00)); - taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter -> { }); + // next, drain pending tasks, initialize them, and then add to stateupdater + when(tasks.drainPendingTasksToInit()).thenReturn(Set.of(task00)); - assertThat(task00.state(), is(Task.State.RUNNING)); - assertThat(taskManager.activeTaskMap(), Matchers.equalTo(singletonMap(taskId00, task00))); - assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); - verify(changeLogReader).enforceRestoreActive(); - verify(consumer).assignment(); - verify(consumer).resume(eq(emptySet())); + taskManager.checkStateUpdater(time.milliseconds(), noOpResetter); + + verify(task00).initializeIfNeeded(); + verify(stateUpdater).add(task00); + + // last, drain the restored tasks from stateupdater and transition to running + when(stateUpdater.restoresActiveTasks()).thenReturn(true); + when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(Set.of(task00)); + + taskManager.checkStateUpdater(time.milliseconds(), noOpResetter); + + verifyTransitionToRunningOfRestoredTask(Set.of(task00), tasks); } @Test @@ -2842,70 +2959,84 @@ public void shouldSuspendActiveTasksDuringRevocation() { @SuppressWarnings("removal") @Test public void shouldCommitAllActiveTasksThatNeedCommittingOnHandleRevocationWithEosV2() { + // task being revoked, needs commit + final StreamTask task00 = statefulTask(taskId00, taskId00ChangelogPartitions) + .withInputPartitions(taskId00Partitions) + .inState(State.RUNNING) + .build(); + + // unrevoked task that needs commit, this should also be committed with EOS-v2 + final StreamTask task01 = statefulTask(taskId01, taskId01ChangelogPartitions) + .withInputPartitions(taskId01Partitions) + .inState(State.RUNNING) + .build(); + + // unrevoked task that doesn't need commit, should not be committed + final StreamTask task02 = statefulTask(taskId02, taskId02ChangelogPartitions) + .withInputPartitions(taskId02Partitions) + .inState(State.RUNNING) + .build(); + + // standby task should not be committed + final StandbyTask task10 = standbyTask(taskId10, emptySet()) + .withInputPartitions(taskId10Partitions) + .inState(State.RUNNING) + .build(); + + final TasksRegistry tasks = mock(TasksRegistry.class); + + when(tasks.allTasks()).thenReturn(Set.of(task00, task01, task02, task10)); + final StreamsProducer producer = mock(StreamsProducer.class); - final TaskManager taskManager = setUpTaskManagerWithoutStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, null, false); + when(activeTaskCreator.streamsProducer()).thenReturn(producer); + final ConsumerGroupMetadata groupMetadata = new ConsumerGroupMetadata("appId"); + when(consumer.groupMetadata()).thenReturn(groupMetadata); - final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager); final Map offsets00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); - task00.setCommittableOffsetsAndMetadata(offsets00); - task00.setCommitNeeded(); + when(task00.commitNeeded()).thenReturn(true); + when(task00.prepareCommit(true)).thenReturn(offsets00); + doNothing().when(task00).postCommit(anyBoolean()); + doNothing().when(task00).suspend(); - final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager); final Map offsets01 = singletonMap(t1p1, new OffsetAndMetadata(1L, null)); - task01.setCommittableOffsetsAndMetadata(offsets01); - task01.setCommitNeeded(); + when(task01.commitNeeded()).thenReturn(true); + when(task01.prepareCommit(true)).thenReturn(offsets01); + doNothing().when(task01).postCommit(anyBoolean()); - final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true, stateManager); - final Map offsets02 = singletonMap(t1p2, new OffsetAndMetadata(2L, null)); - task02.setCommittableOffsetsAndMetadata(offsets02); + // task02 does not need commit + when(task02.commitNeeded()).thenReturn(false); - final StateMachineTask task10 = new StateMachineTask(taskId10, taskId10Partitions, false, stateManager); + // standby task should not take part in commit + when(task10.commitNeeded()).thenReturn(false); + // expected committed offsets, only task00 and task01 (both need commit) final Map expectedCommittedOffsets = new HashMap<>(); expectedCommittedOffsets.putAll(offsets00); expectedCommittedOffsets.putAll(offsets01); - final Map> assignmentActive = mkMap( - mkEntry(taskId00, taskId00Partitions), - mkEntry(taskId01, taskId01Partitions), - mkEntry(taskId02, taskId02Partitions) - ); - - final Map> assignmentStandby = mkMap( - mkEntry(taskId10, taskId10Partitions) - ); - when(consumer.assignment()).thenReturn(assignment); - - when(activeTaskCreator.createTasks(any(), eq(assignmentActive))) - .thenReturn(asList(task00, task01, task02)); + final TaskManager taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, tasks); - when(activeTaskCreator.streamsProducer()).thenReturn(producer); - when(standbyTaskCreator.createTasks(assignmentStandby)) - .thenReturn(singletonList(task10)); - - final ConsumerGroupMetadata groupMetadata = new ConsumerGroupMetadata("appId"); - when(consumer.groupMetadata()).thenReturn(groupMetadata); + taskManager.handleRevocation(taskId00Partitions); - task00.committedOffsets(); - task01.committedOffsets(); - task02.committedOffsets(); - task10.committedOffsets(); + // Verify the commit transaction was called with offsets from task00 and task01 + verify(producer).commitTransaction(expectedCommittedOffsets, groupMetadata); - taskManager.handleAssignment(assignmentActive, assignmentStandby); - assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); - assertThat(task00.state(), is(Task.State.RUNNING)); - assertThat(task01.state(), is(Task.State.RUNNING)); - assertThat(task02.state(), is(Task.State.RUNNING)); - assertThat(task10.state(), is(Task.State.RUNNING)); + // Verify task00 (revoked) was suspended and committed + verify(task00).prepareCommit(true); + verify(task00).postCommit(true); + verify(task00).suspend(); - taskManager.handleRevocation(taskId00Partitions); + // Verify task01 (unrevoked but needs commit) was also committed + verify(task01).prepareCommit(true); + verify(task01).postCommit(false); - assertThat(task00.commitNeeded, is(false)); - assertThat(task01.commitNeeded, is(false)); - assertThat(task02.commitPrepared, is(false)); - assertThat(task10.commitPrepared, is(false)); + // Verify task02 (doesn't need commit) was not committed + verify(task02, never()).prepareCommit(anyBoolean()); + verify(task02, never()).postCommit(anyBoolean()); - verify(producer).commitTransaction(expectedCommittedOffsets, groupMetadata); + // Verify standby task10 was not committed + verify(task10, never()).prepareCommit(anyBoolean()); + verify(task10, never()).postCommit(anyBoolean()); } @Test @@ -3772,6 +3903,19 @@ public void shouldCommitViaConsumerIfEosDisabled() { @SuppressWarnings("removal") @Test public void shouldCommitViaProducerIfEosV2Enabled() { + final StreamTask task01 = statefulTask(taskId01, taskId01ChangelogPartitions) + .withInputPartitions(taskId01Partitions) + .inState(State.RUNNING) + .build(); + + final StreamTask task02 = statefulTask(taskId02, taskId02ChangelogPartitions) + .withInputPartitions(taskId02Partitions) + .inState(State.RUNNING) + .build(); + + final TasksRegistry tasks = mock(TasksRegistry.class); + when(tasks.allTasks()).thenReturn(Set.of(task01, task02)); + final StreamsProducer producer = mock(StreamsProducer.class); when(activeTaskCreator.streamsProducer()).thenReturn(producer); @@ -3781,22 +3925,27 @@ public void shouldCommitViaProducerIfEosV2Enabled() { allOffsets.putAll(offsetsT01); allOffsets.putAll(offsetsT02); - final TaskManager taskManager = setUpTaskManagerWithoutStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, null, false); + when(task01.commitNeeded()).thenReturn(true); + when(task01.prepareCommit(true)).thenReturn(offsetsT01); + doNothing().when(task01).postCommit(false); - final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager); - task01.setCommittableOffsetsAndMetadata(offsetsT01); - task01.setCommitNeeded(); - taskManager.addTask(task01); - final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true, stateManager); - task02.setCommittableOffsetsAndMetadata(offsetsT02); - task02.setCommitNeeded(); - taskManager.addTask(task02); + when(task02.commitNeeded()).thenReturn(true); + when(task02.prepareCommit(true)).thenReturn(offsetsT02); + doNothing().when(task02).postCommit(false); when(consumer.groupMetadata()).thenReturn(new ConsumerGroupMetadata("appId")); + final TaskManager taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, tasks); + taskManager.commitAll(); verify(producer).commitTransaction(allOffsets, new ConsumerGroupMetadata("appId")); + verify(task01, times(2)).commitNeeded(); + verify(task01).prepareCommit(true); + verify(task01).postCommit(false); + verify(task02, times(2)).commitNeeded(); + verify(task02).prepareCommit(true); + verify(task02).postCommit(false); verifyNoMoreInteractions(producer); }