Skip to content

Commit e79271b

Browse files
authored
Merge pull request #283 from pymc-labs/round_to_summary
User specified number of significant figures for numerical outputs
2 parents 6283c76 + 22f93ff commit e79271b

File tree

2 files changed

+57
-41
lines changed

2 files changed

+57
-41
lines changed

causalpy/pymc_experiments.py

+53-37
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,13 @@ def idata(self):
5959

6060
return self.model.idata
6161

62-
def print_coefficients(self) -> None:
62+
def print_coefficients(self, round_to=None) -> None:
6363
"""
6464
Prints the model coefficients
6565
66+
:param round_to:
67+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
68+
6669
Example
6770
--------
6871
>>> import causalpy as cp
@@ -80,13 +83,13 @@ def print_coefficients(self) -> None:
8083
... "progressbar": False
8184
... }),
8285
... )
83-
>>> result.print_coefficients() # doctest: +NUMBER
86+
>>> result.print_coefficients(round_to=1) # doctest: +NUMBER
8487
Model coefficients:
85-
Intercept 1.0, 94% HDI [1.0, 1.1]
86-
post_treatment[T.True] 0.9, 94% HDI [0.9, 1.0]
87-
group 0.1, 94% HDI [0.0, 0.2]
88+
Intercept 1, 94% HDI [1, 1]
89+
post_treatment[T.True] 1, 94% HDI [0.9, 1]
90+
group 0.2, 94% HDI [0.09, 0.2]
8891
group:post_treatment[T.True] 0.5, 94% HDI [0.4, 0.6]
89-
sigma 0.0, 94% HDI [0.0, 0.1]
92+
sigma 0.08, 94% HDI [0.07, 0.1]
9093
"""
9194
print("Model coefficients:")
9295
coeffs = az.extract(self.idata.posterior, var_names="beta")
@@ -95,13 +98,13 @@ def print_coefficients(self) -> None:
9598
for name in self.labels:
9699
coeff_samples = coeffs.sel(coeffs=name)
97100
print(
98-
f"{name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]" # noqa: E501
101+
f"{name: <30}{round_num(coeff_samples.mean().data, round_to)}, 94% HDI [{round_num(coeff_samples.quantile(0.03).data, round_to)}, {round_num(coeff_samples.quantile(1-0.03).data, round_to)}]" # noqa: E501
99102
)
100103
# add coeff for measurement std
101104
coeff_samples = az.extract(self.model.idata.posterior, var_names="sigma")
102105
name = "sigma"
103106
print(
104-
f"{name: <30}{coeff_samples.mean().data:.2f}, 94% HDI [{coeff_samples.quantile(0.03).data:.2f}, {coeff_samples.quantile(1-0.03).data:.2f}]" # noqa: E501
107+
f"{name: <30}{round_num(coeff_samples.mean().data, round_to)}, 94% HDI [{round_num(coeff_samples.quantile(0.03).data, round_to)}, {round_num(coeff_samples.quantile(1-0.03).data, round_to)}]" # noqa: E501
105108
)
106109

107110

