Skip to content

Commit 77e2c65

Browse files
authored
Merge pull request #84 from igerber/fix/trop-paper-conformance
Fix/trop paper conformance
2 parents b7957a8 + fa69c5f commit 77e2c65

5 files changed

Lines changed: 621 additions & 199 deletions

File tree

diff_diff/trop.py

Lines changed: 184 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ class _PrecomputedStructures(TypedDict):
6363
control_obs: List[Tuple[int, int]]
6464
"""List of (t, i) tuples for valid control observations."""
6565
control_unit_idx: np.ndarray
66-
"""Array of control unit indices."""
66+
"""Array of never-treated unit indices (for backward compatibility)."""
67+
D: np.ndarray
68+
"""Treatment indicator matrix (n_periods x n_units) for dynamic control sets."""
69+
Y: np.ndarray
70+
"""Outcome matrix (n_periods x n_units)."""
6771
n_units: int
6872
"""Number of units."""
6973
n_periods: int
@@ -529,6 +533,8 @@ def _precompute_structures(
529533
"treated_observations": treated_observations,
530534
"control_obs": control_obs,
531535
"control_unit_idx": control_unit_idx,
536+
"D": D,
537+
"Y": Y,
532538
"n_units": n_units,
533539
"n_periods": n_periods,
534540
}
@@ -778,16 +784,14 @@ def fit(
778784
# Prepare inputs for Rust function
779785
control_mask_u8 = control_mask.astype(np.uint8)
780786
time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
781-
unit_dist_matrix = self._precomputed["unit_dist_matrix"]
782-
control_unit_idx_i64 = control_unit_idx.astype(np.int64)
783787

784788
lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
785789
lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
786790
lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
787791

788792
best_lt, best_lu, best_ln, best_score = _rust_loocv_grid_search(
789-
Y, D.astype(np.float64), control_mask_u8, control_unit_idx_i64,
790-
unit_dist_matrix, time_dist_matrix,
793+
Y, D.astype(np.float64), control_mask_u8,
794+
time_dist_matrix,
791795
lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
792796
self.max_loocv_samples, self.max_iter, self.tol,
793797
self.seed if self.seed is not None else 0
@@ -953,10 +957,16 @@ def _compute_observation_weights(
953957
"""
954958
Compute observation-specific weight matrix for treated observation (i, t).
955959
956-
Following the paper's Algorithm 2 (page 27):
960+
Following the paper's Algorithm 2 (page 27) and Equation 2 (page 7):
957961
- Time weights θ_s^{i,t} = exp(-λ_time × |t - s|)
958962
- Unit weights ω_j^{i,t} = exp(-λ_unit × dist_unit_{-t}(j, i))
959963
964+
IMPORTANT (Issue A fix): The paper's objective sums over ALL observations
965+
where (1 - W_js) is non-zero, which includes pre-treatment observations of
966+
eventually-treated units since W_js = 0 for those. This method computes
967+
weights for ALL units where D[t, j] = 0 at the target period, not just
968+
never-treated units.
969+
960970
Uses pre-computed structures when available for efficiency.
961971
962972
Parameters
@@ -974,7 +984,8 @@ def _compute_observation_weights(
974984
lambda_unit : float
975985
Unit weight decay parameter.
976986
control_unit_idx : np.ndarray
977-
Indices of control units.
987+
Indices of never-treated units (for backward compatibility, but not
988+
used for weight computation - we use D matrix directly).
978989
n_units : int
979990
Number of units.
980991
n_periods : int
@@ -991,21 +1002,30 @@ def _compute_observation_weights(
9911002
# time_dist_matrix[t, s] = |t - s|
9921003
time_weights = np.exp(-lambda_time * self._precomputed["time_dist_matrix"][t, :])
9931004

994-
# Unit weights from pre-computed unit distance matrix
1005+
# Unit weights - computed for ALL units where D[t, j] = 0
1006+
# (Issue A fix: includes pre-treatment obs of eventually-treated units)
9951007
unit_weights = np.zeros(n_units)
1008+
D_stored = self._precomputed["D"]
1009+
Y_stored = self._precomputed["Y"]
1010+
1011+
# Valid control units at time t: D[t, j] == 0
1012+
valid_control_at_t = D_stored[t, :] == 0
9961013

9971014
if lambda_unit == 0:
9981015
# Uniform weights when lambda_unit = 0
999-
unit_weights[:] = 1.0
1016+
# All units not treated at time t get weight 1
1017+
unit_weights[valid_control_at_t] = 1.0
10001018
else:
1001-
# Use pre-computed distances: unit_dist_matrix[j, i] = dist(j, i)
1002-
dist_matrix = self._precomputed["unit_dist_matrix"]
1003-
for j in control_unit_idx:
1004-
dist = dist_matrix[j, i]
1005-
if np.isinf(dist):
1006-
unit_weights[j] = 0.0
1007-
else:
1008-
unit_weights[j] = np.exp(-lambda_unit * dist)
1019+
# Use observation-specific distances with target period excluded
1020+
# (Issue B fix: compute exact per-observation distance)
1021+
for j in range(n_units):
1022+
if valid_control_at_t[j] and j != i:
1023+
# Compute distance excluding target period t
1024+
dist = self._compute_unit_distance_for_obs(Y_stored, D_stored, j, i, t)
1025+
if np.isinf(dist):
1026+
unit_weights[j] = 0.0
1027+
else:
1028+
unit_weights[j] = np.exp(-lambda_unit * dist)
10091029

10101030
# Treated unit i gets weight 1
10111031
unit_weights[i] = 1.0
@@ -1018,19 +1038,25 @@ def _compute_observation_weights(
10181038
dist_time = np.abs(np.arange(n_periods) - t)
10191039
time_weights = np.exp(-lambda_time * dist_time)
10201040

1021-
# Unit distance: pairwise RMSE from each control j to treated i
1041+
# Unit weights - computed for ALL units where D[t, j] = 0
1042+
# (Issue A fix: includes pre-treatment obs of eventually-treated units)
10221043
unit_weights = np.zeros(n_units)
10231044

1045+
# Valid control units at time t: D[t, j] == 0
1046+
valid_control_at_t = D[t, :] == 0
1047+
10241048
if lambda_unit == 0:
10251049
# Uniform weights when lambda_unit = 0
1026-
unit_weights[:] = 1.0
1050+
unit_weights[valid_control_at_t] = 1.0
10271051
else:
1028-
for j in control_unit_idx:
1029-
dist = self._compute_unit_distance_for_obs(Y, D, j, i, t)
1030-
if np.isinf(dist):
1031-
unit_weights[j] = 0.0
1032-
else:
1033-
unit_weights[j] = np.exp(-lambda_unit * dist)
1052+
for j in range(n_units):
1053+
if valid_control_at_t[j] and j != i:
1054+
# Compute distance excluding target period t (Issue B fix)
1055+
dist = self._compute_unit_distance_for_obs(Y, D, j, i, t)
1056+
if np.isinf(dist):
1057+
unit_weights[j] = 0.0
1058+
else:
1059+
unit_weights[j] = np.exp(-lambda_unit * dist)
10341060

10351061
# Treated unit i gets weight 1 (or could be omitted since we fit on controls)
10361062
# We include treated unit's own observation for model fitting
@@ -1102,6 +1128,101 @@ def _soft_threshold_svd(
11021128

11031129
return result
11041130

1131+
def _weighted_nuclear_norm_solve(
1132+
self,
1133+
Y: np.ndarray,
1134+
W: np.ndarray,
1135+
L_init: np.ndarray,
1136+
alpha: np.ndarray,
1137+
beta: np.ndarray,
1138+
lambda_nn: float,
1139+
max_inner_iter: int = 20,
1140+
) -> np.ndarray:
1141+
"""
1142+
Solve weighted nuclear norm problem using iterative weighted soft-impute.
1143+
1144+
Issue C fix: Implements the weighted nuclear norm optimization from the
1145+
paper's Equation 2 (page 7). The full objective is:
1146+
min_L Σ W_{ti}(R_{ti} - L_{ti})² + λ_nn||L||_*
1147+
1148+
This uses a proximal gradient / soft-impute approach (Mazumder et al. 2010):
1149+
L_{k+1} = prox_{λ||·||_*}(L_k + W ⊙ (R - L_k))
1150+
1151+
where W ⊙ denotes element-wise multiplication with normalized weights.
1152+
1153+
IMPORTANT: For observations with W=0 (treated observations), we keep
1154+
L values from the previous iteration rather than setting L = R, which
1155+
would absorb the treatment effect.
1156+
1157+
Parameters
1158+
----------
1159+
Y : np.ndarray
1160+
Outcome matrix (n_periods x n_units).
1161+
W : np.ndarray
1162+
Weight matrix (n_periods x n_units), non-negative. W=0 indicates
1163+
observations that should not be used for fitting (treated obs).
1164+
L_init : np.ndarray
1165+
Initial estimate of L matrix.
1166+
alpha : np.ndarray
1167+
Current unit fixed effects estimate.
1168+
beta : np.ndarray
1169+
Current time fixed effects estimate.
1170+
lambda_nn : float
1171+
Nuclear norm regularization parameter.
1172+
max_inner_iter : int, default=20
1173+
Maximum inner iterations for the proximal algorithm.
1174+
1175+
Returns
1176+
-------
1177+
np.ndarray
1178+
Updated L matrix estimate.
1179+
"""
1180+
# Compute target residual R = Y - α - β
1181+
R = Y - alpha[np.newaxis, :] - beta[:, np.newaxis]
1182+
1183+
# Handle invalid values
1184+
R = np.nan_to_num(R, nan=0.0, posinf=0.0, neginf=0.0)
1185+
1186+
# For observations with W=0 (treated obs), keep L_init instead of R
1187+
# This prevents L from absorbing the treatment effect
1188+
valid_obs_mask = W > 0
1189+
R_masked = np.where(valid_obs_mask, R, L_init)
1190+
1191+
if lambda_nn <= 0:
1192+
# No regularization - just return masked residual
1193+
# Use soft-thresholding with threshold=0 which returns the input
1194+
return R_masked
1195+
1196+
# Normalize weights so max is 1 (for step size stability)
1197+
W_max = np.max(W)
1198+
if W_max > 0:
1199+
W_norm = W / W_max
1200+
else:
1201+
W_norm = W
1202+
1203+
# Initialize L
1204+
L = L_init.copy()
1205+
1206+
# Proximal gradient iteration with weighted soft-impute
1207+
# This solves: min_L ||W^{1/2} ⊙ (R - L)||_F^2 + λ||L||_*
1208+
# Using: L_{k+1} = prox_{λ/η}(L_k + W ⊙ (R - L_k))
1209+
# where η is the step size (we use η = 1 with normalized weights)
1210+
for _ in range(max_inner_iter):
1211+
L_old = L.copy()
1212+
1213+
# Gradient step: L_k + W ⊙ (R - L_k)
1214+
# For W=0 observations, this keeps L_k unchanged
1215+
gradient_step = L + W_norm * (R_masked - L)
1216+
1217+
# Proximal step: soft-threshold singular values
1218+
L = self._soft_threshold_svd(gradient_step, lambda_nn)
1219+
1220+
# Check convergence
1221+
if np.max(np.abs(L - L_old)) < self.tol:
1222+
break
1223+
1224+
return L
1225+
11051226
def _estimate_model(
11061227
self,
11071228
Y: np.ndarray,
@@ -1205,14 +1326,13 @@ def _estimate_model(
12051326
beta_numerator = np.sum(weighted_R_minus_alpha, axis=1) # (n_periods,)
12061327
beta = np.where(time_has_obs, beta_numerator / safe_time_denom, 0.0)
12071328

1208-
# Step 2: Update L with nuclear norm penalty
1209-
# Following Equation 2 (page 7): L = prox_{λ_nn||·||_*}(Y - α - β)
1210-
# The proximal operator for nuclear norm is soft-thresholding of SVD
1211-
R_for_L = Y_safe - alpha[np.newaxis, :] - beta[:, np.newaxis]
1212-
# Impute invalid observations with current L for stable SVD
1213-
R_for_L = np.where(valid_mask, R_for_L, L)
1214-
1215-
L = self._soft_threshold_svd(R_for_L, lambda_nn)
1329+
# Step 2: Update L with weighted nuclear norm penalty
1330+
# Issue C fix: Use weighted soft-impute to properly account for
1331+
# observation weights in the nuclear norm optimization.
1332+
# Following Equation 2 (page 7): min_L Σ W_{ti}(Y - α - β - L)² + λ||L||_*
1333+
L = self._weighted_nuclear_norm_solve(
1334+
Y_safe, W_masked, L, alpha, beta, lambda_nn, max_inner_iter=10
1335+
)
12161336

12171337
# Check convergence
12181338
alpha_diff = np.max(np.abs(alpha - alpha_old))
@@ -1388,21 +1508,15 @@ def _bootstrap_variance(
13881508
# Try Rust backend for parallel bootstrap (5-15x speedup)
13891509
if (HAS_RUST_BACKEND and _rust_bootstrap_trop_variance is not None
13901510
and self._precomputed is not None and Y is not None
1391-
and D is not None and control_unit_idx is not None):
1511+
and D is not None):
13921512
try:
1393-
# Prepare inputs
1394-
treated_observations = self._precomputed["treated_observations"]
1395-
treated_t = np.array([t for t, i in treated_observations], dtype=np.int64)
1396-
treated_i = np.array([i for t, i in treated_observations], dtype=np.int64)
13971513
control_mask = self._precomputed["control_mask"]
1514+
time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
13981515

13991516
bootstrap_estimates, se = _rust_bootstrap_trop_variance(
14001517
Y, D.astype(np.float64),
14011518
control_mask.astype(np.uint8),
1402-
control_unit_idx.astype(np.int64),
1403-
treated_t, treated_i,
1404-
self._precomputed["unit_dist_matrix"],
1405-
self._precomputed["time_dist_matrix"].astype(np.int64),
1519+
time_dist_matrix,
14061520
lambda_time, lambda_unit, lambda_nn,
14071521
self.n_bootstrap, self.max_iter, self.tol,
14081522
self.seed if self.seed is not None else 0
@@ -1422,14 +1536,38 @@ def _bootstrap_variance(
14221536

14231537
# Python implementation (fallback)
14241538
rng = np.random.default_rng(self.seed)
1425-
all_units = data[unit].unique()
1426-
n_units_data = len(all_units)
1539+
1540+
# Issue D fix: Stratified bootstrap sampling
1541+
# Paper's Algorithm 3 (page 27) specifies sampling N_0 control rows
1542+
# and N_1 treated rows separately to preserve treatment ratio
1543+
unit_ever_treated = data.groupby(unit)[treatment].max()
1544+
treated_units = np.array(unit_ever_treated[unit_ever_treated == 1].index)
1545+
control_units = np.array(unit_ever_treated[unit_ever_treated == 0].index)
1546+
1547+
n_treated_units = len(treated_units)
1548+
n_control_units = len(control_units)
14271549

14281550
bootstrap_estimates_list = []
14291551

14301552
for _ in range(self.n_bootstrap):
1431-
# Sample units with replacement
1432-
sampled_units = rng.choice(all_units, size=n_units_data, replace=True)
1553+
# Stratified sampling: sample control and treated units separately
1554+
# This preserves the treatment ratio in each bootstrap sample
1555+
if n_control_units > 0:
1556+
sampled_control = rng.choice(
1557+
control_units, size=n_control_units, replace=True
1558+
)
1559+
else:
1560+
sampled_control = np.array([], dtype=control_units.dtype)
1561+
1562+
if n_treated_units > 0:
1563+
sampled_treated = rng.choice(
1564+
treated_units, size=n_treated_units, replace=True
1565+
)
1566+
else:
1567+
sampled_treated = np.array([], dtype=treated_units.dtype)
1568+
1569+
# Combine stratified samples
1570+
sampled_units = np.concatenate([sampled_control, sampled_treated])
14331571

14341572
# Create bootstrap sample with unique unit IDs
14351573
boot_data = pd.concat([

0 commit comments

Comments
 (0)