Skip to content

Commit c1a65bf

Browse files
igerberclaude
andcommitted
Address AI review findings for TROP split
P1: Restore absorbing-state remediation guidance in global path error message P2: Extract _soft_threshold_svd to module-level function, eliminating MRO dependency between TROPGlobalMixin and TROPLocalMixin P2: Add 5 regression tests for module split (dispatch, finite lambda_nn, error message consistency) P3: Remove unused unit_to_idx/period_to_idx in _fit_with_fixed_lambda Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c41e7f1 commit c1a65bf

3 files changed

Lines changed: 171 additions & 65 deletions

File tree

diff_diff/trop_global.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_rust_bootstrap_trop_variance_global,
2525
_rust_loocv_grid_search_global,
2626
)
27+
from diff_diff.trop_local import _soft_threshold_svd
2728
from diff_diff.trop_results import TROPResults
2829
from diff_diff.utils import safe_inference
2930

@@ -39,8 +40,6 @@ class TROPGlobalMixin:
3940
- Inference params: ``alpha``, ``n_bootstrap``, ``seed``
4041
- State: ``results_``, ``is_fitted_``
4142
42-
The ``_solve_global_with_lowrank`` method calls ``self._soft_threshold_svd``
43-
which is defined in ``TROPLocalMixin`` and resolved via Python MRO.
4443
"""
4544

4645
# Type hints for attributes accessed from the main TROP class
@@ -395,9 +394,6 @@ def _solve_global_with_lowrank(
395394
The (1-W) masking is already applied to delta by _compute_global_weights,
396395
so treated observations have zero weight and do not affect the fit.
397396
398-
Note: calls ``self._soft_threshold_svd`` which is defined in
399-
``TROPLocalMixin`` and resolved via Python MRO on the ``TROP`` class.
400-
401397
Parameters
402398
----------
403399
Y : np.ndarray
@@ -468,7 +464,7 @@ def _solve_global_with_lowrank(
468464

469465
# Proximal step: soft-threshold singular values
470466
L_inner_prev = L_inner
471-
L_inner = self._soft_threshold_svd(gradient_step, threshold)
467+
L_inner = _soft_threshold_svd(gradient_step, threshold)
472468
t_fista = t_fista_new
473469

474470
# Convergence check (L_inner_prev holds the pre-SVD value)
@@ -564,7 +560,9 @@ def _fit_global(
564560
if violating_units:
565561
raise ValueError(
566562
f"Treatment indicator is not an absorbing state for units: {violating_units}. "
567-
f"D[t, unit] must be monotonic non-decreasing."
563+
f"D[t, unit] must be monotonic non-decreasing (once treated, always treated). "
564+
f"If this is event-study style data, convert to absorbing state: "
565+
f"D[t, i] = 1 for all t >= first treatment period."
568566
)
569567

570568
# Identify treated observations

diff_diff/trop_local.py

Lines changed: 72 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,76 @@
2727
from diff_diff.trop_results import _PrecomputedStructures
2828

2929

30+
# Module-level convergence tolerance for SVD singular value truncation.
31+
# Singular values below this threshold after soft-thresholding are treated
32+
# as zero to improve numerical stability.
33+
_CONVERGENCE_TOL_SVD: float = 1e-10
34+
35+
36+
def _soft_threshold_svd(
37+
M: np.ndarray,
38+
threshold: float,
39+
convergence_tol: float = _CONVERGENCE_TOL_SVD,
40+
) -> np.ndarray:
41+
"""
42+
Apply soft-thresholding to singular values (proximal operator for nuclear norm).
43+
44+
Parameters
45+
----------
46+
M : np.ndarray
47+
Input matrix.
48+
threshold : float
49+
Soft-thresholding parameter.
50+
convergence_tol : float, default=1e-10
51+
Singular values below this after thresholding are treated as zero.
52+
53+
Returns
54+
-------
55+
np.ndarray
56+
Matrix with soft-thresholded singular values.
57+
"""
58+
if threshold <= 0:
59+
return M
60+
61+
# Handle NaN/Inf values in input
62+
if not np.isfinite(M).all():
63+
M = np.nan_to_num(M, nan=0.0, posinf=0.0, neginf=0.0)
64+
65+
try:
66+
U, s, Vt = np.linalg.svd(M, full_matrices=False)
67+
except np.linalg.LinAlgError:
68+
# SVD failed, return zero matrix
69+
return np.zeros_like(M)
70+
71+
# Check for numerical issues in SVD output
72+
if not (np.isfinite(U).all() and np.isfinite(s).all() and np.isfinite(Vt).all()):
73+
# SVD produced non-finite values, return zero matrix
74+
return np.zeros_like(M)
75+
76+
s_thresh = np.maximum(s - threshold, 0)
77+
78+
# Use truncated reconstruction with only non-zero singular values
79+
nonzero_mask = s_thresh > convergence_tol
80+
if not np.any(nonzero_mask):
81+
return np.zeros_like(M)
82+
83+
# Truncate to non-zero components for numerical stability
84+
U_trunc = U[:, nonzero_mask]
85+
s_trunc = s_thresh[nonzero_mask]
86+
Vt_trunc = Vt[nonzero_mask, :]
87+
88+
# Compute result, suppressing expected numerical warnings from
89+
# ill-conditioned matrices during alternating minimization
90+
with np.errstate(divide='ignore', over='ignore', invalid='ignore'):
91+
result = (U_trunc * s_trunc) @ Vt_trunc
92+
93+
# Replace any NaN/Inf in result with zeros
94+
if not np.isfinite(result).all():
95+
result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
96+
97+
return result
98+
99+
30100
class TROPLocalMixin:
31101
"""Mixin providing local (observation-specific) estimation for TROP.
32102
@@ -378,61 +448,8 @@ def _soft_threshold_svd(
378448
M: np.ndarray,
379449
threshold: float,
380450
) -> np.ndarray:
381-
"""
382-
Apply soft-thresholding to singular values (proximal operator for nuclear norm).
383-
384-
Parameters
385-
----------
386-
M : np.ndarray
387-
Input matrix.
388-
threshold : float
389-
Soft-thresholding parameter.
390-
391-
Returns
392-
-------
393-
np.ndarray
394-
Matrix with soft-thresholded singular values.
395-
"""
396-
if threshold <= 0:
397-
return M
398-
399-
# Handle NaN/Inf values in input
400-
if not np.isfinite(M).all():
401-
M = np.nan_to_num(M, nan=0.0, posinf=0.0, neginf=0.0)
402-
403-
try:
404-
U, s, Vt = np.linalg.svd(M, full_matrices=False)
405-
except np.linalg.LinAlgError:
406-
# SVD failed, return zero matrix
407-
return np.zeros_like(M)
408-
409-
# Check for numerical issues in SVD output
410-
if not (np.isfinite(U).all() and np.isfinite(s).all() and np.isfinite(Vt).all()):
411-
# SVD produced non-finite values, return zero matrix
412-
return np.zeros_like(M)
413-
414-
s_thresh = np.maximum(s - threshold, 0)
415-
416-
# Use truncated reconstruction with only non-zero singular values
417-
nonzero_mask = s_thresh > self.CONVERGENCE_TOL_SVD
418-
if not np.any(nonzero_mask):
419-
return np.zeros_like(M)
420-
421-
# Truncate to non-zero components for numerical stability
422-
U_trunc = U[:, nonzero_mask]
423-
s_trunc = s_thresh[nonzero_mask]
424-
Vt_trunc = Vt[nonzero_mask, :]
425-
426-
# Compute result, suppressing expected numerical warnings from
427-
# ill-conditioned matrices during alternating minimization
428-
with np.errstate(divide='ignore', over='ignore', invalid='ignore'):
429-
result = (U_trunc * s_trunc) @ Vt_trunc
430-
431-
# Replace any NaN/Inf in result with zeros
432-
if not np.isfinite(result).all():
433-
result = np.nan_to_num(result, nan=0.0, posinf=0.0, neginf=0.0)
434-
435-
return result
451+
"""Delegate to module-level ``_soft_threshold_svd``."""
452+
return _soft_threshold_svd(M, threshold, self.CONVERGENCE_TOL_SVD)
436453

