@@ -390,24 +390,18 @@ def _validate_weights(weights, weight_type, n):
390390 """Validate weights array and weight_type for solve_ols/LinearRegression."""
391391 if weight_type not in _VALID_WEIGHT_TYPES :
392392 raise ValueError (
393- f"weight_type must be one of { _VALID_WEIGHT_TYPES } , "
394- f"got '{ weight_type } '"
393+ f"weight_type must be one of { _VALID_WEIGHT_TYPES } , " f"got '{ weight_type } '"
395394 )
396395 if weights is not None :
397396 weights = np .asarray (weights , dtype = np .float64 )
398397 if weights .shape [0 ] != n :
399- raise ValueError (
400- f"weights length ({ weights .shape [0 ]} ) must match "
401- f"X rows ({ n } )"
402- )
398+ raise ValueError (f"weights length ({ weights .shape [0 ]} ) must match " f"X rows ({ n } )" )
403399 if np .any (np .isnan (weights )):
404400 raise ValueError ("Weights contain NaN values" )
405401 if np .any (np .isinf (weights )):
406402 raise ValueError ("Weights contain Inf values" )
407403 if np .any (weights < 0 ):
408- raise ValueError (
409- "Weights must be non-negative"
410- )
404+ raise ValueError ("Weights must be non-negative" )
411405 if weight_type == "fweight" :
412406 fractional = weights - np .round (weights )
413407 if np .any (np .abs (fractional ) > 1e-10 ):
@@ -693,13 +687,9 @@ def solve_ols(
693687 weights = weights ,
694688 weight_type = weight_type ,
695689 )
696- vcov_out = _expand_vcov_with_nan (
697- vcov_reduced , _original_X .shape [1 ], kept_cols
698- )
690+ vcov_out = _expand_vcov_with_nan (vcov_reduced , _original_X .shape [1 ], kept_cols )
699691 else :
700- vcov_out = np .full (
701- (_original_X .shape [1 ], _original_X .shape [1 ]), np .nan
702- )
692+ vcov_out = np .full ((_original_X .shape [1 ], _original_X .shape [1 ]), np .nan )
703693 else :
704694 vcov_out = _compute_robust_vcov_numpy (
705695 _original_X ,
@@ -1122,6 +1112,7 @@ def solve_logit(
11221112 tol : float = 1e-8 ,
11231113 check_separation : bool = True ,
11241114 rank_deficient_action : str = "warn" ,
1115+ weights : Optional [np .ndarray ] = None ,
11251116) -> Tuple [np .ndarray , np .ndarray ]:
11261117 """
11271118 Fit logistic regression via IRLS (Fisher scoring).
@@ -1147,6 +1138,13 @@ def solve_logit(
11471138 - "warn": Emit warning and drop columns (default)
11481139 - "error": Raise ValueError
11491140 - "silent": Drop columns silently
1141+ weights : np.ndarray, optional
1142+ Survey/observation weights of shape (n_samples,). When provided,
1143+ the IRLS working weights become ``weights * mu * (1 - mu)``
1144+ instead of ``mu * (1 - mu)``. This produces the survey-weighted
1145+ maximum likelihood estimator, matching R's ``svyglm(family=binomial)``.
1146+ When None (default), behavior is identical to unweighted logistic
1147+ regression.
11501148
11511149 Returns
11521150 -------
@@ -1203,11 +1201,16 @@ def solve_logit(
12031201 mu = np .clip (mu , 1e-10 , 1 - 1e-10 )
12041202
12051203 # Working weights and working response
1206- w = mu * (1.0 - mu )
1207- z = eta + (y - mu ) / w
1204+ w_irls = mu * (1.0 - mu )
1205+ z = eta + (y - mu ) / w_irls
1206+
1207+ if weights is not None :
1208+ w_total = weights * w_irls
1209+ else :
1210+ w_total = w_irls
12081211
12091212 # Weighted least squares: solve (X'WX) beta = X'Wz
1210- sqrt_w = np .sqrt (w )
1213+ sqrt_w = np .sqrt (w_total )
12111214 Xw = X_solve * sqrt_w [:, None ]
12121215 zw = z * sqrt_w
12131216 beta_new , _ , _ , _ = np .linalg .lstsq (Xw , zw , rcond = None )
@@ -1593,10 +1596,7 @@ def fit(
15931596 _use_survey_vcov = self .survey_design .needs_survey_vcov
15941597 # Canonicalize weights from survey_design to ensure consistency
15951598 # between coefficient estimation and survey vcov computation
1596- if (
1597- self .weights is not None
1598- and self .weights is not self .survey_design .weights
1599- ):
1599+ if self .weights is not None and self .weights is not self .survey_design .weights :
16001600 warnings .warn (
16011601 "Explicit weights= differ from survey_design.weights. "
16021602 "Using survey_design weights for both coefficient "
@@ -1609,9 +1609,7 @@ def fit(
16091609 self .weight_type = self .survey_design .weight_type
16101610
16111611 if self .weights is not None :
1612- self .weights = _validate_weights (
1613- self .weights , self .weight_type , X .shape [0 ]
1614- )
1612+ self .weights = _validate_weights (self .weights , self .weight_type , X .shape [0 ])
16151613
16161614 # Inject cluster as PSU for survey variance when no PSU specified.
16171615 # Use a local variable to avoid mutating self.survey_design, which
@@ -1622,7 +1620,9 @@ def fit(
16221620 and _effective_survey_design is not None
16231621 and _use_survey_vcov
16241622 ):
1625- from diff_diff .survey import ResolvedSurveyDesign as _RSD , _inject_cluster_as_psu
1623+ from diff_diff .survey import ResolvedSurveyDesign as _RSD
1624+ from diff_diff .survey import _inject_cluster_as_psu
1625+
16261626 if isinstance (_effective_survey_design , _RSD ) and _effective_survey_design .psu is None :
16271627 _effective_survey_design = _inject_cluster_as_psu (
16281628 _effective_survey_design , effective_cluster_ids
@@ -1864,9 +1864,7 @@ def get_inference(
18641864 # Use project-standard NaN-safe inference (returns all-NaN when SE <= 0)
18651865 from diff_diff .utils import safe_inference
18661866
1867- t_stat , p_value , conf_int = safe_inference (
1868- coef , se , alpha = effective_alpha , df = effective_df
1869- )
1867+ t_stat , p_value , conf_int = safe_inference (coef , se , alpha = effective_alpha , df = effective_df )
18701868
18711869 return InferenceResult (
18721870 coefficient = coef ,
0 commit comments