Skip to content

Commit 45839b8

Browse files
authored
Merge pull request #120 from nipreps/fix/reenable-single-fit-models
ENH: Re-allow "locking" of models with first fit
2 parents 5ee55de + 5e23050 commit 45839b8

File tree

5 files changed

+89
-36
lines changed

5 files changed

+89
-36
lines changed

src/nifreeze/cli/run.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@ def main(argv=None) -> None:
6969

7070
prev_model: Estimator | None = None
7171
for _model in args.models:
72+
single_fit = _model.lower().startswith("single")
7273
estimator: Estimator = Estimator(
73-
_model,
74+
_model.lower().replace("single", ""),
7475
prev=prev_model,
76+
single_fit=single_fit,
7577
)
7678
prev_model = estimator
7779

src/nifreeze/estimator.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,21 @@ def run(self, dataset: DatasetT, **kwargs) -> DatasetT:
7070
class Estimator:
7171
"""Orchestrates components for a single estimation step."""
7272

73-
__slots__ = ("_model", "_strategy", "_prev", "_model_kwargs", "_align_kwargs")
73+
__slots__ = ("_model", "_single_fit", "_strategy", "_prev", "_model_kwargs", "_align_kwargs")
7474

7575
def __init__(
7676
self,
7777
model: BaseModel | str,
7878
strategy: str = "random",
7979
prev: Estimator | Filter | None = None,
8080
model_kwargs: dict | None = None,
81+
single_fit: bool = False,
8182
**kwargs,
8283
):
8384
self._model = model
8485
self._prev = prev
8586
self._strategy = strategy
87+
self._single_fit = single_fit
8688
self._model_kwargs = model_kwargs or {}
8789
self._align_kwargs = kwargs or {}
8890

@@ -115,11 +117,16 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
115117
# Initialize model
116118
if isinstance(self._model, str):
117119
# Factory creates the appropriate model and pipes arguments
118-
self._model = ModelFactory.init(
120+
model = ModelFactory.init(
119121
model=self._model,
120122
dataset=dataset,
121123
**self._model_kwargs,
122124
)
125+
else:
126+
model = self._model
127+
128+
if self._single_fit:
129+
model.fit_predict(None, n_jobs=n_jobs)
123130

124131
kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None)
125132
kwargs = self._align_kwargs | kwargs
@@ -145,7 +152,7 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
145152

