Skip to content

Fix/trop paper conformance#84

Merged
igerber merged 3 commits into
mainfrom
fix/trop-paper-conformance
Jan 19, 2026
Merged

Fix/trop paper conformance#84
igerber merged 3 commits into
mainfrom
fix/trop-paper-conformance

Conversation

@igerber
Copy link
Copy Markdown
Owner

@igerber igerber commented Jan 19, 2026

No description provided.

igerber and others added 2 commits January 19, 2026 11:56
Address four issues identified in the implementation assessment:

Issue A (Critical): Expand control unit selection
- Modified _compute_observation_weights() to use D[t,j]==0 instead of
  only never-treated units, allowing pre-treatment observations of
  eventually-treated units to serve as controls

Issue B (Moderate): Per-observation distance excludes target period
- Distance computation now correctly excludes target period t as
  specified in Equation 3 of the paper (1{u ≠ t})
- Added compute_unit_distance_for_obs() to Rust backend

Issue C (Moderate-High): Weighted nuclear norm solver
- Implemented _weighted_nuclear_norm_solve() using iterative weighted
  soft-impute (Mazumder et al. 2010)
- Properly handles W=0 observations to prevent L from absorbing
  treatment effects

Issue D (Minor): Stratified bootstrap sampling
- Modified _bootstrap_variance() to sample control and treated units
  separately, preserving treatment ratio per Algorithm 3

Also updated:
- Tutorial notebook with precise paper terminology
- Added TestPaperConformanceFixes test class with 5 new tests
- All 38 TROP tests pass

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Updates compute_weight_matrix:
- Allow ALL units in weight computation (not just untreated at target)
- The (1 - D_js) masking in the loss handles treatment exclusion
- Normalize time and unit weights to sum to 1 (probability weights)
- Distance still excludes target period per Equation 3

Updates estimate_model:
- Use weighted proximal gradient for L update instead of direct soft-thresholding
- L ← prox_{η·λ_nn·||·||_*}(L + η·(W ⊙ (R - L)))
- Step size η ≤ 1/max(W) for convergence
- For W=0 cells (treated obs), L remains unchanged

These changes align the Rust backend with the Python implementation
which was already fixed in the previous commit.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@igerber
Copy link
Copy Markdown
Owner Author

igerber commented Jan 19, 2026

Code Review: PR #84 - Fix/trop paper conformance

Author: igerber
Branch: fix/trop-paper-conformance -> main
Files Changed: 4


Executive Summary

This PR addresses four conformance issues between the TROP estimator implementation and the original paper (Athey, Imbens, Qu & Viviano, 2025). The fixes are well-documented with clear references to paper equations and algorithms. The changes are methodologically correct and improve the estimator's alignment with the paper specification. Test coverage is comprehensive.


Part 1: Methodology Review

Issue A: Control Set Definition (Lines 959-1070 in trop.py)

Paper Reference: Equation 2 (page 7)

Change: The control set for weight computation now includes pre-treatment observations of eventually-treated units, not just never-treated units.

Assessment: Correct. The paper's objective sums over ALL observations where (1 - W_js) is non-zero. Since W_js = 0 for pre-treatment periods of eventually-treated units, these should be included.

Code Review:

# Valid control units at time t: D[t, j] == 0
valid_control_at_t = D_stored[t, :] == 0

This correctly identifies units not treated at time t, including pre-treatment periods of eventually-treated units.

Issue B: Distance Computation Excludes Target Period (Lines 615-660 in trop.py)

Paper Reference: Equation 3 (page 7) specifies 1{u != t} indicator

Change: _compute_unit_distance_for_obs() now excludes the target period when computing pairwise distances.

Assessment: Correct. The indicator function 1{u != t} in Equation 3 explicitly excludes the target period from the distance computation to avoid contamination.

Code Review:

valid = np.ones(n_periods, dtype=bool)
valid[target_period] = False  # Exclude target period

Properly implements the paper specification.

Issue C: Weighted Nuclear Norm Solver (Lines 1133-1226 in trop.py)

Paper Reference: Equation 2 (page 7)

Change: New _weighted_nuclear_norm_solve() method implements proper weighted proximal gradient descent for the nuclear norm optimization.

Assessment: Correct. The previous implementation used unweighted soft-thresholding. The paper's objective includes observation weights in the squared loss term, requiring a weighted soft-impute approach.

Algorithm Implementation:

# Proximal gradient iteration: L_{k+1} = prox_{lambda}(L_k + W * (R - L_k))
gradient_step = L + W_norm * (R_masked - L)
L = self._soft_threshold_svd(gradient_step, lambda_nn)

This follows the standard proximal gradient algorithm for weighted matrix completion (Mazumder et al. 2010).

Minor Concern: The step size is implicitly 1 with normalized weights. The Rust implementation computes eta = 1/max(W) for convergence guarantees. The Python implementation normalizes weights instead. Both approaches are valid but differ slightly.

Issue D: Stratified Bootstrap (Lines 1544-1578 in trop.py)

Paper Reference: Algorithm 3 (page 27)

Change: Bootstrap now samples control and treated units separately (stratified sampling) to preserve the treatment ratio.

Assessment: Correct. The paper specifies sampling N_0 control rows and N_1 treated rows separately.

Code Review:

sampled_control = rng.choice(control_units, size=n_control_units, replace=True)
sampled_treated = rng.choice(treated_units, size=n_treated_units, replace=True)
sampled_units = np.concatenate([sampled_control, sampled_treated])

Properly implements Algorithm 3's stratified sampling.

Rust Backend Consistency

