@@ -332,6 +332,87 @@ def test_summary_includes_survey(self, staggered_survey_data, survey_design_weig
332332 assert "Survey Design" in summary
333333 assert "pweight" in summary
334334
335+ def test_se_scale_invariance_fe_only (self , staggered_survey_data ):
336+ """SE must be invariant to weight rescaling (FE-only, no covariates)."""
337+ data = staggered_survey_data .copy ()
338+ data ["weight2" ] = data ["weight" ] * 3.1
339+ sd1 = SurveyDesign (weights = "weight" )
340+ sd2 = SurveyDesign (weights = "weight2" )
341+ r1 = ImputationDiD ().fit (
342+ data ,
343+ "outcome" ,
344+ "unit" ,
345+ "period" ,
346+ "first_treat" ,
347+ survey_design = sd1 ,
348+ )
349+ r2 = ImputationDiD ().fit (
350+ data ,
351+ "outcome" ,
352+ "unit" ,
353+ "period" ,
354+ "first_treat" ,
355+ survey_design = sd2 ,
356+ )
357+ assert np .isclose (r1 .overall_att , r2 .overall_att , atol = 1e-8 )
358+ assert np .isclose (
359+ r1 .overall_se , r2 .overall_se , atol = 1e-8
360+ ), f"SE not scale-invariant (FE-only): { r1 .overall_se } vs { r2 .overall_se } "
361+
362+ def test_se_scale_invariance_with_covariates (self , staggered_survey_data ):
363+ """SE must be invariant to weight rescaling (with covariates)."""
364+ data = staggered_survey_data .copy ()
365+ data ["x1" ] = np .random .default_rng (99 ).normal (0 , 1 , len (data ))
366+ data ["weight2" ] = data ["weight" ] * 3.1
367+ sd1 = SurveyDesign (weights = "weight" )
368+ sd2 = SurveyDesign (weights = "weight2" )
369+ r1 = ImputationDiD ().fit (
370+ data ,
371+ "outcome" ,
372+ "unit" ,
373+ "period" ,
374+ "first_treat" ,
375+ covariates = ["x1" ],
376+ survey_design = sd1 ,
377+ )
378+ r2 = ImputationDiD ().fit (
379+ data ,
380+ "outcome" ,
381+ "unit" ,
382+ "period" ,
383+ "first_treat" ,
384+ covariates = ["x1" ],
385+ survey_design = sd2 ,
386+ )
387+ assert np .isclose (r1 .overall_att , r2 .overall_att , atol = 1e-8 )
388+ assert np .isclose (
389+ r1 .overall_se , r2 .overall_se , atol = 1e-8
390+ ), f"SE not scale-invariant (covariates): { r1 .overall_se } vs { r2 .overall_se } "
391+
392+ def test_wrapper_imputation_did_with_survey (self , staggered_survey_data ):
393+ """imputation_did() wrapper forwards survey_design correctly."""
394+ from diff_diff import imputation_did
395+
396+ sd = SurveyDesign (weights = "weight" )
397+ r_wrapper = imputation_did (
398+ staggered_survey_data ,
399+ "outcome" ,
400+ "unit" ,
401+ "period" ,
402+ "first_treat" ,
403+ survey_design = sd ,
404+ )
405+ r_direct = ImputationDiD ().fit (
406+ staggered_survey_data ,
407+ "outcome" ,
408+ "unit" ,
409+ "period" ,
410+ "first_treat" ,
411+ survey_design = sd ,
412+ )
413+ assert np .isclose (r_wrapper .overall_att , r_direct .overall_att , atol = 1e-10 )
414+ assert r_wrapper .survey_metadata is not None
415+
335416
336417# =============================================================================
337418# TestTwoStageDiDSurvey
@@ -460,6 +541,30 @@ def test_summary_includes_survey(self, staggered_survey_data, survey_design_weig
460541 assert "Survey Design" in summary
461542 assert "pweight" in summary
462543
544+ def test_wrapper_two_stage_did_with_survey (self , staggered_survey_data ):
545+ """two_stage_did() wrapper forwards survey_design correctly."""
546+ from diff_diff import two_stage_did
547+
548+ sd = SurveyDesign (weights = "weight" )
549+ r_wrapper = two_stage_did (
550+ staggered_survey_data ,
551+ "outcome" ,
552+ "unit" ,
553+ "period" ,
554+ "first_treat" ,
555+ survey_design = sd ,
556+ )
557+ r_direct = TwoStageDiD ().fit (
558+ staggered_survey_data ,
559+ "outcome" ,
560+ "unit" ,
561+ "period" ,
562+ "first_treat" ,
563+ survey_design = sd ,
564+ )
565+ assert np .isclose (r_wrapper .overall_att , r_direct .overall_att , atol = 1e-10 )
566+ assert r_wrapper .survey_metadata is not None
567+
463568 def test_always_treated_with_survey (self , staggered_survey_data ):
464569 """TwoStageDiD with survey + always-treated units should not crash."""
465570 data = staggered_survey_data .copy ()
0 commit comments