diff --git a/core/api/core.api b/core/api/core.api index f03122f4c3..ec524b474e 100644 --- a/core/api/core.api +++ b/core/api/core.api @@ -1340,6 +1340,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/ConcatKt { public static final fun concat (Lorg/jetbrains/kotlinx/dataframe/api/ReducedGroupBy;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun concatRows (Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static final fun concatT (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static final fun concatWithKeys (Lorg/jetbrains/kotlinx/dataframe/api/GroupBy;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; } public final class org/jetbrains/kotlinx/dataframe/api/ConstructorsKt { diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt index b475b50fdb..72af0f7014 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt @@ -3,9 +3,12 @@ package org.jetbrains.kotlinx.dataframe.api import org.jetbrains.kotlinx.dataframe.DataColumn import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.DataRow +import org.jetbrains.kotlinx.dataframe.annotations.Interpretable +import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.values import org.jetbrains.kotlinx.dataframe.impl.api.concatImpl import org.jetbrains.kotlinx.dataframe.impl.asList +import org.jetbrains.kotlinx.dataframe.type // region DataColumn @@ -40,6 +43,64 @@ public fun DataFrame.concat(frames: Iterable>): DataFrame public fun GroupBy.concat(): DataFrame = groups.concat() +/** + * Concatenates all groups in this [GroupBy] into a single [DataFrame], + * preserving and including all grouping key columns that are not present in the group's columns. + * + * Doesn't affect key columns that have the same name as columns inside the groups (even if their content differs). + * + * This function is especially useful when grouping by expressions or renamed columns, + * and you want the resulting [DataFrame] to include those keys as part of the output. + * + * ### Example + * + * ```kotlin + * val df = dataFrameOf( + * "value" to listOf(1, 2, 3, 3), + * "type" to listOf("a", "b", "a", "b") + * ) + * + * val gb = df.groupBy { expr { "Category: \${type.uppercase()}" } named "category" } + * ``` + * + * A regular `concat()` will return a [DataFrame] similar to the original `df` + * (with the same columns and rows but in the different orders): + * + * ``` + * gb.concat() + * ``` + * | value | type | + * | :---- | :--- | + * | 1 | a | + * | 3 | a | + * | 2 | b | + * | 3 | b | + * + * But `concatWithKeys()` will include the new "category" key column: + * + * ``` + * gb.concatWithKeys() + * ``` + * | value | type | category | + * | :---- | :--- | :------------ | + * | 1 | a | Category: A | + * | 3 | a | Category: A | + * | 2 | b | Category: B | + * | 3 | b | Category: B | + * + * @return A new [DataFrame] where all groups are combined and additional key columns are included in each row. + */ +@Refine +@Interpretable("ConcatWithKeys") +public fun GroupBy.concatWithKeys(): DataFrame = + mapToFrames { + val rowsCount = group.rowsCount() + val keyColumns = keys.columns().filter { it.name !in group.columnNames() }.map { keyColumn -> + DataColumn.createByType(keyColumn.name, List(rowsCount) { key[keyColumn] }, keyColumn.type) + } + group.addAll(keyColumns) + }.concat() + // endregion // region ReducedGroupBy diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt index c290e500fb..25472bfd36 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/concat.kt @@ -11,4 +11,16 @@ class ConcatTests { val b by columnOf(3.0, null) a.concat(b) shouldBe columnOf(1, 2, 3.0, null).named("a") } + + @Test + fun `concat with keys`() { + val df = dataFrameOf( + "value" to listOf(1, 2, 3, 3), + "type" to listOf("a", "b", "a", "b"), + ) + val gb = df.groupBy { expr { "Category: ${(this["type"] as String).uppercase()}" } named "category" } + val dfWithCategory = gb.concatWithKeys() + + dfWithCategory.columnNames() shouldBe listOf("value", "type", "category") + } } diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt index a0a8c29518..d90556b3c5 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/groupBy.kt @@ -36,6 +36,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf import org.jetbrains.kotlinx.dataframe.plugin.impl.type import org.jetbrains.kotlinx.dataframe.plugin.interpret import org.jetbrains.kotlinx.dataframe.plugin.loadInterpreter +import kotlin.collections.plus class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema) { companion object { @@ -420,6 +421,12 @@ private fun isIntraComparable(col: SimpleDataColumn, session: FirSession): Boole return col.type.type.isSubtypeOf(comparable, session) } +class ConcatWithKeys : AbstractSchemaModificationInterpreter() { + val Arguments.receiver by groupBy() - - + override fun Arguments.interpret(): PluginDataFrameSchema { + val originalColumns = receiver.groups.columns() + return PluginDataFrameSchema( + originalColumns + receiver.keys.columns().filter { it.name !in originalColumns.map { it.name } }) + } +} diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index f1f5a7f10c..d97a427688 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -99,6 +99,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf2 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnRange +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ConcatWithKeys import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameBuilderInvoke0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf3 @@ -208,6 +209,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.WithoutNulls1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.WithoutNulls2 import org.jetbrains.kotlinx.dataframe.plugin.utils.Names + internal fun FirFunctionCall.loadInterpreter(session: FirSession): Interpreter<*>? { val interpreter = Stdlib.interpreter(this) if (interpreter != null) return interpreter @@ -459,6 +461,7 @@ internal inline fun String.load(): T { "GroupByStdOf" -> GroupByStdOf() "DataFrameXs" -> DataFrameXs() "GroupByXs" -> GroupByXs() + "ConcatWithKeys" -> ConcatWithKeys() else -> error("$this") } as T } diff --git a/plugins/kotlin-dataframe/testData/box/concatWithKeys.kt b/plugins/kotlin-dataframe/testData/box/concatWithKeys.kt new file mode 100644 index 0000000000..fb7b31a1cc --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/concatWithKeys.kt @@ -0,0 +1,19 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + val df = dataFrameOf( + "value" to listOf(1, 2, 3, 3), + "type" to listOf("a", "b", "a", "b") + ) + val gb = df.groupBy { expr { "Category: ${type.uppercase()}" } named "category" } + val categoryKey = gb.keys.category + + val dfWithCategory = gb.concatWithKeys() + + val category: DataColumn = dfWithCategory.category + + return "OK" +} diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index ec31faff13..fb98999b3c 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -76,6 +76,12 @@ public void testColumnWithStarProjection() { runTest("testData/box/columnWithStarProjection.kt"); } + @Test + @TestMetadata("concatWithKeys.kt") + public void testConcatWithKeys() { + runTest("testData/box/concatWithKeys.kt"); + } + @Test @TestMetadata("conflictingJvmDeclarations.kt") public void testConflictingJvmDeclarations() {