diff --git a/thicket/stats/display_boxplot.py b/thicket/stats/display_boxplot.py index ba9bd139..c928e986 100644 --- a/thicket/stats/display_boxplot.py +++ b/thicket/stats/display_boxplot.py @@ -11,7 +11,7 @@ from ..utils import verify_thicket_structures -def display_boxplot(thicket, nodes=None, columns=None, **kwargs): +def display_boxplot(thicket, nodes=None, columns=None, column_mapping=None, legend_title="Performance counter", return_mpl=True, **kwargs): """Display a boxplot for each user passed node(s) and column(s). The passed nodes and columns must be from the performance data table. @@ -24,9 +24,14 @@ def display_boxplot(thicket, nodes=None, columns=None, **kwargs): column (list): List of hardware/timing metrics to view on the y-axis. Note, if using a columnar joined thicket a list of tuples must be passed in with the format (column index, column name). + column_mapping (dict): Dict mapping the names in 'columns' to the desired names + on the plot + legend_title (str): Title to use for the legend of the boxplot + return_mpl (bool): If True, return a matplotlib Axes object. Otherwise, return + a seaborn FacetGrid object Returns: - (matplotlib Axes): Object for managing boxplot. + (matplotlib.pyplot.Axes or seaborn.FacetGrid): Object for managing boxplot. """ if columns is None or nodes is None: raise ValueError( @@ -50,11 +55,18 @@ def display_boxplot(thicket, nodes=None, columns=None, **kwargs): # thicket object without columnar index if thicket.dataframe.columns.nlevels == 1: + df = thicket.dataframe.reset_index() + mapped_columns = columns + if column_mapping is not None: + if not isinstance(column_mapping, dict): + raise TypeError("'column_mapping' must be a dict") + df.rename(columns=column_mapping, inplace=True) + mapped_columns = [column_mapping[c] for c in columns] df = pd.melt( - thicket.dataframe.reset_index(), + df, id_vars=["node", "name"], - value_vars=columns, - var_name="Performance counter", + value_vars=mapped_columns, + var_name=legend_title, value_name=" ", ) @@ -69,13 +81,6 @@ def display_boxplot(thicket, nodes=None, columns=None, **kwargs): filtered_df = df.loc[position].rename( columns={"node": "hatchet node", "name": "node"} ) - - if len(columns) > 1: - return sns.boxplot( - data=filtered_df, x="node", y=" ", hue="Performance counter", **kwargs - ) - else: - return sns.boxplot(data=filtered_df, x="node", y=" ", **kwargs) # columnar joined thicket object else: @@ -85,16 +90,22 @@ def column_name_mapper(current_cols): return str(current_cols) - cols = [str(c) for c in columns] + mapped_columns = [str(c) for c in columns] df_subset = thicket.dataframe[[("name", ""), *columns]].reset_index() df_subset.columns = df_subset.columns.to_flat_index().map(column_name_mapper) df_subset["name"] = thicket.dataframe["name"].tolist() + + if column_mapping is not None: + if not isinstance(column_mapping, dict): + raise TypeError("'column_mapping' must be a dict") + df_subset.rename(columns={str(k): str(v) for k, v, in column_mapping.items()}, inplace=True) + mapped_columns = [str(column_mapping[c]) for c in columns] df = pd.melt( df_subset, id_vars=["node", "name"], - value_vars=cols, - var_name="Performance counter", + value_vars=mapped_columns, + var_name=legend_title, value_name=" ", ) @@ -110,9 +121,15 @@ def column_name_mapper(current_cols): columns={"node": "hatchet node", "name": "node"} ) - if len(columns) > 1: - return sns.boxplot( - data=filtered_df, x="node", y=" ", hue="Performance counter", **kwargs - ) - else: - return sns.boxplot(data=filtered_df, x="node", y=" ", **kwargs) + mod_kwargs = kwargs.copy() + if len(columns) > 1: + mod_kwargs["hue"] = legend_title + + if return_mpl: + return sns.boxplot( + data=filtered_df, x="node", y=" ", **mod_kwargs + ) + else: + return sns.catplot( + data=filtered_df, x="node", y=" ", kind="box", **mod_kwargs + )