Skip to content

Commit d615ae1

Browse files
committed
enh: force parallelization even for dti and dki
1 parent 70c84d7 commit d615ae1

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

src/nifreeze/model/_dipy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import numpy as np
3030
from dipy.core.gradients import GradientTable
3131
from dipy.reconst.base import ReconstModel
32+
from dipy.reconst.multi_voxel import multi_voxel_fit
3233
from sklearn.gaussian_process import GaussianProcessRegressor
3334

3435
from nifreeze.model.gpr import (
@@ -38,6 +39,11 @@
3839
)
3940

4041

42+
@multi_voxel_fit
43+
def multi_fit(obj, data, *, mask=None, **kwargs):
44+
return obj.fit(data, *obj.args, mask=mask)
45+
46+
4147
def gp_prediction(
4248
model: GaussianProcessRegressor,
4349
gtab: GradientTable | np.ndarray,

src/nifreeze/model/dmri.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,21 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
123123
class_name,
124124
)(gtab, **kwargs)
125125

126-
self._model_fit = model.fit(
127-
data,
128-
engine="serial" if n_jobs == 1 else "joblib",
129-
n_jobs=n_jobs,
130-
)
126+
try:
127+
self._model_fit = model.fit(
128+
data,
129+
engine="serial" if n_jobs == 1 else "joblib",
130+
n_jobs=n_jobs,
131+
)
132+
except TypeError:
133+
from nifreeze.model._dipy import multi_fit
134+
135+
self._model_fit = multi_fit(
136+
model,
137+
data,
138+
engine="serial" if n_jobs == 1 else "ray",
139+
n_jobs=n_jobs,
140+
)
131141
return n_jobs
132142

133143
def fit_predict(self, index: int | None = None, **kwargs):

0 commit comments

Comments
 (0)