11import enum
2+ import re
3+ from typing import Optional
24
35import pandas
46from matplotlib .axes import Axes
57from matplotlib .collections import PathCollection
68from matplotlib .lines import Line2D
7- from matplotlib .patches import Rectangle , Wedge
9+ from matplotlib .patches import Rectangle , Wedge , PathPatch
810from matplotlib .pyplot import Figure
911import IPython
1012
1113from IPython .core .formatters import BaseFormatter
14+ from matplotlib .text import Text
1215from traitlets .traitlets import Unicode , ObjectName
1316
1417
@@ -20,21 +23,34 @@ class PlotType(enum.Enum):
2023 UNKNOWN = "unknown"
2124
2225
23- def get_type_of_plot (ax : Axes ) -> PlotType :
26+ def _extract_units (label : str ) -> Optional [str ]:
27+ """
28+ Function to extract units from labels
29+ """
30+ # Look for units in parentheses or brackets
31+ match = re .search (r"\s\((.*?)\)|\[(.*?)\]" , label )
32+ if match :
33+ return match .group (1 ) or match .group (2 ) # return the matched unit
34+ return None # No units found
35+
36+
37+ def _get_type_of_plot (ax : Axes ) -> PlotType :
38+ objects = list (filter (lambda obj : not isinstance (obj , Text ), ax ._children ))
39+
2440 # Check for Line plots
25- if any (isinstance (line , Line2D ) for line in ax . get_lines () ):
41+ if all (isinstance (line , Line2D ) for line in objects ):
2642 return PlotType .LINE
2743
2844 # Check for Scatter plots
29- if any (isinstance (collection , PathCollection ) for collection in ax . collections ):
45+ if all (isinstance (path , PathCollection ) for path in objects ):
3046 return PlotType .SCATTER
3147
3248 # Check for Pie plots
33- if any (isinstance (artist , Wedge ) for artist in ax . patches ):
49+ if all (isinstance (artist , Wedge ) for artist in objects ):
3450 return PlotType .PIE
3551
3652 # Check for Bar plots
37- if any (isinstance (rect , Rectangle ) for rect in ax . patches ):
53+ if all (isinstance (rect , Rectangle ) for rect in objects ):
3854 return PlotType .BAR
3955
4056 return PlotType .UNKNOWN
@@ -53,17 +69,19 @@ def _figure_repr_e2b_data_(self: Figure):
5369 ax_data = {
5470 "title" : ax .get_title (),
5571 "x_label" : ax .get_xlabel (),
72+ "x_unit" : _extract_units (ax .get_xlabel ()),
5673 "x_ticks" : ax .get_xticks (),
5774 "x_tick_labels" : [label .get_text () for label in ax .get_xticklabels ()],
5875 "x_scale" : ax .get_xscale (),
5976 "y_label" : ax .get_ylabel (),
77+ "y_unit" : _extract_units (ax .get_ylabel ()),
6078 "y_ticks" : ax .get_yticks (),
6179 "y_tick_labels" : [label .get_text () for label in ax .get_yticklabels ()],
6280 "y_scale" : ax .get_yscale (),
6381 "data" : [],
6482 }
6583
66- plot_type = get_type_of_plot (ax )
84+ plot_type = _get_type_of_plot (ax )
6785 ax_data ["type" ] = plot_type .value
6886
6987 if plot_type == PlotType .LINE :
0 commit comments