@@ -8,7 +8,9 @@ import org.jetbrains.kotlinx.dataframe.aggregation.AggregateGroupedBody
88import org.jetbrains.kotlinx.dataframe.aggregation.NamedValue
99import org.jetbrains.kotlinx.dataframe.api.GroupBy
1010import org.jetbrains.kotlinx.dataframe.api.GroupedRowFilter
11+ import org.jetbrains.kotlinx.dataframe.api.asFrameColumn
1112import org.jetbrains.kotlinx.dataframe.api.asGroupBy
13+ import org.jetbrains.kotlinx.dataframe.api.cast
1214import org.jetbrains.kotlinx.dataframe.api.concat
1315import org.jetbrains.kotlinx.dataframe.api.convert
1416import org.jetbrains.kotlinx.dataframe.api.getColumn
@@ -18,6 +20,7 @@ import org.jetbrains.kotlinx.dataframe.api.isColumnGroup
1820import org.jetbrains.kotlinx.dataframe.api.pathOf
1921import org.jetbrains.kotlinx.dataframe.api.remove
2022import org.jetbrains.kotlinx.dataframe.api.rename
23+ import org.jetbrains.kotlinx.dataframe.api.take
2124import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
2225import org.jetbrains.kotlinx.dataframe.impl.aggregation.AggregatableInternal
2326import org.jetbrains.kotlinx.dataframe.impl.aggregation.GroupByReceiverImpl
@@ -27,8 +30,10 @@ import org.jetbrains.kotlinx.dataframe.impl.api.GroupedDataRowImpl
2730import org.jetbrains.kotlinx.dataframe.impl.api.insertImpl
2831import org.jetbrains.kotlinx.dataframe.impl.api.removeImpl
2932import org.jetbrains.kotlinx.dataframe.impl.columns.toColumnSet
33+ import org.jetbrains.kotlinx.dataframe.impl.schema.createEmptyDataFrame
3034import org.jetbrains.kotlinx.dataframe.ncol
3135import org.jetbrains.kotlinx.dataframe.nrow
36+ import org.jetbrains.kotlinx.dataframe.size
3237import org.jetbrains.kotlinx.dataframe.values
3338
3439/* *
@@ -74,14 +79,25 @@ internal fun <T, G, R> aggregateGroupBy(
7479 body : AggregateGroupedBody <G , R >,
7580): DataFrame <T > {
7681 val defaultAggregateName = " aggregated"
77-
82+ val groupedDfIsEmpty = df.size().nrow == 0
7883 val column = df.getColumn(selector)
79-
8084 val removed = df.removeImpl(columns = selector)
81-
8285 val hasKeyColumns = removed.df.ncol > 0
8386
84- val groupedFrame = column.values.map {
87+ val groups =
88+ if (groupedDfIsEmpty) {
89+ // if the grouped dataframe is empty, make sure the provided AggregateGroupedBody is called at least once
90+ // to create aggregated columns. We empty them below.
91+ listOf (
92+ column.asFrameColumn().schema.value
93+ .createEmptyDataFrame()
94+ .cast(),
95+ )
96+ } else {
97+ column.values
98+ }
99+
100+ val groupedFrame = groups.map {
85101 if (it == null ) {
86102 null
87103 } else {
@@ -101,6 +117,11 @@ internal fun <T, G, R> aggregateGroupBy(
101117 builder.compute()
102118 }
103119 }.concat()
120+ .let {
121+ // empty the aggregated columns that were created by calling the provided AggregateGroupedBody once
122+ // if the grouped dataframe is empty
123+ if (groupedDfIsEmpty) it.take(0 ) else it
124+ }
104125
105126 val removedNode = removed.removedColumns.single()
106127 val insertPath = removedNode.pathFromRoot().dropLast(1 )
0 commit comments