2525import pandas as pd
2626from scipy import stats
2727
28+ try :
29+ from typing import TypedDict
30+ except ImportError :
31+ from typing_extensions import TypedDict
32+
2833from diff_diff .results import _get_significance_stars
2934from diff_diff .utils import compute_confidence_interval , compute_p_value
3035
3136
37+ class _PrecomputedStructures (TypedDict ):
38+ """Type definition for pre-computed structures used across LOOCV iterations.
39+
40+ These structures are computed once in `_precompute_structures()` and reused
41+ to avoid redundant computation during LOOCV and final estimation.
42+ """
43+
44+ unit_dist_matrix : np .ndarray
45+ """Pairwise unit distance matrix (n_units x n_units)."""
46+ time_dist_matrix : np .ndarray
47+ """Time distance matrix where [t, s] = |t - s| (n_periods x n_periods)."""
48+ control_mask : np .ndarray
49+ """Boolean mask for control observations (D == 0)."""
50+ treated_mask : np .ndarray
51+ """Boolean mask for treated observations (D == 1)."""
52+ treated_observations : List [Tuple [int , int ]]
53+ """List of (t, i) tuples for treated observations."""
54+ control_obs : List [Tuple [int , int ]]
55+ """List of (t, i) tuples for valid control observations."""
56+ control_unit_idx : np .ndarray
57+ """Array of control unit indices."""
58+ n_units : int
59+ """Number of units."""
60+ n_periods : int
61+ """Number of time periods."""
62+
63+
3264@dataclass
3365class TROPResults :
3466 """
@@ -327,6 +359,11 @@ class TROP:
327359 Method for variance estimation: 'bootstrap' or 'jackknife'.
328360 n_bootstrap : int, default=200
329361 Number of replications for variance estimation.
362+ max_loocv_samples : int, default=100
363+ Maximum control observations to use in LOOCV for tuning parameter
364+ selection. Subsampling is used for computational tractability as
365+ noted in the paper. Increase for more precise tuning at the cost
366+ of computational time.
330367 seed : int, optional
331368 Random seed for reproducibility.
332369
@@ -357,6 +394,23 @@ class TROP:
357394 Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
358395 """
359396
397+ # Class constants
398+ DEFAULT_LOOCV_MAX_SAMPLES : int = 100
399+ """Maximum control observations to use in LOOCV (for computational tractability).
400+
401+ As noted in the paper's footnote, LOOCV is subsampled for computational
402+ tractability. This constant controls the maximum number of control observations
403+ used in each LOOCV evaluation. Increase for more precise tuning at the cost
404+ of computational time.
405+ """
406+
407+ CONVERGENCE_TOL_SVD : float = 1e-10
408+ """Tolerance for singular value truncation in soft-thresholding.
409+
410+ Singular values below this threshold after soft-thresholding are treated
411+ as zero to improve numerical stability.
412+ """
413+
360414 def __init__ (
361415 self ,
362416 lambda_time_grid : Optional [List [float ]] = None ,
@@ -367,6 +421,7 @@ def __init__(
367421 alpha : float = 0.05 ,
368422 variance_method : str = 'bootstrap' ,
369423 n_bootstrap : int = 200 ,
424+ max_loocv_samples : int = 100 ,
370425 seed : Optional [int ] = None ,
371426 ):
372427 # Default grids from paper
@@ -379,6 +434,7 @@ def __init__(
379434 self .alpha = alpha
380435 self .variance_method = variance_method
381436 self .n_bootstrap = n_bootstrap
437+ self .max_loocv_samples = max_loocv_samples
382438 self .seed = seed
383439
384440 # Validate parameters
@@ -395,7 +451,7 @@ def __init__(
395451 self ._optimal_lambda : Optional [Tuple [float , float , float ]] = None
396452
397453 # Pre-computed structures (set during fit)
398- self ._precomputed : Optional [Dict [ str , Any ] ] = None
454+ self ._precomputed : Optional [_PrecomputedStructures ] = None
399455
400456 def _precompute_structures (
401457 self ,
@@ -404,7 +460,7 @@ def _precompute_structures(
404460 control_unit_idx : np .ndarray ,
405461 n_units : int ,
406462 n_periods : int ,
407- ) -> Dict [ str , Any ] :
463+ ) -> _PrecomputedStructures :
408464 """
409465 Pre-compute data structures that are reused across LOOCV and estimation.
410466
@@ -428,7 +484,7 @@ def _precompute_structures(
428484
429485 Returns
430486 -------
431- dict
487+ _PrecomputedStructures
432488 Pre-computed structures for efficient reuse.
433489 """
434490 # Compute pairwise unit distances (for all observation-specific weights)
@@ -481,6 +537,9 @@ def _compute_all_unit_distances(
481537 observations, which provides a good approximation. The exact per-observation
482538 distances are refined when needed.
483539
540+ Uses vectorized numpy operations with masked arrays for O(n²) complexity
541+ but with highly optimized inner loops via numpy/BLAS.
542+
484543 Parameters
485544 ----------
486545 Y : np.ndarray
@@ -500,27 +559,36 @@ def _compute_all_unit_distances(
500559 # Mask for valid observations: control periods only (D=0), non-NaN
501560 valid_mask = (D == 0 ) & ~ np .isnan (Y )
502561
503- # Initialize distance matrix
504- dist_matrix = np .full ((n_units , n_units ), np .inf )
562+ # Replace invalid values with NaN for masked computation
563+ Y_masked = np .where (valid_mask , Y , np .nan )
564+
565+ # Transpose to (n_units, n_periods) for easier broadcasting
566+ Y_T = Y_masked .T # (n_units, n_periods)
567+
568+ # Compute pairwise squared differences using broadcasting
569+ # Y_T[:, np.newaxis, :] has shape (n_units, 1, n_periods)
570+ # Y_T[np.newaxis, :, :] has shape (1, n_units, n_periods)
571+ # diff has shape (n_units, n_units, n_periods)
572+ diff = Y_T [:, np .newaxis , :] - Y_T [np .newaxis , :, :]
573+ sq_diff = diff ** 2
574+
575+ # Count valid (non-NaN) observations per pair
576+ # A difference is valid only if both units have valid observations
577+ valid_diff = ~ np .isnan (sq_diff )
578+ n_valid = np .sum (valid_diff , axis = 2 ) # (n_units, n_units)
505579
506- # Compute pairwise distances using vectorized operations
507- # Y has shape (n_periods, n_units)
508- # We want sqrt(mean((Y[:, i] - Y[:, j])^2)) for valid periods
580+ # Compute sum of squared differences (treating NaN as 0)
581+ sq_diff_sum = np .nansum (sq_diff , axis = 2 ) # (n_units, n_units)
509582
510- # For each pair of units, find periods where both are valid
511- for i in range (n_units ):
512- valid_i = valid_mask [:, i ]
513- for j in range (i , n_units ):
514- valid_j = valid_mask [:, j ]
515- both_valid = valid_i & valid_j
583+ # Compute RMSE distance: sqrt(sum / n_valid)
584+ # Avoid division by zero
585+ with np .errstate (divide = 'ignore' , invalid = 'ignore' ):
586+ dist_matrix = np .sqrt (sq_diff_sum / n_valid )
516587
517- if np .any (both_valid ):
518- sq_diff = (Y [both_valid , i ] - Y [both_valid , j ]) ** 2
519- dist = np .sqrt (np .mean (sq_diff ))
520- dist_matrix [i , j ] = dist
521- dist_matrix [j , i ] = dist
588+ # Set pairs with no valid observations to inf
589+ dist_matrix = np .where (n_valid > 0 , dist_matrix , np .inf )
522590
523- # Set diagonal to 0
591+ # Ensure diagonal is 0 (same unit distance)
524592 np .fill_diagonal (dist_matrix , 0.0 )
525593
526594 return dist_matrix
@@ -970,7 +1038,7 @@ def _soft_threshold_svd(
9701038 s_thresh = np .maximum (s - threshold , 0 )
9711039
9721040 # Use truncated reconstruction with only non-zero singular values
973- nonzero_mask = s_thresh > 1e-10
1041+ nonzero_mask = s_thresh > self . CONVERGENCE_TOL_SVD
9741042 if not np .any (nonzero_mask ):
9751043 return np .zeros_like (M )
9761044
@@ -1065,35 +1133,39 @@ def _estimate_model(
10651133 # Replace NaN in Y with 0 for computation (mask handles exclusion)
10661134 Y_safe = np .where (np .isnan (Y ), 0.0 , Y )
10671135
1068- # Alternating minimization
1136+ # Alternating minimization following Algorithm 1 (page 9)
1137+ # Minimize: Σ W_{ti}(Y_{ti} - α_i - β_t - L_{ti})² + λ_nn||L||_*
10691138 for _ in range (self .max_iter ):
10701139 alpha_old = alpha .copy ()
10711140 beta_old = beta .copy ()
10721141 L_old = L .copy ()
10731142
1074- # Step 1: Update α and β (weighted means) - VECTORIZED
1143+ # Step 1: Update α and β (weighted least squares)
1144+ # Following Equation 2 (page 7), fix L and solve for α, β
10751145 # R = Y - L (residual without fixed effects)
10761146 R = Y_safe - L
10771147
1078- # Alpha update: α_i = Σ_t W_{ti} (R_{ti} - β_t) / Σ_t W_{ti}
1079- # Compute weighted sum of (R - β) per unit
1148+ # Alpha update (unit fixed effects):
1149+ # α_i = argmin_α Σ_t W_{ti}(R_{ti} - α - β_t)²
1150+ # Solution: α_i = Σ_t W_{ti}(R_{ti} - β_t) / Σ_t W_{ti}
10801151 R_minus_beta = R - beta [:, np .newaxis ] # (n_periods, n_units)
10811152 weighted_R_minus_beta = W_masked * R_minus_beta
10821153 alpha_numerator = np .sum (weighted_R_minus_beta , axis = 0 ) # (n_units,)
10831154 alpha = np .where (unit_has_obs , alpha_numerator / safe_unit_denom , 0.0 )
10841155
1085- # Beta update: β_t = Σ_i W_{ti} (R_{ti} - α_i) / Σ_i W_{ti}
1086- # Compute weighted sum of (R - α) per period
1156+ # Beta update (time fixed effects):
1157+ # β_t = argmin_β Σ_i W_{ti}(R_{ti} - α_i - β)²
1158+ # Solution: β_t = Σ_i W_{ti}(R_{ti} - α_i) / Σ_i W_{ti}
10871159 R_minus_alpha = R - alpha [np .newaxis , :] # (n_periods, n_units)
10881160 weighted_R_minus_alpha = W_masked * R_minus_alpha
10891161 beta_numerator = np .sum (weighted_R_minus_alpha , axis = 1 ) # (n_periods,)
10901162 beta = np .where (time_has_obs , beta_numerator / safe_time_denom , 0.0 )
10911163
1092- # Step 2: Update L with nuclear norm penalty - VECTORIZED
1093- # R_for_L = Y - α - β where valid, else L (impute missing )
1094- # Vectorized: broadcast alpha and beta
1164+ # Step 2: Update L with nuclear norm penalty
1165+ # Following Equation 2 (page 7): L = prox_{λ_nn||·||_*}( Y - α - β)
1166+ # The proximal operator for nuclear norm is soft-thresholding of SVD
10951167 R_for_L = Y_safe - alpha [np .newaxis , :] - beta [:, np .newaxis ]
1096- # Impute invalid observations with current L
1168+ # Impute invalid observations with current L for stable SVD
10971169 R_for_L = np .where (valid_mask , R_for_L , L )
10981170
10991171 L = self ._soft_threshold_svd (R_for_L , lambda_nn )
@@ -1168,7 +1240,7 @@ def _loocv_score_obs_specific(
11681240
11691241 # Subsample for computational tractability (as noted in paper's footnote)
11701242 rng = np .random .default_rng (self .seed )
1171- max_loocv = min (100 , len (control_obs ))
1243+ max_loocv = min (self . max_loocv_samples , len (control_obs ))
11721244 if len (control_obs ) > max_loocv :
11731245 indices = rng .choice (len (control_obs ), size = max_loocv , replace = False )
11741246 control_obs = [control_obs [idx ] for idx in indices ]
@@ -1463,6 +1535,7 @@ def get_params(self) -> Dict[str, Any]:
14631535 "alpha" : self .alpha ,
14641536 "variance_method" : self .variance_method ,
14651537 "n_bootstrap" : self .n_bootstrap ,
1538+ "max_loocv_samples" : self .max_loocv_samples ,
14661539 "seed" : self .seed ,
14671540 }
14681541
0 commit comments