Skip to content

Commit

Permalink
Merge pull request #3 from aldder/sequentialfeatureselection_earlystop
Browse files Browse the repository at this point in the history
  • Loading branch information
aldder committed Feb 4, 2022
2 parents 86a5124 + d5595a9 commit b2ad0d4
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion mlxtend/feature_selection/sequential_feature_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __init__(self, estimator, k_features=1,
if not isinstance(early_stop_rounds, int) or early_stop_rounds < 0:
raise ValueError('Number of early stopping round should be '
'an integer value greater than or equal to 0.'
'Got %d' % early_stop_rounds)
'Got %s' % early_stop_rounds)

self.early_stop_rounds = early_stop_rounds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ def test_custom_feature_names():
n_jobs=1)

sfs1 = sfs1.fit(X, y, custom_feature_names=(
'sepal length', 'sepal width', 'petal length', 'petal width'))
'sepal length', 'sepal width', 'petal length', 'petal width'))
assert sfs1.k_feature_idx_ == (1, 3)
assert sfs1.k_feature_names_ == ('sepal width', 'petal width')
assert sfs1.subsets_[2]['feature_names'] == ('sepal width',
Expand All @@ -1000,13 +1000,12 @@ def test_run_forward_earlystop():
k_features='best',
forward=True,
floating=False,
early_stop=True,
early_stop_rounds=esr,
verbose=0)
sfs.fit(X_iris_with_noise, y_iris)
assert len(sfs.subsets_) < X_iris_with_noise.shape[1]
assert all([sfs.subsets_[list(sfs.subsets_)[-esr-1]]['avg_score']
>= sfs.subsets_[i]['avg_score'] for i in sfs.subsets_.keys()])
assert all([sfs.k_score_ >= sfs.subsets_[i]['avg_score']
for i in sfs.subsets_])


def test_run_backward_earlystop():
Expand All @@ -1024,10 +1023,9 @@ def test_run_backward_earlystop():
k_features='best',
forward=False,
floating=False,
early_stop=True,
early_stop_rounds=esr,
verbose=0)
sfs.fit(X_iris_with_noise, y_iris)
assert len(sfs.subsets_) > 1
assert all([sfs.subsets_[list(sfs.subsets_)[-esr-1]]['avg_score']
>= sfs.subsets_[i]['avg_score'] for i in sfs.subsets_.keys()])
assert all([sfs.k_score_ >= sfs.subsets_[i]['avg_score']
for i in sfs.subsets_])

0 comments on commit b2ad0d4

Please sign in to comment.