Skip to content

Commit d79b042

Browse files
igerberclaude
andcommitted
Address PR #208 review: remove TWFE from registry, add lo-sufficient short-circuit
- Remove TwoWayFixedEffects from power analysis registry (time="period" produces treated*period_number, not standard ATT) - Add early return in simulate_sample_size() when lower bound already achieves target power (both explicit n_range and auto-bracket paths) - Narrow docstring from "All" to "Most" built-in estimators - Add regression tests for TWFE exclusion and lo-sufficient scenarios Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 11539e7 commit d79b042

2 files changed

Lines changed: 73 additions & 30 deletions

File tree

diff_diff/power.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,6 @@ def _basic_fit_kwargs(
136136
return dict(outcome="outcome", treatment="treated", time="post")
137137

138138

139-
def _twfe_fit_kwargs(
140-
data: pd.DataFrame,
141-
n_units: int,
142-
n_periods: int,
143-
treatment_period: int,
144-
) -> Dict[str, Any]:
145-
return dict(outcome="outcome", treatment="treated", time="period", unit="unit")
146-
147-
148139
def _multiperiod_fit_kwargs(
149140
data: pd.DataFrame,
150141
n_units: int,
@@ -264,13 +255,6 @@ def _get_registry() -> Dict[str, _EstimatorProfile]:
264255
result_extractor=_extract_simple,
265256
min_n=20,
266257
),
267-
"TwoWayFixedEffects": _EstimatorProfile(
268-
default_dgp=generate_did_data,
269-
dgp_kwargs_builder=_basic_dgp_kwargs,
270-
fit_kwargs_builder=_twfe_fit_kwargs,
271-
result_extractor=_extract_simple,
272-
min_n=20,
273-
),
274258
"MultiPeriodDiD": _EstimatorProfile(
275259
default_dgp=generate_did_data,
276260
dgp_kwargs_builder=_basic_dgp_kwargs,
@@ -1221,7 +1205,7 @@ def simulate_power(
12211205
12221206
This function simulates datasets with known treatment effects and estimates
12231207
power as the fraction of simulations where the null hypothesis is rejected.
1224-
All built-in estimators are supported via an internal registry that selects
1208+
Most built-in estimators are supported via an internal registry that selects
12251209
the appropriate data-generating process and fit signature automatically.
12261210
12271211
Parameters
@@ -1987,7 +1971,24 @@ def _power_at_n(n: int) -> float:
19871971
# --- Bracket ---
19881972
if n_range is not None:
19891973
lo, hi = n_range
1990-
_power_at_n(lo) # evaluate lo to populate search_path
1974+
power_lo = _power_at_n(lo)
1975+
if power_lo >= power:
1976+
warnings.warn(
1977+
f"Power at n={lo} is {power_lo:.2f} >= target {power}. "
1978+
f"Lower bound already achieves target power. Returning lo.",
1979+
UserWarning,
1980+
)
1981+
return SimulationSampleSizeResults(
1982+
required_n=lo,
1983+
power_at_n=power_lo,
1984+
target_power=power,
1985+
alpha=alpha,
1986+
effect_size=treatment_effect,
1987+
n_simulations_per_step=n_simulations,
1988+
n_steps=len(search_path),
1989+
search_path=search_path,
1990+
estimator_name=estimator_name,
1991+
)
19911992
power_hi = _power_at_n(hi)
19921993
if power_hi < power:
19931994
warnings.warn(
@@ -1997,6 +1998,19 @@ def _power_at_n(n: int) -> float:
19971998
)
19981999
else:
19992000
lo = min_n
2001+
power_lo = _power_at_n(lo)
2002+
if power_lo >= power:
2003+
return SimulationSampleSizeResults(
2004+
required_n=lo,
2005+
power_at_n=power_lo,
2006+
target_power=power,
2007+
alpha=alpha,
2008+
effect_size=treatment_effect,
2009+
n_simulations_per_step=n_simulations,
2010+
n_steps=len(search_path),
2011+
search_path=search_path,
2012+
estimator_name=estimator_name,
2013+
)
20002014
hi = max(100, 2 * min_n)
20012015
for _ in range(10):
20022016
if _power_at_n(hi) >= power:

tests/test_power.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
_staggered_dgp_kwargs,
4444
_staggered_fit_kwargs,
4545
_trop_fit_kwargs,
46-
_twfe_fit_kwargs,
4746
)
4847

4948

@@ -671,7 +670,6 @@ class TestEstimatorRegistry:
671670

672671
EXPECTED_ESTIMATORS = [
673672
"DifferenceInDifferences",
674-
"TwoWayFixedEffects",
675673
"MultiPeriodDiD",
676674
"CallawaySantAnna",
677675
"SunAbraham",
@@ -720,7 +718,6 @@ def test_fit_kwargs_builders_return_dicts(self):
720718
dummy_df = pd.DataFrame({"period": [0, 1, 2, 3]})
721719
for builder in [
722720
_basic_fit_kwargs,
723-
_twfe_fit_kwargs,
724721
_staggered_fit_kwargs,
725722
_ddd_fit_kwargs,
726723
_trop_fit_kwargs,
@@ -803,6 +800,18 @@ def test_continuous_did_not_in_registry(self):
803800
progress=False,
804801
)
805802

803+
def test_twfe_not_in_registry(self):
804+
"""TwoWayFixedEffects is not in registry and raises without custom data_generator."""
805+
registry = _get_registry()
806+
assert "TwoWayFixedEffects" not in registry
807+
808+
with pytest.raises(ValueError, match="not in registry"):
809+
simulate_power(
810+
TwoWayFixedEffects(),
811+
n_simulations=5,
812+
progress=False,
813+
)
814+
806815
def test_unknown_estimator_raises_without_data_generator(self):
807816
"""Unknown estimator without data_generator raises ValueError."""
808817

@@ -841,15 +850,6 @@ def test_did(self):
841850
)
842851
self._assert_valid_result(result, "DifferenceInDifferences")
843852

844-
def test_twfe(self):
845-
result = simulate_power(
846-
TwoWayFixedEffects(),
847-
n_simulations=10,
848-
seed=42,
849-
progress=False,
850-
)
851-
self._assert_valid_result(result, "TwoWayFixedEffects")
852-
853853
def test_multiperiod(self):
854854
result = simulate_power(
855855
MultiPeriodDiD(),
@@ -1225,3 +1225,32 @@ def test_unbracketed_n_range_warns(self):
12251225
seed=42,
12261226
progress=False,
12271227
)
1228+
1229+
def test_lo_already_sufficient_explicit(self):
1230+
"""When lo already meets power, return lo immediately with warning."""
1231+
with pytest.warns(UserWarning, match="Lower bound already achieves"):
1232+
result = simulate_sample_size(
1233+
DifferenceInDifferences(),
1234+
treatment_effect=50.0,
1235+
sigma=0.1,
1236+
n_simulations=50,
1237+
n_range=(20, 200),
1238+
seed=42,
1239+
progress=False,
1240+
)
1241+
assert result.required_n == 20
1242+
assert result.power_at_n >= 0.80
1243+
1244+
def test_lo_already_sufficient_auto(self):
1245+
"""Auto-bracket returns min_n when effect overwhelmingly large."""
1246+
result = simulate_sample_size(
1247+
DifferenceInDifferences(),
1248+
treatment_effect=50.0,
1249+
sigma=0.1,
1250+
n_simulations=50,
1251+
seed=42,
1252+
progress=False,
1253+
)
1254+
# min_n for DifferenceInDifferences is 20
1255+
assert result.required_n == 20
1256+
assert result.power_at_n >= 0.80

0 commit comments

Comments
 (0)