Skip to content

Commit c49edf6

Browse files
igerberclaude
andcommitted
R-style NA propagation for avg_att (P1) and scipy_lstsq tolerance (P1)
- avg_att is now NaN if ANY post-period effect is unidentified, matching R's default NA propagation semantics (mean(c(1,2,NA)) returns NA) - Add explicit cond=1e-07 to scipy_lstsq calls for consistency with Rust backend and QR rank tolerance - Document avg_att NA behavior in REGISTRY.md - Add test for avg_att NaN propagation when period effect is unidentified Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 4c62776 commit c49edf6

4 files changed

Lines changed: 54 additions & 15 deletions

File tree

diff_diff/estimators.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -935,29 +935,25 @@ def fit( # type: ignore[override]
935935
effect_indices.append(idx)
936936

937937
# Compute average treatment effect
938-
# Only average over identified (non-NaN) period effects
938+
# R-style NA propagation: if ANY period effect is NaN, average is undefined
939939
effect_arr = np.array(effect_values)
940-
identified_effects = effect_arr[~np.isnan(effect_arr)]
941940

942-
if len(identified_effects) == 0:
943-
# All period effects are NaN - cannot compute average
941+
if np.any(np.isnan(effect_arr)):
942+
# Some period effects are NaN (unidentified) - cannot compute valid average
943+
# This follows R's default behavior where mean(c(1, 2, NA)) returns NA
944944
avg_att = np.nan
945945
avg_se = np.nan
946946
avg_t_stat = np.nan
947947
avg_p_value = np.nan
948948
avg_conf_int = (np.nan, np.nan)
949949
else:
950-
# Average ATT = mean of identified period-specific effects
951-
avg_att = float(np.mean(identified_effects))
950+
# All effects identified - compute average normally
951+
avg_att = float(np.mean(effect_arr))
952952

953953
# Standard error of average: need to account for covariance
954-
# Only use identified effects in the variance calculation
955-
identified_mask = ~np.isnan(effect_arr)
956-
identified_indices = [idx for idx, m in zip(effect_indices, identified_mask) if m]
957-
n_identified = len(identified_indices)
958-
959-
sub_vcov = vcov[np.ix_(identified_indices, identified_indices)]
960-
avg_var = np.sum(sub_vcov) / (n_identified ** 2)
954+
n_post = len(post_periods)
955+
sub_vcov = vcov[np.ix_(effect_indices, effect_indices)]
956+
avg_var = np.sum(sub_vcov) / (n_post ** 2)
961957

962958
if np.isnan(avg_var) or avg_var < 0:
963959
# Vcov has NaN (dropped columns) - propagate NaN

diff_diff/linalg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,8 +619,9 @@ def _solve_ols_numpy(
619619
X_reduced = X[:, kept_cols]
620620

621621
# Solve the reduced system (now full-rank)
622+
# Use cond=1e-07 for consistency with Rust backend and QR rank tolerance
622623
coefficients_reduced = scipy_lstsq(
623-
X_reduced, y, lapack_driver="gelsd", check_finite=False
624+
X_reduced, y, lapack_driver="gelsd", check_finite=False, cond=1e-07
624625
)[0]
625626

626627
# Expand coefficients to full size with NaN for dropped columns
@@ -638,7 +639,8 @@ def _solve_ols_numpy(
638639
vcov = _expand_vcov_with_nan(vcov_reduced, k, kept_cols)
639640
else:
640641
# Full-rank case: proceed normally
641-
coefficients = scipy_lstsq(X, y, lapack_driver="gelsd", check_finite=False)[0]
642+
# Use cond=1e-07 for consistency with Rust backend and QR rank tolerance
643+
coefficients = scipy_lstsq(X, y, lapack_driver="gelsd", check_finite=False, cond=1e-07)[0]
642644

643645
# Compute residuals and fitted values
644646
fitted = X @ coefficients

docs/methodology/REGISTRY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ where E_i is treatment time for unit i, and δ_e are event-study coefficients.
101101
- Never-treated units: event-time indicators are all zero
102102
- Endpoint binning: distant event times can be binned
103103
- Rank-deficient design matrix (collinearity): warns and sets NA for dropped coefficients (R-style, matches `lm()`)
104+
- Average ATT (`avg_att`) is NA if any post-period effect is unidentified (R-style NA propagation)
104105

105106
**Reference implementation(s):**
106107
- R: `fixest::feols()` with `i(event_time, ref=-1)`

tests/test_estimators.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,6 +1667,46 @@ def test_rank_deficient_design_warns_and_sets_nan(self, multi_period_data):
16671667
assert results.vcov is not None
16681668
assert np.any(np.isnan(results.vcov)), "Vcov should have NaN for dropped column"
16691669

1670+
# avg_att should still be computed because all period effects are identified
1671+
assert not np.isnan(results.avg_att), "avg_att should be valid when all period effects are identified"
1672+
1673+
def test_avg_att_nan_when_period_effect_nan(self, multi_period_data):
1674+
"""Test that avg_att is NaN if any period effect is NaN (R-style NA propagation)."""
1675+
import warnings
1676+
1677+
# Remove all treated observations in period 3 to make that interaction
1678+
# unidentified (column of zeros)
1679+
data_no_treated_period3 = multi_period_data[
1680+
~((multi_period_data["treated"] == 1) & (multi_period_data["period"] == 3))
1681+
].copy()
1682+
1683+
did = MultiPeriodDiD()
1684+
1685+
with warnings.catch_warnings(record=True) as w:
1686+
warnings.simplefilter("always")
1687+
results = did.fit(
1688+
data_no_treated_period3,
1689+
outcome="outcome",
1690+
treatment="treated",
1691+
time="period",
1692+
post_periods=[3, 4, 5]
1693+
)
1694+
1695+
# Should have warning about rank deficiency (treated:period_3 is all zeros)
1696+
rank_warnings = [x for x in w if "Rank-deficient" in str(x.message)
1697+
or "collinear" in str(x.message).lower()]
1698+
assert len(rank_warnings) > 0, "Expected warning about rank deficiency"
1699+
1700+
# The treated×period_3 interaction should have NaN coefficient (unidentified)
1701+
pe_3 = results.period_effects[3]
1702+
assert np.isnan(pe_3.effect), "Period 3 effect should be NaN (unidentified)"
1703+
1704+
# avg_att should be NaN because one period effect is NaN (R-style NA propagation)
1705+
assert np.isnan(results.avg_att), "avg_att should be NaN when any period effect is NaN"
1706+
assert np.isnan(results.avg_se), "avg_se should be NaN when avg_att is NaN"
1707+
assert np.isnan(results.avg_t_stat), "avg_t_stat should be NaN when avg_att is NaN"
1708+
assert np.isnan(results.avg_p_value), "avg_p_value should be NaN when avg_att is NaN"
1709+
16701710

16711711
class TestSyntheticDiD:
16721712
"""Tests for SyntheticDiD estimator."""

0 commit comments

Comments
 (0)