Skip to content

Commit b6db034

Browse files
igerberclaude
andcommitted
Reject non-unique PSU labels across strata with nest=False and fix test fixtures from PR #218 review (round 20)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent eebc13e commit b6db034

2 files changed

Lines changed: 65 additions & 37 deletions

File tree

diff_diff/survey.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -156,19 +156,24 @@ def resolve(self, data: pd.DataFrame) -> "ResolvedSurveyDesign":
156156
psu_arr = _factorize_cluster_ids(combined)
157157
else:
158158
psu_arr = _factorize_cluster_ids(psu_raw)
159+
# Validate PSU labels are globally unique when nest=False
160+
# and strata are present. Repeated labels cause wrong n_psu,
161+
# df_survey, and lonely_psu="adjust" global mean.
162+
if strata_arr is not None:
163+
seen_psus: set = set()
164+
for h in np.unique(strata_arr):
165+
psu_in_h = set(psu_raw[strata_arr == h])
166+
overlap = seen_psus & psu_in_h
167+
if overlap:
168+
raise ValueError(
169+
f"PSU labels {overlap} appear in multiple strata. "
170+
"Set nest=True in SurveyDesign to make PSU IDs "
171+
"unique within strata, or use globally unique "
172+
"PSU labels."
173+
)
174+
seen_psus |= psu_in_h
159175

160-
# Count total PSUs: sum of unique PSUs within each stratum.
161-
# When nest=True, labels are already globally unique so this
162-
# is equivalent to len(np.unique(psu_arr)). When nest=False
163-
# with strata, PSU labels may repeat across strata (common in
164-
# survey data), so we count per-stratum to get the correct total.
165-
if strata_arr is not None and not self.nest:
166-
n_psu = sum(
167-
len(np.unique(psu_arr[strata_arr == h]))
168-
for h in np.unique(strata_arr)
169-
)
170-
else:
171-
n_psu = len(np.unique(psu_arr))
176+
n_psu = len(np.unique(psu_arr))
172177

173178
# --- FPC ---
174179
fpc_arr = None

tests/test_survey.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def multiperiod_data():
118118
for unit in range(n_units):
119119
is_treated = unit < n_treated
120120
stratum = unit // 20 # 3 strata (20 units each)
121-
psu = unit // 3 # 20 PSUs globally unique
121+
psu_within = (unit % 20) // 5 # 4 PSUs within each stratum
122+
psu = stratum * 4 + psu_within # globally unique PSU ID
122123
wt = 1.0 + 0.4 * stratum
123124

124125
for t in periods:
@@ -871,21 +872,16 @@ def test_nest_true(self):
871872
}
872873
)
873874

874-
# Without nest: PSU labels repeat but n_psu counts per-stratum
875+
# nest=False rejects repeated PSU labels across strata
875876
sd_no_nest = SurveyDesign(weights="w", strata="s", psu="psu", nest=False)
876-
resolved_no_nest = sd_no_nest.resolve(df)
877+
with pytest.raises(ValueError, match="PSU labels.*multiple strata"):
878+
sd_no_nest.resolve(df)
877879

878-
# With nest: PSU 0 in stratum 0 != PSU 0 in stratum 1
880+
# nest=True makes them unique: PSU 0 in stratum 0 != PSU 0 in stratum 1
879881
sd_nest = SurveyDesign(weights="w", strata="s", psu="psu", nest=True)
880882
resolved_nest = sd_nest.resolve(df)
881-
882-
# Both should produce 20 PSUs (10 per stratum × 2 strata)
883-
# nest=True makes globally unique codes; nest=False counts per-stratum
884-
assert resolved_nest.n_psu == 20
885-
assert resolved_no_nest.n_psu == 20
886-
# df_survey should match: 20 - 2 = 18
887-
assert resolved_nest.df_survey == 18
888-
assert resolved_no_nest.df_survey == 18
883+
assert resolved_nest.n_psu == 20 # 10 per stratum × 2 strata
884+
assert resolved_nest.df_survey == 18 # 20 - 2
889885

890886
def test_twfe_with_survey_design(self, twfe_panel_data):
891887
"""TwoWayFixedEffects accepts and uses survey_design."""
@@ -3058,14 +3054,13 @@ def test_did_with_fpc_only_survey(self):
30583054

30593055

30603056
class TestRound19Fixes:
3061-
"""Tests for PR #218 review round 19: per-stratum PSU counting."""
3057+
"""Tests for PR #218 review round 19: PSU nesting validation."""
30623058

3063-
def test_npsu_counts_per_stratum_with_repeated_labels(self):
3064-
"""n_psu counts unique PSUs per stratum, not globally, when labels repeat."""
3059+
def test_repeated_psu_labels_nest_false_rejected(self):
3060+
"""Repeated PSU labels across strata with nest=False are rejected."""
30653061
n = 40
30663062
strata = np.repeat([0, 1], 20)
3067-
# PSU IDs 0..9 repeat across both strata
3068-
psu_raw = np.tile(np.arange(10), 4)[:n]
3063+
psu_raw = np.tile(np.arange(10), 4)[:n] # labels repeat
30693064

30703065
df = pd.DataFrame(
30713066
{
@@ -3075,15 +3070,43 @@ def test_npsu_counts_per_stratum_with_repeated_labels(self):
30753070
"psu": psu_raw,
30763071
}
30773072
)
3078-
3079-
# nest=False with repeated labels: should count 10+10=20 PSUs
30803073
sd = SurveyDesign(weights="w", strata="s", psu="psu", nest=False)
3074+
with pytest.raises(ValueError, match="PSU labels.*multiple strata"):
3075+
sd.resolve(df)
3076+
3077+
def test_repeated_psu_labels_nest_true_accepted(self):
3078+
"""Repeated PSU labels with nest=True produce correct n_psu."""
3079+
n = 40
3080+
strata = np.repeat([0, 1], 20)
3081+
psu_raw = np.tile(np.arange(10), 4)[:n]
3082+
3083+
df = pd.DataFrame(
3084+
{
3085+
"y": np.ones(n),
3086+
"w": np.ones(n),
3087+
"s": strata,
3088+
"psu": psu_raw,
3089+
}
3090+
)
3091+
sd = SurveyDesign(weights="w", strata="s", psu="psu", nest=True)
30813092
resolved = sd.resolve(df)
3082-
assert resolved.n_psu == 20 # 10 per stratum × 2 strata
3093+
assert resolved.n_psu == 20 # 10 per stratum × 2
30833094
assert resolved.df_survey == 18 # 20 - 2
30843095

3085-
# nest=True should give same result
3086-
sd_nest = SurveyDesign(weights="w", strata="s", psu="psu", nest=True)
3087-
resolved_nest = sd_nest.resolve(df)
3088-
assert resolved_nest.n_psu == 20
3089-
assert resolved_nest.df_survey == 18
3096+
def test_unique_psu_labels_nest_false_accepted(self):
3097+
"""Globally unique PSU labels with nest=False work correctly."""
3098+
n = 40
3099+
strata = np.repeat([0, 1], 20)
3100+
psu_raw = np.arange(n) // 2 # 20 unique PSUs, no overlap
3101+
3102+
df = pd.DataFrame(
3103+
{
3104+
"y": np.ones(n),
3105+
"w": np.ones(n),
3106+
"s": strata,
3107+
"psu": psu_raw,
3108+
}
3109+
)
3110+
sd = SurveyDesign(weights="w", strata="s", psu="psu", nest=False)
3111+
resolved = sd.resolve(df)
3112+
assert resolved.n_psu == 20

0 commit comments

Comments
 (0)