Skip to content

Commit e5250de

Browse files
committed
Add random test for logistic regression
1 parent c79d3e8 commit e5250de

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

src/linear/logistic_regression.rs

+25
Original file line numberDiff line numberDiff line change
@@ -903,4 +903,29 @@ mod tests {
903903

904904
assert!(reg_coeff_sum < coeff);
905905
}
906+
#[cfg_attr(
907+
all(target_arch = "wasm32", not(target_os = "wasi")),
908+
wasm_bindgen_test::wasm_bindgen_test
909+
)]
910+
#[test]
911+
fn lr_fit_predict_random() {
912+
let x: DenseMatrix<f32> = DenseMatrix::rand(52181, 94);
913+
let y1: Vec<i32> = vec![1; 2181];
914+
let y2: Vec<i32> = vec![0; 50000];
915+
let y: Vec<i32> = y1.into_iter().chain(y2.into_iter()).collect();
916+
917+
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
918+
let lr_reg = LogisticRegression::fit(
919+
&x,
920+
&y,
921+
LogisticRegressionParameters::default().with_alpha(1.0),
922+
)
923+
.unwrap();
924+
925+
let y_hat = lr.predict(&x).unwrap();
926+
let y_hat_reg = lr_reg.predict(&x).unwrap();
927+
928+
assert_eq!(y.len(), y_hat.len());
929+
assert_eq!(y.len(), y_hat_reg.len());
930+
}
906931
}

0 commit comments

Comments
 (0)