We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b734e8c commit 63322c1Copy full SHA for 63322c1
src/linear/logistic_regression.rs
@@ -931,4 +931,22 @@ mod tests {
931
assert_eq!(y.len(), y_hat.len());
932
assert_eq!(y.len(), y_hat_reg.len());
933
}
934
+
935
+ #[test]
936
+ fn test_logit() {
937
+ let x: &DenseMatrix<f64> = &DenseMatrix::rand(52181, 94);
938
+ let y1: Vec<u32> = vec![1; 2181];
939
+ let y2: Vec<u32> = vec![0; 50000];
940
+ let y: &Vec<u32> = &(y1.into_iter().chain(y2.into_iter()).collect());
941
+ println!("y vec height: {:?}", y.len());
942
+ println!("x matrix shape: {:?}", x.shape());
943
944
+ let lr = LogisticRegression::fit(x, y, Default::default()).unwrap();
945
+ let y_hat = lr.predict(&x).unwrap();
946
947
+ println!("y_hat shape: {:?}", y_hat.shape());
948
949
+ assert_eq!(y_hat.shape(), 52181);
950
951
+ }
952
0 commit comments