Skip to content

Commit 0cb775c

Browse files
committed
enh: use DIPY's parallelization
1 parent 45839b8 commit 0cb775c

File tree

2 files changed

+15
-57
lines changed

2 files changed

+15
-57
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ license = "Apache-2.0"
2121
requires-python = ">=3.10"
2222
dependencies = [
2323
"attrs",
24-
"dipy>=1.5.0",
24+
"dipy>=1.10.0",
2525
"joblib",
2626
"nipype>= 1.5.1,<2.0",
2727
"nitransforms>=22.0.0,<24",
2828
"nireports",
2929
"numpy>=1.21.3",
3030
"nest-asyncio>=1.5.1",
31+
"ray",
3132
"scikit-image>=0.15.0",
3233
"scikit_learn>=1.3.0",
3334
"scipy>=1.8.0",

src/nifreeze/model/dmri.py

Lines changed: 13 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import numpy as np
2727
from dipy.core.gradients import gradient_table_from_bvals_bvecs
28-
from joblib import Parallel, delayed
2928

3029
from nifreeze.data.dmri import (
3130
DEFAULT_CLIP_PERCENTILE,
@@ -35,23 +34,13 @@
3534
from nifreeze.model.base import BaseModel, ExpectationModel
3635

3736

38-
def _exec_fit(model, data, chunk=None):
39-
retval = model.fit(data)
40-
return retval, chunk
41-
42-
43-
def _exec_predict(model, chunk=None, **kwargs):
44-
"""Propagate model parameters and call predict."""
45-
return np.squeeze(model.predict(**kwargs)), chunk
46-
47-
4837
class BaseDWIModel(BaseModel):
4938
"""Interface and default methods for DWI models."""
5039

5140
__slots__ = {
5241
"_model_class": "Defining a model class, DIPY models are instantiated automagically",
5342
"_modelargs": "Arguments acceptable by the underlying DIPY-like model.",
54-
"_models": "List with one or more (if parallel execution) model instances",
43+
"_model_fit": "Fitted model",
5544
}
5645

5746
def __init__(self, dataset: DWI, **kwargs):
@@ -81,8 +70,6 @@ def __init__(self, dataset: DWI, **kwargs):
8170
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
8271
"""Fit the model chunk-by-chunk asynchronously"""
8372

84-
n_jobs = n_jobs or 1
85-
8673
if self._locked_fit is not None:
8774
return n_jobs
8875

@@ -110,25 +97,11 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
11097
class_name,
11198
)(gtab, **kwargs)
11299

113-
# One single CPU - linear execution (full model)
114-
if n_jobs == 1:
115-
_modelfit, _ = _exec_fit(model, data)
116-
self._models = [_modelfit]
117-
return 1
118-
119-
# Split data into chunks of group of slices
120-
data_chunks = np.array_split(data, n_jobs)
121-
122-
self._models = [None] * n_jobs
123-
124-
# Parallelize process with joblib
125-
with Parallel(n_jobs=n_jobs) as executor:
126-
results = executor(
127-
delayed(_exec_fit)(model, dchunk, i) for i, dchunk in enumerate(data_chunks)
128-
)
129-
for submodel, rindex in results:
130-
self._models[rindex] = submodel
131-
100+
self._model_fit = model.fit(
101+
data,
102+
engine="serial" if n_jobs == 1 else "joblib",
103+
n_jobs=n_jobs,
104+
)
132105
return n_jobs
133106

134107
def fit_predict(self, index: int | None = None, **kwargs):
@@ -142,13 +115,14 @@ def fit_predict(self, index: int | None = None, **kwargs):
142115
143116
"""
144117

145-
n_models = self._fit(
118+
self._fit(
146119
index,
147120
n_jobs=kwargs.pop("n_jobs"),
148121
**kwargs,
149122
)
150123

151124
if index is None:
125+
self._locked_fit = True
152126
return None
153127

154128
brainmask = self._dataset.brainmask
@@ -163,29 +137,12 @@ def fit_predict(self, index: int | None = None, **kwargs):
163137
if S0 is not None:
164138
S0 = S0[brainmask, ...] if brainmask is not None else S0.reshape(-1)
165139

166-
if n_models == 1:
167-
predicted, _ = _exec_predict(
168-
self._models[0], **(kwargs | {"gtab": gradient, "S0": S0})
140+
predicted = np.squeeze(
141+
self._model_fit.predict(
142+
gtab=gradient,
143+
S0=S0,
169144
)
170-
else:
171-
S0 = np.array_split(S0, n_models) if S0 is not None else np.full(n_models, None)
172-
173-
predicted = [None] * n_models
174-
175-
# Parallelize process with joblib
176-
with Parallel(n_jobs=n_models) as executor:
177-
results = executor(
178-
delayed(_exec_predict)(
179-
model,
180-
chunk=i,
181-
**(kwargs | {"gtab": gradient, "S0": S0[i]}),
182-
)
183-
for i, model in enumerate(self._models)
184-
)
185-
for subprediction, index in results:
186-
predicted[index] = subprediction
187-
188-
predicted = np.hstack(predicted)
145+
)
189146

190147
if brainmask is not None:
191148
retval = np.zeros_like(brainmask, dtype="float32")

0 commit comments

Comments
 (0)