@@ -138,7 +141,7 @@ class PrePostFit(ExperimentalDesign):
138141
... }
139142
... ),
140143
... )
141-
>>> result.summary() # doctest: +NUMBER
144+
>>> result.summary(round_to=1) # doctest: +NUMBER
142145
==================================Pre-Post Fit==================================
143146
Formula: actual ~ 0 + a + g
144147
Model coefficients:
@@ -231,7 +234,7 @@ def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
231234
Plot the results
232235
233236
:param round_to:
234-
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
237+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
235238
"""
236239
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
237240

@@ -331,15 +334,18 @@ def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
331334

332335
return fig, ax
333336

334-
def summary(self) -> None:
337+
def summary(self, round_to=None) -> None:
335338
"""
336339
Print text output summarising the results
340+
341+
:param round_to:
342+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
337343
"""
338344

339345
print(f"{self.expt_type:=^80}")
340346
print(f"Formula: {self.formula}")
341347
# TODO: extra experiment specific outputs here
342-
self.print_coefficients()
348+
self.print_coefficients(round_to)
343349

344350

345351
class InterruptedTimeSeries(PrePostFit):
@@ -420,7 +426,7 @@ def plot(self, plot_predictors=False, **kwargs):
420426
"""Plot the results
421427
422428
:param round_to:
423-
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
429+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
424430
"""
425431
fig, ax = super().plot(counterfactual_label="Synthetic control", **kwargs)
426432
if plot_predictors:
@@ -589,7 +595,7 @@ def plot(self, round_to=None):
589595
"""Plot the results.
590596
591597
:param round_to:
592-
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
598+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
593599
"""
594600
fig, ax = plt.subplots()
595601

@@ -728,17 +734,19 @@ def _causal_impact_summary_stat(self, round_to=None) -> str:
728734
causal_impact = f"{round_num(self.causal_impact.mean(), round_to)}, "
729735
return f"Causal impact = {causal_impact + ci}"
730736

731-
def summary(self) -> None:
737+
def summary(self, round_to=None) -> None:
732738
"""
733-
Print text output summarising the results
739+
Print text output summarising the results.
740+
741+
:param round_to:
742+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
734743
"""
735744

736745
print(f"{self.expt_type:=^80}")
737746
print(f"Formula: {self.formula}")
738747
print("\nResults:")
739-
# TODO: extra experiment specific outputs here
740-
print(self._causal_impact_summary_stat())
741-
self.print_coefficients()
748+
print(round_num(self._causal_impact_summary_stat(), round_to))
749+
self.print_coefficients(round_to)
742750

743751

744752
class RegressionDiscontinuity(ExperimentalDesign):
@@ -894,7 +902,7 @@ def plot(self, round_to=None):
894902
Plot the results
895903
896904
:param round_to:
897-
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
905+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
898906
"""
899907
fig, ax = plt.subplots()
900908
# Plot raw data
@@ -943,9 +951,12 @@ def plot(self, round_to=None):
943951
)
944952
return fig, ax
945953

946-
def summary(self) -> None:
954+
def summary(self, round_to: None) -> None:
947955
"""
948956
Print text output summarising the results
957+
958+
:param round_to:
959+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
949960
"""
950961

951962
print(f"{self.expt_type:=^80}")
@@ -954,9 +965,9 @@ def summary(self) -> None:
954965
print(f"Threshold on running variable: {self.treatment_threshold}")
955966
print("\nResults:")
956967
print(
957-
f"Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f}"
968+
f"Discontinuity at threshold = {round_num(self.discontinuity_at_threshold.mean(), round_to)}"
958969
)
959-
self.print_coefficients()
970+
self.print_coefficients(round_to)
960971

961972

962973
class RegressionKink(ExperimentalDesign):
@@ -1111,7 +1122,7 @@ def plot(self, round_to=None):
11111122
Plot the results
11121123
11131124
:param round_to:
1114-
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
1125+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
11151126
"""
11161127
fig, ax = plt.subplots()
11171128
# Plot raw data
@@ -1160,9 +1171,12 @@ def plot(self, round_to=None):
11601171
)
11611172
return fig, ax
11621173

1163-
def summary(self) -> None:
1174+
def summary(self, round_to=None) -> None:
11641175
"""
11651176
Print text output summarising the results
1177+
1178+
:param round_to:
1179+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
11661180
"""
11671181

11681182
print(
@@ -1173,10 +1187,10 @@ def summary(self) -> None:
11731187
Kink point on running variable: {self.kink_point}
11741188
11751189
Results:
1176-
Change in slope at kink point = {self.gradient_change.mean():.2f}
1190+
Change in slope at kink point = {round_num(self.gradient_change.mean(), round_to)}
11771191
"""
11781192
)
1179-
self.print_coefficients()
1193+
self.print_coefficients(round_to)
11801194

11811195

