diff --git a/pymc_extras/statespace/models/structural/components/autoregressive.py b/pymc_extras/statespace/models/structural/components/autoregressive.py index 5dac13315..e20c5caa7 100644 --- a/pymc_extras/statespace/models/structural/components/autoregressive.py +++ b/pymc_extras/statespace/models/structural/components/autoregressive.py @@ -141,19 +141,19 @@ def populate_component_properties(self): self.param_info = { f"params_{self.name}": { - "shape": (k_states,) if self.k_endog == 1 else (self.k_endog, k_states), + "shape": (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,), "constraints": None, "dims": (AR_PARAM_DIM,) - if self.k_endog == 1 + if k_endog_effective == 1 else ( f"endog_{self.name}", f"lag_{self.name}", ), }, f"sigma_{self.name}": { - "shape": () if self.k_endog == 1 else (self.k_endog,), + "shape": (k_endog_effective,) if k_endog_effective > 1 else (), "constraints": "Positive", - "dims": None if self.k_endog == 1 else (f"endog_{self.name}",), + "dims": (f"endog_{self.name}",) if k_endog_effective > 1 else None, }, } diff --git a/tests/statespace/models/structural/components/test_autoregressive.py b/tests/statespace/models/structural/components/test_autoregressive.py index 8f05fe96d..33758f0d4 100644 --- a/tests/statespace/models/structural/components/test_autoregressive.py +++ b/tests/statespace/models/structural/components/test_autoregressive.py @@ -163,6 +163,14 @@ def test_autoregressive_shared_and_not_shared(): observed_state_names=["data_1", "data_2", "data_3"], share_states=True, ) + + # make sure param_info is correct + # shound't have endog state when share_states is True + assert not any( + dim.startswith("endog_") for dim in shared.param_info["params_shared_ar"]["dims"] + ) + assert shared.param_info["sigma_shared_ar"]["dims"] is None + individual = st.AutoregressiveComponent( order=3, name="individual_ar",