Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion csep/core/catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,8 @@ def magnitude_counts(self, mag_bins=None, tol=None, retbins=False):
else:
return out
idx = bin1d_vec(self.get_magnitudes(), mag_bins, tol=tol, right_continuous=True)
numpy.add.at(out, idx, 1)
valid = idx >= 0
numpy.add.at(out, idx[valid], 1)
if retbins:
return (mag_bins, out)
else:
Expand Down
235 changes: 168 additions & 67 deletions csep/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,15 @@ def plot_cumulative_events_versus_time(
pyplot.show()
return ax


def plot_magnitude_histogram(
catalog_forecast: Union["CatalogForecast", List["CSEPCatalog"]],
observation: "CSEPCatalog",
forecast: Union["CatalogForecast", "GriddedForecast"],
observation: Optional["CSEPCatalog"] = None,
magnitude_bins: Optional[Union[List[float], numpy.ndarray]] = None,
percentile: int = 95,
log_scale: bool = True,
normalize: bool = True,
cumulative: bool = False,
intervals: bool = True,
ax: Optional["matplotlib.axes.Axes"] = None,
show: bool = False,
**kwargs: Any,
Expand All @@ -416,7 +418,7 @@ def plot_magnitude_histogram(
- :ref:`Catalog-based Forecast Plots<catalog-forecast-evaluation-exploratory>`

Args:
catalog_forecast (CatalogForecast or list of CSEPCatalog): A catalog-based forecast
forecast (CatalogForecast or list of CSEPCatalog or GriddedForecast): A forecast
or a list of observed catalogs.
observation (CSEPCatalog): The observed catalog for comparison.
magnitude_bins (list of float or numpy.ndarray, optional): The bins for magnitude
Expand All @@ -426,6 +428,12 @@ def plot_magnitude_histogram(
`95`.
log_scale (bool, optional): Whether to plot the y-axis in logarithmic scale. Defaults to
True.
normalize (bool, optional): Whether to normalize the forecast for the total number in the
observation catalog.
cumulative (bool, optional): Whether to plot cumulative counts N(M >= m). Defaults to
False.
intervals (bool, optional): Whether to display forecast uncertainty intervals.
Defaults to True.
ax (matplotlib.axes.Axes, optional): The axes object to draw the plot on. If `None`, a
new figure and axes are created. Defaults to `None`.
show (bool, optional): Whether to display the plot immediately. Defaults to `False`.
Expand Down Expand Up @@ -453,11 +461,12 @@ def plot_magnitude_histogram(
matplotlib.axes.Axes: The axes object containing the plot.

.. versionchanged:: 0.8.0
It now requires a `CatalogForecast` rather than a list of stochastic event sets. The
`plot_args` dictionary is only partially supported and will be removed in v1.0.0
.. versionadded:: 0.8.0
Added `magnitude_bins`, `percentile` and `log_scale` to fine-tune the plot.
Added parameters to customize coloring, formatting and sizing of the plot elements.
It now accepts a `CatalogForecast` or a `GriddedForecast`. An obervation `CSEPCatalog`
is now optional. Added `magnitude_bins`, `percentile` and `log_scale` to fine-tune the
plot. Added parameters to customize coloring, formatting and sizing of the plot
elements. The `plot_args` dictionary is only partially supported and will be removed in
v1.0.0

"""
if "plot_args" in kwargs:
_warning_plot_args("plot_magnitude_histogram")
Expand All @@ -467,80 +476,175 @@ def plot_magnitude_histogram(
fig, ax = pyplot.subplots(figsize=plot_args["figsize"]) if ax is None else (ax.figure, ax)

# Get magnitudes from observations and (lazily) from forecast
forecast_mws = list(map(lambda x: x.get_magnitudes(), catalog_forecast))
obs_mw = observation.get_magnitudes()
n_obs = observation.get_number_of_events()

# Get magnitude bins from args, if not from region, or lastly from standard CSEP bins.
if magnitude_bins is None:
if magnitude_bins is not None:
forecast_bins = numpy.asarray(magnitude_bins)
else:
try:
magnitude_bins = observation.region.magnitudes
forecast_bins = getattr(forecast, "magnitudes")
except AttributeError:
magnitude_bins = CSEP_MW_BINS
raise AttributeError("Forecast must be defined on a 'region', having "
"'magnitudes' attribute as left-edge magnitude bins.")

dm_forecast = numpy.median(numpy.diff(forecast_bins))
forecast_centers = forecast_bins + dm_forecast / 2.0

def get_histogram_synthetic_cat(x, mags, normed=True):
n_temp = len(x)
if normed and n_temp != 0:
temp_scale = n_obs / n_temp
hist = numpy.histogram(x, bins=mags)[0] * temp_scale
if observation is not None:
if magnitude_bins is not None:
obs_bins, obs_counts = observation.magnitude_counts(
mag_bins=magnitude_bins, retbins=True
)
elif hasattr(observation, "region") and hasattr(observation.region, "magnitudes"):
obs_bins, obs_counts = observation.magnitude_counts(mag_bins=None, retbins=True)
else:
obs_bins, obs_counts = observation.magnitude_counts(
mag_bins=CSEP_MW_BINS, retbins=True
)
nonzero_idx = numpy.nonzero(obs_counts)[0]
if nonzero_idx.size > 0:
first = nonzero_idx[0]
obs_bins = obs_bins[first:]
obs_counts = obs_counts[first:]

if len(obs_bins) > 1:
dm_obs = numpy.median(numpy.diff(obs_bins))
else:
hist = numpy.histogram(x, bins=mags)[0]
return hist
dm_obs = dm_forecast
obs_centers = obs_bins + dm_obs / 2.0
idxs = numpy.where(obs_centers >= forecast_centers[0])[0]
if idxs.size > 0:
obs_index = idxs[0]
n_obs = numpy.sum(obs_counts[obs_index:])
else:
n_obs = 0
else:
obs_counts = None
obs_centers = None
n_obs = 0

if hasattr(forecast, "catalogs"):
forecast_mws = list(map(lambda x: x.get_magnitudes(), forecast))

def get_histogram_synthetic_cat(x, mags, normed_hist=normalize):
n_syn_events = len(x)
if normed_hist and n_syn_events != 0 and n_obs != 0:
temp_scale = n_obs / n_syn_events
hist = numpy.histogram(x, bins=mags)[0] * temp_scale
else:
hist = numpy.histogram(x, bins=mags)[0]
return hist

# get histogram values
forecast_hist = numpy.array(
list(map(lambda x: get_histogram_synthetic_cat(x, magnitude_bins), forecast_mws))
)
obs_hist, bin_edges = numpy.histogram(obs_mw, bins=magnitude_bins)
bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
# get histogram values
catalog_forecast_bins = numpy.append(forecast_bins, forecast_bins[-1] + dm_forecast)
forecast_hist = numpy.array(
list(map(lambda x: get_histogram_synthetic_cat(x, catalog_forecast_bins), forecast_mws))
)

if cumulative:
hist_for_stats = numpy.cumsum(forecast_hist[:, ::-1], axis=1)[:, ::-1]
else:
hist_for_stats = forecast_hist

forecast_mean = hist_for_stats.mean(axis=0)
if intervals:
lower_p = (100.0 - percentile) / 2.0
upper_p = 100.0 - lower_p
forecast_low = numpy.percentile(hist_for_stats, lower_p, axis=0)
forecast_high = numpy.percentile(hist_for_stats, upper_p, axis=0)
else:
forecast_low = None
forecast_high = None
else:
rates = numpy.asarray(forecast.magnitude_counts())
if len(rates) != len(forecast_bins):
raise ValueError(
"Length of forecast.magnitude_counts() must match number of forecast magnitude bins."
)
if normalize and n_obs != 0:
scale = n_obs / numpy.sum(rates)
else:
scale = 1.0

if cumulative:
lam = numpy.cumsum(rates[::-1])[::-1]
else:
lam = rates

forecast_mean = lam * scale
if intervals:
alpha = (100.0 - percentile) / 200.0
low_counts = poisson.ppf(alpha, lam)
high_counts = poisson.ppf(1.0 - alpha, lam)
forecast_low = low_counts * scale
forecast_high = high_counts * scale
else:
forecast_low = None
forecast_high = None

# Compute statistics for the forecast histograms
# Compute statistics for the forecast histograms
forecast_mean = numpy.mean(forecast_hist, axis=0)
forecast_median = numpy.median(forecast_hist, axis=0)
forecast_low = numpy.percentile(forecast_hist, (100 - percentile) / 2.0, axis=0)
forecast_high = numpy.percentile(forecast_hist, 100 - (100 - percentile) / 2.0, axis=0)
forecast_err_lower = forecast_median - forecast_low
forecast_err_upper = forecast_high - forecast_median
if intervals:
low = numpy.nan_to_num(forecast_low, nan=0.0)
high = numpy.nan_to_num(forecast_high, nan=forecast_mean)
low = numpy.minimum(low, forecast_mean)
high = numpy.maximum(high, forecast_mean)
forecast_err_lower = numpy.clip(forecast_mean - low, 0.0, None)
forecast_err_upper = numpy.clip(high - forecast_mean, 0.0, None)
else:
forecast_err_lower = None
forecast_err_upper = None

# cumulative transform for observation (after n_obs calculation)
if cumulative and obs_counts is not None:
obs_counts = numpy.cumsum(obs_counts[::-1])[::-1]

# Plot observed counts
ax.plot(
bin_centers,
obs_hist,
color=plot_args["color"],
marker="o",
lw=0,
markersize=plot_args["markersize"],
label="Observation",
zorder=3,
)
if obs_counts is not None:
ax.plot(
obs_centers,
obs_counts,
color=plot_args["color"],
marker="o",
lw=0,
markersize=plot_args["markersize"],
label="Observation",
zorder=3,
)
# Plot forecast histograms as bar plot with error bars
ax.plot(
bin_centers,
forecast_centers,
forecast_mean,
".",
markersize=plot_args["markersize"],
color="darkred",
label="Forecast Mean",
)
ax.errorbar(
bin_centers,
forecast_median,
yerr=[forecast_err_lower, forecast_err_upper],
fmt="None",
color="darkred",
markersize=plot_args["markersize"],
capsize=plot_args["capsize"],
linewidth=plot_args["linewidth"],
label="Forecast (95% CI)",
)

if intervals:
ax.errorbar(
forecast_centers,
forecast_mean,
yerr=[forecast_err_lower, forecast_err_upper],
fmt="None",
color="darkred",
markersize=plot_args["markersize"],
capsize=plot_args["capsize"],
linewidth=plot_args["linewidth"],
label=f"Forecast ({percentile}% CI)",
)

# Scale x-axis
if plot_args["xlim"]:
ax.set_xlim(plot_args["xlim"])
else:
ax = _autoscale_histogram(
ax, magnitude_bins, numpy.hstack(forecast_mws), obs_mw, mass=100
)
if observation is not None and hasattr(forecast, "catalogs"):
forecast_mws = [c.get_magnitudes() for c in forecast]
ax = _autoscale_histogram(
ax,
forecast_bins,
numpy.hstack(forecast_mws),
observation.get_magnitudes(),
mass=100,
)
# Scale y-axis
if log_scale:
ax.set_yscale('log')
Expand Down Expand Up @@ -624,7 +728,7 @@ def plot_basemap(
"""

if "plot_args" in kwargs:
_warning_plot_args("plot_magnitude_histogram")
_warning_plot_args("plot_basemap")

# Initialize plot
plot_args = {**DEFAULT_PLOT_ARGS, **kwargs}
Expand Down Expand Up @@ -767,7 +871,7 @@ def plot_catalog(
the events sizing.
"""
if "plot_args" in kwargs:
_warning_plot_args("plot_magnitude_histogram")
_warning_plot_args("plot_basemap")

# Initialize plot
plot_args = {**DEFAULT_PLOT_ARGS, **kwargs.get("plot_args", {}), **kwargs}
Expand Down Expand Up @@ -935,7 +1039,7 @@ def plot_gridded_dataset(
"""

if "plot_args" in kwargs:
_warning_plot_args("plot_magnitude_histogram")
_warning_plot_args("plot_gridded_dataset")
# Initialize plot

plot_args = {**DEFAULT_PLOT_ARGS, **kwargs.get("plot_args", {}), **kwargs}
Expand Down Expand Up @@ -2691,6 +2795,3 @@ def _warning_plot_args(func_name: str):
DeprecationWarning,
stacklevel=2
)



2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
"pandas": ("http://pandas.pydata.org/pandas-docs/stable/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
"matplotlib": ("https://matplotlib.org/stable", None),
"cartopy": ('https://scitools.org.uk/cartopy/docs/latest/', None)
"cartopy": ('https://cartopy.readthedocs.io/stable/', None)
}

html_theme_options = {}
Expand Down
Loading