Skip to content

Commit 3cc8679

Browse files
committed
SSM: nonignorable nuisance tuning aligned with estimation; update tests and manual tuning utils
1 parent 94f4861 commit 3cc8679

File tree

3 files changed

+320
-101
lines changed

3 files changed

+320
-101
lines changed

doubleml/irm/ssm.py

Lines changed: 212 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -440,74 +440,223 @@ def _nuisance_tuning(
440440
if scoring_methods is None:
441441
scoring_methods = {"ml_g": None, "ml_pi": None, "ml_m": None}
442442

443-
# nuisance training sets conditional on d
444-
_, smpls_d0_s1, _, smpls_d1_s1 = _get_cond_smpls_2d(smpls, d, s)
445-
train_inds = [train_index for (train_index, _) in smpls]
446-
train_inds_d0_s1 = [train_index for (train_index, _) in smpls_d0_s1]
447-
train_inds_d1_s1 = [train_index for (train_index, _) in smpls_d1_s1]
448-
449-
# hyperparameter tuning for ML
450-
g_d0_tune_res = _dml_tune(
451-
y,
452-
x,
453-
train_inds_d0_s1,
454-
self._learner["ml_g"],
455-
param_grids["ml_g"],
456-
scoring_methods["ml_g"],
457-
n_folds_tune,
458-
n_jobs_cv,
459-
search_mode,
460-
n_iter_randomized_search,
461-
)
462-
g_d1_tune_res = _dml_tune(
463-
y,
464-
x,
465-
train_inds_d1_s1,
466-
self._learner["ml_g"],
467-
param_grids["ml_g"],
468-
scoring_methods["ml_g"],
469-
n_folds_tune,
470-
n_jobs_cv,
471-
search_mode,
472-
n_iter_randomized_search,
473-
)
474-
pi_tune_res = _dml_tune(
475-
s,
476-
dx,
477-
train_inds,
478-
self._learner["ml_pi"],
479-
param_grids["ml_pi"],
480-
scoring_methods["ml_pi"],
481-
n_folds_tune,
482-
n_jobs_cv,
483-
search_mode,
484-
n_iter_randomized_search,
485-
)
486-
m_tune_res = _dml_tune(
487-
d,
488-
x,
489-
train_inds,
490-
self._learner["ml_m"],
491-
param_grids["ml_m"],
492-
scoring_methods["ml_m"],
493-
n_folds_tune,
494-
n_jobs_cv,
495-
search_mode,
496-
n_iter_randomized_search,
497-
)
443+
if self._score == "nonignorable":
444+
445+
train_inds = [train_index for (train_index, _) in smpls]
446+
447+
# inner folds: split train set into two halves (pi-tuning vs. m/g-tuning)
448+
def get_inner_train_inds(train_inds, d, s, random_state=42):
449+
inner_train0_inds = []
450+
inner_train1_inds = []
451+
452+
for train_index in train_inds:
453+
d_fold = d[train_index]
454+
s_fold = s[train_index]
455+
stratify_vec = d_fold + 2 * s_fold
456+
457+
inner0, inner1 = train_test_split(
458+
train_index, test_size=0.5, stratify=stratify_vec, random_state=random_state
459+
)
460+
461+
inner_train0_inds.append(inner0)
462+
inner_train1_inds.append(inner1)
463+
464+
return inner_train0_inds, inner_train1_inds
465+
466+
inner_train0_inds, inner_train1_inds = get_inner_train_inds(train_inds, d, s)
467+
468+
# split inner1 by (d,s) to build g-models for treated/control
469+
def filter_inner1_by_ds(inner_train1_inds, d, s):
470+
inner1_d0_s1 = []
471+
inner1_d1_s1 = []
472+
473+
for inner1 in inner_train1_inds:
474+
d_fold = d[inner1]
475+
s_fold = s[inner1]
476+
477+
mask_d0_s1 = (d_fold == 0) & (s_fold == 1)
478+
mask_d1_s1 = (d_fold == 1) & (s_fold == 1)
479+
480+
inner1_d0_s1.append(inner1[mask_d0_s1])
481+
inner1_d1_s1.append(inner1[mask_d1_s1])
482+
483+
return inner1_d0_s1, inner1_d1_s1
484+
485+
inner_train1_d0_s1, inner_train1_d1_s1 = filter_inner1_by_ds(inner_train1_inds, d, s)
486+
487+
x_d_z = np.concatenate([x, d.reshape(-1, 1), z.reshape(-1, 1)], axis=1)
488+
489+
# ml_pi: tune on inner0, predict pi-hat on inner1
490+
pi_hat_list = []
491+
pi_tune_res_nonignorable = []
492+
493+
for inner0, inner1 in zip(inner_train0_inds, inner_train1_inds):
494+
495+
# tune pi on inner0
496+
pi_tune_res = _dml_tune(
497+
s,
498+
x_d_z,
499+
[inner0],
500+
self._learner["ml_pi"],
501+
param_grids["ml_pi"],
502+
scoring_methods["ml_pi"],
503+
n_folds_tune,
504+
n_jobs_cv,
505+
search_mode,
506+
n_iter_randomized_search,
507+
)
508+
best_params = pi_tune_res[0].best_params_
509+
510+
# fit tuned model
511+
ml_pi_temp = clone(self._learner["ml_pi"])
512+
ml_pi_temp.set_params(**best_params)
513+
ml_pi_temp.fit(x_d_z[inner0], s[inner0])
514+
515+
# predict proba on inner1
516+
pi_hat_all = _predict_zero_one_propensity(ml_pi_temp, x_d_z)
517+
pi_hat = pi_hat_all[inner1]
518+
pi_hat_list.append((inner1, pi_hat)) # (index, value) tuple
519+
520+
# save best params
521+
pi_tune_res_nonignorable.append(pi_tune_res[0])
522+
523+
pi_hat_full = np.full(shape=s.shape, fill_value=np.nan)
524+
525+
for inner1, pi_hat in pi_hat_list:
526+
pi_hat_full[inner1] = pi_hat
527+
528+
# ml_m: tune with x + pi-hats
529+
x_pi = np.concatenate([x, pi_hat_full.reshape(-1, 1)], axis=1)
530+
531+
m_tune_res = _dml_tune(
532+
d,
533+
x_pi,
534+
inner_train1_inds,
535+
self._learner["ml_m"],
536+
param_grids["ml_m"],
537+
scoring_methods["ml_m"],
538+
n_folds_tune,
539+
n_jobs_cv,
540+
search_mode,
541+
n_iter_randomized_search,
542+
)
543+
544+
# ml_g: tune with x + d + pi-hats for d=0, d=1
545+
x_pi_d = np.concatenate([x, d.reshape(-1, 1), pi_hat_full.reshape(-1, 1)], axis=1)
546+
547+
g_d0_tune_res = _dml_tune(
548+
y,
549+
x_pi_d,
550+
inner_train1_d0_s1,
551+
self._learner["ml_g"],
552+
param_grids["ml_g"],
553+
scoring_methods["ml_g"],
554+
n_folds_tune,
555+
n_jobs_cv,
556+
search_mode,
557+
n_iter_randomized_search,
558+
)
559+
g_d1_tune_res = _dml_tune(
560+
y,
561+
x_pi_d,
562+
inner_train1_d1_s1,
563+
self._learner["ml_g"],
564+
param_grids["ml_g"],
565+
scoring_methods["ml_g"],
566+
n_folds_tune,
567+
n_jobs_cv,
568+
search_mode,
569+
n_iter_randomized_search,
570+
)
571+
572+
g_d0_best_params = [xx.best_params_ for xx in g_d0_tune_res]
573+
g_d1_best_params = [xx.best_params_ for xx in g_d1_tune_res]
574+
pi_best_params = [xx.best_params_ for xx in pi_tune_res_nonignorable]
575+
m_best_params = [xx.best_params_ for xx in m_tune_res]
576+
577+
params = {"ml_g_d0": g_d0_best_params, "ml_g_d1": g_d1_best_params, "ml_pi": pi_best_params, "ml_m": m_best_params}
578+
579+
tune_res = {
580+
"g_d0_tune": g_d0_tune_res,
581+
"g_d1_tune": g_d1_tune_res,
582+
"pi_tune": pi_tune_res_nonignorable,
583+
"m_tune": m_tune_res,
584+
}
585+
586+
res = {"params": params, "tune_res": tune_res}
587+
588+
return res
589+
590+
else:
591+
592+
# nuisance training sets conditional on d
593+
_, smpls_d0_s1, _, smpls_d1_s1 = _get_cond_smpls_2d(smpls, d, s)
594+
train_inds = [train_index for (train_index, _) in smpls]
595+
train_inds_d0_s1 = [train_index for (train_index, _) in smpls_d0_s1]
596+
train_inds_d1_s1 = [train_index for (train_index, _) in smpls_d1_s1]
597+
598+
# hyperparameter tuning for ML
599+
g_d0_tune_res = _dml_tune(
600+
y,
601+
x,
602+
train_inds_d0_s1,
603+
self._learner["ml_g"],
604+
param_grids["ml_g"],
605+
scoring_methods["ml_g"],
606+
n_folds_tune,
607+
n_jobs_cv,
608+
search_mode,
609+
n_iter_randomized_search,
610+
)
611+
g_d1_tune_res = _dml_tune(
612+
y,
613+
x,
614+
train_inds_d1_s1,
615+
self._learner["ml_g"],
616+
param_grids["ml_g"],
617+
scoring_methods["ml_g"],
618+
n_folds_tune,
619+
n_jobs_cv,
620+
search_mode,
621+
n_iter_randomized_search,
622+
)
623+
pi_tune_res = _dml_tune(
624+
s,
625+
dx,
626+
train_inds,
627+
self._learner["ml_pi"],
628+
param_grids["ml_pi"],
629+
scoring_methods["ml_pi"],
630+
n_folds_tune,
631+
n_jobs_cv,
632+
search_mode,
633+
n_iter_randomized_search,
634+
)
635+
m_tune_res = _dml_tune(
636+
d,
637+
x,
638+
train_inds,
639+
self._learner["ml_m"],
640+
param_grids["ml_m"],
641+
scoring_methods["ml_m"],
642+
n_folds_tune,
643+
n_jobs_cv,
644+
search_mode,
645+
n_iter_randomized_search,
646+
)
498647

499-
g_d0_best_params = [xx.best_params_ for xx in g_d0_tune_res]
500-
g_d1_best_params = [xx.best_params_ for xx in g_d1_tune_res]
501-
pi_best_params = [xx.best_params_ for xx in pi_tune_res]
502-
m_best_params = [xx.best_params_ for xx in m_tune_res]
648+
g_d0_best_params = [xx.best_params_ for xx in g_d0_tune_res]
649+
g_d1_best_params = [xx.best_params_ for xx in g_d1_tune_res]
650+
pi_best_params = [xx.best_params_ for xx in pi_tune_res]
651+
m_best_params = [xx.best_params_ for xx in m_tune_res]
503652

504-
params = {"ml_g_d0": g_d0_best_params, "ml_g_d1": g_d1_best_params, "ml_pi": pi_best_params, "ml_m": m_best_params}
653+
params = {"ml_g_d0": g_d0_best_params, "ml_g_d1": g_d1_best_params, "ml_pi": pi_best_params, "ml_m": m_best_params}
505654

506-
tune_res = {"g_d0_tune": g_d0_tune_res, "g_d1_tune": g_d1_tune_res, "pi_tune": pi_tune_res, "m_tune": m_tune_res}
655+
tune_res = {"g_d0_tune": g_d0_tune_res, "g_d1_tune": g_d1_tune_res, "pi_tune": pi_tune_res, "m_tune": m_tune_res}
507656

508-
res = {"params": params, "tune_res": tune_res}
657+
res = {"params": params, "tune_res": tune_res}
509658

510-
return res
659+
return res
511660

512661
def _sensitivity_element_est(self, preds):
513662
pass

doubleml/irm/tests/_utils_ssm_manual.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,17 +273,14 @@ def var_selection(theta, psi_a, psi_b, n_obs):
273273
return var
274274

275275

276-
def tune_nuisance_ssm(y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, score, n_folds_tune, param_grid_g, param_grid_pi, param_grid_m):
276+
def tune_nuisance_ssm_mar(y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, n_folds_tune, param_grid_g, param_grid_pi, param_grid_m):
277277
d0_s1 = np.intersect1d(np.where(d == 0)[0], np.where(s == 1)[0])
278278
d1_s1 = np.intersect1d(np.where(d == 1)[0], np.where(s == 1)[0])
279279

280280
g0_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune, train_cond=d0_s1)
281281
g1_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune, train_cond=d1_s1)
282282

