Skip to content

Commit 3a708fc

Browse files
igerberclaude
andcommitted
Address PR #110 feedback round 10: absorbing-state gap detection and n_post_periods fix
P1: Fix absorbing-state validation to catch 1→0 violations across missing period gaps. The old vectorized check only looked at adjacent periods, missing violations like D[2]=1, missing [3,4], D[5]=0. Now checks each unit's observed D sequence. P3: Fix n_post_periods to count periods with actual D=1 observations, matching the docstring claim, rather than calendar periods from first treatment. Also updates methodology registry documentation for both changes. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent ebd9cb2 commit 3a708fc

3 files changed

Lines changed: 111 additions & 18 deletions

File tree

diff_diff/trop.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -912,23 +912,23 @@ def fit(
912912

913913
# Validate D is monotonic non-decreasing per unit (absorbing state)
914914
# D[t, i] must satisfy: once D=1, it must stay 1 for all subsequent periods
915-
# Vectorized check: diff(D, axis=0) should never be negative
916-
# Issue 3 fix: Only check transitions where BOTH periods are observed
917-
d_diff = np.diff(D, axis=0)
918-
919-
# Valid transition mask: neither the current nor next period is missing
920-
# missing_mask[:-1] = source period missing, missing_mask[1:] = target period missing
921-
valid_transition = ~(missing_mask[:-1] | missing_mask[1:])
922-
923-
# Only flag violations where both periods are observed
924-
violations = (d_diff < 0) & valid_transition
925-
926-
if np.any(violations):
927-
# Find which units violate the absorbing state constraint
928-
violating_units_mask = np.any(violations, axis=0)
929-
violating_unit_ids = [all_units[i] for i in np.where(violating_units_mask)[0]]
915+
# Issue 3 fix (round 10): Check each unit's OBSERVED D sequence for monotonicity
916+
# This catches 1→0 violations that span missing period gaps
917+
# Example: D[2]=1, missing [3,4], D[5]=0 is a real violation even though
918+
# adjacent period transitions don't show it (the gap hides the transition)
919+
violating_units = []
920+
for unit_idx in range(n_units):
921+
# Get observed D values for this unit (where not missing)
922+
observed_mask = ~missing_mask[:, unit_idx]
923+
observed_d = D[observed_mask, unit_idx]
924+
925+
# Check if observed sequence is monotonically non-decreasing
926+
if len(observed_d) > 1 and np.any(np.diff(observed_d) < 0):
927+
violating_units.append(all_units[unit_idx])
928+
929+
if violating_units:
930930
raise ValueError(
931-
f"Treatment indicator is not an absorbing state for units: {violating_unit_ids}. "
931+
f"Treatment indicator is not an absorbing state for units: {violating_units}. "
932932
f"D[t, unit] must be monotonic non-decreasing (once treated, always treated). "
933933
f"If this is event-study style data, convert to absorbing state: "
934934
f"D[t, i] = 1 for all t >= first treatment period."
@@ -960,7 +960,9 @@ def fit(
960960
raise ValueError("Could not infer post-treatment periods from D matrix")
961961

962962
n_pre_periods = first_treat_period
963-
n_post_periods = n_periods - first_treat_period
963+
# Count periods where D=1 is actually observed (matches docstring)
964+
# Per docstring: "Number of post-treatment periods (periods with D=1 observations)"
965+
n_post_periods = int(np.sum(np.any(D[first_treat_period:, :] == 1, axis=1)))
964966

965967
if n_pre_periods < 2:
966968
raise ValueError("Need at least 2 pre-treatment periods")

docs/methodology/REGISTRY.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,8 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
550550
- Handling: Raises `ValueError` with list of violating unit IDs and remediation guidance
551551
- Error message includes: "convert to absorbing state: D[t, i] = 1 for all t >= first treatment period"
552552
- **Rationale**: Event-style D (0→1→0) silently biases ATT; runtime validation prevents misuse
553-
- **Unbalanced panels**: Missing unit-period observations are allowed. Monotonicity validation only checks transitions between observed periods. A unit with D=1 at t=3 and missing data at t=5 is NOT flagged as a violation (the apparent 1→0 transition is due to missing data, not a real violation).
553+
- **Unbalanced panels**: Missing unit-period observations are allowed. Monotonicity validation checks each unit's *observed* D sequence for monotonicity, which correctly catches 1→0 violations that span missing period gaps (e.g., D[2]=1, missing [3,4], D[5]=0 is detected as a violation even though the gap hides the transition in adjacent-period checks).
554+
- **n_post_periods metadata**: Counts periods where D=1 is actually observed (at least one unit has D=1), not calendar periods from first treatment. In unbalanced panels where treated units are missing in some post-treatment periods, only periods with observed D=1 values are counted.
554555
- Wrong D specification: if user provides event-style D (only first treatment period),
555556
the absorbing-state validation will raise ValueError with helpful guidance
556557
- **LOOCV failure metadata**: When LOOCV fits fail in the Rust backend, the first failed observation coordinates (t, i) are returned to Python for informative warning messages

tests/test_trop.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2565,3 +2565,93 @@ def test_infinity_grid_values_with_final_score_computation(self, simple_panel_da
25652565
assert np.isfinite(results.loocv_score), (
25662566
"LOOCV score should be finite when computed with converted inf values"
25672567
)
2568+
2569+
def test_violation_across_missing_gap_caught(self):
2570+
"""Test that 1→0 violations spanning missing periods are caught.
2571+
2572+
Issue: If periods [3, 4] are missing and D[2]=1, D[5]=0, this is a
2573+
real violation that must be detected even though the adjacent
2574+
period transitions don't show it (the gap hides the transition).
2575+
2576+
PR #110 round 10 fix: Check each unit's observed D sequence for
2577+
monotonicity, not just adjacent periods in the full time grid.
2578+
"""
2579+
data = []
2580+
2581+
# Unit 0: control, complete
2582+
for t in range(6):
2583+
data.append({"unit": 0, "period": t, "outcome": 10.0 + t, "treated": 0})
2584+
2585+
# Unit 1: VIOLATION across gap
2586+
# Observed at [0, 1, 2, 5], missing [3, 4]
2587+
# D[2]=1, D[5]=0 is a real violation spanning the gap
2588+
for t in [0, 1, 2, 5]:
2589+
treated = 1 if t == 2 else 0 # Only treated at period 2
2590+
data.append({"unit": 1, "period": t, "outcome": 10.0 + t, "treated": treated})
2591+
2592+
# Unit 2: control, complete
2593+
for t in range(6):
2594+
data.append({"unit": 2, "period": t, "outcome": 10.0 + t, "treated": 0})
2595+
2596+
df = pd.DataFrame(data)
2597+
trop_est = TROP(
2598+
lambda_time_grid=[0.0],
2599+
lambda_unit_grid=[0.0],
2600+
lambda_nn_grid=[0.0],
2601+
n_bootstrap=5,
2602+
)
2603+
2604+
with pytest.raises(ValueError, match="absorbing state"):
2605+
trop_est.fit(
2606+
df,
2607+
outcome="outcome",
2608+
treatment="treated",
2609+
unit="unit",
2610+
time="period",
2611+
)
2612+
2613+
def test_n_post_periods_counts_observed_treatment(self):
2614+
"""Test n_post_periods counts periods with actual D=1 observations.
2615+
2616+
Per docstring: "Number of post-treatment periods (periods with D=1 observations)"
2617+
2618+
This tests that n_post_periods reflects periods where treatment is
2619+
actually observed, not just calendar periods from first treatment.
2620+
"""
2621+
data = []
2622+
2623+
# Create panel where period 5 exists but has no D=1 observations
2624+
# (all treated units are missing at period 5)
2625+
for unit in range(3):
2626+
for period in range(6):
2627+
# Units 1, 2 are treated from period 3, but missing at period 5
2628+
if unit in [1, 2] and period == 5:
2629+
continue # Skip - creates unbalanced panel
2630+
treated = 1 if (unit in [1, 2] and period >= 3) else 0
2631+
data.append({
2632+
"unit": unit,
2633+
"period": period,
2634+
"outcome": 10.0 + period,
2635+
"treated": treated,
2636+
})
2637+
2638+
df = pd.DataFrame(data)
2639+
trop_est = TROP(
2640+
lambda_time_grid=[0.0],
2641+
lambda_unit_grid=[0.0],
2642+
lambda_nn_grid=[0.0],
2643+
n_bootstrap=5,
2644+
seed=42,
2645+
)
2646+
results = trop_est.fit(
2647+
df,
2648+
outcome="outcome",
2649+
treatment="treated",
2650+
unit="unit",
2651+
time="period",
2652+
)
2653+
2654+
# Periods with D=1 observations: 3, 4 (not 5 - missing for treated units)
2655+
assert results.n_post_periods == 2, (
2656+
f"Expected 2 post-periods with D=1, got {results.n_post_periods}"
2657+
)

0 commit comments

Comments
 (0)