Skip to content

Commit 673a830

Browse files
igerberclaude
andcommitted
Add base_period parameter to CallawaySantAnna for pre-treatment effects
Implement the base_period parameter matching R's did::att_gt() API to enable computation of pre-treatment ATT(g,t) values for parallel trends assessment. Two modes are supported: - "varying" (default): Pre-treatment uses t-1 as base (consecutive comparisons) - "universal": All comparisons use g-anticipation-1 as base Both modes produce identical post-treatment ATT(g,t) values. They differ only in how pre-treatment effects are computed. The overall ATT aggregation only includes post-treatment effects, matching R's behavior. Changes: - Add base_period parameter to CallawaySantAnna.__init__ with validation - Modify _compute_att_gt_fast to select base period based on mode - Update fit() to compute pre-treatment ATT(g,t) where t < g - anticipation - Filter _aggregate_simple and bootstrap to only aggregate post-treatment effects - Add base_period to CallawaySantAnnaResults and display in summary() - Update methodology registry with base_period edge case documentation - Add 11 new tests for pre-treatment effects Validated against R's did package v2.3.0 with max numerical difference of 4.91e-05. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 9953d2d commit 673a830

6 files changed

Lines changed: 315 additions & 16 deletions

File tree

diff_diff/staggered.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def __init__(
292292
bootstrap_weight_type: Optional[str] = None,
293293
seed: Optional[int] = None,
294294
rank_deficient_action: str = "warn",
295+
base_period: str = "varying",
295296
):
296297
import warnings
297298

@@ -333,6 +334,12 @@ def __init__(
333334
f"got '{rank_deficient_action}'"
334335
)
335336

