Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gcgridobj/gc_horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Horizontal global grids first
gmao_4x5_global = latlontools.gen_hrz_grid(lat_stride=4.0, lon_stride=5.0, half_polar=True,center_180=True)
gmao_2x25_global = latlontools.gen_hrz_grid(lat_stride=2.0, lon_stride=2.5, half_polar=True,center_180=True)
gmao_1x1_global = latlontools.gen_hrz_grid(lat_stride=1.0, lon_stride=1.0, half_polar=False,center_180=True)
gmao_05x0666_global = latlontools.gen_hrz_grid(lat_stride=0.5, lon_stride=2/3, half_polar=True,center_180=True)
gmao_05x0625_global = latlontools.gen_hrz_grid(lat_stride=0.5, lon_stride=5/8, half_polar=True,center_180=True)
gmao_025x03125_global = latlontools.gen_hrz_grid(lat_stride=0.25,lon_stride=5/16,half_polar=True,center_180=True)
Expand All @@ -20,6 +21,7 @@
# All grids
global_grid_inventory = [gmao_4x5_global,
gmao_2x25_global,
gmao_1x1_global,
gmao_05x0666_global,
gmao_05x0625_global,
gmao_025x03125_global]
Expand Down
142 changes: 111 additions & 31 deletions gcgridobj/plottools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
'plot_country']

crs_plot_standard = ccrs.PlateCarree()
crs_data_standard = ccrs.PlateCarree()
crs_data_standard = ccrs.PlateCarree()

def reshape_cs(cs_data):
warnings.warn('plottools.reshape_cs is deprecated. Please use regrid.reshape_cs instead',FutureWarning)
Expand All @@ -34,16 +34,17 @@ def gen_l2c_regridder(cs_grid,ll_grid,method='conservative',grid_dir='.'):
def gen_cs_regridder(cs_grid,ll_grid,method='conservative',grid_dir='.'):
warnings.warn('plottools.gen_cs_regridder is deprecated. Please use regrid.gen_regridder instead',FutureWarning)
return regrid.gen_regridder(cs_grid,ll_grid,method,grid_dir)

def guess_cs_grid(cs_data_shape):
# This used to return face side length and is_gmao, which was inconsistent
warnings.warn('plottools.guess_cs_grid is deprecated. Please use regrid.guess_cs_grid or regrid.guess_n_cs instead',FutureWarning)
return regrid.guess_n_cs(cs_data_shape)

