Skip to content

Commit 46f7222

Browse files
committed
start vectorizing inside scan
1 parent a95e8ed commit 46f7222

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

pymc_extras/statespace/filters/kalman_filter.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,10 @@ def _build_graph(
282282

283283
if len(sequences) > 0:
284284
sequences = self.add_check_on_time_varying_shapes(data, sequences)
285-
285+
# y, a, P, c, d, T, Z, R, H, Q = self.unpack_args(args)
286+
# s0 and s1 = states, o0 and 01 are observed states, r0 and r1 are random states
286287
results, updates = pytensor.scan(
287-
self.kalman_step,
288+
pt.vectorize(self.kalman_step, signature="(o0,), (s0,), (s0, s1), (s0,), (o0,), (s0, s1), (o0, s0), (s0, r0), (o0, o1), (r0, r1) -> (s0,), (s0,), (o0,), (s0, s1), (P_hat,), (obs_cov,), (ll,)",
288289
sequences=[data, *sequences],
289290
outputs_info=[None, a0, None, None, P0, None, None],
290291
non_sequences=non_sequences,
@@ -320,18 +321,17 @@ def build_graph(
320321
"""
321322
Build the vectorized computation graph for the Kalman filter.
322323
"""
323-
signature = self._make_gufunc_signature(
324-
[data, a0, P0, c, d, T, Z, R, H, Q],
325-
)
324+
# signature = self._make_gufunc_signature(
325+
# [data, a0, P0, c, d, T, Z, R, H, Q],
326+
# )
326327
fn = partial(
327328
self._build_graph,
328329
mode=mode,
329330
return_updates=return_updates,
330331
missing_fill_value=missing_fill_value,
331332
cov_jitter=cov_jitter,
332333
)
333-
filter_outputs = pt.vectorize(fn, signature=signature)(data, a0, P0, c, d, T, Z, R, H, Q)
334-
# filter_outputs = fn(data, a0, P0, c, d, T, Z, R, H, Q)
334+
filter_outputs = fn(data, a0, P0, c, d, T, Z, R, H, Q)
335335
for output, name in zip(filter_outputs, ALL_KF_OUTPUT_NAMES):
336336
output.name = name
337337

pymc_extras/statespace/filters/kalman_smoother.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,16 +169,15 @@ def build_graph(
169169
"""
170170
Build the vectorized computation graph for the Kalman smoother.
171171
"""
172-
signature = self._make_gufunc_signature(
173-
[T, R, Q, filtered_states, filtered_covariances],
174-
)
172+
# signature = self._make_gufunc_signature(
173+
# [T, R, Q, filtered_states, filtered_covariances],
174+
# )
175175
fn = partial(
176176
self._build_graph,
177177
mode=mode,
178178
cov_jitter=cov_jitter,
179179
)
180-
return pt.vectorize(fn, signature=signature)(T, R, Q, filtered_states, filtered_covariances)
181-
# return fn(T, R, Q, filtered_states, filtered_covariances)
180+
return fn(T, R, Q, filtered_states, filtered_covariances)
182181

183182
def smoother_step(self, *args):
184183
a, P, a_smooth, P_smooth, T, R, Q = self.unpack_args(args)

tests/statespace/test_kalman_filter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def test_batched_standard_filter(filter_func):
336336
pt.as_tensor(x, name=name)
337337
for x, name in zip(make_test_inputs(p, m, r, n, rng, batch_size=8), input_names)
338338
]
339+
inputs[0] = pt.moveaxis(inputs[0], -2, 0) # Move time dimension to the front of the data array
339340
kf = StandardFilter()
340341
outputs = kf.build_graph(*inputs)
341342
fn = pytensor.function([], outputs)

0 commit comments

Comments
 (0)