@@ -63,7 +63,11 @@ class _PrecomputedStructures(TypedDict):
6363 control_obs : List [Tuple [int , int ]]
6464 """List of (t, i) tuples for valid control observations."""
6565 control_unit_idx : np .ndarray
66- """Array of control unit indices."""
66+ """Array of never-treated unit indices (for backward compatibility)."""
67+ D : np .ndarray
68+ """Treatment indicator matrix (n_periods x n_units) for dynamic control sets."""
69+ Y : np .ndarray
70+ """Outcome matrix (n_periods x n_units)."""
6771 n_units : int
6872 """Number of units."""
6973 n_periods : int
@@ -529,6 +533,8 @@ def _precompute_structures(
529533 "treated_observations" : treated_observations ,
530534 "control_obs" : control_obs ,
531535 "control_unit_idx" : control_unit_idx ,
536+ "D" : D ,
537+ "Y" : Y ,
532538 "n_units" : n_units ,
533539 "n_periods" : n_periods ,
534540 }
@@ -778,16 +784,14 @@ def fit(
778784 # Prepare inputs for Rust function
779785 control_mask_u8 = control_mask .astype (np .uint8 )
780786 time_dist_matrix = self ._precomputed ["time_dist_matrix" ].astype (np .int64 )
781- unit_dist_matrix = self ._precomputed ["unit_dist_matrix" ]
782- control_unit_idx_i64 = control_unit_idx .astype (np .int64 )
783787
784788 lambda_time_arr = np .array (self .lambda_time_grid , dtype = np .float64 )
785789 lambda_unit_arr = np .array (self .lambda_unit_grid , dtype = np .float64 )
786790 lambda_nn_arr = np .array (self .lambda_nn_grid , dtype = np .float64 )
787791
788792 best_lt , best_lu , best_ln , best_score = _rust_loocv_grid_search (
789- Y , D .astype (np .float64 ), control_mask_u8 , control_unit_idx_i64 ,
790- unit_dist_matrix , time_dist_matrix ,
793+ Y , D .astype (np .float64 ), control_mask_u8 ,
794+ time_dist_matrix ,
791795 lambda_time_arr , lambda_unit_arr , lambda_nn_arr ,
792796 self .max_loocv_samples , self .max_iter , self .tol ,
793797 self .seed if self .seed is not None else 0
@@ -953,10 +957,16 @@ def _compute_observation_weights(
953957 """
954958 Compute observation-specific weight matrix for treated observation (i, t).
955959
956- Following the paper's Algorithm 2 (page 27):
960+ Following the paper's Algorithm 2 (page 27) and Equation 2 (page 7) :
957961 - Time weights θ_s^{i,t} = exp(-λ_time × |t - s|)
958962 - Unit weights ω_j^{i,t} = exp(-λ_unit × dist_unit_{-t}(j, i))
959963
964+ IMPORTANT (Issue A fix): The paper's objective sums over ALL observations
965+ where (1 - W_js) is non-zero, which includes pre-treatment observations of
966+ eventually-treated units since W_js = 0 for those. This method computes
967+ weights for ALL units where D[t, j] = 0 at the target period, not just
968+ never-treated units.
969+
960970 Uses pre-computed structures when available for efficiency.
961971
962972 Parameters
@@ -974,7 +984,8 @@ def _compute_observation_weights(
974984 lambda_unit : float
975985 Unit weight decay parameter.
976986 control_unit_idx : np.ndarray
977- Indices of control units.
987+ Indices of never-treated units (for backward compatibility, but not
988+ used for weight computation - we use D matrix directly).
978989 n_units : int
979990 Number of units.
980991 n_periods : int
@@ -991,21 +1002,30 @@ def _compute_observation_weights(
9911002 # time_dist_matrix[t, s] = |t - s|
9921003 time_weights = np .exp (- lambda_time * self ._precomputed ["time_dist_matrix" ][t , :])
9931004
994- # Unit weights from pre-computed unit distance matrix
1005+ # Unit weights - computed for ALL units where D[t, j] = 0
1006+ # (Issue A fix: includes pre-treatment obs of eventually-treated units)
9951007 unit_weights = np .zeros (n_units )
1008+ D_stored = self ._precomputed ["D" ]
1009+ Y_stored = self ._precomputed ["Y" ]
1010+
1011+ # Valid control units at time t: D[t, j] == 0
1012+ valid_control_at_t = D_stored [t , :] == 0
9961013
9971014 if lambda_unit == 0 :
9981015 # Uniform weights when lambda_unit = 0
999- unit_weights [:] = 1.0
1016+ # All units not treated at time t get weight 1
1017+ unit_weights [valid_control_at_t ] = 1.0
10001018 else :
1001- # Use pre-computed distances: unit_dist_matrix[j, i] = dist(j, i)
1002- dist_matrix = self ._precomputed ["unit_dist_matrix" ]
1003- for j in control_unit_idx :
1004- dist = dist_matrix [j , i ]
1005- if np .isinf (dist ):
1006- unit_weights [j ] = 0.0
1007- else :
1008- unit_weights [j ] = np .exp (- lambda_unit * dist )
1019+ # Use observation-specific distances with target period excluded
1020+ # (Issue B fix: compute exact per-observation distance)
1021+ for j in range (n_units ):
1022+ if valid_control_at_t [j ] and j != i :
1023+ # Compute distance excluding target period t
1024+ dist = self ._compute_unit_distance_for_obs (Y_stored , D_stored , j , i , t )
1025+ if np .isinf (dist ):
1026+ unit_weights [j ] = 0.0
1027+ else :
1028+ unit_weights [j ] = np .exp (- lambda_unit * dist )
10091029
10101030 # Treated unit i gets weight 1
10111031 unit_weights [i ] = 1.0
@@ -1018,19 +1038,25 @@ def _compute_observation_weights(
10181038 dist_time = np .abs (np .arange (n_periods ) - t )
10191039 time_weights = np .exp (- lambda_time * dist_time )
10201040
1021- # Unit distance: pairwise RMSE from each control j to treated i
1041+ # Unit weights - computed for ALL units where D[t, j] = 0
1042+ # (Issue A fix: includes pre-treatment obs of eventually-treated units)
10221043 unit_weights = np .zeros (n_units )
10231044
1045+ # Valid control units at time t: D[t, j] == 0
1046+ valid_control_at_t = D [t , :] == 0
1047+
10241048 if lambda_unit == 0 :
10251049 # Uniform weights when lambda_unit = 0
1026- unit_weights [: ] = 1.0
1050+ unit_weights [valid_control_at_t ] = 1.0
10271051 else :
1028- for j in control_unit_idx :
1029- dist = self ._compute_unit_distance_for_obs (Y , D , j , i , t )
1030- if np .isinf (dist ):
1031- unit_weights [j ] = 0.0
1032- else :
1033- unit_weights [j ] = np .exp (- lambda_unit * dist )
1052+ for j in range (n_units ):
1053+ if valid_control_at_t [j ] and j != i :
1054+ # Compute distance excluding target period t (Issue B fix)
1055+ dist = self ._compute_unit_distance_for_obs (Y , D , j , i , t )
1056+ if np .isinf (dist ):
1057+ unit_weights [j ] = 0.0
1058+ else :
1059+ unit_weights [j ] = np .exp (- lambda_unit * dist )
10341060
10351061 # Treated unit i gets weight 1 (or could be omitted since we fit on controls)
10361062 # We include treated unit's own observation for model fitting
@@ -1102,6 +1128,101 @@ def _soft_threshold_svd(
11021128
11031129 return result
11041130
1131+ def _weighted_nuclear_norm_solve (
1132+ self ,
1133+ Y : np .ndarray ,
1134+ W : np .ndarray ,
1135+ L_init : np .ndarray ,
1136+ alpha : np .ndarray ,
1137+ beta : np .ndarray ,
1138+ lambda_nn : float ,
1139+ max_inner_iter : int = 20 ,
1140+ ) -> np .ndarray :
1141+ """
1142+ Solve weighted nuclear norm problem using iterative weighted soft-impute.
1143+
1144+ Issue C fix: Implements the weighted nuclear norm optimization from the
1145+ paper's Equation 2 (page 7). The full objective is:
1146+ min_L Σ W_{ti}(R_{ti} - L_{ti})² + λ_nn||L||_*
1147+
1148+ This uses a proximal gradient / soft-impute approach (Mazumder et al. 2010):
1149+ L_{k+1} = prox_{λ||·||_*}(L_k + W ⊙ (R - L_k))
1150+
1151+ where W ⊙ denotes element-wise multiplication with normalized weights.
1152+
1153+ IMPORTANT: For observations with W=0 (treated observations), we keep
1154+ L values from the previous iteration rather than setting L = R, which
1155+ would absorb the treatment effect.
1156+
1157+ Parameters
1158+ ----------
1159+ Y : np.ndarray
1160+ Outcome matrix (n_periods x n_units).
1161+ W : np.ndarray
1162+ Weight matrix (n_periods x n_units), non-negative. W=0 indicates
1163+ observations that should not be used for fitting (treated obs).
1164+ L_init : np.ndarray
1165+ Initial estimate of L matrix.
1166+ alpha : np.ndarray
1167+ Current unit fixed effects estimate.
1168+ beta : np.ndarray
1169+ Current time fixed effects estimate.
1170+ lambda_nn : float
1171+ Nuclear norm regularization parameter.
1172+ max_inner_iter : int, default=20
1173+ Maximum inner iterations for the proximal algorithm.
1174+
1175+ Returns
1176+ -------
1177+ np.ndarray
1178+ Updated L matrix estimate.
1179+ """
1180+ # Compute target residual R = Y - α - β
1181+ R = Y - alpha [np .newaxis , :] - beta [:, np .newaxis ]
1182+
1183+ # Handle invalid values
1184+ R = np .nan_to_num (R , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
1185+
1186+ # For observations with W=0 (treated obs), keep L_init instead of R
1187+ # This prevents L from absorbing the treatment effect
1188+ valid_obs_mask = W > 0
1189+ R_masked = np .where (valid_obs_mask , R , L_init )
1190+
1191+ if lambda_nn <= 0 :
1192+ # No regularization - just return masked residual
1193+ # Use soft-thresholding with threshold=0 which returns the input
1194+ return R_masked
1195+
1196+ # Normalize weights so max is 1 (for step size stability)
1197+ W_max = np .max (W )
1198+ if W_max > 0 :
1199+ W_norm = W / W_max
1200+ else :
1201+ W_norm = W
1202+
1203+ # Initialize L
1204+ L = L_init .copy ()
1205+
1206+ # Proximal gradient iteration with weighted soft-impute
1207+ # This solves: min_L ||W^{1/2} ⊙ (R - L)||_F^2 + λ||L||_*
1208+ # Using: L_{k+1} = prox_{λ/η}(L_k + W ⊙ (R - L_k))
1209+ # where η is the step size (we use η = 1 with normalized weights)
1210+ for _ in range (max_inner_iter ):
1211+ L_old = L .copy ()
1212+
1213+ # Gradient step: L_k + W ⊙ (R - L_k)
1214+ # For W=0 observations, this keeps L_k unchanged
1215+ gradient_step = L + W_norm * (R_masked - L )
1216+
1217+ # Proximal step: soft-threshold singular values
1218+ L = self ._soft_threshold_svd (gradient_step , lambda_nn )
1219+
1220+ # Check convergence
1221+ if np .max (np .abs (L - L_old )) < self .tol :
1222+ break
1223+
1224+ return L
1225+
11051226 def _estimate_model (
11061227 self ,
11071228 Y : np .ndarray ,
@@ -1205,14 +1326,13 @@ def _estimate_model(
12051326 beta_numerator = np .sum (weighted_R_minus_alpha , axis = 1 ) # (n_periods,)
12061327 beta = np .where (time_has_obs , beta_numerator / safe_time_denom , 0.0 )
12071328
1208- # Step 2: Update L with nuclear norm penalty
1209- # Following Equation 2 (page 7): L = prox_{λ_nn||·||_*}(Y - α - β)
1210- # The proximal operator for nuclear norm is soft-thresholding of SVD
1211- R_for_L = Y_safe - alpha [np .newaxis , :] - beta [:, np .newaxis ]
1212- # Impute invalid observations with current L for stable SVD
1213- R_for_L = np .where (valid_mask , R_for_L , L )
1214-
1215- L = self ._soft_threshold_svd (R_for_L , lambda_nn )
1329+ # Step 2: Update L with weighted nuclear norm penalty
1330+ # Issue C fix: Use weighted soft-impute to properly account for
1331+ # observation weights in the nuclear norm optimization.
1332+ # Following Equation 2 (page 7): min_L Σ W_{ti}(Y - α - β - L)² + λ||L||_*
1333+ L = self ._weighted_nuclear_norm_solve (
1334+ Y_safe , W_masked , L , alpha , beta , lambda_nn , max_inner_iter = 10
1335+ )
12161336
12171337 # Check convergence
12181338 alpha_diff = np .max (np .abs (alpha - alpha_old ))
@@ -1388,21 +1508,15 @@ def _bootstrap_variance(
13881508 # Try Rust backend for parallel bootstrap (5-15x speedup)
13891509 if (HAS_RUST_BACKEND and _rust_bootstrap_trop_variance is not None
13901510 and self ._precomputed is not None and Y is not None
1391- and D is not None and control_unit_idx is not None ):
1511+ and D is not None ):
13921512 try :
1393- # Prepare inputs
1394- treated_observations = self ._precomputed ["treated_observations" ]
1395- treated_t = np .array ([t for t , i in treated_observations ], dtype = np .int64 )
1396- treated_i = np .array ([i for t , i in treated_observations ], dtype = np .int64 )
13971513 control_mask = self ._precomputed ["control_mask" ]
1514+ time_dist_matrix = self ._precomputed ["time_dist_matrix" ].astype (np .int64 )
13981515
13991516 bootstrap_estimates , se = _rust_bootstrap_trop_variance (
14001517 Y , D .astype (np .float64 ),
14011518 control_mask .astype (np .uint8 ),
1402- control_unit_idx .astype (np .int64 ),
1403- treated_t , treated_i ,
1404- self ._precomputed ["unit_dist_matrix" ],
1405- self ._precomputed ["time_dist_matrix" ].astype (np .int64 ),
1519+ time_dist_matrix ,
14061520 lambda_time , lambda_unit , lambda_nn ,
14071521 self .n_bootstrap , self .max_iter , self .tol ,
14081522 self .seed if self .seed is not None else 0
@@ -1422,14 +1536,38 @@ def _bootstrap_variance(
14221536
14231537 # Python implementation (fallback)
14241538 rng = np .random .default_rng (self .seed )
1425- all_units = data [unit ].unique ()
1426- n_units_data = len (all_units )
1539+
1540+ # Issue D fix: Stratified bootstrap sampling
1541+ # Paper's Algorithm 3 (page 27) specifies sampling N_0 control rows
1542+ # and N_1 treated rows separately to preserve treatment ratio
1543+ unit_ever_treated = data .groupby (unit )[treatment ].max ()
1544+ treated_units = np .array (unit_ever_treated [unit_ever_treated == 1 ].index )
1545+ control_units = np .array (unit_ever_treated [unit_ever_treated == 0 ].index )
1546+
1547+ n_treated_units = len (treated_units )
1548+ n_control_units = len (control_units )
14271549
14281550 bootstrap_estimates_list = []
14291551
14301552 for _ in range (self .n_bootstrap ):
1431- # Sample units with replacement
1432- sampled_units = rng .choice (all_units , size = n_units_data , replace = True )
1553+ # Stratified sampling: sample control and treated units separately
1554+ # This preserves the treatment ratio in each bootstrap sample
1555+ if n_control_units > 0 :
1556+ sampled_control = rng .choice (
1557+ control_units , size = n_control_units , replace = True
1558+ )
1559+ else :
1560+ sampled_control = np .array ([], dtype = control_units .dtype )
1561+
1562+ if n_treated_units > 0 :
1563+ sampled_treated = rng .choice (
1564+ treated_units , size = n_treated_units , replace = True
1565+ )
1566+ else :
1567+ sampled_treated = np .array ([], dtype = treated_units .dtype )
1568+
1569+ # Combine stratified samples
1570+ sampled_units = np .concatenate ([sampled_control , sampled_treated ])
14331571
14341572 # Create bootstrap sample with unique unit IDs
14351573 boot_data = pd .concat ([
0 commit comments