Skip to content

concatWithKeys impl #1107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/api/core.api
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/ConcatKt {
public static final fun concat (Lorg/jetbrains/kotlinx/dataframe/api/ReducedGroupBy;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun concatRows (Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun concatT (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun concatWithKeys (Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
}

public final class org/jetbrains/kotlinx/dataframe/api/ConstructorsKt {
Expand Down
61 changes: 61 additions & 0 deletions core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ package org.jetbrains.kotlinx.dataframe.api
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.annotations.Refine
import org.jetbrains.kotlinx.dataframe.columns.values
import org.jetbrains.kotlinx.dataframe.impl.api.concatImpl
import org.jetbrains.kotlinx.dataframe.impl.asList
import org.jetbrains.kotlinx.dataframe.type

// region DataColumn

Expand Down Expand Up @@ -40,6 +43,64 @@ public fun <T> DataFrame<T>.concat(frames: Iterable<DataFrame<T>>): DataFrame<T>

public fun <T, G> GroupBy<T, G>.concat(): DataFrame<G> = groups.concat()

/**
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice and helpful kdoc :) Maybe something similar could be put on the documentation website?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sure, we will put it on website!

* Concatenates all groups in this [GroupBy] into a single [DataFrame],
* preserving and including all grouping key columns that are not present in the group's columns.
*
* Doesn't affect key columns that have the same name as columns inside the groups (even if their content differs).
*
* This function is especially useful when grouping by expressions or renamed columns,
* and you want the resulting [DataFrame] to include those keys as part of the output.
*
* ### Example
*
* ```kotlin
* val df = dataFrameOf(
* "value" to listOf(1, 2, 3, 3),
* "type" to listOf("a", "b", "a", "b")
* )
*
* val gb = df.groupBy { expr { "Category: \${type.uppercase()}" } named "category" }
* ```
*
* A regular `concat()` will return a [DataFrame] similar to the original `df`
* (with the same columns and rows but in the different orders):
*
* ```
* gb.concat()
* ```
* | value | type |
* | :---- | :--- |
* | 1 | a |
* | 3 | a |
* | 2 | b |
* | 3 | b |
*
* But `concatWithKeys()` will include the new "category" key column:
*
* ```
* gb.concatWithKeys()
* ```
* | value | type | category |
* | :---- | :--- | :------------ |
* | 1 | a | Category: A |
* | 3 | a | Category: A |
* | 2 | b | Category: B |
* | 3 | b | Category: B |
*
* @return A new [DataFrame] where all groups are combined and additional key columns are included in each row.
*/
@Refine
@Interpretable("ConcatWithKeys")
public fun <T, G> GroupBy<T, G>.concatWithKeys(): DataFrame<G> =
mapToFrames {
val rowsCount = group.rowsCount()
val keyColumns = keys.columns().filter { it.name !in group.columnNames() }.map { keyColumn ->
DataColumn.createByType(keyColumn.name, List(rowsCount) { key[keyColumn] }, keyColumn.type)
}
group.addAll(keyColumns)
}.concat()

// endregion

// region ReducedGroupBy
Expand Down
12 changes: 12 additions & 0 deletions core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,16 @@ class ConcatTests {
val b by columnOf(3.0, null)
a.concat(b) shouldBe columnOf(1, 2, 3.0, null).named("a")
}

@Test
fun `concat with keys`() {
val df = dataFrameOf(
"value" to listOf(1, 2, 3, 3),
"type" to listOf("a", "b", "a", "b"),
)
val gb = df.groupBy { expr { "Category: ${(this["type"] as String).uppercase()}" } named "category" }
val dfWithCategory = gb.concatWithKeys()

dfWithCategory.columnNames() shouldBe listOf("value", "type", "category")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
import org.jetbrains.kotlinx.dataframe.plugin.impl.type
import org.jetbrains.kotlinx.dataframe.plugin.interpret
import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter
import kotlin.collections.plus

class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema) {
companion object {
Expand Down Expand Up @@ -420,6 +421,12 @@ private fun isIntraComparable(col: SimpleDataColumn, session: FirSession): Boole
return col.type.type.isSubtypeOf(comparable, session)
}

class ConcatWithKeys : AbstractSchemaModificationInterpreter() {
val Arguments.receiver by groupBy()



override fun Arguments.interpret(): PluginDataFrameSchema {
val originalColumns = receiver.groups.columns()
return PluginDataFrameSchema(
originalColumns + receiver.keys.columns().filter { it.name !in originalColumns.map { it.name } })
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf1
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf2
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnRange
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ConcatWithKeys
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameBuilderInvoke0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf3
Expand Down Expand Up @@ -208,6 +209,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.WithoutNulls1
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.WithoutNulls2
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names


internal fun FirFunctionCall.loadInterpreter(session: FirSession): Interpreter<*>? {
val interpreter = Stdlib.interpreter(this)
if (interpreter != null) return interpreter
Expand Down Expand Up @@ -459,6 +461,7 @@ internal inline fun <reified T> String.load(): T {
"GroupByStdOf" -> GroupByStdOf()
"DataFrameXs" -> DataFrameXs()
"GroupByXs" -> GroupByXs()
"ConcatWithKeys" -> ConcatWithKeys()
else -> error("$this")
} as T
}
19 changes: 19 additions & 0 deletions plugins/kotlin-dataframe/testData/box/concatWithKeys.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val df = dataFrameOf(
"value" to listOf(1, 2, 3, 3),
"type" to listOf("a", "b", "a", "b")
)
val gb = df.groupBy { expr { "Category: ${type.uppercase()}" } named "category" }
val categoryKey = gb.keys.category

val dfWithCategory = gb.concatWithKeys()

val category: DataColumn<String> = dfWithCategory.category

return "OK"
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ public void testColumnWithStarProjection() {
runTest("testData/box/columnWithStarProjection.kt");
}

@Test
@TestMetadata("concatWithKeys.kt")
public void testConcatWithKeys() {
runTest("testData/box/concatWithKeys.kt");
}

@Test
@TestMetadata("conflictingJvmDeclarations.kt")
public void testConflictingJvmDeclarations() {
Expand Down