Skip to content

Commit 77de23c

Browse files
igerberclaude
andcommitted
Address PR review feedback: vectorize distances, add TypedDict, extract constants
Based on PR #76 code review recommendations: 1. **Issue 1.2 - Vectorize unit distance computation**: Replace nested Python loop in _compute_all_unit_distances() with fully vectorized numpy operations using broadcasting. Eliminates O(n²) Python loop in favor of optimized numpy/BLAS operations. 2. **Add TypedDict for _precomputed structure**: Add _PrecomputedStructures TypedDict for better IDE support and documentation of the pre-computed data structures. 3. **Extract magic numbers to class constants**: - DEFAULT_LOOCV_MAX_SAMPLES = 100 (configurable via max_loocv_samples param) - CONVERGENCE_TOL_SVD = 1e-10 (SVD truncation tolerance) 4. **Add max_loocv_samples parameter**: Make LOOCV subsample size configurable instead of hardcoded 100. Updated __init__, get_params, and docstrings. 5. **Add algorithm reference comments**: Enhanced alternating minimization documentation with paper equation references (Equation 2, Algorithm 1). All 33 tests pass. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent d17dc87 commit 77de23c

1 file changed

Lines changed: 105 additions & 32 deletions

File tree

diff_diff/trop.py

Lines changed: 105 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,42 @@
2525
import pandas as pd
2626
from scipy import stats
2727

28+
try:
29+
from typing import TypedDict
30+
except ImportError:
31+
from typing_extensions import TypedDict
32+
2833
from diff_diff.results import _get_significance_stars
2934
from diff_diff.utils import compute_confidence_interval, compute_p_value
3035

3136

37+
class _PrecomputedStructures(TypedDict):
38+
"""Type definition for pre-computed structures used across LOOCV iterations.
39+
40+
These structures are computed once in `_precompute_structures()` and reused
41+
to avoid redundant computation during LOOCV and final estimation.
42+
"""
43+
44+
unit_dist_matrix: np.ndarray
45+
"""Pairwise unit distance matrix (n_units x n_units)."""
46+
time_dist_matrix: np.ndarray
47+
"""Time distance matrix where [t, s] = |t - s| (n_periods x n_periods)."""
48+
control_mask: np.ndarray
49+
"""Boolean mask for control observations (D == 0)."""
50+
treated_mask: np.ndarray
51+
"""Boolean mask for treated observations (D == 1)."""
52+
treated_observations: List[Tuple[int, int]]
53+
"""List of (t, i) tuples for treated observations."""
54+
control_obs: List[Tuple[int, int]]
55+
"""List of (t, i) tuples for valid control observations."""
56+
control_unit_idx: np.ndarray
57+
"""Array of control unit indices."""
58+
n_units: int
59+
"""Number of units."""
60+
n_periods: int
61+
"""Number of time periods."""
62+
63+
3264
@dataclass
3365
class TROPResults:
3466
"""
@@ -327,6 +359,11 @@ class TROP:
327359
Method for variance estimation: 'bootstrap' or 'jackknife'.
328360
n_bootstrap : int, default=200
329361
Number of replications for variance estimation.
362+
max_loocv_samples : int, default=100
363+
Maximum control observations to use in LOOCV for tuning parameter
364+
selection. Subsampling is used for computational tractability as
365+
noted in the paper. Increase for more precise tuning at the cost
366+
of computational time.
330367
seed : int, optional
331368
Random seed for reproducibility.
332369
@@ -357,6 +394,23 @@ class TROP:
357394
Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
358395
"""
359396

