Skip to content

Commit 227ae50

Browse files
igerberclaude
andcommitted
Address CI review: pscore fallback, zero-mass subgroups, test coverage
Fix survey-weighted propensity fallback to use np.average(PA4, weights=sw) instead of unweighted np.mean(PA4). Add zero-weight subgroup detection in _compute_ddd_gt_gc for subpopulation/domain designs. Fix existing test expecting NotImplementedError for invalid survey_design type (now TypeError from _resolve_survey_for_fit). Add covariate-adjusted survey tests for reg/ipw/dr and combined_weights=False replicate test. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b552cb8 commit 227ae50

3 files changed

Lines changed: 260 additions & 115 deletions

File tree

diff_diff/staggered_triple_diff.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -924,16 +924,27 @@ def _compute_ddd_gt_gc(
924924
n_b = int(np.sum(sub_b_mask))
925925
n_c = int(np.sum(sub_c_mask))
926926

927-
if n_treated == 0 or n_a == 0 or n_b == 0 or n_c == 0:
928-
empty = []
929-
if n_treated == 0:
930-
empty.append(f"(S={g},Q=1)")
931-
if n_a == 0:
932-
empty.append(f"(S={g},Q=0)")
933-
if n_b == 0:
934-
empty.append(f"(S={g_c},Q=1)")
935-
if n_c == 0:
936-
empty.append(f"(S={g_c},Q=0)")
927+
# Check for empty subgroups (by count or by survey weight mass)
928+
empty = []
929+
if n_treated == 0:
930+
empty.append(f"(S={g},Q=1)")
931+
if n_a == 0:
932+
empty.append(f"(S={g},Q=0)")
933+
if n_b == 0:
934+
empty.append(f"(S={g_c},Q=1)")
935+
if n_c == 0:
936+
empty.append(f"(S={g_c},Q=0)")
937+
# Zero survey-weight mass after subpopulation filtering = effectively empty
938+
if not empty and survey_weights is not None:
939+
if np.sum(survey_weights[treated_mask]) <= 0:
940+
empty.append(f"(S={g},Q=1,mass=0)")
941+
if np.sum(survey_weights[sub_a_mask]) <= 0:
942+
empty.append(f"(S={g},Q=0,mass=0)")
943+
if np.sum(survey_weights[sub_b_mask]) <= 0:
944+
empty.append(f"(S={g_c},Q=1,mass=0)")
945+
if np.sum(survey_weights[sub_c_mask]) <= 0:
946+
empty.append(f"(S={g_c},Q=0,mass=0)")
947+
if empty:
937948
warnings.warn(
938949
f"Empty subgroup(s) {', '.join(empty)} for "
939950
f"(g={g}, g_c={g_c}, t={t}). "
@@ -1294,7 +1305,16 @@ def _compute_pscore(
12941305
UserWarning,
12951306
stacklevel=5,
12961307
)
1297-
pscore = np.full(n_pair, np.mean(PA4))
1308+
# Use survey-weighted treated share when weights available
1309+
if survey_weights is not None:
1310+
pos = survey_weights > 0
1311+
if np.any(pos):
1312+
p_uc = np.average(PA4[pos], weights=survey_weights[pos])
1313+
else:
1314+
p_uc = np.mean(PA4)
1315+
else:
1316+
p_uc = np.mean(PA4)
1317+
pscore = np.full(n_pair, p_uc)
12981318
pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim)
12991319
# No hessian for unconditional fallback
13001320
return pscore, None

0 commit comments

Comments
 (0)