Skip to content

Commit 63322c1

Browse files
committed
Add test for logistic regression
1 parent b734e8c commit 63322c1

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

src/linear/logistic_regression.rs

+18
Original file line numberDiff line numberDiff line change
@@ -931,4 +931,22 @@ mod tests {
931931
assert_eq!(y.len(), y_hat.len());
932932
assert_eq!(y.len(), y_hat_reg.len());
933933
}
934+
935+
#[test]
936+
fn test_logit() {
937+
let x: &DenseMatrix<f64> = &DenseMatrix::rand(52181, 94);
938+
let y1: Vec<u32> = vec![1; 2181];
939+
let y2: Vec<u32> = vec![0; 50000];
940+
let y: &Vec<u32> = &(y1.into_iter().chain(y2.into_iter()).collect());
941+
println!("y vec height: {:?}", y.len());
942+
println!("x matrix shape: {:?}", x.shape());
943+
944+
let lr = LogisticRegression::fit(x, y, Default::default()).unwrap();
945+
let y_hat = lr.predict(&x).unwrap();
946+
947+
println!("y_hat shape: {:?}", y_hat.shape());
948+
949+
assert_eq!(y_hat.shape(), 52181);
950+
951+
}
934952
}

0 commit comments

Comments
 (0)