Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 95 additions & 67 deletions doubleml/irm/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,86 +428,114 @@ def _nuisance_tuning(
):
x, y = check_X_y(self._dml_data.x, self._dml_data.y, force_all_finite=False)
x, d = check_X_y(x, self._dml_data.d, force_all_finite=False)
# time indicator is used for selection (selection not available in DoubleMLData yet)
x, s = check_X_y(x, self._dml_data.s, force_all_finite=False)

if self._score == "nonignorable":
z, _ = check_X_y(self._dml_data.z, y, force_all_finite=False)
dx = np.column_stack((x, d, z))
else:
dx = np.column_stack((x, d))

if scoring_methods is None:
scoring_methods = {"ml_g": None, "ml_pi": None, "ml_m": None}

# nuisance training sets conditional on d
_, smpls_d0_s1, _, smpls_d1_s1 = _get_cond_smpls_2d(smpls, d, s)
train_inds = [train_index for (train_index, _) in smpls]
train_inds_d0_s1 = [train_index for (train_index, _) in smpls_d0_s1]
train_inds_d1_s1 = [train_index for (train_index, _) in smpls_d1_s1]

# hyperparameter tuning for ML
g_d0_tune_res = _dml_tune(
y,
x,
train_inds_d0_s1,
self._learner["ml_g"],
param_grids["ml_g"],
scoring_methods["ml_g"],
n_folds_tune,
n_jobs_cv,
search_mode,
n_iter_randomized_search,
)
g_d1_tune_res = _dml_tune(
y,
x,
train_inds_d1_s1,
self._learner["ml_g"],
param_grids["ml_g"],
scoring_methods["ml_g"],
n_folds_tune,
n_jobs_cv,
search_mode,
n_iter_randomized_search,
)
pi_tune_res = _dml_tune(
s,
dx,
train_inds,
self._learner["ml_pi"],
param_grids["ml_pi"],
scoring_methods["ml_pi"],
n_folds_tune,
n_jobs_cv,
search_mode,
n_iter_randomized_search,
)
m_tune_res = _dml_tune(
d,
x,
train_inds,
self._learner["ml_m"],
param_grids["ml_m"],
scoring_methods["ml_m"],
n_folds_tune,
n_jobs_cv,
search_mode,
n_iter_randomized_search,
)
# Nested helper functions
def tune_learner(target, features, train_indices, learner_key):
return _dml_tune(
target,
features,
train_indices,
self._learner[learner_key],
param_grids[learner_key],
scoring_methods[learner_key],
n_folds_tune,
n_jobs_cv,
search_mode,
n_iter_randomized_search,
)

g_d0_best_params = [xx.best_params_ for xx in g_d0_tune_res]
g_d1_best_params = [xx.best_params_ for xx in g_d1_tune_res]
pi_best_params = [xx.best_params_ for xx in pi_tune_res]
m_best_params = [xx.best_params_ for xx in m_tune_res]
def split_inner_folds(train_inds, d, s, random_state=42):
inner_train0_inds, inner_train1_inds = [], []
for train_index in train_inds:
stratify_vec = d[train_index] + 2 * s[train_index]
inner0, inner1 = train_test_split(train_index, test_size=0.5, stratify=stratify_vec, random_state=random_state)
inner_train0_inds.append(inner0)
inner_train1_inds.append(inner1)
return inner_train0_inds, inner_train1_inds

def filter_by_ds(inner_train1_inds, d, s):
inner1_d0_s1, inner1_d1_s1 = [], []
for inner1 in inner_train1_inds:
d_fold, s_fold = d[inner1], s[inner1]
mask_d0_s1 = (d_fold == 0) & (s_fold == 1)
mask_d1_s1 = (d_fold == 1) & (s_fold == 1)

inner1_d0_s1.append(inner1[mask_d0_s1])
inner1_d1_s1.append(inner1[mask_d1_s1])
return inner1_d0_s1, inner1_d1_s1

params = {"ml_g_d0": g_d0_best_params, "ml_g_d1": g_d1_best_params, "ml_pi": pi_best_params, "ml_m": m_best_params}
if self._score == "nonignorable":

tune_res = {"g_d0_tune": g_d0_tune_res, "g_d1_tune": g_d1_tune_res, "pi_tune": pi_tune_res, "m_tune": m_tune_res}
train_inds = [train_index for (train_index, _) in smpls]

# inner folds: split train set into two halves (pi-tuning vs. m/g-tuning)
inner_train0_inds, inner_train1_inds = split_inner_folds(train_inds, d, s)
# split inner1 by (d,s) to build g-models for treated/control
inner_train1_d0_s1, inner_train1_d1_s1 = filter_by_ds(inner_train1_inds, d, s)

