@@ -354,83 +354,9 @@ def _collapse_survey_to_unit_level(resolved_survey, df, unit_col, all_units):
354354 Survey design columns are constant within units (validated upstream).
355355 This extracts one row per unit, aligned to ``all_units`` ordering.
356356 """
357- from diff_diff .survey import ResolvedSurveyDesign
358-
359- n_units = len (all_units )
360- # Use groupby().first() to get one value per unit, then reindex
361- unit_groups = df .groupby (unit_col )
362-
363- weights_unit = (
364- pd .Series (resolved_survey .weights , index = df .index )
365- .groupby (df [unit_col ])
366- .first ()
367- .reindex (all_units )
368- .values
369- )
370-
371- strata_unit = None
372- if resolved_survey .strata is not None :
373- strata_unit = (
374- pd .Series (resolved_survey .strata , index = df .index )
375- .groupby (df [unit_col ])
376- .first ()
377- .reindex (all_units )
378- .values
379- )
380-
381- psu_unit = None
382- if resolved_survey .psu is not None :
383- psu_unit = (
384- pd .Series (resolved_survey .psu , index = df .index )
385- .groupby (df [unit_col ])
386- .first ()
387- .reindex (all_units )
388- .values
389- )
357+ from diff_diff .survey import collapse_survey_to_unit_level
390358
391- fpc_unit = None
392- if resolved_survey .fpc is not None :
393- fpc_unit = (
394- pd .Series (resolved_survey .fpc , index = df .index )
395- .groupby (df [unit_col ])
396- .first ()
397- .reindex (all_units )
398- .values
399- )
400-
401- # Collapse replicate weights to unit level (same groupby pattern)
402- rep_weights_unit = None
403- if resolved_survey .replicate_weights is not None :
404- R = resolved_survey .replicate_weights .shape [1 ]
405- rep_weights_unit = np .zeros ((n_units , R ))
406- for r in range (R ):
407- rep_weights_unit [:, r ] = (
408- pd .Series (resolved_survey .replicate_weights [:, r ], index = df .index )
409- .groupby (df [unit_col ])
410- .first ()
411- .reindex (all_units )
412- .values
413- )
414-
415- return ResolvedSurveyDesign (
416- weights = weights_unit .astype (np .float64 ),
417- weight_type = resolved_survey .weight_type ,
418- strata = strata_unit ,
419- psu = psu_unit ,
420- fpc = fpc_unit ,
421- n_strata = resolved_survey .n_strata ,
422- n_psu = resolved_survey .n_psu ,
423- lonely_psu = resolved_survey .lonely_psu ,
424- replicate_weights = rep_weights_unit ,
425- replicate_method = resolved_survey .replicate_method ,
426- fay_rho = resolved_survey .fay_rho ,
427- n_replicates = resolved_survey .n_replicates ,
428- replicate_strata = resolved_survey .replicate_strata ,
429- combined_weights = resolved_survey .combined_weights ,
430- replicate_scale = resolved_survey .replicate_scale ,
431- replicate_rscales = resolved_survey .replicate_rscales ,
432- mse = resolved_survey .mse ,
433- )
359+ return collapse_survey_to_unit_level (resolved_survey , df , unit_col , all_units )
434360
435361 def _precompute_structures (
436362 self ,
0 commit comments