146153
# fit the model
147154
test_set = dataset[i]
148-
predicted = self._model.fit_predict( # type: ignore[union-attr]
155+
predicted = model.fit_predict( # type: ignore[union-attr]
149156
i,
150157
n_jobs=n_jobs,
151158
)

src/nifreeze/model/base.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,46 +87,59 @@ class BaseModel:
8787
8888
"""
8989

90-
__slots__ = ("_dataset",)
90+
__slots__ = ("_dataset", "_locked_fit")
9191

9292
def __init__(self, dataset, **kwargs):
9393
"""Base initialization."""
9494

95+
self._locked_fit = None
9596
self._dataset = dataset
9697
# Warn if mask not present
9798
if dataset.brainmask is None:
9899
warn(mask_absence_warn_msg, stacklevel=2)
99100

100101
@abstractmethod
101-
def fit_predict(self, index, **kwargs) -> np.ndarray:
102-
"""Fit and predict the indicate index of the dataset (abstract signature)."""
102+
def fit_predict(self, index: int | None = None, **kwargs) -> np.ndarray:
103+
"""
104+
Fit and predict the indicated index of the dataset (abstract signature).
105+
106+
If ``index`` is ``None``, then the model is executed in *single-fit mode* meaning
107+
that it will be run only once in all the data available.
108+
Please note that all the predictions of this model will suffer from data leakage
109+
from the original volume.
110+
111+
Parameters
112+
----------
113+
index : :obj:`int` or ``None``
114+
The index to predict.
115+
If ``None``, no prediction will be executed.
116+
117+
"""
103118
raise NotImplementedError("Cannot call fit_predict() on a BaseModel instance.")
104119

105120

106121
class TrivialModel(BaseModel):
107122
"""A trivial model that returns a given map always."""
108123

109-
__slots__ = ("_predicted",)
110-
111124
def __init__(self, dataset, predicted=None, **kwargs):
112125
"""Implement object initialization."""
113126

114127
super().__init__(dataset, **kwargs)
115-
self._predicted = (
128+
self._locked_fit = (
116129
predicted
117130
if predicted is not None
118131
# Infer from dataset if not provided at initialization
119132
else getattr(dataset, "reference", getattr(dataset, "bzero", None))
120133
)
121134

122-
if self._predicted is None:
135+
if self._locked_fit is None:
123136
raise TypeError("This model requires the predicted map at initialization")
124137

125138
def fit_predict(self, *_, **kwargs):
126139
"""Return the reference map."""
127140

128141
# No need to check fit (if not fitted, has raised already)
129-
return self._predicted
142+
return self._locked_fit
130143

131144

132145
class ExpectationModel(BaseModel):
@@ -139,7 +152,7 @@ def __init__(self, dataset, stat="median", **kwargs):
139152
super().__init__(dataset, **kwargs)
140153
self._stat = stat
141154

142-
def fit_predict(self, index: int, **kwargs):
155+
def fit_predict(self, index: int | None = None, **kwargs):
143156
"""
144157
Return the expectation map.
145158
@@ -149,12 +162,20 @@ def fit_predict(self, index: int, **kwargs):
149162
The volume index that is left-out in fitting, and then predicted.
150163
151164
"""
165+
166+
if self._locked_fit is not None:
167+
return self._locked_fit
168+
152169
# Select the summary statistic
153170
avg_func = getattr(np, kwargs.pop("stat", self._stat))
154171

155172
# Create index mask
156173
index_mask = np.ones(len(self._dataset), dtype=bool)
157-
index_mask[index] = False
158174

159-
# Calculate the average
160-
return avg_func(self._dataset[index_mask][0], axis=-1)
175+
if index is not None:
176+
index_mask[index] = False
177+
# Calculate the average
178+
return avg_func(self._dataset[index_mask][0], axis=-1)
179+
180+
self._locked_fit = avg_func(self._dataset[index_mask][0], axis=-1)
181+
return self._locked_fit

src/nifreeze/model/dmri.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class BaseDWIModel(BaseModel):
5151
__slots__ = {
5252
"_model_class": "Defining a model class, DIPY models are instantiated automagically",
5353
"_modelargs": "Arguments acceptable by the underlying DIPY-like model.",
54+
"_models": "List with one or more (if parallel execution) model instances",
5455
}
5556

5657
def __init__(self, dataset: DWI, **kwargs):
@@ -77,13 +78,21 @@ def __init__(self, dataset: DWI, **kwargs):
7778

7879
super().__init__(dataset, **kwargs)
7980

80-
def _fit(self, index, n_jobs=None, **kwargs):
81+
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
8182
"""Fit the model chunk-by-chunk asynchronously"""
83+
8284
n_jobs = n_jobs or 1
8385

86+
if self._locked_fit is not None:
87+
return n_jobs
88+
8489
brainmask = self._dataset.brainmask
8590
idxmask = np.ones(len(self._dataset), dtype=bool)
86-
idxmask[index] = False
91+
92+
if index is not None:
93+
idxmask[index] = False
94+
else:
95+
self._locked_fit = True
8796

8897
data, _, gtab = self._dataset[idxmask]
8998
# Select voxels within mask or just unravel 3D if no mask
@@ -96,14 +105,15 @@ def _fit(self, index, n_jobs=None, **kwargs):
96105

97106
if model_str:
98107
module_name, class_name = model_str.rsplit(".", 1)
99-
self._model = getattr(
108+
model = getattr(
100109
import_module(module_name),
101110
class_name,
102111
)(gtab, **kwargs)
103112

104113
# One single CPU - linear execution (full model)
105114
if n_jobs == 1:
106-
self._model, _ = _exec_fit(self._model, data)
115+
_modelfit, _ = _exec_fit(model, data)
116+
self._models = [_modelfit]
107117
return 1
108118

109119
# Split data into chunks of group of slices
@@ -114,15 +124,14 @@ def _fit(self, index, n_jobs=None, **kwargs):
114124
# Parallelize process with joblib
115125
with Parallel(n_jobs=n_jobs) as executor:
116126
results = executor(
117-
delayed(_exec_fit)(self._model, dchunk, i) for i, dchunk in enumerate(data_chunks)
127+
delayed(_exec_fit)(model, dchunk, i) for i, dchunk in enumerate(data_chunks)
118128
)
119129
for submodel, rindex in results:
120130
self._models[rindex] = submodel
121131

122-
self._model = None # Preempt further actions on the model
123132
return n_jobs
124133

125-
def fit_predict(self, index: int, **kwargs):
134+
def fit_predict(self, index: int | None = None, **kwargs):
126135
"""
127136
Predict asynchronously chunk-by-chunk the diffusion signal.
128137
@@ -133,8 +142,14 @@ def fit_predict(self, index: int, **kwargs):
133142
134143
"""
135144

136-
n_models = self._fit(index, **kwargs)
137-
kwargs.pop("n_jobs")
145+
n_models = self._fit(
146+
index,
147+
n_jobs=kwargs.pop("n_jobs"),
148+
**kwargs,
149+
)
150+
151+
if index is None:
152+
return None
138153

139154
brainmask = self._dataset.brainmask
140155
gradient = self._dataset.gradients[:, index]
@@ -149,9 +164,10 @@ def fit_predict(self, index: int, **kwargs):
149164
S0 = S0[brainmask, ...] if brainmask is not None else S0.reshape(-1)
150165

151166
if n_models == 1:
152-
predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": gradient, "S0": S0}))
167+
predicted, _ = _exec_predict(
168+
self._models[0], **(kwargs | {"gtab": gradient, "S0": S0})
169+
)
153170
else:
154-
print(n_models, S0)
155171
S0 = np.array_split(S0, n_models) if S0 is not None else np.full(n_models, None)
156172

157173
predicted = [None] * n_models
@@ -221,9 +237,12 @@ def __init__(
221237
self._th_high = th_high
222238
self._detrend = detrend
223239

224-
def fit_predict(self, index, *_, **kwargs):
240+
def fit_predict(self, index: int | None = None, *_, **kwargs):
225241
"""Return the average map."""
226242

243+
if index is None:
244+
raise RuntimeError(f"Model {self.__class__.__name__} does not allow locking.")
245+
227246
bvalues = self._dataset.gradients[:, -1]
228247
bcenter = bvalues[index]
229248

src/nifreeze/model/pet.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
class PETModel(BaseModel):
3737
"""A PET imaging realignment model based on B-Spline approximation."""
3838

39-
__slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_n_ctrl")
39+
__slots__ = ("_t", "_x", "_xlim", "_order", "_n_ctrl")
4040

4141
def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs):
4242
"""
@@ -76,13 +76,17 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs):
7676
# B-Spline knots
7777
self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32")
7878

