Skip to content

Commit 7e41d76

Browse files
igerberclaude
andcommitted
Fix TROP joint method Rust/Python parity issues (PR #113 feedback)
Address P1 review feedback: - P1-2: Align nuclear-norm threshold scaling by using eta * lambda_nn for soft-threshold SVD step in Python (matching Rust implementation) - P1-1: Add comprehensive NaN handling in _compute_joint_weights, _solve_joint_no_lowrank, and _solve_joint_with_lowrank Add tests for NaN handling parity between backends. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent b519f67 commit 7e41d76

3 files changed

Lines changed: 158 additions & 9 deletions

File tree

diff_diff/trop.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -910,23 +910,32 @@ def _compute_joint_weights(
910910
delta_time = np.exp(-lambda_time * dist_time)
911911

912912
# Unit weights: RMSE to average treated trajectory over pre-periods
913-
# Compute average treated trajectory
914-
average_treated = np.mean(Y[:, treated_unit_idx], axis=1)
913+
# Compute average treated trajectory (use nanmean to handle NaN)
914+
average_treated = np.nanmean(Y[:, treated_unit_idx], axis=1)
915915

916916
# Pre-period mask: 1 in pre, 0 in post
917917
pre_mask = np.ones(n_periods, dtype=float)
918918
pre_mask[-treated_periods:] = 0.0
919919

920920
# Compute RMS distance for each unit
921921
# dist_unit[i] = sqrt(sum_pre(avg_tr - Y_i)^2 / n_pre)
922-
diff_sq = ((average_treated[:, np.newaxis] - Y) ** 2) * pre_mask[:, np.newaxis]
922+
# Use NaN-safe operations: treat NaN differences as 0 (excluded)
923+
diff = average_treated[:, np.newaxis] - Y
924+
diff_sq = np.where(np.isfinite(diff), diff ** 2, 0.0) * pre_mask[:, np.newaxis]
925+
926+
# Count valid observations per unit in pre-period
927+
valid_count = np.sum(
928+
np.isfinite(Y) * pre_mask[:, np.newaxis], axis=0
929+
)
923930
sum_sq = np.sum(diff_sq, axis=0)
924931
n_pre = np.sum(pre_mask)
925932

926933
if n_pre == 0:
927934
raise ValueError("No pre-treatment periods")
928935

929-
dist_unit = np.sqrt(sum_sq / n_pre)
936+
# Use valid count per unit (avoid division by zero)
937+
valid_count = np.maximum(valid_count, 1)
938+
dist_unit = np.sqrt(sum_sq / valid_count)
930939
delta_unit = np.exp(-lambda_unit * dist_unit)
931940

932941
# Outer product: (n_periods x n_units)
@@ -1050,6 +1059,15 @@ def _solve_joint_no_lowrank(
10501059
y = Y.flatten() # length n_periods * n_units
10511060
w = D.flatten()
10521061
weights = delta.flatten()
1062+
1063+
# Handle NaN values: zero weight for NaN outcomes/weights, impute with 0
1064+
# This ensures NaN observations don't contribute to estimation
1065+
valid_y = np.isfinite(y)
1066+
valid_w = np.isfinite(weights)
1067+
valid_mask = valid_y & valid_w
1068+
weights = np.where(valid_mask, weights, 0.0)
1069+
y = np.where(valid_mask, y, 0.0)
1070+
10531071
sqrt_weights = np.sqrt(np.maximum(weights, 0))
10541072

10551073
# Build design matrix: [intercept, unit_dummies, time_dummies, treatment]
@@ -1132,20 +1150,24 @@ def _solve_joint_with_lowrank(
11321150
"""
11331151
n_periods, n_units = Y.shape
11341152

1153+
# Handle NaN values: impute with 0 for computations
1154+
# The solver will also zero weights for NaN observations
1155+
Y_safe = np.where(np.isfinite(Y), Y, 0.0)
1156+
11351157
# Initialize L = 0
11361158
L = np.zeros((n_periods, n_units))
11371159

11381160
for iteration in range(max_iter):
11391161
L_old = L.copy()
11401162

11411163
# Step 1: Fix L, solve for (mu, alpha, beta, tau)
1142-
# Adjusted outcome: Y - L
1143-
Y_adj = Y - L
1164+
# Adjusted outcome: Y - L (using NaN-safe Y)
1165+
Y_adj = Y_safe - L
11441166
mu, alpha, beta, tau = self._solve_joint_no_lowrank(Y_adj, D, delta)
11451167

11461168
# Step 2: Fix (mu, alpha, beta, tau), update L
1147-
# Residual: R = Y - mu - alpha - beta - tau*D
1148-
R = Y - mu - alpha[np.newaxis, :] - beta[:, np.newaxis] - tau * D
1169+
# Residual: R = Y - mu - alpha - beta - tau*D (using NaN-safe Y)
1170+
R = Y_safe - mu - alpha[np.newaxis, :] - beta[:, np.newaxis] - tau * D
11491171

11501172
# Weighted proximal step for L (soft-threshold SVD)
11511173
# Normalize weights
@@ -1160,7 +1182,9 @@ def _solve_joint_with_lowrank(
11601182
gradient_step = L + delta_norm * (R - L)
11611183

11621184
# Soft-threshold singular values
1163-
L = self._soft_threshold_svd(gradient_step, lambda_nn)
1185+
# Use eta * lambda_nn for proper proximal step size (matches Rust)
1186+
eta = 1.0 / delta_max if delta_max > 0 else 1.0
1187+
L = self._soft_threshold_svd(gradient_step, eta * lambda_nn)
11641188

11651189
# Check convergence
11661190
if np.max(np.abs(L - L_old)) < tol:

tests/test_rust_backend.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,65 @@ def test_trop_joint_and_twostep_agree_in_direction(self):
13641364
# Both should have same sign (both positive for true_effect=2.0)
13651365
assert np.sign(results_joint.att) == np.sign(results_twostep.att)
13661366

1367+
def test_trop_joint_handles_nan_outcomes(self):
1368+
"""Test TROP joint method handles NaN outcome values gracefully."""
1369+
import pandas as pd
1370+
from diff_diff import TROP
1371+
1372+
np.random.seed(42)
1373+
n_units, n_periods = 20, 10
1374+
n_treated = 5
1375+
n_post = 3
1376+
true_effect = 2.0
1377+
1378+
data = []
1379+
for i in range(n_units):
1380+
is_treated = i < n_treated
1381+
for t in range(n_periods):
1382+
post = t >= (n_periods - n_post)
1383+
y = 10.0 + i * 0.2 + t * 0.3 + np.random.randn() * 0.5
1384+
treatment_indicator = 1 if (is_treated and post) else 0
1385+
if treatment_indicator:
1386+
y += true_effect
1387+
data.append({
1388+
'unit': i,
1389+
'time': t,
1390+
'outcome': y,
1391+
'treated': treatment_indicator,
1392+
})
1393+
1394+
df = pd.DataFrame(data)
1395+
1396+
# Introduce NaN values in control observations (pre-treatment periods)
1397+
# Set 5% of control pre-treatment observations to NaN
1398+
nan_indices = []
1399+
for idx, row in df.iterrows():
1400+
if row['treated'] == 0 and row['time'] < (n_periods - n_post):
1401+
if np.random.rand() < 0.05:
1402+
nan_indices.append(idx)
1403+
df.loc[nan_indices, 'outcome'] = np.nan
1404+
1405+
n_nan = len(nan_indices)
1406+
assert n_nan > 0, "Should have introduced some NaN values"
1407+
1408+
trop = TROP(
1409+
method="joint",
1410+
lambda_time_grid=[0.0, 1.0],
1411+
lambda_unit_grid=[0.0, 1.0],
1412+
lambda_nn_grid=[0.0, 0.1],
1413+
n_bootstrap=20,
1414+
seed=42
1415+
)
1416+
results = trop.fit(df, 'outcome', 'treated', 'unit', 'time')
1417+
1418+
# Results should be finite (NaN observations are excluded)
1419+
assert np.isfinite(results.att), f"ATT {results.att} should be finite with NaN data"
1420+
assert np.isfinite(results.se), f"SE {results.se} should be finite with NaN data"
1421+
assert results.se >= 0, "SE should be non-negative"
1422+
1423+
# ATT should still be positive (true effect is positive)
1424+
assert results.att > 0, f"ATT {results.att:.2f} should be positive"
1425+
13671426

13681427
class TestFallbackWhenNoRust:
13691428
"""Test that pure Python fallback works when Rust is unavailable."""

tests/test_trop.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2948,3 +2948,69 @@ def test_joint_loocv_score_internal(self, simple_panel_data):
29482948
treated_periods, n_units, n_periods
29492949
)
29502950
assert np.isfinite(score2) or np.isinf(score2), "Score should be finite or inf"
2951+
2952+
def test_joint_handles_nan_outcomes(self, simple_panel_data):
2953+
"""Joint method handles NaN outcome values gracefully."""
2954+
# Introduce NaN in some control observations
2955+
data = simple_panel_data.copy()
2956+
control_mask = data['treated'] == 0
2957+
control_indices = data[control_mask].index.tolist()
2958+
2959+
# Set 5 random control observations to NaN
2960+
np.random.seed(42)
2961+
nan_indices = np.random.choice(control_indices, size=5, replace=False)
2962+
data.loc[nan_indices, 'outcome'] = np.nan
2963+
2964+
trop_est = TROP(
2965+
method="joint",
2966+
lambda_time_grid=[0.0, 1.0],
2967+
lambda_unit_grid=[0.0, 1.0],
2968+
lambda_nn_grid=[0.0, 0.1],
2969+
n_bootstrap=10,
2970+
seed=42,
2971+
)
2972+
results = trop_est.fit(
2973+
data,
2974+
outcome="outcome",
2975+
treatment="treated",
2976+
unit="unit",
2977+
time="period",
2978+
)
2979+
2980+
# Results should be finite (NaN observations excluded)
2981+
assert np.isfinite(results.att), "ATT should be finite with NaN data"
2982+
assert np.isfinite(results.se), "SE should be finite with NaN data"
2983+
# ATT should be positive (true effect is 3.0)
2984+
assert results.att > 0, "ATT should be positive"
2985+
2986+
def test_joint_with_lowrank_handles_nan(self, simple_panel_data):
2987+
"""Joint method with low-rank handles NaN values correctly."""
2988+
# Introduce NaN in some control observations
2989+
data = simple_panel_data.copy()
2990+
control_mask = data['treated'] == 0
2991+
control_indices = data[control_mask].index.tolist()
2992+
2993+
# Set 3 random control observations to NaN
2994+
np.random.seed(123)
2995+
nan_indices = np.random.choice(control_indices, size=3, replace=False)
2996+
data.loc[nan_indices, 'outcome'] = np.nan
2997+
2998+
trop_est = TROP(
2999+
method="joint",
3000+
lambda_time_grid=[0.0],
3001+
lambda_unit_grid=[0.0],
3002+
lambda_nn_grid=[0.1], # Finite lambda_nn enables low-rank
3003+
n_bootstrap=10,
3004+
seed=42,
3005+
)
3006+
results = trop_est.fit(
3007+
data,
3008+
outcome="outcome",
3009+
treatment="treated",
3010+
unit="unit",
3011+
time="period",
3012+
)
3013+
3014+
# Results should be finite
3015+
assert np.isfinite(results.att), "ATT should be finite with NaN data"
3016+
assert np.isfinite(results.se), "SE should be finite with NaN data"

0 commit comments

Comments
 (0)