Skip to content

Commit 1c642ce

Browse files
committed
remove default random forest as estimator for InstanceHardnessCV
1 parent 636dc5b commit 1c642ce

File tree

4 files changed

+12
-11
lines changed

4 files changed

+12
-11
lines changed

doc/cross_validation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ Now, we do the same using an `InstanceHardnessCV` splitter. We use provide our
101101
classifier to the splitter to calculate instance hardness and distribute samples
102102
with large instance hardness equally over the folds.
103103

104-
>>> ih_cv = InstanceHardnessCV(n_splits=5, estimator=clf,
104+
>>> ih_cv = InstanceHardnessCV(estimator=clf, n_splits=5,
105105
... random_state=random_state)
106106
>>> ih_result = cross_validate(clf, X, y, cv=ih_cv, scoring="average_precision")
107107

examples/cross_validation/plot_instance_hardness_cv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
skf_result = cross_validate(clf, X, y, cv=skf_cv, scoring="average_precision")
6666

6767
# %%
68-
ih_cv = InstanceHardnessCV(n_splits=5, estimator=clf, random_state=10)
68+
ih_cv = InstanceHardnessCV(estimator=clf, n_splits=5, random_state=10)
6969
ih_result = cross_validate(clf, X, y, cv=ih_cv, scoring="average_precision")
7070

7171
# %%

imblearn/cross_validation/_cross_validation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ class InstanceHardnessCV:
1313
1414
Parameters
1515
----------
16+
estimator : estimator object
17+
Classifier to be used to estimate instance hardness of the samples.
18+
This classifier should implement `predict_proba`.
19+
1620
n_splits : int, default=5
1721
Number of folds. Must be at least 2.
1822
19-
estimator : classifier, default=None
20-
Classifier used to determine instance hardness. Defaults to
21-
RandomForestClassifier when set to `None`
22-
2323
random_state : int, RandomState instance, default=None
2424
Determines random_state for reproducible results across multiple calls.
2525
@@ -31,14 +31,14 @@ class InstanceHardnessCV:
3131
>>> from sklearn.linear_model import LogisticRegression
3232
>>> X, y = make_classification(weights=[0.9, 0.1], class_sep=2,
3333
... n_informative=3, n_redundant=1, flip_y=0.05, n_samples=1000, random_state=10)
34-
>>> ih_cv = InstanceHardnessCV(n_splits=5, random_state=10)
3534
>>> estimator = LogisticRegression(random_state=10)
35+
>>> ih_cv = InstanceHardnessCV(estimator=estimator, n_splits=5,random_state=10)
3636
>>> cv_result = cross_validate(estimator, X, y, cv=ih_cv)
3737
>>> print(f"Standard deviation of test_scores: {cv_result['test_score'].std():.3f}")
3838
Standard deviation of test_scores: 0.004
3939
"""
4040

41-
def __init__(self, n_splits=5, estimator=None, random_state=None):
41+
def __init__(self, estimator, n_splits=5, random_state=None):
4242
self.n_splits = n_splits
4343
self.estimator = estimator
4444
self.random_state = random_state

imblearn/cross_validation/tests/test_instance_hardness.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919

2020

2121
def test_instancehardness_cv():
22-
ih_cv = InstanceHardnessCV(random_state=10)
2322
clf = LogisticRegression(random_state=10)
23+
ih_cv = InstanceHardnessCV(estimator=clf, random_state=10)
2424
cv_result = cross_validate(clf, X, y, cv=ih_cv)
25-
assert_array_equal(cv_result['test_score'], [0.965, 0.965, 0.96, 0.965, 0.955])
25+
assert_array_equal(cv_result['test_score'], [0.975, 0.965, 0.96, 0.955, 0.965])
2626

2727

2828
@pytest.mark.parametrize("n_splits", [2, 3, 4])
2929
def test_instancehardness_cv_n_splits(n_splits):
30-
ih_cv = InstanceHardnessCV(n_splits=n_splits, random_state=10)
30+
clf = LogisticRegression(random_state=10)
31+
ih_cv = InstanceHardnessCV(estimator=clf, n_splits=n_splits, random_state=10)
3132
assert ih_cv.get_n_splits() == n_splits

0 commit comments

Comments
 (0)