diff --git a/pymc_extras/model_builder.py b/pymc_extras/model_builder.py index 6e712e5d7..be672db7b 100644 --- a/pymc_extras/model_builder.py +++ b/pymc_extras/model_builder.py @@ -444,6 +444,7 @@ def load(cls, fname: str): sampler_config=json.loads(idata.attrs["sampler_config"]), ) model.idata = idata + model.is_fitted_ = True dataset = idata.fit_data.to_dataframe() X = dataset.drop(columns=[model.output_var]) y = dataset[model.output_var] @@ -524,6 +525,8 @@ def fit( ) self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore + self.is_fitted_ = True + return self.idata # type: ignore def predict(