Skip to content

Commit 9f8b80c

Browse files
Always return dictionary from data_info (#582)
* Always return dictionary from data_info * Remove unused docstring parameter * More robust `needs_exog_data` * fix typo
1 parent 37f4588 commit 9f8b80c

File tree

2 files changed

+112
-48
lines changed

2 files changed

+112
-48
lines changed

pymc_extras/statespace/models/VARMAX.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,6 @@ def __init__(
148148
The type of Kalman Filter to use. Options are "standard", "single", "univariate", "steady_state",
149149
and "cholesky". See the docs for kalman filters for more details.
150150
151-
state_structure: str, default "fast"
152-
How to represent the state-space system. When "interpretable", each element of the state vector will have a
153-
precise meaning as either lagged data, innovations, or lagged innovations. This comes at the cost of a larger
154-
state vector, which may hurt performance.
155-
156-
When "fast", states are combined to minimize the dimension of the state vector, but lags and innovations are
157-
mixed together as a result. Only the first state (the modeled timeseries) will have an obvious interpretation
158-
in this case.
159-
160151
measurement_error: bool, default True
161152
If true, a measurement error term is added to the model.
162153
@@ -181,8 +172,10 @@ def __init__(
181172
if len(endog_names) != k_endog:
182173
raise ValueError("Length of provided endog_names does not match provided k_endog")
183174

175+
needs_exog_data = False
176+
184177
if k_exog is not None and not isinstance(k_exog, int | dict):
185-
raise ValueError("If not None, k_endog must be either an int or a dict")
178+
raise ValueError("If not None, k_exog must be either an int or a dict")
186179
if exog_state_names is not None and not isinstance(exog_state_names, list | dict):
187180
raise ValueError("If not None, exog_state_names must be either a list or a dict")
188181

@@ -208,6 +201,7 @@ def __init__(
208201
"If both k_endog and exog_state_names are provided, lengths of exog_state_names "
209202
"lists must match corresponding values in k_exog"
210203
)
204+
needs_exog_data = True
211205

212206
if k_exog is not None and exog_state_names is None:
213207
if isinstance(k_exog, int):
@@ -216,12 +210,14 @@ def __init__(
216210
exog_state_names = {
217211
name: [f"{name}_exogenous_{i}" for i in range(k)] for name, k in k_exog.items()
218212
}
213+
needs_exog_data = True
219214

220215
if k_exog is None and exog_state_names is not None:
221216
if isinstance(exog_state_names, list):
222217
k_exog = len(exog_state_names)
223218
elif isinstance(exog_state_names, dict):
224219
k_exog = {name: len(names) for name, names in exog_state_names.items()}
220+
needs_exog_data = True
225221

226222
# If exog_state_names is a dict but 1) all endog variables are among the keys, and 2) all values are the same
227223
# then we can drop back to the list case.
@@ -254,6 +250,8 @@ def __init__(
254250
mode=mode,
255251
)
256252

253+
self._needs_exog_data = needs_exog_data
254+
257255
# Save counts of the number of parameters in each category
258256
self.param_counts = {
259257
"x0": k_states * (1 - self.stationary_initialization),
@@ -337,7 +335,7 @@ def param_info(self) -> dict[str, dict[str, Any]]:
337335

338336
@property
339337
def data_info(self) -> dict[str, dict[str, Any]]:
340-
info = None
338+
info = {}
341339

342340
if isinstance(self.exog_state_names, list):
343341
info = {

tests/statespace/models/test_VARMAX.py

Lines changed: 103 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,14 @@ def test_all_prior_covariances_are_PSD(filter_output, pymc_mod, rng):
188188
def test_impulse_response(parameters, varma_mod, idata, rng):
189189
irf = varma_mod.impulse_response_function(idata.prior, random_seed=rng, **parameters)
190190

191-
assert not np.any(np.isnan(irf.irf.values))
191+
assert np.isfinite(irf.irf.values).all()
192+
193+
194+
def test_forecast(varma_mod, idata, rng):
195+
forecast = varma_mod.forecast(idata.prior, periods=10, random_seed=rng)
196+
197+
assert np.isfinite(forecast.forecast_latent.values).all()
198+
assert np.isfinite(forecast.forecast_observed.values).all()
192199

193200

194201
class TestVARMAXWithExogenous:
@@ -436,42 +443,8 @@ def test_create_varmax_with_exogenous_raises_if_args_disagree(self, data):
436443
stationary_initialization=False,
437444
)
438445

439-
@pytest.mark.parametrize(
440-
"k_exog, exog_state_names",
441-
[
442-
(2, None),
443-
(None, ["foo", "bar"]),
444-
(None, {"y1": ["a", "b"], "y2": ["c"]}),
445-
],
446-
ids=["k_exog_int", "exog_state_names_list", "exog_state_names_dict"],
447-
)
448-
@pytest.mark.filterwarnings("ignore::UserWarning")
449-
def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
450-
endog_names = ["y1", "y2", "y3"]
451-
n_obs = 50
452-
time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D")
453-
454-
y = rng.normal(size=(n_obs, len(endog_names)))
455-
df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX)
456-
457-
if isinstance(exog_state_names, dict):
458-
exog_data = {
459-
f"{name}_exogenous_data": pd.DataFrame(
460-
rng.normal(size=(n_obs, len(exog_names))).astype(floatX),
461-
columns=exog_names,
462-
index=time_idx,
463-
)
464-
for name, exog_names in exog_state_names.items()
465-
}
466-
else:
467-
exog_names = exog_state_names or [f"exogenous_{i}" for i in range(k_exog)]
468-
exog_data = {
469-
"exogenous_data": pd.DataFrame(
470-
rng.normal(size=(n_obs, k_exog or len(exog_state_names))).astype(floatX),
471-
columns=exog_names,
472-
index=time_idx,
473-
)
474-
}
446+
def _build_varmax(self, df, k_exog, exog_state_names, exog_data):
447+
endog_names = df.columns.values.tolist()
475448

476449
mod = BayesianVARMAX(
477450
endog_names=endog_names,
@@ -512,6 +485,47 @@ def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
512485

513486
mod.build_statespace_graph(data=df)
514487

488+
return mod, m
489+
490+
@pytest.mark.parametrize(
491+
"k_exog, exog_state_names",
492+
[
493+
(2, None),
494+
(None, ["foo", "bar"]),
495+
(None, {"y1": ["a", "b"], "y2": ["c"]}),
496+
],
497+
ids=["k_exog_int", "exog_state_names_list", "exog_state_names_dict"],
498+
)
499+
@pytest.mark.filterwarnings("ignore::UserWarning")
500+
def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
501+
endog_names = ["y1", "y2", "y3"]
502+
n_obs = 50
503+
time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D")
504+
505+
y = rng.normal(size=(n_obs, len(endog_names)))
506+
df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX)
507+
508+
if isinstance(exog_state_names, dict):
509+
exog_data = {
510+
f"{name}_exogenous_data": pd.DataFrame(
511+
rng.normal(size=(n_obs, len(exog_names))).astype(floatX),
512+
columns=exog_names,
513+
index=time_idx,
514+
)
515+
for name, exog_names in exog_state_names.items()
516+
}
517+
else:
518+
exog_names = exog_state_names or [f"exogenous_{i}" for i in range(k_exog)]
519+
exog_data = {
520+
"exogenous_data": pd.DataFrame(
521+
rng.normal(size=(n_obs, k_exog or len(exog_state_names))).astype(floatX),
522+
columns=exog_names,
523+
index=time_idx,
524+
)
525+
}
526+
527+
mod, m = self._build_varmax(df, k_exog, exog_state_names, exog_data)
528+
515529
with freeze_dims_and_data(m):
516530
prior = pm.sample_prior_predictive(
517531
draws=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
@@ -543,3 +557,55 @@ def test_varmax_with_exog(self, rng, k_exog, exog_state_names):
543557
obs_intercept.append(np.zeros_like(obs_intercept[0]))
544558

545559
np.testing.assert_allclose(beta_dot_data, np.stack(obs_intercept, axis=-1), atol=1e-2)
560+
561+
@pytest.mark.filterwarnings("ignore::UserWarning")
562+
def test_forecast_with_exog(self, rng):
563+
endog_names = ["y1", "y2", "y3"]
564+
n_obs = 50
565+
time_idx = pd.date_range(start="2020-01-01", periods=n_obs, freq="D")
566+
567+
y = rng.normal(size=(n_obs, len(endog_names)))
568+
df = pd.DataFrame(y, columns=endog_names, index=time_idx).astype(floatX)
569+
570+
mod, m = self._build_varmax(
571+
df,
572+
k_exog=2,
573+
exog_state_names=None,
574+
exog_data={
575+
"exogenous_data": pd.DataFrame(
576+
rng.normal(size=(n_obs, 2)).astype(floatX),
577+
columns=["exogenous_0", "exogenous_1"],
578+
index=time_idx,
579+
)
580+
},
581+
)
582+
583+
assert mod._needs_exog_data
584+
585+
with freeze_dims_and_data(m):
586+
prior = pm.sample_prior_predictive(
587+
draws=10, random_seed=rng, compile_kwargs={"mode": "JAX"}
588+
)
589+
590+
with pytest.raises(
591+
ValueError,
592+
match="This model was fit using exogenous data. Forecasting cannot be performed "
593+
"without providing scenario data",
594+
):
595+
mod.forecast(prior.prior, periods=10, random_seed=rng)
596+
597+
forecast = mod.forecast(
598+
prior.prior,
599+
periods=10,
600+
random_seed=rng,
601+
scenario={
602+
"exogenous_data": pd.DataFrame(
603+
rng.normal(size=(10, 2)).astype(floatX),
604+
columns=["exogenous_0", "exogenous_1"],
605+
index=pd.date_range(start=df.index[-1], periods=10, freq="D"),
606+
)
607+
},
608+
)
609+
610+
assert np.isfinite(forecast.forecast_latent.values).all()
611+
assert np.isfinite(forecast.forecast_observed.values).all()

0 commit comments

Comments
 (0)