@@ -256,23 +256,27 @@ def mlpipeline_plot_cm(result_path, mode, model, x_test, y_test, iteration):
256
256
return
257
257
258
258
def mlpipeline_plot_pr (result_path , mode , model , x_test , y_test , iteration ):
259
- #class_names = ['Negative Cohort','Positive Cohort']
260
259
261
260
# Since results are stored differently depending on validation type there are two versions of the code that plots the matrix
262
261
if mode == 'loo' or mode == 'cv' or mode == "cv_and_test" :
263
- # print(x_test)
264
- # print(y_test)
265
- # print(len(x_test))
266
- # print(len(y_test))
267
- #x_test, y_test = cl.prepare_dataset(x_test, y_test)
268
- # cm = confusion_matrix(x_test, y_test, labels=[0, 1])
262
+
269
263
precision , recall , _ = precision_recall_curve (x_test , y_test )
264
+ baseline = sum (x_test ) / len (x_test )
270
265
disp = PrecisionRecallDisplay (precision = precision , recall = recall )
266
+ disp = PrecisionRecallDisplay .from_predictions (x_test , y_test )
271
267
disp = disp .plot ()
268
+ plt .plot ([0 , 1 ], [baseline , baseline ], color = 'orange' , label = "Baseline ( AP = " + str (baseline ) + " )" , linestyle = '--' )
269
+ plt .ylim (0 , 1 )
272
270
else :
273
271
precision , recall , _ = precision_recall_curve (x_test , y_test )
272
+ plt .plot (0 , baseline , color = 'orange' , linestyle = '--' )
274
273
disp = PrecisionRecallDisplay (precision = precision , recall = recall )
275
274
disp = disp .plot ()
275
+ ax = plt .gca ()
276
+ baseline = sum (x_test ) / len (x_test )
277
+ plt .ylim (0 , 1 )
278
+ print (baseline )
279
+
276
280
if iteration == "hold_out" :
277
281
disp .ax_ .set_title ('Precision-Recall Curve Iteration: ' + iteration )
278
282
else :
@@ -428,7 +432,7 @@ def save_results(result_path, mode, feature_selection, classifiers, iterations,
428
432
save_roc (result_path + "/" + mode + "/" + i + "/" + classifier + "/" , fpr , tpr , thresholds )
429
433
mlpipeline_plot_cm (result_path + "/" + mode + "/" + i + "/" + classifier + "/" , mode , models , true_label , pred , 0 )
430
434
#HERE
431
- mlpipeline_plot_pr (result_path + "/" + mode + "/" + i + "/" + classifier + "/" , mode , models , true_label , pred , 0 )
435
+ mlpipeline_plot_pr (result_path + "/" + mode + "/" + i + "/" + classifier + "/" , mode , models , true_label , pred_proba , 0 )
432
436
433
437
elif mode == "cv_and_test" :
434
438
@@ -453,7 +457,7 @@ def save_results(result_path, mode, feature_selection, classifiers, iterations,
453
457
save_roc (result_path + "/" + mode + "/" + i + "/" + classifier + "/" , fpr , tpr , thresholds )
454
458
mlpipeline_plot_cm (result_path + "/" + mode + "/" + i + "/" + classifier + "/" , "cv" , models , true_label , pred , 0 )
455
459
#HERE
456
- mlpipeline_plot_pr (result_path + "/" + mode + "/" + i + "/" + classifier + "/" , "cv" , models , true_label , pred , 0 )
460
+ mlpipeline_plot_pr (result_path + "/" + mode + "/" + i + "/" + classifier + "/" , "cv" , models , true_label , pred_proba , 0 )
457
461
458
462
#hold_out
459
463
thresholds = make_thresholds (10000 )
@@ -465,7 +469,7 @@ def save_results(result_path, mode, feature_selection, classifiers, iterations,
465
469
save_roc (result_path + "/" + mode + "/" + i + "/" + classifier + "/hold_out/" , fpr , tpr , thresholds )
466
470
mlpipeline_plot_cm (result_path + "/" + mode + "/" + i + "/" + classifier + "/hold_out/" , mode , models [i ][classifier ][0 ], results ["holdout" ][i ][classifier ]['true_label' ], results ["holdout" ][i ][classifier ]['pred' ], "hold_out" )
467
471
#HERE
468
- mlpipeline_plot_pr (result_path + "/" + mode + "/" + i + "/" + classifier + "/hold_out/" , mode , models [i ][classifier ][0 ], results ["holdout" ][i ][classifier ]['true_label' ], results ["holdout" ][i ][classifier ]['pred ' ], "hold_out" )
472
+ mlpipeline_plot_pr (result_path + "/" + mode + "/" + i + "/" + classifier + "/hold_out/" , mode , models [i ][classifier ][0 ], results ["holdout" ][i ][classifier ]['true_label' ], results ["holdout" ][i ][classifier ]['pred_prob ' ], "hold_out" )
469
473
#individual_sample_report(results, datasets, i, classifier, iterations, labels, result_path + "/" + mode + "/" + i + "/" + classifier + "/")
470
474
471
475
else :
0 commit comments