Skip to content

Commit 1e811ce

Browse files
authored
Skip _check_param_grid import for scikit-learn > 1.0.2 (#901)
* Only import _check_param_grid for sklearn <= 1.0.2 (test-upstream) * Call _get_param_iterator in bad param grid tests (test-upstream) * Resolve different errors raised in > 1.0.2 (test-upstream)
1 parent 5466bec commit 1e811ce

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

dask_ml/model_selection/_search.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
is_classifier,
2424
)
2525
from sklearn.exceptions import NotFittedError
26-
from sklearn.model_selection._search import BaseSearchCV, _check_param_grid
26+
from sklearn.model_selection._search import BaseSearchCV
2727
from sklearn.model_selection._split import (
2828
BaseShuffleSplit,
2929
KFold,
@@ -71,6 +71,12 @@
7171

7272
__all__ = ["GridSearchCV", "RandomizedSearchCV"]
7373

74+
# scikit-learn > 1.0.2 removed _check_param_grid
75+
if SK_VERSION <= packaging.version.parse("1.0.2"):
76+
from sklearn.model_selection._search import _check_param_grid
77+
else:
78+
_check_param_grid = None
79+
7480
if SK_VERSION <= packaging.version.parse("0.21.dev0"):
7581

7682
_RETURN_TRAIN_SCORE_DEFAULT = "warn"
@@ -1600,8 +1606,8 @@ def __init__(
16001606
n_jobs=n_jobs,
16011607
cache_cv=cache_cv,
16021608
)
1603-
1604-
_check_param_grid(param_grid)
1609+
if _check_param_grid:
1610+
_check_param_grid(param_grid)
16051611
self.param_grid = param_grid
16061612

16071613
def _get_param_iterator(self):

tests/model_selection/dask_searchcv/test_model_selection_sklearn.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -371,28 +371,34 @@ def test_grid_search_one_grid_point():
371371

372372

373373
def test_grid_search_bad_param_grid():
374+
# passing a non-iterable param grid raises a TypeError in scikit-learn > 1.0.2
375+
iterable_err = (
376+
ValueError if SK_VERSION <= packaging.version.parse("1.0.2") else TypeError
377+
)
378+
374379
param_dict = {"C": 1.0}
375380
clf = SVC()
376381

377-
with pytest.raises(ValueError):
378-
dcv.GridSearchCV(clf, param_dict)
382+
with pytest.raises(iterable_err):
383+
dcv.GridSearchCV(clf, param_dict)._get_param_iterator()
379384

380385
param_dict = {"C": []}
381386
clf = SVC()
382387

383388
with pytest.raises(ValueError):
384-
dcv.GridSearchCV(clf, param_dict)
389+
dcv.GridSearchCV(clf, param_dict)._get_param_iterator()
385390

386391
param_dict = {"C": "1,2,3"}
387392
clf = SVC()
388393

389-
with pytest.raises(ValueError):
390-
dcv.GridSearchCV(clf, param_dict)
394+
with pytest.raises(iterable_err):
395+
dcv.GridSearchCV(clf, param_dict)._get_param_iterator()
391396

392397
param_dict = {"C": np.ones(6).reshape(3, 2)}
393398
clf = SVC()
399+
394400
with pytest.raises(ValueError):
395-
dcv.GridSearchCV(clf, param_dict)
401+
dcv.GridSearchCV(clf, param_dict)._get_param_iterator()
396402

397403

398404
def test_grid_search_sparse():

0 commit comments

Comments
 (0)