Skip to content

Commit c3f3580

Browse files
authored
allow prior variables (#672)
1 parent 7e2ef91 commit c3f3580

File tree

4 files changed

+84
-37
lines changed

4 files changed

+84
-37
lines changed

preliz/internal/distribution_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def num_kurtosis(dist):
145145
"ZeroInflatedPoisson": {"psi": 0.8, "mu": 4.5},
146146
"Truncated": {"lower": -10, "upper": 10},
147147
"Censored": {"lower": -10, "upper": 10},
148+
"Dirichlet": {"alpha": [1.0, 1.0, 1.0]},
148149
}
149150

150151

preliz/internal/plot_helper.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,15 @@
44

55
try:
66
from IPython import get_ipython
7-
from ipywidgets import Checkbox, FloatSlider, FloatText, IntSlider, IntText, ToggleButton
7+
from ipywidgets import (
8+
Checkbox,
9+
FloatSlider,
10+
FloatText,
11+
IntSlider,
12+
IntText,
13+
RadioButtons,
14+
ToggleButton,
15+
)
816
except ImportError:
917
pass
1018

@@ -323,7 +331,7 @@ def get_boxes(name, value, lower, upper):
323331
return text
324332

325333

326-
def get_textboxes(signature, model):
334+
def get_textboxes(signature, model, kind_plot):
327335
textboxes = {}
328336
for name, param in signature.parameters.items():
329337
if isinstance(param.default, int | float):
@@ -399,6 +407,13 @@ def get_textboxes(signature, model):
399407
tooltip="Resample",
400408
)
401409

410+
textboxes["__kind__"] = RadioButtons(
411+
options=["hist", "kde", "ecdf"],
412+
value=kind_plot,
413+
description="Kind",
414+
disabled=False,
415+
)
416+
402417
return textboxes
403418

404419

@@ -432,8 +447,9 @@ def looper(*args, **kwargs):
432447
return looper
433448

434449

435-
def plot_repr(results, kind_plot, references, iterations, ax):
450+
def plot_repr(results, kind_plot, references, iterations, stats_kwargs, ax):
436451
alpha = max(0.01, 1 - iterations * 0.009)
452+
stats_kwargs.setdefault("alpha", alpha)
437453

438454
if kind_plot == "hist":
439455
if results[0].dtype.kind == "i":
@@ -442,31 +458,29 @@ def plot_repr(results, kind_plot, references, iterations, ax):
442458
ax.set_xticks(bins + 0.5)
443459
else:
444460
bins = "auto"
445-
ax.hist(
446-
results.T,
447-
alpha=alpha,
448-
density=True,
449-
color=["0.5"] * iterations,
450-
bins=bins,
451-
histtype="step",
452-
)
453-
ax.hist(
454-
np.concatenate(results),
455-
density=True,
456-
bins=bins,
457-
color="k",
458-
ls="--",
459-
histtype="step",
460-
)
461+
462+
stats_kwargs.setdefault("bins", bins)
463+
stats_kwargs.setdefault("density", True)
464+
stats_kwargs.setdefault("histtype", "step")
465+
stats_kwargs.setdefault("alpha", alpha)
466+
stats_kwargs.setdefault("color", ["0.5"] * iterations)
467+
468+
ax.hist(results.T, **stats_kwargs)
469+
stats_kwargs.pop("color")
470+
stats_kwargs.pop("ls", None)
471+
ax.hist(np.concatenate(results), color="k", ls="--", **stats_kwargs)
461472
elif kind_plot == "kde":
473+
stats_kwargs.setdefault("color", "0.5")
462474
for result in results:
463-
ax.plot(*kde(result), "0.5", alpha=alpha)
475+
ax.plot(*kde(result), **stats_kwargs)
464476
ax.plot(*kde(np.concatenate(results)), "k--")
477+
465478
elif kind_plot == "ecdf":
479+
stats_kwargs.setdefault("color", "0.5")
466480
ax.plot(
467481
np.sort(results, axis=1).T,
468482
np.linspace(0, 1, len(results[0]), endpoint=False),
469-
color="0.5",
483+
**stats_kwargs,
470484
)
471485
a = np.concatenate(results)
472486
ax.plot(np.sort(a), np.linspace(0, 1, len(a), endpoint=False), "k--")

preliz/ppls/agnostic.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ def match_preliz_dist(all_dist_str, source, engine):
222222
return matches
223223

224224

225-
def ppl_plot_decorator(func, iterations, kind_plot, references, plot_func, engine):
225+
def ppl_plot_decorator(
226+
func, iterations, kind_plot, references, plot_func, engine, group, var_name, stats_kwargs
227+
):
226228
def looper(*args, **kwargs):
227229
kwargs.pop("__resample__")
228230
x_min = kwargs.pop("__x_min__")
@@ -234,6 +236,7 @@ def looper(*args, **kwargs):
234236
else:
235237
auto = False
236238

