Skip to content

Commit bd55f6e

Browse files
authored
Merge pull request #409 from bsc-wdc/gs_sklearn_estimators_merge
Ready to merge sklearn estimators in GridSearch
2 parents d4d98e1 + a48262c commit bd55f6e

File tree

3 files changed

+109
-3
lines changed

3 files changed

+109
-3
lines changed

dislib/model_selection/_search.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
from dislib.model_selection._split import infer_cv
1515
from dislib.model_selection._validation import check_scorer, \
16-
validate_score, aggregate_score_dicts, fit, score_func
16+
validate_score, aggregate_score_dicts, fit, score_func, \
17+
sklearn_fit, sklearn_score
1718

1819

1920
class BaseSearchCV(ABC):
@@ -55,6 +56,29 @@ def fit(self, x, y=None, **fit_params):
5556
all_candidate_params = []
5657
all_out = []
5758

59+
def evaluate_candidates_sklearn(candidate_params):
60+
"""Evaluate some parameters"""
61+
candidate_params = list(candidate_params)
62+
63+
validation_data = []
64+
fits = []
65+
for parameters, (train, validation) in product(candidate_params,
66+
cv.split(x, y)):
67+
validation_data.append(validation)
68+
fits.append(sklearn_fit(clone(base_estimator), train,
69+
parameters=parameters,
70+
fit_params=fit_params))
71+
out = [sklearn_score(estimator, validation, scorer=scorers) for
72+
estimator, validation in zip(fits, validation_data)]
73+
74+
out = compss_wait_on(out)
75+
76+
nonlocal n_splits
77+
n_splits = cv.get_n_splits()
78+
79+
all_candidate_params.extend(candidate_params)
80+
all_out.extend(out)
81+
5882
def evaluate_candidates(candidate_params):
5983
"""Evaluate some parameters"""
6084
candidate_params = list(candidate_params)
@@ -75,8 +99,10 @@ def evaluate_candidates(candidate_params):
7599

76100
all_candidate_params.extend(candidate_params)
77101
all_out.extend(out)
78-
79-
self._run_search(evaluate_candidates)
102+
if 'sklearn' in str(type(estimator)):
103+
self._run_search(evaluate_candidates_sklearn)
104+
else:
105+
self._run_search(evaluate_candidates)
80106

81107
for params_result in all_out:
82108
scores = params_result[0]
@@ -110,6 +136,9 @@ def evaluate_candidates(candidate_params):
110136
if self.refit:
111137
self.best_estimator_ = clone(base_estimator).set_params(
112138
**self.best_params_)
139+
if 'sklearn' in str(type(estimator)):
140+
x = x.collect()
141+
y = y.collect()
113142
self.best_estimator_.fit(x, y, **fit_params)
114143

115144
# Store the only scorer not as a dict for single metric evaluation

dislib/model_selection/_validation.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import numbers
22

3+
from dislib.data.array import Array
4+
from pycompss.api.task import task
5+
from pycompss.api.parameter import INOUT, Depth, Type, COLLECTION_IN
6+
37
import numpy as np
48

59

@@ -18,6 +22,41 @@ def score_func(estimator, validation_ds, scorer):
1822
return [test_scores]
1923

2024

25+
@task(est=INOUT, blocks_x={Type: COLLECTION_IN, Depth: 2},
26+
blocks_y={Type: COLLECTION_IN, Depth: 2})
27+
def fit_sklearn_estimator(est, blocks_x, blocks_y, **fit_params):
28+
x = Array._merge_blocks(blocks_x)
29+
y = Array._merge_blocks(blocks_y)
30+
return est.fit(x, y, **fit_params)
31+
32+
33+
@task(blocks_x={Type: COLLECTION_IN, Depth: 2},
34+
blocks_y={Type: COLLECTION_IN, Depth: 2},
35+
returns=1)
36+
def score_sklearn_estimator(est, scorer, blocks_x, blocks_y):
37+
x = Array._merge_blocks(blocks_x)
38+
y = Array._merge_blocks(blocks_y)
39+
return _score(est, x, y, scorer)
40+
41+
42+
def sklearn_fit(estimator, train_ds,
43+
parameters, fit_params):
44+
if parameters is not None:
45+
estimator.set_params(**parameters)
46+
x_train, y_train = train_ds
47+
48+
return fit_sklearn_estimator(estimator, x_train._blocks,
49+
y_train._blocks, **fit_params)
50+
51+
52+
def sklearn_score(estimator, validation_ds, scorer):
53+
x_test, y_test = validation_ds
54+
test_scores = score_sklearn_estimator(estimator, scorer,
55+
x_test._blocks, y_test._blocks)
56+
57+
return [test_scores]
58+
59+
2160
def _score(estimator, x, y, scorers):
2261
"""Return a dict of scores"""
2362
scores = {}

tests/test_gridsearch.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from sklearn import clone, datasets
44

5+
from sklearn.ensemble import RandomForestClassifier as SklearnRF
56
import dislib as ds
67
from dislib.classification import CascadeSVM, RandomForestClassifier
78
from dislib.cluster import DBSCAN, KMeans, GaussianMixture
@@ -79,6 +80,43 @@ def test_fit(self):
7980
self.assertTrue(hasattr(searcher, 'scorer_'))
8081
self.assertEqual(searcher.n_splits_, 5)
8182

83+
def test_fit_sk(self):
84+
"""Tests GridSearchCV fit()."""
85+
x_np, y_np = datasets.load_iris(return_X_y=True)
86+
x = ds.array(x_np, (30, 4))
87+
y = ds.array(y_np[:, np.newaxis], (30, 1))
88+
89+
param_grid = {'n_estimators': (2, 4),
90+
'max_depth': range(3, 5)}
91+
rf = SklearnRF()
92+
print("ESTIMATOR TYPE")
93+
print(str(type(rf)))
94+
95+
searcher = GridSearchCV(rf, param_grid)
96+
searcher.fit(x, y)
97+
98+
expected_keys = {'param_max_depth', 'param_n_estimators', 'params',
99+
'mean_test_score', 'std_test_score',
100+
'rank_test_score'}
101+
split_keys = {'split%d_test_score' % i for i in range(5)}
102+
expected_keys.update(split_keys)
103+
self.assertSetEqual(set(searcher.cv_results_.keys()), expected_keys)
104+
105+
expected_params = [(3, 2), (3, 4), (4, 2), (4, 4)]
106+
for params in searcher.cv_results_['params']:
107+
m = params['max_depth']
108+
n = params['n_estimators']
109+
self.assertIn((m, n), expected_params)
110+
expected_params.remove((m, n))
111+
self.assertEqual(len(expected_params), 0)
112+
113+
self.assertTrue(hasattr(searcher, 'best_estimator_'))
114+
self.assertTrue(hasattr(searcher, 'best_score_'))
115+
self.assertTrue(hasattr(searcher, 'best_params_'))
116+
self.assertTrue(hasattr(searcher, 'best_index_'))
117+
self.assertTrue(hasattr(searcher, 'scorer_'))
118+
self.assertEqual(searcher.n_splits_, 5)
119+
82120
def test_fit_2(self):
83121
"""Tests GridSearchCV fit() with different data."""
84122
x_np, y_np = datasets.load_breast_cancer(return_X_y=True)

0 commit comments

Comments
 (0)