Skip to content

Commit fa69c5f

Browse files
igerberclaude
andcommitted
Clean up unused parameters from TROP Rust API
Address code review feedback to remove unused API parameters: Rust backend (trop.rs): - Remove control_unit_idx and unit_dist_matrix from loocv_grid_search - Remove control_unit_idx, treated_obs_t/i, unit_dist_matrix from bootstrap_trop_variance - Remove unit_dist_boot computation in bootstrap (no longer needed) - Remove control_units and unit_dist from internal functions Python (trop.py): - Update _rust_loocv_grid_search call to use new signature - Update _rust_bootstrap_trop_variance call to use new signature - Remove unused variable preparation for removed parameters Tests (test_rust_backend.py): - Update test calls to use new API signatures - Remove unused variable assignments The precomputed unit_dist_matrix is no longer needed by the Rust backend since per-observation distances are computed dynamically to properly exclude the target period per Equation 3 of the paper. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 2449a9e commit fa69c5f

3 files changed

Lines changed: 16 additions & 77 deletions

File tree

diff_diff/trop.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -784,16 +784,14 @@ def fit(
784784
# Prepare inputs for Rust function
785785
control_mask_u8 = control_mask.astype(np.uint8)
786786
time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
787-
unit_dist_matrix = self._precomputed["unit_dist_matrix"]
788-
control_unit_idx_i64 = control_unit_idx.astype(np.int64)
789787

790788
lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
791789
lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
792790
lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)
793791

