-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add chunked and windowed operators #1558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
c0bf01b
86a503d
dbe84ea
6155c22
c74bfcf
07f72e9
9392bbd
8e7ee37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
@file:JvmMultifileClass | ||
@file:JvmName("FlowKt") | ||
|
||
package kotlinx.coroutines.flow | ||
|
||
import kotlinx.coroutines.* | ||
import kotlin.jvm.* | ||
import kotlin.math.* | ||
|
||
/** | ||
* Returns a flow of lists each not exceeding the given [size]. | ||
* The last list in the resulting flow may have less elements than the given [size]. | ||
* | ||
* @param size the number of elements to take in each list, must be positive and can be greater than the number of elements in this flow. | ||
*/ | ||
|
||
@FlowPreview | ||
public fun <T> Flow<T>.chunked(size: Int): Flow<List<T>> = chunked(size) { it.toList() } | ||
|
||
/** | ||
* Chunks a flow of elements into flow of lists, each not exceeding the given [size] | ||
* and applies the given [transform] function to an each. | ||
* | ||
* Note that the list passed to the [transform] function is ephemeral and is valid only inside that function. | ||
* You should not store it or allow it to escape in some way, unless you made a snapshot of it. | ||
* The last list may have less elements than the given [size]. | ||
* | ||
* This is more efficient, than using flow.chunked(n).map { ... } | ||
* | ||
* @param size the number of elements to take in each list, must be positive and can be greater than the number of elements in this flow. | ||
*/ | ||
|
||
@FlowPreview | ||
public fun <T, R> Flow<T>.chunked(size: Int, transform: suspend (List<T>) -> R): Flow<R> { | ||
require(size > 0) { "Size should be greater than 0, but was $size" } | ||
return windowed(size, size, true, transform) | ||
} | ||
|
||
/** | ||
* Returns a flow of snapshots of the window of the given [size] | ||
* sliding along this flow with the given [step], where each | ||
* snapshot is a list. | ||
* | ||
* Several last lists may have less elements than the given [size]. | ||
* | ||
* Both [size] and [step] must be positive and can be greater than the number of elements in this flow. | ||
* @param size the number of elements to take in each window | ||
* @param step the number of elements to move the window forward by on an each step | ||
* @param partialWindows controls whether or not to keep partial windows in the end if any. | ||
*/ | ||
|
||
@FlowPreview | ||
public fun <T> Flow<T>.windowed(size: Int, step: Int, partialWindows: Boolean): Flow<List<T>> = | ||
windowed(size, step, partialWindows) { it.toList() } | ||
|
||
/** | ||
* Returns a flow of results of applying the given [transform] function to | ||
* an each list representing a view over the window of the given [size] | ||
* sliding along this collection with the given [step]. | ||
* | ||
* Note that the list passed to the [transform] function is ephemeral and is valid only inside that function. | ||
* You should not store it or allow it to escape in some way, unless you made a snapshot of it. | ||
* Several last lists may have less elements than the given [size]. | ||
* | ||
* This is more efficient, than using flow.windowed(...).map { ... } | ||
* | ||
* Both [size] and [step] must be positive and can be greater than the number of elements in this collection. | ||
* @param size the number of elements to take in each window | ||
* @param step the number of elements to move the window forward by on an each step. | ||
* @param partialWindows controls whether or not to keep partial windows in the end if any. | ||
*/ | ||
|
||
@OptIn(ExperimentalStdlibApi::class) | ||
@FlowPreview | ||
public fun <T, R> Flow<T>.windowed(size: Int, step: Int, partialWindows: Boolean, transform: suspend (List<T>) -> R): Flow<R> { | ||
require(size > 0 && step > 0) { "Size and step should be greater than 0, but was size: $size, step: $step" } | ||
|
||
return flow { | ||
val buffer = ArrayDeque<T>(size) | ||
val toDrop = min(step, size) | ||
val toSkip = max(step - size, 0) | ||
var skipped = toSkip | ||
|
||
collect { value -> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quick observation. If the upstream flow throws an exception, you could end up with a partial buffer here.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have fear, that we can loose exception this way or otherwise obscure it. What if emitting partial buffer gets cancelled on some suspension point, before we rethrow exception? Or even, what if exception will be thrown from emitting our partial buffer, before we manage to rethrow original exception? Perhaps we could fix this by using a finally block somewhere. But, I think, it would be simpler, to just let exception propagate right away, and loose unfortunate partial buffer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Loosing partial buffer is somewhat analogous to coroutineScope cancelling all its children immediately, on detecting exception anywhere. It is just more consistent this way. |
||
if (toSkip == skipped) buffer.addLast(value) | ||
else skipped++ | ||
|
||
if (buffer.size == size) { | ||
emit(transform(buffer)) | ||
repeat(toDrop) { buffer.removeFirst() } | ||
skipped = 0 | ||
} | ||
} | ||
|
||
while (partialWindows && buffer.isNotEmpty()) { | ||
emit(transform(buffer)) | ||
repeat(min(toDrop, buffer.size)) { buffer.removeFirst() } | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
package kotlinx.coroutines.internal | ||
|
||
internal class RingBuffer<T>(val capacity: Int) : AbstractList<T>(), RandomAccess { | ||
circusmagnus marked this conversation as resolved.
Show resolved
Hide resolved
|
||
init { | ||
require(capacity >= 0) { "ring buffer capacity should not be negative but it is $capacity" } | ||
} | ||
|
||
private val buffer = arrayOfNulls<Any?>(capacity) | ||
private var startIndex: Int = 0 | ||
|
||
override var size: Int = 0 | ||
private set | ||
|
||
override fun get(index: Int): T { | ||
require(index in 0 until size) { "Index out of bounds: $index" } | ||
@Suppress("UNCHECKED_CAST") | ||
return buffer[startIndex.forward(index)] as T | ||
} | ||
|
||
fun isFull() = size == capacity | ||
|
||
override fun iterator(): Iterator<T> = object : AbstractIterator<T>() { | ||
private var count = size | ||
private var index = startIndex | ||
|
||
override fun computeNext() { | ||
if (count == 0) { | ||
done() | ||
} else { | ||
@Suppress("UNCHECKED_CAST") | ||
setNext(buffer[index] as T) | ||
index = index.forward(1) | ||
count-- | ||
} | ||
} | ||
} | ||
|
||
@Suppress("UNCHECKED_CAST") | ||
override fun <T> toArray(array: Array<T>): Array<T> { | ||
val result: Array<T?> = | ||
if (array.size < this.size) array.copyOf(this.size) else array as Array<T?> | ||
|
||
val size = this.size | ||
|
||
var widx = 0 | ||
var idx = startIndex | ||
|
||
while (widx < size && idx < capacity) { | ||
result[widx] = buffer[idx] as T | ||
widx++ | ||
idx++ | ||
} | ||
|
||
idx = 0 | ||
while (widx < size) { | ||
result[widx] = buffer[idx] as T | ||
widx++ | ||
idx++ | ||
} | ||
if (result.size > this.size) result[this.size] = null | ||
|
||
return result as Array<T> | ||
} | ||
|
||
override fun toArray(): Array<Any?> { | ||
return toArray(arrayOfNulls(size)) | ||
} | ||
|
||
/** | ||
* Add [element] to the buffer or fail with [IllegalStateException] if no free space available in the buffer | ||
*/ | ||
fun add(element: T) { | ||
check(!isFull()) { "Ring buffer is full" } | ||
|
||
buffer[startIndex.forward(size)] = element | ||
size++ | ||
} | ||
|
||
/** | ||
* Removes [n] first elements from the buffer or fails with [IllegalArgumentException] if not enough elements in the buffer to remove | ||
*/ | ||
fun removeFirst(n: Int) { | ||
require(n >= 0) { "n shouldn't be negative but it is $n" } | ||
require(n <= size) { "n shouldn't be greater than the buffer size: n = $n, size = $size" } | ||
|
||
if (n > 0) { | ||
val start = startIndex | ||
val end = start.forward(n) | ||
|
||
if (start > end) { | ||
buffer.fill(null, start, capacity) | ||
buffer.fill(null, 0, end) | ||
} else { | ||
buffer.fill(null, start, end) | ||
} | ||
|
||
startIndex = end | ||
size -= n | ||
} | ||
} | ||
|
||
|
||
@Suppress("NOTHING_TO_INLINE") | ||
private inline fun Int.forward(n: Int): Int = (this + n) % capacity | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
package kotlinx.coroutines.flow.operators | ||
|
||
import kotlinx.coroutines.* | ||
import kotlinx.coroutines.channels.Channel | ||
import kotlinx.coroutines.flow.* | ||
import kotlin.test.Test | ||
import kotlin.test.assertEquals | ||
|
||
class ChunkedTest : TestBase() { | ||
|
||
private val flow = flow { | ||
emit(1) | ||
emit(2) | ||
emit(3) | ||
emit(4) | ||
} | ||
|
||
@Test | ||
fun `Chunks correct number of emissions with possible partial window at the end`() = runTest { | ||
assertEquals(2, flow.chunked(2).count()) | ||
assertEquals(2, flow.chunked(3).count()) | ||
assertEquals(1, flow.chunked(5).count()) | ||
} | ||
|
||
@Test | ||
fun `Throws IllegalArgumentException for chunk of size less than 1`() { | ||
assertFailsWith<IllegalArgumentException> { flow.chunked(0) } | ||
assertFailsWith<IllegalArgumentException> { flow.chunked(-1) } | ||
} | ||
|
||
@Test | ||
fun `No emissions with empty flow`() = runTest { | ||
assertEquals(0, flowOf<Int>().chunked(2).count()) | ||
} | ||
|
||
@Test | ||
fun testErrorCancelsUpstream() = runTest { | ||
val latch = Channel<Unit>() | ||
val flow = flow { | ||
coroutineScope { | ||
launch(start = CoroutineStart.ATOMIC) { | ||
latch.send(Unit) | ||
hang { expect(3) } | ||
} | ||
emit(1) | ||
expect(1) | ||
emit(2) | ||
expectUnreached() | ||
} | ||
}.chunked<Int, Int>(2) { chunk -> | ||
expect(2) // 2 | ||
latch.receive() | ||
throw TestException() | ||
}.catch { emit(42) } | ||
|
||
assertEquals(42, flow.single()) | ||
finish(4) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package kotlinx.coroutines.flow.operators | ||
|
||
import kotlinx.coroutines.* | ||
import kotlinx.coroutines.channels.Channel | ||
import kotlinx.coroutines.flow.* | ||
import kotlin.test.Test | ||
import kotlin.test.assertEquals | ||
|
||
class WindowedTest : TestBase() { | ||
|
||
private val flow = flow { | ||
emit(1) | ||
emit(2) | ||
emit(3) | ||
emit(4) | ||
} | ||
|
||
@Test | ||
fun `Throws IllegalArgumentException for window of size or step less than 1`() { | ||
assertFailsWith<IllegalArgumentException> { flow.windowed(0, 1, false) } | ||
assertFailsWith<IllegalArgumentException> { flow.windowed(-1, 2, false) } | ||
assertFailsWith<IllegalArgumentException> { flow.windowed(2, 0, false) } | ||
assertFailsWith<IllegalArgumentException> { flow.windowed(5, -2, false) } | ||
} | ||
|
||
@Test | ||
fun `No emissions with empty flow`() = runTest { | ||
assertEquals(0, flowOf<Int>().windowed(2, 2, false).count()) | ||
} | ||
|
||
@Test | ||
fun `Emits correct sum with overlapping non partial windows`() = runTest { | ||
assertEquals(15, flow.windowed(3, 1, false) { window -> | ||
window.sum() | ||
}.sum()) | ||
} | ||
|
||
@Test | ||
fun `Emits correct sum with overlapping partial windows`() = runTest { | ||
assertEquals(13, flow.windowed(3, 2, true) { window -> | ||
window.sum() | ||
}.sum()) | ||
} | ||
|
||
@Test | ||
fun `Emits correct number of overlapping windows for long sequence of overlapping partial windows`() = runTest { | ||
val elements = generateSequence(1) { it + 1 }.take(100) | ||
val flow = elements.asFlow().windowed(100, 1, true) | ||
assertEquals(100, flow.count()) | ||
} | ||
|
||
@Test | ||
fun `Emits correct sum with partial windows set apart`() = runTest { | ||
assertEquals(7, flow.windowed(2, 3, true) { window -> | ||
window.sum() | ||
}.sum()) | ||
} | ||
|
||
@Test | ||
fun testErrorCancelsUpstream() = runTest { | ||
val latch = Channel<Unit>() | ||
val flow = flow { | ||
coroutineScope { | ||
launch(start = CoroutineStart.ATOMIC) { | ||
latch.send(Unit) | ||
hang { expect(3) } | ||
} | ||
emit(1) | ||
expect(1) | ||
emit(2) | ||
expectUnreached() | ||
} | ||
}.windowed<Int, Int>(2, 3, false) { window -> | ||
expect(2) // 2 | ||
latch.receive() | ||
throw TestException() | ||
}.catch { emit(42) } | ||
|
||
assertEquals(42, flow.single()) | ||
finish(4) | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.