Skip to content

Commit 09df7f6

Browse files
samukwekusamuel.oranyeliericmjl
authored
[ENH] Fix summarise for MultiIndex (#1460)
* updates to jn.summarise * cleanup * cleanup * singledispatch for tuple * minor fix for mutate * remove default parameter in max * remove default parameter in max * remove default parameter in max --------- Co-authored-by: samuel.oranyeli <[email protected]> Co-authored-by: Eric Ma <[email protected]>
1 parent 2a2e9ba commit 09df7f6

File tree

3 files changed

+289
-73
lines changed

3 files changed

+289
-73
lines changed

janitor/functions/mutate.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,8 @@ def mutate(
177177
A pandas DataFrame or Series with aggregated columns.
178178
""" # noqa: E501
179179
check("copy", copy, [bool])
180+
if copy:
181+
df = df.copy(deep=None)
180182
if by is not None:
181183
if isinstance(by, DataFrameGroupBy):
182184
# it is assumed that by is created from df
@@ -188,8 +190,7 @@ def mutate(
188190
if is_scalar(by):
189191
by = [by]
190192
by = df.groupby(by, sort=False, observed=True)
191-
if copy:
192-
df = df.copy(deep=None)
193+
193194
for arg in args:
194195
df = _mutator(arg, df=df, by=by)
195196
return df
@@ -226,11 +227,9 @@ def _(arg, df, by):
226227
for column_name, mutator in arg.items():
227228
if isinstance(mutator, tuple):
228229
column, func = mutator
229-
column = _process_within_dict(mutator=func, obj=val[column])
230+
column = _apply_func_to_obj(mutator=func, obj=val[column])
230231
else:
231-
column = _process_within_dict(
232-
mutator=mutator, obj=val[column_name]
233-
)
232+
column = _apply_func_to_obj(mutator=mutator, obj=val[column_name])
234233
df[column_name] = column
235234
return df
236235

@@ -262,7 +261,7 @@ def _process_maybe_string(func: str, obj):
262261
return obj.transform(func)
263262

264263

265-
def _process_within_dict(mutator, obj):
264+
def _apply_func_to_obj(mutator, obj):
266265
"""Handle str/callables within a dictionary"""
267266
if isinstance(mutator, str):
268267
return _process_maybe_string(func=mutator, obj=obj)

janitor/functions/summarise.py

Lines changed: 141 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def summarise(
1919
df: pd.DataFrame,
2020
*args: tuple[dict | tuple],
2121
by: Any = None,
22-
) -> pd.DataFrame | pd.Series:
22+
) -> pd.DataFrame:
2323
"""
2424
2525
!!! info "New in version 0.31.0"
@@ -42,13 +42,15 @@ def summarise(
4242
- **dictionary argument**:
4343
If the argument is a dictionary,
4444
the value in the `{key:value}` pairing
45-
should be either a string, a callable or a tuple.
45+
should be either a string, a callable, or a tuple.
4646
4747
- If the value in the dictionary
4848
is a string or a callable,
4949
the key of the dictionary
5050
should be an existing column name.
5151
52+
The function is applied on the `df[column_name]` series.
53+
5254
!!!note
5355
5456
- If the value is a string,
@@ -57,28 +59,24 @@ def summarise(
5759
5860
- If the value of the dictionary is a tuple,
5961
it should be of length 2, and of the form
60-
`(column_name, mutation_func)`,
62+
`(column_name, aggfunc)`,
6163
where `column_name` should exist in the DataFrame,
62-
and `mutation_func` should be either a string or a callable.
63-
64-
!!!note
64+
and `aggfunc` should be either a string or a callable.
6565
66-
- If `mutation_func` is a string,
67-
the string should be a pandas string function,
68-
e.g "sum", "mean", etc.
66+
This option allows for custom renaming of the aggregation output,
67+
where the key in the dictionary can be a new column name.
6968
70-
The key in the dictionary can be a new column name.
7169
7270
- **tuple argument**:
7371
If the argument is a tuple, it should be of length 2,
7472
and of the form
75-
`(column_name, mutation_func)`,
73+
`(column_name, aggfunc)`,
7674
where column_name should exist in the DataFrame,
77-
and `mutation_func` should be either a string or a callable.
75+
and `aggfunc` should be either a string or a callable.
7876
7977
!!!note
8078
81-
- if `mutation_func` is a string,
79+
- if `aggfunc` is a string,
8280
the string should be a pandas string function,
8381
e.g "sum", "mean", etc.
8482
@@ -89,6 +87,7 @@ def summarise(
8987
as such multiple columns can be processed here -
9088
they will be processed individually.
9189
90+
9291
- **callable argument**:
9392
If the argument is a callable, the callable is applied
9493
on the DataFrame or GroupBy object.
@@ -169,7 +168,7 @@ def summarise(
169168
ValueError: If a tuple is passed and the length is not 2.
170169
171170
Returns:
172-
A pandas DataFrame or Series with aggregated columns.
171+
A pandas DataFrame with aggregated columns.
173172
174173
""" # noqa: E501
175174

@@ -184,18 +183,52 @@ def summarise(
184183
if is_scalar(by):
185184
by = [by]
186185
by = df.groupby(by, sort=False, observed=True)
187-
dictionary = {}
186+
contents = []
188187
for arg in args:
189-
aggregate = _mutator(arg, df=df, by=by)
190-
dictionary.update(aggregate)
191-
values = map(is_scalar, dictionary.values())
192-
if all(values):
193-
return pd.Series(dictionary)
194-
return pd.concat(dictionary, axis="columns", sort=False, copy=False)
188+
aggregate = _aggfunc(arg, df=df, by=by)
189+
contents.extend(aggregate)
190+
counts = 0
191+
for entry in contents:
192+
if isinstance(entry, pd.DataFrame):
193+
length = entry.columns.nlevels
194+
elif isinstance(entry.name, tuple):
195+
length = len(entry.name)
196+
else:
197+
length = 1
198+
counts = max(counts, length)
199+
contents_ = []
200+
for entry in contents:
201+
if isinstance(entry, pd.DataFrame):
202+
length_ = entry.columns.nlevels
203+
length = counts - length_
204+
if length:
205+
patch = [""] * length
206+
columns = [
207+
entry.columns.get_level_values(n) for n in range(length_)
208+
]
209+
columns.append(patch)
210+
names = [*entry.columns.names]
211+
names.extend([None] * length)
212+
columns = pd.MultiIndex.from_arrays(columns, names=names)
213+
entry.columns = columns
214+
elif is_scalar(entry.name):
215+
length = counts - 1
216+
if length:
217+
patch = [""] * length
218+
name = (entry.name, *patch)
219+
entry.name = name
220+
elif isinstance(entry.name, tuple):
221+
length = counts - len(entry.name)
222+
if length:
223+
patch = [""] * length
224+
name = (*entry.name, *patch)
225+
entry.name = name
226+
contents_.append(entry)
227+
return pd.concat(contents_, axis=1, copy=False, sort=False)
195228

196229

197230
@singledispatch
198-
def _mutator(arg, df, by):
231+
def _aggfunc(arg, df, by):
199232
if by is None:
200233
val = df
201234
else:
@@ -204,41 +237,68 @@ def _mutator(arg, df, by):
204237
if isinstance(outcome, pd.Series):
205238
if not outcome.name:
206239
raise ValueError("Ensure the pandas Series object has a name")
207-
return {outcome.name: outcome}
208-
# assumption: a mapping - DataFrame/dictionary/...
209-
return {**outcome}
240+
return [outcome]
241+
if isinstance(outcome, pd.DataFrame):
242+
return [outcome]
243+
raise TypeError(
244+
"The output from the aggregation should be a named Series or a DataFrame"
245+
)
210246

211247

212-
@_mutator.register(dict)
248+
@_aggfunc.register(tuple)
249+
def _(arg, df, by):
250+
"""Dispatch function for tuple"""
251+
if len(arg) != 2:
252+
raise ValueError("the tuple has to be a length of 2")
253+
column_name, aggfunc = arg
254+
column_names = get_index_labels(arg=[column_name], df=df, axis="columns")
255+
mapping = {column_name: aggfunc for column_name in column_names}
256+
return _aggfunc(mapping, df=df, by=by)
257+
258+
259+
@_aggfunc.register(dict)
213260
def _(arg, df, by):
214261
"""Dispatch function for dictionary"""
215262
if by is None:
216263
val = df
217264
else:
218265
val = by
219266

220-
dictionary = {}
221-
for column_name, mutator in arg.items():
222-
if isinstance(mutator, tuple):
223-
column, func = mutator
224-
column = _process_within_dict(mutator=func, obj=val[column])
267+
contents = []
268+
for column_name, aggfunc in arg.items():
269+
if isinstance(aggfunc, tuple):
270+
if len(aggfunc) != 2:
271+
raise ValueError("the tuple has to be a length of 2")
272+
column, func = aggfunc
273+
column_ = _handle_tuple_groupby_selection(by=by, column=column)
274+
column = _apply_func_to_obj(aggfunc=func, obj=val[column_])
275+
if isinstance(column, pd.DataFrame) and column.shape[-1] == 1:
276+
column = column.squeeze()
277+
column = _convert_obj_to_named_series(
278+
obj=column,
279+
column_name=column_name,
280+
function=func,
281+
)
282+
if not isinstance(column, pd.Series):
283+
raise TypeError(
284+
"Expected a pandas Series object; "
285+
f"instead got {type(column)}"
286+
)
225287
else:
226-
column = _process_within_dict(
227-
mutator=mutator, obj=val[column_name]
288+
column_ = _handle_tuple_groupby_selection(
289+
by=by, column=column_name
228290
)
229-
dictionary[column_name] = column
230-
return dictionary
231-
232-
233-
@_mutator.register(tuple)
234-
def _(arg, df, by):
235-
"""Dispatch function for tuple"""
236-
if len(arg) != 2:
237-
raise ValueError("the tuple has to be a length of 2")
238-
column_names, mutator = arg
239-
column_names = get_index_labels(arg=[column_names], df=df, axis="columns")
240-
mapping = {column_name: mutator for column_name in column_names}
241-
return _mutator(mapping, df=df, by=by)
291+
column = _apply_func_to_obj(aggfunc=aggfunc, obj=val[column_])
292+
column = _convert_obj_to_named_series(
293+
obj=column,
294+
column_name=column_name,
295+
function=aggfunc,
296+
)
297+
column = _rename_column_in_by(
298+
column=column, column_name=column_name, by=by
299+
)
300+
contents.append(column)
301+
return contents
242302

243303

244304
def _process_maybe_callable(func: callable, obj):
@@ -257,8 +317,39 @@ def _process_maybe_string(func: str, obj):
257317
return obj.agg(func)
258318

259319

260-
def _process_within_dict(mutator, obj):
320+
def _apply_func_to_obj(aggfunc, obj):
261321
"""Handle str/callables within a dictionary"""
262-
if isinstance(mutator, str):
263-
return _process_maybe_string(func=mutator, obj=obj)
264-
return _process_maybe_callable(func=mutator, obj=obj)
322+
if isinstance(aggfunc, str):
323+
return _process_maybe_string(func=aggfunc, obj=obj)
324+
return _process_maybe_callable(func=aggfunc, obj=obj)
325+
326+
327+
def _handle_tuple_groupby_selection(by: Any, column: Any):
328+
"""
329+
Properly handle a tuple column selection in the presence of a groupby
330+
"""
331+
if (by is not None) and isinstance(column, tuple):
332+
return [column]
333+
return column
334+
335+
336+
def _convert_obj_to_named_series(obj, function: Any, column_name: Any):
337+
if isinstance(obj, pd.Series):
338+
obj.name = column_name
339+
return obj
340+
if not is_scalar(obj):
341+
return obj
342+
if isinstance(function, str):
343+
function_name = function
344+
else:
345+
function_name = function.__name__
346+
return pd.Series(data=obj, index=[function_name], name=column_name)
347+
348+
349+
def _rename_column_in_by(column, column_name, by):
350+
if by is None:
351+
return column
352+
elif isinstance(column, pd.DataFrame) and is_scalar(column_name):
353+
columns = pd.MultiIndex.from_product([[column_name], column.columns])
354+
column.columns = columns
355+
return column

0 commit comments

Comments
 (0)