Skip to content

Commit fd1160a

Browse files
igerberclaude
andcommitted
Add survey support for Phase 4 estimators (ImputationDiD, TwoStageDiD, CallawaySantAnna)
- Weighted solve_logit(): survey weights enter IRLS as w_survey * mu*(1-mu) - ImputationDiD: weighted iterative FE, survey-weighted ATT aggregation, weighted conservative variance (Theorem 3), survey df for inference - TwoStageDiD: weighted iterative FE, weighted Stage 2 OLS, weighted GMM sandwich variance with survey weights in both stages - CallawaySantAnna: survey-weighted regression, IPW (via weighted solve_logit), and DR methods with explicit influence functions; survey-weighted WIF in aggregation; Cholesky cache bypassed under survey weights - Unblock TripleDifference IPW/DR with survey (weighted solve_logit now available) - 38 new tests in test_survey_phase4.py covering all estimators + scale invariance - Update survey-roadmap.md, REGISTRY.md with Phase 4 status and deviation notes Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b75138c commit fd1160a

13 files changed

Lines changed: 1928 additions & 378 deletions

diff_diff/imputation.py

Lines changed: 203 additions & 41 deletions
Large diffs are not rendered by default.

diff_diff/imputation_results.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ class ImputationDiDResults:
139139
bootstrap_results: Optional[ImputationBootstrapResults] = field(default=None, repr=False)
140140
# Internal: stores data needed for pretrend_test()
141141
_estimator_ref: Optional[Any] = field(default=None, repr=False)
142+
# Survey design metadata (SurveyMetadata instance from diff_diff.survey)
143+
survey_metadata: Optional[Any] = field(default=None, repr=False)
142144

143145
def __repr__(self) -> str:
144146
"""Concise string representation."""
@@ -182,6 +184,27 @@ def summary(self, alpha: Optional[float] = None) -> str:
182184
"",
183185
]
184186