794792
best_lt, best_lu, best_ln, best_score = _rust_loocv_grid_search(
795-
Y, D.astype(np.float64), control_mask_u8, control_unit_idx_i64,
796-
unit_dist_matrix, time_dist_matrix,
793+
Y, D.astype(np.float64), control_mask_u8,
794+
time_dist_matrix,
797795
lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
798796
self.max_loocv_samples, self.max_iter, self.tol,
799797
self.seed if self.seed is not None else 0
@@ -1510,21 +1508,15 @@ def _bootstrap_variance(
15101508
# Try Rust backend for parallel bootstrap (5-15x speedup)
15111509
if (HAS_RUST_BACKEND and _rust_bootstrap_trop_variance is not None
15121510
and self._precomputed is not None and Y is not None
1513-
and D is not None and control_unit_idx is not None):
1511+
and D is not None):
15141512
try:
1515-
# Prepare inputs
1516-
treated_observations = self._precomputed["treated_observations"]
1517-
treated_t = np.array([t for t, i in treated_observations], dtype=np.int64)
1518-
treated_i = np.array([i for t, i in treated_observations], dtype=np.int64)
15191513
control_mask = self._precomputed["control_mask"]
1514+
time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
15201515

15211516
bootstrap_estimates, se = _rust_bootstrap_trop_variance(
15221517
Y, D.astype(np.float64),
15231518
control_mask.astype(np.uint8),
1524-
control_unit_idx.astype(np.int64),
1525-
treated_t, treated_i,
1526-
self._precomputed["unit_dist_matrix"],
1527-
self._precomputed["time_dist_matrix"].astype(np.int64),
1519+
time_dist_matrix,
15281520
lambda_time, lambda_unit, lambda_nn,
15291521
self.n_bootstrap, self.max_iter, self.tol,
15301522
self.seed if self.seed is not None else 0

rust/src/trop.rs

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,13 @@ fn compute_pair_distance(
172172
/// # Returns
173173
/// (best_lambda_time, best_lambda_unit, best_lambda_nn, best_score)
174174
#[pyfunction]
175-
#[pyo3(signature = (y, d, control_mask, control_unit_idx, unit_dist_matrix, time_dist_matrix, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_loocv_samples, max_iter, tol, seed))]
175+
#[pyo3(signature = (y, d, control_mask, time_dist_matrix, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_loocv_samples, max_iter, tol, seed))]
176176
#[allow(clippy::too_many_arguments)]
177177
pub fn loocv_grid_search<'py>(
178178
_py: Python<'py>,
179179
y: PyReadonlyArray2<'py, f64>,
180180
d: PyReadonlyArray2<'py, f64>,
181181
control_mask: PyReadonlyArray2<'py, u8>,
182-
control_unit_idx: PyReadonlyArray1<'py, i64>,
183-
unit_dist_matrix: PyReadonlyArray2<'py, f64>,
184182
time_dist_matrix: PyReadonlyArray2<'py, i64>,
185183
lambda_time_grid: PyReadonlyArray1<'py, f64>,
186184
lambda_unit_grid: PyReadonlyArray1<'py, f64>,
@@ -193,19 +191,11 @@ pub fn loocv_grid_search<'py>(
193191
let y_arr = y.as_array();
194192
let d_arr = d.as_array();
195193
let control_mask_arr = control_mask.as_array();
196-
let control_unit_idx_arr = control_unit_idx.as_array();
197-
let unit_dist_arr = unit_dist_matrix.as_array();
198194
let time_dist_arr = time_dist_matrix.as_array();
199195
let lambda_time_vec: Vec<f64> = lambda_time_grid.as_array().to_vec();
200196
let lambda_unit_vec: Vec<f64> = lambda_unit_grid.as_array().to_vec();
201197
let lambda_nn_vec: Vec<f64> = lambda_nn_grid.as_array().to_vec();
202198

203-
// Convert control_unit_idx to Vec<usize>
204-
let control_units: Vec<usize> = control_unit_idx_arr
205-
.iter()
206-
.map(|&idx| idx as usize)
207-
.collect();
208-
209199
// Get control observations for LOOCV
210200
let control_obs = get_control_observations(
211201
&y_arr,
@@ -232,8 +222,6 @@ pub fn loocv_grid_search<'py>(
232222
&y_arr,
233223
&d_arr,
234224
&control_mask_arr,
235-
&control_units,
236-
&unit_dist_arr,
237225
&time_dist_arr,
238226
&control_obs,
239227
lambda_time,
@@ -293,8 +281,6 @@ fn loocv_score_for_params(
293281
y: &ArrayView2<f64>,
294282
d: &ArrayView2<f64>,
295283
control_mask: &ArrayView2<u8>,
296-
control_units: &[usize],
297-
unit_dist: &ArrayView2<f64>,
298284
time_dist: &ArrayView2<i64>,
299285
control_obs: &[(usize, usize)],
300286
lambda_time: f64,
@@ -311,7 +297,6 @@ fn loocv_score_for_params(
311297

312298
for &(t, i) in control_obs {
313299
// Compute observation-specific weight matrix
314-
// Issue A+B fix: pass y and d for dynamic control sets and per-obs distances
315300
let weight_matrix = compute_weight_matrix(
316301
y,
317302
d,
@@ -321,8 +306,6 @@ fn loocv_score_for_params(
321306
t,
322307
lambda_time,
323308
lambda_unit,
324-
control_units,
325-
unit_dist,
326309
time_dist,
327310
);
328311

@@ -410,8 +393,6 @@ fn compute_weight_matrix(
410393
target_period: usize,
411394
lambda_time: f64,
412395
lambda_unit: f64,
413-
_control_units: &[usize], // Kept for API compatibility but not used
414-
_unit_dist: &ArrayView2<f64>, // Not used - we compute per-observation distances
415396
time_dist: &ArrayView2<i64>,
416397
) -> Array2<f64> {
417398
// Time weights for this target period: θ_s = exp(-λ_time × |t - s|)
@@ -707,17 +688,13 @@ fn max_abs_diff_2d(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
707688
/// # Returns
708689
/// (bootstrap_estimates, standard_error)
709690
#[pyfunction]
710-
#[pyo3(signature = (y, d, control_mask, control_unit_idx, treated_obs_t, treated_obs_i, unit_dist_matrix, time_dist_matrix, lambda_time, lambda_unit, lambda_nn, n_bootstrap, max_iter, tol, seed))]
691+
#[pyo3(signature = (y, d, control_mask, time_dist_matrix, lambda_time, lambda_unit, lambda_nn, n_bootstrap, max_iter, tol, seed))]
711692
#[allow(clippy::too_many_arguments)]
712693
pub fn bootstrap_trop_variance<'py>(
713694
py: Python<'py>,
714695
y: PyReadonlyArray2<'py, f64>,
715696
d: PyReadonlyArray2<'py, f64>,
716697
control_mask: PyReadonlyArray2<'py, u8>,
717-
control_unit_idx: PyReadonlyArray1<'py, i64>,
718-
treated_obs_t: PyReadonlyArray1<'py, i64>,
719-
treated_obs_i: PyReadonlyArray1<'py, i64>,
720-
unit_dist_matrix: PyReadonlyArray2<'py, f64>,
721698
time_dist_matrix: PyReadonlyArray2<'py, i64>,
722699
lambda_time: f64,
723700
lambda_unit: f64,
@@ -730,17 +707,11 @@ pub fn bootstrap_trop_variance<'py>(
730707
let y_arr = y.as_array().to_owned();
731708
let d_arr = d.as_array().to_owned();
732709
let control_mask_arr = control_mask.as_array().to_owned();
733-
let unit_dist_arr = unit_dist_matrix.as_array().to_owned();
734710
let time_dist_arr = time_dist_matrix.as_array().to_owned();
735711

736712
let n_units = y_arr.ncols();
737713
let n_periods = y_arr.nrows();
738714

739-
// Note: control_unit_idx, treated_obs_t, treated_obs_i are passed for API
740-
// compatibility but not used directly - each bootstrap iteration recomputes
741-
// control units and treated observations from the resampled data.
742-
let _ = (control_unit_idx, treated_obs_t, treated_obs_i);
743-
744715
// Issue D fix: Identify treated and control units for stratified sampling
745716
// Following paper's Algorithm 3 (page 27): sample N_0 control and N_1 treated separately
746717
let mut original_treated_units: Vec<usize> = Vec::new();
@@ -784,18 +755,13 @@ pub fn bootstrap_trop_variance<'py>(
784755
let mut y_boot = Array2::<f64>::zeros((n_periods, n_units));
785756
let mut d_boot = Array2::<f64>::zeros((n_periods, n_units));
786757
let mut control_mask_boot = Array2::<u8>::zeros((n_periods, n_units));
787-
let mut unit_dist_boot = Array2::<f64>::zeros((n_units, n_units));
788758

789759
for (new_idx, &old_idx) in sampled_units.iter().enumerate() {
790760
for t in 0..n_periods {
791761
y_boot[[t, new_idx]] = y_arr[[t, old_idx]];
792762
d_boot[[t, new_idx]] = d_arr[[t, old_idx]];
793763
control_mask_boot[[t, new_idx]] = control_mask_arr[[t, old_idx]];
794764
}
795-
796-
for (new_j, &old_j) in sampled_units.iter().enumerate() {
797-
unit_dist_boot[[new_idx, new_j]] = unit_dist_arr[[old_idx, old_j]];
798-
}
799765
}
800766

801767
// Get treated observations in bootstrap sample
@@ -829,7 +795,6 @@ pub fn bootstrap_trop_variance<'py>(
829795
let mut tau_values = Vec::with_capacity(boot_treated.len());
830796

831797
for (t, i) in boot_treated {
832-
// Issue A+B fix: pass y and d for dynamic control sets and per-obs distances
833798
let weight_matrix = compute_weight_matrix(
834799
&y_boot.view(),
835800
&d_boot.view(),
@@ -839,8 +804,6 @@ pub fn bootstrap_trop_variance<'py>(
839804
t,
840805
lambda_time,
841806
lambda_unit,
842-
&boot_control_units,
843-
&unit_dist_boot.view(),
844807
&time_dist_arr.view(),
845808
);
846809

tests/test_rust_backend.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -657,11 +657,8 @@ def test_loocv_grid_search_returns_valid_params(self):
657657
D[6:, 0] = 1.0
658658

659659
control_mask = (D == 0).astype(np.uint8)
660-
control_unit_idx = np.array([1, 2, 3, 4, 5], dtype=np.int64)
661660

662-
# Compute distance matrices
663-
from diff_diff._rust_backend import compute_unit_distance_matrix
664-
unit_dist = compute_unit_distance_matrix(Y, D)
661+
# Compute time distance matrix
665662
time_dist = np.abs(
666663
np.arange(n_periods)[:, np.newaxis] - np.arange(n_periods)[np.newaxis, :]
667664
).astype(np.int64)
@@ -671,8 +668,7 @@ def test_loocv_grid_search_returns_valid_params(self):
671668
lambda_nn = np.array([0.0, 0.1], dtype=np.float64)
672669

673670
best_lt, best_lu, best_ln, score = loocv_grid_search(
674-
Y, D, control_mask, control_unit_idx,
675-
unit_dist, time_dist,
671+
Y, D, control_mask, time_dist,
676672
lambda_time, lambda_unit, lambda_nn,
677673
50, 100, 1e-6, 42
678674
)
@@ -685,7 +681,7 @@ def test_loocv_grid_search_returns_valid_params(self):
685681

686682
def test_bootstrap_variance_shape(self):
687683
"""Test bootstrap returns correct shapes."""
688-
from diff_diff._rust_backend import bootstrap_trop_variance, compute_unit_distance_matrix
684+
from diff_diff._rust_backend import bootstrap_trop_variance
689685

690686
np.random.seed(42)
691687
n_periods, n_units = 8, 6
@@ -694,20 +690,15 @@ def test_bootstrap_variance_shape(self):
694690
D[6:, 0] = 1.0 # Treat unit 0 in last 2 periods
695691

696692
control_mask = (D == 0).astype(np.uint8)
697-
control_unit_idx = np.array([1, 2, 3, 4, 5], dtype=np.int64)
698-
treated_t = np.array([6, 7], dtype=np.int64)
699-
treated_i = np.array([0, 0], dtype=np.int64)
700693

701-
unit_dist = compute_unit_distance_matrix(Y, D)
694+
# Compute time distance matrix
702695
time_dist = np.abs(
703696
np.arange(n_periods)[:, np.newaxis] - np.arange(n_periods)[np.newaxis, :]
704697
).astype(np.int64)
705698

706699
n_bootstrap = 20
707700
estimates, se = bootstrap_trop_variance(
708-
Y, D, control_mask, control_unit_idx,
709-
treated_t, treated_i,
710-
unit_dist, time_dist,
701+
Y, D, control_mask, time_dist,
711702
1.0, 1.0, 0.1, # lambda values
712703
n_bootstrap, 100, 1e-6, 42
713704
)
@@ -718,7 +709,7 @@ def test_bootstrap_variance_shape(self):
718709

719710
def test_bootstrap_reproducibility(self):
720711
"""Test bootstrap is reproducible with same seed."""
721-
from diff_diff._rust_backend import bootstrap_trop_variance, compute_unit_distance_matrix
712+
from diff_diff._rust_backend import bootstrap_trop_variance
722713

723714
np.random.seed(42)
724715
n_periods, n_units = 8, 6
@@ -727,26 +718,19 @@ def test_bootstrap_reproducibility(self):
727718
D[6:, 0] = 1.0
728719

729720
control_mask = (D == 0).astype(np.uint8)
730-
control_unit_idx = np.array([1, 2, 3, 4, 5], dtype=np.int64)
731-
treated_t = np.array([6, 7], dtype=np.int64)
732-
treated_i = np.array([0, 0], dtype=np.int64)
733721

734-
unit_dist = compute_unit_distance_matrix(Y, D)
722+
# Compute time distance matrix
735723
time_dist = np.abs(
736724
np.arange(n_periods)[:, np.newaxis] - np.arange(n_periods)[np.newaxis, :]
737725
).astype(np.int64)
738726

739727
# Run twice with same seed
740728
est1, se1 = bootstrap_trop_variance(
741-
Y, D, control_mask, control_unit_idx,
742-
treated_t, treated_i,
743-
unit_dist, time_dist,
729+
Y, D, control_mask, time_dist,
744730
1.0, 1.0, 0.1, 20, 100, 1e-6, 42
745731
)
746732
est2, se2 = bootstrap_trop_variance(
747-
Y, D, control_mask, control_unit_idx,
748-
treated_t, treated_i,
749-
unit_dist, time_dist,
733+
Y, D, control_mask, time_dist,
750734
1.0, 1.0, 0.1, 20, 100, 1e-6, 42
751735
)
752736

0 commit comments

Comments
 (0)