Skip to content

Commit 07e37fa

Browse files
authored
Merge pull request #245 from igerber/staggered-ddd
Add Staggered Triple Difference estimator
2 parents 93abaea + 4b5157c commit 07e37fa

12 files changed

Lines changed: 3304 additions & 1 deletion

ROADMAP.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Extend the existing `TripleDifference` estimator to handle staggered adoption se
3030
- Event study aggregation and pre-treatment placebo effects
3131
- Multiplier bootstrap for valid inference in staggered settings
3232

33-
**Reference**: [Ortiz-Villavicencio & Sant'Anna (2025)](https://arxiv.org/abs/2505.09942). *Working Paper*. R package: `triplediff`.
33+
**Reference**: [Ortiz-Villavicencio & Sant'Anna (2025)](https://arxiv.org/abs/2505.09942). "Better Understanding Triple Differences Estimators." *Working Paper*. R package: `triplediff`.
3434

3535
### Enhanced Visualization
3636

TODO.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ Deferred items from PR reviews that were not addressed before merge.
5959
| TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. | `prep_dgp.py`, `power.py` | #208 | Low |
6060
| Survey design resolution/collapse patterns inconsistent across panel estimators — extract shared helpers for panel-to-unit collapse, post-filter re-resolution, metadata recomputation | `continuous_did.py`, `efficient_did.py`, `stacked_did.py` | #226 | Low |
6161
| TROP: `fit()` and `_fit_global()` share ~150 lines of near-identical data setup. Extract shared helpers to eliminate cross-file sync risk. | `trop.py`, `trop_global.py`, `trop_local.py` || Low |
62+
| StaggeredTripleDifference R cross-validation: CSV fixtures not committed (gitignored); tests skip without local R + triplediff. Commit fixtures or generate deterministically. | `tests/test_methodology_staggered_triple_diff.py` | #245 | Medium |
63+
| StaggeredTripleDifference R parity: benchmark only tests no-covariate path (xformla=~1). Add covariate-adjusted scenarios and aggregation SE parity assertions. | `benchmarks/R/benchmark_staggered_triplediff.R` | #245 | Medium |
64+
| StaggeredTripleDifference: per-cohort group-effect SEs include WIF (conservative vs R's wif=NULL). Documented in REGISTRY. Could override mixin for exact R match. | `staggered_triple_diff.py` | #245 | Low |
6265

6366
#### Performance
6467

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#!/usr/bin/env Rscript
2+
# Benchmark: Staggered Triple Difference (R `triplediff` package)
3+
#
4+
# Generates golden values for cross-validation against Python
5+
# StaggeredTripleDifference estimator.
6+
#
7+
# Usage:
8+
# Rscript benchmark_staggered_triplediff.R
9+
10+
library(triplediff)
11+
library(jsonlite)
12+
library(data.table)
13+
14+
cat("=== Staggered DDD Benchmark Generator ===\n")
15+
16+
output_dir <- file.path(dirname(dirname(getwd())), "benchmarks", "data", "synthetic")
17+
# Handle running from project root or benchmarks/R
18+
if (!dir.exists(output_dir)) {
19+
output_dir <- "benchmarks/data/synthetic"
20+
}
21+
if (!dir.exists(output_dir)) {
22+
dir.create(output_dir, recursive = TRUE)
23+
}
24+
25+
results <- list()
26+
27+
# Scenario definitions
28+
scenarios <- list(
29+
list(seed=42, dgp=1, method="dr", cg="nevertreated", key="s42_dgp1_dr_nt"),
30+
list(seed=42, dgp=1, method="ipw", cg="nevertreated", key="s42_dgp1_ipw_nt"),
31+
list(seed=42, dgp=1, method="reg", cg="nevertreated", key="s42_dgp1_reg_nt"),
32+
list(seed=42, dgp=1, method="dr", cg="notyettreated", key="s42_dgp1_dr_nyt"),
33+
list(seed=42, dgp=1, method="ipw", cg="notyettreated", key="s42_dgp1_ipw_nyt"),
34+
list(seed=42, dgp=1, method="reg", cg="notyettreated", key="s42_dgp1_reg_nyt"),
35+
list(seed=123, dgp=1, method="dr", cg="nevertreated", key="s123_dgp1_dr_nt"),
36+
list(seed=123, dgp=1, method="dr", cg="notyettreated", key="s123_dgp1_dr_nyt"),
37+
list(seed=99, dgp=1, method="dr", cg="nevertreated", key="s99_dgp1_dr_nt"),
38+
list(seed=99, dgp=1, method="dr", cg="notyettreated", key="s99_dgp1_dr_nyt")
39+
)
40+
41+
for (sc in scenarios) {
42+
cat(sprintf(" Running scenario: %s ...\n", sc$key))
43+
44+
set.seed(sc$seed)
45+
dgp <- gen_dgp_mult_periods(size = 500, dgp_type = sc$dgp)
46+
data <- dgp$data
47+
48+
# Save data CSV (one per seed+dgp combo, reused across methods)
49+
data_key <- sprintf("s%d_dgp%d", sc$seed, sc$dgp)
50+
csv_path <- file.path(output_dir, sprintf("staggered_ddd_data_%s.csv", data_key))
51+
if (!file.exists(csv_path)) {
52+
fwrite(data, csv_path)
53+
cat(sprintf(" Saved data: %s\n", csv_path))
54+
}
55+
56+
# Run DDD estimation
57+
res <- tryCatch({
58+
ddd(yname = "y", tname = "time", idname = "id",
59+
gname = "state", pname = "partition",
60+
xformla = ~1, # no covariates for cross-validation
61+
data = data,
62+
control_group = sc$cg,
63+
base_period = "varying",
64+
est_method = sc$method,
65+
panel = TRUE)
66+
}, error = function(e) {
67+
cat(sprintf(" ERROR: %s\n", e$message))
68+
return(NULL)
69+
})
70+
71+
if (is.null(res)) next
72+
73+
# Group-time results
74+
gt_results <- data.frame(
75+
group = res$groups,
76+
period = res$periods,
77+
att = res$ATT,
78+
se = res$se
79+
)
80+
81+
# Event study aggregation
82+
agg_es <- tryCatch({
83+
agg_ddd(res, type = "eventstudy")
84+
}, error = function(e) {
85+
cat(sprintf(" Event study agg failed: %s\n", e$message))
86+
NULL
87+
})
88+
89+
es_results <- NULL
90+
overall_att_es <- NA
91+
overall_se_es <- NA
92+
if (!is.null(agg_es)) {
93+
a <- agg_es$aggte_ddd
94+
es_results <- data.frame(
95+
event_time = a$egt,
96+
att = a$att.egt,
97+
se = a$se.egt
98+
)
99+
overall_att_es <- a$overall.att
100+
overall_se_es <- a$overall.se
101+
}
102+
103+
# Simple aggregation
104+
agg_simple <- tryCatch({
105+
agg_ddd(res, type = "simple")
106+
}, error = function(e) {
107+
cat(sprintf(" Simple agg failed: %s\n", e$message))
108+
NULL
109+
})
110+
111+
overall_att_simple <- NA
112+
overall_se_simple <- NA
113+
if (!is.null(agg_simple)) {
114+
a <- agg_simple$aggte_ddd
115+
overall_att_simple <- a$overall.att
116+
overall_se_simple <- a$overall.se
117+
}
118+
119+
# Store results
120+
results[[sc$key]] <- list(
121+
seed = sc$seed,
122+
dgp_type = sc$dgp,
123+
est_method = sc$method,
124+
control_group = sc$cg,
125+
n = res$n,
126+
gt_att = as.list(gt_results$att),
127+
gt_se = as.list(gt_results$se),
128+
gt_groups = as.list(gt_results$group),
129+
gt_periods = as.list(gt_results$period),
130+
overall_att_simple = overall_att_simple,
131+
overall_se_simple = overall_se_simple,
132+
overall_att_es = overall_att_es,
133+
overall_se_es = overall_se_es,
134+
es_event_times = if (!is.null(es_results)) as.list(es_results$event_time) else NULL,
135+
es_att = if (!is.null(es_results)) as.list(es_results$att) else NULL,
136+
es_se = if (!is.null(es_results)) as.list(es_results$se) else NULL
137+
)
138+
139+
cat(sprintf(" GT ATT: %s\n", paste(round(res$ATT, 4), collapse=", ")))
140+
cat(sprintf(" Overall ATT (simple): %.4f\n", overall_att_simple))
141+
}
142+
143+
# Save all results as JSON
144+
json_path <- file.path(output_dir, "staggered_ddd_r_results.json")
145+
writeLines(toJSON(results, auto_unbox = TRUE, pretty = TRUE, digits = 10), json_path)
146+
cat(sprintf("\nResults saved to: %s\n", json_path))
147+
cat("Done.\n")

0 commit comments

Comments
 (0)