-
Notifications
You must be signed in to change notification settings - Fork 207
[ENH] Adds a check for consistent output for predict and predict_proba #2824
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for contributing to
|
@@ -374,3 +374,10 @@ def check_classifier_output(estimator, datatype): | |||
# check predict proba (all classifiers have predict_proba by default) | |||
y_proba = estimator.predict_proba(FULL_TEST_DATA_DICT[datatype]["test"][0]) | |||
_assert_predict_probabilities(y_proba, datatype, n_classes=len(unique_labels)) | |||
|
|||
y_pred_proba = np.argmax(y_proba, axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The labels from predict will not necessarily be the argmax. Use the classes and/or class_dictionary attribute
@@ -374,3 +374,10 @@ def check_classifier_output(estimator, datatype): | |||
# check predict proba (all classifiers have predict_proba by default) | |||
y_proba = estimator.predict_proba(FULL_TEST_DATA_DICT[datatype]["test"][0]) | |||
_assert_predict_probabilities(y_proba, datatype, n_classes=len(unique_labels)) | |||
|
|||
y_pred_proba = np.argmax(y_proba, axis=1) | |||
_assert_predict_labels(y_pred_proba, datatype, unique_labels=unique_labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is line 379 necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just added it as we were checking the y_pred
and y_proba
. But since we get y_pred_proba
from y_proba
, I don't think it is necessary.
Reference Issues/PRs
Fixed #2802
What does this implement/fix? Explain your changes.
Adds a check to ensure that outputs produced by
predict()
andpredict_proba()
are consistent for theclassification
andclustering
estimators.Does your contribution introduce a new dependency? If yes, which one?
No
Any other comments?
PR checklist
For all contributions
For new estimators and functions
__maintainer__
at the top of relevant files and want to be contacted regarding its maintenance. Unmaintained files may be removed. This is for the full file, and you should not add yourself if you are just making minor changes or do not want to help maintain its contents.For developers with write access