@@ -440,74 +440,223 @@ def _nuisance_tuning(
440
440
if scoring_methods is None :
441
441
scoring_methods = {"ml_g" : None , "ml_pi" : None , "ml_m" : None }
442
442
443
- # nuisance training sets conditional on d
444
- _ , smpls_d0_s1 , _ , smpls_d1_s1 = _get_cond_smpls_2d (smpls , d , s )
445
- train_inds = [train_index for (train_index , _ ) in smpls ]
446
- train_inds_d0_s1 = [train_index for (train_index , _ ) in smpls_d0_s1 ]
447
- train_inds_d1_s1 = [train_index for (train_index , _ ) in smpls_d1_s1 ]
448
-
449
- # hyperparameter tuning for ML
450
- g_d0_tune_res = _dml_tune (
451
- y ,
452
- x ,
453
- train_inds_d0_s1 ,
454
- self ._learner ["ml_g" ],
455
- param_grids ["ml_g" ],
456
- scoring_methods ["ml_g" ],
457
- n_folds_tune ,
458
- n_jobs_cv ,
459
- search_mode ,
460
- n_iter_randomized_search ,
461
- )
462
- g_d1_tune_res = _dml_tune (
463
- y ,
464
- x ,
465
- train_inds_d1_s1 ,
466
- self ._learner ["ml_g" ],
467
- param_grids ["ml_g" ],
468
- scoring_methods ["ml_g" ],
469
- n_folds_tune ,
470
- n_jobs_cv ,
471
- search_mode ,
472
- n_iter_randomized_search ,
473
- )
474
- pi_tune_res = _dml_tune (
475
- s ,
476
- dx ,
477
- train_inds ,
478
- self ._learner ["ml_pi" ],
479
- param_grids ["ml_pi" ],
480
- scoring_methods ["ml_pi" ],
481
- n_folds_tune ,
482
- n_jobs_cv ,
483
- search_mode ,
484
- n_iter_randomized_search ,
485
- )
486
- m_tune_res = _dml_tune (
487
- d ,
488
- x ,
489
- train_inds ,
490
- self ._learner ["ml_m" ],
491
- param_grids ["ml_m" ],
492
- scoring_methods ["ml_m" ],
493
- n_folds_tune ,
494
- n_jobs_cv ,
495
- search_mode ,
496
- n_iter_randomized_search ,
497
- )
443
+ if self ._score == "nonignorable" :
444
+
445
+ train_inds = [train_index for (train_index , _ ) in smpls ]
446
+
447
+ # inner folds: split train set into two halves (pi-tuning vs. m/g-tuning)
448
+ def get_inner_train_inds (train_inds , d , s , random_state = 42 ):
449
+ inner_train0_inds = []
450
+ inner_train1_inds = []
451
+
452
+ for train_index in train_inds :
453
+ d_fold = d [train_index ]
454
+ s_fold = s [train_index ]
455
+ stratify_vec = d_fold + 2 * s_fold
456
+
457
+ inner0 , inner1 = train_test_split (
458
+ train_index , test_size = 0.5 , stratify = stratify_vec , random_state = random_state
459
+ )
460
+
461
+ inner_train0_inds .append (inner0 )
462
+ inner_train1_inds .append (inner1 )
463
+
464
+ return inner_train0_inds , inner_train1_inds
465
+
466
+ inner_train0_inds , inner_train1_inds = get_inner_train_inds (train_inds , d , s )
467
+
468
+ # split inner1 by (d,s) to build g-models for treated/control
469
+ def filter_inner1_by_ds (inner_train1_inds , d , s ):
470
+ inner1_d0_s1 = []
471
+ inner1_d1_s1 = []
472
+
473
+ for inner1 in inner_train1_inds :
474
+ d_fold = d [inner1 ]
475
+ s_fold = s [inner1 ]
476
+
477
+ mask_d0_s1 = (d_fold == 0 ) & (s_fold == 1 )
478
+ mask_d1_s1 = (d_fold == 1 ) & (s_fold == 1 )
479
+
480
+ inner1_d0_s1 .append (inner1 [mask_d0_s1 ])
481
+ inner1_d1_s1 .append (inner1 [mask_d1_s1 ])
482
+
483
+ return inner1_d0_s1 , inner1_d1_s1
484
+
485
+ inner_train1_d0_s1 , inner_train1_d1_s1 = filter_inner1_by_ds (inner_train1_inds , d , s )
486
+
487
+ x_d_z = np .concatenate ([x , d .reshape (- 1 , 1 ), z .reshape (- 1 , 1 )], axis = 1 )
488
+
489
+ # ml_pi: tune on inner0, predict pi-hat on inner1
490
+ pi_hat_list = []
491
+ pi_tune_res_nonignorable = []
492
+
493
+ for inner0 , inner1 in zip (inner_train0_inds , inner_train1_inds ):
494
+
495
+ # tune pi on inner0
496
+ pi_tune_res = _dml_tune (
497
+ s ,
498
+ x_d_z ,
499
+ [inner0 ],
500
+ self ._learner ["ml_pi" ],
501
+ param_grids ["ml_pi" ],
502
+ scoring_methods ["ml_pi" ],
503
+ n_folds_tune ,
504
+ n_jobs_cv ,
505
+ search_mode ,
506
+ n_iter_randomized_search ,
507
+ )
508
+ best_params = pi_tune_res [0 ].best_params_
509
+
510
+ # fit tuned model
511
+ ml_pi_temp = clone (self ._learner ["ml_pi" ])
512
+ ml_pi_temp .set_params (** best_params )
513
+ ml_pi_temp .fit (x_d_z [inner0 ], s [inner0 ])
514
+
515
+ # predict proba on inner1
516
+ pi_hat_all = _predict_zero_one_propensity (ml_pi_temp , x_d_z )
517
+ pi_hat = pi_hat_all [inner1 ]
518
+ pi_hat_list .append ((inner1 , pi_hat )) # (index, value) tuple
519
+
520
+ # save best params
521
+ pi_tune_res_nonignorable .append (pi_tune_res [0 ])
522
+
523
+ pi_hat_full = np .full (shape = s .shape , fill_value = np .nan )
524
+
525
+ for inner1 , pi_hat in pi_hat_list :
526
+ pi_hat_full [inner1 ] = pi_hat
527
+
528
+ # ml_m: tune with x + pi-hats
529
+ x_pi = np .concatenate ([x , pi_hat_full .reshape (- 1 , 1 )], axis = 1 )
530
+
531
+ m_tune_res = _dml_tune (
532
+ d ,
533
+ x_pi ,
534
+ inner_train1_inds ,
535
+ self ._learner ["ml_m" ],
536
+ param_grids ["ml_m" ],
537
+ scoring_methods ["ml_m" ],
538
+ n_folds_tune ,
539
+ n_jobs_cv ,
540
+ search_mode ,
541
+ n_iter_randomized_search ,
542
+ )
543
+
544
+ # ml_g: tune with x + d + pi-hats for d=0, d=1
545
+ x_pi_d = np .concatenate ([x , d .reshape (- 1 , 1 ), pi_hat_full .reshape (- 1 , 1 )], axis = 1 )
546
+
547
+ g_d0_tune_res = _dml_tune (
548
+ y ,
549
+ x_pi_d ,
550
+ inner_train1_d0_s1 ,
551
+ self ._learner ["ml_g" ],
552
+ param_grids ["ml_g" ],
553
+ scoring_methods ["ml_g" ],
554
+ n_folds_tune ,
555
+ n_jobs_cv ,
556
+ search_mode ,
557
+ n_iter_randomized_search ,
558
+ )
559
+ g_d1_tune_res = _dml_tune (
560
+ y ,
561
+ x_pi_d ,
562
+ inner_train1_d1_s1 ,
563
+ self ._learner ["ml_g" ],
564
+ param_grids ["ml_g" ],
565
+ scoring_methods ["ml_g" ],
566
+ n_folds_tune ,
567
+ n_jobs_cv ,
568
+ search_mode ,
569
+ n_iter_randomized_search ,
570
+ )
571
+
572
+ g_d0_best_params = [xx .best_params_ for xx in g_d0_tune_res ]
573
+ g_d1_best_params = [xx .best_params_ for xx in g_d1_tune_res ]
574
+ pi_best_params = [xx .best_params_ for xx in pi_tune_res_nonignorable ]
575
+ m_best_params = [xx .best_params_ for xx in m_tune_res ]
576
+
577
+ params = {"ml_g_d0" : g_d0_best_params , "ml_g_d1" : g_d1_best_params , "ml_pi" : pi_best_params , "ml_m" : m_best_params }
578
+
579
+ tune_res = {
580
+ "g_d0_tune" : g_d0_tune_res ,
581
+ "g_d1_tune" : g_d1_tune_res ,
582
+ "pi_tune" : pi_tune_res_nonignorable ,
583
+ "m_tune" : m_tune_res ,
584
+ }
585
+
586
+ res = {"params" : params , "tune_res" : tune_res }
587
+
588
+ return res
589
+
590
+ else :
591
+
592
+ # nuisance training sets conditional on d
593
+ _ , smpls_d0_s1 , _ , smpls_d1_s1 = _get_cond_smpls_2d (smpls , d , s )
594
+ train_inds = [train_index for (train_index , _ ) in smpls ]
595
+ train_inds_d0_s1 = [train_index for (train_index , _ ) in smpls_d0_s1 ]
596
+ train_inds_d1_s1 = [train_index for (train_index , _ ) in smpls_d1_s1 ]
597
+
598
+ # hyperparameter tuning for ML
599
+ g_d0_tune_res = _dml_tune (
600
+ y ,
601
+ x ,
602
+ train_inds_d0_s1 ,
603
+ self ._learner ["ml_g" ],
604
+ param_grids ["ml_g" ],
605
+ scoring_methods ["ml_g" ],
606
+ n_folds_tune ,
607
+ n_jobs_cv ,
608
+ search_mode ,
609
+ n_iter_randomized_search ,
610
+ )
611
+ g_d1_tune_res = _dml_tune (
612
+ y ,
613
+ x ,
614
+ train_inds_d1_s1 ,
615
+ self ._learner ["ml_g" ],
616
+ param_grids ["ml_g" ],
617
+ scoring_methods ["ml_g" ],
618
+ n_folds_tune ,
619
+ n_jobs_cv ,
620
+ search_mode ,
621
+ n_iter_randomized_search ,
622
+ )
623
+ pi_tune_res = _dml_tune (
624
+ s ,
625
+ dx ,
626
+ train_inds ,
627
+ self ._learner ["ml_pi" ],
628
+ param_grids ["ml_pi" ],
629
+ scoring_methods ["ml_pi" ],
630
+ n_folds_tune ,
631
+ n_jobs_cv ,
632
+ search_mode ,
633
+ n_iter_randomized_search ,
634
+ )
635
+ m_tune_res = _dml_tune (
636
+ d ,
637
+ x ,
638
+ train_inds ,
639
+ self ._learner ["ml_m" ],
640
+ param_grids ["ml_m" ],
641
+ scoring_methods ["ml_m" ],
642
+ n_folds_tune ,
643
+ n_jobs_cv ,
644
+ search_mode ,
645
+ n_iter_randomized_search ,
646
+ )
498
647
499
- g_d0_best_params = [xx .best_params_ for xx in g_d0_tune_res ]
500
- g_d1_best_params = [xx .best_params_ for xx in g_d1_tune_res ]
501
- pi_best_params = [xx .best_params_ for xx in pi_tune_res ]
502
- m_best_params = [xx .best_params_ for xx in m_tune_res ]
648
+ g_d0_best_params = [xx .best_params_ for xx in g_d0_tune_res ]
649
+ g_d1_best_params = [xx .best_params_ for xx in g_d1_tune_res ]
650
+ pi_best_params = [xx .best_params_ for xx in pi_tune_res ]
651
+ m_best_params = [xx .best_params_ for xx in m_tune_res ]
503
652
504
- params = {"ml_g_d0" : g_d0_best_params , "ml_g_d1" : g_d1_best_params , "ml_pi" : pi_best_params , "ml_m" : m_best_params }
653
+ params = {"ml_g_d0" : g_d0_best_params , "ml_g_d1" : g_d1_best_params , "ml_pi" : pi_best_params , "ml_m" : m_best_params }
505
654
506
- tune_res = {"g_d0_tune" : g_d0_tune_res , "g_d1_tune" : g_d1_tune_res , "pi_tune" : pi_tune_res , "m_tune" : m_tune_res }
655
+ tune_res = {"g_d0_tune" : g_d0_tune_res , "g_d1_tune" : g_d1_tune_res , "pi_tune" : pi_tune_res , "m_tune" : m_tune_res }
507
656
508
- res = {"params" : params , "tune_res" : tune_res }
657
+ res = {"params" : params , "tune_res" : tune_res }
509
658
510
- return res
659
+ return res
511
660
512
661
def _sensitivity_element_est (self , preds ):
513
662
pass
0 commit comments