|
| 1 | +use crate::{ |
| 2 | + api::Predictor, |
| 3 | + error::{Failed, FailedError}, |
| 4 | + linalg::Matrix, |
| 5 | + math::num::RealNumber, |
| 6 | +}; |
| 7 | + |
| 8 | +use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult}; |
| 9 | + |
1 | 10 | /// grid search results. |
2 | 11 | #[derive(Clone, Debug)] |
3 | 12 | pub struct GridSearchResult<T: RealNumber, I: Clone> { |
@@ -60,58 +69,61 @@ where |
60 | 69 |
|
61 | 70 | #[cfg(test)] |
62 | 71 | mod tests { |
63 | | - use crate::linear::logistic_regression::{ |
64 | | - LogisticRegression, LogisticRegressionSearchParameters, |
65 | | -}; |
| 72 | + use crate::{ |
| 73 | + linalg::naive::dense_matrix::DenseMatrix, |
| 74 | + linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters}, |
| 75 | + metrics::accuracy, |
| 76 | + model_selection::{hyper_tuning::grid_search, KFold}, |
| 77 | + }; |
66 | 78 |
|
67 | | - #[test] |
68 | | - fn test_grid_search() { |
69 | | - let x = DenseMatrix::from_2d_array(&[ |
70 | | - &[5.1, 3.5, 1.4, 0.2], |
71 | | - &[4.9, 3.0, 1.4, 0.2], |
72 | | - &[4.7, 3.2, 1.3, 0.2], |
73 | | - &[4.6, 3.1, 1.5, 0.2], |
74 | | - &[5.0, 3.6, 1.4, 0.2], |
75 | | - &[5.4, 3.9, 1.7, 0.4], |
76 | | - &[4.6, 3.4, 1.4, 0.3], |
77 | | - &[5.0, 3.4, 1.5, 0.2], |
78 | | - &[4.4, 2.9, 1.4, 0.2], |
79 | | - &[4.9, 3.1, 1.5, 0.1], |
80 | | - &[7.0, 3.2, 4.7, 1.4], |
81 | | - &[6.4, 3.2, 4.5, 1.5], |
82 | | - &[6.9, 3.1, 4.9, 1.5], |
83 | | - &[5.5, 2.3, 4.0, 1.3], |
84 | | - &[6.5, 2.8, 4.6, 1.5], |
85 | | - &[5.7, 2.8, 4.5, 1.3], |
86 | | - &[6.3, 3.3, 4.7, 1.6], |
87 | | - &[4.9, 2.4, 3.3, 1.0], |
88 | | - &[6.6, 2.9, 4.6, 1.3], |
89 | | - &[5.2, 2.7, 3.9, 1.4], |
90 | | - ]); |
91 | | - let y = vec![ |
92 | | - 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., |
93 | | - ]; |
| 79 | + #[test] |
| 80 | + fn test_grid_search() { |
| 81 | + let x = DenseMatrix::from_2d_array(&[ |
| 82 | + &[5.1, 3.5, 1.4, 0.2], |
| 83 | + &[4.9, 3.0, 1.4, 0.2], |
| 84 | + &[4.7, 3.2, 1.3, 0.2], |
| 85 | + &[4.6, 3.1, 1.5, 0.2], |
| 86 | + &[5.0, 3.6, 1.4, 0.2], |
| 87 | + &[5.4, 3.9, 1.7, 0.4], |
| 88 | + &[4.6, 3.4, 1.4, 0.3], |
| 89 | + &[5.0, 3.4, 1.5, 0.2], |
| 90 | + &[4.4, 2.9, 1.4, 0.2], |
| 91 | + &[4.9, 3.1, 1.5, 0.1], |
| 92 | + &[7.0, 3.2, 4.7, 1.4], |
| 93 | + &[6.4, 3.2, 4.5, 1.5], |
| 94 | + &[6.9, 3.1, 4.9, 1.5], |
| 95 | + &[5.5, 2.3, 4.0, 1.3], |
| 96 | + &[6.5, 2.8, 4.6, 1.5], |
| 97 | + &[5.7, 2.8, 4.5, 1.3], |
| 98 | + &[6.3, 3.3, 4.7, 1.6], |
| 99 | + &[4.9, 2.4, 3.3, 1.0], |
| 100 | + &[6.6, 2.9, 4.6, 1.3], |
| 101 | + &[5.2, 2.7, 3.9, 1.4], |
| 102 | + ]); |
| 103 | + let y = vec![ |
| 104 | + 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., |
| 105 | + ]; |
94 | 106 |
|
95 | | - let cv = KFold { |
96 | | - n_splits: 5, |
97 | | - ..KFold::default() |
98 | | - }; |
| 107 | + let cv = KFold { |
| 108 | + n_splits: 5, |
| 109 | + ..KFold::default() |
| 110 | + }; |
99 | 111 |
|
100 | | - let parameters = LogisticRegressionSearchParameters { |
101 | | - alpha: vec![0., 1.], |
102 | | - ..Default::default() |
103 | | - }; |
| 112 | + let parameters = LogisticRegressionSearchParameters { |
| 113 | + alpha: vec![0., 1.], |
| 114 | + ..Default::default() |
| 115 | + }; |
104 | 116 |
|
105 | | - let results = grid_search( |
106 | | - LogisticRegression::fit, |
107 | | - &x, |
108 | | - &y, |
109 | | - parameters.into_iter(), |
110 | | - cv, |
111 | | - &accuracy, |
112 | | - ) |
113 | | - .unwrap(); |
| 117 | + let results = grid_search( |
| 118 | + LogisticRegression::fit, |
| 119 | + &x, |
| 120 | + &y, |
| 121 | + parameters.into_iter(), |
| 122 | + cv, |
| 123 | + &accuracy, |
| 124 | + ) |
| 125 | + .unwrap(); |
114 | 126 |
|
115 | | - assert!([0., 1.].contains(&results.parameters.alpha)); |
116 | | - } |
| 127 | + assert!([0., 1.].contains(&results.parameters.alpha)); |
| 128 | + } |
117 | 129 | } |
0 commit comments