@@ -870,6 +870,157 @@ def test_replicate_if_no_divide_by_zero_warning(self):
870870 assert np .isfinite (v )
871871
872872
873+ class TestReplicateEdgeCases :
874+ """Regression tests for analysis-weight df and rscales centering."""
875+
876+ def test_df_survey_combined_weights_false (self ):
877+ """df_survey uses analysis-weight rank when combined_weights=False."""
878+ from diff_diff .survey import ResolvedSurveyDesign
879+
880+ np .random .seed (42 )
881+ n = 50
882+ R = 5
883+ weights = 1.0 + np .random .exponential (0.5 , n )
884+ # Perturbation factors (not full weights)
885+ rep_factors = np .random .uniform (0.8 , 1.2 , (n , R ))
886+
887+ resolved = ResolvedSurveyDesign (
888+ weights = weights , weight_type = "pweight" ,
889+ strata = None , psu = None , fpc = None ,
890+ n_strata = 0 , n_psu = 0 , lonely_psu = "remove" ,
891+ replicate_weights = rep_factors ,
892+ replicate_method = "BRR" , n_replicates = R ,
893+ combined_weights = False ,
894+ )
895+ # df should match rank of analysis weights (rep * full-sample)
896+ analysis_weights = rep_factors * weights [:, np .newaxis ]
897+ expected_rank = int (np .linalg .matrix_rank (analysis_weights ))
898+ expected_df = max (expected_rank - 1 , 1 )
899+ assert resolved .df_survey == expected_df
900+
901+ # Verify it differs from raw perturbation-factor rank when weights
902+ # cause a rank reduction (e.g., zero full-sample weights)
903+ weights_with_zeros = weights .copy ()
904+ weights_with_zeros [:10 ] = 0.0 # subpopulation-zeroed
905+ resolved2 = ResolvedSurveyDesign (
906+ weights = weights_with_zeros , weight_type = "pweight" ,
907+ strata = None , psu = None , fpc = None ,
908+ n_strata = 0 , n_psu = 0 , lonely_psu = "remove" ,
909+ replicate_weights = rep_factors ,
910+ replicate_method = "BRR" , n_replicates = R ,
911+ combined_weights = False ,
912+ )
913+ raw_rank = int (np .linalg .matrix_rank (rep_factors ))
914+ analysis_rank = int (np .linalg .matrix_rank (
915+ rep_factors * weights_with_zeros [:, np .newaxis ]
916+ ))
917+ # Analysis rank should be <= raw rank when zero weights present
918+ assert analysis_rank <= raw_rank
919+ assert resolved2 .df_survey == max (analysis_rank - 1 , 1 )
920+
921+ def test_rscales_zero_centering_vcov (self ):
922+ """mse=False with zero rscales: center only on rscales > 0 replicates."""
923+ from diff_diff .survey import compute_replicate_vcov , ResolvedSurveyDesign
924+ from diff_diff .linalg import solve_ols
925+
926+ np .random .seed (42 )
927+ n = 100
928+ R = 6
929+ x = np .random .randn (n )
930+ y = 1.0 + 2.0 * x + np .random .randn (n ) * 0.5
931+ X = np .column_stack ([np .ones (n ), x ])
932+ w = np .ones (n )
933+
934+ # Build JK1-style replicates
935+ cluster_size = n // R
936+ rep_arr = np .ones ((n , R ))
937+ for r in range (R ):
938+ start = r * cluster_size
939+ end = min ((r + 1 ) * cluster_size , n )
940+ rep_arr [start :end , :] = 0.0
941+ # Correct column r only
942+ rep_arr [:, r ] = np .where (
943+ (np .arange (n ) >= start ) & (np .arange (n ) < end ), 0.0 ,
944+ R / (R - 1 )
945+ )
946+
947+ # rscales with one zero entry
948+ rscales = np .array ([1.0 , 1.0 , 0.0 , 1.0 , 1.0 , 1.0 ])
949+
950+ coef , _ , _ = solve_ols (X , y , weights = w )
951+
952+ resolved = ResolvedSurveyDesign (
953+ weights = w , weight_type = "pweight" ,
954+ strata = None , psu = None , fpc = None ,
955+ n_strata = 0 , n_psu = 0 , lonely_psu = "remove" ,
956+ replicate_weights = rep_arr ,
957+ replicate_method = "BRR" , n_replicates = R ,
958+ replicate_rscales = rscales , mse = False ,
959+ )
960+ vcov , _nv = compute_replicate_vcov (X , y , coef , resolved )
961+
962+ # Manual computation: center only on replicates with rscales > 0
963+ coef_reps = []
964+ for r in range (R ):
965+ c_r , _ , _ = solve_ols (X , y , weights = rep_arr [:, r ])
966+ coef_reps .append (c_r )
967+ coef_reps = np .array (coef_reps )
968+ pos_mask = rscales > 0
969+ center = np .mean (coef_reps [pos_mask ], axis = 0 )
970+ diffs = coef_reps - center [np .newaxis , :]
971+ V_manual = np .zeros ((2 , 2 ))
972+ for r in range (R ):
973+ V_manual += rscales [r ] * np .outer (diffs [r ], diffs [r ])
974+
975+ assert np .allclose (np .diag (vcov ), np .diag (V_manual ), rtol = 1e-10 )
976+
977+ def test_rscales_zero_centering_if (self ):
978+ """mse=False with zero rscales: IF path centers only on rscales > 0."""
979+ from diff_diff .survey import compute_replicate_if_variance , ResolvedSurveyDesign
980+
981+ np .random .seed (42 )
982+ n = 50
983+ R = 5
984+ psi = np .random .randn (n ) * 0.1
985+ w = np .ones (n )
986+
987+ # Build simple replicates
988+ rep_arr = np .ones ((n , R ))
989+ for r in range (R ):
990+ start = r * (n // R )
991+ end = min ((r + 1 ) * (n // R ), n )
992+ rep_arr [start :end , r ] = 0.0
993+ rep_arr [:, r ] = np .where (
994+ (np .arange (n ) >= start ) & (np .arange (n ) < end ), 0.0 ,
995+ R / (R - 1 )
996+ )
997+
998+ rscales = np .array ([1.0 , 0.0 , 1.0 , 1.0 , 1.0 ])
999+
1000+ resolved = ResolvedSurveyDesign (
1001+ weights = w , weight_type = "pweight" ,
1002+ strata = None , psu = None , fpc = None ,
1003+ n_strata = 0 , n_psu = 0 , lonely_psu = "remove" ,
1004+ replicate_weights = rep_arr ,
1005+ replicate_method = "BRR" , n_replicates = R ,
1006+ replicate_rscales = rscales , mse = False ,
1007+ )
1008+ var , _nv = compute_replicate_if_variance (psi , resolved )
1009+
1010+ # Manual: theta_r = sum((w_r/w) * psi), center on rscales > 0 only
1011+ theta_full = float (np .sum (psi ))
1012+ theta_reps = np .array ([
1013+ float (np .sum (np .divide (rep_arr [:, r ], w , out = np .zeros (n ), where = w > 0 ) * psi ))
1014+ for r in range (R )
1015+ ])
1016+ pos_mask = rscales > 0
1017+ center = float (np .mean (theta_reps [pos_mask ]))
1018+ diffs = theta_reps - center
1019+ var_manual = float (np .sum (rscales * diffs ** 2 ))
1020+
1021+ assert var == pytest .approx (var_manual , rel = 1e-10 )
1022+
1023+
8731024# =============================================================================
8741025# Estimator-Level Replicate Weight Tests
8751026# =============================================================================
0 commit comments