Skip to content

Commit 6de04d0

Browse files
committed
fixed precision recall plots, added baseline curve
1 parent b1f1fcd commit 6de04d0

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

validation.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -256,23 +256,27 @@ def mlpipeline_plot_cm(result_path, mode, model, x_test, y_test, iteration):
256256
return
257257

258258
def mlpipeline_plot_pr(result_path, mode, model, x_test, y_test, iteration):
259-
#class_names = ['Negative Cohort','Positive Cohort']
260259

261260
# Since results are stored differently depending on validation type there are two versions of the code that plots the matrix
262261
if mode == 'loo' or mode == 'cv' or mode == "cv_and_test":
263-
# print(x_test)
264-
# print(y_test)
265-
# print(len(x_test))
266-
# print(len(y_test))
267-
#x_test, y_test = cl.prepare_dataset(x_test, y_test)
268-
# cm = confusion_matrix(x_test, y_test, labels=[0, 1])
262+
269263
precision, recall, _ = precision_recall_curve(x_test, y_test)
264+
baseline = sum(x_test) / len(x_test)
270265
disp = PrecisionRecallDisplay(precision=precision, recall=recall)
266+
disp = PrecisionRecallDisplay.from_predictions(x_test, y_test)
271267
disp = disp.plot()
268+
plt.plot([0, 1], [baseline, baseline], color='orange', label= "Baseline ( AP = " + str(baseline) + " )", linestyle='--')
269+
plt.ylim(0, 1)
272270
else:
273271
precision, recall, _ = precision_recall_curve(x_test, y_test)
272+
plt.plot(0, baseline, color='orange', linestyle='--')
274273
disp = PrecisionRecallDisplay(precision=precision, recall=recall)
275274
disp = disp.plot()
275+
ax = plt.gca()
276+
baseline = sum(x_test) / len(x_test)
277+
plt.ylim(0, 1)
278+
print(baseline)
279+
276280
if iteration == "hold_out":
277281
disp.ax_.set_title('Precision-Recall Curve Iteration: ' + iteration)
278282
else:
@@ -428,7 +432,7 @@ def save_results(result_path, mode, feature_selection, classifiers, iterations,
428432
save_roc(result_path + "/" + mode + "/" + i + "/" + classifier + "/", fpr, tpr, thresholds)
429433
mlpipeline_plot_cm(result_path + "/" + mode + "/" + i + "/" + classifier + "/", mode, models, true_label, pred, 0)
430434
#HERE
431-
mlpipeline_plot_pr(result_path + "/" + mode + "/" + i + "/" + classifier + "/", mode, models, true_label, pred, 0)
435+
mlpipeline_plot_pr(result_path + "/" + mode + "/" + i + "/" + classifier + "/", mode, models, true_label, pred_proba, 0)
432436

433437
elif mode == "cv_and_test":
434438

@@ -453,7 +457,7 @@ def save_results(result_path, mode, feature_selection, classifiers, iterations,
453457
save_roc(result_path + "/" + mode + "/" + i + "/" + classifier + "/", fpr, tpr, thresholds)
454458
mlpipeline_plot_cm(result_path + "/" + mode + "/" + i + "/" + classifier + "/", "cv", models, true_label, pred, 0)
455459
#HERE
456-
mlpipeline_plot_pr(result_path + "/" + mode + "/" + i + "/" + classifier + "/", "cv", models, true_label, pred, 0)
460+
mlpipeline_plot_pr(result_path + "/" + mode + "/" + i + "/" + classifier + "/", "cv", models, true_label, pred_proba, 0)
457461

458462
#hold_out
459463
thresholds = make_thresholds(10000)
@@ -465,7 +469,7 @@ def save_results(result_path, mode, feature_selection, classifiers, iterations,
465469
save_roc(result_path + "/" + mode + "/" + i + "/" + classifier + "/hold_out/", fpr, tpr, thresholds)
466470
mlpipeline_plot_cm(result_path + "/" + mode + "/" + i + "/" + classifier + "/hold_out/", mode, models[i][classifier][0], results["holdout"][i][classifier]['true_label'], results["holdout"][i][classifier]['pred'], "hold_out")
467471
#HERE
468-
mlpipeline_plot_pr(result_path + "/" + mode + "/" + i + "/" + classifier + "/hold_out/", mode, models[i][classifier][0], results["holdout"][i][classifier]['true_label'], results["holdout"][i][classifier]['pred'], "hold_out")
472+
mlpipeline_plot_pr(result_path + "/" + mode + "/" + i + "/" + classifier + "/hold_out/", mode, models[i][classifier][0], results["holdout"][i][classifier]['true_label'], results["holdout"][i][classifier]['pred_prob'], "hold_out")
469473
#individual_sample_report(results, datasets, i, classifier, iterations, labels, result_path + "/" + mode + "/" + i + "/" + classifier + "/")
470474

471475
else:

0 commit comments

Comments
 (0)