@@ -133,7 +133,7 @@ def __init__(
133133 self .kernel_bandwidth = kernel_bandwidth
134134 self .is_fitted_ = False
135135 self .results_ : Optional [EfficientDiDResults ] = None
136- self ._survey_se_ctx : Optional [ tuple ] = None
136+ self ._unit_resolved_survey = None
137137 self ._validate_params ()
138138
139139 def _validate_params (self ) -> None :
@@ -361,9 +361,45 @@ def fit(
361361 all_units = sorted (df [unit ].unique ())
362362 n_units = len (all_units )
363363
364- # Build unit-to-first-panel-row index (for unit-level survey collapse)
365- _first_rows = df .groupby (unit ).cumcount () == 0
366- self ._unit_first_panel_row = np .where (_first_rows )[0 ]
364+ # Build unit-to-first-panel-row index aligned to all_units (sorted)
365+ # order. The previous approach (groupby cumcount == 0) yielded
366+ # first-appearance order which can differ from sorted order when the
367+ # input DataFrame is not pre-sorted by unit.
368+ first_pos : Dict [Any , int ] = {}
369+ for i , u in enumerate (df [unit ].values ):
370+ if u not in first_pos :
371+ first_pos [u ] = i
372+ self ._unit_first_panel_row = np .array ([first_pos [u ] for u in all_units ])
373+
374+ # Build unit-level ResolvedSurveyDesign once (avoids repeated
375+ # construction in _compute_survey_eif_se and ensures consistent
376+ # unit-level df for safe_inference t-distribution).
377+ if resolved_survey is not None :
378+ from diff_diff .survey import ResolvedSurveyDesign
379+
380+ row_idx = self ._unit_first_panel_row
381+ unit_weights_s = resolved_survey .weights [row_idx ]
382+ unit_strata = (
383+ resolved_survey .strata [row_idx ] if resolved_survey .strata is not None else None
384+ )
385+ unit_psu = resolved_survey .psu [row_idx ] if resolved_survey .psu is not None else None
386+ unit_fpc = resolved_survey .fpc [row_idx ] if resolved_survey .fpc is not None else None
387+ n_strata_u = len (np .unique (unit_strata )) if unit_strata is not None else 0
388+ n_psu_u = len (np .unique (unit_psu )) if unit_psu is not None else 0
389+ self ._unit_resolved_survey = ResolvedSurveyDesign (
390+ weights = unit_weights_s ,
391+ weight_type = resolved_survey .weight_type ,
392+ strata = unit_strata ,
393+ psu = unit_psu ,
394+ fpc = unit_fpc ,
395+ n_strata = n_strata_u ,
396+ n_psu = n_psu_u ,
397+ lonely_psu = resolved_survey .lonely_psu ,
398+ )
399+ # Use unit-level df (not panel-level) for t-distribution
400+ self ._survey_df = self ._unit_resolved_survey .df_survey
401+ else :
402+ self ._unit_resolved_survey = None
367403
368404 period_to_col = {p : i for i , p in enumerate (time_periods )}
369405 period_1 = time_periods [0 ]
@@ -686,11 +722,8 @@ def fit(
686722
687723 # Analytical SE = sqrt(mean(EIF^2) / n) [paper p.21]
688724 # With survey: use TSL variance via compute_survey_vcov
689- if resolved_survey is not None :
690- se_gt = self ._compute_survey_eif_se (
691- eif_vals ,
692- resolved_survey ,
693- )
725+ if self ._unit_resolved_survey is not None :
726+ se_gt = self ._compute_survey_eif_se (eif_vals )
694727 else :
695728 se_gt = float (np .sqrt (np .mean (eif_vals ** 2 ) / n_units ))
696729
@@ -714,12 +747,6 @@ def fit(
714747 "Check data has sufficient observations."
715748 )
716749
717- # ----- Store survey context for aggregation SE helpers -----
718- # Temporarily store survey context for use in aggregation helpers.
719- # This avoids threading survey args through the deeply nested
720- # aggregation methods that are also used by the bootstrap mixin.
721- self ._survey_se_ctx = resolved_survey if resolved_survey is not None else None
722-
723750 # ----- Aggregation -----
724751 overall_att , overall_se = self ._aggregate_overall (
725752 group_time_effects , eif_by_gt , n_units , cohort_fractions , unit_cohorts
@@ -752,9 +779,6 @@ def fit(
752779 unit_cohorts = unit_cohorts ,
753780 )
754781
755- # Clean up temporary survey context
756- self ._survey_se_ctx = None
757-
758782 # ----- Bootstrap -----
759783 bootstrap_results = None
760784 if self .n_bootstrap > 0 and eif_by_gt :
@@ -855,63 +879,27 @@ def fit(
855879
856880 # -- Survey SE helpers ----------------------------------------------------
857881
858- def _compute_survey_eif_se (
859- self ,
860- eif_vals : np .ndarray ,
861- resolved_survey : Any ,
862- ) -> float :
882+ def _compute_survey_eif_se (self , eif_vals : np .ndarray ) -> float :
863883 """Compute SE from EIF scores using Taylor Series Linearization.
864884
865- The EIF is at unit level (shape n_units). We collapse the
866- panel-level resolved survey to unit level using the first-panel-row
867- index and pass unit-level arrays to ``compute_survey_vcov``.
868- This avoids the previous bug where expanding EIF to panel rows
869- created one implicit PSU per period-copy, deflating SEs for
870- weights-only and stratified-no-PSU survey designs.
885+ Uses the pre-built unit-level ``_unit_resolved_survey`` constructed
886+ once in ``fit()``, ensuring consistent unit-level arrays and
887+ avoiding repeated subsetting of panel-level survey data.
871888 """
872- from diff_diff .survey import ResolvedSurveyDesign , compute_survey_vcov
889+ from diff_diff .survey import compute_survey_vcov
873890
874- row_idx = self ._unit_first_panel_row
875- n_units = len (eif_vals )
876-
877- # Subset survey arrays to unit level
878- unit_weights = resolved_survey .weights [row_idx ]
879- unit_strata = (
880- resolved_survey .strata [row_idx ] if resolved_survey .strata is not None else None
881- )
882- unit_psu = resolved_survey .psu [row_idx ] if resolved_survey .psu is not None else None
883- unit_fpc = resolved_survey .fpc [row_idx ] if resolved_survey .fpc is not None else None
884-
885- # Count unique strata/PSU in the unit-level subset
886- n_strata_unit = len (np .unique (unit_strata )) if unit_strata is not None else 0
887- n_psu_unit = len (np .unique (unit_psu )) if unit_psu is not None else 0
888-
889- unit_resolved = ResolvedSurveyDesign (
890- weights = unit_weights ,
891- weight_type = resolved_survey .weight_type ,
892- strata = unit_strata ,
893- psu = unit_psu ,
894- fpc = unit_fpc ,
895- n_strata = n_strata_unit ,
896- n_psu = n_psu_unit ,
897- lonely_psu = resolved_survey .lonely_psu ,
898- )
899-
900- X_ones = np .ones ((n_units , 1 ))
901- vcov = compute_survey_vcov (X_ones , eif_vals , unit_resolved )
891+ X_ones = np .ones ((len (eif_vals ), 1 ))
892+ vcov = compute_survey_vcov (X_ones , eif_vals , self ._unit_resolved_survey )
902893 return float (np .sqrt (np .abs (vcov [0 , 0 ])))
903894
904895 def _eif_se (self , eif_vals : np .ndarray , n_units : int ) -> float :
905896 """Compute SE from aggregated EIF scores.
906897
907- Dispatches to survey TSL when ``_survey_se_ctx `` is set (during
908- fit), otherwise uses the standard analytical formula.
898+ Dispatches to survey TSL when ``_unit_resolved_survey `` is set
899+ (during fit), otherwise uses the standard analytical formula.
909900 """
910- if self ._survey_se_ctx is not None :
911- return self ._compute_survey_eif_se (
912- eif_vals ,
913- self ._survey_se_ctx ,
914- )
901+ if self ._unit_resolved_survey is not None :
902+ return self ._compute_survey_eif_se (eif_vals )
915903 return float (np .sqrt (np .mean (eif_vals ** 2 ) / n_units ))
916904
917905 # -- Aggregation helpers --------------------------------------------------
0 commit comments