Skip to content

Commit 2dc1888

Browse files
igerberclaude
andcommitted
Address second AI review round on survey tutorial
- Fix StackedDiD support table: show "Full (pweight only)" for strata/PSU/FPC since it supports full TSL on composed weights, just restricted to pweight weight type (P1) - Fix replicate weight section: remove misleading TSL vs JK1 equivalence claim; show JK1 as standalone API demo and note that stratified replicates (JKn) are needed for stratified designs (P1) - Add input validation for weight_variation and cohort_periods in generate_survey_did_data() with negative tests (P2) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 7d381de commit 2dc1888

3 files changed

Lines changed: 37 additions & 18 deletions

File tree

diff_diff/prep_dgp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,14 @@ def generate_survey_did_data(
12201220

12211221
if cohort_periods is None:
12221222
cohort_periods = [3, 5]
1223+
if not cohort_periods:
1224+
raise ValueError("cohort_periods must be a non-empty list of integers")
1225+
1226+
valid_wv = ("none", "moderate", "high")
1227+
if weight_variation not in valid_wv:
1228+
raise ValueError(
1229+
f"weight_variation must be one of {valid_wv}, got {weight_variation!r}"
1230+
)
12231231

12241232
# --- Survey structure: assign units to strata and PSUs ---
12251233
n_psu_total = n_strata * psu_per_stratum

docs/tutorials/16_survey_did.ipynb

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -486,21 +486,15 @@
486486
" first_treat='first_treat', survey_design=sd_rep,\n",
487487
")\n",
488488
"\n",
489-
"# Compare TSL vs replicate SE\n",
490-
"# (Use the same data with strata/PSU/FPC for TSL)\n",
491-
"sd_tsl = SurveyDesign(weights='weight', strata='stratum', psu='psu', fpc='fpc')\n",
492-
"cs_tsl = CallawaySantAnna(control_group='never_treated')\n",
493-
"results_tsl = cs_tsl.fit(\n",
494-
" df_rep, outcome='outcome', unit='unit', time='period',\n",
495-
" first_treat='first_treat', survey_design=sd_tsl,\n",
496-
")\n",
497-
"\n",
498-
"print(f\"{'Method':20s} {'ATT':>10s} {'SE':>10s}\")\n",
499-
"print(\"-\" * 42)\n",
500-
"print(f\"{'TSL (strata/PSU/FPC)':20s} {results_tsl.overall_att:>10.4f} {results_tsl.overall_se:>10.4f}\")\n",
501-
"print(f\"{'JK1 (replicate wts)':20s} {results_rep.overall_att:>10.4f} {results_rep.overall_se:>10.4f}\")\n",
502-
"print(f\"\\nBoth methods correctly account for the survey design.\")\n",
503-
"print(f\"In practice, use whichever your survey provides.\")"
489+
"print(f\"Overall ATT: {results_rep.overall_att:.4f} (SE: {results_rep.overall_se:.4f})\")\n",
490+
"print(f\"Survey d.f.: {results_rep.survey_metadata.df_survey}\")\n",
491+
"print(f\"Replicate method: {results_rep.survey_metadata.replicate_method}\")\n",
492+
"print(f\"Number of replicates: {results_rep.survey_metadata.n_replicates}\")\n",
493+
"print()\n",
494+
"print(\"Note: JK1 replicates are unstratified (global delete-one-PSU).\")\n",
495+
"print(\"If your survey uses stratified sampling, stratified replicates (JKn)\")\n",
496+
"print(\"provide design-consistent variance estimation. Use whichever method\")\n",
497+
"print(\"your survey documentation specifies.\")"
504498
]
505499
},
506500
{
@@ -697,7 +691,7 @@
697691
"| **CallawaySantAnna** | Full | Full | Full | Multiplier at PSU |\n",
698692
"| **TripleDifference** | Full | Full | Full (analytical) | -- |\n",
699693
"| **SunAbraham** | Full | Full | -- | Rao-Wu rescaled |\n",
700-
"| **StackedDiD** | pweight only | pweight only | -- | -- |\n",
694+
"| **StackedDiD** | pweight only | Full (pweight only) | -- | -- |\n",
701695
"| **ImputationDiD** | Full | Partial (no FPC) | -- | Multiplier at PSU |\n",
702696
"| **TwoStageDiD** | Full | Partial (no FPC) | -- | Multiplier at PSU |\n",
703697
"| **ContinuousDiD** | Full | Full | Full (analytical) | Multiplier at PSU |\n",
@@ -707,9 +701,10 @@
707701
"| **BaconDecomposition** | Diagnostic | Diagnostic | -- | -- |\n",
708702
"\n",
709703
"**Legend:**\n",
710-
"- **Full**: Weights + strata/PSU/FPC + Taylor Series Linearization variance\n",
704+
"- **Full**: Weights (all types) + strata/PSU/FPC + Taylor Series Linearization variance\n",
705+
"- **Full (pweight only)**: Full TSL support with strata/PSU/FPC, but only accepts `pweight` weight type (`fweight`/`aweight` rejected because Q-weight composition changes their semantics)\n",
711706
"- **Partial (no FPC)**: Weights + strata (for df) + PSU (for clustering); FPC raises `NotImplementedError`\n",
712-
"- **pweight only**: Sampling weights for point estimates; no strata/PSU/FPC\n",
707+
"- **pweight only**: Sampling weights for point estimates; no strata/PSU/FPC design elements\n",
713708
"- **Diagnostic**: Weighted descriptive statistics only (no inference)\n",
714709
"- **--**: Not supported\n",
715710
"\n",

tests/test_prep.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,3 +1322,19 @@ def test_repeated_cross_section(self):
13221322
assert data["unit"].nunique() == 60 # unique across all periods
13231323
# No unit appears in more than one period
13241324
assert data.groupby("unit")["period"].nunique().max() == 1
1325+
1326+
def test_invalid_weight_variation(self):
1327+
"""Test that invalid weight_variation raises ValueError."""
1328+
import pytest
1329+
from diff_diff.prep import generate_survey_did_data
1330+
1331+
with pytest.raises(ValueError, match="weight_variation must be"):
1332+
generate_survey_did_data(weight_variation="invalid", seed=42)
1333+
1334+
def test_empty_cohort_periods(self):
1335+
"""Test that empty cohort_periods raises ValueError."""
1336+
import pytest
1337+
from diff_diff.prep import generate_survey_did_data
1338+
1339+
with pytest.raises(ValueError, match="cohort_periods must be"):
1340+
generate_survey_did_data(cohort_periods=[], seed=42)

0 commit comments

Comments
 (0)