Skip to content

Commit 9eeb436

Browse files
igerberclaude
andcommitted
Address PR #113 Round 6 feedback: fix staggered adoption check and test reloads
- Fix simultaneous-adoption check to use observed periods only, avoiding false positives on unbalanced panels where missing entries were filled as 0 - Add zero-weight guard in Python joint solver matching Rust's behavior - Fix backend parity tests to properly reload trop module using sys.modules Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 4a27ec5 commit 9eeb436

2 files changed

Lines changed: 35 additions & 13 deletions

File tree

diff_diff/trop.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,11 @@ def _solve_joint_no_lowrank(
10791079

10801080
sqrt_weights = np.sqrt(np.maximum(weights, 0))
10811081

1082+
# Check for all-zero weights (matches Rust's sum_w < 1e-10 check)
1083+
sum_w = np.sum(weights)
1084+
if sum_w < 1e-10:
1085+
raise ValueError("All weights are zero - cannot estimate")
1086+
10821087
# Build design matrix: [intercept, unit_dummies, time_dummies, treatment]
10831088
# Total columns: 1 + n_units + n_periods + 1
10841089
# But we need to drop one unit and one time dummy for identification
@@ -1321,11 +1326,17 @@ def _fit_joint(
13211326
raise ValueError("Need at least 2 pre-treatment periods")
13221327

13231328
# Check for staggered adoption (joint method requires simultaneous treatment)
1329+
# Use only observed periods (skip missing) to avoid false positives on unbalanced panels
13241330
first_treat_by_unit = []
13251331
for i in treated_unit_idx:
1326-
treated_periods_i = np.where(D[:, i] == 1)[0]
1327-
if len(treated_periods_i) > 0:
1328-
first_treat_by_unit.append(treated_periods_i[0])
1332+
observed_mask = ~missing_mask[:, i]
1333+
# Get D values for observed periods only
1334+
observed_d = D[observed_mask, i]
1335+
observed_periods = np.where(observed_mask)[0]
1336+
# Find first treatment among observed periods
1337+
treated_idx = np.where(observed_d == 1)[0]
1338+
if len(treated_idx) > 0:
1339+
first_treat_by_unit.append(observed_periods[treated_idx[0]])
13291340

13301341
unique_starts = sorted(set(first_treat_by_unit))
13311342
if len(unique_starts) > 1:

tests/test_rust_backend.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,11 +1559,15 @@ def test_trop_joint_nan_exclusion_rust_python_parity(self):
15591559
try:
15601560
os.environ['DIFF_DIFF_BACKEND'] = 'python'
15611561
# Need to reimport to pick up new backend setting
1562+
# Must reload both _backend AND trop modules since trop imports
1563+
# HAS_RUST_BACKEND and Rust functions at module load time
15621564
import importlib
1563-
import diff_diff._backend
1564-
importlib.reload(diff_diff._backend)
1565+
import sys
1566+
importlib.reload(sys.modules['diff_diff._backend'])
1567+
importlib.reload(sys.modules['diff_diff.trop'])
1568+
from diff_diff.trop import TROP as TROP_Python
15651569

1566-
trop_python = TROP(**trop_params)
1570+
trop_python = TROP_Python(**trop_params)
15671571
results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time')
15681572
finally:
15691573
# Restore original backend setting
@@ -1572,8 +1576,9 @@ def test_trop_joint_nan_exclusion_rust_python_parity(self):
15721576
else:
15731577
os.environ['DIFF_DIFF_BACKEND'] = old_backend
15741578
import importlib
1575-
import diff_diff._backend
1576-
importlib.reload(diff_diff._backend)
1579+
import sys
1580+
importlib.reload(sys.modules['diff_diff._backend'])
1581+
importlib.reload(sys.modules['diff_diff.trop'])
15771582

15781583
# Both should produce finite results
15791584
assert np.isfinite(results_rust.att), f"Rust ATT {results_rust.att} should be finite"
@@ -1656,11 +1661,16 @@ def test_trop_joint_treated_pre_nan_rust_python_parity(self):
16561661
old_backend = os.environ.get('DIFF_DIFF_BACKEND')
16571662
try:
16581663
os.environ['DIFF_DIFF_BACKEND'] = 'python'
1664+
# Need to reimport to pick up new backend setting
1665+
# Must reload both _backend AND trop modules since trop imports
1666+
# HAS_RUST_BACKEND and Rust functions at module load time
16591667
import importlib
1660-
import diff_diff._backend
1661-
importlib.reload(diff_diff._backend)
1668+
import sys
1669+
importlib.reload(sys.modules['diff_diff._backend'])
1670+
importlib.reload(sys.modules['diff_diff.trop'])
1671+
from diff_diff.trop import TROP as TROP_Python
16621672

1663-
trop_python = TROP(**trop_params)
1673+
trop_python = TROP_Python(**trop_params)
16641674
results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time')
16651675
finally:
16661676
# Restore original backend setting
@@ -1669,8 +1679,9 @@ def test_trop_joint_treated_pre_nan_rust_python_parity(self):
16691679
else:
16701680
os.environ['DIFF_DIFF_BACKEND'] = old_backend
16711681
import importlib
1672-
import diff_diff._backend
1673-
importlib.reload(diff_diff._backend)
1682+
import sys
1683+
importlib.reload(sys.modules['diff_diff._backend'])
1684+
importlib.reload(sys.modules['diff_diff.trop'])
16741685

16751686
# Both should produce finite results
16761687
assert np.isfinite(results_rust.att), f"Rust ATT {results_rust.att} should be finite"

0 commit comments

Comments
 (0)