Skip to content

Commit 5466bec

Browse files
authored
Specify HLG layer name during delayed object creation in fit (#898)
1 parent 8cb4c2f commit 5466bec

File tree

6 files changed

+19
-10
lines changed

6 files changed

+19
-10
lines changed

dask_ml/_compat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
DASK_2_26_0 = DASK_VERSION >= packaging.version.parse("2.26.0")
2323
DASK_2_28_0 = DASK_VERSION > packaging.version.parse("2.27.0")
2424
DASK_2021_02_0 = DASK_VERSION >= packaging.version.parse("2021.02.0")
25+
DASK_2022_01_0 = DASK_VERSION > packaging.version.parse("2021.12.0")
2526
DISTRIBUTED_2_5_0 = DISTRIBUTED_VERSION > packaging.version.parse("2.5.0")
2627
DISTRIBUTED_2_11_0 = DISTRIBUTED_VERSION > packaging.version.parse("2.10.0") # dev
2728
DISTRIBUTED_2021_02_0 = DISTRIBUTED_VERSION >= packaging.version.parse("2021.02.0")

dask_ml/_partial.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from dask.highlevelgraph import HighLevelGraph
1111
from toolz import partial
1212

13+
from ._compat import DASK_2022_01_0
14+
1315
logger = logging.getLogger(__name__)
1416

1517

@@ -125,7 +127,11 @@ def fit(
125127
if y is not None:
126128
dependencies.append(y)
127129
new_dsk = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies)
128-
value = Delayed((name, nblocks - 1), new_dsk)
130+
131+
if DASK_2022_01_0:
132+
value = Delayed((name, nblocks - 1), new_dsk, layer=name)
133+
else:
134+
value = Delayed((name, nblocks - 1), new_dsk)
129135

130136
if compute:
131137
return value.compute()

dask_ml/model_selection/methods.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import warnings
44
from collections import defaultdict
5-
from distutils.version import LooseVersion
65
from threading import Lock
76
from timeit import default_timer
87

98
import numpy as np
9+
import packaging.version
1010
from dask.base import normalize_token
1111
from scipy import sparse
1212
from scipy.stats import rankdata
@@ -20,7 +20,7 @@
2020

2121
# Copied from scikit-learn/sklearn/utils/fixes.py, can be removed once we drop
2222
# support for scikit-learn < 0.18.1 or numpy < 1.12.0.
23-
if LooseVersion(np.__version__) < "1.12.0":
23+
if packaging.version.parse(np.__version__) < packaging.version.parse("1.12.0"):
2424

2525
class MaskedArray(np.ma.MaskedArray):
2626
# Before numpy 1.12, np.ma.MaskedArray object is not picklable

dask_ml/model_selection/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
import copy
22
import warnings
3-
from distutils.version import LooseVersion
43
from itertools import compress
54

65
import dask
76
import dask.array as da
87
import dask.dataframe as dd
98
import numpy as np
9+
import packaging.version
1010
import scipy.sparse as sp
1111
from dask.base import tokenize
1212
from dask.delayed import Delayed, delayed
1313
from sklearn.utils.validation import _is_arraylike, indexable
1414

1515
from ..utils import _num_samples
1616

17-
if LooseVersion(dask.__version__) > "0.15.4":
17+
if packaging.version.parse(dask.__version__) > packaging.version.parse("0.15.4"):
1818
from dask.base import is_dask_collection
1919
else:
2020
from dask.base import Base

dask_ml/preprocessing/data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import multiprocessing
55
import numbers
66
from collections import OrderedDict
7-
from distutils.version import LooseVersion
87
from typing import Any, List, Optional, Sequence, Union
98

109
import dask.array as da
1110
import dask.dataframe as dd
1211
import numpy as np
12+
import packaging.version
1313
import pandas as pd
1414
import sklearn.preprocessing
1515
from dask import compute
@@ -26,8 +26,8 @@
2626
from .._typing import ArrayLike, DataFrameType, NDArrayOrScalar, SeriesType
2727
from ..base import DaskMLBaseMixin
2828

29-
_PANDAS_VERSION = LooseVersion(pd.__version__)
30-
_HAS_CTD = _PANDAS_VERSION >= "0.21.0"
29+
_PANDAS_VERSION = packaging.version.parse(pd.__version__)
30+
_HAS_CTD = _PANDAS_VERSION >= packaging.version.parse("0.21.0")
3131
BOUNDS_THRESHOLD = 1e-7
3232

3333

tests/linear_model/test_glm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ def test_fit(fit_intercept, solver):
6464
)
6565
def test_fit_solver(solver):
6666
import dask_glm
67-
from distutils.version import LooseVersion
67+
import packaging.version
6868

69-
if LooseVersion(dask_glm.__version__) <= "0.2.0":
69+
if packaging.version.parse(dask_glm.__version__) <= packaging.version.parse(
70+
"0.2.0"
71+
):
7072
pytest.skip("FutureWarning for dask config.")
7173

7274
X, y = make_classification(n_samples=100, n_features=5, chunks=50)

0 commit comments

Comments
 (0)