@@ -19,7 +19,7 @@ def summarise(
1919 df : pd .DataFrame ,
2020 * args : tuple [dict | tuple ],
2121 by : Any = None ,
22- ) -> pd .DataFrame | pd . Series :
22+ ) -> pd .DataFrame :
2323 """
2424
2525 !!! info "New in version 0.31.0"
@@ -42,13 +42,15 @@ def summarise(
4242 - **dictionary argument**:
4343 If the argument is a dictionary,
4444 the value in the `{key:value}` pairing
45- should be either a string, a callable or a tuple.
45+ should be either a string, a callable, or a tuple.
4646
4747 - If the value in the dictionary
4848 is a string or a callable,
4949 the key of the dictionary
5050 should be an existing column name.
5151
52+ The function is applied on the `df[column_name]` series.
53+
5254 !!!note
5355
5456 - If the value is a string,
@@ -57,28 +59,24 @@ def summarise(
5759
5860 - If the value of the dictionary is a tuple,
5961 it should be of length 2, and of the form
60- `(column_name, mutation_func )`,
62+ `(column_name, aggfunc )`,
6163 where `column_name` should exist in the DataFrame,
62- and `mutation_func` should be either a string or a callable.
63-
64- !!!note
64+ and `aggfunc` should be either a string or a callable.
6565
66- - If `mutation_func` is a string,
67- the string should be a pandas string function,
68- e.g "sum", "mean", etc.
66+ This option allows for custom renaming of the aggregation output,
67+ where the key in the dictionary can be a new column name.
6968
70- The key in the dictionary can be a new column name.
7169
7270 - **tuple argument**:
7371 If the argument is a tuple, it should be of length 2,
7472 and of the form
75- `(column_name, mutation_func )`,
73+ `(column_name, aggfunc )`,
7674 where column_name should exist in the DataFrame,
77- and `mutation_func ` should be either a string or a callable.
75+ and `aggfunc ` should be either a string or a callable.
7876
7977 !!!note
8078
81- - if `mutation_func ` is a string,
79+ - if `aggfunc ` is a string,
8280 the string should be a pandas string function,
8381 e.g "sum", "mean", etc.
8482
@@ -89,6 +87,7 @@ def summarise(
8987 as such multiple columns can be processed here -
9088 they will be processed individually.
9189
90+
9291 - **callable argument**:
9392 If the argument is a callable, the callable is applied
9493 on the DataFrame or GroupBy object.
@@ -169,7 +168,7 @@ def summarise(
169168 ValueError: If a tuple is passed and the length is not 2.
170169
171170 Returns:
172- A pandas DataFrame or Series with aggregated columns.
171+ A pandas DataFrame with aggregated columns.
173172
174173 """ # noqa: E501
175174
@@ -184,18 +183,52 @@ def summarise(
184183 if is_scalar (by ):
185184 by = [by ]
186185 by = df .groupby (by , sort = False , observed = True )
187- dictionary = {}
186+ contents = []
188187 for arg in args :
189- aggregate = _mutator (arg , df = df , by = by )
190- dictionary .update (aggregate )
191- values = map (is_scalar , dictionary .values ())
192- if all (values ):
193- return pd .Series (dictionary )
194- return pd .concat (dictionary , axis = "columns" , sort = False , copy = False )
188+ aggregate = _aggfunc (arg , df = df , by = by )
189+ contents .extend (aggregate )
190+ counts = 0
191+ for entry in contents :
192+ if isinstance (entry , pd .DataFrame ):
193+ length = entry .columns .nlevels
194+ elif isinstance (entry .name , tuple ):
195+ length = len (entry .name )
196+ else :
197+ length = 1
198+ counts = max (counts , length )
199+ contents_ = []
200+ for entry in contents :
201+ if isinstance (entry , pd .DataFrame ):
202+ length_ = entry .columns .nlevels
203+ length = counts - length_
204+ if length :
205+ patch = ["" ] * length
206+ columns = [
207+ entry .columns .get_level_values (n ) for n in range (length_ )
208+ ]
209+ columns .append (patch )
210+ names = [* entry .columns .names ]
211+ names .extend ([None ] * length )
212+ columns = pd .MultiIndex .from_arrays (columns , names = names )
213+ entry .columns = columns
214+ elif is_scalar (entry .name ):
215+ length = counts - 1
216+ if length :
217+ patch = ["" ] * length
218+ name = (entry .name , * patch )
219+ entry .name = name
220+ elif isinstance (entry .name , tuple ):
221+ length = counts - len (entry .name )
222+ if length :
223+ patch = ["" ] * length
224+ name = (* entry .name , * patch )
225+ entry .name = name
226+ contents_ .append (entry )
227+ return pd .concat (contents_ , axis = 1 , copy = False , sort = False )
195228
196229
197230@singledispatch
198- def _mutator (arg , df , by ):
231+ def _aggfunc (arg , df , by ):
199232 if by is None :
200233 val = df
201234 else :
@@ -204,41 +237,68 @@ def _mutator(arg, df, by):
204237 if isinstance (outcome , pd .Series ):
205238 if not outcome .name :
206239 raise ValueError ("Ensure the pandas Series object has a name" )
207- return {outcome .name : outcome }
208- # assumption: a mapping - DataFrame/dictionary/...
209- return {** outcome }
240+ return [outcome ]
241+ if isinstance (outcome , pd .DataFrame ):
242+ return [outcome ]
243+ raise TypeError (
244+ "The output from the aggregation should be a named Series or a DataFrame"
245+ )
210246
211247
212- @_mutator .register (dict )
248+ @_aggfunc .register (tuple )
249+ def _ (arg , df , by ):
250+ """Dispatch function for tuple"""
251+ if len (arg ) != 2 :
252+ raise ValueError ("the tuple has to be a length of 2" )
253+ column_name , aggfunc = arg
254+ column_names = get_index_labels (arg = [column_name ], df = df , axis = "columns" )
255+ mapping = {column_name : aggfunc for column_name in column_names }
256+ return _aggfunc (mapping , df = df , by = by )
257+
258+
259+ @_aggfunc .register (dict )
213260def _ (arg , df , by ):
214261 """Dispatch function for dictionary"""
215262 if by is None :
216263 val = df
217264 else :
218265 val = by
219266
220- dictionary = {}
221- for column_name , mutator in arg .items ():
222- if isinstance (mutator , tuple ):
223- column , func = mutator
224- column = _process_within_dict (mutator = func , obj = val [column ])
267+ contents = []
268+ for column_name , aggfunc in arg .items ():
269+ if isinstance (aggfunc , tuple ):
270+ if len (aggfunc ) != 2 :
271+ raise ValueError ("the tuple has to be a length of 2" )
272+ column , func = aggfunc
273+ column_ = _handle_tuple_groupby_selection (by = by , column = column )
274+ column = _apply_func_to_obj (aggfunc = func , obj = val [column_ ])
275+ if isinstance (column , pd .DataFrame ) and column .shape [- 1 ] == 1 :
276+ column = column .squeeze ()
277+ column = _convert_obj_to_named_series (
278+ obj = column ,
279+ column_name = column_name ,
280+ function = func ,
281+ )
282+ if not isinstance (column , pd .Series ):
283+ raise TypeError (
284+ "Expected a pandas Series object; "
285+ f"instead got { type (column )} "
286+ )
225287 else :
226- column = _process_within_dict (
227- mutator = mutator , obj = val [ column_name ]
288+ column_ = _handle_tuple_groupby_selection (
289+ by = by , column = column_name
228290 )
229- dictionary [column_name ] = column
230- return dictionary
231-
232-
233- @_mutator .register (tuple )
234- def _ (arg , df , by ):
235- """Dispatch function for tuple"""
236- if len (arg ) != 2 :
237- raise ValueError ("the tuple has to be a length of 2" )
238- column_names , mutator = arg
239- column_names = get_index_labels (arg = [column_names ], df = df , axis = "columns" )
240- mapping = {column_name : mutator for column_name in column_names }
241- return _mutator (mapping , df = df , by = by )
291+ column = _apply_func_to_obj (aggfunc = aggfunc , obj = val [column_ ])
292+ column = _convert_obj_to_named_series (
293+ obj = column ,
294+ column_name = column_name ,
295+ function = aggfunc ,
296+ )
297+ column = _rename_column_in_by (
298+ column = column , column_name = column_name , by = by
299+ )
300+ contents .append (column )
301+ return contents
242302
243303
244304def _process_maybe_callable (func : callable , obj ):
@@ -257,8 +317,39 @@ def _process_maybe_string(func: str, obj):
257317 return obj .agg (func )
258318
259319
260- def _process_within_dict ( mutator , obj ):
320+ def _apply_func_to_obj ( aggfunc , obj ):
261321 """Handle str/callables within a dictionary"""
262- if isinstance (mutator , str ):
263- return _process_maybe_string (func = mutator , obj = obj )
264- return _process_maybe_callable (func = mutator , obj = obj )
322+ if isinstance (aggfunc , str ):
323+ return _process_maybe_string (func = aggfunc , obj = obj )
324+ return _process_maybe_callable (func = aggfunc , obj = obj )
325+
326+
327+ def _handle_tuple_groupby_selection (by : Any , column : Any ):
328+ """
329+ Properly handle a tuple column selection in the presence of a groupby
330+ """
331+ if (by is not None ) and isinstance (column , tuple ):
332+ return [column ]
333+ return column
334+
335+
336+ def _convert_obj_to_named_series (obj , function : Any , column_name : Any ):
337+ if isinstance (obj , pd .Series ):
338+ obj .name = column_name
339+ return obj
340+ if not is_scalar (obj ):
341+ return obj
342+ if isinstance (function , str ):
343+ function_name = function
344+ else :
345+ function_name = function .__name__
346+ return pd .Series (data = obj , index = [function_name ], name = column_name )
347+
348+
349+ def _rename_column_in_by (column , column_name , by ):
350+ if by is None :
351+ return column
352+ elif isinstance (column , pd .DataFrame ) and is_scalar (column_name ):
353+ columns = pd .MultiIndex .from_product ([[column_name ], column .columns ])
354+ column .columns = columns
355+ return column
0 commit comments