The Rust implementation in rust/src/trop.rs mirrors all four fixes:

  • Issue A+B: compute_weight_matrix() now receives y and d parameters for dynamic control sets and per-observation distances
  • Issue B: compute_unit_distance_for_obs() excludes target period
  • Issue C: estimate_model() uses weighted proximal gradient with proper step size eta = 1/max(W)
  • Issue D: bootstrap_trop_variance() implements stratified sampling

Assessment: Both Python and Rust implementations are consistent.


Part 2: Issues Found

Critical Issues

None identified.

Medium Issues

  1. Potential Performance Regression in Distance Computation (trop.py:1020-1030)

    The fix for Issue A+B now computes per-observation distances in a loop:

    for j in range(n_units):
        if valid_control_at_t[j] and j != i:
            dist = self._compute_unit_distance_for_obs(Y_stored, D_stored, j, i, t)

    This is O(n_units * n_periods) per treated observation, potentially slower than using pre-computed distance matrices. Consider whether the pre-computed unit_dist_matrix can still be used for an approximation when lambda_unit > 0 is small.

  2. Documentation Mismatch in Rust Backend (rust/src/trop.rs:413-415)

    The comment says parameters are "not used" but they're actually used in the function:

    _control_units: &[usize],  // Kept for API compatibility but not used
    _unit_dist: &ArrayView2<f64>,  // Not used - we compute per-observation distances

    These parameters should be removed if truly unused, or the comments should be updated if they serve a purpose.

Minor Issues

  1. Tutorial Notebook Formatting (10_trop.ipynb)

    Some markdown cells were collapsed into single-line strings, which is valid but harder to edit manually. This appears to be a Jupyter notebook saving artifact.

  2. Test Assertion Message (test_trop.py:1279)

    assert results.att > 0, f"ATT={results.att:.3f} should be positive"

    Consider adding the expected value for better debugging: f"ATT={results.att:.3f} should be positive (true_att={true_att})"


Part 3: Security Assessment

No security issues identified. This PR modifies statistical estimation code without any:

  • User input handling changes
  • File I/O changes
  • Network operations
  • Dynamic code execution

Part 4: Documentation Assessment

Strengths:

  • Extensive inline comments explaining paper references (e.g., "Following Equation 2 (page 7)")
  • Docstrings updated with Issue A/B/C/D fix explanations
  • Tutorial notebook updated with clarifications on:
    • Distance computation methodology
    • Bootstrap variance estimation methodology
    • Jackknife being an implementation convenience, not paper-specified

Gaps:

  • CHANGELOG.md should be updated with these fixes
  • Consider adding a "Paper Conformance" section to the TROP API documentation

Part 5: Performance Assessment

Potential Concerns:

  1. Per-Observation Distance Computation: The fix for Issues A+B requires computing distances excluding the target period for each treated observation. This changes complexity from O(n_units^2) once to O(n_treated_obs * n_units * n_periods).

  2. Mitigation: The Rust backend parallelizes this computation, and the pre-computed distance matrix is still used as a cache when available.

Recommendation: Run benchmarks on larger datasets to quantify any performance impact. The correctness improvement likely justifies moderate slowdown.


Part 6: Maintainability Assessment

Strengths:

  • Clean separation of concerns: each fix is isolated to specific methods
  • TypedDict updated with new fields (D, Y in _PrecomputedStructures)
  • Tests are well-organized in TestPaperConformanceFixes class
  • Both Python and Rust implementations updated consistently

No Breaking Changes: The public API remains unchanged. Internal method signatures were modified but these are private (_compute_observation_weights, _weighted_nuclear_norm_solve).


Recommendations

Must Fix (before merge)

None - this PR is ready for merge.

Should Fix

  1. Update CHANGELOG.md with v2.x.x entry documenting these conformance fixes
  2. Clean up unused parameters in Rust compute_weight_matrix() or update comments

Nice to Have

  1. Add performance benchmarks comparing before/after this PR
  2. Expand test coverage for edge cases with staggered adoption timing
  3. Add "Paper Conformance" section to API docs referencing these fixes

Final Assessment

Category Rating Notes
Methodology Fixes align with paper specification
Code Quality Clean implementation, consistent Python/Rust
Security No security concerns
Documentation Good inline docs, tutorial updated
Performance ⚠️ Potential regression, mitigated by Rust backend
Maintainability No breaking changes, good test coverage

Overall Verdict: Approved

This PR correctly addresses four methodological discrepancies between the implementation and the TROP paper. The fixes are well-documented, thoroughly tested (292 new test lines), and consistently implemented in both Python and Rust backends. The changes improve statistical correctness and should be merged.


Review generated by Claude Code

Address code review feedback to remove unused API parameters:

Rust backend (trop.rs):
- Remove control_unit_idx and unit_dist_matrix from loocv_grid_search
- Remove control_unit_idx, treated_obs_t/i, unit_dist_matrix from
  bootstrap_trop_variance
- Remove unit_dist_boot computation in bootstrap (no longer needed)
- Remove control_units and unit_dist from internal functions

Python (trop.py):
- Update _rust_loocv_grid_search call to use new signature
- Update _rust_bootstrap_trop_variance call to use new signature
- Remove unused variable preparation for removed parameters

Tests (test_rust_backend.py):
- Update test calls to use new API signatures
- Remove unused variable assignments

The precomputed unit_dist_matrix is no longer needed by the Rust backend
since per-observation distances are computed dynamically to properly
exclude the target period per Equation 3 of the paper.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@igerber igerber merged commit 77e2c65 into main Jan 19, 2026
4 checks passed
@igerber igerber deleted the fix/trop-paper-conformance branch January 19, 2026 18:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant