Skip to content

Commit d61c040

Browse files
authored
Merge pull request #352 from batuhanovski/ssm-hyperparameter-tuning
SSM: nonignorable nuisance tuning aligned with estimation; update tests and manual tuning utils
2 parents 94f4861 + 14ee7b9 commit d61c040

File tree

3 files changed

+203
-105
lines changed

3 files changed

+203
-105
lines changed

doubleml/irm/ssm.py

Lines changed: 95 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -428,86 +428,114 @@ def _nuisance_tuning(
428428
):
429429
x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
430430
x, d = check_X_y(x, self._dml_data.d, force_all_finite=False)
431-
# time indicator is used for selection (selection not available in DoubleMLData yet)
432431
x, s = check_X_y(x, self._dml_data.s, force_all_finite=False)
433432

434433
if self._score == "nonignorable":
435434
z, _ = check_X_y(self._dml_data.z, y, force_all_finite=False)
436-
dx = np.column_stack((x, d, z))
437-
else:
438-
dx = np.column_stack((x, d))
439435

440436
if scoring_methods is None:
441437
scoring_methods = {"ml_g": None, "ml_pi": None, "ml_m": None}
442438

443-
# nuisance training sets conditional on d
444-
_, smpls_d0_s1, _, smpls_d1_s1 = _get_cond_smpls_2d(smpls, d, s)
445-
train_inds = [train_index for (train_index, _) in smpls]
446-
train_inds_d0_s1 = [train_index for (train_index, _) in smpls_d0_s1]
447-
train_inds_d1_s1 = [train_index for (train_index, _) in smpls_d1_s1]
448-
449-
# hyperparameter tuning for ML
450-
g_d0_tune_res = _dml_tune(
451-
y,
452-
x,
453-
train_inds_d0_s1,
454-
self._learner["ml_g"],
455-
param_grids["ml_g"],
456-
scoring_methods["ml_g"],
457-
n_folds_tune,
458-
n_jobs_cv,
459-
search_mode,
460-
n_iter_randomized_search,
461-
)
462-
g_d1_tune_res = _dml_tune(
463-
y,
464-
x,
465-
train_inds_d1_s1,
466-
self._learner["ml_g"],
467-
param_grids["ml_g"],
468-
scoring_methods["ml_g"],
469-
n_folds_tune,
470-
n_jobs_cv,
471-
search_mode,
472-
n_iter_randomized_search,
473-
)
474-
pi_tune_res = _dml_tune(
475-
s,
476-
dx,
477-
train_inds,
478-
self._learner["ml_pi"],
479-
param_grids["ml_pi"],
480-
scoring_methods["ml_pi"],
481-
n_folds_tune,
482-
n_jobs_cv,
483-
search_mode,
484-
n_iter_randomized_search,
485-
)
486-
m_tune_res = _dml_tune(
487-
d,
488-
x,
489-
train_inds,
490-
self._learner["ml_m"],
491-
param_grids["ml_m"],
492-
scoring_methods["ml_m"],
493-
n_folds_tune,
494-
n_jobs_cv,
495-
search_mode,
496-
n_iter_randomized_search,
497-
)
439+
# Nested helper functions
440+
def tune_learner(target, features, train_indices, learner_key):
441+
return _dml_tune(
442+
target,
443+
features,
444+
train_indices,
445+
self._learner[learner_key],
446+
param_grids[learner_key],
447+
scoring_methods[learner_key],
448+
n_folds_tune,
449+
n_jobs_cv,
450+
search_mode,
451+
n_iter_randomized_search,
452+
)
498453