79-
self._coeff = None
80-
81-
def _fit(self, n_jobs=None, **kwargs):
79+
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
8280
"""Fit the model."""
8381
from scipy.interpolate import BSpline
8482
from scipy.sparse.linalg import cg
8583

84+
if self._locked_fit is not None:
85+
return n_jobs
86+
87+
if index is not None:
88+
raise NotImplementedError("Fitting with held-out data is not supported")
89+
8690
timepoints = kwargs.get("timepoints", None) or self._x
8791
x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl
8892

@@ -101,15 +105,15 @@ def _fit(self, n_jobs=None, **kwargs):
101105
with Parallel(n_jobs=n_jobs or min(cpu_count() or 1, 8)) as executor:
102106
results = executor(delayed(cg)(ATdotA, AT @ v) for v in data)
103107

104-
self._coeff = np.array([r[0] for r in results])
108+
self._locked_fit = np.array([r[0] for r in results])
105109

106110
def fit_predict(self, index: int | None = None, **kwargs):
107111
"""Return the corrected volume using B-spline interpolation."""
108112
from scipy.interpolate import BSpline
109113

110114
# Fit the BSpline basis on all data
111-
if self._coeff is None:
112-
self._fit(n_jobs=kwargs.pop("n_jobs", None))
115+
if self._locked_fit is None:
116+
self._fit(index, n_jobs=kwargs.pop("n_jobs", None), **kwargs)
113117

114118
if index is None: # If no index, just fit the data.
115119
return None
@@ -120,7 +124,7 @@ def fit_predict(self, index: int | None = None, **kwargs):
120124

121125
# A is 1 (num. timepoints) x C (num. coeff)
122126
# self._coeff is V (num. voxels) x K - 4
123-
predicted = np.squeeze(A @ self._coeff.T)
127+
predicted = np.squeeze(A @ self._locked_fit.T)
124128

125129
brainmask = self._dataset.brainmask
126130
datashape = self._dataset.dataobj.shape[:3]

0 commit comments

Comments
 (0)