Skip to content

Commit d7ddb19

Browse files
igerberclaude
andcommitted
Fix AI review: remove over-restrictive group-constant validation, vectorize IF expansion
- Remove _validate_group_constant_survey() call - the IF expansion psi_i = U[g] * (w_i / W_g) handles observation-level variation in weights, strata, and PSU within groups correctly - Vectorize _survey_se_from_group_if using np.bincount + np.unique (was Python loops over all observations) - Replace test_rejects_varying_weights_within_group with two positive tests: varying weights accepted, and varying weights change ATT (time-varying noise to survive first-differencing) - Remove unused survey_weight_type variable Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 531013e commit d7ddb19

2 files changed

Lines changed: 66 additions & 37 deletions

File tree

diff_diff/chaisemartin_dhaultfoeuille.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -630,12 +630,9 @@ def fit(
630630
# ------------------------------------------------------------------
631631
# Step 3: Survey resolution
632632
# ------------------------------------------------------------------
633-
from diff_diff.survey import (
634-
_resolve_survey_for_fit,
635-
_validate_group_constant_survey,
636-
)
633+
from diff_diff.survey import _resolve_survey_for_fit
637634

638-
resolved_survey, survey_weights, survey_weight_type, survey_metadata = (
635+
resolved_survey, survey_weights, _, survey_metadata = (
639636
_resolve_survey_for_fit(survey_design, data, "analytical")
640637
)
641638

@@ -653,8 +650,9 @@ def fit(
653650
"Use strata/PSU/FPC for design-based inference via Taylor "
654651
"Series Linearization."
655652
)
656-
# Validate survey columns are constant within groups.
657-
_validate_group_constant_survey(data, group, survey_design)
653+
# No group-constant survey validation: the IF expansion
654+
# psi_i = U[g] * (w_i / W_g) handles observation-level
655+
# variation in weights, strata, and PSU within groups.
658656

659657
# Design-2 precondition: requires drop_larger_lower=False
660658
if design2 and self.drop_larger_lower:
@@ -4593,27 +4591,21 @@ def _survey_se_from_group_if(
45934591
group_ids = obs_survey_info["group_ids"]
45944592
weights = obs_survey_info["weights"]
45954593
resolved = obs_survey_info["resolved"]
4596-
n_obs = len(group_ids)
45974594

4598-
# Build group → U_centered lookup
4599-
group_to_u = {}
4600-
for idx, gid in enumerate(eligible_groups):
4601-
group_to_u[gid] = U_centered[idx]
4595+
# Build group → U_centered lookup (vectorized via factorization)
4596+
group_to_u = {gid: U_centered[idx] for idx, gid in enumerate(eligible_groups)}
4597+
4598+
# Map group IFs to observation level
4599+
u_obs = np.array([group_to_u.get(gid, 0.0) for gid in group_ids])
46024600

4603-
# Compute per-group weight totals W_g
4604-
group_to_w_total: Dict[Any, float] = {}
4605-
for i in range(n_obs):
4606-
gid = group_ids[i]
4607-
group_to_w_total[gid] = group_to_w_total.get(gid, 0.0) + weights[i]
4601+
# Compute per-group weight totals W_g via bincount
4602+
unique_gids, inverse = np.unique(group_ids, return_inverse=True)
4603+
w_totals_per_group = np.bincount(inverse, weights=weights)
4604+
w_obs_total = w_totals_per_group[inverse]
46084605

46094606
# Expand to observation level: psi_i = U[g] * (w_i / W_g)
4610-
psi = np.zeros(n_obs)
4611-
for i in range(n_obs):
4612-
gid = group_ids[i]
4613-
u_val = group_to_u.get(gid, 0.0)
4614-
w_total = group_to_w_total.get(gid, 1.0)
4615-
if w_total > 0:
4616-
psi[i] = u_val * (weights[i] / w_total)
4607+
safe_w = np.where(w_obs_total > 0, w_obs_total, 1.0)
4608+
psi = u_obs * (weights / safe_w)
46174609

46184610
variance = compute_survey_if_variance(psi, resolved)
46194611
if not np.isfinite(variance) or variance < 0:

tests/test_survey_dcdh.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -264,21 +264,58 @@ def test_rejects_aweight(self, base_data):
264264
survey_design=sd,
265265
)
266266

267-
def test_rejects_varying_weights_within_group(self, base_data):
268-
"""Weights must be constant within groups."""
267+
def test_varying_weights_within_group_accepted(self, base_data):
268+
"""Observation-level weights varying within groups are valid."""
269+
# Create multi-obs cells with varying weights
270+
rng = np.random.default_rng(1)
269271
df = base_data.copy()
270-
# Assign different weights to different observations in the same group
271-
df["pw"] = np.random.default_rng(1).uniform(0.5, 3.0, size=len(df))
272+
df2 = base_data.copy()
273+
df2["outcome"] = df2["outcome"] + rng.normal(0, 0.5, size=len(df2))
274+
multi = pd.concat([df, df2], ignore_index=True)
275+
# Observation-level weights (vary within group)
276+
multi["pw"] = rng.uniform(0.5, 3.0, size=len(multi))
272277
sd = SurveyDesign(weights="pw")
273-
with pytest.raises(ValueError, match="varies within groups"):
274-
ChaisemartinDHaultfoeuille().fit(
275-
df,
276-
outcome="outcome",
277-
group="group",
278-
time="period",
279-
treatment="treatment",
280-
survey_design=sd,
281-
)
278+
# Should succeed - no group-constant restriction
279+
result = ChaisemartinDHaultfoeuille(seed=1).fit(
280+
multi,
281+
outcome="outcome",
282+
group="group",
283+
time="period",
284+
treatment="treatment",
285+
survey_design=sd,
286+
)
287+
assert np.isfinite(result.overall_att)
288+
289+
def test_varying_weights_change_att(self, base_data):
290+
"""With multi-obs cells and varying weights, ATT differs from unweighted.
291+
292+
dCDH uses first differences Y_{g,t} - Y_{g,t-1}, so group-constant
293+
noise cancels. The noise must vary across both group AND time for
294+
weighted cell means to affect the ATT via different first differences.
295+
"""
296+
rng = np.random.default_rng(42)
297+
df = base_data.copy()
298+
df2 = base_data.copy()
299+
# Per-observation noise (varies by group AND time)
300+
df2["outcome"] = df2["outcome"] + rng.normal(0, 3.0, size=len(df2))
301+
multi = pd.concat([df, df2], ignore_index=True)
302+
# Give first copy weight=1, second copy weight=10
303+
multi["pw"] = np.where(np.arange(len(multi)) < len(df), 1.0, 10.0)
304+
sd = SurveyDesign(weights="pw")
305+
result_plain = ChaisemartinDHaultfoeuille(seed=1).fit(
306+
multi, outcome="outcome", group="group",
307+
time="period", treatment="treatment",
308+
)
309+
result_survey = ChaisemartinDHaultfoeuille(seed=1).fit(
310+
multi, outcome="outcome", group="group",
311+
time="period", treatment="treatment",
312+
survey_design=sd,
313+
)
314+
# Weighted cell means with time-varying noise produce different
315+
# first differences -> different ATT
316+
assert result_plain.overall_att != pytest.approx(
317+
result_survey.overall_att, abs=0.01
318+
)
282319

283320
def test_rejects_replicate_weights(self, base_data):
284321
"""Replicate weight variance not yet supported."""

0 commit comments

Comments
 (0)