@@ -428,86 +428,114 @@ def _nuisance_tuning(
428
428
):
429
429
x , y = check_X_y (self ._dml_data .x , self ._dml_data .y , force_all_finite = False )
430
430
x , d = check_X_y (x , self ._dml_data .d , force_all_finite = False )
431
- # time indicator is used for selection (selection not available in DoubleMLData yet)
432
431
x , s = check_X_y (x , self ._dml_data .s , force_all_finite = False )
433
432
434
433
if self ._score == "nonignorable" :
435
434
z , _ = check_X_y (self ._dml_data .z , y , force_all_finite = False )
436
- dx = np .column_stack ((x , d , z ))
437
- else :
438
- dx = np .column_stack ((x , d ))
439
435
440
436
if scoring_methods is None :
441
437
scoring_methods = {"ml_g" : None , "ml_pi" : None , "ml_m" : None }
442
438
443
- # nuisance training sets conditional on d
444
- _ , smpls_d0_s1 , _ , smpls_d1_s1 = _get_cond_smpls_2d (smpls , d , s )
445
- train_inds = [train_index for (train_index , _ ) in smpls ]
446
- train_inds_d0_s1 = [train_index for (train_index , _ ) in smpls_d0_s1 ]
447
- train_inds_d1_s1 = [train_index for (train_index , _ ) in smpls_d1_s1 ]
448
-
449
- # hyperparameter tuning for ML
450
- g_d0_tune_res = _dml_tune (
451
- y ,
452
- x ,
453
- train_inds_d0_s1 ,
454
- self ._learner ["ml_g" ],
455
- param_grids ["ml_g" ],
456
- scoring_methods ["ml_g" ],
457
- n_folds_tune ,
458
- n_jobs_cv ,
459
- search_mode ,
460
- n_iter_randomized_search ,
461
- )
462
- g_d1_tune_res = _dml_tune (
463
- y ,
464
- x ,
465
- train_inds_d1_s1 ,
466
- self ._learner ["ml_g" ],
467
- param_grids ["ml_g" ],
468
- scoring_methods ["ml_g" ],
469
- n_folds_tune ,
470
- n_jobs_cv ,
471
- search_mode ,
472
- n_iter_randomized_search ,
473
- )
474
- pi_tune_res = _dml_tune (
475
- s ,
476
- dx ,
477
- train_inds ,
478
- self ._learner ["ml_pi" ],
479
- param_grids ["ml_pi" ],
480
- scoring_methods ["ml_pi" ],
481
- n_folds_tune ,
482
- n_jobs_cv ,
483
- search_mode ,
484
- n_iter_randomized_search ,
485
- )
486
- m_tune_res = _dml_tune (
487
- d ,
488
- x ,
489
- train_inds ,
490
- self ._learner ["ml_m" ],
491
- param_grids ["ml_m" ],
492
- scoring_methods ["ml_m" ],
493
- n_folds_tune ,
494
- n_jobs_cv ,
495
- search_mode ,
496
- n_iter_randomized_search ,
497
- )
439
+ # Nested helper functions
440
+ def tune_learner (target , features , train_indices , learner_key ):
441
+ return _dml_tune (
442
+ target ,
443
+ features ,
444
+ train_indices ,
445
+ self ._learner [learner_key ],
446
+ param_grids [learner_key ],
447
+ scoring_methods [learner_key ],
448
+ n_folds_tune ,
449
+ n_jobs_cv ,
450
+ search_mode ,
451
+ n_iter_randomized_search ,
452
+ )
498
453
499
- g_d0_best_params = [xx .best_params_ for xx in g_d0_tune_res ]
500
- g_d1_best_params = [xx .best_params_ for xx in g_d1_tune_res ]
501
- pi_best_params = [xx .best_params_ for xx in pi_tune_res ]
502
- m_best_params = [xx .best_params_ for xx in m_tune_res ]
454
+ def split_inner_folds (train_inds , d , s , random_state = 42 ):
455
+ inner_train0_inds , inner_train1_inds = [], []
456
+ for train_index in train_inds :
457
+ stratify_vec = d [train_index ] + 2 * s [train_index ]
458
+ inner0 , inner1 = train_test_split (train_index , test_size = 0.5 , stratify = stratify_vec , random_state = random_state )
459
+ inner_train0_inds .append (inner0 )
460
+ inner_train1_inds .append (inner1 )
461
+ return inner_train0_inds , inner_train1_inds
462
+
463
+ def filter_by_ds (inner_train1_inds , d , s ):
464
+ inner1_d0_s1 , inner1_d1_s1 = [], []
465
+ for inner1 in inner_train1_inds :
466
+ d_fold , s_fold = d [inner1 ], s [inner1 ]
467
+ mask_d0_s1 = (d_fold == 0 ) & (s_fold == 1 )
468
+ mask_d1_s1 = (d_fold == 1 ) & (s_fold == 1 )
469
+
470
+ inner1_d0_s1 .append (inner1 [mask_d0_s1 ])
471
+ inner1_d1_s1 .append (inner1 [mask_d1_s1 ])
472
+ return inner1_d0_s1 , inner1_d1_s1
503
473
504
- params = { "ml_g_d0" : g_d0_best_params , "ml_g_d1" : g_d1_best_params , "ml_pi" : pi_best_params , "ml_m" : m_best_params }
474
+ if self . _score == "nonignorable" :
505
475
506
- tune_res = {"g_d0_tune" : g_d0_tune_res , "g_d1_tune" : g_d1_tune_res , "pi_tune" : pi_tune_res , "m_tune" : m_tune_res }
476
+ train_inds = [train_index for (train_index , _ ) in smpls ]
477
+
478
+ # inner folds: split train set into two halves (pi-tuning vs. m/g-tuning)
479
+ inner_train0_inds , inner_train1_inds = split_inner_folds (train_inds , d , s )
480
+ # split inner1 by (d,s) to build g-models for treated/control
481
+ inner_train1_d0_s1 , inner_train1_d1_s1 = filter_by_ds (inner_train1_inds , d , s )
482
+
483
+ # Tune ml_pi
484
+ x_d_z = np .column_stack ((x , d , z ))
485
+ pi_tune_res = []
486
+ pi_hat_full = np .full (shape = s .shape , fill_value = np .nan )
487
+ for inner0 , inner1 in zip (inner_train0_inds , inner_train1_inds ):
488
+ res = tune_learner (s , x_d_z , [inner0 ], "ml_pi" )
489
+ best_params = res [0 ].best_params_
490
+
491
+ # Fit tuned model and predict
492
+ ml_pi_temp = clone (self ._learner ["ml_pi" ])
493
+ ml_pi_temp .set_params (** best_params )
494
+ ml_pi_temp .fit (x_d_z [inner0 ], s [inner0 ])
495
+ pi_hat_full [inner1 ] = _predict_zero_one_propensity (ml_pi_temp , x_d_z )[inner1 ]
496
+ pi_tune_res .append (res [0 ])
497
+
498
+ # Tune ml_m with x + pi-hats
499
+ x_pi = np .column_stack ([x , pi_hat_full .reshape (- 1 , 1 )])
500
+ m_tune_res = tune_learner (d , x_pi , inner_train1_inds , "ml_m" )
501
+
502
+ # Tune ml_g for d=0 and d=1
503
+ x_pi_d = np .column_stack ([x , d .reshape (- 1 , 1 ), pi_hat_full .reshape (- 1 , 1 )])
504
+ g_d0_tune_res = tune_learner (y , x_pi_d , inner_train1_d0_s1 , "ml_g" )
505
+ g_d1_tune_res = tune_learner (y , x_pi_d , inner_train1_d1_s1 , "ml_g" )
507
506
508
- res = {"params" : params , "tune_res" : tune_res }
507
+ else :
508
+ # nuisance training sets conditional on d
509
+ _ , smpls_d0_s1 , _ , smpls_d1_s1 = _get_cond_smpls_2d (smpls , d , s )
510
+ train_inds = [train_index for (train_index , _ ) in smpls ]
511
+ train_inds_d0_s1 = [train_index for (train_index , _ ) in smpls_d0_s1 ]
512
+ train_inds_d1_s1 = [train_index for (train_index , _ ) in smpls_d1_s1 ]
513
+
514
+ # Tune ml_g for d=0 and d=1
515
+ g_d0_tune_res = tune_learner (y , x , train_inds_d0_s1 , "ml_g" )
516
+ g_d1_tune_res = tune_learner (y , x , train_inds_d1_s1 , "ml_g" )
517
+
518
+ # Tune ml_pi and ml_m
519
+ x_d = np .column_stack ((x , d ))
520
+ pi_tune_res = tune_learner (s , x_d , train_inds , "ml_pi" )
521
+ m_tune_res = tune_learner (d , x , train_inds , "ml_m" )
522
+
523
+ # Collect results
524
+ params = {
525
+ "ml_g_d0" : [res .best_params_ for res in g_d0_tune_res ],
526
+ "ml_g_d1" : [res .best_params_ for res in g_d1_tune_res ],
527
+ "ml_pi" : [res .best_params_ for res in pi_tune_res ],
528
+ "ml_m" : [res .best_params_ for res in m_tune_res ],
529
+ }
530
+
531
+ tune_res = {
532
+ "g_d0_tune" : g_d0_tune_res ,
533
+ "g_d1_tune" : g_d1_tune_res ,
534
+ "pi_tune" : pi_tune_res ,
535
+ "m_tune" : m_tune_res ,
536
+ }
509
537
510
- return res
538
+ return { "params" : params , "tune_res" : tune_res }
511
539
512
540
def _sensitivity_element_est (self , preds ):
513
541
pass
0 commit comments