@@ -2394,6 +2394,8 @@ def fit(
23942394 L_max = L_max ,
23952395 alpha = self .alpha ,
23962396 rank_deficient_action = self .rank_deficient_action ,
2397+ group_ids_order = np .array (all_groups ),
2398+ obs_survey_info = _obs_survey_info ,
23972399 )
23982400
23992401 twfe_weights_df = None
@@ -3174,6 +3176,8 @@ def _compute_heterogeneity_test(
31743176 L_max : int ,
31753177 alpha : float = 0.05 ,
31763178 rank_deficient_action : str = "warn" ,
3179+ group_ids_order : Optional [np .ndarray ] = None ,
3180+ obs_survey_info : Optional [Dict [str , Any ]] = None ,
31773181) -> Dict [int , Dict [str , Any ]]:
31783182 """Test for heterogeneous treatment effects (Web Appendix Section 1.5).
31793183
@@ -3192,6 +3196,15 @@ def _compute_heterogeneity_test(
31923196 Time-invariant covariate to test for heterogeneity.
31933197 L_max : int
31943198 alpha : float
3199+ group_ids_order : np.ndarray, optional
3200+ Canonical post-filter group id list aligned to Y_mat row order.
3201+ Required when ``obs_survey_info`` is supplied.
3202+ obs_survey_info : dict, optional
3203+ Observation-level survey info with keys ``group_ids`` (raw per-row
3204+ group labels), ``weights`` (per-row survey weights), and ``resolved``
3205+ (ResolvedSurveyDesign). When provided, the regression uses WLS with
3206+ per-group weights W_g = sum of obs survey weights, and SE is computed
3207+ via Binder TSL IF expansion through ``compute_survey_if_variance``.
31953208
31963209 Returns
31973210 -------
@@ -3204,6 +3217,39 @@ def _compute_heterogeneity_test(
32043217 n_groups , n_periods = Y_mat .shape
32053218 results : Dict [int , Dict [str , Any ]] = {}
32063219
3220+ # Survey setup (once, before horizon loop). When inactive, df_s=None and
3221+ # the existing plain-OLS path runs unchanged.
3222+ use_survey = (
3223+ obs_survey_info is not None and group_ids_order is not None
3224+ )
3225+ if use_survey :
3226+ from diff_diff .survey import compute_survey_if_variance
3227+
3228+ obs_gids_raw = np .asarray (obs_survey_info ["group_ids" ])
3229+ obs_w_raw = np .asarray (obs_survey_info ["weights" ], dtype = np .float64 )
3230+ resolved = obs_survey_info ["resolved" ]
3231+ df_s = (
3232+ resolved .df_survey if resolved is not None else None
3233+ )
3234+ # Contract: only obs whose group is in the canonical post-filter
3235+ # list contribute. Groups dropped upstream (Step 5b interior gaps,
3236+ # Step 6 multi-switch) appear in obs_gids_raw but must be
3237+ # zero-weighted in the IF expansion.
3238+ gid_list = (
3239+ group_ids_order .tolist ()
3240+ if hasattr (group_ids_order , "tolist" )
3241+ else list (group_ids_order )
3242+ )
3243+ gid_set = set (gid_list )
3244+ valid = np .array ([g in gid_set for g in obs_gids_raw ])
3245+ # Per-group total weight aligned to Y_mat row order
3246+ W_g_all = np .zeros (n_groups , dtype = np .float64 )
3247+ for i , gid in enumerate (gid_list ):
3248+ mask_g = (obs_gids_raw == gid ) & valid
3249+ W_g_all [i ] = obs_w_raw [mask_g ].sum ()
3250+ else :
3251+ df_s = None
3252+
32073253 for l_h in range (1 , L_max + 1 ):
32083254 # Eligible switchers at this horizon (same logic as multi-horizon DID)
32093255 eligible = []
@@ -3276,20 +3322,78 @@ def _compute_heterogeneity_test(
32763322 }
32773323 continue
32783324
3279- coefs , _residuals , vcov = solve_ols (
3280- design , dep_arr ,
3281- return_vcov = True ,
3282- rank_deficient_action = rank_deficient_action ,
3283- )
3325+ if not use_survey :
3326+ # Plain OLS path (unchanged): standard inference per Lemma 7.
3327+ coefs , _residuals , vcov = solve_ols (
3328+ design , dep_arr ,
3329+ return_vcov = True ,
3330+ rank_deficient_action = rank_deficient_action ,
3331+ )
3332+ beta_het = float (coefs [1 ])
3333+ se_het = float ("nan" )
3334+ if vcov is not None and np .isfinite (vcov [1 , 1 ]) and vcov [1 , 1 ] > 0 :
3335+ se_het = float (np .sqrt (vcov [1 , 1 ]))
3336+ t_stat , p_val , ci = safe_inference (beta_het , se_het , alpha = alpha , df = None )
3337+ else :
3338+ # Survey-aware path: WLS with per-group weights + TSL IF variance.
3339+ W_elig = W_g_all [eligible ]
3340+ # solve_ols handles sqrt-weight scaling natively when
3341+ # weight_type='pweight' (linalg.py). Skip vcov — we compute
3342+ # design-based variance ourselves below.
3343+ coefs , _residuals , _vcov_ignored = solve_ols (
3344+ design , dep_arr ,
3345+ weights = W_elig , weight_type = "pweight" ,
3346+ return_vcov = False ,
3347+ rank_deficient_action = rank_deficient_action ,
3348+ )
3349+ # Rank-deficiency short-circuit: if any coef is NaN, return NaN
3350+ # inference. Mixing solve_ols's R-style drop with a pinv-derived
3351+ # IF would describe different estimands.
3352+ if not np .all (np .isfinite (coefs )):
3353+ results [l_h ] = {
3354+ "beta" : float ("nan" ), "se" : float ("nan" ),
3355+ "t_stat" : float ("nan" ), "p_value" : float ("nan" ),
3356+ "conf_int" : (float ("nan" ), float ("nan" )),
3357+ "n_obs" : n_obs ,
3358+ }
3359+ continue
3360+
3361+ beta_het = float (coefs [1 ])
3362+ # Original-scale residuals (solve_ols applies sqrt-weight scaling
3363+ # internally and back-transforms residuals, but we need them for
3364+ # our IF computation below).
3365+ r_g = dep_arr - design @ coefs
3366+
3367+ # Group-level IF for β_X: ψ_g[X] = inv(X'WX)[1,:] @ x_g * W_g * r_g.
3368+ # Under full rank (gated above), pinv == inv. Wrap matmuls in
3369+ # errstate: macOS Accelerate BLAS can emit spurious divide/overflow
3370+ # warnings on sparse-cohort designs even though the result is finite.
3371+ with np .errstate (divide = "ignore" , invalid = "ignore" , over = "ignore" ):
3372+ XtWX = design .T @ (W_elig [:, None ] * design )
3373+ XtWX_inv = np .linalg .pinv (XtWX )
3374+ psi_g = (XtWX_inv [1 , :] @ design .T ) * W_elig * r_g # (n_eligible,)
3375+
3376+ # Expand to obs level: ψ_i = ψ_g * (w_i / W_g) for i in group g.
3377+ psi_obs = np .zeros (len (obs_w_raw ))
3378+ for e_idx , g_idx in enumerate (eligible ):
3379+ gid = gid_list [g_idx ]
3380+ mask_g = (obs_gids_raw == gid ) & valid
3381+ w_sum_g = obs_w_raw [mask_g ].sum ()
3382+ if w_sum_g > 0 :
3383+ psi_obs [mask_g ] = psi_g [e_idx ] * (
3384+ obs_w_raw [mask_g ] / w_sum_g
3385+ )
32843386
3285- # beta_het is at index 1 (index 0 is intercept)
3286- beta_het = float (coefs [1 ])
3287- # NaN-safe: if vcov is None or target coefficient variance is NaN
3288- # (rank-deficient), all inference fields are NaN.
3289- se_het = float ("nan" )
3290- if vcov is not None and np .isfinite (vcov [1 , 1 ]) and vcov [1 , 1 ] > 0 :
3291- se_het = float (np .sqrt (vcov [1 , 1 ]))
3292- t_stat , p_val , ci = safe_inference (beta_het , se_het , alpha = alpha , df = None )
3387+ # Binder TSL variance across stratified PSUs.
3388+ var_s = compute_survey_if_variance (psi_obs , resolved )
3389+ se_het = (
3390+ float (np .sqrt (var_s ))
3391+ if np .isfinite (var_s ) and var_s > 0
3392+ else float ("nan" )
3393+ )
3394+ t_stat , p_val , ci = safe_inference (
3395+ beta_het , se_het , alpha = alpha , df = df_s
3396+ )
32933397
32943398 results [l_h ] = {
32953399 "beta" : beta_het ,
@@ -4891,6 +4995,7 @@ def twowayfeweights(
48914995 time : str ,
48924996 treatment : str ,
48934997 rank_deficient_action : str = "warn" ,
4998+ survey_design : Any = None ,
48944999) -> TWFEWeightsResult :
48955000 """
48965001 Standalone TWFE decomposition diagnostic.
@@ -4910,13 +5015,35 @@ def twowayfeweights(
49105015 treatment : str
49115016 rank_deficient_action : str, default="warn"
49125017 Action when the FE design matrix is rank-deficient.
5018+ survey_design : SurveyDesign, optional
5019+ If provided, cell aggregation uses survey-weighted cell means
5020+ (matching ``fit(..., survey_design=sd).twfe_*``). Required to preserve
5021+ fit-vs-helper parity under survey-backed inputs. Only
5022+ ``weight_type='pweight'`` is supported; other types raise ValueError.
49135023
49145024 Returns
49155025 -------
49165026 TWFEWeightsResult
49175027 Object with attributes ``weights`` (DataFrame), ``fraction_negative``
49185028 (float), ``sigma_fe`` (float), and ``beta_fe`` (float).
49195029 """
5030+ # Survey resolution (optional): mirrors the fit() path so that the
5031+ # standalone helper produces identical numbers to fit(..., survey_design=sd).
5032+ survey_weights = None
5033+ if survey_design is not None :
5034+ from diff_diff .survey import _resolve_survey_for_fit
5035+
5036+ resolved , survey_weights , _ , _ = _resolve_survey_for_fit (
5037+ survey_design , data , "analytical"
5038+ )
5039+ if resolved is not None and resolved .weight_type != "pweight" :
5040+ raise ValueError (
5041+ f"twowayfeweights() survey support requires "
5042+ f"weight_type='pweight', got '{ resolved .weight_type } '. "
5043+ f"The TWFE diagnostic under survey uses survey-weighted cell "
5044+ f"means; other weight types are not supported."
5045+ )
5046+
49205047 # Validation + cell aggregation via the same helper used by
49215048 # ChaisemartinDHaultfoeuille.fit() — enforces NaN/binary/within-cell
49225049 # rules from REGISTRY.md so the standalone diagnostic does not
@@ -4927,6 +5054,7 @@ def twowayfeweights(
49275054 group = group ,
49285055 time = time ,
49295056 treatment = treatment ,
5057+ weights = survey_weights ,
49305058 )
49315059 # TWFE diagnostic assumes binary treatment (d_arr == 1 for treated mask).
49325060 if not set (cell ["d_gt" ].unique ()).issubset ({0.0 , 1.0 , 0 , 1 }):
0 commit comments