397+
# Class constants
398+
DEFAULT_LOOCV_MAX_SAMPLES: int = 100
399+
"""Maximum control observations to use in LOOCV (for computational tractability).
400+
401+
As noted in the paper's footnote, LOOCV is subsampled for computational
402+
tractability. This constant controls the maximum number of control observations
403+
used in each LOOCV evaluation. Increase for more precise tuning at the cost
404+
of computational time.
405+
"""
406+
407+
CONVERGENCE_TOL_SVD: float = 1e-10
408+
"""Tolerance for singular value truncation in soft-thresholding.
409+
410+
Singular values below this threshold after soft-thresholding are treated
411+
as zero to improve numerical stability.
412+
"""
413+
360414
def __init__(
361415
self,
362416
lambda_time_grid: Optional[List[float]] = None,
@@ -367,6 +421,7 @@ def __init__(
367421
alpha: float = 0.05,
368422
variance_method: str = 'bootstrap',
369423
n_bootstrap: int = 200,
424+
max_loocv_samples: int = 100,
370425
seed: Optional[int] = None,
371426
):
372427
# Default grids from paper
@@ -379,6 +434,7 @@ def __init__(
379434
self.alpha = alpha
380435
self.variance_method = variance_method
381436
self.n_bootstrap = n_bootstrap
437+
self.max_loocv_samples = max_loocv_samples
382438
self.seed = seed
383439

384440
# Validate parameters
@@ -395,7 +451,7 @@ def __init__(
395451
self._optimal_lambda: Optional[Tuple[float, float, float]] = None
396452

397453
# Pre-computed structures (set during fit)
398-
self._precomputed: Optional[Dict[str, Any]] = None
454+
self._precomputed: Optional[_PrecomputedStructures] = None
399455

400456
def _precompute_structures(
401457
self,
@@ -404,7 +460,7 @@ def _precompute_structures(
404460
control_unit_idx: np.ndarray,
405461
n_units: int,
406462
n_periods: int,
407-
) -> Dict[str, Any]:
463+
) -> _PrecomputedStructures:
408464
"""
409465
Pre-compute data structures that are reused across LOOCV and estimation.
410466
@@ -428,7 +484,7 @@ def _precompute_structures(
428484
429485
Returns
430486
-------
431-
dict
487+
_PrecomputedStructures
432488
Pre-computed structures for efficient reuse.
433489
"""
434490
# Compute pairwise unit distances (for all observation-specific weights)
@@ -481,6 +537,9 @@ def _compute_all_unit_distances(
481537
observations, which provides a good approximation. The exact per-observation
482538
distances are refined when needed.
483539
540+
Uses vectorized numpy operations with masked arrays for O(n²) complexity
541+
but with highly optimized inner loops via numpy/BLAS.
542+
484543
Parameters
485544
----------
486545
Y : np.ndarray
@@ -500,27 +559,36 @@ def _compute_all_unit_distances(
500559
# Mask for valid observations: control periods only (D=0), non-NaN
501560
valid_mask = (D == 0) & ~np.isnan(Y)
502561

503-
# Initialize distance matrix
504-
dist_matrix = np.full((n_units, n_units), np.inf)
562+
# Replace invalid values with NaN for masked computation
563+
Y_masked = np.where(valid_mask, Y, np.nan)
564+
565+
# Transpose to (n_units, n_periods) for easier broadcasting
566+
Y_T = Y_masked.T # (n_units, n_periods)
567+
568+
# Compute pairwise squared differences using broadcasting
569+
# Y_T[:, np.newaxis, :] has shape (n_units, 1, n_periods)
570+
# Y_T[np.newaxis, :, :] has shape (1, n_units, n_periods)
571+
# diff has shape (n_units, n_units, n_periods)
572+
diff = Y_T[:, np.newaxis, :] - Y_T[np.newaxis, :, :]
573+
sq_diff = diff ** 2
574+
575+
# Count valid (non-NaN) observations per pair
576+
# A difference is valid only if both units have valid observations
577+
valid_diff = ~np.isnan(sq_diff)
578+
n_valid = np.sum(valid_diff, axis=2) # (n_units, n_units)
505579

506-
# Compute pairwise distances using vectorized operations
507-
# Y has shape (n_periods, n_units)
508-
# We want sqrt(mean((Y[:, i] - Y[:, j])^2)) for valid periods
580+
# Compute sum of squared differences (treating NaN as 0)
581+
sq_diff_sum = np.nansum(sq_diff, axis=2) # (n_units, n_units)
509582

510-
# For each pair of units, find periods where both are valid
511-
for i in range(n_units):
512-
valid_i = valid_mask[:, i]
513-
for j in range(i, n_units):
514-
valid_j = valid_mask[:, j]
515-
both_valid = valid_i & valid_j
583+
# Compute RMSE distance: sqrt(sum / n_valid)
584+
# Avoid division by zero
585+
with np.errstate(divide='ignore', invalid='ignore'):
586+
dist_matrix = np.sqrt(sq_diff_sum / n_valid)
516587

517-
if np.any(both_valid):
518-
sq_diff = (Y[both_valid, i] - Y[both_valid, j]) ** 2
519-
dist = np.sqrt(np.mean(sq_diff))
520-
dist_matrix[i, j] = dist
521-
dist_matrix[j, i] = dist
588+
# Set pairs with no valid observations to inf
589+
dist_matrix = np.where(n_valid > 0, dist_matrix, np.inf)
522590

523-
# Set diagonal to 0
591+
# Ensure diagonal is 0 (same unit distance)
524592
np.fill_diagonal(dist_matrix, 0.0)
525593

526594
return dist_matrix
@@ -970,7 +1038,7 @@ def _soft_threshold_svd(
9701038
s_thresh = np.maximum(s - threshold, 0)
9711039

9721040
# Use truncated reconstruction with only non-zero singular values
973-
nonzero_mask = s_thresh > 1e-10
1041+
nonzero_mask = s_thresh > self.CONVERGENCE_TOL_SVD
9741042
if not np.any(nonzero_mask):
9751043
return np.zeros_like(M)
9761044

@@ -1065,35 +1133,39 @@ def _estimate_model(
10651133
# Replace NaN in Y with 0 for computation (mask handles exclusion)
10661134
Y_safe = np.where(np.isnan(Y), 0.0, Y)
10671135

1068-
# Alternating minimization
1136+
# Alternating minimization following Algorithm 1 (page 9)
1137+
# Minimize: Σ W_{ti}(Y_{ti} - α_i - β_t - L_{ti})² + λ_nn||L||_*
10691138
for _ in range(self.max_iter):
10701139
alpha_old = alpha.copy()
10711140
beta_old = beta.copy()
10721141
L_old = L.copy()
10731142

1074-
# Step 1: Update α and β (weighted means) - VECTORIZED
1143+
# Step 1: Update α and β (weighted least squares)
1144+
# Following Equation 2 (page 7), fix L and solve for α, β
10751145
# R = Y - L (residual without fixed effects)
10761146
R = Y_safe - L
10771147

1078-
# Alpha update: α_i = Σ_t W_{ti} (R_{ti} - β_t) / Σ_t W_{ti}
1079-
# Compute weighted sum of (R - β) per unit
1148+
# Alpha update (unit fixed effects):
1149+
# α_i = argmin_α Σ_t W_{ti}(R_{ti} - α - β_t)²
1150+
# Solution: α_i = Σ_t W_{ti}(R_{ti} - β_t) / Σ_t W_{ti}
10801151
R_minus_beta = R - beta[:, np.newaxis] # (n_periods, n_units)
10811152
weighted_R_minus_beta = W_masked * R_minus_beta
10821153
alpha_numerator = np.sum(weighted_R_minus_beta, axis=0) # (n_units,)
10831154
alpha = np.where(unit_has_obs, alpha_numerator / safe_unit_denom, 0.0)
10841155

1085-
# Beta update: β_t = Σ_i W_{ti} (R_{ti} - α_i) / Σ_i W_{ti}
1086-
# Compute weighted sum of (R - α) per period
1156+
# Beta update (time fixed effects):
1157+
# β_t = argmin_β Σ_i W_{ti}(R_{ti} - α_i - β)²
1158+
# Solution: β_t = Σ_i W_{ti}(R_{ti} - α_i) / Σ_i W_{ti}
10871159
R_minus_alpha = R - alpha[np.newaxis, :] # (n_periods, n_units)
10881160
weighted_R_minus_alpha = W_masked * R_minus_alpha
10891161
beta_numerator = np.sum(weighted_R_minus_alpha, axis=1) # (n_periods,)
10901162
beta = np.where(time_has_obs, beta_numerator / safe_time_denom, 0.0)
10911163

1092-
# Step 2: Update L with nuclear norm penalty - VECTORIZED
1093-
# R_for_L = Y - α - β where valid, else L (impute missing)
1094-
# Vectorized: broadcast alpha and beta
1164+
# Step 2: Update L with nuclear norm penalty
1165+
# Following Equation 2 (page 7): L = prox_{λ_nn||·||_*}(Y - α - β)
1166+
# The proximal operator for nuclear norm is soft-thresholding of SVD
10951167
R_for_L = Y_safe - alpha[np.newaxis, :] - beta[:, np.newaxis]
1096-
# Impute invalid observations with current L
1168+
# Impute invalid observations with current L for stable SVD
10971169
R_for_L = np.where(valid_mask, R_for_L, L)
10981170

10991171
L = self._soft_threshold_svd(R_for_L, lambda_nn)
@@ -1168,7 +1240,7 @@ def _loocv_score_obs_specific(
11681240

11691241
# Subsample for computational tractability (as noted in paper's footnote)
11701242
rng = np.random.default_rng(self.seed)
1171-
max_loocv = min(100, len(control_obs))
1243+
max_loocv = min(self.max_loocv_samples, len(control_obs))
11721244
if len(control_obs) > max_loocv:
11731245
indices = rng.choice(len(control_obs), size=max_loocv, replace=False)
11741246
control_obs = [control_obs[idx] for idx in indices]
@@ -1463,6 +1535,7 @@ def get_params(self) -> Dict[str, Any]:
14631535
"alpha": self.alpha,
14641536
"variance_method": self.variance_method,
14651537
"n_bootstrap": self.n_bootstrap,
1538+
"max_loocv_samples": self.max_loocv_samples,
14661539
"seed": self.seed,
14671540
}
14681541

0 commit comments

Comments
 (0)