Skip to content

Commit 527618f

Browse files
igerberclaude
andcommitted
Fix IPW PS correction sign and merge main
Flip PS nuisance correction sign in panel survey IPW and RCS IPW paths. R adds the correction to inf.control then subtracts (att = treat - control), so the net effect on ATT IF is subtraction. DR paths are unaffected because their M2 residual (m_control - control_change) already has opposite sign. Also merges main (StaggeredTripleDifference from #245), resolves TODO.md conflict, and updates survey-roadmap.md to reflect Phase 7a implementation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2 parents c0c7e5a + 07e37fa commit 527618f

19 files changed

Lines changed: 3606 additions & 38 deletions

ROADMAP.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Extend the existing `TripleDifference` estimator to handle staggered adoption se
5353
- Event study aggregation and pre-treatment placebo effects
5454
- Multiplier bootstrap for valid inference in staggered settings
5555

56-
**Reference**: [Ortiz-Villavicencio & Sant'Anna (2025)](https://arxiv.org/abs/2505.09942). *Working Paper*. R package: `triplediff`.
56+
**Reference**: [Ortiz-Villavicencio & Sant'Anna (2025)](https://arxiv.org/abs/2505.09942). "Better Understanding Triple Differences Estimators." *Working Paper*. R package: `triplediff`.
5757

5858
---
5959

TODO.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ Deferred items from PR reviews that were not addressed before merge.
6565
| Survey design resolution/collapse patterns are inconsistent across panel estimators — ContinuousDiD rebuilds unit-level design in SE code, EfficientDiD builds once in fit(), StackedDiD re-resolves on stacked data; extract shared helpers for panel-to-unit collapse, post-filter re-resolution, and metadata recomputation | `continuous_did.py`, `efficient_did.py`, `stacked_did.py` | #226 | Low |
6666
| Survey metadata formatting dedup — **Resolved**. Extracted `_format_survey_block()` helper in `results.py`, replaced 13 occurrences across 11 files. | `results.py` + 10 results files || Resolved |
6767
| TROP: `fit()` and `_fit_global()` share ~150 lines of near-identical data setup (panel pivoting, absorbing-state validation, first-treatment detection, effective rank, NaN warnings). Both bootstrap methods also duplicate the stratified resampling loop. Extract shared helpers to eliminate cross-file sync risk. | `trop.py`, `trop_global.py`, `trop_local.py` || Low |
68+
| StaggeredTripleDifference R cross-validation: CSV fixtures not committed (gitignored); tests skip without local R + triplediff. Commit fixtures or generate deterministically. | `tests/test_methodology_staggered_triple_diff.py` | #245 | Medium |
69+
| StaggeredTripleDifference R parity: benchmark only tests no-covariate path (xformla=~1). Add covariate-adjusted scenarios and aggregation SE parity assertions. | `benchmarks/R/benchmark_staggered_triplediff.R` | #245 | Medium |
70+
| StaggeredTripleDifference: per-cohort group-effect SEs include WIF (conservative vs R's wif=NULL). Documented in REGISTRY. Could override mixin for exact R match. | `staggered_triple_diff.py` | #245 | Low |
6871

6972
#### Performance
7073

@@ -77,7 +80,6 @@ Deferred items from PR reviews that were not addressed before merge.
7780

7881
| Issue | Location | PR | Priority |
7982
|-------|----------|----|----------|
80-
| Plotly renderers silently ignore styling kwargs (marker, markersize, linewidth, capsize, ci_linewidth) that the matplotlib backend honors; thread them through or reject when `backend="plotly"` | `visualization/_event_study.py`, `_diagnostic.py`, `_power.py` | #222 | Medium |
8183
| R comparison tests spawn separate `Rscript` per test (slow CI) | `tests/test_methodology_twfe.py:294` | #139 | Low |
8284
| CS R helpers hard-code `xformla = ~ 1`; no covariate-adjusted R benchmark for IRLS path | `tests/test_methodology_callaway.py` | #202 | Low |
8385
| ~376 `duplicate object description` Sphinx warnings — restructure `docs/api/*.rst` to avoid duplicate `:members:` + `autosummary` | `docs/api/*.rst` || Low |
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#!/usr/bin/env Rscript
2+
# Benchmark: Staggered Triple Difference (R `triplediff` package)
3+
#
4+
# Generates golden values for cross-validation against Python
5+
# StaggeredTripleDifference estimator.
6+
#
7+
# Usage:
8+
# Rscript benchmark_staggered_triplediff.R
9+
10+
library(triplediff)
11+
library(jsonlite)
12+
library(data.table)
13+
14+
cat("=== Staggered DDD Benchmark Generator ===\n")
15+
16+
output_dir <- file.path(dirname(dirname(getwd())), "benchmarks", "data", "synthetic")
17+
# Handle running from project root or benchmarks/R
18+
if (!dir.exists(output_dir)) {
19+
output_dir <- "benchmarks/data/synthetic"
20+
}
21+
if (!dir.exists(output_dir)) {
22+
dir.create(output_dir, recursive = TRUE)
23+
}
24+
25+
results <- list()
26+
27+
# Scenario definitions
28+
scenarios <- list(
29+
list(seed=42, dgp=1, method="dr", cg="nevertreated", key="s42_dgp1_dr_nt"),
30+
list(seed=42, dgp=1, method="ipw", cg="nevertreated", key="s42_dgp1_ipw_nt"),
31+
list(seed=42, dgp=1, method="reg", cg="nevertreated", key="s42_dgp1_reg_nt"),
32+
list(seed=42, dgp=1, method="dr", cg="notyettreated", key="s42_dgp1_dr_nyt"),
33+
list(seed=42, dgp=1, method="ipw", cg="notyettreated", key="s42_dgp1_ipw_nyt"),
34+
list(seed=42, dgp=1, method="reg", cg="notyettreated", key="s42_dgp1_reg_nyt"),
35+
list(seed=123, dgp=1, method="dr", cg="nevertreated", key="s123_dgp1_dr_nt"),
36+
list(seed=123, dgp=1, method="dr", cg="notyettreated", key="s123_dgp1_dr_nyt"),
37+
list(seed=99, dgp=1, method="dr", cg="nevertreated", key="s99_dgp1_dr_nt"),
38+
list(seed=99, dgp=1, method="dr", cg="notyettreated", key="s99_dgp1_dr_nyt")
39+
)
40+
41+
for (sc in scenarios) {
42+
cat(sprintf(" Running scenario: %s ...\n", sc$key))
43+
44+
set.seed(sc$seed)
45+
dgp <- gen_dgp_mult_periods(size = 500, dgp_type = sc$dgp)
46+
data <- dgp$data
47+
48+
# Save data CSV (one per seed+dgp combo, reused across methods)
49+
data_key <- sprintf("s%d_dgp%d", sc$seed, sc$dgp)
50+
csv_path <- file.path(output_dir, sprintf("staggered_ddd_data_%s.csv", data_key))
51+
if (!file.exists(csv_path)) {
52+
fwrite(data, csv_path)
53+
cat(sprintf(" Saved data: %s\n", csv_path))
54+
}
55+
56+
# Run DDD estimation
57+
res <- tryCatch({
58+
ddd(yname = "y", tname = "time", idname = "id",
59+
gname = "state", pname = "partition",
60+
xformla = ~1, # no covariates for cross-validation
61+
data = data,
62+
control_group = sc$cg,
63+
base_period = "varying",
64+
est_method = sc$method,
65+
panel = TRUE)
66+
}, error = function(e) {
67+
cat(sprintf(" ERROR: %s\n", e$message))
68+
return(NULL)
69+
})
70+
71+
if (is.null(res)) next
72+
73+
# Group-time results
74+
gt_results <- data.frame(
75+
group = res$groups,
76+
period = res$periods,
77+
att = res$ATT,
78+
se = res$se
79+
)
80+
81+
# Event study aggregation
82+
agg_es <- tryCatch({
83+
agg_ddd(res, type = "eventstudy")
84+
}, error = function(e) {
85+
cat(sprintf(" Event study agg failed: %s\n", e$message))
86+
NULL
87+
})
88+
89+
es_results <- NULL
90+
overall_att_es <- NA
91+
overall_se_es <- NA
92+
if (!is.null(agg_es)) {
93+
a <- agg_es$aggte_ddd
94+
es_results <- data.frame(
95+
event_time = a$egt,
96+
att = a$att.egt,
97+
se = a$se.egt
98+
)
99+
overall_att_es <- a$overall.att
100+
overall_se_es <- a$overall.se
101+
}
102+
103+
# Simple aggregation
104+
agg_simple <- tryCatch({
105+
agg_ddd(res, type = "simple")
106+
}, error = function(e) {
107+
cat(sprintf(" Simple agg failed: %s\n", e$message))
108+
NULL
109+
})
110+
111+
overall_att_simple <- NA
112+
overall_se_simple <- NA
113+
if (!is.null(agg_simple)) {
114+
a <- agg_simple$aggte_ddd
115+
overall_att_simple <- a$overall.att
116+
overall_se_simple <- a$overall.se
117+
}
118+
119+
# Store results
120+
results[[sc$key]] <- list(
121+
seed = sc$seed,
122+
dgp_type = sc$dgp,
123+
est_method = sc$method,
124+
control_group = sc$cg,
125+
n = res$n,
126+
gt_att = as.list(gt_results$att),
127+
gt_se = as.list(gt_results$se),
128+
gt_groups = as.list(gt_results$group),
129+
gt_periods = as.list(gt_results$period),
130+
overall_att_simple = overall_att_simple,
131+
overall_se_simple = overall_se_simple,
132+
overall_att_es = overall_att_es,
133+
overall_se_es = overall_se_es,
134+
es_event_times = if (!is.null(es_results)) as.list(es_results$event_time) else NULL,
135+
es_att = if (!is.null(es_results)) as.list(es_results$att) else NULL,
136+
es_se = if (!is.null(es_results)) as.list(es_results$se) else NULL
137+
)
138+
139+
cat(sprintf(" GT ATT: %s\n", paste(round(res$ATT, 4), collapse=", ")))
140+
cat(sprintf(" Overall ATT (simple): %.4f\n", overall_att_simple))
141+
}
142+
143+
# Save all results as JSON
144+
json_path <- file.path(output_dir, "staggered_ddd_r_results.json")
145+
writeLines(toJSON(results, auto_unbox = TRUE, pretty = TRUE, digits = 10), json_path)
146+
cat(sprintf("\nResults saved to: %s\n", json_path))
147+
cat("Done.\n")

0 commit comments

Comments
 (0)