Skip to content

Commit 97789f5

Browse files
igerberclaude
andcommitted
Fix CI review R4: survey-aware heterogeneity + TWFE helper parity
- P1 #1: _compute_heterogeneity_test now accepts obs_survey_info and runs survey-aware WLS + Binder TSL IF when survey_design is active. Point estimate via solve_ols(weights=W_elig, weight_type='pweight'); group-level IF ψ_g[X] = inv(X'WX)[1,:] @ x_g * W_g * r_g, expanded to obs-level via w_i/W_g ratio, then compute_survey_if_variance for stratified/PSU variance. safe_inference uses df_survey. Rank-deficiency short-circuits to NaN to avoid point-estimate/IF mismatch between solve_ols's R-style drop and pinv's minimum-norm. - P1 #2: twowayfeweights() now accepts Optional[SurveyDesign]. When provided, resolves weights via _resolve_survey_for_fit and passes them to _validate_and_aggregate_to_cells, restoring fit-vs-helper parity under survey-backed inputs. fweight/aweight rejected. - P3: REGISTRY updates — TWFE parity sentence now includes survey; heterogeneity Note documents the TSL IF mechanics and library extension disclaimer; checklist line-651 lists survey-aware surfaces; new survey+bootstrap-fallback Note after line 652. - P2: 5 new regression tests in test_survey_dcdh.py: TestSurveyHeterogeneity (uniform-weights match, non-uniform beta change, t-dist df_survey) and TestSurveyTWFEParity (fit-vs-helper match, non-pweight rejection). All 254 targeted tests pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 4eca232 commit 97789f5

3 files changed

Lines changed: 297 additions & 16 deletions

File tree

diff_diff/chaisemartin_dhaultfoeuille.py

Lines changed: 141 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2394,6 +2394,8 @@ def fit(
23942394
L_max=L_max,
23952395
alpha=self.alpha,
23962396
rank_deficient_action=self.rank_deficient_action,
2397+
group_ids_order=np.array(all_groups),
2398+
obs_survey_info=_obs_survey_info,
23972399
)
23982400

