@@ -1075,7 +1075,8 @@ fn compute_joint_weights(
10751075 let n_pre = n_periods. saturating_sub ( treated_periods) ;
10761076
10771077 // Compute average treated trajectory
1078- let mut average_treated = Array1 :: < f64 > :: zeros ( n_periods) ;
1078+ // Initialize to NaN so periods with all-NaN treated data stay NaN (excluded from RMSE)
1079+ let mut average_treated = Array1 :: < f64 > :: from_elem ( n_periods, f64:: NAN ) ;
10791080 if !treated_unit_idx. is_empty ( ) {
10801081 for t in 0 ..n_periods {
10811082 let mut sum = 0.0 ;
@@ -1089,6 +1090,7 @@ fn compute_joint_weights(
10891090 if count > 0 {
10901091 average_treated[ t] = sum / count as f64 ;
10911092 }
1093+ // If count == 0, average_treated[t] stays NaN (correctly excluded)
10921094 }
10931095 }
10941096
@@ -1163,7 +1165,8 @@ fn solve_joint_no_lowrank(
11631165
11641166 for t in 0 ..n_periods {
11651167 for i in 0 ..n_units {
1166- let w = delta[ [ t, i] ] ;
1168+ // NaN outcomes get zero weight (not imputed to 0.0 with active weight)
1169+ let w = if y[ [ t, i] ] . is_finite ( ) { delta[ [ t, i] ] } else { 0.0 } ;
11671170 let y_ti = if y[ [ t, i] ] . is_finite ( ) { y[ [ t, i] ] } else { 0.0 } ;
11681171
11691172 sum_w += w;
@@ -1196,7 +1199,8 @@ fn solve_joint_no_lowrank(
11961199 if sum_w_by_unit[ i] > 1e-10 {
11971200 let mut num = 0.0 ;
11981201 for t in 0 ..n_periods {
1199- let w = delta[ [ t, i] ] ;
1202+ // NaN outcomes get zero weight
1203+ let w = if y[ [ t, i] ] . is_finite ( ) { delta[ [ t, i] ] } else { 0.0 } ;
12001204 let y_ti = if y[ [ t, i] ] . is_finite ( ) { y[ [ t, i] ] } else { 0.0 } ;
12011205 num += w * ( y_ti - mu - beta[ t] - tau * d[ [ t, i] ] ) ;
12021206 }
@@ -1209,7 +1213,8 @@ fn solve_joint_no_lowrank(
12091213 if sum_w_by_period[ t] > 1e-10 {
12101214 let mut num = 0.0 ;
12111215 for i in 0 ..n_units {
1212- let w = delta[ [ t, i] ] ;
1216+ // NaN outcomes get zero weight
1217+ let w = if y[ [ t, i] ] . is_finite ( ) { delta[ [ t, i] ] } else { 0.0 } ;
12131218 let y_ti = if y[ [ t, i] ] . is_finite ( ) { y[ [ t, i] ] } else { 0.0 } ;
12141219 num += w * ( y_ti - mu - alpha[ i] - tau * d[ [ t, i] ] ) ;
12151220 }
@@ -1222,7 +1227,8 @@ fn solve_joint_no_lowrank(
12221227 let mut denom_tau = 0.0 ;
12231228 for t in 0 ..n_periods {
12241229 for i in 0 ..n_units {
1225- let w = delta[ [ t, i] ] ;
1230+ // NaN outcomes get zero weight
1231+ let w = if y[ [ t, i] ] . is_finite ( ) { delta[ [ t, i] ] } else { 0.0 } ;
12261232 let y_ti = if y[ [ t, i] ] . is_finite ( ) { y[ [ t, i] ] } else { 0.0 } ;
12271233 let d_ti = d[ [ t, i] ] ;
12281234 if d_ti > 0.5 { // Only treated observations contribute
@@ -1239,7 +1245,8 @@ fn solve_joint_no_lowrank(
12391245 let mut num_mu = 0.0 ;
12401246 for t in 0 ..n_periods {
12411247 for i in 0 ..n_units {
1242- let w = delta[ [ t, i] ] ;
1248+ // NaN outcomes get zero weight
1249+ let w = if y[ [ t, i] ] . is_finite ( ) { delta[ [ t, i] ] } else { 0.0 } ;
12431250 let y_ti = if y[ [ t, i] ] . is_finite ( ) { y[ [ t, i] ] } else { 0.0 } ;
12441251 num_mu += w * ( y_ti - alpha[ i] - beta[ t] - tau * d[ [ t, i] ] ) ;
12451252 }
@@ -1279,21 +1286,20 @@ fn solve_joint_with_lowrank(
12791286 let l_old = l. clone ( ) ;
12801287
12811288 // Step 1: Fix L, solve for (mu, alpha, beta, tau)
1282- // Adjusted outcome: Y - L
1289+ // Adjusted outcome: Y - L (preserve NaN so solve_joint_no_lowrank masks weights)
12831290 let y_adj = Array2 :: from_shape_fn ( ( n_periods, n_units) , |( t, i) | {
1284- let y_ti = if y[ [ t, i] ] . is_finite ( ) { y[ [ t, i] ] } else { 0.0 } ;
1285- y_ti - l[ [ t, i] ]
1291+ y[ [ t, i] ] - l[ [ t, i] ] // NaN - finite = NaN (preserves NaN info)
12861292 } ) ;
12871293
12881294 let ( mu, alpha, beta, tau) = solve_joint_no_lowrank ( & y_adj. view ( ) , d, delta) ?;
12891295
12901296 // Step 2: Fix (mu, alpha, beta, tau), update L
1291- // Residual: R = Y - mu - alpha - beta - tau*D
1297+ // Residual: R = Y - mu - alpha - beta - tau*D (preserve NaN)
12921298 let mut r = Array2 :: < f64 > :: zeros ( ( n_periods, n_units) ) ;
12931299 for t in 0 ..n_periods {
12941300 for i in 0 ..n_units {
1295- let y_ti = if y [ [ t , i ] ] . is_finite ( ) { y [ [ t , i ] ] } else { 0.0 } ;
1296- r[ [ t, i] ] = y_ti - mu - alpha[ i] - beta[ t] - tau * d[ [ t, i] ] ;
1301+ // NaN - finite = NaN (will be masked in gradient step)
1302+ r[ [ t, i] ] = y [ [ t , i ] ] - mu - alpha[ i] - beta[ t] - tau * d[ [ t, i] ] ;
12971303 }
12981304 }
12991305
@@ -1302,15 +1308,20 @@ fn solve_joint_with_lowrank(
13021308 let eta = if delta_max > 0.0 { 1.0 / delta_max } else { 1.0 } ;
13031309
13041310 // gradient_step = L + eta * delta * (R - L)
1311+ // NaN outcomes get zero weight so they don't affect gradient
13051312 let mut gradient_step = Array2 :: < f64 > :: zeros ( ( n_periods, n_units) ) ;
13061313 for t in 0 ..n_periods {
13071314 for i in 0 ..n_units {
1315+ // Mask delta for NaN outcomes
1316+ let delta_ti = if y[ [ t, i] ] . is_finite ( ) { delta[ [ t, i] ] } else { 0.0 } ;
13081317 let delta_norm = if delta_max > 0.0 {
1309- delta [ [ t , i ] ] / delta_max
1318+ delta_ti / delta_max
13101319 } else {
1311- delta [ [ t , i ] ]
1320+ delta_ti
13121321 } ;
1313- gradient_step[ [ t, i] ] = l[ [ t, i] ] + delta_norm * ( r[ [ t, i] ] - l[ [ t, i] ] ) ;
1322+ // r[[t,i]] may be NaN, but delta_norm=0 for NaN obs, so contribution=0
1323+ let r_contrib = if r[ [ t, i] ] . is_finite ( ) { r[ [ t, i] ] } else { 0.0 } ;
1324+ gradient_step[ [ t, i] ] = l[ [ t, i] ] + delta_norm * ( r_contrib - l[ [ t, i] ] ) ;
13141325 }
13151326 }
13161327
@@ -1324,10 +1335,9 @@ fn solve_joint_with_lowrank(
13241335 }
13251336 }
13261337
1327- // Final solve with converged L
1338+ // Final solve with converged L (preserve NaN so solve_joint_no_lowrank masks weights)
13281339 let y_adj = Array2 :: from_shape_fn ( ( n_periods, n_units) , |( t, i) | {
1329- let y_ti = if y[ [ t, i] ] . is_finite ( ) { y[ [ t, i] ] } else { 0.0 } ;
1330- y_ti - l[ [ t, i] ]
1340+ y[ [ t, i] ] - l[ [ t, i] ] // NaN - finite = NaN (preserves NaN info)
13311341 } ) ;
13321342 let ( mu, alpha, beta, tau) = solve_joint_no_lowrank ( & y_adj. view ( ) , d, delta) ?;
13331343
0 commit comments