diff --git a/.gitignore b/.gitignore index d7503a8..90e8c00 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .idea/ +.vscode/* docs/_build_html/ */__pycache__/ *.pyc diff --git a/pylustrator/QtGuiDrag.py b/pylustrator/QtGuiDrag.py index 09932d2..27624bc 100644 --- a/pylustrator/QtGuiDrag.py +++ b/pylustrator/QtGuiDrag.py @@ -75,8 +75,8 @@ def initialize(use_global_variable_names=False): plt.figure = figure patchColormapsWithMetaInfo() - #stack_call_position = traceback.extract_stack()[-2] - #stack_call_position.filename + # stack_call_position = traceback.extract_stack()[-2] + # stack_call_position.filename plt.keys_for_lines = keys_for_lines @@ -84,7 +84,9 @@ def initialize(use_global_variable_names=False): sf = Figure.savefig def savefig(self, filename, *args, **kwargs): - self._last_saved_figure = getattr(self, "_last_saved_figure", []) + [(filename, args, kwargs)] + self._last_saved_figure = getattr(self, "_last_saved_figure", []) + [ + (filename, args, kwargs) + ] sf(self, filename, *args, **kwargs) Figure.savefig = savefig @@ -98,14 +100,15 @@ def savefig(self, filename, *args, **kwargs): def show(hide_window: bool = False): - """ the function overloads the matplotlib show function. + """the function overloads the matplotlib show function. It opens a DragManager window instead of the default matplotlib window. """ global figures # set an application id, so that windows properly stacks them in the task bar - if sys.platform[:3] == 'win': + if sys.platform[:3] == "win": import ctypes - myappid = 'rgerum.pylustrator' # arbitrary string + + myappid = "rgerum.pylustrator" # arbitrary string ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(myappid) # iterate over figures for figure in _pylab_helpers.Gcf.figs.copy(): @@ -131,7 +134,7 @@ def show(hide_window: bool = False): class CmapColor(list): - """ a color like object that has the colormap as metadata """ + """a color like object that has the colormap as metadata""" def setMeta(self, value, cmap): self.value = value @@ -139,7 +142,7 @@ def setMeta(self, value, cmap): def patchColormapsWithMetaInfo(): - """ all colormaps now return color with metadata from which colormap the color came from """ + """all colormaps now return color with metadata from which colormap the color came from""" from matplotlib.colors import Colormap cm_call = Colormap.__call__ @@ -154,8 +157,16 @@ def new_call(self, *args, **kwargs): Colormap.__call__ = new_call -def figure(num=None, figsize=None, force_add=False, *args, **kwargs): - """ overloads the matplotlib figure call and wrapps the Figure in a PlotWindow """ +def figure( + num=None, + figsize=None, + force_add=False, + output_file: str = "source", + reqd_code: list = [], + *args, + **kwargs +): + """overloads the matplotlib figure call and wrapps the Figure in a PlotWindow""" global figures # if num is not defined create a new number if num is None: @@ -163,7 +174,9 @@ def figure(num=None, figsize=None, force_add=False, *args, **kwargs): # if number is not defined if force_add or num not in _pylab_helpers.Gcf.figs.keys(): # create a new window and store it - canvas = PlotWindow(num, figsize, *args, **kwargs).canvas + canvas = PlotWindow( + num, figsize, output_file, reqd_code, *args, **kwargs + ).canvas canvas.figure.number = num canvas.figure.clf() canvas.manager.num = num @@ -172,7 +185,9 @@ def figure(num=None, figsize=None, force_add=False, *args, **kwargs): manager = _pylab_helpers.Gcf.figs[num] # set the size if it is defined if figsize is not None: - _pylab_helpers.Gcf.figs[num].window.setGeometry(100, 100, figsize[0] * 80, figsize[1] * 80) + _pylab_helpers.Gcf.figs[num].window.setGeometry( + 100, 100, figsize[0] * 80, figsize[1] * 80 + ) # set the figure as the active figure _pylab_helpers.Gcf.set_active(manager) # return the figure @@ -180,8 +195,9 @@ def figure(num=None, figsize=None, force_add=False, *args, **kwargs): def warnAboutTicks(fig): - """ warn if the tick labels and tick values do not match, to prevent users from accidently setting wrong tick values """ + """warn if the tick labels and tick values do not match, to prevent users from accidently setting wrong tick values""" import sys + for index, ax in enumerate(fig.axes): ticks = ax.get_yticks() labels = [t.get_text() for t in ax.get_yticklabels()] @@ -198,7 +214,14 @@ def warnAboutTicks(fig): ax_name = "#%d" % index else: ax_name = '"' + ax_name + '"' - print("Warning tick and label differ", t, l, "for axes", ax_name, file=sys.stderr) + print( + "Warning tick and label differ", + t, + l, + "for axes", + ax_name, + file=sys.stderr, + ) """ Window """ @@ -206,11 +229,11 @@ def warnAboutTicks(fig): class myTreeWidgetItem(QtGui.QStandardItem): def __init__(self, parent: QtWidgets.QWidget = None): - """ a tree view item to display the contents of the figure """ + """a tree view item to display the contents of the figure""" QtGui.QStandardItem.__init__(self, parent) def __lt__(self, otherItem: QtGui.QStandardItem): - """ how to sort the items """ + """how to sort the items""" if self.sort is None: return 0 return self.sort < otherItem.sort @@ -226,8 +249,10 @@ class MyTreeView(QtWidgets.QTreeView): last_selection = None last_hover = None - def __init__(self, parent: QtWidgets.QWidget, layout: QtWidgets.QLayout, fig: Figure): - """ A tree view to display the contents of a figure + def __init__( + self, parent: QtWidgets.QWidget, layout: QtWidgets.QLayout, fig: Figure + ): + """A tree view to display the contents of a figure Args: parent: the parent widget @@ -268,10 +293,17 @@ def __init__(self, parent: QtWidgets.QWidget, layout: QtWidgets.QLayout, fig: Fi self.expand(None) - def selectionChanged(self, selection: QtCore.QItemSelection, y: QtCore.QItemSelection): - """ when the selection in the tree view changes """ + def selectionChanged( + self, selection: QtCore.QItemSelection, y: QtCore.QItemSelection + ): + """when the selection in the tree view changes""" try: - entry = selection.indexes()[0].model().itemFromIndex(selection.indexes()[0]).entry + entry = ( + selection.indexes()[0] + .model() + .itemFromIndex(selection.indexes()[0]) + .entry + ) except IndexError: entry = None if self.last_selection != entry: @@ -279,7 +311,7 @@ def selectionChanged(self, selection: QtCore.QItemSelection, y: QtCore.QItemSele self.item_selected(entry) def setCurrentIndex(self, entry: Artist): - """ set the currently selected entry """ + """set the currently selected entry""" while entry: item = self.getItemFromEntry(entry) if item is not None: @@ -291,17 +323,17 @@ def setCurrentIndex(self, entry: Artist): return def treeClicked(self, index: QtCore.QModelIndex): - """ upon selecting one of the tree elements """ + """upon selecting one of the tree elements""" data = index.model().itemFromIndex(index).entry return self.item_clicked(data) def treeActivated(self, index: QtCore.QModelIndex): - """ upon selecting one of the tree elements """ + """upon selecting one of the tree elements""" data = index.model().itemFromIndex(index).entry return self.item_activated(data) def eventFilter(self, object: QtWidgets.QWidget, event: QtCore.QEvent): - """ event filter for tree view port to handle mouse over events and marker highlighting""" + """event filter for tree view port to handle mouse over events and marker highlighting""" if event.type() == QtCore.QEvent.HoverMove: index = self.indexAt(event.pos()) try: @@ -328,24 +360,24 @@ def eventFilter(self, object: QtWidgets.QWidget, event: QtCore.QEvent): return False def queryToExpandEntry(self, entry: Artist) -> list: - """ when expanding a tree item """ + """when expanding a tree item""" if entry is None: return [self.fig] return entry.get_children() def getParentEntry(self, entry: Artist) -> Artist: - """ get the parent of an item """ + """get the parent of an item""" return entry.tree_parent def getNameOfEntry(self, entry: Artist) -> str: - """ convert an entry to a string """ + """convert an entry to a string""" try: return str(entry) except AttributeError: return "unknown" def getIconOfEntry(self, entry: Artist) -> QtGui.QIcon: - """ get the icon of an entry """ + """get the icon of an entry""" if getattr(entry, "_draggable", None): if entry._draggable.connected: return qta.icon("fa.hand-paper-o") @@ -355,11 +387,11 @@ def getEntrySortRole(self, entry: Artist): return None def getKey(self, entry: Artist) -> Artist: - """ get the key of an entry, which is the entry itself """ + """get the key of an entry, which is the entry itself""" return entry def getItemFromEntry(self, entry: Artist) -> Optional[QtWidgets.QTreeWidgetItem]: - """ get the tree view item for the given artist """ + """get the tree view item for the given artist""" if entry is None: return None key = self.getKey(entry) @@ -369,12 +401,12 @@ def getItemFromEntry(self, entry: Artist) -> Optional[QtWidgets.QTreeWidgetItem] return None def setItemForEntry(self, entry: Artist, item: QtWidgets.QTreeWidgetItem): - """ store a new artist and tree view widget pair """ + """store a new artist and tree view widget pair""" key = self.getKey(entry) self.item_lookup[key] = item def expand(self, entry: Artist, force_reload: bool = True): - """ expand the children of a tree view item """ + """expand the children of a tree view item""" query = self.queryToExpandEntry(entry) parent_item = self.getItemFromEntry(entry) parent_entry = entry @@ -396,9 +428,11 @@ def expand(self, entry: Artist, force_reload: bool = True): for row, entry in enumerate(query): entry.tree_parent = parent_entry if 1: - if (isinstance(entry, mpl.spines.Spine) or - isinstance(entry, mpl.axis.XAxis) or - isinstance(entry, mpl.axis.YAxis)): + if ( + isinstance(entry, mpl.spines.Spine) + or isinstance(entry, mpl.axis.XAxis) + or isinstance(entry, mpl.axis.YAxis) + ): continue if isinstance(entry, mpl.text.Text) and entry.get_text() == "": continue @@ -416,7 +450,7 @@ def expand(self, entry: Artist, force_reload: bool = True): self.addChild(parent_item, entry) def addChild(self, parent_item: QtWidgets.QWidget, entry: Artist, row=None): - """ add a child to a tree view node """ + """add a child to a tree view node""" if parent_item is None: parent_item = self.model @@ -442,7 +476,9 @@ def addChild(self, parent_item: QtWidgets.QWidget, entry: Artist, row=None): self.setItemForEntry(entry, item) # add dummy child - if self.queryToExpandEntry(entry) is not None and len(self.queryToExpandEntry(entry)): + if self.queryToExpandEntry(entry) is not None and len( + self.queryToExpandEntry(entry) + ): child = QtGui.QStandardItem("loading") child.entry = None child.setEditable(False) @@ -452,7 +488,7 @@ def addChild(self, parent_item: QtWidgets.QWidget, entry: Artist, row=None): return item def TreeExpand(self, index): - """ expand a tree view node """ + """expand a tree view node""" # Get item and entry item = index.model().itemFromIndex(index) entry = item.entry @@ -468,8 +504,14 @@ def TreeExpand(self, index): thread.setDaemon(True) thread.start() - def updateEntry(self, entry: Artist, update_children: bool = False, insert_before: Artist = None, insert_after: Artist = None): - """ update a tree view node """ + def updateEntry( + self, + entry: Artist, + update_children: bool = False, + insert_before: Artist = None, + insert_after: Artist = None, + ): + """update a tree view node""" # get the tree view item for the database entry item = self.getItemFromEntry(entry) # if we haven't one yet, we have to create it @@ -537,7 +579,7 @@ def updateEntry(self, entry: Artist, update_children: bool = False, insert_befor self.expand(entry, force_reload=True) def deleteEntry(self, entry: Artist): - """ delete an entry from the tree """ + """delete an entry from the tree""" item = self.getItemFromEntry(entry) if item is None: return @@ -562,24 +604,29 @@ def deleteEntry(self, entry: Artist): class InfoDialog(QtWidgets.QWidget): def __init__(self, parent): - """ A dialog displaying the version number of pylustrator. + """A dialog displaying the version number of pylustrator. Args: parent: the parent widget """ QtWidgets.QWidget.__init__(self) self.setWindowTitle("Pylustrator - Info") - self.setWindowIcon(QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", "logo.ico"))) + self.setWindowIcon( + QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", "logo.ico")) + ) self.layout = QtWidgets.QVBoxLayout(self) self.label = QtWidgets.QLabel("") - pixmap = QtGui.QPixmap(os.path.join(os.path.dirname(__file__), "icons", "logo.png")) + pixmap = QtGui.QPixmap( + os.path.join(os.path.dirname(__file__), "icons", "logo.png") + ) self.label.setPixmap(pixmap) self.label.setMask(pixmap.mask()) self.layout.addWidget(self.label) import pylustrator + self.label = QtWidgets.QLabel("Version " + pylustrator.__version__ + "") font = self.label.font() font.setPointSize(16) @@ -591,7 +638,9 @@ def __init__(self, parent): self.label.setAlignment(QtCore.Qt.AlignCenter) self.layout.addWidget(self.label) - self.label = QtWidgets.QLabel("Documentation") + self.label = QtWidgets.QLabel( + "Documentation" + ) self.label.setAlignment(QtCore.Qt.AlignCenter) self.label.setTextInteractionFlags(QtCore.Qt.TextBrowserInteraction) self.label.setOpenExternalLinks(True) @@ -600,7 +649,7 @@ def __init__(self, parent): class Align(QtWidgets.QWidget): def __init__(self, layout: QtWidgets.QLayout, fig: Figure): - """ A widget that allows to align the elements of a multi selection. + """A widget that allows to align the elements of a multi selection. Args: layout: the layout to which to add the widget @@ -613,13 +662,36 @@ def __init__(self, layout: QtWidgets.QLayout, fig: Figure): self.layout = QtWidgets.QHBoxLayout(self) self.layout.setContentsMargins(0, 0, 0, 0) - actions = ["left_x", "center_x", "right_x", "distribute_x", "top_y", "center_y", "bottom_y", "distribute_y", "group"] - icons = ["left_x.png", "center_x.png", "right_x.png", "distribute_x.png", "top_y.png", "center_y.png", - "bottom_y.png", "distribute_y.png", "group.png"] + actions = [ + "left_x", + "center_x", + "right_x", + "distribute_x", + "top_y", + "center_y", + "bottom_y", + "distribute_y", + "group", + ] + icons = [ + "left_x.png", + "center_x.png", + "right_x.png", + "distribute_x.png", + "top_y.png", + "center_y.png", + "bottom_y.png", + "distribute_y.png", + "group.png", + ] self.buttons = [] for index, act in enumerate(actions): - button = QtWidgets.QPushButton(QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", icons[index])), - "") + button = QtWidgets.QPushButton( + QtGui.QIcon( + os.path.join(os.path.dirname(__file__), "icons", icons[index]) + ), + "", + ) self.layout.addWidget(button) button.clicked.connect(lambda x, act=act: self.execute_action(act)) self.buttons.append(button) @@ -631,7 +703,7 @@ def __init__(self, layout: QtWidgets.QLayout, fig: Figure): self.layout.addStretch() def execute_action(self, act: str): - """ execute an alignment action """ + """execute an alignment action""" self.fig.selection.align_points(act) self.fig.selection.update_selection_rectangles() self.fig.canvas.draw() @@ -640,19 +712,24 @@ def execute_action(self, act: str): class PlotWindow(QtWidgets.QWidget): fitted_to_view = False - def __init__(self, number: int, size: tuple): - """ The main window of pylustrator + def __init__(self, number: int, size: tuple, output_file: str, reqd_code: list): + """The main window of pylustrator Args: number: the id of the figure size: the size of the figure + output_file: destination for generated code. Defaults to the source file. """ + self.output_file = output_file + self.reqd_code = reqd_code QtWidgets.QWidget.__init__(self) self.canvas_canvas = QtWidgets.QWidget() self.canvas_canvas.setMinimumHeight(400) self.canvas_canvas.setMinimumWidth(400) - self.canvas_canvas.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) + self.canvas_canvas.setSizePolicy( + QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding + ) self.canvas_canvas.setStyleSheet("background:white") self.canvas_canvas.setFocusPolicy(QtCore.Qt.StrongFocus) @@ -686,7 +763,9 @@ def mouseRelease(event): # widget layout and elements self.setWindowTitle("Figure %s - Pylustrator" % number) - self.setWindowIcon(QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", "logo.ico"))) + self.setWindowIcon( + QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", "logo.ico")) + ) layout_parent = QtWidgets.QVBoxLayout(self) self.menuBar = QtWidgets.QMenuBar() @@ -741,7 +820,9 @@ def mouseRelease(event): self.treeView = MyTreeView(self, self.layout_tools, self.fig) self.treeView.item_selected = self.elementSelected - self.input_properties = QItemProperties(self.layout_tools, self.fig, self.treeView, self) + self.input_properties = QItemProperties( + self.layout_tools, self.fig, self.treeView, self + ) self.input_align = Align(self.layout_tools, self.fig) # add plot layout @@ -757,14 +838,14 @@ def mouseRelease(event): self.fig.canvas.mpl_disconnect(self.fig.canvas.manager.key_press_handler_id) - self.fig.canvas.mpl_connect('scroll_event', self.scroll_event) - self.fig.canvas.mpl_connect('key_press_event', self.canvas_key_press) - self.fig.canvas.mpl_connect('key_release_event', self.canvas_key_release) + self.fig.canvas.mpl_connect("scroll_event", self.scroll_event) + self.fig.canvas.mpl_connect("key_press_event", self.canvas_key_press) + self.fig.canvas.mpl_connect("key_release_event", self.canvas_key_release) self.control_modifier = False - self.fig.canvas.mpl_connect('button_press_event', self.button_press_event) - self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_move_event) - self.fig.canvas.mpl_connect('button_release_event', self.button_release_event) + self.fig.canvas.mpl_connect("button_press_event", self.button_press_event) + self.fig.canvas.mpl_connect("motion_notify_event", self.mouse_move_event) + self.fig.canvas.mpl_connect("button_release_event", self.button_release_event) self.drag = None self.footer_layout = QtWidgets.QHBoxLayout() @@ -779,12 +860,13 @@ def mouseRelease(event): self.footer_layout.addWidget(self.footer_label2) from .QtGui import ColorChooserWidget + self.colorWidget = ColorChooserWidget(self, self.canvas) self.colorWidget.setMaximumWidth(150) self.layout_main.addWidget(self.colorWidget) def rasterize(self, rasterize: bool): - """ convert the figur elements to an image """ + """convert the figur elements to an image""" if len(self.fig.selection.targets): self.fig.figure_dragger.select_element(None) if rasterize: @@ -796,15 +878,21 @@ def rasterize(self, rasterize: bool): self.fig.canvas.draw() def actionSave(self): - """ save the code for the figure """ - self.fig.change_tracker.save() - for _last_saved_figure, args, kwargs in getattr(self.fig, "_last_saved_figure", []): + """save the code for the figure""" + self.fig.change_tracker.save(self.output_file, self.reqd_code) + for _last_saved_figure, args, kwargs in getattr( + self.fig, "_last_saved_figure", [] + ): self.fig.savefig(_last_saved_figure, *args, **kwargs) def actionSaveImage(self): - """ save figure as an image """ - path = QtWidgets.QFileDialog.getSaveFileName(self, "Save Image", getattr(self.fig, "_last_saved_figure", [(None,)])[0][0], - "Images (*.png *.jpg *.pdf)") + """save figure as an image""" + path = QtWidgets.QFileDialog.getSaveFileName( + self, + "Save Image", + getattr(self.fig, "_last_saved_figure", [(None,)])[0][0], + "Images (*.png *.jpg *.pdf)", + ) if isinstance(path, tuple): path = str(path[0]) else: @@ -818,13 +906,16 @@ def actionSaveImage(self): print("Saved plot image as", path) def showInfo(self): - """ show the info dialog """ + """show the info dialog""" self.info_dialog = InfoDialog(self) self.info_dialog.show() def updateRuler(self): - """ update the ruler around the figure to show the dimensions """ - trans = transforms.Affine2D().scale(1. / 2.54, 1. / 2.54) + self.fig.dpi_scale_trans + """update the ruler around the figure to show the dimensions""" + trans = ( + transforms.Affine2D().scale(1.0 / 2.54, 1.0 / 2.54) + + self.fig.dpi_scale_trans + ) l = 17 l1 = 13 l2 = 6 @@ -850,14 +941,19 @@ def updateRuler(self): end_x = np.ceil(trans.inverted().transform((-offset + w, 0))[0]) dx = 0.1 for i, pos_cm in enumerate(np.arange(start_x, end_x, dx)): - x = (trans.transform((pos_cm, 0))[0] + offset) + x = trans.transform((pos_cm, 0))[0] + offset if i % 10 == 0: painterX.drawLine(x, l - l1 - 1, x, l - 1) text = str("%d" % np.round(pos_cm)) o = 0 - painterX.drawText(x + 3, o, self.fontMetrics().width(text), o + self.fontMetrics().height(), - QtCore.Qt.AlignLeft, - text) + painterX.drawText( + x + 3, + o, + self.fontMetrics().width(text), + o + self.fontMetrics().height(), + QtCore.Qt.AlignLeft, + text, + ) elif i % 2 == 0: painterX.drawLine(x, l - l2 - 1, x, l - 1) else: @@ -874,14 +970,19 @@ def updateRuler(self): end_y = np.ceil(trans.inverted().transform((0, +offset))[1]) dy = 0.1 for i, pos_cm in enumerate(np.arange(start_y, end_y, dy)): - y = (-trans.transform((0, pos_cm))[1] + offset) + y = -trans.transform((0, pos_cm))[1] + offset if i % 10 == 0: painterY.drawLine(l - l1 - 1, y, l - 1, y) text = str("%d" % np.round(pos_cm)) o = 0 - painterY.drawText(o, y + 3, o + self.fontMetrics().width(text), self.fontMetrics().height(), - QtCore.Qt.AlignRight, - text) + painterY.drawText( + o, + y + 3, + o + self.fontMetrics().width(text), + self.fontMetrics().height(), + QtCore.Qt.AlignRight, + text, + ) elif i % 2 == 0: painterY.drawLine(l - l2 - 1, y, l - 1, y) else: @@ -916,33 +1017,38 @@ def updateRuler(self): self.shadow.setMaximumSize(w + 100, h + 10) def showEvent(self, event: QtCore.QEvent): - """ when the window is shown """ + """when the window is shown""" self.fitToView() self.updateRuler() self.colorWidget.updateColors() def resizeEvent(self, event: QtCore.QEvent): - """ when the window is resized """ + """when the window is resized""" if self.fitted_to_view: self.fitToView(True) else: self.updateRuler() def button_press_event(self, event: QtCore.QEvent): - """ when a mouse button is pressed """ + """when a mouse button is pressed""" if event.button == 2: self.drag = np.array([event.x, event.y]) def mouse_move_event(self, event: QtCore.QEvent): - """ when the mouse is moved """ + """when the mouse is moved""" if self.drag is not None: pos = np.array([event.x, event.y]) offset = pos - self.drag offset[1] = -offset[1] self.moveCanvasCanvas(*offset) - trans = transforms.Affine2D().scale(2.54, 2.54) + self.fig.dpi_scale_trans.inverted() + trans = ( + transforms.Affine2D().scale(2.54, 2.54) + + self.fig.dpi_scale_trans.inverted() + ) pos = trans.transform((event.x, event.y)) - self.footer_label.setText("%.2f, %.2f (cm) [%d, %d]" % (pos[0], pos[1], event.x, event.y)) + self.footer_label.setText( + "%.2f, %.2f (cm) [%d, %d]" % (pos[0], pos[1], event.x, event.y) + ) if event.ydata is not None: self.footer_label2.setText("%.2f, %.2f" % (event.xdata, event.ydata)) @@ -950,29 +1056,29 @@ def mouse_move_event(self, event: QtCore.QEvent): self.footer_label2.setText("") def button_release_event(self, event: QtCore.QEvent): - """ when the mouse button is released """ + """when the mouse button is released""" if event.button == 2: self.drag = None def canvas_key_press(self, event: QtCore.QEvent): - """ when a key in the canvas widget is pressed """ + """when a key in the canvas widget is pressed""" if event.key == "control": self.control_modifier = True def canvas_key_release(self, event: QtCore.QEvent): - """ when a key in the canvas widget is released """ + """when a key in the canvas widget is released""" if event.key == "control": self.control_modifier = False def moveCanvasCanvas(self, offset_x: float, offset_y: float): - """ when the canvas is panned """ + """when the canvas is panned""" p = self.canvas_container.pos() self.canvas_container.move(p.x() + offset_x, p.y() + offset_y) self.updateRuler() def keyPressEvent(self, event: QtCore.QEvent): - """ when a key is pressed """ + """when a key is pressed""" if event.key() == QtCore.Qt.Key_Control: self.control_modifier = True if event.key() == QtCore.Qt.Key_Left: @@ -988,11 +1094,14 @@ def keyPressEvent(self, event: QtCore.QEvent): self.fitToView(True) def fitToView(self, change_dpi: bool = False): - """ fit the figure to the view """ + """fit the figure to the view""" self.fitted_to_view = True if change_dpi: w, h = self.canvas.get_width_height() - factor = min((self.canvas_canvas.width() - 30) / w, (self.canvas_canvas.height() - 30) / h) + factor = min( + (self.canvas_canvas.width() - 30) / w, + (self.canvas_canvas.height() - 30) / h, + ) self.fig.set_dpi(self.fig.get_dpi() * factor) self.fig.canvas.draw() @@ -1001,8 +1110,10 @@ def fitToView(self, change_dpi: bool = False): self.canvas_container.setMinimumSize(w, h) self.canvas_container.setMaximumSize(w, h) - self.canvas_container.move((self.canvas_canvas.width() - w) / 2 + 5, - (self.canvas_canvas.height() - h) / 2 + 5) + self.canvas_container.move( + (self.canvas_canvas.width() - w) / 2 + 5, + (self.canvas_canvas.height() - h) / 2 + 5, + ) self.updateRuler() self.fig.canvas.draw() @@ -1012,17 +1123,19 @@ def fitToView(self, change_dpi: bool = False): self.canvas_canvas.setMinimumWidth(w + 30) self.canvas_canvas.setMinimumHeight(h + 30) - self.canvas_container.move((self.canvas_canvas.width() - w) / 2 + 5, - (self.canvas_canvas.height() - h) / 2 + 5) + self.canvas_container.move( + (self.canvas_canvas.width() - w) / 2 + 5, + (self.canvas_canvas.height() - h) / 2 + 5, + ) self.updateRuler() def keyReleaseEvent(self, event: QtCore.QEvent): - """ when a key is released """ + """when a key is released""" if event.key() == QtCore.Qt.Key_Control: self.control_modifier = False def scroll_event(self, event: QtCore.QEvent): - """ when the mouse wheel is used to zoom the figure """ + """when the mouse wheel is used to zoom the figure""" if self.control_modifier: new_dpi = self.fig.get_dpi() + 10 * event.step @@ -1049,22 +1162,22 @@ def scroll_event(self, event: QtCore.QEvent): bb = self.fig.axes[0].get_position() def updateFigureSize(self): - """ update the size of the figure """ + """update the size of the figure""" w, h = self.canvas.get_width_height() self.canvas_container.setMinimumSize(w, h) self.canvas_container.setMaximumSize(w, h) def changedFigureSize(self, size: tuple): - """ change the size of the figure """ + """change the size of the figure""" self.fig.set_size_inches(np.array(size) / 2.54) self.fig.canvas.draw() def elementSelected(self, element: Artist): - """ when an element is selected """ + """when an element is selected""" self.input_properties.setElement(element) def update(self): - """ update the tree view """ + """update the tree view""" # self.input_size.setValue(np.array(self.fig.get_size_inches())*2.54) self.treeView.deleteEntry(self.fig) self.treeView.expand(None) @@ -1093,14 +1206,14 @@ def newfunc(*args): self.treeView.setCurrentIndex(self.fig) def updateTitle(self): - """ update the title of the window to display if it is saved or not """ + """update the title of the window to display if it is saved or not""" if self.fig.change_tracker.saved: self.setWindowTitle("Figure %s - Pylustrator" % self.fig.number) else: self.setWindowTitle("Figure %s* - Pylustrator" % self.fig.number) def select_element(self, element: Artist): - """ select an element """ + """select an element""" if element is None: self.treeView.setCurrentIndex(self.fig) self.input_properties.setElement(self.fig) @@ -1109,12 +1222,18 @@ def select_element(self, element: Artist): self.input_properties.setElement(element) def closeEvent(self, event: QtCore.QEvent): - """ when the window is closed, ask the user to save """ + """when the window is closed, ask the user to save""" if not self.fig.change_tracker.saved: - reply = QtWidgets.QMessageBox.question(self, 'Warning - Pylustrator', 'The figure has not been saved. ' - 'All data will be lost.\nDo you want to save it?', - QtWidgets.QMessageBox.Cancel | QtWidgets.QMessageBox.No | QtWidgets.QMessageBox.Yes, - QtWidgets.QMessageBox.Yes) + reply = QtWidgets.QMessageBox.question( + self, + "Warning - Pylustrator", + "The figure has not been saved. " + "All data will be lost.\nDo you want to save it?", + QtWidgets.QMessageBox.Cancel + | QtWidgets.QMessageBox.No + | QtWidgets.QMessageBox.Yes, + QtWidgets.QMessageBox.Yes, + ) if reply == QtWidgets.QMessageBox.Cancel: event.ignore() diff --git a/pylustrator/change_tracker.py b/pylustrator/change_tracker.py index 7c615b4..366864c 100644 --- a/pylustrator/change_tracker.py +++ b/pylustrator/change_tracker.py @@ -22,8 +22,14 @@ import re import sys import traceback +import os +import inspect +import warnings +import types from typing import IO +import numpy as np + import matplotlib import matplotlib as mpl import matplotlib.pyplot as plt @@ -42,12 +48,17 @@ """ External overload """ + + class CustomStackPosition: filename = None lineno = None + def __init__(self, filename, lineno): self.filename = filename self.lineno = lineno + + custom_stack_position = None custom_prepend = "" custom_append = "" @@ -56,20 +67,24 @@ def __init__(self, filename, lineno): ("\\", "\\\\"), ("\n", "\\n"), ("\r", "\\r"), - ("\"", "\\\""), + ('"', '\\"'), ] + + def escape_string(str): for pair in escape_pairs: str = str.replace(pair[0], pair[1]) return str + def unescape_string(str): for pair in escape_pairs: str = str.replace(pair[1], pair[0]) return str + def getReference(element: Artist, allow_using_variable_names=True): - """ get the code string that represents the given Artist. """ + """get the code string that represents the given Artist.""" if element is None: return "" if isinstance(element, Figure): @@ -80,7 +95,7 @@ def getReference(element: Artist, allow_using_variable_names=True): if isinstance(element.number, (float, int)): return "plt.figure(%s)" % element.number else: - return "plt.figure(\"%s\")" % element.number + return 'plt.figure("%s")' % element.number if isinstance(element, matplotlib.lines.Line2D): index = element.axes.lines.index(element) return getReference(element.axes) + ".lines[%d]" % index @@ -118,30 +133,56 @@ def getReference(element: Artist, allow_using_variable_names=True): for index, label in enumerate(axes.get_xaxis().get_major_ticks()): if element == label.label1: - return getReference(axes) + ".get_xaxis().get_major_ticks()[%d].label1" % index + return ( + getReference(axes) + + ".get_xaxis().get_major_ticks()[%d].label1" % index + ) if element == label.label2: - return getReference(axes) + ".get_xaxis().get_major_ticks()[%d].label2" % index + return ( + getReference(axes) + + ".get_xaxis().get_major_ticks()[%d].label2" % index + ) for index, label in enumerate(axes.get_xaxis().get_minor_ticks()): if element == label.label1: - return getReference(axes) + ".get_xaxis().get_minor_ticks()[%d].label1" % index + return ( + getReference(axes) + + ".get_xaxis().get_minor_ticks()[%d].label1" % index + ) if element == label.label2: - return getReference(axes) + ".get_xaxis().get_minor_ticks()[%d].label2" % index + return ( + getReference(axes) + + ".get_xaxis().get_minor_ticks()[%d].label2" % index + ) for axes in element.figure.axes: for index, label in enumerate(axes.get_yaxis().get_major_ticks()): if element == label.label1: - return getReference(axes) + ".get_yaxis().get_major_ticks()[%d].label1" % index + return ( + getReference(axes) + + ".get_yaxis().get_major_ticks()[%d].label1" % index + ) if element == label.label2: - return getReference(axes) + ".get_yaxis().get_major_ticks()[%d].label2" % index + return ( + getReference(axes) + + ".get_yaxis().get_major_ticks()[%d].label2" % index + ) for index, label in enumerate(axes.get_yaxis().get_minor_ticks()): if element == label.label1: - return getReference(axes) + ".get_yaxis().get_minor_ticks()[%d].label1" % index + return ( + getReference(axes) + + ".get_yaxis().get_minor_ticks()[%d].label1" % index + ) if element == label.label2: - return getReference(axes) + ".get_yaxis().get_minor_ticks()[%d].label2" % index + return ( + getReference(axes) + + ".get_yaxis().get_minor_ticks()[%d].label2" % index + ) if isinstance(element, matplotlib.axes._axes.Axes): if element.get_label(): - return getReference(element.figure) + ".ax_dict[\"%s\"]" % escape_string(element.get_label()) + return getReference(element.figure) + '.ax_dict["%s"]' % escape_string( + element.get_label() + ) return getReference(element.figure) + ".axes[%d]" % element.number if isinstance(element, matplotlib.legend.Legend): @@ -150,8 +191,9 @@ def getReference(element: Artist, allow_using_variable_names=True): def setFigureVariableNames(figure: Figure): - """ get the global variable names that refer to the given figure """ + """get the global variable names that refer to the given figure""" import inspect + mpl_figure = _pylab_helpers.Gcf.figs[figure].canvas.figure calling_globals = inspect.stack()[2][0].f_globals fig_names = [ @@ -166,7 +208,8 @@ def setFigureVariableNames(figure: Figure): class ChangeTracker: - """ a class that records a list of the change to the figure """ + """a class that records a list of the change to the figure""" + changes = None saved = True @@ -192,18 +235,24 @@ def __init__(self, figure: Figure): self.load() - def addChange(self, command_obj: Artist, command: str, reference_obj: Artist = None, reference_command: str = None): - """ add a change """ + def addChange( + self, + command_obj: Artist, + command: str, + reference_obj: Artist = None, + reference_command: str = None, + ): + """add a change""" command = command.replace("\n", "\\n") if reference_obj is None: reference_obj = command_obj if reference_command is None: - reference_command, = re.match(r"(\.[^(=]*)", command).groups() + (reference_command,) = re.match(r"(\.[^(=]*)", command).groups() self.changes[reference_obj, reference_command] = (command_obj, command) self.saved = False def removeElement(self, element: Artist): - """ remove an Artis from the figure """ + """remove an Artis from the figure""" # create_key = key+".new" created_by_pylustrator = (element, ".new") in self.changes # delete changes related to this element @@ -219,14 +268,14 @@ def removeElement(self, element: Artist): self.figure.selection.remove_target(element) def addEdit(self, edit: list): - """ add an edit to the stored list of edits """ + """add an edit to the stored list of edits""" if self.last_edit < len(self.edits) - 1: - self.edits = self.edits[:self.last_edit + 1] + self.edits = self.edits[: self.last_edit + 1] self.edits.append(edit) self.last_edit = len(self.edits) - 1 def backEdit(self): - """ undo an edit in the list """ + """undo an edit in the list""" if self.last_edit < 0: return edit = self.edits[self.last_edit] @@ -235,7 +284,7 @@ def backEdit(self): self.figure.canvas.draw() def forwardEdit(self): - """ redo an edit """ + """redo an edit""" if self.last_edit >= len(self.edits) - 1: return edit = self.edits[self.last_edit + 1] @@ -244,22 +293,23 @@ def forwardEdit(self): self.figure.canvas.draw() def load(self): - """ load a set of changes from a script file. The changes are the code that pylustrator generated """ + """load a set of changes from a script file. The changes are the code that pylustrator generated""" regex = re.compile(r"(\.[^\(= ]*)(.*)") - command_obj_regexes = [getReference(self.figure), - r"plt\.figure\([^)]*\)", - r"fig", - r"\.ax_dict\[\"[^\"]*\"\]", - r"\.axes\[\d*\]", - r"\.texts\[\d*\]", - r"\.{title|_left_title|_right_title}", - r"\.lines\[\d*\]", - r"\.collections\[\d*\]", - r"\.patches\[\d*\]", - r"\.get_[xy]axis\(\)\.get_(major|minor)_ticks\(\)\[\d*\]", - r"\.get_[xy]axis\(\)\.get_label\(\)", - r"\.get_legend\(\)", - ] + command_obj_regexes = [ + getReference(self.figure), + r"plt\.figure\([^)]*\)", + r"fig", + r"\.ax_dict\[\"[^\"]*\"\]", + r"\.axes\[\d*\]", + r"\.texts\[\d*\]", + r"\.{title|_left_title|_right_title}", + r"\.lines\[\d*\]", + r"\.collections\[\d*\]", + r"\.patches\[\d*\]", + r"\.get_[xy]axis\(\)\.get_(major|minor)_ticks\(\)\[\d*\]", + r"\.get_[xy]axis\(\)\.get_label\(\)", + r"\.get_legend\(\)", + ] command_obj_regexes = [re.compile(r) for r in command_obj_regexes] fig = self.figure @@ -271,7 +321,10 @@ def load(self): block = getTextFromFile(getReference(self.figure), stack_position) if not block: - block = getTextFromFile(getReference(self.figure, allow_using_variable_names=False), stack_position) + block = getTextFromFile( + getReference(self.figure, allow_using_variable_names=False), + stack_position, + ) for line in block: line = line.strip() if line == "" or line in header or line.startswith("#"): @@ -286,7 +339,7 @@ def load(self): for r in command_obj_regexes: try: found = r.match(line).group() - line = line[len(found):] + line = line[len(found) :] command_obj += found except AttributeError: pass @@ -306,7 +359,12 @@ def load(self): reference_obj = command_obj reference_command = command - if command == ".set_xticks" or command == ".set_yticks" or command == ".set_xlabels" or command == ".set_ylabels": + if ( + command == ".set_xticks" + or command == ".set_yticks" + or command == ".set_xlabels" + or command == ".set_ylabels" + ): if line.find("minor=True") != -1: reference_command = command + "_minor" @@ -321,17 +379,25 @@ def load(self): # if the reference object is just a dummy, we ignore it if isinstance(reference_obj, Dummy): - print("WARNING: line references a missing object, will remove line on save:", raw_line, file=sys.stderr) + print( + "WARNING: line references a missing object, will remove line on save:", + raw_line, + file=sys.stderr, + ) continue self.get_reference_cached[reference_obj] = reference_obj_str - #print("---", [reference_obj, reference_command], (command_obj, command + parameter)) - self.changes[reference_obj, reference_command] = (command_obj, command + parameter) + # print("---", [reference_obj, reference_command], (command_obj, command + parameter)) + self.changes[reference_obj, reference_command] = ( + command_obj, + command + parameter, + ) self.sorted_changes() def sorted_changes(self): - """ sort the changes by their priority. For example setting to logscale needs to be executed before xlim. """ + """sort the changes by their priority. For example setting to logscale needs to be executed before xlim.""" + def getRef(obj): try: return getReference(obj) @@ -351,21 +417,43 @@ def getRef(obj): if getattr(reference_obj, "axes", None) is not None: if reference_command == ".new": index = "0" - elif reference_command == ".set_xscale" or reference_command == ".set_yscale": + elif ( + reference_command == ".set_xscale" + or reference_command == ".set_yscale" + ): index = "1" - elif reference_command == ".set_xlim" or reference_command == ".set_ylim": + elif ( + reference_command == ".set_xlim" + or reference_command == ".set_ylim" + ): index = "2" - elif reference_command == ".set_xticks" or reference_command == ".set_yticks": + elif ( + reference_command == ".set_xticks" + or reference_command == ".set_yticks" + ): index = "3" - elif reference_command == ".set_xticklabels" or reference_command == ".set_yticklabels": + elif ( + reference_command == ".set_xticklabels" + or reference_command == ".set_yticklabels" + ): index = "4" else: index = "5" - obj_indices = (getRef(reference_obj.axes), getRef(reference_obj), index, reference_command) + obj_indices = ( + getRef(reference_obj.axes), + getRef(reference_obj), + index, + reference_command, + ) else: obj_indices = (getRef(reference_obj), "", "", reference_command) indices.append( - [(reference_obj, reference_command), self.changes[reference_obj, reference_command], obj_indices]) + [ + (reference_obj, reference_command), + self.changes[reference_obj, reference_command], + obj_indices, + ] + ) except (ValueError, TypeError) as err: print(err, file=sys.stderr) @@ -380,13 +468,20 @@ def getRef(obj): return output - def save(self): - """ save the changes to the .py file """ - header = [getReference(self.figure) + ".ax_dict = {ax.get_label(): ax for ax in " + getReference( - self.figure) + ".axes}", "import matplotlib as mpl"] + def save(self, output_file, reqd_code): + """save the changes to the .py file""" + header = [ + getReference(self.figure) + + ".ax_dict = {ax.get_label(): ax for ax in " + + getReference(self.figure) + + ".axes}", + "import matplotlib as mpl", + ] # block = getTextFromFile(header[0], self.stack_position) - output = [custom_prepend + "#% start: automatic generated code from pylustrator"] + output = [ + custom_prepend + "#% start: automatic generated code from pylustrator" + ] # add the lines from the header for line in header: output.append(line) @@ -395,7 +490,9 @@ def save(self): output.append(line) if line.startswith("fig.add_axes"): output.append(header[1]) - output.append("#% end: automatic generated code from pylustrator" + custom_append) + output.append( + "#% end: automatic generated code from pylustrator" + custom_append + ) # print("\n".join(output)) block_id = getReference(self.figure) @@ -403,27 +500,33 @@ def save(self): if not block: block_id = getReference(self.figure, allow_using_variable_names=False) block = getTextFromFile(block_id, stack_position) - insertTextToFile(output, stack_position, block_id) + insertTextToFile(output, stack_position, block_id, output_file, reqd_code) self.saved = True def getTextFromFile(block_id: str, stack_pos: traceback.FrameSummary): - """ get the text which corresponds to the block_id (e.g. which figure) at the given position sepcified by stack_pos. """ + """get the text which corresponds to the block_id (e.g. which figure) at the given position sepcified by stack_pos.""" block_id = lineToId(block_id) block = None if not custom_stack_position: - if not stack_pos.filename.endswith('.py') and not stack_pos.filename.startswith("