Skip to content

Commit d12bab2

Browse files
igerberclaude
andcommitted
Add NA validation for survey strata/PSU/cluster IDs and fix results label consistency from PR #218 review (round 15)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c5eca78 commit d12bab2

3 files changed

Lines changed: 78 additions & 11 deletions

File tree

diff_diff/results.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ class DiDResults:
3434
n_obs : int
3535
Number of observations used in estimation.
3636
n_treated : int
37-
Number of treated observations.
37+
Number of treated units/observations.
3838
n_control : int
39-
Number of control observations.
39+
Number of control units/observations.
4040
"""
4141

4242
att: float
@@ -93,8 +93,8 @@ def summary(self, alpha: Optional[float] = None) -> str:
9393
"=" * 70,
9494
"",
9595
f"{'Observations:':<25} {self.n_obs:>10}",
96-
f"{'Treated obs:':<25} {self.n_treated:>10}",
97-
f"{'Control obs:':<25} {self.n_control:>10}",
96+
f"{'Treated:':<25} {self.n_treated:>10}",
97+
f"{'Control:':<25} {self.n_control:>10}",
9898
]
9999

100100
if self.r_squared is not None:
@@ -312,9 +312,9 @@ class MultiPeriodDiDResults:
312312
n_obs : int
313313
Number of observations used in estimation.
314314
n_treated : int
315-
Number of treated observations.
315+
Number of treated units/observations.
316316
n_control : int
317-
Number of control observations.
317+
Number of control units/observations.
318318
pre_periods : list
319319
List of pre-treatment period identifiers.
320320
post_periods : list
@@ -645,9 +645,9 @@ class SyntheticDiDResults:
645645
n_obs : int
646646
Number of observations used in estimation.
647647
n_treated : int
648-
Number of treated observations.
648+
Number of treated units/observations.
649649
n_control : int
650-
Number of control observations.
650+
Number of control units/observations.
651651
unit_weights : dict
652652
Dictionary mapping control unit IDs to their synthetic weights.
653653
time_weights : dict
@@ -714,8 +714,8 @@ def summary(self, alpha: Optional[float] = None) -> str:
714714
"=" * 75,
715715
"",
716716
f"{'Observations:':<25} {self.n_obs:>10}",
717-
f"{'Treated obs:':<25} {self.n_treated:>10}",
718-
f"{'Control obs:':<25} {self.n_control:>10}",
717+
f"{'Treated:':<25} {self.n_treated:>10}",
718+
f"{'Control:':<25} {self.n_control:>10}",
719719
f"{'Pre-treatment periods:':<25} {len(self.pre_periods):>10}",
720720
f"{'Post-treatment periods:':<25} {len(self.post_periods):>10}",
721721
]

diff_diff/survey.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,13 @@ def resolve(self, data: pd.DataFrame) -> "ResolvedSurveyDesign":
128128
if self.strata is not None:
129129
if self.strata not in data.columns:
130130
raise ValueError(f"Strata column '{self.strata}' not found in data")
131-
strata_arr = _factorize_cluster_ids(data[self.strata].values)
131+
strata_vals = data[self.strata].values
132+
if pd.isna(strata_vals).any():
133+
raise ValueError(
134+
f"Strata column '{self.strata}' contains missing values. "
135+
"All observations must have valid strata identifiers."
136+
)
137+
strata_arr = _factorize_cluster_ids(strata_vals)
132138
n_strata = len(np.unique(strata_arr))
133139

134140
# --- PSU ---
@@ -138,6 +144,11 @@ def resolve(self, data: pd.DataFrame) -> "ResolvedSurveyDesign":
138144
if self.psu not in data.columns:
139145
raise ValueError(f"PSU column '{self.psu}' not found in data")
140146
psu_raw = data[self.psu].values
147+
if pd.isna(psu_raw).any():
148+
raise ValueError(
149+
f"PSU column '{self.psu}' contains missing values. "
150+
"All observations must have valid PSU identifiers."
151+
)
141152

142153
if self.nest and strata_arr is not None:
143154
# Make PSU IDs unique within strata by combining
@@ -440,6 +451,14 @@ def _inject_cluster_as_psu(resolved, cluster_ids):
440451
if resolved.psu is not None:
441452
return resolved # PSU already present; _resolve_effective_cluster handles this
442453

454+
# Validate no missing cluster IDs before factorization
455+
if pd.isna(cluster_ids).any():
456+
raise ValueError(
457+
"Cluster IDs contain missing values. "
458+
"All observations must have valid cluster identifiers "
459+
"when used as effective PSUs for survey variance estimation."
460+
)
461+
443462
# Factorize cluster_ids for consistent integer encoding
444463
codes, uniques = pd.factorize(cluster_ids)
445464
n_clusters = len(uniques)

tests/test_survey.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2824,3 +2824,51 @@ def test_multiperiod_bootstrap_survey_fallback(self):
28242824
survey_design=sd,
28252825
)
28262826
assert np.isfinite(result.avg_att)
2827+
2828+
2829+
class TestRound15Fixes:
2830+
"""Tests for PR #218 review round 15: NA validation for survey identifiers."""
2831+
2832+
def test_strata_with_na_rejected(self):
2833+
"""SurveyDesign.resolve() rejects NA values in strata column."""
2834+
df = pd.DataFrame(
2835+
{
2836+
"y": [1.0, 2.0, 3.0, 4.0],
2837+
"w": [1.0, 1.0, 1.0, 1.0],
2838+
"strat": [0, 1, None, 0], # NA in strata
2839+
}
2840+
)
2841+
sd = SurveyDesign(weights="w", weight_type="pweight", strata="strat")
2842+
with pytest.raises(ValueError, match="Strata column.*missing values"):
2843+
sd.resolve(df)
2844+
2845+
def test_psu_with_na_rejected(self):
2846+
"""SurveyDesign.resolve() rejects NA values in PSU column."""
2847+
df = pd.DataFrame(
2848+
{
2849+
"y": [1.0, 2.0, 3.0, 4.0],
2850+
"w": [1.0, 1.0, 1.0, 1.0],
2851+
"cluster": [0, 1, np.nan, 0], # NA in PSU
2852+
}
2853+
)
2854+
sd = SurveyDesign(weights="w", weight_type="pweight", psu="cluster")
2855+
with pytest.raises(ValueError, match="PSU column.*missing values"):
2856+
sd.resolve(df)
2857+
2858+
def test_cluster_as_psu_with_na_rejected(self):
2859+
"""_inject_cluster_as_psu rejects NA values in cluster IDs."""
2860+
from diff_diff.survey import _inject_cluster_as_psu
2861+
2862+
resolved = ResolvedSurveyDesign(
2863+
weights=np.ones(4),
2864+
weight_type="pweight",
2865+
strata=None,
2866+
psu=None,
2867+
fpc=None,
2868+
n_strata=0,
2869+
n_psu=0,
2870+
lonely_psu="remove",
2871+
)
2872+
cluster_ids = np.array([0, 1, np.nan, 0])
2873+
with pytest.raises(ValueError, match="Cluster IDs contain missing"):
2874+
_inject_cluster_as_psu(resolved, cluster_ids)

0 commit comments

Comments
 (0)