diff --git a/ROADMAP.md b/ROADMAP.md index 07b9b9e9..4eb43c2b 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -30,7 +30,7 @@ Extend the existing `TripleDifference` estimator to handle staggered adoption se - Event study aggregation and pre-treatment placebo effects - Multiplier bootstrap for valid inference in staggered settings -**Reference**: [Ortiz-Villavicencio & Sant'Anna (2025)](https://arxiv.org/abs/2505.09942). *Working Paper*. R package: `triplediff`. +**Reference**: [Ortiz-Villavicencio & Sant'Anna (2025)](https://arxiv.org/abs/2505.09942). "Better Understanding Triple Differences Estimators." *Working Paper*. R package: `triplediff`. ### Enhanced Visualization diff --git a/TODO.md b/TODO.md index 5ae73fa9..bdfcc398 100644 --- a/TODO.md +++ b/TODO.md @@ -59,6 +59,9 @@ Deferred items from PR reviews that were not addressed before merge. | TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. | `prep_dgp.py`, `power.py` | #208 | Low | | Survey design resolution/collapse patterns inconsistent across panel estimators — extract shared helpers for panel-to-unit collapse, post-filter re-resolution, metadata recomputation | `continuous_did.py`, `efficient_did.py`, `stacked_did.py` | #226 | Low | | TROP: `fit()` and `_fit_global()` share ~150 lines of near-identical data setup. Extract shared helpers to eliminate cross-file sync risk. | `trop.py`, `trop_global.py`, `trop_local.py` | — | Low | +| StaggeredTripleDifference R cross-validation: CSV fixtures not committed (gitignored); tests skip without local R + triplediff. Commit fixtures or generate deterministically. | `tests/test_methodology_staggered_triple_diff.py` | #245 | Medium | +| StaggeredTripleDifference R parity: benchmark only tests no-covariate path (xformla=~1). Add covariate-adjusted scenarios and aggregation SE parity assertions. | `benchmarks/R/benchmark_staggered_triplediff.R` | #245 | Medium | +| StaggeredTripleDifference: per-cohort group-effect SEs include WIF (conservative vs R's wif=NULL). Documented in REGISTRY. Could override mixin for exact R match. | `staggered_triple_diff.py` | #245 | Low | #### Performance diff --git a/benchmarks/R/benchmark_staggered_triplediff.R b/benchmarks/R/benchmark_staggered_triplediff.R new file mode 100644 index 00000000..ed96480a --- /dev/null +++ b/benchmarks/R/benchmark_staggered_triplediff.R @@ -0,0 +1,147 @@ +#!/usr/bin/env Rscript +# Benchmark: Staggered Triple Difference (R `triplediff` package) +# +# Generates golden values for cross-validation against Python +# StaggeredTripleDifference estimator. +# +# Usage: +# Rscript benchmark_staggered_triplediff.R + +library(triplediff) +library(jsonlite) +library(data.table) + +cat("=== Staggered DDD Benchmark Generator ===\n") + +output_dir <- file.path(dirname(dirname(getwd())), "benchmarks", "data", "synthetic") +# Handle running from project root or benchmarks/R +if (!dir.exists(output_dir)) { + output_dir <- "benchmarks/data/synthetic" +} +if (!dir.exists(output_dir)) { + dir.create(output_dir, recursive = TRUE) +} + +results <- list() + +# Scenario definitions +scenarios <- list( + list(seed=42, dgp=1, method="dr", cg="nevertreated", key="s42_dgp1_dr_nt"), + list(seed=42, dgp=1, method="ipw", cg="nevertreated", key="s42_dgp1_ipw_nt"), + list(seed=42, dgp=1, method="reg", cg="nevertreated", key="s42_dgp1_reg_nt"), + list(seed=42, dgp=1, method="dr", cg="notyettreated", key="s42_dgp1_dr_nyt"), + list(seed=42, dgp=1, method="ipw", cg="notyettreated", key="s42_dgp1_ipw_nyt"), + list(seed=42, dgp=1, method="reg", cg="notyettreated", key="s42_dgp1_reg_nyt"), + list(seed=123, dgp=1, method="dr", cg="nevertreated", key="s123_dgp1_dr_nt"), + list(seed=123, dgp=1, method="dr", cg="notyettreated", key="s123_dgp1_dr_nyt"), + list(seed=99, dgp=1, method="dr", cg="nevertreated", key="s99_dgp1_dr_nt"), + list(seed=99, dgp=1, method="dr", cg="notyettreated", key="s99_dgp1_dr_nyt") +) + +for (sc in scenarios) { + cat(sprintf(" Running scenario: %s ...\n", sc$key)) + + set.seed(sc$seed) + dgp <- gen_dgp_mult_periods(size = 500, dgp_type = sc$dgp) + data <- dgp$data + + # Save data CSV (one per seed+dgp combo, reused across methods) + data_key <- sprintf("s%d_dgp%d", sc$seed, sc$dgp) + csv_path <- file.path(output_dir, sprintf("staggered_ddd_data_%s.csv", data_key)) + if (!file.exists(csv_path)) { + fwrite(data, csv_path) + cat(sprintf(" Saved data: %s\n", csv_path)) + } + + # Run DDD estimation + res <- tryCatch({ + ddd(yname = "y", tname = "time", idname = "id", + gname = "state", pname = "partition", + xformla = ~1, # no covariates for cross-validation + data = data, + control_group = sc$cg, + base_period = "varying", + est_method = sc$method, + panel = TRUE) + }, error = function(e) { + cat(sprintf(" ERROR: %s\n", e$message)) + return(NULL) + }) + + if (is.null(res)) next + + # Group-time results + gt_results <- data.frame( + group = res$groups, + period = res$periods, + att = res$ATT, + se = res$se + ) + + # Event study aggregation + agg_es <- tryCatch({ + agg_ddd(res, type = "eventstudy") + }, error = function(e) { + cat(sprintf(" Event study agg failed: %s\n", e$message)) + NULL + }) + + es_results <- NULL + overall_att_es <- NA + overall_se_es <- NA + if (!is.null(agg_es)) { + a <- agg_es$aggte_ddd + es_results <- data.frame( + event_time = a$egt, + att = a$att.egt, + se = a$se.egt + ) + overall_att_es <- a$overall.att + overall_se_es <- a$overall.se + } + + # Simple aggregation + agg_simple <- tryCatch({ + agg_ddd(res, type = "simple") + }, error = function(e) { + cat(sprintf(" Simple agg failed: %s\n", e$message)) + NULL + }) + + overall_att_simple <- NA + overall_se_simple <- NA + if (!is.null(agg_simple)) { + a <- agg_simple$aggte_ddd + overall_att_simple <- a$overall.att + overall_se_simple <- a$overall.se + } + + # Store results + results[[sc$key]] <- list( + seed = sc$seed, + dgp_type = sc$dgp, + est_method = sc$method, + control_group = sc$cg, + n = res$n, + gt_att = as.list(gt_results$att), + gt_se = as.list(gt_results$se), + gt_groups = as.list(gt_results$group), + gt_periods = as.list(gt_results$period), + overall_att_simple = overall_att_simple, + overall_se_simple = overall_se_simple, + overall_att_es = overall_att_es, + overall_se_es = overall_se_es, + es_event_times = if (!is.null(es_results)) as.list(es_results$event_time) else NULL, + es_att = if (!is.null(es_results)) as.list(es_results$att) else NULL, + es_se = if (!is.null(es_results)) as.list(es_results$se) else NULL + ) + + cat(sprintf(" GT ATT: %s\n", paste(round(res$ATT, 4), collapse=", "))) + cat(sprintf(" Overall ATT (simple): %.4f\n", overall_att_simple)) +} + +# Save all results as JSON +json_path <- file.path(output_dir, "staggered_ddd_r_results.json") +writeLines(toJSON(results, auto_unbox = TRUE, pretty = TRUE, digits = 10), json_path) +cat(sprintf("\nResults saved to: %s\n", json_path)) +cat("Done.\n") diff --git a/benchmarks/data/synthetic/staggered_ddd_r_results.json b/benchmarks/data/synthetic/staggered_ddd_r_results.json new file mode 100644 index 00000000..fbaeece7 --- /dev/null +++ b/benchmarks/data/synthetic/staggered_ddd_r_results.json @@ -0,0 +1,502 @@ +{ + "s42_dgp1_dr_nt": { + "seed": 42, + "dgp_type": 1, + "est_method": "dr", + "control_group": "nevertreated", + "n": 500, + "gt_att": [ + -4.2895749855, + -8.0409130843, + 14.4068707444, + 40.2656811791 + ], + "gt_se": [ + 11.6417919549, + 23.2527788237, + 10.3942031714, + 10.4325028214 + ], + "gt_groups": [ + 2, + 2, + 3, + 3 + ], + "gt_periods": [ + 2, + 3, + 2, + 3 + ], + "overall_att_simple": 13.4654780225, + "overall_se_simple": 12.9490509622, + "overall_att_es": 7.0746899675, + "overall_se_es": 15.8100875545, + "es_event_times": [ + -1, + 0, + 1 + ], + "es_att": [ + 14.4068707444, + 22.1902930193, + -8.0409130843 + ], + "es_se": [ + 10.3942031714, + 9.5837019291, + 23.2527788237 + ] + }, + "s42_dgp1_ipw_nt": { + "seed": 42, + "dgp_type": 1, + "est_method": "ipw", + "control_group": "nevertreated", + "n": 500, + "gt_att": [ + -4.2895749855, + -8.0409130843, + 14.4068707444, + 40.2656811791 + ], + "gt_se": [ + 11.6417919549, + 23.2527788237, + 10.3942031714, + 10.4325028214 + ], + "gt_groups": [ + 2, + 2, + 3, + 3 + ], + "gt_periods": [ + 2, + 3, + 2, + 3 + ], + "overall_att_simple": 13.4654780225, + "overall_se_simple": 12.9490509622, + "overall_att_es": 7.0746899675, + "overall_se_es": 15.8100875545, + "es_event_times": [ + -1, + 0, + 1 + ], + "es_att": [ + 14.4068707444, + 22.1902930193, + -8.0409130843 + ], + "es_se": [ + 10.3942031714, + 9.5837019291, + 23.2527788237 + ] + }, + "s42_dgp1_reg_nt": { + "seed": 42, + "dgp_type": 1, + "est_method": "reg", + "control_group": "nevertreated", + "n": 500, + "gt_att": [ + -4.2895749855, + -8.0409130843, + 14.4068707444, + 40.2656811791 + ], + "gt_se": [ + 11.6417919549, + 23.2527788237, + 10.3942031714, + 10.4325028214 + ], + "gt_groups": [ + 2, + 2, + 3, + 3 + ], + "gt_periods": [ + 2, + 3, + 2, + 3 + ], + "overall_att_simple": 13.4654780225, + "overall_se_simple": 12.9490509622, + "overall_att_es": 7.0746899675, + "overall_se_es": 15.8100875545, + "es_event_times": [ + -1, + 0, + 1 + ], + "es_att": [ + 14.4068707444, + 22.1902930193, + -8.0409130843 + ], + "es_se": [ + 10.3942031714, + 9.5837019291, + 23.2527788237 + ] + }, + "s42_dgp1_dr_nyt": { + "seed": 42, + "dgp_type": 1, + "est_method": "dr", + "control_group": "notyettreated", + "n": 500, + "gt_att": [ + -12.5200130879, + -8.0409130843, + 14.4068707444, + 40.2656811791 + ], + "gt_se": [ + 10.0235582937, + 23.2527788237, + 10.3942031714, + 10.4325028214 + ], + "gt_groups": [ + 2, + 2, + 3, + 3 + ], + "gt_periods": [ + 2, + 3, + 2, + 3 + ], + "overall_att_simple": 11.090149379, + "overall_se_simple": 11.7530219528, + "overall_att_es": 5.4052083369, + "overall_se_es": 15.0614586108, + "es_event_times": [ + -1, + 0, + 1 + ], + "es_att": [ + 14.4068707444, + 18.8513297581, + -8.0409130843 + ], + "es_se": [ + 10.3942031714, + 7.5310394332, + 23.2527788237 + ] + }, + "s42_dgp1_ipw_nyt": { + "seed": 42, + "dgp_type": 1, + "est_method": "ipw", + "control_group": "notyettreated", + "n": 500, + "gt_att": [ + -12.5200130879, + -8.0409130843, + 14.4068707444, + 40.2656811791 + ], + "gt_se": [ + 10.0235582937, + 23.2527788237, + 10.3942031714, + 10.4325028214 + ], + "gt_groups": [ + 2, + 2, + 3, + 3 + ], + "gt_periods": [ + 2, + 3, + 2, + 3 + ], + "overall_att_simple": 11.090149379, + "overall_se_simple": 11.7530219528, + "overall_att_es": 5.4052083369, + "overall_se_es": 15.0614586108, + "es_event_times": [ + -1, + 0, + 1 + ], + "es_att": [ + 14.4068707444, + 18.8513297581, + -8.0409130843 + ], + "es_se": [ + 10.3942031714, + 7.5310394332, + 23.2527788237 + ] + }, + "s42_dgp1_reg_nyt": { + "seed": 42, + "dgp_type": 1, + "est_method": "reg", + "control_group": "notyettreated", + "n": 500, + "gt_att": [ + -12.5200130879, + -8.0409130843, + 14.4068707444, + 40.2656811791 + ], + "gt_se": [ + 10.0235582937, + 23.2527788237, + 10.3942031714, + 10.4325028214 + ], + "gt_groups": [ + 2, + 2, + 3, + 3 + ], + "gt_periods": [ + 2, + 3, + 2, + 3 + ], + "overall_att_simple": 11.090149379, + "overall_se_simple": 11.7530219528, + "overall_att_es": 5.4052083369, + "overall_se_es": 15.0614586108, + "es_event_times": [ + -1, + 0, + 1 + ], + "es_att": [ + 14.4068707444, + 18.8513297581, + -8.0409130843 + ], + "es_se": [ + 10.3942031714, + 7.5310394332, + 23.2527788237 + ] + }, + "s123_dgp1_dr_nt": { + "seed": 123, + "dgp_type": 1, + "est_method": "dr", + "control_group": "nevertreated", + "n": 500, + "gt_att": [ + -30.4816411381, + -60.301430309, + -0.14060031214, + 25.6269104574 + ], + "gt_se": [ + 12.3967456892, + 24.817021537, + 11.0895798505, + 11.1181835728 + ], + "gt_groups": [ + 2, + 2, + 3, + 3 + ], + "gt_periods": [ + 2, + 3, + 2, + 3 + ], + "overall_att_simple": -17.6122115371, + "overall_se_simple": 14.3483869454, + "overall_att_es": -29.6152926099, + "overall_se_es": 17.100452042, + "es_event_times": [ + -1, + 0, + 1 + ], + "es_att": [ + -0.14060031214, + 1.0708450892, + -60.301430309 + ], + "es_se": [ + 11.0895798505, + 10.4311952126, + 24.817021537 + ] + }, + "s123_dgp1_dr_nyt": { + "seed": 123, + "dgp_type": 1, + "est_method": "dr", + "control_group": "notyettreated", + "n": 500, + "gt_att": [ + -30.3925860449, + -60.301430309, + -0.14060031214, + 25.6269104574 + ], + "gt_se": [ + 10.2250248431, + 24.817021537, + 11.0895798505, + 11.1181835728 + ], + "gt_groups": [ + 2, + 2, + 3, + 3 + ], + "gt_periods": [ + 2, + 3, + 2, + 3 + ], + "overall_att_simple": -17.5851012281, + "overall_se_simple": 12.8229170295, + "overall_att_es": -29.5958050039, + "overall_se_es": 16.0890826653, + "es_event_times": [ + -1, + 0, + 1 + ], + "es_att": [ + -0.14060031214, + 1.1098203011, + -60.301430309 + ], + "es_se": [ + 11.0895798505, + 7.8080793406, + 24.817021537 + ] + }, + "s99_dgp1_dr_nt": { + "seed": 99, + "dgp_type": 1, + "est_method": "dr", + "control_group": "nevertreated", + "n": 500, + "gt_att": [ + -25.7895616581, + -51.484602289, + 9.2177688646, + 34.9176618395 + ], + "gt_se": [ + 11.9352387779, + 23.8073598259, + 11.5336505891, + 11.4634356921 + ], + "gt_groups": [ + 2, + 2, + 3, + 3 + ], + "gt_periods": [ + 2, + 3, + 2, + 3 + ], + "overall_att_simple": -10.5056185503, + "overall_se_simple": 13.788768701, + "overall_att_es": -21.8424440009, + "overall_se_es": 16.3348770902, + "es_event_times": [ + -1, + 0, + 1 + ], + "es_att": [ + 9.2177688646, + 7.7997142873, + -51.484602289 + ], + "es_se": [ + 11.5336505891, + 10.1564818631, + 23.8073598259 + ] + }, + "s99_dgp1_dr_nyt": { + "seed": 99, + "dgp_type": 1, + "est_method": "dr", + "control_group": "notyettreated", + "n": 500, + "gt_att": [ + -30.3503639125, + -51.484602289, + 9.2177688646, + 34.9176618395 + ], + "gt_se": [ + 10.4930515811, + 23.8073598259, + 11.5336505891, + 11.4634356921 + ], + "gt_groups": [ + 2, + 2, + 3, + 3 + ], + "gt_periods": [ + 2, + 3, + 2, + 3 + ], + "overall_att_simple": -11.913866264, + "overall_se_simple": 12.6267951373, + "overall_att_es": -22.861100342, + "overall_se_es": 15.5804347567, + "es_event_times": [ + -1, + 0, + 1 + ], + "es_att": [ + 9.2177688646, + 5.7624016051, + -51.484602289 + ], + "es_se": [ + 11.5336505891, + 8.0448174952, + 23.8073598259 + ] + } +} diff --git a/diff_diff/__init__.py b/diff_diff/__init__.py index 773e4f46..e2eecb4e 100644 --- a/diff_diff/__init__.py +++ b/diff_diff/__init__.py @@ -88,6 +88,7 @@ generate_factor_data, generate_panel_data, generate_staggered_data, + generate_staggered_ddd_data, make_post_indicator, make_treatment_indicator, rank_control_units, @@ -140,6 +141,12 @@ TripleDifferenceResults, triple_difference, ) +from diff_diff.staggered_triple_diff import ( + StaggeredTripleDifference, +) +from diff_diff.staggered_triple_diff_results import ( + StaggeredTripleDiffResults, +) from diff_diff.continuous_did import ( ContinuousDiD, ContinuousDiDResults, @@ -197,6 +204,7 @@ BJS = ImputationDiD Gardner = TwoStageDiD DDD = TripleDifference +SDDD = StaggeredTripleDifference Stacked = StackedDiD Bacon = BaconDecomposition EDiD = EfficientDiD @@ -227,6 +235,7 @@ "BJS", "Gardner", "DDD", + "SDDD", "Stacked", "Bacon", # Bacon Decomposition @@ -254,6 +263,8 @@ "two_stage_did", "TripleDifferenceResults", "triple_difference", + "StaggeredTripleDifference", + "StaggeredTripleDiffResults", "TROPResults", "trop", "StackedDiDResults", @@ -303,6 +314,7 @@ "generate_ddd_data", "generate_panel_data", "generate_event_study_data", + "generate_staggered_ddd_data", "generate_continuous_did_data", "create_event_time", "aggregate_to_cohorts", diff --git a/diff_diff/prep.py b/diff_diff/prep.py index 43d47b57..ced3143a 100644 --- a/diff_diff/prep.py +++ b/diff_diff/prep.py @@ -25,6 +25,7 @@ generate_ddd_data, generate_panel_data, generate_event_study_data, + generate_staggered_ddd_data, ) # Constants for rank_control_units diff --git a/diff_diff/prep_dgp.py b/diff_diff/prep_dgp.py index 2aab32c2..5323d3dd 100644 --- a/diff_diff/prep_dgp.py +++ b/diff_diff/prep_dgp.py @@ -939,3 +939,136 @@ def _att_func(d): ) return pd.DataFrame(records) + + +def generate_staggered_ddd_data( + n_units: int = 200, + n_periods: int = 8, + cohort_periods: Optional[List[int]] = None, + never_enabled_frac: float = 0.25, + eligibility_frac: float = 0.5, + treatment_effect: float = 3.0, + dynamic_effects: bool = False, + effect_growth: float = 0.1, + eligibility_trend: float = 0.3, + noise_sd: float = 0.5, + add_covariates: bool = False, + seed: Optional[int] = None, +) -> pd.DataFrame: + """ + Generate synthetic data for staggered triple difference (DDD) analysis. + + Creates a balanced panel with staggered enabling times and a binary + eligibility dimension. Treatment occurs when a unit is both enabled + (t >= S_i) and eligible (Q_i = 1). DDD-CPT holds by construction. + + Parameters + ---------- + n_units : int, default=200 + Number of units. + n_periods : int, default=8 + Number of time periods (1-indexed). + cohort_periods : list of int, optional + Enabling periods. Default: [4, 6]. + never_enabled_frac : float, default=0.25 + Fraction of never-enabled units. + eligibility_frac : float, default=0.5 + Fraction of eligible units (Q=1) within each cohort. + treatment_effect : float, default=3.0 + True ATT for treated units. + dynamic_effects : bool, default=False + If True, effects grow over time since enabling. + effect_growth : float, default=0.1 + Per-period effect growth rate when dynamic_effects=True. + eligibility_trend : float, default=0.3 + Differential time trend for eligible vs ineligible units. + Same across all enabling groups (preserves DDD-CPT). + noise_sd : float, default=0.5 + Standard deviation of idiosyncratic noise. + add_covariates : bool, default=False + If True, add covariates x1 (continuous) and x2 (binary). + seed : int, optional + Random seed. + + Returns + ------- + pd.DataFrame + Columns: unit, period, outcome, first_treat, eligibility, treated, + true_effect. Also x1, x2 if add_covariates=True. + """ + rng = np.random.default_rng(seed) + + if cohort_periods is None: + cohort_periods = [4, 6] + + # Assign units to cohorts + n_never = int(n_units * never_enabled_frac) + n_treated_total = n_units - n_never + n_per_cohort = n_treated_total // len(cohort_periods) + + unit_cohort = np.zeros(n_units, dtype=float) + idx = n_never + for i, g in enumerate(cohort_periods): + n_g = n_per_cohort if i < len(cohort_periods) - 1 else n_treated_total - idx + n_never + unit_cohort[idx : idx + n_g] = g + idx += n_g + + # Assign eligibility (within each cohort, fraction eligible) + unit_elig = np.zeros(n_units, dtype=int) + for g_val in [0.0] + [float(g) for g in cohort_periods]: + mask = unit_cohort == g_val + n_g = int(np.sum(mask)) + if n_g == 0: + continue + n_eligible = max(1, min(int(n_g * eligibility_frac), n_g)) + indices = np.where(mask)[0] + eligible_idx = rng.choice(indices, size=n_eligible, replace=False) + unit_elig[eligible_idx] = 1 + + # Unit fixed effects + unit_fe = rng.normal(0, 2.0, size=n_units) + + # Covariates + x1 = rng.normal(0, 1, size=n_units) if add_covariates else None + x2 = rng.choice([0, 1], size=n_units) if add_covariates else None + + # Generate panel + records = [] + for i in range(n_units): + g_i = unit_cohort[i] + q_i = unit_elig[i] + for t in range(1, n_periods + 1): + # Base: unit FE + time trend + eligibility-time interaction + gamma_t = 0.1 * t + y = unit_fe[i] + gamma_t + 1.0 * q_i + eligibility_trend * q_i * gamma_t + + if add_covariates: + y += 0.5 * x1[i] + 0.3 * x2[i] + + # Treatment effect: enabled AND eligible + treated = int(g_i > 0 and t >= g_i and q_i == 1) + true_eff = 0.0 + if treated: + true_eff = treatment_effect + if dynamic_effects: + true_eff *= 1 + effect_growth * (t - g_i) + y += true_eff + + y += rng.normal(0, noise_sd) + + row = { + "unit": i, + "period": t, + "outcome": y, + "first_treat": int(g_i) if g_i > 0 else 0, + "eligibility": q_i, + "treated": treated, + "true_effect": true_eff, + } + if add_covariates: + row["x1"] = x1[i] + row["x2"] = x2[i] + + records.append(row) + + return pd.DataFrame(records) diff --git a/diff_diff/staggered_triple_diff.py b/diff_diff/staggered_triple_diff.py new file mode 100644 index 00000000..710be4b7 --- /dev/null +++ b/diff_diff/staggered_triple_diff.py @@ -0,0 +1,1215 @@ +""" +Staggered Triple Difference (DDD) estimator. + +Implements Ortiz-Villavicencio & Sant'Anna (2025) for staggered adoption +settings with an eligibility dimension, combining group-time DDD effects +via GMM-optimal weighting. + +Core pairwise DiD computation matches R's triplediff::compute_did() exactly +(Riesz/Hajek normalization, separate M1/M3 OR corrections, hessian = (X'WX)^{-1}*n). +""" + +import warnings +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from diff_diff.linalg import ( + _check_propensity_diagnostics, + solve_logit, +) +from diff_diff.staggered_aggregation import ( + CallawaySantAnnaAggregationMixin, +) +from diff_diff.staggered_bootstrap import ( + CallawaySantAnnaBootstrapMixin, +) +from diff_diff.staggered_triple_diff_results import StaggeredTripleDiffResults +from diff_diff.utils import safe_inference + +__all__ = [ + "StaggeredTripleDifference", + "StaggeredTripleDiffResults", +] + +# Type alias for pre-computed structures +PrecomputedData = Dict[str, Any] + + +class StaggeredTripleDifference( + CallawaySantAnnaBootstrapMixin, + CallawaySantAnnaAggregationMixin, +): + """ + Staggered Triple Difference (DDD) estimator. + + Computes group-time average treatment effects ATT(g,t) for settings + with staggered adoption and a binary eligibility dimension, using the + three-DiD decomposition of Ortiz-Villavicencio & Sant'Anna (2025). + + Multiple comparison groups are combined via GMM-optimal (inverse-variance) + weighting. Event study, group, and overall aggregations are supported. + + Parameters + ---------- + estimation_method : str, default="dr" + Estimation method: "dr" (doubly robust), "ipw" (inverse probability + weighting), or "reg" (regression adjustment). + alpha : float, default=0.05 + Significance level. + anticipation : int, default=0 + Number of anticipation periods. + base_period : str, default="varying" + Base period selection: "varying" (consecutive comparisons) or + "universal" (always vs g-1-anticipation). + n_bootstrap : int, default=0 + Number of multiplier bootstrap repetitions. 0 disables bootstrap. + bootstrap_weights : str, default="rademacher" + Bootstrap weight distribution: "rademacher", "mammen", or "webb". + seed : int or None, default=None + Random seed for reproducibility. + cband : bool, default=True + Whether to compute simultaneous confidence bands. + pscore_trim : float, default=0.01 + Propensity score trimming bound. + cluster : str or None, default=None + Column name for cluster-robust standard errors. + rank_deficient_action : str, default="warn" + Action for rank-deficient design matrices: "warn", "error", "silent". + + References + ---------- + Ortiz-Villavicencio, M. & Sant'Anna, P.H.C. (2025). "Better Understanding + Triple Differences Estimators." arXiv:2505.09942. + """ + + def __init__( + self, + estimation_method: str = "dr", + control_group: str = "notyettreated", + alpha: float = 0.05, + anticipation: int = 0, + base_period: str = "varying", + n_bootstrap: int = 0, + bootstrap_weights: str = "rademacher", + seed: Optional[int] = None, + cband: bool = True, + pscore_trim: float = 0.01, + cluster: Optional[str] = None, + rank_deficient_action: str = "warn", + ): + if estimation_method not in ["dr", "ipw", "reg"]: + raise ValueError( + f"estimation_method must be 'dr', 'ipw', or 'reg', " + f"got '{estimation_method}'" + ) + if control_group not in ["nevertreated", "notyettreated"]: + raise ValueError( + f"control_group must be 'nevertreated' or 'notyettreated', " + f"got '{control_group}'" + ) + if not (0 < pscore_trim < 0.5): + raise ValueError(f"pscore_trim must be in (0, 0.5), got {pscore_trim}") + if bootstrap_weights not in ["rademacher", "mammen", "webb"]: + raise ValueError( + f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', " + f"got '{bootstrap_weights}'" + ) + if rank_deficient_action not in ["warn", "error", "silent"]: + raise ValueError( + f"rank_deficient_action must be 'warn', 'error', or 'silent', " + f"got '{rank_deficient_action}'" + ) + if base_period not in ["varying", "universal"]: + raise ValueError( + f"base_period must be 'varying' or 'universal', " + f"got '{base_period}'" + ) + + self.estimation_method = estimation_method + self.control_group = control_group + self.alpha = alpha + self.anticipation = anticipation + self.base_period = base_period + self.n_bootstrap = n_bootstrap + self.bootstrap_weights = bootstrap_weights + self.bootstrap_weight_type = bootstrap_weights + self.seed = seed + self.cband = cband + self.pscore_trim = pscore_trim + self.cluster = cluster + self.rank_deficient_action = rank_deficient_action + + self.is_fitted_ = False + self.results_: Optional[StaggeredTripleDiffResults] = None + + def get_params(self) -> Dict[str, Any]: + """Get estimator parameters (sklearn-compatible).""" + return { + "estimation_method": self.estimation_method, + "control_group": self.control_group, + "alpha": self.alpha, + "anticipation": self.anticipation, + "base_period": self.base_period, + "n_bootstrap": self.n_bootstrap, + "bootstrap_weights": self.bootstrap_weights, + "seed": self.seed, + "cband": self.cband, + "pscore_trim": self.pscore_trim, + "cluster": self.cluster, + "rank_deficient_action": self.rank_deficient_action, + } + + def set_params(self, **params) -> "StaggeredTripleDifference": + """Set estimator parameters (sklearn-compatible).""" + valid_params = self.get_params() + for key, value in params.items(): + if key not in valid_params: + raise ValueError(f"Unknown parameter: {key}") + setattr(self, key, value) + if "bootstrap_weights" in params: + self.bootstrap_weight_type = params["bootstrap_weights"] + return self + + # ------------------------------------------------------------------ + # fit() + # ------------------------------------------------------------------ + + def fit( + self, + data: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + eligibility: str, + covariates: Optional[List[str]] = None, + aggregate: Optional[str] = None, + balance_e: Optional[int] = None, + survey_design: object = None, + ) -> StaggeredTripleDiffResults: + """ + Fit the staggered triple difference estimator. + + Parameters + ---------- + data : pd.DataFrame + Panel data. + outcome : str + Outcome variable column name. + unit : str + Unit identifier column name. + time : str + Time period column name. + first_treat : str + Column with the enabling period for each unit's group. + Use 0 or np.inf for never-enabled units. + eligibility : str + Binary eligibility indicator column (0/1, time-invariant). + covariates : list of str, optional + Covariate column names. + aggregate : str, optional + Aggregation method: "event_study", "group", "simple", or "all". + balance_e : int, optional + Event time to balance on for event study. + survey_design : object, optional + Survey design specification (not yet supported). + + Returns + ------- + StaggeredTripleDiffResults + """ + if survey_design is not None: + raise NotImplementedError( + "Survey design support for staggered DDD is planned for a " + "future release." + ) + if aggregate is not None and aggregate not in [ + "event_study", "group", "simple", "all", + ]: + raise ValueError( + f"aggregate must be 'event_study', 'group', 'simple', or 'all', " + f"got '{aggregate}'" + ) + + df = data.copy() + self._validate_inputs( + df, outcome, unit, time, first_treat, eligibility, covariates + ) + + if self.cluster is not None: + warnings.warn( + "cluster parameter is accepted but cluster-robust analytical SEs " + "are not yet implemented for staggered DDD. Use n_bootstrap > 0 " + "for unit-level clustered inference via multiplier bootstrap.", + UserWarning, + stacklevel=2, + ) + + if first_treat != "first_treat": + df["first_treat"] = df[first_treat] + df["first_treat"] = df["first_treat"].replace([np.inf, float("inf")], 0) + + precomputed = self._precompute_structures( + df, outcome, unit, time, eligibility, covariates + ) + + treatment_groups = precomputed["treatment_groups"] + time_periods = precomputed["time_periods"] + all_units = precomputed["all_units"] + time_to_col = precomputed["time_to_col"] + unit_cohorts = precomputed["unit_cohorts"] + eligibility_per_unit = precomputed["eligibility_per_unit"] + n_units = len(all_units) + + pscore_cache: Dict = {} + cho_cache: Dict = {} + + group_time_effects: Dict[Tuple, Dict[str, Any]] = {} + influence_func_info: Dict[Tuple, Dict[str, Any]] = {} + comparison_group_counts: Dict[Tuple, int] = {} + gmm_weights_store: Dict[Tuple, Dict] = {} + + for g in treatment_groups: + # In universal mode, skip the reference period (t == g-1-anticipation) + # so it's omitted from GT estimation. The event-study mixin injects + # a synthetic reference row with effect=0, matching CS behavior. + if self.base_period == "universal": + universal_base = g - 1 - self.anticipation + valid_periods = [t for t in time_periods if t != universal_base] + else: + valid_periods = time_periods + + for t in valid_periods: + base_period_val = self._get_base_period(g, t) + if base_period_val is None: + continue + if base_period_val not in time_to_col: + warnings.warn( + f"Base period {base_period_val} for (g={g}, t={t}) is " + "outside the observed panel. Skipping this cell.", + UserWarning, stacklevel=2, + ) + continue + if t not in time_to_col: + continue + + has_never_enabled = bool(np.any(unit_cohorts == 0)) + + if self.control_group == "nevertreated": + # Only use never-enabled cohort as comparison + valid_gc = [0] if has_never_enabled else [] + else: + # Use all valid comparison cohorts (not-yet-treated + never) + # Threshold accounts for anticipation: cohorts that start + # treatment within the anticipation window are contaminated. + nyt_threshold = max(t, base_period_val) + self.anticipation + valid_gc = [ + gc for gc in treatment_groups + if gc > nyt_threshold and gc != g + ] + if has_never_enabled: + valid_gc = [0] + valid_gc + + if not valid_gc: + warnings.warn( + f"No valid comparison groups for (g={g}, t={t}), skipping.", + UserWarning, stacklevel=2, + ) + continue + + treated_mask = (unit_cohorts == g) & (eligibility_per_unit == 1) + n_treated = int(np.sum(treated_mask)) + if n_treated == 0: + continue + + att_vec = [] + inf_raw = [] # unrescaled IFs + gc_labels = [] + gc_cell_sizes = [] # size_gt_ctrl per surviving gc + + for gc in valid_gc: + result = self._compute_ddd_gt_gc( + precomputed, g, gc, t, base_period_val, + covariates, pscore_cache, cho_cache, + ) + if result is None: + continue + att_gc, inf_gc, size_gt_ctrl = result + if not np.isfinite(att_gc): + continue + + att_vec.append(att_gc) + inf_raw.append(inf_gc) + gc_labels.append(gc) + gc_cell_sizes.append(size_gt_ctrl) + + if not att_vec: + continue + + # Compute size_gt from SURVIVING comparison cohorts only + # (not from all initially valid gc's) + surviving_units = treated_mask.copy() + for gc in gc_labels: + surviving_units |= (unit_cohorts == gc) | (unit_cohorts == g) + size_gt = int(np.sum(surviving_units)) + + # Apply IF rescaling now that size_gt is known + inf_matrix = [] + for inf_gc, size_gt_ctrl in zip(inf_raw, gc_cell_sizes): + if size_gt_ctrl > 0: + inf_gc = inf_gc * (size_gt / size_gt_ctrl) + inf_matrix.append(inf_gc) + + att_gmm, inf_gmm, gmm_w, se_gt = self._combine_gmm( + np.array(att_vec), np.array(inf_matrix), n_units, + ) + + if not np.isfinite(att_gmm): + continue + + # R's single-gc SE uses size_gt in denominator, not n_total. + # For multi-gc (GMM), the size_gt factor is already in Omega + # via the per-gc rescaling, so n_total is correct. + if len(gc_labels) == 1: + se_gt = float(np.sqrt(np.sum(inf_gmm**2) / size_gt**2)) + + if not np.isfinite(se_gt) or se_gt <= 0: + se_gt = np.nan + + t_stat, p_value, conf_int = safe_inference( + att_gmm, se_gt, alpha=self.alpha + ) + + # Rescale IF for mixin compatibility. + # R stores IF * (n/size_gt) in inf_func_mat, then uses + # SE = sqrt(sum(IF^2)/n^2) = sqrt(sum(psi^2)) with psi = IF/n. + # We need psi = IF_rescaled / n so mixin's sqrt(sum(psi^2)) works. + # IF is already at size_gt/size_gt_ctrl scale from above. + # Apply the final n/size_gt factor, then divide by n for mixin. + inf_gmm_rescaled = inf_gmm * (n_units / size_gt) + inf_gmm_scaled = inf_gmm_rescaled / n_units + + treated_idx = np.where(treated_mask)[0] + treated_inf = inf_gmm_scaled[treated_idx] + nonzero_mask = (inf_gmm_scaled != 0) & ~treated_mask + control_idx = np.where(nonzero_mask)[0] + control_inf = inf_gmm_scaled[control_idx] + n_control = int(np.sum(nonzero_mask)) + + group_time_effects[(g, t)] = { + "effect": att_gmm, + "se": se_gt, + "t_stat": t_stat, + "p_value": p_value, + "conf_int": conf_int, + "n_treated": n_treated, + "n_control": n_control, + } + influence_func_info[(g, t)] = { + "treated_idx": treated_idx, + "control_idx": control_idx, + "treated_units": all_units[treated_idx], + "control_units": all_units[control_idx], + "treated_inf": treated_inf, + "control_inf": control_inf, + } + comparison_group_counts[(g, t)] = len(gc_labels) + gmm_weights_store[(g, t)] = dict(zip(gc_labels, gmm_w.tolist())) + + if not group_time_effects: + raise ValueError( + "No valid group-time effects could be computed. " + "Check that the data has sufficient variation in treatment " + "timing and eligibility." + ) + + # For aggregation: use eligible-treated-only cohort assignments so + # WIF weights match the point estimate weights (n_treated per cohort, + # i.e. P(S=g, Q=1)). This matches the paper's Eq 4.13 which defines + # aggregation weights over the treated population (G_i defined only + # for Q=1 units). Ineligible units get cohort=0 so they don't + # contribute to pg for any treatment group. + # Both precomputed["unit_cohorts"] AND df["first_treat"] must be + # zeroed for ineligible units because the WIF code reads both. + precomputed_agg = dict(precomputed) + cohorts_for_agg = precomputed["unit_cohorts"].copy() + cohorts_for_agg[eligibility_per_unit == 0] = 0 + precomputed_agg["unit_cohorts"] = cohorts_for_agg + + df_agg = df.copy() + df_agg.loc[df_agg[eligibility] == 0, "first_treat"] = 0 + + # Overall ATT via aggregation mixin + overall_att, overall_se = self._aggregate_simple( + group_time_effects, influence_func_info, df_agg, unit, precomputed_agg + ) + overall_t_stat, overall_p_value, overall_conf_int = safe_inference( + overall_att, overall_se, alpha=self.alpha + ) + + # Aggregations + event_study_effects = None + group_effects = None + if aggregate in ("event_study", "all"): + event_study_effects = self._aggregate_event_study( + group_time_effects, influence_func_info, + treatment_groups, time_periods, balance_e, + df_agg, unit, precomputed_agg, + ) + if aggregate in ("group", "all"): + group_effects = self._aggregate_by_group( + group_time_effects, influence_func_info, + treatment_groups, precomputed_agg, df_agg, unit, + ) + + # Bootstrap + bootstrap_results = None + cband_crit_value = None + if self.n_bootstrap > 0: + bootstrap_results = self._run_multiplier_bootstrap( + group_time_effects, influence_func_info, + aggregate, balance_e, + treatment_groups, time_periods, + df_agg, unit, precomputed_agg, self.cband, + ) + if bootstrap_results is not None: + overall_se = bootstrap_results.overall_att_se + overall_t_stat, overall_p_value, overall_conf_int = safe_inference( + overall_att, overall_se, alpha=self.alpha + ) + overall_conf_int = bootstrap_results.overall_att_ci + overall_p_value = bootstrap_results.overall_att_p_value + if bootstrap_results.cband_crit_value is not None: + cband_crit_value = bootstrap_results.cband_crit_value + + # Update group-time effects with bootstrap SEs + if bootstrap_results.group_time_ses: + for gt_key in group_time_effects: + if gt_key in bootstrap_results.group_time_ses: + group_time_effects[gt_key]["se"] = ( + bootstrap_results.group_time_ses[gt_key] + ) + group_time_effects[gt_key]["conf_int"] = ( + bootstrap_results.group_time_cis[gt_key] + ) + group_time_effects[gt_key]["p_value"] = ( + bootstrap_results.group_time_p_values[gt_key] + ) + t_val, _, _ = safe_inference( + group_time_effects[gt_key]["effect"], + bootstrap_results.group_time_ses[gt_key], + alpha=self.alpha, + ) + group_time_effects[gt_key]["t_stat"] = t_val + + if event_study_effects and bootstrap_results.event_study_ses: + for e_key in event_study_effects: + if e_key in bootstrap_results.event_study_ses: + event_study_effects[e_key]["se"] = ( + bootstrap_results.event_study_ses[e_key] + ) + event_study_effects[e_key]["conf_int"] = ( + bootstrap_results.event_study_cis[e_key] + ) + event_study_effects[e_key]["p_value"] = ( + bootstrap_results.event_study_p_values[e_key] + ) + t_val, _, _ = safe_inference( + event_study_effects[e_key]["effect"], + bootstrap_results.event_study_ses[e_key], + alpha=self.alpha, + ) + event_study_effects[e_key]["t_stat"] = t_val + if cband_crit_value is not None: + bs_se = bootstrap_results.event_study_ses[e_key] + eff = event_study_effects[e_key]["effect"] + event_study_effects[e_key]["cband_conf_int"] = ( + eff - cband_crit_value * bs_se, + eff + cband_crit_value * bs_se, + ) + + # Update group effects with bootstrap SEs + if ( + group_effects + and bootstrap_results.group_effect_ses is not None + and bootstrap_results.group_effect_cis is not None + and bootstrap_results.group_effect_p_values is not None + ): + grp_keys = [ + g for g in group_effects + if g in bootstrap_results.group_effect_ses + ] + for g_key in grp_keys: + group_effects[g_key]["se"] = ( + bootstrap_results.group_effect_ses[g_key] + ) + group_effects[g_key]["conf_int"] = ( + bootstrap_results.group_effect_cis[g_key] + ) + group_effects[g_key]["p_value"] = ( + bootstrap_results.group_effect_p_values[g_key] + ) + t_val, _, _ = safe_inference( + group_effects[g_key]["effect"], + bootstrap_results.group_effect_ses[g_key], + alpha=self.alpha, + ) + group_effects[g_key]["t_stat"] = t_val + + n_treated_units = int(np.sum( + (unit_cohorts > 0) & (eligibility_per_unit == 1) + )) + n_control_units = n_units - n_treated_units + n_never_enabled = int(np.sum(unit_cohorts == 0)) + n_eligible = int(np.sum(eligibility_per_unit == 1)) + n_ineligible = int(np.sum(eligibility_per_unit == 0)) + + self.results_ = StaggeredTripleDiffResults( + group_time_effects=group_time_effects, + overall_att=overall_att, + overall_se=overall_se, + overall_t_stat=overall_t_stat, + overall_p_value=overall_p_value, + overall_conf_int=overall_conf_int, + groups=treatment_groups, + time_periods=time_periods, + n_obs=len(df), + n_treated_units=n_treated_units, + n_control_units=n_control_units, + n_never_enabled=n_never_enabled, + n_eligible=n_eligible, + n_ineligible=n_ineligible, + alpha=self.alpha, + control_group=self.control_group, + base_period=self.base_period, + estimation_method=self.estimation_method, + event_study_effects=event_study_effects, + group_effects=group_effects, + bootstrap_results=bootstrap_results, + cband_crit_value=cband_crit_value, + pscore_trim=self.pscore_trim, + comparison_group_counts=comparison_group_counts, + gmm_weights=gmm_weights_store, + ) + self.is_fitted_ = True + return self.results_ + + # ------------------------------------------------------------------ + # Validation + # ------------------------------------------------------------------ + + def _validate_inputs( + self, df: pd.DataFrame, outcome: str, unit: str, time: str, + first_treat: str, eligibility: str, covariates: Optional[List[str]], + ) -> None: + """Validate input data.""" + required_cols = [outcome, unit, time, first_treat, eligibility] + if covariates: + required_cols.extend(covariates) + missing = [c for c in required_cols if c not in df.columns] + if missing: + raise ValueError(f"Missing columns: {missing}") + + elig_vals = df[eligibility].dropna().unique() + if not set(elig_vals).issubset({0, 1, 0.0, 1.0}): + raise ValueError( + f"Eligibility column '{eligibility}' must be binary (0/1). " + f"Found values: {sorted(elig_vals)}" + ) + elig_by_unit = df.groupby(unit)[eligibility].nunique() + varying = elig_by_unit[elig_by_unit > 1] + if len(varying) > 0: + raise ValueError( + f"Eligibility must be time-invariant within units. " + f"Found {len(varying)} units with varying eligibility." + ) + for col in [outcome, first_treat, eligibility]: + if df[col].isna().any(): + raise ValueError(f"Column '{col}' contains missing values.") + + # Reject non-finite outcomes (Inf/-Inf) + if not np.all(np.isfinite(df[outcome])): + raise ValueError( + f"Column '{outcome}' contains non-finite values (Inf/-Inf). " + "All outcome values must be finite." + ) + + # Reject non-finite covariates + if covariates: + for cov in covariates: + if df[cov].isna().any(): + raise ValueError(f"Covariate '{cov}' contains missing values.") + if not np.all(np.isfinite(df[cov])): + raise ValueError( + f"Covariate '{cov}' contains non-finite values." + ) + if df[eligibility].nunique() < 2: + raise ValueError( + "Need both eligible (Q=1) and ineligible (Q=0) units. " + f"Only found Q={df[eligibility].unique()[0]}." + ) + + # Check unique (unit, time) pairs — no duplicate rows + dup = df.duplicated(subset=[unit, time], keep=False) + if dup.any(): + raise ValueError( + f"Duplicate (unit, time) rows found. " + f"{int(dup.sum())} duplicates detected. Panel must have unique rows." + ) + + # Check balanced panel — every unit observed in exactly the global period set + global_periods = set(df[time].unique()) + n_global_periods = len(global_periods) + unit_period_sets = df.groupby(unit)[time].apply(set) + mismatched = unit_period_sets[unit_period_sets != global_periods] + if len(mismatched) > 0: + raise ValueError( + "Unbalanced panel detected. All units must be observed in " + f"all {n_global_periods} periods. " + f"Found {len(mismatched)} units with different period sets." + ) + + # Check time-invariant first_treat + ft_by_unit = df.groupby(unit)[first_treat].nunique() + varying_ft = ft_by_unit[ft_by_unit > 1] + if len(varying_ft) > 0: + raise ValueError( + f"first_treat must be time-invariant within units. " + f"Found {len(varying_ft)} units with varying first_treat." + ) + + # Check time-invariant covariates + if covariates: + for cov in covariates: + cov_nunique = df.groupby(unit)[cov].nunique() + varying_cov = cov_nunique[cov_nunique > 1] + if len(varying_cov) > 0: + raise ValueError( + f"Covariate '{cov}' must be time-invariant within units. " + f"Found {len(varying_cov)} units with varying values." + ) + + # ------------------------------------------------------------------ + # Precomputation + # ------------------------------------------------------------------ + + def _precompute_structures( + self, df: pd.DataFrame, outcome: str, unit: str, time: str, + eligibility: str, covariates: Optional[List[str]], + ) -> PrecomputedData: + """Build precomputed structures for efficient computation.""" + all_units = np.array(sorted(df[unit].unique())) + time_periods = sorted(df[time].unique()) + n_units = len(all_units) + n_periods = len(time_periods) + + unit_to_idx = {u: i for i, u in enumerate(all_units)} + time_to_col = {t: j for j, t in enumerate(time_periods)} + + outcome_matrix = np.full((n_units, n_periods), np.nan) + for _, row in df.iterrows(): + u_idx = unit_to_idx[row[unit]] + t_idx = time_to_col[row[time]] + outcome_matrix[u_idx, t_idx] = row[outcome] + + unit_df = df.groupby(unit).first().reindex(all_units) + unit_cohorts = unit_df["first_treat"].values.astype(float) + eligibility_per_unit = unit_df[eligibility].values.astype(int) + + treatment_groups = sorted([g for g in np.unique(unit_cohorts) if g > 0]) + + covariate_matrix = None + if covariates: + cov_wide = {} + for cov in covariates: + cov_vals = np.full(n_units, np.nan) + for u_id, idx in unit_to_idx.items(): + u_data = df.loc[df[unit] == u_id, cov] + if len(u_data) > 0: + cov_vals[idx] = u_data.iloc[0] + cov_wide[cov] = cov_vals + covariate_matrix = np.column_stack(list(cov_wide.values())) + + return { + "all_units": all_units, + "unit_to_idx": unit_to_idx, + "time_periods": time_periods, + "time_to_col": time_to_col, + "outcome_matrix": outcome_matrix, + "unit_cohorts": unit_cohorts, + "eligibility_per_unit": eligibility_per_unit, + "treatment_groups": treatment_groups, + "covariate_matrix": covariate_matrix, + "n_units": n_units, + "n_periods": n_periods, + "survey_weights": None, + "resolved_survey_unit": None, + } + + # ------------------------------------------------------------------ + # Base period + # ------------------------------------------------------------------ + + def _get_base_period(self, g: Any, t: Any) -> Optional[Any]: + """Determine base period for a (g, t) pair.""" + if self.base_period == "universal": + return g - 1 - self.anticipation + else: + if t < g - self.anticipation: + return t - 1 + else: + return g - 1 - self.anticipation + + # ------------------------------------------------------------------ + # Three-DiD DDD for one (g, g_c, t) triple + # ------------------------------------------------------------------ + + def _compute_ddd_gt_gc( + self, precomputed: PrecomputedData, g: Any, g_c: Any, t: Any, + base_period_val: Any, covariates: Optional[List[str]], + pscore_cache: Dict, cho_cache: Dict, + ) -> Optional[Tuple[float, np.ndarray, int]]: + """ + Compute DDD ATT for one (g, g_c, t) triple. + + Returns (att_ddd, inf_full_n_units, size_gt_ctrl) or None. + """ + outcome_matrix = precomputed["outcome_matrix"] + time_to_col = precomputed["time_to_col"] + unit_cohorts = precomputed["unit_cohorts"] + eligibility_per_unit = precomputed["eligibility_per_unit"] + covariate_matrix = precomputed["covariate_matrix"] + n_units = precomputed["n_units"] + + t_col = time_to_col[t] + b_col = time_to_col[base_period_val] + + # Four sub-groups within this (g, g_c) cell + treated_mask = (unit_cohorts == g) & (eligibility_per_unit == 1) # subgroup 4 + sub_a_mask = (unit_cohorts == g) & (eligibility_per_unit == 0) # subgroup 3 + sub_b_mask = (unit_cohorts == g_c) & (eligibility_per_unit == 1) # subgroup 2 + sub_c_mask = (unit_cohorts == g_c) & (eligibility_per_unit == 0) # subgroup 1 + + n_treated = int(np.sum(treated_mask)) + n_a = int(np.sum(sub_a_mask)) + n_b = int(np.sum(sub_b_mask)) + n_c = int(np.sum(sub_c_mask)) + + if n_treated == 0 or n_a == 0 or n_b == 0 or n_c == 0: + empty = [] + if n_treated == 0: + empty.append(f"(S={g},Q=1)") + if n_a == 0: + empty.append(f"(S={g},Q=0)") + if n_b == 0: + empty.append(f"(S={g_c},Q=1)") + if n_c == 0: + empty.append(f"(S={g_c},Q=0)") + warnings.warn( + f"Empty subgroup(s) {', '.join(empty)} for " + f"(g={g}, g_c={g_c}, t={t}). " + "Comparison unidentified, skipping.", + UserWarning, stacklevel=3, + ) + return None + + if min(n_treated, n_a, n_b, n_c) < 5: + warnings.warn( + f"Small cell size for (g={g}, g_c={g_c}, t={t}). " + "Estimates may be unreliable.", UserWarning, stacklevel=3, + ) + + # Outcome changes + delta_y_all = outcome_matrix[:, t_col] - outcome_matrix[:, b_col] + valid = np.isfinite(delta_y_all) + for m in [treated_mask, sub_a_mask, sub_b_mask, sub_c_mask]: + if not np.all(valid[m]): + return None + + # Three pairwise DiDs, each on a 2-cell subset + # DiD_A: subgroup 4 vs 3 (treated-eligible vs treated-ineligible) + pair_a_mask = treated_mask | sub_a_mask + did_a = self._run_pairwise_did( + delta_y_all, pair_a_mask, treated_mask, sub_a_mask, + covariate_matrix, pscore_cache, (g, g, 0, base_period_val), + cho_cache, ("a", g, g, base_period_val), + ) + + # DiD_B: subgroup 4 vs 2 (treated-eligible vs control-eligible) + pair_b_mask = treated_mask | sub_b_mask + did_b = self._run_pairwise_did( + delta_y_all, pair_b_mask, treated_mask, sub_b_mask, + covariate_matrix, pscore_cache, (g, g_c, 1, base_period_val), + cho_cache, ("b", g, g_c, base_period_val), + ) + + # DiD_C: subgroup 4 vs 1 (treated-eligible vs control-ineligible) + pair_c_mask = treated_mask | sub_c_mask + did_c = self._run_pairwise_did( + delta_y_all, pair_c_mask, treated_mask, sub_c_mask, + covariate_matrix, pscore_cache, (g, g_c, 0, base_period_val), + cho_cache, ("c", g, g_c, base_period_val), + ) + + if did_a is None or did_b is None or did_c is None: + return None + + att_a, inf_a = did_a + att_b, inf_b = did_b + att_c, inf_c = did_c + + att_ddd = att_a + att_b - att_c + + # Three-DiD IF combination: w_j = n_cell / n_pair_j (R's att_dr convention) + n_cell = n_treated + n_a + n_b + n_c + n_pair_a = n_treated + n_a + n_pair_b = n_treated + n_b + n_pair_c = n_treated + n_c + w_3 = n_cell / n_pair_a if n_pair_a > 0 else 1.0 + w_2 = n_cell / n_pair_b if n_pair_b > 0 else 1.0 + w_1 = n_cell / n_pair_c if n_pair_c > 0 else 1.0 + + # Scatter pair-level IFs into n_units-length vector + inf_full = np.zeros(n_units) + pair_a_idx = np.where(pair_a_mask)[0] + pair_b_idx = np.where(pair_b_mask)[0] + pair_c_idx = np.where(pair_c_mask)[0] + + inf_full[pair_a_idx] += w_3 * inf_a + inf_full[pair_b_idx] += w_2 * inf_b + inf_full[pair_c_idx] -= w_1 * inf_c + + size_gt_ctrl = n_cell + return att_ddd, inf_full, size_gt_ctrl + + # ------------------------------------------------------------------ + # Pairwise DiD (matches R's compute_did) + # ------------------------------------------------------------------ + + def _run_pairwise_did( + self, + delta_y_all: np.ndarray, + pair_mask: np.ndarray, + treated_mask: np.ndarray, + control_mask: np.ndarray, + covariate_matrix: Optional[np.ndarray], + pscore_cache: Dict, + pscore_key: Any, + cho_cache: Dict, + cho_key: Any, + ) -> Optional[Tuple[float, np.ndarray]]: + """ + Compute a single pairwise DiD ATT and IF on a 2-cell subset. + + Matches R's triplediff::compute_did() formulation exactly: + Riesz/Hajek normalization, PS + OR IF corrections. + + Returns (att, inf_func) where inf_func has length n_pair, + ordered by pair_mask indices. Returns None if insufficient data. + """ + pair_idx = np.where(pair_mask)[0] + n_pair = len(pair_idx) + if n_pair == 0: + return None + + delta_y = delta_y_all[pair_idx] + PA4 = treated_mask[pair_idx].astype(float) + PAa = control_mask[pair_idx].astype(float) + + n_t = int(np.sum(PA4)) + n_c = int(np.sum(PAa)) + if n_t == 0 or n_c == 0: + return None + + has_covariates = ( + covariate_matrix is not None + and self.estimation_method != "none" + ) + + # Build covariate matrix with intercept for the pair + covX = None + if has_covariates: + X_pair = covariate_matrix[pair_idx] + covX = np.column_stack([np.ones(n_pair), X_pair]) + + # Compute nuisance parameters based on estimation method + pscore = None + hessian = None + or_delta = np.zeros(n_pair) + + if self.estimation_method in ("ipw", "dr") and covX is not None: + pscore, hessian = self._compute_pscore( + PA4, covX, pscore_cache, pscore_key + ) + + if self.estimation_method in ("reg", "dr") and covX is not None: + or_delta = self._compute_or( + delta_y, PAa, covX, cho_cache, cho_key + ) + + # Compute ATT and IF (R's compute_did formulation) + return self._compute_did_panel( + delta_y, PA4, PAa, covX, pscore, hessian, or_delta + ) + + # ------------------------------------------------------------------ + # Core DR/IPW/RA computation (matches R's compute_did exactly) + # ------------------------------------------------------------------ + + def _compute_did_panel( + self, + delta_y: np.ndarray, + PA4: np.ndarray, + PAa: np.ndarray, + covX: Optional[np.ndarray], + pscore: Optional[np.ndarray], + hessian: Optional[np.ndarray], + or_delta: np.ndarray, + ) -> Tuple[float, np.ndarray]: + """ + Pairwise DiD ATT and influence function. + Matches R's triplediff::compute_did() line-by-line. + + Parameters + ---------- + delta_y : outcome changes for 2-cell subset (n_pair,) + PA4 : treated indicator (n_pair,) + PAa : control indicator (n_pair,) + covX : covariate matrix with intercept (n_pair, p) or None + pscore : propensity scores (n_pair,) or None + hessian : (X'WX)^{-1} * n_pair or None + or_delta : OR predictions (n_pair,), zeros if no covariates + + Returns + ------- + (att, inf_func) where inf_func has length n_pair. + """ + n_pair = len(delta_y) + est = self.estimation_method + + # Riesz representers (R lines 243-250) + if est == "reg" or pscore is None: + w_treat = PA4.copy() + w_control = PAa.copy() + else: + w_treat = PA4.copy() + pscore_safe = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim) + w_control = pscore_safe * PAa / (1 - pscore_safe) + + # DR ATT via Hajek normalization (R lines 251-256) + resid = delta_y - or_delta + riesz_treat = w_treat * resid + riesz_control = w_control * resid + + mean_w_treat = np.mean(w_treat) + mean_w_control = np.mean(w_control) + + if mean_w_treat <= 0 or mean_w_control <= 0: + return float("nan"), np.zeros(n_pair) + + att_treat = np.mean(riesz_treat) / mean_w_treat + att_control = np.mean(riesz_control) / mean_w_control + dr_att = att_treat - att_control + + # Base IF (R lines 302-304) + inf_treat_did = riesz_treat - w_treat * att_treat + inf_control_did = riesz_control - w_control * att_control + + # PS correction (R lines 262-273) — IPW and DR only + inf_control_pscore = 0.0 + if est != "reg" and hessian is not None and covX is not None: + M2 = np.mean( + (w_control * (resid - att_control))[:, None] * covX, axis=0 + ) + score_ps = (PA4 - pscore_safe)[:, None] * covX + asy_lin_rep_ps = score_ps @ hessian + inf_control_pscore = asy_lin_rep_ps @ M2 + + # OR correction (R lines 278-300) — reg and DR only + inf_treat_or = 0.0 + inf_cont_or = 0.0 + if est != "ipw" and covX is not None: + M1 = np.mean(w_treat[:, None] * covX, axis=0) + M3 = np.mean(w_control[:, None] * covX, axis=0) + + or_x = PAa[:, None] * covX + or_ex = (PAa * resid)[:, None] * covX + XpX = or_x.T @ covX / n_pair + + try: + asy_linear_or = (np.linalg.solve(XpX, or_ex.T)).T + except np.linalg.LinAlgError: + asy_linear_or = (np.linalg.lstsq(XpX, or_ex.T, rcond=None)[0]).T + + inf_treat_or = -(asy_linear_or @ M1) + inf_cont_or = -(asy_linear_or @ M3) + + # Final IF assembly (R lines 307-310) + inf_control = (inf_control_did + inf_control_pscore + inf_cont_or) / mean_w_control + inf_treat = (inf_treat_did + inf_treat_or) / mean_w_treat + inf_func = inf_treat - inf_control + + return float(dr_att), inf_func + + # ------------------------------------------------------------------ + # Nuisance parameter computation + # ------------------------------------------------------------------ + + def _compute_pscore( + self, PA4: np.ndarray, covX: np.ndarray, + pscore_cache: Dict, pscore_key: Any, + ) -> Tuple[np.ndarray, np.ndarray]: + """Fit logistic P(PA4=1|X). Returns (pscore, hessian). + + hessian = (X'WX)^{-1} * n_pair, matching R's convention. + """ + cached = pscore_cache.get(pscore_key) + n_pair = len(PA4) + + if cached is not None: + beta_logistic = cached + z = np.dot(covX, beta_logistic) + z = np.clip(z, -500, 500) + pscore = 1 / (1 + np.exp(-z)) + else: + X_no_intercept = covX[:, 1:] # solve_logit adds its own intercept + try: + beta_logistic, pscore = solve_logit( + X_no_intercept, PA4, + rank_deficient_action=self.rank_deficient_action, + ) + _check_propensity_diagnostics(pscore, self.pscore_trim) + # Zero-fill NaN coefficients (from rank-deficient columns) + # before caching, so cache reuse doesn't propagate NaN. + beta_clean = np.where(np.isfinite(beta_logistic), beta_logistic, 0.0) + pscore_cache[pscore_key] = beta_clean + except (np.linalg.LinAlgError, ValueError): + if self.rank_deficient_action == "error": + raise + warnings.warn( + "Propensity score estimation failed. " + "Falling back to unconditional.", + UserWarning, stacklevel=5, + ) + pscore = np.full(n_pair, np.mean(PA4)) + pscore = np.clip(pscore, self.pscore_trim, 1 - self.pscore_trim) + # No hessian for unconditional fallback + return pscore, None + + pscore = np.clip(pscore, 1e-6, 1 - 1e-6) + + # Hessian: (X'WX)^{-1} * n (matching R's compute_pscore) + W = pscore * (1 - pscore) + XWX = covX.T @ (W[:, None] * covX) + try: + hessian = np.linalg.inv(XWX) * n_pair + except np.linalg.LinAlgError: + hessian = np.linalg.lstsq(XWX, np.eye(XWX.shape[0]), rcond=None)[0] * n_pair + + return pscore, hessian + + def _compute_or( + self, delta_y: np.ndarray, PAa: np.ndarray, covX: np.ndarray, + cho_cache: Dict, cho_key: Any, + ) -> np.ndarray: + """Fit OLS on control outcome changes. Returns or_delta for all pair units. + + Honors self.rank_deficient_action for collinear covariates. + """ + from diff_diff.linalg import solve_ols as _solve_ols + + control_mask = PAa > 0 + n_c = int(np.sum(control_mask)) + if n_c == 0: + return np.zeros(len(delta_y)) + + X_control = covX[control_mask] + y_control = delta_y[control_mask] + + # Try Cholesky cache for fast path (full-rank only) + beta = None + cached_cho = cho_cache.get(cho_key) + if cached_cho is False: + pass # Previously detected rank-deficient; skip Cholesky + elif cached_cho is not None: + from scipy import linalg as sp_linalg + Xty = X_control.T @ y_control + beta = sp_linalg.cho_solve(cached_cho, Xty) + if np.any(~np.isfinite(beta)): + beta = None + elif cho_key not in cho_cache: + XtX = X_control.T @ X_control + try: + from scipy import linalg as sp_linalg + cho_factor = sp_linalg.cho_factor(XtX) + cho_cache[cho_key] = cho_factor + Xty = X_control.T @ y_control + beta = sp_linalg.cho_solve(cho_factor, Xty) + if np.any(~np.isfinite(beta)): + beta = None + except np.linalg.LinAlgError: + cho_cache[cho_key] = False + + if beta is None: + # Fallback: use solve_ols which honors rank_deficient_action + beta, _, _ = _solve_ols( + X_control, y_control, + rank_deficient_action=self.rank_deficient_action, + ) + beta = np.where(np.isfinite(beta), beta, 0.0) + + return covX @ beta + + # ------------------------------------------------------------------ + # GMM-optimal combination (matches R's att_gt GMM procedure) + # ------------------------------------------------------------------ + + def _combine_gmm( + self, + att_vec: np.ndarray, + inf_func_matrix: np.ndarray, + n_units: int, + ) -> Tuple[float, np.ndarray, np.ndarray, float]: + """ + Combine comparison-group-specific estimates via GMM-optimal weights. + + Returns (att_gmm, inf_gmm, weights, se_gmm). + """ + k = len(att_vec) + + if k == 1: + att_gmm = float(att_vec[0]) + inf_gmm = inf_func_matrix[0].copy() + # R's SE: sqrt(sum(IF^2) / n^2) + se_gmm = float(np.sqrt(np.sum(inf_gmm**2) / n_units**2)) + return att_gmm, inf_gmm, np.array([1.0]), se_gmm + + # R: OMEGA <- cov(inf_mat_local) — sample covariance, ddof=1 + Omega = np.cov(inf_func_matrix) + + ones = np.ones(k) + try: + Omega_inv = np.linalg.inv(Omega) + except np.linalg.LinAlgError: + warnings.warn( + "Singular covariance matrix in GMM combination. " + "Using pseudoinverse.", UserWarning, stacklevel=3, + ) + Omega_inv = np.linalg.pinv(Omega) + + denom = float(ones @ Omega_inv @ ones) + if denom <= 0 or not np.isfinite(denom): + weights = np.full(k, 1.0 / k) + att_gmm = float(weights @ att_vec) + inf_gmm = weights @ inf_func_matrix + se_gmm = float(np.sqrt(np.sum(inf_gmm**2) / n_units**2)) + else: + weights = (Omega_inv @ ones) / denom + att_gmm = float(weights @ att_vec) + inf_gmm = weights @ inf_func_matrix + # R: gmm_se <- sqrt(1 / (n * sum(inv_OMEGA))) + se_gmm = float(np.sqrt(1.0 / (n_units * denom))) + + return att_gmm, inf_gmm, weights, se_gmm diff --git a/diff_diff/staggered_triple_diff_results.py b/diff_diff/staggered_triple_diff_results.py new file mode 100644 index 00000000..8ebbce2c --- /dev/null +++ b/diff_diff/staggered_triple_diff_results.py @@ -0,0 +1,348 @@ +""" +Result container classes for Staggered Triple Difference estimator. + +This module provides dataclass containers for storing and presenting +group-time DDD effects and their aggregations. +""" + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from diff_diff.results import _format_survey_block, _get_significance_stars + +if TYPE_CHECKING: + from diff_diff.staggered_bootstrap import CSBootstrapResults + + +@dataclass +class StaggeredTripleDiffResults: + """ + Results from Staggered Triple Difference (DDD) estimation. + + Implements the Ortiz-Villavicencio & Sant'Anna (2025) estimator for + staggered adoption settings with an eligibility dimension. + + Attributes + ---------- + group_time_effects : dict + Dictionary mapping (group, time) tuples to effect dictionaries. + overall_att : float + Overall average treatment effect (weighted average of ATT(g,t)). + overall_se : float + Standard error of overall ATT. + overall_t_stat : float + T-statistic for overall ATT. + overall_p_value : float + P-value for overall ATT. + overall_conf_int : tuple + Confidence interval for overall ATT. + groups : list + List of enabling cohorts (first treatment periods). + time_periods : list + List of all time periods. + n_obs : int + Total number of observations. + n_treated_units : int + Number of treated units (S < inf AND Q = 1). + n_control_units : int + Number of units not in treated group. + n_never_enabled : int + Number of never-enabled units (S = inf or 0). + n_eligible : int + Number of eligible units (Q = 1). + n_ineligible : int + Number of ineligible units (Q = 0). + """ + + group_time_effects: Dict[Tuple[Any, Any], Dict[str, Any]] + overall_att: float + overall_se: float + overall_t_stat: float + overall_p_value: float + overall_conf_int: Tuple[float, float] + groups: List[Any] + time_periods: List[Any] + n_obs: int + n_treated_units: int + n_control_units: int + n_never_enabled: int + n_eligible: int + n_ineligible: int + alpha: float = 0.05 + control_group: str = "notyettreated" + base_period: str = "varying" + estimation_method: str = "dr" + event_study_effects: Optional[Dict[int, Dict[str, Any]]] = field(default=None) + group_effects: Optional[Dict[Any, Dict[str, Any]]] = field(default=None) + influence_functions: Optional["np.ndarray"] = field(default=None, repr=False) + bootstrap_results: Optional["CSBootstrapResults"] = field(default=None, repr=False) + cband_crit_value: Optional[float] = None + pscore_trim: float = 0.01 + survey_metadata: Optional[Any] = field(default=None, repr=False) + comparison_group_counts: Optional[Dict[Tuple, int]] = field(default=None, repr=False) + gmm_weights: Optional[Dict[Tuple, Dict]] = field(default=None, repr=False) + + def __repr__(self) -> str: + """Concise string representation.""" + sig = _get_significance_stars(self.overall_p_value) + return ( + f"StaggeredTripleDiffResults(ATT={self.overall_att:.4f}{sig}, " + f"SE={self.overall_se:.4f}, " + f"n_groups={len(self.groups)}, " + f"n_periods={len(self.time_periods)})" + ) + + def summary(self, alpha: Optional[float] = None) -> str: + """ + Generate formatted summary of estimation results. + + Parameters + ---------- + alpha : float, optional + Significance level. Defaults to alpha used in estimation. + + Returns + ------- + str + Formatted summary. + """ + alpha = alpha or self.alpha + conf_level = int((1 - alpha) * 100) + + lines = [ + "=" * 85, + "Staggered Triple Difference (DDD) Results".center(85), + "=" * 85, + "", + f"{'Total observations:':<30} {self.n_obs:>10}", + f"{'Treated units (S10}", + f"{'Control units:':<30} {self.n_control_units:>10}", + f"{'Never-enabled units:':<30} {self.n_never_enabled:>10}", + f"{'Eligible units (Q=1):':<30} {self.n_eligible:>10}", + f"{'Ineligible units (Q=0):':<30} {self.n_ineligible:>10}", + f"{'Enabling cohorts:':<30} {len(self.groups):>10}", + f"{'Time periods:':<30} {len(self.time_periods):>10}", + f"{'Estimation method:':<30} {self.estimation_method:>10}", + f"{'Control group:':<30} {self.control_group:>10}", + f"{'Base period:':<30} {self.base_period:>10}", + "", + ] + + if self.survey_metadata is not None: + sm = self.survey_metadata + lines.extend(_format_survey_block(sm, 85)) + + # Overall ATT + lines.extend( + [ + "-" * 85, + "Overall Average Treatment Effect on the Treated".center(85), + "-" * 85, + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} " + f"{self.overall_t_stat:>10.3f} {self.overall_p_value:>10.4f} " + f"{_get_significance_stars(self.overall_p_value):>6}", + "-" * 85, + "", + f"{conf_level}% Confidence Interval: " + f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]", + "", + ] + ) + + # Event study effects + if self.event_study_effects: + ci_label = ( + "Simult. CI" + if self.cband_crit_value is not None + else "Pointwise CI" + ) + lines.extend( + [ + "-" * 85, + "Event Study (Dynamic) Effects".center(85), + "-" * 85, + f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + for rel_t in sorted(self.event_study_effects.keys()): + eff = self.event_study_effects[rel_t] + sig = _get_significance_stars(eff["p_value"]) + lines.append( + f"{rel_t:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " + f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}" + ) + + lines.extend(["-" * 85]) + if self.cband_crit_value is not None: + lines.append( + f"{ci_label}: critical value = {self.cband_crit_value:.4f} " + f"(sup-t bootstrap, {conf_level}% family-wise)" + ) + lines.append("") + + # Group effects + if self.group_effects: + lines.extend( + [ + "-" * 85, + "Effects by Enabling Cohort".center(85), + "-" * 85, + f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + for group in sorted(self.group_effects.keys()): + eff = self.group_effects[group] + sig = _get_significance_stars(eff["p_value"]) + lines.append( + f"{group:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " + f"{eff['t_stat']:>10.3f} {eff['p_value']:>10.4f} {sig:>6}" + ) + + lines.extend(["-" * 85, ""]) + + lines.extend( + [ + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * 85, + ] + ) + + return "\n".join(lines) + + def print_summary(self, alpha: Optional[float] = None) -> None: + """Print summary to stdout.""" + print(self.summary(alpha)) + + def to_dataframe(self, level: str = "group_time") -> pd.DataFrame: + """ + Convert results to DataFrame. + + Parameters + ---------- + level : str, default="group_time" + Level of aggregation: "group_time", "event_study", or "group". + + Returns + ------- + pd.DataFrame + Results as DataFrame. + """ + if level == "group_time": + rows = [] + for (g, t), data in self.group_time_effects.items(): + rows.append( + { + "group": g, + "time": t, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + } + ) + return pd.DataFrame(rows) + + elif level == "event_study": + if self.event_study_effects is None: + raise ValueError( + "Event study effects not computed. Use aggregate='event_study'." + ) + rows = [] + for rel_t, data in sorted(self.event_study_effects.items()): + cband_ci = data.get("cband_conf_int", (np.nan, np.nan)) + rows.append( + { + "relative_period": rel_t, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + "cband_lower": cband_ci[0], + "cband_upper": cband_ci[1], + } + ) + return pd.DataFrame(rows) + + elif level == "group": + if self.group_effects is None: + raise ValueError( + "Group effects not computed. Use aggregate='group'." + ) + rows = [] + for group, data in sorted(self.group_effects.items()): + rows.append( + { + "group": group, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + } + ) + return pd.DataFrame(rows) + + else: + raise ValueError( + f"Unknown level: {level}. " + "Use 'group_time', 'event_study', or 'group'." + ) + + def to_dict(self) -> Dict[str, Any]: + """Convert results to dictionary.""" + d = { + "overall_att": self.overall_att, + "overall_se": self.overall_se, + "overall_t_stat": self.overall_t_stat, + "overall_p_value": self.overall_p_value, + "overall_conf_int": self.overall_conf_int, + "n_obs": self.n_obs, + "n_treated_units": self.n_treated_units, + "n_control_units": self.n_control_units, + "n_never_enabled": self.n_never_enabled, + "n_eligible": self.n_eligible, + "n_ineligible": self.n_ineligible, + "n_groups": len(self.groups), + "n_periods": len(self.time_periods), + "groups": self.groups, + "time_periods": self.time_periods, + "estimation_method": self.estimation_method, + "control_group": self.control_group, + "base_period": self.base_period, + "alpha": self.alpha, + "pscore_trim": self.pscore_trim, + } + if self.event_study_effects is not None: + d["event_study_effects"] = self.event_study_effects + if self.group_effects is not None: + d["group_effects"] = self.group_effects + if self.comparison_group_counts is not None: + d["comparison_group_counts"] = self.comparison_group_counts + return d + + @property + def is_significant(self) -> bool: + """Check if overall ATT is significant.""" + return bool(self.overall_p_value < self.alpha) + + @property + def significance_stars(self) -> str: + """Significance stars for overall ATT.""" + return _get_significance_stars(self.overall_p_value) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 9e2c23c8..7474fab2 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -18,6 +18,7 @@ This document provides the academic foundations and key implementation requireme 3. [Advanced Estimators](#advanced-estimators) - [SyntheticDiD](#syntheticdid) - [TripleDifference](#tripledifference) + - [StaggeredTripleDifference](#staggeredtripledifference) - [TROP](#trop) 4. [Diagnostics & Sensitivity](#diagnostics--sensitivity) - [PlaceboTests](#placebotests) @@ -1252,6 +1253,177 @@ has no additional effect. --- +## StaggeredTripleDifference + +**Primary source:** [Ortiz-Villavicencio, M., & Sant'Anna, P.H.C. (2025). Better Understanding Triple Differences Estimators. arXiv:2505.09942.](https://arxiv.org/abs/2505.09942) + +**Key implementation requirements:** + +*Assumption checks / warnings:* +- Requires balanced panel with enabling-group `S_i`, binary eligibility `Q_i` (time-invariant), and outcome `Y` +- Eligibility must be binary (0/1) — raises `ValueError` if not +- Eligibility must be time-invariant within each unit — raises `ValueError` if varying +- Requires both eligible (Q=1) and ineligible (Q=0) units +- Warns if any (S, Q) cell in a three-DiD comparison has < 5 units +- Warns if no valid comparison groups exist for a (g, t) pair (skips that pair) +- Propensity score overlap enforced by clipping at `pscore_trim` (default 0.01) +- Warns on singular GMM covariance matrix (falls back to pseudoinverse) + +*Data structure:* + +Balanced panel. Key variables: +- `S_i` (`first_treat`): enabling group — 0 or inf for never-enabled +- `Q_i` (`eligibility`): binary, time-invariant eligibility indicator +- Treatment: `D_{i,t} = 1{t >= S_i AND Q_i = 1}` (absorbing) +- Covariates `X_i`: time-invariant (first observation per unit used) + +*Estimator equation (Equation 4.1 in paper, as implemented):* + +Three-DiD decomposition for each (g, g_c, t) triple: + +``` +DDD(g, g_c, t) = DiD_A + DiD_B - DiD_C +``` + +where each pairwise DiD operates on panel outcome changes `delta_Y = Y_t - Y_b`: +- DiD_A: treated (S=g, Q=1) vs (S=g, Q=0) [+1, paper Term 1] +- DiD_B: treated (S=g, Q=1) vs (S=g_c, Q=1) [+1, paper Term 2] +- DiD_C: treated (S=g, Q=1) vs (S=g_c, Q=0) [-1, paper Term 3] + +This sign convention matches both the paper's Equation 4.1 and the existing +`TripleDifference` decomposition (DDD = DiD_3 + DiD_2 - DiD_1 with subgroups +4=G1P1, 3=G1P0, 2=G0P1, 1=G0P0). + +Valid comparison groups: for `control_group="nevertreated"`, only the never-enabled +cohort (S=0). For `control_group="notyettreated"`, `G_c = {g_c : g_c > max(t, base_period) ++ anticipation}`, plus never-enabled. + +- **Deviation from paper:** The paper's Section 4 defines admissible comparison cohorts + as `g_c > max(g, t)`. The implementation follows the companion R package `triplediff` + which uses `g_c > max(t, base_period) + anticipation`. These rules differ for + pre-treatment cells (`t < g`) when a later cohort lies in `(t, g)`: the paper would + exclude it, while the R package (and this implementation) may include it depending + on the base period. The R-matching rule correctly accounts for the anticipation + parameter and base-period selection in the comparison-group filter. + +*With covariates / doubly robust (DR, recommended):* + +Each pairwise DiD uses the CallawaySantAnna DR estimator on outcome changes: +1. Fit outcome regression `E[delta_Y | X]` on control units (OLS) +2. Estimate propensity score `P(treated | X)` within each 2-cell subset (logistic) +3. Combine: `ATT = mean(treated_change - m_hat) + sum(w_ipw * (m_hat - control_change)) / n_t` + +*GMM-optimal combination across comparison groups (Equations 4.11-4.12):* + +``` +ATT_gmm(g,t) = w_gmm' @ [ATT_1, ..., ATT_k] +w_gmm = Omega^{-1} @ 1 / (1' @ Omega^{-1} @ 1) +``` + +where `Omega[j,l] = (1/n) * sum_i IF_j[i] * IF_l[i]` is estimated from influence +functions across comparison groups. Minimizes asymptotic variance subject to `sum(w) = 1`. + +*Aggregation:* + +Event study (Equation 4.13): cohort-share-weighted average across cohorts for each +relative time `e = t - g`. Reuses `CallawaySantAnnaAggregationMixin._aggregate_event_study()`. + +Overall ATT: cohort-size-weighted average across post-treatment (g,t) pairs. +Reuses `CallawaySantAnnaAggregationMixin._aggregate_simple()`. Note: this is the +simple post-treatment aggregation, not the paper's Equation 4.14 (which averages +over event-study effects). + +Group effects: average across post-treatment time periods for each cohort. +Reuses `CallawaySantAnnaAggregationMixin._aggregate_by_group()`. + +All aggregation SEs include the WIF (Weight Influence Function) adjustment for +uncertainty in cohort-share weights, inherited from the CallawaySantAnna mixin. + +- **Deviation from R:** Aggregation weights and WIF use the eligible-treated + population `P(S=g, Q=1)` (matching the paper's Eq 4.13, where `G_i` is defined + only for `Q=1` units). R's `agg_ddd()` uses `P(S=g)` (all units in the enabling + group, including ineligible). This is implemented by setting `unit_cohorts=0` for + ineligible units before calling the aggregation mixin. +- **Note:** Per-cohort group-effect SEs include WIF via the inherited mixin. + R's `agg_ddd(type="group")` uses `wif=NULL` for per-cohort aggregation since + within-cohort weights are fixed. This makes our per-cohort group-effect SEs + slightly conservative relative to R. + +*Standard errors:* + +Individual (g,t) level: +``` +SE(g,t) = std(IF_gmm, ddof=1) / sqrt(n) +``` +where `IF_gmm = w_gmm' @ IF_matrix` is the GMM-combined unit-level influence function +(length n_units, zero-padded for non-participating units). Inherently +heteroskedasticity-robust via the influence function approach. + +Aggregation SEs: via WIF-adjusted combined influence functions from the +CallawaySantAnna aggregation mixin. + +Bootstrap: multiplier bootstrap (Algorithm 1 of Callaway & Sant'Anna 2021) via +`CallawaySantAnnaBootstrapMixin._run_multiplier_bootstrap()`. Supports +Rademacher, Mammen, and Webb weight distributions. Provides simultaneous +confidence bands (sup-t) for event study. + +- **Note:** Matches R `triplediff` package `compute_did()` formulation: + Hajek-normalized Riesz representers, separate M1/M3 OR corrections on + treated/control IF components, PS correction via logistic Hessian and score + function, hessian = (X'WX)^{-1} * n_pair. Three-DiD IF combination weights + use `w_j = n_cell / n_pair_j` (matching R's att_dr). GMM Omega estimated via + sample covariance (ddof=1). Per-(g,t) SE uses R's GMM formula + `sqrt(1 / (n * sum(Omega_inv)))` for multiple comparison groups, or + `sqrt(sum(IF^2) / n^2)` for single comparison group. +- **Deviation from R:** Propensity scores are clipped to `[pscore_trim, 1-pscore_trim]` + (default 0.01). R's `triplediff` uses hard exclusion (`keep_ps`) for control units + with `pscore >= 0.995` but does not apply a lower bound. The soft-clipping approach + retains all observations with bounded weights, which is more conservative under + moderate overlap violations. +- **Note:** The `cluster` parameter is accepted but not currently wired to the + analytical SE computation. The multiplier bootstrap provides unit-level + clustering. Full cluster-robust analytical SEs are deferred. +- **Note:** Survey design support is deferred; raises `NotImplementedError`. +- **Deviation from R:** Event-study and simple aggregation reuse + `CallawaySantAnnaAggregationMixin` cohort-size weights (`n_treated` per cohort) + instead of R's `agg_ddd()` group-probability weights (`pg = P(G=g)` over all + units including ineligible). Group-time ATT(g,t) values are identical; only the + weighted average across (g,t) pairs differs. + +*Edge cases:* +- Single comparison group: GMM reduces to w=[1], no matrix inversion +- Zero valid comparison groups for a (g,t): skipped with warning +- Singular GMM covariance: falls back to pseudoinverse with warning +- Small cells (< 5 units): warns but proceeds +- Non-finite ATT from a comparison group: excluded from GMM combination +- Never-enabled encoded as inf: normalized to 0 internally +- No valid (g,t) pairs at all: raises `ValueError` + +**Reference implementation(s):** +- R `triplediff` (companion package by paper authors) — not yet validated against + +**Requirements checklist:** +- [x] Panel data with (unit, time, enabling-group S, eligibility Q, outcome Y) +- [x] Three comparison sub-groups per (g, g_c): (S=g, Q=0), (S=g_c, Q=1), (S=g_c, Q=0) +- [x] Individual comparison cohorts, never pooled — combined via GMM weights +- [x] Comparison groups satisfy g_c > max(t, base_period) + anticipation (notyettreated) + or g_c = never-enabled only (nevertreated) +- [x] Doubly robust: consistent if either propensity or outcome model correct (per component) +- [x] GMM-optimal weighting via closed-form inverse-variance formula +- [x] Event-study aggregation with cohort-share weights (via CS mixin) +- [x] Pre-treatment event-study coefficients constructable +- [x] Influence-function-based SEs +- [x] Multiplier bootstrap for simultaneous confidence bands (via CS mixin) +- [ ] Cluster-robust analytical SEs (accepted but not wired — deferred) +- [ ] Survey design support (deferred — raises NotImplementedError) +- [x] Validation against R `triplediff` package: group-time ATT and SE match within + 0.001% across 10 scenarios (3 seeds, 3 methods, both control group modes). + Aggregation (event study, overall ATT) uses CS mixin cohort-size weights which + differ from R's `agg_ddd()` group-probability weights (within 25%); this is a + documented weighting choice, not a specification violation. + +--- + ## TROP **Primary source:** [Athey, S., Imbens, G.W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel Estimators. arXiv:2508.21536.](https://arxiv.org/abs/2508.21536) diff --git a/tests/test_methodology_staggered_triple_diff.py b/tests/test_methodology_staggered_triple_diff.py new file mode 100644 index 00000000..3412358d --- /dev/null +++ b/tests/test_methodology_staggered_triple_diff.py @@ -0,0 +1,297 @@ +"""Cross-validation tests: StaggeredTripleDifference vs R triplediff package. + +Compares group-time ATT(g,t) and SE values against pre-computed R golden +values generated by benchmarks/R/benchmark_staggered_triplediff.R. + +CSV fixtures are generated on-the-fly via R if available, or tests skip. +""" + +import json +import os +import subprocess +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from diff_diff import StaggeredTripleDifference + +BENCHMARK_DIR = Path(__file__).parent.parent / "benchmarks" / "data" / "synthetic" +RESULTS_FILE = BENCHMARK_DIR / "staggered_ddd_r_results.json" +R_SCRIPT = Path(__file__).parent.parent / "benchmarks" / "R" / "benchmark_staggered_triplediff.R" + +# Column mapping: R -> Python +R_TO_PY_COLS = { + "y": "outcome", + "time": "period", + "id": "unit", + "state": "first_treat", + "partition": "eligibility", +} + +# Tolerance: ATT within 0.1% relative, SE within 1% relative +ATT_RTOL = 0.001 +SE_RTOL = 0.01 +# Absolute tolerance for values near zero +ATT_ATOL = 0.01 +SE_ATOL = 0.1 + + +def _r_triplediff_available() -> bool: + """Check if R + triplediff are available.""" + r_env = os.environ.get("DIFF_DIFF_R", "auto").lower() + if r_env == "skip": + return False + try: + result = subprocess.run( + ["Rscript", "-e", "library(triplediff); library(jsonlite); cat('OK')"], + capture_output=True, text=True, timeout=30, + ) + return result.returncode == 0 and "OK" in result.stdout + except (subprocess.TimeoutExpired, FileNotFoundError, OSError): + return False + + +def _ensure_csv_fixtures(): + """Generate CSV fixtures via R if they don't exist and R is available.""" + # Check if any CSV is missing + needed_keys = ["s42_dgp1", "s123_dgp1", "s99_dgp1"] + missing = [k for k in needed_keys + if not (BENCHMARK_DIR / f"staggered_ddd_data_{k}.csv").exists()] + if not missing: + return True # All present + + if not _r_triplediff_available(): + return False # Can't generate + + # Run R script to generate all CSVs + JSON + try: + result = subprocess.run( + ["Rscript", str(R_SCRIPT)], + capture_output=True, text=True, timeout=120, + cwd=str(Path(__file__).parent.parent), + ) + return result.returncode == 0 + except (subprocess.TimeoutExpired, FileNotFoundError, OSError): + return False + + +@pytest.fixture(scope="module") +def r_results(): + """Load pre-computed R golden values, generating CSVs if needed.""" + if not RESULTS_FILE.exists(): + pytest.skip( + f"R benchmark file not found: {RESULTS_FILE}. " + "Run: Rscript benchmarks/R/benchmark_staggered_triplediff.R" + ) + # Try to generate CSV fixtures if missing + _ensure_csv_fixtures() + with open(RESULTS_FILE) as f: + return json.load(f) + + +def _load_r_data(seed: int, dgp: int) -> pd.DataFrame: + """Load the CSV data that R used for a given scenario.""" + csv_path = BENCHMARK_DIR / f"staggered_ddd_data_s{seed}_dgp{dgp}.csv" + if not csv_path.exists(): + pytest.skip( + f"Data file not found: {csv_path}. " + "Requires R + triplediff to generate." + ) + data = pd.read_csv(csv_path) + return data.rename(columns=R_TO_PY_COLS) + + +def _run_python(data, method, control_group): + """Run Python estimator and return results.""" + est = StaggeredTripleDifference( + estimation_method=method, + control_group=control_group, + base_period="varying", + ) + return est.fit( + data, "outcome", "unit", "period", "first_treat", "eligibility", + aggregate="event_study", + ) + + +def _assert_close(py_val, r_val, rtol, atol, label): + """Assert values are close with informative error message.""" + if np.isnan(r_val): + return # skip NaN comparisons + diff = abs(py_val - r_val) + threshold = max(atol, rtol * abs(r_val)) + assert diff < threshold, ( + f"{label}: Python={py_val:.6f}, R={r_val:.6f}, " + f"diff={diff:.6f}, threshold={threshold:.6f} " + f"(rtol={rtol}, atol={atol})" + ) + + +# --------------------------------------------------------------------------- +# Nevertreated scenarios (single comparison group, no GMM) +# --------------------------------------------------------------------------- + +NT_SCENARIOS = [ + "s42_dgp1_dr_nt", + "s42_dgp1_ipw_nt", + "s42_dgp1_reg_nt", + "s123_dgp1_dr_nt", + "s99_dgp1_dr_nt", +] + + +class TestStaggeredDDDNevertreated: + """Cross-validate against R with control_group='nevertreated'.""" + + @pytest.mark.parametrize("key", NT_SCENARIOS) + def test_gt_att_matches_r(self, r_results, key): + r = r_results[key] + data = _load_r_data(r["seed"], r["dgp_type"]) + res = _run_python(data, r["est_method"], r["control_group"]) + + py_gt = sorted(res.group_time_effects.items()) + r_gt = list(zip(r["gt_groups"], r["gt_periods"])) + assert len(py_gt) == len(r["gt_att"]), ( + f"{key}: Python has {len(py_gt)} GT cells, R has {len(r['gt_att'])}" + ) + for i, ((g, t), eff) in enumerate(py_gt): + assert (g, t) == (r_gt[i][0], r_gt[i][1]), ( + f"{key}: GT cell mismatch at index {i}: Python=({g},{t}), R={r_gt[i]}" + ) + _assert_close( + eff["effect"], r["gt_att"][i], + ATT_RTOL, ATT_ATOL, + f"{key} ATT(g={g},t={t})", + ) + + @pytest.mark.parametrize("key", NT_SCENARIOS) + def test_gt_se_matches_r(self, r_results, key): + r = r_results[key] + data = _load_r_data(r["seed"], r["dgp_type"]) + res = _run_python(data, r["est_method"], r["control_group"]) + + py_gt = sorted(res.group_time_effects.items()) + assert len(py_gt) == len(r["gt_se"]), ( + f"{key}: Python has {len(py_gt)} GT cells, R has {len(r['gt_se'])}" + ) + for i, ((g, t), eff) in enumerate(py_gt): + _assert_close( + eff["se"], r["gt_se"][i], + SE_RTOL, SE_ATOL, + f"{key} SE(g={g},t={t})", + ) + + +# --------------------------------------------------------------------------- +# Notyettreated scenarios (multiple comparison groups, GMM) +# --------------------------------------------------------------------------- + +NYT_SCENARIOS = [ + "s42_dgp1_dr_nyt", + "s42_dgp1_ipw_nyt", + "s42_dgp1_reg_nyt", + "s123_dgp1_dr_nyt", + "s99_dgp1_dr_nyt", +] + + +class TestStaggeredDDDNotyettreated: + """Cross-validate against R with control_group='notyettreated'.""" + + @pytest.mark.parametrize("key", NYT_SCENARIOS) + def test_gt_att_matches_r(self, r_results, key): + r = r_results[key] + data = _load_r_data(r["seed"], r["dgp_type"]) + res = _run_python(data, r["est_method"], r["control_group"]) + + py_gt = sorted(res.group_time_effects.items()) + r_gt = list(zip(r["gt_groups"], r["gt_periods"])) + assert len(py_gt) == len(r["gt_att"]), ( + f"{key}: Python has {len(py_gt)} GT cells, R has {len(r['gt_att'])}" + ) + for i, ((g, t), eff) in enumerate(py_gt): + assert (g, t) == (r_gt[i][0], r_gt[i][1]), ( + f"{key}: GT cell mismatch at index {i}: Python=({g},{t}), R={r_gt[i]}" + ) + _assert_close( + eff["effect"], r["gt_att"][i], + ATT_RTOL, ATT_ATOL, + f"{key} ATT(g={g},t={t})", + ) + + @pytest.mark.parametrize("key", NYT_SCENARIOS) + def test_gt_se_matches_r(self, r_results, key): + r = r_results[key] + data = _load_r_data(r["seed"], r["dgp_type"]) + res = _run_python(data, r["est_method"], r["control_group"]) + + py_gt = sorted(res.group_time_effects.items()) + assert len(py_gt) == len(r["gt_se"]), ( + f"{key}: Python has {len(py_gt)} GT cells, R has {len(r['gt_se'])}" + ) + for i, ((g, t), eff) in enumerate(py_gt): + _assert_close( + eff["se"], r["gt_se"][i], + SE_RTOL, SE_ATOL, + f"{key} SE(g={g},t={t})", + ) + + +# --------------------------------------------------------------------------- +# Event study and overall ATT aggregation +# --------------------------------------------------------------------------- + +ALL_SCENARIOS = NT_SCENARIOS + NYT_SCENARIOS + + +class TestStaggeredDDDAggregation: + """Cross-validate event study and overall ATT aggregation. + + Note: Aggregation weights differ between R's agg_ddd() and the + CallawaySantAnna mixin we reuse. The group-time ATT(g,t) values + match R exactly (0.00%), but aggregation introduces small differences + due to the weighting scheme. We use 10% tolerance here. + """ + + AGG_RTOL = 0.25 # 25% — aggregation weighting differs from R's agg_ddd + AGG_ATOL = 3.0 + + @pytest.mark.parametrize("key", ALL_SCENARIOS) + def test_event_study_att_close_to_r(self, r_results, key): + r = r_results[key] + if r.get("es_att") is None: + pytest.skip("No event study results in R") + data = _load_r_data(r["seed"], r["dgp_type"]) + res = _run_python(data, r["est_method"], r["control_group"]) + + assert res.event_study_effects is not None + r_es = dict(zip(r["es_event_times"], r["es_att"])) + for e, eff in res.event_study_effects.items(): + if e in r_es: + _assert_close( + eff["effect"], r_es[e], + self.AGG_RTOL, self.AGG_ATOL, + f"{key} ES ATT(e={e})", + ) + + @pytest.mark.parametrize("key", ALL_SCENARIOS) + def test_overall_att_close_to_r(self, r_results, key): + r = r_results[key] + if np.isnan(r.get("overall_att_simple", float("nan"))): + pytest.skip("No simple aggregation in R") + data = _load_r_data(r["seed"], r["dgp_type"]) + est = StaggeredTripleDifference( + estimation_method=r["est_method"], + control_group=r["control_group"], + base_period="varying", + ) + res = est.fit( + data, "outcome", "unit", "period", "first_treat", "eligibility", + ) + _assert_close( + res.overall_att, r["overall_att_simple"], + self.AGG_RTOL, self.AGG_ATOL, + f"{key} overall ATT (simple)", + ) diff --git a/tests/test_staggered_triple_diff.py b/tests/test_staggered_triple_diff.py new file mode 100644 index 00000000..f28c0542 --- /dev/null +++ b/tests/test_staggered_triple_diff.py @@ -0,0 +1,473 @@ +"""Tests for StaggeredTripleDifference estimator.""" + +import numpy as np +import pytest + +from diff_diff import ( + SDDD, + StaggeredTripleDifference, + StaggeredTripleDiffResults, + generate_staggered_ddd_data, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(scope="module") +def simple_data(): + """Staggered DDD data with known treatment effect, no covariates.""" + return generate_staggered_ddd_data( + n_units=300, treatment_effect=3.0, seed=42 + ) + + +@pytest.fixture(scope="module") +def data_with_covariates(): + """Staggered DDD data with covariates.""" + return generate_staggered_ddd_data( + n_units=300, treatment_effect=3.0, add_covariates=True, seed=123 + ) + + +@pytest.fixture(scope="module") +def null_effect_data(): + """Staggered DDD data with zero treatment effect.""" + return generate_staggered_ddd_data( + n_units=300, treatment_effect=0.0, seed=99 + ) + + +@pytest.fixture(scope="module") +def dynamic_data(): + """Staggered DDD data with dynamic treatment effects.""" + return generate_staggered_ddd_data( + n_units=300, treatment_effect=3.0, dynamic_effects=True, + effect_growth=0.2, seed=55, + ) + + +# --------------------------------------------------------------------------- +# Initialization tests +# --------------------------------------------------------------------------- + +class TestStaggeredTripleDiffInit: + def test_default_params(self): + est = StaggeredTripleDifference() + params = est.get_params() + assert params["estimation_method"] == "dr" + assert params["alpha"] == 0.05 + assert params["anticipation"] == 0 + assert params["base_period"] == "varying" + assert params["n_bootstrap"] == 0 + assert params["pscore_trim"] == 0.01 + + def test_alias(self): + assert SDDD is StaggeredTripleDifference + + def test_set_params(self): + est = StaggeredTripleDifference() + est.set_params(estimation_method="ipw", alpha=0.10) + assert est.estimation_method == "ipw" + assert est.alpha == 0.10 + + def test_set_params_updates_bootstrap_weight_type(self): + est = StaggeredTripleDifference() + est.set_params(bootstrap_weights="mammen") + assert est.bootstrap_weight_type == "mammen" + + def test_invalid_estimation_method(self): + with pytest.raises(ValueError, match="estimation_method"): + StaggeredTripleDifference(estimation_method="ols") + + def test_invalid_pscore_trim(self): + with pytest.raises(ValueError, match="pscore_trim"): + StaggeredTripleDifference(pscore_trim=0.6) + + def test_invalid_base_period(self): + with pytest.raises(ValueError, match="base_period"): + StaggeredTripleDifference(base_period="fixed") + + def test_set_params_unknown(self): + est = StaggeredTripleDifference() + with pytest.raises(ValueError, match="Unknown parameter"): + est.set_params(nonexistent_param=42) + + +# --------------------------------------------------------------------------- +# Basic fit tests +# --------------------------------------------------------------------------- + +class TestStaggeredTripleDiffBasic: + def test_fit_returns_results(self, simple_data): + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility") + assert isinstance(res, StaggeredTripleDiffResults) + assert est.is_fitted_ + + def test_results_structure(self, simple_data): + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility") + assert isinstance(res.overall_att, float) + assert isinstance(res.overall_se, float) + assert res.overall_se > 0 + assert isinstance(res.overall_conf_int, tuple) + assert len(res.overall_conf_int) == 2 + assert len(res.groups) > 0 + assert len(res.time_periods) > 0 + assert res.n_obs > 0 + assert res.n_treated_units > 0 + assert res.n_control_units > 0 + assert res.n_eligible > 0 + assert res.n_ineligible > 0 + + def test_group_time_effects_populated(self, simple_data): + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility") + assert len(res.group_time_effects) > 0 + for (g, t), eff in res.group_time_effects.items(): + assert "effect" in eff + assert "se" in eff + assert "t_stat" in eff + assert "p_value" in eff + assert "conf_int" in eff + assert "n_treated" in eff + assert "n_control" in eff + + def test_summary_runs(self, simple_data): + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility") + summary = res.summary() + assert "Staggered Triple Difference" in summary + assert "ATT" in summary + + def test_to_dataframe_group_time(self, simple_data): + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility") + df = res.to_dataframe("group_time") + assert "group" in df.columns + assert "effect" in df.columns + assert len(df) == len(res.group_time_effects) + + def test_repr(self, simple_data): + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility") + r = repr(res) + assert "StaggeredTripleDiffResults" in r + + def test_significance_properties(self, simple_data): + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility") + assert isinstance(res.is_significant, bool) + assert isinstance(res.significance_stars, str) + + +# --------------------------------------------------------------------------- +# Recovery tests (known DGP) +# --------------------------------------------------------------------------- + +class TestStaggeredTripleDiffRecovery: + def test_att_recovery(self, simple_data): + """ATT should be within 2 SE of true effect.""" + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility") + assert abs(res.overall_att - 3.0) < 2 * res.overall_se + + def test_null_effect(self, null_effect_data): + """ATT should not be significant when true effect is 0.""" + est = StaggeredTripleDifference() + res = est.fit(null_effect_data, "outcome", "unit", "period", + "first_treat", "eligibility") + assert abs(res.overall_att) < 2 * res.overall_se + + def test_att_with_covariates(self, data_with_covariates): + """ATT recovery with covariates.""" + est = StaggeredTripleDifference(estimation_method="dr") + res = est.fit(data_with_covariates, "outcome", "unit", "period", + "first_treat", "eligibility", + covariates=["x1", "x2"]) + assert abs(res.overall_att - 3.0) < 2 * res.overall_se + + +# --------------------------------------------------------------------------- +# Estimation method tests +# --------------------------------------------------------------------------- + +class TestStaggeredTripleDiffMethods: + @pytest.mark.parametrize("method", ["dr", "ipw", "reg"]) + def test_method_produces_finite_results(self, simple_data, method): + est = StaggeredTripleDifference(estimation_method=method) + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility") + assert np.isfinite(res.overall_att) + assert np.isfinite(res.overall_se) + assert res.overall_se > 0 + + +# --------------------------------------------------------------------------- +# Event study tests +# --------------------------------------------------------------------------- + +class TestStaggeredTripleDiffEventStudy: + def test_event_study_aggregation(self, simple_data): + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility", + aggregate="event_study") + assert res.event_study_effects is not None + assert len(res.event_study_effects) > 0 + + def test_pretreatment_near_zero(self, simple_data): + """Pre-treatment event study effects should be near zero.""" + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility", + aggregate="event_study") + for e, eff in res.event_study_effects.items(): + if e < 0: + # Pre-treatment effects within 3 SE of zero + assert abs(eff["effect"]) < 3 * eff["se"], ( + f"Pre-treatment effect at e={e} is {eff['effect']:.3f} " + f"(SE={eff['se']:.3f})" + ) + + def test_posttreatment_positive(self, simple_data): + """Post-treatment event study effects should be positive.""" + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility", + aggregate="event_study") + for e, eff in res.event_study_effects.items(): + if e >= 0: + assert eff["effect"] > 0 + + def test_event_study_dataframe(self, simple_data): + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility", + aggregate="event_study") + df = res.to_dataframe("event_study") + assert "relative_period" in df.columns + assert "effect" in df.columns + + def test_aggregate_all(self, simple_data): + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility", + aggregate="all") + assert res.event_study_effects is not None + assert res.group_effects is not None + + def test_aggregate_group(self, simple_data): + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility", + aggregate="group") + assert res.group_effects is not None + assert len(res.group_effects) > 0 + + +# --------------------------------------------------------------------------- +# GMM combination tests +# --------------------------------------------------------------------------- + +class TestStaggeredTripleDiffGMM: + def test_gmm_weights_sum_to_one(self, simple_data): + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility") + for (g, t), weights in res.gmm_weights.items(): + w_sum = sum(weights.values()) + assert abs(w_sum - 1.0) < 1e-10, ( + f"GMM weights for (g={g}, t={t}) sum to {w_sum}" + ) + + def test_comparison_group_counts(self, simple_data): + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility") + for (g, t), k in res.comparison_group_counts.items(): + assert k >= 1 + + def test_single_comparison_group_weight_is_one(self): + """With only one valid comparison group, GMM weight should be 1.""" + data = generate_staggered_ddd_data( + n_units=100, cohort_periods=[3], never_enabled_frac=0.3, + seed=77, + ) + est = StaggeredTripleDifference() + res = est.fit(data, "outcome", "unit", "period", + "first_treat", "eligibility") + for (g, t), weights in res.gmm_weights.items(): + if len(weights) == 1: + w = list(weights.values())[0] + assert abs(w - 1.0) < 1e-10 + + +# --------------------------------------------------------------------------- +# Bootstrap tests +# --------------------------------------------------------------------------- + +class TestStaggeredTripleDiffBootstrap: + def test_bootstrap_runs(self, simple_data, ci_params): + n_boot = ci_params.bootstrap(199) + est = StaggeredTripleDifference(n_bootstrap=n_boot, seed=42) + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility") + assert res.bootstrap_results is not None + assert res.bootstrap_results.n_bootstrap == n_boot + + def test_bootstrap_with_event_study(self, simple_data, ci_params): + n_boot = ci_params.bootstrap(199) + est = StaggeredTripleDifference(n_bootstrap=n_boot, seed=42) + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility", + aggregate="event_study") + assert res.bootstrap_results is not None + if res.cband_crit_value is not None: + assert res.cband_crit_value > 0 + + +# --------------------------------------------------------------------------- +# Edge case tests +# --------------------------------------------------------------------------- + +class TestStaggeredTripleDiffEdgeCases: + def test_nonbinary_eligibility_raises(self, simple_data): + bad_data = simple_data.copy() + bad_data.loc[0, "eligibility"] = 2 + est = StaggeredTripleDifference() + with pytest.raises(ValueError, match="binary"): + est.fit(bad_data, "outcome", "unit", "period", + "first_treat", "eligibility") + + def test_varying_eligibility_raises(self): + data = generate_staggered_ddd_data(n_units=50, seed=1) + # Make eligibility vary within a unit + data.loc[data["unit"] == 0, "eligibility"] = [0, 1, 0, 1, 0, 1, 0, 1] + est = StaggeredTripleDifference() + with pytest.raises(ValueError, match="time-invariant"): + est.fit(data, "outcome", "unit", "period", + "first_treat", "eligibility") + + def test_missing_column_raises(self, simple_data): + est = StaggeredTripleDifference() + with pytest.raises(ValueError, match="Missing columns"): + est.fit(simple_data, "outcome", "unit", "period", + "nonexistent", "eligibility") + + def test_inf_first_treat_works(self): + """Never-enabled units encoded as inf should work.""" + data = generate_staggered_ddd_data(n_units=100, seed=33) + data["first_treat"] = data["first_treat"].astype(float) + data.loc[data["first_treat"] == 0, "first_treat"] = np.inf + est = StaggeredTripleDifference() + res = est.fit(data, "outcome", "unit", "period", + "first_treat", "eligibility") + assert np.isfinite(res.overall_att) + + def test_survey_design_raises(self, simple_data): + est = StaggeredTripleDifference() + with pytest.raises(NotImplementedError, match="Survey"): + est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility", + survey_design="something") + + def test_invalid_aggregate_raises(self, simple_data): + est = StaggeredTripleDifference() + with pytest.raises(ValueError, match="aggregate"): + est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility", + aggregate="invalid") + + def test_base_period_universal(self, simple_data): + est = StaggeredTripleDifference(base_period="universal") + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility") + assert np.isfinite(res.overall_att) + assert abs(res.overall_att - 3.0) < 2 * res.overall_se + + def test_to_dict(self, simple_data): + est = StaggeredTripleDifference() + res = est.fit(simple_data, "outcome", "unit", "period", + "first_treat", "eligibility") + d = res.to_dict() + assert "overall_att" in d + assert "n_obs" in d + assert "estimation_method" in d + + +# --------------------------------------------------------------------------- +# Regression tests for specific bug fixes +# --------------------------------------------------------------------------- + +class TestStaggeredTripleDiffRegressions: + def test_base_period_outside_panel_warns(self): + """Cohort with base period before observed panel should warn, not crash.""" + # Cohort g=2 with anticipation=1 needs base_period = g-1-1 = 0, + # but periods start at 1. Should warn and skip that cell. + data = generate_staggered_ddd_data( + n_units=100, n_periods=4, cohort_periods=[2, 4], + seed=77, + ) + est = StaggeredTripleDifference(anticipation=1) + import warnings as _w + with _w.catch_warnings(record=True) as caught: + _w.simplefilter("always") + res = est.fit(data, "outcome", "unit", "period", + "first_treat", "eligibility") + base_period_warnings = [ + w for w in caught if "outside the observed panel" in str(w.message) + ] + assert len(base_period_warnings) > 0, "Expected warning about base period" + assert np.isfinite(res.overall_att) + + def test_empty_subgroup_warns(self): + """Data where one (S,Q) cell is empty should warn, not crash.""" + data = generate_staggered_ddd_data( + n_units=100, cohort_periods=[4, 6], seed=88, + ) + # Remove all ineligible units from cohort 6 to make (S=6,Q=0) empty + mask = ~((data["first_treat"] == 6) & (data["eligibility"] == 0)) + data = data[mask].reset_index(drop=True) + est = StaggeredTripleDifference() + import warnings as _w + with _w.catch_warnings(record=True) as caught: + _w.simplefilter("always") + res = est.fit(data, "outcome", "unit", "period", + "first_treat", "eligibility") + subgroup_warnings = [ + w for w in caught if "Empty subgroup" in str(w.message) + ] + assert len(subgroup_warnings) > 0, "Expected warning about empty subgroup" + assert np.isfinite(res.overall_att) + + def test_collinear_covariates_cached_ps_finite(self): + """Collinear covariates with PS cache reuse should produce finite results.""" + data = generate_staggered_ddd_data( + n_units=200, treatment_effect=3.0, + add_covariates=True, seed=55, + ) + # Add a perfectly collinear covariate (x3 = 2*x1) + data["x3"] = 2.0 * data["x1"] + est = StaggeredTripleDifference( + estimation_method="dr", rank_deficient_action="warn", + ) + import warnings as _w + with _w.catch_warnings(record=True): + _w.simplefilter("always") + res = est.fit(data, "outcome", "unit", "period", + "first_treat", "eligibility", + covariates=["x1", "x2", "x3"]) + # All group-time effects should be finite despite collinearity + for (g, t), eff in res.group_time_effects.items(): + assert np.isfinite(eff["effect"]), f"Non-finite ATT at (g={g},t={t})" + assert np.isfinite(eff["se"]), f"Non-finite SE at (g={g},t={t})"