8
8
9
9
from sklearn import svm , datasets
10
10
from sklearn .model_selection import train_test_split
11
- from sklearn .metrics import confusion_matrix
11
+ from sklearn .metrics import confusion_matrix , precision_recall_curve
12
12
from sklearn .utils .multiclass import unique_labels
13
13
14
14
@@ -119,4 +119,37 @@ def get_probabilities(self, body:str, title:str):
119
119
# get predictions
120
120
probs = self .model .predict (x = [vec_body , vec_title ]).tolist ()[0 ]
121
121
122
- return {k :v for k ,v in zip (self .class_names , probs )}
122
+ return {k :v for k ,v in zip (self .class_names , probs )}
123
+
124
+
125
+ def plot_precision_recall_vs_threshold (y , y_hat , class_names , precision_threshold ):
126
+ "plot precision recall curves focused on precision."
127
+ # credit: https://github.com/ageron/handson-ml/blob/master/03_classification.ipynb
128
+ assert len (class_names )- 1 <= y_hat .shape [- 1 ], 'number of class names must equal number of classes in the data'
129
+ assert y .shape == y_hat .shape , 'shape of ground_truth and predictions must be the same.'
130
+
131
+ for class_name in class_names :
132
+ class_int = class_names .index (class_name )
133
+ precisions , recalls , thresholds = precision_recall_curve (y [:, class_int ], y_hat [:, class_int ])
134
+
135
+ # get the first index of the precision that meets the threshold
136
+ precision_idx = np .argmax (precisions >= precision_threshold )
137
+ # find the exact probability at that threshold
138
+ prob_thresh = thresholds [precision_idx ]
139
+ # find the exact recall at that threshold
140
+ recall_at_thresh = recalls [precision_idx ]
141
+
142
+ plt .figure (figsize = (8 , 4 ))
143
+ plt .plot (thresholds , precisions [:- 1 ], "b--" , label = "Precision" , linewidth = 2 )
144
+ plt .plot (thresholds , recalls [:- 1 ], "g-" , label = "Recall" , linewidth = 2 )
145
+ plt .axhline (y = precision_threshold , label = f'{ precision_threshold :.2f} ' , linewidth = 1 )
146
+ plt .xlabel ("Threshold" , fontsize = 11 )
147
+ plt .legend (loc = "lower left" , fontsize = 10 )
148
+ plt .title (f'Precision vs. Recall For Label: { class_name } ' )
149
+ plt .ylim ([0 , 1 ])
150
+ plt .xlim ([0 , 1 ])
151
+ plt .show ()
152
+ print (f'Label "{ class_name } " @ { precision_threshold :.2f} precision:' )
153
+ print (f' Cutoff: { prob_thresh :.2f} ' )
154
+ print (f' Recall: { recall_at_thresh :.2f} ' )
155
+ print ('\n ' )
0 commit comments