Skip to content

Commit 090dd27

Browse files
author
Aaron Richter
authored
#386 LogisticRegression.predict_proba should return (n, 2) for binary (#760)
* #386 LogisticRegression.predict_proba should return (n, 2) for binary classification * empty commit to trigger CI
1 parent db2e7d5 commit 090dd27

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

dask_ml/linear_model/glm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from ..metrics import r2_score
1717
from ..utils import check_array
18+
from .utils import lr_prob_stack
1819

1920
_base_doc = textwrap.dedent(
2021
"""\
@@ -228,7 +229,7 @@ def decision_function(self, X):
228229
229230
Returns
230231
-------
231-
T : array-like, shape = [n_samples, n_classes]
232+
T : array-like, shape = [n_samples,]
232233
The confidence score of the sample for each class in the model.
233234
"""
234235
X_ = self._check_array(X)
@@ -246,7 +247,7 @@ def predict(self, X):
246247
C : array, shape = [n_samples,]
247248
Predicted class labels for each sample
248249
"""
249-
return self.predict_proba(X) > 0.5 # TODO: verify, multi_class broken
250+
return self.predict_proba(X)[:, 1] > 0.5 # TODO: verify, multi_class broken
250251

251252
def predict_proba(self, X):
252253
"""Probability estimates for samples in X.
@@ -260,7 +261,9 @@ def predict_proba(self, X):
260261
T : array-like, shape = [n_samples, n_classes]
261262
The probability of the sample for each class in the model.
262263
"""
263-
return sigmoid(self.decision_function(X))
264+
# TODO: more work needed here to support multi_class
265+
prob = sigmoid(self.decision_function(X))
266+
return lr_prob_stack(prob)
264267

265268
def score(self, X, y):
266269
"""The mean accuracy on the given data and labels

dask_ml/linear_model/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,13 @@ def add_intercept(X): # noqa: F811
5959
if "intercept" in columns:
6060
raise ValueError("'intercept' column already in 'X'")
6161
return X.assign(intercept=1)[["intercept"] + list(columns)]
62+
63+
64+
@dispatch(np.ndarray) # noqa: F811
65+
def lr_prob_stack(prob): # noqa: F811
66+
return np.vstack([1 - prob, prob]).T
67+
68+
69+
@dispatch(da.Array) # noqa: F811
70+
def lr_prob_stack(prob): # noqa: F811
71+
return da.vstack([1 - prob, prob]).T

tests/linear_model/test_glm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,11 @@ def test_dataframe_warns_about_chunks(fit_intercept):
193193
clf.fit(X.values, y.values)
194194
clf.fit(X.to_dask_array(), y.to_dask_array())
195195
clf.fit(X.to_dask_array(lengths=True), y.to_dask_array(lengths=True))
196+
197+
198+
def test_logistic_predict_proba_shape():
199+
X, y = make_classification(n_samples=100, n_features=5, chunks=50)
200+
lr = LogisticRegression()
201+
lr.fit(X, y)
202+
prob = lr.predict_proba(X)
203+
assert prob.shape == (100, 2)

0 commit comments

Comments
 (0)