@@ -665,10 +665,7 @@ def fit(
665665 clustered bootstrap. **Out-of-scope combinations raise
666666 ``NotImplementedError``**: (a) replicate weights with
667667 ``n_bootstrap > 0`` (replicate variance is closed-form;
668- bootstrap would double-count variance); (b)
669- ``heterogeneity=`` with PSU/strata that vary within group
670- (heterogeneity WLS still uses the legacy group-level IF
671- expansion; follow-up PR extends it); (c) ``n_bootstrap >
668+ bootstrap would double-count variance); (b) ``n_bootstrap >
672669 0`` with PSU that varies within group (PSU-level bootstrap
673670 still uses the legacy group-level PSU map; follow-up PR
674671 extends it). See REGISTRY.md
@@ -828,39 +825,25 @@ def fit(
828825
829826 # Cell-period IF allocator contract: strata and PSU must be
830827 # constant within each (g, t) cell, a strict relaxation of
831- # the previous within-group constancy rule. Two out-of-scope
832- # combinations are gated with NotImplementedError until the
833- # corresponding follow-up PRs extend them:
834- # - heterogeneity= + within-group-varying PSU/strata
835- # (PR 3: cell-period allocator for the WLS psi_obs)
828+ # the previous within-group constancy rule. One out-of-scope
829+ # combination remains gated with NotImplementedError until
830+ # the corresponding follow-up PR extends it:
836831 # - n_bootstrap > 0 + within-group-varying PSU
837832 # (PR 4: cell-level Hall-Mammen wild bootstrap)
838- strata_varies , psu_varies = _strata_psu_vary_within_group (
833+ _ , psu_varies = _strata_psu_vary_within_group (
839834 resolved_survey , data , group , survey_weights ,
840835 )
841- if strata_varies or psu_varies :
842- if heterogeneity is not None :
843- raise NotImplementedError (
844- "heterogeneity= is not supported under a survey "
845- "design whose PSU or strata vary within group. "
846- "The heterogeneity WLS path uses the legacy "
847- "group-level IF expansion and will be extended "
848- "to the cell-period allocator in a follow-up "
849- "PR. For now, either (a) set heterogeneity=None, "
850- "or (b) collapse PSU/strata to be constant "
851- "within each group."
852- )
853- if self .n_bootstrap > 0 :
854- raise NotImplementedError (
855- "n_bootstrap > 0 is not supported under a "
856- "survey design whose PSU varies within group. "
857- "The PSU-level Hall-Mammen wild bootstrap uses "
858- "the legacy group-level PSU map and will be "
859- "extended to cell-level PSU in a follow-up PR. "
860- "For now, use n_bootstrap=0 (analytical TSL "
861- "variance, which fully supports within-group-"
862- "varying PSU via the cell-period allocator)."
863- )
836+ if psu_varies and self .n_bootstrap > 0 :
837+ raise NotImplementedError (
838+ "n_bootstrap > 0 is not supported under a "
839+ "survey design whose PSU varies within group. "
840+ "The PSU-level Hall-Mammen wild bootstrap uses "
841+ "the legacy group-level PSU map and will be "
842+ "extended to cell-level PSU in a follow-up PR. "
843+ "For now, use n_bootstrap=0 (analytical TSL "
844+ "variance, which fully supports within-group-"
845+ "varying PSU via the cell-period allocator)."
846+ )
864847 _validate_cell_constant_strata_psu (
865848 resolved_survey , data , group , time , survey_weights ,
866849 )
@@ -3734,15 +3717,45 @@ def _compute_heterogeneity_test(
37343717 Required when ``obs_survey_info`` is supplied.
37353718 obs_survey_info : dict, optional
37363719 Observation-level survey info with keys ``group_ids`` (raw per-row
3737- group labels), ``weights`` (per-row survey weights), and ``resolved``
3738- (ResolvedSurveyDesign). When provided, the regression uses WLS with
3739- per-group weights W_g = sum of obs survey weights. SE is computed
3740- via Binder TSL IF expansion through ``compute_survey_if_variance``
3741- by default; under a replicate-weight design (BRR/Fay/JK1/JKn/SDR),
3742- dispatches to ``compute_replicate_if_variance`` for Rao-Wu-style
3743- variance. The effective df for t-critical values follows the
3744- site-level ``min(df_s, n_valid_het - 1)`` rule and the helper
3745- mutates ``replicate_n_valid_list`` so the final
3720+ group labels), ``time_ids`` (raw per-row period labels),
3721+ ``weights`` (per-row survey weights), ``resolved``
3722+ (ResolvedSurveyDesign), and ``periods`` (sorted canonical period
3723+ array matching ``Y_mat``'s column order). When provided, the
3724+ regression uses WLS with per-group weights
3725+ ``W_g = sum of obs survey weights in group g``. The group-level
3726+ WLS coefficient IF is
3727+ ``ψ_g = inv(X'WX)[1,:] @ x_g * W_g * r_g``. Two observation-level
3728+ expansions of ``ψ_g`` coexist on this path, split by variance
3729+ helper so each path uses the allocator that preserves
3730+ byte-identity for its aggregation rule:
3731+
3732+ * **Binder TSL** (``compute_survey_if_variance``): the
3733+ cell-period single-cell allocator —
3734+ ``ψ_i = ψ_g * (w_i / W_{g, out_idx})`` for obs in
3735+ ``(g, out_idx)``, zero elsewhere. Under PSU=group per-obs
3736+ distribution differs from the legacy
3737+ ``ψ_i = ψ_g * (w_i / W_g)`` but PSU-level aggregates
3738+ telescope to the same ``ψ_g``, so Binder variance is
3739+ byte-identical to the pre-cell-period release. Under
3740+ within-group-varying PSU mass lands in the post-period PSU
3741+ of the transition (DID_l post-period convention).
3742+ * **Rao-Wu replicate** (``compute_replicate_if_variance``):
3743+ the legacy group-level allocator ``ψ_i = ψ_g * (w_i / W_g)``.
3744+ Replicate variance computes ``θ_r = sum_i ratio_ir * ψ_i``
3745+ at observation level, so moving ψ_g mass onto the
3746+ post-period cell would silently change the replicate SE
3747+ whenever a replicate column's ratios vary within a group
3748+ (which the library allows — e.g., per-row BRR/Fay/SDR
3749+ matrices). Keeping the legacy allocator on this branch
3750+ preserves byte-identity of replicate SE across every
3751+ previously-supported fit. Replicate + within-group-varying
3752+ PSU is unreachable by construction (``SurveyDesign``
3753+ rejects ``replicate_weights`` combined with explicit
3754+ ``strata/psu/fpc``).
3755+
3756+ The effective df for t-critical values follows the site-level
3757+ ``min(df_s, n_valid_het - 1)`` rule and the helper mutates
3758+ ``replicate_n_valid_list`` so the final
37463759 ``_effective_df_survey(...)`` sees this site's n_valid.
37473760 replicate_n_valid_list : list[int], optional
37483761 Shared accumulator for replicate-weight ``n_valid`` counts across
@@ -3931,25 +3944,46 @@ def _compute_heterogeneity_test(
39313944 XtWX_inv = np .linalg .pinv (XtWX )
39323945 psi_g = (XtWX_inv [1 , :] @ design .T ) * W_elig * r_g # (n_eligible,)
39333946
3934- # Expand to obs level: ψ_i = ψ_g * (w_i / W_g) for i in group g.
3935- psi_obs = np .zeros (len (obs_w_raw ))
3936- for e_idx , g_idx in enumerate (eligible ):
3937- gid = gid_list [g_idx ]
3938- mask_g = (obs_gids_raw == gid ) & valid
3939- w_sum_g = obs_w_raw [mask_g ].sum ()
3940- if w_sum_g > 0 :
3941- psi_obs [mask_g ] = psi_g [e_idx ] * (
3942- obs_w_raw [mask_g ] / w_sum_g
3943- )
3944-
3945- # Dispatch: replicate-weight variance (BRR/Fay/JK1/JKn/SDR)
3946- # vs Binder TSL across stratified PSUs. Mirrors the inline
3947- # branch in _survey_se_from_group_if and the pattern in
3948- # TripleDifference:1206-1238. Heterogeneity uses WLS with
3949- # full-sample weights; theta_hat is treated as fixed per the
3950- # FWL plug-in IF convention (REGISTRY.md Note on heterogeneity
3951- # under replicate — no per-replicate refits).
3947+ # Allocator dispatch. Two observation-level expansions of
3948+ # ψ_g coexist on this path, split by variance helper:
3949+ #
3950+ # * Binder TSL (compute_survey_if_variance): cell-period
3951+ # single-cell allocator —
3952+ # ψ_i = ψ_g * (w_i / W_{g, out_idx})
3953+ # for obs in (g, out_idx), zero elsewhere. Under
3954+ # PSU=group, per-obs distribution differs from the
3955+ # legacy ψ_i = ψ_g * (w_i / W_g) but PSU-level
3956+ # aggregates telescope to ψ_g, so Binder variance is
3957+ # byte-identical. Under within-group-varying PSU, mass
3958+ # lands in the post-period PSU of the transition, which
3959+ # is what Binder needs. DID_l single-cell convention —
3960+ # see REGISTRY.md ChaisemartinDHaultfoeuille survey IF
3961+ # expansion Note.
3962+ #
3963+ # * Rao-Wu replicate (compute_replicate_if_variance):
3964+ # legacy group-level allocator —
3965+ # ψ_i = ψ_g * (w_i / W_g)
3966+ # for obs in group g. Replicate variance computes
3967+ # θ_r = sum_i ratio_ir * ψ_i at observation level, so
3968+ # moving ψ_g onto the post-period cell only would
3969+ # silently change the replicate SE whenever a
3970+ # replicate column's ratios vary within group (e.g.,
3971+ # the per-row replicate matrices this library
3972+ # accepts). The group-level allocator preserves
3973+ # byte-identity for all replicate usages under
3974+ # PSU=group. The replicate + within-group-varying
3975+ # PSU case is not reachable (SurveyDesign rejects
3976+ # replicate_weights combined with explicit psu).
39523977 if getattr (resolved , "uses_replicate_variance" , False ):
3978+ psi_obs = np .zeros (len (obs_w_raw ), dtype = np .float64 )
3979+ for e_idx , g_idx in enumerate (eligible ):
3980+ gid = gid_list [g_idx ]
3981+ mask_g = (obs_gids_raw == gid ) & valid
3982+ w_sum_g = obs_w_raw [mask_g ].sum ()
3983+ if w_sum_g > 0 :
3984+ psi_obs [mask_g ] = psi_g [e_idx ] * (
3985+ obs_w_raw [mask_g ] / w_sum_g
3986+ )
39533987 var_s , n_valid_het = compute_replicate_if_variance (
39543988 psi_obs , resolved
39553989 )
@@ -3968,6 +4002,23 @@ def _compute_heterogeneity_test(
39684002 else :
39694003 df_s_local = min (int (df_s ), int (n_valid_het ) - 1 )
39704004 else :
4005+ obs_tids = np .asarray (obs_survey_info ["time_ids" ])
4006+ periods_arr = np .asarray (obs_survey_info ["periods" ])
4007+ psi_obs = np .zeros (len (obs_w_raw ), dtype = np .float64 )
4008+ for e_idx , g_idx in enumerate (eligible ):
4009+ gid = gid_list [g_idx ]
4010+ out_idx = first_switch_idx [g_idx ] - 1 + l_h
4011+ t_val_out = periods_arr [out_idx ]
4012+ mask_cell = (
4013+ (obs_gids_raw == gid )
4014+ & (obs_tids == t_val_out )
4015+ & valid
4016+ )
4017+ w_cell = obs_w_raw [mask_cell ].sum ()
4018+ if w_cell > 0 :
4019+ psi_obs [mask_cell ] = psi_g [e_idx ] * (
4020+ obs_w_raw [mask_cell ] / w_cell
4021+ )
39714022 var_s = compute_survey_if_variance (psi_obs , resolved )
39724023 df_s_local = df_s
39734024 se_het = (
@@ -5413,12 +5464,14 @@ def _strata_psu_vary_within_group(
54135464) -> Tuple [bool , bool ]:
54145465 """Return (strata_varies_within_group, psu_varies_within_group).
54155466
5416- Diagnostic helper used to gate out-of-scope combinations for the
5417- cell-period IF allocator — heterogeneity and ``n_bootstrap > 0``
5418- currently require within-group constancy because they read
5419- ``obs_survey_info`` through the legacy group-level expansion path.
5420- PR 3 and PR 4 will extend them. Zero-weight rows are excluded from
5421- the check (subpopulation contract).
5467+ Diagnostic helper used at ``fit()`` time to gate the remaining
5468+ out-of-scope combination for the cell-period IF allocator:
5469+ ``n_bootstrap > 0`` still uses a group-level PSU map and raises
5470+ ``NotImplementedError`` when PSU varies within group. The
5471+ heterogeneity WLS path supports within-group-varying PSU/strata
5472+ via the cell-period allocator (shipped in the PR that lifted the
5473+ previous gate). Zero-weight rows are excluded from the check
5474+ (subpopulation contract).
54225475 """
54235476 if resolved is None :
54245477 return False , False
@@ -5766,10 +5819,12 @@ def _survey_se_from_group_if(
57665819 else :
57675820 # Legacy group-level allocator (no per-period attribution
57685821 # provided, or time/period info unavailable). Preserved for
5769- # paths that haven't threaded per-period attribution through
5770- # yet (e.g., the heterogeneity psi_obs construction in
5771- # _compute_heterogeneity_test — gated to within-group-constant
5772- # PSU in Stage 2 per PR 2 scope).
5822+ # defensive fallback and for unit tests that exercise the
5823+ # legacy allocator. No current caller in fit() uses this
5824+ # branch — ATT / joiners / leavers / placebos all thread
5825+ # U_centered_per_period, and heterogeneity (as of PR 3)
5826+ # constructs its own cell-period psi_obs and calls
5827+ # compute_survey_if_variance directly.
57735828 group_to_u = {gid : U_centered [idx ] for idx , gid in enumerate (eligible_groups )}
57745829 u_obs_eff = np .array ([group_to_u .get (gid , 0.0 ) for gid in gids_eff ])
57755830 unique_gids , inverse = np .unique (gids_eff , return_inverse = True )
0 commit comments