11821196
class PrePostNEGD(ExperimentalDesign):
@@ -1213,17 +1227,17 @@ class PrePostNEGD(ExperimentalDesign):
12131227
... }
12141228
... )
12151229
... )
1216-
>>> result.summary() # doctest: +NUMBER
1230+
>>> result.summary(round_to=1) # doctest: +NUMBER
12171231
==================Pretest/posttest Nonequivalent Group Design===================
12181232
Formula: post ~ 1 + C(group) + pre
12191233
<BLANKLINE>
12201234
Results:
1221-
Causal impact = 1.8, $CI_{94%}$[1.7, 2.1]
1235+
Causal impact = 2, $CI_{94%}$[2, 2]
12221236
Model coefficients:
1223-
Intercept -0.4, 94% HDI [-1.1, 0.2]
1224-
C(group)[T.1] 1.8, 94% HDI [1.6, 2.0]
1225-
pre 1.0, 94% HDI [0.9, 1.1]
1226-
sigma 0.5, 94% HDI [0.4, 0.5]
1237+
Intercept -0.5, 94% HDI [-1, 0.2]
1238+
C(group)[T.1] 2, 94% HDI [2, 2]
1239+
pre 1, 94% HDI [1, 1]
1240+
sigma 0.5, 94% HDI [0.5, 0.6]
12271241
"""
12281242

12291243
def __init__(
@@ -1304,7 +1318,7 @@ def plot(self, round_to=None):
13041318
"""Plot the results
13051319
13061320
:param round_to:
1307-
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
1321+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
13081322
"""
13091323
fig, ax = plt.subplots(
13101324
2, 1, figsize=(7, 9), gridspec_kw={"height_ratios": [3, 1]}
@@ -1362,20 +1376,23 @@ def _causal_impact_summary_stat(self, round_to) -> str:
13621376
r"$CI_{94%}$"
13631377
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
13641378
)
1365-
causal_impact = f"{self.causal_impact.mean():.2f}, "
1379+
causal_impact = f"{round_num(self.causal_impact.mean(), round_to)}, "
13661380
return f"Causal impact = {causal_impact + ci}"
13671381

13681382
def summary(self, round_to=None) -> None:
13691383
"""
13701384
Print text output summarising the results
1385+
1386+
:param round_to:
1387+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
13711388
"""
13721389

13731390
print(f"{self.expt_type:=^80}")
13741391
print(f"Formula: {self.formula}")
13751392
print("\nResults:")
13761393
# TODO: extra experiment specific outputs here
13771394
print(self._causal_impact_summary_stat(round_to))
1378-
self.print_coefficients()
1395+
self.print_coefficients(round_to)
13791396

13801397
def _get_treatment_effect_coeff(self) -> str:
13811398
"""Find the beta regression coefficient corresponding to the
@@ -1452,7 +1469,6 @@ class InstrumentalVariable(ExperimentalDesign):
14521469
... formula=formula,
14531470
... model=InstrumentalVariableRegression(sample_kwargs=sample_kwargs),
14541471
... )
1455-
14561472
"""
14571473

14581474
def __init__(

causalpy/skl_experiments.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
119119
"""Plot experiment results
120120
121121
:param round_to:
122-
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
122+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
123123
"""
124124
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
125125

@@ -270,7 +270,7 @@ def plot(self, plot_predictors=False, round_to=None, **kwargs):
270270
"""Plot the results
271271
272272
:param round_to:
273-
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
273+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
274274
"""
275275
fig, ax = super().plot(
276276
counterfactual_label="Synthetic control", round_to=round_to, **kwargs
@@ -415,7 +415,7 @@ def plot(self, round_to=None):
415415
"""Plot results
416416
417417
:param round_to:
418-
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
418+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
419419
"""
420420
fig, ax = plt.subplots()
421421

@@ -629,7 +629,7 @@ def plot(self, round_to=None):
629629
"""Plot results
630630
631631
:param round_to:
632-
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
632+
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
633633
"""
634634
fig, ax = plt.subplots()
635635
# Plot raw data

0 commit comments

Comments
 (0)