Skip to content

Commit b4cd770

Browse files
igerberclaude
andcommitted
Fix cluster-ignored-with-survey and weight validation gaps from PR #218 review (round 7)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 21690d2 commit b4cd770

6 files changed

Lines changed: 267 additions & 4 deletions

File tree

diff_diff/estimators.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,14 @@ def fit(
308308
resolved_survey, cluster_ids, self.cluster
309309
)
310310

311+
# Inject cluster as effective PSU for survey variance estimation
312+
if resolved_survey is not None and effective_cluster_ids is not None:
313+
from diff_diff.survey import _inject_cluster_as_psu, compute_survey_metadata
314+
resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids)
315+
if resolved_survey.psu is not None and survey_metadata is not None:
316+
raw_w = data[survey_design.weights].values.astype(np.float64) if survey_design.weights else np.ones(len(data), dtype=np.float64)
317+
survey_metadata = compute_survey_metadata(resolved_survey, raw_w)
318+
311319
reg = LinearRegression(
312320
include_intercept=False, # Intercept already in X
313321
robust=self.robust,
@@ -1036,6 +1044,14 @@ def fit( # type: ignore[override]
10361044
resolved_survey, cluster_ids, self.cluster
10371045
)
10381046

1047+
# Inject cluster as effective PSU for survey variance estimation
1048+
if resolved_survey is not None and effective_cluster_ids is not None:
1049+
from diff_diff.survey import _inject_cluster_as_psu, compute_survey_metadata
1050+
resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids)
1051+
if resolved_survey.psu is not None and survey_metadata is not None:
1052+
raw_w = data[survey_design.weights].values.astype(np.float64) if survey_design.weights else np.ones(len(data), dtype=np.float64)
1053+
survey_metadata = compute_survey_metadata(resolved_survey, raw_w)
1054+
10391055
# Determine if survey vcov should be used
10401056
_use_survey_vcov = resolved_survey is not None and resolved_survey.needs_survey_vcov
10411057

diff_diff/linalg.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,34 @@ def solve_ols(
383383
]: ...
384384

385385

