@@ -758,23 +758,48 @@ def test_bootstrap_reproducibility(self):
758758class TestTROPRustVsNumpy :
759759 """Tests comparing TROP Rust and NumPy implementations for numerical equivalence."""
760760
761- def test_full_trop_estimation_matches (self ):
762- """Test end-to-end TROP estimation matches with/without Rust."""
763- import os
761+ def test_distance_matrix_matches_numpy (self ):
762+ """Test Rust distance matrix matches NumPy implementation exactly."""
763+ from diff_diff ._rust_backend import compute_unit_distance_matrix
764+ from diff_diff .trop import TROP
765+
766+ np .random .seed (42 )
767+ n_periods , n_units = 12 , 8
768+ Y = np .random .randn (n_periods , n_units )
769+ D = np .zeros ((n_periods , n_units ))
770+ # Add some treatment to make it realistic
771+ D [8 :, 0 ] = 1.0
772+ D [10 :, 1 ] = 1.0
773+
774+ # Rust implementation
775+ rust_dist = compute_unit_distance_matrix (Y , D )
776+
777+ # NumPy implementation (directly call the private method)
778+ trop = TROP ()
779+ numpy_dist = trop ._compute_all_unit_distances (Y , D , n_units , n_periods )
780+
781+ np .testing .assert_array_almost_equal (
782+ rust_dist , numpy_dist , decimal = 10 ,
783+ err_msg = "Distance matrices should match exactly"
784+ )
785+
786+ def test_trop_produces_valid_results (self ):
787+ """Test TROP with Rust backend produces valid estimation results."""
764788 import pandas as pd
765789 from diff_diff import TROP
766790
767791 np .random .seed (42 )
768792
769- # Create small test data
793+ # Create test data with known treatment effect
770794 n_units = 10
771795 n_periods = 8
796+ true_effect = 2.0
772797 data = []
773798
774799 for i in range (n_units ):
775800 for t in range (n_periods ):
776- is_treated = (i == 0 ) and (t >= 6 ) # Unit 0 treated from period 6
777- y = 1.0 + 0.5 * i + 0.3 * t + (2.0 if is_treated else 0 ) + np .random .randn () * 0.5
801+ is_treated = (i == 0 ) and (t >= 6 )
802+ y = 1.0 + 0.5 * i + 0.3 * t + (true_effect if is_treated else 0 ) + np .random .randn () * 0.5
778803 data .append ({
779804 'unit' : i ,
780805 'time' : t ,
@@ -784,65 +809,30 @@ def test_full_trop_estimation_matches(self):
784809
785810 df = pd .DataFrame (data )
786811
787- # Fit with Rust backend
788- trop_rust = TROP (
812+ # Fit with current backend (Rust if available)
813+ trop = TROP (
789814 lambda_time_grid = [0.0 , 1.0 ],
790815 lambda_unit_grid = [0.0 , 1.0 ],
791816 lambda_nn_grid = [0.0 , 0.1 ],
792817 n_bootstrap = 20 ,
793818 max_loocv_samples = 30 ,
794819 seed = 42
795820 )
796- results_rust = trop_rust .fit (df , 'outcome' , 'treated' , 'unit' , 'time' )
797-
798- # Fit with Python backend (force Python mode)
799- original_env = os .environ .get ('DIFF_DIFF_BACKEND' )
800- try :
801- os .environ ['DIFF_DIFF_BACKEND' ] = 'python'
802-
803- # Need to reimport to get Python-only version
804- import importlib
805- import diff_diff ._backend
806- import diff_diff .trop
807- importlib .reload (diff_diff ._backend )
808- importlib .reload (diff_diff .trop )
809- from diff_diff .trop import TROP as TROP_Python
810-
811- trop_python = TROP_Python (
812- lambda_time_grid = [0.0 , 1.0 ],
813- lambda_unit_grid = [0.0 , 1.0 ],
814- lambda_nn_grid = [0.0 , 0.1 ],
815- n_bootstrap = 20 ,
816- max_loocv_samples = 30 ,
817- seed = 42
818- )
819- results_python = trop_python .fit (df , 'outcome' , 'treated' , 'unit' , 'time' )
820-
821- # ATT should be very close (within numerical precision)
822- assert abs (results_rust .att - results_python .att ) < 0.5 , \
823- f"ATT mismatch: Rust={ results_rust .att :.4f} , Python={ results_python .att :.4f} "
824-
825- # Tuning parameters should match (same grid search)
826- assert results_rust .lambda_time == results_python .lambda_time , \
827- "lambda_time should match"
828- assert results_rust .lambda_unit == results_python .lambda_unit , \
829- "lambda_unit should match"
830- assert results_rust .lambda_nn == results_python .lambda_nn , \
831- "lambda_nn should match"
832-
833- finally :
834- # Restore original environment
835- if original_env is not None :
836- os .environ ['DIFF_DIFF_BACKEND' ] = original_env
837- else :
838- os .environ .pop ('DIFF_DIFF_BACKEND' , None )
839-
840- # Reload modules to restore Rust backend
841- import importlib
842- import diff_diff ._backend
843- import diff_diff .trop
844- importlib .reload (diff_diff ._backend )
845- importlib .reload (diff_diff .trop )
821+ results = trop .fit (df , 'outcome' , 'treated' , 'unit' , 'time' )
822+
823+ # Check results are valid
824+ assert np .isfinite (results .att ), "ATT should be finite"
825+ assert np .isfinite (results .se ), "SE should be finite"
826+ assert results .se >= 0 , "SE should be non-negative"
827+
828+ # ATT should be in reasonable range of true effect
829+ assert abs (results .att - true_effect ) < 2.0 , \
830+ f"ATT { results .att :.2f} should be close to true effect { true_effect } "
831+
832+ # Tuning parameters should be from the grid
833+ assert results .lambda_time in [0.0 , 1.0 ]
834+ assert results .lambda_unit in [0.0 , 1.0 ]
835+ assert results .lambda_nn in [0.0 , 0.1 ]
846836
847837
848838class TestFallbackWhenNoRust :
0 commit comments