Skip to content

Commit 70e09bd

Browse files
igerberclaude
andcommitted
Fix P0 control-group bug, stale CS API calls, and SyntheticDiD translation
- Fix always-treated snippet that silently dropped never-treated controls by matching actual estimator logic (first_treat > 0 && <= min_period) - Replace stale results.bootstrap()/aggregate()/att with correct CS API (n_bootstrap= at constructor, aggregate= at fit time, overall_att) - Fix SyntheticDiD R comparison to derive ever-treated indicator and post_periods from data instead of passing time-varying treatment Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 0ef224b commit 70e09bd

3 files changed

Lines changed: 33 additions & 20 deletions

File tree

docs/python_comparison.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,8 @@ Staggered DiD (Callaway-Sant'Anna)
408408
unit='unit',
409409
time='time',
410410
first_treat='first_treat',
411-
covariates=['x1', 'x2']
411+
covariates=['x1', 'x2'],
412+
aggregate='event_study'
412413
)
413414
event_study = results.event_study_effects
414415

docs/r_comparison.rst

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ staggered DiD. Here's how to translate common operations:
114114
115115
.. code-block:: python
116116
117-
# Python
117+
# Python (unlike R's aggte(), aggregation is requested at fit time)
118+
results = cs.fit(data, outcome='Y', time='period', unit='id',
119+
first_treat='G', aggregate='all')
118120
overall_att = results.overall_att # Simple aggregation
119121
event_study = results.event_study_effects # Dynamic
120122
by_group = results.group_effects # By cohort
@@ -191,14 +193,20 @@ The synthdid package implements Arkhangelsky et al. (2021):
191193
# Python
192194
from diff_diff import SyntheticDiD
193195
196+
# SyntheticDiD requires a time-invariant ever-treated indicator
197+
data['ever_treated'] = data.groupby('unit')['treatment'].transform('max')
198+
199+
# Derive post-treatment periods from treatment timing
200+
post_periods = sorted(data.loc[data['treatment'] == 1, 'time'].unique())
201+
194202
sdid = SyntheticDiD()
195203
results = sdid.fit(
196204
data,
197205
outcome='Y',
198206
unit='unit',
199207
time='time',
200-
treatment='treatment',
201-
post_periods=[T0, T0+1, T0+2]
208+
treatment='ever_treated',
209+
post_periods=post_periods
202210
)
203211
204212
Key Differences

docs/troubleshooting.rst

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,14 @@ Staggered Adoption Issues
205205
print(data.groupby('first_treat')['unit_id'].nunique())
206206
207207
# Use bootstrap for better inference
208-
results = cs.fit(data, ...)
209-
bootstrap_results = results.bootstrap(n_bootstrap=999)
208+
cs = CallawaySantAnna(n_bootstrap=999)
209+
results = cs.fit(data, outcome='y', unit='unit_id',
210+
time='period', first_treat='first_treat',
211+
aggregate='event_study')
210212
211-
# Aggregate to get more precise estimates
212-
event_study = results.aggregate('event_time')
213-
overall_att = results.att # Aggregated ATT
213+
# Access aggregated results
214+
print(results.overall_att) # Overall ATT
215+
print(results.event_study_effects) # Event study effects
214216
215217
Visualization Issues
216218
--------------------
@@ -232,9 +234,11 @@ Visualization Issues
232234
# Specify reference period explicitly
233235
plot_event_study(results, reference_period=-1)
234236
235-
# For CallawaySantAnna, aggregate first
236-
event_study = results.aggregate('event_time')
237-
plot_event_study(event_study)
237+
# For CallawaySantAnna, fit with aggregate='event_study'
238+
results = cs.fit(data, outcome='y', unit='unit_id',
239+
time='period', first_treat='first_treat',
240+
aggregate='event_study')
241+
plot_event_study(results)
238242
239243
"Plot doesn't show in Jupyter"
240244
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -510,15 +514,15 @@ pre-treatment outcomes exist to construct counterfactuals.
510514

511515
.. code-block:: python
512516
513-
# Identify always-treated units
514-
always_treated = data.groupby('unit_id').apply(
515-
lambda g: (g['period'] >= g['first_treat']).all()
516-
)
517-
print(f"Always-treated units: {always_treated.sum()}")
517+
# Identify always-treated units (treated at or before first observed period)
518+
# Exclude never-treated (first_treat == 0) which are the control group
519+
unit_ft = data.groupby('unit_id')['first_treat'].first()
520+
min_period = data['period'].min()
521+
always_treated = unit_ft[(unit_ft > 0) & (unit_ft <= min_period)]
522+
print(f"Always-treated units: {len(always_treated)}")
518523
519-
# Drop always-treated units
520-
keep_units = always_treated[~always_treated].index
521-
data = data[data['unit_id'].isin(keep_units)]
524+
# Drop always-treated units (keep never-treated controls)
525+
data = data[~data['unit_id'].isin(always_treated.index)]
522526
523527
"Horizons not identified without never-treated units"
524528
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

0 commit comments

Comments
 (0)