Skip to content

Commit bbda382

Browse files
committed
simplify _construct_gt_combinations
1 parent f2b2733 commit bbda382

File tree

1 file changed

+31
-33
lines changed

1 file changed

+31
-33
lines changed

doubleml/did/utils/_did_utils.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,11 @@ def _construct_gt_combinations(setting, g_values, t_values, never_treated_value,
115115
"""Construct treatment-time combinations for difference-in-differences analysis.
116116
117117
Parameters:
118-
setting (str): Strategy for constructing combinations ('standard' only)
118+
setting (str): Strategy for constructing combinations. One of 'standard', 'all', 'universal'.
119119
g_values (array): Treatment group values, must be sorted
120120
t_values (array): Time period values, must be sorted
121+
never_treated_value (int, float or pd.NaT): Value indicating never-treated units.
122+
anticipation_periods (int): Number of anticipation periods.
121123
122124
Returns:
123125
list: List of (g_val, t_pre, t_eval) tuples
@@ -133,38 +135,34 @@ def _construct_gt_combinations(setting, g_values, t_values, never_treated_value,
133135
raise ValueError("t_values must be sorted in ascending order.")
134136

135137
gt_combinations = []
136-
if setting == "standard":
137-
for g_val in treatment_groups:
138-
t_values_before_g = t_values[t_values < g_val]
139-
if len(t_values_before_g) > anticipation_periods:
140-
first_eval_index = anticipation_periods + 1 # first relevant evaluation period index
141-
t_before_g = t_values_before_g[-first_eval_index]
142-
143-
# collect all evaluation periods
144-
for i_t_eval, t_eval in enumerate(t_values[first_eval_index:]):
145-
t_previous = t_values[i_t_eval] # refers to t-anticipation_periods-1
146-
t_pre = min(t_previous, t_before_g) # if t_previous larger than g_val, use t_before_g
147-
gt_combinations.append((g_val, t_pre, t_eval))
148-
149-
if setting == "all":
150-
for g_val in treatment_groups:
151-
t_values_before_g = t_values[t_values < g_val]
152-
if len(t_values_before_g) > anticipation_periods:
153-
first_eval_index = anticipation_periods + 1 # first relevant evaluation period index
154-
for t_eval in t_values[first_eval_index:]:
155-
# all t-values before g_val - anticipation_periods
156-
valid_t_pre_values = t_values[t_values <= min(g_val, t_eval)][:-first_eval_index]
157-
for t_pre in valid_t_pre_values:
158-
gt_combinations.append((g_val, t_pre, t_eval))
159-
160-
if setting == "universal":
161-
for g_val in treatment_groups:
162-
t_values_before_g = t_values[t_values < g_val]
163-
if len(t_values_before_g) > anticipation_periods:
164-
base_period = g_val - anticipation_periods - 1
165-
for t_eval in t_values:
166-
if t_eval != base_period:
167-
gt_combinations.append((g_val, base_period, t_eval))
138+
for g_val in treatment_groups:
139+
t_values_before_g = t_values[t_values < g_val]
140+
if len(t_values_before_g) <= anticipation_periods:
141+
continue
142+
first_eval_index = anticipation_periods + 1 # first relevant evaluation period index
143+
144+
if setting == "standard":
145+
t_before_g = t_values_before_g[-first_eval_index]
146+
combinations = [
147+
(g_val, min(t_values[i_t_eval], t_before_g), t_eval)
148+
for i_t_eval, t_eval in enumerate(t_values[first_eval_index:])
149+
]
150+
gt_combinations.extend(combinations)
151+
152+
elif setting == "all":
153+
combinations = [
154+
(g_val, t_pre, t_eval)
155+
for t_eval in t_values[first_eval_index:]
156+
for t_pre in t_values[t_values <= min(g_val, t_eval)][:-first_eval_index]
157+
]
158+
gt_combinations.extend(combinations)
159+
160+
elif setting == "universal":
161+
# The base period is the last period before treatment, accounting for anticipation.
162+
# `g_val - anticipation_periods - 1` is not robust for non-integer or non-consecutive periods.
163+
base_period = t_values_before_g[-first_eval_index]
164+
combinations = [(g_val, base_period, t_eval) for t_eval in t_values if t_eval != base_period]
165+
gt_combinations.extend(combinations)
168166

169167
if len(gt_combinations) == 0:
170168
raise ValueError(

0 commit comments

Comments
 (0)