Skip to content

Commit 4b5157c

Browse files
igerberclaude
andcommitted
Recompute size_gt from surviving comparison cohorts only
Move size_gt computation to after the inner gc loop so it reflects only the comparison cohorts that actually survived identification. Prevents incorrect IF rescaling and analytical SE when some gc's are skipped due to empty subgroups. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8ce7c5d commit 4b5157c

1 file changed

Lines changed: 18 additions & 14 deletions

File tree

diff_diff/staggered_triple_diff.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -324,16 +324,10 @@ def fit(
324324
if n_treated == 0:
325325
continue
326326

327-
# Compute size_gt: union of treated + all possible sub-groups
328-
# across ALL comparison cohorts for this (g, t)
329-
all_gc_units = treated_mask.copy()
330-
for gc in valid_gc:
331-
all_gc_units |= (unit_cohorts == gc) | (unit_cohorts == g)
332-
size_gt = int(np.sum(all_gc_units))
333-
334327
att_vec = []
335-
inf_matrix = []
328+
inf_raw = [] # unrescaled IFs
336329
gc_labels = []
330+
gc_cell_sizes = [] # size_gt_ctrl per surviving gc
337331

338332
for gc in valid_gc:
339333
result = self._compute_ddd_gt_gc(
@@ -346,18 +340,28 @@ def fit(
346340
if not np.isfinite(att_gc):
347341
continue
348342

349-
# R's att_gt IF rescaling per g_c: (size_gt / size_gt_ctrl)
350-
# For single g_c, size_gt == size_gt_ctrl so this is 1.0
351-
if size_gt_ctrl > 0:
352-
inf_gc = inf_gc * (size_gt / size_gt_ctrl)
353-
354343
att_vec.append(att_gc)
355-
inf_matrix.append(inf_gc)
344+
inf_raw.append(inf_gc)
356345
gc_labels.append(gc)
346+
gc_cell_sizes.append(size_gt_ctrl)
357347

358348
if not att_vec:
359349
continue
360350

351+
# Compute size_gt from SURVIVING comparison cohorts only
352+
# (not from all initially valid gc's)
353+
surviving_units = treated_mask.copy()
354+
for gc in gc_labels:
355+
surviving_units |= (unit_cohorts == gc) | (unit_cohorts == g)
356+
size_gt = int(np.sum(surviving_units))
357+
358+
# Apply IF rescaling now that size_gt is known
359+
inf_matrix = []
360+
for inf_gc, size_gt_ctrl in zip(inf_raw, gc_cell_sizes):
361+
if size_gt_ctrl > 0:
362+
inf_gc = inf_gc * (size_gt / size_gt_ctrl)
363+
inf_matrix.append(inf_gc)
364+
361365
att_gmm, inf_gmm, gmm_w, se_gt = self._combine_gmm(
362366
np.array(att_vec), np.array(inf_matrix), n_units,
363367
)

0 commit comments

Comments
 (0)