Skip to content

Commit b519f67

Browse files
igerberclaude
andcommitted
Add unified LOOCV for TROP joint method with Rust acceleration
Implements proper leave-one-out cross-validation for the TROP joint method, matching the twostep method's approach per the paper's Equation 5. Adds Rust backend acceleration for parallel LOOCV grid search and bootstrap variance estimation (5-15x speedup). Changes: - rust/src/trop.rs: Add loocv_grid_search_joint(), bootstrap_trop_variance_joint(), and supporting joint model fitting functions - diff_diff/trop.py: Update _fit_joint() to use LOOCV, add _loocv_score_joint() Python fallback, integrate Rust acceleration for bootstrap variance - diff_diff/_backend.py: Export new Rust functions - docs/methodology/REGISTRY.md: Document unified LOOCV approach for joint method - tests/: Add comprehensive tests for joint method LOOCV and Rust/Python parity Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent addd96f commit b519f67

12 files changed

Lines changed: 2329 additions & 47 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,6 @@ scripts/
8080
# Launch directories (local only)
8181
launch/
8282
launch-video/
83+
84+
# Reference implementations (local only)
85+
trop_avg_ref/

CLAUDE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ test bootstrap::tests::test_webb_mean_approx_zero ... ok
138138
- `TROPResults` - Results with ATT, factors, loadings, unit/time weights
139139
- `trop()` - Convenience function for quick estimation
140140
- Three robustness components: factor adjustment, unit weights, time weights
141+
- Two estimation methods via `method` parameter:
142+
- `"twostep"` (default): Per-observation model fitting (Algorithm 2 of paper)
143+
- `"joint"`: Weighted least squares with homogeneous treatment effect (faster)
141144
- Automatic rank selection via cross-validation, information criterion, or elbow detection
142145
- Bootstrap and placebo-based variance estimation
143146

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,6 +1267,7 @@ trop = TROP(
12671267

12681268
```python
12691269
TROP(
1270+
method='twostep', # Estimation method: 'twostep' (default) or 'joint'
12701271
lambda_time_grid=None, # Time decay grid (default: [0, 0.1, 0.5, 1, 2, 5])
12711272
lambda_unit_grid=None, # Unit distance grid (default: [0, 0.1, 0.5, 1, 2, 5])
12721273
lambda_nn_grid=None, # Nuclear norm grid (default: [0, 0.01, 0.1, 1, 10])
@@ -1279,6 +1280,10 @@ TROP(
12791280
)
12801281
```
12811282

1283+
**Estimation methods:**
1284+
- `'twostep'` (default): Per-observation model fitting following Algorithm 2 of the paper. Computes observation-specific weights and fits a model for each treated observation, then averages the individual treatment effects. More flexible but computationally intensive.
1285+
- `'joint'`: Joint weighted least squares optimization. Estimates a single scalar treatment effect τ along with fixed effects and optional low-rank factor adjustment. Faster but assumes homogeneous treatment effects.
1286+
12821287
**Convenience function:**
12831288

12841289
```python

diff_diff/_backend.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
project_simplex as _rust_project_simplex,
2424
solve_ols as _rust_solve_ols,
2525
compute_robust_vcov as _rust_compute_robust_vcov,
26-
# TROP estimator acceleration
26+
# TROP estimator acceleration (twostep method)
2727
compute_unit_distance_matrix as _rust_unit_distance_matrix,
2828
loocv_grid_search as _rust_loocv_grid_search,
2929
bootstrap_trop_variance as _rust_bootstrap_trop_variance,
30+
# TROP estimator acceleration (joint method)
31+
loocv_grid_search_joint as _rust_loocv_grid_search_joint,
32+
bootstrap_trop_variance_joint as _rust_bootstrap_trop_variance_joint,
3033
)
3134
_rust_available = True
3235
except ImportError:
@@ -36,10 +39,13 @@
3639
_rust_project_simplex = None
3740
_rust_solve_ols = None
3841
_rust_compute_robust_vcov = None
39-
# TROP estimator acceleration
42+
# TROP estimator acceleration (twostep method)
4043
_rust_unit_distance_matrix = None
4144
_rust_loocv_grid_search = None
4245
_rust_bootstrap_trop_variance = None
46+
# TROP estimator acceleration (joint method)
47+
_rust_loocv_grid_search_joint = None
48+
_rust_bootstrap_trop_variance_joint = None
4349

4450
# Determine final backend based on environment variable and availability
4551
if _backend_env == 'python':
@@ -50,10 +56,13 @@
5056
_rust_project_simplex = None
5157
_rust_solve_ols = None
5258
_rust_compute_robust_vcov = None
53-
# TROP estimator acceleration
59+
# TROP estimator acceleration (twostep method)
5460
_rust_unit_distance_matrix = None
5561
_rust_loocv_grid_search = None
5662
_rust_bootstrap_trop_variance = None
63+
# TROP estimator acceleration (joint method)
64+
_rust_loocv_grid_search_joint = None
65+
_rust_bootstrap_trop_variance_joint = None
5766
elif _backend_env == 'rust':
5867
# Force Rust mode - fail if not available
5968
if not _rust_available:
@@ -73,8 +82,11 @@
7382
'_rust_project_simplex',
7483
'_rust_solve_ols',
7584
'_rust_compute_robust_vcov',
76-
# TROP estimator acceleration
85+
# TROP estimator acceleration (twostep method)
7786
'_rust_unit_distance_matrix',
7887
'_rust_loocv_grid_search',
7988
'_rust_bootstrap_trop_variance',
89+
# TROP estimator acceleration (joint method)
90+
'_rust_loocv_grid_search_joint',
91+
'_rust_bootstrap_trop_variance_joint',
8092
]

0 commit comments

Comments
 (0)