-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMyFigureUtils.py
168 lines (142 loc) · 5.99 KB
/
MyFigureUtils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 23 14:57:23 2015
@author: Ken
"""
from inspect import getmembers, isclass
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
def splay_figures():
"""Get all figures and spread them across my secondary monitor"""
fig_list = plt.get_fignums()
wx = 640
h = 500
x1, x2, x3 = 1367, 1367 + wx, 1367 + wx*2
y0 = 30
y1 = 570
points = np.array([[x1,y0,wx,h],
[x2,y0,wx,h],
[x3,y0,wx,h],
[x1,y1,wx,h],
[x2,y1,wx,h],
[x3,y1,wx,h]])
if len(fig_list) == 2:
points = points[[2, 5]]
if len(fig_list) == 3:
points = points[[2, 4, 5]]
if len(fig_list) == 4:
points = points[[1, 2, 4, 5]]
for i in range(len(fig_list)):
plt.figure(fig_list[i])
plt.get_current_fig_manager().window.setGeometry(
points[i,0],points[i,1], points[i,2], points[i,3])
def raster_and_save(fname, rasterize_list=None, fig=None, dpi=None,
savefig_kw={}):
"""Save a figure with raster and vector components
This function lets you specify which objects to rasterize at the export
stage, rather than within each plotting call. Rasterizing certain
components of a complex figure can significantly reduce file size.
Inputs
------
fname : str
Output filename with extension
rasterize_list : list (or object)
List of objects to rasterize (or a single object to rasterize)
fig : matplotlib figure object
Defaults to current figure
dpi : int
Resolution (dots per inch) for rasterizing
savefig_kw : dict
Extra keywords to pass to matplotlib.pyplot.savefig
If rasterize_list is not specified, then all contour, pcolor, and
collects objects (e.g., ``scatter, fill_between`` etc) will be
rasterized
Note: does not work correctly with round=True in Basemap
Example
-------
Rasterize the contour, pcolor, and scatter plots, but not the line
>>> from numpy.random import random
>>> X, Y, Z = random((9, 9)), random((9, 9)), random((9, 9))
>>> fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2)
>>> cax1 = ax1.contourf(Z)
>>> cax2 = ax2.scatter(X, Y, s=Z)
>>> cax3 = ax3.pcolormesh(Z)
>>> cax4 = ax4.plot(Z[:, 0])
>>> rasterize_list = [cax1, cax2, cax3]
>>> raster_and_save('out.svg', rasterize_list, fig=fig, dpi=300)
"""
# Behave like pyplot and act on current figure if no figure is specified
fig = plt.gcf() if fig is None else fig
# Need to set_rasterization_zorder in order for rasterizing to work
zorder = -5 # Somewhat arbitrary, just ensuring less than 0
if rasterize_list is None:
# Have a guess at stuff that should be rasterised
types_to_raster = ['QuadMesh', 'Contour', 'collections']
rasterize_list = []
print("""
No rasterize_list specified, so the following objects will
be rasterized: """)
# Get all axes, and then get objects within axes
for ax in fig.get_axes():
for item in ax.get_children():
if any(x in str(item) for x in types_to_raster):
rasterize_list.append(item)
print('\n'.join([str(x) for x in rasterize_list]))
else:
# Allow rasterize_list to be input as an object to rasterize
if type(rasterize_list) != list:
rasterize_list = [rasterize_list]
for item in rasterize_list:
# Whether or not plot is a contour plot is important
is_contour = (isinstance(item, matplotlib.contour.QuadContourSet) or
isinstance(item, matplotlib.tri.TriContourSet))
# Whether or not collection of lines
# This is commented as we seldom want to rasterize lines
# is_lines = isinstance(item, matplotlib.collections.LineCollection)
# Whether or not current item is list of patches
all_patch_types = tuple(
x[1] for x in getmembers(matplotlib.patches, isclass))
try:
is_patch_list = isinstance(item[0], all_patch_types)
except TypeError:
is_patch_list = False
# Convert to rasterized mode and then change zorder properties
if is_contour:
curr_ax = item.ax.axes
curr_ax.set_rasterization_zorder(zorder)
# For contour plots, need to set each part of the contour
# collection individually
for contour_level in item.collections:
contour_level.set_zorder(zorder - 1)
contour_level.set_rasterized(True)
elif is_patch_list:
# For list of patches, need to set zorder for each patch
for patch in item:
curr_ax = patch.axes
curr_ax.set_rasterization_zorder(zorder)
patch.set_zorder(zorder - 1)
patch.set_rasterized(True)
else:
# For all other objects, we can just do it all at once
curr_ax = item.axes
curr_ax.set_rasterization_zorder(zorder)
item.set_rasterized(True)
item.set_zorder(zorder - 1)
# dpi is a savefig keyword argument, but treat it as special since it is
# important to this function
if dpi is not None:
savefig_kw['dpi'] = dpi
# Save resulting figure
fig.savefig(fname, **savefig_kw)
# Test raster_and_save
if __name__ is '__main__':
from numpy.random import random
X, Y, Z = random((9, 9)), random((9, 9)), random((9, 9))
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(ncols=2, nrows=2)
cax1 = ax1.contourf(Z)
cax2 = ax2.scatter(X, Y, s=Z)
cax3 = ax3.pcolormesh(Z)
cax4 = ax4.plot(Z[:, 0])
rasterize_list = [cax1, cax2, cax3]
raster_and_save('out.svg', rasterize_list, fig=fig, dpi=300)