Skip to content

Commit 14ee7b9

Browse files
committed
refactor tuning process in DoubleMLSSM class for clarity and maintainability
1 parent 3cc8679 commit 14ee7b9

File tree

1 file changed

+76
-197
lines changed

1 file changed

+76
-197
lines changed

doubleml/irm/ssm.py

Lines changed: 76 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -428,235 +428,114 @@ def _nuisance_tuning(
428428
):
429429
x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
430430
x, d = check_X_y(x, self._dml_data.d, force_all_finite=False)
431-
# time indicator is used for selection (selection not available in DoubleMLData yet)
432431
x, s = check_X_y(x, self._dml_data.s, force_all_finite=False)
433432

434433
if self._score == "nonignorable":
435434
z, _ = check_X_y(self._dml_data.z, y, force_all_finite=False)
436-
dx = np.column_stack((x, d, z))
437-
else:
438-
dx = np.column_stack((x, d))
439435

440436
if scoring_methods is None:
441437
scoring_methods = {"ml_g": None, "ml_pi": None, "ml_m": None}
442438

439+
# Nested helper functions
440+
def tune_learner(target, features, train_indices, learner_key):
441+
return _dml_tune(
442+
target,
443+
features,
444+
train_indices,
445+
self._learner[learner_key],
446+
param_grids[learner_key],
447+
scoring_methods[learner_key],
448+
n_folds_tune,
449+
n_jobs_cv,
450+
search_mode,
451+
n_iter_randomized_search,
452+
)
453+
454+
def split_inner_folds(train_inds, d, s, random_state=42):
455+
inner_train0_inds, inner_train1_inds = [], []
456+
for train_index in train_inds:
457+
stratify_vec = d[train_index] + 2 * s[train_index]
458+
inner0, inner1 = train_test_split(train_index, test_size=0.5, stratify=stratify_vec, random_state=random_state)
459+
inner_train0_inds.append(inner0)
460+
inner_train1_inds.append(inner1)
461+
return inner_train0_inds, inner_train1_inds
462+
463+
def filter_by_ds(inner_train1_inds, d, s):
464+
inner1_d0_s1, inner1_d1_s1 = [], []
465+
for inner1 in inner_train1_inds:
466+
d_fold, s_fold = d[inner1], s[inner1]
467+
mask_d0_s1 = (d_fold == 0) & (s_fold == 1)
468+
mask_d1_s1 = (d_fold == 1) & (s_fold == 1)
469+
470+
inner1_d0_s1.append(inner1[mask_d0_s1])
471+
inner1_d1_s1.append(inner1[mask_d1_s1])
472+
return inner1_d0_s1, inner1_d1_s1
473+
443474
if self._score == "nonignorable":
444475

445476
train_inds = [train_index for (train_index, _) in smpls]
446477

447478
# inner folds: split train set into two halves (pi-tuning vs. m/g-tuning)
448-
def get_inner_train_inds(train_inds, d, s, random_state=42):
449-
inner_train0_inds = []
450-
inner_train1_inds = []
451-
452-
for train_index in train_inds:
453-
d_fold = d[train_index]
454-
s_fold = s[train_index]
455-
stratify_vec = d_fold + 2 * s_fold
456-
457-
inner0, inner1 = train_test_split(
458-
train_index, test_size=0.5, stratify=stratify_vec, random_state=random_state
459-
)
460-
461-
inner_train0_inds.append(inner0)
462-
inner_train1_inds.append(inner1)
463-
464-
return inner_train0_inds, inner_train1_inds
465-
466-
inner_train0_inds, inner_train1_inds = get_inner_train_inds(train_inds, d, s)
467-
479+
inner_train0_inds, inner_train1_inds = split_inner_folds(train_inds, d, s)
468480
# split inner1 by (d,s) to build g-models for treated/control
469-
def filter_inner1_by_ds(inner_train1_inds, d, s):
470-
inner1_d0_s1 = []
471-
inner1_d1_s1 = []
472-
473-
for inner1 in inner_train1_inds:
474-
d_fold = d[inner1]
475-
s_fold = s[inner1]
476-
477-
mask_d0_s1 = (d_fold == 0) & (s_fold == 1)
478-
mask_d1_s1 = (d_fold == 1) & (s_fold == 1)
479-
480-
inner1_d0_s1.append(inner1[mask_d0_s1])
481-
inner1_d1_s1.append(inner1[mask_d1_s1])
482-
483-
return inner1_d0_s1, inner1_d1_s1
484-
485-
inner_train1_d0_s1, inner_train1_d1_s1 = filter_inner1_by_ds(inner_train1_inds, d, s)
486-
487-
x_d_z = np.concatenate([x, d.reshape(-1, 1), z.reshape(-1, 1)], axis=1)
488-
489-
# ml_pi: tune on inner0, predict pi-hat on inner1
490-
pi_hat_list = []
491-
pi_tune_res_nonignorable = []
481+
inner_train1_d0_s1, inner_train1_d1_s1 = filter_by_ds(inner_train1_inds, d, s)
492482

