Skip to content

Commit 5068a6d

Browse files
committed
Made sure aggregating an empty grouped dataframe still generates aggregation columns with test. Issue #1531
1 parent 7c189f0 commit 5068a6d

File tree

2 files changed

+60
-4
lines changed

2 files changed

+60
-4
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/GroupByImpl.kt

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import org.jetbrains.kotlinx.dataframe.aggregation.AggregateGroupedBody
88
import org.jetbrains.kotlinx.dataframe.aggregation.NamedValue
99
import org.jetbrains.kotlinx.dataframe.api.GroupBy
1010
import org.jetbrains.kotlinx.dataframe.api.GroupedRowFilter
11+
import org.jetbrains.kotlinx.dataframe.api.asFrameColumn
1112
import org.jetbrains.kotlinx.dataframe.api.asGroupBy
13+
import org.jetbrains.kotlinx.dataframe.api.cast
1214
import org.jetbrains.kotlinx.dataframe.api.concat
1315
import org.jetbrains.kotlinx.dataframe.api.convert
1416
import org.jetbrains.kotlinx.dataframe.api.getColumn
@@ -18,6 +20,7 @@ import org.jetbrains.kotlinx.dataframe.api.isColumnGroup
1820
import org.jetbrains.kotlinx.dataframe.api.pathOf
1921
import org.jetbrains.kotlinx.dataframe.api.remove
2022
import org.jetbrains.kotlinx.dataframe.api.rename
23+
import org.jetbrains.kotlinx.dataframe.api.take
2124
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
2225
import org.jetbrains.kotlinx.dataframe.impl.aggregation.AggregatableInternal
2326
import org.jetbrains.kotlinx.dataframe.impl.aggregation.GroupByReceiverImpl
@@ -27,8 +30,10 @@ import org.jetbrains.kotlinx.dataframe.impl.api.GroupedDataRowImpl
2730
import org.jetbrains.kotlinx.dataframe.impl.api.insertImpl
2831
import org.jetbrains.kotlinx.dataframe.impl.api.removeImpl
2932
import org.jetbrains.kotlinx.dataframe.impl.columns.toColumnSet
33+
import org.jetbrains.kotlinx.dataframe.impl.schema.createEmptyDataFrame
3034
import org.jetbrains.kotlinx.dataframe.ncol
3135
import org.jetbrains.kotlinx.dataframe.nrow
36+
import org.jetbrains.kotlinx.dataframe.size
3237
import org.jetbrains.kotlinx.dataframe.values
3338

3439
/**
@@ -74,14 +79,25 @@ internal fun <T, G, R> aggregateGroupBy(
7479
body: AggregateGroupedBody<G, R>,
7580
): DataFrame<T> {
7681
val defaultAggregateName = "aggregated"
77-
82+
val groupedDfIsEmpty = df.size().nrow == 0
7883
val column = df.getColumn(selector)
79-
8084
val removed = df.removeImpl(columns = selector)
81-
8285
val hasKeyColumns = removed.df.ncol > 0
8386

84-
val groupedFrame = column.values.map {
87+
val groups =
88+
if (groupedDfIsEmpty) {
89+
// if the grouped dataframe is empty, make sure the provided AggregateGroupedBody is called at least once
90+
// to create aggregated columns. We empty them below.
91+
listOf(
92+
column.asFrameColumn().schema.value
93+
.createEmptyDataFrame()
94+
.cast(),
95+
)
96+
} else {
97+
column.values
98+
}
99+
100+
val groupedFrame = groups.map {
85101
if (it == null) {
86102
null
87103
} else {
@@ -101,6 +117,11 @@ internal fun <T, G, R> aggregateGroupBy(
101117
builder.compute()
102118
}
103119
}.concat()
120+
.let {
121+
// empty the aggregated columns that were created by calling the provided AggregateGroupedBody once
122+
// if the grouped dataframe is empty
123+
if (groupedDfIsEmpty) it.take(0) else it
124+
}
104125

105126
val removedNode = removed.removedColumns.single()
106127
val insertPath = removedNode.pathFromRoot().dropLast(1)

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/testSets/person/DataFrameTests.kt

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ import org.jetbrains.kotlinx.dataframe.api.fill
6262
import org.jetbrains.kotlinx.dataframe.api.fillNulls
6363
import org.jetbrains.kotlinx.dataframe.api.filter
6464
import org.jetbrains.kotlinx.dataframe.api.first
65+
import org.jetbrains.kotlinx.dataframe.api.firstOrNull
6566
import org.jetbrains.kotlinx.dataframe.api.forEach
6667
import org.jetbrains.kotlinx.dataframe.api.forEachIndexed
6768
import org.jetbrains.kotlinx.dataframe.api.frameColumn
@@ -94,14 +95,17 @@ import org.jetbrains.kotlinx.dataframe.api.match
9495
import org.jetbrains.kotlinx.dataframe.api.matches
9596
import org.jetbrains.kotlinx.dataframe.api.max
9697
import org.jetbrains.kotlinx.dataframe.api.maxBy
98+
import org.jetbrains.kotlinx.dataframe.api.maxByOrNull
9799
import org.jetbrains.kotlinx.dataframe.api.mean
98100
import org.jetbrains.kotlinx.dataframe.api.meanFor
99101
import org.jetbrains.kotlinx.dataframe.api.meanOf
100102
import org.jetbrains.kotlinx.dataframe.api.median
103+
import org.jetbrains.kotlinx.dataframe.api.medianOrNull
101104
import org.jetbrains.kotlinx.dataframe.api.merge
102105
import org.jetbrains.kotlinx.dataframe.api.min
103106
import org.jetbrains.kotlinx.dataframe.api.minBy
104107
import org.jetbrains.kotlinx.dataframe.api.minOf
108+
import org.jetbrains.kotlinx.dataframe.api.minOrNull
105109
import org.jetbrains.kotlinx.dataframe.api.minus
106110
import org.jetbrains.kotlinx.dataframe.api.move
107111
import org.jetbrains.kotlinx.dataframe.api.moveTo
@@ -710,6 +714,37 @@ class DataFrameTests : BaseTest() {
710714
res.size() shouldBe 2
711715
}
712716

717+
// Issue #1531
718+
@Test
719+
fun `groupBy empty df should generate empty aggregation cols`() {
720+
val empty = typed.take(0)
721+
val resDf = empty.groupBy { name }.aggregate {
722+
count() into "n"
723+
count { age > 25 } into "old count"
724+
medianOrNull { age } into "median age"
725+
minOrNull { age } into "min age"
726+
all { weight != null } into "all with weights"
727+
maxByOrNull { age }?.city into "oldest origin"
728+
sortBy { age }.firstOrNull()?.city into "youngest origin"
729+
pivot { city.map { "from $it" } }.count()
730+
age.toList() into "ages"
731+
}
732+
733+
resDf.columnNames() shouldBe listOf(
734+
"name",
735+
"n",
736+
"old count",
737+
"median age",
738+
"min age",
739+
"all with weights",
740+
"oldest origin",
741+
"youngest origin",
742+
"ages",
743+
)
744+
745+
resDf.alsoDebug()
746+
}
747+
713748
@Test
714749
fun `groupBy`() {
715750
fun AnyFrame.check() {

0 commit comments

Comments
 (0)