1
+ import numpy as np
2
+ from sklearn .tree import DecisionTreeClassifier
3
+ from sklearn .model_selection import train_test_split
4
+ from sklearn .model_selection import StratifiedKFold
5
+ from sklearn .model_selection import LeaveOneOut
6
+ from sklearn .metrics import confusion_matrix
7
+ from sklearn .metrics import classification_report
8
+
9
+
10
+ def jho (feat , label , opts ):
11
+ ho = 0.3 # ratio of testing set
12
+
13
+ if 'ho' in opts :
14
+ ho = opts ['ho' ]
15
+
16
+ # number of instances
17
+ num_data = np .size (feat , 0 )
18
+ label = label .reshape (num_data ) # Solve bug
19
+
20
+ # prepare data
21
+ xtrain , xtest , ytrain , ytest = train_test_split (feat , label , test_size = ho , stratify = label )
22
+ # train model
23
+ mdl = DecisionTreeClassifier (criterion = "gini" )
24
+ mdl .fit (xtrain , ytrain )
25
+
26
+ # prediction
27
+ ypred = mdl .predict (xtest )
28
+ # confusion matric
29
+ uni = np .unique (ytest )
30
+ confmat = confusion_matrix (ytest , ypred , labels = uni )
31
+ # report
32
+ report = classification_report (ytest , ypred )
33
+ # accuracy
34
+ acc = np .sum (ytest == ypred ) / np .size (xtest ,0 )
35
+
36
+ print ("Accuracy (DT_HO):" , 100 * acc )
37
+
38
+ dt = {'acc' : acc , 'con' : confmat , 'r' : report }
39
+
40
+ return dt
41
+
42
+
43
+ def jkfold (feat , label , opts ):
44
+ kfold = 10 # number of k in kfold
45
+
46
+ if 'kfold' in opts :
47
+ kfold = opts ['kfold' ]
48
+
49
+ # number of instances
50
+ num_data = np .size (feat , 0 )
51
+ # define selected features
52
+ x_data = feat
53
+ y_data = label .reshape (num_data ) # Solve bug
54
+
55
+ fold = StratifiedKFold (n_splits = kfold )
56
+ fold .get_n_splits (x_data , y_data )
57
+
58
+ ytest2 = []
59
+ ypred2 = []
60
+ Afold2 = []
61
+ for train_idx , test_idx in fold .split (x_data , y_data ):
62
+ xtrain = x_data [train_idx ,:]
63
+ ytrain = y_data [train_idx ]
64
+ xtest = x_data [test_idx ,:]
65
+ ytest = y_data [test_idx ]
66
+ # train model
67
+ mdl = DecisionTreeClassifier (criterion = "gini" )
68
+ mdl .fit (xtrain , ytrain )
69
+ # prediction
70
+ ypred = mdl .predict (xtest )
71
+ # accuracy
72
+ Afold = np .sum (ytest == ypred ) / np .size (xtest ,0 )
73
+
74
+ ytest2 = np .concatenate ((ytest2 , ytest ), axis = 0 )
75
+ ypred2 = np .concatenate ((ypred2 , ypred ), axis = 0 )
76
+ Afold2 .append (Afold )
77
+
78
+ # average accuracy
79
+ Afold2 = np .array (Afold2 )
80
+ acc = np .mean (Afold2 )
81
+ # confusion matric
82
+ uni = np .unique (ytest2 )
83
+ confmat = confusion_matrix (ytest2 , ypred2 , labels = uni )
84
+ # report
85
+ report = classification_report (ytest2 , ypred2 )
86
+
87
+ print ("Accuracy (DT_K-fold):" , 100 * acc )
88
+
89
+ dt = {'acc' : acc , 'con' : confmat , 'r' : report }
90
+
91
+ return dt
92
+
93
+
94
+ def jloo (feat , label , opts ):
95
+
96
+ # number of instances
97
+ num_data = np .size (feat , 0 )
98
+ # define selected features
99
+ x_data = feat
100
+ y_data = label .reshape (num_data ) # Solve bug
101
+
102
+ loo = LeaveOneOut ()
103
+ loo .get_n_splits (x_data )
104
+
105
+ ytest2 = []
106
+ ypred2 = []
107
+ Afold2 = []
108
+ for train_idx , test_idx in loo .split (x_data ):
109
+ xtrain = x_data [train_idx ,:]
110
+ ytrain = y_data [train_idx ]
111
+ xtest = x_data [test_idx ,:]
112
+ ytest = y_data [test_idx ]
113
+ # train model
114
+ mdl = DecisionTreeClassifier (criterion = "gini" )
115
+ mdl .fit (xtrain , ytrain )
116
+ # prediction
117
+ ypred = mdl .predict (xtest )
118
+ # accuracy
119
+ Afold = np .sum (ytest == ypred ) / np .size (xtest ,0 )
120
+
121
+ ytest2 = np .concatenate ((ytest2 , ytest ), axis = 0 )
122
+ ypred2 = np .concatenate ((ypred2 , ypred ), axis = 0 )
123
+ Afold2 .append (Afold )
124
+
125
+ # average accuracy
126
+ Afold2 = np .array (Afold2 )
127
+ acc = np .mean (Afold2 )
128
+ # confusion matric
129
+ uni = np .unique (ytest2 )
130
+ confmat = confusion_matrix (ytest2 , ypred2 , labels = uni )
131
+ # report
132
+ report = classification_report (ytest2 , ypred2 )
133
+
134
+ print ("Accuracy (DT_LOO):" , 100 * acc )
135
+
136
+ dt = {'acc' : acc , 'con' : confmat , 'r' : report }
137
+
138
+ return dt
139
+
140
+
141
+
142
+
143
+
144
+
0 commit comments