Skip to content

Commit 4f28c05

Browse files
igerberclaude
andcommitted
Address PR #113 Round 7 feedback: fix Python-only LOOCV subsampling
Fix ValueError in joint method when control_obs exceeds max_loocv_samples without Rust backend. np.random.choice cannot directly sample from a list of tuples - now samples indices first, then indexes into the list (matching the pattern already used in the twostep method). Add test to verify Python-only joint LOOCV subsampling works correctly. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 9eeb436 commit 4f28c05

3 files changed

Lines changed: 85 additions & 54 deletions

File tree

diff_diff/trop.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,12 +1412,12 @@ def _fit_joint(
14121412
if control_mask[t, i] and not np.isnan(Y[t, i])
14131413
]
14141414

1415-
# Subsample if needed
1416-
if len(control_obs) > self.max_loocv_samples:
1417-
rng = np.random.default_rng(self.seed)
1418-
control_obs = list(
1419-
rng.choice(control_obs, size=self.max_loocv_samples, replace=False)
1420-
)
1415+
# Subsample if needed (sample indices to avoid ValueError on list of tuples)
1416+
rng = np.random.default_rng(self.seed)
1417+
max_loocv = min(self.max_loocv_samples, len(control_obs))
1418+
if len(control_obs) > max_loocv:
1419+
indices = rng.choice(len(control_obs), size=max_loocv, replace=False)
1420+
control_obs = [control_obs[idx] for idx in indices]
14211421

14221422
# Grid search with true LOOCV
14231423
for lambda_time_val in self.lambda_time_grid:

tests/test_rust_backend.py

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,31 +1554,18 @@ def test_trop_joint_nan_exclusion_rust_python_parity(self):
15541554
trop_rust = TROP(**trop_params)
15551555
results_rust = trop_rust.fit(df.copy(), 'outcome', 'treated', 'unit', 'time')
15561556

1557-
# Run with Python-only backend
1558-
old_backend = os.environ.get('DIFF_DIFF_BACKEND')
1559-
try:
1560-
os.environ['DIFF_DIFF_BACKEND'] = 'python'
1561-
# Need to reimport to pick up new backend setting
1562-
# Must reload both _backend AND trop modules since trop imports
1563-
# HAS_RUST_BACKEND and Rust functions at module load time
1564-
import importlib
1565-
import sys
1566-
importlib.reload(sys.modules['diff_diff._backend'])
1567-
importlib.reload(sys.modules['diff_diff.trop'])
1568-
from diff_diff.trop import TROP as TROP_Python
1569-
1570-
trop_python = TROP_Python(**trop_params)
1557+
# Run with Python-only backend using mock.patch to avoid module reload issues
1558+
# (Module reload breaks isinstance() checks in other tests due to class identity)
1559+
from unittest.mock import patch
1560+
import sys
1561+
trop_module = sys.modules['diff_diff.trop']
1562+
1563+
with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \
1564+
patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \
1565+
patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None):
1566+
1567+
trop_python = TROP(**trop_params)
15711568
results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time')
1572-
finally:
1573-
# Restore original backend setting
1574-
if old_backend is None:
1575-
os.environ.pop('DIFF_DIFF_BACKEND', None)
1576-
else:
1577-
os.environ['DIFF_DIFF_BACKEND'] = old_backend
1578-
import importlib
1579-
import sys
1580-
importlib.reload(sys.modules['diff_diff._backend'])
1581-
importlib.reload(sys.modules['diff_diff.trop'])
15821569

15831570
# Both should produce finite results
15841571
assert np.isfinite(results_rust.att), f"Rust ATT {results_rust.att} should be finite"
@@ -1657,31 +1644,18 @@ def test_trop_joint_treated_pre_nan_rust_python_parity(self):
16571644
trop_rust = TROP(**trop_params)
16581645
results_rust = trop_rust.fit(df.copy(), 'outcome', 'treated', 'unit', 'time')
16591646

