diff --git a/thicket/external/console.py b/thicket/external/console.py index b1546ec1..aa3eb3e3 100644 --- a/thicket/external/console.py +++ b/thicket/external/console.py @@ -1,3 +1,4 @@ +import math import warnings import numpy as np @@ -53,6 +54,8 @@ def render(self, roots, dataframe, **kwargs): self.min_value = kwargs["min_value"] self.max_value = kwargs["max_value"] self.indices = kwargs["indices"] + self.hist_data = kwargs["hist_data"] + self.histogram = kwargs["histogram"] if self.color: self.colors = self.colors_enabled @@ -271,6 +274,49 @@ def render_frame(self, node, dataframe, indent="", child_indent=""): result = "{indent}{metric_str} {name_str}".format( indent=indent, metric_str=metric_str, name_str=name_str ) + + if self.histogram: + hist_data = self.hist_data.loc[df_index] + nprofs = len(hist_data) + # Auto choose num bins + nintervals = min(math.ceil(math.sqrt(nprofs)), 20) + # Add min/max to tree + min_num = hist_data.min() + max_num = hist_data.max() + result += ( + f" ({min_num:.{self.precision}f}, {max_num:.{self.precision}f}) " + ) + # Define unicode bars + bar_list = [ + "_", + "\u2581", + "\u2582", + "\u2583", + "\u2584", + "\u2585", + "\u2586", + "\u2587", + "\u2588", + ] + try: + # Compute histogram intervals using pandas binning + binned = pd.cut(hist_data, bins=nintervals) + hist = binned.value_counts().sort_index() + # Normalize values to the number of bars + normalized_hist = ( + (len(bar_list) - 1) + * (hist - hist.min()) + / (hist.max() - hist.min()) + ) + normalized_hist = normalized_hist.apply(np.ceil).astype(int) + # Add histogram to tree + for idx in normalized_hist.values: + result += bar_list[idx] + except ( + ValueError or pd.errors.IntCastingNaNError + ): # NA or inf cannot be binned + pass + if self.context in dataframe.columns: result += " {c.faint}{context}{c.end}\n".format( context=dataframe.loc[df_index, self.context], c=self.colors diff --git a/thicket/thicket.py b/thicket/thicket.py index c77cde7a..1bef55c2 100644 --- a/thicket/thicket.py +++ b/thicket/thicket.py @@ -170,7 +170,6 @@ def profile_hasher(obj, hex_len=11): Returns: (int): hash of the object """ - return int(md5(obj.encode("utf-8")).hexdigest()[:hex_len], 16) @staticmethod @@ -1000,6 +999,7 @@ def tree( min_value=None, max_value=None, indices=None, + histogram=False, ): """Visualize the Thicket as a tree @@ -1021,6 +1021,7 @@ def tree( min_value (int, optional): Overwrites the min value for the coloring legend. Defaults to None. max_value (int, optional): Overwrites the max value for the coloring legend. Defaults to None. indices(tuple, list, optional): Index/indices to display on the DataFrame. Defaults to None. + histogram (bool, optional): Whether to show a histogram next to each node of the data for all profiles. Defaults to False. Returns: (str): String representation of the tree, ready to print @@ -1132,6 +1133,8 @@ def tree( min_value=min_value, max_value=max_value, indices=idx_dict, + hist_data=self.dataframe[metric_column], + histogram=histogram, ) @staticmethod