239+
var_to_plot = var_name
237240
y_min = kwargs.pop("__y_min__", None)
238241
y_max = kwargs.pop("__y_max__", None)
239242
set_ylim = kwargs.pop("__set_ylim__", False)
@@ -244,6 +247,8 @@ def looper(*args, **kwargs):
244247
else:
245248
auto_ylim = False
246249

250+
kind = kwargs.pop("__kind__", kind_plot)
251+
247252
if engine == "preliz":
248253
results = []
249254
for _ in range(iterations):
@@ -255,28 +260,30 @@ def looper(*args, **kwargs):
255260
elif engine == "bambi":
256261
model = func(*args, **kwargs)
257262
model.build()
263+
if var_to_plot is None:
264+
var_to_plot = model.observed_RVs[0].name
265+
258266
with disable_pymc_sampling_logs():
259267
idata = model.prior_predictive(iterations)
260-
results = (
261-
idata["prior_predictive"]
262-
.stack(sample=("chain", "draw"))[model.response_component.response.name]
263-
.values.T
264-
)
268+
results = idata[group].stack(sample=("chain", "draw"))[var_to_plot].values.T
269+
if group == "prior":
270+
results = np.atleast_2d(results)
265271

266272
elif engine == "pymc":
267273
with func(*args, **kwargs) as model:
268-
obs_name = model.observed_RVs[0].name
274+
if var_to_plot is None:
275+
var_to_plot = model.observed_RVs[0].name
269276
with disable_pymc_sampling_logs():
270277
idata = sample_prior_predictive(samples=iterations)
271-
results = (
272-
idata["prior_predictive"].stack(sample=("chain", "draw"))[obs_name].values.T
273-
)
278+
results = idata[group].stack(sample=("chain", "draw"))[var_to_plot].values.T
279+
if group == "prior":
280+
results = np.atleast_2d(results)
274281

275282
_, ax = plt.subplots()
276283
ax.set_xlim(x_min, x_max, auto=auto)
277284
ax.set_ylim(y_min, y_max, auto=auto_ylim)
278285
if plot_func is None:
279-
plot_repr(results, kind_plot, references, iterations, ax)
286+
plot_repr(results, kind, references, iterations, stats_kwargs, ax)
280287
else:
281288
plot_func(results, ax)
282289

preliz/predictive/predictive_explorer.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,23 @@
1212

1313

1414
def predictive_explorer(
15-
fmodel, samples=50, kind_plot="ecdf", references=None, plot_func=None, engine="auto"
15+
fmodel,
16+
samples=50,
17+
kind_plot="ecdf",
18+
references=None,
19+
plot_func=None,
20+
engine="auto",
21+
group="prior_predictive",
22+
var_name=None,
23+
stats_kwargs=None,
1624
):
1725
"""
1826
Explore how changing parameters in the prior affects the prior predictive distribution.
1927
20-
Use this function to interactively explore how a prior predictive distribution changes when the
21-
priors are changed.
28+
Use this function to interactively explore how a prior predictive distribution changes
29+
when the priors are changed. It also allows you to visualize how one prior changes when
30+
another prior is changed, this can be useful for prior that are not set independently, but
31+
are dependent on each other.
2232
2333
Parameters
2434
----------
@@ -40,13 +50,28 @@ def predictive_explorer(
4050
Library used to define the fmodel. Either `preliz`, `pymc` or `bambi`. Default is `auto`.
4151
The function will automatically select the appropriate library to use based on the fmodel
4252
provided.
53+
group : str, optional
54+
Which group to use. Ignored if the model is defined in `preliz`.
55+
Defaults to "prior_predictive". You can also pass "prior".
56+
var_name: str, optional
57+
The name of the variable to plot. Ignored if the model is defined in `preliz`.
58+
If "group=prior_predictive" it defaults to the first variable in `observed_RVs`.
59+
For "prior" it defaults to the last variable in `free_RVs`.
60+
stats_kwargs : dict, optional
61+
Additional keyword arguments to pass to the statistics function.
62+
Defaults to an empty dictionary.
4363
"""
64+
if stats_kwargs is None:
65+
stats_kwargs = {}
4466
source, signature, engine = inspect_source(fmodel)
4567
model = parse_function_for_pred_textboxes(source, signature, engine)
46-
textboxes = get_textboxes(signature, model)
47-
new_fmodel = ppl_plot_decorator(fmodel, samples, kind_plot, references, plot_func, engine)
68+
textboxes = get_textboxes(signature, model, kind_plot)
69+
new_fmodel = ppl_plot_decorator(
70+
fmodel, samples, kind_plot, references, plot_func, engine, group, var_name, stats_kwargs
71+
)
4872
out = interactive_output(new_fmodel, textboxes)
4973
default_names = [
74+
"__kind__",
5075
"__set_xlim__",
5176
"__x_min__",
5277
"__x_max__",

0 commit comments

Comments
 (0)