Skip to content

Commit c40c46c

Browse files
igerberclaude
andcommitted
Address PR #97 review: fix base_period fallback and add validation
- Remove fallback to non-consecutive base periods in varying mode - Add base_period parameter to CallawaySantAnna docstring - Add validation and warning for empty post-treatment effect sets - Add tests for no-fallback behavior and NaN result with warning Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 673a830 commit c40c46c

4 files changed

Lines changed: 118 additions & 5 deletions

File tree

diff_diff/staggered.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,15 @@ class CallawaySantAnna(
209209
- "warn": Issue warning and drop linearly dependent columns (default)
210210
- "error": Raise ValueError
211211
- "silent": Drop columns silently without warning
212+
base_period : str, default="varying"
213+
Method for selecting the base (reference) period for computing
214+
ATT(g,t). Options:
215+
- "varying": For pre-treatment periods (t < g - anticipation), use
216+
t-1 as base (consecutive comparisons). For post-treatment, use
217+
g-1-anticipation. Requires t-1 to exist in data.
218+
- "universal": Always use g-1-anticipation as base period.
219+
Both produce identical post-treatment effects. Matches R's
220+
did::att_gt() base_period parameter.
212221
213222
Attributes
214223
----------
@@ -462,11 +471,8 @@ def _compute_att_gt_fast(
462471
base_period_val = g - 1 - self.anticipation
463472

464473
if base_period_val not in period_to_col:
465-
# Find closest earlier period
466-
earlier = [p for p in time_periods if p < base_period_val]
467-
if not earlier:
468-
return None, 0.0, 0, 0, None
469-
base_period_val = max(earlier)
474+
# Base period must exist; no fallback to maintain methodological consistency
475+
return None, 0.0, 0, 0, None
470476

471477
# Check if periods exist in the data
472478
if base_period_val not in period_to_col or t not in period_to_col:

diff_diff/staggered_aggregation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,17 @@ def _aggregate_simple(
7272
gt_pairs.append((g, t))
7373
groups_for_gt.append(g)
7474

75+
# Guard against empty post-treatment set
76+
if len(effects) == 0:
77+
import warnings
78+
warnings.warn(
79+
"No post-treatment effects available for overall ATT aggregation. "
80+
"This can occur when cohorts lack post-treatment periods in the data.",
81+
UserWarning,
82+
stacklevel=2
83+
)
84+
return np.nan, np.nan
85+
7586
effects = np.array(effects)
7687
weights = np.array(weights_list, dtype=float)
7788
groups_for_gt = np.array(groups_for_gt)

diff_diff/staggered_bootstrap.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,16 @@ def _run_multiplier_bootstrap(
323323
group_time_effects[gt]['n_treated'] for gt in gt_pairs
324324
], dtype=float)
325325
post_n_treated = all_n_treated[post_treatment_mask]
326+
327+
# Guard against empty post-treatment set
328+
if len(post_treatment_indices) == 0:
329+
warnings.warn(
330+
"No post-treatment effects for bootstrap aggregation.",
331+
UserWarning,
332+
stacklevel=2
333+
)
334+
# Return results with NaN for overall ATT - will be handled by caller
335+
326336
overall_weights_post = post_n_treated / np.sum(post_n_treated)
327337

328338
# Original point estimates

tests/test_staggered.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,3 +2205,89 @@ def test_default_base_period_is_varying(self):
22052205
cs = CallawaySantAnna()
22062206
assert cs.base_period == "varying"
22072207
assert cs.get_params()["base_period"] == "varying"
2208+
2209+
def test_varying_mode_no_fallback_to_nonconsecutive(self):
2210+
"""Varying mode skips pre-treatment effects where t-1 doesn't exist."""
2211+
# Create data where first period (e.g., period 1) has no t-1 predecessor
2212+
data = generate_staggered_data(
2213+
n_units=100,
2214+
n_periods=6, # periods 1-6
2215+
n_cohorts=2,
2216+
treatment_effect=2.0,
2217+
seed=42
2218+
)
2219+
2220+
# Identify the earliest time period in data
2221+
min_period = data['time'].min()
2222+
2223+
cs = CallawaySantAnna(base_period="varying")
2224+
results = cs.fit(
2225+
data,
2226+
outcome='outcome',
2227+
unit='unit',
2228+
time='time',
2229+
first_treat='first_treat'
2230+
)
2231+
2232+
# In varying mode, ATT(g, min_period) should NOT be computed for
2233+
# any cohort g because t-1 (period 0) doesn't exist
2234+
for (g, t) in results.group_time_effects.keys():
2235+
if t == min_period:
2236+
# This should not happen - the (g, min_period) pair should be skipped
2237+
pytest.fail(
2238+
f"ATT({g}, {t}) should not exist because t-1 doesn't exist. "
2239+
"Fallback to non-consecutive base period was incorrectly applied."
2240+
)
2241+
2242+
def test_no_post_treatment_effects_returns_nan_with_warning(self):
2243+
"""Warn and return NaN when no post-treatment effects exist."""
2244+
import warnings
2245+
2246+
# Create data where the treatment cohort treats AFTER the last observed period
2247+
# so there are no post-treatment periods (t >= g never holds)
2248+
n_units = 50
2249+
n_periods = 5
2250+
np.random.seed(42)
2251+
2252+
data = []
2253+
for unit in range(n_units):
2254+
for t in range(1, n_periods + 1):
2255+
# Treated units get treated at period 6 (beyond data range)
2256+
# Data only goes to period 5, so no post-treatment periods exist
2257+
first_treat = n_periods + 1 if unit < n_units // 2 else 0
2258+
outcome = np.random.randn()
2259+
data.append({
2260+
'unit': unit,
2261+
'time': t,
2262+
'outcome': outcome,
2263+
'first_treat': first_treat
2264+
})
2265+
2266+
df = pd.DataFrame(data)
2267+
2268+
cs = CallawaySantAnna(base_period="varying")
2269+
2270+
with warnings.catch_warnings(record=True) as w:
2271+
warnings.simplefilter("always")
2272+
results = cs.fit(
2273+
df,
2274+
outcome='outcome',
2275+
unit='unit',
2276+
time='time',
2277+
first_treat='first_treat'
2278+
)
2279+
2280+
# Should have emitted a warning about no post-treatment effects
2281+
warning_messages = [str(warning.message) for warning in w]
2282+
has_warning = any(
2283+
"No post-treatment effects" in msg for msg in warning_messages
2284+
)
2285+
assert has_warning, (
2286+
f"Expected warning about no post-treatment effects, got: {warning_messages}"
2287+
)
2288+
2289+
# Overall ATT should be NaN
2290+
assert np.isnan(results.overall_att), (
2291+
f"Expected NaN for overall_att when no post-treatment effects exist, "
2292+
f"got {results.overall_att}"
2293+
)

0 commit comments

Comments
 (0)