@@ -378,12 +378,13 @@ def test_brr_replicate(self, sddd_data):
378378 rng = np .random .default_rng (99 )
379379 n_units = data ["unit" ].nunique ()
380380 R = 20
381- # Generate unit-level replicate weights
381+ # Generate combined replicate weights: rep_r = weight * factor_r
382+ # (combined_weights=True means each column includes full-sample weight)
382383 unit_ids = sorted (data ["unit" ].unique ())
383- rep_matrix = 1.0 + rng .standard_normal ((n_units , R )) * 0.1
384- rep_matrix = np .abs (rep_matrix ) # Ensure positive
384+ base_w = data .groupby ("unit" )["weight" ].first ().reindex (unit_ids ).values
385385 for r in range (R ):
386- unit_w = dict (zip (unit_ids , rep_matrix [:, r ]))
386+ factor = np .abs (1.0 + rng .standard_normal (n_units ) * 0.1 )
387+ unit_w = dict (zip (unit_ids , base_w * factor ))
387388 data [f"rep_{ r } " ] = data ["unit" ].map (unit_w )
388389
389390 rep_cols = [f"rep_{ r } " for r in range (R )]
@@ -422,9 +423,10 @@ def test_brr_with_bootstrap_rejected(self, sddd_data):
422423 n_units = data ["unit" ].nunique ()
423424 unit_ids = sorted (data ["unit" ].unique ())
424425 R = 10
425- rep_matrix = np . abs ( 1.0 + rng . standard_normal (( n_units , R )) * 0.1 )
426+ base_w = data . groupby ( "unit" )[ "weight" ]. first (). reindex ( unit_ids ). values
426427 for r in range (R ):
427- unit_w = dict (zip (unit_ids , rep_matrix [:, r ]))
428+ factor = np .abs (1.0 + rng .standard_normal (n_units ) * 0.1 )
429+ unit_w = dict (zip (unit_ids , base_w * factor ))
428430 data [f"rep_{ r } " ] = data ["unit" ].map (unit_w )
429431
430432 rep_cols = [f"rep_{ r } " for r in range (R )]
@@ -449,6 +451,87 @@ def test_brr_with_bootstrap_rejected(self, sddd_data):
449451 )
450452
451453
454+ # ---------------------------------------------------------------------------
455+ # Replicate-weight scale invariance
456+ # ---------------------------------------------------------------------------
457+
458+
459+ def _make_brr_data (sddd_data , rng_seed = 99 , R = 20 ):
460+ """Helper: build combined BRR replicate weights for sddd_data."""
461+ data = sddd_data .copy ()
462+ rng = np .random .default_rng (rng_seed )
463+ unit_ids = sorted (data ["unit" ].unique ())
464+ n_units = len (unit_ids )
465+ base_w = data .groupby ("unit" )["weight" ].first ().reindex (unit_ids ).values
466+ for r in range (R ):
467+ factor = np .abs (1.0 + rng .standard_normal (n_units ) * 0.1 )
468+ unit_w = dict (zip (unit_ids , base_w * factor ))
469+ data [f"rep_{ r } " ] = data ["unit" ].map (unit_w )
470+ rep_cols = [f"rep_{ r } " for r in range (R )]
471+ return data , rep_cols
472+
473+
474+ class TestReplicateScaleInvariance :
475+ """Rescaling all weights + replicates by constant k must not change results."""
476+
477+ @pytest .mark .parametrize ("agg" , ["simple" , "event_study" , "group" ])
478+ def test_scale_invariance (self , sddd_data , agg ):
479+ data , rep_cols = _make_brr_data (sddd_data )
480+ k = 5.0
481+
482+ sd1 = SurveyDesign (
483+ weights = "weight" ,
484+ replicate_weights = rep_cols ,
485+ replicate_method = "BRR" ,
486+ )
487+ est = StaggeredTripleDifference (estimation_method = "reg" )
488+ res1 = est .fit (
489+ data ,
490+ "outcome" ,
491+ "unit" ,
492+ "period" ,
493+ "first_treat" ,
494+ "eligibility" ,
495+ aggregate = agg ,
496+ survey_design = sd1 ,
497+ )
498+
499+ # Scale all weights and replicate columns by k
500+ data_k = data .copy ()
501+ data_k ["weight" ] = data_k ["weight" ] * k
502+ for col in rep_cols :
503+ data_k [col ] = data_k [col ] * k
504+
505+ sd2 = SurveyDesign (
506+ weights = "weight" ,
507+ replicate_weights = rep_cols ,
508+ replicate_method = "BRR" ,
509+ )
510+ res2 = est .fit (
511+ data_k ,
512+ "outcome" ,
513+ "unit" ,
514+ "period" ,
515+ "first_treat" ,
516+ "eligibility" ,
517+ aggregate = agg ,
518+ survey_design = sd2 ,
519+ )
520+
521+ np .testing .assert_allclose (
522+ res2 .overall_att ,
523+ res1 .overall_att ,
524+ atol = 1e-12 ,
525+ err_msg = f"ATT changed with weight rescaling (agg={ agg } )" ,
526+ )
527+ np .testing .assert_allclose (
528+ res2 .overall_se ,
529+ res1 .overall_se ,
530+ rtol = 1e-6 ,
531+ err_msg = f"SE changed with weight rescaling (agg={ agg } )" ,
532+ )
533+
534+
452535# ---------------------------------------------------------------------------
453536# Survey-weighted aggregation point estimates
454537# ---------------------------------------------------------------------------
0 commit comments