# Tune ml_pi
x_d_z = np.column_stack((x, d, z))
pi_tune_res = []
pi_hat_full = np.full(shape=s.shape, fill_value=np.nan)
for inner0, inner1 in zip(inner_train0_inds, inner_train1_inds):
res = tune_learner(s, x_d_z, [inner0], "ml_pi")
best_params = res[0].best_params_

# Fit tuned model and predict
ml_pi_temp = clone(self._learner["ml_pi"])
ml_pi_temp.set_params(**best_params)
ml_pi_temp.fit(x_d_z[inner0], s[inner0])
pi_hat_full[inner1] = _predict_zero_one_propensity(ml_pi_temp, x_d_z)[inner1]
pi_tune_res.append(res[0])

# Tune ml_m with x + pi-hats
x_pi = np.column_stack([x, pi_hat_full.reshape(-1, 1)])
m_tune_res = tune_learner(d, x_pi, inner_train1_inds, "ml_m")

# Tune ml_g for d=0 and d=1
x_pi_d = np.column_stack([x, d.reshape(-1, 1), pi_hat_full.reshape(-1, 1)])
g_d0_tune_res = tune_learner(y, x_pi_d, inner_train1_d0_s1, "ml_g")
g_d1_tune_res = tune_learner(y, x_pi_d, inner_train1_d1_s1, "ml_g")

res = {"params": params, "tune_res": tune_res}
else:
# nuisance training sets conditional on d
_, smpls_d0_s1, _, smpls_d1_s1 = _get_cond_smpls_2d(smpls, d, s)
train_inds = [train_index for (train_index, _) in smpls]
train_inds_d0_s1 = [train_index for (train_index, _) in smpls_d0_s1]
train_inds_d1_s1 = [train_index for (train_index, _) in smpls_d1_s1]

# Tune ml_g for d=0 and d=1
g_d0_tune_res = tune_learner(y, x, train_inds_d0_s1, "ml_g")
g_d1_tune_res = tune_learner(y, x, train_inds_d1_s1, "ml_g")

# Tune ml_pi and ml_m
x_d = np.column_stack((x, d))
pi_tune_res = tune_learner(s, x_d, train_inds, "ml_pi")
m_tune_res = tune_learner(d, x, train_inds, "ml_m")

# Collect results
params = {
"ml_g_d0": [res.best_params_ for res in g_d0_tune_res],
"ml_g_d1": [res.best_params_ for res in g_d1_tune_res],
"ml_pi": [res.best_params_ for res in pi_tune_res],
"ml_m": [res.best_params_ for res in m_tune_res],
}

tune_res = {
"g_d0_tune": g_d0_tune_res,
"g_d1_tune": g_d1_tune_res,
"pi_tune": pi_tune_res,
"m_tune": m_tune_res,
}

return res
return {"params": params, "tune_res": tune_res}

def _sensitivity_element_est(self, preds):
pass
48 changes: 43 additions & 5 deletions doubleml/irm/tests/_utils_ssm_manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,17 +273,14 @@ def var_selection(theta, psi_a, psi_b, n_obs):
return var


def tune_nuisance_ssm(y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, score, n_folds_tune, param_grid_g, param_grid_pi, param_grid_m):
def tune_nuisance_ssm_mar(y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, n_folds_tune, param_grid_g, param_grid_pi, param_grid_m):
d0_s1 = np.intersect1d(np.where(d == 0)[0], np.where(s == 1)[0])
d1_s1 = np.intersect1d(np.where(d == 1)[0], np.where(s == 1)[0])

g0_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune, train_cond=d0_s1)
g1_tune_res = tune_grid_search(y, x, ml_g, smpls, param_grid_g, n_folds_tune, train_cond=d1_s1)

if score == "nonignorable":
dx = np.column_stack((x, d, z))
else:
dx = np.column_stack((x, d))
dx = np.column_stack((x, d))

pi_tune_res = tune_grid_search(s, dx, ml_pi, smpls, param_grid_pi, n_folds_tune)