499-
g_d0_best_params = [xx.best_params_ for xx in g_d0_tune_res]
500-
g_d1_best_params = [xx.best_params_ for xx in g_d1_tune_res]
501-
pi_best_params = [xx.best_params_ for xx in pi_tune_res]
502-
m_best_params = [xx.best_params_ for xx in m_tune_res]
454+
def split_inner_folds(train_inds, d, s, random_state=42):
455+
inner_train0_inds, inner_train1_inds = [], []
456+
for train_index in train_inds:
457+
stratify_vec = d[train_index] + 2 * s[train_index]
458+
inner0, inner1 = train_test_split(train_index, test_size=0.5, stratify=stratify_vec, random_state=random_state)
459+
inner_train0_inds.append(inner0)
460+
inner_train1_inds.append(inner1)
461+
return inner_train0_inds, inner_train1_inds
462+
463+
def filter_by_ds(inner_train1_inds, d, s):
464+
inner1_d0_s1, inner1_d1_s1 = [], []
465+
for inner1 in inner_train1_inds:
466+
d_fold, s_fold = d[inner1], s[inner1]
467+
mask_d0_s1 = (d_fold == 0) & (s_fold == 1)
468+
mask_d1_s1 = (d_fold == 1) & (s_fold == 1)
469+
470+
inner1_d0_s1.append(inner1[mask_d0_s1])
471+
inner1_d1_s1.append(inner1[mask_d1_s1])
472+
return inner1_d0_s1, inner1_d1_s1
503473

504-
params = {"ml_g_d0": g_d0_best_params, "ml_g_d1": g_d1_best_params, "ml_pi": pi_best_params, "ml_m": m_best_params}
474+
if self._score == "nonignorable":
505475

506-
tune_res = {"g_d0_tune": g_d0_tune_res, "g_d1_tune": g_d1_tune_res, "pi_tune": pi_tune_res, "m_tune": m_tune_res}
476+
train_inds = [train_index for (train_index, _) in smpls]
477+
478+
# inner folds: split train set into two halves (pi-tuning vs. m/g-tuning)
479+
inner_train0_inds, inner_train1_inds = split_inner_folds(train_inds, d, s)
480+
# split inner1 by (d,s) to build g-models for treated/control
481+
inner_train1_d0_s1, inner_train1_d1_s1 = filter_by_ds(inner_train1_inds, d, s)
482+
483+
# Tune ml_pi
484+
x_d_z = np.column_stack((x, d, z))
485+
pi_tune_res = []
486+
pi_hat_full = np.full(shape=s.shape, fill_value=np.nan)
487+
for inner0, inner1 in zip(inner_train0_inds, inner_train1_inds):
488+
res = tune_learner(s, x_d_z, [inner0], "ml_pi")
489+
best_params = res[0].best_params_
490+
491+
# Fit tuned model and predict
492+
ml_pi_temp = clone(self._learner["ml_pi"])
493+
ml_pi_temp.set_params(**best_params)
494+
ml_pi_temp.fit(x_d_z[inner0], s[inner0])
495+
pi_hat_full[inner1] = _predict_zero_one_propensity(ml_pi_temp, x_d_z)[inner1]
496+
pi_tune_res.append(res[0])
497+
498+
# Tune ml_m with x + pi-hats
499+
x_pi = np.column_stack([x, pi_hat_full.reshape(-1, 1)])
500+
m_tune_res = tune_learner(d, x_pi, inner_train1_inds, "ml_m")
501+
502+
# Tune ml_g for d=0 and d=1
503+
x_pi_d = np.column_stack([x, d.reshape(-1, 1), pi_hat_full.reshape(-1, 1)])
504+
g_d0_tune_res = tune_learner(y, x_pi_d, inner_train1_d0_s1, "ml_g")
505+
g_d1_tune_res = tune_learner(y, x_pi_d, inner_train1_d1_s1, "ml_g")
507506

508-
res = {"params": params, "tune_res": tune_res}
507+
else:
508+
# nuisance training sets conditional on d
509+
_, smpls_d0_s1, _, smpls_d1_s1 = _get_cond_smpls_2d(smpls, d, s)
510+
train_inds = [train_index for (train_index, _) in smpls]
511+
train_inds_d0_s1 = [train_index for (train_index, _) in smpls_d0_s1]
512+
train_inds_d1_s1 = [train_index for (train_index, _) in smpls_d1_s1]
513+
514+
# Tune ml_g for d=0 and d=1
515+
g_d0_tune_res = tune_learner(y, x, train_inds_d0_s1, "ml_g")
516+
g_d1_tune_res = tune_learner(y, x, train_inds_d1_s1, "ml_g")
517+
518+
# Tune ml_pi and ml_m
519+
x_d = np.column_stack((x, d))
520+
pi_tune_res = tune_learner(s, x_d, train_inds, "ml_pi")
521+
m_tune_res = tune_learner(d, x, train_inds, "ml_m")
522+
523+
# Collect results
524+
params = {
525+
"ml_g_d0": [res.best_params_ for res in g_d0_tune_res],
526+
"ml_g_d1": [res.best_params_ for res in g_d1_tune_res],
527+
"ml_pi": [res.best_params_ for res in pi_tune_res],
528+
"ml_m": [res.best_params_ for res in m_tune_res],
529+
}
530+
531+
tune_res = {
532+
"g_d0_tune": g_d0_tune_res,
533+
"g_d1_tune": g_d1_tune_res,
534+
"pi_tune": pi_tune_res,
535+
"m_tune": m_tune_res,
536+
}
509537

