diff --git a/.gitignore b/.gitignore
index d7503a8..90e8c00 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
.idea/
+.vscode/*
docs/_build_html/
*/__pycache__/
*.pyc
diff --git a/pylustrator/QtGuiDrag.py b/pylustrator/QtGuiDrag.py
index 09932d2..27624bc 100644
--- a/pylustrator/QtGuiDrag.py
+++ b/pylustrator/QtGuiDrag.py
@@ -75,8 +75,8 @@ def initialize(use_global_variable_names=False):
plt.figure = figure
patchColormapsWithMetaInfo()
- #stack_call_position = traceback.extract_stack()[-2]
- #stack_call_position.filename
+ # stack_call_position = traceback.extract_stack()[-2]
+ # stack_call_position.filename
plt.keys_for_lines = keys_for_lines
@@ -84,7 +84,9 @@ def initialize(use_global_variable_names=False):
sf = Figure.savefig
def savefig(self, filename, *args, **kwargs):
- self._last_saved_figure = getattr(self, "_last_saved_figure", []) + [(filename, args, kwargs)]
+ self._last_saved_figure = getattr(self, "_last_saved_figure", []) + [
+ (filename, args, kwargs)
+ ]
sf(self, filename, *args, **kwargs)
Figure.savefig = savefig
@@ -98,14 +100,15 @@ def savefig(self, filename, *args, **kwargs):
def show(hide_window: bool = False):
- """ the function overloads the matplotlib show function.
+ """the function overloads the matplotlib show function.
It opens a DragManager window instead of the default matplotlib window.
"""
global figures
# set an application id, so that windows properly stacks them in the task bar
- if sys.platform[:3] == 'win':
+ if sys.platform[:3] == "win":
import ctypes
- myappid = 'rgerum.pylustrator' # arbitrary string
+
+ myappid = "rgerum.pylustrator" # arbitrary string
ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(myappid)
# iterate over figures
for figure in _pylab_helpers.Gcf.figs.copy():
@@ -131,7 +134,7 @@ def show(hide_window: bool = False):
class CmapColor(list):
- """ a color like object that has the colormap as metadata """
+ """a color like object that has the colormap as metadata"""
def setMeta(self, value, cmap):
self.value = value
@@ -139,7 +142,7 @@ def setMeta(self, value, cmap):
def patchColormapsWithMetaInfo():
- """ all colormaps now return color with metadata from which colormap the color came from """
+ """all colormaps now return color with metadata from which colormap the color came from"""
from matplotlib.colors import Colormap
cm_call = Colormap.__call__
@@ -154,8 +157,16 @@ def new_call(self, *args, **kwargs):
Colormap.__call__ = new_call
-def figure(num=None, figsize=None, force_add=False, *args, **kwargs):
- """ overloads the matplotlib figure call and wrapps the Figure in a PlotWindow """
+def figure(
+ num=None,
+ figsize=None,
+ force_add=False,
+ output_file: str = "source",
+ reqd_code: list = [],
+ *args,
+ **kwargs
+):
+ """overloads the matplotlib figure call and wrapps the Figure in a PlotWindow"""
global figures
# if num is not defined create a new number
if num is None:
@@ -163,7 +174,9 @@ def figure(num=None, figsize=None, force_add=False, *args, **kwargs):
# if number is not defined
if force_add or num not in _pylab_helpers.Gcf.figs.keys():
# create a new window and store it
- canvas = PlotWindow(num, figsize, *args, **kwargs).canvas
+ canvas = PlotWindow(
+ num, figsize, output_file, reqd_code, *args, **kwargs
+ ).canvas
canvas.figure.number = num
canvas.figure.clf()
canvas.manager.num = num
@@ -172,7 +185,9 @@ def figure(num=None, figsize=None, force_add=False, *args, **kwargs):
manager = _pylab_helpers.Gcf.figs[num]
# set the size if it is defined
if figsize is not None:
- _pylab_helpers.Gcf.figs[num].window.setGeometry(100, 100, figsize[0] * 80, figsize[1] * 80)
+ _pylab_helpers.Gcf.figs[num].window.setGeometry(
+ 100, 100, figsize[0] * 80, figsize[1] * 80
+ )
# set the figure as the active figure
_pylab_helpers.Gcf.set_active(manager)
# return the figure
@@ -180,8 +195,9 @@ def figure(num=None, figsize=None, force_add=False, *args, **kwargs):
def warnAboutTicks(fig):
- """ warn if the tick labels and tick values do not match, to prevent users from accidently setting wrong tick values """
+ """warn if the tick labels and tick values do not match, to prevent users from accidently setting wrong tick values"""
import sys
+
for index, ax in enumerate(fig.axes):
ticks = ax.get_yticks()
labels = [t.get_text() for t in ax.get_yticklabels()]
@@ -198,7 +214,14 @@ def warnAboutTicks(fig):
ax_name = "#%d" % index
else:
ax_name = '"' + ax_name + '"'
- print("Warning tick and label differ", t, l, "for axes", ax_name, file=sys.stderr)
+ print(
+ "Warning tick and label differ",
+ t,
+ l,
+ "for axes",
+ ax_name,
+ file=sys.stderr,
+ )
""" Window """
@@ -206,11 +229,11 @@ def warnAboutTicks(fig):
class myTreeWidgetItem(QtGui.QStandardItem):
def __init__(self, parent: QtWidgets.QWidget = None):
- """ a tree view item to display the contents of the figure """
+ """a tree view item to display the contents of the figure"""
QtGui.QStandardItem.__init__(self, parent)
def __lt__(self, otherItem: QtGui.QStandardItem):
- """ how to sort the items """
+ """how to sort the items"""
if self.sort is None:
return 0
return self.sort < otherItem.sort
@@ -226,8 +249,10 @@ class MyTreeView(QtWidgets.QTreeView):
last_selection = None
last_hover = None
- def __init__(self, parent: QtWidgets.QWidget, layout: QtWidgets.QLayout, fig: Figure):
- """ A tree view to display the contents of a figure
+ def __init__(
+ self, parent: QtWidgets.QWidget, layout: QtWidgets.QLayout, fig: Figure
+ ):
+ """A tree view to display the contents of a figure
Args:
parent: the parent widget
@@ -268,10 +293,17 @@ def __init__(self, parent: QtWidgets.QWidget, layout: QtWidgets.QLayout, fig: Fi
self.expand(None)
- def selectionChanged(self, selection: QtCore.QItemSelection, y: QtCore.QItemSelection):
- """ when the selection in the tree view changes """
+ def selectionChanged(
+ self, selection: QtCore.QItemSelection, y: QtCore.QItemSelection
+ ):
+ """when the selection in the tree view changes"""
try:
- entry = selection.indexes()[0].model().itemFromIndex(selection.indexes()[0]).entry
+ entry = (
+ selection.indexes()[0]
+ .model()
+ .itemFromIndex(selection.indexes()[0])
+ .entry
+ )
except IndexError:
entry = None
if self.last_selection != entry:
@@ -279,7 +311,7 @@ def selectionChanged(self, selection: QtCore.QItemSelection, y: QtCore.QItemSele
self.item_selected(entry)
def setCurrentIndex(self, entry: Artist):
- """ set the currently selected entry """
+ """set the currently selected entry"""
while entry:
item = self.getItemFromEntry(entry)
if item is not None:
@@ -291,17 +323,17 @@ def setCurrentIndex(self, entry: Artist):
return
def treeClicked(self, index: QtCore.QModelIndex):
- """ upon selecting one of the tree elements """
+ """upon selecting one of the tree elements"""
data = index.model().itemFromIndex(index).entry
return self.item_clicked(data)
def treeActivated(self, index: QtCore.QModelIndex):
- """ upon selecting one of the tree elements """
+ """upon selecting one of the tree elements"""
data = index.model().itemFromIndex(index).entry
return self.item_activated(data)
def eventFilter(self, object: QtWidgets.QWidget, event: QtCore.QEvent):
- """ event filter for tree view port to handle mouse over events and marker highlighting"""
+ """event filter for tree view port to handle mouse over events and marker highlighting"""
if event.type() == QtCore.QEvent.HoverMove:
index = self.indexAt(event.pos())
try:
@@ -328,24 +360,24 @@ def eventFilter(self, object: QtWidgets.QWidget, event: QtCore.QEvent):
return False
def queryToExpandEntry(self, entry: Artist) -> list:
- """ when expanding a tree item """
+ """when expanding a tree item"""
if entry is None:
return [self.fig]
return entry.get_children()
def getParentEntry(self, entry: Artist) -> Artist:
- """ get the parent of an item """
+ """get the parent of an item"""
return entry.tree_parent
def getNameOfEntry(self, entry: Artist) -> str:
- """ convert an entry to a string """
+ """convert an entry to a string"""
try:
return str(entry)
except AttributeError:
return "unknown"
def getIconOfEntry(self, entry: Artist) -> QtGui.QIcon:
- """ get the icon of an entry """
+ """get the icon of an entry"""
if getattr(entry, "_draggable", None):
if entry._draggable.connected:
return qta.icon("fa.hand-paper-o")
@@ -355,11 +387,11 @@ def getEntrySortRole(self, entry: Artist):
return None
def getKey(self, entry: Artist) -> Artist:
- """ get the key of an entry, which is the entry itself """
+ """get the key of an entry, which is the entry itself"""
return entry
def getItemFromEntry(self, entry: Artist) -> Optional[QtWidgets.QTreeWidgetItem]:
- """ get the tree view item for the given artist """
+ """get the tree view item for the given artist"""
if entry is None:
return None
key = self.getKey(entry)
@@ -369,12 +401,12 @@ def getItemFromEntry(self, entry: Artist) -> Optional[QtWidgets.QTreeWidgetItem]
return None
def setItemForEntry(self, entry: Artist, item: QtWidgets.QTreeWidgetItem):
- """ store a new artist and tree view widget pair """
+ """store a new artist and tree view widget pair"""
key = self.getKey(entry)
self.item_lookup[key] = item
def expand(self, entry: Artist, force_reload: bool = True):
- """ expand the children of a tree view item """
+ """expand the children of a tree view item"""
query = self.queryToExpandEntry(entry)
parent_item = self.getItemFromEntry(entry)
parent_entry = entry
@@ -396,9 +428,11 @@ def expand(self, entry: Artist, force_reload: bool = True):
for row, entry in enumerate(query):
entry.tree_parent = parent_entry
if 1:
- if (isinstance(entry, mpl.spines.Spine) or
- isinstance(entry, mpl.axis.XAxis) or
- isinstance(entry, mpl.axis.YAxis)):
+ if (
+ isinstance(entry, mpl.spines.Spine)
+ or isinstance(entry, mpl.axis.XAxis)
+ or isinstance(entry, mpl.axis.YAxis)
+ ):
continue
if isinstance(entry, mpl.text.Text) and entry.get_text() == "":
continue
@@ -416,7 +450,7 @@ def expand(self, entry: Artist, force_reload: bool = True):
self.addChild(parent_item, entry)
def addChild(self, parent_item: QtWidgets.QWidget, entry: Artist, row=None):
- """ add a child to a tree view node """
+ """add a child to a tree view node"""
if parent_item is None:
parent_item = self.model
@@ -442,7 +476,9 @@ def addChild(self, parent_item: QtWidgets.QWidget, entry: Artist, row=None):
self.setItemForEntry(entry, item)
# add dummy child
- if self.queryToExpandEntry(entry) is not None and len(self.queryToExpandEntry(entry)):
+ if self.queryToExpandEntry(entry) is not None and len(
+ self.queryToExpandEntry(entry)
+ ):
child = QtGui.QStandardItem("loading")
child.entry = None
child.setEditable(False)
@@ -452,7 +488,7 @@ def addChild(self, parent_item: QtWidgets.QWidget, entry: Artist, row=None):
return item
def TreeExpand(self, index):
- """ expand a tree view node """
+ """expand a tree view node"""
# Get item and entry
item = index.model().itemFromIndex(index)
entry = item.entry
@@ -468,8 +504,14 @@ def TreeExpand(self, index):
thread.setDaemon(True)
thread.start()
- def updateEntry(self, entry: Artist, update_children: bool = False, insert_before: Artist = None, insert_after: Artist = None):
- """ update a tree view node """
+ def updateEntry(
+ self,
+ entry: Artist,
+ update_children: bool = False,
+ insert_before: Artist = None,
+ insert_after: Artist = None,
+ ):
+ """update a tree view node"""
# get the tree view item for the database entry
item = self.getItemFromEntry(entry)
# if we haven't one yet, we have to create it
@@ -537,7 +579,7 @@ def updateEntry(self, entry: Artist, update_children: bool = False, insert_befor
self.expand(entry, force_reload=True)
def deleteEntry(self, entry: Artist):
- """ delete an entry from the tree """
+ """delete an entry from the tree"""
item = self.getItemFromEntry(entry)
if item is None:
return
@@ -562,24 +604,29 @@ def deleteEntry(self, entry: Artist):
class InfoDialog(QtWidgets.QWidget):
def __init__(self, parent):
- """ A dialog displaying the version number of pylustrator.
+ """A dialog displaying the version number of pylustrator.
Args:
parent: the parent widget
"""
QtWidgets.QWidget.__init__(self)
self.setWindowTitle("Pylustrator - Info")
- self.setWindowIcon(QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", "logo.ico")))
+ self.setWindowIcon(
+ QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", "logo.ico"))
+ )
self.layout = QtWidgets.QVBoxLayout(self)
self.label = QtWidgets.QLabel("")
- pixmap = QtGui.QPixmap(os.path.join(os.path.dirname(__file__), "icons", "logo.png"))
+ pixmap = QtGui.QPixmap(
+ os.path.join(os.path.dirname(__file__), "icons", "logo.png")
+ )
self.label.setPixmap(pixmap)
self.label.setMask(pixmap.mask())
self.layout.addWidget(self.label)
import pylustrator
+
self.label = QtWidgets.QLabel("Version " + pylustrator.__version__ + "")
font = self.label.font()
font.setPointSize(16)
@@ -591,7 +638,9 @@ def __init__(self, parent):
self.label.setAlignment(QtCore.Qt.AlignCenter)
self.layout.addWidget(self.label)
- self.label = QtWidgets.QLabel("Documentation")
+ self.label = QtWidgets.QLabel(
+ "Documentation"
+ )
self.label.setAlignment(QtCore.Qt.AlignCenter)
self.label.setTextInteractionFlags(QtCore.Qt.TextBrowserInteraction)
self.label.setOpenExternalLinks(True)
@@ -600,7 +649,7 @@ def __init__(self, parent):
class Align(QtWidgets.QWidget):
def __init__(self, layout: QtWidgets.QLayout, fig: Figure):
- """ A widget that allows to align the elements of a multi selection.
+ """A widget that allows to align the elements of a multi selection.
Args:
layout: the layout to which to add the widget
@@ -613,13 +662,36 @@ def __init__(self, layout: QtWidgets.QLayout, fig: Figure):
self.layout = QtWidgets.QHBoxLayout(self)
self.layout.setContentsMargins(0, 0, 0, 0)
- actions = ["left_x", "center_x", "right_x", "distribute_x", "top_y", "center_y", "bottom_y", "distribute_y", "group"]
- icons = ["left_x.png", "center_x.png", "right_x.png", "distribute_x.png", "top_y.png", "center_y.png",
- "bottom_y.png", "distribute_y.png", "group.png"]
+ actions = [
+ "left_x",
+ "center_x",
+ "right_x",
+ "distribute_x",
+ "top_y",
+ "center_y",
+ "bottom_y",
+ "distribute_y",
+ "group",
+ ]
+ icons = [
+ "left_x.png",
+ "center_x.png",
+ "right_x.png",
+ "distribute_x.png",
+ "top_y.png",
+ "center_y.png",
+ "bottom_y.png",
+ "distribute_y.png",
+ "group.png",
+ ]
self.buttons = []
for index, act in enumerate(actions):
- button = QtWidgets.QPushButton(QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", icons[index])),
- "")
+ button = QtWidgets.QPushButton(
+ QtGui.QIcon(
+ os.path.join(os.path.dirname(__file__), "icons", icons[index])
+ ),
+ "",
+ )
self.layout.addWidget(button)
button.clicked.connect(lambda x, act=act: self.execute_action(act))
self.buttons.append(button)
@@ -631,7 +703,7 @@ def __init__(self, layout: QtWidgets.QLayout, fig: Figure):
self.layout.addStretch()
def execute_action(self, act: str):
- """ execute an alignment action """
+ """execute an alignment action"""
self.fig.selection.align_points(act)
self.fig.selection.update_selection_rectangles()
self.fig.canvas.draw()
@@ -640,19 +712,24 @@ def execute_action(self, act: str):
class PlotWindow(QtWidgets.QWidget):
fitted_to_view = False
- def __init__(self, number: int, size: tuple):
- """ The main window of pylustrator
+ def __init__(self, number: int, size: tuple, output_file: str, reqd_code: list):
+ """The main window of pylustrator
Args:
number: the id of the figure
size: the size of the figure
+ output_file: destination for generated code. Defaults to the source file.
"""
+ self.output_file = output_file
+ self.reqd_code = reqd_code
QtWidgets.QWidget.__init__(self)
self.canvas_canvas = QtWidgets.QWidget()
self.canvas_canvas.setMinimumHeight(400)
self.canvas_canvas.setMinimumWidth(400)
- self.canvas_canvas.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)
+ self.canvas_canvas.setSizePolicy(
+ QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding
+ )
self.canvas_canvas.setStyleSheet("background:white")
self.canvas_canvas.setFocusPolicy(QtCore.Qt.StrongFocus)
@@ -686,7 +763,9 @@ def mouseRelease(event):
# widget layout and elements
self.setWindowTitle("Figure %s - Pylustrator" % number)
- self.setWindowIcon(QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", "logo.ico")))
+ self.setWindowIcon(
+ QtGui.QIcon(os.path.join(os.path.dirname(__file__), "icons", "logo.ico"))
+ )
layout_parent = QtWidgets.QVBoxLayout(self)
self.menuBar = QtWidgets.QMenuBar()
@@ -741,7 +820,9 @@ def mouseRelease(event):
self.treeView = MyTreeView(self, self.layout_tools, self.fig)
self.treeView.item_selected = self.elementSelected
- self.input_properties = QItemProperties(self.layout_tools, self.fig, self.treeView, self)
+ self.input_properties = QItemProperties(
+ self.layout_tools, self.fig, self.treeView, self
+ )
self.input_align = Align(self.layout_tools, self.fig)
# add plot layout
@@ -757,14 +838,14 @@ def mouseRelease(event):
self.fig.canvas.mpl_disconnect(self.fig.canvas.manager.key_press_handler_id)
- self.fig.canvas.mpl_connect('scroll_event', self.scroll_event)
- self.fig.canvas.mpl_connect('key_press_event', self.canvas_key_press)
- self.fig.canvas.mpl_connect('key_release_event', self.canvas_key_release)
+ self.fig.canvas.mpl_connect("scroll_event", self.scroll_event)
+ self.fig.canvas.mpl_connect("key_press_event", self.canvas_key_press)
+ self.fig.canvas.mpl_connect("key_release_event", self.canvas_key_release)
self.control_modifier = False
- self.fig.canvas.mpl_connect('button_press_event', self.button_press_event)
- self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_move_event)
- self.fig.canvas.mpl_connect('button_release_event', self.button_release_event)
+ self.fig.canvas.mpl_connect("button_press_event", self.button_press_event)
+ self.fig.canvas.mpl_connect("motion_notify_event", self.mouse_move_event)
+ self.fig.canvas.mpl_connect("button_release_event", self.button_release_event)
self.drag = None
self.footer_layout = QtWidgets.QHBoxLayout()
@@ -779,12 +860,13 @@ def mouseRelease(event):
self.footer_layout.addWidget(self.footer_label2)
from .QtGui import ColorChooserWidget
+
self.colorWidget = ColorChooserWidget(self, self.canvas)
self.colorWidget.setMaximumWidth(150)
self.layout_main.addWidget(self.colorWidget)
def rasterize(self, rasterize: bool):
- """ convert the figur elements to an image """
+ """convert the figur elements to an image"""
if len(self.fig.selection.targets):
self.fig.figure_dragger.select_element(None)
if rasterize:
@@ -796,15 +878,21 @@ def rasterize(self, rasterize: bool):
self.fig.canvas.draw()
def actionSave(self):
- """ save the code for the figure """
- self.fig.change_tracker.save()
- for _last_saved_figure, args, kwargs in getattr(self.fig, "_last_saved_figure", []):
+ """save the code for the figure"""
+ self.fig.change_tracker.save(self.output_file, self.reqd_code)
+ for _last_saved_figure, args, kwargs in getattr(
+ self.fig, "_last_saved_figure", []
+ ):
self.fig.savefig(_last_saved_figure, *args, **kwargs)
def actionSaveImage(self):
- """ save figure as an image """
- path = QtWidgets.QFileDialog.getSaveFileName(self, "Save Image", getattr(self.fig, "_last_saved_figure", [(None,)])[0][0],
- "Images (*.png *.jpg *.pdf)")
+ """save figure as an image"""
+ path = QtWidgets.QFileDialog.getSaveFileName(
+ self,
+ "Save Image",
+ getattr(self.fig, "_last_saved_figure", [(None,)])[0][0],
+ "Images (*.png *.jpg *.pdf)",
+ )
if isinstance(path, tuple):
path = str(path[0])
else:
@@ -818,13 +906,16 @@ def actionSaveImage(self):
print("Saved plot image as", path)
def showInfo(self):
- """ show the info dialog """
+ """show the info dialog"""
self.info_dialog = InfoDialog(self)
self.info_dialog.show()
def updateRuler(self):
- """ update the ruler around the figure to show the dimensions """
- trans = transforms.Affine2D().scale(1. / 2.54, 1. / 2.54) + self.fig.dpi_scale_trans
+ """update the ruler around the figure to show the dimensions"""
+ trans = (
+ transforms.Affine2D().scale(1.0 / 2.54, 1.0 / 2.54)
+ + self.fig.dpi_scale_trans
+ )
l = 17
l1 = 13
l2 = 6
@@ -850,14 +941,19 @@ def updateRuler(self):
end_x = np.ceil(trans.inverted().transform((-offset + w, 0))[0])
dx = 0.1
for i, pos_cm in enumerate(np.arange(start_x, end_x, dx)):
- x = (trans.transform((pos_cm, 0))[0] + offset)
+ x = trans.transform((pos_cm, 0))[0] + offset
if i % 10 == 0:
painterX.drawLine(x, l - l1 - 1, x, l - 1)
text = str("%d" % np.round(pos_cm))
o = 0
- painterX.drawText(x + 3, o, self.fontMetrics().width(text), o + self.fontMetrics().height(),
- QtCore.Qt.AlignLeft,
- text)
+ painterX.drawText(
+ x + 3,
+ o,
+ self.fontMetrics().width(text),
+ o + self.fontMetrics().height(),
+ QtCore.Qt.AlignLeft,
+ text,
+ )
elif i % 2 == 0:
painterX.drawLine(x, l - l2 - 1, x, l - 1)
else:
@@ -874,14 +970,19 @@ def updateRuler(self):
end_y = np.ceil(trans.inverted().transform((0, +offset))[1])
dy = 0.1
for i, pos_cm in enumerate(np.arange(start_y, end_y, dy)):
- y = (-trans.transform((0, pos_cm))[1] + offset)
+ y = -trans.transform((0, pos_cm))[1] + offset
if i % 10 == 0:
painterY.drawLine(l - l1 - 1, y, l - 1, y)
text = str("%d" % np.round(pos_cm))
o = 0
- painterY.drawText(o, y + 3, o + self.fontMetrics().width(text), self.fontMetrics().height(),
- QtCore.Qt.AlignRight,
- text)
+ painterY.drawText(
+ o,
+ y + 3,
+ o + self.fontMetrics().width(text),
+ self.fontMetrics().height(),
+ QtCore.Qt.AlignRight,
+ text,
+ )
elif i % 2 == 0:
painterY.drawLine(l - l2 - 1, y, l - 1, y)
else:
@@ -916,33 +1017,38 @@ def updateRuler(self):
self.shadow.setMaximumSize(w + 100, h + 10)
def showEvent(self, event: QtCore.QEvent):
- """ when the window is shown """
+ """when the window is shown"""
self.fitToView()
self.updateRuler()
self.colorWidget.updateColors()
def resizeEvent(self, event: QtCore.QEvent):
- """ when the window is resized """
+ """when the window is resized"""
if self.fitted_to_view:
self.fitToView(True)
else:
self.updateRuler()
def button_press_event(self, event: QtCore.QEvent):
- """ when a mouse button is pressed """
+ """when a mouse button is pressed"""
if event.button == 2:
self.drag = np.array([event.x, event.y])
def mouse_move_event(self, event: QtCore.QEvent):
- """ when the mouse is moved """
+ """when the mouse is moved"""
if self.drag is not None:
pos = np.array([event.x, event.y])
offset = pos - self.drag
offset[1] = -offset[1]
self.moveCanvasCanvas(*offset)
- trans = transforms.Affine2D().scale(2.54, 2.54) + self.fig.dpi_scale_trans.inverted()
+ trans = (
+ transforms.Affine2D().scale(2.54, 2.54)
+ + self.fig.dpi_scale_trans.inverted()
+ )
pos = trans.transform((event.x, event.y))
- self.footer_label.setText("%.2f, %.2f (cm) [%d, %d]" % (pos[0], pos[1], event.x, event.y))
+ self.footer_label.setText(
+ "%.2f, %.2f (cm) [%d, %d]" % (pos[0], pos[1], event.x, event.y)
+ )
if event.ydata is not None:
self.footer_label2.setText("%.2f, %.2f" % (event.xdata, event.ydata))
@@ -950,29 +1056,29 @@ def mouse_move_event(self, event: QtCore.QEvent):
self.footer_label2.setText("")
def button_release_event(self, event: QtCore.QEvent):
- """ when the mouse button is released """
+ """when the mouse button is released"""
if event.button == 2:
self.drag = None
def canvas_key_press(self, event: QtCore.QEvent):
- """ when a key in the canvas widget is pressed """
+ """when a key in the canvas widget is pressed"""
if event.key == "control":
self.control_modifier = True
def canvas_key_release(self, event: QtCore.QEvent):
- """ when a key in the canvas widget is released """
+ """when a key in the canvas widget is released"""
if event.key == "control":
self.control_modifier = False
def moveCanvasCanvas(self, offset_x: float, offset_y: float):
- """ when the canvas is panned """
+ """when the canvas is panned"""
p = self.canvas_container.pos()
self.canvas_container.move(p.x() + offset_x, p.y() + offset_y)
self.updateRuler()
def keyPressEvent(self, event: QtCore.QEvent):
- """ when a key is pressed """
+ """when a key is pressed"""
if event.key() == QtCore.Qt.Key_Control:
self.control_modifier = True
if event.key() == QtCore.Qt.Key_Left:
@@ -988,11 +1094,14 @@ def keyPressEvent(self, event: QtCore.QEvent):
self.fitToView(True)
def fitToView(self, change_dpi: bool = False):
- """ fit the figure to the view """
+ """fit the figure to the view"""
self.fitted_to_view = True
if change_dpi:
w, h = self.canvas.get_width_height()
- factor = min((self.canvas_canvas.width() - 30) / w, (self.canvas_canvas.height() - 30) / h)
+ factor = min(
+ (self.canvas_canvas.width() - 30) / w,
+ (self.canvas_canvas.height() - 30) / h,
+ )
self.fig.set_dpi(self.fig.get_dpi() * factor)
self.fig.canvas.draw()
@@ -1001,8 +1110,10 @@ def fitToView(self, change_dpi: bool = False):
self.canvas_container.setMinimumSize(w, h)
self.canvas_container.setMaximumSize(w, h)
- self.canvas_container.move((self.canvas_canvas.width() - w) / 2 + 5,
- (self.canvas_canvas.height() - h) / 2 + 5)
+ self.canvas_container.move(
+ (self.canvas_canvas.width() - w) / 2 + 5,
+ (self.canvas_canvas.height() - h) / 2 + 5,
+ )
self.updateRuler()
self.fig.canvas.draw()
@@ -1012,17 +1123,19 @@ def fitToView(self, change_dpi: bool = False):
self.canvas_canvas.setMinimumWidth(w + 30)
self.canvas_canvas.setMinimumHeight(h + 30)
- self.canvas_container.move((self.canvas_canvas.width() - w) / 2 + 5,
- (self.canvas_canvas.height() - h) / 2 + 5)
+ self.canvas_container.move(
+ (self.canvas_canvas.width() - w) / 2 + 5,
+ (self.canvas_canvas.height() - h) / 2 + 5,
+ )
self.updateRuler()
def keyReleaseEvent(self, event: QtCore.QEvent):
- """ when a key is released """
+ """when a key is released"""
if event.key() == QtCore.Qt.Key_Control:
self.control_modifier = False
def scroll_event(self, event: QtCore.QEvent):
- """ when the mouse wheel is used to zoom the figure """
+ """when the mouse wheel is used to zoom the figure"""
if self.control_modifier:
new_dpi = self.fig.get_dpi() + 10 * event.step
@@ -1049,22 +1162,22 @@ def scroll_event(self, event: QtCore.QEvent):
bb = self.fig.axes[0].get_position()
def updateFigureSize(self):
- """ update the size of the figure """
+ """update the size of the figure"""
w, h = self.canvas.get_width_height()
self.canvas_container.setMinimumSize(w, h)
self.canvas_container.setMaximumSize(w, h)
def changedFigureSize(self, size: tuple):
- """ change the size of the figure """
+ """change the size of the figure"""
self.fig.set_size_inches(np.array(size) / 2.54)
self.fig.canvas.draw()
def elementSelected(self, element: Artist):
- """ when an element is selected """
+ """when an element is selected"""
self.input_properties.setElement(element)
def update(self):
- """ update the tree view """
+ """update the tree view"""
# self.input_size.setValue(np.array(self.fig.get_size_inches())*2.54)
self.treeView.deleteEntry(self.fig)
self.treeView.expand(None)
@@ -1093,14 +1206,14 @@ def newfunc(*args):
self.treeView.setCurrentIndex(self.fig)
def updateTitle(self):
- """ update the title of the window to display if it is saved or not """
+ """update the title of the window to display if it is saved or not"""
if self.fig.change_tracker.saved:
self.setWindowTitle("Figure %s - Pylustrator" % self.fig.number)
else:
self.setWindowTitle("Figure %s* - Pylustrator" % self.fig.number)
def select_element(self, element: Artist):
- """ select an element """
+ """select an element"""
if element is None:
self.treeView.setCurrentIndex(self.fig)
self.input_properties.setElement(self.fig)
@@ -1109,12 +1222,18 @@ def select_element(self, element: Artist):
self.input_properties.setElement(element)
def closeEvent(self, event: QtCore.QEvent):
- """ when the window is closed, ask the user to save """
+ """when the window is closed, ask the user to save"""
if not self.fig.change_tracker.saved:
- reply = QtWidgets.QMessageBox.question(self, 'Warning - Pylustrator', 'The figure has not been saved. '
- 'All data will be lost.\nDo you want to save it?',
- QtWidgets.QMessageBox.Cancel | QtWidgets.QMessageBox.No | QtWidgets.QMessageBox.Yes,
- QtWidgets.QMessageBox.Yes)
+ reply = QtWidgets.QMessageBox.question(
+ self,
+ "Warning - Pylustrator",
+ "The figure has not been saved. "
+ "All data will be lost.\nDo you want to save it?",
+ QtWidgets.QMessageBox.Cancel
+ | QtWidgets.QMessageBox.No
+ | QtWidgets.QMessageBox.Yes,
+ QtWidgets.QMessageBox.Yes,
+ )
if reply == QtWidgets.QMessageBox.Cancel:
event.ignore()
diff --git a/pylustrator/change_tracker.py b/pylustrator/change_tracker.py
index 7c615b4..366864c 100644
--- a/pylustrator/change_tracker.py
+++ b/pylustrator/change_tracker.py
@@ -22,8 +22,14 @@
import re
import sys
import traceback
+import os
+import inspect
+import warnings
+import types
from typing import IO
+import numpy as np
+
import matplotlib
import matplotlib as mpl
import matplotlib.pyplot as plt
@@ -42,12 +48,17 @@
""" External overload """
+
+
class CustomStackPosition:
filename = None
lineno = None
+
def __init__(self, filename, lineno):
self.filename = filename
self.lineno = lineno
+
+
custom_stack_position = None
custom_prepend = ""
custom_append = ""
@@ -56,20 +67,24 @@ def __init__(self, filename, lineno):
("\\", "\\\\"),
("\n", "\\n"),
("\r", "\\r"),
- ("\"", "\\\""),
+ ('"', '\\"'),
]
+
+
def escape_string(str):
for pair in escape_pairs:
str = str.replace(pair[0], pair[1])
return str
+
def unescape_string(str):
for pair in escape_pairs:
str = str.replace(pair[1], pair[0])
return str
+
def getReference(element: Artist, allow_using_variable_names=True):
- """ get the code string that represents the given Artist. """
+ """get the code string that represents the given Artist."""
if element is None:
return ""
if isinstance(element, Figure):
@@ -80,7 +95,7 @@ def getReference(element: Artist, allow_using_variable_names=True):
if isinstance(element.number, (float, int)):
return "plt.figure(%s)" % element.number
else:
- return "plt.figure(\"%s\")" % element.number
+ return 'plt.figure("%s")' % element.number
if isinstance(element, matplotlib.lines.Line2D):
index = element.axes.lines.index(element)
return getReference(element.axes) + ".lines[%d]" % index
@@ -118,30 +133,56 @@ def getReference(element: Artist, allow_using_variable_names=True):
for index, label in enumerate(axes.get_xaxis().get_major_ticks()):
if element == label.label1:
- return getReference(axes) + ".get_xaxis().get_major_ticks()[%d].label1" % index
+ return (
+ getReference(axes)
+ + ".get_xaxis().get_major_ticks()[%d].label1" % index
+ )
if element == label.label2:
- return getReference(axes) + ".get_xaxis().get_major_ticks()[%d].label2" % index
+ return (
+ getReference(axes)
+ + ".get_xaxis().get_major_ticks()[%d].label2" % index
+ )
for index, label in enumerate(axes.get_xaxis().get_minor_ticks()):
if element == label.label1:
- return getReference(axes) + ".get_xaxis().get_minor_ticks()[%d].label1" % index
+ return (
+ getReference(axes)
+ + ".get_xaxis().get_minor_ticks()[%d].label1" % index
+ )
if element == label.label2:
- return getReference(axes) + ".get_xaxis().get_minor_ticks()[%d].label2" % index
+ return (
+ getReference(axes)
+ + ".get_xaxis().get_minor_ticks()[%d].label2" % index
+ )
for axes in element.figure.axes:
for index, label in enumerate(axes.get_yaxis().get_major_ticks()):
if element == label.label1:
- return getReference(axes) + ".get_yaxis().get_major_ticks()[%d].label1" % index
+ return (
+ getReference(axes)
+ + ".get_yaxis().get_major_ticks()[%d].label1" % index
+ )
if element == label.label2:
- return getReference(axes) + ".get_yaxis().get_major_ticks()[%d].label2" % index
+ return (
+ getReference(axes)
+ + ".get_yaxis().get_major_ticks()[%d].label2" % index
+ )
for index, label in enumerate(axes.get_yaxis().get_minor_ticks()):
if element == label.label1:
- return getReference(axes) + ".get_yaxis().get_minor_ticks()[%d].label1" % index
+ return (
+ getReference(axes)
+ + ".get_yaxis().get_minor_ticks()[%d].label1" % index
+ )
if element == label.label2:
- return getReference(axes) + ".get_yaxis().get_minor_ticks()[%d].label2" % index
+ return (
+ getReference(axes)
+ + ".get_yaxis().get_minor_ticks()[%d].label2" % index
+ )
if isinstance(element, matplotlib.axes._axes.Axes):
if element.get_label():
- return getReference(element.figure) + ".ax_dict[\"%s\"]" % escape_string(element.get_label())
+ return getReference(element.figure) + '.ax_dict["%s"]' % escape_string(
+ element.get_label()
+ )
return getReference(element.figure) + ".axes[%d]" % element.number
if isinstance(element, matplotlib.legend.Legend):
@@ -150,8 +191,9 @@ def getReference(element: Artist, allow_using_variable_names=True):
def setFigureVariableNames(figure: Figure):
- """ get the global variable names that refer to the given figure """
+ """get the global variable names that refer to the given figure"""
import inspect
+
mpl_figure = _pylab_helpers.Gcf.figs[figure].canvas.figure
calling_globals = inspect.stack()[2][0].f_globals
fig_names = [
@@ -166,7 +208,8 @@ def setFigureVariableNames(figure: Figure):
class ChangeTracker:
- """ a class that records a list of the change to the figure """
+ """a class that records a list of the change to the figure"""
+
changes = None
saved = True
@@ -192,18 +235,24 @@ def __init__(self, figure: Figure):
self.load()
- def addChange(self, command_obj: Artist, command: str, reference_obj: Artist = None, reference_command: str = None):
- """ add a change """
+ def addChange(
+ self,
+ command_obj: Artist,
+ command: str,
+ reference_obj: Artist = None,
+ reference_command: str = None,
+ ):
+ """add a change"""
command = command.replace("\n", "\\n")
if reference_obj is None:
reference_obj = command_obj
if reference_command is None:
- reference_command, = re.match(r"(\.[^(=]*)", command).groups()
+ (reference_command,) = re.match(r"(\.[^(=]*)", command).groups()
self.changes[reference_obj, reference_command] = (command_obj, command)
self.saved = False
def removeElement(self, element: Artist):
- """ remove an Artis from the figure """
+ """remove an Artis from the figure"""
# create_key = key+".new"
created_by_pylustrator = (element, ".new") in self.changes
# delete changes related to this element
@@ -219,14 +268,14 @@ def removeElement(self, element: Artist):
self.figure.selection.remove_target(element)
def addEdit(self, edit: list):
- """ add an edit to the stored list of edits """
+ """add an edit to the stored list of edits"""
if self.last_edit < len(self.edits) - 1:
- self.edits = self.edits[:self.last_edit + 1]
+ self.edits = self.edits[: self.last_edit + 1]
self.edits.append(edit)
self.last_edit = len(self.edits) - 1
def backEdit(self):
- """ undo an edit in the list """
+ """undo an edit in the list"""
if self.last_edit < 0:
return
edit = self.edits[self.last_edit]
@@ -235,7 +284,7 @@ def backEdit(self):
self.figure.canvas.draw()
def forwardEdit(self):
- """ redo an edit """
+ """redo an edit"""
if self.last_edit >= len(self.edits) - 1:
return
edit = self.edits[self.last_edit + 1]
@@ -244,22 +293,23 @@ def forwardEdit(self):
self.figure.canvas.draw()
def load(self):
- """ load a set of changes from a script file. The changes are the code that pylustrator generated """
+ """load a set of changes from a script file. The changes are the code that pylustrator generated"""
regex = re.compile(r"(\.[^\(= ]*)(.*)")
- command_obj_regexes = [getReference(self.figure),
- r"plt\.figure\([^)]*\)",
- r"fig",
- r"\.ax_dict\[\"[^\"]*\"\]",
- r"\.axes\[\d*\]",
- r"\.texts\[\d*\]",
- r"\.{title|_left_title|_right_title}",
- r"\.lines\[\d*\]",
- r"\.collections\[\d*\]",
- r"\.patches\[\d*\]",
- r"\.get_[xy]axis\(\)\.get_(major|minor)_ticks\(\)\[\d*\]",
- r"\.get_[xy]axis\(\)\.get_label\(\)",
- r"\.get_legend\(\)",
- ]
+ command_obj_regexes = [
+ getReference(self.figure),
+ r"plt\.figure\([^)]*\)",
+ r"fig",
+ r"\.ax_dict\[\"[^\"]*\"\]",
+ r"\.axes\[\d*\]",
+ r"\.texts\[\d*\]",
+ r"\.{title|_left_title|_right_title}",
+ r"\.lines\[\d*\]",
+ r"\.collections\[\d*\]",
+ r"\.patches\[\d*\]",
+ r"\.get_[xy]axis\(\)\.get_(major|minor)_ticks\(\)\[\d*\]",
+ r"\.get_[xy]axis\(\)\.get_label\(\)",
+ r"\.get_legend\(\)",
+ ]
command_obj_regexes = [re.compile(r) for r in command_obj_regexes]
fig = self.figure
@@ -271,7 +321,10 @@ def load(self):
block = getTextFromFile(getReference(self.figure), stack_position)
if not block:
- block = getTextFromFile(getReference(self.figure, allow_using_variable_names=False), stack_position)
+ block = getTextFromFile(
+ getReference(self.figure, allow_using_variable_names=False),
+ stack_position,
+ )
for line in block:
line = line.strip()
if line == "" or line in header or line.startswith("#"):
@@ -286,7 +339,7 @@ def load(self):
for r in command_obj_regexes:
try:
found = r.match(line).group()
- line = line[len(found):]
+ line = line[len(found) :]
command_obj += found
except AttributeError:
pass
@@ -306,7 +359,12 @@ def load(self):
reference_obj = command_obj
reference_command = command
- if command == ".set_xticks" or command == ".set_yticks" or command == ".set_xlabels" or command == ".set_ylabels":
+ if (
+ command == ".set_xticks"
+ or command == ".set_yticks"
+ or command == ".set_xlabels"
+ or command == ".set_ylabels"
+ ):
if line.find("minor=True") != -1:
reference_command = command + "_minor"
@@ -321,17 +379,25 @@ def load(self):
# if the reference object is just a dummy, we ignore it
if isinstance(reference_obj, Dummy):
- print("WARNING: line references a missing object, will remove line on save:", raw_line, file=sys.stderr)
+ print(
+ "WARNING: line references a missing object, will remove line on save:",
+ raw_line,
+ file=sys.stderr,
+ )
continue
self.get_reference_cached[reference_obj] = reference_obj_str
- #print("---", [reference_obj, reference_command], (command_obj, command + parameter))
- self.changes[reference_obj, reference_command] = (command_obj, command + parameter)
+ # print("---", [reference_obj, reference_command], (command_obj, command + parameter))
+ self.changes[reference_obj, reference_command] = (
+ command_obj,
+ command + parameter,
+ )
self.sorted_changes()
def sorted_changes(self):
- """ sort the changes by their priority. For example setting to logscale needs to be executed before xlim. """
+ """sort the changes by their priority. For example setting to logscale needs to be executed before xlim."""
+
def getRef(obj):
try:
return getReference(obj)
@@ -351,21 +417,43 @@ def getRef(obj):
if getattr(reference_obj, "axes", None) is not None:
if reference_command == ".new":
index = "0"
- elif reference_command == ".set_xscale" or reference_command == ".set_yscale":
+ elif (
+ reference_command == ".set_xscale"
+ or reference_command == ".set_yscale"
+ ):
index = "1"
- elif reference_command == ".set_xlim" or reference_command == ".set_ylim":
+ elif (
+ reference_command == ".set_xlim"
+ or reference_command == ".set_ylim"
+ ):
index = "2"
- elif reference_command == ".set_xticks" or reference_command == ".set_yticks":
+ elif (
+ reference_command == ".set_xticks"
+ or reference_command == ".set_yticks"
+ ):
index = "3"
- elif reference_command == ".set_xticklabels" or reference_command == ".set_yticklabels":
+ elif (
+ reference_command == ".set_xticklabels"
+ or reference_command == ".set_yticklabels"
+ ):
index = "4"
else:
index = "5"
- obj_indices = (getRef(reference_obj.axes), getRef(reference_obj), index, reference_command)
+ obj_indices = (
+ getRef(reference_obj.axes),
+ getRef(reference_obj),
+ index,
+ reference_command,
+ )
else:
obj_indices = (getRef(reference_obj), "", "", reference_command)
indices.append(
- [(reference_obj, reference_command), self.changes[reference_obj, reference_command], obj_indices])
+ [
+ (reference_obj, reference_command),
+ self.changes[reference_obj, reference_command],
+ obj_indices,
+ ]
+ )
except (ValueError, TypeError) as err:
print(err, file=sys.stderr)
@@ -380,13 +468,20 @@ def getRef(obj):
return output
- def save(self):
- """ save the changes to the .py file """
- header = [getReference(self.figure) + ".ax_dict = {ax.get_label(): ax for ax in " + getReference(
- self.figure) + ".axes}", "import matplotlib as mpl"]
+ def save(self, output_file, reqd_code):
+ """save the changes to the .py file"""
+ header = [
+ getReference(self.figure)
+ + ".ax_dict = {ax.get_label(): ax for ax in "
+ + getReference(self.figure)
+ + ".axes}",
+ "import matplotlib as mpl",
+ ]
# block = getTextFromFile(header[0], self.stack_position)
- output = [custom_prepend + "#% start: automatic generated code from pylustrator"]
+ output = [
+ custom_prepend + "#% start: automatic generated code from pylustrator"
+ ]
# add the lines from the header
for line in header:
output.append(line)
@@ -395,7 +490,9 @@ def save(self):
output.append(line)
if line.startswith("fig.add_axes"):
output.append(header[1])
- output.append("#% end: automatic generated code from pylustrator" + custom_append)
+ output.append(
+ "#% end: automatic generated code from pylustrator" + custom_append
+ )
# print("\n".join(output))
block_id = getReference(self.figure)
@@ -403,27 +500,33 @@ def save(self):
if not block:
block_id = getReference(self.figure, allow_using_variable_names=False)
block = getTextFromFile(block_id, stack_position)
- insertTextToFile(output, stack_position, block_id)
+ insertTextToFile(output, stack_position, block_id, output_file, reqd_code)
self.saved = True
def getTextFromFile(block_id: str, stack_pos: traceback.FrameSummary):
- """ get the text which corresponds to the block_id (e.g. which figure) at the given position sepcified by stack_pos. """
+ """get the text which corresponds to the block_id (e.g. which figure) at the given position sepcified by stack_pos."""
block_id = lineToId(block_id)
block = None
if not custom_stack_position:
- if not stack_pos.filename.endswith('.py') and not stack_pos.filename.startswith("