Expand All @@ -295,3 +292,44 @@ def tune_nuisance_ssm(y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, score, n_folds_tu
m_best_params = [xx.best_params_ for xx in m_tune_res]

return g0_best_params, g1_best_params, pi_best_params, m_best_params


def tune_nuisance_ssm_nonignorable(
y, x, d, z, s, ml_g, ml_pi, ml_m, smpls, n_folds_tune, param_grid_g, param_grid_pi, param_grid_m
):

train_inds = [tr for (tr, _) in smpls]

inner0_list, inner1_list = [], []
for tr in train_inds:
i0, i1 = train_test_split(tr, test_size=0.5, stratify=d[tr] + 2 * s[tr], random_state=42)
inner0_list.append(i0)
inner1_list.append(i1)

X_dz = np.c_[x, d.reshape(-1, 1), z.reshape(-1, 1)]
pi_tune_res = tune_grid_search(s, X_dz, ml_pi, [(i0, np.array([])) for i0 in inner0_list], param_grid_pi, n_folds_tune)
pi_best_params = [gs.best_params_ for gs in pi_tune_res]

pi_hat_full = np.full_like(s, np.nan, dtype=float)
for i0, i1, gs in zip(inner0_list, inner1_list, pi_tune_res):
ml_pi_temp = clone(ml_pi)
ml_pi_temp.set_params(**gs.best_params_)
ml_pi_temp.fit(X_dz[i0], s[i0])
ph = _predict_zero_one_propensity(ml_pi_temp, X_dz)
pi_hat_full[i1] = ph[i1]

X_pi = np.c_[x, pi_hat_full]
m_tune_res = tune_grid_search(d, X_pi, ml_m, [(i1, np.array([])) for i1 in inner1_list], param_grid_m, n_folds_tune)
m_best_params = [gs.best_params_ for gs in m_tune_res]

X_pi_d = np.c_[x, d.reshape(-1, 1), pi_hat_full.reshape(-1, 1)]
inner1_d0_s1 = [i1[(d[i1] == 0) & (s[i1] == 1)] for i1 in inner1_list]
inner1_d1_s1 = [i1[(d[i1] == 1) & (s[i1] == 1)] for i1 in inner1_list]

g0_tune_res = tune_grid_search(y, X_pi_d, ml_g, [(idx, np.array([])) for idx in inner1_d0_s1], param_grid_g, n_folds_tune)
g1_tune_res = tune_grid_search(y, X_pi_d, ml_g, [(idx, np.array([])) for idx in inner1_d1_s1], param_grid_g, n_folds_tune)

g0_best_params = [gs.best_params_ for gs in g0_tune_res]
g1_best_params = [gs.best_params_ for gs in g1_tune_res]

return g0_best_params, g1_best_params, pi_best_params, m_best_params
98 changes: 65 additions & 33 deletions doubleml/irm/tests/test_ssm_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import doubleml as dml

from ...tests._utils import draw_smpls
from ._utils_ssm_manual import fit_selection, tune_nuisance_ssm
from ._utils_ssm_manual import fit_selection, tune_nuisance_ssm_mar, tune_nuisance_ssm_nonignorable


@pytest.fixture(scope="module", params=[RandomForestRegressor(random_state=42)])
Expand Down Expand Up @@ -115,41 +115,73 @@ def dml_ssm_fixture(
np.random.seed(42)
smpls = all_smpls[0]
if tune_on_folds:
g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm(
y,
x,
d,
z,
s,
clone(learner_g),
clone(learner_m),
clone(learner_m),
smpls,
score,
n_folds_tune,
par_grid["ml_g"],
par_grid["ml_pi"],
par_grid["ml_m"],
)
if score == "missing-at-random":
g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm_mar(
y,
x,
d,
z,
s,
clone(learner_g),
clone(learner_m),
clone(learner_m),
smpls,
n_folds_tune,
par_grid["ml_g"],
par_grid["ml_pi"],
par_grid["ml_m"],
)
elif score == "nonignorable":
g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm_nonignorable(
y,
x,
d,
z,
s,
clone(learner_g),
clone(learner_m),
clone(learner_m),
smpls,
n_folds_tune,
par_grid["ml_g"],
par_grid["ml_pi"],
par_grid["ml_m"],
)

else:
xx = [(np.arange(len(y)), np.array([]))]
g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm(
y,
x,
d,
z,
s,
clone(learner_g),
clone(learner_m),
clone(learner_m),
xx,
score,
n_folds_tune,
par_grid["ml_g"],
par_grid["ml_pi"],
par_grid["ml_m"],
)
if score == "missing-at-random":
g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm_mar(
y,
x,
d,
z,
s,
clone(learner_g),
clone(learner_m),
clone(learner_m),
xx,
n_folds_tune,
par_grid["ml_g"],
par_grid["ml_pi"],
par_grid["ml_m"],
)
elif score == "nonignorable":
g0_best_params, g1_best_params, pi_best_params, m_best_params = tune_nuisance_ssm_nonignorable(
y,
x,
d,
z,
s,
clone(learner_g),
clone(learner_m),
clone(learner_m),
xx,
n_folds_tune,
par_grid["ml_g"],
par_grid["ml_pi"],
par_grid["ml_m"],
)

g0_best_params = g0_best_params * n_folds
g1_best_params = g1_best_params * n_folds
Expand Down
Loading