23992401
twfe_weights_df = None
@@ -3174,6 +3176,8 @@ def _compute_heterogeneity_test(
31743176
L_max: int,
31753177
alpha: float = 0.05,
31763178
rank_deficient_action: str = "warn",
3179+
group_ids_order: Optional[np.ndarray] = None,
3180+
obs_survey_info: Optional[Dict[str, Any]] = None,
31773181
) -> Dict[int, Dict[str, Any]]:
31783182
"""Test for heterogeneous treatment effects (Web Appendix Section 1.5).
31793183
@@ -3192,6 +3196,15 @@ def _compute_heterogeneity_test(
31923196
Time-invariant covariate to test for heterogeneity.
31933197
L_max : int
31943198
alpha : float
3199+
group_ids_order : np.ndarray, optional
3200+
Canonical post-filter group id list aligned to Y_mat row order.
3201+
Required when ``obs_survey_info`` is supplied.
3202+
obs_survey_info : dict, optional
3203+
Observation-level survey info with keys ``group_ids`` (raw per-row
3204+
group labels), ``weights`` (per-row survey weights), and ``resolved``
3205+
(ResolvedSurveyDesign). When provided, the regression uses WLS with
3206+
per-group weights W_g = sum of obs survey weights, and SE is computed
3207+
via Binder TSL IF expansion through ``compute_survey_if_variance``.
31953208
31963209
Returns
31973210
-------
@@ -3204,6 +3217,39 @@ def _compute_heterogeneity_test(
32043217
n_groups, n_periods = Y_mat.shape
32053218
results: Dict[int, Dict[str, Any]] = {}
32063219

3220+
# Survey setup (once, before horizon loop). When inactive, df_s=None and
3221+
# the existing plain-OLS path runs unchanged.
3222+
use_survey = (
3223+
obs_survey_info is not None and group_ids_order is not None
3224+
)
3225+
if use_survey:
3226+
from diff_diff.survey import compute_survey_if_variance
3227+
3228+
obs_gids_raw = np.asarray(obs_survey_info["group_ids"])
3229+
obs_w_raw = np.asarray(obs_survey_info["weights"], dtype=np.float64)
3230+
resolved = obs_survey_info["resolved"]
3231+
df_s = (
3232+
resolved.df_survey if resolved is not None else None
3233+
)
3234+
# Contract: only obs whose group is in the canonical post-filter
3235+
# list contribute. Groups dropped upstream (Step 5b interior gaps,
3236+
# Step 6 multi-switch) appear in obs_gids_raw but must be
3237+
# zero-weighted in the IF expansion.
3238+
gid_list = (
3239+
group_ids_order.tolist()
3240+
if hasattr(group_ids_order, "tolist")
3241+
else list(group_ids_order)
3242+
)
3243+
gid_set = set(gid_list)
3244+
valid = np.array([g in gid_set for g in obs_gids_raw])
3245+
# Per-group total weight aligned to Y_mat row order
3246+
W_g_all = np.zeros(n_groups, dtype=np.float64)
3247+
for i, gid in enumerate(gid_list):
3248+
mask_g = (obs_gids_raw == gid) & valid
3249+
W_g_all[i] = obs_w_raw[mask_g].sum()
3250+
else:
3251+
df_s = None
3252+
32073253
for l_h in range(1, L_max + 1):
32083254
# Eligible switchers at this horizon (same logic as multi-horizon DID)
32093255
eligible = []
@@ -3276,20 +3322,78 @@ def _compute_heterogeneity_test(
32763322
}
32773323
continue
32783324

3279-
coefs, _residuals, vcov = solve_ols(
3280-
design, dep_arr,
3281-
return_vcov=True,
3282-
rank_deficient_action=rank_deficient_action,
3283-
)
3325+
if not use_survey:
3326+
# Plain OLS path (unchanged): standard inference per Lemma 7.
3327+
coefs, _residuals, vcov = solve_ols(
3328+
design, dep_arr,
3329+
return_vcov=True,
3330+
rank_deficient_action=rank_deficient_action,
3331+
)
3332+
beta_het = float(coefs[1])
3333+
se_het = float("nan")
3334+
if vcov is not None and np.isfinite(vcov[1, 1]) and vcov[1, 1] > 0:
3335+
se_het = float(np.sqrt(vcov[1, 1]))
3336+
t_stat, p_val, ci = safe_inference(beta_het, se_het, alpha=alpha, df=None)
3337+
else:
3338+
# Survey-aware path: WLS with per-group weights + TSL IF variance.
3339+
W_elig = W_g_all[eligible]
3340+
# solve_ols handles sqrt-weight scaling natively when
3341+
# weight_type='pweight' (linalg.py). Skip vcov — we compute
3342+
# design-based variance ourselves below.
3343+
coefs, _residuals, _vcov_ignored = solve_ols(
3344+
design, dep_arr,
3345+
weights=W_elig, weight_type="pweight",
3346+
return_vcov=False,
3347+
rank_deficient_action=rank_deficient_action,
3348+
)
3349+
# Rank-deficiency short-circuit: if any coef is NaN, return NaN
3350+
# inference. Mixing solve_ols's R-style drop with a pinv-derived
3351+
# IF would describe different estimands.
3352+
if not np.all(np.isfinite(coefs)):
3353+
results[l_h] = {
3354+
"beta": float("nan"), "se": float("nan"),
3355+
"t_stat": float("nan"), "p_value": float("nan"),
3356+
"conf_int": (float("nan"), float("nan")),
3357+
"n_obs": n_obs,
3358+
}
3359+
continue
3360+
3361+
beta_het = float(coefs[1])
3362+
# Original-scale residuals (solve_ols applies sqrt-weight scaling
3363+
# internally and back-transforms residuals, but we need them for
3364+
# our IF computation below).
3365+
r_g = dep_arr - design @ coefs
3366+
3367+
# Group-level IF for β_X: ψ_g[X] = inv(X'WX)[1,:] @ x_g * W_g * r_g.
3368+
# Under full rank (gated above), pinv == inv. Wrap matmuls in
3369+
# errstate: macOS Accelerate BLAS can emit spurious divide/overflow
3370+
# warnings on sparse-cohort designs even though the result is finite.
3371+
with np.errstate(divide="ignore", invalid="ignore", over="ignore"):
3372+
XtWX = design.T @ (W_elig[:, None] * design)
3373+
XtWX_inv = np.linalg.pinv(XtWX)
3374+
psi_g = (XtWX_inv[1, :] @ design.T) * W_elig * r_g # (n_eligible,)
3375+
3376+
# Expand to obs level: ψ_i = ψ_g * (w_i / W_g) for i in group g.
3377+
psi_obs = np.zeros(len(obs_w_raw))
3378+
for e_idx, g_idx in enumerate(eligible):
3379+
gid = gid_list[g_idx]
3380+
mask_g = (obs_gids_raw == gid) & valid
3381+
w_sum_g = obs_w_raw[mask_g].sum()
3382+
if w_sum_g > 0:
3383+
psi_obs[mask_g] = psi_g[e_idx] * (
3384+
obs_w_raw[mask_g] / w_sum_g
3385+
)
32843386

3285-
# beta_het is at index 1 (index 0 is intercept)
3286-
beta_het = float(coefs[1])
3287-
# NaN-safe: if vcov is None or target coefficient variance is NaN
3288-
# (rank-deficient), all inference fields are NaN.
3289-
se_het = float("nan")
3290-
if vcov is not None and np.isfinite(vcov[1, 1]) and vcov[1, 1] > 0:
3291-
se_het = float(np.sqrt(vcov[1, 1]))
3292-
t_stat, p_val, ci = safe_inference(beta_het, se_het, alpha=alpha, df=None)
3387+
# Binder TSL variance across stratified PSUs.
3388+
var_s = compute_survey_if_variance(psi_obs, resolved)
3389+
se_het = (
3390+
float(np.sqrt(var_s))
3391+
if np.isfinite(var_s) and var_s > 0
3392+
else float("nan")
3393+
)
3394+
t_stat, p_val, ci = safe_inference(
3395+
beta_het, se_het, alpha=alpha, df=df_s
3396+
)
32933397

32943398
results[l_h] = {
32953399
"beta": beta_het,
@@ -4891,6 +4995,7 @@ def twowayfeweights(
48914995
time: str,
48924996
treatment: str,
48934997
rank_deficient_action: str = "warn",
4998+
survey_design: Any = None,
48944999
) -> TWFEWeightsResult:
48955000
"""
48965001
Standalone TWFE decomposition diagnostic.
@@ -4910,13 +5015,35 @@ def twowayfeweights(
49105015
treatment : str
49115016
rank_deficient_action : str, default="warn"
49125017
Action when the FE design matrix is rank-deficient.
5018+
survey_design : SurveyDesign, optional
5019+
If provided, cell aggregation uses survey-weighted cell means
5020+
(matching ``fit(..., survey_design=sd).twfe_*``). Required to preserve
5021+
fit-vs-helper parity under survey-backed inputs. Only
5022+
``weight_type='pweight'`` is supported; other types raise ValueError.
49135023
49145024
Returns
49155025
-------
49165026
TWFEWeightsResult
49175027
Object with attributes ``weights`` (DataFrame), ``fraction_negative``
49185028
(float), ``sigma_fe`` (float), and ``beta_fe`` (float).
49195029
"""
5030+
# Survey resolution (optional): mirrors the fit() path so that the
5031+
# standalone helper produces identical numbers to fit(..., survey_design=sd).
5032+
survey_weights = None
5033+
if survey_design is not None:
5034+
from diff_diff.survey import _resolve_survey_for_fit
5035+
5036+
resolved, survey_weights, _, _ = _resolve_survey_for_fit(
5037+
survey_design, data, "analytical"
5038+
)
5039+
if resolved is not None and resolved.weight_type != "pweight":
5040+
raise ValueError(
5041+
f"twowayfeweights() survey support requires "
5042+
f"weight_type='pweight', got '{resolved.weight_type}'. "
5043+
f"The TWFE diagnostic under survey uses survey-weighted cell "
5044+
f"means; other weight types are not supported."
5045+
)
5046+
49205047
# Validation + cell aggregation via the same helper used by
49215048
# ChaisemartinDHaultfoeuille.fit() — enforces NaN/binary/within-cell
49225049
# rules from REGISTRY.md so the standalone diagnostic does not
@@ -4927,6 +5054,7 @@ def twowayfeweights(
49275054
group=group,
49285055
time=time,
49295056
treatment=treatment,
5057+
weights=survey_weights,
49305058
)
49315059
# TWFE diagnostic assumes binary treatment (d_arr == 1 for treated mask).
49325060
if not set(cell["d_gt"].unique()).issubset({0.0, 1.0, 0, 1}):

0 commit comments

Comments
 (0)