437454
def _weighted_nuclear_norm_solve(
438455
self,
@@ -948,9 +965,6 @@ def _fit_with_fixed_lambda(
948965
n_units = len(all_units)
949966
n_periods = len(all_periods)
950967

951-
unit_to_idx = {u: i for i, u in enumerate(all_units)}
952-
period_to_idx = {p: i for i, p in enumerate(all_periods)}
953-
954968
# Vectorized: use pivot for O(1) reshaping instead of O(n) iterrows loop
955969
Y = (
956970
data.pivot(index=time, columns=unit, values=outcome)

tests/test_trop.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3662,3 +3662,97 @@ def test_local_bootstrap_zero_draws_returns_nan_se(self):
36623662
assert np.isnan(se), f"SE should be NaN when 0 draws succeed, got {se}"
36633663
assert len(dist) == 0
36643664

3665+
3666+
class TestTROPModuleSplit:
3667+
"""Regression tests for the trop.py -> trop_global.py / trop_local.py split."""
3668+
3669+
@staticmethod
3670+
def _make_panel():
3671+
"""Create a simple balanced panel for split regression tests."""
3672+
rng = np.random.default_rng(42)
3673+
n_units, n_periods = 8, 6
3674+
rows = []
3675+
for i in range(n_units):
3676+
treated = i < 3 # 3 treated, 5 control
3677+
for t in range(n_periods):
3678+
y = rng.normal(0, 1)
3679+
if treated and t >= 4:
3680+
y += 2.0 # treatment effect
3681+
rows.append({
3682+
"unit": i, "time": t, "outcome": y,
3683+
"treated": 1 if treated and t >= 4 else 0,
3684+
})
3685+
return pd.DataFrame(rows)
3686+
3687+
def test_global_absorbing_state_error_has_remediation_guidance(self):
3688+
"""Global path ValueError for non-absorbing D includes remediation text."""
3689+
df = self._make_panel()
3690+
# Break absorbing state: unit 0 goes 0->1->0
3691+
df.loc[(df["unit"] == 0) & (df["time"] == 5), "treated"] = 0
3692+
3693+
with pytest.raises(ValueError, match="once treated, always treated"):
3694+
TROP(method="global").fit(df, "outcome", "treated", "unit", "time")
3695+
3696+
with pytest.raises(ValueError, match="convert to absorbing state"):
3697+
TROP(method="global").fit(df, "outcome", "treated", "unit", "time")
3698+
3699+
def test_global_finite_lambda_nn_exercises_lowrank_path(self):
3700+
"""method='global' with finite lambda_nn successfully fits the low-rank solver."""
3701+
df = self._make_panel()
3702+
trop_est = TROP(
3703+
method="global",
3704+
lambda_time_grid=[0.0],
3705+
lambda_unit_grid=[0.0],
3706+
lambda_nn_grid=[0.1], # finite -> exercises _solve_global_with_lowrank
3707+
n_bootstrap=5,
3708+
seed=42,
3709+
)
3710+
with warnings.catch_warnings():
3711+
warnings.simplefilter("ignore")
3712+
result = trop_est.fit(df, "outcome", "treated", "unit", "time")
3713+
assert np.isfinite(result.att)
3714+
3715+
def test_local_finite_lambda_nn_exercises_nuclear_norm(self):
3716+
"""method='local' with finite lambda_nn exercises weighted nuclear norm solver."""
3717+
df = self._make_panel()
3718+
trop_est = TROP(
3719+
method="local",
3720+
lambda_time_grid=[0.0],
3721+
lambda_unit_grid=[0.0],
3722+
lambda_nn_grid=[0.1], # finite -> exercises _weighted_nuclear_norm_solve
3723+
n_bootstrap=5,
3724+
seed=42,
3725+
)
3726+
with warnings.catch_warnings():
3727+
warnings.simplefilter("ignore")
3728+
result = trop_est.fit(df, "outcome", "treated", "unit", "time")
3729+
assert np.isfinite(result.att)
3730+
3731+
def test_method_dispatch_global_uses_fit_global(self):
3732+
"""method='global' dispatches to _fit_global from TROPGlobalMixin."""
3733+
from unittest.mock import patch
3734+
3735+
df = self._make_panel()
3736+
trop_est = TROP(method="global", n_bootstrap=2, seed=42)
3737+
3738+
with patch.object(TROP, '_fit_global', wraps=trop_est._fit_global) as mock_fg:
3739+
with warnings.catch_warnings():
3740+
warnings.simplefilter("ignore")
3741+
trop_est.fit(df, "outcome", "treated", "unit", "time")
3742+
mock_fg.assert_called_once()
3743+
3744+
def test_method_dispatch_local_does_not_use_fit_global(self):
3745+
"""method='local' does NOT call _fit_global."""
3746+
from unittest.mock import patch
3747+
3748+
df = self._make_panel()
3749+
trop_est = TROP(method="local", n_bootstrap=2, seed=42,
3750+
lambda_time_grid=[0.0], lambda_unit_grid=[0.0],
3751+
lambda_nn_grid=[np.inf])
3752+
3753+
with patch.object(TROP, '_fit_global') as mock_fg:
3754+
with warnings.catch_warnings():
3755+
warnings.simplefilter("ignore")
3756+
trop_est.fit(df, "outcome", "treated", "unit", "time")
3757+
mock_fg.assert_not_called()
3758+

0 commit comments

Comments
 (0)