283-
if score == "nonignorable":
284-
dx = np.column_stack((x, d, z))
285-
else:
286-
dx = np.column_stack((x, d))
283+
dx = np.column_stack((x, d))
287284

288285
pi_tune_res = tune_grid_search(s, dx, ml_pi, smpls, param_grid_pi, n_folds_tune)
289286

@@ -295,3 +292,44 @@ def tune_nuisance_ssm(y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, score, n_folds_tu
295292
m_best_params = [xx.best_params_ for xx in m_tune_res]
296293

297294
return g0_best_params, g1_best_params, pi_best_params, m_best_params
295+
296+
297+
def tune_nuisance_ssm_nonignorable(
298+
y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, n_folds_tune, param_grid_g, param_grid_pi, param_grid_m
299+
):
300+
301+
train_inds = [tr for (tr, _) in smpls]
302+
303+
inner0_list, inner1_list = [], []
304+
for tr in train_inds:
305+
i0, i1 = train_test_split(tr, test_size=0.5, stratify=d[tr] + 2 * s[tr], random_state=42)
306+
inner0_list.append(i0)
307+
inner1_list.append(i1)
308+
309+
X_dz = np.c_[x, d.reshape(-1, 1), z.reshape(-1, 1)]
310+
pi_tune_res = tune_grid_search(s, X_dz, ml_pi, [(i0, np.array([])) for i0 in inner0_list], param_grid_pi, n_folds_tune)
311+
pi_best_params = [gs.best_params_ for gs in pi_tune_res]
312+
313+
pi_hat_full = np.full_like(s, np.nan, dtype=float)
314+
for i0, i1, gs in zip(inner0_list, inner1_list, pi_tune_res):
315+
ml_pi_temp = clone(ml_pi)
316+
ml_pi_temp.set_params(**gs.best_params_)
317+
ml_pi_temp.fit(X_dz[i0], s[i0])
318+
ph = _predict_zero_one_propensity(ml_pi_temp, X_dz)
319+
pi_hat_full[i1] = ph[i1]
320+
321+
X_pi = np.c_[x, pi_hat_full]
322+
m_tune_res = tune_grid_search(d, X_pi, ml_m, [(i1, np.array([])) for i1 in inner1_list], param_grid_m, n_folds_tune)
323+
m_best_params = [gs.best_params_ for gs in m_tune_res]
324+
325+
X_pi_d = np.c_[x, d.reshape(-1, 1), pi_hat_full.reshape(-1, 1)]
326+
inner1_d0_s1 = [i1[(d[i1] == 0) & (s[i1] == 1)] for i1 in inner1_list]
327+
inner1_d1_s1 = [i1[(d[i1] == 1) & (s[i1] == 1)] for i1 in inner1_list]
328+
329+
g0_tune_res = tune_grid_search(y, X_pi_d, ml_g, [(idx, np.array([])) for idx in inner1_d0_s1], param_grid_g, n_folds_tune)
330+
g1_tune_res = tune_grid_search(y, X_pi_d, ml_g, [(idx, np.array([])) for idx in inner1_d1_s1], param_grid_g, n_folds_tune)
331+
332+
g0_best_params = [gs.best_params_ for gs in g0_tune_res]
333+
g1_best_params = [gs.best_params_ for gs in g1_tune_res]
334+
335+
return g0_best_params, g1_best_params, pi_best_params, m_best_params

0 commit comments

Comments
 (0)