Skip to content

Commit a7a115d

Browse files
igerberclaude
andcommitted
Address PR #113 Round 4 feedback: enforce simultaneous adoption and fix NaN handling
- Add staggered adoption check in _fit_joint() that raises ValueError when units are first treated at different periods - Fix Rust solve_joint NaN weight masking: observations with NaN outcomes now get zero effective weight instead of having values imputed to 0.0 - Fix Rust average_treated initialization: use NaN instead of 0.0 so periods with all-NaN treated data are excluded from unit distance - Update methodology registry to reflect enforced simultaneous adoption - Add test_joint_rejects_staggered_adoption test Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent cef9d31 commit a7a115d

4 files changed

Lines changed: 76 additions & 21 deletions

File tree

diff_diff/trop.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,6 +1318,21 @@ def _fit_joint(
13181318
if n_pre_periods < 2:
13191319
raise ValueError("Need at least 2 pre-treatment periods")
13201320

1321+
# Check for staggered adoption (joint method requires simultaneous treatment)
1322+
first_treat_by_unit = []
1323+
for i in treated_unit_idx:
1324+
treated_periods_i = np.where(D[:, i] == 1)[0]
1325+
if len(treated_periods_i) > 0:
1326+
first_treat_by_unit.append(treated_periods_i[0])
1327+
1328+
unique_starts = sorted(set(first_treat_by_unit))
1329+
if len(unique_starts) > 1:
1330+
raise ValueError(
1331+
f"method='joint' requires simultaneous treatment adoption, but your data "
1332+
f"shows staggered adoption (units first treated at periods {unique_starts}). "
1333+
f"Use method='twostep' which properly handles staggered adoption designs."
1334+
)
1335+
13211336
# LOOCV grid search for tuning parameters
13221337
# Use Rust backend when available for parallel LOOCV (5-10x speedup)
13231338
best_lambda = None

docs/methodology/REGISTRY.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -636,9 +636,11 @@ For joint method, LOOCV works as follows:
636636
- Faster computation for large panels
637637

638638
**Assumptions**:
639-
- **Simultaneous adoption**: Bootstrap and jackknife variance estimation assume fixed
640-
`treated_periods` across all resamples. Treatment timing is inferred once from the
641-
data and held constant. For staggered adoption designs, use `method="twostep"`.
639+
- **Simultaneous adoption (enforced)**: The joint method requires all treated units
640+
to receive treatment at the same time. A `ValueError` is raised if staggered
641+
adoption is detected (units first treated at different periods). Treatment timing is
642+
inferred once and held constant for bootstrap/jackknife variance estimation.
643+
For staggered adoption designs, use `method="twostep"`.
642644

643645
**Reference**: Adapted from reference implementation. See also Athey et al. (2025).
644646

rust/src/trop.rs

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,8 @@ fn compute_joint_weights(
10751075
let n_pre = n_periods.saturating_sub(treated_periods);
10761076

10771077
// Compute average treated trajectory
1078-
let mut average_treated = Array1::<f64>::zeros(n_periods);
1078+
// Initialize to NaN so periods with all-NaN treated data stay NaN (excluded from RMSE)
1079+
let mut average_treated = Array1::<f64>::from_elem(n_periods, f64::NAN);
10791080
if !treated_unit_idx.is_empty() {
10801081
for t in 0..n_periods {
10811082
let mut sum = 0.0;
@@ -1089,6 +1090,7 @@ fn compute_joint_weights(
10891090
if count > 0 {
10901091
average_treated[t] = sum / count as f64;
10911092
}
1093+
// If count == 0, average_treated[t] stays NaN (correctly excluded)
10921094
}
10931095
}
10941096

@@ -1163,7 +1165,8 @@ fn solve_joint_no_lowrank(
11631165

11641166
for t in 0..n_periods {
11651167
for i in 0..n_units {
1166-
let w = delta[[t, i]];
1168+
// NaN outcomes get zero weight (not imputed to 0.0 with active weight)
1169+
let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 };
11671170
let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 };
11681171

11691172
sum_w += w;
@@ -1196,7 +1199,8 @@ fn solve_joint_no_lowrank(
11961199
if sum_w_by_unit[i] > 1e-10 {
11971200
let mut num = 0.0;
11981201
for t in 0..n_periods {
1199-
let w = delta[[t, i]];
1202+
// NaN outcomes get zero weight
1203+
let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 };
12001204
let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 };
12011205
num += w * (y_ti - mu - beta[t] - tau * d[[t, i]]);
12021206
}
@@ -1209,7 +1213,8 @@ fn solve_joint_no_lowrank(
12091213
if sum_w_by_period[t] > 1e-10 {
12101214
let mut num = 0.0;
12111215
for i in 0..n_units {
1212-
let w = delta[[t, i]];
1216+
// NaN outcomes get zero weight
1217+
let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 };
12131218
let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 };
12141219
num += w * (y_ti - mu - alpha[i] - tau * d[[t, i]]);
12151220
}
@@ -1222,7 +1227,8 @@ fn solve_joint_no_lowrank(
12221227
let mut denom_tau = 0.0;
12231228
for t in 0..n_periods {
12241229
for i in 0..n_units {
1225-
let w = delta[[t, i]];
1230+
// NaN outcomes get zero weight
1231+
let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 };
12261232
let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 };
12271233
let d_ti = d[[t, i]];
12281234
if d_ti > 0.5 { // Only treated observations contribute
@@ -1239,7 +1245,8 @@ fn solve_joint_no_lowrank(
12391245
let mut num_mu = 0.0;
12401246
for t in 0..n_periods {
12411247
for i in 0..n_units {
1242-
let w = delta[[t, i]];
1248+
// NaN outcomes get zero weight
1249+
let w = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 };
12431250
let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 };
12441251
num_mu += w * (y_ti - alpha[i] - beta[t] - tau * d[[t, i]]);
12451252
}
@@ -1279,21 +1286,20 @@ fn solve_joint_with_lowrank(
12791286
let l_old = l.clone();
12801287

12811288
// Step 1: Fix L, solve for (mu, alpha, beta, tau)
1282-
// Adjusted outcome: Y - L
1289+
// Adjusted outcome: Y - L (preserve NaN so solve_joint_no_lowrank masks weights)
12831290
let y_adj = Array2::from_shape_fn((n_periods, n_units), |(t, i)| {
1284-
let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 };
1285-
y_ti - l[[t, i]]
1291+
y[[t, i]] - l[[t, i]] // NaN - finite = NaN (preserves NaN info)
12861292
});
12871293

12881294
let (mu, alpha, beta, tau) = solve_joint_no_lowrank(&y_adj.view(), d, delta)?;
12891295

12901296
// Step 2: Fix (mu, alpha, beta, tau), update L
1291-
// Residual: R = Y - mu - alpha - beta - tau*D
1297+
// Residual: R = Y - mu - alpha - beta - tau*D (preserve NaN)
12921298
let mut r = Array2::<f64>::zeros((n_periods, n_units));
12931299
for t in 0..n_periods {
12941300
for i in 0..n_units {
1295-
let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 };
1296-
r[[t, i]] = y_ti - mu - alpha[i] - beta[t] - tau * d[[t, i]];
1301+
// NaN - finite = NaN (will be masked in gradient step)
1302+
r[[t, i]] = y[[t, i]] - mu - alpha[i] - beta[t] - tau * d[[t, i]];
12971303
}
12981304
}
12991305

@@ -1302,15 +1308,20 @@ fn solve_joint_with_lowrank(
13021308
let eta = if delta_max > 0.0 { 1.0 / delta_max } else { 1.0 };
13031309

13041310
// gradient_step = L + eta * delta * (R - L)
1311+
// NaN outcomes get zero weight so they don't affect gradient
13051312
let mut gradient_step = Array2::<f64>::zeros((n_periods, n_units));
13061313
for t in 0..n_periods {
13071314
for i in 0..n_units {
1315+
// Mask delta for NaN outcomes
1316+
let delta_ti = if y[[t, i]].is_finite() { delta[[t, i]] } else { 0.0 };
13081317
let delta_norm = if delta_max > 0.0 {
1309-
delta[[t, i]] / delta_max
1318+
delta_ti / delta_max
13101319
} else {
1311-
delta[[t, i]]
1320+
delta_ti
13121321
};
1313-
gradient_step[[t, i]] = l[[t, i]] + delta_norm * (r[[t, i]] - l[[t, i]]);
1322+
// r[[t,i]] may be NaN, but delta_norm=0 for NaN obs, so contribution=0
1323+
let r_contrib = if r[[t, i]].is_finite() { r[[t, i]] } else { 0.0 };
1324+
gradient_step[[t, i]] = l[[t, i]] + delta_norm * (r_contrib - l[[t, i]]);
13141325
}
13151326
}
13161327

@@ -1324,10 +1335,9 @@ fn solve_joint_with_lowrank(
13241335
}
13251336
}
13261337

1327-
// Final solve with converged L
1338+
// Final solve with converged L (preserve NaN so solve_joint_no_lowrank masks weights)
13281339
let y_adj = Array2::from_shape_fn((n_periods, n_units), |(t, i)| {
1329-
let y_ti = if y[[t, i]].is_finite() { y[[t, i]] } else { 0.0 };
1330-
y_ti - l[[t, i]]
1340+
y[[t, i]] - l[[t, i]] // NaN - finite = NaN (preserves NaN info)
13311341
});
13321342
let (mu, alpha, beta, tau) = solve_joint_no_lowrank(&y_adj.view(), d, delta)?;
13331343

tests/test_trop.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3157,3 +3157,31 @@ def test_joint_unit_no_valid_pre_gets_zero_weight(self, simple_panel_data):
31573157

31583158
assert np.isfinite(results.att), "ATT should be finite even with unit having no pre-period data"
31593159
assert np.isfinite(results.se), "SE should be finite"
3160+
3161+
def test_joint_rejects_staggered_adoption(self):
3162+
"""Joint method raises ValueError for staggered adoption data.
3163+
3164+
The joint method assumes all treated units receive treatment at the
3165+
same time. With staggered adoption (units first treated at different
3166+
periods), the method's weights and variance estimation are invalid.
3167+
"""
3168+
# Create data with staggered treatment (units treated at different times)
3169+
data = []
3170+
np.random.seed(42)
3171+
for i in range(10):
3172+
# Units 0-2 first treated at t=5, units 3-4 first treated at t=7
3173+
first_treat = 5 if i < 3 else 7
3174+
is_treated_unit = i < 5 # Units 0-4 are treated, 5-9 are control
3175+
for t in range(10):
3176+
treated = 1 if is_treated_unit and t >= first_treat else 0
3177+
data.append({
3178+
'unit': i,
3179+
'time': t,
3180+
'outcome': np.random.randn(),
3181+
'treated': treated
3182+
})
3183+
df = pd.DataFrame(data)
3184+
3185+
trop = TROP(method="joint")
3186+
with pytest.raises(ValueError, match="staggered adoption"):
3187+
trop.fit(df, 'outcome', 'treated', 'unit', 'time')

0 commit comments

Comments
 (0)