Skip to content

Commit e6f410d

Browse files
igerberclaude
andcommitted
Use fixed cohort-level survey weight sums for CS aggregation
Replace per-cell survey_weight_sum with fixed cohort-level weight sums computed once from the full unit-level sample, matching R's did::aggte() which uses pg = n_g / N (cohort shares from the full sample, not varying by cell). On unbalanced panels, cell-specific sums can differ from fixed cohort shares. The WIF already uses fixed pg — this aligns the aggregation point estimate with the WIF computation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 66adcdc commit e6f410d

1 file changed

Lines changed: 34 additions & 6 deletions

File tree

diff_diff/staggered_aggregation.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,26 @@ def _aggregate_simple(
6262
gt_pairs = []
6363
groups_for_gt = []
6464

65+
# For survey: compute fixed per-cohort weight sums from the full
66+
# unit-level sample (matching R's did::aggte pg = n_g / N).
67+
survey_cohort_weights = None
68+
if precomputed is not None and precomputed.get("survey_weights") is not None:
69+
sw = precomputed["survey_weights"]
70+
unit_cohorts = precomputed["unit_cohorts"]
71+
survey_cohort_weights = {}
72+
for g in np.unique(unit_cohorts):
73+
if g > 0: # exclude never-treated (0)
74+
survey_cohort_weights[g] = float(np.sum(sw[unit_cohorts == g]))
75+
6576
for (g, t), data in group_time_effects.items():
6677
# Only include post-treatment effects (t >= g - anticipation)
6778
# Pre-treatment effects are for parallel trends, not overall ATT
6879
if t < g - self.anticipation:
6980
continue
7081
effects.append(data["effect"])
71-
# Use survey_weight_sum for aggregation when available
72-
if data.get("survey_weight_sum") is not None:
73-
weights_list.append(data["survey_weight_sum"])
82+
# Use fixed cohort-level survey weight sum for aggregation
83+
if survey_cohort_weights is not None and g in survey_cohort_weights:
84+
weights_list.append(survey_cohort_weights[g])
7485
else:
7586
weights_list.append(data["n_treated"])
7687
gt_pairs.append((g, t))
@@ -478,12 +489,25 @@ def _aggregate_event_study(
478489
# Organize effects by relative time, keeping track of (g,t) pairs
479490
effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, float]]] = {}
480491

492+
# Fixed per-cohort survey weights for aggregation
493+
survey_cohort_weights = None
494+
if precomputed is not None and precomputed.get("survey_weights") is not None:
495+
sw = precomputed["survey_weights"]
496+
unit_cohorts = precomputed["unit_cohorts"]
497+
survey_cohort_weights = {}
498+
for g in np.unique(unit_cohorts):
499+
if g > 0:
500+
survey_cohort_weights[g] = float(np.sum(sw[unit_cohorts == g]))
501+
481502
for (g, t), data in group_time_effects.items():
482503
e = t - g # Relative time
483504
if e not in effects_by_e:
484505
effects_by_e[e] = []
485-
# Use survey_weight_sum for aggregation when available
486-
w = data.get("survey_weight_sum", data["n_treated"])
506+
w = (
507+
survey_cohort_weights[g]
508+
if survey_cohort_weights is not None and g in survey_cohort_weights
509+
else data["n_treated"]
510+
)
487511
effects_by_e[e].append(
488512
(
489513
(g, t), # Keep track of the (g,t) pair
@@ -507,7 +531,11 @@ def _aggregate_event_study(
507531
e = t - g
508532
if e not in balanced_effects:
509533
balanced_effects[e] = []
510-
w = data.get("survey_weight_sum", data["n_treated"])
534+
w = (
535+
survey_cohort_weights[g]
536+
if survey_cohort_weights is not None and g in survey_cohort_weights
537+
else data["n_treated"]
538+
)
511539
balanced_effects[e].append(
512540
(
513541
(g, t),

0 commit comments

Comments
 (0)