Skip to content
This repository was archived by the owner on Dec 9, 2022. It is now read-only.

Commit 1d2e7c3

Browse files
committed
added pr util
1 parent 60234fd commit 1d2e7c3

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

flask_app/utils.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from sklearn import svm, datasets
1010
from sklearn.model_selection import train_test_split
11-
from sklearn.metrics import confusion_matrix
11+
from sklearn.metrics import confusion_matrix, precision_recall_curve
1212
from sklearn.utils.multiclass import unique_labels
1313

1414

@@ -119,4 +119,37 @@ def get_probabilities(self, body:str, title:str):
119119
# get predictions
120120
probs = self.model.predict(x=[vec_body, vec_title]).tolist()[0]
121121

122-
return {k:v for k,v in zip(self.class_names, probs)}
122+
return {k:v for k,v in zip(self.class_names, probs)}
123+
124+
125+
def plot_precision_recall_vs_threshold(y, y_hat, class_names, precision_threshold):
126+
"plot precision recall curves focused on precision."
127+
# credit: https://github.com/ageron/handson-ml/blob/master/03_classification.ipynb
128+
assert len(class_names)-1 <= y_hat.shape[-1], 'number of class names must equal number of classes in the data'
129+
assert y.shape == y_hat.shape, 'shape of ground_truth and predictions must be the same.'
130+
131+
for class_name in class_names:
132+
class_int = class_names.index(class_name)
133+
precisions, recalls, thresholds = precision_recall_curve(y[:, class_int], y_hat[:, class_int])
134+
135+
# get the first index of the precision that meets the threshold
136+
precision_idx = np.argmax(precisions >= precision_threshold)
137+
# find the exact probability at that threshold
138+
prob_thresh = thresholds[precision_idx]
139+
# find the exact recall at that threshold
140+
recall_at_thresh = recalls[precision_idx]
141+
142+
plt.figure(figsize=(8, 4))
143+
plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
144+
plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
145+
plt.axhline(y=precision_threshold, label=f'{precision_threshold:.2f}', linewidth=1)
146+
plt.xlabel("Threshold", fontsize=11)
147+
plt.legend(loc="lower left", fontsize=10)
148+
plt.title(f'Precision vs. Recall For Label: {class_name}')
149+
plt.ylim([0, 1])
150+
plt.xlim([0, 1])
151+
plt.show()
152+
print(f'Label "{class_name}" @ {precision_threshold:.2f} precision:')
153+
print(f' Cutoff: {prob_thresh:.2f}')
154+
print(f' Recall: {recall_at_thresh:.2f}')
155+
print('\n')

0 commit comments

Comments
 (0)