Skip to content

Commit f2b2733

Browse files
committed
add universal option to gt_combinations
1 parent ea67dc7 commit f2b2733

File tree

4 files changed

+52
-16
lines changed

4 files changed

+52
-16
lines changed

doubleml/did/tests/test_did_multi_exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_exception_learners():
100100

101101
@pytest.mark.ci
102102
def test_exception_gt_combinations():
103-
msg = r"gt_combinations must be one of \['standard', 'all'\]. test was passed."
103+
msg = r"gt_combinations must be one of \['standard', 'all', 'universal'\]. test was passed."
104104
with pytest.raises(ValueError, match=msg):
105105
invalid_arguments = {"gt_combinations": "test"}
106106
_ = dml.did.DoubleMLDIDMulti(**(valid_arguments | invalid_arguments))

doubleml/did/tests/test_did_multi_plot.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,13 @@ def n_rep(request):
2323
return request.param
2424

2525

26+
@pytest.fixture(scope="module", params=["standard", "all", "universal"])
27+
def gt_comb(request):
28+
return request.param
29+
30+
2631
@pytest.fixture(scope="module")
27-
def doubleml_did_fixture(did_score, panel, n_rep):
32+
def doubleml_did_fixture(did_score, panel, n_rep, gt_comb):
2833
n_obs = 1000
2934
dgp = 5 # has to be experimental (for experimental score to be valid)
3035
np.random.seed(42)
@@ -35,7 +40,7 @@ def doubleml_did_fixture(did_score, panel, n_rep):
3540
"obj_dml_data": dml_data,
3641
"ml_g": LinearRegression(),
3742
"ml_m": LogisticRegression(),
38-
"gt_combinations": "all",
43+
"gt_combinations": gt_comb,
3944
"score": did_score,
4045
"panel": panel,
4146
"n_rep": n_rep,

doubleml/did/utils/_did_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,6 @@ def _check_gt_combination(gt_combination, g_values, t_values, never_treated_valu
8484
if t_value_pre == t_value_eval:
8585
raise ValueError(f"The pre-treatment and evaluation period must be different. Got {t_value_pre} for both.")
8686

87-
if t_value_pre > t_value_eval:
88-
raise ValueError(
89-
"The pre-treatment period must be before the evaluation period. "
90-
f"Got t_value_pre {t_value_pre} and t_value_eval {t_value_eval}."
91-
)
92-
9387
# get t_value equal to g_value and adjust for anticipation periods
9488
maximal_t_pre = t_values[max(np.where(t_values == g_value)[0] - anticipation_periods, 0)]
9589
if t_value_pre >= maximal_t_pre:
@@ -128,7 +122,7 @@ def _construct_gt_combinations(setting, g_values, t_values, never_treated_value,
128122
Returns:
129123
list: List of (g_val, t_pre, t_eval) tuples
130124
"""
131-
valid_settings = ["standard", "all"]
125+
valid_settings = ["standard", "all", "universal"]
132126
if setting not in valid_settings:
133127
raise ValueError(f"gt_combinations must be one of {valid_settings}. {setting} was passed.")
134128

@@ -163,6 +157,15 @@ def _construct_gt_combinations(setting, g_values, t_values, never_treated_value,
163157
for t_pre in valid_t_pre_values:
164158
gt_combinations.append((g_val, t_pre, t_eval))
165159

160+
if setting == "universal":
161+
for g_val in treatment_groups:
162+
t_values_before_g = t_values[t_values < g_val]
163+
if len(t_values_before_g) > anticipation_periods:
164+
base_period = g_val - anticipation_periods - 1
165+
for t_eval in t_values:
166+
if t_eval != base_period:
167+
gt_combinations.append((g_val, base_period, t_eval))
168+
166169
if len(gt_combinations) == 0:
167170
raise ValueError(
168171
"No valid group-time combinations found. "

doubleml/did/utils/tests/test_did_utils.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,6 @@ def test_check_gt_combination():
9999
ValueError,
100100
"The pre-treatment and evaluation period must be different. Got 1 for both.",
101101
),
102-
(
103-
{"gt_combination": (1, 1, 0)},
104-
ValueError,
105-
"The pre-treatment period must be before the evaluation period. Got t_value_pre 1 and t_value_eval 0.",
106-
),
107102
(
108103
{"gt_combination": (-1, 0, 1)},
109104
ValueError,
@@ -171,7 +166,7 @@ def test_input_check_gt_values():
171166

172167
@pytest.mark.ci
173168
def test_construct_gt_combinations():
174-
msg = r"gt_combinations must be one of \['standard', 'all'\]. test was passed."
169+
msg = r"gt_combinations must be one of \['standard', 'all', 'universal'\]. test was passed."
175170
with pytest.raises(ValueError, match=msg):
176171
_construct_gt_combinations(
177172
setting="test",
@@ -256,6 +251,24 @@ def test_construct_gt_combinations():
256251
]
257252
assert all_combinations == expected_all
258253

254+
# Test universal setting
255+
universal_combinations = _construct_gt_combinations(
256+
setting="universal",
257+
g_values=np.array([2, 3]),
258+
t_values=np.array([0, 1, 2, 3]),
259+
never_treated_value=np.inf,
260+
anticipation_periods=0,
261+
)
262+
expected_universal = [
263+
(2, 1, 0), # g=2, pre=1, eval=0
264+
(2, 1, 2), # g=2, pre=1, eval=2
265+
(2, 1, 3), # g=2, pre=1, eval=3
266+
(3, 2, 0), # g=3, pre=2, eval=0
267+
(3, 2, 1), # g=3, pre=2, eval=1
268+
(3, 2, 3), # g=3, pre=2, eval=3
269+
]
270+
assert universal_combinations == expected_universal
271+
259272
# Test standard setting with anticipation periods
260273
standard_combinations_anticipation = _construct_gt_combinations(
261274
setting="standard",
@@ -282,6 +295,21 @@ def test_construct_gt_combinations():
282295
]
283296
assert all_combinations_anticipation == expected_all_anticipation
284297

298+
# Test universal setting with anticipation periods
299+
universal_combinations_anticipation = _construct_gt_combinations(
300+
setting="universal",
301+
g_values=np.array([2, 3]),
302+
t_values=np.array([0, 1, 2, 3]),
303+
never_treated_value=np.inf,
304+
anticipation_periods=2,
305+
)
306+
expected_universal_anticipation = [
307+
(3, 0, 1), # g=3, pre=0, eval=1 with anticipation 2
308+
(3, 0, 2),
309+
(3, 0, 3),
310+
]
311+
assert universal_combinations_anticipation == expected_universal_anticipation
312+
285313

286314
@pytest.mark.ci
287315
def test_construct_gt_index():

0 commit comments

Comments
 (0)