483+
# Tune ml_pi
484+
x_d_z = np.column_stack((x, d, z))
485+
pi_tune_res = []
486+
pi_hat_full = np.full(shape=s.shape, fill_value=np.nan)
493487
for inner0, inner1 in zip(inner_train0_inds, inner_train1_inds):
488+
res = tune_learner(s, x_d_z, [inner0], "ml_pi")
489+
best_params = res[0].best_params_
494490

495-
# tune pi on inner0
496-
pi_tune_res = _dml_tune(
497-
s,
498-
x_d_z,
499-
[inner0],
500-
self._learner["ml_pi"],
501-
param_grids["ml_pi"],
502-
scoring_methods["ml_pi"],
503-
n_folds_tune,
504-
n_jobs_cv,
505-
search_mode,
506-
n_iter_randomized_search,
507-
)
508-
best_params = pi_tune_res[0].best_params_
509-
510-
# fit tuned model
491+
# Fit tuned model and predict
511492
ml_pi_temp = clone(self._learner["ml_pi"])
512493
ml_pi_temp.set_params(**best_params)
513494
ml_pi_temp.fit(x_d_z[inner0], s[inner0])
495+
pi_hat_full[inner1] = _predict_zero_one_propensity(ml_pi_temp, x_d_z)[inner1]
496+
pi_tune_res.append(res[0])
514497

515-
# predict proba on inner1
516-
pi_hat_all = _predict_zero_one_propensity(ml_pi_temp, x_d_z)
517-
pi_hat = pi_hat_all[inner1]
518-
pi_hat_list.append((inner1, pi_hat)) # (index, value) tuple
519-
520-
# save best params
521-
pi_tune_res_nonignorable.append(pi_tune_res[0])
522-
523-
pi_hat_full = np.full(shape=s.shape, fill_value=np.nan)
524-
525-
for inner1, pi_hat in pi_hat_list:
526-
pi_hat_full[inner1] = pi_hat
527-
528-
# ml_m: tune with x + pi-hats
529-
x_pi = np.concatenate([x, pi_hat_full.reshape(-1, 1)], axis=1)
530-
531-
m_tune_res = _dml_tune(
532-
d,
533-
x_pi,
534-
inner_train1_inds,
535-
self._learner["ml_m"],
536-
param_grids["ml_m"],
537-
scoring_methods["ml_m"],
538-
n_folds_tune,
539-
n_jobs_cv,
540-
search_mode,
541-
n_iter_randomized_search,
542-
)
543-
544-
# ml_g: tune with x + d + pi-hats for d=0, d=1
545-
x_pi_d = np.concatenate([x, d.reshape(-1, 1), pi_hat_full.reshape(-1, 1)], axis=1)
546-
547-
g_d0_tune_res = _dml_tune(
548-
y,
549-
x_pi_d,
550-
inner_train1_d0_s1,
551-
self._learner["ml_g"],
552-
param_grids["ml_g"],
553-
scoring_methods["ml_g"],
554-
n_folds_tune,
555-
n_jobs_cv,
556-
search_mode,
557-
n_iter_randomized_search,
558-
)
559-
g_d1_tune_res = _dml_tune(
560-
y,
561-
x_pi_d,
562-
inner_train1_d1_s1,
563-
self._learner["ml_g"],
564-
param_grids["ml_g"],
565-
scoring_methods["ml_g"],
566-
n_folds_tune,
567-
n_jobs_cv,
568-
search_mode,
569-
n_iter_randomized_search,
570-
)
571-
572-
g_d0_best_params = [xx.best_params_ for xx in g_d0_tune_res]
573-
g_d1_best_params = [xx.best_params_ for xx in g_d1_tune_res]
574-
pi_best_params = [xx.best_params_ for xx in pi_tune_res_nonignorable]
575-
m_best_params = [xx.best_params_ for xx in m_tune_res]
498+
# Tune ml_m with x + pi-hats
499+
x_pi = np.column_stack([x, pi_hat_full.reshape(-1, 1)])
500+
m_tune_res = tune_learner(d, x_pi, inner_train1_inds, "ml_m")
576501