def plot_zonal(zonal_data,hrz_grid,vrt_grid,ax=None,show_colorbar=True,z_edge=None,vert_coord='altitude',
sec_axis=False,sec_minor=False,sec_ticklabels=True,sec_axlabel=True):
sec_axis=False,sec_minor=False,sec_ticklabels=True,sec_axlabel=True,isDiff=False,
figSize=None,latTicks=None,cbTitle=None,title=None):
'''Plot 2D data as a zonal profile


Keyword arguments:
ax -- axes to use for plotting (default None, results in new axes)
Expand All @@ -65,26 +66,53 @@ def plot_zonal(zonal_data,hrz_grid,vrt_grid,ax=None,show_colorbar=True,z_edge=No
elif z_edge is None:
z_edge = vrt_grid.z_edge_ISA() / 1000.0
alt_b = z_edge

assert len(zonal_data.shape) == 2, 'Zonal data must be 2-D'
assert len(alt_b) == zonal_data.shape[0]+1, 'Zonal data incorrectly shaped (altitude)'
assert len(lat_b) == zonal_data.shape[1]+1, 'Zonal data incorrectly shaped (latitude)'

if ax is None:
f, ax = plt.subplots(1,1,figsize=(8,5))
if figSize is None:
f, ax = plt.subplots(1,1,figsize=(8,5))
else:
f, ax = plt.subplots(1,1,figsize=figSize)
else:
f = ax.figure


im = ax.pcolormesh(lat_b,alt_b,zonal_data)

if np.min(zonal_data) < 0.0 or isDiff:
im.set_cmap('RdBu_r')
clim_max = np.max(np.abs(im.get_clim()))
im.set_clim(np.array([-1,1])*clim_max)

if latTicks is None:
latTicks = [-90,-60,-30,0,30,60,90]
labelSet = []
for tick in latTicks:
tick_str = '{:d}'.format(np.abs(tick))
if tick < -0.01:
tick_label = tick_str + '$^\circ$' + 'S'
elif tick > 0.01:
tick_label = tick_str + '$^\circ$' + 'N'
else:
tick_label = '0'
labelSet.append(tick_label)

if vert_coord == 'pressure':
ax.invert_yaxis()
ax.set_yscale('log')
ax.set_ylabel('Pressure, hPa')
ax.set_ylabel('Pressure, hPa', fontsize=18)
elif vert_coord == 'altitude':
ax.set_ylabel('Altitude, km')
ax.set_ylabel('Altitude, km', fontsize=18)
else:
raise ValueError('Vertical coordinate {:s} not recognized'.format(vert_coord))
ax.set_xlabel('Latitude', fontsize=18)
ax.set_xticks(latTicks)
ax.set_xticklabels(labelSet)
ax.tick_params(axis='both', which='major', labelsize=18)


cb_pad = 0.04
if sec_axis:
Expand Down Expand Up @@ -122,17 +150,17 @@ def p_to_z(p):
# we mark tick points on it based on our transformed vertical coord.
# The axes are also linked so that a change in the y-coord of the
# primary axis will modify the secondary one. However, changing the
# y-scale of the primary axis will not cause an appropriate change
# y-scale of the primary axis will not cause an appropriate change
# in the secondary axis. Equally, although the current ticks will
# always turn up in the right place when y-limits are changed, new
# ticks will not be produced if (for example) a very small altitude
# range is desired.
# range is desired.
def update_ax2(ax1):
y1,y2 = ax1.get_ylim()
#ax2.set_ylim(sa_fwd(y1),sa_fwd(y2))
ax2.set_ylim(y1,y2)
ax2.figure.canvas.draw()

ax2 = ax.twinx()

# log-p and z do not exactly line up - need to identify "by hand"
Expand Down Expand Up @@ -181,7 +209,7 @@ def tick_label_gen(ticks):
# Force minor ticks to also be shown (dangerous!)
if sec_minor and sec_ticklabels:
ax2.set_yticklabels(tick_label_gen(alt_minor),minor=True)

# Initialize the limits
update_ax2(ax)

Expand All @@ -190,21 +218,28 @@ def tick_label_gen(ticks):
# === END IF ===

if sec_axlabel:
ax2.set_ylabel(sec_name)
ax2.set_ylabel(sec_name, fontsize=18)
else:
ax2.set_ylabel('')

if not sec_ticklabels:
ax2.set_yticklabels([])
ax2.tick_params(axis='both', which='major', labelsize=18)

if show_colorbar:
cb = f.colorbar(im, ax=ax, shrink=0.6, orientation='vertical', pad=cb_pad)
cb = f.colorbar(im, ax=ax, shrink=1.0, orientation='vertical', pad=cb_pad)
if cbTitle is not None:
cb.set_label(cbTitle, fontsize=18)
cb.ax.tick_params(axis='y', which='major', labelsize=18)
else:
cb = None

if title is not None:
ax.set_title(title, pad=15, fontsize=18);

return im, cb

def plot_layer(layer_data,hrz_grid=None,ax=None,crs_data=None,crs_plot=None,show_colorbar=True,coastlines=True):
def plot_layer(layer_data,hrz_grid=None,ax=None,crs_data=None,crs_plot=None,show_colorbar=True,coastlines=True,figSize=None,isDiff=False,latTicks=None,lonTicks=None,cbTitle=None,title=None):

if crs_data is None:
crs_data = crs_data_standard
Expand All @@ -213,10 +248,39 @@ def plot_layer(layer_data,hrz_grid=None,ax=None,crs_data=None,crs_plot=None,show
crs_plot = crs_plot_standard

if ax is None:
f, ax = plt.subplots(1,1,figsize=(8,5),subplot_kw={'projection':crs_plot})
if figSize is None:
f, ax = plt.subplots(1,1,figsize=(8,5),subplot_kw={'projection':crs_plot})
else:
f, ax = plt.subplots(1,1,figsize=figSize,subplot_kw={'projection':crs_plot})
else:
f = ax.figure

if latTicks is None:
latTicks = [-90,-60,-30,0,30,60,90]
latSet = []
for tick in latTicks:
tick_str = '{:d}'.format(np.abs(tick))
if tick < -0.01:
tick_label = tick_str + '$^\circ$' + 'S'
elif tick > 0.01:
tick_label = tick_str + '$^\circ$' + 'N'
else:
tick_label = '0'
latSet.append(tick_label)

if lonTicks is None:
lonTicks = [-180,-120,-60,0,60,120,180]
lonSet = []
for tick in lonTicks:
tick_str = '{:d}'.format(np.abs(tick))
if tick < -0.01:
tick_label = tick_str + '$^\circ$' + 'W'
elif tick > 0.01:
tick_label = tick_str + '$^\circ$' + 'E'
else:
tick_label = '0'
lonSet.append(tick_label)

# Test the data; if it looks cubed-sphere, throw it to the CS routines. Otherwise assume lat-lon
ld_shape = layer_data.shape
if len(ld_shape) < 2 or len(ld_shape) > 3:
Expand All @@ -231,7 +295,14 @@ def plot_layer(layer_data,hrz_grid=None,ax=None,crs_data=None,crs_plot=None,show
im_obj = plot_cs(layer_data,hrz_grid=hrz_grid,ax=ax,crs_data=crs_data,crs_plot=crs_plot)
else:
# Assume lat-lon
im_obj = plot_latlon(layer_data,hrz_grid=hrz_grid,ax=ax,crs_data=crs_data,crs_plot=crs_plot)
im_obj = plot_latlon(layer_data,hrz_grid=hrz_grid,ax=ax,crs_data=crs_data,crs_plot=crs_plot,isDiff=isDiff)
ax.set_ylabel('Latitude', fontsize=18)
ax.set_yticks(latTicks)
ax.set_yticklabels(latSet)
ax.set_xlabel('Longitude', fontsize=18)
ax.set_xticks(lonTicks)
ax.set_xticklabels(lonSet)
ax.tick_params(axis='both', which='major', labelsize=18)

# If cubed-sphere, use the first image
is_cs = isinstance(im_obj, list)
Expand All @@ -241,18 +312,24 @@ def plot_layer(layer_data,hrz_grid=None,ax=None,crs_data=None,crs_plot=None,show
im = im_obj

if show_colorbar:
cb = f.colorbar(im, ax=ax, shrink=0.6, orientation='vertical', pad=0.04)
cb = f.colorbar(im, ax=ax, shrink=1.0, orientation='vertical', pad=0.04)
if cbTitle is not None:
cb.set_label(cbTitle, fontsize=18)
cb.ax.tick_params(axis='y', which='major', labelsize=18)
else:
cb = None

if title is not None:
ax.set_title(title, pad=15, fontsize=18);

if coastlines:
# If user wants a different resolution, they can disable set coastlines=False
# and run this command after calling plot_layer
ax.coastlines('50m')

return im_obj, cb

def plot_latlon(layer_data,hrz_grid=None,ax=None,crs_data=None,crs_plot=None,show_colorbar=True):
def plot_latlon(layer_data,hrz_grid=None,ax=None,crs_data=None,crs_plot=None,show_colorbar=True,isDiff=False):
'''Plot 2D lat-lon data
'''

Expand All @@ -262,7 +339,7 @@ def plot_latlon(layer_data,hrz_grid=None,ax=None,crs_data=None,crs_plot=None,sho
#hrz_grid = gc_horizontal.get_grid(layer_data.shape)
hrz_grid = regrid.guess_ll_grid(layer_data.shape)
assert hrz_grid is not None, 'Could not auto-identify grid'

lon_b = hrz_grid['lon_b']
lat_b = hrz_grid['lat_b']

Expand All @@ -271,23 +348,26 @@ def plot_latlon(layer_data,hrz_grid=None,ax=None,crs_data=None,crs_plot=None,sho

im = ax.pcolormesh(lon_b,lat_b,layer_data,transform=crs_data)

if np.min(layer_data) < 0.0 or isDiff:
im.set_cmap('RdBu_r')

return im

def update_cs(layer_data,im_vec,hrz_grid=None,cs_threshold=5):
# WARNING: layer_data must be [6 x N x N]
if hrz_grid is None:
# Try to figure out the grid from the layer data
#n_cs, is_gmao = guess_cs_grid(layer_data.shape)
#n_cs, is_gmao = guess_cs_grid(layer_data.shape)
#hrz_grid = cubedsphere.csgrid_GMAO(n_cs)
hrz_grid = regrid.guess_cs_grid(layer_data.shape)
hrz_grid = regrid.guess_cs_grid(layer_data.shape)
masked_data = np.ma.masked_where(np.abs(hrz_grid['lon'] - 180.0) < cs_threshold, layer_data)
for i_face in range(6):
im_vec[i_face].set_array(masked_data[i_face,:,:].ravel())

def plot_cs(layer_data,hrz_grid=None,ax=None,crs_data=None,crs_plot=None,show_colorbar=True,cs_threshold=5.0):

# 2019-12-17: dropped support for non-GMAO grids
#n_cs, is_gmao = regrid.guess_n_cs(layer_data.shape)
#n_cs, is_gmao = regrid.guess_n_cs(layer_data.shape)

#if is_gmao:
# # Use data as-is
Expand All @@ -300,9 +380,9 @@ def plot_cs(layer_data,hrz_grid=None,ax=None,crs_data=None,crs_plot=None,show_co
# Try to figure out the grid from the layer data
#hrz_grid = cubedsphere.csgrid_GMAO(n_cs)
hrz_grid = regrid.guess_cs_grid(layer_data.shape)

masked_data = np.ma.masked_where(np.abs(hrz_grid['lon'] - 180.0) < cs_threshold, layer_data)

im_vec = []
for i_face in range(6):
im = ax.pcolormesh(hrz_grid['lon_b'][i_face,:,:],hrz_grid['lat_b'][i_face,:,:],masked_data[i_face,:,:],transform=crs_data)
Expand All @@ -320,10 +400,10 @@ def plot_cs(layer_data,hrz_grid=None,ax=None,crs_data=None,crs_plot=None,show_co
def plot_shape(state_name,state_val,shape_data_archive,classifier,
edgecolor='black',cmap=None,c_lim=(0.0,1.0),ax=None,nofail=True):
'''Plot an shape onto a set of axes'''

if ax is None:
f, ax = plt.subplots(1,1,figsize=(10,8),subplot_kw={'projection': ccrs.PlateCarree()})

if state_val is None:
facecolor = 'none'
else:
Expand All @@ -333,18 +413,18 @@ def plot_shape(state_name,state_val,shape_data_archive,classifier,
temp_cm = plt.get_cmap(cmap)
else:
temp_cm = cmap

cmap_val = (state_val - c_lim[0]) / (c_lim[1] - c_lim[0])
facecolor = temp_cm(cmap_val)

im_shp = None

for astate in shpreader.Reader(shape_data_archive).records():
if state_name == astate.attributes[classifier]:
im_shp = ax.add_geometries([astate.geometry], ccrs.PlateCarree(),
facecolor=facecolor,edgecolor=edgecolor)
break

if im_shp is None:
state_msg = 'Shape ''{:s}'' not found'.format(state_name)
if nofail:
Expand Down