Skip to content

Commit d5b5c9f

Browse files
committed
Add units + improve graph type detection
1 parent f853217 commit d5b5c9f

File tree

4 files changed

+36
-12
lines changed

4 files changed

+36
-12
lines changed

python/tests/graphs/test_bar.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ async def test_graph_bar(async_sandbox: AsyncCodeInterpreter):
3636
assert graphs[0]['x_label'] == "Authors"
3737
assert graphs[0]['y_label'] == "Number of Books Sold"
3838

39+
assert graphs[0]['x_unit'] is None
40+
assert graphs[0]['y_unit'] is None
41+
3942
assert all(isinstance(x, int) for x in graphs[0]['x_ticks'])
4043
assert all(isinstance(y, float) for y in graphs[0]['y_ticks'])
4144

@@ -53,4 +56,3 @@ async def test_graph_bar(async_sandbox: AsyncCodeInterpreter):
5356
assert all(isinstance(x, (int, float)) for x in data['heights'])
5457
assert all(isinstance(y, (int, float)) for y in data['x'])
5558
assert all(isinstance(y, (int, float)) for y in data['y'])
56-

python/tests/graphs/test_line.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
plt.plot(x, y_cos, label='cos(x)')
1818
1919
# Add labels and title
20-
plt.xlabel('x')
21-
plt.ylabel('y')
20+
plt.xlabel("Time (s)")
21+
plt.ylabel("Amplitude (Hz)")
2222
plt.title('Plot of sin(x) and cos(x)')
2323
2424
# Display the plot
@@ -37,8 +37,10 @@ async def test_line_graph(async_sandbox: AsyncCodeInterpreter):
3737

3838
assert graphs[0]['type'] == "line"
3939
assert graphs[0]['title'] == "Plot of sin(x) and cos(x)"
40-
assert graphs[0]['x_label'] == "x"
41-
assert graphs[0]['y_label'] == "y"
40+
assert graphs[0]['x_label'] == "Time (s)"
41+
assert graphs[0]['y_label'] == "Amplitude (Hz)"
42+
assert graphs[0]['x_unit'] == "s"
43+
assert graphs[0]['y_unit'] == "Hz"
4244

4345
assert all(isinstance(x, float) for x in graphs[0]['x_ticks'])
4446
assert all(isinstance(y, float) for y in graphs[0]['y_ticks'])

python/tests/graphs/test_pie.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ async def test_pie_graph(async_sandbox: AsyncCodeInterpreter):
3939
assert graphs[0]['title'] == "Will I wake up early tomorrow?"
4040
assert graphs[0]['x_label'] == "x"
4141
assert graphs[0]['y_label'] == "y"
42+
assert graphs[0]['x_unit'] is None
43+
assert graphs[0]['y_unit'] is None
4244

4345
assert len(graphs[0]['x_ticks']) == 0
4446
assert len(graphs[0]['y_ticks']) == 0

template/startup_scripts/0002_data.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import enum
2+
import re
3+
from typing import Optional
24

35
import pandas
46
from matplotlib.axes import Axes
57
from matplotlib.collections import PathCollection
68
from matplotlib.lines import Line2D
7-
from matplotlib.patches import Rectangle, Wedge
9+
from matplotlib.patches import Rectangle, Wedge, PathPatch
810
from matplotlib.pyplot import Figure
911
import IPython
1012

1113
from IPython.core.formatters import BaseFormatter
14+
from matplotlib.text import Text
1215
from traitlets.traitlets import Unicode, ObjectName
1316

1417

@@ -20,21 +23,34 @@ class PlotType(enum.Enum):
2023
UNKNOWN = "unknown"
2124

2225

23-
def get_type_of_plot(ax: Axes) -> PlotType:
26+
def _extract_units(label: str) -> Optional[str]:
27+
"""
28+
Function to extract units from labels
29+
"""
30+
# Look for units in parentheses or brackets
31+
match = re.search(r"\s\((.*?)\)|\[(.*?)\]", label)
32+
if match:
33+
return match.group(1) or match.group(2) # return the matched unit
34+
return None # No units found
35+
36+
37+
def _get_type_of_plot(ax: Axes) -> PlotType:
38+
objects = list(filter(lambda obj: not isinstance(obj, Text), ax._children))
39+
2440
# Check for Line plots
25-
if any(isinstance(line, Line2D) for line in ax.get_lines()):
41+
if all(isinstance(line, Line2D) for line in objects):
2642
return PlotType.LINE
2743

2844
# Check for Scatter plots
29-
if any(isinstance(collection, PathCollection) for collection in ax.collections):
45+
if all(isinstance(path, PathCollection) for path in objects):
3046
return PlotType.SCATTER
3147

3248
# Check for Pie plots
33-
if any(isinstance(artist, Wedge) for artist in ax.patches):
49+
if all(isinstance(artist, Wedge) for artist in objects):
3450
return PlotType.PIE
3551

3652
# Check for Bar plots
37-
if any(isinstance(rect, Rectangle) for rect in ax.patches):
53+
if all(isinstance(rect, Rectangle) for rect in objects):
3854
return PlotType.BAR
3955

4056
return PlotType.UNKNOWN
@@ -53,17 +69,19 @@ def _figure_repr_e2b_data_(self: Figure):
5369
ax_data = {
5470
"title": ax.get_title(),
5571
"x_label": ax.get_xlabel(),
72+
"x_unit": _extract_units(ax.get_xlabel()),
5673
"x_ticks": ax.get_xticks(),
5774
"x_tick_labels": [label.get_text() for label in ax.get_xticklabels()],
5875
"x_scale": ax.get_xscale(),
5976
"y_label": ax.get_ylabel(),
77+
"y_unit": _extract_units(ax.get_ylabel()),
6078
"y_ticks": ax.get_yticks(),
6179
"y_tick_labels": [label.get_text() for label in ax.get_yticklabels()],
6280
"y_scale": ax.get_yscale(),
6381
"data": [],
6482
}
6583

66-
plot_type = get_type_of_plot(ax)
84+
plot_type = _get_type_of_plot(ax)
6785
ax_data["type"] = plot_type.value
6886

6987
if plot_type == PlotType.LINE:

0 commit comments

Comments
 (0)