187+
# Survey design info
188+
if self.survey_metadata is not None:
189+
sm = self.survey_metadata
190+
lines.extend(
191+
[
192+
"-" * 85,
193+
"Survey Design".center(85),
194+
"-" * 85,
195+
f"{'Weight type:':<30} {sm.weight_type:>10}",
196+
]
197+
)
198+
if sm.n_strata is not None:
199+
lines.append(f"{'Strata:':<30} {sm.n_strata:>10}")
200+
if sm.n_psu is not None:
201+
lines.append(f"{'PSU/Cluster:':<30} {sm.n_psu:>10}")
202+
lines.append(f"{'Effective sample size:':<30} {sm.effective_n:>10.1f}")
203+
lines.append(f"{'Design effect (DEFF):':<30} {sm.design_effect:>10.2f}")
204+
if sm.df_survey is not None:
205+
lines.append(f"{'Survey d.f.:':<30} {sm.df_survey:>10}")
206+
lines.extend(["-" * 85, ""])
207+
185208
# Overall ATT
186209
lines.extend(
187210
[

diff_diff/linalg.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -390,24 +390,18 @@ def _validate_weights(weights, weight_type, n):
390390
"""Validate weights array and weight_type for solve_ols/LinearRegression."""
391391
if weight_type not in _VALID_WEIGHT_TYPES:
392392
raise ValueError(
393-
f"weight_type must be one of {_VALID_WEIGHT_TYPES}, "
394-
f"got '{weight_type}'"
393+
f"weight_type must be one of {_VALID_WEIGHT_TYPES}, " f"got '{weight_type}'"
395394
)
396395
if weights is not None:
397396
weights = np.asarray(weights, dtype=np.float64)
398397
if weights.shape[0] != n:
399-
raise ValueError(
400-
f"weights length ({weights.shape[0]}) must match "
401-
f"X rows ({n})"
402-
)
398+
raise ValueError(f"weights length ({weights.shape[0]}) must match " f"X rows ({n})")
403399
if np.any(np.isnan(weights)):
404400
raise ValueError("Weights contain NaN values")
405401
if np.any(np.isinf(weights)):
406402
raise ValueError("Weights contain Inf values")
407403
if np.any(weights < 0):
408-
raise ValueError(
409-
"Weights must be non-negative"
410-
)
404+
raise ValueError("Weights must be non-negative")
411405
if weight_type == "fweight":
412406
fractional = weights - np.round(weights)
413407
if np.any(np.abs(fractional) > 1e-10):
@@ -693,13 +687,9 @@ def solve_ols(
693687
weights=weights,
694688
weight_type=weight_type,
695689
)
696-
vcov_out = _expand_vcov_with_nan(
697-
vcov_reduced, _original_X.shape[1], kept_cols
698-
)
690+
vcov_out = _expand_vcov_with_nan(vcov_reduced, _original_X.shape[1], kept_cols)
699691
else:
700-
vcov_out = np.full(
701-
(_original_X.shape[1], _original_X.shape[1]), np.nan
702-
)
692+
vcov_out = np.full((_original_X.shape[1], _original_X.shape[1]), np.nan)
703693
else:
704694
vcov_out = _compute_robust_vcov_numpy(
705695
_original_X,
@@ -1122,6 +1112,7 @@ def solve_logit(
11221112
tol: float = 1e-8,
11231113
check_separation: bool = True,
11241114
rank_deficient_action: str = "warn",
1115+
weights: Optional[np.ndarray] = None,
11251116
) -> Tuple[np.ndarray, np.ndarray]:
11261117
"""
11271118
Fit logistic regression via IRLS (Fisher scoring).
@@ -1147,6 +1138,13 @@ def solve_logit(
11471138
- "warn": Emit warning and drop columns (default)
11481139
- "error": Raise ValueError
11491140
- "silent": Drop columns silently
1141+
weights : np.ndarray, optional
1142+
Survey/observation weights of shape (n_samples,). When provided,
1143+
the IRLS working weights become ``weights * mu * (1 - mu)``
1144+
instead of ``mu * (1 - mu)``. This produces the survey-weighted
1145+
maximum likelihood estimator, matching R's ``svyglm(family=binomial)``.
1146+
When None (default), behavior is identical to unweighted logistic
1147+
regression.
11501148
11511149
Returns
11521150
-------
@@ -1203,11 +1201,16 @@ def solve_logit(
12031201
mu = np.clip(mu, 1e-10, 1 - 1e-10)
12041202

12051203
# Working weights and working response
1206-
w = mu * (1.0 - mu)
1207-
z = eta + (y - mu) / w
1204+
w_irls = mu * (1.0 - mu)
1205+
z = eta + (y - mu) / w_irls
1206+
1207+
if weights is not None:
1208+
w_total = weights * w_irls
1209+
else:
1210+
w_total = w_irls
12081211

12091212
# Weighted least squares: solve (X'WX) beta = X'Wz
1210-
sqrt_w = np.sqrt(w)
1213+
sqrt_w = np.sqrt(w_total)
12111214
Xw = X_solve * sqrt_w[:, None]
12121215
zw = z * sqrt_w
12131216
beta_new, _, _, _ = np.linalg.lstsq(Xw, zw, rcond=None)
@@ -1593,10 +1596,7 @@ def fit(
15931596
_use_survey_vcov = self.survey_design.needs_survey_vcov
15941597
# Canonicalize weights from survey_design to ensure consistency
15951598
# between coefficient estimation and survey vcov computation
1596-
if (
1597-
self.weights is not None
1598-
and self.weights is not self.survey_design.weights
1599-
):
1599+
if self.weights is not None and self.weights is not self.survey_design.weights:
16001600
warnings.warn(
16011601
"Explicit weights= differ from survey_design.weights. "
16021602
"Using survey_design weights for both coefficient "
@@ -1609,9 +1609,7 @@ def fit(
16091609
self.weight_type = self.survey_design.weight_type
16101610

16111611
if self.weights is not None:
1612-
self.weights = _validate_weights(
1613-
self.weights, self.weight_type, X.shape[0]
1614-
)
1612+
self.weights = _validate_weights(self.weights, self.weight_type, X.shape[0])
16151613

16161614
# Inject cluster as PSU for survey variance when no PSU specified.
16171615
# Use a local variable to avoid mutating self.survey_design, which
@@ -1622,7 +1620,9 @@ def fit(
16221620
and _effective_survey_design is not None
16231621
and _use_survey_vcov
16241622
):
1625-
from diff_diff.survey import ResolvedSurveyDesign as _RSD, _inject_cluster_as_psu
1623+
from diff_diff.survey import ResolvedSurveyDesign as _RSD
1624+
from diff_diff.survey import _inject_cluster_as_psu
1625+
16261626
if isinstance(_effective_survey_design, _RSD) and _effective_survey_design.psu is None:
16271627
_effective_survey_design = _inject_cluster_as_psu(
16281628
_effective_survey_design, effective_cluster_ids
@@ -1864,9 +1864,7 @@ def get_inference(
18641864
# Use project-standard NaN-safe inference (returns all-NaN when SE <= 0)
18651865
from diff_diff.utils import safe_inference
18661866

1867-
t_stat, p_value, conf_int = safe_inference(
1868-
coef, se, alpha=effective_alpha, df=effective_df
1869-
)
1867+
t_stat, p_value, conf_int = safe_inference(coef, se, alpha=effective_alpha, df=effective_df)
18701868

18711869
return InferenceResult(
18721870
coefficient=coef,

0 commit comments

Comments
 (0)