@@ -428,235 +428,114 @@ def _nuisance_tuning(
428
428
):
429
429
x , y = check_X_y (self ._dml_data .x , self ._dml_data .y , force_all_finite = False )
430
430
x , d = check_X_y (x , self ._dml_data .d , force_all_finite = False )
431
- # time indicator is used for selection (selection not available in DoubleMLData yet)
432
431
x , s = check_X_y (x , self ._dml_data .s , force_all_finite = False )
433
432
434
433
if self ._score == "nonignorable" :
435
434
z , _ = check_X_y (self ._dml_data .z , y , force_all_finite = False )
436
- dx = np .column_stack ((x , d , z ))
437
- else :
438
- dx = np .column_stack ((x , d ))
439
435
440
436
if scoring_methods is None :
441
437
scoring_methods = {"ml_g" : None , "ml_pi" : None , "ml_m" : None }
442
438
439
+ # Nested helper functions
440
+ def tune_learner (target , features , train_indices , learner_key ):
441
+ return _dml_tune (
442
+ target ,
443
+ features ,
444
+ train_indices ,
445
+ self ._learner [learner_key ],
446
+ param_grids [learner_key ],
447
+ scoring_methods [learner_key ],
448
+ n_folds_tune ,
449
+ n_jobs_cv ,
450
+ search_mode ,
451
+ n_iter_randomized_search ,
452
+ )
453
+
454
+ def split_inner_folds (train_inds , d , s , random_state = 42 ):
455
+ inner_train0_inds , inner_train1_inds = [], []
456
+ for train_index in train_inds :
457
+ stratify_vec = d [train_index ] + 2 * s [train_index ]
458
+ inner0 , inner1 = train_test_split (train_index , test_size = 0.5 , stratify = stratify_vec , random_state = random_state )
459
+ inner_train0_inds .append (inner0 )
460
+ inner_train1_inds .append (inner1 )
461
+ return inner_train0_inds , inner_train1_inds
462
+
463
+ def filter_by_ds (inner_train1_inds , d , s ):
464
+ inner1_d0_s1 , inner1_d1_s1 = [], []
465
+ for inner1 in inner_train1_inds :
466
+ d_fold , s_fold = d [inner1 ], s [inner1 ]
467
+ mask_d0_s1 = (d_fold == 0 ) & (s_fold == 1 )
468
+ mask_d1_s1 = (d_fold == 1 ) & (s_fold == 1 )
469
+
470
+ inner1_d0_s1 .append (inner1 [mask_d0_s1 ])
471
+ inner1_d1_s1 .append (inner1 [mask_d1_s1 ])
472
+ return inner1_d0_s1 , inner1_d1_s1
473
+
443
474
if self ._score == "nonignorable" :
444
475
445
476
train_inds = [train_index for (train_index , _ ) in smpls ]
446
477
447
478
# inner folds: split train set into two halves (pi-tuning vs. m/g-tuning)
448
- def get_inner_train_inds (train_inds , d , s , random_state = 42 ):
449
- inner_train0_inds = []
450
- inner_train1_inds = []
451
-
452
- for train_index in train_inds :
453
- d_fold = d [train_index ]
454
- s_fold = s [train_index ]
455
- stratify_vec = d_fold + 2 * s_fold
456
-
457
- inner0 , inner1 = train_test_split (
458
- train_index , test_size = 0.5 , stratify = stratify_vec , random_state = random_state
459
- )
460
-
461
- inner_train0_inds .append (inner0 )
462
- inner_train1_inds .append (inner1 )
463
-
464
- return inner_train0_inds , inner_train1_inds
465
-
466
- inner_train0_inds , inner_train1_inds = get_inner_train_inds (train_inds , d , s )
467
-
479
+ inner_train0_inds , inner_train1_inds = split_inner_folds (train_inds , d , s )
468
480
# split inner1 by (d,s) to build g-models for treated/control
469
- def filter_inner1_by_ds (inner_train1_inds , d , s ):
470
- inner1_d0_s1 = []
471
- inner1_d1_s1 = []
472
-
473
- for inner1 in inner_train1_inds :
474
- d_fold = d [inner1 ]
475
- s_fold = s [inner1 ]
476
-
477
- mask_d0_s1 = (d_fold == 0 ) & (s_fold == 1 )
478
- mask_d1_s1 = (d_fold == 1 ) & (s_fold == 1 )
479
-
480
- inner1_d0_s1 .append (inner1 [mask_d0_s1 ])
481
- inner1_d1_s1 .append (inner1 [mask_d1_s1 ])
482
-
483
- return inner1_d0_s1 , inner1_d1_s1
484
-
485
- inner_train1_d0_s1 , inner_train1_d1_s1 = filter_inner1_by_ds (inner_train1_inds , d , s )
486
-
487
- x_d_z = np .concatenate ([x , d .reshape (- 1 , 1 ), z .reshape (- 1 , 1 )], axis = 1 )
488
-
489
- # ml_pi: tune on inner0, predict pi-hat on inner1
490
- pi_hat_list = []
491
- pi_tune_res_nonignorable = []
481
+ inner_train1_d0_s1 , inner_train1_d1_s1 = filter_by_ds (inner_train1_inds , d , s )
492
482
483
+ # Tune ml_pi
484
+ x_d_z = np .column_stack ((x , d , z ))
485
+ pi_tune_res = []
486
+ pi_hat_full = np .full (shape = s .shape , fill_value = np .nan )
493
487
for inner0 , inner1 in zip (inner_train0_inds , inner_train1_inds ):
488
+ res = tune_learner (s , x_d_z , [inner0 ], "ml_pi" )
489
+ best_params = res [0 ].best_params_
494
490
495
- # tune pi on inner0
496
- pi_tune_res = _dml_tune (
497
- s ,
498
- x_d_z ,
499
- [inner0 ],
500
- self ._learner ["ml_pi" ],
501
- param_grids ["ml_pi" ],
502
- scoring_methods ["ml_pi" ],
503
- n_folds_tune ,
504
- n_jobs_cv ,
505
- search_mode ,
506
- n_iter_randomized_search ,
507
- )
508
- best_params = pi_tune_res [0 ].best_params_
509
-
510
- # fit tuned model
491
+ # Fit tuned model and predict
511
492
ml_pi_temp = clone (self ._learner ["ml_pi" ])
512
493
ml_pi_temp .set_params (** best_params )
513
494
ml_pi_temp .fit (x_d_z [inner0 ], s [inner0 ])
495
+ pi_hat_full [inner1 ] = _predict_zero_one_propensity (ml_pi_temp , x_d_z )[inner1 ]
496
+ pi_tune_res .append (res [0 ])
514
497
515
- # predict proba on inner1
516
- pi_hat_all = _predict_zero_one_propensity (ml_pi_temp , x_d_z )
517
- pi_hat = pi_hat_all [inner1 ]
518
- pi_hat_list .append ((inner1 , pi_hat )) # (index, value) tuple
519
-
520
- # save best params
521
- pi_tune_res_nonignorable .append (pi_tune_res [0 ])
522
-
523
- pi_hat_full = np .full (shape = s .shape , fill_value = np .nan )
524
-
525
- for inner1 , pi_hat in pi_hat_list :
526
- pi_hat_full [inner1 ] = pi_hat
527
-
528
- # ml_m: tune with x + pi-hats
529
- x_pi = np .concatenate ([x , pi_hat_full .reshape (- 1 , 1 )], axis = 1 )
530
-
531
- m_tune_res = _dml_tune (
532
- d ,
533
- x_pi ,
534
- inner_train1_inds ,
535
- self ._learner ["ml_m" ],
536
- param_grids ["ml_m" ],
537
- scoring_methods ["ml_m" ],
538
- n_folds_tune ,
539
- n_jobs_cv ,
540
- search_mode ,
541
- n_iter_randomized_search ,
542
- )
543
-
544
- # ml_g: tune with x + d + pi-hats for d=0, d=1
545
- x_pi_d = np .concatenate ([x , d .reshape (- 1 , 1 ), pi_hat_full .reshape (- 1 , 1 )], axis = 1 )
546
-
547
- g_d0_tune_res = _dml_tune (
548
- y ,
549
- x_pi_d ,
550
- inner_train1_d0_s1 ,
551
- self ._learner ["ml_g" ],
552
- param_grids ["ml_g" ],
553
- scoring_methods ["ml_g" ],
554
- n_folds_tune ,
555
- n_jobs_cv ,
556
- search_mode ,
557
- n_iter_randomized_search ,
558
- )
559
- g_d1_tune_res = _dml_tune (
560
- y ,
561
- x_pi_d ,
562
- inner_train1_d1_s1 ,
563
- self ._learner ["ml_g" ],
564
- param_grids ["ml_g" ],
565
- scoring_methods ["ml_g" ],
566
- n_folds_tune ,
567
- n_jobs_cv ,
568
- search_mode ,
569
- n_iter_randomized_search ,
570
- )
571
-
572
- g_d0_best_params = [xx .best_params_ for xx in g_d0_tune_res ]
573
- g_d1_best_params = [xx .best_params_ for xx in g_d1_tune_res ]
574
- pi_best_params = [xx .best_params_ for xx in pi_tune_res_nonignorable ]
575
- m_best_params = [xx .best_params_ for xx in m_tune_res ]
498
+ # Tune ml_m with x + pi-hats
499
+ x_pi = np .column_stack ([x , pi_hat_full .reshape (- 1 , 1 )])
500
+ m_tune_res = tune_learner (d , x_pi , inner_train1_inds , "ml_m" )
576
501
577
- params = {"ml_g_d0" : g_d0_best_params , "ml_g_d1" : g_d1_best_params , "ml_pi" : pi_best_params , "ml_m" : m_best_params }
578
-
579
- tune_res = {
580
- "g_d0_tune" : g_d0_tune_res ,
581
- "g_d1_tune" : g_d1_tune_res ,
582
- "pi_tune" : pi_tune_res_nonignorable ,
583
- "m_tune" : m_tune_res ,
584
- }
585
-
586
- res = {"params" : params , "tune_res" : tune_res }
587
-
588
- return res
502
+ # Tune ml_g for d=0 and d=1
503
+ x_pi_d = np .column_stack ([x , d .reshape (- 1 , 1 ), pi_hat_full .reshape (- 1 , 1 )])
504
+ g_d0_tune_res = tune_learner (y , x_pi_d , inner_train1_d0_s1 , "ml_g" )
505
+ g_d1_tune_res = tune_learner (y , x_pi_d , inner_train1_d1_s1 , "ml_g" )
589
506
590
507
else :
591
-
592
508
# nuisance training sets conditional on d
593
509
_ , smpls_d0_s1 , _ , smpls_d1_s1 = _get_cond_smpls_2d (smpls , d , s )
594
510
train_inds = [train_index for (train_index , _ ) in smpls ]
595
511
train_inds_d0_s1 = [train_index for (train_index , _ ) in smpls_d0_s1 ]
596
512
train_inds_d1_s1 = [train_index for (train_index , _ ) in smpls_d1_s1 ]
597
513
598
- # hyperparameter tuning for ML
599
- g_d0_tune_res = _dml_tune (
600
- y ,
601
- x ,
602
- train_inds_d0_s1 ,
603
- self ._learner ["ml_g" ],
604
- param_grids ["ml_g" ],
605
- scoring_methods ["ml_g" ],
606
- n_folds_tune ,
607
- n_jobs_cv ,
608
- search_mode ,
609
- n_iter_randomized_search ,
610
- )
611
- g_d1_tune_res = _dml_tune (
612
- y ,
613
- x ,
614
- train_inds_d1_s1 ,
615
- self ._learner ["ml_g" ],
616
- param_grids ["ml_g" ],
617
- scoring_methods ["ml_g" ],
618
- n_folds_tune ,
619
- n_jobs_cv ,
620
- search_mode ,
621
- n_iter_randomized_search ,
622
- )
623
- pi_tune_res = _dml_tune (
624
- s ,
625
- dx ,
626
- train_inds ,
627
- self ._learner ["ml_pi" ],
628
- param_grids ["ml_pi" ],
629
- scoring_methods ["ml_pi" ],
630
- n_folds_tune ,
631
- n_jobs_cv ,
632
- search_mode ,
633
- n_iter_randomized_search ,
634
- )
635
- m_tune_res = _dml_tune (
636
- d ,
637
- x ,
638
- train_inds ,
639
- self ._learner ["ml_m" ],
640
- param_grids ["ml_m" ],
641
- scoring_methods ["ml_m" ],
642
- n_folds_tune ,
643
- n_jobs_cv ,
644
- search_mode ,
645
- n_iter_randomized_search ,
646
- )
647
-
648
- g_d0_best_params = [xx .best_params_ for xx in g_d0_tune_res ]
649
- g_d1_best_params = [xx .best_params_ for xx in g_d1_tune_res ]
650
- pi_best_params = [xx .best_params_ for xx in pi_tune_res ]
651
- m_best_params = [xx .best_params_ for xx in m_tune_res ]
652
-
653
- params = {"ml_g_d0" : g_d0_best_params , "ml_g_d1" : g_d1_best_params , "ml_pi" : pi_best_params , "ml_m" : m_best_params }
654
-
655
- tune_res = {"g_d0_tune" : g_d0_tune_res , "g_d1_tune" : g_d1_tune_res , "pi_tune" : pi_tune_res , "m_tune" : m_tune_res }
514
+ # Tune ml_g for d=0 and d=1
515
+ g_d0_tune_res = tune_learner (y , x , train_inds_d0_s1 , "ml_g" )
516
+ g_d1_tune_res = tune_learner (y , x , train_inds_d1_s1 , "ml_g" )
517
+
518
+ # Tune ml_pi and ml_m
519
+ x_d = np .column_stack ((x , d ))
520
+ pi_tune_res = tune_learner (s , x_d , train_inds , "ml_pi" )
521
+ m_tune_res = tune_learner (d , x , train_inds , "ml_m" )
522
+
523
+ # Collect results
524
+ params = {
525
+ "ml_g_d0" : [res .best_params_ for res in g_d0_tune_res ],
526
+ "ml_g_d1" : [res .best_params_ for res in g_d1_tune_res ],
527
+ "ml_pi" : [res .best_params_ for res in pi_tune_res ],
528
+ "ml_m" : [res .best_params_ for res in m_tune_res ],
529
+ }
656
530
657
- res = {"params" : params , "tune_res" : tune_res }
531
+ tune_res = {
532
+ "g_d0_tune" : g_d0_tune_res ,
533
+ "g_d1_tune" : g_d1_tune_res ,
534
+ "pi_tune" : pi_tune_res ,
535
+ "m_tune" : m_tune_res ,
536
+ }
658
537
659
- return res
538
+ return { "params" : params , "tune_res" : tune_res }
660
539
661
540
def _sensitivity_element_est (self , preds ):
662
541
pass
0 commit comments