@@ -1030,6 +1030,21 @@ def fit(
10301030 T_g = T_g_arr ,
10311031 L_max = L_max ,
10321032 )
1033+ # Surface A11 warnings from multi-horizon computation
1034+ mh_a11 = multi_horizon_dids .pop ("_a11_warnings" , None )
1035+ if mh_a11 :
1036+ warnings .warn (
1037+ f"Multi-horizon control-availability violations in "
1038+ f"{ len (mh_a11 )} (group, horizon) pair(s): affected "
1039+ f"DID_{{g,l}} values are zeroed but their switcher "
1040+ f"counts are retained in N_l (matching the A11 "
1041+ f"zero-retention convention). Examples: "
1042+ + ", " .join (mh_a11 [:3 ])
1043+ + (f" (and { len (mh_a11 ) - 3 } more)" if len (mh_a11 ) > 3 else "" ),
1044+ UserWarning ,
1045+ stacklevel = 2 ,
1046+ )
1047+
10331048 multi_horizon_if = _compute_per_group_if_multi_horizon (
10341049 D_mat = D_mat ,
10351050 Y_mat = Y_mat ,
@@ -1051,7 +1066,10 @@ def fit(
10511066
10521067 multi_horizon_se = {}
10531068 multi_horizon_inference = {}
1054- for l_h in range (2 , L_max + 1 ):
1069+ # Compute inference for ALL horizons 1..L_max (including l=1)
1070+ # so the event_study_effects dict uses a consistent estimand
1071+ # (per-group DID_{g,l}) across all horizons.
1072+ for l_h in range (1 , L_max + 1 ):
10551073 U_l = multi_horizon_if [l_h ]
10561074 # Cohort IDs for this horizon: (D_{g,1}, F_g, S_g) triples
10571075 # are the same as Phase 1 (cohort identity depends on first
@@ -1315,7 +1333,12 @@ def fit(
13151333 [g not in singleton_baseline_set_b for g in all_groups ], dtype = bool
13161334 )
13171335 mh_boot_inputs = {}
1318- for l_h in range (2 , L_max + 1 ):
1336+ # Include ALL horizons 1..L_max so the sup-t critical
1337+ # value is calibrated over the same set that receives
1338+ # cband_conf_int. For l=1, use the per-group IF (not
1339+ # the Phase 1 per-period IF) so the bootstrap matches
1340+ # the event_study_effects[1] estimand.
1341+ for l_h in range (1 , L_max + 1 ):
13191342 h_data = multi_horizon_dids .get (l_h )
13201343 if h_data is None or h_data ["N_l" ] == 0 :
13211344 continue
@@ -1400,22 +1423,24 @@ def fit(
14001423 # ------------------------------------------------------------------
14011424 # Step 20: Build the results dataclass
14021425 # ------------------------------------------------------------------
1403- # event_study_effects: l=1 always mirrors the Phase 1 DID_M output.
1404- # When L_max >= 2, horizons 2..L_max are populated from the Phase 2
1405- # multi-horizon computation.
1406- event_study_effects : Dict [int , Dict [str , Any ]] = {
1407- 1 : {
1408- "effect" : overall_att ,
1409- "se" : overall_se ,
1410- "t_stat" : overall_t ,
1411- "p_value" : overall_p ,
1412- "conf_int" : overall_ci ,
1413- "n_obs" : N_S ,
1426+ # event_study_effects: when L_max is None, l=1 mirrors Phase 1
1427+ # DID_M (per-period path). When L_max >= 2, ALL horizons including
1428+ # l=1 use the per-group DID_{g,l} path for a consistent estimand.
1429+ if multi_horizon_inference is not None and 1 in multi_horizon_inference :
1430+ # Phase 2 mode: use per-group path for all horizons
1431+ event_study_effects : Dict [int , Dict [str , Any ]] = dict (multi_horizon_inference )
1432+ else :
1433+ # Phase 1 mode (L_max=None): l=1 from per-period path
1434+ event_study_effects = {
1435+ 1 : {
1436+ "effect" : overall_att ,
1437+ "se" : overall_se ,
1438+ "t_stat" : overall_t ,
1439+ "p_value" : overall_p ,
1440+ "conf_int" : overall_ci ,
1441+ "n_obs" : N_S ,
1442+ }
14141443 }
1415- }
1416- if multi_horizon_inference is not None :
1417- for l_h , inf_dict in multi_horizon_inference .items ():
1418- event_study_effects [l_h ] = inf_dict
14191444
14201445 # Phase 2: propagate bootstrap results to event_study_effects
14211446 if bootstrap_results is not None and bootstrap_results .event_study_ses :
@@ -1514,7 +1539,7 @@ def fit(
15141539 denom = n_data ["denominator" ]
15151540 eff = n_data ["effect" ]
15161541 # SE via delta method: SE(DID^n_l) = SE(DID_l) / delta^D_l
1517- se_did_l = multi_horizon_se .get (l_h , float ("nan" )) if l_h >= 2 else overall_se
1542+ se_did_l = multi_horizon_se .get (l_h , float ("nan" ))
15181543 se_norm = se_did_l / denom if np .isfinite (denom ) and denom > 0 else float ("nan" )
15191544 t_n , p_n , ci_n = safe_inference (eff , se_norm , alpha = self .alpha , df = None )
15201545 normalized_effects_out [l_h ] = {
@@ -2119,6 +2144,7 @@ def _compute_multi_horizon_dids(
21192144 baseline_f [int (d )] = first_switch_idx [mask ]
21202145
21212146 results : Dict [int , Dict [str , Any ]] = {}
2147+ a11_multi_warnings : List [str ] = []
21222148 N_1 = 0 # will be set at l=1 for switcher_fraction
21232149
21242150 for l in range (1 , L_max + 1 ): # noqa: E741
@@ -2187,6 +2213,10 @@ def _compute_multi_horizon_dids(
21872213 # matching the A11 zero-retention convention: the group's
21882214 # switcher count is still in N_l.
21892215 did_g_l [g ] = 0.0
2216+ a11_multi_warnings .append (
2217+ f"horizon { l } , group_idx { g } : "
2218+ f"no baseline-matched controls at outcome period"
2219+ )
21902220 continue
21912221
21922222 ctrl_changes = Y_mat [ctrl_pool , out_idx ] - Y_mat [ctrl_pool , ref_idx ]
@@ -2206,6 +2236,10 @@ def _compute_multi_horizon_dids(
22062236 "switcher_fraction" : N_l / N_1 if N_1 > 0 else float ("nan" ),
22072237 }
22082238
2239+ # Attach A11 warnings to the results for the caller to surface
2240+ if a11_multi_warnings :
2241+ results ["_a11_warnings" ] = a11_multi_warnings # type: ignore[assignment]
2242+
22092243 return results
22102244
22112245
@@ -2393,8 +2427,9 @@ def _compute_multi_horizon_placebos(
23932427 forward_idx = ref_idx + l
23942428 d_base = int (baselines [g ])
23952429
2396- # Switcher's backward outcome change
2397- switcher_change = Y_mat [g , backward_idx ] - Y_mat [g , ref_idx ]
2430+ # Switcher's backward outcome change: reference minus pre-period
2431+ # (matching Phase 1 convention: Y_{ref} - Y_{earlier})
2432+ switcher_change = Y_mat [g , ref_idx ] - Y_mat [g , backward_idx ]
23982433
23992434 # Control pool: same baseline, not switched by forward_idx
24002435 ctrl_indices = baseline_groups [d_base ]
@@ -2410,7 +2445,7 @@ def _compute_multi_horizon_placebos(
24102445 pl_g_l [g ] = 0.0
24112446 continue
24122447
2413- ctrl_changes = Y_mat [ctrl_pool , backward_idx ] - Y_mat [ctrl_pool , ref_idx ]
2448+ ctrl_changes = Y_mat [ctrl_pool , ref_idx ] - Y_mat [ctrl_pool , backward_idx ]
24142449 ctrl_avg = float (ctrl_changes .mean ())
24152450 pl_g_l [g ] = switcher_change - ctrl_avg
24162451
@@ -2522,9 +2557,14 @@ def _compute_cost_benefit_delta(
25222557 dose_l = 0.0
25232558 for g in np .where (eligible )[0 ]:
25242559 f_g = first_switch_idx [g ]
2525- col = f_g - 1 + l
2526- if col < D_mat .shape [1 ]:
2527- dose_l += abs (float (D_mat [g , col ] - baselines [g ]))
2560+ # Cumulative dose: delta^D_{g,l} = sum_{k=0}^{l-1} |D_{g,F_g+k} - D_{g,1}|
2561+ # For binary treatment this equals l (each period contributes 1).
2562+ cum_dose = 0.0
2563+ for k in range (l ):
2564+ col_k = f_g + k
2565+ if col_k < D_mat .shape [1 ]:
2566+ cum_dose += abs (float (D_mat [g , col_k ] - baselines [g ]))
2567+ dose_l += cum_dose
25282568 per_horizon_dose [l ] = dose_l
25292569 total_dose += dose_l
25302570
@@ -2572,9 +2612,12 @@ def _compute_cost_benefit_delta(
25722612 if switch_direction [g ] != direction :
25732613 continue
25742614 f_g = first_switch_idx [g ]
2575- col = f_g - 1 + l
2576- if col < D_mat .shape [1 ]:
2577- dose_l += abs (float (D_mat [g , col ] - baselines [g ]))
2615+ cum_dose = 0.0
2616+ for k in range (l ):
2617+ col_k = f_g + k
2618+ if col_k < D_mat .shape [1 ]:
2619+ cum_dose += abs (float (D_mat [g , col_k ] - baselines [g ]))
2620+ dose_l += cum_dose
25782621 dir_horizon_dose [l ] = dose_l
25792622 dir_dose += dose_l
25802623
0 commit comments