Skip to content

Commit a15ce5a

Browse files
igerberclaude
andcommitted
Add result_extractor param, fix stale power.rst docs, and add missing API symbols
- Add `result_extractor` parameter to simulate_power, simulate_mde, and simulate_sample_size for unregistered estimators with non-standard schemas - Fix power.rst: correct PowerAnalysis method names, example code, and add SimulationMDEResults/SimulationSampleSizeResults/simulate_mde/simulate_sample_size - Add 4 missing symbols to docs/api/index.rst autosummary - Add api/power.rst to doc snippet smoke tests - Add tests for custom result_extractor and MDE forwarding Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5d2fdaf commit a15ce5a

5 files changed

Lines changed: 189 additions & 80 deletions

File tree

diff_diff/power.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,7 @@ def simulate_power(
12141214
data_generator: Optional[Callable] = None,
12151215
data_generator_kwargs: Optional[Dict[str, Any]] = None,
12161216
estimator_kwargs: Optional[Dict[str, Any]] = None,
1217+
result_extractor: Optional[Callable] = None,
12171218
progress: bool = True,
12181219
) -> SimulationPowerResults:
12191220
"""
@@ -1257,6 +1258,11 @@ def simulate_power(
12571258
Additional keyword arguments for data generator.
12581259
estimator_kwargs : dict, optional
12591260
Additional keyword arguments for estimator.fit().
1261+
result_extractor : callable, optional
1262+
Custom function to extract results from the estimator output.
1263+
Takes the estimator result object and returns a tuple of
1264+
``(att, se, p_value, conf_int)``. Useful for unregistered
1265+
estimators with non-standard result schemas.
12601266
progress : bool, default=True
12611267
Whether to print progress updates.
12621268
@@ -1439,6 +1445,8 @@ def simulate_power(
14391445
# --- Extract results ---
14401446
if profile is not None:
14411447
att, se, p_val, ci = profile.result_extractor(result)
1448+
elif result_extractor is not None:
1449+
att, se, p_val, ci = result_extractor(result)
14421450
else:
14431451
att = result.att if hasattr(result, "att") else result.avg_att
14441452
se = result.se if hasattr(result, "se") else result.avg_se
@@ -1717,6 +1725,7 @@ def simulate_mde(
17171725
data_generator: Optional[Callable] = None,
17181726
data_generator_kwargs: Optional[Dict[str, Any]] = None,
17191727
estimator_kwargs: Optional[Dict[str, Any]] = None,
1728+
result_extractor: Optional[Callable] = None,
17201729
progress: bool = True,
17211730
) -> SimulationMDEResults:
17221731
"""
@@ -1759,6 +1768,9 @@ def simulate_mde(
17591768
Additional keyword arguments for data generator.
17601769
estimator_kwargs : dict, optional
17611770
Additional keyword arguments for estimator.fit().
1771+
result_extractor : callable, optional
1772+
Custom function to extract results from the estimator output.
1773+
Forwarded to ``simulate_power()``.
17621774
progress : bool, default=True
17631775
Whether to print progress updates.
17641776
@@ -1789,6 +1801,7 @@ def simulate_mde(
17891801
data_generator=data_generator,
17901802
data_generator_kwargs=data_generator_kwargs,
17911803
estimator_kwargs=estimator_kwargs,
1804+
result_extractor=result_extractor,
17921805
progress=False,
17931806
)
17941807

@@ -1911,6 +1924,7 @@ def simulate_sample_size(
19111924
data_generator: Optional[Callable] = None,
19121925
data_generator_kwargs: Optional[Dict[str, Any]] = None,
19131926
estimator_kwargs: Optional[Dict[str, Any]] = None,
1927+
result_extractor: Optional[Callable] = None,
19141928
progress: bool = True,
19151929
) -> SimulationSampleSizeResults:
19161930
"""
@@ -1951,6 +1965,9 @@ def simulate_sample_size(
19511965
Additional keyword arguments for data generator.
19521966
estimator_kwargs : dict, optional
19531967
Additional keyword arguments for estimator.fit().
1968+
result_extractor : callable, optional
1969+
Custom function to extract results from the estimator output.
1970+
Forwarded to ``simulate_power()``.
19541971
progress : bool, default=True
19551972
Whether to print progress updates.
19561973
@@ -1988,6 +2005,7 @@ def simulate_sample_size(
19882005
data_generator=data_generator,
19892006
data_generator_kwargs=data_generator_kwargs,
19902007
estimator_kwargs=estimator_kwargs,
2008+
result_extractor=result_extractor,
19912009
progress=False,
19922010
)
19932011

docs/api/index.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,14 @@ Power analysis for study design:
148148
diff_diff.PowerAnalysis
149149
diff_diff.PowerResults
150150
diff_diff.SimulationPowerResults
151+
diff_diff.SimulationMDEResults
152+
diff_diff.SimulationSampleSizeResults
151153
diff_diff.compute_power
152154
diff_diff.compute_mde
153155
diff_diff.compute_sample_size
154156
diff_diff.simulate_power
157+
diff_diff.simulate_mde
158+
diff_diff.simulate_sample_size
155159

156160
Pre-Trends Power Analysis
157161
-------------------------

docs/api/power.rst

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ Main class for analytical power calculations.
3030

3131
.. autosummary::
3232

33-
~PowerAnalysis.compute_power
34-
~PowerAnalysis.compute_mde
35-
~PowerAnalysis.compute_sample_size
33+
~PowerAnalysis.power
34+
~PowerAnalysis.mde
35+
~PowerAnalysis.sample_size
36+
~PowerAnalysis.power_curve
37+
~PowerAnalysis.sample_size_curve
3638

3739
Example
3840
~~~~~~~
@@ -41,29 +43,19 @@ Example
4143
4244
from diff_diff import PowerAnalysis
4345
44-
# Create power analysis object
45-
pa = PowerAnalysis(
46-
effect_size=0.5,
47-
n_treated=100,
48-
n_control=100,
49-
n_pre=4,
50-
n_post=4,
51-
sigma=1.0,
52-
rho=0.5, # Within-unit correlation
53-
alpha=0.05
54-
)
46+
pa = PowerAnalysis(alpha=0.05, power=0.80)
5547
5648
# Compute power
57-
power = pa.compute_power()
58-
print(f"Power: {power:.2%}")
49+
result = pa.power(effect_size=0.5, n_treated=100, n_control=100, sigma=1.0)
50+
print(f"Power: {result.power:.2%}")
5951
6052
# Compute MDE at 80% power
61-
mde = pa.compute_mde(power=0.80)
62-
print(f"MDE: {mde:.3f}")
53+
result = pa.mde(n_treated=100, n_control=100, sigma=1.0)
54+
print(f"MDE: {result.mde:.3f}")
6355
6456
# Required sample size
65-
n = pa.compute_sample_size(power=0.80)
66-
print(f"Required N per group: {n}")
57+
result = pa.sample_size(effect_size=0.5, sigma=1.0)
58+
print(f"Required N: {result.required_n}")
6759
6860
PowerResults
6961
------------
@@ -85,6 +77,26 @@ Results from simulation-based power analysis.
8577
:undoc-members:
8678
:show-inheritance:
8779

80+
SimulationMDEResults
81+
--------------------
82+
83+
Results from simulation-based MDE search.
84+
85+
.. autoclass:: diff_diff.SimulationMDEResults
86+
:members:
87+
:undoc-members:
88+
:show-inheritance:
89+
90+
SimulationSampleSizeResults
91+
---------------------------
92+
93+
Results from simulation-based sample size search.
94+
95+
.. autoclass:: diff_diff.SimulationSampleSizeResults
96+
:members:
97+
:undoc-members:
98+
:show-inheritance:
99+
88100
Convenience Functions
89101
---------------------
90102

@@ -116,6 +128,20 @@ Simulation-based power for any DiD estimator.
116128

117129
.. autofunction:: diff_diff.simulate_power
118130

131+
simulate_mde
132+
~~~~~~~~~~~~~
133+
134+
Simulation-based MDE for any DiD estimator.
135+
136+
.. autofunction:: diff_diff.simulate_mde
137+
138+
simulate_sample_size
139+
~~~~~~~~~~~~~~~~~~~~
140+
141+
Simulation-based sample size for any DiD estimator.
142+
143+
.. autofunction:: diff_diff.simulate_sample_size
144+
119145
Complete Example
120146
----------------
121147

@@ -125,8 +151,8 @@ Complete Example
125151
PowerAnalysis,
126152
compute_mde,
127153
simulate_power,
154+
simulate_mde,
128155
DifferenceInDifferences,
129-
plot_power_curve,
130156
)
131157
132158
# Quick MDE calculation
@@ -145,20 +171,23 @@ Complete Example
145171
# Simulation-based power for DiD estimator
146172
sim_results = simulate_power(
147173
estimator=DifferenceInDifferences(),
148-
effect_size=0.5,
149-
n_treated=100,
150-
n_control=100,
151-
n_periods=8,
152-
treatment_start=4,
174+
treatment_effect=5.0,
175+
n_units=100,
176+
n_periods=4,
177+
treatment_period=2,
153178
sigma=1.0,
154-
n_simulations=1000
179+
n_simulations=20,
155180
)
156181
print(f"Simulated power: {sim_results.power:.2%}")
157182
158-
# Power curve
159-
pa = PowerAnalysis(n_treated=100, n_control=100, n_pre=4, n_post=4, sigma=1.0)
160-
ax = plot_power_curve(pa, effect_range=(0, 1), n_points=50)
161-
ax.figure.savefig('power_curve.png')
183+
# Simulation-based MDE
184+
mde_results = simulate_mde(
185+
estimator=DifferenceInDifferences(),
186+
n_units=100,
187+
n_simulations=10,
188+
max_steps=5,
189+
)
190+
print(f"Simulated MDE: {mde_results.mde:.3f}")
162191
163192
See Also
164193
--------

0 commit comments

Comments
 (0)