510-
return res
538+
return {"params": params, "tune_res": tune_res}
511539

512540
def _sensitivity_element_est(self, preds):
513541
pass

doubleml/irm/tests/_utils_ssm_manual.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,17 +273,14 @@ def var_selection(theta, psi_a, psi_b, n_obs):
273273
return var
274274

275275

276-
def tune_nuisance_ssm(y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, score, n_folds_tune, param_grid_g, param_grid_pi, param_grid_m):
276+
def tune_nuisance_ssm_mar(y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, n_folds_tune, param_grid_g, param_grid_pi, param_grid_m):
277277
d0_s1 = np.intersect1d(np.where(d == 0)[0], np.where(s == 1)[0])
278278
d1_s1 = np.intersect1d(np.where(d == 1)[0], np.where(s == 1)[0])
279279

280280
g0_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune, train_cond=d0_s1)
281281
g1_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune, train_cond=d1_s1)
282282

283-
if score == "nonignorable":
284-
dx = np.column_stack((x, d, z))
285-
else:
286-
dx = np.column_stack((x, d))
283+
dx = np.column_stack((x, d))
287284

288285
pi_tune_res = tune_grid_search(s, dx, ml_pi, smpls, param_grid_pi, n_folds_tune)
289286

@@ -295,3 +292,44 @@ def tune_nuisance_ssm(y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, score, n_folds_tu
295292
m_best_params = [xx.best_params_ for xx in m_tune_res]
296293

297294
return g0_best_params, g1_best_params, pi_best_params, m_best_params
295+
296+
297+
def tune_nuisance_ssm_nonignorable(
298+
y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, n_folds_tune, param_grid_g, param_grid_pi, param_grid_m
299+
):
300+
301+
train_inds = [tr for (tr, _) in smpls]
302+
303+
inner0_list, inner1_list = [], []
304+
for tr in train_inds:
305+
i0, i1 = train_test_split(tr, test_size=0.5, stratify=d[tr] + 2 * s[tr], random_state=42)
306+
inner0_list.append(i0)
307+
inner1_list.append(i1)
308+
309+
X_dz = np.c_[x, d.reshape(-1, 1), z.reshape(-1, 1)]
310+
pi_tune_res = tune_grid_search(s, X_dz, ml_pi, [(i0, np.array([])) for i0 in inner0_list], param_grid_pi, n_folds_tune)
311+
pi_best_params = [gs.best_params_ for gs in pi_tune_res]
312+
313+
pi_hat_full = np.full_like(s, np.nan, dtype=float)
314+
for i0, i1, gs in zip(inner0_list, inner1_list, pi_tune_res):
315+
ml_pi_temp = clone(ml_pi)
316+
ml_pi_temp.set_params(**gs.best_params_)
317+
ml_pi_temp.fit(X_dz[i0], s[i0])
318+
ph = _predict_zero_one_propensity(ml_pi_temp, X_dz)
319+
pi_hat_full[i1] = ph[i1]
320+
321+
X_pi = np.c_[x, pi_hat_full]
322+
m_tune_res = tune_grid_search(d, X_pi, ml_m, [(i1, np.array([])) for i1 in inner1_list], param_grid_m, n_folds_tune)
323+
m_best_params = [gs.best_params_ for gs in m_tune_res]
324+
325+
X_pi_d = np.c_[x, d.reshape(-1, 1), pi_hat_full.reshape(-1, 1)]
326+
inner1_d0_s1 = [i1[(d[i1] == 0) & (s[i1] == 1)] for i1 in inner1_list]
327+
inner1_d1_s1 = [i1[(d[i1] == 1) & (s[i1] == 1)] for i1 in inner1_list]
328+
329+
g0_tune_res = tune_grid_search(y, X_pi_d, ml_g, [(idx, np.array([])) for idx in inner1_d0_s1], param_grid_g, n_folds_tune)
330+
g1_tune_res = tune_grid_search(y, X_pi_d, ml_g, [(idx, np.array([])) for idx in inner1_d1_s1], param_grid_g, n_folds_tune)
331+
332+
g0_best_params = [gs.best_params_ for gs in g0_tune_res]
333+
g1_best_params = [gs.best_params_ for gs in g1_tune_res]
334+
335+
return g0_best_params, g1_best_params, pi_best_params, m_best_params

doubleml/irm/tests/test_ssm_tune.py

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import doubleml as dml
1010

1111
from ...tests._utils import draw_smpls
12-
from ._utils_ssm_manual import fit_selection, tune_nuisance_ssm
12+
from ._utils_ssm_manual import fit_selection, tune_nuisance_ssm_mar, tune_nuisance_ssm_nonignorable
1313

1414

1515
@pytest.fixture(scope="module", params=[RandomForestRegressor(random_state=42)])
@@ -115,41 +115,73 @@ def dml_ssm_fixture(
115115
np.random.seed(42)
116116
smpls = all_smpls[0]
117117
if tune_on_folds:
118-
g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm(
119-
y,
120-
x,
121-
d,
122-
z,
123-
s,
124-
clone(learner_g),
125-
clone(learner_m),
126-
clone(learner_m),
127-
smpls,
128-
score,
129-
n_folds_tune,
130-
par_grid["ml_g"],
131-
par_grid["ml_pi"],
132-
par_grid["ml_m"],
133-
)
118+
if score == "missing-at-random":
119+
g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm_mar(
120+
y,
121+
x,
122+
d,
123+
z,
124+
s,
125+
clone(learner_g),
126+
clone(learner_m),
127+
clone(learner_m),
128+
smpls,
129+
n_folds_tune,
130+
par_grid["ml_g"],
131+
par_grid["ml_pi"],
132+
par_grid["ml_m"],
133+
)
134+
elif score == "nonignorable":
135+
g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm_nonignorable(
136+
y,
137+
x,
138+
d,
139+
z,
140+
s,
141+
clone(learner_g),
142+
clone(learner_m),
143+
clone(learner_m),
144+
smpls,
145+
n_folds_tune,
146+
par_grid["ml_g"],
147+
par_grid["ml_pi"],
148+
par_grid["ml_m"],
149+
)
134150

135151
else:
136152
xx = [(np.arange(len(y)), np.array([]))]
137-
g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm(
138-
y,
139-
x,
140-
d,
141-
z,
142-
s,
143-
clone(learner_g),
144-
clone(learner_m),
145-
clone(learner_m),
146-
xx,
147-
score,
148-
n_folds_tune,
149-
par_grid["ml_g"],
150-
par_grid["ml_pi"],
151-
par_grid["ml_m"],
152-
)
153+
if score == "missing-at-random":
154+
g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm_mar(
155+
y,
156+
x,
157+
d,
158+
z,
159+
s,
160+
clone(learner_g),
161+
clone(learner_m),
162+
clone(learner_m),
163+
xx,
164+
n_folds_tune,
165+
par_grid["ml_g"],
166+
par_grid["ml_pi"],
167+
par_grid["ml_m"],
168+
)
169+
elif score == "nonignorable":
170+
g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm_nonignorable(
171+
y,
172+
x,
173+
d,
174+
z,
175+
s,
176+
clone(learner_g),
177+
clone(learner_m),
178+
clone(learner_m),
179+
xx,
180+
n_folds_tune,
181+
par_grid["ml_g"],
182+
par_grid["ml_pi"],
183+
par_grid["ml_m"],
184+
)
153185

154186
g0_best_params = g0_best_params * n_folds
155187
g1_best_params = g1_best_params * n_folds

0 commit comments

Comments
 (0)