Skip to content

Commit 6273674

Browse files
igerberclaude
andcommitted
Address PR #110 feedback round 8: three LOOCV/validation fixes
Issue 1: Final LOOCV score infinity conversion - Convert inf values before calling loocv_score_for_params in Rust - Ensures final score uses same converted values that LOOCV evaluated - λ_time/λ_unit=∞ → 0.0, λ_nn=∞ → 1e10 Issue 2: Rust LOOCV failed observation metadata - Extend loocv_score_for_params to return Option<(usize, usize)> - Track first failed observation (t, i) for informative warnings - Python now includes coordinates in LOOCV failure warnings Issue 3: D matrix validation for unbalanced panels - Track missing values before fillna(0) with missing_mask - Only validate monotonicity between observed periods - Missing data no longer triggers false absorbing-state violations Tests: 4 new tests in TestPR110FeedbackRound8 class Docs: Updated REGISTRY.md with unbalanced panel support Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent f52b100 commit 6273674

4 files changed

Lines changed: 296 additions & 21 deletions

File tree

diff_diff/trop.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -900,21 +900,32 @@ def fit(
900900
.reindex(index=all_periods, columns=all_units)
901901
.values
902902
)
903-
D = (
903+
904+
# For D matrix, track missing values BEFORE fillna to support unbalanced panels
905+
# Issue 3 fix: Missing observations should not trigger spurious violations
906+
D_raw = (
904907
data.pivot(index=time, columns=unit, values=treatment)
905908
.reindex(index=all_periods, columns=all_units)
906-
.fillna(0)
907-
.astype(int)
908-
.values
909909
)
910+
missing_mask = pd.isna(D_raw).values # True where originally missing
911+
D = D_raw.fillna(0).astype(int).values
910912

911913
# Validate D is monotonic non-decreasing per unit (absorbing state)
912914
# D[t, i] must satisfy: once D=1, it must stay 1 for all subsequent periods
913915
# Vectorized check: diff(D, axis=0) should never be negative
916+
# Issue 3 fix: Only check transitions where BOTH periods are observed
914917
d_diff = np.diff(D, axis=0)
915-
if np.any(d_diff < 0):
918+
919+
# Valid transition mask: neither the current nor next period is missing
920+
# missing_mask[:-1] = source period missing, missing_mask[1:] = target period missing
921+
valid_transition = ~(missing_mask[:-1] | missing_mask[1:])
922+
923+
# Only flag violations where both periods are observed
924+
violations = (d_diff < 0) & valid_transition
925+
926+
if np.any(violations):
916927
# Find which units violate the absorbing state constraint
917-
violating_units_mask = np.any(d_diff < 0, axis=0)
928+
violating_units_mask = np.any(violations, axis=0)
918929
violating_unit_ids = [all_units[i] for i in np.where(violating_units_mask)[0]]
919930
raise ValueError(
920931
f"Treatment indicator is not an absorbing state for units: {violating_unit_ids}. "
@@ -977,31 +988,43 @@ def fit(
977988
lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
978989
lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
979990

980-
best_lt, best_lu, best_ln, best_score, n_valid, n_attempted = _rust_loocv_grid_search(
991+
result = _rust_loocv_grid_search(
981992
Y, D.astype(np.float64), control_mask_u8,
982993
time_dist_matrix,
983994
lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
984995
self.max_loocv_samples, self.max_iter, self.tol,
985996
self.seed if self.seed is not None else 0
986997
)
998+
# Unpack result - 7 values including optional first_failed_obs
999+
best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed_obs = result
9871000
# Only accept finite scores - infinite means all fits failed
9881001
if np.isfinite(best_score):
9891002
best_lambda = (best_lt, best_lu, best_ln)
9901003
# else: best_lambda stays None, triggering defaults fallback
9911004
# Emit warnings consistent with Python implementation
9921005
if n_valid == 0:
1006+
# Include failed observation coordinates if available (Issue 2 fix)
1007+
obs_info = ""
1008+
if first_failed_obs is not None:
1009+
t_idx, i_idx = first_failed_obs
1010+
obs_info = f" First failure at observation ({t_idx}, {i_idx})."
9931011
warnings.warn(
9941012
f"LOOCV: All {n_attempted} fits failed for "
9951013
f"λ=({best_lt}, {best_lu}, {best_ln}). "
996-
"Returning infinite score.",
1014+
f"Returning infinite score.{obs_info}",
9971015
UserWarning
9981016
)
9991017
elif n_attempted > 0 and (n_attempted - n_valid) > 0.1 * n_attempted:
10001018
n_failed = n_attempted - n_valid
1019+
# Include failed observation coordinates if available
1020+
obs_info = ""
1021+
if first_failed_obs is not None:
1022+
t_idx, i_idx = first_failed_obs
1023+
obs_info = f" First failure at observation ({t_idx}, {i_idx})."
10011024
warnings.warn(
10021025
f"LOOCV: {n_failed}/{n_attempted} fits failed for "
10031026
f"λ=({best_lt}, {best_lu}, {best_ln}). "
1004-
"This may indicate numerical instability.",
1027+
f"This may indicate numerical instability.{obs_info}",
10051028
UserWarning
10061029
)
10071030
except Exception as e:

docs/methodology/REGISTRY.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,8 +550,10 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
550550
- Handling: Raises `ValueError` with list of violating unit IDs and remediation guidance
551551
- Error message includes: "convert to absorbing state: D[t, i] = 1 for all t >= first treatment period"
552552
- **Rationale**: Event-style D (0→1→0) silently biases ATT; runtime validation prevents misuse
553+
- **Unbalanced panels**: Missing unit-period observations are allowed. Monotonicity validation only checks transitions between observed periods. A unit with D=1 at t=3 and missing data at t=5 is NOT flagged as a violation (the apparent 1→0 transition is due to missing data, not a real violation).
553554
- Wrong D specification: if user provides event-style D (only first treatment period),
554555
the absorbing-state validation will raise ValueError with helpful guidance
556+
- **LOOCV failure metadata**: When LOOCV fits fail in the Rust backend, the first failed observation coordinates (t, i) are returned to Python for informative warning messages
555557

556558
**Reference implementation(s):**
557559
- Authors' replication code (forthcoming)
@@ -566,6 +568,7 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
566568
- [x] ATT averages over all D==1 cells (general assignment patterns)
567569
- [x] No post_periods parameter (D matrix determines treatment timing)
568570
- [x] D matrix semantics documented (absorbing state, not event indicator)
571+
- [x] Unbalanced panels supported (missing observations don't trigger false violations)
569572

570573
---
571574

rust/src/trop.rs

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ fn univariate_loocv_search(
220220
},
221221
};
222222

223-
let (score, _) = loocv_score_for_params(
223+
let (score, _, _) = loocv_score_for_params(
224224
y, d, control_mask, time_dist, control_obs,
225225
lambda_time, lambda_unit, lambda_nn,
226226
max_iter, tol,
@@ -318,9 +318,10 @@ fn cycling_parameter_search(
318318
/// * `seed` - Random seed for subsampling
319319
///
320320
/// # Returns
321-
/// (best_lambda_time, best_lambda_unit, best_lambda_nn, best_score, n_valid, n_attempted)
321+
/// (best_lambda_time, best_lambda_unit, best_lambda_nn, best_score, n_valid, n_attempted, first_failed_obs)
322322
/// where n_valid and n_attempted are the counts for the best parameter combination,
323323
/// allowing Python to emit warnings when >10% of fits fail.
324+
/// first_failed_obs is Some((t, i)) if a fit failed during final score computation, None otherwise.
324325
#[pyfunction]
325326
#[pyo3(signature = (y, d, control_mask, time_dist_matrix, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_loocv_samples, max_iter, tol, seed))]
326327
#[allow(clippy::too_many_arguments)]
@@ -337,7 +338,7 @@ pub fn loocv_grid_search<'py>(
337338
max_iter: usize,
338339
tol: f64,
339340
seed: u64,
340-
) -> PyResult<(f64, f64, f64, f64, usize, usize)> {
341+
) -> PyResult<(f64, f64, f64, f64, usize, usize, Option<(usize, usize)>)> {
341342
let y_arr = y.as_array();
342343
let d_arr = d.as_array();
343344
let control_mask_arr = control_mask.as_array();
@@ -383,14 +384,24 @@ pub fn loocv_grid_search<'py>(
383384
max_iter, tol, 10,
384385
);
385386

386-
// Compute final score
387-
let (best_score, n_valid) = loocv_score_for_params(
387+
// Convert infinity values BEFORE computing final score (Issue 1 fix)
388+
// Per paper Equations 2-3:
389+
// - λ_time/λ_unit=∞ → uniform weights → use 0.0
390+
// - λ_nn=∞ → infinite penalty → L≈0 (factor model disabled) → use 1e10
391+
// This ensures final score computation matches what LOOCV evaluated.
392+
let best_time_eff = if best_time.is_infinite() { 0.0 } else { best_time };
393+
let best_unit_eff = if best_unit.is_infinite() { 0.0 } else { best_unit };
394+
let best_nn_eff = if best_nn.is_infinite() { 1e10 } else { best_nn };
395+
396+
// Compute final score with converted values
397+
let (best_score, n_valid, first_failed) = loocv_score_for_params(
388398
&y_arr, &d_arr, &control_mask_arr, &time_dist_arr, &control_obs,
389-
best_time, best_unit, best_nn,
399+
best_time_eff, best_unit_eff, best_nn_eff,
390400
max_iter, tol,
391401
);
392402

393-
Ok((best_time, best_unit, best_nn, best_score, n_valid, n_attempted))
403+
// Return ORIGINAL grid values (for user visibility) but score computed with converted
404+
Ok((best_time, best_unit, best_nn, best_score, n_valid, n_attempted, first_failed))
394405
}
395406

396407
/// Get sampled control observations for LOOCV.
@@ -429,7 +440,8 @@ fn get_control_observations(
429440
/// Compute LOOCV score for a specific parameter combination.
430441
///
431442
/// # Returns
432-
/// (score, n_valid) - the LOOCV score and number of successful fits
443+
/// (score, n_valid, first_failed_obs) - the LOOCV score, number of successful fits,
444+
/// and the first failed observation (t, i) if any fit failed, None otherwise.
433445
#[allow(clippy::too_many_arguments)]
434446
fn loocv_score_for_params(
435447
y: &ArrayView2<f64>,
@@ -442,7 +454,7 @@ fn loocv_score_for_params(
442454
lambda_nn: f64,
443455
max_iter: usize,
444456
tol: f64,
445-
) -> (f64, usize) {
457+
) -> (f64, usize, Option<(usize, usize)>) {
446458
let n_periods = y.nrows();
447459
let n_units = y.ncols();
448460

@@ -484,17 +496,18 @@ fn loocv_score_for_params(
484496
None => {
485497
// Per Equation 5: Q(λ) must sum over ALL D==0 cells
486498
// Any failure means this λ cannot produce valid estimates for all cells
487-
return (f64::INFINITY, n_valid);
499+
// Return the failed observation (t, i) for warning metadata
500+
return (f64::INFINITY, n_valid, Some((t, i)));
488501
}
489502
}
490503
}
491504

492505
if n_valid == 0 {
493-
(f64::INFINITY, 0)
506+
(f64::INFINITY, 0, None)
494507
} else {
495508
// Return SUM of squared pseudo-treatment effects per Equation 5 (page 8):
496509
// Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
497-
(tau_sq_sum, n_valid)
510+
(tau_sq_sum, n_valid, None)
498511
}
499512
}
500513

0 commit comments

Comments
 (0)