Skip to content

Commit 67a87e3

Browse files
igerberclaude
andcommitted
Fix ContinuousDiD bread normalization and panel-level survey metadata from PR #226 review (round 4)
- ContinuousDiD: normalize WLS bread by weighted treated mass (not raw count) for consistency with downstream IF score denominators; fixes ACRT_glob/ATT(d)/ACRT(d) survey SEs when subgroup-average weights differ - ContinuousDiD/EfficientDiD: recompute survey_metadata from unit-level ResolvedSurveyDesign so reported effective_n/n_psu/df_survey match the inference actually run (not the panel-level overcount) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 1166024 commit 67a87e3

2 files changed

Lines changed: 32 additions & 4 deletions

File tree

diff_diff/continuous_did.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,15 @@ def fit(
512512
# Survey df for t-distribution inference (unit-level, not panel-level)
513513
_survey_df = analytic.get("df_survey")
514514

515+
# Recompute survey_metadata from unit-level design so reported
516+
# effective_n/n_psu/df_survey match the inference actually run
517+
_unit_resolved = analytic.get("unit_resolved")
518+
if _unit_resolved is not None:
519+
from diff_diff.survey import compute_survey_metadata
520+
521+
raw_w_unit = _unit_resolved.weights
522+
survey_metadata = compute_survey_metadata(_unit_resolved, raw_w_unit)
523+
515524
overall_att_t, overall_att_p, overall_att_ci = safe_inference(
516525
overall_att, overall_att_se, self.alpha, df=_survey_df
517526
)
@@ -948,11 +957,14 @@ def _compute_dose_response_gt(
948957
# Store bootstrap info for influence function computation
949958
# bread = (Psi'WPsi / n_treated)^{-1} when survey, (Psi'Psi / n_treated)^{-1} otherwise
950959
if w_treated is not None:
960+
w_treated_sum = float(np.sum(w_treated))
951961
PtWP = Psi.T @ (Psi * w_treated[:, np.newaxis])
962+
# Normalize bread by weighted mass (not raw count) for consistency
963+
# with downstream IF score denominators that also use weighted mass
952964
try:
953-
bread = np.linalg.inv(PtWP / n_treated)
965+
bread = np.linalg.inv(PtWP / w_treated_sum)
954966
except np.linalg.LinAlgError:
955-
bread = np.linalg.pinv(PtWP / n_treated)
967+
bread = np.linalg.pinv(PtWP / w_treated_sum)
956968
else:
957969
PtP = Psi.T @ Psi
958970
try:
@@ -1220,7 +1232,7 @@ def _compute_analytical_se(
12201232
att_d_se = np.sqrt(np.sum(if_att_d**2, axis=0))
12211233
acrt_d_se = np.sqrt(np.sum(if_acrt_d**2, axis=0))
12221234

1223-
# Return unit-level survey df when available (for t-distribution inference)
1235+
# Return unit-level survey df and resolved design for metadata recomputation
12241236
unit_df_survey = unit_resolved.df_survey if resolved_survey is not None else None
12251237

12261238
return {
@@ -1229,6 +1241,7 @@ def _compute_analytical_se(
12291241
"att_d_se": att_d_se,
12301242
"acrt_d_se": acrt_d_se,
12311243
"df_survey": unit_df_survey,
1244+
"unit_resolved": unit_resolved if resolved_survey is not None else None,
12321245
}
12331246

12341247
def _run_bootstrap(

diff_diff/efficient_did.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -667,11 +667,26 @@ def fit(
667667
omega_condition_numbers=stored_cond if stored_cond else None,
668668
influence_functions=None, # can store full EIF matrix if needed
669669
bootstrap_results=bootstrap_results,
670-
survey_metadata=survey_metadata,
670+
survey_metadata=(
671+
self._recompute_unit_survey_metadata(survey_metadata)
672+
if survey_metadata is not None
673+
else None
674+
),
671675
)
672676
self.is_fitted_ = True
673677
return self.results_
674678

679+
def _recompute_unit_survey_metadata(self, panel_metadata):
680+
"""Recompute survey metadata from unit-level design if available."""
681+
if self._unit_resolved_survey is not None:
682+
from diff_diff.survey import compute_survey_metadata
683+
684+
return compute_survey_metadata(
685+
self._unit_resolved_survey,
686+
self._unit_resolved_survey.weights,
687+
)
688+
return panel_metadata
689+
675690
# -- Survey SE helpers ----------------------------------------------------
676691

677692
def _compute_survey_eif_se(self, eif_vals: np.ndarray) -> float:

0 commit comments

Comments
 (0)