386+
_VALID_WEIGHT_TYPES = {"pweight", "fweight", "aweight"}
387+
388+
389+
def _validate_weights(weights, weight_type, n):
390+
"""Validate weights array and weight_type for solve_ols/LinearRegression."""
391+
if weight_type not in _VALID_WEIGHT_TYPES:
392+
raise ValueError(
393+
f"weight_type must be one of {_VALID_WEIGHT_TYPES}, "
394+
f"got '{weight_type}'"
395+
)
396+
if weights is not None:
397+
weights = np.asarray(weights, dtype=np.float64)
398+
if weights.shape[0] != n:
399+
raise ValueError(
400+
f"weights length ({weights.shape[0]}) must match "
401+
f"X rows ({n})"
402+
)
403+
if np.any(np.isnan(weights)):
404+
raise ValueError("Weights contain NaN values")
405+
if np.any(np.isinf(weights)):
406+
raise ValueError("Weights contain Inf values")
407+
if np.any(weights < 0):
408+
raise ValueError(
409+
"Weights must be non-negative"
410+
)
411+
return weights
412+
413+
386414
def solve_ols(
387415
X: np.ndarray,
388416
y: np.ndarray,
@@ -543,9 +571,7 @@ def solve_ols(
543571
_original_X = None
544572
_original_y = None
545573
if weights is not None:
546-
weights = np.asarray(weights, dtype=np.float64)
547-
if weights.shape[0] != n:
548-
raise ValueError(f"weights length ({weights.shape[0]}) must match X rows ({n})")
574+
weights = _validate_weights(weights, weight_type, n)
549575
_original_X = X
550576
_original_y = y
551577
sqrt_w = np.sqrt(weights)
@@ -1567,6 +1593,23 @@ def fit(
15671593
self.weights = self.survey_design.weights
15681594
self.weight_type = self.survey_design.weight_type
15691595

1596+
if self.weights is not None:
1597+
self.weights = _validate_weights(
1598+
self.weights, self.weight_type, X.shape[0]
1599+
)
1600+
1601+
# Inject cluster as PSU for survey variance when no PSU specified
1602+
if (
1603+
effective_cluster_ids is not None
1604+
and self.survey_design is not None
1605+
and _use_survey_vcov
1606+
):
1607+
from diff_diff.survey import ResolvedSurveyDesign as _RSD, _inject_cluster_as_psu
1608+
if isinstance(self.survey_design, _RSD) and self.survey_design.psu is None:
1609+
self.survey_design = _inject_cluster_as_psu(
1610+
self.survey_design, effective_cluster_ids
1611+
)
1612+
15701613
if self.robust or effective_cluster_ids is not None:
15711614
# Use solve_ols with robust/cluster SEs
15721615
# When survey vcov will be used, skip standard vcov computation

diff_diff/survey.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616

1717
import warnings
18-
from dataclasses import dataclass, field
18+
from dataclasses import dataclass, field, replace
1919
from typing import Optional, Tuple
2020

2121
import numpy as np
@@ -430,6 +430,25 @@ def _resolve_effective_cluster(resolved_survey, cluster_ids, cluster_name=None):
430430
return resolved_survey.psu
431431

432432

433+
def _inject_cluster_as_psu(resolved, cluster_ids):
434+
"""
435+
When survey design has no PSU but cluster_ids are provided,
436+
inject cluster_ids as the effective PSU for TSL variance estimation.
437+
438+
Returns a new ResolvedSurveyDesign (no mutation) or the original unchanged.
439+
"""
440+
if resolved is None or cluster_ids is None:
441+
return resolved
442+
if resolved.psu is not None:
443+
return resolved # PSU already present; _resolve_effective_cluster handles this
444+
445+
# Factorize cluster_ids for consistent integer encoding
446+
codes, uniques = pd.factorize(cluster_ids)
447+
n_clusters = len(uniques)
448+
449+
return replace(resolved, psu=codes, n_psu=n_clusters)
450+
451+
433452
def compute_survey_vcov(
434453
X: np.ndarray,
435454
residuals: np.ndarray,

diff_diff/twfe.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,14 @@ def fit( # type: ignore[override]
175175
resolved_survey, cluster_ids, self.cluster
176176
)
177177

178+
# Inject cluster as effective PSU for survey variance estimation
179+
if resolved_survey is not None and effective_cluster_ids is not None:
180+
from diff_diff.survey import _inject_cluster_as_psu, compute_survey_metadata
181+
resolved_survey = _inject_cluster_as_psu(resolved_survey, effective_cluster_ids)
182+
if resolved_survey.psu is not None and survey_metadata is not None:
183+
raw_w = data[survey_design.weights].values.astype(np.float64) if survey_design.weights else np.ones(len(data), dtype=np.float64)
184+
survey_metadata = compute_survey_metadata(resolved_survey, raw_w)
185+
178186
# Pass rank_deficient_action to LinearRegression
179187
# If "error", let LinearRegression raise immediately
180188
# If "warn" or "silent", suppress generic warning and use TWFE's context-specific

docs/methodology/REGISTRY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,6 +1883,10 @@ unequal selection probabilities).
18831883
- **Note:** When no explicit PSU is specified (weights-only or stratified-no-PSU
18841884
designs), each observation is treated as its own PSU for df purposes. Survey df
18851885
becomes `n_obs - n_strata` (or `n_obs - 1` when unstratified).
1886+
- **Note:** When survey_design specifies weights only (no PSU) and cluster=
1887+
is specified, cluster IDs are injected as effective PSUs for Taylor Series
1888+
Linearization variance estimation, matching the R `survey` package
1889+
convention that clusters are the primary sampling units.
18861890

18871891
---
18881892

tests/test_survey.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,3 +1943,176 @@ def test_matching_weights_no_warning(self):
19431943
with warnings.catch_warnings():
19441944
warnings.simplefilter("error")
19451945
reg.fit(X, y)
1946+
1947+
1948+
class TestRound7Fixes:
1949+
"""Tests for round-7 review fixes (PR #218)."""
1950+
1951+
@staticmethod
1952+
def _make_cluster_data(seed=700):
1953+
"""Create 2-period DiD data with 10 clusters of 5 obs each."""
1954+
np.random.seed(seed)
1955+
n_clusters = 10
1956+
obs_per_cluster = 5
1957+
rows = []
1958+
for c in range(n_clusters):
1959+
is_treated = c >= 5
1960+
for i in range(obs_per_cluster):
1961+
for period in [0, 1]:
1962+
y = 10.0 + c * 0.3 + np.random.randn() * 0.5
1963+
if period == 1 and is_treated:
1964+
y += 3.0
1965+
rows.append({
1966+
"unit": c * obs_per_cluster + i,
1967+
"period": period,
1968+
"treated": int(is_treated),
1969+
"y": y,
1970+
"cluster_id": c,
1971+
"w": 1.0 + 0.2 * c,
1972+
})
1973+
return pd.DataFrame(rows)
1974+
1975+
def test_cluster_injected_as_psu_did(self):
1976+
"""Cluster IDs injected as PSU produce identical SEs to explicit PSU."""
1977+
data = self._make_cluster_data()
1978+
1979+
# Fit with cluster= and weights-only survey (no PSU)
1980+
result_inject = DifferenceInDifferences(cluster="cluster_id").fit(
1981+
data, "y", "treated", "period",
1982+
survey_design=SurveyDesign(weights="w"),
1983+
)
1984+
1985+
# Fit with explicit PSU in survey design
1986+
result_explicit = DifferenceInDifferences(cluster="cluster_id").fit(
1987+
data, "y", "treated", "period",
1988+
survey_design=SurveyDesign(weights="w", psu="cluster_id"),
1989+
)
1990+
1991+
np.testing.assert_allclose(result_inject.se, result_explicit.se, atol=1e-12)
1992+
assert result_inject.survey_metadata.n_psu == 10
1993+
assert result_inject.survey_metadata.df_survey == 9
1994+
1995+
def test_cluster_injected_as_psu_twfe(self):
1996+
"""TWFE: cluster IDs injected as PSU produce identical SEs to explicit PSU."""
1997+
data = self._make_cluster_data()
1998+
1999+
result_inject = TwoWayFixedEffects(cluster="cluster_id").fit(
2000+
data, "y", "treated", "period", unit="unit",
2001+
survey_design=SurveyDesign(weights="w"),
2002+
)
2003+
2004+
result_explicit = TwoWayFixedEffects(cluster="cluster_id").fit(
2005+
data, "y", "treated", "period", unit="unit",
2006+
survey_design=SurveyDesign(weights="w", psu="cluster_id"),
2007+
)
2008+
2009+
np.testing.assert_allclose(result_inject.se, result_explicit.se, atol=1e-12)
2010+
assert result_inject.survey_metadata.n_psu == 10
2011+
assert result_inject.survey_metadata.df_survey == 9
2012+
2013+
def test_cluster_injected_as_psu_linear_regression(self):
2014+
"""Standalone LinearRegression: cluster injection matches explicit PSU."""
2015+
np.random.seed(701)
2016+
n = 50
2017+
cluster_ids = np.repeat(np.arange(10), 5)
2018+
X = np.column_stack([np.ones(n), np.random.randn(n)])
2019+
y = 1.0 + X[:, 1] * 0.5 + np.random.randn(n) * 0.4
2020+
weights = np.random.uniform(0.5, 3.0, n)
2021+
2022+
# No PSU in resolved design
2023+
resolved_no_psu = ResolvedSurveyDesign(
2024+
weights=weights, weight_type="pweight",
2025+
strata=None, psu=None, fpc=None,
2026+
n_strata=0, n_psu=0, lonely_psu="remove",
2027+
)
2028+
reg_inject = LinearRegression(
2029+
include_intercept=False, cluster_ids=cluster_ids,
2030+
survey_design=resolved_no_psu,
2031+
)
2032+
reg_inject.fit(X, y)
2033+
2034+
# Explicit PSU
2035+
codes, uniques = pd.factorize(cluster_ids)
2036+
resolved_psu = ResolvedSurveyDesign(
2037+
weights=weights, weight_type="pweight",
2038+
strata=None, psu=codes, fpc=None,
2039+
n_strata=0, n_psu=len(uniques), lonely_psu="remove",
2040+
)
2041+
reg_explicit = LinearRegression(
2042+
include_intercept=False, cluster_ids=cluster_ids,
2043+
survey_design=resolved_psu,
2044+
)
2045+
reg_explicit.fit(X, y)
2046+
2047+
np.testing.assert_allclose(reg_inject.vcov_, reg_explicit.vcov_, atol=1e-12)
2048+
2049+
def test_cluster_injection_no_effect_when_psu_present(self):
2050+
"""When PSU is already present, _inject_cluster_as_psu is a no-op."""
2051+
from diff_diff.survey import _inject_cluster_as_psu
2052+
2053+
existing_psu = np.array([0, 0, 1, 1, 2, 2])
2054+
resolved = ResolvedSurveyDesign(
2055+
weights=np.ones(6), weight_type="pweight",
2056+
strata=None, psu=existing_psu, fpc=None,
2057+
n_strata=0, n_psu=3, lonely_psu="remove",
2058+
)
2059+
result = _inject_cluster_as_psu(resolved, np.array([10, 10, 20, 20, 30, 30]))
2060+
assert result is resolved # Same object — no replacement
2061+
2062+
def test_invalid_weight_type_raises(self):
2063+
"""Invalid weight_type raises ValueError in solve_ols and LinearRegression."""
2064+
n = 20
2065+
X = np.column_stack([np.ones(n), np.random.randn(n)])
2066+
y = np.random.randn(n)
2067+
w = np.ones(n)
2068+
2069+
with pytest.raises(ValueError, match="weight_type must be one of"):
2070+
solve_ols(X, y, weights=w, weight_type="pwieght")
2071+
2072+
with pytest.raises(ValueError, match="weight_type must be one of"):
2073+
LinearRegression(weights=w, weight_type="bad").fit(X, y)
2074+
2075+
def test_nan_weights_raises(self):
2076+
"""NaN weights raise ValueError."""
2077+
n = 20
2078+
X = np.column_stack([np.ones(n), np.random.randn(n)])
2079+
y = np.random.randn(n)
2080+
w = np.ones(n)
2081+
w[5] = np.nan
2082+
2083+
with pytest.raises(ValueError, match="NaN"):
2084+
solve_ols(X, y, weights=w)
2085+
2086+
def test_negative_weights_raises(self):
2087+
"""Negative weights raise ValueError."""
2088+
n = 20
2089+
X = np.column_stack([np.ones(n), np.random.randn(n)])
2090+
y = np.random.randn(n)
2091+
w = np.ones(n)
2092+
w[3] = -0.5
2093+
2094+
with pytest.raises(ValueError, match="non-negative"):
2095+
solve_ols(X, y, weights=w)
2096+
2097+
def test_inf_weights_raises(self):
2098+
"""Inf weights raise ValueError."""
2099+
n = 20
2100+
X = np.column_stack([np.ones(n), np.random.randn(n)])
2101+
y = np.random.randn(n)
2102+
w = np.ones(n)
2103+
w[0] = np.inf
2104+
2105+
with pytest.raises(ValueError, match="Inf"):
2106+
solve_ols(X, y, weights=w)
2107+
2108+
def test_zero_weights_accepted(self):
2109+
"""Zero weights are accepted (intentional divergence from SurveyDesign)."""
2110+
n = 20
2111+
X = np.column_stack([np.ones(n), np.random.randn(n)])
2112+
y = np.random.randn(n)
2113+
w = np.ones(n)
2114+
w[0] = 0.0
2115+
2116+
# Should NOT raise
2117+
coef, resid, vcov = solve_ols(X, y, weights=w)
2118+
assert coef is not None

0 commit comments

Comments
 (0)