@@ -903,6 +903,36 @@ def _validate_had_panel_event_study(
903903 f"single-period WAS)."
904904 )
905905
906+ # Ordered-time-type check. Paper Appendix B.2 event-time horizons
907+ # require chronological ordering of periods (anchor at F-1, horizons
908+ # e = t - F relative to F). Phase 2a two-period panels can use the
909+ # dose invariant alone to distinguish pre from post without needing
910+ # chronological order, so string labels ("pre", "post") work there.
911+ # For multi-period event-study, multiple pre-periods all have D=0
912+ # and multiple post-periods may both have D>0, so dose alone cannot
913+ # recover chronology: we must trust the time column's natural order.
914+ # Raw lexicographic sort on object/string labels silently misorders
915+ # panels like "pre1"/"pre2"/"post1"/"post2" or month-name labels.
916+ # Require an explicitly-ordered time representation.
917+ time_dtype = data [time_col ].dtype
918+ if not (
919+ pd .api .types .is_numeric_dtype (time_dtype )
920+ or pd .api .types .is_datetime64_any_dtype (time_dtype )
921+ or (isinstance (time_dtype , pd .CategoricalDtype ) and bool (time_dtype .ordered ))
922+ ):
923+ raise ValueError (
924+ f"HAD aggregate='event_study' requires an ordered time "
925+ f"column. time_col={ time_col !r} has dtype={ time_dtype !r} , "
926+ f"which has no defined chronological order; raw sort would "
927+ f"fall back to lexicographic ordering and silently misindex "
928+ f"event-time horizons (e.g., 'pre1'/'pre2'/'post1'/'post2' "
929+ f"sorts lexicographically but not chronologically). "
930+ f"Convert time_col to numeric (e.g., integer year), "
931+ f"datetime, or ordered categorical "
932+ f"(``pd.Categorical(..., ordered=True, categories=[...])``) "
933+ f"before calling fit() with aggregate='event_study'."
934+ )
935+
906936 # NaN checks on key columns (before any filter).
907937 for col in [outcome_col , dose_col , unit_col ]:
908938 if bool (data [col ].isna ().any ()):
@@ -936,6 +966,45 @@ def _validate_had_panel_event_study(
936966 f"within unit for { n_bad } unit(s). Each unit must have "
937967 f"a single first_treat value across all observed periods."
938968 )
969+ # Cross-validate first_treat_col against observed first-positive-
970+ # dose period for every unit. A mislabeled cohort column would
971+ # otherwise silently select the wrong cohort as F_last and return
972+ # event-study estimates for the wrong units. Contract:
973+ # - declared first_treat == 0: unit must have D == 0 at all t
974+ # (never-treated)
975+ # - declared first_treat == F_g > 0: unit's first period with
976+ # D > 0 must equal F_g
977+ df_for_check = data .sort_values ([unit_col , time_col ])
978+ pos_rows = df_for_check .loc [df_for_check [dose_col ] > 0 ]
979+ actual_first_pos = pos_rows .groupby (unit_col )[time_col ].first ()
980+ declared_ft = df_for_check .groupby (unit_col )[first_treat_col ].first ()
981+ n_mismatch = 0
982+ example_mismatch : Optional [Tuple [Any , Any , Any ]] = None
983+ for u , declared in declared_ft .items ():
984+ actual = actual_first_pos .get (u , None )
985+ if declared == 0 :
986+ if actual is not None :
987+ n_mismatch += 1
988+ if example_mismatch is None :
989+ example_mismatch = (u , declared , actual )
990+ else :
991+ if actual is None or actual != declared :
992+ n_mismatch += 1
993+ if example_mismatch is None :
994+ example_mismatch = (u , declared , actual )
995+ if n_mismatch > 0 :
996+ u , declared , actual = example_mismatch # type: ignore[misc]
997+ raise ValueError (
998+ f"first_treat_col={ first_treat_col !r} disagrees with the "
999+ f"observed dose path for { n_mismatch } unit(s). Example: "
1000+ f"unit={ u !r} declares first_treat={ declared !r} but the "
1001+ f"unit's first period with D>0 is { actual !r} "
1002+ f"(None means never-treated). A mislabeled cohort column "
1003+ f"would silently select the wrong cohort as F_last in the "
1004+ f"last-cohort auto-filter. Fix the first_treat_col values "
1005+ f"to equal each unit's first positive-dose period (or 0 "
1006+ f"for never-treated) before calling fit()."
1007+ )
9391008 # Identify cohorts (nonzero first_treat values).
9401009 # Use pd.unique to preserve dtype; sort with a stable key.
9411010 ft_unique = list (pd .unique (ft_raw ))
@@ -1015,8 +1084,9 @@ def _validate_had_panel_event_study(
10151084 )
10161085
10171086 # Balanced panel on the (possibly-filtered) data: every unit appears
1018- # exactly once per period.
1019- counts = data_filtered .groupby ([unit_col , time_col ]).size ()
1087+ # exactly once per period. ``observed=False`` preserves current
1088+ # behavior on categorical time columns (pandas' default is changing).
1089+ counts = data_filtered .groupby ([unit_col , time_col ], observed = False ).size ()
10201090 if (counts != 1 ).any ():
10211091 n_bad = int ((counts != 1 ).sum ())
10221092 raise ValueError (
@@ -1057,36 +1127,35 @@ def _validate_had_panel_event_study(
10571127 f"zero dose; there is no treatment to estimate."
10581128 )
10591129
1060- # Sort by natural ordering on the time column dtype. Tuple key
1061- # ``(x is None, x)`` places None at the end and sorts the rest by
1062- # natural order (works for int/float/str/datetime when the dtype is
1063- # homogeneous; mixed dtypes would raise at comparison time, which is
1064- # the desired failure mode).
1065- t_pre_list = sorted (t_pre_list_unsorted , key = lambda x : (x is None , x ))
1066- t_post_list = sorted (t_post_list_unsorted , key = lambda x : (x is None , x ))
1130+ # Sort by natural ordering on the time column dtype. For ordered
1131+ # categorical dtypes, use the declared category order (since
1132+ # ``list(categorical)`` strips the ordered semantics and falls back
1133+ # to string comparison). For numeric / datetime, use natural Python
1134+ # order. Tuple key places None at the end.
1135+ if isinstance (time_dtype , pd .CategoricalDtype ) and time_dtype .ordered :
1136+ _cat_order = {c : i for i , c in enumerate (time_dtype .categories )}
1137+
1138+ def _sort_key (x : Any ) -> Tuple [bool , Any ]:
1139+ return (x is None , _cat_order .get (x , len (_cat_order )))
1140+
1141+ else :
1142+
1143+ def _sort_key (x : Any ) -> Tuple [bool , Any ]:
1144+ return (x is None , x )
1145+
1146+ t_pre_list = sorted (t_pre_list_unsorted , key = _sort_key )
1147+ t_post_list = sorted (t_post_list_unsorted , key = _sort_key )
10671148
10681149 # Contiguity check: all pre < all post in the natural ordering.
10691150 # The HAD dose invariant requires a single transition from all-zero
10701151 # to any-nonzero; interleaved pre/post periods indicate a malformed
10711152 # panel (e.g., dose going back to zero after treatment, or mixing
1072- # never-treated units with out-of-order labels).
1153+ # never-treated units with out-of-order labels). Uses ``_sort_key``
1154+ # so ordered categoricals respect their declared category order.
10731155 if t_pre_list and t_post_list :
10741156 max_pre = t_pre_list [- 1 ]
10751157 min_post = t_post_list [0 ]
1076- # Check all pre-periods are less than all post-periods via the
1077- # natural order. If types are comparable, direct comparison works;
1078- # otherwise fall back to the sorted-key view.
1079- try :
1080- contiguous = max_pre < min_post
1081- except TypeError :
1082- # Mixed incomparable dtypes (e.g., None vs int after removing
1083- # None above). Fall back to sorted-position check.
1084- contiguous = True
1085- for pre_p in t_pre_list :
1086- for post_p in t_post_list :
1087- if not (pre_p < post_p ):
1088- contiguous = False
1089- break
1158+ contiguous = _sort_key (max_pre ) < _sort_key (min_post )
10901159 if not contiguous :
10911160 raise ValueError (
10921161 f"HAD dose invariant violated: pre-periods (all D=0) "
@@ -1318,7 +1387,23 @@ def _aggregate_multi_period_first_differences(
13181387 equal to the LAST pre-period).
13191388 """
13201389 df = data .sort_values ([unit_col , time_col ]).reset_index (drop = True )
1321- all_periods = sorted (t_pre_list + t_post_list , key = lambda x : (x is None , x ))
1390+ # Period sort respects ordered categorical dtypes (matches
1391+ # ``_validate_had_panel_event_study``). The validator already
1392+ # enforces a numeric / datetime / ordered-categorical dtype on the
1393+ # event-study path, so ``_sort_key`` lookups are well-defined here.
1394+ time_dtype = data [time_col ].dtype
1395+ if isinstance (time_dtype , pd .CategoricalDtype ) and time_dtype .ordered :
1396+ _cat_order = {c : i for i , c in enumerate (time_dtype .categories )}
1397+
1398+ def _sort_key (x : Any ) -> Tuple [bool , Any ]:
1399+ return (x is None , _cat_order .get (x , len (_cat_order )))
1400+
1401+ else :
1402+
1403+ def _sort_key (x : Any ) -> Tuple [bool , Any ]:
1404+ return (x is None , x )
1405+
1406+ all_periods = sorted (t_pre_list + t_post_list , key = _sort_key )
13221407 # Event-time mapping: natural rank of each period relative to F.
13231408 F_idx = all_periods .index (F )
13241409 period_to_event_time : Dict [Any , int ] = {p : (i - F_idx ) for i , p in enumerate (all_periods )}
@@ -1604,9 +1689,16 @@ class HeterogeneousAdoptionDiD:
16041689 Weighted-Average-Slope (WAS) estimator with three design-dispatch
16051690 paths: Design 1' (continuous-at-zero), Design 1 continuous-near-
16061691 d_lower, and Design 1 mass-point (2SLS sample-average per paper
1607- Section 3.2.4). Phase 2a ships the single-period path only; the
1608- multi-period event-study extension (Appendix B.2) is queued for
1609- Phase 2b.
1692+ Section 3.2.4). Two aggregation modes:
1693+
1694+ - ``aggregate="overall"`` (Phase 2a, default) returns a single-period
1695+ :class:`HeterogeneousAdoptionDiDResults` on a two-period panel.
1696+ - ``aggregate="event_study"`` (Phase 2b, paper Appendix B.2) returns
1697+ a :class:`HeterogeneousAdoptionDiDEventStudyResults` with per-
1698+ event-time WAS estimates on a multi-period panel, using a uniform
1699+ ``F-1`` anchor and pointwise CIs per horizon. Staggered-timing
1700+ panels auto-filter to the last-treatment cohort plus never-treated
1701+ units (paper Appendix B.2 prescription).
16101702
16111703 Parameters
16121704 ----------
0 commit comments