diff --git a/csep/core/catalogs.py b/csep/core/catalogs.py index 8178be55..e9a28501 100644 --- a/csep/core/catalogs.py +++ b/csep/core/catalogs.py @@ -733,7 +733,8 @@ def magnitude_counts(self, mag_bins=None, tol=None, retbins=False): else: return out idx = bin1d_vec(self.get_magnitudes(), mag_bins, tol=tol, right_continuous=True) - numpy.add.at(out, idx, 1) + valid = idx >= 0 + numpy.add.at(out, idx[valid], 1) if retbins: return (mag_bins, out) else: diff --git a/csep/plots.py b/csep/plots.py index cf801c6f..fea00940 100644 --- a/csep/plots.py +++ b/csep/plots.py @@ -395,13 +395,15 @@ def plot_cumulative_events_versus_time( pyplot.show() return ax - def plot_magnitude_histogram( - catalog_forecast: Union["CatalogForecast", List["CSEPCatalog"]], - observation: "CSEPCatalog", + forecast: Union["CatalogForecast", "GriddedForecast"], + observation: Optional["CSEPCatalog"] = None, magnitude_bins: Optional[Union[List[float], numpy.ndarray]] = None, percentile: int = 95, log_scale: bool = True, + normalize: bool = True, + cumulative: bool = False, + intervals: bool = True, ax: Optional["matplotlib.axes.Axes"] = None, show: bool = False, **kwargs: Any, @@ -416,7 +418,7 @@ def plot_magnitude_histogram( - :ref:`Catalog-based Forecast Plots` Args: - catalog_forecast (CatalogForecast or list of CSEPCatalog): A catalog-based forecast + forecast (CatalogForecast or list of CSEPCatalog or GriddedForecast): A forecast or a list of observed catalogs. observation (CSEPCatalog): The observed catalog for comparison. magnitude_bins (list of float or numpy.ndarray, optional): The bins for magnitude @@ -426,6 +428,12 @@ def plot_magnitude_histogram( `95`. log_scale (bool, optional): Whether to plot the y-axis in logarithmic scale. Defaults to True. + normalize (bool, optional): Whether to normalize the forecast for the total number in the + observation catalog. + cumulative (bool, optional): Whether to plot cumulative counts N(M >= m). Defaults to + False. + intervals (bool, optional): Whether to display forecast uncertainty intervals. + Defaults to True. ax (matplotlib.axes.Axes, optional): The axes object to draw the plot on. If `None`, a new figure and axes are created. Defaults to `None`. show (bool, optional): Whether to display the plot immediately. Defaults to `False`. @@ -453,11 +461,12 @@ def plot_magnitude_histogram( matplotlib.axes.Axes: The axes object containing the plot. .. versionchanged:: 0.8.0 - It now requires a `CatalogForecast` rather than a list of stochastic event sets. The - `plot_args` dictionary is only partially supported and will be removed in v1.0.0 - .. versionadded:: 0.8.0 - Added `magnitude_bins`, `percentile` and `log_scale` to fine-tune the plot. - Added parameters to customize coloring, formatting and sizing of the plot elements. + It now accepts a `CatalogForecast` or a `GriddedForecast`. An obervation `CSEPCatalog` + is now optional. Added `magnitude_bins`, `percentile` and `log_scale` to fine-tune the + plot. Added parameters to customize coloring, formatting and sizing of the plot + elements. The `plot_args` dictionary is only partially supported and will be removed in + v1.0.0 + """ if "plot_args" in kwargs: _warning_plot_args("plot_magnitude_histogram") @@ -467,80 +476,175 @@ def plot_magnitude_histogram( fig, ax = pyplot.subplots(figsize=plot_args["figsize"]) if ax is None else (ax.figure, ax) # Get magnitudes from observations and (lazily) from forecast - forecast_mws = list(map(lambda x: x.get_magnitudes(), catalog_forecast)) - obs_mw = observation.get_magnitudes() - n_obs = observation.get_number_of_events() - - # Get magnitude bins from args, if not from region, or lastly from standard CSEP bins. - if magnitude_bins is None: + if magnitude_bins is not None: + forecast_bins = numpy.asarray(magnitude_bins) + else: try: - magnitude_bins = observation.region.magnitudes + forecast_bins = getattr(forecast, "magnitudes") except AttributeError: - magnitude_bins = CSEP_MW_BINS + raise AttributeError("Forecast must be defined on a 'region', having " + "'magnitudes' attribute as left-edge magnitude bins.") + + dm_forecast = numpy.median(numpy.diff(forecast_bins)) + forecast_centers = forecast_bins + dm_forecast / 2.0 - def get_histogram_synthetic_cat(x, mags, normed=True): - n_temp = len(x) - if normed and n_temp != 0: - temp_scale = n_obs / n_temp - hist = numpy.histogram(x, bins=mags)[0] * temp_scale + if observation is not None: + if magnitude_bins is not None: + obs_bins, obs_counts = observation.magnitude_counts( + mag_bins=magnitude_bins, retbins=True + ) + elif hasattr(observation, "region") and hasattr(observation.region, "magnitudes"): + obs_bins, obs_counts = observation.magnitude_counts(mag_bins=None, retbins=True) + else: + obs_bins, obs_counts = observation.magnitude_counts( + mag_bins=CSEP_MW_BINS, retbins=True + ) + nonzero_idx = numpy.nonzero(obs_counts)[0] + if nonzero_idx.size > 0: + first = nonzero_idx[0] + obs_bins = obs_bins[first:] + obs_counts = obs_counts[first:] + + if len(obs_bins) > 1: + dm_obs = numpy.median(numpy.diff(obs_bins)) else: - hist = numpy.histogram(x, bins=mags)[0] - return hist + dm_obs = dm_forecast + obs_centers = obs_bins + dm_obs / 2.0 + idxs = numpy.where(obs_centers >= forecast_centers[0])[0] + if idxs.size > 0: + obs_index = idxs[0] + n_obs = numpy.sum(obs_counts[obs_index:]) + else: + n_obs = 0 + else: + obs_counts = None + obs_centers = None + n_obs = 0 + + if hasattr(forecast, "catalogs"): + forecast_mws = list(map(lambda x: x.get_magnitudes(), forecast)) + + def get_histogram_synthetic_cat(x, mags, normed_hist=normalize): + n_syn_events = len(x) + if normed_hist and n_syn_events != 0 and n_obs != 0: + temp_scale = n_obs / n_syn_events + hist = numpy.histogram(x, bins=mags)[0] * temp_scale + else: + hist = numpy.histogram(x, bins=mags)[0] + return hist - # get histogram values - forecast_hist = numpy.array( - list(map(lambda x: get_histogram_synthetic_cat(x, magnitude_bins), forecast_mws)) - ) - obs_hist, bin_edges = numpy.histogram(obs_mw, bins=magnitude_bins) - bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2 + # get histogram values + catalog_forecast_bins = numpy.append(forecast_bins, forecast_bins[-1] + dm_forecast) + forecast_hist = numpy.array( + list(map(lambda x: get_histogram_synthetic_cat(x, catalog_forecast_bins), forecast_mws)) + ) + if cumulative: + hist_for_stats = numpy.cumsum(forecast_hist[:, ::-1], axis=1)[:, ::-1] + else: + hist_for_stats = forecast_hist + + forecast_mean = hist_for_stats.mean(axis=0) + if intervals: + lower_p = (100.0 - percentile) / 2.0 + upper_p = 100.0 - lower_p + forecast_low = numpy.percentile(hist_for_stats, lower_p, axis=0) + forecast_high = numpy.percentile(hist_for_stats, upper_p, axis=0) + else: + forecast_low = None + forecast_high = None + else: + rates = numpy.asarray(forecast.magnitude_counts()) + if len(rates) != len(forecast_bins): + raise ValueError( + "Length of forecast.magnitude_counts() must match number of forecast magnitude bins." + ) + if normalize and n_obs != 0: + scale = n_obs / numpy.sum(rates) + else: + scale = 1.0 + + if cumulative: + lam = numpy.cumsum(rates[::-1])[::-1] + else: + lam = rates + + forecast_mean = lam * scale + if intervals: + alpha = (100.0 - percentile) / 200.0 + low_counts = poisson.ppf(alpha, lam) + high_counts = poisson.ppf(1.0 - alpha, lam) + forecast_low = low_counts * scale + forecast_high = high_counts * scale + else: + forecast_low = None + forecast_high = None + + # Compute statistics for the forecast histograms # Compute statistics for the forecast histograms - forecast_mean = numpy.mean(forecast_hist, axis=0) - forecast_median = numpy.median(forecast_hist, axis=0) - forecast_low = numpy.percentile(forecast_hist, (100 - percentile) / 2.0, axis=0) - forecast_high = numpy.percentile(forecast_hist, 100 - (100 - percentile) / 2.0, axis=0) - forecast_err_lower = forecast_median - forecast_low - forecast_err_upper = forecast_high - forecast_median + if intervals: + low = numpy.nan_to_num(forecast_low, nan=0.0) + high = numpy.nan_to_num(forecast_high, nan=forecast_mean) + low = numpy.minimum(low, forecast_mean) + high = numpy.maximum(high, forecast_mean) + forecast_err_lower = numpy.clip(forecast_mean - low, 0.0, None) + forecast_err_upper = numpy.clip(high - forecast_mean, 0.0, None) + else: + forecast_err_lower = None + forecast_err_upper = None + + # cumulative transform for observation (after n_obs calculation) + if cumulative and obs_counts is not None: + obs_counts = numpy.cumsum(obs_counts[::-1])[::-1] # Plot observed counts - ax.plot( - bin_centers, - obs_hist, - color=plot_args["color"], - marker="o", - lw=0, - markersize=plot_args["markersize"], - label="Observation", - zorder=3, - ) + if obs_counts is not None: + ax.plot( + obs_centers, + obs_counts, + color=plot_args["color"], + marker="o", + lw=0, + markersize=plot_args["markersize"], + label="Observation", + zorder=3, + ) # Plot forecast histograms as bar plot with error bars ax.plot( - bin_centers, + forecast_centers, forecast_mean, ".", markersize=plot_args["markersize"], color="darkred", label="Forecast Mean", ) - ax.errorbar( - bin_centers, - forecast_median, - yerr=[forecast_err_lower, forecast_err_upper], - fmt="None", - color="darkred", - markersize=plot_args["markersize"], - capsize=plot_args["capsize"], - linewidth=plot_args["linewidth"], - label="Forecast (95% CI)", - ) + + if intervals: + ax.errorbar( + forecast_centers, + forecast_mean, + yerr=[forecast_err_lower, forecast_err_upper], + fmt="None", + color="darkred", + markersize=plot_args["markersize"], + capsize=plot_args["capsize"], + linewidth=plot_args["linewidth"], + label=f"Forecast ({percentile}% CI)", + ) # Scale x-axis if plot_args["xlim"]: ax.set_xlim(plot_args["xlim"]) else: - ax = _autoscale_histogram( - ax, magnitude_bins, numpy.hstack(forecast_mws), obs_mw, mass=100 - ) + if observation is not None and hasattr(forecast, "catalogs"): + forecast_mws = [c.get_magnitudes() for c in forecast] + ax = _autoscale_histogram( + ax, + forecast_bins, + numpy.hstack(forecast_mws), + observation.get_magnitudes(), + mass=100, + ) # Scale y-axis if log_scale: ax.set_yscale('log') @@ -624,7 +728,7 @@ def plot_basemap( """ if "plot_args" in kwargs: - _warning_plot_args("plot_magnitude_histogram") + _warning_plot_args("plot_basemap") # Initialize plot plot_args = {**DEFAULT_PLOT_ARGS, **kwargs} @@ -767,7 +871,7 @@ def plot_catalog( the events sizing. """ if "plot_args" in kwargs: - _warning_plot_args("plot_magnitude_histogram") + _warning_plot_args("plot_basemap") # Initialize plot plot_args = {**DEFAULT_PLOT_ARGS, **kwargs.get("plot_args", {}), **kwargs} @@ -935,7 +1039,7 @@ def plot_gridded_dataset( """ if "plot_args" in kwargs: - _warning_plot_args("plot_magnitude_histogram") + _warning_plot_args("plot_gridded_dataset") # Initialize plot plot_args = {**DEFAULT_PLOT_ARGS, **kwargs.get("plot_args", {}), **kwargs} @@ -2691,6 +2795,3 @@ def _warning_plot_args(func_name: str): DeprecationWarning, stacklevel=2 ) - - - diff --git a/docs/conf.py b/docs/conf.py index 5ee4fe7e..4c285342 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -103,7 +103,7 @@ "pandas": ("http://pandas.pydata.org/pandas-docs/stable/", None), "scipy": ("https://docs.scipy.org/doc/scipy/", None), "matplotlib": ("https://matplotlib.org/stable", None), - "cartopy": ('https://scitools.org.uk/cartopy/docs/latest/', None) + "cartopy": ('https://cartopy.readthedocs.io/stable/', None) } html_theme_options = {} diff --git a/tests/test_plots.py b/tests/test_plots.py index 15eb1b2b..28b3594e 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -24,6 +24,8 @@ CatalogPseudolikelihoodTestResult, CalibrationTestResult, ) +from csep.utils.calc import bin1d_vec +from csep.utils.constants import CSEP_MW_BINS from csep.plots import ( plot_cumulative_events_versus_time, plot_magnitude_versus_time, @@ -241,7 +243,95 @@ def tearDown(self): class TestPlotMagnitudeHistogram(TestPlots): + + class DummyCatalog: + def __init__(self, mags, region_mags=None): + import numpy + self._mags = numpy.asarray(mags) + if region_mags is not None: + self.region = type("R", (), {})() + self.region.magnitudes = numpy.asarray(region_mags) + else: + self.region = None + + def get_magnitudes(self): + return self._mags + + def magnitude_counts(self, mag_bins=None, tol=None, retbins=False): + # Minimal AbstractBaseCatalog.magnitude_counts behaviour + + + if mag_bins is None: + if self.region is not None and hasattr(self.region, "magnitudes"): + mag_bins = self.region.magnitudes + else: + mag_bins = CSEP_MW_BINS + + mag_bins = numpy.asarray(mag_bins) + out = numpy.zeros(len(mag_bins)) + if self._mags.size > 0: + idx = bin1d_vec(self._mags, mag_bins, tol=tol, right_continuous=True) + valid = idx >= 0 + numpy.add.at(out, idx[valid], 1) + + if retbins: + return mag_bins, out + return out + + class DummyCatalogForecast: + def __init__(self, catalogs, region_mags): + import numpy + self.catalogs = list(catalogs) + self.region = type("R", (), {})() + self.region.magnitudes = numpy.asarray(region_mags) + self.n_cat = len(self.catalogs) + self._idx = 0 + + def __iter__(self): + self._idx = 0 + return self + + def __next__(self): + if self._idx >= self.n_cat: + self._idx = 0 + raise StopIteration() + cat = self.catalogs[self._idx] + self._idx += 1 + return cat + + @property + def magnitudes(self): + return self.region.magnitudes + + class DummyGriddedForecast: + def __init__(self, rates, region_mags): + import numpy + self._rates = numpy.asarray(rates) + self.region = type("R", (), {})() + self.region.magnitudes = numpy.asarray(region_mags) + + @property + def magnitudes(self): + return self.region.magnitudes + + def magnitude_counts(self): + return self._rates + + class DummyGriddedForecastBadCounts(DummyGriddedForecast): + def magnitude_counts(self): + import numpy + rates = super().magnitude_counts() + if rates.size > 1: + return rates[:-1] + return rates + + class DummyForecastNoMagnitudes: + pass + + # --- setUp ------------------------------------------------------------- + def setUp(self): + super().setUp() def gr_dist(num_events, mag_min=3.0, mag_max=8.0, b_val=1.0): U = numpy.random.uniform(0, 1, num_events) @@ -249,15 +339,28 @@ def gr_dist(num_events, mag_min=3.0, mag_max=8.0, b_val=1.0): magnitudes = magnitudes[magnitudes <= mag_max] return magnitudes - self.mock_forecast = [MagicMock(), MagicMock(), MagicMock()] - for i in self.mock_forecast: - i.get_magnitudes.return_value = gr_dist(5000) + # Regular magnitude bins for dummy forecasts + self.region_mags = numpy.arange(3.0, 8.0, 0.1) + + # Dummy catalog-based forecast: 3 catalogs with GR magnitudes + dummy_cats = [ + self.DummyCatalog(gr_dist(5000, b_val=1.0), region_mags=self.region_mags) + for _ in range(3) + ] + self.dummy_catalog_forecast = self.DummyCatalogForecast(dummy_cats, self.region_mags) - self.mock_cat = MagicMock() - self.mock_cat.get_magnitudes.return_value = gr_dist(500, b_val=1.2) - self.mock_cat.get_number_of_events.return_value = 500 - self.mock_cat.region.magnitudes = numpy.arange(3.0, 8.0, 0.1) + # Dummy observation catalog + self.dummy_observation = self.DummyCatalog( + gr_dist(500, b_val=1.2), region_mags=self.region_mags + ) + # Dummy gridded forecast (Poisson / rate-based) + self.dummy_gridded_forecast = self.DummyGriddedForecast( + rates=numpy.array([10.0, 5.0, 1.0, 0.5, 0.1]), + region_mags=numpy.array([3.0, 3.1, 3.2, 3.3, 3.4]), + ) + + # Real data from artifacts (integration test) cat_file_m5 = os.path.join( self.artifacts, "example_csep2_forecasts", @@ -272,29 +375,178 @@ def gr_dist(num_events, mag_min=3.0, mag_max=8.0, b_val=1.0): "ucerf3-landers_short.csv", ) - self.stochastic_event_sets = csep.load_catalog_forecast(forecast_file) + self.stochastic_event_sets = csep.load_catalog_forecast( + forecast_file, + region=csep.regions.california_relm_region(magnitudes=[4.0, 5.0, 6.0]), + ) os.makedirs(self.save_dir, exist_ok=True) - def test_plot_magnitude_histogram_basic(self): - # Test with basic arguments - plot_magnitude_histogram( - self.mock_forecast, self.mock_cat, show=show_plots, density=True + # --- tests ------------------------------------------------------------- + + def test_basic_catalog_forecast(self): + ax = plot_magnitude_histogram( + self.dummy_catalog_forecast, + self.dummy_observation, + normalize=True, + cumulative=False, + intervals=True, + show=show_plots, ) + self.assertIsNotNone(ax) + self.assertEqual(ax.get_yscale(), "log") - # Verify that magnitudes were retrieved - for catalog in self.mock_forecast: - catalog.get_magnitudes.assert_called_once() - self.mock_cat.get_magnitudes.assert_called_once() - self.mock_cat.get_number_of_events.assert_called_once() + def test_ucerf_example(self): + ax = plot_magnitude_histogram( + self.stochastic_event_sets, + self.comcat, + show=show_plots, + ) + self.assertIsNotNone(ax) - def test_plot_magnitude_histogram_ucerf(self): - # Test with basic arguments - plot_magnitude_histogram(self.stochastic_event_sets, self.comcat, show=show_plots) + def test_catalog_cumulative(self): + ax = plot_magnitude_histogram( + self.dummy_catalog_forecast, + self.dummy_observation, + cumulative=True, + intervals=True, + show=show_plots, + ) + self.assertIsNotNone(ax) + y = ax.lines[1].get_ydata() + self.assertTrue(numpy.all(numpy.diff(y) <= 1e-8)) + + def test_catalog_no_intervals(self): + ax = plot_magnitude_histogram( + self.dummy_catalog_forecast, + self.dummy_observation, + cumulative=False, + intervals=False, + show=show_plots, + ) + self.assertIsNotNone(ax) + self.assertGreaterEqual(len(ax.lines), 2) + + def test_gridded_basic(self): + ax = plot_magnitude_histogram( + self.dummy_gridded_forecast, + observation=None, + normalize=False, + cumulative=False, + intervals=True, + show=show_plots, + ) + self.assertIsNotNone(ax) + self.assertEqual(ax.get_yscale(), "log") + + def test_gridded_cumulative(self): + ax = plot_magnitude_histogram( + self.dummy_gridded_forecast, + observation=None, + normalize=False, + cumulative=True, + intervals=True, + show=show_plots, + ) + self.assertIsNotNone(ax) + y = ax.lines[0].get_ydata() + self.assertTrue(numpy.all(numpy.diff(y) <= 1e-8)) + + def test_gridded_no_intervals(self): + ax = plot_magnitude_histogram( + self.dummy_gridded_forecast, + observation=None, + cumulative=False, + intervals=False, + show=show_plots, + ) + self.assertIsNotNone(ax) + + def test_mismatch_rates_bins(self): + bad_forecast = self.DummyGriddedForecastBadCounts( + rates=numpy.array([10.0, 5.0, 1.0, 0.5, 0.1]), + region_mags=numpy.array([3.0, 3.1, 3.2, 3.3, 3.4]), + ) + with self.assertRaises(ValueError): + plot_magnitude_histogram( + bad_forecast, + observation=None, + show=False, + ) + + def test_missing_magnitudes_attr(self): + forecast = self.DummyForecastNoMagnitudes() + with self.assertRaises(AttributeError): + plot_magnitude_histogram( + forecast, + observation=None, + show=False, + ) + + def test_default_labels(self): + ax = plot_magnitude_histogram( + self.dummy_gridded_forecast, + observation=None, + show=False, + ) + self.assertEqual(ax.get_xlabel(), "Magnitude") + self.assertEqual(ax.get_ylabel(), "Event count") + self.assertEqual(ax.get_title(), "Magnitude Histogram") + + def test_custom_labels(self): + ax = plot_magnitude_histogram( + self.dummy_gridded_forecast, + observation=None, + xlabel="Mw", + ylabel="Number of events", + title="Custom Title", + show=False, + ) + self.assertEqual(ax.get_xlabel(), "Mw") + self.assertEqual(ax.get_ylabel(), "Number of events") + self.assertEqual(ax.get_title(), "Custom Title") + + def test_obs_default_bins(self): + obs = self.DummyCatalog(self.dummy_observation.get_magnitudes(), region_mags=None) + ax = plot_magnitude_histogram( + self.dummy_gridded_forecast, + observation=obs, + show=False, + ) + self.assertIsNotNone(ax) + + def test_random_gridded_forecasts(self): + rng = numpy.random.default_rng(1234) + for _ in range(10): + n_bins = rng.integers(3, 15) + m_min = rng.uniform(2.0, 5.0) + m_max = m_min + rng.uniform(0.5, 3.0) + mag_bins = numpy.linspace(m_min, m_max, n_bins) + + rates = rng.uniform(0.0, 20.0, size=n_bins) + + forecast = self.DummyGriddedForecast(rates=rates, region_mags=mag_bins) + cumulative = bool(rng.integers(0, 2)) + log_scale = bool(rng.integers(0, 2)) + normed = bool(rng.integers(0, 2)) + plot_intervals = bool(rng.integers(0, 2)) + + ax = plot_magnitude_histogram( + forecast, + observation=None, + magnitude_bins=mag_bins, + cumulative=cumulative, + log_scale=log_scale, + normalize=normed, + intervals=plot_intervals, + show=False, + ) + self.assertIsNotNone(ax) def tearDown(self): plt.close("all") gc.collect() + super().tearDown() class TestPlotDistributionTests(TestPlots):