25
25
26
26
import numpy as np
27
27
from dipy .core .gradients import gradient_table_from_bvals_bvecs
28
- from joblib import Parallel , delayed
29
28
30
29
from nifreeze .data .dmri import (
31
30
DEFAULT_CLIP_PERCENTILE ,
35
34
from nifreeze .model .base import BaseModel , ExpectationModel
36
35
37
36
38
- def _exec_fit (model , data , chunk = None ):
39
- retval = model .fit (data )
40
- return retval , chunk
41
-
42
-
43
- def _exec_predict (model , chunk = None , ** kwargs ):
44
- """Propagate model parameters and call predict."""
45
- return np .squeeze (model .predict (** kwargs )), chunk
46
-
47
-
48
37
class BaseDWIModel (BaseModel ):
49
38
"""Interface and default methods for DWI models."""
50
39
51
40
__slots__ = {
52
41
"_model_class" : "Defining a model class, DIPY models are instantiated automagically" ,
53
42
"_modelargs" : "Arguments acceptable by the underlying DIPY-like model." ,
54
- "_models " : "List with one or more (if parallel execution) model instances " ,
43
+ "_model_fit " : "Fitted model" ,
55
44
}
56
45
57
46
def __init__ (self , dataset : DWI , ** kwargs ):
@@ -81,8 +70,6 @@ def __init__(self, dataset: DWI, **kwargs):
81
70
def _fit (self , index : int | None = None , n_jobs = None , ** kwargs ):
82
71
"""Fit the model chunk-by-chunk asynchronously"""
83
72
84
- n_jobs = n_jobs or 1
85
-
86
73
if self ._locked_fit is not None :
87
74
return n_jobs
88
75
@@ -110,25 +97,11 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
110
97
class_name ,
111
98
)(gtab , ** kwargs )
112
99
113
- # One single CPU - linear execution (full model)
114
- if n_jobs == 1 :
115
- _modelfit , _ = _exec_fit (model , data )
116
- self ._models = [_modelfit ]
117
- return 1
118
-
119
- # Split data into chunks of group of slices
120
- data_chunks = np .array_split (data , n_jobs )
121
-
122
- self ._models = [None ] * n_jobs
123
-
124
- # Parallelize process with joblib
125
- with Parallel (n_jobs = n_jobs ) as executor :
126
- results = executor (
127
- delayed (_exec_fit )(model , dchunk , i ) for i , dchunk in enumerate (data_chunks )
128
- )
129
- for submodel , rindex in results :
130
- self ._models [rindex ] = submodel
131
-
100
+ self ._model_fit = model .fit (
101
+ data ,
102
+ engine = "serial" if n_jobs == 1 else "joblib" ,
103
+ n_jobs = n_jobs ,
104
+ )
132
105
return n_jobs
133
106
134
107
def fit_predict (self , index : int | None = None , ** kwargs ):
@@ -142,13 +115,14 @@ def fit_predict(self, index: int | None = None, **kwargs):
142
115
143
116
"""
144
117
145
- n_models = self ._fit (
118
+ self ._fit (
146
119
index ,
147
120
n_jobs = kwargs .pop ("n_jobs" ),
148
121
** kwargs ,
149
122
)
150
123
151
124
if index is None :
125
+ self ._locked_fit = True
152
126
return None
153
127
154
128
brainmask = self ._dataset .brainmask
@@ -163,29 +137,12 @@ def fit_predict(self, index: int | None = None, **kwargs):
163
137
if S0 is not None :
164
138
S0 = S0 [brainmask , ...] if brainmask is not None else S0 .reshape (- 1 )
165
139
166
- if n_models == 1 :
167
- predicted , _ = _exec_predict (
168
- self ._models [0 ], ** (kwargs | {"gtab" : gradient , "S0" : S0 })
140
+ predicted = np .squeeze (
141
+ self ._model_fit .predict (
142
+ gtab = gradient ,
143
+ S0 = S0 ,
169
144
)
170
- else :
171
- S0 = np .array_split (S0 , n_models ) if S0 is not None else np .full (n_models , None )
172
-
173
- predicted = [None ] * n_models
174
-
175
- # Parallelize process with joblib
176
- with Parallel (n_jobs = n_models ) as executor :
177
- results = executor (
178
- delayed (_exec_predict )(
179
- model ,
180
- chunk = i ,
181
- ** (kwargs | {"gtab" : gradient , "S0" : S0 [i ]}),
182
- )
183
- for i , model in enumerate (self ._models )
184
- )
185
- for subprediction , index in results :
186
- predicted [index ] = subprediction
187
-
188
- predicted = np .hstack (predicted )
145
+ )
189
146
190
147
if brainmask is not None :
191
148
retval = np .zeros_like (brainmask , dtype = "float32" )
0 commit comments