Skip to content

Commit 7f84f03

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 4ea5fbc commit 7f84f03

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

pymc_extras/model_builder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,9 @@ def sample_prior_predictive(
628628

629629
return prior_predictive_samples
630630

631-
def sample_posterior_predictive(self, X_pred, extend_idata, combined, predictions = True, **kwargs):
631+
def sample_posterior_predictive(
632+
self, X_pred, extend_idata, combined, predictions=True, **kwargs
633+
):
632634
"""
633635
Sample from the model's posterior predictive distribution.
634636
@@ -652,15 +654,15 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, prediction
652654
self._data_setter(X_pred)
653655

654656
with self.model: # sample with new input data
655-
post_pred = pm.sample_posterior_predictive(self.idata, predictions=predictions, **kwargs)
657+
post_pred = pm.sample_posterior_predictive(
658+
self.idata, predictions=predictions, **kwargs
659+
)
656660
if extend_idata:
657661
self.idata.extend(post_pred, join="right")
658662

659663
group_name = "predictions" if predictions else "posterior_predictive"
660664

661-
posterior_predictive_samples = az.extract(
662-
post_pred, group_name, combined=combined
663-
)
665+
posterior_predictive_samples = az.extract(post_pred, group_name, combined=combined)
664666

665667
return posterior_predictive_samples
666668

tests/test_model_builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def test_id():
305305

306306
assert model_builder.id == expected_id
307307

308+
308309
@pytest.mark.parametrize("predictions", [True, False])
309310
def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
310311
x_pred = np.random.uniform(0, 1, 100)
@@ -324,7 +325,7 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
324325
combined=False,
325326
predictions=predictions,
326327
)
327-
328+
328329
pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values
329330

330331
# Check predictions group presence
@@ -335,4 +336,4 @@ def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
335336
else:
336337
assert "predictions" not in fitted_model_instance.idata.groups()
337338
# Posterior predictive should be updated
338-
np.testing.assert_array_not_equal(pp_before, pp_after)
339+
np.testing.assert_array_not_equal(pp_before, pp_after)

0 commit comments

Comments
 (0)