From 478880b7f9bf7aac61cda2dddbca8f9e7798014c Mon Sep 17 00:00:00 2001 From: JacksonBurns Date: Wed, 4 Aug 2021 16:38:01 -0400 Subject: [PATCH 1/3] add support for writing output to different file --- pylustrator/QtGuiDrag.py | 10 ++++--- pylustrator/change_tracker.py | 49 ++++++++++++++++++++++++++++------- temp.py | 13 ++++++++++ thisisatest.py | 9 +++++++ 4 files changed, 68 insertions(+), 13 deletions(-) create mode 100644 temp.py create mode 100644 thisisatest.py diff --git a/pylustrator/QtGuiDrag.py b/pylustrator/QtGuiDrag.py index 09932d2..c5a8513 100644 --- a/pylustrator/QtGuiDrag.py +++ b/pylustrator/QtGuiDrag.py @@ -154,7 +154,7 @@ def new_call(self, *args, **kwargs): Colormap.__call__ = new_call -def figure(num=None, figsize=None, force_add=False, *args, **kwargs): +def figure(num=None, figsize=None, force_add=False, output_file: str = "source", *args, **kwargs): """ overloads the matplotlib figure call and wrapps the Figure in a PlotWindow """ global figures # if num is not defined create a new number @@ -163,7 +163,7 @@ def figure(num=None, figsize=None, force_add=False, *args, **kwargs): # if number is not defined if force_add or num not in _pylab_helpers.Gcf.figs.keys(): # create a new window and store it - canvas = PlotWindow(num, figsize, *args, **kwargs).canvas + canvas = PlotWindow(num, figsize, output_file, *args, **kwargs).canvas canvas.figure.number = num canvas.figure.clf() canvas.manager.num = num @@ -640,13 +640,15 @@ def execute_action(self, act: str): class PlotWindow(QtWidgets.QWidget): fitted_to_view = False - def __init__(self, number: int, size: tuple): + def __init__(self, number: int, size: tuple, output_file: str): """ The main window of pylustrator Args: number: the id of the figure size: the size of the figure + output_file: destination for generated code. Defaults to the source file. """ + self.output_file = output_file QtWidgets.QWidget.__init__(self) self.canvas_canvas = QtWidgets.QWidget() @@ -797,7 +799,7 @@ def rasterize(self, rasterize: bool): def actionSave(self): """ save the code for the figure """ - self.fig.change_tracker.save() + self.fig.change_tracker.save(self.output_file) for _last_saved_figure, args, kwargs in getattr(self.fig, "_last_saved_figure", []): self.fig.savefig(_last_saved_figure, *args, **kwargs) diff --git a/pylustrator/change_tracker.py b/pylustrator/change_tracker.py index 7c615b4..f6df980 100644 --- a/pylustrator/change_tracker.py +++ b/pylustrator/change_tracker.py @@ -22,6 +22,7 @@ import re import sys import traceback +import os from typing import IO import matplotlib @@ -380,7 +381,7 @@ def getRef(obj): return output - def save(self): + def save(self, output_file): """ save the changes to the .py file """ header = [getReference(self.figure) + ".ax_dict = {ax.get_label(): ax for ax in " + getReference( self.figure) + ".axes}", "import matplotlib as mpl"] @@ -403,7 +404,7 @@ def save(self): if not block: block_id = getReference(self.figure, allow_using_variable_names=False) block = getTextFromFile(block_id, stack_position) - insertTextToFile(output, stack_position, block_id) + insertTextToFile(output, stack_position, block_id, output_file) self.saved = True @@ -498,7 +499,7 @@ def lineToId(line: str): return line -def insertTextToFile(new_block: str, stack_pos: traceback.FrameSummary, figure_id_line: str): +def insertTextToFile(new_block: str, stack_pos: traceback.FrameSummary, figure_id_line: str, output_file: str): """ insert a text block into a file """ figure_id_line = lineToId(figure_id_line) block = None @@ -564,9 +565,39 @@ def insertTextToFile(new_block: str, stack_pos: traceback.FrameSummary, figure_i # update the position of the entry point, as we have inserted stuff in the new file which can change the position stack_pos.lineno = lineno_stack - # now copy the temporary file over the old file - with open(stack_pos.filename + ".tmp", 'r', encoding="utf-8") as fp2: - with open(stack_pos.filename, 'w', encoding="utf-8") as fp1: - for line in fp2: - fp1.write(line) - print("save", figure_id_line, "to", stack_pos.filename, "line %d-%d" % (written, written_end)) + if output_file == "source": + # now copy the temporary file over the old file + with open(stack_pos.filename + ".tmp", 'r', encoding="utf-8") as fp2: + with open(stack_pos.filename, 'w', encoding="utf-8") as fp1: + for line in fp2: + fp1.write(line) + msg = "saved {} to {} (lines {}-{})".format( + figure_id_line, stack_pos.filename, written, written_end + ) + else: + # now copy the temporary file over the new file + with open(stack_pos.filename + ".tmp", 'r', encoding="utf-8") as fp2: + with open(output_file, 'w', encoding="utf-8") as fp1: + # write the import and plotting + fp1.write("import matplotlib.pyplot as plt\nplt.figure(1)\n") + # write the plotted data into the file + line = plt.gca().get_lines()[0] + xd = line.get_xdata() + yd = line.get_ydata() + fp1.write("plt.plot({},{})\n".format(str(xd).replace(" ",","), str(yd).replace(" ",","))) + # write only the changes made by pylustrator into the plot + start_writing = False + for line in fp2: + if "start: automatic generated code from pylustrator" in line: + start_writing = True + if start_writing: + fp1.write(line) + if "end: automatic generated code from pylustrator" in line: + break + fp1.write("mpl.pyplot.show()") + msg = "saved {} to {} (lines {}-{})".format( + figure_id_line, output_file, written, written_end + ) + print(msg) + # remove the temporary file + os.remove(stack_pos.filename + ".tmp") diff --git a/temp.py b/temp.py new file mode 100644 index 0000000..8167bf1 --- /dev/null +++ b/temp.py @@ -0,0 +1,13 @@ + +import matplotlib.pyplot as plt + +import pylustrator + +pylustrator.start() + +plt.figure(output_file="thisisatest.py") +plt.plot([1,2,3],[1,2,3]) + + +plt.show() + diff --git a/thisisatest.py b/thisisatest.py new file mode 100644 index 0000000..647e9dc --- /dev/null +++ b/thisisatest.py @@ -0,0 +1,9 @@ +import matplotlib.pyplot as plt +plt.figure(1) +plt.plot([1,2,3],[1,2,3]) +#% start: automatic generated code from pylustrator +plt.figure(1).ax_dict = {ax.get_label(): ax for ax in plt.figure(1).axes} +import matplotlib as mpl +plt.figure(1).axes[0].set_position([0.210938, 0.049583, 0.775000, 0.770000]) +#% end: automatic generated code from pylustrator +mpl.pyplot.show() \ No newline at end of file From ad5cf337b48cdadb0428d222f0d7ceea4dae67c7 Mon Sep 17 00:00:00 2001 From: JacksonBurns Date: Wed, 4 Aug 2021 17:46:22 -0400 Subject: [PATCH 2/3] new approach to plotting data retrieval --- complex_example.py | 29 +++++++++++++++++++++++++++ pylustrator/QtGuiDrag.py | 9 +++++---- pylustrator/change_tracker.py | 23 ++++++++++++---------- temp.py | 12 +++++++++--- thisisatest.py | 37 +++++++++++++++++++++++++++++++---- 5 files changed, 89 insertions(+), 21 deletions(-) create mode 100644 complex_example.py diff --git a/complex_example.py b/complex_example.py new file mode 100644 index 0000000..d9f11b3 --- /dev/null +++ b/complex_example.py @@ -0,0 +1,29 @@ +import numpy as np; np.random.seed(0) +uniform_data = np.random.rand(10, 12) + +import pylustrator + +pylustrator.start() + + + +plt.figure( + output_file="thisisatest.py", + placeholder=f""" +import matplotlib.pyplot as plt +import seaborn as sns; sns.set_theme() +ax = sns.heatmap({repr(uniform_data)}) + """ +) + +## complex imports and calls to plotting which we do not want to try and +## find from pylustrator + +## this is copied into the call to figure ## +import seaborn as sns; sns.set_theme() +import matplotlib.pyplot as plt +ax = sns.heatmap(uniform_data) +## ### ## + + +plt.show() diff --git a/pylustrator/QtGuiDrag.py b/pylustrator/QtGuiDrag.py index c5a8513..e93c13d 100644 --- a/pylustrator/QtGuiDrag.py +++ b/pylustrator/QtGuiDrag.py @@ -154,7 +154,7 @@ def new_call(self, *args, **kwargs): Colormap.__call__ = new_call -def figure(num=None, figsize=None, force_add=False, output_file: str = "source", *args, **kwargs): +def figure(num=None, figsize=None, force_add=False, output_file: str = "source", placeholder: str = "", *args, **kwargs): """ overloads the matplotlib figure call and wrapps the Figure in a PlotWindow """ global figures # if num is not defined create a new number @@ -163,7 +163,7 @@ def figure(num=None, figsize=None, force_add=False, output_file: str = "source", # if number is not defined if force_add or num not in _pylab_helpers.Gcf.figs.keys(): # create a new window and store it - canvas = PlotWindow(num, figsize, output_file, *args, **kwargs).canvas + canvas = PlotWindow(num, figsize, output_file, placeholder, *args, **kwargs).canvas canvas.figure.number = num canvas.figure.clf() canvas.manager.num = num @@ -640,7 +640,7 @@ def execute_action(self, act: str): class PlotWindow(QtWidgets.QWidget): fitted_to_view = False - def __init__(self, number: int, size: tuple, output_file: str): + def __init__(self, number: int, size: tuple, output_file: str, placeholder: str): """ The main window of pylustrator Args: @@ -649,6 +649,7 @@ def __init__(self, number: int, size: tuple, output_file: str): output_file: destination for generated code. Defaults to the source file. """ self.output_file = output_file + self.placeholder = placeholder QtWidgets.QWidget.__init__(self) self.canvas_canvas = QtWidgets.QWidget() @@ -799,7 +800,7 @@ def rasterize(self, rasterize: bool): def actionSave(self): """ save the code for the figure """ - self.fig.change_tracker.save(self.output_file) + self.fig.change_tracker.save(self.output_file, self.placeholder) for _last_saved_figure, args, kwargs in getattr(self.fig, "_last_saved_figure", []): self.fig.savefig(_last_saved_figure, *args, **kwargs) diff --git a/pylustrator/change_tracker.py b/pylustrator/change_tracker.py index f6df980..730065d 100644 --- a/pylustrator/change_tracker.py +++ b/pylustrator/change_tracker.py @@ -381,7 +381,7 @@ def getRef(obj): return output - def save(self, output_file): + def save(self, output_file, placeholder): """ save the changes to the .py file """ header = [getReference(self.figure) + ".ax_dict = {ax.get_label(): ax for ax in " + getReference( self.figure) + ".axes}", "import matplotlib as mpl"] @@ -404,7 +404,7 @@ def save(self, output_file): if not block: block_id = getReference(self.figure, allow_using_variable_names=False) block = getTextFromFile(block_id, stack_position) - insertTextToFile(output, stack_position, block_id, output_file) + insertTextToFile(output, stack_position, block_id, output_file, placeholder) self.saved = True @@ -499,7 +499,7 @@ def lineToId(line: str): return line -def insertTextToFile(new_block: str, stack_pos: traceback.FrameSummary, figure_id_line: str, output_file: str): +def insertTextToFile(new_block: str, stack_pos: traceback.FrameSummary, figure_id_line: str, output_file: str, placeholder: str): """ insert a text block into a file """ figure_id_line = lineToId(figure_id_line) block = None @@ -578,13 +578,16 @@ def insertTextToFile(new_block: str, stack_pos: traceback.FrameSummary, figure_i # now copy the temporary file over the new file with open(stack_pos.filename + ".tmp", 'r', encoding="utf-8") as fp2: with open(output_file, 'w', encoding="utf-8") as fp1: - # write the import and plotting - fp1.write("import matplotlib.pyplot as plt\nplt.figure(1)\n") - # write the plotted data into the file - line = plt.gca().get_lines()[0] - xd = line.get_xdata() - yd = line.get_ydata() - fp1.write("plt.plot({},{})\n".format(str(xd).replace(" ",","), str(yd).replace(" ",","))) + if not placeholder: + # write the import and plotting + fp1.write("import matplotlib.pyplot as plt\nplt.figure(1)\n") + # write the plotted data into the file + line = plt.gca().get_lines()[0] + xd = line.get_xdata() + yd = line.get_ydata() + fp1.write("plt.plot({},{})\n".format(str(xd).replace(" ", ","), str(yd).replace(" ", ","))) + else: + fp1.write(placeholder) # write only the changes made by pylustrator into the plot start_writing = False for line in fp2: diff --git a/temp.py b/temp.py index 8167bf1..93d257f 100644 --- a/temp.py +++ b/temp.py @@ -5,9 +5,15 @@ pylustrator.start() -plt.figure(output_file="thisisatest.py") -plt.plot([1,2,3],[1,2,3]) +a = [1,2,3] +b = [1,2,3] +plt.figure( + output_file="thisisatest.py", + placeholder="import matplotlib\nplt.plot({},{})".format( + str(a).replace(" ",","),str(b).replace(" ",",") + ) +) -plt.show() +plt.show() diff --git a/thisisatest.py b/thisisatest.py index 647e9dc..43dce47 100644 --- a/thisisatest.py +++ b/thisisatest.py @@ -1,9 +1,38 @@ +import numpy as np +import seaborn as sns; sns.set_theme() import matplotlib.pyplot as plt -plt.figure(1) -plt.plot([1,2,3],[1,2,3]) -#% start: automatic generated code from pylustrator +ax = sns.heatmap(np.array([[0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.4236548 , + 0.64589411, 0.43758721, 0.891773 , 0.96366276, 0.38344152, + 0.79172504, 0.52889492], + [0.56804456, 0.92559664, 0.07103606, 0.0871293 , 0.0202184 , + 0.83261985, 0.77815675, 0.87001215, 0.97861834, 0.79915856, + 0.46147936, 0.78052918], + [0.11827443, 0.63992102, 0.14335329, 0.94466892, 0.52184832, + 0.41466194, 0.26455561, 0.77423369, 0.45615033, 0.56843395, + 0.0187898 , 0.6176355 ], + [0.61209572, 0.616934 , 0.94374808, 0.6818203 , 0.3595079 , + 0.43703195, 0.6976312 , 0.06022547, 0.66676672, 0.67063787, + 0.21038256, 0.1289263 ], + [0.31542835, 0.36371077, 0.57019677, 0.43860151, 0.98837384, + 0.10204481, 0.20887676, 0.16130952, 0.65310833, 0.2532916 , + 0.46631077, 0.24442559], + [0.15896958, 0.11037514, 0.65632959, 0.13818295, 0.19658236, + 0.36872517, 0.82099323, 0.09710128, 0.83794491, 0.09609841, + 0.97645947, 0.4686512 ], + [0.97676109, 0.60484552, 0.73926358, 0.03918779, 0.28280696, + 0.12019656, 0.2961402 , 0.11872772, 0.31798318, 0.41426299, + 0.0641475 , 0.69247212], + [0.56660145, 0.26538949, 0.52324805, 0.09394051, 0.5759465 , + 0.9292962 , 0.31856895, 0.66741038, 0.13179786, 0.7163272 , + 0.28940609, 0.18319136], + [0.58651293, 0.02010755, 0.82894003, 0.00469548, 0.67781654, + 0.27000797, 0.73519402, 0.96218855, 0.24875314, 0.57615733, + 0.59204193, 0.57225191], + [0.22308163, 0.95274901, 0.44712538, 0.84640867, 0.69947928, + 0.29743695, 0.81379782, 0.39650574, 0.8811032 , 0.58127287, + 0.88173536, 0.69253159]])) + #% start: automatic generated code from pylustrator plt.figure(1).ax_dict = {ax.get_label(): ax for ax in plt.figure(1).axes} import matplotlib as mpl -plt.figure(1).axes[0].set_position([0.210938, 0.049583, 0.775000, 0.770000]) #% end: automatic generated code from pylustrator mpl.pyplot.show() \ No newline at end of file From 2e1b0389682e8ce302f70399b5e03486b1d5bb12 Mon Sep 17 00:00:00 2001 From: JacksonBurns Date: Thu, 19 Aug 2021 16:30:01 -0400 Subject: [PATCH 3/3] standalone file output working --- .gitignore | 1 + complex_example.py | 29 --- pylustrator/QtGuiDrag.py | 342 +++++++++++++++++++++----------- pylustrator/change_tracker.py | 355 ++++++++++++++++++++++++---------- sample_pylustrator_output.py | 72 +++++++ temp.py | 37 +++- 6 files changed, 588 insertions(+), 248 deletions(-) delete mode 100644 complex_example.py create mode 100644 sample_pylustrator_output.py diff --git a/.gitignore b/.gitignore index d7503a8..90e8c00 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea/ +.vscode/* docs/_build_html/ */__pycache__/ *.pyc diff --git a/complex_example.py b/complex_example.py deleted file mode 100644 index d9f11b3..0000000 --- a/complex_example.py +++ /dev/null @@ -1,29 +0,0 @@ -import numpy as np; np.random.seed(0) -uniform_data = np.random.rand(10, 12) - -import pylustrator - -pylustrator.start() - - - -plt.figure( - output_file="thisisatest.py", - placeholder=f""" -import matplotlib.pyplot as plt -import seaborn as sns; sns.set_theme() -ax = sns.heatmap({repr(uniform_data)}) - """ -) - -## complex imports and calls to plotting which we do not want to try and -## find from pylustrator - -## this is copied into the call to figure ## -import seaborn as sns; sns.set_theme() -import matplotlib.pyplot as plt -ax = sns.heatmap(uniform_data) -## ### ## - - -plt.show() diff --git a/pylustrator/QtGuiDrag.py b/pylustrator/QtGuiDrag.py index e93c13d..27624bc 100644 --- a/pylustrator/QtGuiDrag.py +++ b/pylustrator/QtGuiDrag.py @@ -75,8 +75,8 @@ def initialize(use_global_variable_names=False): plt.figure = figure patchColormapsWithMetaInfo() - #stack_call_position = traceback.extract_stack()[-2] - #stack_call_position.filename + # stack_call_position = traceback.extract_stack()[-2] + # stack_call_position.filename plt.keys_for_lines = keys_for_lines @@ -84,7 +84,9 @@ def initialize(use_global_variable_names=False): sf = Figure.savefig def savefig(self, filename, *args, **kwargs): - self._last_saved_figure = getattr(self, "_last_saved_figure", []) + [(filename, args, kwargs)] + self._last_saved_figure = getattr(self, "_last_saved_figure", []) + [ + (filename, args, kwargs) + ] sf(self, filename, *args, **kwargs) Figure.savefig = savefig @@ -98,14 +100,15 @@ def savefig(self, filename, *args, **kwargs): def show(hide_window: bool = False): - """ the function overloads the matplotlib show function. + """the function overloads the matplotlib show function. It opens a DragManager window instead of the default matplotlib window. """ global figures # set an application id, so that windows properly stacks them in the task bar - if sys.platform[:3] == 'win': + if sys.platform[:3] == "win": import ctypes - myappid = 'rgerum.pylustrator' # arbitrary string + + myappid = "rgerum.pylustrator" # arbitrary string ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(myappid) # iterate over figures for figure in _pylab_helpers.Gcf.figs.copy(): @@ -131,7 +134,7 @@ def show(hide_window: bool = False): class CmapColor(list): - """ a color like object that has the colormap as metadata """ + """a color like object that has the colormap as metadata""" def setMeta(self, value, cmap): self.value = value @@ -139,7 +142,7 @@ def setMeta(self, value, cmap): def patchColormapsWithMetaInfo(): - """ all colormaps now return color with metadata from which colormap the color came from """ + """all colormaps now return color with metadata from which colormap the color came from""" from matplotlib.colors import Colormap cm_call = Colormap.__call__ @@ -154,8 +157,16 @@ def new_call(self, *args, **kwargs): Colormap.__call__ = new_call -def figure(num=None, figsize=None, force_add=False, output_file: str = "source", placeholder: str = "", *args, **kwargs): - """ overloads the matplotlib figure call and wrapps the Figure in a PlotWindow """ +def figure( + num=None, + figsize=None, + force_add=False, + output_file: str = "source", + reqd_code: list = [], + *args, + **kwargs +): + """overloads the matplotlib figure call and wrapps the Figure in a PlotWindow""" global figures # if num is not defined create a new number if num is None: @@ -163,7 +174,9 @@ def figure(num=None, figsize=None, force_add=False, output_file: str = "source", # if number is not defined if force_add or num not in _pylab_helpers.Gcf.figs.keys(): # create a new window and store it - canvas = PlotWindow(num, figsize, output_file, placeholder, *args, **kwargs).canvas + canvas = PlotWindow( + num, figsize, output_file, reqd_code, *args, **kwargs + ).canvas canvas.figure.number = num canvas.figure.clf() canvas.manager.num = num @@ -172,7 +185,9 @@ def figure(num=None, figsize=None, force_add=False, output_file: str = "source", manager = _pylab_helpers.Gcf.figs[num] # set the size if it is defined if figsize is not None: - _pylab_helpers.Gcf.figs[num].window.setGeometry(100, 100, figsize[0] * 80, figsize[1] * 80) + _pylab_helpers.Gcf.figs[num].window.setGeometry( + 100, 100, figsize[0] * 80, figsize[1] * 80 + ) # set the figure as the active figure _pylab_helpers.Gcf.set_active(manager) # return the figure @@ -180,8 +195,9 @@ def figure(num=None, figsize=None, force_add=False, output_file: str = "source", def warnAboutTicks(fig): - """ warn if the tick labels and tick values do not match, to prevent users from accidently setting wrong tick values """ + """warn if the tick labels and tick values do not match, to prevent users from accidently setting wrong tick values""" import sys + for index, ax in enumerate(fig.axes): ticks = ax.get_yticks() labels = [t.get_text() for t in ax.get_yticklabels()] @@ -198,7 +214,14 @@ def warnAboutTicks(fig): ax_name = "#%d" % index else: ax_name = '"' + ax_name + '"' - print("Warning tick and label differ", t, l, "for axes", ax_name, file=sys.stderr) + print( + "Warning tick and label differ", + t, + l, + "for axes", + ax_name, + file=sys.stderr, + ) """ Window """ @@ -206,11 +229,11 @@ def warnAboutTicks(fig): class myTreeWidgetItem(QtGui.QStandardItem): def __init__(self, parent: QtWidgets.QWidget = None): - """ a tree view item to display the contents of the figure """ + """a tree view item to display the contents of the figure""" QtGui.QStandardItem.__init__(self, parent) def __lt__(self, otherItem: QtGui.QStandardItem): - """ how to sort the items """ + """how to sort the items""" if self.sort is None: return 0 return self.sort < otherItem.sort @@ -226,8 +249,10 @@ class MyTreeView(QtWidgets.QTreeView): last_selection = None last_hover = None - def __init__(self, parent: QtWidgets.QWidget, layout: QtWidgets.QLayout, fig: Figure): - """ A tree view to display the contents of a figure + def __init__( + self, parent: QtWidgets.QWidget, layout: QtWidgets.QLayout, fig: Figure + ): + """A tree view to display the contents of a figure Args: parent: the parent widget @@ -268,10 +293,17 @@ def __init__(self, parent: QtWidgets.QWidget, layout: QtWidgets.QLayout, fig: Fi self.expand(None) - def selectionChanged(self, selection: QtCore.QItemSelection, y: QtCore.QItemSelection): - """ when the selection in the tree view changes """ + def selectionChanged( + self, selection: QtCore.QItemSelection, y: QtCore.QItemSelection + ): + """when the selection in the tree view changes""" try: - entry = selection.indexes()[0].model().itemFromIndex(selection.indexes()[0]).entry + entry = ( + selection.indexes()[0] + .model() + .itemFromIndex(selection.indexes()[0]) + .entry + ) except IndexError: entry = None if self.last_selection != entry: @@ -279,7 +311,7 @@ def selectionChanged(self, selection: QtCore.QItemSelection, y: QtCore.QItemSele self.item_selected(entry) def setCurrentIndex(self, entry: Artist): - """ set the currently selected entry """ + """set the currently selected entry""" while entry: item = self.getItemFromEntry(entry) if item is not None: @@ -291,17 +323,17 @@ def setCurrentIndex(self, entry: Artist): return def treeClicked(self, index: QtCore.QModelIndex): - """ upon selecting one of the tree elements """ + """upon selecting one of the tree elements""" data = index.model().itemFromIndex(index).entry return self.item_clicked(data) def treeActivated(self, index: QtCore.QModelIndex): - """ upon selecting one of the tree elements """ + """upon selecting one of the tree elements""" data = index.model().itemFromIndex(index).entry return self.item_activated(data) def eventFilter(self, object: QtWidgets.QWidget, event: QtCore.QEvent): - """ event filter for tree view port to handle mouse over events and marker highlighting""" + """event filter for tree view port to handle mouse over events and marker highlighting""" if event.type() == QtCore.QEvent.HoverMove: index = self.indexAt(event.pos()) try: @@ -328,24 +360,24 @@ def eventFilter(self, object: QtWidgets.QWidget, event: QtCore.QEvent): return False def queryToExpandEntry(self, entry: Artist) -> list: - """ when expanding a tree item """ + """when expanding a tree item""" if entry is None: return [self.fig] return entry.get_children() def getParentEntry(self, entry: Artist) -> Artist: - """ get the parent of an item """ + """get the parent of an item""" return entry.tree_parent def getNameOfEntry(self, entry: Artist) -> str: - """ convert an entry to a string """ + """convert an entry to a string""" try: return str(entry) except AttributeError: return "unknown" def getIconOfEntry(self, entry: Artist) -> QtGui.QIcon: - """ get the icon of an entry """ + """get the icon of an entry""" if getattr(entry, "_draggable", None): if entry._draggable.connected: return qta.icon("fa.hand-paper-o") @@ -355,11 +387,11 @@ def getEntrySortRole(self, entry: Artist): return None def getKey(self, entry: Artist) -> Artist: - """ get the key of an entry, which is the entry itself """ + """get the key of an entry, which is the entry itself""" return entry def getItemFromEntry(self, entry: Artist) -> Optional[QtWidgets.QTreeWidgetItem]: - """ get the tree view item for the given artist """ + """get the tree view item for the given artist""" if entry is None: return None key = self.getKey(entry) @@ -369,12 +401,12 @@ def getItemFromEntry(self, entry: Artist) -> Optional[QtWidgets.QTreeWidgetItem] return None def setItemForEntry(self, entry: Artist, item: QtWidgets.QTreeWidgetItem): - """ store a new artist and tree view widget pair """ + """store a new artist and tree view widget pair""" key = self.getKey(entry) self.item_lookup[key] = item def expand(self, entry: Artist, force_reload: bool = True): - """ expand the children of a tree view item """ + """expand the children of a tree view item""" query = self.queryToExpandEntry(entry) parent_item = self.getItemFromEntry(entry) parent_entry = entry @@ -396,9 +428,11 @@ def expand(self, entry: Artist, force_reload: bool = True): for row, entry in enumerate(query): entry.tree_parent = parent_entry if 1: - if (isinstance(entry, mpl.spines.Spine) or - isinstance(entry, mpl.axis.XAxis) or - isinstance(entry, mpl.axis.YAxis)): + if ( + isinstance(entry, mpl.spines.Spine) + or isinstance(entry, mpl.axis.XAxis) + or isinstance(entry, mpl.axis.YAxis) + ): continue if isinstance(entry, mpl.text.Text) and entry.get_text() == "": continue @@ -416,7 +450,7 @@ def expand(self, entry: Artist, force_reload: bool = True): self.addChild(parent_item, entry) def addChild(self, parent_item: QtWidgets.QWidget, entry: Artist, row=None): - """ add a child to a tree view node """ + """add a child to a tree view node""" if parent_item is None: parent_item = self.model @@ -442,7 +476,9 @@ def addChild(self, parent_item: QtWidgets.QWidget, entry: Artist, row=None): self.setItemForEntry(entry, item) # add dummy child - if self.queryToExpandEntry(entry) is not None and len(self.queryToExpandEntry(entry)): + if self.queryToExpandEntry(entry) is not None and len( + self.queryToExpandEntry(entry) + ): child = QtGui.QStandardItem("loading") child.entry = None child.setEditable(False) @@ -452,7 +488,7 @@ def addChild(self, parent_item: QtWidgets.QWidget, entry: Artist, row=None): return item def TreeExpand(self, index): - """ expand a tree view node """ + """expand a tree view node""" # Get item and entry item = index.model().itemFromIndex(index) entry = item.entry @@ -468,8 +504,14 @@ def TreeExpand(self, index): thread.setDaemon(True) thread.start() - def updateEntry(self, entry: Artist, update_children: bool = False, insert_before: Artist = None, insert_after: Artist = None): - """ update a tree view node """ + def updateEntry( + self, + entry: Artist, + update_children: bool = False, + insert_before: Artist = None, + insert_after: Artist = None, + ): + """update a tree view node""" # get the tree view item for the database entry item = self.getItemFromEntry(entry) # if we haven't one yet, we have to create it @@ -537,7 +579,7 @@ def updateEntry(self, entry: Artist, update_children: bool = False, insert_befor self.expand(entry, force_reload=True) def deleteEntry(self, entry: Artist): - """ delete an entry from the tree """ + """delete an entry from the tree""" item = self.getItemFromEntry(entry) if item is None: return @@ -562,24 +604,29 @@ def deleteEntry(self, entry: Artist): class InfoDialog(QtWidgets.QWidget): def __init__(self, parent): - """ A dialog displaying the version number of pylustrator. + """A dialog displaying the version number of pylustrator. Args: parent: the parent widget """ QtWidgets.QWidget.__init__(self) self.setWindowTitle("Pylustrator - Info") - self.setWindowIcon(QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", "logo.ico"))) + self.setWindowIcon( + QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", "logo.ico")) + ) self.layout = QtWidgets.QVBoxLayout(self) self.label = QtWidgets.QLabel("") - pixmap = QtGui.QPixmap(os.path.join(os.path.dirname(__file__), "icons", "logo.png")) + pixmap = QtGui.QPixmap( + os.path.join(os.path.dirname(__file__), "icons", "logo.png") + ) self.label.setPixmap(pixmap) self.label.setMask(pixmap.mask()) self.layout.addWidget(self.label) import pylustrator + self.label = QtWidgets.QLabel("Version " + pylustrator.__version__ + "") font = self.label.font() font.setPointSize(16) @@ -591,7 +638,9 @@ def __init__(self, parent): self.label.setAlignment(QtCore.Qt.AlignCenter) self.layout.addWidget(self.label) - self.label = QtWidgets.QLabel("Documentation") + self.label = QtWidgets.QLabel( + "Documentation" + ) self.label.setAlignment(QtCore.Qt.AlignCenter) self.label.setTextInteractionFlags(QtCore.Qt.TextBrowserInteraction) self.label.setOpenExternalLinks(True) @@ -600,7 +649,7 @@ def __init__(self, parent): class Align(QtWidgets.QWidget): def __init__(self, layout: QtWidgets.QLayout, fig: Figure): - """ A widget that allows to align the elements of a multi selection. + """A widget that allows to align the elements of a multi selection. Args: layout: the layout to which to add the widget @@ -613,13 +662,36 @@ def __init__(self, layout: QtWidgets.QLayout, fig: Figure): self.layout = QtWidgets.QHBoxLayout(self) self.layout.setContentsMargins(0, 0, 0, 0) - actions = ["left_x", "center_x", "right_x", "distribute_x", "top_y", "center_y", "bottom_y", "distribute_y", "group"] - icons = ["left_x.png", "center_x.png", "right_x.png", "distribute_x.png", "top_y.png", "center_y.png", - "bottom_y.png", "distribute_y.png", "group.png"] + actions = [ + "left_x", + "center_x", + "right_x", + "distribute_x", + "top_y", + "center_y", + "bottom_y", + "distribute_y", + "group", + ] + icons = [ + "left_x.png", + "center_x.png", + "right_x.png", + "distribute_x.png", + "top_y.png", + "center_y.png", + "bottom_y.png", + "distribute_y.png", + "group.png", + ] self.buttons = [] for index, act in enumerate(actions): - button = QtWidgets.QPushButton(QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", icons[index])), - "") + button = QtWidgets.QPushButton( + QtGui.QIcon( + os.path.join(os.path.dirname(__file__), "icons", icons[index]) + ), + "", + ) self.layout.addWidget(button) button.clicked.connect(lambda x, act=act: self.execute_action(act)) self.buttons.append(button) @@ -631,7 +703,7 @@ def __init__(self, layout: QtWidgets.QLayout, fig: Figure): self.layout.addStretch() def execute_action(self, act: str): - """ execute an alignment action """ + """execute an alignment action""" self.fig.selection.align_points(act) self.fig.selection.update_selection_rectangles() self.fig.canvas.draw() @@ -640,8 +712,8 @@ def execute_action(self, act: str): class PlotWindow(QtWidgets.QWidget): fitted_to_view = False - def __init__(self, number: int, size: tuple, output_file: str, placeholder: str): - """ The main window of pylustrator + def __init__(self, number: int, size: tuple, output_file: str, reqd_code: list): + """The main window of pylustrator Args: number: the id of the figure @@ -649,13 +721,15 @@ def __init__(self, number: int, size: tuple, output_file: str, placeholder: str) output_file: destination for generated code. Defaults to the source file. """ self.output_file = output_file - self.placeholder = placeholder + self.reqd_code = reqd_code QtWidgets.QWidget.__init__(self) self.canvas_canvas = QtWidgets.QWidget() self.canvas_canvas.setMinimumHeight(400) self.canvas_canvas.setMinimumWidth(400) - self.canvas_canvas.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) + self.canvas_canvas.setSizePolicy( + QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding + ) self.canvas_canvas.setStyleSheet("background:white") self.canvas_canvas.setFocusPolicy(QtCore.Qt.StrongFocus) @@ -689,7 +763,9 @@ def mouseRelease(event): # widget layout and elements self.setWindowTitle("Figure %s - Pylustrator" % number) - self.setWindowIcon(QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", "logo.ico"))) + self.setWindowIcon( + QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", "logo.ico")) + ) layout_parent = QtWidgets.QVBoxLayout(self) self.menuBar = QtWidgets.QMenuBar() @@ -744,7 +820,9 @@ def mouseRelease(event): self.treeView = MyTreeView(self, self.layout_tools, self.fig) self.treeView.item_selected = self.elementSelected - self.input_properties = QItemProperties(self.layout_tools, self.fig, self.treeView, self) + self.input_properties = QItemProperties( + self.layout_tools, self.fig, self.treeView, self + ) self.input_align = Align(self.layout_tools, self.fig) # add plot layout @@ -760,14 +838,14 @@ def mouseRelease(event): self.fig.canvas.mpl_disconnect(self.fig.canvas.manager.key_press_handler_id) - self.fig.canvas.mpl_connect('scroll_event', self.scroll_event) - self.fig.canvas.mpl_connect('key_press_event', self.canvas_key_press) - self.fig.canvas.mpl_connect('key_release_event', self.canvas_key_release) + self.fig.canvas.mpl_connect("scroll_event", self.scroll_event) + self.fig.canvas.mpl_connect("key_press_event", self.canvas_key_press) + self.fig.canvas.mpl_connect("key_release_event", self.canvas_key_release) self.control_modifier = False - self.fig.canvas.mpl_connect('button_press_event', self.button_press_event) - self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_move_event) - self.fig.canvas.mpl_connect('button_release_event', self.button_release_event) + self.fig.canvas.mpl_connect("button_press_event", self.button_press_event) + self.fig.canvas.mpl_connect("motion_notify_event", self.mouse_move_event) + self.fig.canvas.mpl_connect("button_release_event", self.button_release_event) self.drag = None self.footer_layout = QtWidgets.QHBoxLayout() @@ -782,12 +860,13 @@ def mouseRelease(event): self.footer_layout.addWidget(self.footer_label2) from .QtGui import ColorChooserWidget + self.colorWidget = ColorChooserWidget(self, self.canvas) self.colorWidget.setMaximumWidth(150) self.layout_main.addWidget(self.colorWidget) def rasterize(self, rasterize: bool): - """ convert the figur elements to an image """ + """convert the figur elements to an image""" if len(self.fig.selection.targets): self.fig.figure_dragger.select_element(None) if rasterize: @@ -799,15 +878,21 @@ def rasterize(self, rasterize: bool): self.fig.canvas.draw() def actionSave(self): - """ save the code for the figure """ - self.fig.change_tracker.save(self.output_file, self.placeholder) - for _last_saved_figure, args, kwargs in getattr(self.fig, "_last_saved_figure", []): + """save the code for the figure""" + self.fig.change_tracker.save(self.output_file, self.reqd_code) + for _last_saved_figure, args, kwargs in getattr( + self.fig, "_last_saved_figure", [] + ): self.fig.savefig(_last_saved_figure, *args, **kwargs) def actionSaveImage(self): - """ save figure as an image """ - path = QtWidgets.QFileDialog.getSaveFileName(self, "Save Image", getattr(self.fig, "_last_saved_figure", [(None,)])[0][0], - "Images (*.png *.jpg *.pdf)") + """save figure as an image""" + path = QtWidgets.QFileDialog.getSaveFileName( + self, + "Save Image", + getattr(self.fig, "_last_saved_figure", [(None,)])[0][0], + "Images (*.png *.jpg *.pdf)", + ) if isinstance(path, tuple): path = str(path[0]) else: @@ -821,13 +906,16 @@ def actionSaveImage(self): print("Saved plot image as", path) def showInfo(self): - """ show the info dialog """ + """show the info dialog""" self.info_dialog = InfoDialog(self) self.info_dialog.show() def updateRuler(self): - """ update the ruler around the figure to show the dimensions """ - trans = transforms.Affine2D().scale(1. / 2.54, 1. / 2.54) + self.fig.dpi_scale_trans + """update the ruler around the figure to show the dimensions""" + trans = ( + transforms.Affine2D().scale(1.0 / 2.54, 1.0 / 2.54) + + self.fig.dpi_scale_trans + ) l = 17 l1 = 13 l2 = 6 @@ -853,14 +941,19 @@ def updateRuler(self): end_x = np.ceil(trans.inverted().transform((-offset + w, 0))[0]) dx = 0.1 for i, pos_cm in enumerate(np.arange(start_x, end_x, dx)): - x = (trans.transform((pos_cm, 0))[0] + offset) + x = trans.transform((pos_cm, 0))[0] + offset if i % 10 == 0: painterX.drawLine(x, l - l1 - 1, x, l - 1) text = str("%d" % np.round(pos_cm)) o = 0 - painterX.drawText(x + 3, o, self.fontMetrics().width(text), o + self.fontMetrics().height(), - QtCore.Qt.AlignLeft, - text) + painterX.drawText( + x + 3, + o, + self.fontMetrics().width(text), + o + self.fontMetrics().height(), + QtCore.Qt.AlignLeft, + text, + ) elif i % 2 == 0: painterX.drawLine(x, l - l2 - 1, x, l - 1) else: @@ -877,14 +970,19 @@ def updateRuler(self): end_y = np.ceil(trans.inverted().transform((0, +offset))[1]) dy = 0.1 for i, pos_cm in enumerate(np.arange(start_y, end_y, dy)): - y = (-trans.transform((0, pos_cm))[1] + offset) + y = -trans.transform((0, pos_cm))[1] + offset if i % 10 == 0: painterY.drawLine(l - l1 - 1, y, l - 1, y) text = str("%d" % np.round(pos_cm)) o = 0 - painterY.drawText(o, y + 3, o + self.fontMetrics().width(text), self.fontMetrics().height(), - QtCore.Qt.AlignRight, - text) + painterY.drawText( + o, + y + 3, + o + self.fontMetrics().width(text), + self.fontMetrics().height(), + QtCore.Qt.AlignRight, + text, + ) elif i % 2 == 0: painterY.drawLine(l - l2 - 1, y, l - 1, y) else: @@ -919,33 +1017,38 @@ def updateRuler(self): self.shadow.setMaximumSize(w + 100, h + 10) def showEvent(self, event: QtCore.QEvent): - """ when the window is shown """ + """when the window is shown""" self.fitToView() self.updateRuler() self.colorWidget.updateColors() def resizeEvent(self, event: QtCore.QEvent): - """ when the window is resized """ + """when the window is resized""" if self.fitted_to_view: self.fitToView(True) else: self.updateRuler() def button_press_event(self, event: QtCore.QEvent): - """ when a mouse button is pressed """ + """when a mouse button is pressed""" if event.button == 2: self.drag = np.array([event.x, event.y]) def mouse_move_event(self, event: QtCore.QEvent): - """ when the mouse is moved """ + """when the mouse is moved""" if self.drag is not None: pos = np.array([event.x, event.y]) offset = pos - self.drag offset[1] = -offset[1] self.moveCanvasCanvas(*offset) - trans = transforms.Affine2D().scale(2.54, 2.54) + self.fig.dpi_scale_trans.inverted() + trans = ( + transforms.Affine2D().scale(2.54, 2.54) + + self.fig.dpi_scale_trans.inverted() + ) pos = trans.transform((event.x, event.y)) - self.footer_label.setText("%.2f, %.2f (cm) [%d, %d]" % (pos[0], pos[1], event.x, event.y)) + self.footer_label.setText( + "%.2f, %.2f (cm) [%d, %d]" % (pos[0], pos[1], event.x, event.y) + ) if event.ydata is not None: self.footer_label2.setText("%.2f, %.2f" % (event.xdata, event.ydata)) @@ -953,29 +1056,29 @@ def mouse_move_event(self, event: QtCore.QEvent): self.footer_label2.setText("") def button_release_event(self, event: QtCore.QEvent): - """ when the mouse button is released """ + """when the mouse button is released""" if event.button == 2: self.drag = None def canvas_key_press(self, event: QtCore.QEvent): - """ when a key in the canvas widget is pressed """ + """when a key in the canvas widget is pressed""" if event.key == "control": self.control_modifier = True def canvas_key_release(self, event: QtCore.QEvent): - """ when a key in the canvas widget is released """ + """when a key in the canvas widget is released""" if event.key == "control": self.control_modifier = False def moveCanvasCanvas(self, offset_x: float, offset_y: float): - """ when the canvas is panned """ + """when the canvas is panned""" p = self.canvas_container.pos() self.canvas_container.move(p.x() + offset_x, p.y() + offset_y) self.updateRuler() def keyPressEvent(self, event: QtCore.QEvent): - """ when a key is pressed """ + """when a key is pressed""" if event.key() == QtCore.Qt.Key_Control: self.control_modifier = True if event.key() == QtCore.Qt.Key_Left: @@ -991,11 +1094,14 @@ def keyPressEvent(self, event: QtCore.QEvent): self.fitToView(True) def fitToView(self, change_dpi: bool = False): - """ fit the figure to the view """ + """fit the figure to the view""" self.fitted_to_view = True if change_dpi: w, h = self.canvas.get_width_height() - factor = min((self.canvas_canvas.width() - 30) / w, (self.canvas_canvas.height() - 30) / h) + factor = min( + (self.canvas_canvas.width() - 30) / w, + (self.canvas_canvas.height() - 30) / h, + ) self.fig.set_dpi(self.fig.get_dpi() * factor) self.fig.canvas.draw() @@ -1004,8 +1110,10 @@ def fitToView(self, change_dpi: bool = False): self.canvas_container.setMinimumSize(w, h) self.canvas_container.setMaximumSize(w, h) - self.canvas_container.move((self.canvas_canvas.width() - w) / 2 + 5, - (self.canvas_canvas.height() - h) / 2 + 5) + self.canvas_container.move( + (self.canvas_canvas.width() - w) / 2 + 5, + (self.canvas_canvas.height() - h) / 2 + 5, + ) self.updateRuler() self.fig.canvas.draw() @@ -1015,17 +1123,19 @@ def fitToView(self, change_dpi: bool = False): self.canvas_canvas.setMinimumWidth(w + 30) self.canvas_canvas.setMinimumHeight(h + 30) - self.canvas_container.move((self.canvas_canvas.width() - w) / 2 + 5, - (self.canvas_canvas.height() - h) / 2 + 5) + self.canvas_container.move( + (self.canvas_canvas.width() - w) / 2 + 5, + (self.canvas_canvas.height() - h) / 2 + 5, + ) self.updateRuler() def keyReleaseEvent(self, event: QtCore.QEvent): - """ when a key is released """ + """when a key is released""" if event.key() == QtCore.Qt.Key_Control: self.control_modifier = False def scroll_event(self, event: QtCore.QEvent): - """ when the mouse wheel is used to zoom the figure """ + """when the mouse wheel is used to zoom the figure""" if self.control_modifier: new_dpi = self.fig.get_dpi() + 10 * event.step @@ -1052,22 +1162,22 @@ def scroll_event(self, event: QtCore.QEvent): bb = self.fig.axes[0].get_position() def updateFigureSize(self): - """ update the size of the figure """ + """update the size of the figure""" w, h = self.canvas.get_width_height() self.canvas_container.setMinimumSize(w, h) self.canvas_container.setMaximumSize(w, h) def changedFigureSize(self, size: tuple): - """ change the size of the figure """ + """change the size of the figure""" self.fig.set_size_inches(np.array(size) / 2.54) self.fig.canvas.draw() def elementSelected(self, element: Artist): - """ when an element is selected """ + """when an element is selected""" self.input_properties.setElement(element) def update(self): - """ update the tree view """ + """update the tree view""" # self.input_size.setValue(np.array(self.fig.get_size_inches())*2.54) self.treeView.deleteEntry(self.fig) self.treeView.expand(None) @@ -1096,14 +1206,14 @@ def newfunc(*args): self.treeView.setCurrentIndex(self.fig) def updateTitle(self): - """ update the title of the window to display if it is saved or not """ + """update the title of the window to display if it is saved or not""" if self.fig.change_tracker.saved: self.setWindowTitle("Figure %s - Pylustrator" % self.fig.number) else: self.setWindowTitle("Figure %s* - Pylustrator" % self.fig.number) def select_element(self, element: Artist): - """ select an element """ + """select an element""" if element is None: self.treeView.setCurrentIndex(self.fig) self.input_properties.setElement(self.fig) @@ -1112,12 +1222,18 @@ def select_element(self, element: Artist): self.input_properties.setElement(element) def closeEvent(self, event: QtCore.QEvent): - """ when the window is closed, ask the user to save """ + """when the window is closed, ask the user to save""" if not self.fig.change_tracker.saved: - reply = QtWidgets.QMessageBox.question(self, 'Warning - Pylustrator', 'The figure has not been saved. ' - 'All data will be lost.\nDo you want to save it?', - QtWidgets.QMessageBox.Cancel | QtWidgets.QMessageBox.No | QtWidgets.QMessageBox.Yes, - QtWidgets.QMessageBox.Yes) + reply = QtWidgets.QMessageBox.question( + self, + "Warning - Pylustrator", + "The figure has not been saved. " + "All data will be lost.\nDo you want to save it?", + QtWidgets.QMessageBox.Cancel + | QtWidgets.QMessageBox.No + | QtWidgets.QMessageBox.Yes, + QtWidgets.QMessageBox.Yes, + ) if reply == QtWidgets.QMessageBox.Cancel: event.ignore() diff --git a/pylustrator/change_tracker.py b/pylustrator/change_tracker.py index 730065d..366864c 100644 --- a/pylustrator/change_tracker.py +++ b/pylustrator/change_tracker.py @@ -23,8 +23,13 @@ import sys import traceback import os +import inspect +import warnings +import types from typing import IO +import numpy as np + import matplotlib import matplotlib as mpl import matplotlib.pyplot as plt @@ -43,12 +48,17 @@ """ External overload """ + + class CustomStackPosition: filename = None lineno = None + def __init__(self, filename, lineno): self.filename = filename self.lineno = lineno + + custom_stack_position = None custom_prepend = "" custom_append = "" @@ -57,20 +67,24 @@ def __init__(self, filename, lineno): ("\\", "\\\\"), ("\n", "\\n"), ("\r", "\\r"), - ("\"", "\\\""), + ('"', '\\"'), ] + + def escape_string(str): for pair in escape_pairs: str = str.replace(pair[0], pair[1]) return str + def unescape_string(str): for pair in escape_pairs: str = str.replace(pair[1], pair[0]) return str + def getReference(element: Artist, allow_using_variable_names=True): - """ get the code string that represents the given Artist. """ + """get the code string that represents the given Artist.""" if element is None: return "" if isinstance(element, Figure): @@ -81,7 +95,7 @@ def getReference(element: Artist, allow_using_variable_names=True): if isinstance(element.number, (float, int)): return "plt.figure(%s)" % element.number else: - return "plt.figure(\"%s\")" % element.number + return 'plt.figure("%s")' % element.number if isinstance(element, matplotlib.lines.Line2D): index = element.axes.lines.index(element) return getReference(element.axes) + ".lines[%d]" % index @@ -119,30 +133,56 @@ def getReference(element: Artist, allow_using_variable_names=True): for index, label in enumerate(axes.get_xaxis().get_major_ticks()): if element == label.label1: - return getReference(axes) + ".get_xaxis().get_major_ticks()[%d].label1" % index + return ( + getReference(axes) + + ".get_xaxis().get_major_ticks()[%d].label1" % index + ) if element == label.label2: - return getReference(axes) + ".get_xaxis().get_major_ticks()[%d].label2" % index + return ( + getReference(axes) + + ".get_xaxis().get_major_ticks()[%d].label2" % index + ) for index, label in enumerate(axes.get_xaxis().get_minor_ticks()): if element == label.label1: - return getReference(axes) + ".get_xaxis().get_minor_ticks()[%d].label1" % index + return ( + getReference(axes) + + ".get_xaxis().get_minor_ticks()[%d].label1" % index + ) if element == label.label2: - return getReference(axes) + ".get_xaxis().get_minor_ticks()[%d].label2" % index + return ( + getReference(axes) + + ".get_xaxis().get_minor_ticks()[%d].label2" % index + ) for axes in element.figure.axes: for index, label in enumerate(axes.get_yaxis().get_major_ticks()): if element == label.label1: - return getReference(axes) + ".get_yaxis().get_major_ticks()[%d].label1" % index + return ( + getReference(axes) + + ".get_yaxis().get_major_ticks()[%d].label1" % index + ) if element == label.label2: - return getReference(axes) + ".get_yaxis().get_major_ticks()[%d].label2" % index + return ( + getReference(axes) + + ".get_yaxis().get_major_ticks()[%d].label2" % index + ) for index, label in enumerate(axes.get_yaxis().get_minor_ticks()): if element == label.label1: - return getReference(axes) + ".get_yaxis().get_minor_ticks()[%d].label1" % index + return ( + getReference(axes) + + ".get_yaxis().get_minor_ticks()[%d].label1" % index + ) if element == label.label2: - return getReference(axes) + ".get_yaxis().get_minor_ticks()[%d].label2" % index + return ( + getReference(axes) + + ".get_yaxis().get_minor_ticks()[%d].label2" % index + ) if isinstance(element, matplotlib.axes._axes.Axes): if element.get_label(): - return getReference(element.figure) + ".ax_dict[\"%s\"]" % escape_string(element.get_label()) + return getReference(element.figure) + '.ax_dict["%s"]' % escape_string( + element.get_label() + ) return getReference(element.figure) + ".axes[%d]" % element.number if isinstance(element, matplotlib.legend.Legend): @@ -151,8 +191,9 @@ def getReference(element: Artist, allow_using_variable_names=True): def setFigureVariableNames(figure: Figure): - """ get the global variable names that refer to the given figure """ + """get the global variable names that refer to the given figure""" import inspect + mpl_figure = _pylab_helpers.Gcf.figs[figure].canvas.figure calling_globals = inspect.stack()[2][0].f_globals fig_names = [ @@ -167,7 +208,8 @@ def setFigureVariableNames(figure: Figure): class ChangeTracker: - """ a class that records a list of the change to the figure """ + """a class that records a list of the change to the figure""" + changes = None saved = True @@ -193,18 +235,24 @@ def __init__(self, figure: Figure): self.load() - def addChange(self, command_obj: Artist, command: str, reference_obj: Artist = None, reference_command: str = None): - """ add a change """ + def addChange( + self, + command_obj: Artist, + command: str, + reference_obj: Artist = None, + reference_command: str = None, + ): + """add a change""" command = command.replace("\n", "\\n") if reference_obj is None: reference_obj = command_obj if reference_command is None: - reference_command, = re.match(r"(\.[^(=]*)", command).groups() + (reference_command,) = re.match(r"(\.[^(=]*)", command).groups() self.changes[reference_obj, reference_command] = (command_obj, command) self.saved = False def removeElement(self, element: Artist): - """ remove an Artis from the figure """ + """remove an Artis from the figure""" # create_key = key+".new" created_by_pylustrator = (element, ".new") in self.changes # delete changes related to this element @@ -220,14 +268,14 @@ def removeElement(self, element: Artist): self.figure.selection.remove_target(element) def addEdit(self, edit: list): - """ add an edit to the stored list of edits """ + """add an edit to the stored list of edits""" if self.last_edit < len(self.edits) - 1: - self.edits = self.edits[:self.last_edit + 1] + self.edits = self.edits[: self.last_edit + 1] self.edits.append(edit) self.last_edit = len(self.edits) - 1 def backEdit(self): - """ undo an edit in the list """ + """undo an edit in the list""" if self.last_edit < 0: return edit = self.edits[self.last_edit] @@ -236,7 +284,7 @@ def backEdit(self): self.figure.canvas.draw() def forwardEdit(self): - """ redo an edit """ + """redo an edit""" if self.last_edit >= len(self.edits) - 1: return edit = self.edits[self.last_edit + 1] @@ -245,22 +293,23 @@ def forwardEdit(self): self.figure.canvas.draw() def load(self): - """ load a set of changes from a script file. The changes are the code that pylustrator generated """ + """load a set of changes from a script file. The changes are the code that pylustrator generated""" regex = re.compile(r"(\.[^\(= ]*)(.*)") - command_obj_regexes = [getReference(self.figure), - r"plt\.figure\([^)]*\)", - r"fig", - r"\.ax_dict\[\"[^\"]*\"\]", - r"\.axes\[\d*\]", - r"\.texts\[\d*\]", - r"\.{title|_left_title|_right_title}", - r"\.lines\[\d*\]", - r"\.collections\[\d*\]", - r"\.patches\[\d*\]", - r"\.get_[xy]axis\(\)\.get_(major|minor)_ticks\(\)\[\d*\]", - r"\.get_[xy]axis\(\)\.get_label\(\)", - r"\.get_legend\(\)", - ] + command_obj_regexes = [ + getReference(self.figure), + r"plt\.figure\([^)]*\)", + r"fig", + r"\.ax_dict\[\"[^\"]*\"\]", + r"\.axes\[\d*\]", + r"\.texts\[\d*\]", + r"\.{title|_left_title|_right_title}", + r"\.lines\[\d*\]", + r"\.collections\[\d*\]", + r"\.patches\[\d*\]", + r"\.get_[xy]axis\(\)\.get_(major|minor)_ticks\(\)\[\d*\]", + r"\.get_[xy]axis\(\)\.get_label\(\)", + r"\.get_legend\(\)", + ] command_obj_regexes = [re.compile(r) for r in command_obj_regexes] fig = self.figure @@ -272,7 +321,10 @@ def load(self): block = getTextFromFile(getReference(self.figure), stack_position) if not block: - block = getTextFromFile(getReference(self.figure, allow_using_variable_names=False), stack_position) + block = getTextFromFile( + getReference(self.figure, allow_using_variable_names=False), + stack_position, + ) for line in block: line = line.strip() if line == "" or line in header or line.startswith("#"): @@ -287,7 +339,7 @@ def load(self): for r in command_obj_regexes: try: found = r.match(line).group() - line = line[len(found):] + line = line[len(found) :] command_obj += found except AttributeError: pass @@ -307,7 +359,12 @@ def load(self): reference_obj = command_obj reference_command = command - if command == ".set_xticks" or command == ".set_yticks" or command == ".set_xlabels" or command == ".set_ylabels": + if ( + command == ".set_xticks" + or command == ".set_yticks" + or command == ".set_xlabels" + or command == ".set_ylabels" + ): if line.find("minor=True") != -1: reference_command = command + "_minor" @@ -322,17 +379,25 @@ def load(self): # if the reference object is just a dummy, we ignore it if isinstance(reference_obj, Dummy): - print("WARNING: line references a missing object, will remove line on save:", raw_line, file=sys.stderr) + print( + "WARNING: line references a missing object, will remove line on save:", + raw_line, + file=sys.stderr, + ) continue self.get_reference_cached[reference_obj] = reference_obj_str - #print("---", [reference_obj, reference_command], (command_obj, command + parameter)) - self.changes[reference_obj, reference_command] = (command_obj, command + parameter) + # print("---", [reference_obj, reference_command], (command_obj, command + parameter)) + self.changes[reference_obj, reference_command] = ( + command_obj, + command + parameter, + ) self.sorted_changes() def sorted_changes(self): - """ sort the changes by their priority. For example setting to logscale needs to be executed before xlim. """ + """sort the changes by their priority. For example setting to logscale needs to be executed before xlim.""" + def getRef(obj): try: return getReference(obj) @@ -352,21 +417,43 @@ def getRef(obj): if getattr(reference_obj, "axes", None) is not None: if reference_command == ".new": index = "0" - elif reference_command == ".set_xscale" or reference_command == ".set_yscale": + elif ( + reference_command == ".set_xscale" + or reference_command == ".set_yscale" + ): index = "1" - elif reference_command == ".set_xlim" or reference_command == ".set_ylim": + elif ( + reference_command == ".set_xlim" + or reference_command == ".set_ylim" + ): index = "2" - elif reference_command == ".set_xticks" or reference_command == ".set_yticks": + elif ( + reference_command == ".set_xticks" + or reference_command == ".set_yticks" + ): index = "3" - elif reference_command == ".set_xticklabels" or reference_command == ".set_yticklabels": + elif ( + reference_command == ".set_xticklabels" + or reference_command == ".set_yticklabels" + ): index = "4" else: index = "5" - obj_indices = (getRef(reference_obj.axes), getRef(reference_obj), index, reference_command) + obj_indices = ( + getRef(reference_obj.axes), + getRef(reference_obj), + index, + reference_command, + ) else: obj_indices = (getRef(reference_obj), "", "", reference_command) indices.append( - [(reference_obj, reference_command), self.changes[reference_obj, reference_command], obj_indices]) + [ + (reference_obj, reference_command), + self.changes[reference_obj, reference_command], + obj_indices, + ] + ) except (ValueError, TypeError) as err: print(err, file=sys.stderr) @@ -381,13 +468,20 @@ def getRef(obj): return output - def save(self, output_file, placeholder): - """ save the changes to the .py file """ - header = [getReference(self.figure) + ".ax_dict = {ax.get_label(): ax for ax in " + getReference( - self.figure) + ".axes}", "import matplotlib as mpl"] + def save(self, output_file, reqd_code): + """save the changes to the .py file""" + header = [ + getReference(self.figure) + + ".ax_dict = {ax.get_label(): ax for ax in " + + getReference(self.figure) + + ".axes}", + "import matplotlib as mpl", + ] # block = getTextFromFile(header[0], self.stack_position) - output = [custom_prepend + "#% start: automatic generated code from pylustrator"] + output = [ + custom_prepend + "#% start: automatic generated code from pylustrator" + ] # add the lines from the header for line in header: output.append(line) @@ -396,7 +490,9 @@ def save(self, output_file, placeholder): output.append(line) if line.startswith("fig.add_axes"): output.append(header[1]) - output.append("#% end: automatic generated code from pylustrator" + custom_append) + output.append( + "#% end: automatic generated code from pylustrator" + custom_append + ) # print("\n".join(output)) block_id = getReference(self.figure) @@ -404,27 +500,33 @@ def save(self, output_file, placeholder): if not block: block_id = getReference(self.figure, allow_using_variable_names=False) block = getTextFromFile(block_id, stack_position) - insertTextToFile(output, stack_position, block_id, output_file, placeholder) + insertTextToFile(output, stack_position, block_id, output_file, reqd_code) self.saved = True def getTextFromFile(block_id: str, stack_pos: traceback.FrameSummary): - """ get the text which corresponds to the block_id (e.g. which figure) at the given position sepcified by stack_pos. """ + """get the text which corresponds to the block_id (e.g. which figure) at the given position sepcified by stack_pos.""" block_id = lineToId(block_id) block = None if not custom_stack_position: - if not stack_pos.filename.endswith('.py') and not stack_pos.filename.startswith("