337+
if base_period not in ["varying", "universal"]:
338+
raise ValueError(
339+
f"base_period must be 'varying' or 'universal', "
340+
f"got '{base_period}'"
341+
)
342+
336343
self.control_group = control_group
337344
self.anticipation = anticipation
338345
self.estimation_method = estimation_method
@@ -344,6 +351,7 @@ def __init__(
344351
self.bootstrap_weight_type = bootstrap_weights
345352
self.seed = seed
346353
self.rank_deficient_action = rank_deficient_action
354+
self.base_period = base_period
347355

348356
self.is_fitted_ = False
349357
self.results_: Optional[CallawaySantAnnaResults] = None
@@ -441,20 +449,30 @@ def _compute_att_gt_fast(
441449
all_units = precomputed['all_units']
442450
covariate_by_period = precomputed['covariate_by_period']
443451

444-
# Base period for comparison
445-
base_period = g - 1 - self.anticipation
446-
if base_period not in period_to_col:
452+
# Base period selection based on mode
453+
if self.base_period == "universal":
454+
# Universal: always use g - 1 - anticipation
455+
base_period_val = g - 1 - self.anticipation
456+
else: # varying
457+
if t < g - self.anticipation:
458+
# Pre-treatment: use t - 1 (consecutive comparison)
459+
base_period_val = t - 1
460+
else:
461+
# Post-treatment: use g - 1 - anticipation
462+
base_period_val = g - 1 - self.anticipation
463+
464+
if base_period_val not in period_to_col:
447465
# Find closest earlier period
448-
earlier = [p for p in time_periods if p < g - self.anticipation]
466+
earlier = [p for p in time_periods if p < base_period_val]
449467
if not earlier:
450468
return None, 0.0, 0, 0, None
451-
base_period = max(earlier)
469+
base_period_val = max(earlier)
452470

453471
# Check if periods exist in the data
454-
if base_period not in period_to_col or t not in period_to_col:
472+
if base_period_val not in period_to_col or t not in period_to_col:
455473
return None, 0.0, 0, 0, None
456474

457-
base_col = period_to_col[base_period]
475+
base_col = period_to_col[base_period_val]
458476
post_col = period_to_col[t]
459477

460478
# Get treated units mask (cohort g)
@@ -499,7 +517,7 @@ def _compute_att_gt_fast(
499517
X_treated = None
500518
X_control = None
501519
if covariates and covariate_by_period is not None:
502-
cov_matrix = covariate_by_period[base_period]
520+
cov_matrix = covariate_by_period[base_period_val]
503521
X_treated = cov_matrix[treated_valid]
504522
X_control = cov_matrix[control_valid]
505523

@@ -640,9 +658,21 @@ def fit(
640658
group_time_effects = {}
641659
influence_func_info = {} # Store influence functions for bootstrap
642660

661+
# Get minimum period for determining valid pre-treatment periods
662+
min_period = min(time_periods)
663+
643664
for g in treatment_groups:
644-
# Periods for which we compute effects (t >= g - anticipation)
645-
valid_periods = [t for t in time_periods if t >= g - self.anticipation]
665+
# Compute valid periods including pre-treatment
666+
if self.base_period == "universal":
667+
# Universal: all periods except the base period (which is normalized to 0)
668+
universal_base = g - 1 - self.anticipation
669+
valid_periods = [t for t in time_periods if t != universal_base]
670+
else:
671+
# Varying: post-treatment + pre-treatment where t-1 exists
672+
valid_periods = [
673+
t for t in time_periods
674+
if t >= g - self.anticipation or t > min_period
675+
]
646676

647677
for t in valid_periods:
648678
att_gt, se_gt, n_treat, n_ctrl, inf_info = self._compute_att_gt_fast(
@@ -768,6 +798,7 @@ def fit(
768798
n_control_units=n_control_units,
769799
alpha=self.alpha,
770800
control_group=self.control_group,
801+
base_period=self.base_period,
771802
event_study_effects=event_study_effects,
772803
group_effects=group_effects,
773804
bootstrap_results=bootstrap_results,
@@ -1043,6 +1074,7 @@ def get_params(self) -> Dict[str, Any]:
10431074
"bootstrap_weight_type": self.bootstrap_weight_type,
10441075
"seed": self.seed,
10451076
"rank_deficient_action": self.rank_deficient_action,
1077+
"base_period": self.base_period,
10461078
}
10471079

10481080
def set_params(self, **params) -> "CallawaySantAnna":

diff_diff/staggered_aggregation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ class CallawaySantAnnaAggregationMixin:
3131
# Type hints for attributes accessed from the main class
3232
alpha: float
3333

34+
# Type hint for anticipation attribute accessed from main class
35+
anticipation: int
36+
3437
def _aggregate_simple(
3538
self,
3639
group_time_effects: Dict,
@@ -49,13 +52,21 @@ def _aggregate_simple(
4952
shared control units. This includes the wif (weight influence function)
5053
adjustment from R's `did` package that accounts for uncertainty in
5154
estimating the group-size weights.
55+
56+
Note: Only post-treatment effects (t >= g - anticipation) are included
57+
in the overall ATT. Pre-treatment effects are computed for parallel
58+
trends assessment but are not aggregated into the overall ATT.
5259
"""
5360
effects = []
5461
weights_list = []
5562
gt_pairs = []
5663
groups_for_gt = []
5764

5865
for (g, t), data in group_time_effects.items():
66+
# Only include post-treatment effects (t >= g - anticipation)
67+
# Pre-treatment effects are for parallel trends, not overall ATT
68+
if t < g - self.anticipation:
69+
continue
5970
effects.append(data['effect'])
6071
weights_list.append(data['n_treated'])
6172
gt_pairs.append((g, t))

diff_diff/staggered_bootstrap.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ class CallawaySantAnnaBootstrapMixin:
248248
bootstrap_weight_type: str
249249
alpha: float
250250
seed: Optional[int]
251+
anticipation: int
251252

252253
def _run_multiplier_bootstrap(
253254
self,
@@ -310,15 +311,23 @@ def _run_multiplier_bootstrap(
310311
gt_pairs = list(group_time_effects.keys())
311312
n_gt = len(gt_pairs)
312313

313-
# Compute aggregation weights for overall ATT
314-
overall_weights = np.array([
314+
# Identify post-treatment (g,t) pairs for overall ATT
315+
# Pre-treatment effects are for parallel trends assessment, not aggregated
316+
post_treatment_mask = np.array([
317+
t >= g - self.anticipation for (g, t) in gt_pairs
318+
])
319+
post_treatment_indices = np.where(post_treatment_mask)[0]
320+
321+
# Compute aggregation weights for overall ATT (post-treatment only)
322+
all_n_treated = np.array([
315323
group_time_effects[gt]['n_treated'] for gt in gt_pairs
316324
], dtype=float)
317-
overall_weights = overall_weights / np.sum(overall_weights)
325+
post_n_treated = all_n_treated[post_treatment_mask]
326+
overall_weights_post = post_n_treated / np.sum(post_n_treated)
318327

319328
# Original point estimates
320329
original_atts = np.array([group_time_effects[gt]['effect'] for gt in gt_pairs])
321-
original_overall = np.sum(overall_weights * original_atts)
330+
original_overall = np.sum(overall_weights_post * original_atts[post_treatment_mask])
322331

323332
# Prepare event study and group aggregation info if needed
324333
event_study_info = None
@@ -382,11 +391,11 @@ def _run_multiplier_bootstrap(
382391
# Let non-finite values propagate - they will be handled at statistics computation
383392
bootstrap_atts_gt[:, j] = original_atts[j] + perturbations
384393

385-
# Vectorized overall ATT: matrix-vector multiply
394+
# Vectorized overall ATT: matrix-vector multiply (post-treatment only)
386395
# Shape: (n_bootstrap,)
387396
# Suppress RuntimeWarnings for edge cases - non-finite values handled at statistics computation
388397
with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
389-
bootstrap_overall = bootstrap_atts_gt @ overall_weights
398+
bootstrap_overall = bootstrap_atts_gt[:, post_treatment_indices] @ overall_weights_post
390399

391400
# Vectorized event study aggregation
392401
# Non-finite values handled at statistics computation stage

diff_diff/staggered_results.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class CallawaySantAnnaResults:
106106
n_control_units: int
107107
alpha: float = 0.05
108108
control_group: str = "never_treated"
109+
base_period: str = "varying"
109110
event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None)
110111
group_effects: Optional[Dict[Any, Dict[str, Any]]] = field(default=None)
111112
influence_functions: Optional["np.ndarray"] = field(default=None, repr=False)
@@ -149,6 +150,7 @@ def summary(self, alpha: Optional[float] = None) -> str:
149150
f"{'Treatment cohorts:':<30} {len(self.groups):>10}",
150151
f"{'Time periods:':<30} {len(self.time_periods):>10}",
151152
f"{'Control group:':<30} {self.control_group:>10}",
153+
f"{'Base period:':<30} {self.base_period:>10}",
152154
"",
153155
]
154156

docs/methodology/REGISTRY.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,11 @@ Aggregations:
209209
- Bootstrap: Drops non-finite samples, warns, and adjusts p-value floor accordingly
210210
- Threshold: Returns NaN if <50% of bootstrap samples are valid
211211
- **Note**: This is a defensive enhancement over reference implementations (R's `did::att_gt`, Stata's `csdid`) which may error or produce unhandled inf/nan in edge cases without informative warnings
212+
- Base period selection (`base_period` parameter):
213+
- "varying" (default): Pre-treatment uses t-1 as base (consecutive comparisons)
214+
- "universal": All comparisons use g-anticipation-1 as base
215+
- Both produce identical post-treatment ATT(g,t); differ only pre-treatment
216+
- Matches R `did::att_gt()` base_period parameter
212217

213218
**Reference implementation(s):**
214219
- R: `did::att_gt()` (Callaway & Sant'Anna's official package)

0 commit comments

Comments
 (0)