Skip to content

Commit 5930286

Browse files
igerberclaude
andcommitted
Fix CS df key name, re-read df after aggregation, propagate ContinuousDiD df
- Fix precomputed key: "df_survey" not "survey_df" in CS aggregation - CS: re-read df_survey from precomputed after aggregation so overall ATT inference uses updated n_valid-1 when replicate columns are dropped - ContinuousDiD: track _rep_n_valid across replicate IF calls, use min(n_valid) for df_survey in analytical SE return Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5434007 commit 5930286

3 files changed

Lines changed: 12 additions & 3 deletions

File tree

diff_diff/continuous_did.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,10 +1237,13 @@ def _compute_analytical_se(
12371237
from diff_diff.survey import compute_replicate_if_variance
12381238

12391239
_w_rep = unit_resolved.weights
1240+
_rep_n_valid = unit_resolved.n_replicates # track effective count
12401241

12411242
def _rep_se(if_vals):
1243+
nonlocal _rep_n_valid
12421244
psi_scaled = _w_rep * if_vals
1243-
v, _nv = compute_replicate_if_variance(psi_scaled, unit_resolved)
1245+
v, nv = compute_replicate_if_variance(psi_scaled, unit_resolved)
1246+
_rep_n_valid = min(_rep_n_valid, nv) # worst-case valid count
12441247
return float(np.sqrt(max(v, 0.0))) if np.isfinite(v) else np.nan
12451248

12461249
overall_att_se = _rep_se(if_att_glob)
@@ -1282,7 +1285,11 @@ def _rep_se(if_vals):
12821285
acrt_d_se = np.sqrt(np.sum(if_acrt_d**2, axis=0))
12831286

12841287
# Return unit-level survey df and resolved design for metadata recomputation
1285-
unit_df_survey = unit_resolved.df_survey if resolved_survey is not None else None
1288+
# Use effective replicate df if available (from _rep_se calls)
1289+
if resolved_survey is not None and hasattr(resolved_survey, 'uses_replicate_variance') and resolved_survey.uses_replicate_variance:
1290+
unit_df_survey = _rep_n_valid - 1 if _rep_n_valid > 1 else None
1291+
else:
1292+
unit_df_survey = unit_resolved.df_survey if resolved_survey is not None else None
12861293

12871294
return {
12881295
"overall_att_se": overall_att_se,

diff_diff/staggered.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,6 +1493,8 @@ def fit(
14931493
overall_att, overall_se = self._aggregate_simple(
14941494
group_time_effects, influence_func_info, df, unit, precomputed
14951495
)
1496+
# Re-read df_survey in case replicate aggregation updated it
1497+
df_survey = precomputed.get("df_survey")
14961498
overall_t, overall_p, overall_ci = safe_inference(
14971499
overall_att,
14981500
overall_se,

diff_diff/staggered_aggregation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def _compute_aggregated_se_with_wif(
479479
variance, n_valid_rep = compute_replicate_if_variance(psi_total, resolved_survey)
480480
# Update precomputed survey df to reflect valid replicate count
481481
if precomputed is not None and n_valid_rep < resolved_survey.n_replicates:
482-
precomputed["survey_df"] = n_valid_rep - 1 if n_valid_rep > 1 else None
482+
precomputed["df_survey"] = n_valid_rep - 1 if n_valid_rep > 1 else None
483483
if np.isnan(variance):
484484
return np.nan
485485
return np.sqrt(max(variance, 0.0))

0 commit comments

Comments
 (0)