577-
params = {"ml_g_d0": g_d0_best_params, "ml_g_d1": g_d1_best_params, "ml_pi": pi_best_params, "ml_m": m_best_params}
578-
579-
tune_res = {
580-
"g_d0_tune": g_d0_tune_res,
581-
"g_d1_tune": g_d1_tune_res,
582-
"pi_tune": pi_tune_res_nonignorable,
583-
"m_tune": m_tune_res,
584-
}
585-
586-
res = {"params": params, "tune_res": tune_res}
587-
588-
return res
502+
# Tune ml_g for d=0 and d=1
503+
x_pi_d = np.column_stack([x, d.reshape(-1, 1), pi_hat_full.reshape(-1, 1)])
504+
g_d0_tune_res = tune_learner(y, x_pi_d, inner_train1_d0_s1, "ml_g")
505+
g_d1_tune_res = tune_learner(y, x_pi_d, inner_train1_d1_s1, "ml_g")
589506

590507
else:
591-
592508
# nuisance training sets conditional on d
593509
_, smpls_d0_s1, _, smpls_d1_s1 = _get_cond_smpls_2d(smpls, d, s)
594510
train_inds = [train_index for (train_index, _) in smpls]
595511
train_inds_d0_s1 = [train_index for (train_index, _) in smpls_d0_s1]
596512
train_inds_d1_s1 = [train_index for (train_index, _) in smpls_d1_s1]
597513

598-
# hyperparameter tuning for ML
599-
g_d0_tune_res = _dml_tune(
600-
y,
601-
x,
602-
train_inds_d0_s1,
603-
self._learner["ml_g"],
604-
param_grids["ml_g"],
605-
scoring_methods["ml_g"],
606-
n_folds_tune,
607-
n_jobs_cv,
608-
search_mode,
609-
n_iter_randomized_search,
610-
)
611-
g_d1_tune_res = _dml_tune(
612-
y,
613-
x,
614-
train_inds_d1_s1,
615-
self._learner["ml_g"],
616-
param_grids["ml_g"],
617-
scoring_methods["ml_g"],
618-
n_folds_tune,
619-
n_jobs_cv,
620-
search_mode,
621-
n_iter_randomized_search,
622-
)
623-
pi_tune_res = _dml_tune(
624-
s,
625-
dx,
626-
train_inds,
627-
self._learner["ml_pi"],
628-
param_grids["ml_pi"],
629-
scoring_methods["ml_pi"],
630-
n_folds_tune,
631-
n_jobs_cv,
632-
search_mode,
633-
n_iter_randomized_search,
634-
)
635-
m_tune_res = _dml_tune(
636-
d,
637-
x,
638-
train_inds,
639-
self._learner["ml_m"],
640-
param_grids["ml_m"],
641-
scoring_methods["ml_m"],
642-
n_folds_tune,
643-
n_jobs_cv,
644-
search_mode,
645-
n_iter_randomized_search,
646-
)
647-
648-
g_d0_best_params = [xx.best_params_ for xx in g_d0_tune_res]
649-
g_d1_best_params = [xx.best_params_ for xx in g_d1_tune_res]
650-
pi_best_params = [xx.best_params_ for xx in pi_tune_res]
651-
m_best_params = [xx.best_params_ for xx in m_tune_res]
652-
653-
params = {"ml_g_d0": g_d0_best_params, "ml_g_d1": g_d1_best_params, "ml_pi": pi_best_params, "ml_m": m_best_params}
654-
655-
tune_res = {"g_d0_tune": g_d0_tune_res, "g_d1_tune": g_d1_tune_res, "pi_tune": pi_tune_res, "m_tune": m_tune_res}
514+
# Tune ml_g for d=0 and d=1
515+
g_d0_tune_res = tune_learner(y, x, train_inds_d0_s1, "ml_g")
516+
g_d1_tune_res = tune_learner(y, x, train_inds_d1_s1, "ml_g")
517+
518+
# Tune ml_pi and ml_m
519+
x_d = np.column_stack((x, d))
520+
pi_tune_res = tune_learner(s, x_d, train_inds, "ml_pi")
521+
m_tune_res = tune_learner(d, x, train_inds, "ml_m")
522+
523+
# Collect results
524+
params = {
525+
"ml_g_d0": [res.best_params_ for res in g_d0_tune_res],
526+
"ml_g_d1": [res.best_params_ for res in g_d1_tune_res],
527+
"ml_pi": [res.best_params_ for res in pi_tune_res],
528+
"ml_m": [res.best_params_ for res in m_tune_res],
529+
}
656530

657-
res = {"params": params, "tune_res": tune_res}
531+
tune_res = {
532+
"g_d0_tune": g_d0_tune_res,
533+
"g_d1_tune": g_d1_tune_res,
534+
"pi_tune": pi_tune_res,
535+
"m_tune": m_tune_res,
536+
}
658537

659-
return res
538+
return {"params": params, "tune_res": tune_res}
660539

661540
def _sensitivity_element_est(self, preds):
662541
pass

0 commit comments

Comments
 (0)