1660-
# Run with Python-only backend
1661-
old_backend = os.environ.get('DIFF_DIFF_BACKEND')
1662-
try:
1663-
os.environ['DIFF_DIFF_BACKEND'] = 'python'
1664-
# Need to reimport to pick up new backend setting
1665-
# Must reload both _backend AND trop modules since trop imports
1666-
# HAS_RUST_BACKEND and Rust functions at module load time
1667-
import importlib
1668-
import sys
1669-
importlib.reload(sys.modules['diff_diff._backend'])
1670-
importlib.reload(sys.modules['diff_diff.trop'])
1671-
from diff_diff.trop import TROP as TROP_Python
1672-
1673-
trop_python = TROP_Python(**trop_params)
1647+
# Run with Python-only backend using mock.patch to avoid module reload issues
1648+
# (Module reload breaks isinstance() checks in other tests due to class identity)
1649+
from unittest.mock import patch
1650+
import sys
1651+
trop_module = sys.modules['diff_diff.trop']
1652+
1653+
with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \
1654+
patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \
1655+
patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None):
1656+
1657+
trop_python = TROP(**trop_params)
16741658
results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time')
1675-
finally:
1676-
# Restore original backend setting
1677-
if old_backend is None:
1678-
os.environ.pop('DIFF_DIFF_BACKEND', None)
1679-
else:
1680-
os.environ['DIFF_DIFF_BACKEND'] = old_backend
1681-
import importlib
1682-
import sys
1683-
importlib.reload(sys.modules['diff_diff._backend'])
1684-
importlib.reload(sys.modules['diff_diff.trop'])
16851659

16861660
# Both should produce finite results
16871661
assert np.isfinite(results_rust.att), f"Rust ATT {results_rust.att} should be finite"

tests/test_trop.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3236,3 +3236,60 @@ def test_joint_rejects_staggered_adoption(self):
32363236
trop = TROP(method="joint")
32373237
with pytest.raises(ValueError, match="staggered adoption"):
32383238
trop.fit(df, 'outcome', 'treated', 'unit', 'time')
3239+
3240+
def test_joint_python_loocv_subsampling(self):
3241+
"""Test that joint method works with Python-only LOOCV when control_obs > max_loocv_samples.
3242+
3243+
This tests the fix for PR #113 Round 7 feedback (P1): Python fallback
3244+
LOOCV sampling could raise ValueError when control_obs is a list of tuples.
3245+
"""
3246+
from unittest.mock import patch
3247+
import sys
3248+
3249+
np.random.seed(42)
3250+
# Create data with many control observations (> default max_loocv_samples=500)
3251+
n_units, n_periods = 30, 25 # 30*25 = 750 observations, most are control
3252+
n_treated = 3
3253+
n_post = 3
3254+
3255+
data = []
3256+
for i in range(n_units):
3257+
is_treated = i < n_treated
3258+
for t in range(n_periods):
3259+
post = t >= (n_periods - n_post)
3260+
y = 10.0 + i * 0.1 + t * 0.1 + np.random.randn() * 0.5
3261+
treatment_indicator = 1 if (is_treated and post) else 0
3262+
if treatment_indicator:
3263+
y += 2.0
3264+
data.append({
3265+
'unit': i,
3266+
'time': t,
3267+
'outcome': y,
3268+
'treated': treatment_indicator,
3269+
})
3270+
3271+
df = pd.DataFrame(data)
3272+
3273+
# Patch to force Python backend and set small max_loocv_samples
3274+
trop_module = sys.modules['diff_diff.trop']
3275+
3276+
with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \
3277+
patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \
3278+
patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None):
3279+
3280+
# Use small max_loocv_samples to trigger subsampling
3281+
trop_est = TROP(
3282+
method="joint",
3283+
lambda_time_grid=[1.0],
3284+
lambda_unit_grid=[1.0],
3285+
lambda_nn_grid=[0.0],
3286+
max_loocv_samples=100, # Force subsampling (control_obs > 100)
3287+
n_bootstrap=0,
3288+
seed=42
3289+
)
3290+
3291+
# This should not raise ValueError
3292+
results = trop_est.fit(df, 'outcome', 'treated', 'unit', 'time')
3293+
3294+
assert isinstance(results, TROPResults)
3295+
assert np.isfinite(results.att)

0 commit comments

Comments
 (0)