Skip to content

Commit a9bf9ce

Browse files
igerberclaude
andcommitted
Fix zero-SE inference, full-census FPC, fweight contract, and absorbed sample counts from PR #218 review (round 10)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5c5c2d4 commit a9bf9ce

4 files changed

Lines changed: 173 additions & 45 deletions

File tree

diff_diff/estimators.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ def fit(
246246
absorbed_vars = []
247247
n_absorbed_effects = 0
248248

249+
# Save raw treatment counts before absorb demeaning
250+
n_treated_raw = int(np.sum(data[treatment].values.astype(float)))
251+
n_control_raw = len(data) - n_treated_raw
252+
249253
if absorb:
250254
# FWL theorem: demean ALL regressors alongside outcome.
251255
# Regressors collinear with absorbed FE (e.g., treatment after
@@ -358,9 +362,9 @@ def fit(
358362

359363
r_squared = compute_r_squared(y, residuals)
360364

361-
# Count observations
362-
n_treated = int(np.sum(d))
363-
n_control = int(np.sum(1 - d))
365+
# Count observations (use raw counts to avoid demeaned values from absorb)
366+
n_treated = n_treated_raw
367+
n_control = n_control_raw
364368

365369
# Create coefficient dictionary
366370
coef_dict = {name: coef for name, coef in zip(var_names, coefficients)}
@@ -985,6 +989,10 @@ def fit( # type: ignore[override]
985989
working_data = data.copy()
986990
n_absorbed_effects = 0
987991

992+
# Save raw treatment counts before absorb demeaning
993+
n_treated_raw = int(np.sum(data[treatment].values.astype(float)))
994+
n_control_raw = len(data) - n_treated_raw
995+
988996
# Pre-compute non_ref_periods (needed for absorb demeaning)
989997
non_ref_periods = [p for p in all_periods if p != reference_period]
990998

@@ -1216,9 +1224,9 @@ def fit( # type: ignore[override]
12161224
avg_att, avg_se, alpha=self.alpha, df=df
12171225
)
12181226

1219-
# Count observations
1220-
n_treated = int(np.sum(d))
1221-
n_control = int(np.sum(1 - d))
1227+
# Count observations (use raw counts to avoid demeaned values from absorb)
1228+
n_treated = n_treated_raw
1229+
n_control = n_control_raw
12221230

12231231
# Create coefficient dictionary
12241232
coef_dict = {name: coef for name, coef in zip(var_names, coefficients)}

diff_diff/linalg.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,25 +1835,6 @@ def get_inference(
18351835
coef = float(self.coefficients_[index])
18361836
se = float(np.sqrt(self.vcov_[index, index]))
18371837

1838-
# Handle zero or negative SE (indicates perfect fit or numerical issues)
1839-
if se <= 0:
1840-
import warnings
1841-
1842-
warnings.warn(
1843-
f"Standard error is zero or negative (se={se}) for coefficient at index {index}. "
1844-
"This may indicate perfect multicollinearity or numerical issues.",
1845-
UserWarning,
1846-
)
1847-
# NOTE: Deliberately uses ±inf (not NaN via safe_inference) for zero-SE coefficients.
1848-
if coef > 0:
1849-
t_stat = np.inf
1850-
elif coef < 0:
1851-
t_stat = -np.inf
1852-
else:
1853-
t_stat = 0.0
1854-
else:
1855-
t_stat = coef / se
1856-
18571838
# Use instance alpha if not provided
18581839
effective_alpha = alpha if alpha is not None else self.alpha
18591840

@@ -1877,11 +1858,12 @@ def get_inference(
18771858
)
18781859
effective_df = None
18791860

1880-
# Compute p-value
1881-
p_value = _compute_p_value(t_stat, df=effective_df)
1861+
# Use project-standard NaN-safe inference (returns all-NaN when SE <= 0)
1862+
from diff_diff.utils import safe_inference
18821863

1883-
# Compute confidence interval
1884-
conf_int = _compute_confidence_interval(coef, se, effective_alpha, df=effective_df)
1864+
t_stat, p_value, conf_int = safe_inference(
1865+
coef, se, alpha=effective_alpha, df=effective_df
1866+
)
18851867

18861868
return InferenceResult(
18871869
coefficient=coef,

diff_diff/survey.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,13 @@ def resolve(self, data: pd.DataFrame) -> "ResolvedSurveyDesign":
105105
if np.any(raw_weights <= 0):
106106
raise ValueError("Weights must be strictly positive")
107107

108-
# fweight validation: should be positive integers
108+
# fweight validation: must be positive integers
109109
if self.weight_type == "fweight":
110110
fractional = raw_weights - np.round(raw_weights)
111111
if np.any(np.abs(fractional) > 1e-10):
112-
warnings.warn(
113-
"Frequency weights (fweight) should be positive integers. "
114-
"Fractional values detected; rounding will not be applied.",
115-
UserWarning,
116-
stacklevel=2,
112+
raise ValueError(
113+
"Frequency weights (fweight) must be positive integers. "
114+
"Fractional values detected. Use pweight for non-integer weights."
117115
)
118116

119117
# Normalize: pweights/aweights to sum=n (mean=1); fweights unchanged
@@ -493,7 +491,7 @@ def compute_survey_vcov(
493491
strata = resolved.strata
494492
psu = resolved.psu
495493

496-
certainty_strata_count = 0
494+
legitimate_zero_count = 0
497495

498496
if strata is None and psu is None:
499497
# No survey structure beyond weights — use implicit per-observation PSUs
@@ -521,6 +519,8 @@ def compute_survey_vcov(
521519
if resolved.fpc is not None:
522520
N_h = resolved.fpc[0]
523521
f_h = n_psu / N_h
522+
if f_h >= 1.0:
523+
legitimate_zero_count += 1
524524
adjustment = (1.0 - f_h) * (n_psu / (n_psu - 1))
525525
meat = adjustment * (centered.T @ centered)
526526
else:
@@ -558,7 +558,7 @@ def compute_survey_vcov(
558558
if resolved.lonely_psu == "remove":
559559
continue # Skip this stratum
560560
elif resolved.lonely_psu == "certainty":
561-
certainty_strata_count += 1
561+
legitimate_zero_count += 1
562562
continue # f_h = 1, so (1-f_h) = 0, zero contribution
563563
elif resolved.lonely_psu == "adjust":
564564
# Center around overall mean instead of stratum mean
@@ -572,6 +572,8 @@ def compute_survey_vcov(
572572
if resolved.fpc is not None:
573573
N_h = resolved.fpc[mask_h][0]
574574
f_h = n_psu_h / N_h
575+
if f_h >= 1.0:
576+
legitimate_zero_count += 1
575577

576578
# Stratum mean of PSU scores
577579
psu_mean_h = psu_scores_h.mean(axis=0, keepdims=True)
@@ -584,8 +586,8 @@ def compute_survey_vcov(
584586

585587
# Guard: if no stratum contributed variance, check why
586588
if not np.any(meat != 0):
587-
if certainty_strata_count > 0:
588-
# All zero variance came from certainty PSUs — legitimate zero
589+
if legitimate_zero_count > 0:
590+
# All zero variance came from legitimate sources (certainty PSUs or full-census FPC)
589591
return np.zeros((k, k))
590592
return np.full((k, k), np.nan)
591593

tests/test_survey.py

Lines changed: 142 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -940,20 +940,17 @@ def test_multiperiod_with_survey_design(self, multiperiod_data):
940940
# Average ATT should be close to 2.5
941941
assert abs(result.avg_att - 2.5) < 1.5
942942

943-
def test_fweight_warning_for_fractional(self):
944-
"""Fractional fweights emit a UserWarning."""
943+
def test_fweight_error_for_fractional(self):
944+
"""Fractional fweights raise ValueError."""
945945
df = pd.DataFrame(
946946
{
947947
"y": [1, 2, 3],
948948
"w": [1.5, 2.0, 3.0], # 1.5 is fractional
949949
}
950950
)
951951
sd = SurveyDesign(weights="w", weight_type="fweight")
952-
with warnings.catch_warnings(record=True) as w:
953-
warnings.simplefilter("always")
952+
with pytest.raises(ValueError, match="Frequency weights.*must be positive integers"):
954953
sd.resolve(df)
955-
fweight_warnings = [x for x in w if "Frequency weights" in str(x.message)]
956-
assert len(fweight_warnings) >= 1
957954

958955
def test_lonely_psu_remove_warning(self):
959956
"""Singleton stratum with lonely_psu='remove' emits warning."""
@@ -2443,3 +2440,142 @@ def test_multiperiod_fweight_df_rounding(self):
24432440
assert np.isfinite(result.avg_att)
24442441
assert np.isfinite(result.avg_se)
24452442
assert result.avg_se > 0
2443+
2444+
2445+
class TestRound10Fixes:
2446+
"""Tests for PR #218 review round 10 fixes."""
2447+
2448+
def test_zero_se_estimator_nan_inference(self):
2449+
"""Zero-SE path in LinearRegression.get_inference() returns NaN, not ±inf."""
2450+
# Build a design where all strata are certainty PSUs → zero vcov → zero SE
2451+
np.random.seed(42)
2452+
n = 40
2453+
strata = np.repeat([0, 1, 2, 3], 10)
2454+
psu = strata.copy() # 1 PSU per stratum → all certainty
2455+
df = pd.DataFrame(
2456+
{
2457+
"outcome": np.random.randn(n),
2458+
"treated": np.array([1] * 20 + [0] * 20),
2459+
"post": np.tile([0, 1], 20),
2460+
"w": np.ones(n),
2461+
"strat": strata,
2462+
"cluster": psu,
2463+
}
2464+
)
2465+
sd = SurveyDesign(
2466+
weights="w",
2467+
weight_type="pweight",
2468+
strata="strat",
2469+
psu="cluster",
2470+
lonely_psu="certainty",
2471+
)
2472+
did = DifferenceInDifferences()
2473+
with warnings.catch_warnings():
2474+
warnings.simplefilter("ignore")
2475+
result = did.fit(
2476+
df,
2477+
outcome="outcome",
2478+
treatment="treated",
2479+
time="post",
2480+
survey_design=sd,
2481+
)
2482+
# SE should be 0 (all certainty strata), inference should be NaN
2483+
assert result.se == 0.0
2484+
assert np.isnan(result.t_stat)
2485+
assert np.isnan(result.p_value)
2486+
assert np.isnan(result.conf_int[0])
2487+
assert np.isnan(result.conf_int[1])
2488+
2489+
def test_full_census_fpc_stratified_zero_vcov(self):
2490+
"""Full-census FPC (f_h=1) returns zero vcov, not NaN."""
2491+
np.random.seed(42)
2492+
n = 60
2493+
strata = np.repeat([0, 1, 2], 20)
2494+
psu = np.tile(np.arange(5), 12) # 5 PSUs per stratum
2495+
2496+
X = np.column_stack([np.ones(n), np.random.randn(n)])
2497+
y = np.random.randn(n)
2498+
residuals = np.random.randn(n)
2499+
weights = np.ones(n)
2500+
2501+
# FPC = n_psu per stratum (full census: f_h = 5/5 = 1)
2502+
fpc = np.array([5.0] * n)
2503+
2504+
resolved = ResolvedSurveyDesign(
2505+
weights=weights,
2506+
weight_type="pweight",
2507+
strata=strata,
2508+
psu=psu,
2509+
fpc=fpc,
2510+
n_strata=3,
2511+
n_psu=15,
2512+
lonely_psu="remove",
2513+
)
2514+
vcov = compute_survey_vcov(X, residuals, resolved=resolved)
2515+
# Full census → zero variance → zero vcov
2516+
np.testing.assert_array_equal(vcov, np.zeros((2, 2)))
2517+
2518+
def test_full_census_fpc_unstratified_zero_vcov(self):
2519+
"""Unstratified full-census FPC returns zero vcov, not NaN."""
2520+
np.random.seed(42)
2521+
n = 30
2522+
psu = np.repeat(np.arange(6), 5) # 6 PSUs
2523+
2524+
X = np.column_stack([np.ones(n), np.random.randn(n)])
2525+
y = np.random.randn(n)
2526+
residuals = np.random.randn(n)
2527+
weights = np.ones(n)
2528+
2529+
# FPC = n_psu (full census: f_h = 6/6 = 1)
2530+
fpc = np.array([6.0] * n)
2531+
2532+
resolved = ResolvedSurveyDesign(
2533+
weights=weights,
2534+
weight_type="pweight",
2535+
strata=None,
2536+
psu=psu,
2537+
fpc=fpc,
2538+
n_strata=0,
2539+
n_psu=6,
2540+
lonely_psu="remove",
2541+
)
2542+
vcov = compute_survey_vcov(X, residuals, resolved=resolved)
2543+
# Full census → (1-f_h)=0 → zero meat → zero vcov
2544+
np.testing.assert_array_equal(vcov, np.zeros((2, 2)))
2545+
2546+
def test_absorbed_did_sample_counts(self):
2547+
"""n_treated/n_control reflect raw data, not demeaned values after absorb."""
2548+
np.random.seed(42)
2549+
n_units = 20
2550+
n_times = 4
2551+
rows = []
2552+
for u in range(n_units):
2553+
for t in range(n_times):
2554+
rows.append(
2555+
{
2556+
"unit": u,
2557+
"time": t,
2558+
"treated": 1 if u < 8 else 0,
2559+
"post": 1 if t >= 2 else 0,
2560+
"outcome": np.random.randn(),
2561+
"region": u % 3,
2562+
}
2563+
)
2564+
df = pd.DataFrame(rows)
2565+
2566+
did = DifferenceInDifferences()
2567+
with warnings.catch_warnings():
2568+
warnings.simplefilter("ignore")
2569+
result = did.fit(
2570+
df,
2571+
outcome="outcome",
2572+
treatment="treated",
2573+
time="post",
2574+
absorb=["region"],
2575+
)
2576+
2577+
# Raw counts: 8 treated units * 4 times = 32 treated obs
2578+
raw_treated = int(df["treated"].sum())
2579+
raw_control = len(df) - raw_treated
2580+
assert result.n_treated == raw_treated
2581+
assert result.n_control == raw_control

0 commit comments

Comments
 (0)