Skip to content

Commit 504ef1c

Browse files
Fix: autoregressive params info (#584)
* fix: autoregressive params info * Add test * make test less arbitrary Co-authored-by: Jesse Grabowski <[email protected]> * Fix typo --------- Co-authored-by: Jesse Grabowski <[email protected]>
1 parent 084a274 commit 504ef1c

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

pymc_extras/statespace/models/structural/components/autoregressive.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,19 +141,19 @@ def populate_component_properties(self):
141141

142142
self.param_info = {
143143
f"params_{self.name}": {
144-
"shape": (k_states,) if self.k_endog == 1 else (self.k_endog, k_states),
144+
"shape": (k_endog_effective, k_states) if k_endog_effective > 1 else (k_states,),
145145
"constraints": None,
146146
"dims": (AR_PARAM_DIM,)
147-
if self.k_endog == 1
147+
if k_endog_effective == 1
148148
else (
149149
f"endog_{self.name}",
150150
f"lag_{self.name}",
151151
),
152152
},
153153
f"sigma_{self.name}": {
154-
"shape": () if self.k_endog == 1 else (self.k_endog,),
154+
"shape": (k_endog_effective,) if k_endog_effective > 1 else (),
155155
"constraints": "Positive",
156-
"dims": None if self.k_endog == 1 else (f"endog_{self.name}",),
156+
"dims": (f"endog_{self.name}",) if k_endog_effective > 1 else None,
157157
},
158158
}
159159

tests/statespace/models/structural/components/test_autoregressive.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,14 @@ def test_autoregressive_shared_and_not_shared():
163163
observed_state_names=["data_1", "data_2", "data_3"],
164164
share_states=True,
165165
)
166+
167+
# make sure param_info is correct
168+
# shound't have endog state when share_states is True
169+
assert not any(
170+
dim.startswith("endog_") for dim in shared.param_info["params_shared_ar"]["dims"]
171+
)
172+
assert shared.param_info["sigma_shared_ar"]["dims"] is None
173+
166174
individual = st.AutoregressiveComponent(
167175
order=3,
168176
name="individual_ar",

0 commit comments

Comments
 (0)